Skip to content

Commit be54faa

Browse files
committed
[AutoDiff] Derive Differentiable.zeroTangentVectorInitializer.
`Differentiable` conformance derivation now supports `Differentiable.zeroTangentVectorInitializer`. There are two potential cases: 1. Memberwise derivation: done when `TangentVector` can be initialized memberwise. 2. `{ TangentVector.zero }` derivation: done as a fallback. `zeroTangentVectorInitializer` is a closure that produces a zero tangent vector, capturing minimal necessary information from `self`. It is an instance property, unlike the static property `AdditiveArithmetic.zero`, and should be used by the differentiation transform for correctness. Remove `Differentiable.zeroTangentVectorInitializer` dummy default implementation. Update stdlib `Differentiable` conformances and tests. Clean up DerivedConformanceDifferentiable.cpp cruft. Resolves TF-1007. Progress towards TF-1008: differentiation correctness for projection operations.
1 parent b90e579 commit be54faa

18 files changed

+743
-276
lines changed

docs/DifferentiableProgramming.md

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,11 +1079,6 @@ public extension Differentiable where Self == TangentVector {
10791079
mutating func move(along direction: TangentVector) {
10801080
self += direction
10811081
}
1082-
1083-
@noDerivative
1084-
var zeroTangentVectorInitializer: () -> TangentVector {
1085-
{ .zero }
1086-
}
10871082
}
10881083
```
10891084

@@ -1144,8 +1139,8 @@ extension Array: Differentiable where Element: Differentiable {
11441139

11451140
@noDerivative
11461141
public var zeroTangentVectorInitializer: () -> TangentVector {
1147-
{ [count = self.count] in
1148-
TangentVector(Array(repeating: .zero, count: count))
1142+
{ [zeroInits = map(\.zeroTangentVectorInitializer)] in
1143+
TangentVector(zeroInits.map { $0() })
11491144
}
11501145
}
11511146
}
@@ -1238,8 +1233,15 @@ the same effective access level as their corresponding original properties.
12381233

12391234
A `move(along:)` method is synthesized with a body that calls `move(along:)` for
12401235
each pair of the original property and its corresponding property in
1241-
`TangentVector`. Similarly, `zeroTangentVector` is synthesized to return a
1242-
tangent vector that consists of each stored property's `zeroTangentVector`.
1236+
`TangentVector`.
1237+
1238+
Similarly, when memberwise derivation is possible,
1239+
`zeroTangentVectorInitializer` is synthesized to return a closure that captures
1240+
and calls each stored property's `zeroTangentVectorInitializer` closure.
1241+
When memberwise derivation is not possible (e.g. for custom user-defined
1242+
`TangentVector` types), `zeroTangentVectorInitializer` is synthesized as a
1243+
`{ TangentVector.zero }` closure.
1244+
12431245
Here's an example:
12441246

12451247
```swift
@@ -1251,14 +1253,17 @@ struct Foo<T: Differentiable, U: Differentiable>: @memberwise Differentiable {
12511253
@noDerivative let helperVariable: T
12521254

12531255
// The compiler synthesizes:
1256+
//
12541257
// struct TangentVector: Differentiable, AdditiveArithmetic {
12551258
// var x: T.TangentVector
12561259
// var y: U.TangentVector
12571260
// }
1261+
//
12581262
// mutating func move(along direction: TangentVector) {
12591263
// x.move(along: direction.x)
12601264
// y.move(along: direction.y)
12611265
// }
1266+
//
12621267
// @noDerivative
12631268
// var zeroTangentVectorInitializer: () -> TangentVector {
12641269
// { [xTanInit = x.zeroTangentVectorInitializer,
@@ -1278,16 +1283,25 @@ properties are declared to conform to `AdditiveArithmetic`. There are no
12781283
`@noDerivative` stored properties.
12791284

12801285
In these cases, the compiler will make `TangentVector` be a type alias for Self.
1281-
Method `move(along:)` and property `zeroTangentVector` will not be synthesized
1282-
because a default implementation already exists.
1286+
Method `move(along:)` will not be synthesized because a default implementation
1287+
already exists.
12831288

12841289
```swift
12851290
struct Point<T: Real>: @memberwise Differentiable, @memberwise AdditiveArithmetic {
12861291
// `x` and `y` are the "differentiation properties".
12871292
var x, y: T
12881293

12891294
// The compiler synthesizes:
1295+
//
12901296
// typealias TangentVector = Self
1297+
//
1298+
// @noDerivative
1299+
// var zeroTangentVectorInitializer: () -> TangentVector {
1300+
// { [xTanInit = x.zeroTangentVectorInitializer,
1301+
// yTanInit = y.zeroTangentVectorInitializer] in
1302+
// TangentVector(x: xTanInit(), y: yTanInit())
1303+
// }
1304+
// }
12911305
}
12921306
```
12931307

include/swift/AST/KnownIdentifiers.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ IDENTIFIER(move)
223223
IDENTIFIER(pullback)
224224
IDENTIFIER(TangentVector)
225225
IDENTIFIER(zero)
226+
IDENTIFIER(zeroTangentVectorInitializer)
226227

227228
#undef IDENTIFIER
228229
#undef IDENTIFIER_

lib/Sema/CodeSynthesis.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,3 +1414,22 @@ void swift::addFixedLayoutAttr(NominalTypeDecl *nominal) {
14141414
// Add `@_fixed_layout` to the nominal.
14151415
nominal->getAttrs().add(new (C) FixedLayoutAttr(/*Implicit*/ true));
14161416
}
1417+
1418+
Expr *DiscriminatorFinder::walkToExprPost(Expr *E) {
1419+
auto *ACE = dyn_cast<AbstractClosureExpr>(E);
1420+
if (!ACE)
1421+
return E;
1422+
1423+
unsigned Discriminator = ACE->getDiscriminator();
1424+
assert(Discriminator != AbstractClosureExpr::InvalidDiscriminator &&
1425+
"Existing closures should have valid discriminators");
1426+
if (Discriminator >= NextDiscriminator)
1427+
NextDiscriminator = Discriminator + 1;
1428+
return E;
1429+
}
1430+
1431+
unsigned DiscriminatorFinder::getNextDiscriminator() {
1432+
if (NextDiscriminator == AbstractClosureExpr::InvalidDiscriminator)
1433+
llvm::report_fatal_error("Out of valid closure discriminators");
1434+
return NextDiscriminator++;
1435+
}

lib/Sema/CodeSynthesis.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#ifndef SWIFT_TYPECHECKING_CODESYNTHESIS_H
1919
#define SWIFT_TYPECHECKING_CODESYNTHESIS_H
2020

21+
#include "swift/AST/ASTWalker.h"
2122
#include "swift/AST/ForeignErrorConvention.h"
2223
#include "swift/Basic/ExternalUnion.h"
2324
#include "swift/Basic/LLVM.h"
@@ -75,6 +76,20 @@ bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal);
7576
/// Add `@_fixed_layout` attribute to the nominal type, if possible.
7677
void addFixedLayoutAttr(NominalTypeDecl *nominal);
7778

