Skip to content

Commit 3c4612d

Browse files
committed
Initial support for differentiation of throwing functions
1 parent a654df7 commit 3c4612d

File tree

13 files changed

+643
-62
lines changed

13 files changed

+643
-62
lines changed

include/swift/SILOptimizer/Differentiation/Common.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,12 @@ inline void createEntryArguments(SILFunction *f) {
277277
indResTy = indResTy.mapTypeOutOfContext();
278278
createFunctionArgument(f->mapTypeIntoContext(indResTy).getAddressType());
279279
}
280+
if (auto indErrorResTy = conv.getIndirectErrorResultType(f->getTypeExpansionContext())) {
281+
if (indErrorResTy.hasArchetype())
282+
indErrorResTy = indErrorResTy.mapTypeOutOfContext();
283+
createFunctionArgument(f->mapTypeIntoContext(indErrorResTy).getAddressType());
284+
}
285+
280286
for (auto paramTy : conv.getParameterSILTypes(f->getTypeExpansionContext())) {
281287
if (paramTy.hasArchetype())
282288
paramTy = paramTy.mapTypeOutOfContext();

lib/AST/Builtins.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2892,6 +2892,7 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
28922892
if (!autodiff::getBuiltinApplyDerivativeConfig(
28932893
OperationName, kind, arity, throws))
28942894
return nullptr;
2895+
// TODO: Support somehow typed throws
28952896
return getAutoDiffApplyDerivativeFunction(Context, Id, kind, arity,
28962897
throws, /*thrownType=*/Type());
28972898
}
@@ -2901,6 +2902,7 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
29012902
if (!autodiff::getBuiltinApplyTransposeConfig(
29022903
OperationName, arity, throws))
29032904
return nullptr;
2905+
// TODO: Support somehow typed throws
29042906
return getAutoDiffApplyTransposeFunction(Context, Id, arity, throws,
29052907
/*thrownType=*/Type());
29062908
}

