diff --git a/docs/DifferentiableProgramming.md b/docs/DifferentiableProgramming.md index c40dd3da310ff..0cec90aa47712 100644 --- a/docs/DifferentiableProgramming.md +++ b/docs/DifferentiableProgramming.md @@ -1079,11 +1079,6 @@ public extension Differentiable where Self == TangentVector { mutating func move(along direction: TangentVector) { self += direction } - - @noDerivative - var zeroTangentVectorInitializer: () -> TangentVector { - { .zero } - } } ``` @@ -1144,8 +1139,8 @@ extension Array: Differentiable where Element: Differentiable { @noDerivative public var zeroTangentVectorInitializer: () -> TangentVector { - { [count = self.count] in - TangentVector(Array(repeating: .zero, count: count)) + { [zeroInits = map(\.zeroTangentVectorInitializer)] in + TangentVector(zeroInits.map { $0() }) } } } @@ -1238,8 +1233,15 @@ the same effective access level as their corresponding original properties. A `move(along:)` method is synthesized with a body that calls `move(along:)` for each pair of the original property and its corresponding property in -`TangentVector`. Similarly, `zeroTangentVector` is synthesized to return a -tangent vector that consists of each stored property's `zeroTangentVector`. +`TangentVector`. + +Similarly, when memberwise derivation is possible, +`zeroTangentVectorInitializer` is synthesized to return a closure that captures +and calls each stored property's `zeroTangentVectorInitializer` closure. +When memberwise derivation is not possible (e.g. for custom user-defined +`TangentVector` types), `zeroTangentVectorInitializer` is synthesized as a +`{ TangentVector.zero }` closure. + Here's an example: ```swift @@ -1251,14 +1253,17 @@ struct Foo: @memberwise Differentiable { @noDerivative let helperVariable: T // The compiler synthesizes: + // // struct TangentVector: Differentiable, AdditiveArithmetic { // var x: T.TangentVector // var y: U.TangentVector // } + // // mutating func move(along direction: TangentVector) { // x.move(along: direction.x) // y.move(along: direction.y) // } + // // @noDerivative // var zeroTangentVectorInitializer: () -> TangentVector { // { [xTanInit = x.zeroTangentVectorInitializer, @@ -1278,8 +1283,8 @@ properties are declared to conform to `AdditiveArithmetic`. There are no `@noDerivative` stored properties. In these cases, the compiler will make `TangentVector` be a type alias for Self. -Method `move(along:)` and property `zeroTangentVector` will not be synthesized -because a default implementation already exists. +Method `move(along:)` will not be synthesized because a default implementation +already exists. ```swift struct Point: @memberwise Differentiable, @memberwise AdditiveArithmetic { @@ -1287,7 +1292,16 @@ struct Point: @memberwise Differentiable, @memberwise AdditiveArithmeti var x, y: T // The compiler synthesizes: + // // typealias TangentVector = Self + // + // @noDerivative + // var zeroTangentVectorInitializer: () -> TangentVector { + // { [xTanInit = x.zeroTangentVectorInitializer, + // yTanInit = y.zeroTangentVectorInitializer] in + // TangentVector(x: xTanInit(), y: yTanInit()) + // } + // } } ``` diff --git a/include/swift/AST/KnownIdentifiers.def b/include/swift/AST/KnownIdentifiers.def index 7dd491d870d22..24c9f9c3d979c 100644 --- a/include/swift/AST/KnownIdentifiers.def +++ b/include/swift/AST/KnownIdentifiers.def @@ -223,6 +223,7 @@ IDENTIFIER(move) IDENTIFIER(pullback) IDENTIFIER(TangentVector) IDENTIFIER(zero) +IDENTIFIER(zeroTangentVectorInitializer) #undef IDENTIFIER #undef IDENTIFIER_ diff --git a/lib/Sema/CodeSynthesis.cpp b/lib/Sema/CodeSynthesis.cpp index 225de9b9751f8..325132ea9fbaa 100644 --- a/lib/Sema/CodeSynthesis.cpp +++ b/lib/Sema/CodeSynthesis.cpp @@ -1414,3 +1414,22 @@ void swift::addFixedLayoutAttr(NominalTypeDecl *nominal) { // Add `@_fixed_layout` to the nominal. nominal->getAttrs().add(new (C) FixedLayoutAttr(/*Implicit*/ true)); } + +Expr *DiscriminatorFinder::walkToExprPost(Expr *E) { + auto *ACE = dyn_cast(E); + if (!ACE) + return E; + + unsigned Discriminator = ACE->getDiscriminator(); + assert(Discriminator != AbstractClosureExpr::InvalidDiscriminator && + "Existing closures should have valid discriminators"); + if (Discriminator >= NextDiscriminator) + NextDiscriminator = Discriminator + 1; + return E; +} + +unsigned DiscriminatorFinder::getNextDiscriminator() { + if (NextDiscriminator == AbstractClosureExpr::InvalidDiscriminator) + llvm::report_fatal_error("Out of valid closure discriminators"); + return NextDiscriminator++; +} diff --git a/lib/Sema/CodeSynthesis.h b/lib/Sema/CodeSynthesis.h index a73bb87903e9c..95db8aab697f0 100644 --- a/lib/Sema/CodeSynthesis.h +++ b/lib/Sema/CodeSynthesis.h @@ -18,6 +18,7 @@ #ifndef SWIFT_TYPECHECKING_CODESYNTHESIS_H #define SWIFT_TYPECHECKING_CODESYNTHESIS_H +#include "swift/AST/ASTWalker.h" #include "swift/AST/ForeignErrorConvention.h" #include "swift/Basic/ExternalUnion.h" #include "swift/Basic/LLVM.h" @@ -75,6 +76,20 @@ bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal); /// Add `@_fixed_layout` attribute to the nominal type, if possible. void addFixedLayoutAttr(NominalTypeDecl *nominal); +/// Find available closure discriminators. +/// +/// The parser typically takes care of assigning unique discriminators to +/// closures, but the parser is unavailable during semantic analysis. +class DiscriminatorFinder : public ASTWalker { + unsigned NextDiscriminator = 0; + +public: + Expr *walkToExprPost(Expr *E) override; + + // Get the next available closure discriminator. + unsigned getNextDiscriminator(); +}; + } // end namespace swift #endif diff --git a/lib/Sema/DebuggerTestingTransform.cpp b/lib/Sema/DebuggerTestingTransform.cpp index 2e5719cfcef87..6f9e1c2ed073e 100644 --- a/lib/Sema/DebuggerTestingTransform.cpp +++ b/lib/Sema/DebuggerTestingTransform.cpp @@ -15,6 +15,7 @@ /// //===----------------------------------------------------------------------===// +#include "CodeSynthesis.h" #include "swift/AST/ASTContext.h" #include "swift/AST/ASTNode.h" #include "swift/AST/ASTWalker.h" @@ -33,35 +34,6 @@ using namespace swift; namespace { -/// Find available closure discriminators. -/// -/// The parser typically takes care of assigning unique discriminators to -/// closures, but the parser is unavailable to this transform. -class DiscriminatorFinder : public ASTWalker { - unsigned NextDiscriminator = 0; - -public: - Expr *walkToExprPost(Expr *E) override { - auto *ACE = dyn_cast(E); - if (!ACE) - return E; - - unsigned Discriminator = ACE->getDiscriminator(); - assert(Discriminator != AbstractClosureExpr::InvalidDiscriminator && - "Existing closures should have valid discriminators"); - if (Discriminator >= NextDiscriminator) - NextDiscriminator = Discriminator + 1; - return E; - } - - // Get the next available closure discriminator. - unsigned getNextDiscriminator() { - if (NextDiscriminator == AbstractClosureExpr::InvalidDiscriminator) - llvm::report_fatal_error("Out of valid closure discriminators"); - return NextDiscriminator++; - } -}; - /// Instrument decls with sanity-checks which the debugger can evaluate. class DebuggerTestingTransform : public ASTWalker { ASTContext &Ctx; diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index 6254ea6cb95f6..78312f6779e86 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -17,15 +17,14 @@ #include "CodeSynthesis.h" #include "TypeChecker.h" -#include "DerivedConformances.h" #include "swift/AST/AutoDiff.h" #include "swift/AST/Decl.h" #include "swift/AST/Expr.h" #include "swift/AST/Module.h" #include "swift/AST/ParameterList.h" #include "swift/AST/Pattern.h" -#include "swift/AST/ProtocolConformance.h" #include "swift/AST/PropertyWrappers.h" +#include "swift/AST/ProtocolConformance.h" #include "swift/AST/Stmt.h" #include "swift/AST/Types.h" #include "DerivedConformances.h" @@ -36,7 +35,8 @@ using namespace swift; /// differentiation, except the ones tagged `@noDerivative`. static void getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, DeclContext *DC, - SmallVectorImpl &result) { + SmallVectorImpl &result, + bool includeLetProperties = false) { auto &C = nominal->getASTContext(); auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); for (auto *vd : nominal->getStoredProperties()) { @@ -52,9 +52,10 @@ getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, DeclContext *DC, // Skip stored properties with `@noDerivative` attribute. if (vd->getAttrs().hasAttribute()) continue; - // Skip `let` stored properties. `mutating func move(along:)` cannot be - // synthesized to update these properties. - if (vd->isLet()) + // Skip `let` stored properties if requested. + // `mutating func move(along:)` cannot be synthesized to update `let` + // properties. + if (!includeLetProperties && vd->isLet()) continue; if (vd->getInterfaceType()->hasError()) continue; @@ -77,107 +78,150 @@ static StructDecl *convertToStructDecl(ValueDecl *v) { typeDecl->getDeclaredInterfaceType()->getAnyNominal()); } -/// Get the `Differentiable` protocol `TangentVector` associated type for the -/// given `VarDecl`. -/// TODO: Generalize and move function to shared place for use with other -/// derived conformances. -static Type getTangentVectorType(VarDecl *decl, DeclContext *DC) { - auto &C = decl->getASTContext(); +/// Get the `Differentiable` protocol `TangentVector` associated type witness +/// for the given interface type and declaration context. +static Type getTangentVectorInterfaceType(Type contextualType, + DeclContext *DC) { + auto &C = contextualType->getASTContext(); auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); - auto varType = DC->mapTypeIntoContext(decl->getValueInterfaceType()); - auto conf = TypeChecker::conformsToProtocol(varType, diffableProto, DC); + assert(diffableProto && "`Differentiable` protocol not found"); + auto conf = + TypeChecker::conformsToProtocol(contextualType, diffableProto, DC); + assert(conf && "Contextual type must conform to `Differentiable`"); if (!conf) return nullptr; - Type tangentType = conf.getTypeWitnessByName(varType, C.Id_TangentVector); - return tangentType; + auto tanType = conf.getTypeWitnessByName(contextualType, C.Id_TangentVector); + return tanType->hasArchetype() ? tanType->mapTypeOutOfContext() : tanType; } -// Get the `Differentiable` protocol associated `TangentVector` struct for the -// given nominal `DeclContext`. Asserts that the `TangentVector` struct type -// exists. -static StructDecl *getTangentVectorStructDecl(DeclContext *DC) { - assert(DC->getSelfNominalTypeDecl() && "Must be a nominal `DeclContext`"); - auto &C = DC->getASTContext(); +/// Returns true iff the given nominal type declaration can derive +/// `TangentVector` as `Self` in the given conformance context. +static bool canDeriveTangentVectorAsSelf(NominalTypeDecl *nominal, + DeclContext *DC) { + // `Self` must not be a class declaraiton. + if (nominal->getSelfClassDecl()) + return false; + + auto nominalTypeInContext = + DC->mapTypeIntoContext(nominal->getDeclaredInterfaceType()); + auto &C = nominal->getASTContext(); auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); - assert(diffableProto && "`Differentiable` protocol not found"); - auto conf = TypeChecker::conformsToProtocol(DC->getSelfTypeInContext(), - diffableProto, DC); - assert(conf && "Nominal must conform to `Differentiable`"); - auto assocType = - conf.getTypeWitnessByName(DC->getSelfTypeInContext(), C.Id_TangentVector); - assert(assocType && "`Differentiable.TangentVector` type not found"); - auto *structDecl = dyn_cast(assocType->getAnyNominal()); - assert(structDecl && "Associated type must be a struct type"); - return structDecl; + auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); + // `Self` must conform to `AdditiveArithmetic`. + if (!TypeChecker::conformsToProtocol(nominalTypeInContext, addArithProto, DC)) + return false; + for (auto *field : nominal->getStoredProperties()) { + // `Self` must not have any `@noDerivative` stored properties. + if (field->getAttrs().hasAttribute()) + return false; + // `Self` must have all stored properties satisfy `Self == TangentVector`. + auto fieldType = DC->mapTypeIntoContext(field->getValueInterfaceType()); + auto conf = TypeChecker::conformsToProtocol(fieldType, diffableProto, DC); + if (!conf) + return false; + auto tangentType = conf.getTypeWitnessByName(fieldType, C.Id_TangentVector); + if (!fieldType->isEqual(tangentType)) + return false; + } + return true; +} + +// Synthesizable `Differentiable` protocol requirements. +enum class DifferentiableRequirement { + // associatedtype TangentVector + TangentVector, + // mutating func move(along direction: TangentVector) + MoveAlong, + // var zeroTangentVectorInitializer: () -> TangentVector + ZeroTangentVectorInitializer, +}; + +static DifferentiableRequirement +getDifferentiableRequirementKind(ValueDecl *requirement) { + auto &C = requirement->getASTContext(); + if (requirement->getBaseName() == C.Id_TangentVector) + return DifferentiableRequirement::TangentVector; + if (requirement->getBaseName() == C.Id_move) + return DifferentiableRequirement::MoveAlong; + if (requirement->getBaseName() == C.Id_zeroTangentVectorInitializer) + return DifferentiableRequirement::ZeroTangentVectorInitializer; + llvm_unreachable("Invalid `Differentiable` protocol requirement"); } bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal, - DeclContext *DC) { + DeclContext *DC, + ValueDecl *requirement) { // Experimental differentiable programming must be enabled. if (auto *SF = DC->getParentSourceFile()) if (!isDifferentiableProgrammingEnabled(*SF)) return false; - // Nominal type must be a struct or class. (No stored properties is okay.) - if (!isa(nominal) && !isa(nominal)) - return false; - auto &C = nominal->getASTContext(); - auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); - auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); - // Nominal type must not customize `TangentVector` to anything other than - // `Self`. Otherwise, synthesis is semantically unsupported. - auto tangentDecls = nominal->lookupDirect(C.Id_TangentVector); - auto nominalTypeInContext = - DC->mapTypeIntoContext(nominal->getDeclaredInterfaceType()); + auto reqKind = getDifferentiableRequirementKind(requirement); - auto isValidAssocTypeCandidate = [&](ValueDecl *v) -> StructDecl * { + auto &C = nominal->getASTContext(); + // If there are any `TangentVector` type witness candidates, check whether + // there exists only a single valid candidate. + bool canUseTangentVectorAsSelf = canDeriveTangentVectorAsSelf(nominal, DC); + auto isValidTangentVectorCandidate = [&](ValueDecl *v) -> bool { + // If the requirement is `var zeroTangentVectorInitializer` and + // the candidate is a type declaration that conforms to + // `AdditiveArithmetic`, return true. + if (reqKind == DifferentiableRequirement::ZeroTangentVectorInitializer) { + if (auto *tangentVectorTypeDecl = dyn_cast(v)) { + auto tangentType = DC->mapTypeIntoContext( + tangentVectorTypeDecl->getDeclaredInterfaceType()); + auto *addArithProto = + C.getProtocol(KnownProtocolKind::AdditiveArithmetic); + if (TypeChecker::conformsToProtocol(tangentType, addArithProto, DC)) + return true; + } + } // Valid candidate must be a struct or a typealias to a struct. auto *structDecl = convertToStructDecl(v); if (!structDecl) - return nullptr; + return false; // Valid candidate must either: // 1. Be implicit (previously synthesized). if (structDecl->isImplicit()) - return structDecl; - // 2. Equal nominal's implicit parent. - // This can occur during mutually recursive constraints. Example: - // `X == X.TangentVector`. - if (nominal->isImplicit() && structDecl == nominal->getDeclContext() && - TypeChecker::conformsToProtocol(structDecl->getDeclaredInterfaceType(), - diffableProto, DC)) - return structDecl; - // 3. Equal nominal and conform to `AdditiveArithmetic`. - if (structDecl == nominal) { - // Check conformance to `AdditiveArithmetic`. - if (TypeChecker::conformsToProtocol(nominalTypeInContext, addArithProto, - DC)) - return structDecl; - } + return true; + // 2. Equal nominal, when the nominal can derive `TangentVector` as `Self`. + // Nominal type must not customize `TangentVector` to anything other than + // `Self`. Otherwise, synthesis is semantically unsupported. + if (structDecl == nominal && canUseTangentVectorAsSelf) + return true; // Otherwise, candidate is invalid. - return nullptr; + return false; }; - - auto invalidTangentDecls = llvm::partition( - tangentDecls, [&](ValueDecl *v) { return isValidAssocTypeCandidate(v); }); - - auto validTangentDeclCount = - std::distance(tangentDecls.begin(), invalidTangentDecls); - auto invalidTangentDeclCount = - std::distance(invalidTangentDecls, tangentDecls.end()); - - // There cannot be any invalid `TangentVector` types. + auto tangentDecls = nominal->lookupDirect(C.Id_TangentVector); // There can be at most one valid `TangentVector` type. - if (invalidTangentDeclCount != 0 || validTangentDeclCount > 1) + if (tangentDecls.size() > 1) return false; + // There cannot be any invalid `TangentVector` types. + if (tangentDecls.size() == 1) { + auto *tangentDecl = tangentDecls.front(); + if (!isValidTangentVectorCandidate(tangentDecl)) + return false; + } + bool hasValidTangentDecl = !tangentDecls.empty(); + + // Check requirement-specific derivation conditions. + if (reqKind == DifferentiableRequirement::ZeroTangentVectorInitializer) { + // If there is a valid `TangentVector` type witness (conforming to + // `AdditiveArithmetic`), return true. + if (hasValidTangentDecl) + return true; + // Otherwise, fallback on `TangentVector` struct derivation conditions. + } - // All stored properties not marked with `@noDerivative`: - // - Must conform to `Differentiable`. - // - Must not have any `let` stored properties with an initial value. - // - This restriction may be lifted later with support for "true" memberwise - // initializers that initialize all stored properties, including initial - // value information. + // Check `TangentVector` struct derivation conditions. + // Nominal type must be a struct or class. (No stored properties is okay.) + if (!isa(nominal) && !isa(nominal)) + return false; + // If there are no `TangentVector` candidates, derivation is possible if all + // differentiation stored properties conform to `Differentiable`. SmallVector diffProperties; getStoredPropertiesForDifferentiation(nominal, DC, diffProperties); + auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); return llvm::all_of(diffProperties, [&](VarDecl *v) { if (v->getInterfaceType()->hasError()) return false; @@ -186,18 +230,16 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal, }); } -/// Synthesize body for a `Differentiable` method requirement. +/// Synthesize body for `move(along:)`. static std::pair -deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl, - Identifier methodName, - Identifier methodParamLabel) { +deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) { + auto &C = funcDecl->getASTContext(); auto *parentDC = funcDecl->getParent(); auto *nominal = parentDC->getSelfNominalTypeDecl(); - auto &C = nominal->getASTContext(); - // Get method protocol requirement. + // Get `Differentiable.move(along:)` protocol requirement. auto *diffProto = C.getProtocol(KnownProtocolKind::Differentiable); - auto *methodReq = getProtocolRequirement(diffProto, methodName); + auto *requirement = getProtocolRequirement(diffProto, C.Id_move); // Get references to `self` and parameter declarations. auto *selfDecl = funcDecl->getImplicitSelfDecl(); @@ -210,9 +252,8 @@ deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl, SmallVector diffProperties; getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties); - // Create call expression applying a member method to a parameter member. - // Format: `.method(.)`. - // Example: `x.move(along: direction.x)`. + // Create call expression applying a member `move(along:)` method to a + // parameter member: `self..move(along: direction.)`. auto createMemberMethodCallExpr = [&](VarDecl *member) -> Expr * { auto *module = nominal->getModuleContext(); auto memberType = @@ -220,27 +261,24 @@ deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl, auto confRef = module->lookupConformance(memberType, diffProto); assert(confRef && "Member does not conform to `Differentiable`"); - // Get member type's method, e.g. `Member.move(along:)`. - // Use protocol requirement declaration for the method by default: this - // will be dynamically dispatched. - ValueDecl *memberMethodDecl = methodReq; - // If conformance reference is concrete, then use concrete witness - // declaration for the operator. + // Get member type's requirement witness: `.move(along:)`. + ValueDecl *memberWitnessDecl = requirement; if (confRef.isConcrete()) - memberMethodDecl = confRef.getConcrete()->getWitnessDecl(methodReq); - assert(memberMethodDecl && "Member method declaration must exist"); - auto *memberMethodDRE = - new (C) DeclRefExpr(memberMethodDecl, DeclNameLoc(), /*Implicit*/ true); + if (auto *witness = confRef.getConcrete()->getWitnessDecl(requirement)) + memberWitnessDecl = witness; + assert(memberWitnessDecl && "Member witness declaration must exist"); + auto *memberMethodDRE = new (C) + DeclRefExpr(memberWitnessDecl, DeclNameLoc(), /*Implicit*/ true); memberMethodDRE->setFunctionRefKind(FunctionRefKind::SingleApply); - // Create reference to member method: `x.move(along:)`. + // Create reference to member method: `self..move(along:)`. Expr *memberExpr = new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true); auto *memberMethodExpr = new (C) DotSyntaxCallExpr(memberMethodDRE, SourceLoc(), memberExpr); - // Create reference to parameter member: `direction.x`. + // Create reference to parameter member: `direction.`. VarDecl *paramMember = nullptr; auto *paramNominal = paramDecl->getType()->getAnyNominal(); assert(paramNominal && "Parameter should have a nominal type"); @@ -255,14 +293,14 @@ deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl, auto *paramMemberExpr = new (C) MemberRefExpr(paramDRE, SourceLoc(), paramMember, DeclNameLoc(), /*Implicit*/ true); - // Create expression: `x.move(along: direction.x)`. + // Create expression: `self..move(along: direction.)`. return CallExpr::createImplicit(C, memberMethodExpr, {paramMemberExpr}, - {methodParamLabel}); + {C.Id_along}); }; - // Create array of member method call expressions. - llvm::SmallVector memberMethodCallExprs; - llvm::SmallVector memberNames; + // Collect member `move(along:)` method call expressions. + SmallVector memberMethodCallExprs; + SmallVector memberNames; for (auto *member : diffProperties) { memberMethodCallExprs.push_back(createMemberMethodCallExpr(member)); memberNames.push_back(member->getName()); @@ -272,11 +310,229 @@ deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl, return std::pair(braceStmt, false); } -/// Synthesize body for `move(along:)`. +/// Synthesize body for `var zeroTangentVectorInitializer` getter. static std::pair -deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) { +deriveBodyDifferentiable_zeroTangentVectorInitializer( + AbstractFunctionDecl *funcDecl, void *) { auto &C = funcDecl->getASTContext(); - return deriveBodyDifferentiable_method(funcDecl, C.Id_move, C.Id_along); + auto *parentDC = funcDecl->getParent(); + auto *nominal = parentDC->getSelfNominalTypeDecl(); + + // Get method protocol requirement. + auto *diffProto = C.getProtocol(KnownProtocolKind::Differentiable); + auto *requirement = + getProtocolRequirement(diffProto, C.Id_zeroTangentVectorInitializer); + + auto nominalType = + parentDC->mapTypeIntoContext(nominal->getDeclaredInterfaceType()); + auto conf = TypeChecker::conformsToProtocol(nominalType, diffProto, parentDC); + auto tangentType = conf.getTypeWitnessByName(nominalType, C.Id_TangentVector); + auto *tangentTypeExpr = TypeExpr::createImplicit(tangentType, C); + + // Get differentiation properties. + SmallVector diffProperties; + getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties, + /*includeLetProperties*/ true); + + // Check whether memberwise derivation of `zeroTangentVectorInitializer` is + // possible. + bool canPerformMemberwiseDerivation = [&]() -> bool { + // Memberwise derivation is possible only for struct `TangentVector` types. + auto *tangentTypeDecl = tangentType->getAnyNominal(); + if (!tangentTypeDecl || !tangentTypeDecl->getSelfStructDecl()) + return false; + // Get effective memberwise initializer. + auto *memberwiseInitDecl = + tangentTypeDecl->getEffectiveMemberwiseInitializer(); + // Return false if number of memberwise initializer parameters does not + // equal number of differentiation properties. + if (memberwiseInitDecl->getParameters()->size() != diffProperties.size()) + return false; + // Iterate over all initializer parameters and differentiation properties. + for (auto pair : llvm::zip(memberwiseInitDecl->getParameters()->getArray(), + diffProperties)) { + auto *initParam = std::get<0>(pair); + auto *diffProp = std::get<1>(pair); + // Return false if parameter label does not equal property name. + if (initParam->getParameterName() != diffProp->getName()) + return false; + auto diffPropContextualType = + parentDC->mapTypeIntoContext(diffProp->getValueInterfaceType()); + auto diffPropTangentType = + getTangentVectorInterfaceType(diffPropContextualType, parentDC); + // Return false if parameter type does not equal property tangent type. + if (!initParam->getValueInterfaceType()->isEqual(diffPropTangentType)) + return false; + } + return true; + }(); + + // If memberwise derivation is not possible, synthesize + // `{ TangentVector.zero }` as a fallback. + if (!canPerformMemberwiseDerivation) { + auto *module = nominal->getModuleContext(); + auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); + auto confRef = module->lookupConformance(tangentType, addArithProto); + assert(confRef && + "`TangentVector` does not conform to `AdditiveArithmetic`"); + auto *zeroDecl = getProtocolRequirement(addArithProto, C.Id_zero); + // If conformance reference is concrete, then use concrete witness + // declaration for the operator. + if (confRef.isConcrete()) + if (auto *witnessDecl = confRef.getConcrete()->getWitnessDecl(zeroDecl)) + zeroDecl = witnessDecl; + assert(zeroDecl && "Member method declaration must exist"); + auto *zeroExpr = + new (C) MemberRefExpr(tangentTypeExpr, SourceLoc(), zeroDecl, + DeclNameLoc(), /*Implicit*/ true); + + // Create closure expression. + DiscriminatorFinder DF; + for (Decl *D : parentDC->getParentSourceFile()->getTopLevelDecls()) + D->walk(DF); + auto discriminator = DF.getNextDiscriminator(); + auto resultTy = funcDecl->getMethodInterfaceType() + ->castTo() + ->getResult(); + + auto *closureParams = ParameterList::createEmpty(C); + auto *closure = new (C) ClosureExpr( + SourceRange(), /*capturedSelfDecl*/ nullptr, closureParams, SourceLoc(), + SourceLoc(), SourceLoc(), TypeExpr::createImplicit(resultTy, C), + discriminator, funcDecl); + closure->setImplicit(); + auto *closureReturn = new (C) ReturnStmt(SourceLoc(), zeroExpr, true); + auto *closureBody = + BraceStmt::create(C, SourceLoc(), {closureReturn}, SourceLoc(), true); + closure->setBody(closureBody, /*isSingleExpression=*/true); + + ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), closure, true); + auto *braceStmt = + BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true); + return std::pair(braceStmt, false); + } + + // Otherwise, perform memberwise derivation. + // Get effective memberwise initializer: `Nominal.init(...)`. + auto *tangentTypeDecl = tangentType->getAnyNominal(); + auto *memberwiseInitDecl = + tangentTypeDecl->getEffectiveMemberwiseInitializer(); + assert(memberwiseInitDecl && "Memberwise initializer must exist"); + auto *initDRE = + new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true); + initDRE->setFunctionRefKind(FunctionRefKind::SingleApply); + auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, tangentTypeExpr); + + // Get references to `self` and parameter declarations. + auto *selfDecl = funcDecl->getImplicitSelfDecl(); + + // Create `self..zeroTangentVectorInitializer` capture list entry. + auto createMemberZeroTanInitCaptureListEntry = + [&](VarDecl *member) -> CaptureListEntry { + // Create `_zeroTangentVectorInitializer` capture var declaration. + auto memberCaptureName = C.getIdentifier(std::string(member->getNameStr()) + + "_zeroTangentVectorInitializer"); + auto *memberZeroTanInitCaptureDecl = new (C) VarDecl( + /*isStatic*/ false, VarDecl::Introducer::Let, /*isCaptureList*/ true, + SourceLoc(), memberCaptureName, funcDecl); + memberZeroTanInitCaptureDecl->setImplicit(); + auto *memberZeroTanInitPattern = + NamedPattern::createImplicit(C, memberZeroTanInitCaptureDecl); + + auto *module = nominal->getModuleContext(); + auto memberType = + parentDC->mapTypeIntoContext(member->getValueInterfaceType()); + auto confRef = module->lookupConformance(memberType, diffProto); + assert(confRef && "Member does not conform to `Differentiable`"); + + // Get member type's `zeroTangentVectorInitializer` requirement witness. + ValueDecl *memberWitnessDecl = requirement; + if (confRef.isConcrete()) + if (auto *witness = confRef.getConcrete()->getWitnessDecl(requirement)) + memberWitnessDecl = witness; + assert(memberWitnessDecl && "Member witness declaration must exist"); + + // .zeroTangentVectorInitializer + auto *selfDRE = + new (C) DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true); + auto *memberExpr = + new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(), + /*Implicit*/ true); + auto *memberZeroTangentVectorInitExpr = + new (C) MemberRefExpr(memberExpr, SourceLoc(), memberWitnessDecl, + DeclNameLoc(), /*Implicit*/ true); + auto *memberZeroTanInitPBD = PatternBindingDecl::createImplicit( + C, StaticSpellingKind::None, memberZeroTanInitPattern, + memberZeroTangentVectorInitExpr, funcDecl); + CaptureListEntry captureEntry(memberZeroTanInitCaptureDecl, + memberZeroTanInitPBD); + return captureEntry; + }; + + // Create `_zeroTangentVectorInitializer()` call expression. + auto createMemberZeroTanInitCallExpr = + [&](CaptureListEntry memberZeroTanInitEntry) -> Expr * { + // _zeroTangentVectorInitializer + auto *memberZeroTanInitDRE = new (C) DeclRefExpr( + memberZeroTanInitEntry.Var, DeclNameLoc(), /*Implicit*/ true); + // _zeroTangentVectorInitializer() + auto *memberZeroTangentVector = + CallExpr::createImplicit(C, memberZeroTanInitDRE, {}, {}); + return memberZeroTangentVector; + }; + + // Collect member zero tangent vector expressions. + SmallVector memberNames; + SmallVector memberZeroTanExprs; + SmallVector memberZeroTanInitCaptures; + for (auto *member : diffProperties) { + memberNames.push_back(member->getName()); + auto memberZeroTanInitCapture = + createMemberZeroTanInitCaptureListEntry(member); + memberZeroTanInitCaptures.push_back(memberZeroTanInitCapture); + memberZeroTanExprs.push_back( + createMemberZeroTanInitCallExpr(memberZeroTanInitCapture)); + } + + // Create `zeroTangentVectorInitializer` closure body: + // `TangentVector(x: x_zeroTangentVectorInitializer(), ...)`. + auto *callExpr = + CallExpr::createImplicit(C, initExpr, memberZeroTanExprs, memberNames); + + // Create closure expression: + // `{ TangentVector(x: x_zeroTangentVectorInitializer(), ...) }`. + DiscriminatorFinder DF; + for (Decl *D : parentDC->getParentSourceFile()->getTopLevelDecls()) + D->walk(DF); + auto discriminator = DF.getNextDiscriminator(); + auto resultTy = funcDecl->getMethodInterfaceType() + ->castTo() + ->getResult(); + auto *closureParams = ParameterList::createEmpty(C); + auto *closure = new (C) ClosureExpr( + SourceRange(), /*capturedSelfDecl*/ nullptr, closureParams, SourceLoc(), + SourceLoc(), SourceLoc(), TypeExpr::createImplicit(resultTy, C), + discriminator, funcDecl); + closure->setImplicit(); + auto *closureReturn = new (C) ReturnStmt(SourceLoc(), callExpr, true); + auto *closureBody = + BraceStmt::create(C, SourceLoc(), {closureReturn}, SourceLoc(), true); + closure->setBody(closureBody, /*isSingleExpression=*/true); + + // Create capture list expression: + // ``` + // { [x_zeroTangentVectorInitializer = x.zeroTangentVectorInitializer, ...] in + // TangentVector(x: x_zeroTangentVectorInitializer(), ...) + // } + // ``` + auto *captureList = + CaptureListExpr::create(C, memberZeroTanInitCaptures, closure); + captureList->setImplicit(); + + ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), captureList, true); + auto *braceStmt = + BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true); + return std::pair(braceStmt, false); } /// Synthesize function declaration for a `Differentiable` method requirement. @@ -316,15 +572,41 @@ static ValueDecl *deriveDifferentiable_method( static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) { auto &C = derived.Context; auto *parentDC = derived.getConformanceContext(); - - auto *tangentDecl = getTangentVectorStructDecl(parentDC); - auto tangentType = tangentDecl->getDeclaredInterfaceType(); - + auto tangentType = + getTangentVectorInterfaceType(parentDC->getSelfTypeInContext(), parentDC); return deriveDifferentiable_method( derived, C.Id_move, C.Id_along, C.Id_direction, tangentType, C.TheEmptyTupleType, {deriveBodyDifferentiable_move, nullptr}); } +/// Synthesize the `zeroTangentVectorInitializer` computed property declaration. +static ValueDecl * +deriveDifferentiable_zeroTangentVectorInitializer(DerivedConformance &derived) { + auto &C = derived.Context; + auto *parentDC = derived.getConformanceContext(); + + auto tangentType = + getTangentVectorInterfaceType(parentDC->getSelfTypeInContext(), parentDC); + auto returnType = FunctionType::get({}, tangentType); + + VarDecl *propDecl; + PatternBindingDecl *pbDecl; + std::tie(propDecl, pbDecl) = derived.declareDerivedProperty( + C.Id_zeroTangentVectorInitializer, returnType, returnType, + /*isStatic*/ false, /*isFinal*/ true); + + // Define the getter. + auto *getterDecl = + derived.addGetterToReadOnlyDerivedProperty(propDecl, returnType); + // Add an implicit `@noDerivative` attribute. + // `zeroTangentVectorInitializer` getter calls should never be differentiated. + getterDecl->getAttrs().add(new (C) NoDerivativeAttr(/*Implicit*/ true)); + getterDecl->setBodySynthesizer( + &deriveBodyDifferentiable_zeroTangentVectorInitializer); + derived.addMembersToConformanceContext({propDecl, pbDecl}); + return propDecl; +} + /// Return associated `TangentVector` struct for a nominal type, if it exists. /// If not, synthesize the struct. static StructDecl * @@ -368,24 +650,22 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) { for (auto *member : diffProperties) { // Add this member's corresponding `TangentVector` type to the parent's // `TangentVector` struct. + // Note: `newMember` is not marked as implicit here, because that + // incorrectly affects memberwise initializer synthesis. auto *newMember = new (C) VarDecl( member->isStatic(), member->getIntroducer(), member->isCaptureList(), /*NameLoc*/ SourceLoc(), member->getName(), structDecl); - // NOTE: `newMember` is not marked as implicit here, because that affects - // memberwise initializer synthesis. - - auto memberAssocType = getTangentVectorType(member, parentDC); - auto memberAssocInterfaceType = memberAssocType->hasArchetype() - ? memberAssocType->mapTypeOutOfContext() - : memberAssocType; - auto memberAssocContextualType = - parentDC->mapTypeIntoContext(memberAssocInterfaceType); - newMember->setInterfaceType(memberAssocInterfaceType); + + auto memberContextualType = + parentDC->mapTypeIntoContext(member->getValueInterfaceType()); + auto memberTanType = + getTangentVectorInterfaceType(memberContextualType, parentDC); + newMember->setInterfaceType(memberTanType); Pattern *memberPattern = NamedPattern::createImplicit(C, newMember); - memberPattern->setType(memberAssocContextualType); - memberPattern = TypedPattern::createImplicit(C, memberPattern, - memberAssocContextualType); - memberPattern->setType(memberAssocContextualType); + memberPattern->setType(memberTanType); + memberPattern = + TypedPattern::createImplicit(C, memberPattern, memberTanType); + memberPattern->setType(memberTanType); auto *memberBinding = PatternBindingDecl::createImplicit( C, StaticSpellingKind::None, memberPattern, /*initExpr*/ nullptr, structDecl); @@ -582,13 +862,6 @@ getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) { addAssociatedTypeAliasDecl(C.Id_TangentVector, tangentStruct, tangentStruct, C); - // Sanity checks for synthesized struct. - assert(DerivedConformance::canDeriveAdditiveArithmetic(tangentStruct, - parentDC) && - "Should be able to derive `AdditiveArithmetic`"); - assert(DerivedConformance::canDeriveDifferentiable(tangentStruct, parentDC) && - "Should be able to derive `Differentiable`"); - // Return the `TangentVector` struct type. return parentDC->mapTypeIntoContext( tangentStruct->getDeclaredInterfaceType()); @@ -599,82 +872,75 @@ static Type deriveDifferentiable_TangentVectorStruct(DerivedConformance &derived) { auto *parentDC = derived.getConformanceContext(); auto *nominal = derived.Nominal; - auto &C = nominal->getASTContext(); - - // Get all stored properties for differentation. - SmallVector diffProperties; - getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties); - // If any member has an invalid `TangentVector` type, return nullptr. - for (auto *member : diffProperties) - if (!getTangentVectorType(member, parentDC)) - return nullptr; - - // Prevent re-synthesis during repeated calls. - // FIXME: Investigate why this is necessary to prevent duplicate synthesis. - auto lookup = nominal->lookupDirect(C.Id_TangentVector); - if (lookup.size() == 1) - if (auto *structDecl = convertToStructDecl(lookup.front())) - if (structDecl->isImplicit()) - return structDecl->getDeclaredInterfaceType(); - - // Check whether at least one `@noDerivative` stored property exists. - unsigned numStoredProperties = - std::distance(nominal->getStoredProperties().begin(), - nominal->getStoredProperties().end()); - bool hasNoDerivativeStoredProp = diffProperties.size() != numStoredProperties; - - // Check conditions for returning `Self`. - // - `Self` is not a class type. - // - No `@noDerivative` stored properties exist. - // - All stored properties must have `TangentVector` type equal to `Self`. - // - Parent type must also conform to `AdditiveArithmetic`. - bool allMembersAssocTypeEqualsSelf = - llvm::all_of(diffProperties, [&](VarDecl *member) { - auto memberAssocType = getTangentVectorType(member, parentDC); - return member->getType()->isEqual(memberAssocType); - }); - - auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); - auto nominalConformsToAddArith = TypeChecker::conformsToProtocol( - parentDC->getSelfTypeInContext(), addArithProto, parentDC); - - // Return `Self` if conditions are met. - if (!hasNoDerivativeStoredProp && !nominal->getSelfClassDecl() && - allMembersAssocTypeEqualsSelf && nominalConformsToAddArith) { - auto selfType = parentDC->getSelfTypeInContext(); - auto *aliasDecl = - new (C) TypeAliasDecl(SourceLoc(), SourceLoc(), C.Id_TangentVector, - SourceLoc(), {}, parentDC); - aliasDecl->setUnderlyingType(selfType); - aliasDecl->setImplicit(); - aliasDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); - derived.addMembersToConformanceContext({aliasDecl}); - return selfType; - } + // If nominal type can derive `TangentVector` as the contextual `Self` type, + // return it. + if (canDeriveTangentVectorAsSelf(nominal, parentDC)) + return parentDC->getSelfTypeInContext(); // Otherwise, get or synthesize `TangentVector` struct type. return getOrSynthesizeTangentVectorStructType(derived); } ValueDecl *DerivedConformance::deriveDifferentiable(ValueDecl *requirement) { + // Diagnose unknown requirements. + if (requirement->getBaseName() != Context.Id_move && + requirement->getBaseName() != Context.Id_zeroTangentVectorInitializer) { + Context.Diags.diagnose(requirement->getLoc(), + diag::broken_differentiable_requirement); + return nullptr; + } // Diagnose conformances in disallowed contexts. if (checkAndDiagnoseDisallowedContext(requirement)) return nullptr; - if (requirement->getBaseName() == Context.Id_move) - return deriveDifferentiable_move(*this); - Context.Diags.diagnose(requirement->getLoc(), - diag::broken_differentiable_requirement); + + // Start an error diagnostic before attempting derivation. + // If derivation succeeds, cancel the diagnostic. + DiagnosticTransaction diagnosticTransaction(Context.Diags); + ConformanceDecl->diagnose(diag::type_does_not_conform, + Nominal->getDeclaredType(), getProtocolType()); + requirement->diagnose(diag::no_witnesses, + getProtocolRequirementKind(requirement), + requirement->getName(), getProtocolType(), + /*AddFixIt=*/false); + + // If derivation is possible, cancel the diagnostic and perform derivation. + if (canDeriveDifferentiable(Nominal, getConformanceContext(), requirement)) { + diagnosticTransaction.abort(); + if (requirement->getBaseName() == Context.Id_move) + return deriveDifferentiable_move(*this); + if (requirement->getBaseName() == Context.Id_zeroTangentVectorInitializer) + return deriveDifferentiable_zeroTangentVectorInitializer(*this); + } + + // Otheriwse, return nullptr. return nullptr; } Type DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) { + // Diagnose unknown requirements. + if (requirement->getBaseName() != Context.Id_TangentVector) { + Context.Diags.diagnose(requirement->getLoc(), + diag::broken_differentiable_requirement); + return nullptr; + } // Diagnose conformances in disallowed contexts. if (checkAndDiagnoseDisallowedContext(requirement)) return nullptr; - if (requirement->getBaseName() == Context.Id_TangentVector) + + // Start an error diagnostic before attempting derivation. + // If derivation succeeds, cancel the diagnostic. + DiagnosticTransaction diagnosticTransaction(Context.Diags); + ConformanceDecl->diagnose(diag::type_does_not_conform, + Nominal->getDeclaredType(), getProtocolType()); + requirement->diagnose(diag::no_witnesses_type, requirement->getName()); + + // If derivation is possible, cancel the diagnostic and perform derivation. + if (canDeriveDifferentiable(Nominal, getConformanceContext(), requirement)) { + diagnosticTransaction.abort(); return deriveDifferentiable_TangentVectorStruct(*this); - Context.Diags.diagnose(requirement->getLoc(), - diag::broken_differentiable_requirement); + } + + // Otherwise, return nullptr. return nullptr; } diff --git a/lib/Sema/DerivedConformances.cpp b/lib/Sema/DerivedConformances.cpp index 771fcf01fd4be..46f6ad101ce1d 100644 --- a/lib/Sema/DerivedConformances.cpp +++ b/lib/Sema/DerivedConformances.cpp @@ -74,8 +74,11 @@ bool DerivedConformance::derivesProtocolConformance(DeclContext *DC, if (*derivableKind == KnownDerivableProtocolKind::AdditiveArithmetic) return canDeriveAdditiveArithmetic(Nominal, DC); + // Eagerly return true here. Actual synthesis conditions are checked in + // `DerivedConformance::deriveDifferentiable`: they are complicated and depend + // on the requirement being derived. if (*derivableKind == KnownDerivableProtocolKind::Differentiable) - return canDeriveDifferentiable(Nominal, DC); + return true; if (auto *enumDecl = dyn_cast(Nominal)) { switch (*derivableKind) { @@ -227,6 +230,10 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal, if (name.isSimpleName(ctx.Id_intValue)) return getRequirement(KnownProtocolKind::CodingKey); + // Differentiable.zeroTangentVectorInitializer + if (name.isSimpleName(ctx.Id_zeroTangentVectorInitializer)) + return getRequirement(KnownProtocolKind::Differentiable); + // AdditiveArithmetic.zero if (name.isSimpleName(ctx.Id_zero)) return getRequirement(KnownProtocolKind::AdditiveArithmetic); diff --git a/lib/Sema/DerivedConformances.h b/lib/Sema/DerivedConformances.h index 0073f8edbbfd3..1f6296ee84fcb 100644 --- a/lib/Sema/DerivedConformances.h +++ b/lib/Sema/DerivedConformances.h @@ -107,10 +107,12 @@ class DerivedConformance { /// \returns the derived member, which will also be added to the type. ValueDecl *deriveAdditiveArithmetic(ValueDecl *requirement); - /// Determine if a Differentiable requirement can be derived for a type. + /// Determine if a Differentiable requirement can be derived for a nominal + /// type. /// /// \returns True if the requirement can be derived. - static bool canDeriveDifferentiable(NominalTypeDecl *type, DeclContext *DC); + static bool canDeriveDifferentiable(NominalTypeDecl *type, DeclContext *DC, + ValueDecl *requirement); /// Derive a Differentiable requirement for a nominal type. /// diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index b0c09e87d9d8b..e2a906f9f2f18 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -2039,19 +2039,17 @@ static Type getRequirementTypeForDisplay(ModuleDecl *module, return substType(type, /*result*/false); } -/// Retrieve the kind of requirement described by the given declaration, -/// for use in some diagnostics. -static diag::RequirementKind getRequirementKind(ValueDecl *VD) { - if (isa(VD)) - return diag::RequirementKind::Constructor; +diag::RequirementKind +swift::getProtocolRequirementKind(ValueDecl *Requirement) { + assert(Requirement->isProtocolRequirement()); - if (isa(VD)) + if (isa(Requirement)) + return diag::RequirementKind::Constructor; + if (isa(Requirement)) return diag::RequirementKind::Func; - - if (isa(VD)) + if (isa(Requirement)) return diag::RequirementKind::Var; - - assert(isa(VD) && "Unhandled requirement kind"); + assert(isa(Requirement) && "Unhandled requirement kind"); return diag::RequirementKind::Subscript; } @@ -2254,7 +2252,7 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance, case MatchKind::KindConflict: diags.diagnose(match.Witness, diag::protocol_witness_kind_conflict, - getRequirementKind(req)); + getProtocolRequirementKind(req)); break; case MatchKind::WitnessInvalid: @@ -3053,13 +3051,14 @@ diagnoseMissingWitnesses(MissingWitnessDiagnosisKind Kind) { // If the protocol member decl is in the same file of the stub, // we can directly associate the fixit with the note issued to the // requirement. - Diags.diagnose(VD, diag::no_witnesses, getRequirementKind(VD), - VD->getName(), RequirementType, true) + Diags + .diagnose(VD, diag::no_witnesses, getProtocolRequirementKind(VD), + VD->getName(), RequirementType, true) .fixItInsertAfter(FixitLocation, FixIt); } else { // Otherwise, we have to issue another note to carry the fixit, // because editor may assume the fixit is in the same file with the note. - Diags.diagnose(VD, diag::no_witnesses, getRequirementKind(VD), + Diags.diagnose(VD, diag::no_witnesses, getProtocolRequirementKind(VD), VD->getName(), RequirementType, false); if (EditorMode) { Diags.diagnose(ComplainLoc, diag::missing_witnesses_general) @@ -3067,7 +3066,7 @@ diagnoseMissingWitnesses(MissingWitnessDiagnosisKind Kind) { } } } else { - Diags.diagnose(VD, diag::no_witnesses, getRequirementKind(VD), + Diags.diagnose(VD, diag::no_witnesses, getProtocolRequirementKind(VD), VD->getName(), RequirementType, true); } } @@ -3425,11 +3424,8 @@ ConformanceChecker::resolveWitnessViaLookup(ValueDecl *requirement) { auto &diags = DC->getASTContext().Diags; diags.diagnose(getLocForDiagnosingWitness(conformance, witness), - diagKind, - getRequirementKind(requirement), - witness->getName(), - isSetter, - requiredAccess, + diagKind, getProtocolRequirementKind(requirement), + witness->getName(), isSetter, requiredAccess, protoAccessScope.accessLevelForDiagnostics(), proto->getName()); if (auto *decl = dyn_cast(witness)) { @@ -3619,9 +3615,8 @@ ConformanceChecker::resolveWitnessViaLookup(ValueDecl *requirement) { diagnosticMessage = diag::ambiguous_witnesses_wrong_name; } diags.diagnose(requirement, diagnosticMessage, - getRequirementKind(requirement), - requirement->getName(), - reqType); + getProtocolRequirementKind(requirement), + requirement->getName(), reqType); // Diagnose each of the matches. for (const auto &match : matches) diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h index 588ac6d1254b2..a4e1f7974f571 100644 --- a/lib/Sema/TypeChecker.h +++ b/lib/Sema/TypeChecker.h @@ -1316,6 +1316,12 @@ class EncodedDiagnosticMessage { const StringRef Message; }; +/// Returns the protocol requirement kind of the given declaration. +/// Used in diagnostics. +/// +/// Asserts that the given declaration is a protocol requirement. +diag::RequirementKind getProtocolRequirementKind(ValueDecl *Requirement); + /// Returns true if the given method is an valid implementation of a /// @dynamicCallable attribute requirement. The method is given to be defined /// as one of the following: `dynamicallyCall(withArguments:)` or diff --git a/stdlib/public/Differentiation/AnyDifferentiable.swift b/stdlib/public/Differentiation/AnyDifferentiable.swift index 421212f90037f..3a116104256c0 100644 --- a/stdlib/public/Differentiation/AnyDifferentiable.swift +++ b/stdlib/public/Differentiation/AnyDifferentiable.swift @@ -24,6 +24,7 @@ import Swift internal protocol _AnyDifferentiableBox { // `Differentiable` requirements. mutating func _move(along direction: AnyDerivative) + var _zeroTangentVectorInitializer: () -> AnyDerivative { get } /// The underlying base value, type-erased to `Any`. var _typeErasedBase: Any { get } @@ -59,6 +60,10 @@ internal struct _ConcreteDifferentiableBox: _AnyDifferentiabl } _base.move(along: directionBase) } + + var _zeroTangentVectorInitializer: () -> AnyDerivative { + { AnyDerivative(_base.zeroTangentVector) } + } } public struct AnyDifferentiable: Differentiable { @@ -103,6 +108,10 @@ public struct AnyDifferentiable: Differentiable { public mutating func move(along direction: TangentVector) { _box._move(along: direction) } + + public var zeroTangentVectorInitializer: () -> TangentVector { + _box._zeroTangentVectorInitializer + } } //===----------------------------------------------------------------------===// diff --git a/stdlib/public/Differentiation/ArrayDifferentiation.swift b/stdlib/public/Differentiation/ArrayDifferentiation.swift index cd1f20e308798..fbaef9c34fe80 100644 --- a/stdlib/public/Differentiation/ArrayDifferentiation.swift +++ b/stdlib/public/Differentiation/ArrayDifferentiation.swift @@ -168,8 +168,8 @@ extension Array: Differentiable where Element: Differentiable { /// A closure that produces a `TangentVector` of zeros with the same /// `count` as `self`. public var zeroTangentVectorInitializer: () -> TangentVector { - { [count = self.count] in - TangentVector(.init(repeating: .zero, count: count)) + { [zeroInits = map(\.zeroTangentVectorInitializer)] in + TangentVector(zeroInits.map { $0() }) } } } diff --git a/stdlib/public/Differentiation/Differentiable.swift b/stdlib/public/Differentiation/Differentiable.swift index 077144e40f1e9..1341e034811ef 100644 --- a/stdlib/public/Differentiation/Differentiable.swift +++ b/stdlib/public/Differentiation/Differentiable.swift @@ -80,21 +80,6 @@ public extension Differentiable where TangentVector == Self { } public extension Differentiable { - // This is a temporary solution enabling the addition of - // `zeroTangentVectorInitializer` without implementing derived conformances. - // This property will produce incorrect results when tangent vectors depend - // on instance-specific information from `self`. - // TODO: Implement derived conformances and remove this default - // implementation. - @available(*, deprecated, message: """ - `zeroTangentVectorInitializer` derivation has not been implemented; this \ - default implementation is not correct when tangent vectors depend on \ - instance-specific information from `self` and should not be used - """) - var zeroTangentVectorInitializer: () -> TangentVector { - { TangentVector.zero } - } - /// A tangent vector initialized using `zeroTangentVectorInitializer`. /// `move(along: zeroTangentVector)` should not modify `self`. var zeroTangentVector: TangentVector { zeroTangentVectorInitializer() } diff --git a/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb b/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb index e4d4ab68350a1..c4ea3fb01f06c 100644 --- a/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb +++ b/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb @@ -42,6 +42,11 @@ extension ${Self}: Differentiable { public mutating func move(along direction: TangentVector) { self += direction } + + @inlinable + public var zeroTangentVectorInitializer: () -> TangentVector { + { 0 } + } } //===----------------------------------------------------------------------===// diff --git a/stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb b/stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb index d70b7201d722d..a60ab8dd461ed 100644 --- a/stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb +++ b/stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb @@ -31,6 +31,11 @@ where Scalar.TangentVector: BinaryFloatingPoint { public typealias TangentVector = SIMD${n} + + @inlinable + public var zeroTangentVectorInitializer: () -> TangentVector { + { .init(repeating: 0) } + } } //===----------------------------------------------------------------------===// diff --git a/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift b/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift index 3d8acb8d2979c..90244a1973c87 100644 --- a/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift +++ b/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift @@ -4,6 +4,7 @@ import a extension Struct: Differentiable { public struct TangentVector: Differentiable & AdditiveArithmetic {} public mutating func move(along _: TangentVector) {} + public var zeroTangentVectorInitializer: () -> TangentVector { { .zero } } @usableFromInline @derivative(of: method, wrt: x) diff --git a/test/AutoDiff/Sema/DerivedConformances/derived_zero_tangent_vector_initializer.swift b/test/AutoDiff/Sema/DerivedConformances/derived_zero_tangent_vector_initializer.swift new file mode 100644 index 0000000000000..98ac63b275aaf --- /dev/null +++ b/test/AutoDiff/Sema/DerivedConformances/derived_zero_tangent_vector_initializer.swift @@ -0,0 +1,249 @@ +// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s + +// Check `Differentiable.zeroTangentVectorInitializer` derivation. +// +// There are two cases: +// 1. Memberwise derivation. +// +// var zeroTangentVectorInitializer: () -> TangentVector { +// { [xZeroTanInit = x.zeroTangentVectorInitializer, +// yZeroTanInit = y.zeroTangentVectorInitializer, ...] in +// return TangentVector(x: xZeroTanInit(), y: yZeroTanInit(), ...) +// } +// } +// +// 2. `{ TangentVector.zero }` fallback derivation. +// +// var zeroTangentVectorInitializer: () -> TangentVector { +// { TangentVector.zero } +// } + +import _Differentiation + +// - MARK: Structs + +struct MemberwiseTangentVectorStruct: Differentiable { + var x: Float + var y: Double + + // Expected memberwise `zeroTangentVectorInitializer` synthesis (1). +} + +struct SelfTangentVectorStruct: Differentiable & AdditiveArithmetic { + var x: Float + var y: Double + typealias TangentVector = Self + + // Expected memberwise `zeroTangentVectorInitializer` synthesis (1). +} + +struct CustomTangentVectorStruct: Differentiable { + var x: T + var y: U + + typealias TangentVector = T.TangentVector + mutating func move(along direction: TangentVector) {} + + // Expected fallback `zeroTangentVectorInitializer` synthesis (2). +} + +// - MARK: Classes + +class MemberwiseTangentVectorClass: Differentiable { + var x: Float = 0.0 + var y: Double = 0.0 + + // Expected memberwise `zeroTangentVectorInitializer` synthesis (1). +} + +final class SelfTangentVectorClass: Differentiable & AdditiveArithmetic { + var x: Float = 0.0 + var y: Double = 0.0 + typealias TangentVector = SelfTangentVectorClass + + static func ==(lhs: SelfTangentVectorClass, rhs: SelfTangentVectorClass) -> Bool { fatalError() } + static var zero: Self { fatalError() } + static func +(lhs: SelfTangentVectorClass, rhs: SelfTangentVectorClass) -> Self { fatalError() } + static func -(lhs: SelfTangentVectorClass, rhs: SelfTangentVectorClass) -> Self { fatalError() } + + // Expected memberwise `zeroTangentVectorInitializer` synthesis (1). +} + +class CustomTangentVectorClass: Differentiable { + var x: T + var y: U + + init(x: T, y: U) { + self.x = x + self.y = y + } + + typealias TangentVector = T.TangentVector + func move(along direction: TangentVector) {} + + // Expected fallback `zeroTangentVectorInitializer` synthesis (2). +} + +// - MARK: Enums + +enum SelfTangentVectorEnum: Differentiable & AdditiveArithmetic { + case a([Float]) + case b([Float], Float) + case c + + typealias TangentVector = SelfTangentVectorEnum + + static func ==(lhs: Self, rhs: Self) -> Bool { fatalError() } + static var zero: Self { fatalError() } + static func +(lhs: Self, rhs: Self) -> Self { fatalError() } + static func -(lhs: Self, rhs: Self) -> Self { fatalError() } + + // TODO(TF-1012): Implement memberwise `zeroTangentVectorInitializer` synthesis for enums. + // Expected fallback `zeroTangentVectorInitializer` synthesis (2). +} + +enum CustomTangentVectorEnum: Differentiable { + case a(T) + + typealias TangentVector = T.TangentVector + mutating func move(along direction: TangentVector) {} + + // Expected fallback `zeroTangentVectorInitializer` synthesis (2). +} + +// CHECK-LABEL: // MemberwiseTangentVectorStruct.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}MemberwiseTangentVectorStructV0bgH11InitializerAC0gH0Vycvg : $@convention(method) (MemberwiseTangentVectorStruct) -> @owned @callee_guaranteed () -> MemberwiseTangentVectorStruct.TangentVector { +// CHECK: bb0([[SELF:%.*]] : $MemberwiseTangentVectorStruct): +// CHECK: [[X_PROP:%.*]] = struct_extract [[SELF]] : $MemberwiseTangentVectorStruct, #MemberwiseTangentVectorStruct.x +// CHECK: [[X_ZERO_INIT_FN:%.*]] = function_ref @$sSf{{.*}}E28zeroTangentVectorInitializerSfycvg : $@convention(method) (Float) -> @owned @callee_guaranteed () -> Float +// CHECK: [[X_ZERO_INIT:%.*]] = apply [[X_ZERO_INIT_FN]]([[X_PROP]]) +// CHECK: [[Y_PROP:%.*]] = struct_extract [[SELF]] : $MemberwiseTangentVectorStruct, #MemberwiseTangentVectorStruct.y +// CHECK: [[Y_ZERO_INIT_FN:%.*]] = function_ref @$sSd{{.*}}E28zeroTangentVectorInitializerSdycvg : $@convention(method) (Double) -> @owned @callee_guaranteed () -> Double +// CHECK: [[Y_ZERO_INIT:%.*]] = apply [[Y_ZERO_INIT_FN]]([[Y_PROP]]) +// CHECK: // function_ref closure #1 in MemberwiseTangentVectorStruct.zeroTangentVectorInitializer.getter +// CHECK: [[CLOSURE_FN:%.*]] = function_ref @${{.*}}MemberwiseTangentVectorStructV0bgH11InitializerAC0gH0VycvgAFycfU_ +// CHECK: [[X_ZERO_INIT_COPY:%.*]] = copy_value [[X_ZERO_INIT]] +// CHECK: [[Y_ZERO_INIT_COPY:%.*]] = copy_value [[Y_ZERO_INIT]] +// CHECK: [[ZERO_INIT:%.*]] = partial_apply [callee_guaranteed] [[CLOSURE_FN]]([[X_ZERO_INIT_COPY]], [[Y_ZERO_INIT_COPY]]) +// CHECK: return [[ZERO_INIT]] : $@callee_guaranteed () -> MemberwiseTangentVectorStruct.TangentVector +// CHECK: } + +// CHECK-LABEL: // SelfTangentVectorStruct.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}SelfTangentVectorStructV0bgH11InitializerACycvg : $@convention(method) (SelfTangentVectorStruct) -> @owned @callee_guaranteed () -> SelfTangentVectorStruct { +// CHECK: bb0([[SELF:%.*]] : $SelfTangentVectorStruct): +// CHECK: [[X_PROP:%.*]] = struct_extract [[SELF]] : $SelfTangentVectorStruct, #SelfTangentVectorStruct.x +// CHECK: [[X_ZERO_INIT_FN:%.*]] = function_ref @$sSf{{.*}}E28zeroTangentVectorInitializerSfycvg : $@convention(method) (Float) -> @owned @callee_guaranteed () -> Float +// CHECK: [[X_ZERO_INIT:%.*]] = apply [[X_ZERO_INIT_FN]]([[X_PROP]]) +// CHECK: [[Y_PROP:%.*]] = struct_extract [[SELF]] : $SelfTangentVectorStruct, #SelfTangentVectorStruct.y +// CHECK: [[Y_ZERO_INIT_FN:%.*]] = function_ref @$sSd{{.*}}E28zeroTangentVectorInitializerSdycvg : $@convention(method) (Double) -> @owned @callee_guaranteed () -> Double +// CHECK: [[Y_ZERO_INIT:%.*]] = apply [[Y_ZERO_INIT_FN]]([[Y_PROP]]) +// CHECK: // function_ref closure #2 in SelfTangentVectorStruct.zeroTangentVectorInitializer.getter +// CHECK: [[CLOSURE_FN:%.*]] = function_ref @${{.*}}SelfTangentVectorStructV0bgH11InitializerACycvgACycfU0_ +// CHECK: [[X_ZERO_INIT_COPY:%.*]] = copy_value [[X_ZERO_INIT]] +// CHECK: [[Y_ZERO_INIT_COPY:%.*]] = copy_value [[Y_ZERO_INIT]] +// CHECK: [[ZERO_INIT:%.*]] = partial_apply [callee_guaranteed] [[CLOSURE_FN]]([[X_ZERO_INIT_COPY]], [[Y_ZERO_INIT_COPY]]) +// CHECK: return [[ZERO_INIT]] : $@callee_guaranteed () -> SelfTangentVectorStruct +// CHECK: } + +// CHECK-LABEL: // CustomTangentVectorStruct.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}CustomTangentVectorStructV0bgH11Initializer0gH0Qzycvg : $@convention(method) (@in_guaranteed CustomTangentVectorStruct) -> @owned @callee_guaranteed @substituted <τ_0_0> () -> @out τ_0_0 for { +// CHECK: bb0([[SELF:%.*]] : $*CustomTangentVectorStruct): +// CHECK: // function_ref closure #3 in CustomTangentVectorStruct.zeroTangentVectorInitializer.getter +// CHECK: function_ref @${{.*}}CustomTangentVectorStructV0bgH11Initializer0gH0QzycvgAFycfU1_ +// CHECK: } + +// CHECK-LABEL: // MemberwiseTangentVectorClass.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}MemberwiseTangentVectorClassC0bgH11InitializerAC0gH0Vycvg : $@convention(method) (@guaranteed MemberwiseTangentVectorClass) -> @owned @callee_guaranteed () -> MemberwiseTangentVectorClass.TangentVector { +// CHECK: bb0([[SELF:%.*]] : @guaranteed $MemberwiseTangentVectorClass): +// CHECK: [[X_PROP_METHOD:%.*]] = class_method [[SELF]] : $MemberwiseTangentVectorClass, #MemberwiseTangentVectorClass.x!getter +// CHECK: [[X_PROP:%.*]] = apply [[X_PROP_METHOD]]([[SELF]]) +// CHECK: [[X_ZERO_INIT_FN:%.*]] = function_ref @$sSf{{.*}}E28zeroTangentVectorInitializerSfycvg : $@convention(method) (Float) -> @owned @callee_guaranteed () -> Float +// CHECK: [[X_ZERO_INIT:%.*]] = apply [[X_ZERO_INIT_FN]]([[X_PROP]]) +// CHECK: [[Y_PROP_METHOD:%.*]] = class_method [[SELF]] : $MemberwiseTangentVectorClass, #MemberwiseTangentVectorClass.y!getter +// CHECK: [[Y_PROP:%.*]] = apply [[Y_PROP_METHOD]]([[SELF]]) +// CHECK: [[Y_ZERO_INIT_FN:%.*]] = function_ref @$sSd{{.*}}E28zeroTangentVectorInitializerSdycvg : $@convention(method) (Double) -> @owned @callee_guaranteed () -> Double +// CHECK: [[Y_ZERO_INIT:%.*]] = apply [[Y_ZERO_INIT_FN]]([[Y_PROP]]) +// CHECK: // function_ref closure #4 in MemberwiseTangentVectorClass.zeroTangentVectorInitializer.getter +// CHECK: [[CLOSURE_FN:%.*]] = function_ref @${{.*}}MemberwiseTangentVectorClassC0bgH11InitializerAC0gH0VycvgAFycfU2_ +// CHECK: [[X_ZERO_INIT_COPY:%.*]] = copy_value [[X_ZERO_INIT]] +// CHECK: [[Y_ZERO_INIT_COPY:%.*]] = copy_value [[Y_ZERO_INIT]] +// CHECK: [[ZERO_INIT:%.*]] = partial_apply [callee_guaranteed] [[CLOSURE_FN]]([[X_ZERO_INIT_COPY]], [[Y_ZERO_INIT_COPY]]) +// CHECK: return [[ZERO_INIT]] : $@callee_guaranteed () -> MemberwiseTangentVectorClass.TangentVector +// CHECK: } + +// CHECK-LABEL: // SelfTangentVectorClass.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}SelfTangentVectorClassC0bgH11InitializerACycvg : $@convention(method) (@guaranteed SelfTangentVectorClass) -> @owned @callee_guaranteed () -> @owned SelfTangentVectorClass { +// CHECK: bb0([[SELF:%.*]] : @guaranteed $SelfTangentVectorClass): +// CHECK: // function_ref closure #5 in SelfTangentVectorClass.zeroTangentVectorInitializer.getter +// CHECK: function_ref @${{.*}}SelfTangentVectorClassC0bgH11InitializerACycvgACycfU3_ +// CHECK: } + +// CHECK-LABEL: // CustomTangentVectorClass.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}CustomTangentVectorClassC0bgH11Initializer0gH0Qzycvg : $@convention(method) (@guaranteed CustomTangentVectorClass) -> @owned @callee_guaranteed @substituted <τ_0_0> () -> @out τ_0_0 for { +// CHECK: bb0(%0 : @guaranteed $CustomTangentVectorClass): +// CHECK: // function_ref closure #6 in CustomTangentVectorClass.zeroTangentVectorInitializer.getter +// CHECK: function_ref @${{.*}}CustomTangentVectorClassC0bgH11Initializer0gH0QzycvgAFycfU4_ +// CHECK: } + +// CHECK-LABEL: // SelfTangentVectorEnum.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}SelfTangentVectorEnumO0bgH11InitializerACycvg : $@convention(method) (@guaranteed SelfTangentVectorEnum) -> @owned @callee_guaranteed () -> @owned SelfTangentVectorEnum { +// CHECK: bb0([[SELF:%.*]] : @guaranteed $SelfTangentVectorEnum): +// CHECK: // function_ref closure #7 in SelfTangentVectorEnum.zeroTangentVectorInitializer.getter +// CHECK: function_ref @${{.*}}SelfTangentVectorEnumO0bgH11InitializerACycvgACycfU5_ +// CHECK: } + +// CHECK-LABEL: // CustomTangentVectorEnum.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}CustomTangentVectorEnumO0bgH11Initializer0gH0Qzycvg : $@convention(method) (@in_guaranteed CustomTangentVectorEnum) -> @owned @callee_guaranteed @substituted <τ_0_0> () -> @out τ_0_0 for { +// CHECK: bb0([[SELF:%.*]] : $*CustomTangentVectorEnum): +// CHECK: // function_ref closure #8 in CustomTangentVectorEnum.zeroTangentVectorInitializer.getter +// CHECK: function_ref @${{.*}}CustomTangentVectorEnumO0bgH11Initializer0gH0QzycvgAFycfU6_ +// CHECK: } + +// CHECK-LABEL: // closure #1 in MemberwiseTangentVectorStruct.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil private [ossa] @${{.*}}MemberwiseTangentVectorStructV0bgH11InitializerAC0gH0VycvgAFycfU_ : $@convention(thin) (@guaranteed @callee_guaranteed () -> Float, @guaranteed @callee_guaranteed () -> Double) -> MemberwiseTangentVectorStruct.TangentVector { +// CHECK: // function_ref MemberwiseTangentVectorStruct.TangentVector.init(x:y:) +// CHECK-NOT: // function_ref static {{.*}}.zero.getter +// CHECK-NOT: witness_method {{.*}}, #AdditiveArithmetic.zero!getter +// CHECK: } + +// CHECK-LABEL: // closure #2 in SelfTangentVectorStruct.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil private [ossa] @${{.*}}SelfTangentVectorStructV0bgH11InitializerACycvgACycfU0_ : $@convention(thin) (@guaranteed @callee_guaranteed () -> Float, @guaranteed @callee_guaranteed () -> Double) -> SelfTangentVectorStruct { +// CHECK: // function_ref SelfTangentVectorStruct.init(x:y:) +// CHECK-NOT: // function_ref static {{.*}}.zero.getter +// CHECK-NOT: witness_method {{.*}}, #AdditiveArithmetic.zero!getter +// CHECK: } + +// CHECK-LABEL: // closure #3 in CustomTangentVectorStruct.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil private [ossa] @${{.*}}CustomTangentVectorStructV0bgH11Initializer0gH0QzycvgAFycfU1_ : $@convention(thin) () -> @out T.TangentVector { +// CHECK: witness_method $T.TangentVector, #AdditiveArithmetic.zero!getter +// CHECK: } + +// CHECK-LABEL: // closure #4 in MemberwiseTangentVectorClass.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil private [ossa] @${{.*}}MemberwiseTangentVectorClassC0bgH11InitializerAC0gH0VycvgAFycfU2_ : $@convention(thin) (@guaranteed @callee_guaranteed () -> Float, @guaranteed @callee_guaranteed () -> Double) -> MemberwiseTangentVectorClass.TangentVector { +// CHECK: // function_ref MemberwiseTangentVectorClass.TangentVector.init(x:y:) +// CHECK-NOT: // function_ref static {{.*}}.zero.getter +// CHECK-NOT: witness_method {{.*}}, #AdditiveArithmetic.zero!getter +// CHECK: } + +// CHECK-LABEL: // closure #5 in SelfTangentVectorClass.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil private [ossa] @${{.*}}SelfTangentVectorClassC0bgH11InitializerACycvgACycfU3_ : $@convention(thin) () -> @owned SelfTangentVectorClass { +// CHECK: // function_ref static SelfTangentVectorClass.zero.getter +// CHECK: function_ref @${{.*}}SelfTangentVectorClassC0B0ACXDvgZ : $@convention(method) (@thick SelfTangentVectorClass.Type) -> @owned SelfTangentVectorClass +// CHECK: } + +// CHECK-LABEL: // closure #6 in CustomTangentVectorClass.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil private [ossa] @${{.*}}CustomTangentVectorClassC0bgH11Initializer0gH0QzycvgAFycfU4_ : $@convention(thin) () -> @out T.TangentVector { +// CHECK: witness_method $T.TangentVector, #AdditiveArithmetic.zero!getter +// CHECK: } + +// TODO(TF-1012): Implement memberwise `zeroTangentVectorInitializer` synthesis for enums. +// CHECK-LABEL: // closure #7 in SelfTangentVectorEnum.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil private [ossa] @${{.*}}SelfTangentVectorEnumO0bgH11InitializerACycvgACycfU5_ : $@convention(thin) () -> @owned SelfTangentVectorEnum { +// CHECK: // function_ref static SelfTangentVectorEnum.zero.getter +// CHECK: function_ref @${{.*}}SelfTangentVectorEnumO0B0ACvgZ : $@convention(method) (@thin SelfTangentVectorEnum.Type) -> @owned SelfTangentVectorEnum +// CHECK: } + +// CHECK-LABEL: // closure #8 in CustomTangentVectorEnum.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil private [ossa] @$s39derived_zero_tangent_vector_initializer23CustomTangentVectorEnumO0bgH11Initializer0gH0QzycvgAFycfU6_ : $@convention(thin) () -> @out T.TangentVector { +// CHECK: witness_method $T.TangentVector, #AdditiveArithmetic.zero!getter +// CHECK: } diff --git a/test/AutoDiff/validation-test/zero_tangent_vector_initializer.swift b/test/AutoDiff/validation-test/zero_tangent_vector_initializer.swift new file mode 100644 index 0000000000000..05b8c8bb514fb --- /dev/null +++ b/test/AutoDiff/validation-test/zero_tangent_vector_initializer.swift @@ -0,0 +1,59 @@ +// RUN: %target-run-simple-swift +// REQUIRES: executable_test + +import _Differentiation +import StdlibUnittest + +var ZeroTangentVectorTests = TestSuite("zeroTangentVectorInitializer") + +struct Generic: Differentiable { + var x: T + var y: U +} + +struct Nested: Differentiable { + var generic: Generic +} + +ZeroTangentVectorTests.test("Derivation") { + typealias G = Generic<[Float], [[Float]]> + + let generic = G(x: [1, 2, 3], y: [[4, 5, 6], [], [2]]) + let genericZero = G.TangentVector(x: [0, 0, 0], y: [[0, 0, 0], [], [0]]) + expectEqual(generic.zeroTangentVector, genericZero) + + let nested = Nested(generic: generic) + let nestedZero = Nested.TangentVector(generic: genericZero) + expectEqual(nested.zeroTangentVector, nestedZero) +} + +// Test differentiation correctness involving projection operations and +// per-instance zeros. +ZeroTangentVectorTests.test("DifferentiationCorrectness") { + struct Struct: Differentiable { + var x, y: [Float] + } + func concatenated(_ lhs: Struct, _ rhs: Struct) -> Struct { + return Struct(x: lhs.x + rhs.x, y: lhs.y + rhs.y) + } + func test(_ s: Struct) -> [Float] { + let result = concatenated(s, s).withDerivative { dresult in + // FIXME(TF-1008): Fix incorrect derivative values for + // "projection operation" operands when differentiation transform uses + // `Differentiable.zeroTangentVectorInitializer`. + // Actual: TangentVector(x: [1.0, 1.0, 1.0], y: []) + // Expected: TangentVector(x: [1.0, 1.0, 1.0], y: [1.0, 1.0, 1.0]) + expectEqual(dresult, Struct.TangentVector(x: [1, 1, 1], y: [1, 1, 1])) + } + return result.x + } + let s = Struct(x: [1, 2, 3], y: [1, 2, 3]) + let pb = pullback(at: s, in: test) + // FIXME(TF-1008): Remove `expectCrash` when differentiation transform uses + // `Differentiable.zeroTangentVectorInitializer`. + expectCrash { + _ = pb([1, 1, 1]) + } +} + +runAllTests()