Skip to content

Commit b4a8fc0

Browse files
committed
Correctly synthesize semantic member _modify accessor. Support
differentiation of _modify accessor for wrapped values. Fixes #55084
1 parent e48ceef commit b4a8fc0

File tree

5 files changed

+116
-19
lines changed

5 files changed

+116
-19
lines changed

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ bool isSemanticMemberAccessor(SILFunction *original) {
6161
auto *accessor = dyn_cast<AccessorDecl>(decl);
6262
if (!accessor)
6363
return false;
64-
// Currently, only getters and setters are supported.
65-
// TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors.
64+
// Currently, only getters, setters and _modify accessors are supported.
6665
if (accessor->getAccessorKind() != AccessorKind::Get &&
67-
accessor->getAccessorKind() != AccessorKind::Set)
66+
accessor->getAccessorKind() != AccessorKind::Set &&
67+
accessor->getAccessorKind() != AccessorKind::Modify)
6868
return false;
6969
// Accessor must come from a `var` declaration.
7070
auto *varDecl = dyn_cast<VarDecl>(accessor->getStorage());

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ class PullbackCloner::Implementation final
905905
bool runForSemanticMemberAccessor();
906906
bool runForSemanticMemberGetter();
907907
bool runForSemanticMemberSetter();
908+
bool runForSemanticMemberModify();
908909

