@@ -315,15 +315,12 @@ getDifferentiabilityParameters(SILFunctionType *originalFnTy,
315
315
// / Collects the semantic results of the given function type in
316
316
// / `originalResults`. The semantic results are formal results followed by
317
317
// / `inout` parameters, in type order.
318
- // TODO(TF-983): Generalize to support multiple `inout` parameters. The current
319
- // singular `inoutParam` and `isWrtInoutParameter` are hacky.
320
318
static void
321
319
getSemanticResults (SILFunctionType *functionType, IndexSubset *parameterIndices,
322
- Optional<SILParameterInfo> &inoutParam,
323
- bool &isWrtInoutParameter,
320
+ IndexSubset *&inoutParameterIndices,
324
321
SmallVectorImpl<SILResultInfo> &originalResults) {
325
- inoutParam = None ;
326
- isWrtInoutParameter = false ;
322
+ auto &C = functionType-> getASTContext () ;
323
+ SmallVector< unsigned , 4 > inoutParamIndices ;
327
324
// Collect original formal results.
328
325
originalResults.append (functionType->getResults ().begin (),
329
326
functionType->getResults ().end ());
@@ -332,11 +329,12 @@ getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices,
332
329
auto param = functionType->getParameters ()[i];
333
330
if (!param.isIndirectInOut ())
334
331
continue ;
335
- inoutParam = param;
336
- isWrtInoutParameter = parameterIndices->contains (i);
332
+ inoutParamIndices.push_back (i);
337
333
originalResults.push_back (
338
334
SILResultInfo (param.getInterfaceType (), ResultConvention::Indirect));
339
335
}
336
+ inoutParameterIndices =
337
+ IndexSubset::get (C, parameterIndices->getCapacity (), inoutParamIndices);
340
338
}
341
339
342
340
// / Returns the differential type for the given original function type,
@@ -402,11 +400,10 @@ static CanSILFunctionType getAutoDiffDifferentialType(
402
400
SmallVector<Type, 4 > substReplacements;
403
401
SmallVector<ProtocolConformanceRef, 4 > substConformances;
404
402
405
- Optional<SILParameterInfo> inoutParam = None;
406
- bool isWrtInoutParameter = false ;
403
+ IndexSubset *inoutParamIndices;
407
404
SmallVector<SILResultInfo, 2 > originalResults;
408
- getSemanticResults (originalFnTy, parameterIndices, inoutParam ,
409
- isWrtInoutParameter, originalResults);
405
+ getSemanticResults (originalFnTy, parameterIndices, inoutParamIndices ,
406
+ originalResults);
410
407
411
408
SmallVector<SILParameterInfo, 4 > diffParams;
412
409
getDifferentiabilityParameters (originalFnTy, parameterIndices, diffParams);
@@ -430,7 +427,7 @@ static CanSILFunctionType getAutoDiffDifferentialType(
430
427
}
431
428
}
432
429
SmallVector<SILResultInfo, 1 > differentialResults;
433
- if (!inoutParam || !isWrtInoutParameter ) {
430
+ if (inoutParamIndices-> isEmpty () ) {
434
431
for (auto resultIndex : resultIndices->getIndices ()) {
435
432
auto &result = originalResults[resultIndex];
436
433
auto resultTan =
@@ -480,11 +477,10 @@ static CanSILFunctionType getAutoDiffPullbackType(
480
477
SmallVector<Type, 4 > substReplacements;
481
478
SmallVector<ProtocolConformanceRef, 4 > substConformances;
482
479
483
- Optional<SILParameterInfo> inoutParam = None;
484
- bool isWrtInoutParameter = false ;
480
+ IndexSubset *inoutParamIndices;
485
481
SmallVector<SILResultInfo, 2 > originalResults;
486
- getSemanticResults (originalFnTy, parameterIndices, inoutParam ,
487
- isWrtInoutParameter, originalResults);
482
+ getSemanticResults (originalFnTy, parameterIndices, inoutParamIndices ,
483
+ originalResults);
488
484
489
485
// Given a type, returns its formal SIL parameter info.
490
486
auto getTangentParameterConventionForOriginalResult =
@@ -551,27 +547,11 @@ static CanSILFunctionType getAutoDiffPullbackType(
551
547
return conv;
552
548
};
553
549
550
+ // Collect pullback parameters.
554
551
SmallVector<SILParameterInfo, 1 > pullbackParams;
555
- if (inoutParam) {
556
- auto paramTan = inoutParam->getInterfaceType ()->getAutoDiffTangentSpace (
557
- lookupConformance);
558
- assert (paramTan && " Parameter type does not have a tangent space?" );
559
- auto paramTanConvention = isWrtInoutParameter
560
- ? inoutParam->getConvention ()
561
- : ParameterConvention::Indirect_In_Guaranteed;
562
- auto paramTanType = paramTan->getCanonicalType ();
563
- if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
564
- pullbackParams.push_back (
565
- SILParameterInfo (paramTanType, paramTanConvention));
566
- } else {
567
- auto gpIndex = substGenericParams.size ();
568
- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
569
- substGenericParams.push_back (gpType);
570
- substReplacements.push_back (paramTanType);
571
- pullbackParams.push_back ({gpType, paramTanConvention});
572
- }
573
- } else {
574
- for (auto resultIndex : resultIndices->getIndices ()) {
552
+ for (auto resultIndex : resultIndices->getIndices ()) {
553
+ // Handle formal original result.
554
+ if (resultIndex < originalFnTy->getNumResults ()) {
575
555
auto &origRes = originalResults[resultIndex];
576
556
auto resultTan = origRes.getInterfaceType ()->getAutoDiffTangentSpace (
577
557
lookupConformance);
@@ -590,12 +570,46 @@ static CanSILFunctionType getAutoDiffPullbackType(
590
570
substReplacements.push_back (resultTanType);
591
571
pullbackParams.push_back ({gpType, paramTanConvention});
592
572
}
573
+ continue ;
574
+ }
575
+ // Handle original `inout` parameter.
576
+ auto inoutParamIndex = resultIndex - originalFnTy->getNumResults ();
577
+ auto inoutParamIt = std::next (
578
+ originalFnTy->getIndirectMutatingParameters ().begin (), inoutParamIndex);
579
+ auto paramIndex =
580
+ std::distance (originalFnTy->getParameters ().begin (), &*inoutParamIt);
581
+ auto inoutParam = originalFnTy->getParameters ()[paramIndex];
582
+ auto paramTan = inoutParam.getInterfaceType ()->getAutoDiffTangentSpace (
583
+ lookupConformance);
584
+ assert (paramTan && " Parameter type does not have a tangent space?" );
585
+ // The pullback parameter convention depends on whether the original `inout`
586
+ // paramater is a differentiability parameter.
587
+ // - If yes, the pullback parameter convention is `@inout`.
588
+ // - If no, the pullback parameter convention is `@in_guaranteed`.
589
+ bool isWrtInoutParameter = parameterIndices->contains (paramIndex);
590
+ auto paramTanConvention = isWrtInoutParameter
591
+ ? inoutParam.getConvention ()
592
+ : ParameterConvention::Indirect_In_Guaranteed;
593
+ auto paramTanType = paramTan->getCanonicalType ();
594
+ if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
595
+ pullbackParams.push_back (
596
+ SILParameterInfo (paramTanType, paramTanConvention));
597
+ } else {
598
+ auto gpIndex = substGenericParams.size ();
599
+ auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
600
+ substGenericParams.push_back (gpType);
601
+ substReplacements.push_back (paramTanType);
602
+ pullbackParams.push_back ({gpType, paramTanConvention});
593
603
}
594
604
}
605
+
606
+ // Collect pullback results.
595
607
SmallVector<SILParameterInfo, 4 > diffParams;
596
608
getDifferentiabilityParameters (originalFnTy, parameterIndices, diffParams);
597
609
SmallVector<SILResultInfo, 8 > pullbackResults;
598
610
for (auto ¶m : diffParams) {
611
+ // Skip `inout` parameters, which semantically behave as original results
612
+ // and always appear as pullback parameters.
599
613
if (param.isIndirectInOut ())
600
614
continue ;
601
615
auto paramTan =
0 commit comments