@@ -2128,6 +2128,23 @@ PullbackCloner::PullbackCloner(VJPCloner &vjpCloner)
2128
2128
2129
2129
PullbackCloner::~PullbackCloner () { delete &impl; }
2130
2130
2131
+ static SILValue getArrayValue (ApplyInst *ai) {
2132
+ SILValue arrayValue;
2133
+ for (auto use : ai->getUses ()) {
2134
+ auto *dti = dyn_cast<DestructureTupleInst>(use->getUser ());
2135
+ if (!dti)
2136
+ continue ;
2137
+ assert (!arrayValue && " Array value already found" );
2138
+ // The first `destructure_tuple` result is the `Array` value.
2139
+ arrayValue = dti->getResult (0 );
2140
+ #ifdef NDEBUG
2141
+ break ;
2142
+ #endif
2143
+ }
2144
+ assert (arrayValue);
2145
+ return arrayValue;
2146
+ }
2147
+
2131
2148
// --------------------------------------------------------------------------//
2132
2149
// Entry point
2133
2150
// --------------------------------------------------------------------------//
@@ -2472,6 +2489,133 @@ bool PullbackCloner::Implementation::run() {
2472
2489
// Visit original blocks in post-order and perform differentiation
2473
2490
// in corresponding pullback blocks. If errors occurred, back out.
2474
2491
else {
2492
+ LLVM_DEBUG (getADDebugStream ()
2493
+ << " Begin search for adjoints of loop-local active values\n " );
2494
+ llvm::DenseMap<const SILLoop *, llvm::DenseSet<SILValue>>
2495
+ loopLocalActiveValues;
2496
+ for (auto *bb : originalBlocks) {
2497
+ const SILLoop *loop = vjpCloner.getLoopInfo ()->getLoopFor (bb);
2498
+ if (loop == nullptr )
2499
+ continue ;
2500
+ SILBasicBlock *loopHeader = loop->getHeader ();
2501
+ SILBasicBlock *pbLoopHeader = getPullbackBlock (loopHeader);
2502
+ LLVM_DEBUG (getADDebugStream ()
2503
+ << " Original bb" << bb->getDebugID ()
2504
+ << " belongs to a loop, original header bb"
2505
+ << loopHeader->getDebugID () << " , pullback header bb"
2506
+ << pbLoopHeader->getDebugID () << ' \n ' );
2507
+ builder.setInsertionPoint (pbLoopHeader);
2508
+ auto bbActiveValuesIt = activeValues.find (bb);
2509
+ if (bbActiveValuesIt == activeValues.end ())
2510
+ continue ;
2511
+ const auto &bbActiveValues = bbActiveValuesIt->second ;
2512
+ for (SILValue bbActiveValue : bbActiveValues) {
2513
+ if (vjpCloner.getLoopInfo ()->getLoopFor (
2514
+ bbActiveValue->getParentBlock ()) != loop) {
2515
+ LLVM_DEBUG (
2516
+ getADDebugStream ()
2517
+ << " The following active value is NOT loop-local, skipping: "
2518
+ << bbActiveValue);
2519
+ continue ;
2520
+ }
2521
+
2522
+ auto [_, wasInserted] =
2523
+ loopLocalActiveValues[loop].insert (bbActiveValue);
2524
+ LLVM_DEBUG (getADDebugStream ()
2525
+ << " The following active value is loop-local, " );
2526
+ if (!wasInserted) {
2527
+ LLVM_DEBUG (llvm::dbgs () << " but it was already processed, skipping: "
2528
+ << bbActiveValue);
2529
+ continue ;
2530
+ }
2531
+
2532
+ if (getTangentValueCategory (bbActiveValue) ==
2533
+ SILValueCategory::Object) {
2534
+ LLVM_DEBUG (llvm::dbgs ()
2535
+ << " zeroing its adjoint value in loop header: "
2536
+ << bbActiveValue);
2537
+ setAdjointValue (bb, bbActiveValue,
2538
+ makeZeroAdjointValue (getRemappedTangentType (
2539
+ bbActiveValue->getType ())));
2540
+ continue ;
2541
+ }
2542
+
2543
+ assert (getTangentValueCategory (bbActiveValue) ==
2544
+ SILValueCategory::Address);
2545
+
2546
+ // getAdjointProjection might call materializeAdjointDirect which
2547
+ // writes to debug output, emit \n.
2548
+ LLVM_DEBUG (llvm::dbgs ()
2549
+ << " checking if it's adjoint is a projection\n " );
2550
+
2551
+ if (!getAdjointProjection (bb, bbActiveValue)) {
2552
+ LLVM_DEBUG (getADDebugStream ()
2553
+ << " Adjoint for the following value is NOT a projection, "
2554
+ " zeroing its adjoint buffer in loop header: "
2555
+ << bbActiveValue);
2556
+
2557
+ builder.emitZeroIntoBuffer (pbLoc, getAdjointBuffer (bb, bbActiveValue),
2558
+ IsInitialization);
2559
+
2560
+ continue ;
2561
+ }
2562
+
2563
+ LLVM_DEBUG (getADDebugStream ()
2564
+ << " Adjoint for the following value is a projection, " );
2565
+
2566
+ // If Projection::isAddressProjection(v) is true for a value v, it
2567
+ // is not added to active values list (see recordValueIfActive).
2568
+ //
2569
+ // Ensure that only the following value types conforming to
2570
+ // getAdjointProjection but not conforming to
2571
+ // Projection::isAddressProjection can go here.
2572
+ //
2573
+ // Instructions conforming to Projection::isAddressProjection and
2574
+ // thus never corresponding to an active value do not need any
2575
+ // handling, because only active values can have adjoints from
2576
+ // previous iterations propagated via BB arguments.
2577
+ do {
2578
+ // Consider '%X = begin_access [modify] [static] %Y'.
2579
+ // 1. If %Y is loop-local, it's adjoint buffer will
2580
+ // be zeroed, and we'll have zero adjoint projection to it.
2581
+ // 2. Otherwise, we do not need to zero the projection buffer.
2582
+ // Thus, we can just skip.
2583
+ if (dyn_cast<BeginAccessInst>(bbActiveValue)) {
2584
+ LLVM_DEBUG (llvm::dbgs () << " skipping: " << bbActiveValue);
2585
+ break ;
2586
+ }
2587
+
2588
+ // Consider the following sequence:
2589
+ // %1 = function_ref @allocUninitArray
2590
+ // %2 = apply %1<Float>(%0)
2591
+ // (%3, %4) = destructure_tuple %2
2592
+ // %5 = mark_dependence %4 on %3
2593
+ // %6 = pointer_to_address %6 to [strict] $*Float
2594
+ // Since %6 is active, %3 (which is an array) must also be active.
2595
+ // Thus, adjoint for %3 will be zeroed if needed. Ensure that expected
2596
+ // invariants hold and then skip.
2597
+ if (auto *ai = getAllocateUninitializedArrayIntrinsicElementAddress (
2598
+ bbActiveValue)) {
2599
+ #ifndef NDEBUG
2600
+ assert (isa<PointerToAddressInst>(bbActiveValue));
2601
+
2602
+ SILValue arrayValue = getArrayValue (ai);
2603
+ assert (llvm::find (bbActiveValues, arrayValue) !=
2604
+ bbActiveValues.end ());
2605
+ assert (vjpCloner.getLoopInfo ()->getLoopFor (
2606
+ arrayValue->getParentBlock ()) == loop);
2607
+ #endif
2608
+ LLVM_DEBUG (llvm::dbgs () << " skipping: " << bbActiveValue);
2609
+ break ;
2610
+ }
2611
+
2612
+ assert (false );
2613
+ } while (false );
2614
+ }
2615
+ }
2616
+ LLVM_DEBUG (getADDebugStream ()
2617
+ << " End search for adjoints of loop-local active values\n " );
2618
+
2475
2619
for (auto *bb : originalBlocks) {
2476
2620
visitSILBasicBlock (bb);
2477
2621
if (errorOccurred)
@@ -3387,19 +3531,9 @@ SILValue PullbackCloner::Implementation::getAdjointProjection(
3387
3531
eltIndex = ili->getValue ().getLimitedValue ();
3388
3532
}
3389
3533
// Get the array adjoint value.
3390
- SILValue arrayAdjoint;
3391
- assert (ai && " Expected `array.uninitialized_intrinsic` application" );
3392
- for (auto use : ai->getUses ()) {
3393
- auto *dti = dyn_cast<DestructureTupleInst>(use->getUser ());
3394
- if (!dti)
3395
- continue ;
3396
- assert (!arrayAdjoint && " Array adjoint already found" );
3397
- // The first `destructure_tuple` result is the `Array` value.
3398
- auto arrayValue = dti->getResult (0 );
3399
- arrayAdjoint = materializeAdjointDirect (
3400
- getAdjointValue (origBB, arrayValue), definingInst->getLoc ());
3401
- }
3402
- assert (arrayAdjoint && " Array does not have adjoint value" );
3534
+ SILValue arrayValue = getArrayValue (ai);
3535
+ SILValue arrayAdjoint = materializeAdjointDirect (
3536
+ getAdjointValue (origBB, arrayValue), definingInst->getLoc ());
3403
3537
// Apply `Array.TangentVector.subscript` to get array element adjoint value.
3404
3538
auto *eltAdjBuffer =
3405
3539
getArrayAdjointElementBuffer (arrayAdjoint, eltIndex, ai->getLoc ());
0 commit comments