Skip to content

[AutoDiff] Add 'zeroTangentVector' property to 'Differentiable' protocol. #26521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/AST/KnownIdentifiers.def
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ IDENTIFIER(AllDifferentiableVariables)
IDENTIFIER(TangentVector)
IDENTIFIER(allDifferentiableVariables)
IDENTIFIER(move)
IDENTIFIER(zeroTangentVector)

// Kinds of layout constraints
IDENTIFIER_WITH_NAME(UnknownLayout, "_UnknownLayout")
Expand Down
136 changes: 136 additions & 0 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,105 @@ static ValueDecl *getUnderlyingAllDiffableVariables(DeclContext *DC,
return allDiffableVarsDecl;
}

// Return the underlying `zeroTangentVector` of a VarDecl `x`. If `x` conforms
// to `Differentiable`, return `zeroTangentVector`. Otherwise, return
// `x`.
static ValueDecl *getUnderlyingZeroTangentVector(DeclContext *DC,
VarDecl *varDecl) {
auto *module = DC->getParentModule();
auto &C = module->getASTContext();
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
auto *zeroTangentVectorReq =
getProtocolRequirement(diffableProto, C.Id_zeroTangentVector);
if (!varDecl->hasInterfaceType())
C.getLazyResolver()->resolveDeclSignature(varDecl);
auto varType = DC->mapTypeIntoContext(varDecl->getValueInterfaceType());
auto confRef = module->lookupConformance(varType, diffableProto);
if (!confRef)
return varDecl;
// Use protocol requirement as a default for abstract conformances.
// If conformance is concrete, get concrete witness declaration instead.
ValueDecl *zeroTangentVectorDecl = zeroTangentVectorReq;
if (confRef->isConcrete())
zeroTangentVectorDecl = confRef->getConcrete()->getWitnessDecl(
zeroTangentVectorReq);
return zeroTangentVectorDecl;
}

// Get the effective memberwise initializer of the given nominal type, or create
// it if it does not exist.
static ConstructorDecl *getOrCreateEffectiveMemberwiseInitializer(
TypeChecker &TC, NominalTypeDecl *nominal) {
auto &C = nominal->getASTContext();
if (auto *initDecl = nominal->getEffectiveMemberwiseInitializer())
return initDecl;
auto *initDecl = createImplicitConstructor(
TC, nominal, ImplicitConstructorKind::Memberwise);
nominal->addMember(initDecl);
C.addSynthesizedDecl(initDecl);
return initDecl;
}

// Synthesize getter body for `zeroTangentVector` computed property.
static std::pair<BraceStmt *, bool>
derivedBody_zeroTangentVectorGetter(AbstractFunctionDecl *getterDecl, void *) {
auto *parentDC = getterDecl->getParent();
auto *nominal = parentDC->getSelfNominalTypeDecl();
auto &C = nominal->getASTContext();

auto *tangentVectorStruct =
getAssociatedStructDecl(parentDC, C.Id_TangentVector);
auto *tangentVectorInitDecl =
tangentVectorStruct->getEffectiveMemberwiseInitializer();
assert(tangentVectorInitDecl &&
"'TangentVector' memberwise initializer not found");
auto *selfDecl = getterDecl->getImplicitSelfDecl();
auto *selfDRE =
new (C) DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);

auto *initDRE = new (C) DeclRefExpr(tangentVectorInitDecl, DeclNameLoc(),
/*Implicit*/ true);
initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);

auto tangentVectorType = parentDC->mapTypeIntoContext(
tangentVectorStruct->getDeclaredInterfaceType());
Expr *baseExpr = TypeExpr::createImplicit(tangentVectorType, C);
auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, baseExpr);
initExpr->setThrows(false);
initExpr->setImplicit();

SmallVector<Expr *, 2> members;
SmallVector<Identifier, 2> memberNames;

llvm::DenseMap<Identifier, VarDecl *> diffPropertyMap;
SmallVector<VarDecl *, 8> diffProperties;
getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties);
for (auto *member : diffProperties)
diffPropertyMap[member->getName()] = member;

for (auto initParam : *tangentVectorInitDecl->getParameters()) {
auto member = diffPropertyMap[initParam->getName()];
member->setInterfaceType(member->getValueInterfaceType());
Expr *memberExpr = new (C) MemberRefExpr(selfDRE, SourceLoc(), member,
DeclNameLoc(), /*Implicit*/ true);
auto *memberZeroTangentVector =
getUnderlyingZeroTangentVector(parentDC, member);
auto *memberZeroTangentExpr =
new (C) MemberRefExpr(memberExpr, SourceLoc(), memberZeroTangentVector,
DeclNameLoc(), /*Implicit*/ true);
members.push_back(memberZeroTangentExpr);
memberNames.push_back(member->getName());
}
Expr *callExpr = CallExpr::createImplicit(C, initExpr, members, memberNames);

ASTNode returnStmt =
new (C) ReturnStmt(SourceLoc(), callExpr, /*Implicit*/ true);
auto *braceStmt =
BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(),
/*Implicit*/ true);
return std::make_pair(braceStmt, false);
}

