Skip to content

Commit a946ec9

Browse files
bgoguldan-zheng
authored andcommitted
[AutoDiff] Fix aggregate adjoint value accumulation bug. (#28477)
`PullbackEmitter::accumulateAdjointsDirect` did not handle the case where `lhs` is an `Aggregate` value and `rhs` is a `Concrete` value. Resolves TF-943.
1 parent b935486 commit a946ec9

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7740,8 +7740,9 @@ PullbackEmitter::accumulateAdjointsDirect(AdjointValue lhs, AdjointValue rhs,
77407740
// (x, y)
77417741
case AdjointValueKind::Aggregate:
77427742
switch (rhs.getKind()) {
7743-
// (x, y) + z => (x + z.0, y + z.1)
7743+
// (x, y) + z => (z.0 + x, z.1 + y)
77447744
case AdjointValueKind::Concrete:
7745+
return accumulateAdjointsDirect(rhs, lhs, loc);
77457746
// x + 0 => x
77467747
case AdjointValueKind::Zero:
77477748
return lhs;

test/AutoDiff/simple_math.swift

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
// RUN: %target-swift-frontend -Xllvm -sil-print-after=differentiation %s -emit-sil -o /dev/null 2>&1 | %FileCheck %s
12
// RUN: %target-run-simple-swift
23
// NOTE(TF-813): verify that enabling forward-mode does not affect reverse-mode.
34
// RUN: %target_run_simple_swift_forward_mode_differentiation
@@ -203,7 +204,7 @@ SimpleMathTests.test("StructMemberwiseInitializer") {
203204
let foo = Foo(stored: input)
204205
return foo.computed * foo.stored
205206
}
206-
expectEqual(16, 𝛁product)
207+
expectEqual(48, 𝛁product)
207208

208209
struct Custom : AdditiveArithmetic, Differentiable {
209210
var x: Float
@@ -350,4 +351,37 @@ SimpleMathTests.test("ForceUnwrapping") {
350351
expectEqual((1, 2), forceUnwrap(Float(2)))
351352
}
352353

354+
// CHECK-LABEL: sil hidden [ossa] @AD__${{.*}}jumpTimesTwo{{.*}}pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__$s4nullyycfU18_12jumpTimesTwoL_5modelSfAAyycfU18_14SmallTestModelL_V_tF_bb0__PB__src_0_wrt_0) -> SmallTestModel.TangentVector {
355+
// CHECK: bb0([[DX:%.*]] : $Float, [[PB_STRUCT:%.*]] : {{.*}}):
356+
// CHECK: ([[PB0:%.*]], [[PB1:%.*]]) = destructure_struct [[PB_STRUCT]]
357+
// CHECK: [[ADJ_TUPLE:%.*]] = apply [[PB1]]([[DX]]) : $@callee_guaranteed (Float) -> (Float, Float)
358+
// CHECK: ([[TMP0:%.*]], [[ADJ_CONCRETE:%.*]]) = destructure_tuple [[ADJ_TUPLE]] : $(Float, Float)
359+
// CHECK: [[TMP1:%.*]] = apply [[PB0]]([[TMP0]]) : $@callee_guaranteed (Float) -> SmallTestModel.TangentVector
360+
// CHECK: [[ADJ_STRUCT_FIELD:%.*]] = destructure_struct [[TMP1]] : $SmallTestModel.TangentVector
361+
// CHECK: [[TMP_RES:%.*]] = alloc_stack $Float
362+
// CHECK: [[TMP_ADJ_STRUCT_FIELD:%.*]] = alloc_stack $Float
363+
// CHECK: [[TMP_ADJ_CONCRETE:%.*]] = alloc_stack $Float
364+
// CHECK: store [[ADJ_STRUCT_FIELD]] to [trivial] [[TMP_ADJ_STRUCT_FIELD]] : $*Float
365+
// CHECK: store [[ADJ_CONCRETE]] to [trivial] [[TMP_ADJ_CONCRETE]] : $*Float
366+
// CHECK: [[PLUS_EQUAL:%.*]] = witness_method $Float, #AdditiveArithmetic."+"
367+
// CHECK: %{{.*}} = apply [[PLUS_EQUAL]]<Float>([[TMP_RES]], [[TMP_ADJ_CONCRETE]], [[TMP_ADJ_STRUCT_FIELD]], {{.*}})
368+
// CHECK: [[RES:%.*]] = load [trivial] [[TMP_RES]] : $*Float
369+
// CHECK: [[RES_STRUCT:%.*]] = struct $SmallTestModel.TangentVector ([[RES]] : $Float)
370+
// CHECK: return [[RES_STRUCT]] : $SmallTestModel.TangentVector
371+
// CHECK: }
372+
373+
SimpleMathTests.test("Struct") {
374+
// TF-943: Test adjoint value accumulation for aggregate lhs and concrete rhs.
375+
struct SmallTestModel : Differentiable {
376+
public var jump: Float = 3.0
377+
@differentiable public func callAsFunction() -> Float { return jump }
378+
}
379+
380+
func jumpTimesTwo(model: SmallTestModel) -> Float{
381+
return model() + model.jump
382+
}
383+
let grads = gradient(at: SmallTestModel(), in: jumpTimesTwo)
384+
expectEqual(2.0, grads.jump)
385+
}
386+
353387
runAllTests()

0 commit comments

Comments
 (0)