Skip to content

Commit 6a334f7

Browse files
committed
[AutoDiff] Add 'zeroTangentVector' property to 'Differentiable' protocol.
Zero tangent vector is necessary for optimizations on models with an array of parameters, especially for optimizers that iterates over parameters using key paths. The [current implementation](https://github.com/tensorflow/swift-apis/blob/master/Sources/TensorFlow/Optimizers/MomentumBased.swift) of some key-path-based optimizers is wrong in that it won't work with models that contain an array of parameters (tangent vectors like `infinityNorm` are initialized as `.zero`). An earlier version of these optimizer using the deprecated `AllDifferentiableVariables` property would give the correct results, but would be heavyweight and inefficient because they'd need to 1. add a constraint `TangentVector == AllDifferentiableVariables` to optimizers, and 2. make a copy of all parameters and resetting them to `.zero`. Since we are deprecating `AllDifferentiableVariables`, this is not the right direction. This problem also means that our `Differentiable` abstraction needs to provide a general mechanism of obtaining a zero tangent vector at a certain instance. Hence we add a `zeroTangentVector` property to the `Differentiable` protocol. Zero tangent vectors do not have a canonical mathematical definition, but makes sense for `Differentiable` in the standard library because Swift does not have dependent types and thus cannot have a `TangentVector` that depends on a point on a differentiable manifold. Manopt also has an API, `M.zerovec(x)`, that creates a zero tangent vector at a point (see their API doc [here](https://www.manopt.org/tutorial.html). Adding `zeroTangentVector` will make it possible to deprecate `AllDifferentiableVariables` completely, because currently some fast.ai notebooks depend on initializing parameter gradients using `AllDifferentiableVariables`. The new `Differentiable` protocol looks like the following. The [design overview](http://bit.ly/swift-autodiff) has been updated to reflect this change. ```swift protocol Differentiable { /// A type representing the differentiable value’s derivatives. /// Mathematically, this is equivalent to the tangent bundle of the /// differentiable manifold represented by the differentiable type. associatedtype TangentVector: Differentiable & AdditiveArithmetic /// Moves `self` along the given direction. In Riemannian geometry, /// this is equivalent to exponential map, which moves `self` on the /// geodesic surface along the given tangent vector. mutating func move(along direction: TangentVector) /// A tangent vector such that `move(along: zeroTangentVector)` /// will not modify `self`. /// - Note: `zeroTangentVector` can be `TangentVector.zero` in most cases, /// but types whose tangent vectors depend on instance properties of /// `self` need to provide a different implementation. For example, an /// array’s zero tangent vector depends on the array’s `count`. var zeroTangentVector: TangentVector { get } } ```
1 parent 499c875 commit 6a334f7

File tree

8 files changed

+186
-10
lines changed

8 files changed

+186
-10
lines changed

include/swift/AST/KnownIdentifiers.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ IDENTIFIER(AllDifferentiableVariables)
159159
IDENTIFIER(TangentVector)
160160
IDENTIFIER(allDifferentiableVariables)
161161
IDENTIFIER(move)
162+
IDENTIFIER(zeroTangentVector)
162163

163164
// Kinds of layout constraints
164165
IDENTIFIER_WITH_NAME(UnknownLayout, "_UnknownLayout")

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,105 @@ static ValueDecl *getUnderlyingAllDiffableVariables(DeclContext *DC,
380380
return allDiffableVarsDecl;
381381
}
382382

383+
// Return the underlying `zeroTangentVector` of a VarDecl `x`. If `x` conforms
384+
// to `Differentiable`, return `zeroTangentVector`. Otherwise, return
385+
// `x`.
386+
static ValueDecl *getUnderlyingZeroTangentVector(DeclContext *DC,
387+
VarDecl *varDecl) {
388+
auto *module = DC->getParentModule();
389+
auto &C = module->getASTContext();
390+
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
391+
auto *zeroTangentVectorReq =
392+
getProtocolRequirement(diffableProto, C.Id_zeroTangentVector);
393+
if (!varDecl->hasInterfaceType())
394+
C.getLazyResolver()->resolveDeclSignature(varDecl);
395+
auto varType = DC->mapTypeIntoContext(varDecl->getValueInterfaceType());
396+
auto confRef = module->lookupConformance(varType, diffableProto);
397+
if (!confRef)
398+
return varDecl;
399+
// Use protocol requirement as a default for abstract conformances.
400+
// If conformance is concrete, get concrete witness declaration instead.
401+
ValueDecl *zeroTangentVectorDecl = zeroTangentVectorReq;
402+
if (confRef->isConcrete())
403+
zeroTangentVectorDecl = confRef->getConcrete()->getWitnessDecl(
404+
zeroTangentVectorReq);
405+
return zeroTangentVectorDecl;
406+
}
407+
408+
// Get the effective memberwise initializer of the given nominal type, or create
409+
// it if it does not exist.
410+
static ConstructorDecl *getOrCreateEffectiveMemberwiseInitializer(
411+
TypeChecker &TC, NominalTypeDecl *nominal) {
412+
auto &C = nominal->getASTContext();
413+
if (auto *initDecl = nominal->getEffectiveMemberwiseInitializer())
414+
return initDecl;
415+
auto *initDecl = createImplicitConstructor(
416+
TC, nominal, ImplicitConstructorKind::Memberwise);
417+
nominal->addMember(initDecl);
418+
C.addSynthesizedDecl(initDecl);
419+
return initDecl;
420+
}
421+
422+
// Synthesize getter body for `zeroTangentVector` computed property.
423+
static std::pair<BraceStmt *, bool>
424+
derivedBody_zeroTangentVectorGetter(AbstractFunctionDecl *getterDecl, void *) {
425+
auto *parentDC = getterDecl->getParent();
426+
auto *nominal = parentDC->getSelfNominalTypeDecl();
427+
auto &C = nominal->getASTContext();
428+
429+
auto *tangentVectorStruct =
430+
getAssociatedStructDecl(parentDC, C.Id_TangentVector);
431+
auto *tangentVectorInitDecl =
432+
tangentVectorStruct->getEffectiveMemberwiseInitializer();
433+
assert(tangentVectorInitDecl &&
434+
"'TangentVector' memberwise initializer not found");
435+
auto *selfDecl = getterDecl->getImplicitSelfDecl();
436+
auto *selfDRE =
437+
new (C) DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);
438+
439+
auto *initDRE = new (C) DeclRefExpr(tangentVectorInitDecl, DeclNameLoc(),
440+
/*Implicit*/ true);
441+
initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
442+
443+
auto tangentVectorType = parentDC->mapTypeIntoContext(
444+
tangentVectorStruct->getDeclaredInterfaceType());
445+
Expr *baseExpr = TypeExpr::createImplicit(tangentVectorType, C);
446+
auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, baseExpr);
447+
initExpr->setThrows(false);
448+
initExpr->setImplicit();
449+
450+
SmallVector<Expr *, 2> members;
451+
SmallVector<Identifier, 2> memberNames;
452+
453+
llvm::DenseMap<Identifier, VarDecl *> diffPropertyMap;
454+
SmallVector<VarDecl *, 8> diffProperties;
455+
getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties);
456+
for (auto *member : diffProperties)
457+
diffPropertyMap[member->getName()] = member;
458+
459+
for (auto initParam : *tangentVectorInitDecl->getParameters()) {
460+
auto member = diffPropertyMap[initParam->getName()];
461+
member->setInterfaceType(member->getValueInterfaceType());
462+
Expr *memberExpr = new (C) MemberRefExpr(selfDRE, SourceLoc(), member,
463+
DeclNameLoc(), /*Implicit*/ true);
464+
auto *memberZeroTangentVector =
465+
getUnderlyingZeroTangentVector(parentDC, member);
466+
auto *memberZeroTangentExpr =
467+
new (C) MemberRefExpr(memberExpr, SourceLoc(), memberZeroTangentVector,
468+
DeclNameLoc(), /*Implicit*/ true);
469+
members.push_back(memberZeroTangentExpr);
470+
memberNames.push_back(member->getName());
471+
}
472+
Expr *callExpr = CallExpr::createImplicit(C, initExpr, members, memberNames);
473+
474+
ASTNode returnStmt =
475+
new (C) ReturnStmt(SourceLoc(), callExpr, /*Implicit*/ true);
476+
auto *braceStmt =
477+
BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(),
478+
/*Implicit*/ true);
479+
return std::make_pair(braceStmt, false);
480+
}
481+
383482
// Synthesize getter body for `allDifferentiableVariables` computed property.
384483
static std::pair<BraceStmt *, bool>
385484
derivedBody_allDifferentiableVariablesGetter(AbstractFunctionDecl *getterDecl,
@@ -537,6 +636,41 @@ deriveDifferentiable_allDifferentiableVariables(DerivedConformance &derived) {
537636
return allDiffableVarsDecl;
538637
}
539638

639+
// Synthesize `zeroTangentVector` computed property declaration.
640+
static ValueDecl *
641+
deriveDifferentiable_zeroTangentVector(DerivedConformance &derived) {
642+
auto *parentDC = derived.getConformanceContext();
643+
auto &TC = derived.TC;
644+
auto &C = TC.Context;
645+
646+
// Get `TangentVector` struct.
647+
auto *tangentVectorStruct =
648+
getAssociatedStructDecl(parentDC, C.Id_TangentVector);
649+
// Make sure a memberwise initializer exists because the body synthesizer
650+
// needs it.
651+
getOrCreateEffectiveMemberwiseInitializer(TC, tangentVectorStruct);
652+
653+
auto returnInterfaceTy = tangentVectorStruct->getDeclaredInterfaceType();
654+
auto returnTy = parentDC->mapTypeIntoContext(returnInterfaceTy);
655+
656+
VarDecl *zeroTangentVectorDecl;
657+
PatternBindingDecl *pbDecl;
658+
std::tie(zeroTangentVectorDecl, pbDecl) = derived.declareDerivedProperty(
659+
C.Id_zeroTangentVector, returnInterfaceTy, returnTy,
660+
/*isStatic*/ false, /*isFinal*/ true);
661+
662+
auto *getterDecl = derived.addGetterToReadOnlyDerivedProperty(
663+
zeroTangentVectorDecl, returnTy);
664+
getterDecl->setBodySynthesizer(&derivedBody_zeroTangentVectorGetter);
665+
derived.addMembersToConformanceContext(
666+
{getterDecl, zeroTangentVectorDecl, pbDecl});
667+
668+
addExpectedOpaqueAccessorsToStorage(zeroTangentVectorDecl, C);
669+
triggerAccessorSynthesis(TC, zeroTangentVectorDecl);
670+
671+
return zeroTangentVectorDecl;
672+
}
673+
540674
// Return associated `TangentVector` or `AllDifferentiableVariables` struct for
541675
// a nominal type, if it exists.
542676
// If not, synthesize the struct. Also return a Boolean value that indicates
@@ -1035,6 +1169,8 @@ ValueDecl *DerivedConformance::deriveDifferentiable(ValueDecl *requirement) {
10351169
return deriveDifferentiable_move(*this);
10361170
if (requirement->getBaseName() == TC.Context.Id_allDifferentiableVariables)
10371171
return deriveDifferentiable_allDifferentiableVariables(*this);
1172+
if (requirement->getBaseName() == TC.Context.Id_zeroTangentVector)
1173+
return deriveDifferentiable_zeroTangentVector(*this);
10381174
TC.diagnose(requirement->getLoc(), diag::broken_differentiable_requirement);
10391175
return nullptr;
10401176
}

lib/Sema/DerivedConformances.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
290290
if (name.isSimpleName(ctx.Id_allDifferentiableVariables))
291291
return getRequirement(KnownProtocolKind::Differentiable);
292292

293+
// SWIFT_ENABLE_TENSORFLOW
294+
// Differentiable.zeroTangentVector
295+
if (name.isSimpleName(ctx.Id_zeroTangentVector))
296+
return getRequirement(KnownProtocolKind::Differentiable);
297+
293298
return nullptr;
294299
}
295300

