Skip to content

[AutoDiff] Fix adjoints for loop-local active values #78374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 148 additions & 13 deletions lib/SILOptimizer/Differentiation/PullbackCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2095,6 +2095,23 @@ PullbackCloner::PullbackCloner(VJPCloner &vjpCloner)

PullbackCloner::~PullbackCloner() { delete &impl; }

static SILValue getArrayValue(ApplyInst *ai) {
SILValue arrayValue;
for (auto use : ai->getUses()) {
auto *dti = dyn_cast<DestructureTupleInst>(use->getUser());
if (!dti)
continue;
DEBUG_ASSERT(!arrayValue && "Array value already found");
// The first `destructure_tuple` result is the `Array` value.
arrayValue = dti->getResult(0);
#ifndef DEBUG_ASSERT_enabled
break;
#endif
}
ASSERT(arrayValue);
return arrayValue;
}

//--------------------------------------------------------------------------//
// Entry point
//--------------------------------------------------------------------------//
Expand Down Expand Up @@ -2439,6 +2456,134 @@ bool PullbackCloner::Implementation::run() {
// Visit original blocks in post-order and perform differentiation
// in corresponding pullback blocks. If errors occurred, back out.
else {
LLVM_DEBUG(getADDebugStream()
<< "Begin search for adjoints of loop-local active values\n");
llvm::DenseMap<const SILLoop *, llvm::DenseSet<SILValue>>
loopLocalActiveValues;
for (auto *bb : originalBlocks) {
const SILLoop *loop = vjpCloner.getLoopInfo()->getLoopFor(bb);
if (loop == nullptr)
continue;
SILBasicBlock *loopHeader = loop->getHeader();
SILBasicBlock *pbLoopHeader = getPullbackBlock(loopHeader);
LLVM_DEBUG(getADDebugStream()
<< "Original bb" << bb->getDebugID()
<< " belongs to a loop, original header bb"
<< loopHeader->getDebugID() << ", pullback header bb"
<< pbLoopHeader->getDebugID() << '\n');
builder.setInsertionPoint(pbLoopHeader);
auto bbActiveValuesIt = activeValues.find(bb);
if (bbActiveValuesIt == activeValues.end())
continue;
const auto &bbActiveValues = bbActiveValuesIt->second;
for (SILValue bbActiveValue : bbActiveValues) {
if (vjpCloner.getLoopInfo()->getLoopFor(
bbActiveValue->getParentBlock()) != loop) {
LLVM_DEBUG(
getADDebugStream()
<< "The following active value is NOT loop-local, skipping: "
<< bbActiveValue);
continue;
}

auto [_, wasInserted] =
loopLocalActiveValues[loop].insert(bbActiveValue);
LLVM_DEBUG(getADDebugStream()
<< "The following active value is loop-local, ");
if (!wasInserted) {
LLVM_DEBUG(llvm::dbgs() << "but it was already processed, skipping: "
<< bbActiveValue);
continue;
}

if (getTangentValueCategory(bbActiveValue) ==
SILValueCategory::Object) {
LLVM_DEBUG(llvm::dbgs()
<< "zeroing its adjoint value in loop header: "
<< bbActiveValue);
setAdjointValue(bb, bbActiveValue,
makeZeroAdjointValue(getRemappedTangentType(
bbActiveValue->getType())));
continue;
}

ASSERT(getTangentValueCategory(bbActiveValue) ==
SILValueCategory::Address);

// getAdjointProjection might call materializeAdjointDirect which
// writes to debug output, emit \n.
LLVM_DEBUG(llvm::dbgs()
<< "checking if it's adjoint is a projection\n");

if (!getAdjointProjection(bb, bbActiveValue)) {
LLVM_DEBUG(getADDebugStream()
<< "Adjoint for the following value is NOT a projection, "
"zeroing its adjoint buffer in loop header: "
<< bbActiveValue);

// All adjoint buffers are allocated in the pullback entry and
// deallocated in the pullback exit. So, use IsNotInitialization to
// emit destroy_addr before zeroing the buffer.
ASSERT(bufferMap.contains({bb, bbActiveValue}));
builder.emitZeroIntoBuffer(pbLoc, getAdjointBuffer(bb, bbActiveValue),
IsNotInitialization);

continue;
}

LLVM_DEBUG(getADDebugStream()
<< "Adjoint for the following value is a projection, ");

// If Projection::isAddressProjection(v) is true for a value v, it
// is not added to active values list (see recordValueIfActive).
//
// Ensure that only the following value types conforming to
// getAdjointProjection but not conforming to
// Projection::isAddressProjection can go here.
//
// Instructions conforming to Projection::isAddressProjection and
// thus never corresponding to an active value do not need any
// handling, because only active values can have adjoints from
// previous iterations propagated via BB arguments.
do {
// Consider '%X = begin_access [modify] [static] %Y'.
// 1. If %Y is loop-local, it's adjoint buffer will
// be zeroed, and we'll have zero adjoint projection to it.
// 2. Otherwise, we do not need to zero the projection buffer.
// Thus, we can just skip.
if (dyn_cast<BeginAccessInst>(bbActiveValue)) {
LLVM_DEBUG(llvm::dbgs() << "skipping: " << bbActiveValue);
break;
}

// Consider the following sequence:
// %1 = function_ref @allocUninitArray
// %2 = apply %1<Float>(%0)
// (%3, %4) = destructure_tuple %2
// %5 = mark_dependence %4 on %3
// %6 = pointer_to_address %6 to [strict] $*Float
// Since %6 is active, %3 (which is an array) must also be active.
// Thus, adjoint for %3 will be zeroed if needed. Ensure that expected
// invariants hold and then skip.
if (auto *ai = getAllocateUninitializedArrayIntrinsicElementAddress(
bbActiveValue)) {
ASSERT(isa<PointerToAddressInst>(bbActiveValue));
SILValue arrayValue = getArrayValue(ai);
ASSERT(llvm::find(bbActiveValues, arrayValue) !=
bbActiveValues.end());
ASSERT(vjpCloner.getLoopInfo()->getLoopFor(
arrayValue->getParentBlock()) == loop);
LLVM_DEBUG(llvm::dbgs() << "skipping: " << bbActiveValue);
break;
}

ASSERT(false);
} while (false);
}
}
LLVM_DEBUG(getADDebugStream()
<< "End search for adjoints of loop-local active values\n");

for (auto *bb : originalBlocks) {
visitSILBasicBlock(bb);
if (errorOccurred)
Expand Down Expand Up @@ -3339,19 +3484,9 @@ SILValue PullbackCloner::Implementation::getAdjointProjection(
eltIndex = ili->getValue().getLimitedValue();
}
// Get the array adjoint value.
SILValue arrayAdjoint;
assert(ai && "Expected `array.uninitialized_intrinsic` application");
for (auto use : ai->getUses()) {
auto *dti = dyn_cast<DestructureTupleInst>(use->getUser());
if (!dti)
continue;
assert(!arrayAdjoint && "Array adjoint already found");
// The first `destructure_tuple` result is the `Array` value.
auto arrayValue = dti->getResult(0);
arrayAdjoint = materializeAdjointDirect(
getAdjointValue(origBB, arrayValue), definingInst->getLoc());
}
assert(arrayAdjoint && "Array does not have adjoint value");
SILValue arrayValue = getArrayValue(ai);
SILValue arrayAdjoint = materializeAdjointDirect(
getAdjointValue(origBB, arrayValue), definingInst->getLoc());
// Apply `Array.TangentVector.subscript` to get array element adjoint value.
auto *eltAdjBuffer =
getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, ai->getLoc());
Expand Down
24 changes: 12 additions & 12 deletions test/AutoDiff/SILOptimizer/pullback_generation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -182,22 +182,22 @@ func f4(a: NonTrivial) -> Float {
}

// CHECK-LABEL: sil private [ossa] @$s19pullback_generation2f41aSfAA10NonTrivialV_tFTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> @owned NonTrivial {
// CHECK: bb5(%67 : @owned $NonTrivial, %68 : $Float, %69 : @owned $(predecessor: _AD__$s19pullback_generation2f41aSfAA10NonTrivialV_tF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (@inout Float) -> Float)):
// CHECK: %88 = alloc_stack $NonTrivial
// CHECK: bb5(%[[#ARG0:]] : @owned $NonTrivial, %[[#]] : $Float, %[[#]] : @owned $(predecessor: _AD__$s19pullback_generation2f41aSfAA10NonTrivialV_tF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (@inout Float) -> Float)):
// CHECK: %[[#T0:]] = alloc_stack $NonTrivial

// Non-trivial value must be copied or there will be a
// double consume when all owned parameters are destroyed
// at the end of the basic block.
// CHECK: %89 = copy_value %67 : $NonTrivial
// CHECK: %[[#T1:]] = copy_value %[[#ARG0]] : $NonTrivial

// CHECK: store %89 to [init] %88 : $*NonTrivial
// CHECK: %91 = struct_element_addr %88 : $*NonTrivial, #NonTrivial.x
// CHECK: %92 = alloc_stack $Float
// CHECK: store %86 to [trivial] %92 : $*Float
// CHECK: %94 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %95 = metatype $@thick Float.Type
// CHECK: %96 = apply %94<Float>(%91, %92, %95) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_value %67 : $NonTrivial
// CHECK: store %[[#T1]] to [init] %[[#T0]] : $*NonTrivial
// CHECK: %[[#T2:]] = struct_element_addr %[[#T0]] : $*NonTrivial, #NonTrivial.x
// CHECK: %[[#T3:]] = alloc_stack $Float
// CHECK: store %[[#T4:]] to [trivial] %[[#T3]] : $*Float
// CHECK: %[[#T5:]] = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %[[#T6:]] = metatype $@thick Float.Type
// CHECK: %[[#]] = apply %[[#T5]]<Float>(%[[#T2]], %[[#T3]], %[[#T6]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_value %[[#ARG0]] : $NonTrivial

@differentiable(reverse)
func move_value(x: Float) -> Float {
Expand All @@ -218,5 +218,5 @@ func move_value(x: Float) -> Float {
// CHECK: %[[#]] = apply %[[#T2]]<Float>(%[[#T1]], %[[#T3]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %[[#T4:]] = load [trivial] %[[#T1]]
// CHECK: dealloc_stack %[[#T1]]
// CHECK: bb4(%113 : $Builtin.RawPointer):
// CHECK: bb4(%[[#]] : $Builtin.RawPointer):
// CHECK: br bb5(%[[#]] : $Float, %[[#]] : $Float, %[[#T4]] : $Float, %[[#]] : $(predecessor: _AD__$s19pullback_generation10move_value1xS2f_tF_bb2__Pred__src_0_wrt_0))
Loading