Skip to content

Commit 19790be

Browse files
committed
Ensure we can fold apply of a differentiable_function_inst. Also fixes one small potential issue while there.
Fixes #65489
1 parent af05904 commit 19790be

File tree

3 files changed

+88
-30
lines changed

3 files changed

+88
-30
lines changed

lib/SILOptimizer/SILCombiner/SILCombinerApplyVisitors.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,6 +1500,13 @@ SILInstruction *SILCombiner::visitApplyInst(ApplyInst *AI) {
15001500
}
15011501
}
15021502

1503+
// (apply (differentiable_function f)) to (apply f)
1504+
if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(AI->getCallee())) {
1505+
return cloneFullApplySiteReplacingCallee(AI, DFI->getOperand(0),
1506+
Builder.getBuilderContext())
1507+
.getInstruction();
1508+
}
1509+
15031510
// (apply (thin_to_thick_function f)) to (apply f)
15041511
if (auto *TTTFI = dyn_cast<ThinToThickFunctionInst>(AI->getCallee())) {
15051512
// We currently don't remove any possible retain associated with the thick

lib/SILOptimizer/SILCombiner/SILCombinerCastVisitors.cpp

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,38 +1114,41 @@ SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst(
11141114
// %vjp' = convert_escape_to_noescape %vjp
11151115
// %y = differentiable_function(%orig', %jvp', %vjp')
11161116
if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(Cvt->getConverted())) {
1117-
auto createConvertEscapeToNoEscape = [&](NormalDifferentiableFunctionTypeComponent extractee) {
1118-
if (!DFI->hasExtractee(extractee))
1119-
return SILValue();
1117+
if (DFI->hasOneUse()) {
1118+
auto createConvertEscapeToNoEscape =
1119+
[&](NormalDifferentiableFunctionTypeComponent extractee) {
1120+
if (!DFI->hasExtractee(extractee))
1121+
return SILValue();
11201122

1121-
auto operand = DFI->getExtractee(extractee);
1122-
auto fnType = operand->getType().castTo<SILFunctionType>();
1123-
auto noEscapeFnType =
1124-
fnType->getWithExtInfo(fnType->getExtInfo().withNoEscape());
1125-
auto noEscapeType = SILType::getPrimitiveObjectType(noEscapeFnType);
1126-
return Builder.createConvertEscapeToNoEscape(
1127-
operand.getLoc(), operand, noEscapeType, Cvt->isLifetimeGuaranteed())->getResult(0);
1128-
};
1123+
auto operand = DFI->getExtractee(extractee);
1124+
auto fnType = operand->getType().castTo<SILFunctionType>();
1125+
auto noEscapeFnType =
1126+
fnType->getWithExtInfo(fnType->getExtInfo().withNoEscape());
1127+
auto noEscapeType = SILType::getPrimitiveObjectType(noEscapeFnType);
1128+
return Builder.createConvertEscapeToNoEscape(
1129+
operand.getLoc(), operand, noEscapeType, Cvt->isLifetimeGuaranteed())->getResult(0);
1130+
};
11291131

1130-
SILValue originalNoEscape =
1131-
createConvertEscapeToNoEscape(NormalDifferentiableFunctionTypeComponent::Original);
1132-
SILValue convertedJVP = createConvertEscapeToNoEscape(
1133-
NormalDifferentiableFunctionTypeComponent::JVP);
1134-
SILValue convertedVJP = createConvertEscapeToNoEscape(
1135-
NormalDifferentiableFunctionTypeComponent::VJP);
1136-
1137-
Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1138-
if (convertedJVP && convertedVJP)
1139-
derivativeFunctions = std::make_pair(convertedJVP, convertedVJP);
1140-
1141-
auto *newDFI = Builder.createDifferentiableFunction(
1142-
DFI->getLoc(), DFI->getParameterIndices(), DFI->getResultIndices(),
1143-
originalNoEscape, derivativeFunctions);
1144-
assert(newDFI->getType() == Cvt->getType() &&
1145-
"New `@differentiable` function instruction should have same type "
1146-
"as the old `convert_escape_to_no_escape` instruction");
1147-
return newDFI;
1148-
}
1132+
SILValue originalNoEscape =
1133+
createConvertEscapeToNoEscape(NormalDifferentiableFunctionTypeComponent::Original);
1134+
SILValue convertedJVP = createConvertEscapeToNoEscape(
1135+
NormalDifferentiableFunctionTypeComponent::JVP);
1136+
SILValue convertedVJP = createConvertEscapeToNoEscape(
1137+
NormalDifferentiableFunctionTypeComponent::VJP);
1138+
1139+
Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1140+
if (convertedJVP && convertedVJP)
1141+
derivativeFunctions = std::make_pair(convertedJVP, convertedVJP);
1142+
1143+
auto *newDFI = Builder.createDifferentiableFunction(
1144+
DFI->getLoc(), DFI->getParameterIndices(), DFI->getResultIndices(),
1145+
originalNoEscape, derivativeFunctions);
1146+
assert(newDFI->getType() == Cvt->getType() &&
1147+
"New `@differentiable` function instruction should have same type "
1148+
"as the old `convert_escape_to_no_escape` instruction");
1149+
return newDFI;
1150+
}
1151+
}
11491152

11501153
return nullptr;
11511154
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// RUN: %target-swift-frontend -emit-sil -O %s | %FileCheck %s
2+
// REQUIRES: swift_in_compiler
3+
4+
import _Differentiation
5+
6+
@differentiable(reverse)
7+
@_silgen_name("test_f")
8+
// Check that (differentiable) closure apply is optimized out
9+
// CHECK-LABEL: test_f : $@convention(thin) (@guaranteed Array<Double>) -> Double
10+
// CHECK-NOT: differentiable_function [parameters 0] [results 0]
11+
func f(array: [Double]) -> Double {
12+
var array = array
13+
array.update(at: 1,
14+
byCalling: {
15+
(element: inout Double) in
16+
let initialElement = element;
17+
element *= initialElement
18+
}
19+
)
20+
21+
return 0.0
22+
}
23+
24+
public func valueWithPullback<T>(at x: T, of f: @differentiable(reverse) (inout T) -> Void) -> (value: Void, pullback: (inout T.TangentVector) -> Void) {fatalError()}
25+
public func pullback<T>(at x: T, of f: @differentiable(reverse) (inout T) -> Void) -> (inout T.TangentVector) -> Void {return valueWithPullback(at: x, of: f).pullback}
26+
27+
public extension Array {
28+
@differentiable(reverse)
29+
mutating func update(at index: Int,
30+
byCalling closure: @differentiable(reverse) (inout Element) -> Void) where Element: Differentiable {
31+
closure(&self[index])
32+
}
33+
}
34+
35+
public extension Array where Element: Differentiable {
36+
@derivative(of: update(at:byCalling:))
37+
mutating func vjpUpdate(at index: Int, byCalling closure: @differentiable(reverse) (inout Element) -> Void) -> (value: Void, pullback: (inout Self.TangentVector) -> Void) {
38+
let closurePullback = pullback(at: self[index], of: closure)
39+
return (value: (), pullback: { closurePullback(&$0.base[index]) })
40+
}
41+
}
42+
43+
public struct D<I: Equatable, D> {
44+
public subscript(_ index: I) -> D? {
45+
get {fatalError()}
46+
set {fatalError()}
47+
}
48+
}

0 commit comments

Comments
 (0)