Skip to content

Commit 5d0bfed

Browse files
authored
[AutoDiff] Fix adjoints for loop-local active values (#78374)
Fixes #78264
1 parent e656fe5 commit 5d0bfed

File tree

4 files changed

+559
-37
lines changed

4 files changed

+559
-37
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 148 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2095,6 +2095,23 @@ PullbackCloner::PullbackCloner(VJPCloner &vjpCloner)
20952095

20962096
PullbackCloner::~PullbackCloner() { delete &impl; }
20972097

2098+
static SILValue getArrayValue(ApplyInst *ai) {
2099+
SILValue arrayValue;
2100+
for (auto use : ai->getUses()) {
2101+
auto *dti = dyn_cast<DestructureTupleInst>(use->getUser());
2102+
if (!dti)
2103+
continue;
2104+
DEBUG_ASSERT(!arrayValue && "Array value already found");
2105+
// The first `destructure_tuple` result is the `Array` value.
2106+
arrayValue = dti->getResult(0);
2107+
#ifndef DEBUG_ASSERT_enabled
2108+
break;
2109+
#endif
2110+
}
2111+
ASSERT(arrayValue);
2112+
return arrayValue;
2113+
}
2114+
20982115
//--------------------------------------------------------------------------//
20992116
// Entry point
21002117
//--------------------------------------------------------------------------//
@@ -2439,6 +2456,134 @@ bool PullbackCloner::Implementation::run() {
24392456
// Visit original blocks in post-order and perform differentiation
24402457
// in corresponding pullback blocks. If errors occurred, back out.
24412458
else {
2459+
LLVM_DEBUG(getADDebugStream()
2460+
<< "Begin search for adjoints of loop-local active values\n");
2461+
llvm::DenseMap<const SILLoop *, llvm::DenseSet<SILValue>>
2462+
loopLocalActiveValues;
2463+
for (auto *bb : originalBlocks) {
2464+
const SILLoop *loop = vjpCloner.getLoopInfo()->getLoopFor(bb);
2465+
if (loop == nullptr)
2466+
continue;
2467+
SILBasicBlock *loopHeader = loop->getHeader();
2468+
SILBasicBlock *pbLoopHeader = getPullbackBlock(loopHeader);
2469+
LLVM_DEBUG(getADDebugStream()
2470+
<< "Original bb" << bb->getDebugID()
2471+
<< " belongs to a loop, original header bb"
2472+
<< loopHeader->getDebugID() << ", pullback header bb"
2473+
<< pbLoopHeader->getDebugID() << '\n');
2474+
builder.setInsertionPoint(pbLoopHeader);
2475+
auto bbActiveValuesIt = activeValues.find(bb);
2476+
if (bbActiveValuesIt == activeValues.end())
2477+
continue;
2478+
const auto &bbActiveValues = bbActiveValuesIt->second;
2479+
for (SILValue bbActiveValue : bbActiveValues) {
2480+
if (vjpCloner.getLoopInfo()->getLoopFor(
2481+
bbActiveValue->getParentBlock()) != loop) {
2482+
LLVM_DEBUG(
2483+
getADDebugStream()
2484+
<< "The following active value is NOT loop-local, skipping: "
2485+
<< bbActiveValue);
2486+
continue;
2487+
}
2488+
2489+
auto [_, wasInserted] =
2490+
loopLocalActiveValues[loop].insert(bbActiveValue);
2491+
LLVM_DEBUG(getADDebugStream()
2492+
<< "The following active value is loop-local, ");
2493+
if (!wasInserted) {
2494+
LLVM_DEBUG(llvm::dbgs() << "but it was already processed, skipping: "
2495+
<< bbActiveValue);
2496+
continue;
2497+
}
2498+
2499+
if (getTangentValueCategory(bbActiveValue) ==
2500+
SILValueCategory::Object) {
2501+
LLVM_DEBUG(llvm::dbgs()
2502+
<< "zeroing its adjoint value in loop header: "
2503+
<< bbActiveValue);
2504+
setAdjointValue(bb, bbActiveValue,
2505+
makeZeroAdjointValue(getRemappedTangentType(
2506+
bbActiveValue->getType())));
2507+
continue;
2508+
}
2509+
2510+
ASSERT(getTangentValueCategory(bbActiveValue) ==
2511+
SILValueCategory::Address);
2512+
2513+
// getAdjointProjection might call materializeAdjointDirect which
2514+
// writes to debug output, emit \n.
2515+
LLVM_DEBUG(llvm::dbgs()
2516+
<< "checking if it's adjoint is a projection\n");
2517+
2518+
if (!getAdjointProjection(bb, bbActiveValue)) {
2519+
LLVM_DEBUG(getADDebugStream()
2520+
<< "Adjoint for the following value is NOT a projection, "
2521+
"zeroing its adjoint buffer in loop header: "
2522+
<< bbActiveValue);
2523+
2524+
// All adjoint buffers are allocated in the pullback entry and
2525+
// deallocated in the pullback exit. So, use IsNotInitialization to
2526+
// emit destroy_addr before zeroing the buffer.
2527+
ASSERT(bufferMap.contains({bb, bbActiveValue}));
2528+
builder.emitZeroIntoBuffer(pbLoc, getAdjointBuffer(bb, bbActiveValue),
2529+
IsNotInitialization);
2530+
2531+
continue;
2532+
}
2533+
2534+
LLVM_DEBUG(getADDebugStream()
2535+
<< "Adjoint for the following value is a projection, ");
2536+
2537+
// If Projection::isAddressProjection(v) is true for a value v, it
2538+
// is not added to active values list (see recordValueIfActive).
2539+
//
2540+
// Ensure that only the following value types conforming to
2541+
// getAdjointProjection but not conforming to
2542+
// Projection::isAddressProjection can go here.
2543+
//
2544+
// Instructions conforming to Projection::isAddressProjection and
2545+
// thus never corresponding to an active value do not need any
2546+
// handling, because only active values can have adjoints from
2547+
// previous iterations propagated via BB arguments.
2548+
do {
2549+
// Consider '%X = begin_access [modify] [static] %Y'.
2550+
// 1. If %Y is loop-local, it's adjoint buffer will
2551+
// be zeroed, and we'll have zero adjoint projection to it.
2552+
// 2. Otherwise, we do not need to zero the projection buffer.
2553+
// Thus, we can just skip.
2554+
if (dyn_cast<BeginAccessInst>(bbActiveValue)) {
2555+
LLVM_DEBUG(llvm::dbgs() << "skipping: " << bbActiveValue);
2556+
break;
2557+
}
2558+
2559+
// Consider the following sequence:
2560+
// %1 = function_ref @allocUninitArray
2561+
// %2 = apply %1<Float>(%0)
2562+
// (%3, %4) = destructure_tuple %2
2563+
// %5 = mark_dependence %4 on %3
2564+
// %6 = pointer_to_address %6 to [strict] $*Float
2565+
// Since %6 is active, %3 (which is an array) must also be active.
2566+
// Thus, adjoint for %3 will be zeroed if needed. Ensure that expected
2567+
// invariants hold and then skip.
2568+
if (auto *ai = getAllocateUninitializedArrayIntrinsicElementAddress(
2569+
bbActiveValue)) {
2570+
ASSERT(isa<PointerToAddressInst>(bbActiveValue));
2571+
SILValue arrayValue = getArrayValue(ai);
2572+
ASSERT(llvm::find(bbActiveValues, arrayValue) !=
2573+
bbActiveValues.end());
2574+
ASSERT(vjpCloner.getLoopInfo()->getLoopFor(
2575+
arrayValue->getParentBlock()) == loop);
2576+
LLVM_DEBUG(llvm::dbgs() << "skipping: " << bbActiveValue);
2577+
break;
2578+
}
2579+
2580+
ASSERT(false);
2581+
} while (false);
2582+
}
2583+
}
2584+
LLVM_DEBUG(getADDebugStream()
2585+
<< "End search for adjoints of loop-local active values\n");
2586+
24422587
for (auto *bb : originalBlocks) {
24432588
visitSILBasicBlock(bb);
24442589
if (errorOccurred)
@@ -3339,19 +3484,9 @@ SILValue PullbackCloner::Implementation::getAdjointProjection(
33393484
eltIndex = ili->getValue().getLimitedValue();
33403485
}
33413486
// Get the array adjoint value.
3342-
SILValue arrayAdjoint;
3343-
assert(ai && "Expected `array.uninitialized_intrinsic` application");
3344-
for (auto use : ai->getUses()) {
3345-
auto *dti = dyn_cast<DestructureTupleInst>(use->getUser());
3346-
if (!dti)
3347-
continue;
3348-
assert(!arrayAdjoint && "Array adjoint already found");
3349-
// The first `destructure_tuple` result is the `Array` value.
3350-
auto arrayValue = dti->getResult(0);
3351-
arrayAdjoint = materializeAdjointDirect(
3352-
getAdjointValue(origBB, arrayValue), definingInst->getLoc());
3353-
}
3354-
assert(arrayAdjoint && "Array does not have adjoint value");
3487+
SILValue arrayValue = getArrayValue(ai);
3488+
SILValue arrayAdjoint = materializeAdjointDirect(
3489+
getAdjointValue(origBB, arrayValue), definingInst->getLoc());
33553490
// Apply `Array.TangentVector.subscript` to get array element adjoint value.
33563491
auto *eltAdjBuffer =
33573492
getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, ai->getLoc());

test/AutoDiff/SILOptimizer/pullback_generation.swift

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -182,22 +182,22 @@ func f4(a: NonTrivial) -> Float {
182182
}
183183

184184
// CHECK-LABEL: sil private [ossa] @$s19pullback_generation2f41aSfAA10NonTrivialV_tFTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> @owned NonTrivial {
185-
// CHECK: bb5(%67 : @owned $NonTrivial, %68 : $Float, %69 : @owned $(predecessor: _AD__$s19pullback_generation2f41aSfAA10NonTrivialV_tF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (@inout Float) -> Float)):
186-
// CHECK: %88 = alloc_stack $NonTrivial
185+
// CHECK: bb5(%[[#ARG0:]] : @owned $NonTrivial, %[[#]] : $Float, %[[#]] : @owned $(predecessor: _AD__$s19pullback_generation2f41aSfAA10NonTrivialV_tF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (@inout Float) -> Float)):
186+
// CHECK: %[[#T0:]] = alloc_stack $NonTrivial
187187

188188
// Non-trivial value must be copied or there will be a
189189
// double consume when all owned parameters are destroyed
190190
// at the end of the basic block.
191-
// CHECK: %89 = copy_value %67 : $NonTrivial
191+
// CHECK: %[[#T1:]] = copy_value %[[#ARG0]] : $NonTrivial
192192

193-
// CHECK: store %89 to [init] %88 : $*NonTrivial
194-
// CHECK: %91 = struct_element_addr %88 : $*NonTrivial, #NonTrivial.x
195-
// CHECK: %92 = alloc_stack $Float
196-
// CHECK: store %86 to [trivial] %92 : $*Float
197-
// CHECK: %94 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
198-
// CHECK: %95 = metatype $@thick Float.Type
199-
// CHECK: %96 = apply %94<Float>(%91, %92, %95) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
200-
// CHECK: destroy_value %67 : $NonTrivial
193+
// CHECK: store %[[#T1]] to [init] %[[#T0]] : $*NonTrivial
194+
// CHECK: %[[#T2:]] = struct_element_addr %[[#T0]] : $*NonTrivial, #NonTrivial.x
195+
// CHECK: %[[#T3:]] = alloc_stack $Float
196+
// CHECK: store %[[#T4:]] to [trivial] %[[#T3]] : $*Float
197+
// CHECK: %[[#T5:]] = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
198+
// CHECK: %[[#T6:]] = metatype $@thick Float.Type
199+
// CHECK: %[[#]] = apply %[[#T5]]<Float>(%[[#T2]], %[[#T3]], %[[#T6]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
200+
// CHECK: destroy_value %[[#ARG0]] : $NonTrivial
201201

202202
@differentiable(reverse)
203203
func move_value(x: Float) -> Float {
@@ -218,5 +218,5 @@ func move_value(x: Float) -> Float {
218218
// CHECK: %[[#]] = apply %[[#T2]]<Float>(%[[#T1]], %[[#T3]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
219219
// CHECK: %[[#T4:]] = load [trivial] %[[#T1]]
220220
// CHECK: dealloc_stack %[[#T1]]
221-
// CHECK: bb4(%113 : $Builtin.RawPointer):
221+
// CHECK: bb4(%[[#]] : $Builtin.RawPointer):
222222
// CHECK: br bb5(%[[#]] : $Float, %[[#]] : $Float, %[[#T4]] : $Float, %[[#]] : $(predecessor: _AD__$s19pullback_generation10move_value1xS2f_tF_bb2__Pred__src_0_wrt_0))

0 commit comments

Comments
 (0)