stdlib/public/core/Array.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1982,6 +1982,11 @@ extension Array where Element : Differentiable {
19821982
base[i].move(along: direction.base[i])
19831983
}
19841984
}
1985+
1986+
public var zeroTangentVector: TangentVector {
1987+
TangentVector(Array<Element.TangentVector>(repeating: .zero,
1988+
count: base.count))
1989+
}
19851990
}
19861991
}
19871992

@@ -2093,6 +2098,10 @@ extension Array : Differentiable where Element : Differentiable {
20932098
view.move(along: direction)
20942099
self = view.base
20952100
}
2101+
2102+
public var zeroTangentVector: TangentVector {
2103+
TangentVector(Array<Element.TangentVector>(repeating: .zero, count: count))
2104+
}
20962105
}
20972106

20982107
extension Array where Element : Differentiable {

stdlib/public/core/AutoDiff.swift

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ public extension VectorProtocol where VectorSpaceScalar : SignedNumeric {
152152
/// A type that mathematically represents a differentiable manifold whose
153153
/// tangent spaces are finite-dimensional.
154154
public protocol Differentiable {
155+
/// A type representing a differentiable value’s derivatives.
156+
/// Mathematically, this is equivalent to the tangent bundle of the
157+
/// differentiable manifold represented by the differentiable type.
155158
associatedtype TangentVector: Differentiable & AdditiveArithmetic
156159
where TangentVector.TangentVector == TangentVector,
157160
AllDifferentiableVariables.AllDifferentiableVariables ==
@@ -163,10 +166,20 @@ public protocol Differentiable {
163166
/// All differentiable variables of this value.
164167
var allDifferentiableVariables: AllDifferentiableVariables { get set }
165168

166-
/// Moves `self` along the value space towards the given tangent vector. In
167-
/// Riemannian geometry (mathematics), this represents an exponential map.
169+
/// Moves `self` along the given direction. In Riemannian geometry,
170+
/// this is equivalent to exponential map, which moves `self` on the
171+
/// geodesic surface along the given tangent vector.
168172
mutating func move(along direction: TangentVector)
169173

174+
/// A tangent vector such that `move(along: zeroTangentVector)` will not
175+
/// modify `self`.
176+
///
177+
/// - Note: `zeroTangentVector` can be `TangentVector.zero` in most cases,
178+
/// but types whose tangent vectors depend on instance properties of
179+
/// `self` need to provide a different implementation. For example, an
180+
/// array’s zero tangent vector depends on the array’s `count`.
181+
var zeroTangentVector: TangentVector { get }
182+
170183
@available(*, deprecated,
171184
message: "'CotangentVector' is now equal to 'TangentVector' and will be removed")
172185
typealias CotangentVector = TangentVector
@@ -183,6 +196,9 @@ public extension Differentiable where TangentVector == Self {
183196
mutating func move(along direction: TangentVector) {
184197
self += direction
185198
}
199+
var zeroTangentVector: TangentVector {
200+
return .zero
201+
}
186202
}
187203

188204
/// Returns `x` like an identity function. When used in a context where `x` is
@@ -725,6 +741,7 @@ internal protocol _AnyDerivativeBox {
725741
// `Differentiable` requirements.
726742
var _allDifferentiableVariables: _AnyDerivativeBox { get }
727743
mutating func _move(along direction: _AnyDerivativeBox)
744+
var _zeroTangentVector: _AnyDerivativeBox { get }
728745

729746
/// The underlying base value, type-erased to `Any`.
730747
var _typeErasedBase: Any { get }
@@ -840,6 +857,10 @@ internal struct _ConcreteDerivativeBox<T> : _AnyDerivativeBox
840857
}
841858
_base.move(along: directionBase)
842859
}
860+
861+
var _zeroTangentVector: _AnyDerivativeBox {
862+
return _ConcreteDerivativeBox(_base.zeroTangentVector)
863+
}
843864
}
844865

845866
/// A type-erased derivative value.
@@ -940,6 +961,9 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
940961
}
941962
_box._move(along: direction._box)
942963
}
964+
public var zeroTangentVector: TangentVector {
965+
AnyDerivative(_box: _box._zeroTangentVector)
966+
}
943967
}
944968

