diff --git a/include/swift/AST/ASTDemangler.h b/include/swift/AST/ASTDemangler.h index 5b553d8bae264..c9b12cb084674 100644 --- a/include/swift/AST/ASTDemangler.h +++ b/include/swift/AST/ASTDemangler.h @@ -150,7 +150,9 @@ class ASTBuilder { Type createImplFunctionType( Demangle::ImplParameterConvention calleeConvention, + Demangle::ImplCoroutineKind coroutineKind, ArrayRef> params, + ArrayRef> yields, ArrayRef> results, std::optional> errorResult, ImplFunctionTypeFlags flags); diff --git a/include/swift/AST/DiagnosticsSIL.def b/include/swift/AST/DiagnosticsSIL.def index 90f386946ef8c..6f2b4fa5df739 100644 --- a/include/swift/AST/DiagnosticsSIL.def +++ b/include/swift/AST/DiagnosticsSIL.def @@ -607,6 +607,8 @@ NOTE(autodiff_cannot_differentiate_through_multiple_results,none, "cannot differentiate through multiple results", ()) NOTE(autodiff_cannot_differentiate_through_inout_arguments,none, "cannot differentiate through 'inout' arguments", ()) +NOTE(autodiff_cannot_differentiate_through_direct_yield,none, + "cannot differentiate through a '_read' accessor", ()) NOTE(autodiff_enums_unsupported,none, "differentiating enum values is not yet supported", ()) NOTE(autodiff_stored_property_parent_not_differentiable,none, diff --git a/include/swift/AST/IndexSubset.h b/include/swift/AST/IndexSubset.h index 6b848b91ad177..1943666fc161d 100644 --- a/include/swift/AST/IndexSubset.h +++ b/include/swift/AST/IndexSubset.h @@ -108,8 +108,10 @@ class IndexSubset : public llvm::FoldingSetNode { static IndexSubset *get(ASTContext &ctx, unsigned capacity, ArrayRef indices) { SmallBitVector indicesBitVec(capacity, false); - for (auto index : indices) + for (auto index : indices) { + assert(index < capacity); indicesBitVec.set(index); + } return IndexSubset::get(ctx, indicesBitVec); } diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index 182587fb2c5f9..e8281e8a58a72 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -5174,8 +5174,11 @@ class SILFunctionType final /// Returns the number of function potential semantic results: /// * Usual results /// * Inout parameters + /// * yields unsigned getNumAutoDiffSemanticResults() const { - return getNumResults() + getNumAutoDiffSemanticResultsParameters(); + return getNumResults() + + getNumAutoDiffSemanticResultsParameters() + + getNumYields(); } /// Get the generic signature that the component types are specified diff --git a/include/swift/Demangling/Demangle.h b/include/swift/Demangling/Demangle.h index fe44b2d203bc4..bd163422a9464 100644 --- a/include/swift/Demangling/Demangle.h +++ b/include/swift/Demangling/Demangle.h @@ -553,6 +553,7 @@ struct [[nodiscard]] ManglingError { UnknownEncoding, InvalidImplCalleeConvention, InvalidImplDifferentiability, + InvalidImplCoroutineKind, InvalidImplFunctionAttribute, InvalidImplParameterConvention, InvalidImplParameterTransferring, diff --git a/include/swift/Demangling/DemangleNodes.def b/include/swift/Demangling/DemangleNodes.def index 1f726f6f1e22c..8b5d18963ca5e 100644 --- a/include/swift/Demangling/DemangleNodes.def +++ b/include/swift/Demangling/DemangleNodes.def @@ -139,6 +139,7 @@ NODE(ImplFunctionAttribute) NODE(ImplFunctionConvention) NODE(ImplFunctionConventionName) NODE(ImplFunctionType) +NODE(ImplCoroutineKind) NODE(ImplInvocationSubstitutions) CONTEXT_NODE(ImplicitClosure) NODE(ImplParameter) diff --git a/include/swift/Demangling/TypeDecoder.h b/include/swift/Demangling/TypeDecoder.h index f9a35e5819cdb..db993e79d7914 100644 --- a/include/swift/Demangling/TypeDecoder.h +++ b/include/swift/Demangling/TypeDecoder.h @@ -48,6 +48,12 @@ enum class ImplMetatypeRepresentation { ObjC, }; +enum class ImplCoroutineKind { + None, + YieldOnce, + YieldMany, +}; + /// Describe a function parameter, parameterized on the type /// representation. template @@ -188,6 +194,9 @@ class ImplFunctionParam { BuiltType getType() const { return Type; } }; +template +using ImplFunctionYield = ImplFunctionParam; + enum class ImplResultConvention { Indirect, Owned, @@ -1023,9 +1032,11 @@ class TypeDecoder { case NodeKind::ImplFunctionType: { auto calleeConvention = ImplParameterConvention::Direct_Unowned; llvm::SmallVector, 8> parameters; + llvm::SmallVector, 8> yields; llvm::SmallVector, 8> results; llvm::SmallVector, 8> errorResults; ImplFunctionTypeFlags flags; + ImplCoroutineKind coroutineKind = ImplCoroutineKind::None; for (unsigned i = 0; i < Node->getNumChildren(); i++) { auto child = Node->getChild(i); @@ -1066,6 +1077,15 @@ class TypeDecoder { } else if (child->getText() == "@async") { flags = flags.withAsync(); } + } else if (child->getKind() == NodeKind::ImplCoroutineKind) { + if (!child->hasText()) + return MAKE_NODE_TYPE_ERROR0(child, "expected text"); + if (child->getText() == "yield_once") { + coroutineKind = ImplCoroutineKind::YieldOnce; + } else if (child->getText() == "yield_many") { + coroutineKind = ImplCoroutineKind::YieldMany; + } else + return MAKE_NODE_TYPE_ERROR0(child, "failed to decode coroutine kind"); } else if (child->getKind() == NodeKind::ImplDifferentiabilityKind) { ImplFunctionDifferentiabilityKind implDiffKind; switch ((MangledDifferentiabilityKind)child->getIndex()) { @@ -1088,10 +1108,14 @@ class TypeDecoder { if (decodeImplFunctionParam(child, depth + 1, parameters)) return MAKE_NODE_TYPE_ERROR0(child, "failed to decode function parameter"); + } else if (child->getKind() == NodeKind::ImplYield) { + if (decodeImplFunctionParam(child, depth + 1, yields)) + return MAKE_NODE_TYPE_ERROR0(child, + "failed to decode function yields"); } else if (child->getKind() == NodeKind::ImplResult) { if (decodeImplFunctionParam(child, depth + 1, results)) return MAKE_NODE_TYPE_ERROR0(child, - "failed to decode function parameter"); + "failed to decode function results"); } else if (child->getKind() == NodeKind::ImplErrorResult) { if (decodeImplFunctionPart(child, depth + 1, errorResults)) return MAKE_NODE_TYPE_ERROR0(child, @@ -1115,11 +1139,10 @@ class TypeDecoder { // TODO: Some cases not handled above, but *probably* they cannot // appear as the types of values in SIL (yet?): - // - functions with yield returns // - functions with generic signatures // - foreign error conventions - return Builder.createImplFunctionType(calleeConvention, - parameters, results, + return Builder.createImplFunctionType(calleeConvention, coroutineKind, + parameters, yields, results, errorResult, flags); } diff --git a/include/swift/RemoteInspection/TypeRefBuilder.h b/include/swift/RemoteInspection/TypeRefBuilder.h index 4af241bd50a7d..316f01223c755 100644 --- a/include/swift/RemoteInspection/TypeRefBuilder.h +++ b/include/swift/RemoteInspection/TypeRefBuilder.h @@ -1134,7 +1134,9 @@ class TypeRefBuilder { const FunctionTypeRef *createImplFunctionType( Demangle::ImplParameterConvention calleeConvention, + Demangle::ImplCoroutineKind coroutineKind, llvm::ArrayRef> params, + llvm::ArrayRef> yields, llvm::ArrayRef> results, std::optional> errorResult, ImplFunctionTypeFlags flags) { diff --git a/include/swift/SIL/SILFunctionConventions.h b/include/swift/SIL/SILFunctionConventions.h index 1890272278ca1..ed61f86e4ca4c 100644 --- a/include/swift/SIL/SILFunctionConventions.h +++ b/include/swift/SIL/SILFunctionConventions.h @@ -248,6 +248,14 @@ class SILFunctionConventions { idx < indirectResults + getNumIndirectSILErrorResults(); } + unsigned getNumAutoDiffSemanticResults() const { + return funcTy->getNumAutoDiffSemanticResults(); + } + + unsigned getNumAutoDiffSemanticResultParameters() const { + return funcTy->getNumAutoDiffSemanticResultsParameters(); + } + /// Are any SIL results passed as address-typed arguments? bool hasIndirectSILResults() const { return getNumIndirectSILResults() != 0; } bool hasIndirectSILErrorResults() const { return getNumIndirectSILErrorResults() != 0; } diff --git a/include/swift/SILOptimizer/Differentiation/ADContext.h b/include/swift/SILOptimizer/Differentiation/ADContext.h index b24b5487f0fb6..021a1ec2b8aa7 100644 --- a/include/swift/SILOptimizer/Differentiation/ADContext.h +++ b/include/swift/SILOptimizer/Differentiation/ADContext.h @@ -17,6 +17,7 @@ #ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADCONTEXT_H #define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADCONTEXT_H +#include "swift/SIL/ApplySite.h" #include "swift/SILOptimizer/Differentiation/Common.h" #include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" @@ -51,6 +52,12 @@ struct NestedApplyInfo { /// The original pullback type before reabstraction. `None` if the pullback /// type is not reabstracted. std::optional originalPullbackType; + /// Index of `apply` pullback in nested pullback call + unsigned pullbackIdx = -1U; + /// Pullback value itself that is memoized in some cases (e.g. pullback is + /// called by `begin_apply`, but should be destroyed after `end_apply`). + SILValue pullback = SILValue(); + SILValue beginApplyToken = SILValue(); }; /// Per-module contextual information for the Differentiation pass. @@ -97,7 +104,7 @@ class ADContext { /// Mapping from original `apply` instructions to their corresponding /// `NestedApplyInfo`s. - llvm::DenseMap nestedApplyInfo; + llvm::DenseMap nestedApplyInfo; /// List of generated functions (JVPs, VJPs, pullbacks, and thunks). /// Saved for deletion during cleanup. @@ -185,7 +192,7 @@ class ADContext { invokers.insert({witness, DifferentiationInvoker(witness)}); } - llvm::DenseMap &getNestedApplyInfo() { + llvm::DenseMap &getNestedApplyInfo() { return nestedApplyInfo; } diff --git a/include/swift/SILOptimizer/Differentiation/Common.h b/include/swift/SILOptimizer/Differentiation/Common.h index 2af9cc4c6c754..d43bf3897cd6c 100644 --- a/include/swift/SILOptimizer/Differentiation/Common.h +++ b/include/swift/SILOptimizer/Differentiation/Common.h @@ -20,6 +20,7 @@ #include "swift/AST/DiagnosticsSIL.h" #include "swift/AST/Expr.h" #include "swift/AST/SemanticAttrs.h" +#include "swift/SIL/ApplySite.h" #include "swift/SIL/SILDifferentiabilityWitness.h" #include "swift/SIL/SILFunction.h" #include "swift/SIL/Projection.h" @@ -112,7 +113,7 @@ void collectAllDirectResultsInTypeOrder(SILFunction &function, /// Given a function call site, gathers all of its actual results (both direct /// and indirect) in an order defined by its result type. void collectAllActualResultsInTypeOrder( - ApplyInst *ai, ArrayRef extractedDirectResults, + FullApplySite fai, ArrayRef extractedDirectResults, SmallVectorImpl &results); /// For an `apply` instruction with active results, compute: @@ -120,7 +121,7 @@ void collectAllActualResultsInTypeOrder( /// - The set of minimal parameter and result indices for differentiating the /// `apply` instruction. void collectMinimalIndicesForFunctionCall( - ApplyInst *ai, const AutoDiffConfig &parentConfig, + FullApplySite fai, const AutoDiffConfig &parentConfig, const DifferentiableActivityInfo &activityInfo, SmallVectorImpl &results, SmallVectorImpl ¶mIndices, SmallVectorImpl &resultIndices); diff --git a/include/swift/SILOptimizer/Differentiation/LinearMapInfo.h b/include/swift/SILOptimizer/Differentiation/LinearMapInfo.h index ca50cdca4343d..c5a3d8be9a105 100644 --- a/include/swift/SILOptimizer/Differentiation/LinearMapInfo.h +++ b/include/swift/SILOptimizer/Differentiation/LinearMapInfo.h @@ -77,9 +77,9 @@ class LinearMapInfo { /// For differentials: these are successor enums. llvm::DenseMap branchingTraceDecls; - /// Mapping from `apply` instructions in the original function to the + /// Mapping from `apply` / `begin_apply` instructions in the original function to the /// corresponding linear map tuple type index. - llvm::DenseMap linearMapIndexMap; + llvm::DenseMap linearMapIndexMap; /// Mapping from predecessor-successor basic block pairs in the original /// function to the corresponding branching trace enum case. @@ -112,9 +112,9 @@ class LinearMapInfo { void populateBranchingTraceDecl(SILBasicBlock *originalBB, SILLoopInfo *loopInfo); - /// Given an `apply` instruction, conditionally gets a linear map tuple field - /// AST type for its linear map function if it is active. - Type getLinearMapType(ADContext &context, ApplyInst *ai); + /// Given an `apply` / `begin_apply` instruction, conditionally gets a linear + /// map tuple field AST type for its linear map function if it is active. + Type getLinearMapType(ADContext &context, FullApplySite fai); /// Generates linear map struct and branching enum declarations for the given /// function. Linear map structs are populated with linear map fields and a @@ -180,18 +180,18 @@ class LinearMapInfo { } /// Finds the linear map index in the pullback tuple for the given - /// `apply` instruction in the original function. - unsigned lookUpLinearMapIndex(ApplyInst *ai) const { - assert(ai->getFunction() == original); - auto lookup = linearMapIndexMap.find(ai); + /// `apply` / `begin_apply` instruction in the original function. + unsigned lookUpLinearMapIndex(FullApplySite fas) const { + assert(fas->getFunction() == original); + auto lookup = linearMapIndexMap.find(fas); assert(lookup != linearMapIndexMap.end() && "No linear map field corresponding to the given `apply`"); return lookup->getSecond(); } - Type lookUpLinearMapType(ApplyInst *ai) const { - unsigned idx = lookUpLinearMapIndex(ai); - return getLinearMapTupleType(ai->getParentBlock())->getElement(idx).getType(); + Type lookUpLinearMapType(FullApplySite fas) const { + unsigned idx = lookUpLinearMapIndex(fas); + return getLinearMapTupleType(fas->getParent())->getElement(idx).getType(); } bool hasHeapAllocatedContext() const { diff --git a/include/swift/SILOptimizer/Differentiation/Thunk.h b/include/swift/SILOptimizer/Differentiation/Thunk.h index 02aa54f56c714..923a107e87425 100644 --- a/include/swift/SILOptimizer/Differentiation/Thunk.h +++ b/include/swift/SILOptimizer/Differentiation/Thunk.h @@ -56,6 +56,11 @@ SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb, CanSILFunctionType fromType, CanSILFunctionType toType); +SILValue reabstractCoroutine( + SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc, + SILValue fn, CanSILFunctionType toType, + std::function remapSubstitutions); + /// Reabstracts the given function-typed value `fn` to the target type `toType`. /// Remaps substitutions using `remapSubstitutions`. SILValue reabstractFunction( diff --git a/lib/AST/ASTDemangler.cpp b/lib/AST/ASTDemangler.cpp index 0d618a5c2fb5b..2743b761f960f 100644 --- a/lib/AST/ASTDemangler.cpp +++ b/lib/AST/ASTDemangler.cpp @@ -571,17 +571,33 @@ getResultOptions(ImplResultInfoOptions implOptions) { return result; } +static SILCoroutineKind +getCoroutineKind(ImplCoroutineKind kind) { + switch (kind) { + case ImplCoroutineKind::None: + return SILCoroutineKind::None; + case ImplCoroutineKind::YieldOnce: + return SILCoroutineKind::YieldOnce; + case ImplCoroutineKind::YieldMany: + return SILCoroutineKind::YieldMany; + } + llvm_unreachable("unknown coroutine kind"); +} + Type ASTBuilder::createImplFunctionType( Demangle::ImplParameterConvention calleeConvention, + Demangle::ImplCoroutineKind coroutineKind, ArrayRef> params, + ArrayRef> yields, ArrayRef> results, std::optional> errorResult, ImplFunctionTypeFlags flags) { GenericSignature genericSig; - SILCoroutineKind funcCoroutineKind = SILCoroutineKind::None; ParameterConvention funcCalleeConvention = getParameterConvention(calleeConvention); + SILCoroutineKind funcCoroutineKind = + getCoroutineKind(coroutineKind); SILFunctionTypeRepresentation representation; switch (flags.getRepresentation()) { @@ -644,6 +660,13 @@ Type ASTBuilder::createImplFunctionType( funcParams.emplace_back(type, conv, options); } + for (const auto &yield : yields) { + auto type = yield.getType()->getCanonicalType(); + auto conv = getParameterConvention(yield.getConvention()); + auto options = *getParameterOptions(yield.getOptions()); + funcParams.emplace_back(type, conv, options); + } + for (const auto &result : results) { auto type = result.getType()->getCanonicalType(); auto conv = getResultConvention(result.getConvention()); diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 79de6ff824cfc..142fcd0e75ded 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -101,6 +101,10 @@ DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const { llvm_unreachable("invalid derivative kind"); } +void AutoDiffConfig::dump() const { + print(llvm::errs()); +} + void AutoDiffConfig::print(llvm::raw_ostream &s) const { s << "(parameters="; parameterIndices->print(s); @@ -354,22 +358,30 @@ GenericSignature autodiff::getConstrainedDerivativeGenericSignature( // Require differentiability results to conform to `Differentiable`. SmallVector originalResults; getSemanticResults(originalFnTy, diffParamIndices, originalResults); + unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults(); + unsigned firstYieldResultIndex = originalFnTy->getNumResults() + + originalFnTy->getNumAutoDiffSemanticResultsParameters(); for (unsigned resultIdx : diffResultIndices->getIndices()) { // Handle formal original result. - if (resultIdx < originalFnTy->getNumResults()) { + if (resultIdx < firstSemanticParamResultIdx) { auto resultType = originalResults[resultIdx].getInterfaceType(); addRequirement(resultType); - continue; + } else if (resultIdx < firstYieldResultIndex) { + // Handle original semantic result parameters. + auto resultParamIndex = resultIdx - originalFnTy->getNumResults(); + auto resultParamIt = std::next( + originalFnTy->getAutoDiffSemanticResultsParameters().begin(), + resultParamIndex); + auto paramIndex = + std::distance(originalFnTy->getParameters().begin(), &*resultParamIt); + addRequirement(originalFnTy->getParameters()[paramIndex].getInterfaceType()); + } else { + // Handle formal original yields. + assert(originalFnTy->isCoroutine()); + assert(originalFnTy->getCoroutineKind() == SILCoroutineKind::YieldOnce); + auto yieldResultIndex = resultIdx - firstYieldResultIndex; + addRequirement(originalFnTy->getYields()[yieldResultIndex].getInterfaceType()); } - // Handle original semantic result parameters. - // FIXME: Constraint generic yields when we will start supporting them - auto resultParamIndex = resultIdx - originalFnTy->getNumResults(); - auto resultParamIt = std::next( - originalFnTy->getAutoDiffSemanticResultsParameters().begin(), - resultParamIndex); - auto paramIndex = - std::distance(originalFnTy->getParameters().begin(), &*resultParamIt); - addRequirement(originalFnTy->getParameters()[paramIndex].getInterfaceType()); } return buildGenericSignature(ctx, derivativeGenSig, diff --git a/lib/Demangling/Demangler.cpp b/lib/Demangling/Demangler.cpp index e542e67495f57..4d84a18fe202c 100644 --- a/lib/Demangling/Demangler.cpp +++ b/lib/Demangling/Demangler.cpp @@ -2292,11 +2292,11 @@ NodePointer Demangler::demangleImplFunctionType() { const char *CoroAttr = nullptr; if (nextIf('A')) - CoroAttr = "@yield_once"; + CoroAttr = "yield_once"; else if (nextIf('G')) - CoroAttr = "@yield_many"; + CoroAttr = "yield_many"; if (CoroAttr) - type->addChild(createNode(Node::Kind::ImplFunctionAttribute, CoroAttr), *this); + type->addChild(createNode(Node::Kind::ImplCoroutineKind, CoroAttr), *this); if (nextIf('h')) { type->addChild(createNode(Node::Kind::ImplFunctionAttribute, "@Sendable"), diff --git a/lib/Demangling/NodePrinter.cpp b/lib/Demangling/NodePrinter.cpp index 9c2687b13edd7..0b65a1a28d75b 100644 --- a/lib/Demangling/NodePrinter.cpp +++ b/lib/Demangling/NodePrinter.cpp @@ -433,6 +433,7 @@ class NodePrinter { case Node::Kind::ImplFunctionConvention: case Node::Kind::ImplFunctionConventionName: case Node::Kind::ImplFunctionType: + case Node::Kind::ImplCoroutineKind: case Node::Kind::ImplInvocationSubstitutions: case Node::Kind::ImplPatternSubstitutions: case Node::Kind::ImplicitClosure: @@ -2759,6 +2760,13 @@ NodePointer NodePrinter::print(NodePointer Node, unsigned depth, return nullptr; case Node::Kind::ImplErasedIsolation: Printer << "@isolated(any)"; + return nullptr; + case Node::Kind::ImplCoroutineKind: + // Skip if text is empty. + if (Node->getText().empty()) + return nullptr; + // Otherwise, print with leading @. + Printer << '@' << Node->getText(); return nullptr; case Node::Kind::ImplTransferringResult: Printer << "transferring"; diff --git a/lib/Demangling/OldRemangler.cpp b/lib/Demangling/OldRemangler.cpp index f0a028a39a7de..e07e5faedb4d2 100644 --- a/lib/Demangling/OldRemangler.cpp +++ b/lib/Demangling/OldRemangler.cpp @@ -1652,14 +1652,23 @@ ManglingError Remangler::mangleImplFunctionType(Node *node, unsigned depth) { return ManglingError::Success; } -ManglingError Remangler::mangleImplFunctionAttribute(Node *node, - unsigned depth) { +ManglingError Remangler::mangleImplCoroutineKind(Node *node, + unsigned depth) { StringRef text = node->getText(); - if (text == "@yield_once") { + if (text == "yield_once") { Buffer << "A"; - } else if (text == "@yield_many") { + } else if (text == "yield_many") { Buffer << "G"; - } else if (text == "@Sendable") { + } else { + return MANGLING_ERROR(ManglingError::InvalidImplCoroutineKind, node); + } + return ManglingError::Success; +} + +ManglingError Remangler::mangleImplFunctionAttribute(Node *node, + unsigned depth) { + StringRef text = node->getText(); + if (text == "@Sendable") { Buffer << "h"; } else if (text == "@async") { Buffer << "H"; diff --git a/lib/Demangling/Remangler.cpp b/lib/Demangling/Remangler.cpp index 7879bb7c6705d..b485f68a6d69e 100644 --- a/lib/Demangling/Remangler.cpp +++ b/lib/Demangling/Remangler.cpp @@ -1917,6 +1917,12 @@ ManglingError Remangler::mangleImplPatternSubstitutions(Node *node, return MANGLING_ERROR(ManglingError::UnsupportedNodeKind, node); } +ManglingError Remangler::mangleImplCoroutineKind(Node *node, + unsigned depth) { + // handled inline + return MANGLING_ERROR(ManglingError::UnsupportedNodeKind, node); +} + ManglingError Remangler::mangleImplFunctionType(Node *node, unsigned depth) { const char *PseudoGeneric = ""; Node *GenSig = nullptr; @@ -2015,10 +2021,21 @@ ManglingError Remangler::mangleImplFunctionType(Node *node, unsigned depth) { RETURN_IF_ERROR(mangleImplFunctionConvention(Child, depth + 1)); break; } + case Node::Kind::ImplCoroutineKind: { + char CoroAttr = llvm::StringSwitch(Child->getText()) + .Case("yield_once", 'A') + .Case("yield_many", 'G') + .Default(0); + + if (!CoroAttr) { + return MANGLING_ERROR(ManglingError::InvalidImplCoroutineKind, + Child); + } + Buffer << CoroAttr; + break; + } case Node::Kind::ImplFunctionAttribute: { char FuncAttr = llvm::StringSwitch(Child->getText()) - .Case("@yield_once", 'A') - .Case("@yield_many", 'G') .Case("@Sendable", 'h') .Case("@async", 'H') .Default(0); diff --git a/lib/SIL/IR/SILFunctionType.cpp b/lib/SIL/IR/SILFunctionType.cpp index 462b53f21fa43..9542becf2a541 100644 --- a/lib/SIL/IR/SILFunctionType.cpp +++ b/lib/SIL/IR/SILFunctionType.cpp @@ -258,6 +258,13 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() { numSemanticResults += getNumAutoDiffSemanticResultsParameters(); + // Check yields. + for (auto yieldAndIndex : enumerate(getYields())) + if (!yieldAndIndex.value().hasOption( + SILParameterInfo::NotDifferentiable)) + resultIndices.push_back(numSemanticResults + yieldAndIndex.index()); + + numSemanticResults += getNumYields(); return IndexSubset::get(getASTContext(), numSemanticResults, resultIndices); } @@ -555,10 +562,14 @@ static CanSILFunctionType getAutoDiffDifferentialType( param.getConvention()); differentialParams.push_back({paramTanType, paramConv}); } + SmallVector differentialResults; + unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults(); + unsigned firstYieldResultIndex = originalFnTy->getNumResults() + + originalFnTy->getNumAutoDiffSemanticResultsParameters(); for (auto resultIndex : resultIndices->getIndices()) { // Handle formal original result. - if (resultIndex < originalFnTy->getNumResults()) { + if (resultIndex < firstSemanticParamResultIdx) { auto &result = originalResults[resultIndex]; auto resultTanType = getAutoDiffTangentTypeForLinearMap( result.getInterfaceType(), lookupConformance, @@ -571,26 +582,38 @@ static CanSILFunctionType getAutoDiffDifferentialType( result.getConvention()); differentialResults.push_back({resultTanType, resultConv}); continue; - } - // Handle original semantic result parameters. - auto resultParamIndex = resultIndex - originalFnTy->getNumResults(); - auto resultParamIt = std::next( + } else if (resultIndex < firstYieldResultIndex) { + // Handle original semantic result parameters. + auto resultParamIndex = resultIndex - originalFnTy->getNumResults(); + auto resultParamIt = std::next( originalFnTy->getAutoDiffSemanticResultsParameters().begin(), resultParamIndex); - auto paramIndex = - std::distance(originalFnTy->getParameters().begin(), &*resultParamIt); - // If the original semantic result parameter is a differentiability - // parameter, then it already has a corresponding differential - // parameter. Skip adding a corresponding differential result. - if (parameterIndices->contains(paramIndex)) - continue; + auto paramIndex = + std::distance(originalFnTy->getParameters().begin(), &*resultParamIt); + // If the original semantic result parameter is a differentiability + // parameter, then it already has a corresponding differential + // parameter. Skip adding a corresponding differential result. + if (parameterIndices->contains(paramIndex)) + continue; - auto resultParam = originalFnTy->getParameters()[paramIndex]; - auto resultParamTanType = getAutoDiffTangentTypeForLinearMap( - resultParam.getInterfaceType(), lookupConformance, + auto resultParam = originalFnTy->getParameters()[paramIndex]; + auto resultParamTanType = getAutoDiffTangentTypeForLinearMap( + resultParam.getInterfaceType(), lookupConformance, + substGenericParams, substReplacements, ctx); + differentialResults.emplace_back(resultParamTanType, + ResultConvention::Indirect); + } else { + assert(originalFnTy->isCoroutine()); + assert(originalFnTy->getCoroutineKind() == SILCoroutineKind::YieldOnce); + auto yieldResultIndex = resultIndex - firstYieldResultIndex; + auto yieldResult = originalFnTy->getYields()[yieldResultIndex]; + auto resultParamTanType = getAutoDiffTangentTypeForLinearMap( + yieldResult.getInterfaceType(), lookupConformance, substGenericParams, substReplacements, ctx); - differentialResults.emplace_back(resultParamTanType, - ResultConvention::Indirect); + ParameterConvention paramTanConvention = yieldResult.getConvention(); + assert(yieldResult.getConvention() == ParameterConvention::Indirect_Inout); + differentialParams.emplace_back(resultParamTanType, paramTanConvention); + } } SubstitutionMap substitutions; @@ -696,11 +719,15 @@ static CanSILFunctionType getAutoDiffPullbackType( return conv; }; - // Collect pullback parameters. + // Collect pullback parameters & yields SmallVector pullbackParams; + SmallVector pullbackYields; + unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults(); + unsigned firstYieldResultIndex = originalFnTy->getNumResults() + + originalFnTy->getNumAutoDiffSemanticResultsParameters(); for (auto resultIndex : resultIndices->getIndices()) { // Handle formal original result. - if (resultIndex < originalFnTy->getNumResults()) { + if (resultIndex < firstSemanticParamResultIdx) { auto &origRes = originalResults[resultIndex]; auto resultTanType = getAutoDiffTangentTypeForLinearMap( origRes.getInterfaceType(), lookupConformance, @@ -712,28 +739,38 @@ static CanSILFunctionType getAutoDiffPullbackType( ->getCanonicalType(), origRes.getConvention()); pullbackParams.emplace_back(resultTanType, paramConv); - continue; + } else if (resultIndex < firstYieldResultIndex) { + // Handle original semantic result parameters. + auto resultParamIndex = resultIndex - firstSemanticParamResultIdx; + auto resultParamIt = std::next( + originalFnTy->getAutoDiffSemanticResultsParameters().begin(), + resultParamIndex); + auto paramIndex = + std::distance(originalFnTy->getParameters().begin(), &*resultParamIt); + auto resultParam = originalFnTy->getParameters()[paramIndex]; + // The pullback parameter convention depends on whether the original `inout` + // parameter is a differentiability parameter. + // - If yes, the pullback parameter convention is `@inout`. + // - If no, the pullback parameter convention is `@in_guaranteed`. + auto resultParamTanType = getAutoDiffTangentTypeForLinearMap( + resultParam.getInterfaceType(), lookupConformance, + substGenericParams, substReplacements, ctx); + ParameterConvention paramTanConvention = resultParam.getConvention(); + if (!parameterIndices->contains(paramIndex)) + paramTanConvention = ParameterConvention::Indirect_In_Guaranteed; + pullbackParams.emplace_back(resultParamTanType, paramTanConvention); + } else { + assert(originalFnTy->isCoroutine()); + assert(originalFnTy->getCoroutineKind() == SILCoroutineKind::YieldOnce); + auto yieldResultIndex = resultIndex - firstYieldResultIndex; + auto yieldResult = originalFnTy->getYields()[yieldResultIndex]; + auto resultParamTanType = getAutoDiffTangentTypeForLinearMap( + yieldResult.getInterfaceType(), lookupConformance, + substGenericParams, substReplacements, ctx); + ParameterConvention paramTanConvention = yieldResult.getConvention(); + assert(yieldResult.getConvention() == ParameterConvention::Indirect_Inout); + pullbackYields.emplace_back(resultParamTanType, paramTanConvention); } - // Handle original semantic result parameters. - auto resultParamIndex = resultIndex - originalFnTy->getNumResults(); - auto resultParamIt = std::next( - originalFnTy->getAutoDiffSemanticResultsParameters().begin(), - resultParamIndex); - auto paramIndex = - std::distance(originalFnTy->getParameters().begin(), &*resultParamIt); - auto resultParam = originalFnTy->getParameters()[paramIndex]; - // The pullback parameter convention depends on whether the original `inout` - // parameter is a differentiability parameter. - // - If yes, the pullback parameter convention is `@inout`. - // - If no, the pullback parameter convention is `@in_guaranteed`. - auto resultParamTanType = getAutoDiffTangentTypeForLinearMap( - resultParam.getInterfaceType(), lookupConformance, - substGenericParams, substReplacements, ctx); - ParameterConvention paramTanConvention = resultParam.getConvention(); - if (!parameterIndices->contains(paramIndex)) - paramTanConvention = ParameterConvention::Indirect_In_Guaranteed; - - pullbackParams.emplace_back(resultParamTanType, paramTanConvention); } // Collect pullback results. @@ -756,6 +793,7 @@ static CanSILFunctionType getAutoDiffPullbackType( param.getConvention()); pullbackResults.push_back({paramTanType, resultTanConvention}); } + SubstitutionMap substitutions; if (!substGenericParams.empty()) { auto genericSig = @@ -766,9 +804,9 @@ static CanSILFunctionType getAutoDiffPullbackType( llvm::ArrayRef(substConformances)); } return SILFunctionType::get( - GenericSignature(), SILFunctionType::ExtInfo(), - originalFnTy->getCoroutineKind(), ParameterConvention::Direct_Guaranteed, - pullbackParams, {}, pullbackResults, std::nullopt, substitutions, + GenericSignature(), SILFunctionType::ExtInfo(), originalFnTy->getCoroutineKind(), + ParameterConvention::Direct_Guaranteed, + pullbackParams, pullbackYields, pullbackResults, std::nullopt, substitutions, /*invocationSubstitutions*/ SubstitutionMap(), ctx); } @@ -920,10 +958,9 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType( // Compute the derivative function results. SmallVector newResults; newResults.reserve(getNumResults() + 1); - for (auto &result : constrainedOriginalFnTy->getResults()) { + for (auto &result : constrainedOriginalFnTy->getResults()) newResults.push_back(result); - } - newResults.push_back({closureType, ResultConvention::Owned}); + newResults.emplace_back(closureType, ResultConvention::Owned); // Compute the derivative function ExtInfo. // If original function is `@convention(c)`, the derivative function should diff --git a/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp b/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp index e19df5d9b7ea5..3b9701e4b8a43 100644 --- a/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp +++ b/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp @@ -81,6 +81,8 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di, s << val << '\n'; }); // Outputs are indirect result buffers and return values, count `m`. + // For the purposes of differentiation, we consider yields to be results as + // well collectAllFormalResultsInTypeOrder(function, outputValues); LLVM_DEBUG({ auto &s = getADDebugStream(); @@ -312,14 +314,17 @@ void DifferentiableActivityInfo::setUsefulAndPropagateToOperands( for (auto incomingValue : incomingValues) setUsefulAndPropagateToOperands(incomingValue, dependentVariableIndex); return; - } else if (bbArg->isTerminatorResult()) { + } + + if (bbArg->isTerminatorResult()) { if (TryApplyInst *tai = dyn_cast(bbArg->getTerminatorForResult())) { propagateUseful(tai, dependentVariableIndex); return; - } else - llvm::report_fatal_error("unknown terminator with result"); - } else - llvm::report_fatal_error("do not know how to handle this incoming bb argument"); + } + llvm::report_fatal_error("unknown terminator with result"); + } + + llvm::report_fatal_error("do not know how to handle this incoming bb argument"); } auto *inst = value->getDefiningInstruction(); diff --git a/lib/SILOptimizer/Differentiation/Common.cpp b/lib/SILOptimizer/Differentiation/Common.cpp index 4bdbc6e1e9b85..a1cd24887d27b 100644 --- a/lib/SILOptimizer/Differentiation/Common.cpp +++ b/lib/SILOptimizer/Differentiation/Common.cpp @@ -17,6 +17,7 @@ #include "swift/Basic/STLExtras.h" #define DEBUG_TYPE "differentiation" +#include "swift/SIL/ApplySite.h" #include "swift/SILOptimizer/Differentiation/Common.h" #include "swift/AST/TypeCheckRequests.h" #include "swift/SILOptimizer/Differentiation/ADContext.h" @@ -145,6 +146,20 @@ void collectAllFormalResultsInTypeOrder(SILFunction &function, auto *argument = function.getArgumentsWithoutIndirectResults()[i]; results.push_back(argument); } + // Treat yields as semantic results. Note that we can only differentiate + // @yield_once with simple control flow, so we can assume that the function + // contains only a single `yield` instruction + auto yieldIt = + std::find_if(function.begin(), function.end(), + [](const SILBasicBlock &BB) -> bool { + const TermInst *TI = BB.getTerminator(); + return isa(TI); + }); + if (yieldIt != function.end()) { + auto *yieldInst = cast(yieldIt->getTerminator()); + for (auto yield : yieldInst->getOperandValues()) + results.push_back(yield); + } } void collectAllDirectResultsInTypeOrder(SILFunction &function, @@ -161,30 +176,30 @@ void collectAllDirectResultsInTypeOrder(SILFunction &function, } void collectAllActualResultsInTypeOrder( - ApplyInst *ai, ArrayRef extractedDirectResults, + FullApplySite fai, ArrayRef extractedDirectResults, SmallVectorImpl &results) { - auto calleeConvs = ai->getSubstCalleeConv(); + auto calleeConvs = fai.getSubstCalleeConv(); unsigned indResIdx = 0, dirResIdx = 0; for (auto &resInfo : calleeConvs.getResults()) { results.push_back(resInfo.isFormalDirect() ? extractedDirectResults[dirResIdx++] - : ai->getIndirectSILResults()[indResIdx++]); + : fai.getIndirectSILResults()[indResIdx++]); } } void collectMinimalIndicesForFunctionCall( - ApplyInst *ai, const AutoDiffConfig &parentConfig, + FullApplySite ai, const AutoDiffConfig &parentConfig, const DifferentiableActivityInfo &activityInfo, SmallVectorImpl &results, SmallVectorImpl ¶mIndices, SmallVectorImpl &resultIndices) { - auto calleeFnTy = ai->getSubstCalleeType(); - auto calleeConvs = ai->getSubstCalleeConv(); + auto calleeFnTy = ai.getSubstCalleeType(); + auto calleeConvs = ai.getSubstCalleeConv(); // Parameter indices are indices (in the callee type signature) of parameter // arguments that are varied or are arguments. // Record all parameter indices in type order. unsigned currentParamIdx = 0; - for (auto applyArg : ai->getArgumentsWithoutIndirectResults()) { + for (auto applyArg : ai.getArgumentsWithoutIndirectResults()) { if (activityInfo.isActive(applyArg, parentConfig)) paramIndices.push_back(currentParamIdx); ++currentParamIdx; @@ -196,7 +211,7 @@ void collectMinimalIndicesForFunctionCall( forEachApplyDirectResult(ai, [&](SILValue directResult) { directResults.push_back(directResult); }); - auto indirectResults = ai->getIndirectSILResults(); + auto indirectResults = ai.getIndirectSILResults(); // Record all results and result indices in type order. results.reserve(calleeFnTy->getNumResults()); unsigned dirResIdx = 0; @@ -225,10 +240,20 @@ void collectMinimalIndicesForFunctionCall( if (!param.isAutoDiffSemanticResult()) continue; unsigned idx = paramAndIdx.index() + calleeFnTy->getNumIndirectFormalResults(); - results.push_back(ai->getArgument(idx)); + results.push_back(ai.getArgument(idx)); resultIndices.push_back(semanticResultParamResultIndex++); } + // Record all yields. While we do not have a way to represent direct yields + // (_read accessors) we run activity analysis for them. These will be + // diagnosed later. + if (BeginApplyInst *bai = dyn_cast(*ai)) { + for (const auto &yieldAndIdx : enumerate(calleeConvs.getYields())) { + results.push_back(bai->getYieldedValues()[yieldAndIdx.index()]); + resultIndices.push_back(semanticResultParamResultIndex++); + } + } + // Make sure the function call has active results. #ifndef NDEBUG assert(results.size() == calleeFnTy->getNumAutoDiffSemanticResults()); diff --git a/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp b/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp index f6462f01aa5e9..8abb75d6dbd8d 100644 --- a/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp +++ b/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp @@ -16,6 +16,7 @@ #define DEBUG_TYPE "differentiation" +#include "swift/SIL/ApplySite.h" #include "swift/SILOptimizer/Differentiation/LinearMapInfo.h" #include "swift/SILOptimizer/Differentiation/ADContext.h" @@ -169,11 +170,11 @@ void LinearMapInfo::populateBranchingTraceDecl(SILBasicBlock *originalBB, } -Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { +Type LinearMapInfo::getLinearMapType(ADContext &context, FullApplySite fai) { SmallVector allResults; SmallVector activeParamIndices; SmallVector activeResultIndices; - collectMinimalIndicesForFunctionCall(ai, config, activityInfo, allResults, + collectMinimalIndicesForFunctionCall(fai, config, activityInfo, allResults, activeParamIndices, activeResultIndices); // Check if there are any active results or arguments. If not, skip @@ -183,12 +184,12 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { }); bool hasActiveSemanticResultArgument = false; bool hasActiveArguments = false; - auto numIndirectResults = ai->getNumIndirectResults(); - for (auto argIdx : range(ai->getSubstCalleeConv().getNumParameters())) { - auto arg = ai->getArgumentsWithoutIndirectResults()[argIdx]; + auto numIndirectResults = fai.getNumIndirectSILResults(); + for (auto argIdx : range(fai.getSubstCalleeConv().getNumParameters())) { + auto arg = fai.getArgumentsWithoutIndirectResults()[argIdx]; if (activityInfo.isActive(arg, config)) { hasActiveArguments = true; - auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( + auto paramInfo = fai.getSubstCalleeConv().getParamInfoForSILArg( numIndirectResults + argIdx); if (paramInfo.isAutoDiffSemanticResult()) hasActiveSemanticResultArgument = true; @@ -204,7 +205,7 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { // parameters from the function type. // - Otherwise, use the active parameters. IndexSubset *parameters; - auto origFnSubstTy = ai->getSubstCalleeType(); + auto origFnSubstTy = fai.getSubstCalleeType(); auto remappedOrigFnSubstTy = remapTypeInDerivative(SILType::getPrimitiveObjectType(origFnSubstTy)) .castTo() @@ -214,7 +215,7 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { } else { parameters = IndexSubset::get( original->getASTContext(), - ai->getArgumentsWithoutIndirectResults().size(), activeParamIndices); + fai.getArgumentsWithoutIndirectResults().size(), activeParamIndices); } // Compute differentiability results. auto *results = IndexSubset::get(original->getASTContext(), @@ -224,8 +225,8 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { AutoDiffConfig applyConfig(parameters, results); // Check for non-differentiable original function type. - auto checkNondifferentiableOriginalFunctionType = [&](CanSILFunctionType - origFnTy) { + auto checkNondifferentiableOriginalFunctionType = + [&](CanSILFunctionType origFnTy) { // Check non-differentiable arguments. for (auto paramIndex : applyConfig.parameterIndices->getIndices()) { auto remappedParamType = @@ -234,12 +235,22 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { return true; } // Check non-differentiable results. + unsigned firstSemanticParamResultIdx = origFnTy->getNumResults(); + unsigned firstYieldResultIndex = origFnTy->getNumResults() + + origFnTy->getNumAutoDiffSemanticResultsParameters(); for (auto resultIndex : applyConfig.resultIndices->getIndices()) { SILType remappedResultType; - if (resultIndex >= origFnTy->getNumResults()) { - auto semanticResultArgIdx = resultIndex - origFnTy->getNumResults(); + if (resultIndex >= firstYieldResultIndex) { + auto yieldResultIdx = resultIndex - firstYieldResultIndex; + const auto& yield = origFnTy->getYields()[yieldResultIdx]; + // We do not have a good way to differentiate direct yields + if (!yield.isAutoDiffSemanticResult()) + return true; + remappedResultType = yield.getSILStorageInterfaceType(); + } else if (resultIndex >= firstSemanticParamResultIdx) { + auto semanticResultArgIdx = resultIndex - firstSemanticParamResultIdx; auto semanticResultArg = - *std::next(ai->getAutoDiffSemanticResultArguments().begin(), + *std::next(fai.getAutoDiffSemanticResultArguments().begin(), semanticResultArgIdx); remappedResultType = semanticResultArg->getType(); } else { @@ -263,8 +274,7 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { derivative->getModule().getSwiftModule())) ->getUnsubstitutedType(original->getModule()); - auto derivativeFnResultTypes = derivativeFnType->getAllResultsInterfaceType(); - auto linearMapSILType = derivativeFnResultTypes; + auto linearMapSILType = derivativeFnType->getAllResultsInterfaceType(); if (auto tupleType = linearMapSILType.getAs()) { linearMapSILType = SILType::getPrimitiveObjectType( tupleType.getElementType(tupleType->getElements().size() - 1)); @@ -366,21 +376,27 @@ void LinearMapInfo::generateDifferentiationDataStructures( // special-case pullback generation. Linear map tuples should be empty. } else { for (auto &inst : *origBB) { - if (auto *ai = dyn_cast(&inst)) { - // Add linear map field to struct for active `apply` instructions. + if (auto *ai = dyn_cast(&inst)) // Skip array literal intrinsic applications since array literal // initialization is linear and handled separately. - if (!shouldDifferentiateApplySite(ai) || - ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) + if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC) || + ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) continue; - if (ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) - continue; - LLVM_DEBUG(getADDebugStream() - << "Adding linear map tuple field for " << *ai); - if (Type linearMapType = getLinearMapType(context, ai)) { - linearMapIndexMap.insert({ai, linearTupleTypes.size()}); - linearTupleTypes.emplace_back(linearMapType); - } + + if (!isa(&inst)) + continue; + + FullApplySite fai(&inst); + // Add linear map field to struct for active apply sites instructions. + if (!shouldDifferentiateApplySite(fai)) + continue; + + LLVM_DEBUG(getADDebugStream() + << "Adding linear map tuple field for " << inst); + if (Type linearMapType = getLinearMapType(context, fai)) { + LLVM_DEBUG(getADDebugStream() << "Computed type: " << linearMapType << '\n'); + linearMapIndexMap.insert({fai, linearTupleTypes.size()}); + linearTupleTypes.emplace_back(linearMapType); } } } @@ -513,12 +529,18 @@ bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) { // Should differentiate any allocation instruction that has an active result. if ((isa(inst) && hasActiveResults)) return true; + // Should differentiate end_apply if the corresponding begin_apply is + // differentiable + if (auto *eai = dyn_cast(inst)) + return shouldDifferentiateApplySite(eai->getBeginApply()); if (hasActiveOperands) { // Should differentiate any instruction that performs reference counting, // lifetime ending, access ending, or destroying on an active operand. if (isa(inst) || isa(inst) || isa(inst) || isa(inst) || - isa(inst) || isa(inst)) + isa(inst) || isa(inst) || + isa(inst) || + isa(inst)) return true; } diff --git a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp index 8fb4813d7518d..58c06b4e4869e 100644 --- a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp @@ -30,6 +30,7 @@ #include "swift/AST/Expr.h" #include "swift/AST/PropertyWrappers.h" #include "swift/AST/TypeCheckRequests.h" +#include "swift/SIL/ApplySite.h" #include "swift/SIL/InstructionUtils.h" #include "swift/SIL/Projection.h" #include "swift/SIL/TypeSubstCloner.h" @@ -184,11 +185,11 @@ class PullbackCloner::Implementation final /// Returns the pullback tuple element value corresponding to the given /// original block and apply inst. - SILValue getPullbackTupleElement(ApplyInst *ai) { - unsigned idx = getPullbackInfo().lookUpLinearMapIndex(ai); - assert((idx > 0 || (idx == 0 && ai->getParentBlock()->isEntry())) && + SILValue getPullbackTupleElement(FullApplySite fai) { + unsigned idx = getPullbackInfo().lookUpLinearMapIndex(fai); + assert((idx > 0 || (idx == 0 && fai.getParent()->isEntry())) && "impossible linear map index"); - auto values = pullbackTupleElements.lookup(ai->getParentBlock()); + auto values = pullbackTupleElements.lookup(fai.getParent()); assert(idx < values.size() && "pullback tuple element for this apply does not exist!"); return values[idx]; @@ -951,6 +952,7 @@ class PullbackCloner::Implementation final // `store` and `copy_addr` support. if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) return; + auto loc = ai->getLoc(); auto *bb = ai->getParent(); // Handle `array.finalize_intrinsic` applications. @@ -965,33 +967,44 @@ class PullbackCloner::Implementation final addAdjointValue(bb, origArg, adjResult, loc); return; } + + buildPullbackCall(ai); + } + + void buildPullbackCall(FullApplySite fai) { + auto loc = fai->getLoc(); + auto *bb = fai->getParent(); + // Replace a call to a function with a call to its pullback. auto &nestedApplyInfo = getContext().getNestedApplyInfo(); - auto applyInfoLookup = nestedApplyInfo.find(ai); + auto applyInfoLookup = nestedApplyInfo.find(fai); // If no `NestedApplyInfo` was found, then this task doesn't need to be // differentiated. if (applyInfoLookup == nestedApplyInfo.end()) { // Must not be active. - assert(!getActivityInfo().isActive(ai, getConfig())); + // TODO: Do we need to check token result for begin_apply? + SILValue result = fai.getResult(); + assert(!result || !getActivityInfo().isActive(result, getConfig())); return; } - auto applyInfo = applyInfoLookup->getSecond(); + auto &applyInfo = applyInfoLookup->getSecond(); // Get the original result of the `apply` instruction. + const auto &conv = fai.getSubstCalleeConv(); SmallVector origDirectResults; - forEachApplyDirectResult(ai, [&](SILValue directResult) { + forEachApplyDirectResult(fai, [&](SILValue directResult) { origDirectResults.push_back(directResult); }); SmallVector origAllResults; - collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults); + collectAllActualResultsInTypeOrder(fai, origDirectResults, origAllResults); // Append semantic result arguments after original results. for (auto paramIdx : applyInfo.config.parameterIndices->getIndices()) { - auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( - ai->getNumIndirectResults() + paramIdx); + unsigned argIdx = fai.getNumIndirectSILResults() + paramIdx; + auto paramInfo = conv.getParamInfoForSILArg(argIdx); if (!paramInfo.isAutoDiffSemanticResult()) continue; origAllResults.push_back( - ai->getArgumentsWithoutIndirectResults()[paramIdx]); + fai.getArgumentsWithoutIndirectResults()[paramIdx]); } // Get callee pullback arguments. @@ -999,7 +1012,7 @@ class PullbackCloner::Implementation final // Handle callee pullback indirect results. // Create local allocations for these and destroy them after the call. - auto pullback = getPullbackTupleElement(ai); + auto pullback = getPullbackTupleElement(fai); auto pullbackType = remapType(pullback->getType()).castTo(); @@ -1016,7 +1029,12 @@ class PullbackCloner::Implementation final } // Collect callee pullback formal arguments. + unsigned firstSemanticParamResultIdx = conv.getResults().size(); + unsigned firstYieldResultIndex = firstSemanticParamResultIdx + + conv.getNumAutoDiffSemanticResultParameters(); for (auto resultIndex : applyInfo.config.resultIndices->getIndices()) { + if (resultIndex >= firstYieldResultIndex) + continue; assert(resultIndex < origAllResults.size()); auto origResult = origAllResults[resultIndex]; // Get the seed (i.e. adjoint value of the original result). @@ -1034,22 +1052,40 @@ class PullbackCloner::Implementation final // If callee pullback was reabstracted in VJP, reabstract callee pullback. if (applyInfo.originalPullbackType) { + auto toType = *applyInfo.originalPullbackType; SILOptFunctionBuilder fb(getContext().getTransform()); - pullback = reabstractFunction( - builder, fb, loc, pullback, *applyInfo.originalPullbackType, + if (toType->isCoroutine()) + pullback = reabstractCoroutine( + builder, fb, loc, pullback, toType, + [this](SubstitutionMap subs) -> SubstitutionMap { + return this->remapSubstitutionMap(subs); + }); + else + pullback = reabstractFunction( + builder, fb, loc, pullback, toType, [this](SubstitutionMap subs) -> SubstitutionMap { return this->remapSubstitutionMap(subs); }); } // Call the callee pullback. - auto *pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(), - args); - builder.emitDestroyValueOperation(loc, pullback); - - // Extract all results from `pullbackCall`. + FullApplySite pullbackCall; SmallVector dirResults; - extractAllElements(pullbackCall, builder, dirResults); + if (actualPullbackType->isCoroutine()) { + pullbackCall = builder.createBeginApply(loc, pullback, SubstitutionMap(), + args); + // Record pullback and begin_apply token: the pullback will be consumed + // after end_apply. + applyInfo.pullback = pullback; + applyInfo.beginApplyToken = cast(pullbackCall)->getTokenResult(); + } else { + pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(), + args); + builder.emitDestroyValueOperation(loc, pullback); + // Extract all results from `pullbackCall`. + extractAllElements(cast(pullbackCall), builder, dirResults); + } + // Get all results in type-defined order. SmallVector allResults; collectAllActualResultsInTypeOrder(pullbackCall, dirResults, allResults); @@ -1063,10 +1099,10 @@ class PullbackCloner::Implementation final // Accumulate adjoints for original differentiation parameters. auto allResultsIt = allResults.begin(); for (unsigned i : applyInfo.config.parameterIndices->getIndices()) { - auto origArg = ai->getArgument(ai->getNumIndirectResults() + i); + unsigned argIdx = fai.getNumIndirectSILResults() + i; + auto origArg = fai.getArgument(argIdx); // Skip adjoint accumulation for semantic results arguments. - auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( - ai->getNumIndirectResults() + i); + auto paramInfo = fai.getSubstCalleeConv().getParamInfoForSILArg(argIdx); if (paramInfo.isAutoDiffSemanticResult()) continue; auto tan = *allResultsIt++; @@ -1086,6 +1122,22 @@ class PullbackCloner::Implementation final } } } + + // Propagate adjoints for yields + if (actualPullbackType->isCoroutine()) { + auto originalYields = cast(fai)->getYieldedValues(); + auto pullbackYields = cast(pullbackCall)->getYieldedValues(); + assert(originalYields.size() == pullbackYields.size()); + + for (auto resultIndex : applyInfo.config.resultIndices->getIndices()) { + if (resultIndex < firstYieldResultIndex) + continue; + + auto yieldResultIndex = resultIndex - firstYieldResultIndex; + setAdjointBuffer(bb, originalYields[yieldResultIndex], pullbackYields[yieldResultIndex]); + } + } + // Destroy unused pullback direct results. Needed for pullback results from // VJPs extracted from `@differentiable` function callees, where the // `@differentiable` function's differentiation parameter indices are a @@ -1096,6 +1148,7 @@ class PullbackCloner::Implementation final continue; builder.emitDestroyValueOperation(loc, unusedPullbackDirectResult); } + // Destroy and deallocate pullback indirect results. for (auto *alloc : llvm::reverse(pullbackIndirectResults)) { builder.emitDestroyAddrAndFold(loc, alloc); @@ -1103,15 +1156,55 @@ class PullbackCloner::Implementation final } } - void visitBeginApplyInst(BeginApplyInst *bai) { - // Diagnose `begin_apply` instructions. - // Coroutine differentiation is not yet supported. + void visitAbortApplyInst(AbortApplyInst *aai) { + BeginApplyInst *bai = aai->getBeginApply(); + assert(getPullbackInfo().shouldDifferentiateApplySite(bai)); + + // abort_apply differentiation is not yet supported. getContext().emitNondifferentiabilityError( bai, getInvoker(), diag::autodiff_coroutines_not_supported); errorOccurred = true; - return; } + void visitEndApplyInst(EndApplyInst *eai) { + BeginApplyInst *bai = eai->getBeginApply(); + assert(getPullbackInfo().shouldDifferentiateApplySite(bai)); + + // Replace a call to a function with a call to its pullback. + auto &nestedApplyInfo = getContext().getNestedApplyInfo(); + auto applyInfoLookup = nestedApplyInfo.find(bai); + // If no `NestedApplyInfo` was found, then this task doesn't need to be + // differentiated. + if (applyInfoLookup == nestedApplyInfo.end()) { + // Must not be active. + assert(!getActivityInfo().isActive(bai->getTokenResult(), getConfig())); + assert(!getActivityInfo().isActive(eai, getConfig())); + return; + } + + buildPullbackCall(bai); + } + + void visitBeginApplyInst(BeginApplyInst *bai) { + assert(getPullbackInfo().shouldDifferentiateApplySite(bai)); + + auto &nestedApplyInfo = getContext().getNestedApplyInfo(); + auto applyInfoLookup = nestedApplyInfo.find(bai); + // If no `NestedApplyInfo` was found, then this task doesn't need to be + // differentiated. + if (applyInfoLookup == nestedApplyInfo.end()) { + // Must not be active. + assert(!getActivityInfo().isActive(bai->getTokenResult(), getConfig())); + return; + } + auto applyInfo = applyInfoLookup->getSecond(); + + auto loc = bai->getLoc(); + builder.createEndApply(loc, applyInfo.beginApplyToken, + SILType::getEmptyTupleType(getASTContext())); + builder.emitDestroyValueOperation(loc, applyInfo.pullback); + } + /// Handle `struct` instruction. /// Original: y = struct (x0, x1, x2, ...) /// Adjoint: adj[x0] += struct_extract adj[y], #x0 @@ -1882,6 +1975,7 @@ class PullbackCloner::Implementation final NO_ADJOINT(Return) NO_ADJOINT(Branch) NO_ADJOINT(CondBranch) + NO_ADJOINT(Yield) // Address projections. NO_ADJOINT(StructElementAddr) @@ -2047,6 +2141,10 @@ bool PullbackCloner::Implementation::run() { if (Projection::isAddressProjection(v)) return false; + // Co-routines borrow adjoint buffers for yields + if (isa_and_nonnull(v.getDefiningInstruction())) + return false; + // Check that active values are differentiable. Otherwise we may crash // later when tangent space is required, but not available. if (!getTangentSpace(remapType(type).getASTType())) { @@ -2215,8 +2313,9 @@ bool PullbackCloner::Implementation::run() { // The pullback function has type: // `(seed0, seed1, ..., (exit_pb_tuple_el0, ..., )|context_obj) -> (d_arg0, ..., d_argn)`. + auto conv = getOriginal().getConventions(); auto pbParamArgs = pullback.getArgumentsWithoutIndirectResults(); - assert(getConfig().resultIndices->getNumIndices() == pbParamArgs.size() - numVals && + assert(getConfig().resultIndices->getNumIndices() - conv.getNumYields() == pbParamArgs.size() - numVals && pbParamArgs.size() >= 1); // Assign adjoints for original result. builder.setCurrentDebugScope( @@ -2224,7 +2323,15 @@ bool PullbackCloner::Implementation::run() { builder.setInsertionPoint(pullbackEntry, getNextFunctionLocalAllocationInsertionPoint()); unsigned seedIndex = 0; + unsigned firstSemanticParamResultIdx = conv.getResults().size(); + unsigned firstYieldResultIndex = firstSemanticParamResultIdx + + conv.getNumAutoDiffSemanticResultParameters(); for (auto resultIndex : getConfig().resultIndices->getIndices()) { + // Yields seed buffers are only to be touched in yield BB and required + // special handling + if (resultIndex >= firstYieldResultIndex) + continue; + auto origResult = origFormalResults[resultIndex]; auto *seed = pbParamArgs[seedIndex]; if (seed->getType().isAddress()) { @@ -2235,6 +2342,10 @@ bool PullbackCloner::Implementation::run() { if (seedParamInfo.isIndirectInOut()) { setAdjointBuffer(originalExitBlock, origResult, seed); + LLVM_DEBUG(getADDebugStream() + << "Assigned seed buffer " << *seed + << " as the adjoint of original indirect result " + << origResult); } // Otherwise, assign a copy of the seed argument as the adjoint buffer of // the original result. @@ -2289,7 +2400,6 @@ bool PullbackCloner::Implementation::run() { // This vector will identify the locations where initialization is needed. SmallBitVector outputsToInitialize; - auto conv = getOriginal().getConventions(); auto origParams = getOriginal().getArgumentsWithoutIndirectResults(); // Materializes the return element corresponding to the parameter @@ -2718,6 +2828,27 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) { if (bb->isEntry()) return; + // If the original block is a resume yield destination, then we need to yield + // the adjoint buffer and do everything else in the resume destination. Unwind + // destination is unreachable as the co-routine can never be aborted. + if (auto *predBB = bb->getSinglePredecessorBlock()) { + if (auto *yield = dyn_cast(predBB->getTerminator())) { + auto *resumeBB = pbBB->split(builder.getInsertionPoint()); + auto *unwindBB = getPullback().createBasicBlock(); + + SmallVector adjYields; + for (auto yieldedVal : yield->getYieldedValues()) + adjYields.push_back(getAdjointBuffer(bb, yieldedVal)); + + builder.createYield(yield->getLoc(), adjYields, resumeBB, unwindBB); + builder.setInsertionPoint(unwindBB); + builder.createUnreachable(SILLocation::invalid()); + + pbBB = resumeBB; + builder.setInsertionPoint(pbBB); + } + } + // Otherwise, add a `switch_enum` terminator for non-exit // pullback blocks. // 1. Get the pullback struct pullback block argument. @@ -2819,7 +2950,7 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) { origPredpullbackSuccBBMap[predBB] = pullbackSuccBB; auto *enumEltDecl = getPullbackInfo().lookUpBranchingTraceEnumElement(predBB, bb); - pullbackSuccessorCases.push_back({enumEltDecl, pullbackSuccBB}); + pullbackSuccessorCases.emplace_back(enumEltDecl, pullbackSuccBB); } // Values are trampolined by only a subset of pullback successor blocks. // Other successors blocks should destroy the value. diff --git a/lib/SILOptimizer/Differentiation/Thunk.cpp b/lib/SILOptimizer/Differentiation/Thunk.cpp index 116ea3a2d228a..163d7dbb170a6 100644 --- a/lib/SILOptimizer/Differentiation/Thunk.cpp +++ b/lib/SILOptimizer/Differentiation/Thunk.cpp @@ -326,6 +326,41 @@ SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb, return thunk; } +// FIXME: This is pretty rudimentary as of now as there is no proper AST type +// for coroutine and therefore we cannot e.g. store a coroutine into a tuple or +// do other things that are allowed with first-class function types. For now we +// have to unsafely bitcast coroutine to function type and vice versa. This +// function should be rethought when we will have proper AST coroutine types. +SILValue reabstractCoroutine( + SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc, + SILValue fn, CanSILFunctionType toType, + std::function remapSubstitutions) { + auto &module = *fn->getModule(); + auto fromType = fn->getType().getAs(); + auto unsubstFromType = fromType->getUnsubstitutedType(module); + auto unsubstToType = toType->getUnsubstitutedType(module); + + LLVM_DEBUG(auto &s = getADDebugStream() << "Converting coroutine\n"; + s << " From type: " << fromType << '\n'; + s << " To type: " << toType << '\n'; s << '\n'); + + if (fromType != unsubstFromType) + fn = builder.createConvertFunction( + loc, fn, SILType::getPrimitiveObjectType(unsubstFromType), + /*withoutActuallyEscaping*/ false); + + fn = builder.createConvertFunction(loc, fn, + SILType::getPrimitiveObjectType(unsubstToType), + /*withoutActuallyEscaping*/ false); + + if (toType != unsubstToType) + fn = builder.createConvertFunction(loc, fn, + SILType::getPrimitiveObjectType(toType), + /*withoutActuallyEscaping*/ false); + + return fn; +} + SILValue reabstractFunction( SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc, SILValue fn, CanSILFunctionType toType, diff --git a/lib/SILOptimizer/Differentiation/VJPCloner.cpp b/lib/SILOptimizer/Differentiation/VJPCloner.cpp index 86180ef170fcb..b329e6ea4dd97 100644 --- a/lib/SILOptimizer/Differentiation/VJPCloner.cpp +++ b/lib/SILOptimizer/Differentiation/VJPCloner.cpp @@ -151,6 +151,14 @@ class VJPCloner::Implementation final return vjp->getLoweredType(pattern, type); } + SILType getPullbackType() { + auto vjpFuncTy = vjp->getLoweredFunctionType(); + const auto &conv = vjp->getConventions(); + + return conv.getSILType(vjpFuncTy->getResults().back(), + vjp->getTypeExpansionContext()); + } + GenericSignature getBuiltinAutoDiffAllocateSubcontextDecl() { if (builtinAutoDiffAllocateSubcontextGenericSignature) return builtinAutoDiffAllocateSubcontextGenericSignature; @@ -268,10 +276,7 @@ class VJPCloner::Implementation final ParameterConvention::Direct_Guaranteed); } - auto pullbackType = vjp->mapTypeIntoContext( - vjp->getConventions().getSILType( - vjp->getLoweredFunctionType()->getResults().back(), - vjp->getTypeExpansionContext())); + auto pullbackType = vjp->mapTypeIntoContext(getPullbackType()); auto pullbackFnType = pullbackType.castTo(); auto pullbackSubstType = pullbackPartialApply->getType().castTo(); @@ -295,10 +300,38 @@ class VJPCloner::Implementation final SmallVector directResults; directResults.append(origResults.begin(), origResults.end()); directResults.push_back(pullbackValue); + Builder.createReturn(ri->getLoc(), joinElements(directResults, Builder, loc)); } + void visitUnwindInst(UnwindInst *ui) { + Builder.setCurrentDebugScope(getOpScope(ui->getDebugScope())); + auto loc = ui->getLoc(); + auto *origExit = ui->getParent(); + + // Consume unused pullback values + if (borrowedPullbackContextValue) { + auto *pbTupleVal = buildPullbackValueTupleValue(ui); + // Initialize the top-level subcontext buffer with the top-level pullback + // tuple. + auto addr = emitProjectTopLevelSubcontext( + Builder, loc, borrowedPullbackContextValue, pbTupleVal->getType()); + Builder.createStore( + loc, pbTupleVal, addr, + pbTupleVal->getType().isTrivial(*pullback) ? + StoreOwnershipQualifier::Trivial : StoreOwnershipQualifier::Init); + + Builder.createEndBorrow(loc, borrowedPullbackContextValue); + Builder.emitDestroyValueOperation(loc, pullbackContextValue); + } else { + for (SILValue val : getPullbackValues(origExit)) + Builder.emitDestroyValueOperation(loc, val); + } + + Builder.createUnwind(loc); + } + void visitBranchInst(BranchInst *bi) { Builder.setCurrentDebugScope(getOpScope(bi->getDebugScope())); // Build pullback struct value for original block. @@ -319,6 +352,19 @@ class VJPCloner::Implementation final args); } + void visitYieldInst(YieldInst *yi) { + Builder.setCurrentDebugScope(getOpScope(yi->getDebugScope())); + // Build pullback struct value for original block. + auto *pbTupleVal = buildPullbackValueTupleValue(yi); + // Create a new `yield` instruction. Note that resume / unwind blocks cannot + // have arguments, so we're building trampolines with branch tracing enum + // values. + getBuilder().createYield( + yi->getLoc(), getOpValueArray<1>(yi->getOperandValues()), + createTrampolineBasicBlock(yi, pbTupleVal, yi->getResumeBB()), + createTrampolineBasicBlock(yi, pbTupleVal, yi->getUnwindBB())); + } + void visitCondBranchInst(CondBranchInst *cbi) { Builder.setCurrentDebugScope(getOpScope(cbi->getDebugScope())); // Build pullback struct value for original block. @@ -401,6 +447,285 @@ class VJPCloner::Implementation final ccabi->getTrueBBCount(), ccabi->getFalseBBCount()); } + void visitEndApplyInst(EndApplyInst *eai) { + BeginApplyInst *bai = eai->getBeginApply(); + + // If callee should not be differentiated, do standard cloning. + if (!pullbackInfo.shouldDifferentiateApplySite(bai)) { + LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *bai << '\n'); + TypeSubstCloner::visitEndApplyInst(eai); + return; + } + + Builder.setCurrentDebugScope(getOpScope(eai->getDebugScope())); + auto loc = eai->getLoc(); + auto &builder = getBuilder(); + auto token = getMappedValue(bai->getTokenResult()); + + LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *eai << '\n'); + + FullApplySite fai(token->getDefiningInstruction()); + auto vjpResult = builder.createEndApply(loc, token, fai.getType()); + LLVM_DEBUG(getADDebugStream() << "Created end_apply\n" << *vjpResult); + + builder.emitDestroyValueOperation(loc, fai.getCallee()); + + // Checkpoint the pullback. + SmallVector vjpDirectResults; + extractAllElements(vjpResult, getBuilder(), vjpDirectResults); + ArrayRef originalDirectResults = + ArrayRef(vjpDirectResults).drop_back(1); + SILValue originalDirectResult = + joinElements(originalDirectResults, getBuilder(), loc); + SILValue pullback = vjpDirectResults.back(); + { + auto pullbackFnType = pullback->getType().castTo(); + auto pullbackUnsubstFnType = + pullbackFnType->getUnsubstitutedType(getModule()); + if (pullbackFnType != pullbackUnsubstFnType) { + pullback = builder.createConvertFunction( + loc, pullback, + SILType::getPrimitiveObjectType(pullbackUnsubstFnType), + /*withoutActuallyEscaping*/ false); + } + } + + // Store the original result to the value map. + mapValue(eai, originalDirectResult); + + auto pullbackType = pullbackInfo.lookUpLinearMapType(bai); + + // If actual pullback type does not match lowered pullback type, reabstract + // the pullback using a thunk. + auto actualPullbackType = + getOpType(pullback->getType()).getAs(); + auto loweredPullbackType = + getOpType(getLoweredType(pullbackType)).castTo(); + + auto applyInfoIt = context.getNestedApplyInfo().find(bai); + assert(applyInfoIt != context.getNestedApplyInfo().end()); + if (!loweredPullbackType->isEqual(actualPullbackType)) { + // Set non-reabstracted original pullback type in nested apply info. + applyInfoIt->second.originalPullbackType = actualPullbackType; + SILOptFunctionBuilder fb(context.getTransform()); + pullback = reabstractCoroutine( + getBuilder(), fb, loc, pullback, loweredPullbackType, + [this](SubstitutionMap subs) -> SubstitutionMap { + return this->getOpSubstitutionMap(subs); + }); + } + unsigned pullbackIdx = applyInfoIt->second.pullbackIdx; + pullbackValues[bai->getParent()][pullbackIdx] = pullback; + + // Some instructions that produce the callee may have been cloned. + // If the original callee did not have any users beyond this `apply`, + // recursively kill the cloned callee. + if (auto *origCallee = cast_or_null( + bai->getCallee()->getDefiningInstruction())) + if (origCallee->hasOneUse()) + recursivelyDeleteTriviallyDeadInstructions( + getOpValue(origCallee)->getDefiningInstruction()); + } + + // Check and diagnose non-differentiable original function type. + bool diagnoseNondifferentiableOriginalFunctionType(CanSILFunctionType originalFnTy, + FullApplySite fai, SILValue origCallee, + const AutoDiffConfig &config) const { + // Check and diagnose non-differentiable arguments. + for (auto paramIndex : config.parameterIndices->getIndices()) { + if (!originalFnTy->getParameters()[paramIndex] + .getSILStorageInterfaceType() + .isDifferentiable(getModule())) { + auto arg = fai.getArgumentsWithoutIndirectResults()[paramIndex]; + // FIXME: This shouldn't be necessary and might indicate a bug in + // the transformation. + RegularLocation nonAutoGenLoc(arg.getLoc()); + nonAutoGenLoc.markNonAutoGenerated(); + auto startLoc = nonAutoGenLoc.getStartSourceLoc(); + auto endLoc = nonAutoGenLoc.getEndSourceLoc(); + context.emitNondifferentiabilityError( + arg, invoker, diag::autodiff_nondifferentiable_argument) + .fixItInsert(startLoc, "withoutDerivative(at: ") + .fixItInsertAfter(endLoc, ")"); + return true; + } + } + + // Check and diagnose non-differentiable results. + unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults(); + unsigned firstYieldResultIndex = originalFnTy->getNumResults() + + originalFnTy->getNumAutoDiffSemanticResultsParameters(); + + for (auto resultIndex : config.resultIndices->getIndices()) { + SILType remappedResultType; + if (resultIndex >= firstYieldResultIndex) { + auto yieldResultIdx = resultIndex - firstYieldResultIndex; + const auto& yield = originalFnTy->getYields()[yieldResultIdx]; + // We do not have a good way to differentiate direct yields + if (yield.isAutoDiffSemanticResult()) + remappedResultType = yield.getSILStorageInterfaceType(); + else { + context.emitNondifferentiabilityError( + origCallee, invoker, + diag::autodiff_cannot_differentiate_through_direct_yield); + return true; + } + } else if (resultIndex >= firstSemanticParamResultIdx) { + auto semanticResultArgIdx = resultIndex - firstSemanticParamResultIdx; + auto semanticResultArg = + *std::next(fai.getAutoDiffSemanticResultArguments().begin(), + semanticResultArgIdx); + remappedResultType = semanticResultArg->getType(); + } else { + remappedResultType = originalFnTy->getResults()[resultIndex] + .getSILStorageInterfaceType(); + } + + if (!remappedResultType || !remappedResultType.isDifferentiable(getModule())) { + auto startLoc = fai.getLoc().getStartSourceLoc(); + auto endLoc = fai.getLoc().getEndSourceLoc(); + context.emitNondifferentiabilityError( + origCallee, invoker, + diag::autodiff_nondifferentiable_result) + .fixItInsert(startLoc, "withoutDerivative(at: ") + .fixItInsertAfter(endLoc, ")"); + return true; + } + } + + return false; + } + + void visitBeginApplyInst(BeginApplyInst *bai) { + // If callee should not be differentiated, do standard cloning. + if (!pullbackInfo.shouldDifferentiateApplySite(bai)) { + LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *bai << '\n'); + TypeSubstCloner::visitBeginApplyInst(bai); + return; + } + + Builder.setCurrentDebugScope(getOpScope(bai->getDebugScope())); + auto loc = bai->getLoc(); + auto &builder = getBuilder(); + auto origCallee = getOpValue(bai->getCallee()); + auto originalFnTy = origCallee->getType().castTo(); + + LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *bai << '\n'); + + + SmallVector allResults; + SmallVector activeParamIndices; + SmallVector activeResultIndices; + collectMinimalIndicesForFunctionCall(bai, getConfig(), activityInfo, + allResults, activeParamIndices, + activeResultIndices); + assert(!activeParamIndices.empty() && "Parameter indices cannot be empty"); + assert(!activeResultIndices.empty() && "Result indices cannot be empty"); + LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params=("; + llvm::interleave( + activeParamIndices.begin(), activeParamIndices.end(), + [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); + s << "), results=("; llvm::interleave( + activeResultIndices.begin(), activeResultIndices.end(), + [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); + s << ")\n";); + + // Form expected indices. + AutoDiffConfig config( + IndexSubset::get(getASTContext(), + bai->getArgumentsWithoutIndirectResults().size(), + activeParamIndices), + IndexSubset::get(getASTContext(), + bai->getSubstCalleeType()->getNumAutoDiffSemanticResults(), + activeResultIndices)); + + if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy, + bai, origCallee, config)) { + errorOccurred = true; + return; + } + + // Emit the VJP. + SILValue vjpValue; + + // If the original `apply` instruction has a substitution map, then the + // applied function is specialized. + // In the VJP, specialization is also necessary for parity. The original + // function operand is specialized with a remapped version of same + // substitution map using an argument-less `partial_apply`. + if (bai->getSubstitutionMap().empty()) { + origCallee = builder.emitCopyValueOperation(loc, origCallee); + } else { + auto substMap = getOpSubstitutionMap(bai->getSubstitutionMap()); + auto vjpPartialApply = getBuilder().createPartialApply( + bai->getLoc(), origCallee, substMap, {}, + ParameterConvention::Direct_Guaranteed); + origCallee = vjpPartialApply; + originalFnTy = origCallee->getType().castTo(); + + // Diagnose if new original function type is non-differentiable. + if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy, + bai, origCallee, config)) { + errorOccurred = true; + return; + } + } + + auto *diffFuncInst = + context.createDifferentiableFunction(getBuilder(), loc, + config.parameterIndices, config.resultIndices, + origCallee); + + // Record the `differentiable_function` instruction. + context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst); + + builder.emitScopedBorrowOperation( + loc, diffFuncInst, + [&](SILValue borrowedADFunc) { + auto extractedVJP = + getBuilder().createDifferentiableFunctionExtract( + loc, NormalDifferentiableFunctionTypeComponent::VJP, + borrowedADFunc); + vjpValue = builder.emitCopyValueOperation(loc, extractedVJP); + }); + builder.emitDestroyValueOperation(loc, diffFuncInst); + + // Record desired/actual VJP indices. + // Temporarily set original pullback type to `None`. + NestedApplyInfo info{config, /*originalPullbackType*/ std::nullopt}; + auto insertion = context.getNestedApplyInfo().try_emplace(bai, info); + auto &nestedApplyInfo = insertion.first->getSecond(); + nestedApplyInfo = info; + + // Call the VJP using the original parameters. + SmallVector vjpArgs; + auto vjpFnTy = getOpType(vjpValue->getType()).castTo(); + auto numVJPArgs = + vjpFnTy->getNumParameters() + vjpFnTy->getNumIndirectFormalResults(); + vjpArgs.reserve(numVJPArgs); + // Collect substituted arguments. + for (auto origArg : bai->getArguments()) + vjpArgs.push_back(getOpValue(origArg)); + + // Apply the VJP. + // The VJP should be specialized, so no substitution map is necessary. + auto *vjpCall = getBuilder().createBeginApply(loc, vjpValue, SubstitutionMap(), + vjpArgs, bai->getApplyOptions()); + LLVM_DEBUG(getADDebugStream() << "Applied vjp function\n" << *vjpCall); + // Note that vjpValue is destroyed after end_apply + + // Store all the results (yields and token) to the value map. + assert(bai->getNumResults() == vjpCall->getNumResults()); + for (unsigned i = 0; i < vjpCall->getNumResults(); ++i) + mapValue(bai->getResult(i), vjpCall->getResult(i)); + + // Checkpoint the pullback. + nestedApplyInfo.pullbackIdx = pullbackValues[bai->getParent()].size(); + pullbackValues[bai->getParent()].push_back(SILValue()); + + // The rest of the cloning magic happens during `end_apply` cloning. + } + // If an `apply` has active results or active inout arguments, replace it // with an `apply` of its VJP. void visitApplyInst(ApplyInst *ai) { @@ -515,60 +840,11 @@ class VJPCloner::Implementation final } } - // Check and diagnose non-differentiable original function type. - auto diagnoseNondifferentiableOriginalFunctionType = - [&](CanSILFunctionType origFnTy) { - // Check and diagnose non-differentiable arguments. - for (auto paramIndex : config.parameterIndices->getIndices()) { - if (!originalFnTy->getParameters()[paramIndex] - .getSILStorageInterfaceType() - .isDifferentiable(getModule())) { - auto arg = ai->getArgumentsWithoutIndirectResults()[paramIndex]; - // FIXME: This shouldn't be necessary and might indicate a bug in - // the transformation. - RegularLocation nonAutoGenLoc(arg.getLoc()); - nonAutoGenLoc.markNonAutoGenerated(); - auto startLoc = nonAutoGenLoc.getStartSourceLoc(); - auto endLoc = nonAutoGenLoc.getEndSourceLoc(); - context - .emitNondifferentiabilityError( - arg, invoker, diag::autodiff_nondifferentiable_argument) - .fixItInsert(startLoc, "withoutDerivative(at: ") - .fixItInsertAfter(endLoc, ")"); - errorOccurred = true; - return true; - } - } - // Check and diagnose non-differentiable results. - for (auto resultIndex : config.resultIndices->getIndices()) { - SILType remappedResultType; - if (resultIndex >= originalFnTy->getNumResults()) { - auto semanticResultArgIdx = resultIndex - originalFnTy->getNumResults(); - auto semanticResultArg = - *std::next(ai->getAutoDiffSemanticResultArguments().begin(), - semanticResultArgIdx); - remappedResultType = semanticResultArg->getType(); - } else { - remappedResultType = originalFnTy->getResults()[resultIndex] - .getSILStorageInterfaceType(); - } - if (!remappedResultType.isDifferentiable(getModule())) { - auto startLoc = ai->getLoc().getStartSourceLoc(); - auto endLoc = ai->getLoc().getEndSourceLoc(); - context - .emitNondifferentiabilityError( - origCallee, invoker, - diag::autodiff_nondifferentiable_result) - .fixItInsert(startLoc, "withoutDerivative(at: ") - .fixItInsertAfter(endLoc, ")"); - errorOccurred = true; - return true; - } - } - return false; - }; - if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy)) + if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy, + ai, origCallee, config)) { + errorOccurred = true; return; + } // If VJP has not yet been found, emit an `differentiable_function` // instruction on the remapped original function operand and @@ -601,9 +877,13 @@ class VJPCloner::Implementation final ParameterConvention::Direct_Guaranteed); origCallee = vjpPartialApply; originalFnTy = origCallee->getType().castTo(); + // Diagnose if new original function type is non-differentiable. - if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy)) + if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy, + ai, origCallee, config)) { + errorOccurred = true; return; + } } auto *diffFuncInst = context.createDifferentiableFunction( @@ -690,6 +970,7 @@ class VJPCloner::Implementation final return this->getOpSubstitutionMap(subs); }); } + nestedApplyInfo.pullbackIdx = pullbackValues[ai->getParent()].size(); pullbackValues[ai->getParent()].push_back(pullback); // Some instructions that produce the callee may have been cloned. @@ -893,6 +1174,7 @@ SILFunction *VJPCloner::Implementation::createEmptyPullback() { // Results of the pullback are in the tangent space of the original // parameters. SmallVector pbParams; + SmallVector pbYields; SmallVector adjResults; auto origParams = origTy->getParameters(); auto config = witness->getConfig(); @@ -906,9 +1188,12 @@ SILFunction *VJPCloner::Implementation::createEmptyPullback() { semanticResultParamIndices.push_back(i); } + unsigned firstSemanticParamResultIdx = origTy->getNumResults(); + unsigned firstYieldResultIndex = firstSemanticParamResultIdx + + origTy->getNumAutoDiffSemanticResultsParameters(); for (auto resultIndex : config.resultIndices->getIndices()) { // Handle formal result. - if (resultIndex < origTy->getNumResults()) { + if (resultIndex < firstSemanticParamResultIdx) { auto origResult = origTy->getResults()[resultIndex]; origResult = origResult.getWithInterfaceType( origResult.getInterfaceType()->getReducedType(witnessCanGenSig)); @@ -919,36 +1204,51 @@ SILFunction *VJPCloner::Implementation::createEmptyPullback() { ->getReducedType(witnessCanGenSig), origResult.getConvention()); pbParams.push_back(paramInfo); - continue; - } - - // Handle semantic result parameter. - unsigned paramIndex = 0; - unsigned resultParamIndex = 0; - for (auto i : range(origTy->getNumParameters())) { - auto origParam = origTy->getParameters()[i]; - if (!origParam.isAutoDiffSemanticResult()) { + } else if (resultIndex < firstYieldResultIndex) { + // Handle semantic result parameter. + unsigned paramIndex = 0; + unsigned resultParamIndex = 0; + for (auto i : range(origTy->getNumParameters())) { + auto origParam = origTy->getParameters()[i]; + if (!origParam.isAutoDiffSemanticResult()) { + ++paramIndex; + continue; + } + if (resultParamIndex == resultIndex - firstSemanticParamResultIdx) + break; ++paramIndex; - continue; + ++resultParamIndex; } - if (resultParamIndex == resultIndex - origTy->getNumResults()) - break; - ++paramIndex; - ++resultParamIndex; + auto resultParam = origParams[paramIndex]; + auto origResult = resultParam.getWithInterfaceType( + resultParam.getInterfaceType()->getReducedType(witnessCanGenSig)); + + auto resultParamTanConvention = resultParam.getConvention(); + if (!config.isWrtParameter(paramIndex)) + resultParamTanConvention = ParameterConvention::Indirect_In_Guaranteed; + + pbParams.emplace_back(origResult.getInterfaceType() + ->getAutoDiffTangentSpace(lookupConformance) + ->getType() + ->getReducedType(witnessCanGenSig), + resultParamTanConvention); + } else { + assert(origTy->isCoroutine()); + assert(origTy->getCoroutineKind() == SILCoroutineKind::YieldOnce); + + auto yieldResultIndex = resultIndex - firstYieldResultIndex; + auto yieldResult = origTy->getYields()[yieldResultIndex]; + auto origYield = + yieldResult.getWithInterfaceType( + yieldResult.getInterfaceType()->getReducedType(witnessCanGenSig)); + assert(yieldResult.getConvention() == ParameterConvention::Indirect_Inout); + + pbYields.emplace_back(origYield.getInterfaceType() + ->getAutoDiffTangentSpace(lookupConformance) + ->getType() + ->getReducedType(witnessCanGenSig), + yieldResult.getConvention()); } - auto resultParam = origParams[paramIndex]; - auto origResult = resultParam.getWithInterfaceType( - resultParam.getInterfaceType()->getReducedType(witnessCanGenSig)); - - auto resultParamTanConvention = resultParam.getConvention(); - if (!config.isWrtParameter(paramIndex)) - resultParamTanConvention = ParameterConvention::Indirect_In_Guaranteed; - - pbParams.emplace_back(origResult.getInterfaceType() - ->getAutoDiffTangentSpace(lookupConformance) - ->getType() - ->getReducedType(witnessCanGenSig), - resultParamTanConvention); } if (pullbackInfo.hasHeapAllocatedContext()) { @@ -958,7 +1258,7 @@ SILFunction *VJPCloner::Implementation::createEmptyPullback() { ParameterConvention::Direct_Guaranteed }); } else { - // Accept a pullback struct in the pullback parameter list. This is the + // Accept a pullback tuple in the pullback parameter list. This is the // returned pullback's closure context. auto *origExit = &*original->findReturnBB(); auto pbTupleType = @@ -992,7 +1292,7 @@ SILFunction *VJPCloner::Implementation::createEmptyPullback() { auto *pbGenericEnv = pbGenericSig.getGenericEnvironment(); auto pbType = SILFunctionType::get( pbGenericSig, SILExtInfo::getThin(), origTy->getCoroutineKind(), - origTy->getCalleeConvention(), pbParams, {}, adjResults, std::nullopt, + origTy->getCalleeConvention(), pbParams, pbYields, adjResults, std::nullopt, origTy->getPatternSubstitutions(), origTy->getInvocationSubstitutions(), original->getASTContext()); @@ -1115,8 +1415,10 @@ bool VJPCloner::Implementation::run() { emitLinearMapContextInitializationIfNeeded(); // Clone. - SmallVector entryArgs(entry->getArguments().begin(), - entry->getArguments().end()); + SmallVector entryArgs; + entryArgs.assign(entry->getArguments().begin(), + entry->getArguments().end()); + cloneFunctionBody(original, entry, entryArgs); // If errors occurred, back out. if (errorOccurred) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 51ff587a094ff..9faea438c6889 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -156,6 +156,19 @@ static bool diagnoseUnsupportedControlFlow(ADContext &context, isa(term) || isa(term) || isa(term)) continue; + + // We can differentiate only indirect yields. + if (auto *yi = dyn_cast(term)) { +#ifndef NDEBUG + for (const auto &val : yi->getAllOperands()) { + // This should be diagnosed earlier in VJPCloner. + assert(yi->getYieldInfoForOperand(val).isAutoDiffSemanticResult() && + "unsupported result"); + } +#endif + continue; + } + // If terminator is an unsupported branching terminator, emit an error. if (term->isBranch()) { context.emitNondifferentiabilityError( @@ -541,10 +554,20 @@ emitDerivativeFunctionReference( } } // Check and diagnose non-differentiable results. + unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults(); + unsigned firstYieldResultIndex = originalFnTy->getNumResults() + + originalFnTy->getNumAutoDiffSemanticResultsParameters(); for (auto resultIndex : desiredResultIndices->getIndices()) { SILType resultType; - if (resultIndex >= originalFnTy->getNumResults()) { - auto semanticResultParamIdx = resultIndex - originalFnTy->getNumResults(); + if (resultIndex >= firstYieldResultIndex) { + auto yieldResultIndex = resultIndex - firstYieldResultIndex; + auto yield = originalFnTy->getYields()[yieldResultIndex]; + // We can only differentiate indirect yields. This should be diagnosed + // earlier in VJPCloner. + assert(yield.isAutoDiffSemanticResult() && "unsupported result"); + resultType = yield.getSILStorageInterfaceType(); + } else if (resultIndex >= firstSemanticParamResultIdx) { + auto semanticResultParamIdx = resultIndex - firstSemanticParamResultIdx; auto semanticResultParam = *std::next(originalFnTy->getAutoDiffSemanticResultsParameters().begin(), semanticResultParamIdx); @@ -553,7 +576,7 @@ emitDerivativeFunctionReference( resultType = originalFnTy->getResults()[resultIndex] .getSILStorageInterfaceType(); } - if (!resultType.isDifferentiable(context.getModule())) { + if (!resultType || !resultType.isDifferentiable(context.getModule())) { context.emitNondifferentiabilityError( original, invoker, diag::autodiff_nondifferentiable_result); return std::nullopt; diff --git a/stdlib/public/runtime/MetadataLookup.cpp b/stdlib/public/runtime/MetadataLookup.cpp index 9bc380c791748..f6ce2bb964e29 100644 --- a/stdlib/public/runtime/MetadataLookup.cpp +++ b/stdlib/public/runtime/MetadataLookup.cpp @@ -2023,7 +2023,9 @@ class DecodedMetadataBuilder { TypeLookupErrorOr createImplFunctionType( Demangle::ImplParameterConvention calleeConvention, + Demangle::ImplCoroutineKind coroutineKind, llvm::ArrayRef> params, + llvm::ArrayRef> yields, llvm::ArrayRef> results, std::optional> errorResult, ImplFunctionTypeFlags flags) { diff --git a/test/AutoDiff/SILOptimizer/activity_analysis.swift b/test/AutoDiff/SILOptimizer/activity_analysis.swift index e4b10d66fd7b4..f29e029f513b3 100644 --- a/test/AutoDiff/SILOptimizer/activity_analysis.swift +++ b/test/AutoDiff/SILOptimizer/activity_analysis.swift @@ -571,53 +571,81 @@ func testTryApply(_ x: Float) -> Float { // Coroutine differentiation (`begin_apply`) //===----------------------------------------------------------------------===// -struct HasCoroutineAccessors: Differentiable { +struct HasCoroutineReadAccessors: Differentiable { var stored: Float var computed: Float { // `_read` is a coroutine: `(Self) -> () -> ()`. _read { yield stored } + } +} + +struct HasCoroutineModifyAccessors: Differentiable { + var stored: Float + var computed: Float { + get { stored } // `_modify` is a coroutine: `(inout Self) -> () -> ()`. _modify { yield &stored } } } + + // expected-error @+1 {{function is not differentiable}} @differentiable(reverse) // expected-note @+1 {{when differentiating this function definition}} -func testAccessorCoroutines(_ x: HasCoroutineAccessors) -> HasCoroutineAccessors { +func testAccessorCoroutinesRead(_ x: HasCoroutineReadAccessors) -> Float { + // We do not support differentiation of _read accessors + // expected-note @+1 {{cannot differentiate through a '_read' accessor}} + return x.computed +} + +// CHECK-LABEL: [AD] Activity info for ${{.*}}testAccessorCoroutinesRead{{.*}} at parameter indices (0) and result indices (0) +// CHECK: [ACTIVE] %0 = argument of bb0 : $HasCoroutineReadAccessors +// CHECK: [NONE] // function_ref HasCoroutineReadAccessors.computed.read +// CHECK: %2 = function_ref @$s17activity_analysis25HasCoroutineReadAccessorsV8computedSfvr : $@yield_once @convention(method) (HasCoroutineReadAccessors) -> @yields Float +// CHECK: [ACTIVE] (**%3**, %4) = begin_apply %2(%0) : $@yield_once @convention(method) (HasCoroutineReadAccessors) -> @yields Float +// CHECK: [VARIED] (%3, **%4**) = begin_apply %2(%0) : $@yield_once @convention(method) (HasCoroutineReadAccessors) -> @yields Float +// CHECK: [VARIED] %5 = end_apply %4 as $() + +@differentiable(reverse) +func testAccessorCoroutinesModify(_ x: HasCoroutineModifyAccessors) -> Float { var x = x - // expected-note @+1 {{differentiation of coroutine calls is not yet supported}} - x.computed = x.computed - return x + x.computed *= x.computed + return x.computed } -// CHECK-LABEL: [AD] Activity info for ${{.*}}testAccessorCoroutines{{.*}} at parameter indices (0) and result indices (0) -// CHECK: [ACTIVE] %0 = argument of bb0 : $HasCoroutineAccessors -// CHECK: [ACTIVE] %2 = alloc_stack [var_decl] $HasCoroutineAccessors, var, name "x" -// CHECK: [ACTIVE] %4 = begin_access [read] [static] %2 : $*HasCoroutineAccessors -// CHECK: [ACTIVE] %5 = load [trivial] %4 : $*HasCoroutineAccessors -// CHECK: [NONE] // function_ref HasCoroutineAccessors.computed.read -// CHECK: [ACTIVE] (**%7**, %8) = begin_apply %6(%5) : $@yield_once @convention(method) (HasCoroutineAccessors) -> @yields Float -// CHECK: [VARIED] (%7, **%8**) = begin_apply %6(%5) : $@yield_once @convention(method) (HasCoroutineAccessors) -> @yields Float -// CHECK: [ACTIVE] %9 = alloc_stack $Float -// CHECK: [ACTIVE] %11 = load [trivial] %9 : $*Float -// CHECK: [ACTIVE] %14 = begin_access [modify] [static] %2 : $*HasCoroutineAccessors -// CHECK: [NONE] // function_ref HasCoroutineAccessors.computed.modify -// CHECK: %15 = function_ref @${{.*}}21HasCoroutineAccessorsV8computedSfvM : $@yield_once @convention(method) (@inout HasCoroutineAccessors) -> @yields @inout Float -// CHECK: [ACTIVE] (**%16**, %17) = begin_apply %15(%14) : $@yield_once @convention(method) (@inout HasCoroutineAccessors) -> @yields @inout Float -// CHECK: [VARIED] (%16, **%17**) = begin_apply %15(%14) : $@yield_once @convention(method) (@inout HasCoroutineAccessors) -> @yields @inout Float -// CHECK: [ACTIVE] %22 = begin_access [read] [static] %2 : $*HasCoroutineAccessors -// CHECK: [ACTIVE] %23 = load [trivial] %22 : $*HasCoroutineAccessors +// CHECK-LABEL: [AD] Activity info for ${{.*}}testAccessorCoroutinesModify{{.*}} at parameter indices (0) and result indices (0) +// CHECK: [ACTIVE] %0 = argument of bb0 : $HasCoroutineModifyAccessors +// CHECK: [ACTIVE] %2 = alloc_stack [var_decl] $HasCoroutineModifyAccessors +// CHECK: [USEFUL] %4 = metatype $@thin Float.Type +// CHECK: [ACTIVE] %5 = begin_access [read] [static] %2 : $*HasCoroutineModifyAccessors +// CHECK: [ACTIVE] %6 = load [trivial] %5 : $*HasCoroutineModifyAccessors +// CHECK: [NONE] // function_ref HasCoroutineModifyAccessors.computed.getter +// CHECK: %7 = function_ref @$s17activity_analysis27HasCoroutineModifyAccessorsV8computedSfvg : $@convention(method) (HasCoroutineModifyAccessors) -> Float +// CHECK: [ACTIVE] %8 = apply %7(%6) : $@convention(method) (HasCoroutineModifyAccessors) -> Float +// CHECK: [ACTIVE] %10 = begin_access [modify] [static] %2 : $*HasCoroutineModifyAccessors +// CHECK: [NONE] // function_ref HasCoroutineModifyAccessors.computed.modify +// CHECK: %11 = function_ref @$s17activity_analysis27HasCoroutineModifyAccessorsV8computedSfvM : $@yield_once @convention(method) (@inout HasCoroutineModifyAccessors) -> @yields @inout Float +// CHECK: [ACTIVE] (**%12**, %13) = begin_apply %11(%10) : $@yield_once @convention(method) (@inout HasCoroutineModifyAccessors) -> @yields @inout Float +// CHECK: [VARIED] (%12, **%13**) = begin_apply %11(%10) : $@yield_once @convention(method) (@inout HasCoroutineModifyAccessors) -> @yields @inout Float +// CHECK: [NONE] // function_ref static Float.*= infix(_:_:) +// CHECK: %14 = function_ref @$sSf2meoiyySfz_SftFZ : $@convention(method) (@inout Float, Float, @thin Float.Type) -> () +// CHECK: [NONE] %15 = apply %14(%12, %8, %4) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> () +// CHECK: [VARIED] %16 = end_apply %13 as $() +// CHECK: [ACTIVE] %18 = begin_access [read] [static] %2 : $*HasCoroutineModifyAccessors +// CHECK: [ACTIVE] %19 = load [trivial] %18 : $*HasCoroutineModifyAccessors +// CHECK: [NONE] // function_ref HasCoroutineModifyAccessors.computed.getter +// CHECK: %20 = function_ref @$s17activity_analysis27HasCoroutineModifyAccessorsV8computedSfvg : $@convention(method) (HasCoroutineModifyAccessors) -> Float +// CHECK: [ACTIVE] %21 = apply %20(%19) : $@convention(method) (HasCoroutineModifyAccessors) -> Float // TF-1078: Test `begin_apply` active `inout` argument. // `Array.subscript.modify` is the applied coroutine. -// expected-error @+1 {{function is not differentiable}} @differentiable(reverse) -// expected-note @+1 {{when differentiating this function definition}} func testBeginApplyActiveInoutArgument(array: [Float], x: Float) -> Float { var array = array // Array subscript assignment below calls `Array.subscript.modify`. - // expected-note @+1 {{differentiation of coroutine calls is not yet supported}} + // expected-error @+2 {{expression is not differentiable}} + // expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}} array[0] = x return array[0] } @@ -648,14 +676,13 @@ func testBeginApplyActiveInoutArgument(array: [Float], x: Float) -> Float { // TF-1115: Test `begin_apply` active `inout` argument with non-active initial result. -// expected-error @+1 {{function is not differentiable}} @differentiable(reverse) -// expected-note @+1 {{when differentiating this function definition}} func testBeginApplyActiveButInitiallyNonactiveInoutArgument(x: Float) -> Float { // `var array` is initially non-active. var array: [Float] = [0] // Array subscript assignment below calls `Array.subscript.modify`. - // expected-note @+1 {{differentiation of coroutine calls is not yet supported}} + // expected-error @+2 {{expression is not differentiable}} + // expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}} array[0] = x return array[0] } diff --git a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift index e149bcb3116db..d8219c6e5e52e 100644 --- a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift +++ b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift @@ -672,8 +672,10 @@ extension DifferentiableWrapper: Differentiable where Value: Differentiable {} // accesses. struct Struct: Differentiable { - // expected-error @+2 {{expression is not differentiable}} - // expected-note @+1 {{cannot differentiate access to property 'Struct._x' because 'Struct.TangentVector' does not have a stored property named '_x'}} + // expected-error @+4 {{expression is not differentiable}} + // expected-error @+3 {{expression is not differentiable}} + // expected-note @+2 {{cannot differentiate access to property 'Struct._x' because 'Struct.TangentVector' does not have a stored property named '_x'}} + // expected-note @+1 {{cannot differentiate access to property 'Struct._x' because 'Struct.TangentVector' does not have a stored property named '_x'}} @DifferentiableWrapper @DifferentiableWrapper var x: Float = 10 @Wrapper var y: Float = 20 @@ -696,12 +698,9 @@ func projectedValueAccess(_ s: Struct) -> Float { // https://github.com/apple/swift/issues/55084 // Test `wrapperValue.modify` differentiation. -// expected-error @+2 {{function is not differentiable}} -// expected-note @+2 {{when differentiating this function definition}} @differentiable(reverse) func modify(_ s: Struct, _ x: Float) -> Float { var s = s - // expected-note @+1 {{differentiation of coroutine calls is not yet supported}} s.x *= x * s.z return s.x } @@ -806,45 +805,56 @@ public func f_54819( // Coroutines (SIL function yields, `begin_apply`) (not yet supported) //===----------------------------------------------------------------------===// -struct HasCoroutineAccessors: Differentiable { +struct HasReadAccessors: Differentiable { var stored: Float var computed: Float { // `_read` is a coroutine: `(Self) -> () -> ()`. _read { yield stored } + } +} + +struct HasModifyAccessors: Differentiable { + var stored: Float + var computed: Float { + get { stored } // `_modify` is a coroutine: `(inout Self) -> () -> ()`. _modify { yield &stored } } } + // expected-error @+2 {{function is not differentiable}} // expected-note @+2 {{when differentiating this function definition}} @differentiable(reverse) -func testAccessorCoroutines(_ x: HasCoroutineAccessors) -> HasCoroutineAccessors { +func testReadAccessorCoroutines(_ x: HasReadAccessors) -> Float { + // expected-note @+1 {{cannot differentiate through a '_read' accessor}} + return x.computed +} + +@differentiable(reverse) +func testModifyAccessorCoroutines(_ x: HasModifyAccessors) -> Float { var x = x - // expected-note @+1 {{differentiation of coroutine calls is not yet supported}} x.computed = x.computed - return x + return x.computed } // TF-1078: Diagnose `_modify` accessor application with active `inout` argument. -// expected-error @+2 {{function is not differentiable}} -// expected-note @+2 {{when differentiating this function definition}} @differentiable(reverse) func TF_1078(array: [Float], x: Float) -> Float { var array = array // Array subscript assignment below calls `Array.subscript.modify`. - // expected-note @+1 {{differentiation of coroutine calls is not yet supported}} + // expected-error @+2 {{expression is not differentiable}} + // expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}} array[0] = x return array[0] } // TF-1115: Diagnose `_modify` accessor application with initially non-active `inout` argument. -// expected-error @+2 {{function is not differentiable}} -// expected-note @+2 {{when differentiating this function definition}} @differentiable(reverse) func TF_1115(_ x: Float) -> Float { var array: [Float] = [0] // Array subscript assignment below calls `Array.subscript.modify`. - // expected-note @+1 {{differentiation of coroutine calls is not yet supported}} + // expected-error @+2 {{expression is not differentiable}} + // expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}} array[0] = x return array[0] } @@ -862,13 +872,10 @@ extension Float { } } -// expected-error @+2 {{function is not differentiable}} -// expected-note @+2 {{when differentiating this function definition}} @differentiable(reverse) func TF_1115_modifyNonSelfProjection(x: Float) -> Float { var result: Float = 0 // Assignment below calls `Float.projection.modify`. - // expected-note @+1 {{differentiation of coroutine calls is not yet supported}} result.projection = x return result }