From 7d81ad8e175414b6e274cd83260ec9cacfda1cc9 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Wed, 25 Nov 2020 23:24:23 -0800 Subject: [PATCH] [AutoDiff] Bump-pointer allocate pullback structs in loops. In derivatives of loops, no longer allocate boxes for indirect case payloads. Instead, use a custom pullback context in the runtime which contains a bump-pointer allocator. --- include/swift/AST/ASTContext.h | 3 + include/swift/AST/Builtins.def | 9 + include/swift/Runtime/RuntimeFunctions.def | 24 +++ .../SILOptimizer/Differentiation/Common.h | 10 ++ .../Differentiation/LinearMapInfo.h | 21 ++- .../SILOptimizer/Differentiation/VJPCloner.h | 2 + lib/AST/ASTMangler.cpp | 10 +- lib/AST/ASTVerifier.cpp | 8 +- lib/AST/Availability.cpp | 4 + lib/AST/Builtins.cpp | 28 +++ lib/IDE/CodeCompletion.cpp | 4 +- lib/IRGen/GenBuiltin.cpp | 22 +++ lib/IRGen/GenCall.cpp | 29 +++ lib/IRGen/GenCall.h | 7 + lib/IRGen/IRGenModule.cpp | 8 + lib/SIL/IR/OperandOwnership.cpp | 3 + lib/SIL/IR/ValueOwnership.cpp | 3 + lib/SIL/Utils/MemAccessUtils.cpp | 2 + lib/SILGen/SILGenBuiltin.cpp | 42 +++++ lib/SILOptimizer/Differentiation/Common.cpp | 29 +++ .../Differentiation/JVPCloner.cpp | 10 +- .../Differentiation/LinearMapInfo.cpp | 40 +++-- .../Differentiation/PullbackCloner.cpp | 67 ++++--- .../Differentiation/VJPCloner.cpp | 165 +++++++++++++----- .../AccessEnforcementReleaseSinking.cpp | 3 + lib/Sema/MiscDiagnostics.cpp | 4 +- lib/Serialization/Serialization.cpp | 4 +- stdlib/public/Differentiation/CMakeLists.txt | 2 + stdlib/public/SwiftShims/Visibility.h | 5 + stdlib/public/runtime/AutoDiffSupport.cpp | 72 ++++++++ stdlib/public/runtime/AutoDiffSupport.h | 56 ++++++ stdlib/public/runtime/CMakeLists.txt | 1 + test/AutoDiff/IRGen/runtime.swift | 24 +++ test/AutoDiff/SILGen/autodiff_builtins.swift | 23 +++ 34 files changed, 644 insertions(+), 100 deletions(-) create mode 100644 stdlib/public/runtime/AutoDiffSupport.cpp create mode 100644 stdlib/public/runtime/AutoDiffSupport.h create mode 100644 test/AutoDiff/IRGen/runtime.swift diff --git a/include/swift/AST/ASTContext.h b/include/swift/AST/ASTContext.h index 5e536c5aaaea1..1be39fbd6956f 100644 --- a/include/swift/AST/ASTContext.h +++ b/include/swift/AST/ASTContext.h @@ -713,6 +713,9 @@ class ASTContext final { /// Get the runtime availability of support for concurrency. AvailabilityContext getConcurrencyAvailability(); + /// Get the runtime availability of support for differentiation. + AvailabilityContext getDifferentiationAvailability(); + /// Get the runtime availability of features introduced in the Swift 5.2 /// compiler for the target platform. AvailabilityContext getSwift52Availability(); diff --git a/include/swift/AST/Builtins.def b/include/swift/AST/Builtins.def index 457b4c1792eed..6b774261ad4a1 100644 --- a/include/swift/AST/Builtins.def +++ b/include/swift/AST/Builtins.def @@ -752,6 +752,15 @@ BUILTIN_MISC_OPERATION_WITH_SILGEN(CreateAsyncTaskFuture, /// is a pure value and therefore we can consider it as readnone). BUILTIN_MISC_OPERATION_WITH_SILGEN(GlobalStringTablePointer, "globalStringTablePointer", "n", Special) +// autoDiffCreateLinearMapContext: (Builtin.Word) -> Builtin.NativeObject +BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffCreateLinearMapContext, "autoDiffCreateLinearMapContext", "n", Special) + +// autoDiffProjectTopLevelSubcontext: (Builtin.NativeObject) -> Builtin.RawPointer +BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffProjectTopLevelSubcontext, "autoDiffProjectTopLevelSubcontext", "n", Special) + +// autoDiffAllocateSubcontext: (Builtin.NativeObject, Builtin.Word) -> Builtin.RawPointer +BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffAllocateSubcontext, "autoDiffAllocateSubcontext", "", Special) + #undef BUILTIN_MISC_OPERATION_WITH_SILGEN #undef BUILTIN_MISC_OPERATION diff --git a/include/swift/Runtime/RuntimeFunctions.def b/include/swift/Runtime/RuntimeFunctions.def index bc9278e9bab72..d3e6ce86da1ef 100644 --- a/include/swift/Runtime/RuntimeFunctions.def +++ b/include/swift/Runtime/RuntimeFunctions.def @@ -1518,6 +1518,30 @@ FUNCTION(TaskCreateFutureFunc, TaskContinuationFunctionPtrTy, SizeTy), ATTRS(NoUnwind, ArgMemOnly)) +// AutoDiffLinearMapContext *swift_autoDiffCreateLinearMapContext(size_t); +FUNCTION(AutoDiffCreateLinearMapContext, + swift_autoDiffCreateLinearMapContext, SwiftCC, + DifferentiationAvailability, + RETURNS(RefCountedPtrTy), + ARGS(SizeTy), + ATTRS(NoUnwind, ArgMemOnly)) + +// void *swift_autoDiffProjectTopLevelSubcontext(AutoDiffLinearMapContext *); +FUNCTION(AutoDiffProjectTopLevelSubcontext, + swift_autoDiffProjectTopLevelSubcontext, SwiftCC, + DifferentiationAvailability, + RETURNS(Int8PtrTy), + ARGS(RefCountedPtrTy), + ATTRS(NoUnwind, ArgMemOnly)) + +// void *swift_autoDiffAllocateSubcontext(AutoDiffLinearMapContext *, size_t); +FUNCTION(AutoDiffAllocateSubcontext, + swift_autoDiffAllocateSubcontext, SwiftCC, + DifferentiationAvailability, + RETURNS(Int8PtrTy), + ARGS(RefCountedPtrTy, SizeTy), + ATTRS(NoUnwind, ArgMemOnly)) + #undef RETURNS #undef ARGS #undef ATTRS diff --git a/include/swift/SILOptimizer/Differentiation/Common.h b/include/swift/SILOptimizer/Differentiation/Common.h index 90e3c26690928..04ebf1d86d5b9 100644 --- a/include/swift/SILOptimizer/Differentiation/Common.h +++ b/include/swift/SILOptimizer/Differentiation/Common.h @@ -192,6 +192,16 @@ void extractAllElements(SILValue value, SILBuilder &builder, void emitZeroIntoBuffer(SILBuilder &builder, CanType type, SILValue bufferAccess, SILLocation loc); +/// Emit a `Builtin.Word` value that represents the given type's memory layout +/// size. +SILValue emitMemoryLayoutSize( + SILBuilder &builder, SILLocation loc, CanType type); + +/// Emit a projection of the top-level subcontext from the context object. +SILValue emitProjectTopLevelSubcontext( + SILBuilder &builder, SILLocation loc, SILValue context, + SILType subcontextType); + //===----------------------------------------------------------------------===// // Utilities for looking up derivatives of functions //===----------------------------------------------------------------------===// diff --git a/include/swift/SILOptimizer/Differentiation/LinearMapInfo.h b/include/swift/SILOptimizer/Differentiation/LinearMapInfo.h index 4a643d304597f..5a392313a5cf1 100644 --- a/include/swift/SILOptimizer/Differentiation/LinearMapInfo.h +++ b/include/swift/SILOptimizer/Differentiation/LinearMapInfo.h @@ -63,6 +63,9 @@ class LinearMapInfo { /// Activity info of the original function. const DifferentiableActivityInfo &activityInfo; + /// The original function's loop info. + SILLoopInfo *loopInfo; + /// Differentiation indices of the function. const SILAutoDiffIndices indices; @@ -86,6 +89,9 @@ class LinearMapInfo { /// Mapping from linear map structs to their branching trace enum fields. llvm::DenseMap linearMapStructEnumFields; + /// Blocks in a loop. + llvm::SmallSetVector blocksInLoop; + /// A synthesized file unit. SynthesizedFileUnit &synthesizedFile; @@ -144,7 +150,8 @@ class LinearMapInfo { explicit LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind, SILFunction *original, SILFunction *derivative, SILAutoDiffIndices indices, - const DifferentiableActivityInfo &activityInfo); + const DifferentiableActivityInfo &activityInfo, + SILLoopInfo *loopInfo); /// Returns the linear map struct associated with the given original block. StructDecl *getLinearMapStruct(SILBasicBlock *origBB) const { @@ -200,20 +207,28 @@ class LinearMapInfo { /// Returns the branching trace enum field for the linear map struct of the /// given original block. - VarDecl *lookUpLinearMapStructEnumField(SILBasicBlock *origBB) { + VarDecl *lookUpLinearMapStructEnumField(SILBasicBlock *origBB) const { auto *linearMapStruct = getLinearMapStruct(origBB); return linearMapStructEnumFields.lookup(linearMapStruct); } /// Finds the linear map declaration in the pullback struct for the given /// `apply` instruction in the original function. - VarDecl *lookUpLinearMapDecl(ApplyInst *ai) { + VarDecl *lookUpLinearMapDecl(ApplyInst *ai) const { assert(ai->getFunction() == original); auto lookup = linearMapFieldMap.find(ai); assert(lookup != linearMapFieldMap.end() && "No linear map field corresponding to the given `apply`"); return lookup->getSecond(); } + + bool hasLoops() const { + return !blocksInLoop.empty(); + } + + ArrayRef getBlocksInLoop() const { + return blocksInLoop.getArrayRef(); + } }; } // end namespace autodiff diff --git a/include/swift/SILOptimizer/Differentiation/VJPCloner.h b/include/swift/SILOptimizer/Differentiation/VJPCloner.h index ff8ce874021e6..d0b722e5fd29e 100644 --- a/include/swift/SILOptimizer/Differentiation/VJPCloner.h +++ b/include/swift/SILOptimizer/Differentiation/VJPCloner.h @@ -21,6 +21,7 @@ #include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" #include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" #include "swift/SILOptimizer/Differentiation/LinearMapInfo.h" +#include "swift/SIL/LoopInfo.h" namespace swift { namespace autodiff { @@ -52,6 +53,7 @@ class VJPCloner final { const SILAutoDiffIndices getIndices() const; DifferentiationInvoker getInvoker() const; LinearMapInfo &getPullbackInfo() const; + SILLoopInfo *getLoopInfo() const; const DifferentiableActivityInfo &getActivityInfo() const; /// Performs VJP generation on the empty VJP function. Returns true if any diff --git a/lib/AST/ASTMangler.cpp b/lib/AST/ASTMangler.cpp index b6433344dc0e4..f4af7974f9db1 100644 --- a/lib/AST/ASTMangler.cpp +++ b/lib/AST/ASTMangler.cpp @@ -785,8 +785,8 @@ static StringRef getPrivateDiscriminatorIfNecessary(const ValueDecl *decl) { // Mangle non-local private declarations with a textual discriminator // based on their enclosing file. - auto topLevelContext = decl->getDeclContext()->getModuleScopeContext(); - auto fileUnit = cast(topLevelContext); + auto topLevelSubcontext = decl->getDeclContext()->getModuleScopeContext(); + auto fileUnit = cast(topLevelSubcontext); Identifier discriminator = fileUnit->getDiscriminatorForPrivateValue(decl); @@ -2900,9 +2900,9 @@ void ASTMangler::appendEntity(const ValueDecl *decl) { void ASTMangler::appendProtocolConformance(const ProtocolConformance *conformance) { GenericSignature contextSig; - auto topLevelContext = + auto topLevelSubcontext = conformance->getDeclContext()->getModuleScopeContext(); - Mod = topLevelContext->getParentModule(); + Mod = topLevelSubcontext->getParentModule(); auto conformingType = conformance->getType(); appendType(conformingType->getCanonicalType()); @@ -2910,7 +2910,7 @@ ASTMangler::appendProtocolConformance(const ProtocolConformance *conformance) { appendProtocolName(conformance->getProtocol()); bool needsModule = true; - if (auto *file = dyn_cast(topLevelContext)) { + if (auto *file = dyn_cast(topLevelSubcontext)) { if (file->getKind() == FileUnitKind::ClangModule || file->getKind() == FileUnitKind::DWARFModule) { if (conformance->getProtocol()->hasClangNode()) diff --git a/lib/AST/ASTVerifier.cpp b/lib/AST/ASTVerifier.cpp index 6f502b96319c1..200a485eec0e4 100644 --- a/lib/AST/ASTVerifier.cpp +++ b/lib/AST/ASTVerifier.cpp @@ -229,7 +229,7 @@ class Verifier : public ASTWalker { typedef llvm::PointerIntPair ClosureDiscriminatorKey; llvm::DenseMap ClosureDiscriminators; - DeclContext *CanonicalTopLevelContext = nullptr; + DeclContext *CanonicalTopLevelSubcontext = nullptr; Verifier(PointerUnion M, DeclContext *DC) : M(M), @@ -898,9 +898,9 @@ class Verifier : public ASTWalker { DeclContext *getCanonicalDeclContext(DeclContext *DC) { // All we really need to do is use a single TopLevelCodeDecl. if (auto topLevel = dyn_cast(DC)) { - if (!CanonicalTopLevelContext) - CanonicalTopLevelContext = topLevel; - return CanonicalTopLevelContext; + if (!CanonicalTopLevelSubcontext) + CanonicalTopLevelSubcontext = topLevel; + return CanonicalTopLevelSubcontext; } // TODO: check for uniqueness of initializer contexts? diff --git a/lib/AST/Availability.cpp b/lib/AST/Availability.cpp index e042a3dda668f..1fee0851ffcf6 100644 --- a/lib/AST/Availability.cpp +++ b/lib/AST/Availability.cpp @@ -327,6 +327,10 @@ AvailabilityContext ASTContext::getConcurrencyAvailability() { return getSwiftFutureAvailability(); } +AvailabilityContext ASTContext::getDifferentiationAvailability() { + return getSwiftFutureAvailability(); +} + AvailabilityContext ASTContext::getSwift52Availability() { auto target = LangOpts.Target; diff --git a/lib/AST/Builtins.cpp b/lib/AST/Builtins.cpp index 6125b87112d9e..99ce7c5917168 100644 --- a/lib/AST/Builtins.cpp +++ b/lib/AST/Builtins.cpp @@ -1383,6 +1383,25 @@ static ValueDecl *getCreateAsyncTaskFuture(ASTContext &ctx, Identifier id) { return builder.build(id); } +static ValueDecl *getAutoDiffCreateLinearMapContext(ASTContext &ctx, + Identifier id) { + return getBuiltinFunction( + id, {BuiltinIntegerType::getWordType(ctx)}, ctx.TheNativeObjectType); +} + +static ValueDecl *getAutoDiffProjectTopLevelSubcontext(ASTContext &ctx, + Identifier id) { + return getBuiltinFunction( + id, {ctx.TheNativeObjectType}, ctx.TheRawPointerType); +} + +static ValueDecl *getAutoDiffAllocateSubcontext(ASTContext &ctx, + Identifier id) { + return getBuiltinFunction( + id, {ctx.TheNativeObjectType, BuiltinIntegerType::getWordType(ctx)}, + ctx.TheRawPointerType); +} + static ValueDecl *getPoundAssert(ASTContext &Context, Identifier Id) { auto int1Type = BuiltinIntegerType::get(1, Context); auto optionalRawPointerType = BoundGenericEnumType::get( @@ -2549,6 +2568,15 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) { case BuiltinValueKind::TriggerFallbackDiagnostic: return getTriggerFallbackDiagnosticOperation(Context, Id); + + case BuiltinValueKind::AutoDiffCreateLinearMapContext: + return getAutoDiffCreateLinearMapContext(Context, Id); + + case BuiltinValueKind::AutoDiffProjectTopLevelSubcontext: + return getAutoDiffProjectTopLevelSubcontext(Context, Id); + + case BuiltinValueKind::AutoDiffAllocateSubcontext: + return getAutoDiffAllocateSubcontext(Context, Id); } llvm_unreachable("bad builtin value!"); diff --git a/lib/IDE/CodeCompletion.cpp b/lib/IDE/CodeCompletion.cpp index 48433efc41bd6..a724b609c33b0 100644 --- a/lib/IDE/CodeCompletion.cpp +++ b/lib/IDE/CodeCompletion.cpp @@ -1677,7 +1677,7 @@ class CodeCompletionCallbacksImpl : public CodeCompletionCallbacks { } // end anonymous namespace namespace { -static bool isTopLevelContext(const DeclContext *DC) { +static bool isTopLevelSubcontext(const DeclContext *DC) { for (; DC && DC->isLocalContext(); DC = DC->getParent()) { switch (DC->getContextKind()) { case DeclContextKind::TopLevelCodeDecl: @@ -2139,7 +2139,7 @@ class CompletionLookup final : public swift::VisibleDeclConsumer { if (CurrDeclContext && D->getModuleContext() == CurrModule) { // Treat global variables from the same source file as local when // completing at top-level. - if (isa(D) && isTopLevelContext(CurrDeclContext) && + if (isa(D) && isTopLevelSubcontext(CurrDeclContext) && D->getDeclContext()->getParentSourceFile() == CurrDeclContext->getParentSourceFile()) { return SemanticContextKind::Local; diff --git a/lib/IRGen/GenBuiltin.cpp b/lib/IRGen/GenBuiltin.cpp index 58efe1a268d23..8ad1baba8efb3 100644 --- a/lib/IRGen/GenBuiltin.cpp +++ b/lib/IRGen/GenBuiltin.cpp @@ -1115,5 +1115,27 @@ if (Builtin.ID == BuiltinValueKind::id) { \ return; } + if (Builtin.ID == BuiltinValueKind::AutoDiffCreateLinearMapContext) { + auto topLevelSubcontextSize = args.claimNext(); + out.add(emitAutoDiffCreateLinearMapContext(IGF, topLevelSubcontextSize) + .getAddress()); + return; + } + + if (Builtin.ID == BuiltinValueKind::AutoDiffProjectTopLevelSubcontext) { + Address allocatorAddr(args.claimNext(), IGF.IGM.getPointerAlignment()); + out.add( + emitAutoDiffProjectTopLevelSubcontext(IGF, allocatorAddr).getAddress()); + return; + } + + if (Builtin.ID == BuiltinValueKind::AutoDiffAllocateSubcontext) { + Address allocatorAddr(args.claimNext(), IGF.IGM.getPointerAlignment()); + auto size = args.claimNext(); + out.add( + emitAutoDiffAllocateSubcontext(IGF, allocatorAddr, size).getAddress()); + return; + } + llvm_unreachable("IRGen unimplemented for this builtin!"); } diff --git a/lib/IRGen/GenCall.cpp b/lib/IRGen/GenCall.cpp index bc72e83a5b911..a2005e6ad1f96 100644 --- a/lib/IRGen/GenCall.cpp +++ b/lib/IRGen/GenCall.cpp @@ -4595,3 +4595,32 @@ IRGenFunction::getFunctionPointerForResumeIntrinsic(llvm::Value *resume) { PointerAuthInfo(), signature); return fnPtr; } + +Address irgen::emitAutoDiffCreateLinearMapContext( + IRGenFunction &IGF, llvm::Value *topLevelSubcontextSize) { + auto *call = IGF.Builder.CreateCall( + IGF.IGM.getAutoDiffCreateLinearMapContextFn(), {topLevelSubcontextSize}); + call->setDoesNotThrow(); + call->setCallingConv(IGF.IGM.SwiftCC); + return Address(call, IGF.IGM.getPointerAlignment()); +} + +Address irgen::emitAutoDiffProjectTopLevelSubcontext( + IRGenFunction &IGF, Address context) { + auto *call = IGF.Builder.CreateCall( + IGF.IGM.getAutoDiffProjectTopLevelSubcontextFn(), + {context.getAddress()}); + call->setDoesNotThrow(); + call->setCallingConv(IGF.IGM.SwiftCC); + return Address(call, IGF.IGM.getPointerAlignment()); +} + +Address irgen::emitAutoDiffAllocateSubcontext( + IRGenFunction &IGF, Address context, llvm::Value *size) { + auto *call = IGF.Builder.CreateCall( + IGF.IGM.getAutoDiffAllocateSubcontextFn(), + {context.getAddress(), size}); + call->setDoesNotThrow(); + call->setCallingConv(IGF.IGM.SwiftCC); + return Address(call, IGF.IGM.getPointerAlignment()); +} diff --git a/lib/IRGen/GenCall.h b/lib/IRGen/GenCall.h index f97833fc77e5e..f3bbddf5e3c19 100644 --- a/lib/IRGen/GenCall.h +++ b/lib/IRGen/GenCall.h @@ -432,6 +432,13 @@ namespace irgen { void emitAsyncReturn(IRGenFunction &IGF, AsyncContextLayout &layout, CanSILFunctionType fnType); + + Address emitAutoDiffCreateLinearMapContext( + IRGenFunction &IGF, llvm::Value *topLevelSubcontextSize); + Address emitAutoDiffProjectTopLevelSubcontext( + IRGenFunction &IGF, Address context); + Address emitAutoDiffAllocateSubcontext( + IRGenFunction &IGF, Address context, llvm::Value *size); } // end namespace irgen } // end namespace swift diff --git a/lib/IRGen/IRGenModule.cpp b/lib/IRGen/IRGenModule.cpp index 11801186efed5..6ab1eced84c24 100644 --- a/lib/IRGen/IRGenModule.cpp +++ b/lib/IRGen/IRGenModule.cpp @@ -735,6 +735,14 @@ namespace RuntimeConstants { } return RuntimeAvailability::AlwaysAvailable; } + + RuntimeAvailability DifferentiationAvailability(ASTContext &context) { + auto featureAvailability = context.getDifferentiationAvailability(); + if (!isDeploymentAvailabilityContainedIn(context, featureAvailability)) { + return RuntimeAvailability::ConditionallyAvailable; + } + return RuntimeAvailability::AlwaysAvailable; + } } // namespace RuntimeConstants // We don't use enough attributes to justify generalizing the diff --git a/lib/SIL/IR/OperandOwnership.cpp b/lib/SIL/IR/OperandOwnership.cpp index d541489df2495..50d7b67922923 100644 --- a/lib/SIL/IR/OperandOwnership.cpp +++ b/lib/SIL/IR/OperandOwnership.cpp @@ -880,6 +880,9 @@ CONSTANT_OWNERSHIP_BUILTIN(Owned, LifetimeEnding, UnsafeGuaranteed) CONSTANT_OWNERSHIP_BUILTIN(Guaranteed, NonLifetimeEnding, CancelAsyncTask) CONSTANT_OWNERSHIP_BUILTIN(Guaranteed, NonLifetimeEnding, CreateAsyncTask) CONSTANT_OWNERSHIP_BUILTIN(Guaranteed, NonLifetimeEnding, CreateAsyncTaskFuture) +CONSTANT_OWNERSHIP_BUILTIN(None, NonLifetimeEnding, AutoDiffCreateLinearMapContext) +CONSTANT_OWNERSHIP_BUILTIN(Guaranteed, NonLifetimeEnding, AutoDiffAllocateSubcontext) +CONSTANT_OWNERSHIP_BUILTIN(Guaranteed, NonLifetimeEnding, AutoDiffProjectTopLevelSubcontext) #undef CONSTANT_OWNERSHIP_BUILTIN diff --git a/lib/SIL/IR/ValueOwnership.cpp b/lib/SIL/IR/ValueOwnership.cpp index 1aaa9541a2d66..564b46ae3864d 100644 --- a/lib/SIL/IR/ValueOwnership.cpp +++ b/lib/SIL/IR/ValueOwnership.cpp @@ -545,6 +545,9 @@ CONSTANT_OWNERSHIP_BUILTIN(None, GetCurrentAsyncTask) CONSTANT_OWNERSHIP_BUILTIN(None, CancelAsyncTask) CONSTANT_OWNERSHIP_BUILTIN(Owned, CreateAsyncTask) CONSTANT_OWNERSHIP_BUILTIN(Owned, CreateAsyncTaskFuture) +CONSTANT_OWNERSHIP_BUILTIN(Owned, AutoDiffCreateLinearMapContext) +CONSTANT_OWNERSHIP_BUILTIN(None, AutoDiffProjectTopLevelSubcontext) +CONSTANT_OWNERSHIP_BUILTIN(None, AutoDiffAllocateSubcontext) #undef CONSTANT_OWNERSHIP_BUILTIN diff --git a/lib/SIL/Utils/MemAccessUtils.cpp b/lib/SIL/Utils/MemAccessUtils.cpp index e01519e9b1573..84d47c80aa822 100644 --- a/lib/SIL/Utils/MemAccessUtils.cpp +++ b/lib/SIL/Utils/MemAccessUtils.cpp @@ -1805,6 +1805,8 @@ static void visitBuiltinAddress(BuiltinInst *builtin, case BuiltinValueKind::CancelAsyncTask: case BuiltinValueKind::CreateAsyncTask: case BuiltinValueKind::CreateAsyncTaskFuture: + case BuiltinValueKind::AutoDiffCreateLinearMapContext: + case BuiltinValueKind::AutoDiffAllocateSubcontext: return; // General memory access to a pointer in first operand position. diff --git a/lib/SILGen/SILGenBuiltin.cpp b/lib/SILGen/SILGenBuiltin.cpp index 1d4484d8bec62..5ea705de57a00 100644 --- a/lib/SILGen/SILGenBuiltin.cpp +++ b/lib/SILGen/SILGenBuiltin.cpp @@ -1462,6 +1462,48 @@ static ManagedValue emitBuiltinCreateAsyncTaskFuture( return SGF.emitManagedRValueWithCleanup(apply); } +static ManagedValue emitBuiltinAutoDiffCreateLinearMapContext( + SILGenFunction &SGF, SILLocation loc, SubstitutionMap subs, + ArrayRef args, SGFContext C) { + ASTContext &ctx = SGF.getASTContext(); + auto *builtinApply = SGF.B.createBuiltin( + loc, + ctx.getIdentifier( + getBuiltinName(BuiltinValueKind::AutoDiffCreateLinearMapContext)), + SILType::getNativeObjectType(ctx), + subs, + /*args*/ {args[0].getValue()}); + return SGF.emitManagedRValueWithCleanup(builtinApply); +} + +static ManagedValue emitBuiltinAutoDiffProjectTopLevelSubcontext( + SILGenFunction &SGF, SILLocation loc, SubstitutionMap subs, + ArrayRef args, SGFContext C) { + ASTContext &ctx = SGF.getASTContext(); + auto *builtinApply = SGF.B.createBuiltin( + loc, + ctx.getIdentifier( + getBuiltinName(BuiltinValueKind::AutoDiffProjectTopLevelSubcontext)), + SILType::getRawPointerType(ctx), + subs, + /*args*/ {args[0].borrow(SGF, loc).getValue()}); + return ManagedValue::forUnmanaged(builtinApply); +} + +static ManagedValue emitBuiltinAutoDiffAllocateSubcontext( + SILGenFunction &SGF, SILLocation loc, SubstitutionMap subs, + ArrayRef args, SGFContext C) { + ASTContext &ctx = SGF.getASTContext(); + auto *builtinApply = SGF.B.createBuiltin( + loc, + ctx.getIdentifier( + getBuiltinName(BuiltinValueKind::AutoDiffAllocateSubcontext)), + SILType::getRawPointerType(ctx), + subs, + /*args*/ {args[0].borrow(SGF, loc).getValue(), args[1].getValue()}); + return ManagedValue::forUnmanaged(builtinApply); +} + Optional SpecializedEmitter::forDecl(SILGenModule &SGM, SILDeclRef function) { // Only consider standalone declarations in the Builtin module. diff --git a/lib/SILOptimizer/Differentiation/Common.cpp b/lib/SILOptimizer/Differentiation/Common.cpp index 20cdd11a781fe..730c646b73de6 100644 --- a/lib/SILOptimizer/Differentiation/Common.cpp +++ b/lib/SILOptimizer/Differentiation/Common.cpp @@ -399,6 +399,35 @@ void emitZeroIntoBuffer(SILBuilder &builder, CanType type, builder.emitDestroyValueOperation(loc, getter); } +SILValue emitMemoryLayoutSize( + SILBuilder &builder, SILLocation loc, CanType type) { + auto &ctx = builder.getASTContext(); + auto id = ctx.getIdentifier(getBuiltinName(BuiltinValueKind::Sizeof)); + auto *builtin = cast(getBuiltinValueDecl(ctx, id)); + auto metatypeTy = SILType::getPrimitiveObjectType( + CanMetatypeType::get(type, MetatypeRepresentation::Thin)); + auto metatypeVal = builder.createMetatype(loc, metatypeTy); + return builder.createBuiltin( + loc, id, SILType::getBuiltinWordType(ctx), + SubstitutionMap::get( + builtin->getGenericSignature(), ArrayRef{type}, {}), + {metatypeVal}); +} + +SILValue emitProjectTopLevelSubcontext( + SILBuilder &builder, SILLocation loc, SILValue context, + SILType subcontextType) { + assert(context.getOwnershipKind() == OwnershipKind::Guaranteed); + auto &ctx = builder.getASTContext(); + auto id = ctx.getIdentifier( + getBuiltinName(BuiltinValueKind::AutoDiffProjectTopLevelSubcontext)); + assert(context->getType() == SILType::getNativeObjectType(ctx)); + auto *subcontextAddr = builder.createBuiltin( + loc, id, SILType::getRawPointerType(ctx), SubstitutionMap(), {context}); + return builder.createPointerToAddress( + loc, subcontextAddr, subcontextType.getAddressType(), /*isStrict*/ true); +} + //===----------------------------------------------------------------------===// // Utilities for looking up derivatives of functions //===----------------------------------------------------------------------===// diff --git a/lib/SILOptimizer/Differentiation/JVPCloner.cpp b/lib/SILOptimizer/Differentiation/JVPCloner.cpp index b6ee6518ba75a..43d0d2c621dc0 100644 --- a/lib/SILOptimizer/Differentiation/JVPCloner.cpp +++ b/lib/SILOptimizer/Differentiation/JVPCloner.cpp @@ -26,7 +26,9 @@ #include "swift/SILOptimizer/Differentiation/PullbackCloner.h" #include "swift/SILOptimizer/Differentiation/Thunk.h" +#include "swift/SIL/LoopInfo.h" #include "swift/SIL/TypeSubstCloner.h" +#include "swift/SILOptimizer/Analysis/LoopAnalysis.h" #include "swift/SILOptimizer/PassManager/PrettyStackTrace.h" #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" #include "llvm/ADT/DenseMap.h" @@ -57,6 +59,9 @@ class JVPCloner::Implementation final /// Info from activity analysis on the original function. const DifferentiableActivityInfo &activityInfo; + /// The loop info. + SILLoopInfo *loopInfo; + /// The differential info. LinearMapInfo differentialInfo; @@ -1403,8 +1408,11 @@ JVPCloner::Implementation::Implementation(ADContext &context, invoker(invoker), activityInfo(getActivityInfo(context, original, witness->getSILAutoDiffIndices(), jvp)), + loopInfo(context.getPassManager().getAnalysis() + ->get(original)), differentialInfo(context, AutoDiffLinearMapKind::Differential, original, - jvp, witness->getSILAutoDiffIndices(), activityInfo), + jvp, witness->getSILAutoDiffIndices(), activityInfo, + loopInfo), differentialBuilder(SILBuilder( *createEmptyDifferential(context, witness, &differentialInfo))), diffLocalAllocBuilder(getDifferential()) { diff --git a/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp b/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp index 6de6cee60781e..a7c48fc386da1 100644 --- a/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp +++ b/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp @@ -23,7 +23,6 @@ #include "swift/AST/ParameterList.h" #include "swift/AST/SourceFile.h" #include "swift/SIL/LoopInfo.h" -#include "swift/SILOptimizer/Analysis/LoopAnalysis.h" namespace swift { namespace autodiff { @@ -56,9 +55,10 @@ static GenericParamList *cloneGenericParameters(ASTContext &ctx, LinearMapInfo::LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind, SILFunction *original, SILFunction *derivative, SILAutoDiffIndices indices, - const DifferentiableActivityInfo &activityInfo) + const DifferentiableActivityInfo &activityInfo, + SILLoopInfo *loopInfo) : kind(kind), original(original), derivative(derivative), - activityInfo(activityInfo), indices(indices), + activityInfo(activityInfo), loopInfo(loopInfo), indices(indices), synthesizedFile(context.getOrCreateSynthesizedFile(original)), typeConverter(context.getTypeConverter()) { generateDifferentiationDataStructures(context, derivative); @@ -146,21 +146,30 @@ LinearMapInfo::createBranchingTraceDecl(SILBasicBlock *originalBB, file.addTopLevelDecl(branchingTraceDecl); // Add basic block enum cases. for (auto *predBB : originalBB->getPredecessorBlocks()) { - auto bbId = "bb" + std::to_string(predBB->getDebugID()); - auto *linearMapStruct = getLinearMapStruct(predBB); - assert(linearMapStruct); - auto linearMapStructTy = - linearMapStruct->getDeclaredInterfaceType()->getCanonicalType(); // Create dummy declaration representing enum case parameter. auto *decl = new (astCtx) ParamDecl(loc, loc, Identifier(), loc, Identifier(), moduleDecl); decl->setSpecifier(ParamDecl::Specifier::Default); - if (linearMapStructTy->hasArchetype()) - decl->setInterfaceType(linearMapStructTy->mapTypeOutOfContext()); - else - decl->setInterfaceType(linearMapStructTy); + // If predecessor block is in a loop, its linear map struct will be + // indirectly referenced in memory owned by the context object. The payload + // is just a raw pointer. + if (loopInfo->getLoopFor(predBB)) { + blocksInLoop.insert(predBB); + decl->setInterfaceType(astCtx.TheRawPointerType); + } + // Otherwise the payload is the linear map struct. + else { + auto *linearMapStruct = getLinearMapStruct(predBB); + assert(linearMapStruct); + auto linearMapStructTy = + linearMapStruct->getDeclaredInterfaceType()->getCanonicalType(); + decl->setInterfaceType( + linearMapStructTy->hasArchetype() + ? linearMapStructTy->mapTypeOutOfContext() : linearMapStructTy); + } // Create enum element and enum case declarations. auto *paramList = ParameterList::create(astCtx, {decl}); + auto bbId = "bb" + std::to_string(predBB->getDebugID()); auto *enumEltDecl = new (astCtx) EnumElementDecl( /*IdentifierLoc*/ loc, DeclName(astCtx.getIdentifier(bbId)), paramList, loc, /*RawValueExpr*/ nullptr, branchingTraceDecl); @@ -173,10 +182,6 @@ LinearMapInfo::createBranchingTraceDecl(SILBasicBlock *originalBB, // Record enum element declaration. branchingTraceEnumCases.insert({{predBB, originalBB}, enumEltDecl}); } - // If original block is in a loop, mark branching trace enum as indirect. - if (loopInfo->getLoopFor(originalBB)) - branchingTraceDecl->getAttrs().add(new (astCtx) - IndirectAttr(/*Implicit*/ true)); return branchingTraceDecl; } @@ -359,9 +364,6 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai) { void LinearMapInfo::generateDifferentiationDataStructures( ADContext &context, SILFunction *derivativeFn) { auto &astCtx = original->getASTContext(); - auto *loopAnalysis = context.getPassManager().getAnalysis(); - auto *loopInfo = loopAnalysis->get(original); - // Get the derivative function generic signature. CanGenericSignature derivativeFnGenSig = nullptr; if (auto *derivativeFnGenEnv = derivativeFn->getGenericEnvironment()) diff --git a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp index 2fcd7ea574f1a..f2889c63a191f 100644 --- a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp @@ -86,9 +86,6 @@ class PullbackCloner::Implementation final /// adjoint buffers. llvm::DenseMap, SILValue> bufferMap; - /// Mapping from pullback basic blocks to pullback struct arguments. - llvm::DenseMap pullbackStructArguments; - /// Mapping from pullback struct field declarations to pullback struct /// elements destructured from the linear map basic block argument. In the /// beginning of each pullback basic block, the block's pullback struct is @@ -130,6 +127,9 @@ class PullbackCloner::Implementation final /// The seed arguments of the pullback function. SmallVector seeds; + /// The `AutoDiffLinearMapContext` object, if any. + SILValue contextValue = nullptr; + llvm::BumpPtrAllocator allocator; bool errorOccurred = false; @@ -1895,11 +1895,28 @@ bool PullbackCloner::Implementation::run() { if (origBB == origExit) { assert(pullbackBB->isEntry()); createEntryArguments(&pullback); - auto *mainPullbackStruct = pullbackBB->getArguments().back(); - assert(mainPullbackStruct->getType() == pbStructLoweredType); - pullbackStructArguments[origBB] = mainPullbackStruct; - // Destructure the pullback struct to get the elements. builder.setInsertionPoint(pullbackBB); + // Obtain the context object, if any, and the top-level subcontext, i.e. + // the main pullback struct. + SILValue mainPullbackStruct; + if (getPullbackInfo().hasLoops()) { + // The last argument is the context object (`Builtin.NativeObject`). + contextValue = pullbackBB->getArguments().back(); + assert(contextValue->getType() == + SILType::getNativeObjectType(getASTContext())); + // Load the pullback struct. + auto subcontextAddr = emitProjectTopLevelSubcontext( + builder, pbLoc, contextValue, pbStructLoweredType); + mainPullbackStruct = builder.createLoad( + pbLoc, subcontextAddr, + pbStructLoweredType.isTrivial(getPullback()) ? + LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Take); + } else { + // Obtain and destructure pullback struct elements. + mainPullbackStruct = pullbackBB->getArguments().back(); + assert(mainPullbackStruct->getType() == pbStructLoweredType); + } + auto *dsi = builder.createDestructureStruct(pbLoc, mainPullbackStruct); initializePullbackStructElements(origBB, dsi->getResults()); continue; @@ -1938,7 +1955,6 @@ bool PullbackCloner::Implementation::run() { // Add a pullback struct argument. auto *pbStructArg = pullbackBB->createPhiArgument(pbStructLoweredType, OwnershipKind::Owned); - pullbackStructArguments[origBB] = pbStructArg; // Destructure the pullback struct to get the elements. builder.setInsertionPoint(pullbackBB); auto *dsi = builder.createDestructureStruct(pbLoc, pbStructArg); @@ -1969,7 +1985,7 @@ bool PullbackCloner::Implementation::run() { auto *pullbackEntry = pullback.getEntryBlock(); // The pullback function has type: - // `(seed0, seed1, ..., exit_pb_struct) -> (d_arg0, ..., d_argn)`. + // `(seed0, seed1, ..., exit_pb_struct|context_obj) -> (d_arg0, ..., d_argn)`. auto pbParamArgs = pullback.getArgumentsWithoutIndirectResults(); assert(getIndices().results->getNumIndices() == pbParamArgs.size() - 1 && pbParamArgs.size() >= 2); @@ -2328,17 +2344,22 @@ SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor( } // Propagate pullback struct argument. SILBuilder pullbackTrampolineBBBuilder(pullbackTrampolineBB); - auto *predPBStructVal = pullbackTrampolineBB->getArguments().front(); - auto boxType = dyn_cast(predPBStructVal->getType().getASTType()); - if (!boxType) { - trampolineArguments.push_back(predPBStructVal); + auto *pullbackTrampolineBBArg = pullbackTrampolineBB->getArguments().front(); + if (vjpCloner.getLoopInfo()->getLoopFor(origPredBB)) { + assert(pullbackTrampolineBBArg->getType() == + SILType::getRawPointerType(getASTContext())); + auto pbStructType = + remapType(getPullbackInfo().getLinearMapStructLoweredType(origPredBB)); + auto predPbStructAddr = pullbackTrampolineBBBuilder.createPointerToAddress( + loc, pullbackTrampolineBBArg, pbStructType.getAddressType(), + /*isStrict*/ true); + auto predPbStructVal = pullbackTrampolineBBBuilder.createLoad( + loc, predPbStructAddr, + pbStructType.isTrivial(getPullback()) ? + LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Take); + trampolineArguments.push_back(predPbStructVal); } else { - auto *projectBox = pullbackTrampolineBBBuilder.createProjectBox( - loc, predPBStructVal, /*index*/ 0); - auto loaded = pullbackTrampolineBBBuilder.emitLoadValueOperation( - loc, projectBox, LoadOwnershipQualifier::Copy); - pullbackTrampolineBBBuilder.emitDestroyValueOperation(loc, predPBStructVal); - trampolineArguments.push_back(loaded); + trampolineArguments.push_back(pullbackTrampolineBBArg); } // Branch from pullback trampoline block to pullback block. pullbackTrampolineBBBuilder.createBranch(loc, pullbackBB, @@ -2535,7 +2556,7 @@ bool PullbackCloner::Implementation::runForSemanticMemberGetter() { // Get getter argument and result values. // Getter type: $(Self) -> Result - // Pullback type: $(Result', PB_Struct) -> Self' + // Pullback type: $(Result', PB_Struct|Context) -> Self' assert(original.getLoweredFunctionType()->getNumParameters() == 1); assert(pullback.getLoweredFunctionType()->getNumParameters() == 2); assert(pullback.getLoweredFunctionType()->getNumResults() == 1); @@ -2547,8 +2568,10 @@ bool PullbackCloner::Implementation::runForSemanticMemberGetter() { "Getter should have one semantic result"); auto origResult = origFormalResults[*getIndices().results->begin()]; - auto tangentVectorSILTy = pullback.getConventions().getSingleSILResultType( - TypeExpansionContext::minimal()); + auto tangentVectorSILTy = pullback.getConventions().getResults().front() + .getSILStorageType(getModule(), + pullback.getLoweredFunctionType(), + TypeExpansionContext::minimal()); auto tangentVectorTy = tangentVectorSILTy.getASTType(); auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct(); diff --git a/lib/SILOptimizer/Differentiation/VJPCloner.cpp b/lib/SILOptimizer/Differentiation/VJPCloner.cpp index 09d26e3b99baa..23bd853e2f3a3 100644 --- a/lib/SILOptimizer/Differentiation/VJPCloner.cpp +++ b/lib/SILOptimizer/Differentiation/VJPCloner.cpp @@ -27,6 +27,7 @@ #include "swift/SIL/TerminatorUtils.h" #include "swift/SIL/TypeSubstCloner.h" +#include "swift/SILOptimizer/Analysis/LoopAnalysis.h" #include "swift/SILOptimizer/PassManager/PrettyStackTrace.h" #include "swift/SILOptimizer/Utils/CFGOptUtils.h" #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" @@ -64,6 +65,9 @@ class VJPCloner::Implementation final /// Info from activity analysis on the original function. const DifferentiableActivityInfo &activityInfo; + /// The loop info. + SILLoopInfo *loopInfo; + /// The linear map info. LinearMapInfo pullbackInfo; @@ -71,6 +75,16 @@ class VJPCloner::Implementation final /// predecessor enum argument). SmallPtrSet remappedBasicBlocks; + /// The `AutoDiffLinearMapContext` object. If null, no explicit context is + /// needed (no loops). + SILValue pullbackContextValue; + /// The unique, borrowed context object. This is valid until the exit block. + SILValue borrowedPullbackContextValue; + + /// The generic signature of the `Builtin.autoDiffAllocateSubcontext(_:_:)` + /// declaration. It is used for creating a builtin call. + GenericSignature builtinAutoDiffAllocateSubcontextGenericSignature; + bool errorOccurred = false; /// Mapping from original blocks to pullback values. Used to build pullback @@ -93,6 +107,31 @@ class VJPCloner::Implementation final /// Run VJP generation. Returns true on error. bool run(); + /// Initializes a context object if needed. + void emitLinearMapContextInitializationIfNeeded() { + if (!pullbackInfo.hasLoops()) + return; + // Get linear map struct size. + auto *returnBB = &*original->findReturnBB(); + auto pullbackStructType = + remapType(pullbackInfo.getLinearMapStructLoweredType(returnBB)); + Builder.setInsertionPoint(vjp->getEntryBlock()); + auto topLevelSubcontextSize = emitMemoryLayoutSize( + Builder, original->getLocation(), pullbackStructType.getASTType()); + // Create an context. + pullbackContextValue = Builder.createBuiltin( + original->getLocation(), + getASTContext().getIdentifier( + getBuiltinName(BuiltinValueKind::AutoDiffCreateLinearMapContext)), + SILType::getNativeObjectType(getASTContext()), + SubstitutionMap(), {topLevelSubcontextSize}); + borrowedPullbackContextValue = Builder.createBeginBorrow( + original->getLocation(), pullbackContextValue); + LLVM_DEBUG(getADDebugStream() + << "Context object initialized because there are loops\n" + << *vjp->getEntryBlock() << '\n'); + } + /// Get the lowered SIL type of the given AST type. SILType getLoweredType(Type type) { auto vjpGenSig = vjp->getLoweredFunctionType()->getSubstGenericSignature(); @@ -101,11 +140,17 @@ class VJPCloner::Implementation final return vjp->getLoweredType(pattern, type); } - /// Get the lowered SIL type of the given nominal type declaration. - SILType getNominalDeclLoweredType(NominalTypeDecl *nominal) { - auto nominalType = - getOpASTType(nominal->getDeclaredInterfaceType()->getCanonicalType()); - return getLoweredType(nominalType); + GenericSignature getBuiltinAutoDiffAllocateSubcontextDecl() { + if (builtinAutoDiffAllocateSubcontextGenericSignature) + return builtinAutoDiffAllocateSubcontextGenericSignature; + auto &ctx = getASTContext(); + auto *decl = cast(getBuiltinValueDecl( + ctx, ctx.getIdentifier( + getBuiltinName(BuiltinValueKind::AutoDiffAllocateSubcontext)))); + builtinAutoDiffAllocateSubcontextGenericSignature = + decl->getGenericSignature(); + assert(builtinAutoDiffAllocateSubcontextGenericSignature); + return builtinAutoDiffAllocateSubcontextGenericSignature; } // Creates a trampoline block for given original terminator instruction, the @@ -173,8 +218,6 @@ class VJPCloner::Implementation final void visitReturnInst(ReturnInst *ri) { auto loc = ri->getOperand().getLoc(); - auto &builder = getBuilder(); - // Build pullback struct value for original block. auto *origExit = ri->getParent(); auto *pbStructVal = buildPullbackValueStructValue(ri); @@ -183,17 +226,35 @@ class VJPCloner::Implementation final auto *origRetInst = cast(origExit->getTerminator()); auto origResult = getOpValue(origRetInst->getOperand()); SmallVector origResults; - extractAllElements(origResult, builder, origResults); + extractAllElements(origResult, Builder, origResults); // Get and partially apply the pullback. auto vjpGenericEnv = vjp->getGenericEnvironment(); auto vjpSubstMap = vjpGenericEnv ? vjpGenericEnv->getForwardingSubstitutionMap() : vjp->getForwardingSubstitutionMap(); - auto *pullbackRef = builder.createFunctionRef(loc, pullback); - auto *pullbackPartialApply = - builder.createPartialApply(loc, pullbackRef, vjpSubstMap, {pbStructVal}, - ParameterConvention::Direct_Guaranteed); + auto *pullbackRef = Builder.createFunctionRef(loc, pullback); + + // Prepare partial application arguments. + SILValue partialApplyArg; + if (borrowedPullbackContextValue) { + // Initialize the top-level subcontext buffer with the top-level pullback + // struct. + auto addr = emitProjectTopLevelSubcontext( + Builder, loc, borrowedPullbackContextValue, pbStructVal->getType()); + Builder.createStore( + loc, pbStructVal, addr, + pbStructVal->getType().isTrivial(*pullback) ? + StoreOwnershipQualifier::Trivial : StoreOwnershipQualifier::Init); + partialApplyArg = pullbackContextValue; + Builder.createEndBorrow(loc, borrowedPullbackContextValue); + } else { + partialApplyArg = pbStructVal; + } + + auto *pullbackPartialApply = Builder.createPartialApply( + loc, pullbackRef, vjpSubstMap, {partialApplyArg}, + ParameterConvention::Direct_Guaranteed); auto pullbackType = vjp->getLoweredFunctionType() ->getResults() .back() @@ -213,7 +274,7 @@ class VJPCloner::Implementation final } else if (pullbackSubstType->isABICompatibleWith(pullbackFnType, *vjp) .isCompatible()) { pullbackValue = - builder.createConvertFunction(loc, pullbackPartialApply, pullbackType, + Builder.createConvertFunction(loc, pullbackPartialApply, pullbackType, /*withoutActuallyEscaping*/ false); } else { llvm::report_fatal_error("Pullback value type is not ABI-compatible " @@ -224,8 +285,8 @@ class VJPCloner::Implementation final SmallVector directResults; directResults.append(origResults.begin(), origResults.end()); directResults.push_back(pullbackValue); - builder.createReturn(ri->getLoc(), - joinElements(directResults, builder, loc)); + Builder.createReturn(ri->getLoc(), + joinElements(directResults, Builder, loc)); } void visitBranchInst(BranchInst *bi) { @@ -700,8 +761,10 @@ VJPCloner::Implementation::Implementation(VJPCloner &cloner, ADContext &context, vjp(vjp), invoker(invoker), activityInfo(getActivityInfoHelper( context, original, witness->getSILAutoDiffIndices(), vjp)), + loopInfo(context.getPassManager().getAnalysis() + ->get(original)), pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original, vjp, - witness->getSILAutoDiffIndices(), activityInfo) { + witness->getSILAutoDiffIndices(), activityInfo, loopInfo) { // Create empty pullback function. pullback = createEmptyPullback(); context.recordGeneratedFunction(pullback); @@ -728,6 +791,7 @@ const SILAutoDiffIndices VJPCloner::getIndices() const { } DifferentiationInvoker VJPCloner::getInvoker() const { return impl.invoker; } LinearMapInfo &VJPCloner::getPullbackInfo() const { return impl.pullbackInfo; } +SILLoopInfo *VJPCloner::getLoopInfo() const { return impl.loopInfo; } const DifferentiableActivityInfo &VJPCloner::getActivityInfo() const { return impl.activityInfo; } @@ -864,13 +928,21 @@ SILFunction *VJPCloner::Implementation::createEmptyPullback() { pbParams.push_back(inoutParamTanParam); } - // Accept a pullback struct in the pullback parameter list. This is the - // returned pullback's closure context. - auto *origExit = &*original->findReturnBB(); - auto *pbStruct = pullbackInfo.getLinearMapStruct(origExit); - auto pbStructType = - pbStruct->getDeclaredInterfaceType()->getCanonicalType(witnessCanGenSig); - pbParams.push_back({pbStructType, ParameterConvention::Direct_Owned}); + if (pullbackInfo.hasLoops()) { + // Accept a `AutoDiffLinarMapContext` heap object if there are loops. + pbParams.push_back({ + getASTContext().TheNativeObjectType, + ParameterConvention::Direct_Guaranteed + }); + } else { + // Accept a pullback struct in the pullback parameter list. This is the + // returned pullback's closure context. + auto *origExit = &*original->findReturnBB(); + auto *pbStruct = pullbackInfo.getLinearMapStruct(origExit); + auto pbStructType = + pbStruct->getDeclaredInterfaceType()->getCanonicalType(witnessCanGenSig); + pbParams.push_back({pbStructType, ParameterConvention::Direct_Owned}); + } // Add pullback results for the requested wrt parameters. for (auto i : indices.parameters->getIndices()) { @@ -946,8 +1018,8 @@ VJPCloner::Implementation::buildPullbackValueStructValue(TermInst *termInst) { auto loc = RegularLocation::getAutoGeneratedLocation(); auto origBB = termInst->getParent(); auto *vjpBB = BBMap[origBB]; - auto *pbStruct = pullbackInfo.getLinearMapStruct(origBB); - auto structLoweredTy = getNominalDeclLoweredType(pbStruct); + auto structLoweredTy = + remapType(pullbackInfo.getLinearMapStructLoweredType(origBB)); auto bbPullbackValues = pullbackValues[origBB]; if (!origBB->isEntry()) { auto *predEnumArg = vjpBB->getArguments().back(); @@ -961,25 +1033,36 @@ EnumInst *VJPCloner::Implementation::buildPredecessorEnumValue( SILBuilder &builder, SILBasicBlock *predBB, SILBasicBlock *succBB, SILValue pbStructVal) { auto loc = RegularLocation::getAutoGeneratedLocation(); - auto *succEnum = pullbackInfo.getBranchingTraceDecl(succBB); - auto enumLoweredTy = getNominalDeclLoweredType(succEnum); + auto enumLoweredTy = + remapType(pullbackInfo.getBranchingTraceEnumLoweredType(succBB)); auto *enumEltDecl = pullbackInfo.lookUpBranchingTraceEnumElement(predBB, succBB); auto enumEltType = getOpType(enumLoweredTy.getEnumElementType( enumEltDecl, getModule(), TypeExpansionContext::minimal())); - // If the enum element type does not have a box type (i.e. the enum case is - // not indirect), then directly create an enum. - auto boxType = dyn_cast(enumEltType.getASTType()); - if (!boxType) - return builder.createEnum(loc, pbStructVal, enumEltDecl, enumLoweredTy); - // Otherwise, box the pullback struct value and create an enum. - auto *newBox = builder.createAllocBox(loc, boxType); - builder.emitScopedBorrowOperation(loc, newBox, [&](SILValue borrowedBox) { - auto *projectBox = builder.createProjectBox(loc, newBox, /*index*/ 0); - builder.emitStoreValueOperation(loc, pbStructVal, projectBox, - StoreOwnershipQualifier::Init); - }); - return builder.createEnum(loc, newBox, enumEltDecl, enumLoweredTy); + // If the predecessor block is in a loop, its predecessor enum payload is a + // `Builtin.RawPointer`. + if (loopInfo->getLoopFor(predBB)) { + auto rawPtrType = SILType::getRawPointerType(getASTContext()); + assert(enumEltType == rawPtrType); + auto pbStructType = pbStructVal->getType(); + SILValue pbStructSize = + emitMemoryLayoutSize(Builder, loc, pbStructType.getASTType()); + auto rawBufferValue = builder.createBuiltin( + loc, + getASTContext().getIdentifier( + getBuiltinName(BuiltinValueKind::AutoDiffAllocateSubcontext)), + rawPtrType, SubstitutionMap(), + {borrowedPullbackContextValue, pbStructSize}); + auto typedBufferValue = builder.createPointerToAddress( + loc, rawBufferValue, pbStructType.getAddressType(), + /*isStrict*/ true); + builder.createStore( + loc, pbStructVal, typedBufferValue, + pbStructType.isTrivial(*pullback) ? + StoreOwnershipQualifier::Trivial : StoreOwnershipQualifier::Init); + return builder.createEnum(loc, rawBufferValue, enumEltDecl, enumLoweredTy); + } + return builder.createEnum(loc, pbStructVal, enumEltDecl, enumLoweredTy); } bool VJPCloner::Implementation::run() { @@ -991,6 +1074,8 @@ bool VJPCloner::Implementation::run() { auto *entry = vjp->createBasicBlock(); createEntryArguments(vjp); + emitLinearMapContextInitializationIfNeeded(); + // Clone. SmallVector entryArgs(entry->getArguments().begin(), entry->getArguments().end()); diff --git a/lib/SILOptimizer/Transforms/AccessEnforcementReleaseSinking.cpp b/lib/SILOptimizer/Transforms/AccessEnforcementReleaseSinking.cpp index f9c0eac952315..f3a8b7bf6e613 100644 --- a/lib/SILOptimizer/Transforms/AccessEnforcementReleaseSinking.cpp +++ b/lib/SILOptimizer/Transforms/AccessEnforcementReleaseSinking.cpp @@ -141,6 +141,7 @@ static bool isBarrier(SILInstruction *inst) { case BuiltinValueKind::COWBufferForReading: case BuiltinValueKind::IntInstrprofIncrement: case BuiltinValueKind::GetCurrentAsyncTask: + case BuiltinValueKind::AutoDiffCreateLinearMapContext: return false; // Handle some rare builtins that may be sensitive to object lifetime @@ -168,6 +169,8 @@ static bool isBarrier(SILInstruction *inst) { case BuiltinValueKind::CancelAsyncTask: case BuiltinValueKind::CreateAsyncTask: case BuiltinValueKind::CreateAsyncTaskFuture: + case BuiltinValueKind::AutoDiffProjectTopLevelSubcontext: + case BuiltinValueKind::AutoDiffAllocateSubcontext: return true; } } diff --git a/lib/Sema/MiscDiagnostics.cpp b/lib/Sema/MiscDiagnostics.cpp index 1c8d2dd969578..3cb0bc4d6639d 100644 --- a/lib/Sema/MiscDiagnostics.cpp +++ b/lib/Sema/MiscDiagnostics.cpp @@ -653,9 +653,9 @@ static void diagSyntacticUseRestrictions(const Expr *E, const DeclContext *DC, .fixItInsert(DRE->getStartLoc(), "self."); } - DeclContext *topLevelContext = DC->getModuleScopeContext(); + DeclContext *topLevelSubcontext = DC->getModuleScopeContext(); auto descriptor = UnqualifiedLookupDescriptor( - DeclNameRef(VD->getBaseName()), topLevelContext, SourceLoc()); + DeclNameRef(VD->getBaseName()), topLevelSubcontext, SourceLoc()); auto lookup = evaluateOrDefault(Ctx.evaluator, UnqualifiedLookupRequest{descriptor}, {}); diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index 5de3805dc2043..56d4d9d4f450a 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -2605,8 +2605,8 @@ class Serializer::DeclSerializer : public DeclVisitor { storage->hasPrivateAccessor())); if (shouldEmitFilenameForPrivate || shouldEmitPrivateDiscriminator) { - auto topLevelContext = value->getDeclContext()->getModuleScopeContext(); - if (auto *enclosingFile = dyn_cast(topLevelContext)) { + auto topLevelSubcontext = value->getDeclContext()->getModuleScopeContext(); + if (auto *enclosingFile = dyn_cast(topLevelSubcontext)) { if (shouldEmitPrivateDiscriminator) { Identifier discriminator = enclosingFile->getDiscriminatorForPrivateValue(value); diff --git a/stdlib/public/Differentiation/CMakeLists.txt b/stdlib/public/Differentiation/CMakeLists.txt index 48ab091a260c2..c37ae23cf5339 100644 --- a/stdlib/public/Differentiation/CMakeLists.txt +++ b/stdlib/public/Differentiation/CMakeLists.txt @@ -34,6 +34,8 @@ add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPE SWIFT_MODULE_DEPENDS_HAIKU Glibc SWIFT_MODULE_DEPENDS_WINDOWS CRT + C_COMPILE_FLAGS + -Dswift_Differentiation_EXPORTS SWIFT_COMPILE_FLAGS ${SWIFT_STANDARD_LIBRARY_SWIFT_FLAGS} -parse-stdlib diff --git a/stdlib/public/SwiftShims/Visibility.h b/stdlib/public/SwiftShims/Visibility.h index 4ddd34a12eaf7..deeec91d27dbd 100644 --- a/stdlib/public/SwiftShims/Visibility.h +++ b/stdlib/public/SwiftShims/Visibility.h @@ -178,6 +178,11 @@ #else #define SWIFT_IMAGE_EXPORTS_swift_Concurrency 0 #endif +#if defined(swift_Differentiation_EXPORTS) +#define SWIFT_IMAGE_EXPORTS_swift_Differentiation 1 +#else +#define SWIFT_IMAGE_EXPORTS_swift_Differentiation 0 +#endif #define SWIFT_EXPORT_FROM_ATTRIBUTE(LIBRARY) \ SWIFT_MACRO_IF(SWIFT_IMAGE_EXPORTS_##LIBRARY, \ diff --git a/stdlib/public/runtime/AutoDiffSupport.cpp b/stdlib/public/runtime/AutoDiffSupport.cpp new file mode 100644 index 0000000000000..77b3add32de3b --- /dev/null +++ b/stdlib/public/runtime/AutoDiffSupport.cpp @@ -0,0 +1,72 @@ +//===--- AutoDiffSupport.cpp ----------------------------------*- C++ -*---===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +#include "AutoDiffSupport.h" +#include "swift/ABI/Metadata.h" +#include "swift/Runtime/HeapObject.h" + +using namespace swift; +using namespace llvm; + +SWIFT_CC(swift) +static void destroyLinearMapContext(SWIFT_CONTEXT HeapObject *obj) { + free(obj); +} + +/// Heap metadata for a linear map context. +static FullMetadata linearMapContextHeapMetadata = { + { + { + &destroyLinearMapContext + }, + { + /*value witness table*/ nullptr + } + }, + { + MetadataKind::Opaque + } +}; + +AutoDiffLinearMapContext::AutoDiffLinearMapContext() + : HeapObject(&linearMapContextHeapMetadata) { +} + +void *AutoDiffLinearMapContext::projectTopLevelSubcontext() const { + auto offset = alignTo( + sizeof(AutoDiffLinearMapContext), alignof(AutoDiffLinearMapContext)); + return const_cast( + reinterpret_cast(this) + offset); +} + +void *AutoDiffLinearMapContext::allocate(size_t size) { + return allocator.Allocate(size, alignof(AutoDiffLinearMapContext)); +} + +AutoDiffLinearMapContext *swift::swift_autoDiffCreateLinearMapContext( + size_t topLevelLinearMapStructSize) { + auto allocationSize = alignTo( + sizeof(AutoDiffLinearMapContext), alignof(AutoDiffLinearMapContext)) + + topLevelLinearMapStructSize; + auto *buffer = (AutoDiffLinearMapContext *)malloc(allocationSize); + return new (buffer) AutoDiffLinearMapContext; +} + +void *swift::swift_autoDiffProjectTopLevelSubcontext( + AutoDiffLinearMapContext *allocator) { + return allocator->projectTopLevelSubcontext(); +} + +void *swift::swift_autoDiffAllocateSubcontext( + AutoDiffLinearMapContext *allocator, size_t size) { + return allocator->allocate(size); +} diff --git a/stdlib/public/runtime/AutoDiffSupport.h b/stdlib/public/runtime/AutoDiffSupport.h new file mode 100644 index 0000000000000..7df152779e5ee --- /dev/null +++ b/stdlib/public/runtime/AutoDiffSupport.h @@ -0,0 +1,56 @@ +//===--- AutoDiffSupport.h ------------------------------------*- C++ -*---===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +#ifndef SWIFT_RUNTIME_AUTODIFF_SUPPORT_H +#define SWIFT_RUNTIME_AUTODIFF_SUPPORT_H + +#include "swift/Runtime/HeapObject.h" +#include "swift/Runtime/Config.h" +#include "llvm/Support/Allocator.h" + +namespace swift { + +/// A data structure responsible for efficiently allocating closure contexts for +/// linear maps such as pullbacks, including rescursive branching trace enum +/// case payloads. +class AutoDiffLinearMapContext : public HeapObject { +private: + /// The underlying allocator. + // TODO: Use a custom allocator so that the initial slab can be + // tail-allocated. + llvm::BumpPtrAllocator allocator; + +public: + /// Creates a linear map context. + AutoDiffLinearMapContext(); + /// Returns the address of the tail-allocated top-level subcontext. + void *projectTopLevelSubcontext() const; + /// Allocates memory for a new subcontext. + void *allocate(size_t size); +}; + +/// Creates a linear map context with a tail-allocated top-level subcontext. +SWIFT_EXPORT_FROM(swift_Differentiation) SWIFT_CC(swift) +AutoDiffLinearMapContext *swift_autoDiffCreateLinearMapContext( + size_t topLevelSubcontextSize); + +/// Returns the address of the tail-allocated top-level subcontext. +SWIFT_EXPORT_FROM(swift_Differentiation) SWIFT_CC(swift) +void *swift_autoDiffProjectTopLevelSubcontext(AutoDiffLinearMapContext *); + +/// Allocates memory for a new subcontext. +SWIFT_EXPORT_FROM(swift_Differentiation) SWIFT_CC(swift) +void *swift_autoDiffAllocateSubcontext(AutoDiffLinearMapContext *, size_t size); + +} + +#endif /* SWIFT_RUNTIME_AUTODIFF_SUPPORT_H */ diff --git a/stdlib/public/runtime/CMakeLists.txt b/stdlib/public/runtime/CMakeLists.txt index 56d5110ef94d8..fd8e29ae1a85e 100644 --- a/stdlib/public/runtime/CMakeLists.txt +++ b/stdlib/public/runtime/CMakeLists.txt @@ -28,6 +28,7 @@ set(swift_runtime_objc_sources set(swift_runtime_sources AnyHashableSupport.cpp Array.cpp + AutoDiffSupport.cpp BackDeployment.cpp Casting.cpp CompatibilityOverride.cpp diff --git a/test/AutoDiff/IRGen/runtime.swift b/test/AutoDiff/IRGen/runtime.swift new file mode 100644 index 0000000000000..d19edc7deffc9 --- /dev/null +++ b/test/AutoDiff/IRGen/runtime.swift @@ -0,0 +1,24 @@ +// RUN: %target-swift-frontend -parse-stdlib %s -emit-ir | %FileCheck %s + +import Swift +import _Differentiation + +struct ExamplePullbackStruct { + var pb0: (T.TangentVector) -> T.TangentVector +} + +@_silgen_name("test_context_builtins") +func test_context_builtins() { + let pbStruct = ExamplePullbackStruct(pb0: { $0 }) + let context = Builtin.autoDiffCreateLinearMapContext(Builtin.sizeof(type(of: pbStruct))) + let topLevelSubctxAddr = Builtin.autoDiffProjectTopLevelSubcontext(context) + UnsafeMutableRawPointer(topLevelSubctxAddr).storeBytes(of: pbStruct, as: type(of: pbStruct)) + let newBuffer = Builtin.autoDiffAllocateSubcontext(context, Builtin.sizeof(type(of: pbStruct))) + UnsafeMutableRawPointer(newBuffer).storeBytes(of: pbStruct, as: type(of: pbStruct)) +} + +// CHECK-LABEL: define{{.*}}@test_context_builtins() +// CHECK: entry: +// CHECK: [[CTX:%.*]] = call swiftcc %swift.refcounted* @swift_autoDiffCreateLinearMapContext({{i[0-9]+}} {{.*}}) +// CEHCK: call swiftcc i8* @swift_autoDiffProjectTopLevelSubcontext(%swift.refcounted* [[CTX]]) +// CHECK: [[BUF:%.*]] = call swiftcc i8* @swift_autoDiffAllocateSubcontext(%swift.refcounted* [[CTX]], {{i[0-9]+}} {{.*}}) diff --git a/test/AutoDiff/SILGen/autodiff_builtins.swift b/test/AutoDiff/SILGen/autodiff_builtins.swift index f8f521c9d342f..f9806a1f7cf4d 100644 --- a/test/AutoDiff/SILGen/autodiff_builtins.swift +++ b/test/AutoDiff/SILGen/autodiff_builtins.swift @@ -152,3 +152,26 @@ func linearFunction_f_direct_arity1() -> @differentiable(linear) (Float) -> Floa // CHECK: [[THICK_ORIG2:%.*]] = thin_to_thick_function [[ORIG2]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float // CHECK: [[LINEAR:%.*]] = linear_function [parameters 0] [[THICK_ORIG1]] : $@callee_guaranteed (Float) -> Float with_transpose [[THICK_ORIG2]] : $@callee_guaranteed (Float) -> Float // CHECK: return [[LINEAR]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float + +struct ExamplePullbackStruct { + var pb0: (T.TangentVector) -> T.TangentVector +} + +@_silgen_name("test_context_builtins") +func test_context_builtins() { + let pbStruct = ExamplePullbackStruct(pb0: { $0 }) + let context = Builtin.autoDiffCreateLinearMapContext(Builtin.sizeof(type(of: pbStruct))) + let topLevelSubctxAddr = Builtin.autoDiffProjectTopLevelSubcontext(context) + UnsafeMutableRawPointer(topLevelSubctxAddr).storeBytes(of: pbStruct, as: type(of: pbStruct)) + let newBuffer = Builtin.autoDiffAllocateSubcontext(context, Builtin.sizeof(type(of: pbStruct))) + UnsafeMutableRawPointer(newBuffer).storeBytes(of: pbStruct, as: type(of: pbStruct)) +} + +// CHECK-LABEL: sil{{.*}}@test_context_builtins +// CHECK: bb0: +// CHECK: [[CTX:%.*]] = builtin "autoDiffCreateLinearMapContext"({{%.*}} : $Builtin.Word) : $Builtin.NativeObject +// CHECK: [[BORROWED_CTX:%.*]] = begin_borrow [[CTX]] : $Builtin.NativeObject +// CHECK: [[BUF:%.*]] = builtin "autoDiffProjectTopLevelSubcontext"([[BORROWED_CTX]] : $Builtin.NativeObject) : $Builtin.RawPointer +// CHECK: [[BORROWED_CTX:%.*]] = begin_borrow [[CTX]] : $Builtin.NativeObject +// CHECK: [[BUF:%.*]] = builtin "autoDiffAllocateSubcontext"([[BORROWED_CTX]] : $Builtin.NativeObject, {{.*}} : $Builtin.Word) : $Builtin.RawPointer +// CHECK: destroy_value [[CTX]]