-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Description
Description
Kudos to @kovdan01 for initial analysis of this issue.
It turns out that adjoints for active values in loops are just plain wrong. Consider the reproducer. As one case see, the gradient for repeat_while_loop
is wrong, while gradient for while_loop
is correct. Even more, if we'd replace the code in loop by result *= x
then repeat_while_loop
case will start working.
Why it is so?
The loop body for repeat_while_loop
looks like as follows (removed loop condition calculation for brevity):
br bb1 // id: %10
bb1: // Preds: bb2 bb0
%11 = metatype $@thin Float.Type // user: %16
%12 = begin_access [read] [static] %2 // users: %14, %13
%13 = load [trivial] %12 // user: %16
end_access %12 // id: %14
// function_ref static Float.* infix(_:_:)
%15 = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %16
%16 = apply %15(%13, %0, %11) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %18
%17 = begin_access [modify] [static] %2 // users: %18, %19
store %16 to [trivial] %17 // id: %18
end_access %17 // id: %19
...
cond_br %39, bb2, bb3 // id: %40
bb2: // Preds: bb1
br bb1
bb3: // Preds: bb1
...
The key thing here is active %13
(which is essentially a result
value), so we need to generate adjoint for it. AutoDiff code uses notion of "adjoint for value X in basic block Y. This is fine for code without loops. And is just plain wrong for values inside loops as it should be "adjoint for value X in basic block Y on loop iteration Z
. The values are different at different loop iterations. Thus their adjoints should be distinct as well. Without this we're ending with artificial adjoint accumulations (since single adjoint value is shared between loop iterations) and wrong results.
So, when generating pullback for this loop body we need to ensure that initial value for adjoint of %13
is zero on each iteration. And then perform the usual pullback cloning that involves adjoint value generation and accumulation. We don't do this, so essentially we're accumulating into adjoint from the previous loop iteration.
Sure, if things are so broken, why we have not noticed this before? I would say: coincidence.
For the code like result *= x
we do not have these extra active values, Float.*=
takes adjoint buffer as an inout argument and perform proper adjoint generation there.
while / for
case is more interesting. Here the code looks like as follows:
br bb1 // id: %10
bb1: // Preds: bb2 bb0
...
cond_br %21, bb2, bb3 // id: %22
bb2: // Preds: bb1
%23 = metatype $@thin Float.Type // user: %28
%24 = begin_access [read] [static] %2 // users: %26, %25
%25 = load [trivial] %24 // user: %28
end_access %24 // id: %26
// function_ref static Float.* infix(_:_:)
%27 = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %28
%28 = apply %27(%25, %0, %23) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %30
%29 = begin_access [modify] [static] %2 // users: %30, %31
store %28 to [trivial] %29 // id: %30
end_access %29 // id: %31
...
br bb1 // id: %41
bb3: // Preds: bb1
%42 = begin_access [read] [static] %2 // users: %44, %43
%43 = load [trivial] %42 // user: %47
end_access %42 // id: %44
So, we're having loop header (bb1
) first, then loop body and finally the code after loop bb3
. Now, we're having a code that propagates adjoints of active values into predecessor BBs while doing function traverse in reverse post-order. Here, we first visit bb3
, then bb1
. Inside bb1
we're realizing that there are active values (%25
and %28
) in predecessor bb2
, so we are taking their adjoints in bb1
and propagating them into bb2
. Since no adjoints were defined before, they will be zero initialized and further propagated. And since coincidentally it is a loop header, we're ending into zero-initializing them in each loop iteration in a pullback as pullback to loop header will be executed after loop body. Everything magically works.
The situation with repeat loop is in reverse, there is no "loop header" BB in the common sense, there is a "loop footer" instead fused into loop body. So, the adjoints for %13
and %16
will be first zero-initialized in pullback block corresponding to bb3
and then further propagated to bb1
. So, no zero-initialization on each loop iteration, adjoint values will be reused from previous loop iteration, and wrong results will be provided.
It seems to me that we need to perform explicit adjoint zeroing inside loop headers in pullback cloner
Reproduction
import _Differentiation
func repeat_while_loop(_ x: Float) -> Float {
var result = x
var i = 0
repeat {
result = result * x
i += 1
} while i < 2
return result
}
func while_loop(_ x: Float) -> Float {
var result = x
var i = 0
while i < 2 {
result = result * x
i += 1
}
return result
}
print(valueWithGradient(at: 2, of: repeat_while_loop))
print(valueWithGradient(at: 2, of: while_loop))
Expected behavior
Correct gradient calculation for both cases
Environment
Swift version 6.2-dev (LLVM e404f8897f17aff, Swift 5a68861)
Target: arm64-apple-macosx15.0
Additional information
No response