// Synthesize getter body for `allDifferentiableVariables` computed property.
static std::pair<BraceStmt *, bool>
derivedBody_allDifferentiableVariablesGetter(AbstractFunctionDecl *getterDecl,
Expand Down Expand Up @@ -537,6 +636,41 @@ deriveDifferentiable_allDifferentiableVariables(DerivedConformance &derived) {
return allDiffableVarsDecl;
}

// Synthesize `zeroTangentVector` computed property declaration.
static ValueDecl *
deriveDifferentiable_zeroTangentVector(DerivedConformance &derived) {
auto *parentDC = derived.getConformanceContext();
auto &TC = derived.TC;
auto &C = TC.Context;

// Get `TangentVector` struct.
auto *tangentVectorStruct =
getAssociatedStructDecl(parentDC, C.Id_TangentVector);
// Make sure a memberwise initializer exists because the body synthesizer
// needs it.
getOrCreateEffectiveMemberwiseInitializer(TC, tangentVectorStruct);

auto returnInterfaceTy = tangentVectorStruct->getDeclaredInterfaceType();
auto returnTy = parentDC->mapTypeIntoContext(returnInterfaceTy);

VarDecl *zeroTangentVectorDecl;
PatternBindingDecl *pbDecl;
std::tie(zeroTangentVectorDecl, pbDecl) = derived.declareDerivedProperty(
C.Id_zeroTangentVector, returnInterfaceTy, returnTy,
/*isStatic*/ false, /*isFinal*/ true);

auto *getterDecl = derived.addGetterToReadOnlyDerivedProperty(
zeroTangentVectorDecl, returnTy);
getterDecl->setBodySynthesizer(&derivedBody_zeroTangentVectorGetter);
derived.addMembersToConformanceContext(
{getterDecl, zeroTangentVectorDecl, pbDecl});

addExpectedOpaqueAccessorsToStorage(zeroTangentVectorDecl, C);
triggerAccessorSynthesis(TC, zeroTangentVectorDecl);

return zeroTangentVectorDecl;
}

// Return associated `TangentVector` or `AllDifferentiableVariables` struct for
// a nominal type, if it exists.
// If not, synthesize the struct. Also return a Boolean value that indicates
Expand Down Expand Up @@ -1035,6 +1169,8 @@ ValueDecl *DerivedConformance::deriveDifferentiable(ValueDecl *requirement) {
return deriveDifferentiable_move(*this);
if (requirement->getBaseName() == TC.Context.Id_allDifferentiableVariables)
return deriveDifferentiable_allDifferentiableVariables(*this);
if (requirement->getBaseName() == TC.Context.Id_zeroTangentVector)
return deriveDifferentiable_zeroTangentVector(*this);
TC.diagnose(requirement->getLoc(), diag::broken_differentiable_requirement);
return nullptr;
}
Expand Down
5 changes: 5 additions & 0 deletions lib/Sema/DerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
if (name.isSimpleName(ctx.Id_allDifferentiableVariables))
return getRequirement(KnownProtocolKind::Differentiable);

// SWIFT_ENABLE_TENSORFLOW
// Differentiable.zeroTangentVector
if (name.isSimpleName(ctx.Id_zeroTangentVector))
return getRequirement(KnownProtocolKind::Differentiable);

return nullptr;
}

Expand Down
9 changes: 9 additions & 0 deletions stdlib/public/core/Array.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1982,6 +1982,11 @@ extension Array where Element : Differentiable {
base[i].move(along: direction.base[i])
}
}

public var zeroTangentVector: TangentVector {
TangentVector(Array<Element.TangentVector>(repeating: .zero,
count: base.count))
}
}
}

Expand Down Expand Up @@ -2093,6 +2098,10 @@ extension Array : Differentiable where Element : Differentiable {
view.move(along: direction)
self = view.base
}

public var zeroTangentVector: TangentVector {
TangentVector(Array<Element.TangentVector>(repeating: .zero, count: count))
}
}