lib/SILGen/SILGenBuiltin.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,11 +1191,8 @@ static ManagedValue emitBuiltinTypeTrait(SILGenFunction &SGF,
11911191

11921192
static ManagedValue emitBuiltinAutoDiffApplyDerivativeFunction(
11931193
AutoDiffDerivativeFunctionKind kind, unsigned arity,
1194-
bool throws, SILGenFunction &SGF, SILLocation loc,
1195-
SubstitutionMap substitutions, ArrayRef<ManagedValue> args, SGFContext C) {
1196-
// FIXME(https://github.com/apple/swift/issues/54259): Support throwing functions.
1197-
assert(!throws && "Throwing functions are not yet supported");
1198-
1194+
SILGenFunction &SGF, SILLocation loc, SubstitutionMap substitutions,
1195+
ArrayRef<ManagedValue> args, SGFContext C) {
11991196
auto origFnVal = args[0];
12001197
SmallVector<SILValue, 2> origFnArgVals;
12011198
for (auto& arg : args.drop_front(1))
@@ -1213,7 +1210,8 @@ static ManagedValue emitBuiltinAutoDiffApplyDerivativeFunction(
12131210
origFnVal = SGF.B.createBeginBorrow(loc, origFnVal);
12141211
SILValue derivativeFn = SGF.B.createDifferentiableFunctionExtract(
12151212
loc, kind, origFnVal.getValue());
1216-
auto derivativeFnType = derivativeFn->getType().castTo<SILFunctionType>();
1213+
SILType derivativeType = derivativeFn->getType();
1214+
auto derivativeFnType = derivativeType.castTo<SILFunctionType>();
12171215
assert(derivativeFnType->getNumResults() == 2);
12181216
assert(derivativeFnType->getNumParameters() == origFnArgVals.size());
12191217

@@ -1240,8 +1238,10 @@ static ManagedValue emitBuiltinAutoDiffApplyDerivativeFunction(
12401238
applyArgs.push_back(SGF.B.createTupleElementAddr(loc, indResBuffer, 0));
12411239
for (auto origFnArgVal : origFnArgVals)
12421240
applyArgs.push_back(origFnArgVal);
1243-
auto differential = SGF.B.createApply(loc, derivativeFn, SubstitutionMap(),
1244-
applyArgs);
1241+
auto differential =
1242+
SGF.emitApplyWithRethrow(loc,
1243+
derivativeFn, derivativeType,
1244+
SubstitutionMap(), applyArgs);
12451245

12461246
derivativeFn = SILValue();
12471247

@@ -1253,8 +1253,10 @@ static ManagedValue emitBuiltinAutoDiffApplyDerivativeFunction(
12531253
}
12541254

12551255
// Do the apply for the direct result case.
1256-
auto resultTuple = SGF.B.createApply(
1257-
loc, derivativeFn, SubstitutionMap(), origFnArgVals);
1256+
auto resultTuple =
1257+
SGF.emitApplyWithRethrow(loc,
1258+
derivativeFn, derivativeType,
1259+
SubstitutionMap(), origFnArgVals);
12581260

12591261
derivativeFn = SILValue();
12601262

@@ -1324,7 +1326,7 @@ static ManagedValue emitBuiltinApplyDerivative(
13241326
builtinName, kind, arity, throws);
13251327
assert(successfullyParsed);
13261328
return emitBuiltinAutoDiffApplyDerivativeFunction(
1327-
kind, arity, throws, SGF, loc, substitutions, args, C);
1329+
kind, arity, SGF, loc, substitutions, args, C);
13281330
}
13291331

13301332
static ManagedValue emitBuiltinApplyTranspose(

lib/SILOptimizer/Differentiation/LinearMapInfo.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,15 @@ void LinearMapInfo::populateBranchingTraceDecl(SILBasicBlock *originalBB,
144144
heapAllocatedContext = true;
145145
decl->setInterfaceType(astCtx.TheRawPointerType);
146146
} else { // Otherwise the payload is the linear map tuple.
147-
auto *linearMapStructTy = getLinearMapTupleType(predBB);
147+
auto *linearMapTupleTy = getLinearMapTupleType(predBB);
148148
// Do not create entries for unreachable predecessors
149-
if (!linearMapStructTy)
149+
if (!linearMapTupleTy)
150150
continue;
151-
auto canLinearMapStructTy = linearMapStructTy->getCanonicalType();
151+
152+
auto canLinearMapTupleTy = linearMapTupleTy->getCanonicalType();
152153
decl->setInterfaceType(
153-
canLinearMapStructTy->hasArchetype()
154-
? canLinearMapStructTy->mapTypeOutOfContext() : canLinearMapStructTy);
154+
canLinearMapTupleTy->hasArchetype()
155+
? canLinearMapTupleTy->mapTypeOutOfContext() : canLinearMapTupleTy);
155156
}
156157
// Create enum element and enum case declarations.
157158
auto *paramList = ParameterList::create(astCtx, {decl});
@@ -183,6 +184,7 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, FullApplySite fai) {
183184
auto hasActiveResults = llvm::any_of(allResults, [&](SILValue res) {
184185
return activityInfo.isActive(res, config);
185186
});
187+
186188
bool hasActiveSemanticResultArgument = false;
187189
bool hasActiveArguments = false;
188190
auto numIndirectResults = fai.getNumIndirectSILResults();
@@ -311,10 +313,13 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, FullApplySite fai) {
311313
params, silFnTy->getAllResultsInterfaceType().getASTType(), info);
312314
}
313315

314-
if (astFnTy->hasArchetype())
315-
return astFnTy->mapTypeOutOfContext();
316+
Type resultType = astFnTy->hasArchetype() ? astFnTy->mapTypeOutOfContext() : astFnTy;
317+
318+
319+
if (fai.getKind() == FullApplySiteKind::TryApplyInst)
320+
resultType = resultType->wrapInOptionalType();
316321

317-
return astFnTy;
322+
return resultType;
318323
}
319324

320325
void LinearMapInfo::generateDifferentiationDataStructures(

0 commit comments

Comments
 (0)