From 3c3093db75e61115b9198a933b6c5aabebd12884 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 28 May 2020 12:32:15 -0700 Subject: [PATCH 1/3] [AutoDiff] Clean up VJP basic block utilities. Add a common helper function `VJPEmitter::createTrampolineBasicBlock`. Change `VJPEmitter::buildPullbackValueStructValue` to take an original basic block instead of a terminator instruction. --- .../SILOptimizer/Differentiation/VJPEmitter.h | 18 ++- .../Differentiation/VJPEmitter.cpp | 113 +++++++----------- 2 files changed, 57 insertions(+), 74 deletions(-) diff --git a/include/swift/SILOptimizer/Differentiation/VJPEmitter.h b/include/swift/SILOptimizer/Differentiation/VJPEmitter.h index d2f8f7fabb996..a0ec673b4d73b 100644 --- a/include/swift/SILOptimizer/Differentiation/VJPEmitter.h +++ b/include/swift/SILOptimizer/Differentiation/VJPEmitter.h @@ -117,9 +117,21 @@ class VJPEmitter final /// Get the lowered SIL type of the given nominal type declaration. SILType getNominalDeclLoweredType(NominalTypeDecl *nominal); - /// Build a pullback struct value for the original block corresponding to the - /// given terminator. - StructInst *buildPullbackValueStructValue(TermInst *termInst); + // Creates a trampoline block for given original terminator instruction, the + // pullback struct value for its parent block, and a successor basic block. + // + // The trampoline block has the same arguments as and branches to the remapped + // successor block, but drops the last predecessor enum argument. + // + // Used for cloning branching terminator instructions with specific + // requirements on successor block arguments, where an additional predecessor + // enum argument is not acceptable. + SILBasicBlock *createTrampolineBasicBlock(TermInst *termInst, + StructInst *pbStructVal, + SILBasicBlock *succBB); + + /// Build a pullback struct value for the given original block. + StructInst *buildPullbackValueStructValue(SILBasicBlock *bb); /// Build a predecessor enum instance using the given builder for the given /// original predecessor/successor blocks and pullback struct value. diff --git a/lib/SILOptimizer/Differentiation/VJPEmitter.cpp b/lib/SILOptimizer/Differentiation/VJPEmitter.cpp index 082d6884f39b1..8486b7b2bec14 100644 --- a/lib/SILOptimizer/Differentiation/VJPEmitter.cpp +++ b/lib/SILOptimizer/Differentiation/VJPEmitter.cpp @@ -265,6 +265,31 @@ SILBasicBlock *VJPEmitter::remapBasicBlock(SILBasicBlock *bb) { return vjpBB; } +SILBasicBlock *VJPEmitter::createTrampolineBasicBlock(TermInst *termInst, + StructInst *pbStructVal, + SILBasicBlock *succBB) { + assert(llvm::find(termInst->getSuccessorBlocks(), succBB) != + termInst->getSuccessorBlocks().end() && + "Basic block is not a successor of terminator instruction"); + // Create the trampoline block. + auto *vjpSuccBB = getOpBasicBlock(succBB); + auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB); + for (auto *arg : vjpSuccBB->getArguments().drop_back()) + trampolineBB->createPhiArgument(arg->getType(), arg->getOwnershipKind()); + // In the trampoline block, build predecessor enum value for VJP successor + // block and branch to it. + SILBuilder trampolineBuilder(trampolineBB); + auto *origBB = termInst->getParent(); + auto *succEnumVal = + buildPredecessorEnumValue(trampolineBuilder, origBB, succBB, pbStructVal); + SmallVector forwardedArguments( + trampolineBB->getArguments().begin(), trampolineBB->getArguments().end()); + forwardedArguments.push_back(succEnumVal); + trampolineBuilder.createBranch(termInst->getLoc(), vjpSuccBB, + forwardedArguments); + return trampolineBB; +} + void VJPEmitter::visit(SILInstruction *inst) { if (errorOccurred) return; @@ -290,10 +315,9 @@ SILType VJPEmitter::getNominalDeclLoweredType(NominalTypeDecl *nominal) { return getLoweredType(nominalType); } -StructInst *VJPEmitter::buildPullbackValueStructValue(TermInst *termInst) { - assert(termInst->getFunction() == original); - auto loc = termInst->getFunction()->getLocation(); - auto *origBB = termInst->getParent(); +StructInst *VJPEmitter::buildPullbackValueStructValue(SILBasicBlock *origBB) { + assert(origBB->getParent() == original); + auto loc = origBB->getParent()->getLocation(); auto *vjpBB = BBMap[origBB]; auto *pbStruct = pullbackInfo.getLinearMapStruct(origBB); auto structLoweredTy = getNominalDeclLoweredType(pbStruct); @@ -333,9 +357,11 @@ EnumInst *VJPEmitter::buildPredecessorEnumValue(SILBuilder &builder, void VJPEmitter::visitReturnInst(ReturnInst *ri) { auto loc = ri->getOperand().getLoc(); - auto *origExit = ri->getParent(); auto &builder = getBuilder(); - auto *pbStructVal = buildPullbackValueStructValue(ri); + + // Build pullback struct value for original block. + auto *origExit = ri->getParent(); + auto *pbStructVal = buildPullbackValueStructValue(origExit); // Get the value in the VJP corresponding to the original result. auto *origRetInst = cast(origExit->getTerminator()); @@ -390,7 +416,7 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) { // Build pullback struct value for original block. // Build predecessor enum value for destination block. auto *origBB = bi->getParent(); - auto *pbStructVal = buildPullbackValueStructValue(bi); + auto *pbStructVal = buildPullbackValueStructValue(origBB); auto *enumVal = buildPredecessorEnumValue(getBuilder(), origBB, bi->getDestBB(), pbStructVal); @@ -407,85 +433,30 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) { void VJPEmitter::visitCondBranchInst(CondBranchInst *cbi) { // Build pullback struct value for original block. - // Build predecessor enum values for true/false blocks. - auto *origBB = cbi->getParent(); - auto *pbStructVal = buildPullbackValueStructValue(cbi); - - // Creates a trampoline block for given original successor block. The - // trampoline block has the same arguments as the VJP successor block but - // drops the last predecessor enum argument. The generated `switch_enum` - // instruction branches to the trampoline block, and the trampoline block - // constructs a predecessor enum value and branches to the VJP successor - // block. - auto createTrampolineBasicBlock = - [&](SILBasicBlock *origSuccBB) -> SILBasicBlock * { - auto *vjpSuccBB = getOpBasicBlock(origSuccBB); - // Create the trampoline block. - auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB); - for (auto *arg : vjpSuccBB->getArguments().drop_back()) - trampolineBB->createPhiArgument(arg->getType(), arg->getOwnershipKind()); - // Build predecessor enum value for successor block and branch to it. - SILBuilder trampolineBuilder(trampolineBB); - auto *succEnumVal = buildPredecessorEnumValue(trampolineBuilder, origBB, - origSuccBB, pbStructVal); - SmallVector forwardedArguments( - trampolineBB->getArguments().begin(), - trampolineBB->getArguments().end()); - forwardedArguments.push_back(succEnumVal); - trampolineBuilder.createBranch(cbi->getLoc(), vjpSuccBB, - forwardedArguments); - return trampolineBB; - }; - + auto *pbStructVal = buildPullbackValueStructValue(cbi->getParent()); // Create a new `cond_br` instruction. - getBuilder().createCondBranch(cbi->getLoc(), getOpValue(cbi->getCondition()), - createTrampolineBasicBlock(cbi->getTrueBB()), - createTrampolineBasicBlock(cbi->getFalseBB())); + getBuilder().createCondBranch( + cbi->getLoc(), getOpValue(cbi->getCondition()), + createTrampolineBasicBlock(cbi, pbStructVal, cbi->getTrueBB()), + createTrampolineBasicBlock(cbi, pbStructVal, cbi->getFalseBB())); } void VJPEmitter::visitSwitchEnumInstBase(SwitchEnumInstBase *sei) { // Build pullback struct value for original block. - auto *origBB = sei->getParent(); - auto *pbStructVal = buildPullbackValueStructValue(sei); - - // Creates a trampoline block for given original successor block. The - // trampoline block has the same arguments as the VJP successor block but - // drops the last predecessor enum argument. The generated `switch_enum` - // instruction branches to the trampoline block, and the trampoline block - // constructs a predecessor enum value and branches to the VJP successor - // block. - auto createTrampolineBasicBlock = - [&](SILBasicBlock *origSuccBB) -> SILBasicBlock * { - auto *vjpSuccBB = getOpBasicBlock(origSuccBB); - // Create the trampoline block. - auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB); - for (auto *destArg : vjpSuccBB->getArguments().drop_back()) - trampolineBB->createPhiArgument(destArg->getType(), - destArg->getOwnershipKind()); - // Build predecessor enum value for successor block and branch to it. - SILBuilder trampolineBuilder(trampolineBB); - auto *succEnumVal = buildPredecessorEnumValue(trampolineBuilder, origBB, - origSuccBB, pbStructVal); - SmallVector forwardedArguments( - trampolineBB->getArguments().begin(), - trampolineBB->getArguments().end()); - forwardedArguments.push_back(succEnumVal); - trampolineBuilder.createBranch(sei->getLoc(), vjpSuccBB, - forwardedArguments); - return trampolineBB; - }; + auto *pbStructVal = buildPullbackValueStructValue(sei->getParent()); // Create trampoline successor basic blocks. SmallVector, 4> caseBBs; for (unsigned i : range(sei->getNumCases())) { auto caseBB = sei->getCase(i); - auto *trampolineBB = createTrampolineBasicBlock(caseBB.second); + auto *trampolineBB = + createTrampolineBasicBlock(sei, pbStructVal, caseBB.second); caseBBs.push_back({caseBB.first, trampolineBB}); } // Create trampoline default basic block. SILBasicBlock *newDefaultBB = nullptr; if (auto *defaultBB = sei->getDefaultBBOrNull().getPtrOrNull()) - newDefaultBB = createTrampolineBasicBlock(defaultBB); + newDefaultBB = createTrampolineBasicBlock(sei, pbStructVal, defaultBB); // Create a new `switch_enum` instruction. switch (sei->getKind()) { From 24de636822b7680db442b6fb161356d2b0034f6a Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 28 May 2020 12:41:55 -0700 Subject: [PATCH 2/3] [AutoDiff] Re-enable control_flow.swift test. This test was disabled in SR-12741 due to iphonesimulator-i386 failures. Enabling the test on other platforms is important to prevent regressions. --- test/AutoDiff/validation-test/control_flow.swift | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/AutoDiff/validation-test/control_flow.swift b/test/AutoDiff/validation-test/control_flow.swift index d9a45e5b143ff..5a2315022e5f3 100644 --- a/test/AutoDiff/validation-test/control_flow.swift +++ b/test/AutoDiff/validation-test/control_flow.swift @@ -1,7 +1,9 @@ // RUN: %target-run-simple-swift // REQUIRES: executable_test -// REQUIRES: SR12741 +// FIXME(SR-12741): Enable test for all platforms after debugging +// iphonesimulator-i386-specific failures. +// REQUIRES: CPU=x86_64 import _Differentiation import StdlibUnittest From d5d076db6a2381fc9f205b978b0722457f794c34 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 28 May 2020 12:44:36 -0700 Subject: [PATCH 3/3] [AutoDiff] Support differentiation of branching cast instructions. Support differentiation of `is` and `as?` operators. These operators lower to branching cast SIL instructions, requiring control flow differentiation support. Resolves SR-12898. --- .../SILOptimizer/Differentiation/VJPEmitter.h | 6 + .../DifferentiableActivityAnalysis.cpp | 22 +++- .../Differentiation/VJPEmitter.cpp | 41 +++++++ .../Mandatory/Differentiation.cpp | 8 +- .../SILOptimizer/activity_analysis.swift | 110 ++++++++++++++++++ .../validation-test/control_flow.swift | 21 ++++ 6 files changed, 201 insertions(+), 7 deletions(-) diff --git a/include/swift/SILOptimizer/Differentiation/VJPEmitter.h b/include/swift/SILOptimizer/Differentiation/VJPEmitter.h index a0ec673b4d73b..475ac7b3f7272 100644 --- a/include/swift/SILOptimizer/Differentiation/VJPEmitter.h +++ b/include/swift/SILOptimizer/Differentiation/VJPEmitter.h @@ -153,6 +153,12 @@ class VJPEmitter final void visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai); + void visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi); + + void visitCheckedCastValueBranchInst(CheckedCastValueBranchInst *ccvbi); + + void visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *ccabi); + // If an `apply` has active results or active inout arguments, replace it // with an `apply` of its VJP. void visitApplyInst(ApplyInst *ai); diff --git a/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp b/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp index 8358319ef16b2..c903dbeb6435c 100644 --- a/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp +++ b/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp @@ -193,12 +193,26 @@ void DifferentiableActivityInfo::propagateVaried( if (auto *destBBArg = cbi->getArgForOperand(operand)) setVariedAndPropagateToUsers(destBBArg, i); } - // Handle `switch_enum`. - else if (auto *sei = dyn_cast(inst)) { - if (isVaried(sei->getOperand(), i)) - for (auto *succBB : sei->getSuccessorBlocks()) + // Handle `checked_cast_addr_br`. + // Propagate variedness from source operand to destination operand, in + // addition to all successor block arguments. + else if (auto *ccabi = dyn_cast(inst)) { + if (isVaried(ccabi->getSrc(), i)) { + setVariedAndPropagateToUsers(ccabi->getDest(), i); + for (auto *succBB : ccabi->getSuccessorBlocks()) for (auto *arg : succBB->getArguments()) setVariedAndPropagateToUsers(arg, i); + } + } + // Handle all other terminators: if any operand is active, propagate + // variedness to all successor block arguments. This logic may be incorrect + // for some terminator instructions, so special cases must be defined above. + else if (auto *termInst = dyn_cast(inst)) { + for (auto &op : termInst->getAllOperands()) + if (isVaried(op.get(), i)) + for (auto *succBB : termInst->getSuccessorBlocks()) + for (auto *arg : succBB->getArguments()) + setVariedAndPropagateToUsers(arg, i); } // Handle everything else. else { diff --git a/lib/SILOptimizer/Differentiation/VJPEmitter.cpp b/lib/SILOptimizer/Differentiation/VJPEmitter.cpp index 8486b7b2bec14..ec6b82cc7fb17 100644 --- a/lib/SILOptimizer/Differentiation/VJPEmitter.cpp +++ b/lib/SILOptimizer/Differentiation/VJPEmitter.cpp @@ -481,6 +481,47 @@ void VJPEmitter::visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) { visitSwitchEnumInstBase(seai); } +void VJPEmitter::visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) { + // Build pullback struct value for original block. + auto *pbStructVal = buildPullbackValueStructValue(ccbi->getParent()); + // Create a new `checked_cast_branch` instruction. + getBuilder().createCheckedCastBranch( + ccbi->getLoc(), ccbi->isExact(), getOpValue(ccbi->getOperand()), + getOpType(ccbi->getTargetLoweredType()), + getOpASTType(ccbi->getTargetFormalType()), + createTrampolineBasicBlock(ccbi, pbStructVal, ccbi->getSuccessBB()), + createTrampolineBasicBlock(ccbi, pbStructVal, ccbi->getFailureBB()), + ccbi->getTrueBBCount(), ccbi->getFalseBBCount()); +} + +void VJPEmitter::visitCheckedCastValueBranchInst( + CheckedCastValueBranchInst *ccvbi) { + // Build pullback struct value for original block. + auto *pbStructVal = buildPullbackValueStructValue(ccvbi->getParent()); + // Create a new `checked_cast_value_branch` instruction. + getBuilder().createCheckedCastValueBranch( + ccvbi->getLoc(), getOpValue(ccvbi->getOperand()), + getOpASTType(ccvbi->getSourceFormalType()), + getOpType(ccvbi->getTargetLoweredType()), + getOpASTType(ccvbi->getTargetFormalType()), + createTrampolineBasicBlock(ccvbi, pbStructVal, ccvbi->getSuccessBB()), + createTrampolineBasicBlock(ccvbi, pbStructVal, ccvbi->getFailureBB())); +} + +void VJPEmitter::visitCheckedCastAddrBranchInst( + CheckedCastAddrBranchInst *ccabi) { + // Build pullback struct value for original block. + auto *pbStructVal = buildPullbackValueStructValue(ccabi->getParent()); + // Create a new `checked_cast_addr_branch` instruction. + getBuilder().createCheckedCastAddrBranch( + ccabi->getLoc(), ccabi->getConsumptionKind(), getOpValue(ccabi->getSrc()), + getOpASTType(ccabi->getSourceFormalType()), getOpValue(ccabi->getDest()), + getOpASTType(ccabi->getTargetFormalType()), + createTrampolineBasicBlock(ccabi, pbStructVal, ccabi->getSuccessBB()), + createTrampolineBasicBlock(ccabi, pbStructVal, ccabi->getFailureBB()), + ccabi->getTrueBBCount(), ccabi->getFalseBBCount()); +} + void VJPEmitter::visitApplyInst(ApplyInst *ai) { // If callee should not be differentiated, do standard cloning. if (!pullbackInfo.shouldDifferentiateApplySite(ai)) { diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 89588cd7b110d..3914cc3380e2f 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -152,10 +152,12 @@ static bool diagnoseUnsupportedControlFlow(ADContext &context, // Diagnose unsupported branching terminators. for (auto &bb : *original) { auto *term = bb.getTerminator(); - // Supported terminators are: `br`, `cond_br`, `switch_enum`, - // `switch_enum_addr`. + // Check supported branching terminators. if (isa(term) || isa(term) || - isa(term) || isa(term)) + isa(term) || isa(term) || + isa(term) || + isa(term) || + isa(term)) continue; // If terminator is an unsupported branching terminator, emit an error. if (term->isBranch()) { diff --git a/test/AutoDiff/SILOptimizer/activity_analysis.swift b/test/AutoDiff/SILOptimizer/activity_analysis.swift index 6c1fc31f23cf4..b4360484123a5 100644 --- a/test/AutoDiff/SILOptimizer/activity_analysis.swift +++ b/test/AutoDiff/SILOptimizer/activity_analysis.swift @@ -122,6 +122,116 @@ func TF_954(_ x: Float) -> Float { // CHECK: [ACTIVE] %40 = begin_access [read] [static] %2 : $*Float // CHECK: [ACTIVE] %41 = load [trivial] %40 : $*Float +//===----------------------------------------------------------------------===// +// Branching cast instructions +//===----------------------------------------------------------------------===// + +@differentiable +func checked_cast_branch(_ x: Float) -> Float { + // expected-warning @+1 {{'is' test is always true}} + if Int.self is Any.Type { + return x + x + } + return x * x +} + +// CHECK-LABEL: [AD] Activity info for ${{.*}}checked_cast_branch{{.*}} at (source=0 parameters=(0)) +// CHECK: bb0: +// CHECK: [ACTIVE] %0 = argument of bb0 : $Float +// CHECK: [NONE] %2 = metatype $@thin Int.Type +// CHECK: [NONE] %3 = metatype $@thick Int.Type +// CHECK: bb1: +// CHECK: [NONE] %5 = argument of bb1 : $@thick Any.Type +// CHECK: [NONE] %6 = integer_literal $Builtin.Int1, -1 +// CHECK: bb2: +// CHECK: [NONE] %8 = argument of bb2 : $@thick Int.Type +// CHECK: [NONE] %9 = integer_literal $Builtin.Int1, 0 +// CHECK: bb3: +// CHECK: [NONE] %11 = argument of bb3 : $Builtin.Int1 +// CHECK: [NONE] %12 = metatype $@thin Bool.Type +// CHECK: [NONE] // function_ref Bool.init(_builtinBooleanLiteral:) +// CHECK: [NONE] %14 = apply %13(%11, %12) : $@convention(method) (Builtin.Int1, @thin Bool.Type) -> Bool +// CHECK: [NONE] %15 = struct_extract %14 : $Bool, #Bool._value +// CHECK: bb4: +// CHECK: [USEFUL] %17 = metatype $@thin Float.Type +// CHECK: [NONE] // function_ref static Float.+ infix(_:_:) +// CHECK: [ACTIVE] %19 = apply %18(%0, %0, %17) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// CHECK: bb5: +// CHECK: [USEFUL] %21 = metatype $@thin Float.Type +// CHECK: [NONE] // function_ref static Float.* infix(_:_:) +// CHECK: [ACTIVE] %23 = apply %22(%0, %0, %21) : $@convention(method) (Float, Float, @thin Float.Type) -> Float + +// CHECK-LABEL: sil hidden [ossa] @${{.*}}checked_cast_branch{{.*}} : $@convention(thin) (Float) -> Float { +// CHECK: checked_cast_br %3 : $@thick Int.Type to Any.Type, bb1, bb2 +// CHECK: } + +@differentiable +func checked_cast_addr_nonactive_result(_ x: T) -> T { + if let _ = x as? Float { + // Do nothing with `y: Float?` value. + } + return x +} + +// CHECK-LABEL: [AD] Activity info for ${{.*}}checked_cast_addr_nonactive_result{{.*}} at (source=0 parameters=(0)) +// CHECK: bb0: +// CHECK: [ACTIVE] %0 = argument of bb0 : $*T +// CHECK: [ACTIVE] %1 = argument of bb0 : $*T +// CHECK: [VARIED] %3 = alloc_stack $T +// CHECK: [VARIED] %5 = alloc_stack $Float +// CHECK: bb1: +// CHECK: [VARIED] %7 = load [trivial] %5 : $*Float +// CHECK: [VARIED] %8 = enum $Optional, #Optional.some!enumelt, %7 : $Float +// CHECK: bb2: +// CHECK: [NONE] %11 = enum $Optional, #Optional.none!enumelt +// CHECK: bb3: +// CHECK: [VARIED] %14 = argument of bb3 : $Optional +// CHECK: bb4: +// CHECK: bb5: +// CHECK: [VARIED] %18 = argument of bb5 : $Float +// CHECK: bb6: +// CHECK: [NONE] %22 = tuple () + +// CHECK-LABEL: sil hidden [ossa] @${{.*}}checked_cast_addr_nonactive_result{{.*}} : $@convention(thin) (@in_guaranteed T) -> @out T { +// CHECK: checked_cast_addr_br take_always T in %3 : $*T to Float in %5 : $*Float, bb1, bb2 +// CHECK: } + +// expected-error @+1 {{function is not differentiable}} +@differentiable +// expected-note @+1 {{when differentiating this function definition}} +func checked_cast_addr_active_result(x: T) -> T { + // expected-note @+1 {{differentiating enum values is not yet supported}} + if let y = x as? Float { + // Use `y: Float?` value in an active way. + return y as! T + } + return x +} + +// CHECK-LABEL: [AD] Activity info for ${{.*}}checked_cast_addr_active_result{{.*}} at (source=0 parameters=(0)) +// CHECK: bb0: +// CHECK: [ACTIVE] %0 = argument of bb0 : $*T +// CHECK: [ACTIVE] %1 = argument of bb0 : $*T +// CHECK: [ACTIVE] %3 = alloc_stack $T +// CHECK: [ACTIVE] %5 = alloc_stack $Float +// CHECK: bb1: +// CHECK: [ACTIVE] %7 = load [trivial] %5 : $*Float +// CHECK: [ACTIVE] %8 = enum $Optional, #Optional.some!enumelt, %7 : $Float +// CHECK: bb2: +// CHECK: [USEFUL] %11 = enum $Optional, #Optional.none!enumelt +// CHECK: bb3: +// CHECK: [ACTIVE] %14 = argument of bb3 : $Optional +// CHECK: bb4: +// CHECK: [ACTIVE] %16 = argument of bb4 : $Float +// CHECK: [ACTIVE] %19 = alloc_stack $Float +// CHECK: bb5: +// CHECK: bb6: +// CHECK: [NONE] %27 = tuple () + +// CHECK-LABEL: sil hidden [ossa] @${{.*}}checked_cast_addr_active_result{{.*}} : $@convention(thin) (@in_guaranteed T) -> @out T { +// CHECK: checked_cast_addr_br take_always T in %3 : $*T to Float in %5 : $*Float, bb1, bb2 +// CHECK: } + //===----------------------------------------------------------------------===// // Array literal differentiation //===----------------------------------------------------------------------===// diff --git a/test/AutoDiff/validation-test/control_flow.swift b/test/AutoDiff/validation-test/control_flow.swift index 5a2315022e5f3..d2042a1135216 100644 --- a/test/AutoDiff/validation-test/control_flow.swift +++ b/test/AutoDiff/validation-test/control_flow.swift @@ -715,4 +715,25 @@ ControlFlowTests.test("Loops") { expectEqual((24, 28), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 4) })) } +ControlFlowTests.test("BranchingCastInstructions") { + // checked_cast_br + func typeCheckOperator(_ x: Float, _ metatype: T.Type) -> Float { + if metatype is Int.Type { + return x + x + } + return x * x + } + expectEqual((6, 2), valueWithGradient(at: 3, in: { typeCheckOperator($0, Int.self) })) + expectEqual((9, 6), valueWithGradient(at: 3, in: { typeCheckOperator($0, Float.self) })) + + // checked_cast_addr_br + func conditionalCast(_ x: T) -> T { + if let _ = x as? Float { + // Do nothing with `y: Float?` value. + } + return x + } + expectEqual((3, 1), valueWithGradient(at: Float(3), in: conditionalCast)) +} + runAllTests()