diff --git a/include/swift/AST/KnownIdentifiers.def b/include/swift/AST/KnownIdentifiers.def index 4c566d0ab44ca..4168af53c6b0e 100644 --- a/include/swift/AST/KnownIdentifiers.def +++ b/include/swift/AST/KnownIdentifiers.def @@ -159,6 +159,7 @@ IDENTIFIER(AllDifferentiableVariables) IDENTIFIER(TangentVector) IDENTIFIER(allDifferentiableVariables) IDENTIFIER(move) +IDENTIFIER(zeroTangentVector) // Kinds of layout constraints IDENTIFIER_WITH_NAME(UnknownLayout, "_UnknownLayout") diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index 41436fbb521ec..4fe8d4d904304 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -380,6 +380,105 @@ static ValueDecl *getUnderlyingAllDiffableVariables(DeclContext *DC, return allDiffableVarsDecl; } +// Return the underlying `zeroTangentVector` of a VarDecl `x`. If `x` conforms +// to `Differentiable`, return `zeroTangentVector`. Otherwise, return +// `x`. +static ValueDecl *getUnderlyingZeroTangentVector(DeclContext *DC, + VarDecl *varDecl) { + auto *module = DC->getParentModule(); + auto &C = module->getASTContext(); + auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); + auto *zeroTangentVectorReq = + getProtocolRequirement(diffableProto, C.Id_zeroTangentVector); + if (!varDecl->hasInterfaceType()) + C.getLazyResolver()->resolveDeclSignature(varDecl); + auto varType = DC->mapTypeIntoContext(varDecl->getValueInterfaceType()); + auto confRef = module->lookupConformance(varType, diffableProto); + if (!confRef) + return varDecl; + // Use protocol requirement as a default for abstract conformances. + // If conformance is concrete, get concrete witness declaration instead. + ValueDecl *zeroTangentVectorDecl = zeroTangentVectorReq; + if (confRef->isConcrete()) + zeroTangentVectorDecl = confRef->getConcrete()->getWitnessDecl( + zeroTangentVectorReq); + return zeroTangentVectorDecl; +} + +// Get the effective memberwise initializer of the given nominal type, or create +// it if it does not exist. +static ConstructorDecl *getOrCreateEffectiveMemberwiseInitializer( + TypeChecker &TC, NominalTypeDecl *nominal) { + auto &C = nominal->getASTContext(); + if (auto *initDecl = nominal->getEffectiveMemberwiseInitializer()) + return initDecl; + auto *initDecl = createImplicitConstructor( + TC, nominal, ImplicitConstructorKind::Memberwise); + nominal->addMember(initDecl); + C.addSynthesizedDecl(initDecl); + return initDecl; +} + +// Synthesize getter body for `zeroTangentVector` computed property. +static std::pair +derivedBody_zeroTangentVectorGetter(AbstractFunctionDecl *getterDecl, void *) { + auto *parentDC = getterDecl->getParent(); + auto *nominal = parentDC->getSelfNominalTypeDecl(); + auto &C = nominal->getASTContext(); + + auto *tangentVectorStruct = + getAssociatedStructDecl(parentDC, C.Id_TangentVector); + auto *tangentVectorInitDecl = + tangentVectorStruct->getEffectiveMemberwiseInitializer(); + assert(tangentVectorInitDecl && + "'TangentVector' memberwise initializer not found"); + auto *selfDecl = getterDecl->getImplicitSelfDecl(); + auto *selfDRE = + new (C) DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true); + + auto *initDRE = new (C) DeclRefExpr(tangentVectorInitDecl, DeclNameLoc(), + /*Implicit*/ true); + initDRE->setFunctionRefKind(FunctionRefKind::SingleApply); + + auto tangentVectorType = parentDC->mapTypeIntoContext( + tangentVectorStruct->getDeclaredInterfaceType()); + Expr *baseExpr = TypeExpr::createImplicit(tangentVectorType, C); + auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, baseExpr); + initExpr->setThrows(false); + initExpr->setImplicit(); + + SmallVector members; + SmallVector memberNames; + + llvm::DenseMap diffPropertyMap; + SmallVector diffProperties; + getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties); + for (auto *member : diffProperties) + diffPropertyMap[member->getName()] = member; + + for (auto initParam : *tangentVectorInitDecl->getParameters()) { + auto member = diffPropertyMap[initParam->getName()]; + member->setInterfaceType(member->getValueInterfaceType()); + Expr *memberExpr = new (C) MemberRefExpr(selfDRE, SourceLoc(), member, + DeclNameLoc(), /*Implicit*/ true); + auto *memberZeroTangentVector = + getUnderlyingZeroTangentVector(parentDC, member); + auto *memberZeroTangentExpr = + new (C) MemberRefExpr(memberExpr, SourceLoc(), memberZeroTangentVector, + DeclNameLoc(), /*Implicit*/ true); + members.push_back(memberZeroTangentExpr); + memberNames.push_back(member->getName()); + } + Expr *callExpr = CallExpr::createImplicit(C, initExpr, members, memberNames); + + ASTNode returnStmt = + new (C) ReturnStmt(SourceLoc(), callExpr, /*Implicit*/ true); + auto *braceStmt = + BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), + /*Implicit*/ true); + return std::make_pair(braceStmt, false); +} + // Synthesize getter body for `allDifferentiableVariables` computed property. static std::pair derivedBody_allDifferentiableVariablesGetter(AbstractFunctionDecl *getterDecl, @@ -537,6 +636,41 @@ deriveDifferentiable_allDifferentiableVariables(DerivedConformance &derived) { return allDiffableVarsDecl; } +// Synthesize `zeroTangentVector` computed property declaration. +static ValueDecl * +deriveDifferentiable_zeroTangentVector(DerivedConformance &derived) { + auto *parentDC = derived.getConformanceContext(); + auto &TC = derived.TC; + auto &C = TC.Context; + + // Get `TangentVector` struct. + auto *tangentVectorStruct = + getAssociatedStructDecl(parentDC, C.Id_TangentVector); + // Make sure a memberwise initializer exists because the body synthesizer + // needs it. + getOrCreateEffectiveMemberwiseInitializer(TC, tangentVectorStruct); + + auto returnInterfaceTy = tangentVectorStruct->getDeclaredInterfaceType(); + auto returnTy = parentDC->mapTypeIntoContext(returnInterfaceTy); + + VarDecl *zeroTangentVectorDecl; + PatternBindingDecl *pbDecl; + std::tie(zeroTangentVectorDecl, pbDecl) = derived.declareDerivedProperty( + C.Id_zeroTangentVector, returnInterfaceTy, returnTy, + /*isStatic*/ false, /*isFinal*/ true); + + auto *getterDecl = derived.addGetterToReadOnlyDerivedProperty( + zeroTangentVectorDecl, returnTy); + getterDecl->setBodySynthesizer(&derivedBody_zeroTangentVectorGetter); + derived.addMembersToConformanceContext( + {getterDecl, zeroTangentVectorDecl, pbDecl}); + + addExpectedOpaqueAccessorsToStorage(zeroTangentVectorDecl, C); + triggerAccessorSynthesis(TC, zeroTangentVectorDecl); + + return zeroTangentVectorDecl; +} + // Return associated `TangentVector` or `AllDifferentiableVariables` struct for // a nominal type, if it exists. // If not, synthesize the struct. Also return a Boolean value that indicates @@ -1035,6 +1169,8 @@ ValueDecl *DerivedConformance::deriveDifferentiable(ValueDecl *requirement) { return deriveDifferentiable_move(*this); if (requirement->getBaseName() == TC.Context.Id_allDifferentiableVariables) return deriveDifferentiable_allDifferentiableVariables(*this); + if (requirement->getBaseName() == TC.Context.Id_zeroTangentVector) + return deriveDifferentiable_zeroTangentVector(*this); TC.diagnose(requirement->getLoc(), diag::broken_differentiable_requirement); return nullptr; } diff --git a/lib/Sema/DerivedConformances.cpp b/lib/Sema/DerivedConformances.cpp index cca10754df9ee..ffe029bb58e2d 100644 --- a/lib/Sema/DerivedConformances.cpp +++ b/lib/Sema/DerivedConformances.cpp @@ -290,6 +290,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal, if (name.isSimpleName(ctx.Id_allDifferentiableVariables)) return getRequirement(KnownProtocolKind::Differentiable); + // SWIFT_ENABLE_TENSORFLOW + // Differentiable.zeroTangentVector + if (name.isSimpleName(ctx.Id_zeroTangentVector)) + return getRequirement(KnownProtocolKind::Differentiable); + return nullptr; } diff --git a/stdlib/public/core/Array.swift b/stdlib/public/core/Array.swift index 81f66898aa5d2..6a9c2fa1a3b29 100644 --- a/stdlib/public/core/Array.swift +++ b/stdlib/public/core/Array.swift @@ -1982,6 +1982,11 @@ extension Array where Element : Differentiable { base[i].move(along: direction.base[i]) } } + + public var zeroTangentVector: TangentVector { + TangentVector(Array(repeating: .zero, + count: base.count)) + } } } @@ -2093,6 +2098,10 @@ extension Array : Differentiable where Element : Differentiable { view.move(along: direction) self = view.base } + + public var zeroTangentVector: TangentVector { + TangentVector(Array(repeating: .zero, count: count)) + } } extension Array where Element : Differentiable { diff --git a/stdlib/public/core/AutoDiff.swift b/stdlib/public/core/AutoDiff.swift index fa799a261362e..ba36ac317b89c 100644 --- a/stdlib/public/core/AutoDiff.swift +++ b/stdlib/public/core/AutoDiff.swift @@ -152,6 +152,9 @@ public extension VectorProtocol where VectorSpaceScalar : SignedNumeric { /// A type that mathematically represents a differentiable manifold whose /// tangent spaces are finite-dimensional. public protocol Differentiable { + /// A type representing a differentiable value’s derivatives. + /// Mathematically, this is equivalent to the tangent bundle of the + /// differentiable manifold represented by the differentiable type. associatedtype TangentVector: Differentiable & AdditiveArithmetic where TangentVector.TangentVector == TangentVector, AllDifferentiableVariables.AllDifferentiableVariables == @@ -163,10 +166,20 @@ public protocol Differentiable { /// All differentiable variables of this value. var allDifferentiableVariables: AllDifferentiableVariables { get set } - /// Moves `self` along the value space towards the given tangent vector. In - /// Riemannian geometry (mathematics), this represents an exponential map. + /// Moves `self` along the given direction. In Riemannian geometry, + /// this is equivalent to exponential map, which moves `self` on the + /// geodesic surface along the given tangent vector. mutating func move(along direction: TangentVector) + /// A tangent vector such that `move(along: zeroTangentVector)` will not + /// modify `self`. + /// + /// - Note: `zeroTangentVector` can be `TangentVector.zero` in most cases, + /// but types whose tangent vectors depend on instance properties of + /// `self` need to provide a different implementation. For example, an + /// array’s zero tangent vector depends on the array’s `count`. + var zeroTangentVector: TangentVector { get } + @available(*, deprecated, message: "'CotangentVector' is now equal to 'TangentVector' and will be removed") typealias CotangentVector = TangentVector @@ -183,6 +196,9 @@ public extension Differentiable where TangentVector == Self { mutating func move(along direction: TangentVector) { self += direction } + var zeroTangentVector: TangentVector { + return .zero + } } /// Returns `x` like an identity function. When used in a context where `x` is @@ -725,6 +741,7 @@ internal protocol _AnyDerivativeBox { // `Differentiable` requirements. var _allDifferentiableVariables: _AnyDerivativeBox { get } mutating func _move(along direction: _AnyDerivativeBox) + var _zeroTangentVector: _AnyDerivativeBox { get } /// The underlying base value, type-erased to `Any`. var _typeErasedBase: Any { get } @@ -840,6 +857,10 @@ internal struct _ConcreteDerivativeBox : _AnyDerivativeBox } _base.move(along: directionBase) } + + var _zeroTangentVector: _AnyDerivativeBox { + return _ConcreteDerivativeBox(_base.zeroTangentVector) + } } /// A type-erased derivative value. @@ -940,6 +961,9 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic { } _box._move(along: direction._box) } + public var zeroTangentVector: TangentVector { + AnyDerivative(_box: _box._zeroTangentVector) + } } //===----------------------------------------------------------------------===// diff --git a/stdlib/public/core/FloatingPointTypes.swift.gyb b/stdlib/public/core/FloatingPointTypes.swift.gyb index f7eb885059c08..e202d7aa7f4cd 100644 --- a/stdlib/public/core/FloatingPointTypes.swift.gyb +++ b/stdlib/public/core/FloatingPointTypes.swift.gyb @@ -1911,6 +1911,8 @@ extension ${Self} : Differentiable { public mutating func move(along direction: TangentVector) { self += direction } + + public var zeroTangentVector: TangentVector { .zero } } //===----------------------------------------------------------------------===// diff --git a/stdlib/public/core/SIMDVectorTypes.swift.gyb b/stdlib/public/core/SIMDVectorTypes.swift.gyb index 8f7f5fb71108f..fe3a94c592fe3 100644 --- a/stdlib/public/core/SIMDVectorTypes.swift.gyb +++ b/stdlib/public/core/SIMDVectorTypes.swift.gyb @@ -200,6 +200,7 @@ extension SIMD${n} : Differentiable public func tangentVector(from cotangent: TangentVector) -> TangentVector { return cotangent } + public var zeroTangentVector: TangentVector { .zero } } extension SIMD${n} diff --git a/test/Sema/struct_differentiable.swift b/test/Sema/struct_differentiable.swift index fab95016d0c42..2c32937884f9e 100644 --- a/test/Sema/struct_differentiable.swift +++ b/test/Sema/struct_differentiable.swift @@ -320,6 +320,7 @@ struct VectorSpaceTypeAlias : AdditiveArithmetic, Differentiable { var w: Float var b: Float typealias TangentVector = Simple + var zeroTangentVector: TangentVector { .zero } } // expected-error @+2 {{type 'VectorSpaceCustomStruct' does not conform to protocol 'Differentiable'}} // expected-note @+1 {{do you want to add protocol stubs?}} @@ -331,6 +332,7 @@ struct VectorSpaceCustomStruct : AdditiveArithmetic, Differentiable { var b: Float.TangentVector typealias TangentVector = VectorSpaceCustomStruct.TangentVector } + var zeroTangentVector: TangentVector { .zero } } struct StaticNoDerivative : Differentiable { @@ -371,14 +373,10 @@ extension NoMemberwiseInitializerExtended: Differentiable // Test derived conformances in disallowed contexts. -// expected-error @+4 {{type 'OtherFileNonconforming' does not conform to protocol 'Differentiable'}} -// expected-error @+3 {{implementation of 'Differentiable' cannot be automatically synthesized in an extension in a different file to the type}} -// expected-note @+2 {{do you want to add protocol stubs?}} -// expected-note @+1 {{do you want to add protocol stubs?}} +// expected-error @+2 {{type 'OtherFileNonconforming' does not conform to protocol 'Differentiable'}} +// expected-error @+1 {{implementation of 'Differentiable' cannot be automatically synthesized in an extension in a different file to the type}} extension OtherFileNonconforming : Differentiable {} -// expected-error @+4 {{type 'GenericOtherFileNonconforming' does not conform to protocol 'Differentiable'}} -// expected-error @+3 {{implementation of 'Differentiable' cannot be automatically synthesized in an extension in a different file to the type}} -// expected-note @+2 {{do you want to add protocol stubs?}} -// expected-note @+1 {{do you want to add protocol stubs?}} +// expected-error @+2 {{type 'GenericOtherFileNonconforming' does not conform to protocol 'Differentiable'}} +// expected-error @+1 {{implementation of 'Differentiable' cannot be automatically synthesized in an extension in a different file to the type}} extension GenericOtherFileNonconforming : Differentiable {}