Skip to content

Support differentiation of wrapped value modify accessors #78794

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lib/IRGen/GenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5006,7 +5006,9 @@ static void emitRetconCoroutineEntry(
ArrayRef<llvm::Value *> extraArguments, llvm::Constant *allocFn,
llvm::Constant *deallocFn, ArrayRef<llvm::Value *> finalArguments) {
auto prototype =
IGF.IGM.getOpaquePtr(IGF.IGM.getAddrOfContinuationPrototype(fnType));
IGF.IGM.getOpaquePtr(
IGF.IGM.getAddrOfContinuationPrototype(fnType,
fnType->getInvocationGenericSignature()));
// Call the right 'llvm.coro.id.retcon' variant.
SmallVector<llvm::Value *, 8> arguments;
arguments.push_back(
Expand Down
5 changes: 3 additions & 2 deletions lib/IRGen/GenDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6280,13 +6280,14 @@ IRGenModule::getAddrOfDefaultAssociatedConformanceAccessor(
}

llvm::Function *
IRGenModule::getAddrOfContinuationPrototype(CanSILFunctionType fnType) {
IRGenModule::getAddrOfContinuationPrototype(CanSILFunctionType fnType,
CanGenericSignature sig) {
LinkEntity entity = LinkEntity::forCoroutineContinuationPrototype(fnType);

llvm::Function *&entry = GlobalFuncs[entity];
if (entry) return entry;

GenericContextScope scope(*this, fnType->getInvocationGenericSignature());
GenericContextScope scope(*this, sig);
auto signature = Signature::forCoroutineContinuation(*this, fnType);
LinkInfo link = LinkInfo::get(*this, entity, NotForDefinition);
entry = createFunction(*this, link, signature);
Expand Down
3 changes: 2 additions & 1 deletion lib/IRGen/GenFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1492,7 +1492,8 @@ class CoroPartialApplicationForwarderEmission
auto prototype = subIGF.IGM.getOpaquePtr(
subIGF.IGM.getAddrOfContinuationPrototype(
cast<SILFunctionType>(
unsubstType->mapTypeOutOfContext()->getCanonicalType())));
unsubstType->mapTypeOutOfContext()->getCanonicalType()),
origType->getInvocationGenericSignature()));


// Use free as our allocator.
Expand Down
3 changes: 2 additions & 1 deletion lib/IRGen/IRGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -1871,7 +1871,8 @@ private: \

void emitDynamicReplacementOriginalFunctionThunk(SILFunction *f);

llvm::Function *getAddrOfContinuationPrototype(CanSILFunctionType fnType);
llvm::Function *getAddrOfContinuationPrototype(CanSILFunctionType fnType,
CanGenericSignature sig);
Address getAddrOfSILGlobalVariable(SILGlobalVariable *var,
const TypeInfo &ti,
ForDefinition_t forDefinition);
Expand Down
17 changes: 10 additions & 7 deletions lib/IRGen/IRGenSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4778,15 +4778,19 @@ void IRGenSILFunction::visitEndApply(BeginApplyInst *i, EndApplyInst *ei) {
const auto &coroutine = getLoweredCoroutine(i->getTokenResult());
bool isAbort = ei == nullptr;

auto sig = Signature::forCoroutineContinuation(IGM, i->getOrigCalleeType());
// Lower the return value in the callee's generic context.
auto origCalleeType = i->getOrigCalleeType();
GenericContextScope scope(IGM, origCalleeType->getInvocationGenericSignature());

auto sig = Signature::forCoroutineContinuation(IGM, origCalleeType);

// Cast the continuation pointer to the right function pointer type.
auto continuation = coroutine.Continuation;
continuation = Builder.CreateBitCast(continuation,
sig.getType()->getPointerTo());

auto schemaAndEntity =
getCoroutineResumeFunctionPointerAuth(IGM, i->getOrigCalleeType());
getCoroutineResumeFunctionPointerAuth(IGM, origCalleeType);
auto pointerAuth = PointerAuthInfo::emit(*this, schemaAndEntity.first,
coroutine.getBuffer().getAddress(),
schemaAndEntity.second);
Expand Down Expand Up @@ -4815,16 +4819,15 @@ void IRGenSILFunction::visitEndApply(BeginApplyInst *i, EndApplyInst *ei) {

if (!isAbort) {
auto resultType = call->getType();
Explosion e;
if (!resultType->isVoidTy()) {
Explosion e;
// FIXME: Do we need to handle ABI-related conversions here?
// It seems we cannot have C function convention for coroutines, etc.
extractScalarResults(*this, resultType, call, e);

// NOTE: This inserts a new entry into the LoweredValues DenseMap,
// invalidating the reference held by `coroutine`.
setLoweredExplosion(ei, e);
}
// NOTE: This inserts a new entry into the LoweredValues DenseMap,
// invalidating the reference held by `coroutine`.
setLoweredExplosion(ei, e);
}
}

Expand Down
6 changes: 3 additions & 3 deletions lib/SILOptimizer/Differentiation/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ bool isSemanticMemberAccessor(SILFunction *original) {
auto *accessor = dyn_cast<AccessorDecl>(decl);
if (!accessor)
return false;
// Currently, only getters and setters are supported.
// TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors.
// Currently, only getters, setters and _modify accessors are supported.
if (accessor->getAccessorKind() != AccessorKind::Get &&
accessor->getAccessorKind() != AccessorKind::Set)
accessor->getAccessorKind() != AccessorKind::Set &&
accessor->getAccessorKind() != AccessorKind::Modify)
return false;
// Accessor must come from a `var` declaration.
auto *varDecl = dyn_cast<VarDecl>(accessor->getStorage());
Expand Down
87 changes: 84 additions & 3 deletions lib/SILOptimizer/Differentiation/PullbackCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,7 @@ class PullbackCloner::Implementation final
bool runForSemanticMemberAccessor();
bool runForSemanticMemberGetter();
bool runForSemanticMemberSetter();
bool runForSemanticMemberModify();

/// If original result is non-varied, it will always have a zero derivative.
/// Skip full pullback generation and simply emit zero derivatives for wrt
Expand Down Expand Up @@ -2452,7 +2453,8 @@ bool PullbackCloner::Implementation::run() {

// If the original function is an accessor with special-case pullback
// generation logic, do special-case generation.
if (isSemanticMemberAccessor(&original)) {
bool isSemanticMemberAcc = isSemanticMemberAccessor(&original);
if (isSemanticMemberAcc) {
if (runForSemanticMemberAccessor())
return true;
}
Expand Down Expand Up @@ -2730,7 +2732,8 @@ bool PullbackCloner::Implementation::run() {
#endif

LLVM_DEBUG(getADDebugStream()
<< "Generated pullback for " << original.getName() << ":\n"
<< "Generated " << (isSemanticMemberAcc ? "semantic member accessor" : "normal")
<< " pullback for " << original.getName() << ":\n"
<< pullback);
return errorOccurred;
}
Expand Down Expand Up @@ -3205,7 +3208,8 @@ bool PullbackCloner::Implementation::runForSemanticMemberAccessor() {
return runForSemanticMemberGetter();
case AccessorKind::Set:
return runForSemanticMemberSetter();
// TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors.
case AccessorKind::Modify:
return runForSemanticMemberModify();
default:
llvm_unreachable("Unsupported accessor kind; inconsistent with "
"`isSemanticMemberAccessor`?");
Expand Down Expand Up @@ -3389,6 +3393,83 @@ bool PullbackCloner::Implementation::runForSemanticMemberSetter() {
return false;
}

bool PullbackCloner::Implementation::runForSemanticMemberModify() {
auto &original = getOriginal();
auto &pullback = getPullback();
auto pbLoc = getPullback().getLocation();

auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
assert(accessor->getAccessorKind() == AccessorKind::Modify);

auto *origEntry = original.getEntryBlock();
// We assume that the accessor has a simple 3-BB structure with yield in the entry BB
// plus resume and unwind BBs
auto *yi = cast<YieldInst>(origEntry->getTerminator());
auto *origResumeBB = yi->getResumeBB();

auto *pbEntry = pullback.getEntryBlock();
builder.setCurrentDebugScope(
remapScope(origEntry->getScopeOfFirstNonMetaInstruction()));
builder.setInsertionPoint(pbEntry);

// Get _modify accessor argument values.
// Accessor type : $(inout Self) -> @yields @inout Argument
// Pullback type : $(inout Self', linear map tuple) -> @yields @inout Argument'
// Normally pullbacks for semantic member accessors are single BB and
// therefore have empty linear map tuple, however, coroutines have a branching
// control flow due to possible coroutine abort, so we need to accommodate for
// this. We keep branch tracing enums in order not to special case in many
// other places. As there is no way to return to coroutine via abort exit, we
// essentially "linearize" a coroutine.
auto loweredFnTy = original.getLoweredFunctionType();
auto pullbackLoweredFnTy = pullback.getLoweredFunctionType();

assert(loweredFnTy->getNumParameters() == 1 &&
loweredFnTy->getNumYields() == 1);
assert(pullbackLoweredFnTy->getNumParameters() == 2);
assert(pullbackLoweredFnTy->getNumYields() == 1);

SILValue origSelf = original.getArgumentsWithoutIndirectResults().front();

SmallVector<SILValue, 8> origFormalResults;
collectAllFormalResultsInTypeOrder(original, origFormalResults);

assert(getConfig().resultIndices->getNumIndices() == 2 &&
"Modify accessor should have two semantic results");

auto origYield = origFormalResults[*std::next(getConfig().resultIndices->begin())];

// Look up the corresponding field in the tangent space.
auto *origField = cast<VarDecl>(accessor->getStorage());
auto baseType = remapType(origSelf->getType()).getASTType();
auto *tanField = getTangentStoredProperty(getContext(), origField, baseType,
pbLoc, getInvoker());
if (!tanField) {
errorOccurred = true;
return true;
}

auto adjSelf = getAdjointBuffer(origResumeBB, origSelf);
auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField);
// Modify accessors have inout yields and therefore should yield addresses.
assert(getTangentValueCategory(origYield) == SILValueCategory::Address &&
"Modify accessors should yield indirect");

// Yield the adjoint buffer and do everything else in the resume
// destination. Unwind destination is unreachable as the coroutine can never
// be aborted.
auto *unwindBB = getPullback().createBasicBlock();
auto *resumeBB = getPullbackBlock(origEntry);
builder.createYield(yi->getLoc(), {adjSelfElt}, resumeBB, unwindBB);
builder.setInsertionPoint(unwindBB);
builder.createUnreachable(SILLocation::invalid());

builder.setInsertionPoint(resumeBB);
addToAdjointBuffer(origEntry, origSelf, adjSelf, pbLoc);

return false;
}

//--------------------------------------------------------------------------//
// Adjoint buffer mapping
//--------------------------------------------------------------------------//
Expand Down
20 changes: 20 additions & 0 deletions lib/SILOptimizer/Differentiation/VJPCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,16 @@ class VJPCloner::Implementation final
TypeSubstCloner::visitEndApplyInst(eai);
return;
}
// If the original function is a semantic member accessor, do standard
// cloning. Semantic member accessors have special pullback generation
// logic, so all `end_apply` instructions can be directly cloned to the VJP.
if (isSemanticMemberAccessor(original)) {
LLVM_DEBUG(getADDebugStream()
<< "Cloning `end_apply` in semantic member accessor:\n"
<< *eai << '\n');
TypeSubstCloner::visitEndApplyInst(eai);
return;
}

Builder.setCurrentDebugScope(getOpScope(eai->getDebugScope()));
auto loc = eai->getLoc();
Expand Down Expand Up @@ -607,6 +617,16 @@ class VJPCloner::Implementation final
TypeSubstCloner::visitBeginApplyInst(bai);
return;
}
// If the original function is a semantic member accessor, do standard
// cloning. Semantic member accessors have special pullback generation
// logic, so all `begin_apply` instructions can be directly cloned to the VJP.
if (isSemanticMemberAccessor(original)) {
LLVM_DEBUG(getADDebugStream()
<< "Cloning `begin_apply` in semantic member accessor:\n"
<< *bai << '\n');
TypeSubstCloner::visitBeginApplyInst(bai);
return;
}

Builder.setCurrentDebugScope(getOpScope(bai->getDebugScope()));
auto loc = bai->getLoc();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,29 @@ where
{
return (lhs * rhs, { (dx, dy) in dx * rhs + dy * lhs })
}

@usableFromInline
@derivative(of: *=)
static func _vjpMultiplyAssign(_ lhs: inout Self, _ rhs: Self) -> (
value: Void, pullback: (inout Self) -> Self)
{
defer { lhs *= rhs }
return ((), { [lhs = lhs] v in
let drhs = lhs * v
v *= rhs
return drhs
})
}

@usableFromInline
@derivative(of: *=)
static func _jvpMultiplyAssign(_ lhs: inout Self, _ rhs: Self) -> (
value: Void, differential: (inout Self, Self) -> Void)
{
let oldLhs = lhs
lhs *= rhs
return ((), { $0 = $0 * rhs + oldLhs * $1 })
}
}

extension ${Self}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,9 +680,7 @@ extension DifferentiableWrapper: Differentiable where Value: Differentiable {}
// accesses.

struct Struct: Differentiable {
// expected-error @+4 {{expression is not differentiable}}
// expected-error @+3 {{expression is not differentiable}}
// expected-note @+2 {{cannot differentiate access to property 'Struct._x' because 'Struct.TangentVector' does not have a stored property named '_x'}}
// expected-error @+2 {{expression is not differentiable}}
// expected-note @+1 {{cannot differentiate access to property 'Struct._x' because 'Struct.TangentVector' does not have a stored property named '_x'}}
@DifferentiableWrapper @DifferentiableWrapper var x: Float = 10

Expand Down
18 changes: 8 additions & 10 deletions test/AutoDiff/validation-test/property_wrappers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct Wrapper<Value> {
var wrappedValue: Value { // computed property
get { value }
set { value = newValue }
_modify { yield &value }
}

init(wrappedValue: Value) {
Expand Down Expand Up @@ -46,16 +47,13 @@ PropertyWrapperTests.test("SimpleStruct") {
expectEqual((.init(x: 60, y: 0, z: 20), 300),
gradient(at: Struct(), 2, of: setter))

// TODO: Support `modify` accessors (https://github.com/apple/swift/issues/55084).
/*
func modify(_ s: Struct, _ x: Tracked<Float>) -> Tracked<Float> {
var s = s
s.x *= x * s.z
return s.x
}
expectEqual((.init(x: 60, y: 0, z: 20), 300),
gradient(at: Struct(), 2, of: modify))
*/
}

struct GenericStruct<T> {
Expand Down Expand Up @@ -86,16 +84,13 @@ PropertyWrapperTests.test("GenericStruct") {
expectEqual((.init(x: 60, y: 0, z: 20), 300),
gradient(at: GenericStruct<Tracked<Float>>(y: 20), 2, of: setter))

// TODO: Support `modify` accessors (https://github.com/apple/swift/issues/55084).
/*
func modify<T>(_ s: GenericStruct<T>, _ x: Tracked<Float>) -> Tracked<Float> {
var s = s
s.x *= x * s.z
return s.x
}
expectEqual((.init(x: 60, y: 0, z: 20), 300),
gradient(at: GenericStruct<Tracked<Float>>(y: 1), 2, of: modify))
*/
}

// TF-1149: Test class with loadable type but address-only `TangentVector` type.
Expand Down Expand Up @@ -131,16 +126,18 @@ PropertyWrapperTests.test("SimpleClass") {
gradient(at: Class(), 2, of: setter))
*/

// TODO: Support `modify` accessors (https://github.com/apple/swift/issues/55084).
/*
// FIXME(TF-1175): Same issue as above
func modify(_ c: Class, _ x: Tracked<Float>) -> Tracked<Float> {
var c = c
c.x *= x * c.z
return c.x
}
/*
expectEqual((.init(x: 60, y: 0, z: 20), 300),
gradient(at: Class(), 2, of: modify))
*/
expectEqual((.init(x: 1, y: 0, z: 0), 0),
gradient(at: Class(), 2, of: modify))
}

// From: https://github.com/apple/swift-evolution/blob/master/proposals/0258-property-wrappers.md#proposed-solution
Expand All @@ -157,12 +154,13 @@ enum Lazy<Value> {

var wrappedValue: Value {
// TODO(TF-1250): Replace with actual mutating getter implementation.
// Requires differentiation to support functions with multiple results.
get {
// Requires support for mutating semantic member accessor
/* mutating */ get {
switch self {
case .uninitialized(let initializer):
let value = initializer()
// NOTE: Actual implementation assigns to `self` here.
// self = .initialized(value)
return value
case .initialized(let value):
return value
Expand Down