Skip to content

[AutoDiff] Fix SR-12641: Handle address-only types in derivative fn types #31496

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 62 additions & 6 deletions lib/SIL/IR/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,59 @@ getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices,
static CanSILFunctionType
getAutoDiffDifferentialType(SILFunctionType *originalFnTy,
IndexSubset *parameterIndices, unsigned resultIndex,
LookupConformanceFn lookupConformance) {
LookupConformanceFn lookupConformance,
TypeConverter &TC) {
// Given the tangent type and the corresponding original parameter's
// convention, returns the tangent parameter's convention.
auto getTangentParameterConvention =
[&](CanType tanType,
ParameterConvention origParamConv) -> ParameterConvention {
tanType =
tanType->getCanonicalType(originalFnTy->getSubstGenericSignature());
AbstractionPattern pattern(originalFnTy->getSubstGenericSignature(),
tanType);
auto &tl =
TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal());
// When the tangent type is address only, we must ensure that the tangent
// parameter's convention is indirect.
if (tl.isAddressOnly() && !isIndirectFormalParameter(origParamConv)) {
switch (origParamConv) {
case ParameterConvention::Direct_Guaranteed:
return ParameterConvention::Indirect_In_Guaranteed;
case ParameterConvention::Direct_Owned:
case ParameterConvention::Direct_Unowned:
return ParameterConvention::Indirect_In;
default:
llvm_unreachable("unhandled parameter convention");
}
}
return origParamConv;
};

// Given the tangent type and the corresponding original result's convention,
// returns the tangent result's convention.
auto getTangentResultConvention =
[&](CanType tanType,
ResultConvention origResConv) -> ResultConvention {
tanType =
tanType->getCanonicalType(originalFnTy->getSubstGenericSignature());
AbstractionPattern pattern(originalFnTy->getSubstGenericSignature(),
tanType);
auto &tl =
TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal());
// When the tangent type is address only, we must ensure that the tangent
// result's convention is indirect.
if (tl.isAddressOnly() && !isIndirectFormalResult(origResConv)) {
switch (origResConv) {
case ResultConvention::Owned:
return ResultConvention::Indirect;
default:
llvm_unreachable("unhandled result convention");
}
}
return origResConv;
};

