Skip to content

Adjoint for active values in loops are just wrong #78264

@asl

Description

@asl

Description

Kudos to @kovdan01 for initial analysis of this issue.

It turns out that adjoints for active values in loops are just plain wrong. Consider the reproducer. As one case see, the gradient for repeat_while_loop is wrong, while gradient for while_loop is correct. Even more, if we'd replace the code in loop by result *= x then repeat_while_loop case will start working.

Why it is so?

The loop body for repeat_while_loop looks like as follows (removed loop condition calculation for brevity):

  br bb1                                          // id: %10

bb1:                                              // Preds: bb2 bb0
  %11 = metatype $@thin Float.Type                // user: %16
  %12 = begin_access [read] [static] %2           // users: %14, %13
  %13 = load [trivial] %12                        // user: %16
  end_access %12                                  // id: %14
  // function_ref static Float.* infix(_:_:)
  %15 = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %16
  %16 = apply %15(%13, %0, %11) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %18
  %17 = begin_access [modify] [static] %2         // users: %18, %19
  store %16 to [trivial] %17                      // id: %18
  end_access %17                                  // id: %19
...
  cond_br %39, bb2, bb3                           // id: %40

bb2:                                              // Preds: bb1
  br bb1
bb3:                                              // Preds: bb1
...

The key thing here is active %13 (which is essentially a result value), so we need to generate adjoint for it. AutoDiff code uses notion of "adjoint for value X in basic block Y. This is fine for code without loops. And is just plain wrong for values inside loops as it should be "adjoint for value X in basic block Y on loop iteration Z. The values are different at different loop iterations. Thus their adjoints should be distinct as well. Without this we're ending with artificial adjoint accumulations (since single adjoint value is shared between loop iterations) and wrong results.

So, when generating pullback for this loop body we need to ensure that initial value for adjoint of %13 is zero on each iteration. And then perform the usual pullback cloning that involves adjoint value generation and accumulation. We don't do this, so essentially we're accumulating into adjoint from the previous loop iteration.

Sure, if things are so broken, why we have not noticed this before? I would say: coincidence.

For the code like result *= x we do not have these extra active values, Float.*= takes adjoint buffer as an inout argument and perform proper adjoint generation there.

while / for case is more interesting. Here the code looks like as follows:

  br bb1                                          // id: %10

bb1:                                              // Preds: bb2 bb0
...
  cond_br %21, bb2, bb3                           // id: %22

bb2:                                              // Preds: bb1
  %23 = metatype $@thin Float.Type                // user: %28
  %24 = begin_access [read] [static] %2           // users: %26, %25
  %25 = load [trivial] %24                        // user: %28
  end_access %24                                  // id: %26
  // function_ref static Float.* infix(_:_:)
  %27 = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %28
  %28 = apply %27(%25, %0, %23) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %30
  %29 = begin_access [modify] [static] %2         // users: %30, %31
  store %28 to [trivial] %29                      // id: %30
  end_access %29                                  // id: %31
...
  br bb1                                          // id: %41

bb3:                                              // Preds: bb1
  %42 = begin_access [read] [static] %2           // users: %44, %43
  %43 = load [trivial] %42                        // user: %47
  end_access %42                                  // id: %44

So, we're having loop header (bb1) first, then loop body and finally the code after loop bb3. Now, we're having a code that propagates adjoints of active values into predecessor BBs while doing function traverse in reverse post-order. Here, we first visit bb3, then bb1. Inside bb1 we're realizing that there are active values (%25 and %28) in predecessor bb2, so we are taking their adjoints in bb1 and propagating them into bb2. Since no adjoints were defined before, they will be zero initialized and further propagated. And since coincidentally it is a loop header, we're ending into zero-initializing them in each loop iteration in a pullback as pullback to loop header will be executed after loop body. Everything magically works.

The situation with repeat loop is in reverse, there is no "loop header" BB in the common sense, there is a "loop footer" instead fused into loop body. So, the adjoints for %13 and %16 will be first zero-initialized in pullback block corresponding to bb3 and then further propagated to bb1. So, no zero-initialization on each loop iteration, adjoint values will be reused from previous loop iteration, and wrong results will be provided.

It seems to me that we need to perform explicit adjoint zeroing inside loop headers in pullback cloner

Reproduction

import _Differentiation

func repeat_while_loop(_ x: Float) -> Float {
    var result = x
    var i = 0
    repeat {
      result = result * x
      i += 1
    } while i < 2
    return result
}

func while_loop(_ x: Float) -> Float {
    var result = x
    var i = 0
    while i < 2 {
      result = result * x
      i += 1
    }
    return result
}

print(valueWithGradient(at: 2, of: repeat_while_loop))
print(valueWithGradient(at: 2, of: while_loop))

Expected behavior

Correct gradient calculation for both cases

Environment

Swift version 6.2-dev (LLVM e404f8897f17aff, Swift 5a68861)
Target: arm64-apple-macosx15.0

Additional information

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    AutoDiffbugA deviation from expected or documented behavior. Also: expected but undesirable behavior.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions