From 63c8c005d7b2ecaa85e77725113241f562adb73b Mon Sep 17 00:00:00 2001 From: Saleem Abdulrasool Date: Thu, 28 May 2020 15:55:03 -0700 Subject: [PATCH 1/8] test: add a ODR violation check for the static standard library The static version of the standard library was leaking symbols in the `llvm::` namespace which would result in ODR violations were the artifact linking against `LLVMSupport` (via another dependency). In particular, `llvm::SmallVector` and `llvm::StringSwitch` symbols were being leaked. This adds a test case specifically for the static variant of the library. The dynamic variant of the library is already tested in a separate test. --- test/stdlib/llvm-support-odr-violation-static.test-sh | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 test/stdlib/llvm-support-odr-violation-static.test-sh diff --git a/test/stdlib/llvm-support-odr-violation-static.test-sh b/test/stdlib/llvm-support-odr-violation-static.test-sh new file mode 100644 index 0000000000000..ca8bcc92ef490 --- /dev/null +++ b/test/stdlib/llvm-support-odr-violation-static.test-sh @@ -0,0 +1,5 @@ +// RUN: %llvm-nm --defined-only -C %target-static-stdlib-path/libswiftCore.a | %FileCheck --allow-empty %s +// CHECK-NOT: [^:]llvm:: + +// REQUIRES: OS=linux-gnu +// REQUIRES: static_stdlib From fa2d4dc6e62b4891ad50d0e74c1826de87cf9d39 Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Fri, 29 May 2020 15:28:30 -0700 Subject: [PATCH 2/8] [AST] DoStmt always contains a BraceStmt --- include/swift/AST/Stmt.h | 8 ++++---- lib/AST/ASTWalker.cpp | 2 +- lib/Sema/TypeCheckStmt.cpp | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/swift/AST/Stmt.h b/include/swift/AST/Stmt.h index a9bece28862f5..17e013dbbcfc0 100644 --- a/include/swift/AST/Stmt.h +++ b/include/swift/AST/Stmt.h @@ -539,11 +539,11 @@ class LabeledStmt : public Stmt { /// DoStmt - do statement, without any trailing clauses. class DoStmt : public LabeledStmt { SourceLoc DoLoc; - Stmt *Body; + BraceStmt *Body; public: DoStmt(LabeledStmtInfo labelInfo, SourceLoc doLoc, - Stmt *body, Optional implicit = None) + BraceStmt *body, Optional implicit = None) : LabeledStmt(StmtKind::Do, getDefaultImplicitFlag(implicit, doLoc), labelInfo), DoLoc(doLoc), Body(body) {} @@ -553,8 +553,8 @@ class DoStmt : public LabeledStmt { SourceLoc getStartLoc() const { return getLabelLocOrKeywordLoc(DoLoc); } SourceLoc getEndLoc() const { return Body->getEndLoc(); } - Stmt *getBody() const { return Body; } - void setBody(Stmt *s) { Body = s; } + BraceStmt *getBody() const { return Body; } + void setBody(BraceStmt *s) { Body = s; } static bool classof(const Stmt *S) { return S->getKind() == StmtKind::Do; } }; diff --git a/lib/AST/ASTWalker.cpp b/lib/AST/ASTWalker.cpp index 8d5fb75f0acef..e3340143015ec 100644 --- a/lib/AST/ASTWalker.cpp +++ b/lib/AST/ASTWalker.cpp @@ -1493,7 +1493,7 @@ Stmt *Traversal::visitGuardStmt(GuardStmt *US) { } Stmt *Traversal::visitDoStmt(DoStmt *DS) { - if (Stmt *S2 = doIt(DS->getBody())) + if (BraceStmt *S2 = cast_or_null(doIt(DS->getBody()))) DS->setBody(S2); else return nullptr; diff --git a/lib/Sema/TypeCheckStmt.cpp b/lib/Sema/TypeCheckStmt.cpp index 2a0f34f76ef6e..6bd303691dbb2 100644 --- a/lib/Sema/TypeCheckStmt.cpp +++ b/lib/Sema/TypeCheckStmt.cpp @@ -658,7 +658,7 @@ class StmtChecker : public StmtVisitor { Stmt *visitDoStmt(DoStmt *DS) { AddLabeledStmt loopNest(*this, DS); - Stmt *S = DS->getBody(); + BraceStmt *S = DS->getBody(); typeCheckStmt(S); DS->setBody(S); return DS; From 2fc2b157538f6ace81e1ed3c7d7d08e175852727 Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Fri, 29 May 2020 15:28:58 -0700 Subject: [PATCH 3/8] [Function builders] Align buildDo() implementation with the pitch. Rather than passing the result of buildBlock() into buildDo(), follow the (better) design from the function builders pitch by passing in the components from the block directly into buildDo(). This means that buildDo() will need to take separate parameters for each component, but allows buildDo() to treat the components separately. --- lib/Sema/BuilderTransform.cpp | 14 +++++++------- test/Constraints/function_builder.swift | 20 ++++++++++++++++++-- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/lib/Sema/BuilderTransform.cpp b/lib/Sema/BuilderTransform.cpp index b5b4b88c4c266..9b2bc350e03b3 100644 --- a/lib/Sema/BuilderTransform.cpp +++ b/lib/Sema/BuilderTransform.cpp @@ -293,6 +293,10 @@ class BuilderClosureVisitor } VarDecl *visitBraceStmt(BraceStmt *braceStmt) { + return visitBraceStmt(braceStmt, ctx.Id_buildBlock); + } + + VarDecl *visitBraceStmt(BraceStmt *braceStmt, Identifier builderFunction) { SmallVector expressions; auto addChild = [&](VarDecl *childVar) { if (!childVar) @@ -359,7 +363,7 @@ class BuilderClosureVisitor // Call Builder.buildBlock(... args ...) auto call = buildCallIfWanted(braceStmt->getStartLoc(), - ctx.Id_buildBlock, expressions, + builderFunction, expressions, /*argLabels=*/{ }); if (!call) return nullptr; @@ -380,17 +384,13 @@ class BuilderClosureVisitor return nullptr; } - auto childVar = visit(doStmt->getBody()); + auto childVar = visitBraceStmt(doStmt->getBody(), ctx.Id_buildDo); if (!childVar) return nullptr; auto childRef = buildVarRef(childVar, doStmt->getEndLoc()); - auto call = buildCallIfWanted(doStmt->getStartLoc(), ctx.Id_buildDo, - childRef, /*argLabels=*/{ }); - if (!call) - return nullptr; - return captureExpr(call, /*oneWay=*/true, doStmt); + return captureExpr(childRef, /*oneWay=*/true, doStmt); } CONTROL_FLOW_STMT(Yield) diff --git a/test/Constraints/function_builder.swift b/test/Constraints/function_builder.swift index e7f1f6795a8f9..eb5e6857c55c1 100644 --- a/test/Constraints/function_builder.swift +++ b/test/Constraints/function_builder.swift @@ -6,6 +6,10 @@ enum Either { case second(U) } +struct Do { + var value: T +} + @_functionBuilder struct TupleBuilder { static func buildBlock(_ t1: T1) -> (T1) { @@ -32,7 +36,19 @@ struct TupleBuilder { return (t1, t2, t3, t4, t5) } - static func buildDo(_ value: T) -> T { return value } + static func buildDo(_ t1: T1) -> Do<(T1)> { + .init(value: t1) + } + + static func buildDo(_ t1: T1, _ t2: T2) -> Do<(T1, T2)> { + .init(value: (t1, t2)) + } + + static func buildDo(_ t1: T1, _ t2: T2, _ t3: T3) + -> Do<(T1, T2, T3)> { + .init(value: (t1, t2, t3)) + } + static func buildIf(_ value: T?) -> T? { return value } static func buildEither(first value: T) -> Either { @@ -49,7 +65,7 @@ func tuplify(_ cond: Bool, @TupleBuilder body: (Bool) -> T) { print(body(cond)) } -// CHECK: (17, 3.14159, "Hello, DSL", (["nested", "do"], 6), Optional((2.71828, ["if", "stmt"]))) +// CHECK: (17, 3.14159, "Hello, DSL", main.Do<(Swift.Array, Swift.Int)>(value: (["nested", "do"], 6)), Optional((2.71828, ["if", "stmt"]))) let name = "dsl" tuplify(true) { 17 From 014466de5ab32254252491ae84106ab77a9ab22e Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Fri, 29 May 2020 20:59:44 -0700 Subject: [PATCH 4/8] [Function builders] Update test case --- test/ModuleInterface/function_builders.swift | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/test/ModuleInterface/function_builders.swift b/test/ModuleInterface/function_builders.swift index e471241774eda..7657e10665769 100644 --- a/test/ModuleInterface/function_builders.swift +++ b/test/ModuleInterface/function_builders.swift @@ -26,7 +26,15 @@ public struct TupleBuilder { return (t1, t2, t3, t4, t5) } - public static func buildDo(_ value: T) -> T { return value } + public static func buildDo(_ t1: T1, _ t2: T2) -> (T1, T2) { + return (t1, t2) + } + + public static func buildDo(_ t1: T1, _ t2: T2, _ t3: T3) + -> (T1, T2, T3) { + return (t1, t2, t3) + } + public static func buildIf(_ value: T?) -> T? { return value } } From 5942d0635c844f1d59d6d933a0059414687d97fe Mon Sep 17 00:00:00 2001 From: Ben Rimmington Date: Sat, 30 May 2020 15:38:34 +0100 Subject: [PATCH 5/8] Update README.md to fix CentOS 8 build status icon --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 033d95fac68e2..c219cecac3ab6 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ | **Ubuntu 16.04** | x86_64 | [![Build Status](https://ci.swift.org/job/oss-swift-incremental-RA-linux-ubuntu-16_04/lastCompletedBuild/badge/icon)](https://ci.swift.org/job/oss-swift-incremental-RA-linux-ubuntu-16_04)|[![Build Status](https://ci.swift.org/job/oss-swift-package-linux-ubuntu-16_04/lastCompletedBuild/badge/icon)](https://ci.swift.org/job/oss-swift-package-linux-ubuntu-16_04)| | **Ubuntu 18.04** | x86_64 | [![Build Status](https://ci.swift.org/job/oss-swift-incremental-RA-linux-ubuntu-18_04/lastCompletedBuild/badge/icon)](https://ci.swift.org/job/oss-swift-incremental-RA-linux-ubuntu-18_04)|[![Build Status](https://ci.swift.org/job/oss-swift-package-linux-ubuntu-18_04/lastCompletedBuild/badge/icon)](https://ci.swift.org/job/oss-swift-package-linux-ubuntu-18_04)| | **Ubuntu 20.04** | x86_64 | [![Build Status](https://ci.swift.org/job/oss-swift-package-ubuntu-20_04/lastCompletedBuild/badge/icon)](https://ci.swift.org/job/oss-swift-package-ubuntu-20_04)|[![Build Status](https://ci.swift.org/job/oss-swift-package-ubuntu-20_04/lastCompletedBuild/badge/icon)](https://ci.swift.org/job/oss-swift-package-ubuntu-20_04)| -| **CentOS 8** | x86_64 | [![Build Status](https://ci.swift.org/job/oss-swift-package-centos-8/lastCompletedBuild/badge/icon)](https://ci.swift.org/job/oss-swift-package-centos-8)|[![Build Status](https://ci.swift.org/job/oss-swift-package-linux-ubuntu-18_04/lastCompletedBuild/badge/icon)](https://ci.swift.org/job/oss-swift-package-centos-8)| +| **CentOS 8** | x86_64 | [![Build Status](https://ci.swift.org/job/oss-swift-package-centos-8/lastCompletedBuild/badge/icon)](https://ci.swift.org/job/oss-swift-package-centos-8)|[![Build Status](https://ci.swift.org/job/oss-swift-package-centos-8/lastCompletedBuild/badge/icon)](https://ci.swift.org/job/oss-swift-package-centos-8)| | **Amazon Linux 2** | x86_64 | [![Build Status](https://ci.swift.org/job/oss-swift-package-amazon-linux-2/lastCompletedBuild/badge/icon)](https://ci.swift.org/job/oss-swift-package-amazon-linux-2)|[![Build Status](https://ci.swift.org/job/oss-swift-package-amazon-linux-2/lastCompletedBuild/badge/icon)](https://ci.swift.org/job/oss-swift-package-amazon-linux-2)| **Swift Community-Hosted CI Platforms** From a6bb9742fec376151a23341dbbf45166eda548d9 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 30 May 2020 11:51:49 -0700 Subject: [PATCH 6/8] [AutoDiff] NFC: garden tests. (#32102) Remove extraneous code. Clarify test name: test/AutoDiff/SIL/Serialization/differentiable_function_type.swift. --- ...ntiable_function.swift => differentiable_function_type.swift} | 0 .../SILOptimizer/differentiation_control_flow_diagnostics.swift | 1 - test/AutoDiff/validation-test/simple_math.swift | 1 - 3 files changed, 2 deletions(-) rename test/AutoDiff/SIL/Serialization/{differentiable_function.swift => differentiable_function_type.swift} (100%) diff --git a/test/AutoDiff/SIL/Serialization/differentiable_function.swift b/test/AutoDiff/SIL/Serialization/differentiable_function_type.swift similarity index 100% rename from test/AutoDiff/SIL/Serialization/differentiable_function.swift rename to test/AutoDiff/SIL/Serialization/differentiable_function_type.swift diff --git a/test/AutoDiff/SILOptimizer/differentiation_control_flow_diagnostics.swift b/test/AutoDiff/SILOptimizer/differentiation_control_flow_diagnostics.swift index d4adb606194f8..6bc8cda64005f 100644 --- a/test/AutoDiff/SILOptimizer/differentiation_control_flow_diagnostics.swift +++ b/test/AutoDiff/SILOptimizer/differentiation_control_flow_diagnostics.swift @@ -121,7 +121,6 @@ enum Tree : Differentiable & AdditiveArithmetic { case branch(Float, Float) typealias TangentVector = Self - typealias AllDifferentiableVariables = Self static var zero: Self { .leaf(0) } // expected-error @+1 {{function is not differentiable}} diff --git a/test/AutoDiff/validation-test/simple_math.swift b/test/AutoDiff/validation-test/simple_math.swift index e66b3a358203c..d1b323f7f2c78 100644 --- a/test/AutoDiff/validation-test/simple_math.swift +++ b/test/AutoDiff/validation-test/simple_math.swift @@ -230,7 +230,6 @@ SimpleMathTests.test("StructMemberwiseInitializer") { // Custom initializer with `@differentiable`. @differentiable init(x: Float) { - print(x) self.x = x } } From 756788eb95e993176229bcfc95af84b262971687 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 30 May 2020 21:37:50 -0700 Subject: [PATCH 7/8] [AutoDiff upstream] Add forward-mode differentiation runtime tests. (#32106) Forward-mode differentiation development isn't currently prioritized, but upstreaming tests allows us to prevent regressions. --- .../validation-test/forward_mode.swift | 1254 +++++++++++++++++ 1 file changed, 1254 insertions(+) create mode 100644 test/AutoDiff/validation-test/forward_mode.swift diff --git a/test/AutoDiff/validation-test/forward_mode.swift b/test/AutoDiff/validation-test/forward_mode.swift new file mode 100644 index 0000000000000..14bd1e38d00e7 --- /dev/null +++ b/test/AutoDiff/validation-test/forward_mode.swift @@ -0,0 +1,1254 @@ +// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation) +// REQUIRES: executable_test + +import StdlibUnittest +import DifferentiationUnittest + +var ForwardModeTests = TestSuite("ForwardModeDifferentiation") + +//===----------------------------------------------------------------------===// +// Basic tests. +//===----------------------------------------------------------------------===// + +ForwardModeTests.test("Identity") { + func func_to_diff(x: Float) -> Float { + return x + } + let (y, differential) = valueWithDifferential(at: 4, in: func_to_diff) + expectEqual(4, y) + expectEqual(1, differential(1)) +} + +ForwardModeTests.test("Unary") { + func func_to_diff(x: Float) -> Float { + return x * x + } + let (y, differential) = valueWithDifferential(at: 4, in: func_to_diff) + expectEqual(16, y) + expectEqual(8, differential(1)) +} + +ForwardModeTests.test("Binary") { + func func_to_diff(x: Float, y: Float) -> Float { + return x * y + } + let (y, differential) = valueWithDifferential(at: 4, 5, in: func_to_diff) + expectEqual(20, y) + expectEqual(9, differential(1, 1)) +} + +ForwardModeTests.test("BinaryWithLets") { + func func_to_diff(x: Float, y: Float) -> Float { + let a = x + y + let b = a + return b * -y + } + let (y, differential) = valueWithDifferential(at: 4, 5, in: func_to_diff) + expectEqual(-45, y) + expectEqual(-19, differential(1, 1)) +} + +ForwardModeTests.test("SubsetParametersDiff") { + func func_to_diff1(x: Int, y: Float, z: Int) -> Float { + return y + } + let (y1, differential1) = valueWithDifferential(at: 5) { y in + func_to_diff1(x: 0, y: y, z: 0) + } + expectEqual(5, y1) + expectEqual(1, differential1(1)) + + func func_to_diff2(x: Float, y: Int, z: Int) -> Float { + return 2 * x + } + let (y2, differential2) = valueWithDifferential(at: 6) { x in + func_to_diff2(x: x, y: 0, z: 0) + } + expectEqual(12, y2) + expectEqual(2, differential2(1)) + + func func_to_diff3(x: Int, y: Int, z: Float) -> Float { + return 3 * z + } + let (y3, differential3) = valueWithDifferential(at: 7) { z in + func_to_diff3(x: 0, y: 0, z: z) + } + expectEqual(21, y3) + expectEqual(3, differential3(1)) +} + +//===----------------------------------------------------------------------===// +// Functions with variables +//===----------------------------------------------------------------------===// + +ForwardModeTests.test("UnaryWithVars") { + func unary(x: Float) -> Float { + var a = x + a = x + var b = a + 2 + b = b - 1 + let c: Float = 3 + var d = a + b + c - 1 + d = d + d + return d + } + + let (y, differential) = valueWithDifferential(at: 4, in: unary) + expectEqual(22, y) + expectEqual(4, differential(1)) +} + +//===----------------------------------------------------------------------===// +// Functions with basic struct +//===----------------------------------------------------------------------===// + +struct A: Differentiable & AdditiveArithmetic { + var x: Float +} + +ForwardModeTests.test("StructInit") { + func structInit(x: Float) -> A { + return A(x: 2 * x) + } + + let (y, differential) = valueWithDifferential(at: 4, in: structInit) + expectEqual(A(x: 8), y) + expectEqual(A(x: 2), differential(1)) +} + +ForwardModeTests.test("StructExtract") { + func structExtract(x: A) -> Float { + return 2 * x.x + } + + let (y, differential) = valueWithDifferential( + at: A(x: 4), + in: structExtract) + expectEqual(8, y) + expectEqual(2, differential(A(x: 1))) +} + +ForwardModeTests.test("LocalStructVariable") { + func structExtract(x: A) -> A { + let a = A(x: 2 * x.x) // 2x + var b = A(x: a.x + 2) // 2x + 2 + b = A(x: b.x + a.x) // 2x + 2 + 2x = 4x + 2 + return b + } + + let (y, differential) = valueWithDifferential( + at: A(x: 4), + in: structExtract) + expectEqual(A(x: 18), y) + expectEqual(A(x: 4), differential(A(x: 1))) +} + +//===----------------------------------------------------------------------===// +// Functions with methods +//===----------------------------------------------------------------------===// + +extension A { + func noParamMethodA() -> A { + return A(x: 2 * x) + } + + func noParamMethodx() -> Float { + return 2 * x + } + + static func *(lhs: A, rhs: A) -> A { + return A(x: lhs.x * rhs.x) + } + + func complexBinaryMethod(u: A, v: Float) -> A { + var b: A = u * A(x: 2) // A(x: u * 2) + b.x = b.x * v // A(x: u * 2 * v) + let c = b.x + 1 // u * 2 * v + 1 + + // A(x: u * 2 * v + 1 + u * 2 * v) = A(x: x * (4uv + 1)) + return A(x: x * (c + b.x)) + } +} + +ForwardModeTests.test("noParamMethodA") { + let (y, differential) = valueWithDifferential(at: A(x: 4)) { x in + x.noParamMethodA() + } + expectEqual(A(x: 8), y) + expectEqual(A(x: 2), differential(A(x: 1))) +} + +ForwardModeTests.test("noParamMethodx") { + let (y, differential) = valueWithDifferential(at: A(x: 4)) { x in + x.noParamMethodx() + } + expectEqual(8, y) + expectEqual(2, differential(A(x: 1))) +} + +ForwardModeTests.test("complexBinaryMethod") { + let (y, differential) = valueWithDifferential(at: A(x: 4), A(x: 5), 3) { + (x, y, z) in + // derivative = A(x: 4uv + 4xv + 4ux + 1) = 4*5*3 + 4*4*3 + 4*5*4 + 1 = 189 + x.complexBinaryMethod(u: y, v: z) + } + expectEqual(A(x: 244), y) + expectEqual(A(x: 189), differential(A(x: 1), A(x: 1), 1)) +} + +//===----------------------------------------------------------------------===// +// Tracked struct +//===----------------------------------------------------------------------===// + +ForwardModeTests.testWithLeakChecking("TrackedIdentity") { + func identity(x: Tracked) -> Tracked { + return x + } + let (y, differential) = valueWithDifferential(at: 4, in: identity) + expectEqual(4, y) + expectEqual(1, differential(1)) +} + +ForwardModeTests.testWithLeakChecking("TrackedAddition") { + func add(x: Tracked, y: Tracked) -> Tracked { + return x + y + } + let (y, differential) = valueWithDifferential(at: 4, 5, in: add) + expectEqual(9, y) + expectEqual(2, differential(1, 1)) +} + +ForwardModeTests.testWithLeakChecking("TrackedDivision") { + func divide(x: Tracked, y: Tracked) -> Tracked { + return x / y + } + let (y, differential) = valueWithDifferential(at: 10, 5, in: divide) + expectEqual(2, y) + expectEqual(-0.2, differential(1, 1)) +} + +ForwardModeTests.testWithLeakChecking("TrackedMultipleMultiplication") { + func add(x: Tracked, y: Tracked) -> Tracked { + return x * y * x + } + let (y, differential) = valueWithDifferential(at: 4, 5, in: add) + expectEqual(80, y) + // 2yx+xx + expectEqual(56, differential(1, 1)) +} + +ForwardModeTests.testWithLeakChecking("TrackedWithLets") { + func add(x: Tracked, y: Tracked) -> Tracked { + let a = x + y + let b = a * a // (x+y)^2 + let c = b / x + y // (x+y)^2/x+y + return c + } + // (3x^2+2xy-y^2)/x^2+1 + let (y, differential) = valueWithDifferential(at: 4, 5, in: add) + expectEqual(25.25, y) + expectEqual(4.9375, differential(1, 1)) +} + +//===----------------------------------------------------------------------===// +// Tuples +//===----------------------------------------------------------------------===// + +ForwardModeTests.test("TupleLet") { + do { + func tupleLet(_ x: Float) -> Float { + let tuple = (2 * x, x) + return tuple.0 + } + let (value, derivative) = valueWithDerivative(at: 4, in: tupleLet) + expectEqual(8, value) + expectEqual(2, derivative) + } +} + +ForwardModeTests.test("TupleVar") { + do { + func tupleVar(_ x: Float) -> Float { + var tuple = (2 * x, x) + return tuple.0 + } + let (value, derivative) = valueWithDerivative(at: 4, in: tupleVar) + expectEqual(8, value) + expectEqual(2, derivative) + } + + do { + // TF-964: Test tuple with non-tuple-typed adjoint value. + func TF_964(_ x: Float) -> Float { + var tuple = (2 * x, 1) + return tuple.0 + } + let (value, derivative) = valueWithDerivative(at: 4, in: TF_964) + expectEqual(8, value) + expectEqual(2, derivative) + } +} + +ForwardModeTests.test("TupleMutation") { + func foo(_ x: Float) -> Float { + var tuple = (x, x) + tuple.0 = tuple.0 * x + return x * tuple.0 + } + expectEqual(27, derivative(at: 3, in: foo)) + + func fifthPower(_ x: Float) -> Float { + var tuple = (x, x) + tuple.0 = tuple.0 * x + tuple.1 = tuple.0 * x + return tuple.0 * tuple.1 + } + expectEqual(405, derivative(at: 3, in: fifthPower)) + + func nested(_ x: Float) -> Float { + var tuple = ((x, x), x) + tuple.0.0 = tuple.0.0 * x + tuple.0.1 = tuple.0.0 * x + return tuple.0.0 * tuple.0.1 + } + expectEqual(405, derivative(at: 3, in: nested)) + + func generic(_ x: T) -> T { + var tuple = (x, x) + return tuple.0 + } + expectEqual(1, derivative(at: 3.0, in: generic)) + + // FIXME(TF-1033): Fix forward-mode ownership error for tuple with non-active + // initial values. + /* + func genericInitialNonactive( + _ x: T + ) -> T { + var tuple = (T.zero, T.zero) + tuple.0 = x + tuple.1 = x + return tuple.0 + } + expectEqual(1, derivative(at: 3.0, in: genericInitialNonactive)) + */ +} + +// Tests TF-321. +ForwardModeTests.test("TupleNonDifferentiableElements") { + // TF-964: Test tuple with non-tuple-typed adjoint value. + func tupleLet(_ x: Tracked) -> Tracked { + let tuple = (2 * x, 1) + return tuple.0 + } + expectEqual((8, 2), valueWithDerivative(at: 4, in: tupleLet)) + + func tupleVar(_ x: Tracked) -> Tracked { + var tuple = (x, 1) + tuple.0 = x + tuple.1 = 1 + return tuple.0 + } + expectEqual((3, 1), valueWithDerivative(at: 3, in: tupleVar)) + + @differentiable + func nested(_ x: Tracked) -> Tracked { + // Convoluted function computing `x * x`. + var tuple: (Int, (Int, Tracked), Tracked) = (1, (1, 0), 0) + tuple.0 = 1 + tuple.1.0 = 1 + tuple.1.1 = x + tuple.2 = x + return tuple.1.1 * tuple.2 + } + // FIXME(SR-12911): Fix runtime segfault. + // expectEqual((16, 8), valueWithDerivative(at: 4, in: nested)) + + struct Wrapper { + @differentiable(where T : Differentiable) + func baz(_ x: T) -> T { + var tuple = (1, 1, x, 1) + tuple.0 = 1 + tuple.2 = x + tuple.3 = 1 + return tuple.2 + } + } + func wrapper(_ x: Tracked) -> Tracked { + let w = Wrapper>() + return w.baz(x) + } + expectEqual((3, 1), valueWithDerivative(at: 3, in: wrapper)) +} + +//===----------------------------------------------------------------------===// +// Generics +//===----------------------------------------------------------------------===// + +struct Tensor + : AdditiveArithmetic, Differentiable { + // NOTE: `value` must have type with known size (e.g. `Float`, not `Scalar`) + // until differentiation has indirect passing support. + var value: Float + init(_ value: Float) { self.value = value } +} + +ForwardModeTests.test("GenericIdentity") { + func identity(_ x: T) -> T { + return x + } + let (y, differential) = valueWithDifferential(at: 4) { (x: Float) in + identity(x) + } + expectEqual(4, y) + expectEqual(1, differential(1)) +} + +ForwardModeTests.test("GenericTensorIdentity") { + func identity( + _ x: Tensor) -> Tensor { + return x + } + let (y, differential) = valueWithDifferential(at: 4) { (x: Float) in + identity(Tensor(x)) + } + expectEqual(Tensor(4), y) + expectEqual(Tensor(1), differential(1)) +} + +ForwardModeTests.test("GenericTensorPlus") { + func plus(_ x: Tensor) -> Float { + return x.value + x.value + } + let (y, differential) = valueWithDifferential(at: 4) { (x: Float) in + plus(Tensor(x)) + } + expectEqual(8, y) + expectEqual(2, differential(1)) +} + +ForwardModeTests.test("GenericTensorBinaryInput") { + func binary( + _ x: Tensor, _ y: Tensor) -> Float { + return x.value * y.value + } + let (y, differential) = valueWithDifferential(at: 4, 5) { + (x: Float, y: Float) in + binary(Tensor(x), Tensor(y)) + } + expectEqual(20, y) + expectEqual(9, differential(1, 1)) +} + +ForwardModeTests.test("GenericTensorWithLets") { + func binary( + _ x: Tensor, _ y: Tensor) -> Float { + let a = Tensor(x.value) + let b = Tensor(y.value) + return a.value * b.value + } + let (y, differential) = valueWithDifferential(at: 4, 5) { + (x: Float, y: Float) in + binary(Tensor(x), Tensor(y)) + } + expectEqual(20, y) + expectEqual(9, differential(1, 1)) +} + +ForwardModeTests.test("GenericTensorWithVars") { + func binary( + _ x: Tensor, _ y: Tensor) -> Float { + var a = Tensor(x.value) + var b = Tensor(y.value) + b = a + a = Tensor(y.value) + return a.value * b.value + } + let (y, differential) = valueWithDifferential(at: 4, 5) { + (x: Float, y: Float) in + binary(Tensor(x), Tensor(y)) + } + expectEqual(20, y) + expectEqual(9, differential(1, 1)) +} + +// Test case where associated derivative function's requirements are met. +extension Tensor where Scalar : Numeric { + @differentiable(wrt: self where Scalar : Differentiable & FloatingPoint) + func mean() -> Tensor { + return self + } + + @differentiable(wrt: self where Scalar : Differentiable & FloatingPoint) + func variance() -> Tensor { + return mean() // ok + } +} +_ = differential(at: Tensor(1), in: { $0.variance() }) + +// Tests TF-508: differentiation requirements with dependent member types. +protocol TF_508_Proto { + associatedtype Scalar +} +extension TF_508_Proto where Scalar : FloatingPoint { + @differentiable( + where Self : Differentiable, Scalar : Differentiable, + // Conformance requirement with dependent member type. + Self.TangentVector : TF_508_Proto + ) + static func +(lhs: Self, rhs: Self) -> Self { + return lhs + } + + @differentiable( + where Self : Differentiable, Scalar : Differentiable, + // Same-type requirement with dependent member type. + Self.TangentVector == Float + ) + static func -(lhs: Self, rhs: Self) -> Self { + return lhs + } +} +extension TF_508_Proto where Self : Differentiable, + Scalar : FloatingPoint & Differentiable, + Self.TangentVector : TF_508_Proto { + @derivative(of: +) + static func jvpAdd(lhs: Self, rhs: Self) + -> (value: Self, differential: (TangentVector, TangentVector) -> TangentVector) { + return (lhs, { (dlhs, drhs) in dlhs }) + } +} +extension TF_508_Proto where Self : Differentiable, + Scalar : FloatingPoint & Differentiable, + Self.TangentVector == Float { + @derivative(of: -) + static func jvpSubtract(lhs: Self, rhs: Self) + -> (value: Self, differential: (TangentVector, TangentVector) -> TangentVector) { + return (lhs, { (dlhs, drhs) in dlhs }) + } +} + +struct TF_508_Struct + : TF_508_Proto, AdditiveArithmetic {} +extension TF_508_Struct : Differentiable where Scalar : Differentiable { + typealias TangentVector = TF_508_Struct +} + +// func TF_508() { +// let x = TF_508_Struct() +// // Test conformance requirement with dependent member type. +// _ = differential(at: x, in: { +// (x: TF_508_Struct) -> TF_508_Struct in +// return x + x +// }) +// // Test same-type requirement with dependent member type. +// _ = differential(at: x, in: { +// (x: TF_508_Struct) -> TF_508_Struct in +// return x - x +// }) +// } + +// TF-523 +struct TF_523_Struct : Differentiable & AdditiveArithmetic { + var a: Float = 1 + typealias TangentVector = TF_523_Struct + typealias AllDifferentiableVariables = TF_523_Struct +} + +@differentiable +func TF_523_f(_ x: TF_523_Struct) -> Float { + return x.a * 2 +} + +// TF-534: Thunk substitution map remapping. +protocol TF_534_Layer : Differentiable { + associatedtype Input : Differentiable + associatedtype Output : Differentiable + + @differentiable + func callAsFunction(_ input: Input) -> Output +} +struct TF_534_Tensor : Differentiable {} + +func TF_534( + _ model: inout Model, inputs: Model.Input +) -> TF_534_Tensor where Model.Output == TF_534_Tensor { + return valueWithDifferential(at: model) { model -> Model.Output in + return model(inputs) + }.0 +} + +// TODO: uncomment once control flow is supported in forward mode. +// TF-652: Test VJPEmitter substitution map generic signature. +// The substitution map should have the VJP's generic signature, not the +// original function's. +// struct TF_652 {} +// extension TF_652 : Differentiable where Scalar : FloatingPoint {} + +// @differentiable(wrt: x where Scalar: FloatingPoint) +// func test(x: TF_652) -> TF_652 { +// for _ in 0..<10 { +// let _ = x +// } +// return x +// } + +//===----------------------------------------------------------------------===// +// Tracked Generic. +//===----------------------------------------------------------------------===// + +ForwardModeTests.test("GenericTrackedIdentity") { + func identity(_ x: Tracked) -> Tracked { + return x + } + let (y, differential) = valueWithDifferential(at: 4) { (x: Float) in + identity(Tracked(x)) + } + expectEqual(4, y) + expectEqual(1, differential(1)) +} + +ForwardModeTests.test("GenericTrackedBinaryAdd") { + func add(_ x: Tracked, _ y: Tracked) -> Tracked + where T: Differentiable, T == T.TangentVector { + return x + y + } + let (y, differential) = valueWithDifferential(at: 4, 5) { + (x: Float, y: Float) in + add(Tracked(x), Tracked(y)) + } + expectEqual(9, y) + expectEqual(2, differential(1, 1)) +} + +ForwardModeTests.test("GenericTrackedBinaryLets") { + func add(_ x: Tracked, _ y: Tracked) -> Tracked + where T: Differentiable & SignedNumeric, + T == T.TangentVector, + T == T.Magnitude { + let a = x * y // xy + let b = a + a // 2xy + return b + b // 4xy + } + // 4y + 4x + let (y, differential) = valueWithDifferential(at: 4, 5) { (x: Float, y: Float) in + add(Tracked(x), Tracked(y)) + } + expectEqual(80, y) + expectEqual(36, differential(1, 1)) +} + +ForwardModeTests.test("GenericTrackedBinaryVars") { + func add(_ x: Tracked, _ y: Tracked) -> Tracked + where T: Differentiable & SignedNumeric, + T == T.TangentVector, + T == T.Magnitude { + var a = x * y // xy + a = a + a // 2xy + var b = x + b = a + return b + b // 4xy + } + // 4y + 4x + let (y, differential) = valueWithDifferential(at: 4, 5) { (x: Float, y: Float) in + add(Tracked(x), Tracked(y)) + } + expectEqual(80, y) + expectEqual(36, differential(1, 1)) +} + +ForwardModeTests.testWithLeakChecking("TrackedDifferentiableFuncType") { + func valAndDeriv( + f: @escaping @differentiable (Tracked) -> Tracked + ) -> (Tracked, Tracked) { + let (y, diff) = valueWithDifferential(at: 5, in: f) + return (y, diff(1)) + } + + func func1(_ x: Tracked) -> Tracked { + let a = x + x // 2x + let b = a + a // 4x + return b * b // 16x^2 + } + let (val1, dv1) = valAndDeriv(f: func1) + expectEqual(400, val1) + expectEqual(160, dv1) +} + +//===----------------------------------------------------------------------===// +// Classes +//===----------------------------------------------------------------------===// + +ForwardModeTests.test("Final") { + final class Final : Differentiable { + func method(_ x: Float) -> Float { + return x * x + } + } + + for i in -5...5 { + expectEqual( + Float(i) * 2, + derivative(at: Float(i)) { x in Final().method(x) }) + } +} + +ForwardModeTests.test("Simple") { + class Super { + @differentiable(wrt: x) + func f(_ x: Float) -> Float { + return 2 * x + } + @derivative(of: f) + final func jvpf(_ x: Float) -> (value: Float, differential: (Float) -> Float) { + return (f(x), { v in 2 * v }) + } + @derivative(of: f) + final func vjpf(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { + return (f(x), { v in 2 * v }) + } + } + + class SubOverride : Super { + @differentiable(wrt: x) + override func f(_ x: Float) -> Float { + return 3 * x + } + } + + class SubOverrideCustomDerivatives : Super { + @differentiable(wrt: x) + override func f(_ x: Float) -> Float { + return 3 * x + } + @derivative(of: f) + final func jvpf2(_ x: Float) -> (value: Float, differential: (Float) -> Float) { + return (f(x), { v in 3 * v }) + } + @derivative(of: f) + final func vjpf2(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { + return (f(x), { v in 3 * v }) + } + } + + func classValueWithDerivative(_ c: Super) -> (Float, Float) { + return valueWithDerivative(at: 1) { c.f($0) } + } + + expectEqual((2, 2), classValueWithDerivative(Super())) + expectEqual((3, 3), classValueWithDerivative(SubOverride())) + expectEqual((3, 3), classValueWithDerivative(SubOverrideCustomDerivatives())) +} + +ForwardModeTests.test("SimpleWrtSelf") { + class Super : Differentiable { + var base: Float + // FIXME(TF-648): Dummy to make `Super.AllDifferentiableVariables` be nontrivial. + var _nontrivial: [Float] = [] + + // FIXME(SR-12175): Fix forward-mode differentiation crash. + // @differentiable + required init(base: Float) { + self.base = base + } + + @differentiable(wrt: (self, x)) + func f(_ x: Float) -> Float { + return base * x + } + @derivative(of: f) + final func jvpf(_ x: Float) -> (value: Float, differential: (TangentVector, Float) -> Float) { + return (f(x), { (dself, dx) in dself.base * dx }) + } + @derivative(of: f) + final func vjpf(_ x: Float) -> (value: Float, pullback: (Float) -> (TangentVector, Float)) { + let base = self.base + return (f(x), { v in + (TangentVector(base: v * x, _nontrivial: []), base * v) + }) + } + } + + class SubOverride : Super { + @differentiable(wrt: (self, x)) + override func f(_ x: Float) -> Float { + return 3 * x + } + } + + class SubOverrideCustomDerivatives : Super { + @differentiable(wrt: (self, x)) + @differentiable(wrt: x) + override func f(_ x: Float) -> Float { + return 3 * x + } + @derivative(of: f, wrt: x) + final func jvpf2(_ x: Float) -> (value: Float, differential: (Float) -> Float) { + return (f(x), { v in 3 * v }) + } + @derivative(of: f, wrt: x) + final func vjpf2(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { + return (f(x), { v in 3 * v }) + } + } + + // FIXME(SR-12175): Fix forward-mode differentiation crash. + // let v = Super.TangentVector(base: 100, _nontrivial: []) + // expectEqual(100, pullback(at: 1337) { x in Super(base: x) }(v)) + // expectEqual(100, pullback(at: 1337) { x in SubOverride(base: x) }(v)) + // expectEqual(100, pullback(at: 1337) { x in SubOverrideCustomDerivatives(base: x) }(v)) + + // `valueWithDerivative` is not used because the derivative requires `Super` + // to conform to `FloatingPoint`. + func classDifferential( + _ c: Super + ) -> (Float, (Super.TangentVector, Float) -> Float) { + return valueWithDifferential(at: c, 10) { (c: Super, x: Float) in c.f(x) } + } + + let (y1, diff1) = classDifferential(Super(base: 5)) + expectEqual(50, y1) + let c1 = Super.TangentVector(base: 1, _nontrivial: []) + expectEqual(1, diff1(c1, 1)) + let (y2, diff2) = classDifferential(SubOverride(base: 5)) + expectEqual(30, y2) + let c2 = SubOverride.TangentVector(base: 1, _nontrivial: []) + expectEqual(3, diff2(c2, 1)) + let (y3, diff3) = classDifferential(SubOverrideCustomDerivatives(base: 5)) + expectEqual(30, y3) + let c3 = SubOverrideCustomDerivatives.TangentVector(base: 1, _nontrivial: []) + expectEqual(3, diff3(c3, 1)) +} + +//===----------------------------------------------------------------------===// +// Protocols +//===----------------------------------------------------------------------===// + +protocol Prot : Differentiable { + @differentiable(wrt: x) + func foo(x: Float) -> Float +} +ForwardModeTests.test("Simple Protocol") { + struct Linear: Prot, AdditiveArithmetic { + typealias TangentVector = Linear + + let m: Float + let b: Float + + @differentiable(wrt: x) + func foo(x: Float) -> Float { + return m * x + b + } + } + + func genericFoo(_ t: T, _ x: Float) -> Float { + t.foo(x: x) + } + let inst = Linear(m: 5, b: -2) + let (y1, diff1) = valueWithDifferential(at: 5) { x in genericFoo(inst, x) } + expectEqual(23, y1) + expectEqual(5, diff1(1)) +} + +protocol DiffReq : Differentiable { + @differentiable(wrt: (self, x)) + func f(_ x: Float) -> Float +} + +extension DiffReq where TangentVector : AdditiveArithmetic { + @inline(never) // Prevent specialization, to test all witness code. + func derivF(at x: Float) -> Float { + return (valueWithDifferential(at: x) { x in self.f(x) }).1(1) + } +} + +struct Quadratic : DiffReq, AdditiveArithmetic { + typealias TangentVector = Quadratic + + @differentiable + let a: Float + + @differentiable + let b: Float + + @differentiable + let c: Float + + init(_ a: Float, _ b: Float, _ c: Float) { + self.a = a + self.b = b + self.c = c + } + + @differentiable(wrt: (self, x)) + func f(_ x: Float) -> Float { + return a * x * x + b * x + c + } +} + +ForwardModeTests.test("ProtocolFunc") { + expectEqual(12, Quadratic(11, 12, 13).derivF(at: 0)) + expectEqual(2 * 11 + 12, Quadratic(11, 12, 13).derivF(at: 1)) + expectEqual(2 * 11 * 2 + 12, Quadratic(11, 12, 13).derivF(at: 2)) +} + +// MARK: Constructor, accessor, and subscript requirements. + +protocol FunctionsOfX: Differentiable { + @differentiable + init(x: Float) + + @differentiable + var x: Float { get } + + @differentiable + var y: Float { get } + + @differentiable + var z: Float { get } + + @differentiable + subscript() -> Float { get } +} + +struct TestFunctionsOfX: FunctionsOfX { + @differentiable + init(x: Float) { + self.x = x + self.y = x * x + } + + /// x = x + var x: Float + + /// y = x * x + var y: Float + + /// z = x * x + x + var z: Float { + return y + x + } + + @differentiable + subscript() -> Float { + return z + } +} + +@inline(never) // Prevent specialization, to test all witness code. +func derivatives(at x: Float, in: F.Type) + -> (Float, Float, Float, Float) +{ + let dxdx = derivative(at: x) { x in F(x: x).x } + let dydx = derivative(at: x) { x in F(x: x).y } + let dzdx = derivative(at: x) { x in F(x: x).z } + let dsubscriptdx = derivative(at: x) { x in F(x: x)[] } + return (dxdx, dydx, dzdx, dsubscriptdx) +} + +ForwardModeTests.test("constructor, accessor, subscript") { + expectEqual( + (1.0, 4.0, 5.0, 5.0), + derivatives(at: 2.0, in: TestFunctionsOfX.self)) +} + +// MARK: - Test witness method SIL type computation. + +protocol P : Differentiable { + @differentiable(wrt: (x, y)) + func foo(_ x: Float, _ y: Double) -> Float +} +struct S : P { + @differentiable(wrt: (x, y)) + func foo(_ x: Float, _ y: Double) -> Float { + return x + } +} + +// MARK: - Overridden protocol method adding differentiable attribute. + +public protocol Distribution { + associatedtype Value + func logProbability(of value: Value) -> Float +} + +public protocol DifferentiableDistribution: Differentiable, Distribution { + @differentiable(wrt: self) + func logProbability(of value: Value) -> Float +} + +struct Foo: DifferentiableDistribution { + @differentiable(wrt: self) + func logProbability(of value: Float) -> Float { + .zero + } +} + +@differentiable +func blah(_ x: T) -> Float where T.Value: AdditiveArithmetic { + x.logProbability(of: .zero) +} + +// Adding a more general `@differentiable` attribute. +public protocol DoubleDifferentiableDistribution: DifferentiableDistribution + where Value: Differentiable { + @differentiable(wrt: self) + @differentiable(wrt: (self, value)) + func logProbability(of value: Value) -> Float +} + +@differentiable +func blah2(_ x: T, _ value: T.Value) -> Float + where T.Value: AdditiveArithmetic { + x.logProbability(of: value) +} + +protocol DifferentiableFoo { + associatedtype T: Differentiable + @differentiable(wrt: x) + func foo(_ x: T) -> Float +} + +protocol MoreDifferentiableFoo: Differentiable, DifferentiableFoo { + @differentiable(wrt: (self, x)) + func foo(_ x: T) -> Float +} + +struct MoreDifferentiableFooStruct: MoreDifferentiableFoo { + @differentiable(wrt: (self, x)) + func foo(_ x: Float) -> Float { + x + } +} + +//===----------------------------------------------------------------------===// +// Simple Math +//===----------------------------------------------------------------------===// + +ForwardModeTests.test("Arithmetics") { + func foo1(x: Float, y: Float) -> Float { + return x * y + } + expectEqual(7, derivative(at: 3, 4, in: foo1)) + func foo2(x: Float, y: Float) -> Float { + return -x * y + } + expectEqual(-7, derivative(at: 3, 4, in: foo2)) + func foo3(x: Float, y: Float) -> Float { + return -x + y + } + expectEqual(0, derivative(at: 3, 4, in: foo3)) +} + +ForwardModeTests.test("Fanout") { + func foo1(x: Float) -> Float { + x - x + } + expectEqual(0, derivative(at: 100, in: foo1)) + func foo2(x: Float) -> Float { + x + x + } + expectEqual(2, derivative(at: 100, in: foo2)) + func foo3(x: Float, y: Float) -> Float { + x + x + x * y + } + expectEqual(7, derivative(at: 3, 2, in: foo3)) +} + +ForwardModeTests.test("FunctionCall") { + func foo(_ x: Float, _ y: Float) -> Float { + return 3 * x + { $0 * 3 }(3) * y + } + expectEqual(12, derivative(at: 3, 4, in: foo)) + expectEqual(3, derivative(at: 3) { x in foo(x, 4) }) +} + +ForwardModeTests.test("ResultSelection") { + func foo(_ x: Float, _ y: Float) -> (Float, Float) { + return (x + 1, y + 2) + } + expectEqual(1, derivative(at: 3, 3, in: { x, y in foo(x, y).0 })) + expectEqual(1, derivative(at: 3, 3, in: { x, y in foo(x, y).1 })) +} + +ForwardModeTests.test("CaptureLocal") { + let z: Float = 10 + func foo(_ x: Float) -> Float { + return z * x + } + expectEqual(10, derivative(at: 0, in: foo)) +} + +var globalVar: Float = 10 +ForwardModeTests.test("CaptureGlobal") { + func foo(x: Float) -> Float { + globalVar += 20 + return globalVar * x + } + expectEqual(30, derivative(at: 0, in: foo)) +} + +ForwardModeTests.test("Mutation") { + func fourthPower(x: Float) -> Float { + var a = x + a = a * x + a = a * x + return a * x + } + expectEqual(4 * 27, derivative(at: 3, in: fourthPower)) +} + +// Tests TF-21. +ForwardModeTests.test("StructMemberwiseInitializer") { + struct Foo : AdditiveArithmetic, Differentiable { + var stored: Float + var computed: Float { + return stored * stored + } + } + + let derivFoo = differential(at: Float(4), in: { input -> Foo in + let foo = Foo(stored: input) + let foo2 = foo + foo + return Foo(stored: foo2.stored) + })(1) + expectEqual(Foo.TangentVector(stored: 2), derivFoo) + + let computed = derivative(at: Float(4)) { input -> Float in + let foo = Foo(stored: input) + return foo.computed + } + expectEqual(8, computed) + + let derivProduct = derivative(at: Float(4)) { input -> Float in + let foo = Foo(stored: input) + return foo.computed * foo.stored + } + expectEqual(48, derivProduct) + + struct Custom : AdditiveArithmetic, Differentiable { + var x: Float + + // Custom initializer with `@differentiable`. + @differentiable + init(x: Float) { + self.x = x + } + } + + let derivCustom = differential(at: Float(4), in: { input -> Custom in + let foo = Custom(x: input) + return foo + foo + })(1) + expectEqual(Custom.TangentVector(x: 2), derivCustom) +} + +// Tests TF-319: struct with non-differentiable constant stored property. +ForwardModeTests.test("StructConstantStoredProperty") { + struct TF_319 : Differentiable { + var x: Float + @noDerivative let constant = Float(2) + + @differentiable + init(x: Float) { + self.x = x + } + + @differentiable(wrt: (self, input)) + func applied(to input: Float) -> Float { + return x * constant * input + } + } + func testStructInit(to input: Float) -> Float { + let model = TF_319(x: 10) + return model.applied(to: input) + } + expectEqual(6, derivative(at: 10, in: { TF_319(x: $0).applied(to: 3) })) + expectEqual(20, derivative(at: 3, in: testStructInit)) +} + +ForwardModeTests.test("StructMutation") { + struct Point : AdditiveArithmetic, Differentiable { + var x: Float + var y: Float + var z: Float + } + + func double(_ input: Float) -> Point { + let point = Point(x: input, y: input, z: input) + return point + point + } + expectEqual(Point(x: 2, y: 2, z: 2), differential(at: 4, in: double)(1)) + + func fifthPower(_ input: Float) -> Float { + var point = Point(x: input, y: input, z: input) + point.x = point.x * input + point.y = point.x * input + return point.x * point.y + } + expectEqual(405, derivative(at: 3, in: fifthPower)) + + func mix(_ input: Float) -> Float { + var tuple = (point: Point(x: input, y: input, z: input), float: input) + tuple.point.x = tuple.point.x * tuple.float + tuple.point.y = tuple.point.x * input + return tuple.point.x * tuple.point.y + } + expectEqual(405, derivative(at: 3, in: mix)) + + // Test TF-282. + struct Add : Differentiable { + var bias: Float + func applied(to input: Float) -> Float { + var tmp = input + tmp = tmp + bias + return tmp + } + } + expectEqual(1, derivative(at: 1) { m in Add(bias: m).applied(to: 1) }) +} + +ForwardModeTests.test("StructGeneric") { + struct Generic : AdditiveArithmetic, Differentiable { + var x: T + var y: T + var z: T + } + + let deriv = differential(at: Float(3), in: { input -> Generic in + var generic = Generic(x: input, y: input, z: input) + return generic + })(1) + expectEqual(Generic.TangentVector(x: 1, y: 1, z: 1), deriv) + + func fifthPower(_ input: Float) -> Float { + var generic = Generic(x: input, y: input, z: input) + generic.x = generic.x * input + generic.y = generic.x * input + return generic.x * generic.y + } + expectEqual(405, derivative(at: 3, in: fifthPower)) +} + +ForwardModeTests.test("SubsetIndices") { + func deriv(_ lossFunction: @differentiable (Float, Float) -> Float) -> Float { + return derivative(at: 1) { x in lossFunction(x * x, 10.0) } + } + expectEqual(2, deriv { x, y in x + y }) + + func derivWRTNonDiff(_ lossFunction: @differentiable (Float, @noDerivative Int) -> Float) -> Float { + return derivative(at: 2) { x in lossFunction(x * x, 10) } + } + expectEqual(4, derivWRTNonDiff { x, y in x + Float(y) }) +} + +ForwardModeTests.test("ForceUnwrapping") { + func forceUnwrap(_ t: T) -> Float where T == T.TangentVector { + derivative(at: t, Float(3)) { (x, y) in + (x as! Float) * y + } + } + expectEqual(5, forceUnwrap(Float(2))) +} + +runAllTests() From dd645d88d14126af9747fd768f98581cdc729e37 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 30 May 2020 23:53:13 -0700 Subject: [PATCH 8/8] [AutoDiff] Enable .swiftinterface verification for _Differentiation. (#32110) Enable .swiftinterface verification for the _Differentiation module to prevent regressions. --- validation-test/ParseableInterface/verify_all_overlays.py | 1 - 1 file changed, 1 deletion(-) diff --git a/validation-test/ParseableInterface/verify_all_overlays.py b/validation-test/ParseableInterface/verify_all_overlays.py index e3668a6e16892..f17a690c9b46e 100755 --- a/validation-test/ParseableInterface/verify_all_overlays.py +++ b/validation-test/ParseableInterface/verify_all_overlays.py @@ -50,7 +50,6 @@ continue if module_name in [ - "_Differentiation", "DifferentiationUnittest", "Swift", "SwiftLang",