|
| 1 | +// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s |
| 2 | + |
| 3 | +import _Differentiation |
| 4 | + |
| 5 | +func foo<T: Numeric>(_ x: T, _ y: T) -> T { x * y } |
| 6 | + |
| 7 | +@derivative(of: foo) |
| 8 | +func foo_vjp<T: Numeric & Differentiable>(_ x: T, _ y: T) -> ( |
| 9 | + value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector) |
| 10 | +) { |
| 11 | + (foo(x, y), { _ in (.zero, .zero) }) |
| 12 | +} |
| 13 | + |
| 14 | +@differentiable |
| 15 | +func differentiate_foo_wrt_0(_ x: Float) -> Float { |
| 16 | + foo(x, 1) |
| 17 | +} |
| 18 | + |
| 19 | +// CHECK-LABEL: sil hidden @{{.*}}differentiate_foo_wrt_0{{.*}}__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { |
| 20 | +// CHECK: bb0 |
| 21 | +// CHECK: [[FOO_ORIG:%.*]] = function_ref @{{.*}}foo{{.*}} : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0 |
| 22 | +// CHECK: [[FOO_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_ORIG]]<Float>() : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0 |
| 23 | +// CHECK: [[FOO_JVP:%.*]] = differentiability_witness_function [jvp] [parameters 0 1] [results 0] <T where T : Numeric, T : Differentiable> @{{.*}}foo{{.*}} : $@convention(thin) <T where T : Numeric> (@in_guaranteed T, @in_guaranteed T) -> @out T |
| 24 | +// CHECK: [[FOO_JVP_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_JVP]]<Float>() : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric, τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_1) -> @out τ_0_2 for <τ_0_0.TangentVector, τ_0_0.TangentVector, τ_0_0.TangentVector>) |
| 25 | +// CHECK: [[FOO_JVP_SUBSET_THUNK_THIN:%.*]] = function_ref @AD__orig_{{.*}}foo{{.*}}_src_0_wrt_0_jvp_subset_parameters_thunk : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) |
| 26 | +// CHECK: [[FOO_JVP_SUBSET_THUNK:%.*]] = thin_to_thick_function [[FOO_JVP_SUBSET_THUNK_THIN]] : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) to $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) |
| 27 | +// CHECK: [[FOO_VJP:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] <T where T : Numeric, T : Differentiable> @{{.*}}foo{{.*}} : $@convention(thin) <T where T : Numeric> (@in_guaranteed T, @in_guaranteed T) -> @out T |
| 28 | +// CHECK: [[FOO_VJP_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_VJP]]<Float>() : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric, τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <τ_0_0.TangentVector, τ_0_0.TangentVector, τ_0_0.TangentVector>) |
| 29 | +// CHECK: [[FOO_VJP_SUBSET_THUNK_THIN:%.*]] = function_ref @AD__orig_{{.*}}foo{{.*}}_src_0_wrt_0_vjp_subset_parameters_thunk : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) |
| 30 | +// CHECK: [[FOO_VJP_SUBSET_THUNK:%.*]] = thin_to_thick_function [[FOO_VJP_SUBSET_THUNK_THIN]] : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) to $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) |
| 31 | +// CHECK: [[FOO_DIFF:%.*]] = differentiable_function [parameters 0] [[FOO_FLOAT]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> @out Float with_derivative {[[FOO_JVP_SUBSET_THUNK]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[FOO_VJP_SUBSET_THUNK]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)} |
| 32 | +// CHECK: } |
| 33 | + |
| 34 | +func inoutIndirect<T: Differentiable, U: Differentiable, V: Differentiable>( |
| 35 | + _ x: T, _ y: inout U, _ z: V |
| 36 | +) {} |
| 37 | + |
| 38 | +@derivative(of: inoutIndirect) |
| 39 | +func vjpInoutIndirect<T: Differentiable, U: Differentiable, V: Differentiable>( |
| 40 | + _ x: T, _ y: inout U, _ z: V |
| 41 | +) -> ( |
| 42 | + value: Void, pullback: (inout U.TangentVector) -> (T.TangentVector, V.TangentVector) |
| 43 | +) { |
| 44 | + return ((), { dy in |
| 45 | + return (.zero, .zero) |
| 46 | + }) |
| 47 | +} |
| 48 | + |
| 49 | +@differentiable(wrt: x) |
| 50 | +@differentiable(wrt: y) |
| 51 | +@differentiable |
| 52 | +func inoutIndirectCaller<T: Differentiable, U: Differentiable, V: Differentiable>( |
| 53 | + _ x: T, _ y: U, _ z: V |
| 54 | +) -> U { |
| 55 | + var result = y |
| 56 | + inoutIndirect(x, &result, z) |
| 57 | + return result |
| 58 | +} |
| 59 | + |
| 60 | +@differentiable(wrt: (x, z)) |
| 61 | +func concreteInoutIndirectCaller( |
| 62 | + _ x: Float, _ y: Double, _ z: Float |
| 63 | +) -> Double { |
| 64 | + return inoutIndirectCaller(x, y, z) |
| 65 | +} |
| 66 | + |
| 67 | +// CHECK-LABEL: sil shared [transparent] [serialized] [thunk] @AD__$sSdSfSdSfIegnrrr_SdS2fIegnrr_TR_src_0_wrt_0_2_pullback_index_subset_thunk : $@convention(thin) (@in_guaranteed Double, @guaranteed @callee_guaranteed (@in_guaranteed Double) -> (@out Float, @out Double, @out Float)) -> (@out Float, @out Float) { |
| 68 | +// CHECK: bb0(%0 : $*Float, %1 : $*Float, %2 : $*Double, %3 : $@callee_guaranteed (@in_guaranteed Double) -> (@out Float, @out Double, @out Float)): |
| 69 | +// CHECK: %4 = alloc_stack $Double |
| 70 | +// CHECK: %5 = apply %3(%0, %4, %1, %2) : $@callee_guaranteed (@in_guaranteed Double) -> (@out Float, @out Double, @out Float) |
| 71 | +// CHECK: destroy_addr %4 : $*Double |
| 72 | +// CHECK: dealloc_stack %4 : $*Double |
| 73 | +// CHECK: %8 = tuple () |
| 74 | +// CHECK: return %8 : $() |
| 75 | +// CHECK: } |
| 76 | + |
| 77 | +// CHECK-LABEL: sil shared [transparent] [serialized] [thunk] @AD__$s13TangentVector16_Differentiation14DifferentiablePQy_AaDQzAaDQy0_Ieglrr_AeFIeglr_AbCRzAbCR_AbCR0_r1_lTR_src_0_wrt_0_1_pullback_index_subset_thunk : $@convention(thin) <τ_0_0, τ_0_1, τ_0_2 where τ_0_0 : Differentiable, τ_0_1 : Differentiable, τ_0_2 : Differentiable> (@inout τ_0_1.TangentVector, @guaranteed @callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)) -> @out τ_0_0.TangentVector { |
| 78 | +// CHECK: bb0(%0 : $*τ_0_0.TangentVector, %1 : $*τ_0_1.TangentVector, %2 : $@callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)): |
| 79 | +// CHECK: %3 = alloc_stack $τ_0_2.TangentVector |
| 80 | +// CHECK: %4 = apply %2(%0, %3, %1) : $@callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector) |
| 81 | +// CHECK: destroy_addr %3 : $*τ_0_2.TangentVector |
| 82 | +// CHECK: dealloc_stack %3 : $*τ_0_2.TangentVector |
| 83 | +// CHECK: %7 = tuple () |
| 84 | +// CHECK: return %7 : $() |
| 85 | +// CHECK: } |
| 86 | + |
| 87 | +// CHECK-LABEL: sil shared [transparent] [serialized] [thunk] @AD__$s13TangentVector16_Differentiation14DifferentiablePQy_AaDQzAaDQy0_Ieglrr_AEIegl_AbCRzAbCR_AbCR0_r1_lTR_src_0_wrt_1_pullback_index_subset_thunk : $@convention(thin) <τ_0_0, τ_0_1, τ_0_2 where τ_0_0 : Differentiable, τ_0_1 : Differentiable, τ_0_2 : Differentiable> (@inout τ_0_1.TangentVector, @guaranteed @callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)) -> () { |
| 88 | +// CHECK: bb0(%0 : $*τ_0_1.TangentVector, %1 : $@callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)): |
| 89 | +// CHECK: %2 = alloc_stack $τ_0_0.TangentVector |
| 90 | +// CHECK: %3 = alloc_stack $τ_0_2.TangentVector |
| 91 | +// CHECK: %4 = apply %1(%2, %3, %0) : $@callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector) |
| 92 | +// CHECK: destroy_addr %2 : $*τ_0_0.TangentVector |
| 93 | +// CHECK: destroy_addr %3 : $*τ_0_2.TangentVector |
| 94 | +// CHECK: dealloc_stack %3 : $*τ_0_2.TangentVector |
| 95 | +// CHECK: dealloc_stack %2 : $*τ_0_0.TangentVector |
| 96 | +// CHECK: %9 = tuple () |
| 97 | +// CHECK: return %9 : $() |
| 98 | +// CHECK: } |
0 commit comments