diff --git a/include/swift/AST/Attr.h b/include/swift/AST/Attr.h index c90f2b6066619..a6583e4311a06 100644 --- a/include/swift/AST/Attr.h +++ b/include/swift/AST/Attr.h @@ -1709,6 +1709,13 @@ class DifferentiableAttr final /// attribute's where clause requirements. This is set only if the attribute /// has a where clause. GenericSignature DerivativeGenericSignature; + /// The source location of the implicitly inherited protocol requirement + /// `@differentiable` attribute. Used for diagnostics, not serialized. + /// + /// This is set during conformance type-checking, only for implicit + /// `@differentiable` attributes created for non-public protocol witnesses of + /// protocol requirements with `@differentiable` attributes. + SourceLoc ImplicitlyInheritedDifferentiableAttrLocation; explicit DifferentiableAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange, bool linear, @@ -1771,6 +1778,14 @@ class DifferentiableAttr final DerivativeGenericSignature = derivativeGenSig; } + SourceLoc getImplicitlyInheritedDifferentiableAttrLocation() const { + return ImplicitlyInheritedDifferentiableAttrLocation; + } + void getImplicitlyInheritedDifferentiableAttrLocation(SourceLoc loc) { + assert(isImplicit()); + ImplicitlyInheritedDifferentiableAttrLocation = loc; + } + /// Get the derivative generic environment for the given `@differentiable` /// attribute and original function. GenericEnvironment * diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 72ea6de60fba0..dc9bc78598085 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2958,6 +2958,12 @@ ERROR(overriding_decl_missing_differentiable_attr,none, "overriding declaration is missing attribute '%0'", (StringRef)) NOTE(protocol_witness_missing_differentiable_attr,none, "candidate is missing attribute '%0'", (StringRef)) +NOTE(protocol_witness_missing_differentiable_attr_nonpublic_other_file,none, + "non-public %1 %2 must have explicit '%0' attribute to satisfy " + "requirement %3 %4 (in protocol %6) because it is declared in a different " + "file than the conformance of %5 to %6", + (StringRef, DescriptiveDeclKind, DeclName, DescriptiveDeclKind, DeclName, + Type, Type)) // @derivative ERROR(derivative_attr_expected_result_tuple,none, diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index 4dd677b4ff10c..6c82cb4f776f8 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -308,7 +308,8 @@ static ValueDecl *getStandinForAccessor(AbstractStorageDecl *witness, /// witness. /// - If requirement's `@differentiable` attributes are met, or if `result` is /// not viable, returns `result`. -/// - Otherwise, returns a `DifferentiableConflict` `RequirementMatch`. +/// - Otherwise, returns a "missing `@differentiable` attribute" +/// `RequirementMatch`. // Note: the `result` argument is only necessary for using // `RequirementMatch::WitnessSubstitutions`. static RequirementMatch @@ -384,15 +385,50 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req, } if (!foundExactConfig) { bool success = false; - if (supersetConfig) { - // If the witness has a "superset" derivative configuration, create an - // implicit `@differentiable` attribute with the exact requirement - // `@differentiable` attribute parameter indices. + // If no exact witness derivative configuration was found, check + // conditions for creating an implicit witness `@differentiable` attribute + // with the exact derivative configuration: + // - If the witness has a "superset" derivative configuration. + // - If the witness is less than public and is declared in the same file + // as the conformance. + // - `@differentiable` attributes are really only significant for public + // declarations: it improves usability to not require explicit + // `@differentiable` attributes for less-visible declarations. + bool createImplicitWitnessAttribute = + supersetConfig || witness->getFormalAccess() < AccessLevel::Public; + // If the witness has less-than-public visibility and is declared in a + // different file than the conformance, produce an error. + if (!supersetConfig && witness->getFormalAccess() < AccessLevel::Public && + dc->getModuleScopeContext() != + witness->getDeclContext()->getModuleScopeContext()) { + // FIXME(TF-1014): `@differentiable` attribute diagnostic does not + // appear if associated type inference is involved. + if (auto *vdWitness = dyn_cast(witness)) { + return RequirementMatch( + getStandinForAccessor(vdWitness, AccessorKind::Get), + MatchKind::MissingDifferentiableAttr, reqDiffAttr); + } else { + return RequirementMatch(witness, MatchKind::MissingDifferentiableAttr, + reqDiffAttr); + } + } + if (createImplicitWitnessAttribute) { + auto derivativeGenSig = witnessAFD->getGenericSignature(); + if (supersetConfig) + derivativeGenSig = supersetConfig->derivativeGenericSignature; + // Use source location of the witness declaration as the source location + // of the implicit `@differentiable` attribute. auto *newAttr = DifferentiableAttr::create( - witnessAFD, /*implicit*/ true, reqDiffAttr->AtLoc, - reqDiffAttr->getRange(), reqDiffAttr->isLinear(), - reqDiffAttr->getParameterIndices(), - supersetConfig->derivativeGenericSignature); + witnessAFD, /*implicit*/ true, witness->getLoc(), witness->getLoc(), + reqDiffAttr->isLinear(), reqDiffAttr->getParameterIndices(), + derivativeGenSig); + // If the implicit attribute is inherited from a protocol requirement's + // attribute, store the protocol requirement attribute's location for + // use in diagnostics. + if (witness->getFormalAccess() < AccessLevel::Public) { + newAttr->getImplicitlyInheritedDifferentiableAttrLocation( + reqDiffAttr->getLocation()); + } auto insertion = ctx.DifferentiableAttrs.try_emplace( {witnessAFD, newAttr->getParameterIndices()}, newAttr); // Valid `@differentiable` attributes are uniqued by original function @@ -418,9 +454,9 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req, if (auto *vdWitness = dyn_cast(witness)) { return RequirementMatch( getStandinForAccessor(vdWitness, AccessorKind::Get), - MatchKind::DifferentiableConflict, reqDiffAttr); + MatchKind::MissingDifferentiableAttr, reqDiffAttr); } else { - return RequirementMatch(witness, MatchKind::DifferentiableConflict, + return RequirementMatch(witness, MatchKind::MissingDifferentiableAttr, reqDiffAttr); } } @@ -2318,14 +2354,15 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance, case MatchKind::NonObjC: diags.diagnose(match.Witness, diag::protocol_witness_not_objc); break; - case MatchKind::DifferentiableConflict: { + case MatchKind::MissingDifferentiableAttr: { + auto *witness = match.Witness; // Emit a note and fix-it showing the missing requirement `@differentiable` // attribute. auto *reqAttr = cast(match.UnmetAttribute); assert(reqAttr); // Omit printing `wrt:` clause if attribute's differentiability // parameters match inferred differentiability parameters. - auto *original = cast(match.Witness); + auto *original = cast(witness); auto *whereClauseGenEnv = reqAttr->getDerivativeGenericEnvironment(original); auto *inferredParameters = TypeChecker::inferDifferentiabilityParameters( @@ -2336,11 +2373,29 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance, llvm::raw_string_ostream os(reqDiffAttrString); reqAttr->print(os, req, omitWrtClause); os.flush(); - diags - .diagnose(match.Witness, - diag::protocol_witness_missing_differentiable_attr, - reqDiffAttrString) - .fixItInsert(match.Witness->getStartLoc(), reqDiffAttrString + ' '); + // If the witness has less-than-public visibility and is declared in a + // different file than the conformance, emit a specialized diagnostic. + if (witness->getFormalAccess() < AccessLevel::Public && + conformance->getDeclContext()->getModuleScopeContext() != + witness->getDeclContext()->getModuleScopeContext()) { + diags + .diagnose( + witness, + diag:: + protocol_witness_missing_differentiable_attr_nonpublic_other_file, + reqDiffAttrString, witness->getDescriptiveKind(), + witness->getFullName(), req->getDescriptiveKind(), + req->getFullName(), conformance->getType(), + conformance->getProtocol()->getDeclaredInterfaceType()) + .fixItInsert(match.Witness->getStartLoc(), reqDiffAttrString + ' '); + } + // Otherwise, emit a general "missing attribute" diagnostic. + else { + diags + .diagnose(witness, diag::protocol_witness_missing_differentiable_attr, + reqDiffAttrString) + .fixItInsert(witness->getStartLoc(), reqDiffAttrString + ' '); + } break; } } diff --git a/lib/Sema/TypeCheckProtocol.h b/lib/Sema/TypeCheckProtocol.h index 9050e88102637..e64b2a3aa940c 100644 --- a/lib/Sema/TypeCheckProtocol.h +++ b/lib/Sema/TypeCheckProtocol.h @@ -209,9 +209,8 @@ enum class MatchKind : uint8_t { /// The witness is explicitly @nonobjc but the requirement is @objc. NonObjC, - /// The witness does not have a `@differentiable` attribute satisfying one - /// from the requirement. - DifferentiableConflict, + /// The witness is missing a `@differentiable` attribute from the requirement. + MissingDifferentiableAttr, }; /// Describes the kind of optional adjustment performed when @@ -362,7 +361,7 @@ struct RequirementMatch { : Witness(witness), Kind(kind), WitnessType(), UnmetAttribute(attr), ReqEnv(None) { assert(!hasWitnessType() && "Should have witness type"); - assert(UnmetAttribute); + assert(hasUnmetAttribute() && "Should have unmet attribute"); } RequirementMatch(ValueDecl *witness, MatchKind kind, @@ -437,7 +436,7 @@ struct RequirementMatch { case MatchKind::RethrowsConflict: case MatchKind::ThrowsConflict: case MatchKind::NonObjC: - case MatchKind::DifferentiableConflict: + case MatchKind::MissingDifferentiableAttr: return false; } @@ -467,7 +466,7 @@ struct RequirementMatch { case MatchKind::RethrowsConflict: case MatchKind::ThrowsConflict: case MatchKind::NonObjC: - case MatchKind::DifferentiableConflict: + case MatchKind::MissingDifferentiableAttr: return false; } @@ -478,7 +477,9 @@ struct RequirementMatch { bool hasRequirement() { return Kind == MatchKind::MissingRequirement; } /// Determine whether this requirement match has an unmet attribute. - bool hasUnmetAttribute() { return Kind == MatchKind::DifferentiableConflict; } + bool hasUnmetAttribute() { + return Kind == MatchKind::MissingDifferentiableAttr; + } swift::Witness getWitness(ASTContext &ctx) const; }; diff --git a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift index 60d0d2727fd8c..91c512f9ed196 100644 --- a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift @@ -3,11 +3,11 @@ import _Differentiation // Dummy `Differentiable`-conforming type. -struct DummyTangentVector: Differentiable & AdditiveArithmetic { - static var zero: Self { Self() } - static func + (_: Self, _: Self) -> Self { Self() } - static func - (_: Self, _: Self) -> Self { Self() } - typealias TangentVector = Self +public struct DummyTangentVector: Differentiable & AdditiveArithmetic { + public static var zero: Self { Self() } + public static func + (_: Self, _: Self) -> Self { Self() } + public static func - (_: Self, _: Self) -> Self { Self() } + public typealias TangentVector = Self } @differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}} @@ -152,7 +152,10 @@ struct DifferentiableInstanceMethod: Differentiable { } // Test subscript methods. -struct SubscriptMethod { +struct SubscriptMethod: Differentiable { + typealias TangentVector = DummyTangentVector + mutating func move(along _: TangentVector) {} + @differentiable // ok subscript(implicitGetter x: Float) -> Float { return x @@ -167,14 +170,16 @@ struct SubscriptMethod { subscript(explicit x: Float) -> Float { @differentiable // ok get { return x } - @differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}} + // expected-error @+1 {{'@differentiable' attribute cannot be applied to this declaration}} + @differentiable set {} } subscript(x: Float, y: Float) -> Float { @differentiable // ok get { return x + y } - @differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}} + // expected-error @+1 {{'@differentiable' attribute cannot be applied to this declaration}} + @differentiable set {} } } @@ -232,58 +237,98 @@ protocol ProtocolRequirements: Differentiable { } protocol ProtocolRequirementsRefined: ProtocolRequirements { - // expected-error @+1 {{overriding declaration is missing attribute '@differentiable'}} {{3-3=@differentiable }} + // expected-error @+1 {{overriding declaration is missing attribute '@differentiable'}} func f1(_ x: Float) -> Float } -// expected-error @+1 {{does not conform to protocol 'ProtocolRequirements'}} -struct DiffAttrConformanceErrors: ProtocolRequirements { +// Test missing `@differentiable` attribute for internal protocol witnesses. +// No errors expected; internal `@differentiable` attributes are created. + +struct InternalDiffAttrConformance: ProtocolRequirements { typealias TangentVector = DummyTangentVector mutating func move(along _: TangentVector) {} var x: Float var y: Float - // FIXME(TF-284): Fix unexpected diagnostic. - // expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }} - // expected-note @+1 {{candidate has non-matching type '(x: Float, y: Float)'}} init(x: Float, y: Float) { self.x = x self.y = y } - // FIXME(TF-284): Fix unexpected diagnostic. - // expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }} - // expected-note @+1 {{candidate has non-matching type '(x: Float, y: Int)'}} init(x: Float, y: Int) { self.x = x self.y = Float(y) } - // expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }} - // expected-note @+1 {{candidate has non-matching type '(Float, Float) -> Float'}} func amb(x: Float, y: Float) -> Float { return x } - // expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: x)'}} {{3-3=@differentiable(wrt: x) }} - // expected-note @+1 {{candidate has non-matching type '(Float, Int) -> Float'}} func amb(x: Float, y: Int) -> Float { return x } - // expected-note @+1 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }} func f1(_ x: Float) -> Float { return x } - // expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }} @differentiable(wrt: (self, x)) func f2(_ x: Float, _ y: Float) -> Float { return x + y } } +// Test missing `@differentiable` attribute for public protocol witnesses. Errors expected. + +// expected-error @+1 {{does not conform to protocol 'ProtocolRequirements'}} +public struct PublicDiffAttrConformance: ProtocolRequirements { + public typealias TangentVector = DummyTangentVector + public mutating func move(along _: TangentVector) {} + + var x: Float + var y: Float + + // FIXME(TF-284): Fix unexpected diagnostic. + // expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{10-10=@differentiable }} + // expected-note @+1 {{candidate has non-matching type '(x: Float, y: Float)'}} + public init(x: Float, y: Float) { + self.x = x + self.y = y + } + + // FIXME(TF-284): Fix unexpected diagnostic. + // expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{10-10=@differentiable }} + // expected-note @+1 {{candidate has non-matching type '(x: Float, y: Int)'}} + public init(x: Float, y: Int) { + self.x = x + self.y = Float(y) + } + + // expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{10-10=@differentiable }} + // expected-note @+1 {{candidate has non-matching type '(Float, Float) -> Float'}} + public func amb(x: Float, y: Float) -> Float { + return x + } + + // expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: x)'}} {{10-10=@differentiable(wrt: x) }} + // expected-note @+1 {{candidate has non-matching type '(Float, Int) -> Float'}} + public func amb(x: Float, y: Int) -> Float { + return x + } + + // expected-note @+1 {{candidate is missing attribute '@differentiable'}} + public func f1(_ x: Float) -> Float { + return x + } + + // expected-note @+2 {{candidate is missing attribute '@differentiable'}} + @differentiable(wrt: (self, x)) + public func f2(_ x: Float, _ y: Float) -> Float { + return x + y + } +} + protocol ProtocolRequirementsWithDefault_NoConformingTypes { @differentiable func f1(_ x: Float) -> Float @@ -295,51 +340,38 @@ extension ProtocolRequirementsWithDefault_NoConformingTypes { } protocol ProtocolRequirementsWithDefault { - // expected-note @+2 {{protocol requires function 'f1'}} @differentiable func f1(_ x: Float) -> Float } extension ProtocolRequirementsWithDefault { - // expected-note @+1 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }} func f1(_ x: Float) -> Float { x } } -// expected-error @+1 {{type 'DiffAttrConformanceErrors2' does not conform to protocol 'ProtocolRequirementsWithDefault'}} struct DiffAttrConformanceErrors2: ProtocolRequirementsWithDefault { - typealias TangentVector = DummyTangentVector - mutating func move(along _: TangentVector) {} - - // expected-note @+1 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }} func f1(_ x: Float) -> Float { x } } protocol NotRefiningDiffable { @differentiable(wrt: x) - // expected-note @+1 {{protocol requires function 'a' with type '(Float) -> Float'; do you want to add a stub?}} func a(_ x: Float) -> Float } -// expected-error @+1 {{type 'CertainlyNotDiffableWrtSelf' does not conform to protocol 'NotRefiningDiffable'}} struct CertainlyNotDiffableWrtSelf: NotRefiningDiffable { - // expected-note @+1 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }} func a(_ x: Float) -> Float { return x * 5.0 } } - protocol TF285: Differentiable { @differentiable(wrt: (x, y)) @differentiable(wrt: x) - // expected-note @+1 {{protocol requires function 'foo(x:y:)' with type '(Float, Float) -> Float'; do you want to add a stub?}} func foo(x: Float, y: Float) -> Float } -// expected-error @+1 {{type 'TF285MissingOneDiffAttr' does not conform to protocol 'TF285'}} struct TF285MissingOneDiffAttr: TF285 { typealias TangentVector = DummyTangentVector mutating func move(along _: TangentVector) {} - // Requirement is missing an attribute. + // Requirement is missing the required `@differentiable(wrt: (x, y))` attribute. + // Since `TF285MissingOneDiffAttr.foo` is internal, the attribute is implicitly created. @differentiable(wrt: x) - // expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x, y))}} {{3-3=@differentiable(wrt: (x, y)) }} func foo(x: Float, y: Float) -> Float { return x } @@ -363,9 +395,9 @@ struct TF_521 { extension TF_521: Differentiable where T: Differentiable { // expected-note @+1 {{possibly intended match 'TF_521.TangentVector' does not conform to 'AdditiveArithmetic'}} typealias TangentVector = TF_521 - typealias AllDifferentiableVariables = TF_521 } -let _: @differentiable (Float, Float) -> TF_521 = { r, i in + +let _: @differentiable(Float, Float) -> TF_521 = { r, i in TF_521(real: r, imaginary: i) } @@ -480,6 +512,18 @@ func two9(x: Float, y: Float) -> Float { return x + y } +// Inout 'wrt:' arguments. + +@differentiable(wrt: y) +func inout1(x: Float, y: inout Float) -> Void { + let _ = x + y +} +// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} +@differentiable(wrt: y) +func inout2(x: Float, y: inout Float) -> Float { + let _ = x + y +} + // Test refining protocol requirements with `@differentiable` attribute. public protocol Distribution { @@ -549,7 +593,7 @@ class Super: Differentiable { // expected-note @+1 2 {{overridden declaration is here}} func testMissingAttributes(_ x: Float) -> Float { x } - @differentiable + @differentiable(wrt: x) func testSuperclassDerivatives(_ x: Float) -> Float { x } // Test duplicate attributes with different derivative generic signatures. @@ -563,6 +607,10 @@ class Super: Differentiable { @differentiable func dynamicSelfResult() -> Self { self } + // expected-error @+1 {{'@differentiable' attribute cannot be declared on class members returning 'Self'}} + @differentiable + var testDynamicSelfProperty: Self { self } + // TODO(TF-632): Fix "'TangentVector' is not a member type of 'Self'" diagnostic. // The underlying error should appear instead: // "covariant 'Self' can only appear at the top level of method result type". @@ -573,8 +621,8 @@ class Super: Differentiable { } class Sub: Super { - // expected-error @+2 {{overriding declaration is missing attribute '@differentiable(wrt: x)'}} {{12-12=@differentiable(wrt: x) }} - // expected-error @+1 {{overriding declaration is missing attribute '@differentiable'}} {{12-12=@differentiable }} + // expected-error @+2 {{overriding declaration is missing attribute '@differentiable(wrt: x)'}} + // expected-error @+1 {{overriding declaration is missing attribute '@differentiable'}} override func testMissingAttributes(_ x: Float) -> Float { x } } @@ -626,15 +674,14 @@ extension InoutParameters { mutating func mutatingMethod(_ other: Self) -> Self {} } -// Test unsupported accessors: `set`, `_read`, `_modify`. +// Test accessors: `set`, `_read`, `_modify`. -struct UnsupportedAccessors: Differentiable { +struct Accessors: Differentiable { typealias TangentVector = DummyTangentVector mutating func move(along _: TangentVector) {} var stored: Float var computed: Float { - // `set` has an `inout` parameter: `(inout Self) -> (Float) -> ()`. // expected-error @+1 {{'@differentiable' attribute cannot be applied to this declaration}} @differentiable set { stored = newValue }