909910
/// If original result is non-varied, it will always have a zero derivative.
910911
/// Skip full pullback generation and simply emit zero derivatives for wrt
@@ -2431,7 +2432,8 @@ bool PullbackCloner::Implementation::run() {
24312432

24322433
// If the original function is an accessor with special-case pullback
24332434
// generation logic, do special-case generation.
2434-
if (isSemanticMemberAccessor(&original)) {
2435+
bool isSemanticMemberAcc = isSemanticMemberAccessor(&original);
2436+
if (isSemanticMemberAcc) {
24352437
if (runForSemanticMemberAccessor())
24362438
return true;
24372439
}
@@ -2581,7 +2583,8 @@ bool PullbackCloner::Implementation::run() {
25812583
#endif
25822584

25832585
LLVM_DEBUG(getADDebugStream()
2584-
<< "Generated pullback for " << original.getName() << ":\n"
2586+
<< "Generated " << (isSemanticMemberAcc ? "semantic member accessor" : "normal")
2587+
<< " pullback for " << original.getName() << ":\n"
25852588
<< pullback);
25862589
return errorOccurred;
25872590
}
@@ -3043,7 +3046,8 @@ bool PullbackCloner::Implementation::runForSemanticMemberAccessor() {
30433046
return runForSemanticMemberGetter();
30443047
case AccessorKind::Set:
30453048
return runForSemanticMemberSetter();
3046-
// TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors.
3049+
case AccessorKind::Modify:
3050+
return runForSemanticMemberModify();
30473051
default:
30483052
llvm_unreachable("Unsupported accessor kind; inconsistent with "
30493053
"`isSemanticMemberAccessor`?");
@@ -3227,6 +3231,83 @@ bool PullbackCloner::Implementation::runForSemanticMemberSetter() {
32273231
return false;
32283232
}
32293233

3234+
bool PullbackCloner::Implementation::runForSemanticMemberModify() {
3235+
auto &original = getOriginal();
3236+
auto &pullback = getPullback();
3237+
auto pbLoc = getPullback().getLocation();
3238+
3239+
auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
3240+
assert(accessor->getAccessorKind() == AccessorKind::Modify);
3241+
3242+
auto *origEntry = original.getEntryBlock();
3243+
// We assume that the accessor has a simple 3-BB structure with yield in the entry BB
3244+
// plus resume and unwind BBs
3245+
auto *yi = cast<YieldInst>(origEntry->getTerminator());
3246+
auto *origResumeBB = yi->getResumeBB();
3247+
3248+
auto *pbEntry = pullback.getEntryBlock();
3249+
builder.setCurrentDebugScope(
3250+
remapScope(origEntry->getScopeOfFirstNonMetaInstruction()));
3251+
builder.setInsertionPoint(pbEntry);
3252+
3253+
// Get _modify accessor argument values.
3254+
// Accessor type : $(inout Self) -> @yields @inout Argument
3255+
// Pullback type : $(inout Self', linear map tuple) -> @yields @inout Argument'
3256+
// Normally pullbacks for semantic member accessors are single BB and
3257+
// therefore have empty linear map tuple, however, coroutines have a branching
3258+
// control flow due to possible coroutine abort, so we need to accommodate for
3259+
// this. We keep branch tracing enums in order not to special case in many
3260+
// other places. As there is no way to return to coroutine via abort exit, we
3261+
// essentially "linearize" a coroutine.
3262+
auto loweredFnTy = original.getLoweredFunctionType();
3263+
auto pullbackLoweredFnTy = pullback.getLoweredFunctionType();
3264+
3265+
assert(loweredFnTy->getNumParameters() == 1 &&
3266+
loweredFnTy->getNumYields() == 1);
3267+
assert(pullbackLoweredFnTy->getNumParameters() == 2);
3268+
assert(pullbackLoweredFnTy->getNumYields() == 1);
3269+
3270+
SILValue origSelf = original.getArgumentsWithoutIndirectResults().front();
3271+
3272+
SmallVector<SILValue, 8> origFormalResults;
3273+
collectAllFormalResultsInTypeOrder(original, origFormalResults);
3274+
3275+
assert(getConfig().resultIndices->getNumIndices() == 2 &&
3276+
"Modify accessor should have two semantic results");
3277+
3278+
auto origYield = origFormalResults[*std::next(getConfig().resultIndices->begin())];
3279+
3280+
// Look up the corresponding field in the tangent space.
3281+
auto *origField = cast<VarDecl>(accessor->getStorage());
3282+
auto baseType = remapType(origSelf->getType()).getASTType();
3283+
auto *tanField = getTangentStoredProperty(getContext(), origField, baseType,
3284+
pbLoc, getInvoker());
3285+
if (!tanField) {
3286+
errorOccurred = true;
3287+
return true;
3288+
}
3289+
3290+
auto adjSelf = getAdjointBuffer(origResumeBB, origSelf);
3291+
auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField);
3292+
// Modify accessors have inout yields and therefore should yield addresses.
3293+
assert(getTangentValueCategory(origYield) == SILValueCategory::Address &&
3294+
"Modify accessors should yield indirect");
3295+
3296+
// Yield the adjoint buffer and do everything else in the resume
3297+
// destination. Unwind destination is unreachable as the coroutine can never
3298+
// be aborted.
3299+
auto *unwindBB = getPullback().createBasicBlock();
3300+
auto *resumeBB = getPullbackBlock(origEntry);
3301+
builder.createYield(yi->getLoc(), {adjSelfElt}, resumeBB, unwindBB);
3302+
builder.setInsertionPoint(unwindBB);
3303+
builder.createUnreachable(SILLocation::invalid());
3304+
3305+
builder.setInsertionPoint(resumeBB);
3306+
addToAdjointBuffer(origEntry, origSelf, adjSelf, pbLoc);
3307+
3308+
return false;
3309+
}
3310+
32303311
//--------------------------------------------------------------------------//
32313312
// Adjoint buffer mapping
32323313
//--------------------------------------------------------------------------//

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,16 @@ class VJPCloner::Implementation final
460460
TypeSubstCloner::visitEndApplyInst(eai);
461461
return;
462462
}
463+
// If the original function is a semantic member accessor, do standard
464+
// cloning. Semantic member accessors have special pullback generation
465+
// logic, so all `end_apply` instructions can be directly cloned to the VJP.
466+
if (isSemanticMemberAccessor(original)) {
467+
LLVM_DEBUG(getADDebugStream()
468+
<< "Cloning `end_apply` in semantic member accessor:\n"
469+
<< *eai << '\n');
470+
TypeSubstCloner::visitEndApplyInst(eai);
471+
return;
472+
}
463473

464474
Builder.setCurrentDebugScope(getOpScope(eai->getDebugScope()));
465475
auto loc = eai->getLoc();
@@ -607,6 +617,16 @@ class VJPCloner::Implementation final
607617
TypeSubstCloner::visitBeginApplyInst(bai);
608618
return;
609619
}
620+
// If the original function is a semantic member accessor, do standard
621+
// cloning. Semantic member accessors have special pullback generation
622+
// logic, so all `begin_apply` instructions can be directly cloned to the VJP.
623+
if (isSemanticMemberAccessor(original)) {
624+
LLVM_DEBUG(getADDebugStream()
625+
<< "Cloning `begin_apply` in semantic member accessor:\n"
626+
<< *bai << '\n');
627+
TypeSubstCloner::visitBeginApplyInst(bai);
628+
return;
629+
}
610630

