diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index 82a5c3f76f61d..713b7ca8129b8 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -729,34 +729,6 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) { return structDecl; } -/// Add a typealias declaration with the given name and underlying target -/// struct type to the given source nominal declaration context. -static void addAssociatedTypeAliasDecl(Identifier name, DeclContext *sourceDC, - StructDecl *target, - ASTContext &Context) { - auto *nominal = sourceDC->getSelfNominalTypeDecl(); - assert(nominal && "Expected `DeclContext` to be a nominal type"); - auto lookup = nominal->lookupDirect(name); - assert(lookup.size() < 2 && - "Expected at most one associated type named member"); - // If implicit type declaration with the given name already exists in source - // struct, return it. - if (lookup.size() == 1) { - auto existingTypeDecl = dyn_cast(lookup.front()); - assert(existingTypeDecl && existingTypeDecl->isImplicit() && - "Expected lookup result to be an implicit type declaration"); - return; - } - // Otherwise, create a new typealias. - auto *aliasDecl = new (Context) - TypeAliasDecl(SourceLoc(), SourceLoc(), name, SourceLoc(), {}, sourceDC); - aliasDecl->setUnderlyingType(target->getDeclaredInterfaceType()); - aliasDecl->setImplicit(); - aliasDecl->setGenericSignature(sourceDC->getGenericSignatureOfContext()); - cast(sourceDC->getAsDecl())->addMember(aliasDecl); - aliasDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); -}; - /// Diagnose stored properties in the nominal that do not have an explicit /// `@noDerivative` attribute, but either: /// - Do not conform to `Differentiable`. @@ -842,7 +814,7 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context, } /// Get or synthesize `TangentVector` struct type. -static Type +static std::pair getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) { auto *parentDC = derived.getConformanceContext(); auto *nominal = derived.Nominal; @@ -852,20 +824,20 @@ getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) { auto *tangentStruct = getOrSynthesizeTangentVectorStruct(derived, C.Id_TangentVector); if (!tangentStruct) - return nullptr; + return std::make_pair(nullptr, nullptr); + // Check and emit warnings for implicit `@noDerivative` members. checkAndDiagnoseImplicitNoDerivative(C, nominal, parentDC); - // Add `TangentVector` typealias for `TangentVector` struct. - addAssociatedTypeAliasDecl(C.Id_TangentVector, tangentStruct, tangentStruct, - C); // Return the `TangentVector` struct type. - return parentDC->mapTypeIntoContext( - tangentStruct->getDeclaredInterfaceType()); + return std::make_pair( + parentDC->mapTypeIntoContext( + tangentStruct->getDeclaredInterfaceType()), + tangentStruct); } /// Synthesize the `TangentVector` struct type. -static Type +static std::pair deriveDifferentiable_TangentVectorStruct(DerivedConformance &derived) { auto *parentDC = derived.getConformanceContext(); auto *nominal = derived.Nominal; @@ -873,7 +845,7 @@ deriveDifferentiable_TangentVectorStruct(DerivedConformance &derived) { // If nominal type can derive `TangentVector` as the contextual `Self` type, // return it. if (canDeriveTangentVectorAsSelf(nominal, parentDC)) - return parentDC->getSelfTypeInContext(); + return std::make_pair(parentDC->getSelfTypeInContext(), nullptr); // Otherwise, get or synthesize `TangentVector` struct type. return getOrSynthesizeTangentVectorStructType(derived); @@ -914,16 +886,17 @@ ValueDecl *DerivedConformance::deriveDifferentiable(ValueDecl *requirement) { return nullptr; } -Type DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) { +std::pair +DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) { // Diagnose unknown requirements. if (requirement->getBaseName() != Context.Id_TangentVector) { Context.Diags.diagnose(requirement->getLoc(), diag::broken_differentiable_requirement); - return nullptr; + return std::make_pair(nullptr, nullptr); } // Diagnose conformances in disallowed contexts. if (checkAndDiagnoseDisallowedContext(requirement)) - return nullptr; + return std::make_pair(nullptr, nullptr); // Start an error diagnostic before attempting derivation. // If derivation succeeds, cancel the diagnostic. @@ -939,5 +912,5 @@ Type DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) { } // Otherwise, return nullptr. - return nullptr; + return std::make_pair(nullptr, nullptr); } diff --git a/lib/Sema/DerivedConformances.h b/lib/Sema/DerivedConformances.h index 1f6296ee84fcb..678c9fc208b3f 100644 --- a/lib/Sema/DerivedConformances.h +++ b/lib/Sema/DerivedConformances.h @@ -122,7 +122,8 @@ class DerivedConformance { /// Derive a Differentiable type witness for a nominal type. /// /// \returns the derived member, which will also be added to the type. - Type deriveDifferentiable(AssociatedTypeDecl *assocType); + std::pair + deriveDifferentiable(AssociatedTypeDecl *assocType); /// Derive a CaseIterable requirement for an enum if it has no associated /// values for any of its cases. diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index c7f2082bfd3a6..f641356ad70d8 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -5579,15 +5579,16 @@ ValueDecl *TypeChecker::deriveProtocolRequirement(DeclContext *DC, llvm_unreachable("unknown derivable protocol kind"); } -Type TypeChecker::deriveTypeWitness(DeclContext *DC, - NominalTypeDecl *TypeDecl, - AssociatedTypeDecl *AssocType) { +std::pair +TypeChecker::deriveTypeWitness(DeclContext *DC, + NominalTypeDecl *TypeDecl, + AssociatedTypeDecl *AssocType) { auto *protocol = cast(AssocType->getDeclContext()); auto knownKind = protocol->getKnownProtocolKind(); if (!knownKind) - return nullptr; + return std::make_pair(nullptr, nullptr); auto Decl = DC->getInnermostDeclarationDeclContext(); @@ -5595,13 +5596,13 @@ Type TypeChecker::deriveTypeWitness(DeclContext *DC, protocol); switch (*knownKind) { case KnownProtocolKind::RawRepresentable: - return derived.deriveRawRepresentable(AssocType); + return std::make_pair(derived.deriveRawRepresentable(AssocType), nullptr); case KnownProtocolKind::CaseIterable: - return derived.deriveCaseIterable(AssocType); + return std::make_pair(derived.deriveCaseIterable(AssocType), nullptr); case KnownProtocolKind::Differentiable: return derived.deriveDifferentiable(AssocType); default: - return nullptr; + return std::make_pair(nullptr, nullptr); } } diff --git a/lib/Sema/TypeCheckProtocol.h b/lib/Sema/TypeCheckProtocol.h index 5149fbf6c5a8f..297fb90ec37f9 100644 --- a/lib/Sema/TypeCheckProtocol.h +++ b/lib/Sema/TypeCheckProtocol.h @@ -822,15 +822,13 @@ class AssociatedTypeInference { /// Compute the "derived" type witness for an associated type that is /// known to the compiler. - Type computeDerivedTypeWitness(AssociatedTypeDecl *assocType); + std::pair + computeDerivedTypeWitness(AssociatedTypeDecl *assocType); /// Compute a type witness without using a specific potential witness, /// e.g., using a fixed type (from a refined protocol), default type /// on an associated type, or deriving the type. - /// - /// \param allowDerived Whether to allow "derived" type witnesses. - Type computeAbstractTypeWitness(AssociatedTypeDecl *assocType, - bool allowDerived); + Type computeAbstractTypeWitness(AssociatedTypeDecl *assocType); /// Substitute the current type witnesses into the given interface type. Type substCurrentTypeWitnesses(Type type); diff --git a/lib/Sema/TypeCheckProtocolInference.cpp b/lib/Sema/TypeCheckProtocolInference.cpp index 9e4a214187f22..81b4aad879314 100644 --- a/lib/Sema/TypeCheckProtocolInference.cpp +++ b/lib/Sema/TypeCheckProtocolInference.cpp @@ -868,38 +868,37 @@ Type AssociatedTypeInference::computeDefaultTypeWitness( return defaultType; } -Type AssociatedTypeInference::computeDerivedTypeWitness( +std::pair +AssociatedTypeInference::computeDerivedTypeWitness( AssociatedTypeDecl *assocType) { if (adoptee->hasError()) - return Type(); + return std::make_pair(Type(), nullptr); // Can we derive conformances for this protocol and adoptee? NominalTypeDecl *derivingTypeDecl = adoptee->getAnyNominal(); if (!DerivedConformance::derivesProtocolConformance(dc, derivingTypeDecl, proto)) - return Type(); + return std::make_pair(Type(), nullptr); // Try to derive the type witness. - Type derivedType = - TypeChecker::deriveTypeWitness(dc, derivingTypeDecl, assocType); - if (!derivedType) - return Type(); + auto result = TypeChecker::deriveTypeWitness(dc, derivingTypeDecl, assocType); + if (!result.first) + return std::make_pair(Type(), nullptr); - // Make sure that the derived type is sane. - if (checkTypeWitness(derivedType, assocType, conformance)) { + // Make sure that the derived type satisfies requirements. + if (checkTypeWitness(result.first, assocType, conformance)) { /// FIXME: Diagnose based on this. failedDerivedAssocType = assocType; - failedDerivedWitness = derivedType; - return Type(); + failedDerivedWitness = result.first; + return std::make_pair(Type(), nullptr); } - return derivedType; + return result; } Type AssociatedTypeInference::computeAbstractTypeWitness( - AssociatedTypeDecl *assocType, - bool allowDerived) { + AssociatedTypeDecl *assocType) { // We don't have a type witness for this associated type, so go // looking for more options. if (Type concreteType = computeFixedTypeWitness(assocType)) @@ -909,12 +908,6 @@ AssociatedTypeInference::computeAbstractTypeWitness( if (Type defaultType = computeDefaultTypeWitness(assocType)) return defaultType; - // If we can derive a type witness, do so. - if (allowDerived) { - if (Type derivedType = computeDerivedTypeWitness(assocType)) - return derivedType; - } - // If there is a generic parameter of the named type, use that. if (auto genericSig = dc->getGenericSignatureOfContext()) { for (auto gp : genericSig->getInnermostGenericParams()) { @@ -1197,8 +1190,7 @@ void AssociatedTypeInference::findSolutionsRec( // Try to compute the type without the aid of a specific potential // witness. - if (Type type = computeAbstractTypeWitness(assocType, - /*allowDerived=*/true)) { + if (Type type = computeAbstractTypeWitness(assocType)) { if (type->hasError()) { recordMissing(); return; @@ -1880,10 +1872,23 @@ auto AssociatedTypeInference::solve(ConformanceChecker &checker) continue; case ResolveWitnessResult::Missing: - // Note that we haven't resolved this associated type yet. - unresolvedAssocTypes.insert(assocType); + // We did not find the witness via name lookup. Try to derive + // it below. break; } + + // Finally, try to derive the witness if we know how. + auto derivedType = computeDerivedTypeWitness(assocType); + if (derivedType.first) { + checker.recordTypeWitness(assocType, + derivedType.first->mapTypeOutOfContext(), + derivedType.second); + continue; + } + + // We failed to derive the witness. We're going to go on to try + // to infer it from potential value witnesses next. + unresolvedAssocTypes.insert(assocType); } // Result variable to use for returns so that we get NRVO. diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h index 734ffaba59e94..f4c1c6ca61cf4 100644 --- a/lib/Sema/TypeChecker.h +++ b/lib/Sema/TypeChecker.h @@ -911,8 +911,9 @@ ValueDecl *deriveProtocolRequirement(DeclContext *DC, /// Derive an implicit type witness for the given associated type in /// the conformance of the given nominal type to some known /// protocol. -Type deriveTypeWitness(DeclContext *DC, NominalTypeDecl *nominal, - AssociatedTypeDecl *assocType); +std::pair +deriveTypeWitness(DeclContext *DC, NominalTypeDecl *nominal, + AssociatedTypeDecl *assocType); /// \name Name lookup /// diff --git a/test/Sema/enum_raw_representable.swift b/test/Sema/enum_raw_representable.swift index 59323fbb591d3..b856e3dbf387f 100644 --- a/test/Sema/enum_raw_representable.swift +++ b/test/Sema/enum_raw_representable.swift @@ -46,7 +46,8 @@ var doubles: [Double] = serialize([Bar.a, .b, .c]) var foos: [Foo] = deserialize([1, 2, 3]) var bars: [Bar] = deserialize([1.2, 3.4, 5.6]) -// Infer RawValue from witnesses. +// We reject enums where the raw type stated in the inheritance clause does not +// match the types of the witnesses. enum Color : Int { case red case blue @@ -56,11 +57,13 @@ enum Color : Int { } var rawValue: Double { + // expected-error@-1 {{invalid redeclaration of synthesized implementation for protocol requirement 'rawValue'}} return 1.0 } } var colorRaw: Color.RawValue = 7.5 +// expected-error@-1 {{cannot convert value of type 'Double' to specified type 'Color.RawValue' (aka 'Int')}} // Mismatched case types diff --git a/test/Sema/enum_raw_representable_circularity.swift b/test/Sema/enum_raw_representable_circularity.swift new file mode 100644 index 0000000000000..9fb76a0714d44 --- /dev/null +++ b/test/Sema/enum_raw_representable_circularity.swift @@ -0,0 +1,13 @@ +// RUN: %target-typecheck-verify-swift + +// This used to fail with "reference to invalid associated type 'RawValue' of type 'E'" +_ = E(rawValue: 123) + +enum E : Int { + case a = 123 + + init?(rawValue: RawValue) { + self = .a + } +} +