diff --git a/lib/Sema/TypeCheckDecl.cpp b/lib/Sema/TypeCheckDecl.cpp index 9832d239c83b0..1bf872e638ffb 100644 --- a/lib/Sema/TypeCheckDecl.cpp +++ b/lib/Sema/TypeCheckDecl.cpp @@ -3025,6 +3025,34 @@ AllMembersRequest::evaluate( return evaluateMembersRequest(idc, MembersRequestKind::All); } +static bool isTypeInferredByTypealias(TypeAliasDecl *typealias, + NominalTypeDecl *nominal) { + if (!nominal->isGeneric()){ + return false; + } + + auto nominalGenericArguments = nominal->getDeclaredInterfaceType() + ->getAs() + ->getGenericArgs(); + auto typealiasGenericArguments = typealias->getUnderlyingType() + ->getAs() + ->getGenericArgs(); + + for (size_t i = 0; i < nominalGenericArguments.size(); i++) { + auto nominalBoundGenericType = nominalGenericArguments[i]; + auto typealiasBoundGenericType = typealiasGenericArguments[i]; + if (nominalBoundGenericType->isEqual(typealiasBoundGenericType)) { + continue; + } + + if (typealiasBoundGenericType->hasTypeParameter()) { + return false; + } + } + + return true; +} + bool TypeChecker::isPassThroughTypealias(TypeAliasDecl *typealias, NominalTypeDecl *nominal) { // Pass-through only makes sense when the typealias refers to a nominal @@ -3045,17 +3073,39 @@ bool TypeChecker::isPassThroughTypealias(TypeAliasDecl *typealias, // If neither is generic, we're done: it's a pass-through alias. if (!nominalSig) return true; - // Check that the type parameters are the same the whole way through. auto nominalGenericParams = nominalSig.getGenericParams(); auto typealiasGenericParams = typealiasSig.getGenericParams(); - if (nominalGenericParams.size() != typealiasGenericParams.size()) - return false; - if (!std::equal(nominalGenericParams.begin(), nominalGenericParams.end(), - typealiasGenericParams.begin(), - [](GenericTypeParamType *gp1, GenericTypeParamType *gp2) { - return gp1->isEqual(gp2); - })) - return false; + + if (nominalGenericParams.size() != typealiasGenericParams.size()) { + + unsigned nominalMaxDepth = nominalGenericParams.back()->getDepth(); + unsigned typealiasMaxDepth = typealiasGenericParams.back()->getDepth(); + unsigned maxDepth = std::max(nominalMaxDepth, typealiasMaxDepth); + + while (!nominalGenericParams.empty() && + nominalGenericParams.back()->getDepth() == maxDepth) { + nominalGenericParams = nominalGenericParams.drop_back(); + } + + while (!typealiasGenericParams.empty() && + typealiasGenericParams.back()->getDepth() == maxDepth) { + typealiasGenericParams = typealiasGenericParams.drop_back(); + } + + if (nominalGenericParams.size() != typealiasGenericParams.size()) { + return false; + } + + if (!std::equal(nominalGenericParams.begin(), nominalGenericParams.end(), + typealiasGenericParams.begin(), + [](GenericTypeParamType *gp1, GenericTypeParamType *gp2) { + return gp1->isEqual(gp2); + })) { + return false; + } + + return isTypeInferredByTypealias(typealias, nominal); + } // If neither is generic at this level, we have a pass-through typealias. if (!typealias->isGeneric()) return true; diff --git a/test/decl/ext/specialize.swift b/test/decl/ext/specialize.swift index 84c1faadd7b00..ded791e96c6f8 100644 --- a/test/decl/ext/specialize.swift +++ b/test/decl/ext/specialize.swift @@ -32,6 +32,106 @@ extension IntFoo where U == Int { Foo(x: "test", y: 1).hello() + +struct Field { + let tag: Tag + let value: Value +} + +typealias IntField = Field + +extension IntField { + func adding(_ value: Int) -> Self { + Field(tag: tag, value: self.value + value) + } +} + +struct S10 {} +typealias InferredSpecializedNestedTypes = S10 +extension InferredSpecializedNestedTypes { + func returnTuple(value: Y) -> (Int, [Int]?) { + return value + } +} + + +struct S2 { + let x: X + let y: Y + let z: Z +} + +typealias A2 = S2 + +extension A2 { + func test() { + let int: Int + let _: X = int // expected-error {{cannot convert value of type 'Int' to specified type 'X'}} + } +} + + +struct S4 { + // Generic parameters: + // Depth: 0 0 1 + struct Nested { + let c: C + } +} + +struct S5 { + // Generic parameters: + // Depth: 0 1 1 + typealias Alias = S4.Nested where A == Int +} + +extension S5.Alias{ + func test() { + let int: Int + let _: A = int // expected-error {{cannot convert value of type 'Int' to specified type 'A'}} + } +} + + +struct S11 { + struct Inner {} +} + +struct S12 { + struct Inner {} + typealias A1 = S11.Inner + typealias A2 = S12.Inner where T == Int +} + +extension S12.A1 { + func foo1() { + let int: Int + let _: T = int + } +} + +extension S12.A2 { + func foo2() { + let int: Int + let _: T = int + } +} + + +struct S13 { + struct Inner {} +} +struct S14 { + typealias A = S13.Inner +} +extension S14.A { + func test() { + let int: Int + let _: U = int // error + } +} + + struct MyType { var a : TyA, b : TyB }