diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index 0ffd3a600eb17..30d7e30ddcd23 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -3637,13 +3637,35 @@ inline bool isGuaranteedParameter(ParameterConvention conv) { llvm_unreachable("bad convention kind"); } +/// The differentiability of a SIL function type parameter. +enum class SILParameterDifferentiability : unsigned { + /// Either differentiable or not applicable. + /// + /// - If the function type is not `@differentiable`, parameter + /// differentiability is not applicable. This case is the default value. + /// - If the function type is `@differentiable`, the function is + /// differentiable with respect to this parameter. + DifferentiableOrNotApplicable, + + /// Not differentiable: a `@noDerivative` parameter. + /// + /// May be applied only to parameters of `@differentiable` function types. + /// The function type is not differentiable with respect to this parameter. + NotDifferentiable, +}; + /// A parameter type and the rules for passing it. class SILParameterInfo { llvm::PointerIntPair TypeAndConvention; + SILParameterDifferentiability Differentiability : 1; + public: SILParameterInfo() = default;//: Ty(), Convention((ParameterConvention)0) {} - SILParameterInfo(CanType type, ParameterConvention conv) - : TypeAndConvention(type, conv) { + SILParameterInfo( + CanType type, ParameterConvention conv, + SILParameterDifferentiability differentiability = + SILParameterDifferentiability::DifferentiableOrNotApplicable) + : TypeAndConvention(type, conv), Differentiability(differentiability) { assert(type->isLegalSILType() && "SILParameterInfo has illegal SIL type"); } @@ -3698,6 +3720,16 @@ class SILParameterInfo { return isGuaranteedParameter(getConvention()); } + SILParameterDifferentiability getDifferentiability() const { + return Differentiability; + } + + SILParameterInfo getWithDifferentiability( + SILParameterDifferentiability differentiability) const { + return SILParameterInfo(getInterfaceType(), getConvention(), + differentiability); + } + /// The SIL storage type determines the ABI for arguments based purely on the /// formal parameter conventions. The actual SIL type for the argument values /// may differ in canonical SIL. In particular, opaque values require indirect @@ -3726,6 +3758,7 @@ class SILParameterInfo { void profile(llvm::FoldingSetNodeID &id) { id.AddPointer(getInterfaceType().getPointer()); id.AddInteger((unsigned)getConvention()); + id.AddInteger((unsigned)getDifferentiability()); } SWIFT_DEBUG_DUMP; @@ -3739,8 +3772,9 @@ class SILParameterInfo { } bool operator==(SILParameterInfo rhs) const { - return getInterfaceType() == rhs.getInterfaceType() - && getConvention() == rhs.getConvention(); + return getInterfaceType() == rhs.getInterfaceType() && + getConvention() == rhs.getConvention() && + getDifferentiability() == rhs.getDifferentiability(); } bool operator!=(SILParameterInfo rhs) const { return !(*this == rhs); @@ -4093,6 +4127,13 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode, return ExtInfo(NoEscape ? (Bits | NoEscapeMask) : (Bits & ~NoEscapeMask), Other); } + ExtInfo + withDifferentiabilityKind(DifferentiabilityKind differentiability) const { + return ExtInfo( + (Bits & ~DifferentiabilityMask) | + ((unsigned)differentiability << DifferentiabilityMaskOffset), + Other); + } std::pair getFuncAttrKey() const { return std::make_pair(Bits, Other.ClangFunctionType); diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index b335fa071117b..ef6969430da40 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -3327,6 +3327,15 @@ SILFunctionType::SILFunctionType( "Cannot return an @noescape function type"); } } + + // Check that `@noDerivative` parameters only exist on `@differentiable` + // functions. + if (!ext.isDifferentiable()) + for (auto param : getParameters()) + assert(param.getDifferentiability() == + SILParameterDifferentiability::DifferentiableOrNotApplicable && + "non-`@differentiable` function should not have NotDifferentiable " + "parameter"); #endif } diff --git a/lib/AST/ASTPrinter.cpp b/lib/AST/ASTPrinter.cpp index 1fe355d76bc47..eeabdf4de8a65 100644 --- a/lib/AST/ASTPrinter.cpp +++ b/lib/AST/ASTPrinter.cpp @@ -4649,6 +4649,13 @@ void SILParameterInfo::print(raw_ostream &OS, const PrintOptions &Opts) const { } void SILParameterInfo::print(ASTPrinter &Printer, const PrintOptions &Opts) const { + switch (getDifferentiability()) { + case SILParameterDifferentiability::NotDifferentiable: + Printer << "@noDerivative "; + break; + default: + break; + } Printer << getStringForParameterConvention(getConvention()); getInterfaceType().print(Printer, Opts); } diff --git a/lib/SIL/SILFunctionType.cpp b/lib/SIL/SILFunctionType.cpp index df9384141bf26..bd2c0b9ff2ea9 100644 --- a/lib/SIL/SILFunctionType.cpp +++ b/lib/SIL/SILFunctionType.cpp @@ -754,8 +754,8 @@ class DestructureInputs { auto eltPattern = origType.getFunctionParamType(i); auto flags = params[i].getParameterFlags(); - visit(flags.getValueOwnership(), /*forSelf=*/false, - eltPattern, ty, silRepresentation); + visit(flags.getValueOwnership(), /*forSelf=*/false, eltPattern, ty, + silRepresentation, flags.isNoDerivative()); } // Process the self parameter. Note that we implicitly drop self @@ -776,7 +776,8 @@ class DestructureInputs { void visit(ValueOwnership ownership, bool forSelf, AbstractionPattern origType, CanType substType, - SILFunctionTypeRepresentation rep) { + SILFunctionTypeRepresentation rep, + bool isNonDifferentiable = false) { assert(!isa(substType)); // Tuples get handled specially, in some cases: @@ -829,9 +830,12 @@ class DestructureInputs { substTLConv); assert(!isIndirectFormalParameter(convention)); } - - Inputs.push_back(SILParameterInfo( - substTL.getLoweredType().getASTType(), convention)); + + SILParameterInfo param(substTL.getLoweredType().getASTType(), convention); + if (isNonDifferentiable) + param = param.getWithDifferentiability( + SILParameterDifferentiability::NotDifferentiable); + Inputs.push_back(param); maybeAddForeignParameters(); } @@ -1269,7 +1273,8 @@ static CanSILFunctionType getSILFunctionType( auto silExtInfo = SILFunctionType::ExtInfo() .withRepresentation(extInfo.getSILRepresentation()) .withIsPseudogeneric(pseudogeneric) - .withNoEscape(extInfo.isNoEscape()); + .withNoEscape(extInfo.isNoEscape()) + .withDifferentiabilityKind(extInfo.getDifferentiabilityKind()); // Build the substituted generic signature we extracted. bool impliedSignature = false; @@ -2734,7 +2739,7 @@ class SILTypeSubstituter : SILParameterInfo substInterface(SILParameterInfo orig) { return SILParameterInfo(visit(orig.getInterfaceType()), - orig.getConvention()); + orig.getConvention(), orig.getDifferentiability()); } /// Tuples need to have their component types substituted by these diff --git a/lib/Sema/TypeCheckType.cpp b/lib/Sema/TypeCheckType.cpp index 3cc68bf07d601..ad3ab2f99e947 100644 --- a/lib/Sema/TypeCheckType.cpp +++ b/lib/Sema/TypeCheckType.cpp @@ -2948,6 +2948,8 @@ SILParameterInfo TypeResolver::resolveSILParameter( auto convention = DefaultParameterConvention; Type type; bool hadError = false; + auto differentiability = + SILParameterDifferentiability::DifferentiableOrNotApplicable; if (auto attrRepr = dyn_cast(repr)) { auto attrs = attrRepr->getAttrs(); @@ -2973,6 +2975,10 @@ SILParameterInfo TypeResolver::resolveSILParameter( checkFor(TypeAttrKind::TAK_owned, ParameterConvention::Direct_Owned); checkFor(TypeAttrKind::TAK_guaranteed, ParameterConvention::Direct_Guaranteed); + if (attrs.has(TAK_noDerivative)) { + attrs.clearAttribute(TAK_noDerivative); + differentiability = SILParameterDifferentiability::NotDifferentiable; + } type = resolveAttributedType(attrs, attrRepr->getTypeRepr(), options); } else { @@ -2989,7 +2995,8 @@ SILParameterInfo TypeResolver::resolveSILParameter( } if (hadError) type = ErrorType::get(Context); - return SILParameterInfo(type->getCanonicalType(), convention); + return SILParameterInfo(type->getCanonicalType(), convention, + differentiability); } bool TypeResolver::resolveSingleSILResult(TypeRepr *repr, diff --git a/lib/Serialization/Deserialization.cpp b/lib/Serialization/Deserialization.cpp index fc0abb60f83b7..64ab317c826ab 100644 --- a/lib/Serialization/Deserialization.cpp +++ b/lib/Serialization/Deserialization.cpp @@ -4503,6 +4503,21 @@ Optional getActualParameterConvention(uint8_t raw) { return None; } +/// Translate from the serialization SILParameterDifferentiability enumerators, +/// which are guaranteed to be stable, to the AST ones. +static Optional +getActualSILParameterDifferentiability(uint8_t raw) { + switch (serialization::SILParameterDifferentiability(raw)) { +#define CASE(ID) \ + case serialization::SILParameterDifferentiability::ID: \ + return swift::SILParameterDifferentiability::ID; + CASE(DifferentiableOrNotApplicable) + CASE(NotDifferentiable) +#undef CASE + } + return None; +} + /// Translate from the serialization ResultConvention enumerators, /// which are guaranteed to be stable, to the AST ones. static @@ -5144,15 +5159,26 @@ class TypeDeserializer { if (!calleeConvention.hasValue()) MF.fatal(); - auto processParameter = [&](TypeID typeID, uint64_t rawConvention) - -> llvm::Expected { + auto processParameter = + [&](TypeID typeID, uint64_t rawConvention, + uint64_t ramDifferentiability) -> llvm::Expected { auto convention = getActualParameterConvention(rawConvention); if (!convention) MF.fatal(); auto type = MF.getTypeChecked(typeID); if (!type) return type.takeError(); - return SILParameterInfo(type.get()->getCanonicalType(), *convention); + auto differentiability = + swift::SILParameterDifferentiability::DifferentiableOrNotApplicable; + if (diffKind != DifferentiabilityKind::NonDifferentiable) { + auto differentiabilityOpt = + getActualSILParameterDifferentiability(ramDifferentiability); + if (!differentiabilityOpt) + MF.fatal(); + differentiability = *differentiabilityOpt; + } + return SILParameterInfo(type.get()->getCanonicalType(), *convention, + differentiability); }; auto processYield = [&](TypeID typeID, uint64_t rawConvention) @@ -5191,7 +5217,10 @@ class TypeDeserializer { for (unsigned i = 0; i != numParams; ++i) { auto typeID = variableData[nextVariableDataIndex++]; auto rawConvention = variableData[nextVariableDataIndex++]; - auto param = processParameter(typeID, rawConvention); + uint64_t differentiability = 0; + if (diffKind != DifferentiabilityKind::NonDifferentiable) + differentiability = variableData[nextVariableDataIndex++]; + auto param = processParameter(typeID, rawConvention, differentiability); if (!param) return param.takeError(); allParams.push_back(param.get()); diff --git a/lib/Serialization/ModuleFormat.h b/lib/Serialization/ModuleFormat.h index 1306adfb3341b..265ea709a7c89 100644 --- a/lib/Serialization/ModuleFormat.h +++ b/lib/Serialization/ModuleFormat.h @@ -55,7 +55,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0; /// describe what change you made. The content of this comment isn't important; /// it just ensures a conflict if two people change the module format. /// Don't worry about adhering to the 80-column limit for this line. -const uint16_t SWIFTMODULE_VERSION_MINOR = 533; // removed @_implicitly_synthesizes_nested_requirement +const uint16_t SWIFTMODULE_VERSION_MINOR = 534; // add SIL parameter differentiability /// A standard hash seed used for all string hashes in a serialized module. /// @@ -347,6 +347,13 @@ enum class ParameterConvention : uint8_t { }; using ParameterConventionField = BCFixed<4>; +// These IDs must \em not be renumbered or reordered without incrementing +// the module version. +enum class SILParameterDifferentiability : uint8_t { + DifferentiableOrNotApplicable, + NotDifferentiable, +}; + // These IDs must \em not be renumbered or reordered without incrementing // the module version. enum class ResultConvention : uint8_t { diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index f49d466d52bfb..ce49f72c4bccc 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -3766,6 +3766,17 @@ static uint8_t getRawStableParameterConvention(swift::ParameterConvention pc) { llvm_unreachable("bad parameter convention kind"); } +/// Translate from AST SILParameterDifferentiability enum to the Serialization +/// enum values, which are guaranteed to be stable. +static uint8_t +getRawSILParameterDifferentiability(swift::SILParameterDifferentiability pd) { + switch (pd) { + SIMPLE_CASE(SILParameterDifferentiability, DifferentiableOrNotApplicable) + SIMPLE_CASE(SILParameterDifferentiability, NotDifferentiable) + } + llvm_unreachable("bad parameter differentiability kind"); +} + /// Translate from the AST ResultConvention enum to the /// Serialization enum values, which are guaranteed to be stable. static uint8_t getRawStableResultConvention(swift::ResultConvention rc) { @@ -4075,6 +4086,9 @@ class Serializer::TypeSerializer : public TypeVisitor { variableData.push_back(S.addTypeRef(param.getInterfaceType())); unsigned conv = getRawStableParameterConvention(param.getConvention()); variableData.push_back(TypeID(conv)); + if (fnTy->isDifferentiable()) + variableData.push_back(TypeID( + getRawSILParameterDifferentiability(param.getDifferentiability()))); } for (auto yield : fnTy->getYields()) { variableData.push_back(S.addTypeRef(yield.getInterfaceType())); diff --git a/test/AutoDiff/SIL/Serialization/differentiation.swift b/test/AutoDiff/SIL/Serialization/differentiation.swift index e72bdd001b680..828b08ccb52b7 100644 --- a/test/AutoDiff/SIL/Serialization/differentiation.swift +++ b/test/AutoDiff/SIL/Serialization/differentiation.swift @@ -26,3 +26,23 @@ bb0(%0 : $@differentiable(linear) (Float) -> Float): // CHECK: bb0([[ARG:%.*]] : $@differentiable(linear) (Float) -> Float): // CHECK: return [[ARG]] : $@differentiable(linear) (Float) -> Float // CHECK: } + +sil @c : $@convention(thin) (@differentiable (Float, @noDerivative Float) -> Float) -> @differentiable (Float, @noDerivative Float) -> Float { +bb0(%0 : $@differentiable (Float, @noDerivative Float) -> Float): + return %0 : $@differentiable (Float, @noDerivative Float) -> Float +} + +// CHECK-LABEL: sil @c : $@convention(thin) (@differentiable (Float, @noDerivative Float) -> Float) -> @differentiable (Float, @noDerivative Float) -> Float { +// CHECK: bb0(%0 : $@differentiable (Float, @noDerivative Float) -> Float): +// CHECK: return %0 : $@differentiable (Float, @noDerivative Float) -> Float +// CHECK: } + +sil @d : $@convention(thin) (@differentiable(linear) (Float, @noDerivative Float) -> Float) -> @differentiable(linear) (Float, @noDerivative Float) -> Float { +bb0(%0 : $@differentiable(linear) (Float, @noDerivative Float) -> Float): + return %0 : $@differentiable(linear) (Float, @noDerivative Float) -> Float +} + +// CHECK-LABEL: sil @d : $@convention(thin) (@differentiable(linear) (Float, @noDerivative Float) -> Float) -> @differentiable(linear) (Float, @noDerivative Float) -> Float { +// CHECK: bb0(%0 : $@differentiable(linear) (Float, @noDerivative Float) -> Float): +// CHECK: return %0 : $@differentiable(linear) (Float, @noDerivative Float) -> Float +// CHECK: } diff --git a/test/AutoDiff/SILGen/differentiable_function.swift b/test/AutoDiff/SILGen/differentiable_function.swift new file mode 100644 index 0000000000000..e298eb690a55e --- /dev/null +++ b/test/AutoDiff/SILGen/differentiable_function.swift @@ -0,0 +1,55 @@ +// RUN: %target-swift-frontend -emit-silgen -enable-experimental-differentiable-programming %s | %FileCheck %s + +// Test SILGen for `@differentiable` function typed values. + +import _Differentiation + +@_silgen_name("differentiable") +func differentiable(_ fn: @escaping @differentiable (Float) -> Float) + -> @differentiable (Float) -> Float { + return fn +} + +@_silgen_name("linear") +func linear(_ fn: @escaping @differentiable(linear) (Float) -> Float) + -> @differentiable(linear) (Float) -> Float { + return fn +} + +@_silgen_name("differentiable_noDerivative") +func differentiable_noDerivative( + _ fn: @escaping @differentiable (Float, @noDerivative Float) -> Float +) -> @differentiable (Float, @noDerivative Float) -> Float { + return fn +} + +@_silgen_name("linear_noDerivative") +func linear_noDerivative( + _ fn: @escaping @differentiable(linear) (Float, @noDerivative Float) -> Float +) -> @differentiable(linear) (Float, @noDerivative Float) -> Float { + return fn +} + +// CHECK-LABEL: sil hidden [ossa] @differentiable : $@convention(thin) (@guaranteed @differentiable @callee_guaranteed (Float) -> Float) -> @owned @differentiable @callee_guaranteed (Float) -> Float { +// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable @callee_guaranteed (Float) -> Float): +// CHECK: [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable @callee_guaranteed (Float) -> Float +// CHECK: return [[COPIED_FN]] : $@differentiable @callee_guaranteed (Float) -> Float +// CHECK: } + +// CHECK-LABEL: sil hidden [ossa] @linear : $@convention(thin) (@guaranteed @differentiable(linear) @callee_guaranteed (Float) -> Float) -> @owned @differentiable(linear) @callee_guaranteed (Float) -> Float { +// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable(linear) @callee_guaranteed (Float) -> Float): +// CHECK: [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float +// CHECK: return [[COPIED_FN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float +// CHECK: } + +// CHECK-LABEL: sil hidden [ossa] @differentiable_noDerivative : $@convention(thin) (@guaranteed @differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float) -> @owned @differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float { +// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float): +// CHECK: [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float +// CHECK: return [[COPIED_FN]] : $@differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float +// CHECK: } + +// CHECK-LABEL: sil hidden [ossa] @linear_noDerivative : $@convention(thin) (@guaranteed @differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float) -> @owned @differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float { +// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float): +// CHECK: [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float +// CHECK: return [[COPIED_FN]] : $@differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float +// CHECK: }