Skip to content

Commit 6960602

Browse files
committed
[AutoDiff] Fix adjoints for loop-local active values
Fixes #78264
1 parent 09d122a commit 6960602

File tree

4 files changed

+558
-37
lines changed

4 files changed

+558
-37
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 147 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2128,6 +2128,23 @@ PullbackCloner::PullbackCloner(VJPCloner &vjpCloner)
21282128

21292129
PullbackCloner::~PullbackCloner() { delete &impl; }
21302130

2131+
static SILValue getArrayValue(ApplyInst *ai) {
2132+
SILValue arrayValue;
2133+
for (auto use : ai->getUses()) {
2134+
auto *dti = dyn_cast<DestructureTupleInst>(use->getUser());
2135+
if (!dti)
2136+
continue;
2137+
assert(!arrayValue && "Array value already found");
2138+
// The first `destructure_tuple` result is the `Array` value.
2139+
arrayValue = dti->getResult(0);
2140+
#ifdef NDEBUG
2141+
break;
2142+
#endif
2143+
}
2144+
assert(arrayValue);
2145+
return arrayValue;
2146+
}
2147+
21312148
//--------------------------------------------------------------------------//
21322149
// Entry point
21332150
//--------------------------------------------------------------------------//
@@ -2472,6 +2489,133 @@ bool PullbackCloner::Implementation::run() {
24722489
// Visit original blocks in post-order and perform differentiation
24732490
// in corresponding pullback blocks. If errors occurred, back out.
24742491
else {
2492+
LLVM_DEBUG(getADDebugStream()
2493+
<< "Begin search for adjoints of loop-local active values\n");
2494+
llvm::DenseMap<const SILLoop *, llvm::DenseSet<SILValue>>
2495+
loopLocalActiveValues;
2496+
for (auto *bb : originalBlocks) {
2497+
const SILLoop *loop = vjpCloner.getLoopInfo()->getLoopFor(bb);
2498+
if (loop == nullptr)
2499+
continue;
2500+
SILBasicBlock *loopHeader = loop->getHeader();
2501+
SILBasicBlock *pbLoopHeader = getPullbackBlock(loopHeader);
2502+
LLVM_DEBUG(getADDebugStream()
2503+
<< "Original bb" << bb->getDebugID()
2504+
<< " belongs to a loop, original header bb"
2505+
<< loopHeader->getDebugID() << ", pullback header bb"
2506+
<< pbLoopHeader->getDebugID() << '\n');
2507+
builder.setInsertionPoint(pbLoopHeader);
2508+
auto bbActiveValuesIt = activeValues.find(bb);
2509+
if (bbActiveValuesIt == activeValues.end())
2510+
continue;
2511+
const auto &bbActiveValues = bbActiveValuesIt->second;
2512+
for (SILValue bbActiveValue : bbActiveValues) {
2513+
if (vjpCloner.getLoopInfo()->getLoopFor(
2514+
bbActiveValue->getParentBlock()) != loop) {
2515+
LLVM_DEBUG(
2516+
getADDebugStream()
2517+
<< "The following active value is NOT loop-local, skipping: "
2518+
<< bbActiveValue);
2519+
continue;
2520+
}
2521+
2522+
auto [_, wasInserted] =
2523+
loopLocalActiveValues[loop].insert(bbActiveValue);
2524+
LLVM_DEBUG(getADDebugStream()
2525+
<< "The following active value is loop-local, ");
2526+
if (!wasInserted) {
2527+
LLVM_DEBUG(llvm::dbgs() << "but it was already processed, skipping: "
2528+
<< bbActiveValue);
2529+
continue;
2530+
}
2531+
2532+
if (getTangentValueCategory(bbActiveValue) ==
2533+
SILValueCategory::Object) {
2534+
LLVM_DEBUG(llvm::dbgs()
2535+
<< "zeroing its adjoint value in loop header: "
2536+
<< bbActiveValue);
2537+
setAdjointValue(bb, bbActiveValue,
2538+
makeZeroAdjointValue(getRemappedTangentType(
2539+
bbActiveValue->getType())));
2540+
continue;
2541+
}
2542+
2543+
assert(getTangentValueCategory(bbActiveValue) ==
2544+
SILValueCategory::Address);
2545+
2546+
// getAdjointProjection might call materializeAdjointDirect which
2547+
// writes to debug output, emit \n.
2548+
LLVM_DEBUG(llvm::dbgs()
2549+
<< "checking if it's adjoint is a projection\n");
2550+
2551+
if (!getAdjointProjection(bb, bbActiveValue)) {
2552+
LLVM_DEBUG(getADDebugStream()
2553+
<< "Adjoint for the following value is NOT a projection, "
2554+
"zeroing its adjoint buffer in loop header: "
2555+
<< bbActiveValue);
2556+
2557+
builder.emitZeroIntoBuffer(pbLoc, getAdjointBuffer(bb, bbActiveValue),
2558+
IsInitialization);
2559+
2560+
continue;
2561+
}
2562+
2563+
LLVM_DEBUG(getADDebugStream()
2564+
<< "Adjoint for the following value is a projection, ");
2565+
2566+
// If Projection::isAddressProjection(v) is true for a value v, it
2567+
// is not added to active values list (see recordValueIfActive).
2568+
//
2569+
// Ensure that only the following value types conforming to
2570+
// getAdjointProjection but not conforming to
2571+
// Projection::isAddressProjection can go here.
2572+
//
2573+
// Instructions conforming to Projection::isAddressProjection and
2574+
// thus never corresponding to an active value do not need any
2575+
// handling, because only active values can have adjoints from
2576+
// previous iterations propagated via BB arguments.
2577+
do {
2578+
// Consider '%X = begin_access [modify] [static] %Y'.
2579+
// 1. If %Y is loop-local, it's adjoint buffer will
2580+
// be zeroed, and we'll have zero adjoint projection to it.
2581+
// 2. Otherwise, we do not need to zero the projection buffer.
2582+
// Thus, we can just skip.
2583+
if (dyn_cast<BeginAccessInst>(bbActiveValue)) {
2584+
LLVM_DEBUG(llvm::dbgs() << "skipping: " << bbActiveValue);
2585+
break;
2586+
}
2587+
2588+
// Consider the following sequence:
2589+
// %1 = function_ref @allocUninitArray
2590+
// %2 = apply %1<Float>(%0)
2591+
// (%3, %4) = destructure_tuple %2
2592+
// %5 = mark_dependence %4 on %3
2593+
// %6 = pointer_to_address %6 to [strict] $*Float
2594+
// Since %6 is active, %3 (which is an array) must also be active.
2595+
// Thus, adjoint for %3 will be zeroed if needed. Ensure that expected
2596+
// invariants hold and then skip.
2597+
if (auto *ai = getAllocateUninitializedArrayIntrinsicElementAddress(
2598+
bbActiveValue)) {
2599+
#ifndef NDEBUG
2600+
assert(isa<PointerToAddressInst>(bbActiveValue));
2601+
2602+
SILValue arrayValue = getArrayValue(ai);
2603+
assert(llvm::find(bbActiveValues, arrayValue) !=
2604+
bbActiveValues.end());
2605+
assert(vjpCloner.getLoopInfo()->getLoopFor(
2606+
arrayValue->getParentBlock()) == loop);
2607+
#endif
2608+
LLVM_DEBUG(llvm::dbgs() << "skipping: " << bbActiveValue);
2609+
break;
2610+
}
2611+
2612+
assert(false);
2613+
} while (false);
2614+
}
2615+
}
2616+
LLVM_DEBUG(getADDebugStream()
2617+
<< "End search for adjoints of loop-local active values\n");
2618+
24752619
for (auto *bb : originalBlocks) {
24762620
visitSILBasicBlock(bb);
24772621
if (errorOccurred)
@@ -3387,19 +3531,9 @@ SILValue PullbackCloner::Implementation::getAdjointProjection(
33873531
eltIndex = ili->getValue().getLimitedValue();
33883532
}
33893533
// Get the array adjoint value.
3390-
SILValue arrayAdjoint;
3391-
assert(ai && "Expected `array.uninitialized_intrinsic` application");
3392-
for (auto use : ai->getUses()) {
3393-
auto *dti = dyn_cast<DestructureTupleInst>(use->getUser());
3394-
if (!dti)
3395-
continue;
3396-
assert(!arrayAdjoint && "Array adjoint already found");
3397-
// The first `destructure_tuple` result is the `Array` value.
3398-
auto arrayValue = dti->getResult(0);
3399-
arrayAdjoint = materializeAdjointDirect(
3400-
getAdjointValue(origBB, arrayValue), definingInst->getLoc());
3401-
}
3402-
assert(arrayAdjoint && "Array does not have adjoint value");
3534+
SILValue arrayValue = getArrayValue(ai);
3535+
SILValue arrayAdjoint = materializeAdjointDirect(
3536+
getAdjointValue(origBB, arrayValue), definingInst->getLoc());
34033537
// Apply `Array.TangentVector.subscript` to get array element adjoint value.
34043538
auto *eltAdjBuffer =
34053539
getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, ai->getLoc());

test/AutoDiff/SILOptimizer/pullback_generation.swift

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -182,22 +182,22 @@ func f4(a: NonTrivial) -> Float {
182182
}
183183

184184
// CHECK-LABEL: sil private [ossa] @$s19pullback_generation2f41aSfAA10NonTrivialV_tFTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> @owned NonTrivial {
185-
// CHECK: bb5(%67 : @owned $NonTrivial, %68 : $Float, %69 : @owned $(predecessor: _AD__$s19pullback_generation2f41aSfAA10NonTrivialV_tF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (@inout Float) -> Float)):
186-
// CHECK: %88 = alloc_stack $NonTrivial
185+
// CHECK: bb5(%[[#ARG0:]] : @owned $NonTrivial, %[[#]] : $Float, %[[#]] : @owned $(predecessor: _AD__$s19pullback_generation2f41aSfAA10NonTrivialV_tF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (@inout Float) -> Float)):
186+
// CHECK: %[[#T0:]] = alloc_stack $NonTrivial
187187

188188
// Non-trivial value must be copied or there will be a
189189
// double consume when all owned parameters are destroyed
190190
// at the end of the basic block.
191-
// CHECK: %89 = copy_value %67 : $NonTrivial
192-
193-
// CHECK: store %89 to [init] %88 : $*NonTrivial
194-
// CHECK: %91 = struct_element_addr %88 : $*NonTrivial, #NonTrivial.x
195-
// CHECK: %92 = alloc_stack $Float
196-
// CHECK: store %86 to [trivial] %92 : $*Float
197-
// CHECK: %94 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
198-
// CHECK: %95 = metatype $@thick Float.Type
199-
// CHECK: %96 = apply %94<Float>(%91, %92, %95) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
200-
// CHECK: destroy_value %67 : $NonTrivial
191+
// CHECK: %[[#T1:]] = copy_value %[[#ARG0]] : $NonTrivial
192+
193+
// CHECK: store %[[#T1]] to [init] %[[#T0]] : $*NonTrivial
194+
// CHECK: %[[#T2:]] = struct_element_addr %[[#T0]] : $*NonTrivial, #NonTrivial.x
195+
// CHECK: %[[#T3:]] = alloc_stack $Float
196+
// CHECK: store %[[#T4:]] to [trivial] %[[#T3]] : $*Float
197+
// CHECK: %[[#T5:]] = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
198+
// CHECK: %[[#T6:]] = metatype $@thick Float.Type
199+
// CHECK: %[[#]] = apply %[[#T5]]<Float>(%[[#T2]], %[[#T3]], %[[#T6]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
200+
// CHECK: destroy_value %[[#ARG0]] : $NonTrivial
201201

202202
@differentiable(reverse)
203203
func move_value(x: Float) -> Float {

0 commit comments

Comments
 (0)