From 1052d3c48c955e38581eb64025f8ce154f724221 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 18 Jun 2020 14:16:54 -0700 Subject: [PATCH 1/3] [AutoDiff] Simplify basic block active value collection. Make `recordValueIfActive` short-circuit, returning true on error. Remove booleans tracking whether diagnostics have been emitted. --- .../Differentiation/PullbackEmitter.cpp | 73 ++++++++++--------- 1 file changed, 39 insertions(+), 34 deletions(-) diff --git a/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp b/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp index f6e2c5be52732..50020859dbae5 100644 --- a/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp @@ -797,6 +797,7 @@ bool PullbackEmitter::run() { "Functions without returns must have been diagnosed"); auto *origExit = &*origExitIt; + // Collect original formal results. SmallVector origFormalResults; collectAllFormalResultsInTypeOrder(original, origFormalResults); for (auto resultIndex : getIndices().results->getIndices()) { @@ -815,7 +816,7 @@ bool PullbackEmitter::run() { } } - // Get dominated active values in original blocks. + // Collect dominated active values in original basic blocks. // Adjoint values of dominated active values are passed as pullback block // arguments. DominanceOrder domOrder(original.getEntryBlock(), domInfo); @@ -829,15 +830,21 @@ bool PullbackEmitter::run() { auto &domBBActiveValues = activeValues[domNode->getBlock()]; bbActiveValues.append(domBBActiveValues.begin(), domBBActiveValues.end()); } - // Booleans tracking whether active-value-related errors have been emitted. - // This prevents duplicate diagnostics for the same active values. - bool diagnosedActiveEnumValue = false; - bool diagnosedActiveValueTangentValueCategoryIncompatible = false; - // Mark the activity of a value if it has not yet been visited. - auto markValueActivity = [&](SILValue v) { + // If `v` is active and has not been visited, records it as an active value + // in the original basic block. + // For active values unsupported by differentiation, emits a diagnostic and + // returns true. Otherwise, returns false. + auto recordValueIfActive = [&](SILValue v) -> bool { + // If value is not active, skip. + if (!getActivityInfo().isActive(v, getIndices())) + return false; + // If active value has already been visited, skip. if (visited.count(v)) - return; + return false; + // Mark active value as visited. visited.insert(v); + + // Diagnose unsupported active values. auto type = v->getType(); // Diagnose active values whose value category is incompatible with their // tangent types's value category. @@ -851,56 +858,54 @@ bool PullbackEmitter::run() { // $*A | $L | Yes (can create $*L adjoint buffer) // $L | $*A | No (cannot create $A adjoint value) // $*A | $*A | Yes (no mismatch) - if (!diagnosedActiveValueTangentValueCategoryIncompatible) { - if (auto tanSpace = getTangentSpace(remapType(type).getASTType())) { - auto tanASTType = tanSpace->getCanonicalType(); - auto &origTL = getTypeLowering(type.getASTType()); - auto &tanTL = getTypeLowering(tanASTType); - if (!origTL.isAddressOnly() && tanTL.isAddressOnly()) { - getContext().emitNondifferentiabilityError( - v, getInvoker(), - diag::autodiff_loadable_value_addressonly_tangent_unsupported, - type.getASTType(), tanASTType); - diagnosedActiveValueTangentValueCategoryIncompatible = true; - errorOccurred = true; - } + if (auto tanSpace = getTangentSpace(remapType(type).getASTType())) { + auto tanASTType = tanSpace->getCanonicalType(); + auto &origTL = getTypeLowering(type.getASTType()); + auto &tanTL = getTypeLowering(tanASTType); + if (!origTL.isAddressOnly() && tanTL.isAddressOnly()) { + getContext().emitNondifferentiabilityError( + v, getInvoker(), + diag::autodiff_loadable_value_addressonly_tangent_unsupported, + type.getASTType(), tanASTType); + errorOccurred = true; + return true; } } // Do not emit remaining activity-related diagnostics for semantic member // accessors, which have special-case pullback generation. if (isSemanticMemberAccessor(&original)) - return; + return false; // Diagnose active enum values. Differentiation of enum values requires // special adjoint value handling and is not yet supported. Diagnose // only the first active enum value to prevent too many diagnostics. - if (!diagnosedActiveEnumValue && type.getEnumOrBoundGenericEnum()) { + if (type.getEnumOrBoundGenericEnum()) { getContext().emitNondifferentiabilityError( v, getInvoker(), diag::autodiff_enums_unsupported); errorOccurred = true; - diagnosedActiveEnumValue = true; + return true; } // Skip address projections. // Address projections do not need their own adjoint buffers; they // become projections into their adjoint base buffer. if (Projection::isAddressProjection(v)) - return; + return false; + // Record active value. bbActiveValues.push_back(v); + return false; }; - // Visit bb arguments and all instruction operands/results. + // Record all active values in the basic block. for (auto *arg : bb->getArguments()) - if (getActivityInfo().isActive(arg, getIndices())) - markValueActivity(arg); + if (recordValueIfActive(arg)) + return true; for (auto &inst : *bb) { for (auto op : inst.getOperandValues()) - if (getActivityInfo().isActive(op, getIndices())) - markValueActivity(op); + if (recordValueIfActive(op)) + return true; for (auto result : inst.getResults()) - if (getActivityInfo().isActive(result, getIndices())) - markValueActivity(result); + if (recordValueIfActive(result)) + return true; } domOrder.pushChildren(bb); - if (errorOccurred) - return true; } // Create pullback blocks and arguments, visiting original blocks in From f163072b2a956cd12613bea70dbe323bc65ed5a3 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Mon, 22 Jun 2020 10:10:56 -0700 Subject: [PATCH 2/3] [AutoDiff] Add TangentStoredPropertyRequest. Add request that resolves the "tangent stored property" corresponding to an original stored property in a `Differentiable`-conforming type. Enables better non-differentiability differentiation transform diagnostics. --- include/swift/AST/ASTTypeIDZone.def | 1 + include/swift/AST/ASTTypeIDs.h | 9 +- include/swift/AST/AutoDiff.h | 94 +++++++++++++ include/swift/AST/TypeCheckRequests.h | 21 +++ include/swift/AST/TypeCheckerTypeIDZone.def | 2 + lib/AST/AutoDiff.cpp | 125 ++++++++++++++++++ lib/Sema/DerivedConformanceDifferentiable.cpp | 41 +++--- 7 files changed, 271 insertions(+), 22 deletions(-) diff --git a/include/swift/AST/ASTTypeIDZone.def b/include/swift/AST/ASTTypeIDZone.def index 5c1eaa91fe286..daf638f797480 100644 --- a/include/swift/AST/ASTTypeIDZone.def +++ b/include/swift/AST/ASTTypeIDZone.def @@ -27,6 +27,7 @@ SWIFT_TYPEID(PropertyWrapperTypeInfo) SWIFT_TYPEID(Requirement) SWIFT_TYPEID(ResilienceExpansion) SWIFT_TYPEID(FragileFunctionKind) +SWIFT_TYPEID(TangentPropertyInfo) SWIFT_TYPEID(Type) SWIFT_TYPEID(TypePair) SWIFT_TYPEID(TypeWitnessAndDecl) diff --git a/include/swift/AST/ASTTypeIDs.h b/include/swift/AST/ASTTypeIDs.h index 9359f95427d0e..b81fdc686b9ac 100644 --- a/include/swift/AST/ASTTypeIDs.h +++ b/include/swift/AST/ASTTypeIDs.h @@ -19,6 +19,7 @@ #include "swift/Basic/LLVM.h" #include "swift/Basic/TypeID.h" + namespace swift { class AbstractFunctionDecl; @@ -58,14 +59,14 @@ class Requirement; enum class ResilienceExpansion : unsigned; struct FragileFunctionKind; class SourceFile; +struct TangentPropertyInfo; class Type; -class ValueDecl; -class VarDecl; -class Witness; class TypeAliasDecl; -class Type; struct TypePair; struct TypeWitnessAndDecl; +class ValueDecl; +class VarDecl; +class Witness; enum class AncestryFlags : uint8_t; enum class ImplicitMemberAction : uint8_t; struct FingerprintAndMembers; diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 5b447027b20c5..18524c9933e9e 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -35,6 +35,7 @@ class AnyFunctionType; class SourceFile; class SILFunctionType; class TupleType; +class VarDecl; /// A function type differentiability kind. enum class DifferentiabilityKind : uint8_t { @@ -459,6 +460,99 @@ class DerivativeFunctionTypeError } }; +/// Describes the "tangent stored property" corresponding to an original stored +/// property in a `Differentiable`-conforming type. +/// +/// The tangent stored property is the stored property in the `TangentVector` +/// struct of the `Differentiable`-conforming type, with the same name as the +/// original stored property and with the original stored property's +/// `TangentVector` type. +struct TangentPropertyInfo { + struct Error { + enum class Kind { + /// The original property is `@noDerivative`. + NoDerivativeOriginalProperty, + /// The nominal parent type does not conform to `Differentiable`. + NominalParentNotDifferentiable, + /// The original property's type does not conform to `Differentiable`. + OriginalPropertyNotDifferentiable, + /// The parent `TangentVector` type is not a struct. + ParentTangentVectorNotStruct, + /// The parent `TangentVector` struct does not declare a stored property + /// with the same name as the original property. + TangentPropertyNotFound, + /// The tangent property's type is not equal to the original property's + /// `TangentVector` type. + TangentPropertyWrongType, + /// The tangent property is not a stored property. + TangentPropertyNotStored + }; + + /// The error kind. + Kind kind; + + private: + union Value { + Type type; + Value(Type type) : type(type) {} + Value() {} + } value; + + public: + Error(Kind kind) : kind(kind), value() { + assert(kind == Kind::NoDerivativeOriginalProperty || + kind == Kind::NominalParentNotDifferentiable || + kind == Kind::OriginalPropertyNotDifferentiable || + kind == Kind::ParentTangentVectorNotStruct || + kind == Kind::TangentPropertyNotFound || + kind == Kind::TangentPropertyNotStored); + }; + + Error(Kind kind, Type type) : kind(kind), value(type) { + assert(kind == Kind::TangentPropertyWrongType); + }; + + Type getType() const { + assert(kind == Kind::TangentPropertyWrongType); + return value.type; + } + + friend bool operator==(const Error &lhs, const Error &rhs); + }; + + /// The tangent stored property. + VarDecl *tangentProperty = nullptr; + + /// An optional error. + Optional error = None; + +private: + TangentPropertyInfo(VarDecl *tangentProperty, Optional error) + : tangentProperty(tangentProperty), error(error) {} + +public: + TangentPropertyInfo(VarDecl *tangentProperty) + : TangentPropertyInfo(tangentProperty, None) {} + + TangentPropertyInfo(Error::Kind errorKind) + : TangentPropertyInfo(nullptr, Error(errorKind)) {} + + TangentPropertyInfo(Error::Kind errorKind, Type errorType) + : TangentPropertyInfo(nullptr, Error(errorKind, errorType)) {} + + /// Returns `true` iff this tangent property info is valid. + bool isValid() const { return tangentProperty && !error; } + + explicit operator bool() const { return isValid(); } + + friend bool operator==(const TangentPropertyInfo &lhs, + const TangentPropertyInfo &rhs) { + return lhs.tangentProperty == rhs.tangentProperty && lhs.error == rhs.error; + } +}; + +void simple_display(llvm::raw_ostream &OS, TangentPropertyInfo info); + /// The key type used for uniquing `SILDifferentiabilityWitness` in /// `SILModule`: original function name, parameter indices, result indices, and /// derivative generic signature. diff --git a/include/swift/AST/TypeCheckRequests.h b/include/swift/AST/TypeCheckRequests.h index 46d6b5ae0810a..bed14bfd6c62a 100644 --- a/include/swift/AST/TypeCheckRequests.h +++ b/include/swift/AST/TypeCheckRequests.h @@ -2191,6 +2191,27 @@ class DerivativeAttrOriginalDeclRequest bool isCached() const { return true; } }; +/// Resolves the "tangent stored property" corresponding to an original stored +/// property in a `Differentiable`-conforming type. +class TangentStoredPropertyRequest + : public SimpleRequest { +public: + using SimpleRequest::SimpleRequest; + +private: + friend SimpleRequest; + + // Evaluation. + TangentPropertyInfo evaluate(Evaluator &evaluator, + VarDecl *originalField) const; + +public: + // Caching. + bool isCached() const { return true; } +}; + /// Checks whether a type eraser has a viable initializer. class TypeEraserHasViableInitRequest : public SimpleRequest(VarDecl *), Cached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, TypeCheckFunctionBodyRequest, bool(AbstractFunctionDecl *), Cached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, TypeCheckFunctionBodyAtLocRequest, diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index df1376a6cc41a..87c65561dcf6c 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -421,3 +421,128 @@ void DerivativeFunctionTypeError::log(raw_ostream &OS) const { } } } + +bool swift::operator==(const TangentPropertyInfo::Error &lhs, + const TangentPropertyInfo::Error &rhs) { + if (lhs.kind != rhs.kind) + return false; + switch (lhs.kind) { + case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty: + case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable: + case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable: + case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct: + case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound: + case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored: + return true; + case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType: + return lhs.getType()->isEqual(rhs.getType()); + } +} + +void swift::simple_display(llvm::raw_ostream &os, TangentPropertyInfo info) { + os << "{ "; + os << "tangent property: " + << (info.tangentProperty ? info.tangentProperty->printRef() : "null"); + if (info.error) { + os << ", error: "; + switch (info.error->kind) { + case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty: + os << "'@noDerivative' original property has no tangent property"; + break; + case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable: + os << "nominal parent does not conform to 'Differentiable'"; + break; + case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable: + os << "original property type does not conform to 'Differentiable'"; + break; + case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct: + os << "'TangentVector' type is not a struct"; + break; + case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound: + os << "'TangentVector' struct does not have stored property with the " + "same name as the original property"; + break; + case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType: + os << "tangent property's type is not equal to the original property's " + "'TangentVector' type"; + break; + case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored: + os << "'TangentVector' property '" << info.tangentProperty->getName() + << "' is not a stored property"; + break; + } + } + os << " }"; +} + +TangentPropertyInfo +TangentStoredPropertyRequest::evaluate(Evaluator &evaluator, + VarDecl *originalField) const { + assert(originalField->hasStorage() && originalField->isInstanceMember() && + "Expected stored property"); + auto *parentDC = originalField->getDeclContext(); + assert(parentDC->isTypeContext()); + auto parentType = parentDC->getDeclaredTypeInContext(); + auto *moduleDecl = originalField->getModuleContext(); + auto parentTan = parentType->getAutoDiffTangentSpace( + LookUpConformanceInModule(moduleDecl)); + // Error if parent nominal type does not conform to `Differentiable`. + if (!parentTan) { + return TangentPropertyInfo( + TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable); + } + // Error if original stored property is `@noDerivative`. + if (originalField->getAttrs().hasAttribute()) { + return TangentPropertyInfo( + TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty); + } + // Error if original property's type does not conform to `Differentiable`. + auto originalFieldTan = originalField->getType()->getAutoDiffTangentSpace( + LookUpConformanceInModule(moduleDecl)); + if (!originalFieldTan) { + return TangentPropertyInfo( + TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable); + } + auto parentTanType = parentTan->getType(); + auto *parentTanStruct = parentTanType->getStructOrBoundGenericStruct(); + // Error if parent `TangentVector` is not a struct. + if (!parentTanStruct) { + return TangentPropertyInfo( + TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct); + } + // Find the corresponding field in the tangent space. + VarDecl *tanField = nullptr; + // If `TangentVector` is the original struct, then the tangent property is the + // original property. + if (parentTanStruct == parentDC->getSelfStructDecl()) { + tanField = originalField; + } + // Otherwise, look up the field by name. + else { + auto tanFieldLookup = + parentTanStruct->lookupDirect(originalField->getName()); + llvm::erase_if(tanFieldLookup, + [](ValueDecl *v) { return !isa(v); }); + // Error if tangent property could not be found. + if (tanFieldLookup.empty()) { + return TangentPropertyInfo( + TangentPropertyInfo::Error::Kind::TangentPropertyNotFound); + } + tanField = cast(tanFieldLookup.front()); + } + // Error if tangent property's type is not equal to the original property's + // `TangentVector` type. + auto originalFieldTanType = originalFieldTan->getType(); + if (!originalFieldTanType->isEqual(tanField->getType())) { + return TangentPropertyInfo( + TangentPropertyInfo::Error::Kind::TangentPropertyWrongType, + originalFieldTanType); + } + // Error if tangent property is not a stored property. + if (!tanField->hasStorage()) { + return TangentPropertyInfo( + TangentPropertyInfo::Error::Kind::TangentPropertyNotStored); + } + // Otherwise, tangent property is valid. + return TangentPropertyInfo(tanField); +} diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index 78312f6779e86..f0e6e10b5406b 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -26,6 +26,7 @@ #include "swift/AST/PropertyWrappers.h" #include "swift/AST/ProtocolConformance.h" #include "swift/AST/Stmt.h" +#include "swift/AST/TypeCheckRequests.h" #include "swift/AST/Types.h" #include "DerivedConformances.h" @@ -646,22 +647,21 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) { structDecl->setImplicit(); structDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); - // Add members to `TangentVector` struct. + // Add stored properties to the `TangentVector` struct. for (auto *member : diffProperties) { - // Add this member's corresponding `TangentVector` type to the parent's - // `TangentVector` struct. - // Note: `newMember` is not marked as implicit here, because that - // incorrectly affects memberwise initializer synthesis. - auto *newMember = new (C) VarDecl( + // Add a tangent stored property to the `TangentVector` struct, with the + // name and `TangentVector` type of the original property. + auto *tangentProperty = new (C) VarDecl( member->isStatic(), member->getIntroducer(), member->isCaptureList(), /*NameLoc*/ SourceLoc(), member->getName(), structDecl); - + // Note: `tangentProperty` is not marked as implicit here, because that + // incorrectly affects memberwise initializer synthesis. auto memberContextualType = parentDC->mapTypeIntoContext(member->getValueInterfaceType()); auto memberTanType = getTangentVectorInterfaceType(memberContextualType, parentDC); - newMember->setInterfaceType(memberTanType); - Pattern *memberPattern = NamedPattern::createImplicit(C, newMember); + tangentProperty->setInterfaceType(memberTanType); + Pattern *memberPattern = NamedPattern::createImplicit(C, tangentProperty); memberPattern->setType(memberTanType); memberPattern = TypedPattern::createImplicit(C, memberPattern, memberTanType); @@ -669,16 +669,21 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) { auto *memberBinding = PatternBindingDecl::createImplicit( C, StaticSpellingKind::None, memberPattern, /*initExpr*/ nullptr, structDecl); - structDecl->addMember(newMember); + structDecl->addMember(tangentProperty); structDecl->addMember(memberBinding); - newMember->copyFormalAccessFrom(member, /*sourceIsParentContext*/ true); - newMember->setSetterAccess(member->getFormalAccess()); - - // Now that this member is in the `TangentVector` type, it should be marked - // `@differentiable` so that the differentiation transform will synthesize - // derivative functions for it. We only add this to public stored - // properties, because their access outside the module will go through a - // call to the getter. + tangentProperty->copyFormalAccessFrom(member, + /*sourceIsParentContext*/ true); + tangentProperty->setSetterAccess(member->getFormalAccess()); + + // Cache the tangent property. + C.evaluator.cacheOutput(TangentStoredPropertyRequest{member}, + TangentPropertyInfo(tangentProperty)); + + // Now that the original property has a corresponding tangent property, it + // should be marked `@differentiable` so that the differentiation transform + // will synthesize derivative functions for its accessors. We only add this + // to public stored properties, because their access outside the module will + // go through accessor declarations. if (member->getEffectiveAccess() > AccessLevel::Internal && !member->getAttrs().hasAttribute()) { auto *getter = member->getSynthesizedAccessor(AccessorKind::Get); From c690ac87d68697edb4f2b2cce141c57dd6ba636f Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Mon, 22 Jun 2020 10:11:58 -0700 Subject: [PATCH 3/3] [AutoDiff] Improve invalid stored property projection diagnostics. Use TangentStoredPropertyRequest in differentiation transform. Improve non-differentiability diagnostics regarding invalid stored property projection instructions: `struct_extract`, `struct_element_addr`, `ref_element_addr`. Diagnose the following cases: - Original property's type does not conform to `Differentiable`. - Base type's `TangentVector` is not a struct. - Tangent property not found: base type's `TangentVector` does not have a stored property with the same name as the original property. - Tangent property's type is not equal to the original property's `TangentVector` type. - Tangent property is not a stored property. Resolves TF-969 and TF-970. --- include/swift/AST/DiagnosticsSIL.def | 21 ++- include/swift/SIL/SILInstruction.h | 7 + .../SILOptimizer/Differentiation/ADContext.h | 8 +- .../SILOptimizer/Differentiation/Common.h | 37 +++- lib/SILOptimizer/Differentiation/Common.cpp | 93 ++++++++++ .../Differentiation/JVPEmitter.cpp | 51 +---- .../Differentiation/PullbackEmitter.cpp | 125 +++---------- lib/SILOptimizer/Differentiation/Thunk.cpp | 1 - .../differentiation_diagnostics.swift | 174 +++++++++++++++++- .../forward_mode_diagnostics.swift | 170 +++++++++++++++++ 10 files changed, 536 insertions(+), 151 deletions(-) diff --git a/include/swift/AST/DiagnosticsSIL.def b/include/swift/AST/DiagnosticsSIL.def index 34ca3c7b08c86..cbc3ad300a690 100644 --- a/include/swift/AST/DiagnosticsSIL.def +++ b/include/swift/AST/DiagnosticsSIL.def @@ -504,9 +504,26 @@ NOTE(autodiff_loadable_value_addressonly_tangent_unsupported,none, "properties", (Type, Type)) NOTE(autodiff_enums_unsupported,none, "differentiating enum values is not yet supported", ()) +NOTE(autodiff_stored_property_parent_not_differentiable,none, + "cannot differentiate access to property '%0.%1' because '%0' does not " + "conform to 'Differentiable'", (StringRef, StringRef)) +NOTE(autodiff_stored_property_not_differentiable,none, + "cannot differentiate access to property '%0.%1' because property type %2 " + "does not conform to 'Differentiable'", (StringRef, StringRef, Type)) +NOTE(autodiff_stored_property_tangent_not_struct,none, + "cannot differentiate access to property '%0.%1' because " + "'%0.TangentVector' is not a struct", (StringRef, StringRef)) NOTE(autodiff_stored_property_no_corresponding_tangent,none, - "property cannot be differentiated because '%0.TangentVector' does not " - "have a member named '%1'", (StringRef, StringRef)) + "cannot differentiate access to property '%0.%1' because " + "'%0.TangentVector' does not have a stored property named '%1'", + (StringRef, StringRef)) +NOTE(autodiff_tangent_property_wrong_type,none, + "cannot differentiate access to property '%0.%1' because " + "'%0.TangentVector.%1' does not have expected type %2", + (StringRef, StringRef, /*originalPropertyTanType*/ Type)) +NOTE(autodiff_tangent_property_not_stored,none, + "cannot differentiate access to property '%0.%1' because " + "'%0.TangentVector.%1' is not a stored property", (StringRef, StringRef)) NOTE(autodiff_coroutines_not_supported,none, "differentiation of coroutine calls is not yet supported", ()) NOTE(autodiff_cannot_differentiate_writes_to_global_variables,none, diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index 39be3a320fe1e..0e811b94a46eb 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -5758,6 +5758,13 @@ class FieldIndexCacheBase : public SingleValueInstruction { return s; } + static bool classof(const SILNode *node) { + SILNodeKind kind = node->getKind(); + return kind == SILNodeKind::StructExtractInst || + kind == SILNodeKind::StructElementAddrInst || + kind == SILNodeKind::RefElementAddrInst; + } + private: unsigned cacheFieldIndex(); }; diff --git a/include/swift/SILOptimizer/Differentiation/ADContext.h b/include/swift/SILOptimizer/Differentiation/ADContext.h index 1b0426aed0f1e..a5bf5e86f1505 100644 --- a/include/swift/SILOptimizer/Differentiation/ADContext.h +++ b/include/swift/SILOptimizer/Differentiation/ADContext.h @@ -253,11 +253,9 @@ ADContext::emitNondifferentiabilityError(SILValue value, getADDebugStream() << "For value:\n" << value; getADDebugStream() << "With invoker:\n" << invoker << '\n'; }); - auto valueLoc = value.getLoc().getSourceLoc(); // If instruction does not have a valid location, use the function location // as a fallback. Improves diagnostics in some cases. - if (valueLoc.isInvalid()) - valueLoc = value->getFunction()->getLocation().getSourceLoc(); + auto valueLoc = getValidLocation(value).getSourceLoc(); return emitNondifferentiabilityError(valueLoc, invoker, diag, std::forward(args)...); } @@ -272,12 +270,10 @@ ADContext::emitNondifferentiabilityError(SILInstruction *inst, getADDebugStream() << "For instruction:\n" << *inst; getADDebugStream() << "With invoker:\n" << invoker << '\n'; }); - auto instLoc = inst->getLoc().getSourceLoc(); // If instruction does not have a valid location, use the function location // as a fallback. Improves diagnostics for `ref_element_addr` generated in // synthesized stored property getters. - if (instLoc.isInvalid()) - instLoc = inst->getFunction()->getLocation().getSourceLoc(); + auto instLoc = getValidLocation(inst).getSourceLoc(); return emitNondifferentiabilityError(instLoc, invoker, diag, std::forward(args)...); } diff --git a/include/swift/SILOptimizer/Differentiation/Common.h b/include/swift/SILOptimizer/Differentiation/Common.h index a4d8637b386d8..82bdb4a2a492d 100644 --- a/include/swift/SILOptimizer/Differentiation/Common.h +++ b/include/swift/SILOptimizer/Differentiation/Common.h @@ -17,6 +17,8 @@ #ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_COMMON_H #define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_COMMON_H +#include "swift/AST/DiagnosticsSIL.h" +#include "swift/AST/Expr.h" #include "swift/AST/SemanticAttrs.h" #include "swift/SIL/SILDifferentiabilityWitness.h" #include "swift/SIL/SILFunction.h" @@ -24,15 +26,18 @@ #include "swift/SIL/TypeSubstCloner.h" #include "swift/SILOptimizer/Analysis/ArraySemantic.h" #include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" +#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" namespace swift { +namespace autodiff { + +class ADContext; + //===----------------------------------------------------------------------===// // Helpers //===----------------------------------------------------------------------===// -namespace autodiff { - /// Prints an "[AD] " prefix to `llvm::dbgs()` and returns the debug stream. /// This is being used to print short debug messages within the AD pass. raw_ostream &getADDebugStream(); @@ -136,6 +141,34 @@ template Inst *peerThroughFunctionConversions(SILValue value) { return nullptr; } +//===----------------------------------------------------------------------===// +// Diagnostic utilities +//===----------------------------------------------------------------------===// + +// Returns `v`'s location if it is valid. Otherwise, returns `v`'s function's +// location as as a fallback. Used for diagnostics. +SILLocation getValidLocation(SILValue v); + +// Returns `inst`'s location if it is valid. Otherwise, returns `inst`'s +// function's location as as a fallback. Used for diagnostics. +SILLocation getValidLocation(SILInstruction *inst); + +//===----------------------------------------------------------------------===// +// Tangent property lookup utilities +//===----------------------------------------------------------------------===// + +/// Returns the tangent stored property of `originalField`. On error, emits +/// diagnostic and returns nullptr. +VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField, + SILLocation loc, + DifferentiationInvoker invoker); + +/// Returns the tangent stored property of the original stored property +/// referenced by `inst`. On error, emits diagnostic and returns nullptr. +VarDecl *getTangentStoredProperty(ADContext &context, + FieldIndexCacheBase *projectionInst, + DifferentiationInvoker invoker); + //===----------------------------------------------------------------------===// // Code emission utilities //===----------------------------------------------------------------------===// diff --git a/lib/SILOptimizer/Differentiation/Common.cpp b/lib/SILOptimizer/Differentiation/Common.cpp index be3f802c0e7d2..4184387fd1624 100644 --- a/lib/SILOptimizer/Differentiation/Common.cpp +++ b/lib/SILOptimizer/Differentiation/Common.cpp @@ -17,6 +17,8 @@ #define DEBUG_TYPE "differentiation" #include "swift/SILOptimizer/Differentiation/Common.h" +#include "swift/AST/TypeCheckRequests.h" +#include "swift/SILOptimizer/Differentiation/ADContext.h" namespace swift { namespace autodiff { @@ -244,6 +246,97 @@ void collectMinimalIndicesForFunctionCall( })); } +//===----------------------------------------------------------------------===// +// Diagnostic utilities +//===----------------------------------------------------------------------===// + +SILLocation getValidLocation(SILValue v) { + auto loc = v.getLoc(); + if (loc.isNull() || loc.getSourceLoc().isInvalid()) + loc = v->getFunction()->getLocation(); + return loc; +} + +SILLocation getValidLocation(SILInstruction *inst) { + auto loc = inst->getLoc(); + if (loc.isNull() || loc.getSourceLoc().isInvalid()) + loc = inst->getFunction()->getLocation(); + return loc; +} + +//===----------------------------------------------------------------------===// +// Tangent property lookup utilities +//===----------------------------------------------------------------------===// + +VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField, + SILLocation loc, + DifferentiationInvoker invoker) { + auto &astCtx = context.getASTContext(); + auto tanFieldInfo = evaluateOrDefault( + astCtx.evaluator, TangentStoredPropertyRequest{originalField}, + TangentPropertyInfo(nullptr)); + // If no error, return the tangent property. + if (tanFieldInfo) + return tanFieldInfo.tangentProperty; + // Otherwise, diagnose error and return nullptr. + assert(tanFieldInfo.error); + auto *parentDC = originalField->getDeclContext(); + assert(parentDC->isTypeContext()); + auto parentDeclName = parentDC->getSelfNominalTypeDecl()->getNameStr(); + auto fieldName = originalField->getNameStr(); + auto sourceLoc = loc.getSourceLoc(); + switch (tanFieldInfo.error->kind) { + case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty: + llvm_unreachable( + "`@noDerivative` stored property accesses should not be " + "differentiated; activity analysis should not mark as varied"); + case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable: + context.emitNondifferentiabilityError( + sourceLoc, invoker, + diag::autodiff_stored_property_parent_not_differentiable, + parentDeclName, fieldName); + break; + case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable: + context.emitNondifferentiabilityError( + sourceLoc, invoker, diag::autodiff_stored_property_not_differentiable, + parentDeclName, fieldName, originalField->getInterfaceType()); + break; + case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct: + context.emitNondifferentiabilityError( + sourceLoc, invoker, diag::autodiff_stored_property_tangent_not_struct, + parentDeclName, fieldName); + break; + case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound: + context.emitNondifferentiabilityError( + sourceLoc, invoker, + diag::autodiff_stored_property_no_corresponding_tangent, parentDeclName, + fieldName); + break; + case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType: + context.emitNondifferentiabilityError( + sourceLoc, invoker, diag::autodiff_tangent_property_wrong_type, + parentDeclName, fieldName, tanFieldInfo.error->getType()); + break; + case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored: + context.emitNondifferentiabilityError( + sourceLoc, invoker, diag::autodiff_tangent_property_not_stored, + parentDeclName, fieldName); + break; + } + return nullptr; +} + +VarDecl *getTangentStoredProperty(ADContext &context, + FieldIndexCacheBase *projectionInst, + DifferentiationInvoker invoker) { + assert(isa(projectionInst) || + isa(projectionInst) || + isa(projectionInst)); + auto loc = getValidLocation(projectionInst); + return getTangentStoredProperty(context, projectionInst->getField(), loc, + invoker); +} + //===----------------------------------------------------------------------===// // Code emission utilities //===----------------------------------------------------------------------===// diff --git a/lib/SILOptimizer/Differentiation/JVPEmitter.cpp b/lib/SILOptimizer/Differentiation/JVPEmitter.cpp index 8f53b37667113..9e0cd4f47a2ae 100644 --- a/lib/SILOptimizer/Differentiation/JVPEmitter.cpp +++ b/lib/SILOptimizer/Differentiation/JVPEmitter.cpp @@ -547,29 +547,12 @@ CLONE_AND_EMIT_TANGENT(StructExtract, sei) { assert(!sei->getField()->getAttrs().hasAttribute() && "`struct_extract` with `@noDerivative` field should not be " "differentiated; activity analysis should not marked as varied."); - auto diffBuilder = getDifferentialBuilder(); - ; - auto tangentVectorTy = getRemappedTangentType(sei->getOperand()->getType()); - auto *tangentVectorDecl = tangentVectorTy.getStructOrBoundGenericStruct(); - // Find the corresponding field in the tangent space. - VarDecl *tanField = nullptr; - // If the tangent space is the original struct, then field is the same. - if (tangentVectorDecl == sei->getStructDecl()) - tanField = sei->getField(); - // Otherwise, look up the field by name. - else { - auto tanFieldLookup = - tangentVectorDecl->lookupDirect(sei->getField()->getName()); - if (tanFieldLookup.empty()) { - context.emitNondifferentiabilityError( - sei, invoker, diag::autodiff_stored_property_no_corresponding_tangent, - sei->getStructDecl()->getNameStr(), sei->getField()->getNameStr()); - errorOccurred = true; - return; - } - tanField = cast(tanFieldLookup.front()); + auto *tanField = getTangentStoredProperty(context, sei, invoker); + if (!tanField) { + errorOccurred = true; + return; } // Emit tangent `struct_extract`. auto tanStruct = @@ -590,32 +573,14 @@ CLONE_AND_EMIT_TANGENT(StructElementAddr, seai) { assert(!seai->getField()->getAttrs().hasAttribute() && "`struct_element_addr` with `@noDerivative` field should not be " "differentiated; activity analysis should not marked as varied."); - auto diffBuilder = getDifferentialBuilder(); auto *bb = seai->getParent(); - auto tangentVectorTy = getRemappedTangentType(seai->getOperand()->getType()); - auto *tangentVectorDecl = tangentVectorTy.getStructOrBoundGenericStruct(); - // Find the corresponding field in the tangent space. - VarDecl *tanField = nullptr; - // If the tangent space is the original struct, then field is the same. - if (tangentVectorDecl == seai->getStructDecl()) - tanField = seai->getField(); - // Otherwise, look up the field by name. - else { - auto tanFieldLookup = - tangentVectorDecl->lookupDirect(seai->getField()->getName()); - if (tanFieldLookup.empty()) { - context.emitNondifferentiabilityError( - seai, invoker, - diag::autodiff_stored_property_no_corresponding_tangent, - seai->getStructDecl()->getNameStr(), seai->getField()->getNameStr()); - errorOccurred = true; - return; - } - tanField = cast(tanFieldLookup.front()); + auto *tanField = getTangentStoredProperty(context, seai, invoker); + if (!tanField) { + errorOccurred = true; + return; } - // Emit tangent `struct_element_addr`. auto tanOperand = getTangentBuffer(bb, seai->getOperand()); auto tangentInst = diff --git a/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp b/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp index 50020859dbae5..ff4d804ae5260 100644 --- a/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp @@ -24,6 +24,7 @@ #include "swift/AST/Expr.h" #include "swift/AST/PropertyWrappers.h" +#include "swift/AST/TypeCheckRequests.h" #include "swift/SIL/InstructionUtils.h" #include "swift/SIL/Projection.h" #include "swift/SILOptimizer/PassManager/PrettyStackTrace.h" @@ -335,13 +336,8 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB, assert(!seai->getField()->getAttrs().hasAttribute() && "`@noDerivative` struct projections should never be active"); auto adjSource = getAdjointBuffer(origBB, seai->getOperand()); - auto *tangentVectorDecl = - adjSource->getType().getStructOrBoundGenericStruct(); - // TODO(TF-970): Emit diagnostic when `TangentVector` is not a struct. - auto tanFieldLookup = - tangentVectorDecl->lookupDirect(seai->getField()->getName()); - assert(tanFieldLookup.size() == 1); - auto *tanField = cast(tanFieldLookup.front()); + auto *tanField = getTangentStoredProperty(getContext(), seai, getInvoker()); + assert(tanField && "Invalid projections should have been diagnosed"); return builder.createStructElementAddr(seai->getLoc(), adjSource, tanField); } // Handle `tuple_element_addr`. @@ -373,15 +369,8 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB, // `TangentVector` struct. auto adjClass = materializeAdjointDirect(getAdjointValue(origBB, classOperand), loc); - auto *tangentVectorDecl = - adjClass->getType().getStructOrBoundGenericStruct(); - // TODO(TF-970): Replace assertions below with diagnostics. - assert(tangentVectorDecl && "`TangentVector` of a class must be a struct"); - auto tanFieldLookup = - tangentVectorDecl->lookupDirect(reai->getField()->getName()); - assert(tanFieldLookup.size() == 1 && - "Class `TangentVector` must have field of the same name"); - auto *tanField = cast(tanFieldLookup.front()); + auto *tanField = getTangentStoredProperty(getContext(), reai, getInvoker()); + assert(tanField && "Invalid projections should have been diagnosed"); // Create a local allocation for the element adjoint buffer. auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType(); auto eltTanSILType = @@ -619,25 +608,19 @@ bool PullbackEmitter::runForSemanticMemberGetter() { "Getter should have one semantic result"); auto origResult = origFormalResults[*getIndices().results->begin()]; - // TODO(TF-970): Emit diagnostic when `TangentVector` is not a struct. auto tangentVectorSILTy = pullback.getConventions().getSingleSILResultType( TypeExpansionContext::minimal()); auto tangentVectorTy = tangentVectorSILTy.getASTType(); auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct(); // Look up the corresponding field in the tangent space. - VarDecl *origField = cast(accessor->getStorage()); - VarDecl *tanField = nullptr; - auto tanFieldLookup = tangentVectorDecl->lookupDirect(origField->getName()); - if (tanFieldLookup.empty()) { - getContext().emitNondifferentiabilityError( - pbLoc.getSourceLoc(), getInvoker(), - diag::autodiff_stored_property_no_corresponding_tangent, - origSelf->getType().getASTType().getString(), origField->getNameStr()); + auto *origField = cast(accessor->getStorage()); + auto *tanField = + getTangentStoredProperty(getContext(), origField, pbLoc, getInvoker()); + if (!tanField) { errorOccurred = true; return true; } - tanField = cast(tanFieldLookup.front()); // Switch based on the base tangent struct's value category. // TODO(TF-1255): Simplify using unified adjoint value data structure. @@ -736,27 +719,14 @@ bool PullbackEmitter::runForSemanticMemberSetter() { SILValue origArg = original.getArgumentsWithoutIndirectResults()[0]; SILValue origSelf = original.getArgumentsWithoutIndirectResults()[1]; - // TODO(TF-970): Emit diagnostic when `TangentVector` is not a struct. - auto tangentVectorSILTy = pullback.getLoweredFunctionType() - ->getParameters()[0] - .getSILStorageInterfaceType(); - assert(tangentVectorSILTy.getCategory() == SILValueCategory::Address); - auto tangentVectorTy = tangentVectorSILTy.getASTType(); - auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct(); - // Look up the corresponding field in the tangent space. - VarDecl *origField = cast(accessor->getStorage()); - VarDecl *tanField = nullptr; - auto tanFieldLookup = tangentVectorDecl->lookupDirect(origField->getName()); - if (tanFieldLookup.empty()) { - getContext().emitNondifferentiabilityError( - pbLoc.getSourceLoc(), getInvoker(), - diag::autodiff_stored_property_no_corresponding_tangent, - origSelf->getType().getASTType().getString(), origField->getNameStr()); + auto *origField = cast(accessor->getStorage()); + auto *tanField = + getTangentStoredProperty(getContext(), origField, pbLoc, getInvoker()); + if (!tanField) { errorOccurred = true; return true; } - tanField = cast(tanFieldLookup.front()); auto adjSelf = getAdjointBuffer(origEntry, origSelf); auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField); @@ -884,6 +854,13 @@ bool PullbackEmitter::run() { errorOccurred = true; return true; } + // Diagnose unsupported stored property projections. + if (auto *inst = dyn_cast(v)) { + if (!getTangentStoredProperty(getContext(), inst, getInvoker())) { + errorOccurred = true; + return true; + } + } // Skip address projections. // Address projections do not need their own adjoint buffers; they // become projections into their adjoint base buffer. @@ -1674,23 +1651,12 @@ void PullbackEmitter::visitStructInst(StructInst *si) { if (field->getAttrs().hasAttribute()) continue; // Find the corresponding field in the tangent space. - VarDecl *tanField = nullptr; - if (tangentVectorDecl == structDecl) - tanField = field; - // Otherwise, look up the field by name. - else { - auto tanFieldLookup = tangentVectorDecl->lookupDirect(field->getName()); - if (tanFieldLookup.empty()) { - getContext().emitNondifferentiabilityError( - si, getInvoker(), - diag::autodiff_stored_property_no_corresponding_tangent, - tangentVectorDecl->getNameStr(), field->getNameStr()); - errorOccurred = true; - return; - } - tanField = cast(tanFieldLookup.front()); + auto *tanField = + getTangentStoredProperty(getContext(), field, loc, getInvoker()); + if (!tanField) { + errorOccurred = true; + return; } - assert(tanField); auto tanElt = dti->getResult(fieldIndex); addAdjointValue(bb, si->getFieldValue(field), makeConcreteAdjointValue(tanElt), si->getLoc()); @@ -1717,36 +1683,16 @@ void PullbackEmitter::visitBeginApplyInst(BeginApplyInst *bai) { } void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) { - assert(!sei->getField()->getAttrs().hasAttribute() && - "`struct_extract` with `@noDerivative` field should not be " - "differentiated; activity analysis should not marked as varied"); auto *bb = sei->getParent(); auto structTy = remapType(sei->getOperand()->getType()).getASTType(); auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType(); assert(!getTypeLowering(tangentVectorTy).isAddressOnly()); auto tangentVectorSILTy = SILType::getPrimitiveObjectType(tangentVectorTy); auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct(); - // TODO(TF-970): Emit diagnostic when `TangentVector` is not a struct. assert(tangentVectorDecl); // Find the corresponding field in the tangent space. - VarDecl *tanField = nullptr; - // If the tangent space is the original struct, then field is the same. - if (tangentVectorDecl == sei->getStructDecl()) - tanField = sei->getField(); - // Otherwise, look up the field by name. - else { - auto tanFieldLookup = - tangentVectorDecl->lookupDirect(sei->getField()->getName()); - if (tanFieldLookup.empty()) { - getContext().emitNondifferentiabilityError( - sei, getInvoker(), - diag::autodiff_stored_property_no_corresponding_tangent, - sei->getStructDecl()->getNameStr(), sei->getField()->getNameStr()); - errorOccurred = true; - return; - } - tanField = cast(tanFieldLookup.front()); - } + auto *tanField = getTangentStoredProperty(getContext(), sei, getInvoker()); + assert(tanField && "Invalid projections should have been diagnosed"); // Accumulate adjoint for the `struct_extract` operand. auto av = getAdjointValue(bb, sei); switch (av.getKind()) { @@ -1784,21 +1730,8 @@ void PullbackEmitter::visitRefElementAddrInst(RefElementAddrInst *reai) { assert(!getTypeLowering(tangentVectorTy).isAddressOnly()); auto tangentVectorSILTy = SILType::getPrimitiveObjectType(tangentVectorTy); auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct(); - // TODO(TF-970): Emit diagnostic when `TangentVector` is not a struct. - assert(tangentVectorDecl); - // Look up the corresponding field in the tangent space by name. - VarDecl *tanField = nullptr; - auto tanFieldLookup = - tangentVectorDecl->lookupDirect(reai->getField()->getName()); - if (tanFieldLookup.empty()) { - getContext().emitNondifferentiabilityError( - reai, getInvoker(), - diag::autodiff_stored_property_no_corresponding_tangent, - reai->getClassDecl()->getNameStr(), reai->getField()->getNameStr()); - errorOccurred = true; - return; - } - tanField = cast(tanFieldLookup.front()); + auto *tanField = getTangentStoredProperty(getContext(), reai, getInvoker()); + assert(tanField && "Invalid projections should have been diagnosed"); // Accumulate adjoint for the `ref_element_addr` operand. SmallVector eltVals; for (auto *field : tangentVectorDecl->getStoredProperties()) { diff --git a/lib/SILOptimizer/Differentiation/Thunk.cpp b/lib/SILOptimizer/Differentiation/Thunk.cpp index 9d56ae56310b5..e5b15cc51e95f 100644 --- a/lib/SILOptimizer/Differentiation/Thunk.cpp +++ b/lib/SILOptimizer/Differentiation/Thunk.cpp @@ -237,7 +237,6 @@ CanSILFunctionType buildThunkType(SILFunction *fn, if (expectedType->hasErrorResult()) { auto errorResult = expectedType->getErrorResult(); interfaceErrorResult = errorResult.map(mapTypeOutOfContext); - ; } // The type of the thunk function. diff --git a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift index 2aaca3f5f1d3c..db81b705c680c 100644 --- a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift +++ b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift @@ -470,6 +470,178 @@ func testInoutParameterAndFormalResult(_ x: Float) -> Float { return inoutParameterAndFormalResult(&x) } +//===----------------------------------------------------------------------===// +// Stored property access differentiation +//===----------------------------------------------------------------------===// + +// Test differentiation of invalid stored property access instructions: +// `struct_extract`, `struct_element_addr`, `ref_element_addr`. + +struct StructTangentVectorNotStruct: Differentiable { + var x: Float + + enum TangentVector: Differentiable, AdditiveArithmetic { + case x(Float) + typealias TangentVector = Self + static func ==(_: Self, _: Self) -> Bool { fatalError() } + static var zero: Self { fatalError() } + static func +(_: Self, _: Self) -> Self { fatalError() } + static func -(_: Self, _: Self) -> Self { fatalError() } + } + mutating func move(along direction: TangentVector) {} +} + +// expected-error @+2 {{function is not differentiable}} +// expected-note @+3 {{when differentiating this function definition}} +@differentiable +@_silgen_name("test_struct_tangent_vector_not_struct") +func testStructTangentVectorNotStruct(_ s: StructTangentVectorNotStruct) -> Float { + // expected-note @+1 {{cannot differentiate access to property 'StructTangentVectorNotStruct.x' because 'StructTangentVectorNotStruct.TangentVector' is not a struct}} + return s.x +} + +// CHECK-LABEL: sil {{.*}} @test_struct_tangent_vector_not_struct +// CHECK: struct_extract {{%.*}} : $StructTangentVectorNotStruct, #StructTangentVectorNotStruct.x + +struct StructOriginalPropertyNotDifferentiable: Differentiable { + struct Nondiff { + var x: Float + } + var nondiff: Nondiff + + struct TangentVector: Differentiable & AdditiveArithmetic { + var nondiff: Float + } + mutating func move(along direction: TangentVector) {} +} + +// expected-error @+2 {{function is not differentiable}} +// expected-note @+3 {{when differentiating this function definition}} +@differentiable +@_silgen_name("test_struct_original_property_not_differentiable") +func testStructOriginalPropertyNotDifferentiable(_ s: StructOriginalPropertyNotDifferentiable) -> Float { + // expected-note @+1 {{cannot differentiate access to property 'StructOriginalPropertyNotDifferentiable.nondiff' because property type 'StructOriginalPropertyNotDifferentiable.Nondiff' does not conform to 'Differentiable'}} + return s.nondiff.x +} + +// CHECK-LABEL: sil {{.*}} @test_struct_original_property_not_differentiable +// CHECK: struct_extract {{%.*}} : $StructOriginalPropertyNotDifferentiable, #StructOriginalPropertyNotDifferentiable.nondiff + +struct StructTangentVectorPropertyNotFound: Differentiable { + var x: Float + + struct TangentVector: Differentiable, AdditiveArithmetic { + var y: Float + } + mutating func move(along direction: TangentVector) {} +} + +// expected-error @+2 {{function is not differentiable}} +// expected-note @+3 {{when differentiating this function definition}} +@differentiable +@_silgen_name("test_struct_tangent_property_not_found") +func testStructTangentPropertyNotFound(_ s: StructTangentVectorPropertyNotFound) -> Float { + // expected-warning @+1 {{variable 'tmp' was never mutated}} + var tmp = s + // expected-note @+1 {{cannot differentiate access to property 'StructTangentVectorPropertyNotFound.x' because 'StructTangentVectorPropertyNotFound.TangentVector' does not have a stored property named 'x'}} + return tmp.x +} + +// CHECK-LABEL: sil {{.*}} @test_struct_tangent_property_not_found +// CHECK: struct_element_addr {{%.*}} : $*StructTangentVectorPropertyNotFound, #StructTangentVectorPropertyNotFound.x + +struct StructTangentPropertyWrongType: Differentiable { + var x: Float + + struct TangentVector: Differentiable, AdditiveArithmetic { + var x: Double + } + mutating func move(along direction: TangentVector) {} +} + +// expected-error @+2 {{function is not differentiable}} +// expected-note @+3 {{when differentiating this function definition}} +@differentiable +@_silgen_name("test_struct_tangent_property_wrong_type") +func testStructTangentPropertyWrongType(_ s: StructTangentPropertyWrongType) -> Float { + // expected-warning @+1 {{variable 'tmp' was never mutated}} + var tmp = s + // expected-note @+1 {{cannot differentiate access to property 'StructTangentPropertyWrongType.x' because 'StructTangentPropertyWrongType.TangentVector.x' does not have expected type 'Float.TangentVector' (aka 'Float')}} + return tmp.x +} + +// CHECK-LABEL: sil {{.*}} @test_struct_tangent_property_wrong_type +// CHECK: struct_element_addr {{%.*}} : $*StructTangentPropertyWrongType, #StructTangentPropertyWrongType.x + +final class ClassTangentPropertyWrongType: Differentiable { + var x: Float = 0 + + struct TangentVector: Differentiable, AdditiveArithmetic { + var x: Double + } + func move(along direction: TangentVector) {} +} + +// expected-error @+2 {{function is not differentiable}} +// expected-note @+3 {{when differentiating this function definition}} +@differentiable +@_silgen_name("test_class_tangent_property_wrong_type") +func testClassTangentPropertyWrongType(_ c: ClassTangentPropertyWrongType) -> Float { + // expected-warning @+1 {{variable 'tmp' was never mutated}} + var tmp = c + // expected-note @+1 {{cannot differentiate access to property 'ClassTangentPropertyWrongType.x' because 'ClassTangentPropertyWrongType.TangentVector.x' does not have expected type 'Float.TangentVector' (aka 'Float')}} + return tmp.x +} + +// CHECK-LABEL: sil {{.*}} @test_class_tangent_property_wrong_type +// CHECK: ref_element_addr {{%.*}} : $ClassTangentPropertyWrongType, #ClassTangentPropertyWrongType.x + +struct StructTangentPropertyNotStored: Differentiable { + var x: Float + + struct TangentVector: Differentiable, AdditiveArithmetic { + var x: Float { 0 } + } + mutating func move(along direction: TangentVector) {} +} + +// expected-error @+2 {{function is not differentiable}} +// expected-note @+3 {{when differentiating this function definition}} +@differentiable +@_silgen_name("test_struct_tangent_property_not_stored") +func testStructTangentPropertyNotStored(_ s: StructTangentPropertyNotStored) -> Float { + // expected-warning @+1 {{variable 'tmp' was never mutated}} + var tmp = s + // expected-note @+1 {{cannot differentiate access to property 'StructTangentPropertyNotStored.x' because 'StructTangentPropertyNotStored.TangentVector.x' is not a stored property}} + return tmp.x +} + +// CHECK-LABEL: sil {{.*}} @test_struct_tangent_property_not_stored +// CHECK: struct_element_addr {{%.*}} : $*StructTangentPropertyNotStored, #StructTangentPropertyNotStored.x + +final class ClassTangentPropertyNotStored: Differentiable { + var x: Float = 0 + + struct TangentVector: Differentiable, AdditiveArithmetic { + var x: Float { 0 } + } + func move(along direction: TangentVector) {} +} + +// expected-error @+2 {{function is not differentiable}} +// expected-note @+3 {{when differentiating this function definition}} +@differentiable +@_silgen_name("test_class_tangent_property_not_stored") +func testClassTangentPropertyNotStored(_ c: ClassTangentPropertyNotStored) -> Float { + // expected-warning @+1 {{variable 'tmp' was never mutated}} + var tmp = c + // expected-note @+1 {{cannot differentiate access to property 'ClassTangentPropertyNotStored.x' because 'ClassTangentPropertyNotStored.TangentVector.x' is not a stored property}} + return tmp.x +} + +// CHECK-LABEL: sil {{.*}} @test_class_tangent_property_not_stored +// CHECK: ref_element_addr {{%.*}} : $ClassTangentPropertyNotStored, #ClassTangentPropertyNotStored.x + //===----------------------------------------------------------------------===// // Wrapped property differentiation //===----------------------------------------------------------------------===// @@ -508,7 +680,7 @@ extension DifferentiableWrapper: Differentiable where Value: Differentiable {} struct Struct: Differentiable { // expected-error @+2 {{expression is not differentiable}} - // expected-note @+1 {{property cannot be differentiated because 'Struct.TangentVector' does not have a member named '_x'}} + // expected-note @+1 {{cannot differentiate access to property 'Struct._x' because 'Struct.TangentVector' does not have a stored property named '_x'}} @DifferentiableWrapper @DifferentiableWrapper var x: Float = 10 @Wrapper var y: Float = 20 diff --git a/test/AutoDiff/SILOptimizer/forward_mode_diagnostics.swift b/test/AutoDiff/SILOptimizer/forward_mode_diagnostics.swift index a2ab040928b92..f028db4baca4f 100644 --- a/test/AutoDiff/SILOptimizer/forward_mode_diagnostics.swift +++ b/test/AutoDiff/SILOptimizer/forward_mode_diagnostics.swift @@ -148,3 +148,173 @@ func testNoDerivativeParameter(_ f: @differentiable (Float, @noDerivative Float) return derivative(at: 2, 3) { (x, y) in f(x * x, y) } } */ + +//===----------------------------------------------------------------------===// +// Stored property access differentiation +//===----------------------------------------------------------------------===// + +// Test differentiation of invalid stored property access instructions: +// `struct_extract`, `struct_element_addr`, `ref_element_addr`. + +struct StructTangentVectorNotStruct: Differentiable { + var x: Float + + enum TangentVector: Differentiable, AdditiveArithmetic { + case x(Float) + typealias TangentVector = Self + static func ==(_: Self, _: Self) -> Bool { fatalError() } + static var zero: Self { fatalError() } + static func +(_: Self, _: Self) -> Self { fatalError() } + static func -(_: Self, _: Self) -> Self { fatalError() } + } + mutating func move(along direction: TangentVector) {} +} + +// expected-error @+2 {{function is not differentiable}} +// expected-note @+3 {{when differentiating this function definition}} +@differentiable +@_silgen_name("test_struct_tangent_vector_not_struct") +func testStructTangentVectorNotStruct(_ s: StructTangentVectorNotStruct) -> Float { + // expected-note @+1 {{cannot differentiate access to property 'StructTangentVectorNotStruct.x' because 'StructTangentVectorNotStruct.TangentVector' is not a struct}} + return s.x +} + +// CHECK-LABEL: sil {{.*}} @test_struct_tangent_vector_not_struct +// CHECK: struct_extract {{%.*}} : $StructTangentVectorNotStruct, #StructTangentVectorNotStruct.x + +struct StructOriginalPropertyNotDifferentiable: Differentiable { + struct Nondiff { + var x: Float + } + var nondiff: Nondiff + + struct TangentVector: Differentiable & AdditiveArithmetic { + var nondiff: Float + } + mutating func move(along direction: TangentVector) {} +} + +// expected-error @+2 {{function is not differentiable}} +// expected-note @+3 {{when differentiating this function definition}} +@differentiable +@_silgen_name("test_struct_original_property_not_differentiable") +func testStructOriginalPropertyNotDifferentiable(_ s: StructOriginalPropertyNotDifferentiable) -> Float { + // expected-note @+1 {{cannot differentiate access to property 'StructOriginalPropertyNotDifferentiable.nondiff' because property type 'StructOriginalPropertyNotDifferentiable.Nondiff' does not conform to 'Differentiable'}} + return s.nondiff.x +} + +// CHECK-LABEL: sil {{.*}} @test_struct_original_property_not_differentiable +// CHECK: struct_extract {{%.*}} : $StructOriginalPropertyNotDifferentiable, #StructOriginalPropertyNotDifferentiable.nondiff + +struct StructTangentVectorPropertyNotFound: Differentiable { + var x: Float + + struct TangentVector: Differentiable, AdditiveArithmetic { + var y: Float + } + mutating func move(along direction: TangentVector) {} +} + +// expected-error @+2 {{function is not differentiable}} +// expected-note @+3 {{when differentiating this function definition}} +@differentiable +@_silgen_name("test_struct_tangent_property_not_found") +func testStructTangentPropertyNotFound(_ s: StructTangentVectorPropertyNotFound) -> Float { + // expected-warning @+1 {{variable 'tmp' was never mutated}} + var tmp = s + // expected-note @+1 {{cannot differentiate access to property 'StructTangentVectorPropertyNotFound.x' because 'StructTangentVectorPropertyNotFound.TangentVector' does not have a stored property named 'x'}} + return tmp.x +} + +// CHECK-LABEL: sil {{.*}} @test_struct_tangent_property_not_found +// CHECK: struct_element_addr {{%.*}} : $*StructTangentVectorPropertyNotFound, #StructTangentVectorPropertyNotFound.x + +struct StructTangentPropertyWrongType: Differentiable { + var x: Float + + struct TangentVector: Differentiable, AdditiveArithmetic { + var x: Double + } + mutating func move(along direction: TangentVector) {} +} + +// expected-error @+2 {{function is not differentiable}} +// expected-note @+3 {{when differentiating this function definition}} +@differentiable +@_silgen_name("test_struct_tangent_property_wrong_type") +func testStructTangentPropertyWrongType(_ s: StructTangentPropertyWrongType) -> Float { + // expected-warning @+1 {{variable 'tmp' was never mutated}} + var tmp = s + // expected-note @+1 {{cannot differentiate access to property 'StructTangentPropertyWrongType.x' because 'StructTangentPropertyWrongType.TangentVector.x' does not have expected type 'Float.TangentVector' (aka 'Float')}} + return tmp.x +} + +// CHECK-LABEL: sil {{.*}} @test_struct_tangent_property_wrong_type +// CHECK: struct_element_addr {{%.*}} : $*StructTangentPropertyWrongType, #StructTangentPropertyWrongType.x + +final class ClassTangentPropertyWrongType: Differentiable { + var x: Float = 0 + + struct TangentVector: Differentiable, AdditiveArithmetic { + var x: Double + } + func move(along direction: TangentVector) {} +} + +// FIXME(TF-984): Forward-mode crash due to unset tangent buffer. +/* +@differentiable +@_silgen_name("test_class_tangent_property_wrong_type") +func testClassTangentPropertyWrongType(_ c: ClassTangentPropertyWrongType) -> Float { + var tmp = c + return tmp.x +} +*/ + +// CHECK-LABEL: sil {{.*}} @test_class_tangent_property_wrong_type +// CHECK: ref_element_addr {{%.*}} : $ClassTangentPropertyWrongType, #ClassTangentPropertyWrongType.x + +struct StructTangentPropertyNotStored: Differentiable { + var x: Float + + struct TangentVector: Differentiable, AdditiveArithmetic { + var x: Float { 0 } + } + mutating func move(along direction: TangentVector) {} +} + +// expected-error @+2 {{function is not differentiable}} +// expected-note @+3 {{when differentiating this function definition}} +@differentiable +@_silgen_name("test_struct_tangent_property_not_stored") +func testStructTangentPropertyNotStored(_ s: StructTangentPropertyNotStored) -> Float { + // expected-warning @+1 {{variable 'tmp' was never mutated}} + var tmp = s + // expected-note @+1 {{cannot differentiate access to property 'StructTangentPropertyNotStored.x' because 'StructTangentPropertyNotStored.TangentVector.x' is not a stored property}} + return tmp.x +} + +// CHECK-LABEL: sil {{.*}} @test_struct_tangent_property_not_stored +// CHECK: struct_element_addr {{%.*}} : $*StructTangentPropertyNotStored, #StructTangentPropertyNotStored.x + +final class ClassTangentPropertyNotStored: Differentiable { + var x: Float = 0 + + struct TangentVector: Differentiable, AdditiveArithmetic { + var x: Float { 0 } + } + func move(along direction: TangentVector) {} +} + +// FIXME(TF-984): Forward-mode crash due to unset tangent buffer. +/* +@differentiable +@_silgen_name("test_class_tangent_property_not_stored") +func testClassTangentPropertyNotStored(_ c: ClassTangentPropertyNotStored) -> Float { + var tmp = c + return tmp.x +} +*/ + +// CHECK-LABEL: sil {{.*}} @test_class_tangent_property_not_stored +// CHECK: ref_element_addr {{%.*}} : $ClassTangentPropertyNotStored, #ClassTangentPropertyNotStored.x