Skip to content

Commit c675009

Browse files
committed
Correctly synthesize semantic member _modify accessor. Support
differentiation of _modify accessor for wrapped values. Fixes #55084
1 parent 4a9d021 commit c675009

File tree

5 files changed

+115
-19
lines changed

5 files changed

+115
-19
lines changed

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ bool isSemanticMemberAccessor(SILFunction *original) {
6161
auto *accessor = dyn_cast<AccessorDecl>(decl);
6262
if (!accessor)
6363
return false;
64-
// Currently, only getters and setters are supported.
65-
// TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors.
64+
// Currently, only getters, setters and _modify accessors are supported.
6665
if (accessor->getAccessorKind() != AccessorKind::Get &&
67-
accessor->getAccessorKind() != AccessorKind::Set)
66+
accessor->getAccessorKind() != AccessorKind::Set &&
67+
accessor->getAccessorKind() != AccessorKind::Modify)
6868
return false;
6969
// Accessor must come from a `var` declaration.
7070
auto *varDecl = dyn_cast<VarDecl>(accessor->getStorage());

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ class PullbackCloner::Implementation final
905905
bool runForSemanticMemberAccessor();
906906
bool runForSemanticMemberGetter();
907907
bool runForSemanticMemberSetter();
908+
bool runForSemanticMemberModify();
908909

