Skip to content

Commit a48880d

Browse files
author
marcrasi
authored
[AutoDiff upstream] add more validation tests (#31190)
1 parent 52c3d70 commit a48880d

File tree

7 files changed

+274
-0
lines changed

7 files changed

+274
-0
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s
2+
3+
import _Differentiation
4+
5+
func foo<T: Numeric>(_ x: T, _ y: T) -> T { x * y }
6+
7+
@derivative(of: foo)
8+
func foo_vjp<T: Numeric & Differentiable>(_ x: T, _ y: T) -> (
9+
value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)
10+
) {
11+
(foo(x, y), { _ in (.zero, .zero) })
12+
}
13+
14+
@differentiable
15+
func differentiate_foo_wrt_0(_ x: Float) -> Float {
16+
foo(x, 1)
17+
}
18+
19+
// CHECK-LABEL: sil hidden @{{.*}}differentiate_foo_wrt_0{{.*}}__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
20+
// CHECK: bb0
21+
// CHECK: [[FOO_ORIG:%.*]] = function_ref @{{.*}}foo{{.*}} : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0
22+
// CHECK: [[FOO_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_ORIG]]<Float>() : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0
23+
// CHECK: [[FOO_JVP:%.*]] = differentiability_witness_function [jvp] [parameters 0 1] [results 0] <T where T : Numeric, T : Differentiable> @{{.*}}foo{{.*}} : $@convention(thin) <T where T : Numeric> (@in_guaranteed T, @in_guaranteed T) -> @out T
24+
// CHECK: [[FOO_JVP_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_JVP]]<Float>() : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric, τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_1) -> @out τ_0_2 for <τ_0_0.TangentVector, τ_0_0.TangentVector, τ_0_0.TangentVector>)
25+
// CHECK: [[FOO_JVP_SUBSET_THUNK_THIN:%.*]] = function_ref @AD__orig_{{.*}}foo{{.*}}_src_0_wrt_0_jvp_subset_parameters_thunk : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
26+
// CHECK: [[FOO_JVP_SUBSET_THUNK:%.*]] = thin_to_thick_function [[FOO_JVP_SUBSET_THUNK_THIN]] : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) to $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
27+
// CHECK: [[FOO_VJP:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] <T where T : Numeric, T : Differentiable> @{{.*}}foo{{.*}} : $@convention(thin) <T where T : Numeric> (@in_guaranteed T, @in_guaranteed T) -> @out T
28+
// CHECK: [[FOO_VJP_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_VJP]]<Float>() : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric, τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <τ_0_0.TangentVector, τ_0_0.TangentVector, τ_0_0.TangentVector>)
29+
// CHECK: [[FOO_VJP_SUBSET_THUNK_THIN:%.*]] = function_ref @AD__orig_{{.*}}foo{{.*}}_src_0_wrt_0_vjp_subset_parameters_thunk : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
30+
// CHECK: [[FOO_VJP_SUBSET_THUNK:%.*]] = thin_to_thick_function [[FOO_VJP_SUBSET_THUNK_THIN]] : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) to $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
31+
// CHECK: [[FOO_DIFF:%.*]] = differentiable_function [parameters 0] [[FOO_FLOAT]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> @out Float with_derivative {[[FOO_JVP_SUBSET_THUNK]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[FOO_VJP_SUBSET_THUNK]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)}
32+
// CHECK: }
33+
34+
func inoutIndirect<T: Differentiable, U: Differentiable, V: Differentiable>(
35+
_ x: T, _ y: inout U, _ z: V
36+
) {}
37+
38+
@derivative(of: inoutIndirect)
39+
func vjpInoutIndirect<T: Differentiable, U: Differentiable, V: Differentiable>(
40+
_ x: T, _ y: inout U, _ z: V
41+
) -> (
42+
value: Void, pullback: (inout U.TangentVector) -> (T.TangentVector, V.TangentVector)
43+
) {
44+
return ((), { dy in
45+
return (.zero, .zero)
46+
})
47+
}
48+
49+
@differentiable(wrt: x)
50+
@differentiable(wrt: y)
51+
@differentiable
52+
func inoutIndirectCaller<T: Differentiable, U: Differentiable, V: Differentiable>(
53+
_ x: T, _ y: U, _ z: V
54+
) -> U {
55+
var result = y
56+
inoutIndirect(x, &result, z)
57+
return result
58+
}
59+
60+
@differentiable(wrt: (x, z))
61+
func concreteInoutIndirectCaller(
62+
_ x: Float, _ y: Double, _ z: Float
63+
) -> Double {
64+
return inoutIndirectCaller(x, y, z)
65+
}
66+
67+
// CHECK-LABEL: sil shared [transparent] [serialized] [thunk] @AD__$sSdSfSdSfIegnrrr_SdS2fIegnrr_TR_src_0_wrt_0_2_pullback_index_subset_thunk : $@convention(thin) (@in_guaranteed Double, @guaranteed @callee_guaranteed (@in_guaranteed Double) -> (@out Float, @out Double, @out Float)) -> (@out Float, @out Float) {
68+
// CHECK: bb0(%0 : $*Float, %1 : $*Float, %2 : $*Double, %3 : $@callee_guaranteed (@in_guaranteed Double) -> (@out Float, @out Double, @out Float)):
69+
// CHECK: %4 = alloc_stack $Double
70+
// CHECK: %5 = apply %3(%0, %4, %1, %2) : $@callee_guaranteed (@in_guaranteed Double) -> (@out Float, @out Double, @out Float)
71+
// CHECK: destroy_addr %4 : $*Double
72+
// CHECK: dealloc_stack %4 : $*Double
73+
// CHECK: %8 = tuple ()
74+
// CHECK: return %8 : $()
75+
// CHECK: }
76+
77+
// CHECK-LABEL: sil shared [transparent] [serialized] [thunk] @AD__$s13TangentVector16_Differentiation14DifferentiablePQy_AaDQzAaDQy0_Ieglrr_AeFIeglr_AbCRzAbCR_AbCR0_r1_lTR_src_0_wrt_0_1_pullback_index_subset_thunk : $@convention(thin) <τ_0_0, τ_0_1, τ_0_2 where τ_0_0 : Differentiable, τ_0_1 : Differentiable, τ_0_2 : Differentiable> (@inout τ_0_1.TangentVector, @guaranteed @callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)) -> @out τ_0_0.TangentVector {
78+
// CHECK: bb0(%0 : $*τ_0_0.TangentVector, %1 : $*τ_0_1.TangentVector, %2 : $@callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)):
79+
// CHECK: %3 = alloc_stack $τ_0_2.TangentVector
80+
// CHECK: %4 = apply %2(%0, %3, %1) : $@callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)
81+
// CHECK: destroy_addr %3 : $*τ_0_2.TangentVector
82+
// CHECK: dealloc_stack %3 : $*τ_0_2.TangentVector
83+
// CHECK: %7 = tuple ()
84+
// CHECK: return %7 : $()
85+
// CHECK: }
86+
87+
// CHECK-LABEL: sil shared [transparent] [serialized] [thunk] @AD__$s13TangentVector16_Differentiation14DifferentiablePQy_AaDQzAaDQy0_Ieglrr_AEIegl_AbCRzAbCR_AbCR0_r1_lTR_src_0_wrt_1_pullback_index_subset_thunk : $@convention(thin) <τ_0_0, τ_0_1, τ_0_2 where τ_0_0 : Differentiable, τ_0_1 : Differentiable, τ_0_2 : Differentiable> (@inout τ_0_1.TangentVector, @guaranteed @callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)) -> () {
88+
// CHECK: bb0(%0 : $*τ_0_1.TangentVector, %1 : $@callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)):
89+
// CHECK: %2 = alloc_stack $τ_0_0.TangentVector
90+
// CHECK: %3 = alloc_stack $τ_0_2.TangentVector
91+
// CHECK: %4 = apply %1(%2, %3, %0) : $@callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)
92+
// CHECK: destroy_addr %2 : $*τ_0_0.TangentVector
93+
// CHECK: destroy_addr %3 : $*τ_0_2.TangentVector
94+
// CHECK: dealloc_stack %3 : $*τ_0_2.TangentVector
95+
// CHECK: dealloc_stack %2 : $*τ_0_0.TangentVector
96+
// CHECK: %9 = tuple ()
97+
// CHECK: return %9 : $()
98+
// CHECK: }
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import StdlibUnittest
2+
import _Differentiation
3+
4+
import module1
5+
6+
var Tests = TestSuite("CrossModuleDerivativeAttr")
7+
8+
Tests.test("CrossFile") {
9+
let grad = gradient(at: 0, in: fCrossFile)
10+
expectEqual(10, grad)
11+
}
12+
13+
runAllTests()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
public func fCrossFile(_ x: Float) -> Float { x }
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import _Differentiation
2+
3+
@derivative(of: fCrossFile)
4+
public func vjpCrossFile(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
5+
(x, { 10 * $0 })
6+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-build-swift -working-directory %t -I%t -parse-as-library -emit-module -module-name module1 -emit-module-path %t/module1.swiftmodule -emit-library -static %S/Inputs/cross_module_derivative_attr/module1/module1.swift %S/Inputs/cross_module_derivative_attr/module1/module1_other_file.swift -Xfrontend -validate-tbd-against-ir=none
3+
// RUN: %target-build-swift -I%t -L%t %S/Inputs/cross_module_derivative_attr/main/main.swift -o %t/a.out -lm -lmodule1 -Xfrontend -validate-tbd-against-ir=none
4+
// RUN: %target-run %t/a.out
5+
// REQUIRES: executable_test
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import StdlibUnittest
5+
#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
6+
import Darwin.C
7+
#else
8+
import Glibc
9+
#endif
10+
import DifferentiationUnittest
11+
12+
var CustomDerivativesTests = TestSuite("CustomDerivatives")
13+
14+
// Specify non-differentiable functions.
15+
// These will be wrapped in `differentiableFunction` and tested.
16+
17+
func unary(_ x: Tracked<Float>) -> Tracked<Float> {
18+
var x = x
19+
x *= 2
20+
return x
21+
}
22+
23+
func binary(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
24+
var x = x
25+
x *= y
26+
return x
27+
}
28+
29+
CustomDerivativesTests.testWithLeakChecking("differentiableFunction-unary") {
30+
let diffableUnary = differentiableFunction { x in
31+
(value: unary(x), pullback: { v in v * x * 2 })
32+
}
33+
expectEqual(20, gradient(at: 10, in: diffableUnary))
34+
// Test differentiation of @differentiable function.
35+
expectEqual(20, gradient(at: 10, in: { diffableUnary($0) }))
36+
expectEqual(40, gradient(at: 10, in: { diffableUnary($0) * 2 }))
37+
}
38+
39+
CustomDerivativesTests.testWithLeakChecking("differentiableFunction-binary") {
40+
let diffableBinary = differentiableFunction { (x, y) in
41+
(value: binary(x, y), pullback: { v in (v * y, v * x) })
42+
}
43+
expectEqual((10, 5), gradient(at: 5, 10, in: diffableBinary))
44+
// Test differentiation of @differentiable function.
45+
expectEqual((10, 5), gradient(at: 5, 10, in: { diffableBinary($0, $1) }))
46+
expectEqual((20, 10), gradient(at: 5, 10, in: { diffableBinary($0, $1) * 2 }))
47+
}
48+
49+
CustomDerivativesTests.testWithLeakChecking("SumOfGradPieces") {
50+
var grad: Tracked<Float> = 0
51+
func addToGrad(_ x: inout Tracked<Float>) { grad += x }
52+
_ = gradient(at: 4) { (x: Tracked<Float>) in
53+
x.withDerivative(addToGrad)
54+
* x.withDerivative(addToGrad)
55+
* x.withDerivative(addToGrad)
56+
}
57+
expectEqual(48, grad)
58+
}
59+
60+
CustomDerivativesTests.testWithLeakChecking("ModifyGradientOfSum") {
61+
expectEqual(30, gradient(at: 4) { (x: Tracked<Float>) in
62+
x.withDerivative { $0 *= 10 } + x.withDerivative { $0 *= 20 }
63+
})
64+
}
65+
66+
CustomDerivativesTests.testWithLeakChecking("WithoutDerivative") {
67+
expectEqual(0, gradient(at: Tracked<Float>(4)) { x in
68+
withoutDerivative(at: x) { x in
69+
Tracked<Float>(sinf(x.value) + cosf(x.value))
70+
}
71+
})
72+
}
73+
74+
runAllTests()
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import StdlibUnittest
5+
import _Differentiation
6+
7+
var SubsetParameterThunkTests = TestSuite("SubsetParameterThunks")
8+
9+
func inoutDirect(_ x: Float, _ y: inout Double, _ z: Float) {}
10+
11+
@derivative(of: inoutDirect)
12+
func vjpInoutDirect(_ x: Float, _ y: inout Double, _ z: Float) -> (
13+
value: Void, pullback: (inout Double) -> (Float, Float)
14+
) {
15+
return ((), { dy in
16+
dy = 3
17+
return (2, 4)
18+
})
19+
}
20+
21+
SubsetParameterThunkTests.test("InoutParametersDirect") {
22+
@differentiable(wrt: x)
23+
@differentiable(wrt: y)
24+
@differentiable(wrt: z)
25+
func inoutDirectCaller(_ x: Float, _ y: Double, _ z: Float) -> Double {
26+
var result = y
27+
inoutDirect(x, &result, z)
28+
return result
29+
}
30+
31+
let x: Float = 3
32+
let y: Double = 4
33+
let z: Float = 5
34+
expectEqual((2, 3, 4), gradient(at: x, y, z, in: inoutDirectCaller))
35+
expectEqual((3, 4), gradient(at: y, z, in: { y, z in inoutDirectCaller(x, y, z) }))
36+
expectEqual((2, 4), gradient(at: x, z, in: { x, z in inoutDirectCaller(x, y, z) }))
37+
expectEqual((2, 3), gradient(at: x, y, in: { x, y in inoutDirectCaller(x, y, z) }))
38+
}
39+
40+
func inoutIndirect<T: Differentiable, U: Differentiable, V: Differentiable>(
41+
_ x: T, _ y: inout U, _ z: V
42+
) {}
43+
44+
@derivative(of: inoutIndirect)
45+
func vjpInoutIndirect<T: Differentiable, U: Differentiable, V: Differentiable>(
46+
_ x: T, _ y: inout U, _ z: V
47+
) -> (
48+
value: Void, pullback: (inout U.TangentVector) -> (T.TangentVector, V.TangentVector)
49+
) {
50+
return ((), { dy in
51+
return (.zero, .zero)
52+
})
53+
}
54+
55+
SubsetParameterThunkTests.test("InoutParametersIndirect") {
56+
@differentiable(wrt: x)
57+
@differentiable(wrt: y)
58+
@differentiable(wrt: z)
59+
@differentiable
60+
func inoutIndirectCaller<T: Differentiable, U: Differentiable, V: Differentiable>(
61+
_ x: T, _ y: U, _ z: V
62+
) -> U {
63+
var result = y
64+
inoutIndirect(x, &result, z)
65+
return result
66+
}
67+
68+
let x: Float = 3
69+
let y: Double = 4
70+
let z: Float = 5
71+
expectEqual((0, 1, 0), gradient(at: x, y, z, in: inoutIndirectCaller))
72+
expectEqual((1, 0), gradient(at: y, z, in: { y, z in inoutIndirectCaller(x, y, z) }))
73+
expectEqual((0, 0), gradient(at: x, z, in: { x, z in inoutIndirectCaller(x, y, z) }))
74+
expectEqual((0, 1), gradient(at: x, y, in: { x, y in inoutIndirectCaller(x, y, z) }))
75+
}
76+
77+
runAllTests()

0 commit comments

Comments
 (0)