945969
//===----------------------------------------------------------------------===//

stdlib/public/core/FloatingPointTypes.swift.gyb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,6 +1911,8 @@ extension ${Self} : Differentiable {
19111911
public mutating func move(along direction: TangentVector) {
19121912
self += direction
19131913
}
1914+
1915+
public var zeroTangentVector: TangentVector { .zero }
19141916
}
19151917

19161918
//===----------------------------------------------------------------------===//

stdlib/public/core/SIMDVectorTypes.swift.gyb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ extension SIMD${n} : Differentiable
200200
public func tangentVector(from cotangent: TangentVector) -> TangentVector {
201201
return cotangent
202202
}
203+
public var zeroTangentVector: TangentVector { .zero }
203204
}
204205

205206
extension SIMD${n}

test/Sema/struct_differentiable.swift

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ struct VectorSpaceTypeAlias : AdditiveArithmetic, Differentiable {
320320
var w: Float
321321
var b: Float
322322
typealias TangentVector = Simple
323+
var zeroTangentVector: TangentVector { .zero }
323324
}
324325
// expected-error @+2 {{type 'VectorSpaceCustomStruct' does not conform to protocol 'Differentiable'}}
325326
// expected-note @+1 {{do you want to add protocol stubs?}}
@@ -331,6 +332,7 @@ struct VectorSpaceCustomStruct : AdditiveArithmetic, Differentiable {
331332
var b: Float.TangentVector
332333
typealias TangentVector = VectorSpaceCustomStruct.TangentVector
333334
}
335+
var zeroTangentVector: TangentVector { .zero }
334336
}
335337

