diff --git a/lib/SIL/IR/SILFunctionType.cpp b/lib/SIL/IR/SILFunctionType.cpp index 39bf2c3f9a1a7..a13417d2b81f7 100644 --- a/lib/SIL/IR/SILFunctionType.cpp +++ b/lib/SIL/IR/SILFunctionType.cpp @@ -303,7 +303,59 @@ getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices, static CanSILFunctionType getAutoDiffDifferentialType(SILFunctionType *originalFnTy, IndexSubset *parameterIndices, unsigned resultIndex, - LookupConformanceFn lookupConformance) { + LookupConformanceFn lookupConformance, + TypeConverter &TC) { + // Given the tangent type and the corresponding original parameter's + // convention, returns the tangent parameter's convention. + auto getTangentParameterConvention = + [&](CanType tanType, + ParameterConvention origParamConv) -> ParameterConvention { + tanType = + tanType->getCanonicalType(originalFnTy->getSubstGenericSignature()); + AbstractionPattern pattern(originalFnTy->getSubstGenericSignature(), + tanType); + auto &tl = + TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal()); + // When the tangent type is address only, we must ensure that the tangent + // parameter's convention is indirect. + if (tl.isAddressOnly() && !isIndirectFormalParameter(origParamConv)) { + switch (origParamConv) { + case ParameterConvention::Direct_Guaranteed: + return ParameterConvention::Indirect_In_Guaranteed; + case ParameterConvention::Direct_Owned: + case ParameterConvention::Direct_Unowned: + return ParameterConvention::Indirect_In; + default: + llvm_unreachable("unhandled parameter convention"); + } + } + return origParamConv; + }; + + // Given the tangent type and the corresponding original result's convention, + // returns the tangent result's convention. + auto getTangentResultConvention = + [&](CanType tanType, + ResultConvention origResConv) -> ResultConvention { + tanType = + tanType->getCanonicalType(originalFnTy->getSubstGenericSignature()); + AbstractionPattern pattern(originalFnTy->getSubstGenericSignature(), + tanType); + auto &tl = + TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal()); + // When the tangent type is address only, we must ensure that the tangent + // result's convention is indirect. + if (tl.isAddressOnly() && !isIndirectFormalResult(origResConv)) { + switch (origResConv) { + case ResultConvention::Owned: + return ResultConvention::Indirect; + default: + llvm_unreachable("unhandled result convention"); + } + } + return origResConv; + }; + auto &ctx = originalFnTy->getASTContext(); SmallVector substGenericParams; SmallVector substRequirements; @@ -324,15 +376,17 @@ getAutoDiffDifferentialType(SILFunctionType *originalFnTy, param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance); assert(paramTan && "Parameter type does not have a tangent space?"); auto paramTanType = paramTan->getCanonicalType(); + auto paramConv = getTangentParameterConvention(paramTanType, + param.getConvention()); if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) { differentialParams.push_back( - {paramTan->getCanonicalType(), param.getConvention()}); + {paramTan->getCanonicalType(), paramConv}); } else { auto gpIndex = substGenericParams.size(); auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx); substGenericParams.push_back(gpType); substReplacements.push_back(paramTanType); - differentialParams.push_back({gpType, param.getConvention()}); + differentialParams.push_back({gpType, paramConv}); } } SmallVector differentialResults; @@ -342,15 +396,17 @@ getAutoDiffDifferentialType(SILFunctionType *originalFnTy, result.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance); assert(resultTan && "Result type does not have a tangent space?"); auto resultTanType = resultTan->getCanonicalType(); + auto resultConv = getTangentResultConvention(resultTanType, + result.getConvention()); if (!resultTanType->hasArchetype() && !resultTanType->hasTypeParameter()) { differentialResults.push_back( - {resultTan->getCanonicalType(), result.getConvention()}); + {resultTan->getCanonicalType(), resultConv}); } else { auto gpIndex = substGenericParams.size(); auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx); substGenericParams.push_back(gpType); substReplacements.push_back(resultTanType); - differentialResults.push_back({gpType, result.getConvention()}); + differentialResults.push_back({gpType, resultConv}); } } SubstitutionMap substitutions; @@ -620,7 +676,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType( case AutoDiffDerivativeFunctionKind::JVP: closureType = getAutoDiffDifferentialType(constrainedOriginalFnTy, parameterIndices, - resultIndex, lookupConformance); + resultIndex, lookupConformance, TC); break; case AutoDiffDerivativeFunctionKind::VJP: closureType = diff --git a/lib/SILOptimizer/Differentiation/VJPEmitter.cpp b/lib/SILOptimizer/Differentiation/VJPEmitter.cpp index 35ceca1af9cec..fc086878b6386 100644 --- a/lib/SILOptimizer/Differentiation/VJPEmitter.cpp +++ b/lib/SILOptimizer/Differentiation/VJPEmitter.cpp @@ -97,8 +97,12 @@ SILFunction *VJPEmitter::createEmptyPullback() { switch (origResConv) { case ResultConvention::Owned: case ResultConvention::Autoreleased: - conv = tl.isTrivial() ? ParameterConvention::Direct_Unowned - : ParameterConvention::Direct_Guaranteed; + if (tl.isAddressOnly()) { + conv = ParameterConvention::Indirect_In_Guaranteed; + } else { + conv = tl.isTrivial() ? ParameterConvention::Direct_Unowned + : ParameterConvention::Direct_Guaranteed; + } break; case ResultConvention::Unowned: case ResultConvention::UnownedInnerPointer: @@ -123,8 +127,12 @@ SILFunction *VJPEmitter::createEmptyPullback() { case ParameterConvention::Direct_Owned: case ParameterConvention::Direct_Guaranteed: case ParameterConvention::Direct_Unowned: - conv = - tl.isTrivial() ? ResultConvention::Unowned : ResultConvention::Owned; + if (tl.isAddressOnly()) { + conv = ResultConvention::Indirect; + } else { + conv = tl.isTrivial() ? ResultConvention::Unowned + : ResultConvention::Owned; + } break; case ParameterConvention::Indirect_In: case ParameterConvention::Indirect_Inout: diff --git a/test/AutoDiff/compiler_crashers/sr12641-silgen-immutable-address-use-verification-failure.swift b/test/AutoDiff/compiler_crashers/sr12641-silgen-immutable-address-use-verification-failure.swift deleted file mode 100644 index 434732ff061b8..0000000000000 --- a/test/AutoDiff/compiler_crashers/sr12641-silgen-immutable-address-use-verification-failure.swift +++ /dev/null @@ -1,40 +0,0 @@ -// RUN: not --crash %target-swift-frontend -emit-sil -verify %s -// REQUIRES: asserts - -// This crash does not occur on the tensorflow branch. -// UNSUPPORTED: tensorflow - -// SR-12641: SILGen verification error regarding `ImmutableAddressUseVerifier` and AutoDiff-generated code. - -import _Differentiation -import DifferentiationUnittest - -class Class: Differentiable { - var x: Tracked - init(_ x: Tracked) { - self.x = x - } -} - -func getter(_ c: Class) -> Tracked { - return c.x -} -_ = gradient(at: Class(10), in: getter) - -// Assertion failed: (conv.isIndirectConvention() && "Expect an indirect convention"), function isConsumingOrMutatingArgumentConvention, file swift/lib/SIL/Verifier/SILVerifier.cpp, line 453. -// Stack dump: -// ... -// 1. Swift version 5.3-dev (LLVM ca0260ddec, Swift b17e1b23fe) -// 2. While evaluating request SILGenWholeModuleRequest(SIL Generation for module main) -// 3. While verifying SIL function "@$s4main5ClassC13TangentVectorV23DifferentiationUnittest7TrackedVySfGIeggr_AeIIegnr_TR". -// for <":0:0>>0 swift 0x000000010d6b4138 llvm::sys::PrintStackTrace(llvm::raw_ostream&) + 40 -// 1 swift 0x000000010d6b30b8 llvm::sys::RunSignalHandlers() + 248 -// 2 swift 0x000000010d6b472d SignalHandler(int) + 285 -// 3 libsystem_platform.dylib 0x00007fff718335fd _sigtramp + 29 -// 4 libsystem_platform.dylib 000000000000000000 _sigtramp + 18446603338611739168 -// 5 libsystem_c.dylib 0x00007fff71709808 abort + 120 -// 6 libsystem_c.dylib 0x00007fff71708ac6 err + 0 -// 7 swift 0x000000010da31c23 (anonymous namespace)::ImmutableAddressUseVerifier::isConsumingOrMutatingApplyUse(swift::Operand*) (.cold.4) + 35 -// 8 swift 0x0000000109c74d11 (anonymous namespace)::ImmutableAddressUseVerifier::isConsumingOrMutatingApplyUse(swift::Operand*) + 289 -// 9 swift 0x0000000109c73c1d (anonymous namespace)::ImmutableAddressUseVerifier::isMutatingOrConsuming(swift::SILValue) + 157 -// 10 swift 0x0000000109c5cea9 (anonymous namespace)::SILVerifier::visitSILBasicBlock(swift::SILBasicBlock*) + 1161 diff --git a/test/AutoDiff/compiler_crashers_fixed/sr12641-silgen-immutable-address-use-verification-failure.swift b/test/AutoDiff/compiler_crashers_fixed/sr12641-silgen-immutable-address-use-verification-failure.swift new file mode 100644 index 0000000000000..aa7f0918f0afa --- /dev/null +++ b/test/AutoDiff/compiler_crashers_fixed/sr12641-silgen-immutable-address-use-verification-failure.swift @@ -0,0 +1,70 @@ +// RUN: %target-swift-frontend -enable-resilience -emit-sil -verify %s +// REQUIRES: asserts + +// SR-12641: SILGen verification error regarding `ImmutableAddressUseVerifier` and AutoDiff-generated code. + +import _Differentiation + +public struct Resilient: Differentiable { + var x: Float +} + +public class Class: Differentiable { + var x: Resilient + init(_ x: Resilient) { + self.x = x + } +} + +public func f(_ c: Class) -> Resilient { + return Resilient(x: 0) +} + +_ = pullback(at: Class(Resilient(x: 10)), in: f) + +// swift/lib/SIL/Verifier/SILVerifier.cpp:456: bool (anonymous namespace)::ImmutableAddressUseVerifier::isConsumingOrMutatingArgumentConvention(swift::SILArgumentConvention): Assertion `conv.isIndirectConvention() && "Expect an indirect convention"' failed. +// Stack dump: +// ... +// 1. Swift version 5.3-dev (LLVM be43a34c3c, Swift 6d5b2f5220) +// 2. While evaluating request SILGenWholeModuleRequest(SIL Generation for module main) +// 3. While verifying SIL function "@$s4main5ClassC13TangentVectorVAA9ResilientVADVIeggr_AeHIegnr_TR". +// ... +// #8 0x00000000011e7a3e (anonymous namespace)::ImmutableAddressUseVerifier::isConsumingOrMutatingApplyUse(swift::Operand*) +// #9 0x00000000011e6add (anonymous namespace)::ImmutableAddressUseVerifier::isMutatingOrConsuming(swift::SILValue) +// #10 0x00000000011ce0b4 (anonymous namespace)::SILVerifier::visitSILBasicBlock(swift::SILBasicBlock*) + +// Related crasher discovered while fixing SR-12641. + +class LoadableOriginal: Differentiable { + var x: T + init(_ x: T) { self.x = x } +} + +@differentiable +func loadableOriginal(_ loadable: LoadableOriginal) -> T { + return T.zero +} + +// swift/include/swift/SIL/TypeLowering.h:845: swift::SILType swift::Lowering::TypeConverter::getLoweredLoadableType(swift::Type, swift::TypeExpansionContext, swift::SILModule &): Assertion `(ti.isLoadable() || !SILModuleConventions(M).useLoweredAddresses()) && "unexpected address-only type"' failed. +// Stack dump: +// ... +// 2. While evaluating request ExecuteSILPipelineRequest(Run pipelines { Guaranteed Passes } on SIL for main.main) +// 3. While running pass #153 SILModuleTransform "Differentiation". +// 4. While processing // differentiability witness for loadableOriginal(_:) +// sil_differentiability_witness hidden [parameters 0] [results 0] @$s4main16loadableOriginalyxAA08LoadableC0CyxGs18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlF : $@convention(thin) (@guaranteed LoadableOriginal) -> @out T { +// } +// +// on SIL function "@$s4main16loadableOriginalyxAA08LoadableC0CyxGs18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlF". +// for 'loadableOriginal(_:)' +// 5. While generating VJP for SIL function "@$s4main16loadableOriginalyxAA08LoadableC0CyxGs18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlF". +// for 'loadableOriginal(_:)' +// 6. While generating pullback for SIL function "@$s4main16loadableOriginalyxAA08LoadableC0CyxGs18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlF". +// for 'loadableOriginal(_:)' +// ... +// #9 0x0000000000f83fbb swift::autodiff::PullbackEmitter::emitZeroDirect(swift::CanType, swift::SILLocation) +// #10 0x0000000000f8248b swift::autodiff::PullbackEmitter::emitZeroDerivativesForNonvariedResult(swift::SILValue) +// #11 0x0000000000f7fcae swift::autodiff::PullbackEmitter::run() +// #12 0x0000000000f3fba4 swift::autodiff::VJPEmitter::run() +// #13 0x0000000000eb1669 (anonymous namespace)::DifferentiationTransformer::canonicalizeDifferentiabilityWitness(swift::SILFunction*, swift::SILDifferentiabilityWitness*, swift::autodiff::DifferentiationInvoker, swift::IsSerialized_t) +// #14 0x0000000000eaea5e (anonymous namespace)::Differentiation::run()