|
| 1 | +// RUN: %target-swift-frontend -Xllvm -sil-print-after=differentiation %s -emit-sil -o /dev/null 2>&1 | %FileCheck %s |
1 | 2 | // RUN: %target-run-simple-swift
|
2 | 3 | // NOTE(TF-813): verify that enabling forward-mode does not affect reverse-mode.
|
3 | 4 | // RUN: %target_run_simple_swift_forward_mode_differentiation
|
@@ -203,7 +204,7 @@ SimpleMathTests.test("StructMemberwiseInitializer") {
|
203 | 204 | let foo = Foo(stored: input)
|
204 | 205 | return foo.computed * foo.stored
|
205 | 206 | }
|
206 |
| - expectEqual(16, 𝛁product) |
| 207 | + expectEqual(48, 𝛁product) |
207 | 208 |
|
208 | 209 | struct Custom : AdditiveArithmetic, Differentiable {
|
209 | 210 | var x: Float
|
@@ -350,4 +351,37 @@ SimpleMathTests.test("ForceUnwrapping") {
|
350 | 351 | expectEqual((1, 2), forceUnwrap(Float(2)))
|
351 | 352 | }
|
352 | 353 |
|
| 354 | +// CHECK-LABEL: sil hidden [ossa] @AD__${{.*}}jumpTimesTwo{{.*}}pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__$s4nullyycfU18_12jumpTimesTwoL_5modelSfAAyycfU18_14SmallTestModelL_V_tF_bb0__PB__src_0_wrt_0) -> SmallTestModel.TangentVector { |
| 355 | +// CHECK: bb0([[DX:%.*]] : $Float, [[PB_STRUCT:%.*]] : {{.*}}): |
| 356 | +// CHECK: ([[PB0:%.*]], [[PB1:%.*]]) = destructure_struct [[PB_STRUCT]] |
| 357 | +// CHECK: [[ADJ_TUPLE:%.*]] = apply [[PB1]]([[DX]]) : $@callee_guaranteed (Float) -> (Float, Float) |
| 358 | +// CHECK: ([[TMP0:%.*]], [[ADJ_CONCRETE:%.*]]) = destructure_tuple [[ADJ_TUPLE]] : $(Float, Float) |
| 359 | +// CHECK: [[TMP1:%.*]] = apply [[PB0]]([[TMP0]]) : $@callee_guaranteed (Float) -> SmallTestModel.TangentVector |
| 360 | +// CHECK: [[ADJ_STRUCT_FIELD:%.*]] = destructure_struct [[TMP1]] : $SmallTestModel.TangentVector |
| 361 | +// CHECK: [[TMP_RES:%.*]] = alloc_stack $Float |
| 362 | +// CHECK: [[TMP_ADJ_STRUCT_FIELD:%.*]] = alloc_stack $Float |
| 363 | +// CHECK: [[TMP_ADJ_CONCRETE:%.*]] = alloc_stack $Float |
| 364 | +// CHECK: store [[ADJ_STRUCT_FIELD]] to [trivial] [[TMP_ADJ_STRUCT_FIELD]] : $*Float |
| 365 | +// CHECK: store [[ADJ_CONCRETE]] to [trivial] [[TMP_ADJ_CONCRETE]] : $*Float |
| 366 | +// CHECK: [[PLUS_EQUAL:%.*]] = witness_method $Float, #AdditiveArithmetic."+" |
| 367 | +// CHECK: %{{.*}} = apply [[PLUS_EQUAL]]<Float>([[TMP_RES]], [[TMP_ADJ_CONCRETE]], [[TMP_ADJ_STRUCT_FIELD]], {{.*}}) |
| 368 | +// CHECK: [[RES:%.*]] = load [trivial] [[TMP_RES]] : $*Float |
| 369 | +// CHECK: [[RES_STRUCT:%.*]] = struct $SmallTestModel.TangentVector ([[RES]] : $Float) |
| 370 | +// CHECK: return [[RES_STRUCT]] : $SmallTestModel.TangentVector |
| 371 | +// CHECK: } |
| 372 | + |
| 373 | +SimpleMathTests.test("Struct") { |
| 374 | + // TF-943: Test adjoint value accumulation for aggregate lhs and concrete rhs. |
| 375 | + struct SmallTestModel : Differentiable { |
| 376 | + public var jump: Float = 3.0 |
| 377 | + @differentiable public func callAsFunction() -> Float { return jump } |
| 378 | + } |
| 379 | + |
| 380 | + func jumpTimesTwo(model: SmallTestModel) -> Float{ |
| 381 | + return model() + model.jump |
| 382 | + } |
| 383 | + let grads = gradient(at: SmallTestModel(), in: jumpTimesTwo) |
| 384 | + expectEqual(2.0, grads.jump) |
| 385 | +} |
| 386 | + |
353 | 387 | runAllTests()
|
0 commit comments