@@ -1114,38 +1114,41 @@ SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst(
1114
1114
// %vjp' = convert_escape_to_noescape %vjp
1115
1115
// %y = differentiable_function(%orig', %jvp', %vjp')
1116
1116
if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(Cvt->getConverted ())) {
1117
- auto createConvertEscapeToNoEscape = [&](NormalDifferentiableFunctionTypeComponent extractee) {
1118
- if (!DFI->hasExtractee (extractee))
1119
- return SILValue ();
1117
+ if (DFI->hasOneUse ()) {
1118
+ auto createConvertEscapeToNoEscape =
1119
+ [&](NormalDifferentiableFunctionTypeComponent extractee) {
1120
+ if (!DFI->hasExtractee (extractee))
1121
+ return SILValue ();
1120
1122
1121
- auto operand = DFI->getExtractee (extractee);
1122
- auto fnType = operand->getType ().castTo <SILFunctionType>();
1123
- auto noEscapeFnType =
1124
- fnType->getWithExtInfo (fnType->getExtInfo ().withNoEscape ());
1125
- auto noEscapeType = SILType::getPrimitiveObjectType (noEscapeFnType);
1126
- return Builder.createConvertEscapeToNoEscape (
1127
- operand.getLoc (), operand, noEscapeType, Cvt->isLifetimeGuaranteed ())->getResult (0 );
1128
- };
1123
+ auto operand = DFI->getExtractee (extractee);
1124
+ auto fnType = operand->getType ().castTo <SILFunctionType>();
1125
+ auto noEscapeFnType =
1126
+ fnType->getWithExtInfo (fnType->getExtInfo ().withNoEscape ());
1127
+ auto noEscapeType = SILType::getPrimitiveObjectType (noEscapeFnType);
1128
+ return Builder.createConvertEscapeToNoEscape (
1129
+ operand.getLoc (), operand, noEscapeType, Cvt->isLifetimeGuaranteed ())->getResult (0 );
1130
+ };
1129
1131
1130
- SILValue originalNoEscape =
1131
- createConvertEscapeToNoEscape (NormalDifferentiableFunctionTypeComponent::Original);
1132
- SILValue convertedJVP = createConvertEscapeToNoEscape (
1133
- NormalDifferentiableFunctionTypeComponent::JVP);
1134
- SILValue convertedVJP = createConvertEscapeToNoEscape (
1135
- NormalDifferentiableFunctionTypeComponent::VJP);
1136
-
1137
- Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1138
- if (convertedJVP && convertedVJP)
1139
- derivativeFunctions = std::make_pair (convertedJVP, convertedVJP);
1140
-
1141
- auto *newDFI = Builder.createDifferentiableFunction (
1142
- DFI->getLoc (), DFI->getParameterIndices (), DFI->getResultIndices (),
1143
- originalNoEscape, derivativeFunctions);
1144
- assert (newDFI->getType () == Cvt->getType () &&
1145
- " New `@differentiable` function instruction should have same type "
1146
- " as the old `convert_escape_to_no_escape` instruction" );
1147
- return newDFI;
1148
- }
1132
+ SILValue originalNoEscape =
1133
+ createConvertEscapeToNoEscape (NormalDifferentiableFunctionTypeComponent::Original);
1134
+ SILValue convertedJVP = createConvertEscapeToNoEscape (
1135
+ NormalDifferentiableFunctionTypeComponent::JVP);
1136
+ SILValue convertedVJP = createConvertEscapeToNoEscape (
1137
+ NormalDifferentiableFunctionTypeComponent::VJP);
1138
+
1139
+ Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1140
+ if (convertedJVP && convertedVJP)
1141
+ derivativeFunctions = std::make_pair (convertedJVP, convertedVJP);
1142
+
1143
+ auto *newDFI = Builder.createDifferentiableFunction (
1144
+ DFI->getLoc (), DFI->getParameterIndices (), DFI->getResultIndices (),
1145
+ originalNoEscape, derivativeFunctions);
1146
+ assert (newDFI->getType () == Cvt->getType () &&
1147
+ " New `@differentiable` function instruction should have same type "
1148
+ " as the old `convert_escape_to_no_escape` instruction" );
1149
+ return newDFI;
1150
+ }
1151
+ }
1149
1152
1150
1153
return nullptr ;
1151
1154
}
0 commit comments