Skip to content

Emit reabstraction thunks for implicit conversions between T.TangentType and Optional<T>.TangentType #78076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/swift/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,16 @@ class ASTContext final {
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
llvm::SetVector<AutoDiffConfig> &results);

/// Given `Optional<T>.TangentVector` type, retrieve the
/// `Optional<T>.TangentVector.init` declaration.
ConstructorDecl *getOptionalTanInitDecl(CanType optionalTanType);

/// Optional<T>.TangentVector is a struct with a single
/// Optional<T.TangentVector> `value` property. This is an implementation
/// detail of OptionalDifferentiation.swift. Retrieve `VarDecl` corresponding
/// to this property.
VarDecl *getOptionalTanValueDecl(CanType optionalTanType);

/// Retrieve the next macro expansion discriminator within the given
/// name and context.
unsigned getNextMacroDiscriminator(MacroDiscriminatorContext context,
Expand Down
52 changes: 52 additions & 0 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,12 @@ struct ASTContext::Implementation {
/// The declaration of Swift.Optional<T>.None.
EnumElementDecl *OptionalNoneDecl = nullptr;

/// The declaration of Optional<T>.TangentVector.init
ConstructorDecl *OptionalTanInitDecl = nullptr;

/// The declaration of Optional<T>.TangentVector.value
VarDecl *OptionalTanValueDecl = nullptr;

/// The declaration of Swift.Void.
TypeAliasDecl *VoidDecl = nullptr;

Expand Down Expand Up @@ -2245,6 +2251,52 @@ void ASTContext::loadObjCMethods(
}
}

ConstructorDecl *ASTContext::getOptionalTanInitDecl(CanType optionalTanType) {
if (!getImpl().OptionalTanInitDecl) {
auto *optionalTanDecl = optionalTanType.getNominalOrBoundGenericNominal();
// Look up the `Optional<T>.TangentVector.init` declaration.
auto initLookup =
optionalTanDecl->lookupDirect(DeclBaseName::createConstructor());
ConstructorDecl *constructorDecl = nullptr;
for (auto *candidate : initLookup) {
auto candidateModule = candidate->getModuleContext();
if (candidateModule->getName() == Id_Differentiation ||
candidateModule->isStdlibModule()) {
assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s");
constructorDecl = cast<ConstructorDecl>(candidate);
#ifdef NDEBUG
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future PRs: we're trying to get away from conditional asserts. Instead, use ASSERT and avoid #ifdef NDEBUG

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, wasn't aware of this, thanks!

break;
#endif
}
}
assert(constructorDecl && "No `Optional.TangentVector.init`");

getImpl().OptionalTanInitDecl = constructorDecl;
}

return getImpl().OptionalTanInitDecl;
}

VarDecl *ASTContext::getOptionalTanValueDecl(CanType optionalTanType) {
if (!getImpl().OptionalTanValueDecl) {
// TODO: Maybe it would be better to have getters / setters here that we
// can call and hide this implementation detail?
StructDecl *optStructDecl = optionalTanType.getStructOrBoundGenericStruct();
assert(optStructDecl && "Unexpected type of Optional.TangentVector");

ArrayRef<VarDecl *> properties = optStructDecl->getStoredProperties();
assert(properties.size() == 1 && "Unexpected type of Optional.TangentVector");
VarDecl *wrappedValueVar = properties[0];

assert(wrappedValueVar->getTypeInContext()->getEnumOrBoundGenericEnum() ==
getOptionalDecl() && "Unexpected type of Optional.TangentVector");

getImpl().OptionalTanValueDecl = wrappedValueVar;
}

return getImpl().OptionalTanValueDecl;
}

void ASTContext::loadDerivativeFunctionConfigurations(
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
llvm::SetVector<AutoDiffConfig> &results) {
Expand Down
21 changes: 20 additions & 1 deletion lib/AST/ASTDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "swift/AST/ASTPrinter.h"
#include "swift/AST/ASTVisitor.h"
#include "swift/AST/Attr.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/ClangModuleLoader.h"
#include "swift/AST/ForeignAsyncConvention.h"
#include "swift/AST/ForeignErrorConvention.h"
Expand Down Expand Up @@ -6083,7 +6084,7 @@ namespace {
}

void printAnyFunctionTypeCommonRec(AnyFunctionType *T, Label label,
StringRef name) {
StringRef name) {
printCommon(name, label);

if (T->hasExtInfo()) {
Expand All @@ -6098,6 +6099,24 @@ namespace {
printFlag(T->isAsync(), "async");
printFlag(T->isThrowing(), "throws");
printFlag(T->hasSendingResult(), "sending_result");
if (T->isDifferentiable()) {
switch (T->getDifferentiabilityKind()) {
default:
llvm_unreachable("unexpected differentiability kind");
case DifferentiabilityKind::Reverse:
printFlag("@differentiable(reverse)");
break;
case DifferentiabilityKind::Forward:
printFlag("@differentiable(_forward)");
break;
case DifferentiabilityKind::Linear:
printFlag("@differentiable(_linear)");
break;
case DifferentiabilityKind::Normal:
printFlag("@differentiable");
break;
}
}
}
if (Type globalActor = T->getGlobalActor()) {
printFieldQuoted(globalActor.getString(), Label::always("global_actor"));
Expand Down
19 changes: 19 additions & 0 deletions lib/SILGen/SILGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2521,6 +2521,25 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
CanSILFunctionType toType,
bool reorderSelf);

/// Emit conversion from T.TangentVector to Optional<T>.TangentVector.
ManagedValue
emitTangentVectorToOptionalTangentVector(SILLocation loc,
ManagedValue input,
CanType wrappedType, // `T`
CanType inputType, // `T.TangentVector`
CanType outputType, // `Optional<T>.TangentVector`
SGFContext ctxt);

/// Emit conversion from Optional<T>.TangentVector to T.TangentVector.
ManagedValue
emitOptionalTangentVectorToTangentVector(SILLocation loc,
ManagedValue input,
CanType wrappedType, // `T`
CanType inputType, // `Optional<T>.TangentVector`
CanType outputType, // `T.TangentVector`
SGFContext ctxt);


//===--------------------------------------------------------------------===//
// Back Deployment thunks
//===--------------------------------------------------------------------===//
Expand Down
110 changes: 110 additions & 0 deletions lib/SILGen/SILGenPoly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
//===----------------------------------------------------------------------===//

#define DEBUG_TYPE "silgen-poly"
#include "ArgumentSource.h"
#include "ExecutorBreadcrumb.h"
#include "FunctionInputGenerator.h"
#include "Initialization.h"
Expand Down Expand Up @@ -294,6 +295,67 @@ SILGenFunction::emitTransformExistential(SILLocation loc,
});
}

// Convert T.TangentVector to Optional<T>.TangentVector.
// Optional<T>.TangentVector is a struct wrapping Optional<T.TangentVector>
// So we just need to call appropriate .init on it.
ManagedValue SILGenFunction::emitTangentVectorToOptionalTangentVector(
SILLocation loc, ManagedValue input, CanType wrappedType, CanType inputType,
CanType outputType, SGFContext ctxt) {
// Look up the `Optional<T>.TangentVector.init` declaration.
auto *constructorDecl = getASTContext().getOptionalTanInitDecl(outputType);

// `Optional<T.TangentVector>`
CanType optionalOfWrappedTanType = inputType.wrapInOptionalType();

const TypeLowering &optTL = getTypeLowering(optionalOfWrappedTanType);
auto optVal = emitInjectOptional(
loc, optTL, SGFContext(), [&](SGFContext objectCtxt) { return input; });

auto *diffProto = getASTContext().getProtocol(KnownProtocolKind::Differentiable);
auto diffConf = lookupConformance(wrappedType, diffProto);
assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`");
ConcreteDeclRef initDecl(
constructorDecl,
SubstitutionMap::get(constructorDecl->getGenericSignature(),
{wrappedType}, {diffConf}));
PreparedArguments args({AnyFunctionType::Param(optionalOfWrappedTanType)});
args.add(loc, RValue(*this, {optVal}, optionalOfWrappedTanType));

auto result = emitApplyAllocatingInitializer(loc, initDecl, std::move(args),
Type(), ctxt);
return std::move(result).getScalarValue();
}

ManagedValue SILGenFunction::emitOptionalTangentVectorToTangentVector(
SILLocation loc, ManagedValue input, CanType wrappedType, CanType inputType,
CanType outputType, SGFContext ctxt) {
// Optional<T>.TangentVector should be a struct with a single
// Optional<T.TangentVector> `value` property. This is an implementation
// detail of OptionalDifferentiation.swift
// TODO: Maybe it would be better to have explicit getters / setters here that we can
// call and hide this implementation detail?
VarDecl *wrappedValueVar = getASTContext().getOptionalTanValueDecl(inputType);
// `Optional<T.TangentVector>`
CanType optionalOfWrappedTanType = outputType.wrapInOptionalType();

FormalEvaluationScope scope(*this);

auto sig = wrappedValueVar->getDeclContext()->getGenericSignatureOfContext();
auto *diffProto =
getASTContext().getProtocol(KnownProtocolKind::Differentiable);
auto diffConf = lookupConformance(wrappedType, diffProto);
assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`");

auto wrappedVal = emitRValueForStorageLoad(
loc, input, inputType, /*super*/ false, wrappedValueVar,
PreparedArguments(), SubstitutionMap::get(sig, {wrappedType}, {diffConf}),
AccessSemantics::Ordinary, optionalOfWrappedTanType, SGFContext());

return emitCheckedGetOptionalValueFrom(
loc, std::move(wrappedVal).getScalarValue(),
/*isImplicitUnwrap*/ true, getTypeLowering(optionalOfWrappedTanType), ctxt);
}

/// Apply this transformation to an arbitrary value.
RValue Transform::transform(RValue &&input,
AbstractionPattern inputOrigType,
Expand Down Expand Up @@ -675,6 +737,54 @@ ManagedValue Transform::transform(ManagedValue v,
return std::move(result).getAsSingleValue(SGF, Loc);
}

// - T.TangentVector to Optional<T>.TangentVector
// Optional<T>.TangentVector is a struct wrapping Optional<T.TangentVector>
// So we just need to call appropriate .init on it.
// However, we might have T.TangentVector == T, so we need to calculate all
// required types first.
{
CanType optionalTy = isa<NominalType>(outputSubstType)
? outputSubstType.getNominalParent()
: CanType(); // `Optional<T>`
if (optionalTy && (bool)optionalTy.getOptionalObjectType()) {
CanType wrappedType = optionalTy.getOptionalObjectType(); // `T`
// Check that T.TangentVector is indeed inputSubstType (this also handles
// case when T == T.TangentVector).
// Also check that outputSubstType is an Optional<T>.TangentVector.
auto inputTanSpace =
wrappedType->getAutoDiffTangentSpace(LookUpConformanceInModule());
auto outputTanSpace =
optionalTy->getAutoDiffTangentSpace(LookUpConformanceInModule());
if (inputTanSpace && outputTanSpace &&
inputTanSpace->getCanonicalType() == inputSubstType &&
outputTanSpace->getCanonicalType() == outputSubstType)
return SGF.emitTangentVectorToOptionalTangentVector(
Loc, v, wrappedType, inputSubstType, outputSubstType, ctxt);
}
}

// - Optional<T>.TangentVector to T.TangentVector.
{
CanType optionalTy = isa<NominalType>(inputSubstType)
? inputSubstType.getNominalParent()
: CanType(); // `Optional<T>`
if (optionalTy && (bool)optionalTy.getOptionalObjectType()) {
CanType wrappedType = optionalTy.getOptionalObjectType(); // `T`
// Check that T.TangentVector is indeed outputSubstType (this also handles
// case when T == T.TangentVector)
// Also check that inputSubstType is an Optional<T>.TangentVector
auto inputTanSpace =
optionalTy->getAutoDiffTangentSpace(LookUpConformanceInModule());
auto outputTanSpace =
wrappedType->getAutoDiffTangentSpace(LookUpConformanceInModule());
if (inputTanSpace && outputTanSpace &&
inputTanSpace->getCanonicalType() == inputSubstType &&
outputTanSpace->getCanonicalType() == outputSubstType)
return SGF.emitOptionalTangentVectorToTangentVector(
Loc, v, wrappedType, inputSubstType, outputSubstType, ctxt);
}
}

// Should have handled the conversion in one of the cases above.
v.dump();
llvm_unreachable("Unhandled transform?");
Expand Down
61 changes: 6 additions & 55 deletions lib/SILOptimizer/Differentiation/PullbackCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1852,27 +1852,9 @@ class PullbackCloner::Implementation final

auto adjOpt = getAdjointValue(bb, ei);
auto adjStruct = materializeAdjointDirect(adjOpt, loc);
StructDecl *adjStructDecl =
adjStruct->getType().getStructOrBoundGenericStruct();

VarDecl *adjOptVar = nullptr;
if (adjStructDecl) {
ArrayRef<VarDecl *> properties = adjStructDecl->getStoredProperties();
adjOptVar = properties.size() == 1 ? properties[0] : nullptr;
}

EnumDecl *adjOptDecl =
adjOptVar ? adjOptVar->getTypeInContext()->getEnumOrBoundGenericEnum()
: nullptr;

// Optional<T>.TangentVector should be a struct with a single
// Optional<T.TangentVector> property. This is an implementation detail of
// OptionalDifferentiation.swift
// TODO: Maybe it would be better to have getters / setters here that we
// can call and hide this implementation detail?
if (!adjOptDecl || adjOptDecl != optionalEnumDecl)
llvm_unreachable("Unexpected type of Optional.TangentVector");

VarDecl *adjOptVar =
getASTContext().getOptionalTanValueDecl(adjStruct->getType().getASTType());
auto *adjVal = builder.createStructExtract(loc, adjStruct, adjOptVar);

EnumElementDecl *someElemDecl = getASTContext().getOptionalSomeDecl();
Expand Down Expand Up @@ -1931,24 +1913,8 @@ class PullbackCloner::Implementation final
}

SILValue adjDest = getAdjointBuffer(bb, origEnum);
StructDecl *adjStructDecl =
adjDest->getType().getStructOrBoundGenericStruct();

VarDecl *adjOptVar = nullptr;
if (adjStructDecl) {
ArrayRef<VarDecl *> properties = adjStructDecl->getStoredProperties();
adjOptVar = properties.size() == 1 ? properties[0] : nullptr;
}

EnumDecl *adjOptDecl =
adjOptVar ? adjOptVar->getTypeInContext()->getEnumOrBoundGenericEnum()
: nullptr;

// Optional<T>.TangentVector should be a struct with a single
// Optional<T.TangentVector> property. This is an implementation detail of
// OptionalDifferentiation.swift
if (!adjOptDecl || adjOptDecl != optionalEnumDecl)
llvm_unreachable("Unexpected type of Optional.TangentVector");
VarDecl *adjOptVar =
getASTContext().getOptionalTanValueDecl(adjDest->getType().getASTType());

SILLocation loc = origData->getLoc();
StructElementAddrInst *adjOpt =
Expand Down Expand Up @@ -2678,24 +2644,9 @@ AllocStackInst *PullbackCloner::Implementation::createOptionalAdjoint(
auto optionalOfWrappedTanType = SILType::getOptionalType(wrappedTanType);
// `Optional<T>.TangentVector`
auto optionalTanTy = getRemappedTangentType(optionalTy);
auto *optionalTanDecl = optionalTanTy.getNominalOrBoundGenericNominal();
// Look up the `Optional<T>.TangentVector.init` declaration.
auto initLookup =
optionalTanDecl->lookupDirect(DeclBaseName::createConstructor());
ConstructorDecl *constructorDecl = nullptr;
for (auto *candidate : initLookup) {
auto candidateModule = candidate->getModuleContext();
if (candidateModule->getName() ==
builder.getASTContext().Id_Differentiation ||
candidateModule->isStdlibModule()) {
assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s");
constructorDecl = cast<ConstructorDecl>(candidate);
#ifdef NDEBUG
break;
#endif
}
}
assert(constructorDecl && "No `Optional.TangentVector.init`");
ConstructorDecl *constructorDecl =
getASTContext().getOptionalTanInitDecl(optionalTanTy.getASTType());

// Allocate a local buffer for the `Optional` adjoint value.
auto *optTanAdjBuf = builder.createAllocStack(pbLoc, optionalTanTy);
Expand Down
17 changes: 15 additions & 2 deletions lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "clang/Sema/TemplateDeduction.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/SaveAndRestore.h"
Expand Down Expand Up @@ -7538,8 +7539,20 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType,
fromEI.intoBuilder()
.withDifferentiabilityKind(toEI.getDifferentiabilityKind())
.build();
fromFunc = FunctionType::get(toFunc->getParams(), fromFunc->getResult(),
newEI);
SmallVector<AnyFunctionType::Param, 4> params(fromFunc->getParams());
assert(params.size() == toFunc->getParams().size() &&
"unexpected @differentiable conversion");
// Propagate @noDerivate from target function type
for (auto paramAndIndex : llvm::enumerate(toFunc->getParams())) {
if (!paramAndIndex.value().isNoDerivative())
continue;

auto &param = params[paramAndIndex.index()];
param =
param.withFlags(param.getParameterFlags().withNoDerivative(true));
}

fromFunc = FunctionType::get(params, fromFunc->getResult(), newEI);
switch (toEI.getDifferentiabilityKind()) {
// TODO: Ban `Normal` and `Forward` cases.
case DifferentiabilityKind::Normal:
Expand Down
Loading