Skip to content

[AutoDiff] Support differentiation of branching cast instructions. #32069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions include/swift/SILOptimizer/Differentiation/VJPEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,21 @@ class VJPEmitter final
/// Get the lowered SIL type of the given nominal type declaration.
SILType getNominalDeclLoweredType(NominalTypeDecl *nominal);

/// Build a pullback struct value for the original block corresponding to the
/// given terminator.
StructInst *buildPullbackValueStructValue(TermInst *termInst);
// Creates a trampoline block for given original terminator instruction, the
// pullback struct value for its parent block, and a successor basic block.
//
// The trampoline block has the same arguments as and branches to the remapped
// successor block, but drops the last predecessor enum argument.
//
// Used for cloning branching terminator instructions with specific
// requirements on successor block arguments, where an additional predecessor
// enum argument is not acceptable.
SILBasicBlock *createTrampolineBasicBlock(TermInst *termInst,
StructInst *pbStructVal,
SILBasicBlock *succBB);

/// Build a pullback struct value for the given original block.
StructInst *buildPullbackValueStructValue(SILBasicBlock *bb);

/// Build a predecessor enum instance using the given builder for the given
/// original predecessor/successor blocks and pullback struct value.
Expand All @@ -141,6 +153,12 @@ class VJPEmitter final

void visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai);

void visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi);

void visitCheckedCastValueBranchInst(CheckedCastValueBranchInst *ccvbi);

void visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *ccabi);