909910
/// If original result is non-varied, it will always have a zero derivative.
910911
/// Skip full pullback generation and simply emit zero derivatives for wrt
@@ -2464,7 +2465,8 @@ bool PullbackCloner::Implementation::run() {
24642465

24652466
// If the original function is an accessor with special-case pullback
24662467
// generation logic, do special-case generation.
2467-
if (isSemanticMemberAccessor(&original)) {
2468+
bool isSemanticMemberAcc = isSemanticMemberAccessor(&original);
2469+
if (isSemanticMemberAcc) {
24682470
if (runForSemanticMemberAccessor())
24692471
return true;
24702472
}
@@ -2614,7 +2616,8 @@ bool PullbackCloner::Implementation::run() {
26142616
#endif
26152617

26162618
LLVM_DEBUG(getADDebugStream()
2617-
<< "Generated pullback for " << original.getName() << ":\n"
2619+
<< "Generated " << (isSemanticMemberAcc ? "semantic member accessor" : "normal")
2620+
<< " pullback for " << original.getName() << ":\n"
26182621
<< pullback);
26192622
return errorOccurred;
26202623
}
@@ -3091,7 +3094,8 @@ bool PullbackCloner::Implementation::runForSemanticMemberAccessor() {
30913094
return runForSemanticMemberGetter();
30923095
case AccessorKind::Set:
30933096
return runForSemanticMemberSetter();
3094-
// TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors.
3097+
case AccessorKind::Modify:
3098+
return runForSemanticMemberModify();
30953099
default:
30963100
llvm_unreachable("Unsupported accessor kind; inconsistent with "
30973101
"`isSemanticMemberAccessor`?");
@@ -3275,6 +3279,82 @@ bool PullbackCloner::Implementation::runForSemanticMemberSetter() {
32753279
return false;
32763280
}
32773281

3282+
bool PullbackCloner::Implementation::runForSemanticMemberModify() {
3283+
auto &original = getOriginal();
3284+
auto &pullback = getPullback();
3285+
auto pbLoc = getPullback().getLocation();
3286+
3287+
auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
3288+
assert(accessor->getAccessorKind() == AccessorKind::Modify);
3289+
3290+
auto *origEntry = original.getEntryBlock();
3291+
// We assume that the accessor has a simple 3-BB structure with yield in the entry BB
3292+
// plus resume and unwind BBs
3293+
auto *yi = cast<YieldInst>(origEntry->getTerminator());
3294+
auto *origResumeBB = yi->getResumeBB();
3295+
3296+
auto *pbEntry = pullback.getEntryBlock();
3297+
builder.setCurrentDebugScope(
3298+
remapScope(origEntry->getScopeOfFirstNonMetaInstruction()));
3299+
builder.setInsertionPoint(pbEntry);
3300+
3301+
// Get _modify accessor argument values.
3302+
// Accessor type : $(inout Self) -> @yields @inout Argument
3303+
// Pullback type : $(inout Self', linear map tuple) -> @yields @inout Argument'
3304+
// Normally pullbacks for semantic member accessors are single BB and therefore
3305+
// has empty linear map tuple, however, coroutines has a branching control flow
3306+
// due to possible coroutine abort, so we need to accommodate for this. We keep branch
3307+
// tracing enums in order not to special case in many other places. As there is no way
3308+
// to return to coroutine via abort exit, we essentially "linearize" a coroutine.
3309+
auto loweredFnTy = original.getLoweredFunctionType();
3310+
auto pullbackLoweredFnTy = pullback.getLoweredFunctionType();
3311+
3312+
assert(loweredFnTy->getNumParameters() == 1 &&
3313+
loweredFnTy->getNumYields() == 1);
3314+
assert(pullbackLoweredFnTy->getNumParameters() == 2);
3315+
assert(pullbackLoweredFnTy->getNumYields() == 1);
3316+
3317+
SILValue origSelf = original.getArgumentsWithoutIndirectResults().front();
3318+
3319+
SmallVector<SILValue, 8> origFormalResults;
3320+
collectAllFormalResultsInTypeOrder(original, origFormalResults);
3321+
3322+
assert(getConfig().resultIndices->getNumIndices() == 2 &&
3323+
"Modify accessor should have two semantic results");
3324+
3325+
auto origYield = origFormalResults[*std::next(getConfig().resultIndices->begin())];
3326+
3327+
// Look up the corresponding field in the tangent space.
3328+
auto *origField = cast<VarDecl>(accessor->getStorage());
3329+
auto baseType = remapType(origSelf->getType()).getASTType();
3330+
auto *tanField = getTangentStoredProperty(getContext(), origField, baseType,
3331+
pbLoc, getInvoker());
3332+
if (!tanField) {
3333+
errorOccurred = true;
3334+
return true;
3335+
}
3336+
3337+
auto adjSelf = getAdjointBuffer(origResumeBB, origSelf);
3338+
auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField);
3339+
// Modify accessors have inout yields and therefore should yield addresses.
3340+
assert(getTangentValueCategory(origYield) == SILValueCategory::Address &&
3341+
"Modify accessors should yield indirect");
3342+
3343+
// Yield the adjoint buffer and do everything else in the resume
3344+
// destination. Unwind destination is unreachable as the coroutine can never
3345+
// be aborted.
3346+
auto *unwindBB = getPullback().createBasicBlock();
3347+
auto *resumeBB = getPullbackBlock(origEntry);
3348+
builder.createYield(yi->getLoc(), {adjSelfElt}, resumeBB, unwindBB);
3349+
builder.setInsertionPoint(unwindBB);
3350+
builder.createUnreachable(SILLocation::invalid());
3351+
3352+
builder.setInsertionPoint(resumeBB);
3353+
addToAdjointBuffer(origEntry, origSelf, adjSelf, pbLoc);
3354+
3355+
return false;
3356+
}
3357+
32783358
//--------------------------------------------------------------------------//
32793359
// Adjoint buffer mapping
32803360
//--------------------------------------------------------------------------//

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,16 @@ class VJPCloner::Implementation final
457457
TypeSubstCloner::visitEndApplyInst(eai);
458458
return;
459459
}
460+
// If the original function is a semantic member accessor, do standard
461+
// cloning. Semantic member accessors have special pullback generation
462+
// logic, so all `end_apply` instructions can be directly cloned to the VJP.
463+
if (isSemanticMemberAccessor(original)) {
464+
LLVM_DEBUG(getADDebugStream()
465+
<< "Cloning `end_apply` in semantic member accessor:\n"
466+
<< *eai << '\n');
467+
TypeSubstCloner::visitEndApplyInst(eai);
468+
return;
469+
}
460470

461471
Builder.setCurrentDebugScope(getOpScope(eai->getDebugScope()));
462472
auto loc = eai->getLoc();
@@ -604,6 +614,16 @@ class VJPCloner::Implementation final
604614
TypeSubstCloner::visitBeginApplyInst(bai);
605615
return;
606616
}
617+
// If the original function is a semantic member accessor, do standard
618+
// cloning. Semantic member accessors have special pullback generation
619+
// logic, so all `begin_apply` instructions can be directly cloned to the VJP.
620+
if (isSemanticMemberAccessor(original)) {
621+
LLVM_DEBUG(getADDebugStream()
622+
<< "Cloning `begin_apply` in semantic member accessor:\n"
623+
<< *bai << '\n');
624+
TypeSubstCloner::visitBeginApplyInst(bai);
625+
return;
626+
}
607627