79+
/// Find available closure discriminators.
80+
///
81+
/// The parser typically takes care of assigning unique discriminators to
82+
/// closures, but the parser is unavailable during semantic analysis.
83+
class DiscriminatorFinder : public ASTWalker {
84+
unsigned NextDiscriminator = 0;
85+
86+
public:
87+
Expr *walkToExprPost(Expr *E) override;
88+
89+
// Get the next available closure discriminator.
90+
unsigned getNextDiscriminator();
91+
};
92+
7893
} // end namespace swift
7994

8095
#endif

lib/Sema/DebuggerTestingTransform.cpp

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
///
1616
//===----------------------------------------------------------------------===//
1717

18+
#include "CodeSynthesis.h"
1819
#include "swift/AST/ASTContext.h"
1920
#include "swift/AST/ASTNode.h"
2021
#include "swift/AST/ASTWalker.h"
@@ -33,35 +34,6 @@ using namespace swift;
3334

3435
namespace {
3536

36-
/// Find available closure discriminators.
37-
///
38-
/// The parser typically takes care of assigning unique discriminators to
39-
/// closures, but the parser is unavailable to this transform.
40-
class DiscriminatorFinder : public ASTWalker {
41-
unsigned NextDiscriminator = 0;
42-
43-
public:
44-
Expr *walkToExprPost(Expr *E) override {
45-
auto *ACE = dyn_cast<AbstractClosureExpr>(E);
46-
if (!ACE)
47-
return E;
48-
49-
unsigned Discriminator = ACE->getDiscriminator();
50-
assert(Discriminator != AbstractClosureExpr::InvalidDiscriminator &&
51-
"Existing closures should have valid discriminators");
52-
if (Discriminator >= NextDiscriminator)
53-
NextDiscriminator = Discriminator + 1;
54-
return E;
55-
}
56-
57-
// Get the next available closure discriminator.
58-
unsigned getNextDiscriminator() {
59-
if (NextDiscriminator == AbstractClosureExpr::InvalidDiscriminator)
60-
llvm::report_fatal_error("Out of valid closure discriminators");
61-
return NextDiscriminator++;
62-
}
63-
};
64-
6537
/// Instrument decls with sanity-checks which the debugger can evaluate.
6638
class DebuggerTestingTransform : public ASTWalker {
6739
ASTContext &Ctx;

0 commit comments

Comments
 (0)