Skip to content

Commit 1bc9159

Browse files
authored
[AutoDiff] Support differentiation of functions with multiple results in SIL. (#32629)
Reverse-mode differentiation now supports `apply` instructions with multiple active "semantic results" (formal results or `inout` parameters). The "cannot differentiate through multiple results" non-differentiability error is lifted. Resolves TF-983.
1 parent 5f95138 commit 1bc9159

File tree

7 files changed

+269
-125
lines changed

7 files changed

+269
-125
lines changed

include/swift/SILOptimizer/Differentiation/PullbackEmitter.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
105105
/// A set used to remember local allocations that were destroyed.
106106
llvm::SmallDenseSet<SILValue> destroyedLocalAllocations;
107107

108-
/// The seed argument in the pullback function.
109-
SILArgument *seed = nullptr;
108+
/// The seed arguments of the pullback function.
109+
SmallVector<SILArgument *, 4> seeds;
110110

111111
llvm::BumpPtrAllocator allocator;
112112

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 51 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -315,15 +315,12 @@ getDifferentiabilityParameters(SILFunctionType *originalFnTy,
315315
/// Collects the semantic results of the given function type in
316316
/// `originalResults`. The semantic results are formal results followed by
317317
/// `inout` parameters, in type order.
318-
// TODO(TF-983): Generalize to support multiple `inout` parameters. The current
319-
// singular `inoutParam` and `isWrtInoutParameter` are hacky.
320318
static void
321319
getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices,
322-
Optional<SILParameterInfo> &inoutParam,
323-
bool &isWrtInoutParameter,
320+
IndexSubset *&inoutParameterIndices,
324321
SmallVectorImpl<SILResultInfo> &originalResults) {
325-
inoutParam = None;
326-
isWrtInoutParameter = false;
322+
auto &C = functionType->getASTContext();
323+
SmallVector<unsigned, 4> inoutParamIndices;
327324
// Collect original formal results.
328325
originalResults.append(functionType->getResults().begin(),
329326
functionType->getResults().end());
@@ -332,11 +329,12 @@ getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices,
332329
auto param = functionType->getParameters()[i];
333330
if (!param.isIndirectInOut())
334331
continue;
335-
inoutParam = param;
336-
isWrtInoutParameter = parameterIndices->contains(i);
332+
inoutParamIndices.push_back(i);
337333
originalResults.push_back(
338334
SILResultInfo(param.getInterfaceType(), ResultConvention::Indirect));
339335
}
336+
inoutParameterIndices =
337+
IndexSubset::get(C, parameterIndices->getCapacity(), inoutParamIndices);
340338
}
341339