336338
struct StaticNoDerivative : Differentiable {
@@ -371,14 +373,10 @@ extension NoMemberwiseInitializerExtended: Differentiable
371373

372374
// Test derived conformances in disallowed contexts.
373375

374-
// expected-error @+4 {{type 'OtherFileNonconforming' does not conform to protocol 'Differentiable'}}
375-
// expected-error @+3 {{implementation of 'Differentiable' cannot be automatically synthesized in an extension in a different file to the type}}
376-
// expected-note @+2 {{do you want to add protocol stubs?}}
377-
// expected-note @+1 {{do you want to add protocol stubs?}}
376+
// expected-error @+2 {{type 'OtherFileNonconforming' does not conform to protocol 'Differentiable'}}
377+
// expected-error @+1 {{implementation of 'Differentiable' cannot be automatically synthesized in an extension in a different file to the type}}
378378
extension OtherFileNonconforming : Differentiable {}
379379

380-
// expected-error @+4 {{type 'GenericOtherFileNonconforming<T>' does not conform to protocol 'Differentiable'}}
381-
// expected-error @+3 {{implementation of 'Differentiable' cannot be automatically synthesized in an extension in a different file to the type}}
382-
// expected-note @+2 {{do you want to add protocol stubs?}}
383-
// expected-note @+1 {{do you want to add protocol stubs?}}
380+
// expected-error @+2 {{type 'GenericOtherFileNonconforming<T>' does not conform to protocol 'Differentiable'}}
381+
// expected-error @+1 {{implementation of 'Differentiable' cannot be automatically synthesized in an extension in a different file to the type}}
384382
extension GenericOtherFileNonconforming : Differentiable {}

0 commit comments

Comments
 (0)