@@ -905,6 +905,7 @@ class PullbackCloner::Implementation final
905
905
bool runForSemanticMemberAccessor ();
906
906
bool runForSemanticMemberGetter ();
907
907
bool runForSemanticMemberSetter ();
908
+ bool runForSemanticMemberModify ();
908
909
909
910
// / If original result is non-varied, it will always have a zero derivative.
910
911
// / Skip full pullback generation and simply emit zero derivatives for wrt
@@ -2431,7 +2432,8 @@ bool PullbackCloner::Implementation::run() {
2431
2432
2432
2433
// If the original function is an accessor with special-case pullback
2433
2434
// generation logic, do special-case generation.
2434
- if (isSemanticMemberAccessor (&original)) {
2435
+ bool isSemanticMemberAcc = isSemanticMemberAccessor (&original);
2436
+ if (isSemanticMemberAcc) {
2435
2437
if (runForSemanticMemberAccessor ())
2436
2438
return true ;
2437
2439
}
@@ -2581,7 +2583,8 @@ bool PullbackCloner::Implementation::run() {
2581
2583
#endif
2582
2584
2583
2585
LLVM_DEBUG (getADDebugStream ()
2584
- << " Generated pullback for " << original.getName () << " :\n "
2586
+ << " Generated " << (isSemanticMemberAcc ? " semantic member accessor" : " normal" )
2587
+ << " pullback for " << original.getName () << " :\n "
2585
2588
<< pullback);
2586
2589
return errorOccurred;
2587
2590
}
@@ -3043,7 +3046,8 @@ bool PullbackCloner::Implementation::runForSemanticMemberAccessor() {
3043
3046
return runForSemanticMemberGetter ();
3044
3047
case AccessorKind::Set:
3045
3048
return runForSemanticMemberSetter ();
3046
- // TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors.
3049
+ case AccessorKind::Modify:
3050
+ return runForSemanticMemberModify ();
3047
3051
default :
3048
3052
llvm_unreachable (" Unsupported accessor kind; inconsistent with "
3049
3053
" `isSemanticMemberAccessor`?" );
@@ -3227,6 +3231,83 @@ bool PullbackCloner::Implementation::runForSemanticMemberSetter() {
3227
3231
return false ;
3228
3232
}
3229
3233
3234
+ bool PullbackCloner::Implementation::runForSemanticMemberModify () {
3235
+ auto &original = getOriginal ();
3236
+ auto &pullback = getPullback ();
3237
+ auto pbLoc = getPullback ().getLocation ();
3238
+
3239
+ auto *accessor = cast<AccessorDecl>(original.getDeclContext ()->getAsDecl ());
3240
+ assert (accessor->getAccessorKind () == AccessorKind::Modify);
3241
+
3242
+ auto *origEntry = original.getEntryBlock ();
3243
+ // We assume that the accessor has a simple 3-BB structure with yield in the entry BB
3244
+ // plus resume and unwind BBs
3245
+ auto *yi = cast<YieldInst>(origEntry->getTerminator ());
3246
+ auto *origResumeBB = yi->getResumeBB ();
3247
+
3248
+ auto *pbEntry = pullback.getEntryBlock ();
3249
+ builder.setCurrentDebugScope (
3250
+ remapScope (origEntry->getScopeOfFirstNonMetaInstruction ()));
3251
+ builder.setInsertionPoint (pbEntry);
3252
+
3253
+ // Get _modify accessor argument values.
3254
+ // Accessor type : $(inout Self) -> @yields @inout Argument
3255
+ // Pullback type : $(inout Self', linear map tuple) -> @yields @inout Argument'
3256
+ // Normally pullbacks for semantic member accessors are single BB and
3257
+ // therefore have empty linear map tuple, however, coroutines have a branching
3258
+ // control flow due to possible coroutine abort, so we need to accommodate for
3259
+ // this. We keep branch tracing enums in order not to special case in many
3260
+ // other places. As there is no way to return to coroutine via abort exit, we
3261
+ // essentially "linearize" a coroutine.
3262
+ auto loweredFnTy = original.getLoweredFunctionType ();
3263
+ auto pullbackLoweredFnTy = pullback.getLoweredFunctionType ();
3264
+
3265
+ assert (loweredFnTy->getNumParameters () == 1 &&
3266
+ loweredFnTy->getNumYields () == 1 );
3267
+ assert (pullbackLoweredFnTy->getNumParameters () == 2 );
3268
+ assert (pullbackLoweredFnTy->getNumYields () == 1 );
3269
+
3270
+ SILValue origSelf = original.getArgumentsWithoutIndirectResults ().front ();
3271
+
3272
+ SmallVector<SILValue, 8 > origFormalResults;
3273
+ collectAllFormalResultsInTypeOrder (original, origFormalResults);
3274
+
3275
+ assert (getConfig ().resultIndices ->getNumIndices () == 2 &&
3276
+ " Modify accessor should have two semantic results" );
3277
+
3278
+ auto origYield = origFormalResults[*std::next (getConfig ().resultIndices ->begin ())];
3279
+
3280
+ // Look up the corresponding field in the tangent space.
3281
+ auto *origField = cast<VarDecl>(accessor->getStorage ());
3282
+ auto baseType = remapType (origSelf->getType ()).getASTType ();
3283
+ auto *tanField = getTangentStoredProperty (getContext (), origField, baseType,
3284
+ pbLoc, getInvoker ());
3285
+ if (!tanField) {
3286
+ errorOccurred = true ;
3287
+ return true ;
3288
+ }
3289
+
3290
+ auto adjSelf = getAdjointBuffer (origResumeBB, origSelf);
3291
+ auto *adjSelfElt = builder.createStructElementAddr (pbLoc, adjSelf, tanField);
3292
+ // Modify accessors have inout yields and therefore should yield addresses.
3293
+ assert (getTangentValueCategory (origYield) == SILValueCategory::Address &&
3294
+ " Modify accessors should yield indirect" );
3295
+
3296
+ // Yield the adjoint buffer and do everything else in the resume
3297
+ // destination. Unwind destination is unreachable as the coroutine can never
3298
+ // be aborted.
3299
+ auto *unwindBB = getPullback ().createBasicBlock ();
3300
+ auto *resumeBB = getPullbackBlock (origEntry);
3301
+ builder.createYield (yi->getLoc (), {adjSelfElt}, resumeBB, unwindBB);
3302
+ builder.setInsertionPoint (unwindBB);
3303
+ builder.createUnreachable (SILLocation::invalid ());
3304
+
3305
+ builder.setInsertionPoint (resumeBB);
3306
+ addToAdjointBuffer (origEntry, origSelf, adjSelf, pbLoc);
3307
+
3308
+ return false ;
3309
+ }
3310
+
3230
3311
// --------------------------------------------------------------------------//
3231
3312
// Adjoint buffer mapping
3232
3313
// --------------------------------------------------------------------------//
0 commit comments