diff --git a/include/swift/AST/ASTContext.h b/include/swift/AST/ASTContext.h index f440dd7760248..e137b9d0c15d7 100644 --- a/include/swift/AST/ASTContext.h +++ b/include/swift/AST/ASTContext.h @@ -1124,6 +1124,16 @@ class ASTContext final { AbstractFunctionDecl *originalAFD, unsigned previousGeneration, llvm::SetVector &results); + /// Given `Optional.TangentVector` type, retrieve the + /// `Optional.TangentVector.init` declaration. + ConstructorDecl *getOptionalTanInitDecl(CanType optionalTanType); + + /// Optional.TangentVector is a struct with a single + /// Optional `value` property. This is an implementation + /// detail of OptionalDifferentiation.swift. Retrieve `VarDecl` corresponding + /// to this property. + VarDecl *getOptionalTanValueDecl(CanType optionalTanType); + /// Retrieve the next macro expansion discriminator within the given /// name and context. unsigned getNextMacroDiscriminator(MacroDiscriminatorContext context, diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index d087174ccd522..d013ca01be06d 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -343,6 +343,12 @@ struct ASTContext::Implementation { /// The declaration of Swift.Optional.None. EnumElementDecl *OptionalNoneDecl = nullptr; + /// The declaration of Optional.TangentVector.init + ConstructorDecl *OptionalTanInitDecl = nullptr; + + /// The declaration of Optional.TangentVector.value + VarDecl *OptionalTanValueDecl = nullptr; + /// The declaration of Swift.Void. TypeAliasDecl *VoidDecl = nullptr; @@ -2245,6 +2251,52 @@ void ASTContext::loadObjCMethods( } } +ConstructorDecl *ASTContext::getOptionalTanInitDecl(CanType optionalTanType) { + if (!getImpl().OptionalTanInitDecl) { + auto *optionalTanDecl = optionalTanType.getNominalOrBoundGenericNominal(); + // Look up the `Optional.TangentVector.init` declaration. + auto initLookup = + optionalTanDecl->lookupDirect(DeclBaseName::createConstructor()); + ConstructorDecl *constructorDecl = nullptr; + for (auto *candidate : initLookup) { + auto candidateModule = candidate->getModuleContext(); + if (candidateModule->getName() == Id_Differentiation || + candidateModule->isStdlibModule()) { + assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s"); + constructorDecl = cast(candidate); +#ifdef NDEBUG + break; +#endif + } + } + assert(constructorDecl && "No `Optional.TangentVector.init`"); + + getImpl().OptionalTanInitDecl = constructorDecl; + } + + return getImpl().OptionalTanInitDecl; +} + +VarDecl *ASTContext::getOptionalTanValueDecl(CanType optionalTanType) { + if (!getImpl().OptionalTanValueDecl) { + // TODO: Maybe it would be better to have getters / setters here that we + // can call and hide this implementation detail? + StructDecl *optStructDecl = optionalTanType.getStructOrBoundGenericStruct(); + assert(optStructDecl && "Unexpected type of Optional.TangentVector"); + + ArrayRef properties = optStructDecl->getStoredProperties(); + assert(properties.size() == 1 && "Unexpected type of Optional.TangentVector"); + VarDecl *wrappedValueVar = properties[0]; + + assert(wrappedValueVar->getTypeInContext()->getEnumOrBoundGenericEnum() == + getOptionalDecl() && "Unexpected type of Optional.TangentVector"); + + getImpl().OptionalTanValueDecl = wrappedValueVar; + } + + return getImpl().OptionalTanValueDecl; +} + void ASTContext::loadDerivativeFunctionConfigurations( AbstractFunctionDecl *originalAFD, unsigned previousGeneration, llvm::SetVector &results) { diff --git a/lib/AST/ASTDumper.cpp b/lib/AST/ASTDumper.cpp index e47de00e07027..93e86eece36de 100644 --- a/lib/AST/ASTDumper.cpp +++ b/lib/AST/ASTDumper.cpp @@ -18,6 +18,7 @@ #include "swift/AST/ASTPrinter.h" #include "swift/AST/ASTVisitor.h" #include "swift/AST/Attr.h" +#include "swift/AST/AutoDiff.h" #include "swift/AST/ClangModuleLoader.h" #include "swift/AST/ForeignAsyncConvention.h" #include "swift/AST/ForeignErrorConvention.h" @@ -6083,7 +6084,7 @@ namespace { } void printAnyFunctionTypeCommonRec(AnyFunctionType *T, Label label, - StringRef name) { + StringRef name) { printCommon(name, label); if (T->hasExtInfo()) { @@ -6098,6 +6099,24 @@ namespace { printFlag(T->isAsync(), "async"); printFlag(T->isThrowing(), "throws"); printFlag(T->hasSendingResult(), "sending_result"); + if (T->isDifferentiable()) { + switch (T->getDifferentiabilityKind()) { + default: + llvm_unreachable("unexpected differentiability kind"); + case DifferentiabilityKind::Reverse: + printFlag("@differentiable(reverse)"); + break; + case DifferentiabilityKind::Forward: + printFlag("@differentiable(_forward)"); + break; + case DifferentiabilityKind::Linear: + printFlag("@differentiable(_linear)"); + break; + case DifferentiabilityKind::Normal: + printFlag("@differentiable"); + break; + } + } } if (Type globalActor = T->getGlobalActor()) { printFieldQuoted(globalActor.getString(), Label::always("global_actor")); diff --git a/lib/SILGen/SILGenFunction.h b/lib/SILGen/SILGenFunction.h index 0ee34ecd002d2..b9083a16071c0 100644 --- a/lib/SILGen/SILGenFunction.h +++ b/lib/SILGen/SILGenFunction.h @@ -2521,6 +2521,25 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction CanSILFunctionType toType, bool reorderSelf); + /// Emit conversion from T.TangentVector to Optional.TangentVector. + ManagedValue + emitTangentVectorToOptionalTangentVector(SILLocation loc, + ManagedValue input, + CanType wrappedType, // `T` + CanType inputType, // `T.TangentVector` + CanType outputType, // `Optional.TangentVector` + SGFContext ctxt); + + /// Emit conversion from Optional.TangentVector to T.TangentVector. + ManagedValue + emitOptionalTangentVectorToTangentVector(SILLocation loc, + ManagedValue input, + CanType wrappedType, // `T` + CanType inputType, // `Optional.TangentVector` + CanType outputType, // `T.TangentVector` + SGFContext ctxt); + + //===--------------------------------------------------------------------===// // Back Deployment thunks //===--------------------------------------------------------------------===// diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp index ca89af6d83ec1..243789544930e 100644 --- a/lib/SILGen/SILGenPoly.cpp +++ b/lib/SILGen/SILGenPoly.cpp @@ -83,6 +83,7 @@ //===----------------------------------------------------------------------===// #define DEBUG_TYPE "silgen-poly" +#include "ArgumentSource.h" #include "ExecutorBreadcrumb.h" #include "FunctionInputGenerator.h" #include "Initialization.h" @@ -294,6 +295,67 @@ SILGenFunction::emitTransformExistential(SILLocation loc, }); } +// Convert T.TangentVector to Optional.TangentVector. +// Optional.TangentVector is a struct wrapping Optional +// So we just need to call appropriate .init on it. +ManagedValue SILGenFunction::emitTangentVectorToOptionalTangentVector( + SILLocation loc, ManagedValue input, CanType wrappedType, CanType inputType, + CanType outputType, SGFContext ctxt) { + // Look up the `Optional.TangentVector.init` declaration. + auto *constructorDecl = getASTContext().getOptionalTanInitDecl(outputType); + + // `Optional` + CanType optionalOfWrappedTanType = inputType.wrapInOptionalType(); + + const TypeLowering &optTL = getTypeLowering(optionalOfWrappedTanType); + auto optVal = emitInjectOptional( + loc, optTL, SGFContext(), [&](SGFContext objectCtxt) { return input; }); + + auto *diffProto = getASTContext().getProtocol(KnownProtocolKind::Differentiable); + auto diffConf = lookupConformance(wrappedType, diffProto); + assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`"); + ConcreteDeclRef initDecl( + constructorDecl, + SubstitutionMap::get(constructorDecl->getGenericSignature(), + {wrappedType}, {diffConf})); + PreparedArguments args({AnyFunctionType::Param(optionalOfWrappedTanType)}); + args.add(loc, RValue(*this, {optVal}, optionalOfWrappedTanType)); + + auto result = emitApplyAllocatingInitializer(loc, initDecl, std::move(args), + Type(), ctxt); + return std::move(result).getScalarValue(); +} + +ManagedValue SILGenFunction::emitOptionalTangentVectorToTangentVector( + SILLocation loc, ManagedValue input, CanType wrappedType, CanType inputType, + CanType outputType, SGFContext ctxt) { + // Optional.TangentVector should be a struct with a single + // Optional `value` property. This is an implementation + // detail of OptionalDifferentiation.swift + // TODO: Maybe it would be better to have explicit getters / setters here that we can + // call and hide this implementation detail? + VarDecl *wrappedValueVar = getASTContext().getOptionalTanValueDecl(inputType); + // `Optional` + CanType optionalOfWrappedTanType = outputType.wrapInOptionalType(); + + FormalEvaluationScope scope(*this); + + auto sig = wrappedValueVar->getDeclContext()->getGenericSignatureOfContext(); + auto *diffProto = + getASTContext().getProtocol(KnownProtocolKind::Differentiable); + auto diffConf = lookupConformance(wrappedType, diffProto); + assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`"); + + auto wrappedVal = emitRValueForStorageLoad( + loc, input, inputType, /*super*/ false, wrappedValueVar, + PreparedArguments(), SubstitutionMap::get(sig, {wrappedType}, {diffConf}), + AccessSemantics::Ordinary, optionalOfWrappedTanType, SGFContext()); + + return emitCheckedGetOptionalValueFrom( + loc, std::move(wrappedVal).getScalarValue(), + /*isImplicitUnwrap*/ true, getTypeLowering(optionalOfWrappedTanType), ctxt); +} + /// Apply this transformation to an arbitrary value. RValue Transform::transform(RValue &&input, AbstractionPattern inputOrigType, @@ -675,6 +737,54 @@ ManagedValue Transform::transform(ManagedValue v, return std::move(result).getAsSingleValue(SGF, Loc); } + // - T.TangentVector to Optional.TangentVector + // Optional.TangentVector is a struct wrapping Optional + // So we just need to call appropriate .init on it. + // However, we might have T.TangentVector == T, so we need to calculate all + // required types first. + { + CanType optionalTy = isa(outputSubstType) + ? outputSubstType.getNominalParent() + : CanType(); // `Optional` + if (optionalTy && (bool)optionalTy.getOptionalObjectType()) { + CanType wrappedType = optionalTy.getOptionalObjectType(); // `T` + // Check that T.TangentVector is indeed inputSubstType (this also handles + // case when T == T.TangentVector). + // Also check that outputSubstType is an Optional.TangentVector. + auto inputTanSpace = + wrappedType->getAutoDiffTangentSpace(LookUpConformanceInModule()); + auto outputTanSpace = + optionalTy->getAutoDiffTangentSpace(LookUpConformanceInModule()); + if (inputTanSpace && outputTanSpace && + inputTanSpace->getCanonicalType() == inputSubstType && + outputTanSpace->getCanonicalType() == outputSubstType) + return SGF.emitTangentVectorToOptionalTangentVector( + Loc, v, wrappedType, inputSubstType, outputSubstType, ctxt); + } + } + + // - Optional.TangentVector to T.TangentVector. + { + CanType optionalTy = isa(inputSubstType) + ? inputSubstType.getNominalParent() + : CanType(); // `Optional` + if (optionalTy && (bool)optionalTy.getOptionalObjectType()) { + CanType wrappedType = optionalTy.getOptionalObjectType(); // `T` + // Check that T.TangentVector is indeed outputSubstType (this also handles + // case when T == T.TangentVector) + // Also check that inputSubstType is an Optional.TangentVector + auto inputTanSpace = + optionalTy->getAutoDiffTangentSpace(LookUpConformanceInModule()); + auto outputTanSpace = + wrappedType->getAutoDiffTangentSpace(LookUpConformanceInModule()); + if (inputTanSpace && outputTanSpace && + inputTanSpace->getCanonicalType() == inputSubstType && + outputTanSpace->getCanonicalType() == outputSubstType) + return SGF.emitOptionalTangentVectorToTangentVector( + Loc, v, wrappedType, inputSubstType, outputSubstType, ctxt); + } + } + // Should have handled the conversion in one of the cases above. v.dump(); llvm_unreachable("Unhandled transform?"); diff --git a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp index 1b765c3e7f319..0c4aa3622abdf 100644 --- a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp @@ -1852,27 +1852,9 @@ class PullbackCloner::Implementation final auto adjOpt = getAdjointValue(bb, ei); auto adjStruct = materializeAdjointDirect(adjOpt, loc); - StructDecl *adjStructDecl = - adjStruct->getType().getStructOrBoundGenericStruct(); - - VarDecl *adjOptVar = nullptr; - if (adjStructDecl) { - ArrayRef properties = adjStructDecl->getStoredProperties(); - adjOptVar = properties.size() == 1 ? properties[0] : nullptr; - } - - EnumDecl *adjOptDecl = - adjOptVar ? adjOptVar->getTypeInContext()->getEnumOrBoundGenericEnum() - : nullptr; - - // Optional.TangentVector should be a struct with a single - // Optional property. This is an implementation detail of - // OptionalDifferentiation.swift - // TODO: Maybe it would be better to have getters / setters here that we - // can call and hide this implementation detail? - if (!adjOptDecl || adjOptDecl != optionalEnumDecl) - llvm_unreachable("Unexpected type of Optional.TangentVector"); + VarDecl *adjOptVar = + getASTContext().getOptionalTanValueDecl(adjStruct->getType().getASTType()); auto *adjVal = builder.createStructExtract(loc, adjStruct, adjOptVar); EnumElementDecl *someElemDecl = getASTContext().getOptionalSomeDecl(); @@ -1931,24 +1913,8 @@ class PullbackCloner::Implementation final } SILValue adjDest = getAdjointBuffer(bb, origEnum); - StructDecl *adjStructDecl = - adjDest->getType().getStructOrBoundGenericStruct(); - - VarDecl *adjOptVar = nullptr; - if (adjStructDecl) { - ArrayRef properties = adjStructDecl->getStoredProperties(); - adjOptVar = properties.size() == 1 ? properties[0] : nullptr; - } - - EnumDecl *adjOptDecl = - adjOptVar ? adjOptVar->getTypeInContext()->getEnumOrBoundGenericEnum() - : nullptr; - - // Optional.TangentVector should be a struct with a single - // Optional property. This is an implementation detail of - // OptionalDifferentiation.swift - if (!adjOptDecl || adjOptDecl != optionalEnumDecl) - llvm_unreachable("Unexpected type of Optional.TangentVector"); + VarDecl *adjOptVar = + getASTContext().getOptionalTanValueDecl(adjDest->getType().getASTType()); SILLocation loc = origData->getLoc(); StructElementAddrInst *adjOpt = @@ -2678,24 +2644,9 @@ AllocStackInst *PullbackCloner::Implementation::createOptionalAdjoint( auto optionalOfWrappedTanType = SILType::getOptionalType(wrappedTanType); // `Optional.TangentVector` auto optionalTanTy = getRemappedTangentType(optionalTy); - auto *optionalTanDecl = optionalTanTy.getNominalOrBoundGenericNominal(); // Look up the `Optional.TangentVector.init` declaration. - auto initLookup = - optionalTanDecl->lookupDirect(DeclBaseName::createConstructor()); - ConstructorDecl *constructorDecl = nullptr; - for (auto *candidate : initLookup) { - auto candidateModule = candidate->getModuleContext(); - if (candidateModule->getName() == - builder.getASTContext().Id_Differentiation || - candidateModule->isStdlibModule()) { - assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s"); - constructorDecl = cast(candidate); -#ifdef NDEBUG - break; -#endif - } - } - assert(constructorDecl && "No `Optional.TangentVector.init`"); + ConstructorDecl *constructorDecl = + getASTContext().getOptionalTanInitDecl(optionalTanTy.getASTType()); // Allocate a local buffer for the `Optional` adjoint value. auto *optTanAdjBuf = builder.createAllocStack(pbLoc, optionalTanTy); diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index 2867c515a7cea..23877eaeaedc1 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -52,6 +52,7 @@ #include "clang/Sema/TemplateDeduction.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallString.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/SaveAndRestore.h" @@ -7538,8 +7539,20 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType, fromEI.intoBuilder() .withDifferentiabilityKind(toEI.getDifferentiabilityKind()) .build(); - fromFunc = FunctionType::get(toFunc->getParams(), fromFunc->getResult(), - newEI); + SmallVector params(fromFunc->getParams()); + assert(params.size() == toFunc->getParams().size() && + "unexpected @differentiable conversion"); + // Propagate @noDerivate from target function type + for (auto paramAndIndex : llvm::enumerate(toFunc->getParams())) { + if (!paramAndIndex.value().isNoDerivative()) + continue; + + auto ¶m = params[paramAndIndex.index()]; + param = + param.withFlags(param.getParameterFlags().withNoDerivative(true)); + } + + fromFunc = FunctionType::get(params, fromFunc->getResult(), newEI); switch (toEI.getDifferentiabilityKind()) { // TODO: Ban `Normal` and `Forward` cases. case DifferentiabilityKind::Normal: diff --git a/test/AutoDiff/compiler_crashers_fixed/issue-77871-implicit-diff-optional-conversion.swift b/test/AutoDiff/compiler_crashers_fixed/issue-77871-implicit-diff-optional-conversion.swift new file mode 100644 index 0000000000000..474aed45cef19 --- /dev/null +++ b/test/AutoDiff/compiler_crashers_fixed/issue-77871-implicit-diff-optional-conversion.swift @@ -0,0 +1,18 @@ +// RUN: %target-swift-frontend -emit-sil -verify %s + +// https://github.com/swiftlang/swift/issues/77871 +// Ensure we are correctl generating reabstraction thunks for Double <-> Optional +// conversion for derivatives: for differential and pullback we need +// to emit thunks to convert T.TangentVector <-> Optional.TangentVector. + +import _Differentiation + +@differentiable(reverse) +func testFunc(_ x: Double?) -> Double? { + x! * x! * x! +} +print(pullback(at: 1.0, of: testFunc)(.init(1.0)) == 3.0) + +func foo(_ fn: @escaping @differentiable(reverse) (T?) -> Double) { + let _: @differentiable(reverse) (T) -> Double = fn +}