Skip to content

[AutoDiff] Derive Differentiable.zeroTangentVectorInitializer. #31823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions docs/DifferentiableProgramming.md
Original file line number Diff line number Diff line change
Expand Up @@ -1079,11 +1079,6 @@ public extension Differentiable where Self == TangentVector {
mutating func move(along direction: TangentVector) {
self += direction
}

@noDerivative
var zeroTangentVectorInitializer: () -> TangentVector {
{ .zero }
}
}
```

Expand Down Expand Up @@ -1144,8 +1139,8 @@ extension Array: Differentiable where Element: Differentiable {

@noDerivative
public var zeroTangentVectorInitializer: () -> TangentVector {
{ [count = self.count] in
TangentVector(Array(repeating: .zero, count: count))
{ [zeroInits = map(\.zeroTangentVectorInitializer)] in
TangentVector(zeroInits.map { $0() })
}
}
}
Expand Down Expand Up @@ -1238,8 +1233,15 @@ the same effective access level as their corresponding original properties.

A `move(along:)` method is synthesized with a body that calls `move(along:)` for
each pair of the original property and its corresponding property in
`TangentVector`. Similarly, `zeroTangentVector` is synthesized to return a
tangent vector that consists of each stored property's `zeroTangentVector`.
`TangentVector`.

Similarly, when memberwise derivation is possible,
`zeroTangentVectorInitializer` is synthesized to return a closure that captures
and calls each stored property's `zeroTangentVectorInitializer` closure.
When memberwise derivation is not possible (e.g. for custom user-defined
`TangentVector` types), `zeroTangentVectorInitializer` is synthesized as a
`{ TangentVector.zero }` closure.

Here's an example:

```swift
Expand All @@ -1251,14 +1253,17 @@ struct Foo<T: Differentiable, U: Differentiable>: @memberwise Differentiable {
@noDerivative let helperVariable: T

// The compiler synthesizes:
//
// struct TangentVector: Differentiable, AdditiveArithmetic {
// var x: T.TangentVector
// var y: U.TangentVector
// }
//
// mutating func move(along direction: TangentVector) {
// x.move(along: direction.x)
// y.move(along: direction.y)
// }
//
// @noDerivative
// var zeroTangentVectorInitializer: () -> TangentVector {
// { [xTanInit = x.zeroTangentVectorInitializer,
Expand All @@ -1278,16 +1283,25 @@ properties are declared to conform to `AdditiveArithmetic`. There are no
`@noDerivative` stored properties.

In these cases, the compiler will make `TangentVector` be a type alias for Self.
Method `move(along:)` and property `zeroTangentVector` will not be synthesized
because a default implementation already exists.
Method `move(along:)` will not be synthesized because a default implementation
already exists.

```swift
struct Point<T: Real>: @memberwise Differentiable, @memberwise AdditiveArithmetic {
// `x` and `y` are the "differentiation properties".
var x, y: T

// The compiler synthesizes:
//
// typealias TangentVector = Self
//
// @noDerivative
// var zeroTangentVectorInitializer: () -> TangentVector {
// { [xTanInit = x.zeroTangentVectorInitializer,
// yTanInit = y.zeroTangentVectorInitializer] in
// TangentVector(x: xTanInit(), y: yTanInit())
// }
// }
}
```

Expand Down
1 change: 1 addition & 0 deletions include/swift/AST/KnownIdentifiers.def
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ IDENTIFIER(move)
IDENTIFIER(pullback)
IDENTIFIER(TangentVector)
IDENTIFIER(zero)
IDENTIFIER(zeroTangentVectorInitializer)

#undef IDENTIFIER
#undef IDENTIFIER_
Expand Down
19 changes: 19 additions & 0 deletions lib/Sema/CodeSynthesis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1414,3 +1414,22 @@ void swift::addFixedLayoutAttr(NominalTypeDecl *nominal) {
// Add `@_fixed_layout` to the nominal.
nominal->getAttrs().add(new (C) FixedLayoutAttr(/*Implicit*/ true));
}

Expr *DiscriminatorFinder::walkToExprPost(Expr *E) {
auto *ACE = dyn_cast<AbstractClosureExpr>(E);
if (!ACE)
return E;

unsigned Discriminator = ACE->getDiscriminator();
assert(Discriminator != AbstractClosureExpr::InvalidDiscriminator &&
"Existing closures should have valid discriminators");
if (Discriminator >= NextDiscriminator)
NextDiscriminator = Discriminator + 1;
return E;
}

unsigned DiscriminatorFinder::getNextDiscriminator() {
if (NextDiscriminator == AbstractClosureExpr::InvalidDiscriminator)
llvm::report_fatal_error("Out of valid closure discriminators");
return NextDiscriminator++;
}
15 changes: 15 additions & 0 deletions lib/Sema/CodeSynthesis.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#ifndef SWIFT_TYPECHECKING_CODESYNTHESIS_H
#define SWIFT_TYPECHECKING_CODESYNTHESIS_H

#include "swift/AST/ASTWalker.h"
#include "swift/AST/ForeignErrorConvention.h"
#include "swift/Basic/ExternalUnion.h"
#include "swift/Basic/LLVM.h"
Expand Down Expand Up @@ -75,6 +76,20 @@ bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal);
/// Add `@_fixed_layout` attribute to the nominal type, if possible.
void addFixedLayoutAttr(NominalTypeDecl *nominal);

/// Find available closure discriminators.
///
/// The parser typically takes care of assigning unique discriminators to
/// closures, but the parser is unavailable during semantic analysis.
class DiscriminatorFinder : public ASTWalker {
unsigned NextDiscriminator = 0;

public:
Expr *walkToExprPost(Expr *E) override;

// Get the next available closure discriminator.
unsigned getNextDiscriminator();
};

} // end namespace swift

#endif
30 changes: 1 addition & 29 deletions lib/Sema/DebuggerTestingTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
///
//===----------------------------------------------------------------------===//

#include "CodeSynthesis.h"
#include "swift/AST/ASTContext.h"
#include "swift/AST/ASTNode.h"
#include "swift/AST/ASTWalker.h"
Expand All @@ -33,35 +34,6 @@ using namespace swift;

namespace {

/// Find available closure discriminators.
///
/// The parser typically takes care of assigning unique discriminators to
/// closures, but the parser is unavailable to this transform.
class DiscriminatorFinder : public ASTWalker {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I moved DiscriminatorFinder to a shared location lib/Sema/CodeSynthesis.h.

unsigned NextDiscriminator = 0;

public:
Expr *walkToExprPost(Expr *E) override {
auto *ACE = dyn_cast<AbstractClosureExpr>(E);
if (!ACE)
return E;

unsigned Discriminator = ACE->getDiscriminator();
assert(Discriminator != AbstractClosureExpr::InvalidDiscriminator &&
"Existing closures should have valid discriminators");
if (Discriminator >= NextDiscriminator)
NextDiscriminator = Discriminator + 1;
return E;
}

// Get the next available closure discriminator.
unsigned getNextDiscriminator() {
if (NextDiscriminator == AbstractClosureExpr::InvalidDiscriminator)
llvm::report_fatal_error("Out of valid closure discriminators");
return NextDiscriminator++;
}
};

/// Instrument decls with sanity-checks which the debugger can evaluate.
class DebuggerTestingTransform : public ASTWalker {
ASTContext &Ctx;
Expand Down
Loading