342340
/// Returns the differential type for the given original function type,
@@ -402,11 +400,10 @@ static CanSILFunctionType getAutoDiffDifferentialType(
402400
SmallVector<Type, 4> substReplacements;
403401
SmallVector<ProtocolConformanceRef, 4> substConformances;
404402

405-
Optional<SILParameterInfo> inoutParam = None;
406-
bool isWrtInoutParameter = false;
403+
IndexSubset *inoutParamIndices;
407404
SmallVector<SILResultInfo, 2> originalResults;
408-
getSemanticResults(originalFnTy, parameterIndices, inoutParam,
409-
isWrtInoutParameter, originalResults);
405+
getSemanticResults(originalFnTy, parameterIndices, inoutParamIndices,
406+
originalResults);
410407

411408
SmallVector<SILParameterInfo, 4> diffParams;
412409
getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams);
@@ -430,7 +427,7 @@ static CanSILFunctionType getAutoDiffDifferentialType(
430427
}
431428
}
432429
SmallVector<SILResultInfo, 1> differentialResults;
433-
if (!inoutParam || !isWrtInoutParameter) {
430+
if (inoutParamIndices->isEmpty()) {
434431
for (auto resultIndex : resultIndices->getIndices()) {
435432
auto &result = originalResults[resultIndex];
436433
auto resultTan =
@@ -480,11 +477,10 @@ static CanSILFunctionType getAutoDiffPullbackType(
480477
SmallVector<Type, 4> substReplacements;
481478
SmallVector<ProtocolConformanceRef, 4> substConformances;
482479

483-
Optional<SILParameterInfo> inoutParam = None;
484-
bool isWrtInoutParameter = false;
480+
IndexSubset *inoutParamIndices;
485481
SmallVector<SILResultInfo, 2> originalResults;
486-
getSemanticResults(originalFnTy, parameterIndices, inoutParam,
487-
isWrtInoutParameter, originalResults);
482+
getSemanticResults(originalFnTy, parameterIndices, inoutParamIndices,
483+
originalResults);
488484

489485
// Given a type, returns its formal SIL parameter info.
490486
auto getTangentParameterConventionForOriginalResult =
@@ -551,27 +547,11 @@ static CanSILFunctionType getAutoDiffPullbackType(
551547
return conv;
552548
};
553549

550+
// Collect pullback parameters.
554551
SmallVector<SILParameterInfo, 1> pullbackParams;
555-
if (inoutParam) {
556-
auto paramTan = inoutParam->getInterfaceType()->getAutoDiffTangentSpace(
557-
lookupConformance);
558-
assert(paramTan && "Parameter type does not have a tangent space?");
559-
auto paramTanConvention = isWrtInoutParameter
560-
? inoutParam->getConvention()
561-
: ParameterConvention::Indirect_In_Guaranteed;
562-
auto paramTanType = paramTan->getCanonicalType();
563-
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
564-
pullbackParams.push_back(
565-
SILParameterInfo(paramTanType, paramTanConvention));
566-
} else {
567-
auto gpIndex = substGenericParams.size();
568-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
569-
substGenericParams.push_back(gpType);
570-
substReplacements.push_back(paramTanType);
571-
pullbackParams.push_back({gpType, paramTanConvention});
572-
}
573-
} else {
574-
for (auto resultIndex : resultIndices->getIndices()) {
552+
for (auto resultIndex : resultIndices->getIndices()) {
553+
// Handle formal original result.
554+
if (resultIndex < originalFnTy->getNumResults()) {
575555
auto &origRes = originalResults[resultIndex];
576556
auto resultTan = origRes.getInterfaceType()->getAutoDiffTangentSpace(
577557
lookupConformance);
@@ -590,12 +570,46 @@ static CanSILFunctionType getAutoDiffPullbackType(
590570
substReplacements.push_back(resultTanType);
591571
pullbackParams.push_back({gpType, paramTanConvention});
592572
}
573+
continue;
574+
}
575+
// Handle original `inout` parameter.
576+
auto inoutParamIndex = resultIndex - originalFnTy->getNumResults();
577+
auto inoutParamIt = std::next(
578+
originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex);
579+
auto paramIndex =
580+
std::distance(originalFnTy->getParameters().begin(), &*inoutParamIt);
581+
auto inoutParam = originalFnTy->getParameters()[paramIndex];
582+
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
583+
lookupConformance);
584+
assert(paramTan && "Parameter type does not have a tangent space?");
585+
// The pullback parameter convention depends on whether the original `inout`
586+
// paramater is a differentiability parameter.
587+
// - If yes, the pullback parameter convention is `@inout`.
588+
// - If no, the pullback parameter convention is `@in_guaranteed`.
589+
bool isWrtInoutParameter = parameterIndices->contains(paramIndex);
590+
auto paramTanConvention = isWrtInoutParameter
591+
? inoutParam.getConvention()
592+
: ParameterConvention::Indirect_In_Guaranteed;
593+
auto paramTanType = paramTan->getCanonicalType();
594+
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
595+
pullbackParams.push_back(
596+
SILParameterInfo(paramTanType, paramTanConvention));
597+
} else {
598+
auto gpIndex = substGenericParams.size();
599+
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
600+
substGenericParams.push_back(gpType);
601+
substReplacements.push_back(paramTanType);
602+
pullbackParams.push_back({gpType, paramTanConvention});
593603
}
594604
}
605+
606+
// Collect pullback results.
595607
SmallVector<SILParameterInfo, 4> diffParams;
596608
getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams);
597609
SmallVector<SILResultInfo, 8> pullbackResults;
598610
for (auto &param : diffParams) {
611+
// Skip `inout` parameters, which semantically behave as original results
612+
// and always appear as pullback parameters.
599613
if (param.isIndirectInOut())
600614
continue;
601615
auto paramTan =

lib/SILOptimizer/Differentiation/PullbackEmitter.cpp

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,44 +1007,47 @@ bool PullbackEmitter::run() {
10071007
}
10081008

10091009
auto *pullbackEntry = pullback.getEntryBlock();
1010-
// The pullback function has type (seed, exit_pbs) -> ([arg0], ..., [argn]).
1010+
// The pullback function has type:
1011+
// `(seed0, seed1, ..., exit_pb_struct) -> (d_arg0, ..., d_argn)`.
10111012
auto pbParamArgs = pullback.getArgumentsWithoutIndirectResults();
1012-
assert(pbParamArgs.size() == 2);
1013-
seed = pbParamArgs[0];
1014-
// TODO(TF-983): Handle multiple original results.
1015-
assert(getIndices().results->getNumIndices() == 1);
1016-
auto origResult = origFormalResults[*getIndices().results->begin()];
1017-
1018-
// Assign adjoint for original result.
1013+
assert(getIndices().results->getNumIndices() == pbParamArgs.size() - 1 &&
1014+
pbParamArgs.size() >= 2);
1015+
// Assign adjoints for original result.
10191016
builder.setInsertionPoint(pullbackEntry,
10201017
getNextFunctionLocalAllocationInsertionPoint());
1021-
if (seed->getType().isAddress()) {
1022-
// If the pullback `seed` is an `inout` parameter, assign it directly as the
1023-
// adjoint buffer of the original result.
1024-
if (pullback.getLoweredFunctionType()
1025-
->getParameters()
1026-
.front()
1027-
.isIndirectInOut()) {
1028-
setAdjointBuffer(origExit, origResult, seed);
1029-
}
1030-
// Otherwise, assign a copy of `seed` as the adjoint buffer of the original
1031-
// result.
1032-
else {
1033-
auto *seedBufCopy = builder.createAllocStack(pbLoc, seed->getType());
1034-
builder.createCopyAddr(pbLoc, seed, seedBufCopy, IsNotTake,
1035-
IsInitialization);
1036-
functionLocalAllocations.push_back(seedBufCopy);
1037-
setAdjointBuffer(origExit, origResult, seedBufCopy);
1018+
unsigned seedIndex = 0;
1019+
for (auto resultIndex : getIndices().results->getIndices()) {
1020+
auto origResult = origFormalResults[resultIndex];
1021+
auto *seed = pbParamArgs[seedIndex];
1022+
if (seed->getType().isAddress()) {
1023+
// If the seed argument is an `inout` parameter, assign it directly as
1024+
// the adjoint buffer of the original result.
1025+
auto seedParamInfo =
1026+
pullback.getLoweredFunctionType()->getParameters()[seedIndex];
1027+
if (seedParamInfo.isIndirectInOut()) {
1028+
setAdjointBuffer(origExit, origResult, seed);
1029+
}
1030+
// Otherwise, assign a copy of the seed argument as the adjoint buffer of
1031+
// the original result.
1032+
else {
1033+
auto *seedBufCopy =
1034+
createFunctionLocalAllocation(seed->getType(), pbLoc);
1035+
builder.createCopyAddr(pbLoc, seed, seedBufCopy, IsNotTake,
1036+
IsInitialization);
1037+
setAdjointBuffer(origExit, origResult, seedBufCopy);
1038+
LLVM_DEBUG(getADDebugStream()
1039+
<< "Assigned seed buffer " << *seedBufCopy
1040+
<< " as the adjoint of original indirect result "
1041+
<< origResult);
1042+
}
1043+
} else {
1044+
addAdjointValue(origExit, origResult, makeConcreteAdjointValue(seed),
1045+
pbLoc);
10381046
LLVM_DEBUG(getADDebugStream()
1039-
<< "Assigned seed buffer " << seedBufCopy
1040-
<< " as the adjoint of original indirect result "
1041-
<< origResult);
1047+
<< "Assigned seed " << *seed
1048+
<< " as the adjoint of original result " << origResult);
10421049
}
1043-
} else {
1044-
setAdjointValue(origExit, origResult, makeConcreteAdjointValue(seed));
1045-
LLVM_DEBUG(getADDebugStream()
1046-
<< "Assigned seed " << *seed
1047-
<< " as the adjoint of original result " << origResult);
1050+
++seedIndex;
10481051
}
10491052

10501053
// If the original function is an accessor with special-case pullback
@@ -1573,8 +1576,7 @@ void PullbackEmitter::visitApplyInst(ApplyInst *ai) {
15731576
args.push_back(alloc);
15741577
}
15751578

1576-
// Get formal callee pullback arguments.
1577-
assert(applyInfo.indices.results->getNumIndices() == 1);
1579+
// Collect callee pullback formal arguments.
15781580
for (auto resultIndex : applyInfo.indices.results->getIndices()) {
15791581
assert(resultIndex < origAllResults.size());
15801582
auto origResult = origAllResults[resultIndex];

lib/SILOptimizer/Differentiation/VJPEmitter.cpp

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -155,32 +155,18 @@ SILFunction *VJPEmitter::createEmptyPullback() {
155155
auto origParams = origTy->getParameters();
156156
auto indices = witness->getSILAutoDiffIndices();
157157

158-
// Add pullback parameter for the seed.
159-
Optional<SILParameterInfo> inoutParam;
160-
bool isWrtInoutParam = false;
158+
// Add pullback parameters based on original result indices.
159+
SmallVector<unsigned, 4> inoutParamIndices;
161160
for (auto i : range(origTy->getNumParameters())) {
162161
auto origParam = origParams[i];
163162
if (!origParam.isIndirectInOut())
164163
continue;
165-
isWrtInoutParam = indices.parameters->contains(i);
166-
inoutParam = origParam;
164+
inoutParamIndices.push_back(i);
167165
}
168-
if (inoutParam) {
169-
auto origResult = inoutParam->getWithInterfaceType(
170-
inoutParam->getInterfaceType()->getCanonicalType(witnessCanGenSig));
171-
auto inoutParamTanConvention =
172-
isWrtInoutParam ? inoutParam->getConvention()
173-
: ParameterConvention::Indirect_In_Guaranteed;
174-
SILParameterInfo inoutParamTanParam(
175-
origResult.getInterfaceType()
176-
->getAutoDiffTangentSpace(lookupConformance)
177-
->getType()
178-
->getCanonicalType(witnessCanGenSig),
179-
inoutParamTanConvention);
180-
pbParams.push_back(inoutParamTanParam);
181-
} else {
182-
for (auto i : indices.results->getIndices()) {
183-
auto origResult = origTy->getResults()[i];
166+
for (auto resultIndex : indices.results->getIndices()) {
167+
// Handle formal result.
168+
if (resultIndex < origTy->getNumResults()) {
169+
auto origResult = origTy->getResults()[resultIndex];
184170
origResult = origResult.getWithInterfaceType(
185171
origResult.getInterfaceType()->getCanonicalType(witnessCanGenSig));
186172
pbParams.push_back(getTangentParameterInfoForOriginalResult(
@@ -189,7 +175,36 @@ SILFunction *VJPEmitter::createEmptyPullback() {
189175
->getType()
190176
->getCanonicalType(witnessCanGenSig),
191177
origResult.getConvention()));
178+
continue;
179+
}
180+
// Handle `inout` parameter.
181+
unsigned paramIndex = 0;
182+
unsigned inoutParamIndex = 0;
183+
for (auto i : range(origTy->getNumParameters())) {
184+
auto origParam = origTy->getParameters()[i];
185+
if (!origParam.isIndirectMutating()) {
186+
++paramIndex;
187+
continue;
188+
}
189+
if (inoutParamIndex == resultIndex - origTy->getNumResults())
190+
break;
191+
++paramIndex;
192+
++inoutParamIndex;
192193
}
194+
auto inoutParam = origParams[paramIndex];
195+
auto origResult = inoutParam.getWithInterfaceType(
196+
inoutParam.getInterfaceType()->getCanonicalType(witnessCanGenSig));
197+
auto inoutParamTanConvention =
198+
indices.isWrtParameter(paramIndex)
199+
? inoutParam.getConvention()
200+
: ParameterConvention::Indirect_In_Guaranteed;
201+
SILParameterInfo inoutParamTanParam(
202+
origResult.getInterfaceType()
203+
->getAutoDiffTangentSpace(lookupConformance)
204+
->getType()
205+
->getCanonicalType(witnessCanGenSig),
206+
inoutParamTanConvention);
207+
pbParams.push_back(inoutParamTanParam);
193208
}
194209

195210
// Accept a pullback struct in the pullback parameter list. This is the
@@ -587,15 +602,6 @@ void VJPEmitter::visitApplyInst(ApplyInst *ai) {
587602
activeResultIndices.begin(), activeResultIndices.end(),
588603
[&s](unsigned i) { s << i; }, [&s] { s << ", "; });
589604
s << "}\n";);
590-
// Diagnose multiple active results.
591-
// TODO(TF-983): Support multiple active results.
592-
if (activeResultIndices.size() > 1) {
593-
context.emitNondifferentiabilityError(
594-
ai, invoker,
595-
diag::autodiff_cannot_differentiate_through_multiple_results);
596-
errorOccurred = true;
597-
return;
598-
}
599605

600606
// Form expected indices.
601607
auto numSemanticResults =

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -346,12 +346,8 @@ func multipleResults(_ x: Float) -> (Float, Float) {
346346
return (x, x)
347347
}
348348

349-
// TODO(TF-983): Support differentiation of multiple results.
350-
// expected-error @+2 {{function is not differentiable}}
351-
// expected-note @+2 {{when differentiating this function definition}}
352349
@differentiable
353350
func usesMultipleResults(_ x: Float) -> Float {
354-
// expected-note @+1 {{cannot differentiate through multiple results}}
355351
let tuple = multipleResults(x)
356352
return tuple.0 + tuple.1
357353
}
@@ -440,27 +436,19 @@ func activeInoutParamMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) {
440436
nonactive = result.0
441437
}
442438

443-
// TODO(TF-983): Support differentiation of multiple results.
444439
func twoInoutParameters(_ x: inout Float, _ y: inout Float) {}
445-
// expected-error @+2 {{function is not differentiable}}
446-
// expected-note @+2 {{when differentiating this function definition}}
447440
@differentiable
448441
func testTwoInoutParameters(_ x: Float, _ y: Float) -> Float {
449442
var x = x
450443
var y = y
451-
// expected-note @+1 {{cannot differentiate through multiple results}}
452444
twoInoutParameters(&x, &y)
453445
return x
454446
}
455447

456-
// TODO(TF-983): Support differentiation of multiple results.
457448
func inoutParameterAndFormalResult(_ x: inout Float) -> Float { x }
458-
// expected-error @+2 {{function is not differentiable}}
459-
// expected-note @+2 {{when differentiating this function definition}}
460449
@differentiable
461450
func testInoutParameterAndFormalResult(_ x: Float) -> Float {
462451
var x = x
463-
// expected-note @+1 {{cannot differentiate through multiple results}}
464452
return inoutParameterAndFormalResult(&x)
465453
}
466454

0 commit comments

Comments
 (0)