@@ -2095,6 +2095,23 @@ PullbackCloner::PullbackCloner(VJPCloner &vjpCloner)
2095
2095
2096
2096
PullbackCloner::~PullbackCloner () { delete &impl; }
2097
2097
2098
+ static SILValue getArrayValue (ApplyInst *ai) {
2099
+ SILValue arrayValue;
2100
+ for (auto use : ai->getUses ()) {
2101
+ auto *dti = dyn_cast<DestructureTupleInst>(use->getUser ());
2102
+ if (!dti)
2103
+ continue ;
2104
+ DEBUG_ASSERT (!arrayValue && " Array value already found" );
2105
+ // The first `destructure_tuple` result is the `Array` value.
2106
+ arrayValue = dti->getResult (0 );
2107
+ #ifndef DEBUG_ASSERT_enabled
2108
+ break ;
2109
+ #endif
2110
+ }
2111
+ ASSERT (arrayValue);
2112
+ return arrayValue;
2113
+ }
2114
+
2098
2115
// --------------------------------------------------------------------------//
2099
2116
// Entry point
2100
2117
// --------------------------------------------------------------------------//
@@ -2439,6 +2456,134 @@ bool PullbackCloner::Implementation::run() {
2439
2456
// Visit original blocks in post-order and perform differentiation
2440
2457
// in corresponding pullback blocks. If errors occurred, back out.
2441
2458
else {
2459
+ LLVM_DEBUG (getADDebugStream ()
2460
+ << " Begin search for adjoints of loop-local active values\n " );
2461
+ llvm::DenseMap<const SILLoop *, llvm::DenseSet<SILValue>>
2462
+ loopLocalActiveValues;
2463
+ for (auto *bb : originalBlocks) {
2464
+ const SILLoop *loop = vjpCloner.getLoopInfo ()->getLoopFor (bb);
2465
+ if (loop == nullptr )
2466
+ continue ;
2467
+ SILBasicBlock *loopHeader = loop->getHeader ();
2468
+ SILBasicBlock *pbLoopHeader = getPullbackBlock (loopHeader);
2469
+ LLVM_DEBUG (getADDebugStream ()
2470
+ << " Original bb" << bb->getDebugID ()
2471
+ << " belongs to a loop, original header bb"
2472
+ << loopHeader->getDebugID () << " , pullback header bb"
2473
+ << pbLoopHeader->getDebugID () << ' \n ' );
2474
+ builder.setInsertionPoint (pbLoopHeader);
2475
+ auto bbActiveValuesIt = activeValues.find (bb);
2476
+ if (bbActiveValuesIt == activeValues.end ())
2477
+ continue ;
2478
+ const auto &bbActiveValues = bbActiveValuesIt->second ;
2479
+ for (SILValue bbActiveValue : bbActiveValues) {
2480
+ if (vjpCloner.getLoopInfo ()->getLoopFor (
2481
+ bbActiveValue->getParentBlock ()) != loop) {
2482
+ LLVM_DEBUG (
2483
+ getADDebugStream ()
2484
+ << " The following active value is NOT loop-local, skipping: "
2485
+ << bbActiveValue);
2486
+ continue ;
2487
+ }
2488
+
2489
+ auto [_, wasInserted] =
2490
+ loopLocalActiveValues[loop].insert (bbActiveValue);
2491
+ LLVM_DEBUG (getADDebugStream ()
2492
+ << " The following active value is loop-local, " );
2493
+ if (!wasInserted) {
2494
+ LLVM_DEBUG (llvm::dbgs () << " but it was already processed, skipping: "
2495
+ << bbActiveValue);
2496
+ continue ;
2497
+ }
2498
+
2499
+ if (getTangentValueCategory (bbActiveValue) ==
2500
+ SILValueCategory::Object) {
2501
+ LLVM_DEBUG (llvm::dbgs ()
2502
+ << " zeroing its adjoint value in loop header: "
2503
+ << bbActiveValue);
2504
+ setAdjointValue (bb, bbActiveValue,
2505
+ makeZeroAdjointValue (getRemappedTangentType (
2506
+ bbActiveValue->getType ())));
2507
+ continue ;
2508
+ }
2509
+
2510
+ ASSERT (getTangentValueCategory (bbActiveValue) ==
2511
+ SILValueCategory::Address);
2512
+
2513
+ // getAdjointProjection might call materializeAdjointDirect which
2514
+ // writes to debug output, emit \n.
2515
+ LLVM_DEBUG (llvm::dbgs ()
2516
+ << " checking if it's adjoint is a projection\n " );
2517
+
2518
+ if (!getAdjointProjection (bb, bbActiveValue)) {
2519
+ LLVM_DEBUG (getADDebugStream ()
2520
+ << " Adjoint for the following value is NOT a projection, "
2521
+ " zeroing its adjoint buffer in loop header: "
2522
+ << bbActiveValue);
2523
+
2524
+ // All adjoint buffers are allocated in the pullback entry and
2525
+ // deallocated in the pullback exit. So, use IsNotInitialization to
2526
+ // emit destroy_addr before zeroing the buffer.
2527
+ ASSERT (bufferMap.contains ({bb, bbActiveValue}));
2528
+ builder.emitZeroIntoBuffer (pbLoc, getAdjointBuffer (bb, bbActiveValue),
2529
+ IsNotInitialization);
2530
+
2531
+ continue ;
2532
+ }
2533
+
2534
+ LLVM_DEBUG (getADDebugStream ()
2535
+ << " Adjoint for the following value is a projection, " );
2536
+
2537
+ // If Projection::isAddressProjection(v) is true for a value v, it
2538
+ // is not added to active values list (see recordValueIfActive).
2539
+ //
2540
+ // Ensure that only the following value types conforming to
2541
+ // getAdjointProjection but not conforming to
2542
+ // Projection::isAddressProjection can go here.
2543
+ //
2544
+ // Instructions conforming to Projection::isAddressProjection and
2545
+ // thus never corresponding to an active value do not need any
2546
+ // handling, because only active values can have adjoints from
2547
+ // previous iterations propagated via BB arguments.
2548
+ do {
2549
+ // Consider '%X = begin_access [modify] [static] %Y'.
2550
+ // 1. If %Y is loop-local, it's adjoint buffer will
2551
+ // be zeroed, and we'll have zero adjoint projection to it.
2552
+ // 2. Otherwise, we do not need to zero the projection buffer.
2553
+ // Thus, we can just skip.
2554
+ if (dyn_cast<BeginAccessInst>(bbActiveValue)) {
2555
+ LLVM_DEBUG (llvm::dbgs () << " skipping: " << bbActiveValue);
2556
+ break ;
2557
+ }
2558
+
2559
+ // Consider the following sequence:
2560
+ // %1 = function_ref @allocUninitArray
2561
+ // %2 = apply %1<Float>(%0)
2562
+ // (%3, %4) = destructure_tuple %2
2563
+ // %5 = mark_dependence %4 on %3
2564
+ // %6 = pointer_to_address %6 to [strict] $*Float
2565
+ // Since %6 is active, %3 (which is an array) must also be active.
2566
+ // Thus, adjoint for %3 will be zeroed if needed. Ensure that expected
2567
+ // invariants hold and then skip.
2568
+ if (auto *ai = getAllocateUninitializedArrayIntrinsicElementAddress (
2569
+ bbActiveValue)) {
2570
+ ASSERT (isa<PointerToAddressInst>(bbActiveValue));
2571
+ SILValue arrayValue = getArrayValue (ai);
2572
+ ASSERT (llvm::find (bbActiveValues, arrayValue) !=
2573
+ bbActiveValues.end ());
2574
+ ASSERT (vjpCloner.getLoopInfo ()->getLoopFor (
2575
+ arrayValue->getParentBlock ()) == loop);
2576
+ LLVM_DEBUG (llvm::dbgs () << " skipping: " << bbActiveValue);
2577
+ break ;
2578
+ }
2579
+
2580
+ ASSERT (false );
2581
+ } while (false );
2582
+ }
2583
+ }
2584
+ LLVM_DEBUG (getADDebugStream ()
2585
+ << " End search for adjoints of loop-local active values\n " );
2586
+
2442
2587
for (auto *bb : originalBlocks) {
2443
2588
visitSILBasicBlock (bb);
2444
2589
if (errorOccurred)
@@ -3339,19 +3484,9 @@ SILValue PullbackCloner::Implementation::getAdjointProjection(
3339
3484
eltIndex = ili->getValue ().getLimitedValue ();
3340
3485
}
3341
3486
// Get the array adjoint value.
3342
- SILValue arrayAdjoint;
3343
- assert (ai && " Expected `array.uninitialized_intrinsic` application" );
3344
- for (auto use : ai->getUses ()) {
3345
- auto *dti = dyn_cast<DestructureTupleInst>(use->getUser ());
3346
- if (!dti)
3347
- continue ;
3348
- assert (!arrayAdjoint && " Array adjoint already found" );
3349
- // The first `destructure_tuple` result is the `Array` value.
3350
- auto arrayValue = dti->getResult (0 );
3351
- arrayAdjoint = materializeAdjointDirect (
3352
- getAdjointValue (origBB, arrayValue), definingInst->getLoc ());
3353
- }
3354
- assert (arrayAdjoint && " Array does not have adjoint value" );
3487
+ SILValue arrayValue = getArrayValue (ai);
3488
+ SILValue arrayAdjoint = materializeAdjointDirect (
3489
+ getAdjointValue (origBB, arrayValue), definingInst->getLoc ());
3355
3490
// Apply `Array.TangentVector.subscript` to get array element adjoint value.
3356
3491
auto *eltAdjBuffer =
3357
3492
getArrayAdjointElementBuffer (arrayAdjoint, eltIndex, ai->getLoc ());
0 commit comments