@@ -905,6 +905,7 @@ class PullbackCloner::Implementation final
905
905
bool runForSemanticMemberAccessor ();
906
906
bool runForSemanticMemberGetter ();
907
907
bool runForSemanticMemberSetter ();
908
+ bool runForSemanticMemberModify ();
908
909
909
910
// / If original result is non-varied, it will always have a zero derivative.
910
911
// / Skip full pullback generation and simply emit zero derivatives for wrt
@@ -2464,7 +2465,8 @@ bool PullbackCloner::Implementation::run() {
2464
2465
2465
2466
// If the original function is an accessor with special-case pullback
2466
2467
// generation logic, do special-case generation.
2467
- if (isSemanticMemberAccessor (&original)) {
2468
+ bool isSemanticMemberAcc = isSemanticMemberAccessor (&original);
2469
+ if (isSemanticMemberAcc) {
2468
2470
if (runForSemanticMemberAccessor ())
2469
2471
return true ;
2470
2472
}
@@ -2614,7 +2616,8 @@ bool PullbackCloner::Implementation::run() {
2614
2616
#endif
2615
2617
2616
2618
LLVM_DEBUG (getADDebugStream ()
2617
- << " Generated pullback for " << original.getName () << " :\n "
2619
+ << " Generated " << (isSemanticMemberAcc ? " semantic member accessor" : " normal" )
2620
+ << " pullback for " << original.getName () << " :\n "
2618
2621
<< pullback);
2619
2622
return errorOccurred;
2620
2623
}
@@ -3091,7 +3094,8 @@ bool PullbackCloner::Implementation::runForSemanticMemberAccessor() {
3091
3094
return runForSemanticMemberGetter ();
3092
3095
case AccessorKind::Set:
3093
3096
return runForSemanticMemberSetter ();
3094
- // TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors.
3097
+ case AccessorKind::Modify:
3098
+ return runForSemanticMemberModify ();
3095
3099
default :
3096
3100
llvm_unreachable (" Unsupported accessor kind; inconsistent with "
3097
3101
" `isSemanticMemberAccessor`?" );
@@ -3275,6 +3279,82 @@ bool PullbackCloner::Implementation::runForSemanticMemberSetter() {
3275
3279
return false ;
3276
3280
}
3277
3281
3282
+ bool PullbackCloner::Implementation::runForSemanticMemberModify () {
3283
+ auto &original = getOriginal ();
3284
+ auto &pullback = getPullback ();
3285
+ auto pbLoc = getPullback ().getLocation ();
3286
+
3287
+ auto *accessor = cast<AccessorDecl>(original.getDeclContext ()->getAsDecl ());
3288
+ assert (accessor->getAccessorKind () == AccessorKind::Modify);
3289
+
3290
+ auto *origEntry = original.getEntryBlock ();
3291
+ // We assume that the accessor has a simple 3-BB structure with yield in the entry BB
3292
+ // plus resume and unwind BBs
3293
+ auto *yi = cast<YieldInst>(origEntry->getTerminator ());
3294
+ auto *origResumeBB = yi->getResumeBB ();
3295
+
3296
+ auto *pbEntry = pullback.getEntryBlock ();
3297
+ builder.setCurrentDebugScope (
3298
+ remapScope (origEntry->getScopeOfFirstNonMetaInstruction ()));
3299
+ builder.setInsertionPoint (pbEntry);
3300
+
3301
+ // Get _modify accessor argument values.
3302
+ // Accessor type : $(inout Self) -> @yields @inout Argument
3303
+ // Pullback type : $(inout Self', linear map tuple) -> @yields @inout Argument'
3304
+ // Normally pullbacks for semantic member accessors are single BB and therefore
3305
+ // has empty linear map tuple, however, coroutines has a branching control flow
3306
+ // due to possible coroutine abort, so we need to accommodate for this. We keep branch
3307
+ // tracing enums in order not to special case in many other places. As there is no way
3308
+ // to return to coroutine via abort exit, we essentially "linearize" a coroutine.
3309
+ auto loweredFnTy = original.getLoweredFunctionType ();
3310
+ auto pullbackLoweredFnTy = pullback.getLoweredFunctionType ();
3311
+
3312
+ assert (loweredFnTy->getNumParameters () == 1 &&
3313
+ loweredFnTy->getNumYields () == 1 );
3314
+ assert (pullbackLoweredFnTy->getNumParameters () == 2 );
3315
+ assert (pullbackLoweredFnTy->getNumYields () == 1 );
3316
+
3317
+ SILValue origSelf = original.getArgumentsWithoutIndirectResults ().front ();
3318
+
3319
+ SmallVector<SILValue, 8 > origFormalResults;
3320
+ collectAllFormalResultsInTypeOrder (original, origFormalResults);
3321
+
3322
+ assert (getConfig ().resultIndices ->getNumIndices () == 2 &&
3323
+ " Modify accessor should have two semantic results" );
3324
+
3325
+ auto origYield = origFormalResults[*std::next (getConfig ().resultIndices ->begin ())];
3326
+
3327
+ // Look up the corresponding field in the tangent space.
3328
+ auto *origField = cast<VarDecl>(accessor->getStorage ());
3329
+ auto baseType = remapType (origSelf->getType ()).getASTType ();
3330
+ auto *tanField = getTangentStoredProperty (getContext (), origField, baseType,
3331
+ pbLoc, getInvoker ());
3332
+ if (!tanField) {
3333
+ errorOccurred = true ;
3334
+ return true ;
3335
+ }
3336
+
3337
+ auto adjSelf = getAdjointBuffer (origResumeBB, origSelf);
3338
+ auto *adjSelfElt = builder.createStructElementAddr (pbLoc, adjSelf, tanField);
3339
+ // Modify accessors have inout yields and therefore should yield addresses.
3340
+ assert (getTangentValueCategory (origYield) == SILValueCategory::Address &&
3341
+ " Modify accessors should yield indirect" );
3342
+
3343
+ // Yield the adjoint buffer and do everything else in the resume
3344
+ // destination. Unwind destination is unreachable as the coroutine can never
3345
+ // be aborted.
3346
+ auto *unwindBB = getPullback ().createBasicBlock ();
3347
+ auto *resumeBB = getPullbackBlock (origEntry);
3348
+ builder.createYield (yi->getLoc (), {adjSelfElt}, resumeBB, unwindBB);
3349
+ builder.setInsertionPoint (unwindBB);
3350
+ builder.createUnreachable (SILLocation::invalid ());
3351
+
3352
+ builder.setInsertionPoint (resumeBB);
3353
+ addToAdjointBuffer (origEntry, origSelf, adjSelf, pbLoc);
3354
+
3355
+ return false ;
3356
+ }
3357
+
3278
3358
// --------------------------------------------------------------------------//
3279
3359
// Adjoint buffer mapping
3280
3360
// --------------------------------------------------------------------------//
0 commit comments