611631
Builder.setCurrentDebugScope(getOpScope(bai->getDebugScope()));
612632
auto loc = bai->getLoc();

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -669,9 +669,7 @@ extension DifferentiableWrapper: Differentiable where Value: Differentiable {}
669669
// accesses.
670670

671671
struct Struct: Differentiable {
672-
// expected-error @+4 {{expression is not differentiable}}
673-
// expected-error @+3 {{expression is not differentiable}}
674-
// expected-note @+2 {{cannot differentiate access to property 'Struct._x' because 'Struct.TangentVector' does not have a stored property named '_x'}}
672+
// expected-error @+2 {{expression is not differentiable}}
675673
// expected-note @+1 {{cannot differentiate access to property 'Struct._x' because 'Struct.TangentVector' does not have a stored property named '_x'}}
676674
@DifferentiableWrapper @DifferentiableWrapper var x: Float = 10
677675

test/AutoDiff/validation-test/property_wrappers.swift

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ struct Wrapper<Value> {
1919
var wrappedValue: Value { // computed property
2020
get { value }
2121
set { value = newValue }
22+
_modify { yield &value }
2223
}
2324

2425
init(wrappedValue: Value) {
@@ -46,16 +47,13 @@ PropertyWrapperTests.test("SimpleStruct") {
4647
expectEqual((.init(x: 60, y: 0, z: 20), 300),
4748
gradient(at: Struct(), 2, of: setter))
4849

49-
// TODO: Support `modify` accessors (https://github.com/apple/swift/issues/55084).
50-
/*
5150
func modify(_ s: Struct, _ x: Tracked<Float>) -> Tracked<Float> {
5251
var s = s
5352
s.x *= x * s.z
5453
return s.x
5554
}
5655
expectEqual((.init(x: 60, y: 0, z: 20), 300),
5756
gradient(at: Struct(), 2, of: modify))
58-
*/
5957
}
6058

6159
struct GenericStruct<T> {
@@ -86,16 +84,13 @@ PropertyWrapperTests.test("GenericStruct") {
8684
expectEqual((.init(x: 60, y: 0, z: 20), 300),
8785
gradient(at: GenericStruct<Tracked<Float>>(y: 20), 2, of: setter))
8886

89-
// TODO: Support `modify` accessors (https://github.com/apple/swift/issues/55084).
90-
/*
9187
func modify<T>(_ s: GenericStruct<T>, _ x: Tracked<Float>) -> Tracked<Float> {
9288
var s = s
9389
s.x *= x * s.z
9490
return s.x
9591
}
9692
expectEqual((.init(x: 60, y: 0, z: 20), 300),
9793
gradient(at: GenericStruct<Tracked<Float>>(y: 1), 2, of: modify))
98-
*/
9994
}
10095

10196
// TF-1149: Test class with loadable type but address-only `TangentVector` type.
@@ -131,16 +126,18 @@ PropertyWrapperTests.test("SimpleClass") {
131126
gradient(at: Class(), 2, of: setter))
132127
*/
133128

134-
// TODO: Support `modify` accessors (https://github.com/apple/swift/issues/55084).
135-
/*
129+
// FIXME(TF-1175): Same issue as above
136130
func modify(_ c: Class, _ x: Tracked<Float>) -> Tracked<Float> {
137131
var c = c
138132
c.x *= x * c.z
139133
return c.x
140134
}
135+
/*
141136
expectEqual((.init(x: 60, y: 0, z: 20), 300),
142137
gradient(at: Class(), 2, of: modify))
143138
*/
139+
expectEqual((.init(x: 1, y: 0, z: 0), 0),
140+
gradient(at: Class(), 2, of: modify))
144141
}
145142

146143
// From: https://github.com/apple/swift-evolution/blob/master/proposals/0258-property-wrappers.md#proposed-solution
@@ -157,12 +154,13 @@ enum Lazy<Value> {
157154

158155
var wrappedValue: Value {
159156
// TODO(TF-1250): Replace with actual mutating getter implementation.
160-
// Requires differentiation to support functions with multiple results.
161-
get {
157+
// Requires support for mutating semantic member accessor
158+
/* mutating */ get {
162159
switch self {
163160
case .uninitialized(let initializer):
164161
let value = initializer()
165162
// NOTE: Actual implementation assigns to `self` here.
163+
// self = .initialized(value)
166164
return value
167165
case .initialized(let value):
168166
return value

0 commit comments

Comments
 (0)