diff --git a/lib/IRGen/GenCall.cpp b/lib/IRGen/GenCall.cpp index e16a714a2e847..aaa1c11c17cb2 100644 --- a/lib/IRGen/GenCall.cpp +++ b/lib/IRGen/GenCall.cpp @@ -5006,7 +5006,9 @@ static void emitRetconCoroutineEntry( ArrayRef extraArguments, llvm::Constant *allocFn, llvm::Constant *deallocFn, ArrayRef finalArguments) { auto prototype = - IGF.IGM.getOpaquePtr(IGF.IGM.getAddrOfContinuationPrototype(fnType)); + IGF.IGM.getOpaquePtr( + IGF.IGM.getAddrOfContinuationPrototype(fnType, + fnType->getInvocationGenericSignature())); // Call the right 'llvm.coro.id.retcon' variant. SmallVector arguments; arguments.push_back( diff --git a/lib/IRGen/GenDecl.cpp b/lib/IRGen/GenDecl.cpp index e0569dbaa238e..ce2b068e99d5d 100644 --- a/lib/IRGen/GenDecl.cpp +++ b/lib/IRGen/GenDecl.cpp @@ -6280,13 +6280,14 @@ IRGenModule::getAddrOfDefaultAssociatedConformanceAccessor( } llvm::Function * -IRGenModule::getAddrOfContinuationPrototype(CanSILFunctionType fnType) { +IRGenModule::getAddrOfContinuationPrototype(CanSILFunctionType fnType, + CanGenericSignature sig) { LinkEntity entity = LinkEntity::forCoroutineContinuationPrototype(fnType); llvm::Function *&entry = GlobalFuncs[entity]; if (entry) return entry; - GenericContextScope scope(*this, fnType->getInvocationGenericSignature()); + GenericContextScope scope(*this, sig); auto signature = Signature::forCoroutineContinuation(*this, fnType); LinkInfo link = LinkInfo::get(*this, entity, NotForDefinition); entry = createFunction(*this, link, signature); diff --git a/lib/IRGen/GenFunc.cpp b/lib/IRGen/GenFunc.cpp index aa724db88b200..685d0a4ff364c 100644 --- a/lib/IRGen/GenFunc.cpp +++ b/lib/IRGen/GenFunc.cpp @@ -1492,7 +1492,8 @@ class CoroPartialApplicationForwarderEmission auto prototype = subIGF.IGM.getOpaquePtr( subIGF.IGM.getAddrOfContinuationPrototype( cast( - unsubstType->mapTypeOutOfContext()->getCanonicalType()))); + unsubstType->mapTypeOutOfContext()->getCanonicalType()), + origType->getInvocationGenericSignature())); // Use free as our allocator. diff --git a/lib/IRGen/IRGenModule.h b/lib/IRGen/IRGenModule.h index 21ac492717951..d2f902ceeae8f 100644 --- a/lib/IRGen/IRGenModule.h +++ b/lib/IRGen/IRGenModule.h @@ -1871,7 +1871,8 @@ private: \ void emitDynamicReplacementOriginalFunctionThunk(SILFunction *f); - llvm::Function *getAddrOfContinuationPrototype(CanSILFunctionType fnType); + llvm::Function *getAddrOfContinuationPrototype(CanSILFunctionType fnType, + CanGenericSignature sig); Address getAddrOfSILGlobalVariable(SILGlobalVariable *var, const TypeInfo &ti, ForDefinition_t forDefinition); diff --git a/lib/IRGen/IRGenSIL.cpp b/lib/IRGen/IRGenSIL.cpp index 34bdb94ff7a8a..e0cb55af5e02d 100644 --- a/lib/IRGen/IRGenSIL.cpp +++ b/lib/IRGen/IRGenSIL.cpp @@ -4778,7 +4778,11 @@ void IRGenSILFunction::visitEndApply(BeginApplyInst *i, EndApplyInst *ei) { const auto &coroutine = getLoweredCoroutine(i->getTokenResult()); bool isAbort = ei == nullptr; - auto sig = Signature::forCoroutineContinuation(IGM, i->getOrigCalleeType()); + // Lower the return value in the callee's generic context. + auto origCalleeType = i->getOrigCalleeType(); + GenericContextScope scope(IGM, origCalleeType->getInvocationGenericSignature()); + + auto sig = Signature::forCoroutineContinuation(IGM, origCalleeType); // Cast the continuation pointer to the right function pointer type. auto continuation = coroutine.Continuation; @@ -4786,7 +4790,7 @@ void IRGenSILFunction::visitEndApply(BeginApplyInst *i, EndApplyInst *ei) { sig.getType()->getPointerTo()); auto schemaAndEntity = - getCoroutineResumeFunctionPointerAuth(IGM, i->getOrigCalleeType()); + getCoroutineResumeFunctionPointerAuth(IGM, origCalleeType); auto pointerAuth = PointerAuthInfo::emit(*this, schemaAndEntity.first, coroutine.getBuffer().getAddress(), schemaAndEntity.second); @@ -4815,16 +4819,15 @@ void IRGenSILFunction::visitEndApply(BeginApplyInst *i, EndApplyInst *ei) { if (!isAbort) { auto resultType = call->getType(); + Explosion e; if (!resultType->isVoidTy()) { - Explosion e; // FIXME: Do we need to handle ABI-related conversions here? // It seems we cannot have C function convention for coroutines, etc. extractScalarResults(*this, resultType, call, e); - - // NOTE: This inserts a new entry into the LoweredValues DenseMap, - // invalidating the reference held by `coroutine`. - setLoweredExplosion(ei, e); } + // NOTE: This inserts a new entry into the LoweredValues DenseMap, + // invalidating the reference held by `coroutine`. + setLoweredExplosion(ei, e); } } diff --git a/lib/SILOptimizer/Differentiation/Common.cpp b/lib/SILOptimizer/Differentiation/Common.cpp index fd4adfd979f0a..6fb6b07332056 100644 --- a/lib/SILOptimizer/Differentiation/Common.cpp +++ b/lib/SILOptimizer/Differentiation/Common.cpp @@ -61,10 +61,10 @@ bool isSemanticMemberAccessor(SILFunction *original) { auto *accessor = dyn_cast(decl); if (!accessor) return false; - // Currently, only getters and setters are supported. - // TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors. + // Currently, only getters, setters and _modify accessors are supported. if (accessor->getAccessorKind() != AccessorKind::Get && - accessor->getAccessorKind() != AccessorKind::Set) + accessor->getAccessorKind() != AccessorKind::Set && + accessor->getAccessorKind() != AccessorKind::Modify) return false; // Accessor must come from a `var` declaration. auto *varDecl = dyn_cast(accessor->getStorage()); diff --git a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp index 31af1f8691379..02be267e1d6fc 100644 --- a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp @@ -905,6 +905,7 @@ class PullbackCloner::Implementation final bool runForSemanticMemberAccessor(); bool runForSemanticMemberGetter(); bool runForSemanticMemberSetter(); + bool runForSemanticMemberModify(); /// If original result is non-varied, it will always have a zero derivative. /// Skip full pullback generation and simply emit zero derivatives for wrt @@ -2452,7 +2453,8 @@ bool PullbackCloner::Implementation::run() { // If the original function is an accessor with special-case pullback // generation logic, do special-case generation. - if (isSemanticMemberAccessor(&original)) { + bool isSemanticMemberAcc = isSemanticMemberAccessor(&original); + if (isSemanticMemberAcc) { if (runForSemanticMemberAccessor()) return true; } @@ -2730,7 +2732,8 @@ bool PullbackCloner::Implementation::run() { #endif LLVM_DEBUG(getADDebugStream() - << "Generated pullback for " << original.getName() << ":\n" + << "Generated " << (isSemanticMemberAcc ? "semantic member accessor" : "normal") + << " pullback for " << original.getName() << ":\n" << pullback); return errorOccurred; } @@ -3205,7 +3208,8 @@ bool PullbackCloner::Implementation::runForSemanticMemberAccessor() { return runForSemanticMemberGetter(); case AccessorKind::Set: return runForSemanticMemberSetter(); - // TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors. + case AccessorKind::Modify: + return runForSemanticMemberModify(); default: llvm_unreachable("Unsupported accessor kind; inconsistent with " "`isSemanticMemberAccessor`?"); @@ -3389,6 +3393,83 @@ bool PullbackCloner::Implementation::runForSemanticMemberSetter() { return false; } +bool PullbackCloner::Implementation::runForSemanticMemberModify() { + auto &original = getOriginal(); + auto &pullback = getPullback(); + auto pbLoc = getPullback().getLocation(); + + auto *accessor = cast(original.getDeclContext()->getAsDecl()); + assert(accessor->getAccessorKind() == AccessorKind::Modify); + + auto *origEntry = original.getEntryBlock(); + // We assume that the accessor has a simple 3-BB structure with yield in the entry BB + // plus resume and unwind BBs + auto *yi = cast(origEntry->getTerminator()); + auto *origResumeBB = yi->getResumeBB(); + + auto *pbEntry = pullback.getEntryBlock(); + builder.setCurrentDebugScope( + remapScope(origEntry->getScopeOfFirstNonMetaInstruction())); + builder.setInsertionPoint(pbEntry); + + // Get _modify accessor argument values. + // Accessor type : $(inout Self) -> @yields @inout Argument + // Pullback type : $(inout Self', linear map tuple) -> @yields @inout Argument' + // Normally pullbacks for semantic member accessors are single BB and + // therefore have empty linear map tuple, however, coroutines have a branching + // control flow due to possible coroutine abort, so we need to accommodate for + // this. We keep branch tracing enums in order not to special case in many + // other places. As there is no way to return to coroutine via abort exit, we + // essentially "linearize" a coroutine. + auto loweredFnTy = original.getLoweredFunctionType(); + auto pullbackLoweredFnTy = pullback.getLoweredFunctionType(); + + assert(loweredFnTy->getNumParameters() == 1 && + loweredFnTy->getNumYields() == 1); + assert(pullbackLoweredFnTy->getNumParameters() == 2); + assert(pullbackLoweredFnTy->getNumYields() == 1); + + SILValue origSelf = original.getArgumentsWithoutIndirectResults().front(); + + SmallVector origFormalResults; + collectAllFormalResultsInTypeOrder(original, origFormalResults); + + assert(getConfig().resultIndices->getNumIndices() == 2 && + "Modify accessor should have two semantic results"); + + auto origYield = origFormalResults[*std::next(getConfig().resultIndices->begin())]; + + // Look up the corresponding field in the tangent space. + auto *origField = cast(accessor->getStorage()); + auto baseType = remapType(origSelf->getType()).getASTType(); + auto *tanField = getTangentStoredProperty(getContext(), origField, baseType, + pbLoc, getInvoker()); + if (!tanField) { + errorOccurred = true; + return true; + } + + auto adjSelf = getAdjointBuffer(origResumeBB, origSelf); + auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField); + // Modify accessors have inout yields and therefore should yield addresses. + assert(getTangentValueCategory(origYield) == SILValueCategory::Address && + "Modify accessors should yield indirect"); + + // Yield the adjoint buffer and do everything else in the resume + // destination. Unwind destination is unreachable as the coroutine can never + // be aborted. + auto *unwindBB = getPullback().createBasicBlock(); + auto *resumeBB = getPullbackBlock(origEntry); + builder.createYield(yi->getLoc(), {adjSelfElt}, resumeBB, unwindBB); + builder.setInsertionPoint(unwindBB); + builder.createUnreachable(SILLocation::invalid()); + + builder.setInsertionPoint(resumeBB); + addToAdjointBuffer(origEntry, origSelf, adjSelf, pbLoc); + + return false; +} + //--------------------------------------------------------------------------// // Adjoint buffer mapping //--------------------------------------------------------------------------// diff --git a/lib/SILOptimizer/Differentiation/VJPCloner.cpp b/lib/SILOptimizer/Differentiation/VJPCloner.cpp index 7a11087553685..d10e888c79464 100644 --- a/lib/SILOptimizer/Differentiation/VJPCloner.cpp +++ b/lib/SILOptimizer/Differentiation/VJPCloner.cpp @@ -460,6 +460,16 @@ class VJPCloner::Implementation final TypeSubstCloner::visitEndApplyInst(eai); return; } + // If the original function is a semantic member accessor, do standard + // cloning. Semantic member accessors have special pullback generation + // logic, so all `end_apply` instructions can be directly cloned to the VJP. + if (isSemanticMemberAccessor(original)) { + LLVM_DEBUG(getADDebugStream() + << "Cloning `end_apply` in semantic member accessor:\n" + << *eai << '\n'); + TypeSubstCloner::visitEndApplyInst(eai); + return; + } Builder.setCurrentDebugScope(getOpScope(eai->getDebugScope())); auto loc = eai->getLoc(); @@ -607,6 +617,16 @@ class VJPCloner::Implementation final TypeSubstCloner::visitBeginApplyInst(bai); return; } + // If the original function is a semantic member accessor, do standard + // cloning. Semantic member accessors have special pullback generation + // logic, so all `begin_apply` instructions can be directly cloned to the VJP. + if (isSemanticMemberAccessor(original)) { + LLVM_DEBUG(getADDebugStream() + << "Cloning `begin_apply` in semantic member accessor:\n" + << *bai << '\n'); + TypeSubstCloner::visitBeginApplyInst(bai); + return; + } Builder.setCurrentDebugScope(getOpScope(bai->getDebugScope())); auto loc = bai->getLoc(); diff --git a/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift.gyb b/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift.gyb index c84b785d61d30..624a3c3a4afc4 100644 --- a/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift.gyb +++ b/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift.gyb @@ -340,6 +340,29 @@ where { return (lhs * rhs, { (dx, dy) in dx * rhs + dy * lhs }) } + + @usableFromInline + @derivative(of: *=) + static func _vjpMultiplyAssign(_ lhs: inout Self, _ rhs: Self) -> ( + value: Void, pullback: (inout Self) -> Self) + { + defer { lhs *= rhs } + return ((), { [lhs = lhs] v in + let drhs = lhs * v + v *= rhs + return drhs + }) + } + + @usableFromInline + @derivative(of: *=) + static func _jvpMultiplyAssign(_ lhs: inout Self, _ rhs: Self) -> ( + value: Void, differential: (inout Self, Self) -> Void) + { + let oldLhs = lhs + lhs *= rhs + return ((), { $0 = $0 * rhs + oldLhs * $1 }) + } } extension ${Self} diff --git a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift index 6df2c5c5300ac..bc7c3bb71c22b 100644 --- a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift +++ b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift @@ -680,9 +680,7 @@ extension DifferentiableWrapper: Differentiable where Value: Differentiable {} // accesses. struct Struct: Differentiable { - // expected-error @+4 {{expression is not differentiable}} - // expected-error @+3 {{expression is not differentiable}} - // expected-note @+2 {{cannot differentiate access to property 'Struct._x' because 'Struct.TangentVector' does not have a stored property named '_x'}} + // expected-error @+2 {{expression is not differentiable}} // expected-note @+1 {{cannot differentiate access to property 'Struct._x' because 'Struct.TangentVector' does not have a stored property named '_x'}} @DifferentiableWrapper @DifferentiableWrapper var x: Float = 10 diff --git a/test/AutoDiff/validation-test/property_wrappers.swift b/test/AutoDiff/validation-test/property_wrappers.swift index 4f102acc5d8c6..e28f6af997158 100644 --- a/test/AutoDiff/validation-test/property_wrappers.swift +++ b/test/AutoDiff/validation-test/property_wrappers.swift @@ -19,6 +19,7 @@ struct Wrapper { var wrappedValue: Value { // computed property get { value } set { value = newValue } + _modify { yield &value } } init(wrappedValue: Value) { @@ -46,8 +47,6 @@ PropertyWrapperTests.test("SimpleStruct") { expectEqual((.init(x: 60, y: 0, z: 20), 300), gradient(at: Struct(), 2, of: setter)) - // TODO: Support `modify` accessors (https://github.com/apple/swift/issues/55084). - /* func modify(_ s: Struct, _ x: Tracked) -> Tracked { var s = s s.x *= x * s.z @@ -55,7 +54,6 @@ PropertyWrapperTests.test("SimpleStruct") { } expectEqual((.init(x: 60, y: 0, z: 20), 300), gradient(at: Struct(), 2, of: modify)) - */ } struct GenericStruct { @@ -86,8 +84,6 @@ PropertyWrapperTests.test("GenericStruct") { expectEqual((.init(x: 60, y: 0, z: 20), 300), gradient(at: GenericStruct>(y: 20), 2, of: setter)) - // TODO: Support `modify` accessors (https://github.com/apple/swift/issues/55084). - /* func modify(_ s: GenericStruct, _ x: Tracked) -> Tracked { var s = s s.x *= x * s.z @@ -95,7 +91,6 @@ PropertyWrapperTests.test("GenericStruct") { } expectEqual((.init(x: 60, y: 0, z: 20), 300), gradient(at: GenericStruct>(y: 1), 2, of: modify)) - */ } // TF-1149: Test class with loadable type but address-only `TangentVector` type. @@ -131,16 +126,18 @@ PropertyWrapperTests.test("SimpleClass") { gradient(at: Class(), 2, of: setter)) */ - // TODO: Support `modify` accessors (https://github.com/apple/swift/issues/55084). - /* + // FIXME(TF-1175): Same issue as above func modify(_ c: Class, _ x: Tracked) -> Tracked { var c = c c.x *= x * c.z return c.x } + /* expectEqual((.init(x: 60, y: 0, z: 20), 300), gradient(at: Class(), 2, of: modify)) */ + expectEqual((.init(x: 1, y: 0, z: 0), 0), + gradient(at: Class(), 2, of: modify)) } // From: https://github.com/apple/swift-evolution/blob/master/proposals/0258-property-wrappers.md#proposed-solution @@ -157,12 +154,13 @@ enum Lazy { var wrappedValue: Value { // TODO(TF-1250): Replace with actual mutating getter implementation. - // Requires differentiation to support functions with multiple results. - get { + // Requires support for mutating semantic member accessor + /* mutating */ get { switch self { case .uninitialized(let initializer): let value = initializer() // NOTE: Actual implementation assigns to `self` here. + // self = .initialized(value) return value case .initialized(let value): return value