608628
Builder.setCurrentDebugScope(getOpScope(bai->getDebugScope()));
609629
auto loc = bai->getLoc();

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -669,9 +669,7 @@ extension DifferentiableWrapper: Differentiable where Value: Differentiable {}
669669
// accesses.
670670

671671
struct Struct: Differentiable {
672-
// expected-error @+4 {{expression is not differentiable}}
673-
// expected-error @+3 {{expression is not differentiable}}
674-
// expected-note @+2 {{cannot differentiate access to property 'Struct._x' because 'Struct.TangentVector' does not have a stored property named '_x'}}
672+
// expected-error @+2 {{expression is not differentiable}}
675673
// expected-note @+1 {{cannot differentiate access to property 'Struct._x' because 'Struct.TangentVector' does not have a stored property named '_x'}}
676674
@DifferentiableWrapper @DifferentiableWrapper var x: Float = 10
677675

test/AutoDiff/validation-test/property_wrappers.swift

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ struct Wrapper<Value> {
1919
var wrappedValue: Value { // computed property
2020
get { value }
2121
set { value = newValue }
22+
_modify { yield &value }
2223
}
2324

2425
init(wrappedValue: Value) {
@@ -46,16 +47,13 @@ PropertyWrapperTests.test("SimpleStruct") {
4647
expectEqual((.init(x: 60, y: 0, z: 20), 300),
4748
gradient(at: Struct(), 2, of: setter))
4849

49-
// TODO: Support `modify` accessors (https://github.com/apple/swift/issues/55084).
50-
/*
5150
func modify(_ s: Struct, _ x: Tracked<Float>) -> Tracked<Float> {
5251
var s = s
5352
s.x *= x * s.z
5453
return s.x
5554
}
5655
expectEqual((.init(x: 60, y: 0, z: 20), 300),
5756
gradient(at: Struct(), 2, of: modify))
58-
*/
5957
}
6058

6159
struct GenericStruct<T> {
@@ -86,16 +84,13 @@ PropertyWrapperTests.test("GenericStruct") {
8684
expectEqual((.init(x: 60, y: 0, z: 20), 300),
8785
gradient(at: GenericStruct<Tracked<Float>>(y: 20), 2, of: setter))
8886

89-
// TODO: Support `modify` accessors (https://github.com/apple/swift/issues/55084).
90-
/*
9187
func modify<T>(_ s: GenericStruct<T>, _ x: Tracked<Float>) -> Tracked<Float> {
9288
var s = s
9389
s.x *= x * s.z
9490
return s.x
9591
}
9692
expectEqual((.init(x: 60, y: 0, z: 20), 300),
9793
gradient(at: GenericStruct<Tracked<Float>>(y: 1), 2, of: modify))
98-
*/
9994
}
10095

10196
// TF-1149: Test class with loadable type but address-only `TangentVector` type.
@@ -131,16 +126,18 @@ PropertyWrapperTests.test("SimpleClass") {
131126
gradient(at: Class(), 2, of: setter))
132127
*/
133128

134-
// TODO: Support `modify` accessors (https://github.com/apple/swift/issues/55084).
135-
/*
129+
// FIXME(TF-1175): Same issue as above
136130
func modify(_ c: Class, _ x: Tracked<Float>) -> Tracked<Float> {
137131
var c = c
138132
c.x *= x * c.z
139133
return c.x
140134
}
135+
/*
141136
expectEqual((.init(x: 60, y: 0, z: 20), 300),
142137
gradient(at: Class(), 2, of: modify))
143138
*/
139+
expectEqual((.init(x: 1, y: 0, z: 0), 0),
140+
gradient(at: Class(), 2, of: modify))
144141
}
145142

146143
// From: https://github.com/apple/swift-evolution/blob/master/proposals/0258-property-wrappers.md#proposed-solution
@@ -157,12 +154,13 @@ enum Lazy<Value> {
157154

158155
var wrappedValue: Value {
159156
// TODO(TF-1250): Replace with actual mutating getter implementation.
160-
// Requires differentiation to support functions with multiple results.
161-
get {
157+
// Requires support for mutating semantic member accessor
158+
/* mutating */ get {
162159
switch self {
163160
case .uninitialized(let initializer):
164161
let value = initializer()
165162
// NOTE: Actual implementation assigns to `self` here.
163+
// self = .initialized(value)
166164
return value
167165
case .initialized(let value):
168166
return value

0 commit comments

Comments
 (0)