auto &ctx = originalFnTy->getASTContext();
SmallVector<GenericTypeParamType *, 4> substGenericParams;
SmallVector<Requirement, 4> substRequirements;
Expand All @@ -324,15 +376,17 @@ getAutoDiffDifferentialType(SILFunctionType *originalFnTy,
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
assert(paramTan && "Parameter type does not have a tangent space?");
auto paramTanType = paramTan->getCanonicalType();
auto paramConv = getTangentParameterConvention(paramTanType,
param.getConvention());
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
differentialParams.push_back(
{paramTan->getCanonicalType(), param.getConvention()});
{paramTan->getCanonicalType(), paramConv});
} else {
auto gpIndex = substGenericParams.size();
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
substGenericParams.push_back(gpType);
substReplacements.push_back(paramTanType);
differentialParams.push_back({gpType, param.getConvention()});
differentialParams.push_back({gpType, paramConv});
}
}
SmallVector<SILResultInfo, 1> differentialResults;
Expand All @@ -342,15 +396,17 @@ getAutoDiffDifferentialType(SILFunctionType *originalFnTy,
result.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
assert(resultTan && "Result type does not have a tangent space?");
auto resultTanType = resultTan->getCanonicalType();
auto resultConv = getTangentResultConvention(resultTanType,
result.getConvention());
if (!resultTanType->hasArchetype() && !resultTanType->hasTypeParameter()) {
differentialResults.push_back(
{resultTan->getCanonicalType(), result.getConvention()});
{resultTan->getCanonicalType(), resultConv});
} else {
auto gpIndex = substGenericParams.size();
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
substGenericParams.push_back(gpType);
substReplacements.push_back(resultTanType);
differentialResults.push_back({gpType, result.getConvention()});
differentialResults.push_back({gpType, resultConv});
}
}
SubstitutionMap substitutions;
Expand Down Expand Up @@ -620,7 +676,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
case AutoDiffDerivativeFunctionKind::JVP:
closureType =
getAutoDiffDifferentialType(constrainedOriginalFnTy, parameterIndices,
resultIndex, lookupConformance);
resultIndex, lookupConformance, TC);
break;
case AutoDiffDerivativeFunctionKind::VJP:
closureType =
Expand Down
16 changes: 12 additions & 4 deletions lib/SILOptimizer/Differentiation/VJPEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,12 @@ SILFunction *VJPEmitter::createEmptyPullback() {
switch (origResConv) {
case ResultConvention::Owned:
case ResultConvention::Autoreleased:
conv = tl.isTrivial() ? ParameterConvention::Direct_Unowned
: ParameterConvention::Direct_Guaranteed;
if (tl.isAddressOnly()) {
conv = ParameterConvention::Indirect_In_Guaranteed;
} else {
conv = tl.isTrivial() ? ParameterConvention::Direct_Unowned
: ParameterConvention::Direct_Guaranteed;
}
break;
case ResultConvention::Unowned:
case ResultConvention::UnownedInnerPointer:
Expand All @@ -123,8 +127,12 @@ SILFunction *VJPEmitter::createEmptyPullback() {
case ParameterConvention::Direct_Owned:
case ParameterConvention::Direct_Guaranteed:
case ParameterConvention::Direct_Unowned:
conv =
tl.isTrivial() ? ResultConvention::Unowned : ResultConvention::Owned;
if (tl.isAddressOnly()) {
conv = ResultConvention::Indirect;
} else {
conv = tl.isTrivial() ? ResultConvention::Unowned
: ResultConvention::Owned;
}
break;
case ParameterConvention::Indirect_In:
case ParameterConvention::Indirect_Inout:
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// RUN: %target-swift-frontend -enable-resilience -emit-sil -verify %s
// REQUIRES: asserts

// SR-12641: SILGen verification error regarding `ImmutableAddressUseVerifier` and AutoDiff-generated code.

import _Differentiation

public struct Resilient: Differentiable {
var x: Float
}

public class Class: Differentiable {
var x: Resilient
init(_ x: Resilient) {
self.x = x
}
}

public func f(_ c: Class) -> Resilient {
return Resilient(x: 0)
}

_ = pullback(at: Class(Resilient(x: 10)), in: f)

// swift/lib/SIL/Verifier/SILVerifier.cpp:456: bool (anonymous namespace)::ImmutableAddressUseVerifier::isConsumingOrMutatingArgumentConvention(swift::SILArgumentConvention): Assertion `conv.isIndirectConvention() && "Expect an indirect convention"' failed.
// Stack dump:
// ...
// 1. Swift version 5.3-dev (LLVM be43a34c3c, Swift 6d5b2f5220)
// 2. While evaluating request SILGenWholeModuleRequest(SIL Generation for module main)
// 3. While verifying SIL function "@$s4main5ClassC13TangentVectorVAA9ResilientVADVIeggr_AeHIegnr_TR".
// ...
// #8 0x00000000011e7a3e (anonymous namespace)::ImmutableAddressUseVerifier::isConsumingOrMutatingApplyUse(swift::Operand*)
// #9 0x00000000011e6add (anonymous namespace)::ImmutableAddressUseVerifier::isMutatingOrConsuming(swift::SILValue)
// #10 0x00000000011ce0b4 (anonymous namespace)::SILVerifier::visitSILBasicBlock(swift::SILBasicBlock*)

// Related crasher discovered while fixing SR-12641.

class LoadableOriginal<T: Differentiable>: Differentiable {
var x: T
init(_ x: T) { self.x = x }
}

@differentiable
func loadableOriginal<T: AdditiveArithmetic>(_ loadable: LoadableOriginal<T>) -> T {
return T.zero
}

// swift/include/swift/SIL/TypeLowering.h:845: swift::SILType swift::Lowering::TypeConverter::getLoweredLoadableType(swift::Type, swift::TypeExpansionContext, swift::SILModule &): Assertion `(ti.isLoadable() || !SILModuleConventions(M).useLoweredAddresses()) && "unexpected address-only type"' failed.
// Stack dump:
// ...
// 2. While evaluating request ExecuteSILPipelineRequest(Run pipelines { Guaranteed Passes } on SIL for main.main)
// 3. While running pass #153 SILModuleTransform "Differentiation".
// 4. While processing // differentiability witness for loadableOriginal<A>(_:)
// sil_differentiability_witness hidden [parameters 0] [results 0] <T where T : AdditiveArithmetic, T : Differentiable> @$s4main16loadableOriginalyxAA08LoadableC0CyxGs18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlF : $@convention(thin) <T where T : Additive
// Arithmetic, T : Differentiable> (@guaranteed LoadableOriginal<T>) -> @out T {
// }
//
// on SIL function "@$s4main16loadableOriginalyxAA08LoadableC0CyxGs18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlF".
// for 'loadableOriginal(_:)'
// 5. While generating VJP for SIL function "@$s4main16loadableOriginalyxAA08LoadableC0CyxGs18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlF".
// for 'loadableOriginal(_:)'
// 6. While generating pullback for SIL function "@$s4main16loadableOriginalyxAA08LoadableC0CyxGs18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlF".
// for 'loadableOriginal(_:)'
// ...
// #9 0x0000000000f83fbb swift::autodiff::PullbackEmitter::emitZeroDirect(swift::CanType, swift::SILLocation)
// #10 0x0000000000f8248b swift::autodiff::PullbackEmitter::emitZeroDerivativesForNonvariedResult(swift::SILValue)
// #11 0x0000000000f7fcae swift::autodiff::PullbackEmitter::run()
// #12 0x0000000000f3fba4 swift::autodiff::VJPEmitter::run()
// #13 0x0000000000eb1669 (anonymous namespace)::DifferentiationTransformer::canonicalizeDifferentiabilityWitness(swift::SILFunction*, swift::SILDifferentiabilityWitness*, swift::autodiff::DifferentiationInvoker, swift::IsSerialized_t)
// #14 0x0000000000eaea5e (anonymous namespace)::Differentiation::run()