// If an `apply` has active results or active inout arguments, replace it
// with an `apply` of its VJP.
void visitApplyInst(ApplyInst *ai);
Expand Down
22 changes: 18 additions & 4 deletions lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,26 @@ void DifferentiableActivityInfo::propagateVaried(
if (auto *destBBArg = cbi->getArgForOperand(operand))
setVariedAndPropagateToUsers(destBBArg, i);
}
// Handle `switch_enum`.
else if (auto *sei = dyn_cast<SwitchEnumInst>(inst)) {
if (isVaried(sei->getOperand(), i))
for (auto *succBB : sei->getSuccessorBlocks())
// Handle `checked_cast_addr_br`.
// Propagate variedness from source operand to destination operand, in
// addition to all successor block arguments.
else if (auto *ccabi = dyn_cast<CheckedCastAddrBranchInst>(inst)) {
if (isVaried(ccabi->getSrc(), i)) {
setVariedAndPropagateToUsers(ccabi->getDest(), i);
for (auto *succBB : ccabi->getSuccessorBlocks())
for (auto *arg : succBB->getArguments())
setVariedAndPropagateToUsers(arg, i);
}
}
// Handle all other terminators: if any operand is active, propagate
// variedness to all successor block arguments. This logic may be incorrect
// for some terminator instructions, so special cases must be defined above.
else if (auto *termInst = dyn_cast<TermInst>(inst)) {
for (auto &op : termInst->getAllOperands())
if (isVaried(op.get(), i))
for (auto *succBB : termInst->getSuccessorBlocks())
for (auto *arg : succBB->getArguments())
setVariedAndPropagateToUsers(arg, i);
}
// Handle everything else.
else {
Expand Down
154 changes: 83 additions & 71 deletions lib/SILOptimizer/Differentiation/VJPEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,31 @@ SILBasicBlock *VJPEmitter::remapBasicBlock(SILBasicBlock *bb) {
return vjpBB;
}

SILBasicBlock *VJPEmitter::createTrampolineBasicBlock(TermInst *termInst,
StructInst *pbStructVal,
SILBasicBlock *succBB) {
assert(llvm::find(termInst->getSuccessorBlocks(), succBB) !=
termInst->getSuccessorBlocks().end() &&
"Basic block is not a successor of terminator instruction");
// Create the trampoline block.
auto *vjpSuccBB = getOpBasicBlock(succBB);
auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB);
for (auto *arg : vjpSuccBB->getArguments().drop_back())
trampolineBB->createPhiArgument(arg->getType(), arg->getOwnershipKind());
// In the trampoline block, build predecessor enum value for VJP successor
// block and branch to it.
SILBuilder trampolineBuilder(trampolineBB);
auto *origBB = termInst->getParent();
auto *succEnumVal =
buildPredecessorEnumValue(trampolineBuilder, origBB, succBB, pbStructVal);
SmallVector<SILValue, 4> forwardedArguments(
trampolineBB->getArguments().begin(), trampolineBB->getArguments().end());
forwardedArguments.push_back(succEnumVal);
trampolineBuilder.createBranch(termInst->getLoc(), vjpSuccBB,
forwardedArguments);
return trampolineBB;
}

void VJPEmitter::visit(SILInstruction *inst) {
if (errorOccurred)
return;
Expand All @@ -290,10 +315,9 @@ SILType VJPEmitter::getNominalDeclLoweredType(NominalTypeDecl *nominal) {
return getLoweredType(nominalType);
}

StructInst *VJPEmitter::buildPullbackValueStructValue(TermInst *termInst) {
assert(termInst->getFunction() == original);
auto loc = termInst->getFunction()->getLocation();
auto *origBB = termInst->getParent();
StructInst *VJPEmitter::buildPullbackValueStructValue(SILBasicBlock *origBB) {
assert(origBB->getParent() == original);
auto loc = origBB->getParent()->getLocation();
auto *vjpBB = BBMap[origBB];
auto *pbStruct = pullbackInfo.getLinearMapStruct(origBB);
auto structLoweredTy = getNominalDeclLoweredType(pbStruct);
Expand Down Expand Up @@ -333,9 +357,11 @@ EnumInst *VJPEmitter::buildPredecessorEnumValue(SILBuilder &builder,

void VJPEmitter::visitReturnInst(ReturnInst *ri) {
auto loc = ri->getOperand().getLoc();
auto *origExit = ri->getParent();
auto &builder = getBuilder();
auto *pbStructVal = buildPullbackValueStructValue(ri);

// Build pullback struct value for original block.
auto *origExit = ri->getParent();
auto *pbStructVal = buildPullbackValueStructValue(origExit);

// Get the value in the VJP corresponding to the original result.
auto *origRetInst = cast<ReturnInst>(origExit->getTerminator());
Expand Down Expand Up @@ -390,7 +416,7 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) {
// Build pullback struct value for original block.
// Build predecessor enum value for destination block.
auto *origBB = bi->getParent();
auto *pbStructVal = buildPullbackValueStructValue(bi);
auto *pbStructVal = buildPullbackValueStructValue(origBB);
auto *enumVal = buildPredecessorEnumValue(getBuilder(), origBB,
bi->getDestBB(), pbStructVal);

Expand All @@ -407,85 +433,30 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) {

void VJPEmitter::visitCondBranchInst(CondBranchInst *cbi) {
// Build pullback struct value for original block.
// Build predecessor enum values for true/false blocks.
auto *origBB = cbi->getParent();
auto *pbStructVal = buildPullbackValueStructValue(cbi);

// Creates a trampoline block for given original successor block. The
// trampoline block has the same arguments as the VJP successor block but
// drops the last predecessor enum argument. The generated `switch_enum`
// instruction branches to the trampoline block, and the trampoline block
// constructs a predecessor enum value and branches to the VJP successor
// block.
auto createTrampolineBasicBlock =
[&](SILBasicBlock *origSuccBB) -> SILBasicBlock * {
auto *vjpSuccBB = getOpBasicBlock(origSuccBB);
// Create the trampoline block.
auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB);
for (auto *arg : vjpSuccBB->getArguments().drop_back())
trampolineBB->createPhiArgument(arg->getType(), arg->getOwnershipKind());
// Build predecessor enum value for successor block and branch to it.
SILBuilder trampolineBuilder(trampolineBB);
auto *succEnumVal = buildPredecessorEnumValue(trampolineBuilder, origBB,
origSuccBB, pbStructVal);
SmallVector<SILValue, 4> forwardedArguments(
trampolineBB->getArguments().begin(),
trampolineBB->getArguments().end());
forwardedArguments.push_back(succEnumVal);
trampolineBuilder.createBranch(cbi->getLoc(), vjpSuccBB,
forwardedArguments);
return trampolineBB;
};

auto *pbStructVal = buildPullbackValueStructValue(cbi->getParent());
// Create a new `cond_br` instruction.
getBuilder().createCondBranch(cbi->getLoc(), getOpValue(cbi->getCondition()),
createTrampolineBasicBlock(cbi->getTrueBB()),
createTrampolineBasicBlock(cbi->getFalseBB()));
getBuilder().createCondBranch(
cbi->getLoc(), getOpValue(cbi->getCondition()),
createTrampolineBasicBlock(cbi, pbStructVal, cbi->getTrueBB()),
createTrampolineBasicBlock(cbi, pbStructVal, cbi->getFalseBB()));
}

void VJPEmitter::visitSwitchEnumInstBase(SwitchEnumInstBase *sei) {
// Build pullback struct value for original block.
auto *origBB = sei->getParent();
auto *pbStructVal = buildPullbackValueStructValue(sei);

// Creates a trampoline block for given original successor block. The
// trampoline block has the same arguments as the VJP successor block but
// drops the last predecessor enum argument. The generated `switch_enum`
// instruction branches to the trampoline block, and the trampoline block
// constructs a predecessor enum value and branches to the VJP successor
// block.
auto createTrampolineBasicBlock =
[&](SILBasicBlock *origSuccBB) -> SILBasicBlock * {
auto *vjpSuccBB = getOpBasicBlock(origSuccBB);
// Create the trampoline block.
auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB);
for (auto *destArg : vjpSuccBB->getArguments().drop_back())
trampolineBB->createPhiArgument(destArg->getType(),
destArg->getOwnershipKind());
// Build predecessor enum value for successor block and branch to it.
SILBuilder trampolineBuilder(trampolineBB);
auto *succEnumVal = buildPredecessorEnumValue(trampolineBuilder, origBB,
origSuccBB, pbStructVal);
SmallVector<SILValue, 4> forwardedArguments(
trampolineBB->getArguments().begin(),
trampolineBB->getArguments().end());
forwardedArguments.push_back(succEnumVal);
trampolineBuilder.createBranch(sei->getLoc(), vjpSuccBB,
forwardedArguments);
return trampolineBB;
};
auto *pbStructVal = buildPullbackValueStructValue(sei->getParent());

// Create trampoline successor basic blocks.
SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4> caseBBs;
for (unsigned i : range(sei->getNumCases())) {
auto caseBB = sei->getCase(i);
auto *trampolineBB = createTrampolineBasicBlock(caseBB.second);
auto *trampolineBB =
createTrampolineBasicBlock(sei, pbStructVal, caseBB.second);
caseBBs.push_back({caseBB.first, trampolineBB});
}
// Create trampoline default basic block.
SILBasicBlock *newDefaultBB = nullptr;
if (auto *defaultBB = sei->getDefaultBBOrNull().getPtrOrNull())
newDefaultBB = createTrampolineBasicBlock(defaultBB);
newDefaultBB = createTrampolineBasicBlock(sei, pbStructVal, defaultBB);

// Create a new `switch_enum` instruction.
switch (sei->getKind()) {
Expand All @@ -510,6 +481,47 @@ void VJPEmitter::visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) {
visitSwitchEnumInstBase(seai);
}

void VJPEmitter::visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) {
// Build pullback struct value for original block.
auto *pbStructVal = buildPullbackValueStructValue(ccbi->getParent());
// Create a new `checked_cast_branch` instruction.
getBuilder().createCheckedCastBranch(
ccbi->getLoc(), ccbi->isExact(), getOpValue(ccbi->getOperand()),
getOpType(ccbi->getTargetLoweredType()),
getOpASTType(ccbi->getTargetFormalType()),
createTrampolineBasicBlock(ccbi, pbStructVal, ccbi->getSuccessBB()),
createTrampolineBasicBlock(ccbi, pbStructVal, ccbi->getFailureBB()),
ccbi->getTrueBBCount(), ccbi->getFalseBBCount());
}

void VJPEmitter::visitCheckedCastValueBranchInst(
CheckedCastValueBranchInst *ccvbi) {
// Build pullback struct value for original block.
auto *pbStructVal = buildPullbackValueStructValue(ccvbi->getParent());
// Create a new `checked_cast_value_branch` instruction.
getBuilder().createCheckedCastValueBranch(
ccvbi->getLoc(), getOpValue(ccvbi->getOperand()),
getOpASTType(ccvbi->getSourceFormalType()),
getOpType(ccvbi->getTargetLoweredType()),
getOpASTType(ccvbi->getTargetFormalType()),
createTrampolineBasicBlock(ccvbi, pbStructVal, ccvbi->getSuccessBB()),
createTrampolineBasicBlock(ccvbi, pbStructVal, ccvbi->getFailureBB()));
}

void VJPEmitter::visitCheckedCastAddrBranchInst(
CheckedCastAddrBranchInst *ccabi) {
// Build pullback struct value for original block.
auto *pbStructVal = buildPullbackValueStructValue(ccabi->getParent());
// Create a new `checked_cast_addr_branch` instruction.
getBuilder().createCheckedCastAddrBranch(
ccabi->getLoc(), ccabi->getConsumptionKind(), getOpValue(ccabi->getSrc()),
getOpASTType(ccabi->getSourceFormalType()), getOpValue(ccabi->getDest()),
getOpASTType(ccabi->getTargetFormalType()),
createTrampolineBasicBlock(ccabi, pbStructVal, ccabi->getSuccessBB()),
createTrampolineBasicBlock(ccabi, pbStructVal, ccabi->getFailureBB()),
ccabi->getTrueBBCount(), ccabi->getFalseBBCount());
Copy link
Contributor Author

@dan-zheng dan-zheng May 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: these newly added VJPEmitter visitors significantly duplicate code from SILCloner.

We can potentially just inherit SILCloner visitors by baking createTrampolineBasicBlock logic into VJPEmitter::remapBasicBlock. This changes the meaning of VJPEmitter::getOpBasicBlock used by other visitors though, so I'm not sure it would work.

}

void VJPEmitter::visitApplyInst(ApplyInst *ai) {
// If callee should not be differentiated, do standard cloning.
if (!pullbackInfo.shouldDifferentiateApplySite(ai)) {
Expand Down
8 changes: 5 additions & 3 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,12 @@ static bool diagnoseUnsupportedControlFlow(ADContext &context,
// Diagnose unsupported branching terminators.
for (auto &bb : *original) {
auto *term = bb.getTerminator();
// Supported terminators are: `br`, `cond_br`, `switch_enum`,
// `switch_enum_addr`.
// Check supported branching terminators.
if (isa<BranchInst>(term) || isa<CondBranchInst>(term) ||
isa<SwitchEnumInst>(term) || isa<SwitchEnumAddrInst>(term))
isa<SwitchEnumInst>(term) || isa<SwitchEnumAddrInst>(term) ||
isa<CheckedCastBranchInst>(term) ||
isa<CheckedCastValueBranchInst>(term) ||
isa<CheckedCastAddrBranchInst>(term))
continue;
// If terminator is an unsupported branching terminator, emit an error.
if (term->isBranch()) {
Expand Down
Loading