extension Array where Element : Differentiable {
Expand Down
28 changes: 26 additions & 2 deletions stdlib/public/core/AutoDiff.swift
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ public extension VectorProtocol where VectorSpaceScalar : SignedNumeric {
/// A type that mathematically represents a differentiable manifold whose
/// tangent spaces are finite-dimensional.
public protocol Differentiable {
/// A type representing a differentiable value’s derivatives.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big 👍 on these doc comments by the way 🙂

/// Mathematically, this is equivalent to the tangent bundle of the
/// differentiable manifold represented by the differentiable type.
associatedtype TangentVector: Differentiable & AdditiveArithmetic
where TangentVector.TangentVector == TangentVector,
AllDifferentiableVariables.AllDifferentiableVariables ==
Expand All @@ -163,10 +166,20 @@ public protocol Differentiable {
/// All differentiable variables of this value.
var allDifferentiableVariables: AllDifferentiableVariables { get set }

/// Moves `self` along the value space towards the given tangent vector. In
/// Riemannian geometry (mathematics), this represents an exponential map.
/// Moves `self` along the given direction. In Riemannian geometry,
/// this is equivalent to exponential map, which moves `self` on the
/// geodesic surface along the given tangent vector.
mutating func move(along direction: TangentVector)

/// A tangent vector such that `move(along: zeroTangentVector)` will not
/// modify `self`.
///
/// - Note: `zeroTangentVector` can be `TangentVector.zero` in most cases,
/// but types whose tangent vectors depend on instance properties of
/// `self` need to provide a different implementation. For example, an
/// array’s zero tangent vector depends on the array’s `count`.
var zeroTangentVector: TangentVector { get }

@available(*, deprecated,
message: "'CotangentVector' is now equal to 'TangentVector' and will be removed")
typealias CotangentVector = TangentVector
Expand All @@ -183,6 +196,9 @@ public extension Differentiable where TangentVector == Self {
mutating func move(along direction: TangentVector) {
self += direction
}
var zeroTangentVector: TangentVector {
return .zero
}
}

/// Returns `x` like an identity function. When used in a context where `x` is
Expand Down Expand Up @@ -725,6 +741,7 @@ internal protocol _AnyDerivativeBox {
// `Differentiable` requirements.
var _allDifferentiableVariables: _AnyDerivativeBox { get }
mutating func _move(along direction: _AnyDerivativeBox)
var _zeroTangentVector: _AnyDerivativeBox { get }

/// The underlying base value, type-erased to `Any`.
var _typeErasedBase: Any { get }
Expand Down Expand Up @@ -840,6 +857,10 @@ internal struct _ConcreteDerivativeBox<T> : _AnyDerivativeBox
}
_base.move(along: directionBase)
}

var _zeroTangentVector: _AnyDerivativeBox {
return _ConcreteDerivativeBox(_base.zeroTangentVector)
}
}

/// A type-erased derivative value.
Expand Down Expand Up @@ -940,6 +961,9 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
}
_box._move(along: direction._box)
}
public var zeroTangentVector: TangentVector {
AnyDerivative(_box: _box._zeroTangentVector)
}
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 2 additions & 0 deletions stdlib/public/core/FloatingPointTypes.swift.gyb
Original file line number Diff line number Diff line change
Expand Up @@ -1911,6 +1911,8 @@ extension ${Self} : Differentiable {
public mutating func move(along direction: TangentVector) {
self += direction
}

public var zeroTangentVector: TangentVector { .zero }
}

//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions stdlib/public/core/SIMDVectorTypes.swift.gyb
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ extension SIMD${n} : Differentiable
public func tangentVector(from cotangent: TangentVector) -> TangentVector {
return cotangent
}
public var zeroTangentVector: TangentVector { .zero }
}

extension SIMD${n}
Expand Down
14 changes: 6 additions & 8 deletions test/Sema/struct_differentiable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ struct VectorSpaceTypeAlias : AdditiveArithmetic, Differentiable {
var w: Float
var b: Float
typealias TangentVector = Simple
var zeroTangentVector: TangentVector { .zero }
}
// expected-error @+2 {{type 'VectorSpaceCustomStruct' does not conform to protocol 'Differentiable'}}
// expected-note @+1 {{do you want to add protocol stubs?}}
Expand All @@ -331,6 +332,7 @@ struct VectorSpaceCustomStruct : AdditiveArithmetic, Differentiable {
var b: Float.TangentVector
typealias TangentVector = VectorSpaceCustomStruct.TangentVector
}
var zeroTangentVector: TangentVector { .zero }
}

struct StaticNoDerivative : Differentiable {
Expand Down Expand Up @@ -371,14 +373,10 @@ extension NoMemberwiseInitializerExtended: Differentiable

// Test derived conformances in disallowed contexts.

// expected-error @+4 {{type 'OtherFileNonconforming' does not conform to protocol 'Differentiable'}}
// expected-error @+3 {{implementation of 'Differentiable' cannot be automatically synthesized in an extension in a different file to the type}}
// expected-note @+2 {{do you want to add protocol stubs?}}
// expected-note @+1 {{do you want to add protocol stubs?}}
// expected-error @+2 {{type 'OtherFileNonconforming' does not conform to protocol 'Differentiable'}}
// expected-error @+1 {{implementation of 'Differentiable' cannot be automatically synthesized in an extension in a different file to the type}}
extension OtherFileNonconforming : Differentiable {}

// expected-error @+4 {{type 'GenericOtherFileNonconforming<T>' does not conform to protocol 'Differentiable'}}
// expected-error @+3 {{implementation of 'Differentiable' cannot be automatically synthesized in an extension in a different file to the type}}
// expected-note @+2 {{do you want to add protocol stubs?}}
// expected-note @+1 {{do you want to add protocol stubs?}}
// expected-error @+2 {{type 'GenericOtherFileNonconforming<T>' does not conform to protocol 'Differentiable'}}
// expected-error @+1 {{implementation of 'Differentiable' cannot be automatically synthesized in an extension in a different file to the type}}
extension GenericOtherFileNonconforming : Differentiable {}