diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index 21e74782f3db7..ff599f150d0a2 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -26,6 +26,7 @@ #include "swift/AST/ASTWalker.h" #include "swift/AST/ExistentialLayout.h" #include "swift/AST/Initializer.h" +#include "swift/AST/GenericEnvironment.h" #include "swift/AST/GenericSignature.h" #include "swift/AST/ParameterList.h" #include "swift/AST/ProtocolConformance.h" @@ -378,6 +379,66 @@ namespace { return base.getOldType(); } + // Returns None if the AST does not contain enough information to recover + // substitutions; this is different from an Optional(SubstitutionMap()), + // indicating a valid call to a non-generic operator. + Optional + getOperatorSubstitutions(ValueDecl *witness, Type refType) { + // We have to recover substitutions in this hacky way because + // the AST does not retain enough information to devirtualize + // calls like this. + auto witnessType = witness->getInterfaceType(); + + // Compute the substitutions. + auto *gft = witnessType->getAs(); + if (gft == nullptr) { + if (refType->isEqual(witnessType)) + return SubstitutionMap(); + return None; + } + + auto sig = gft->getGenericSignature(); + auto *env = sig->getGenericEnvironment(); + + witnessType = FunctionType::get(gft->getParams(), + gft->getResult(), + gft->getExtInfo()); + witnessType = env->mapTypeIntoContext(witnessType); + + TypeSubstitutionMap subs; + auto substType = witnessType->substituteBindingsTo( + refType, + [&](ArchetypeType *origType, CanType substType) -> CanType { + if (auto gpType = dyn_cast( + origType->getInterfaceType()->getCanonicalType())) + subs[gpType] = substType; + + return substType; + }); + + // If substitution failed, it means that the protocol requirement type + // and the witness type did not match up. The only time that this + // should happen is when the witness is defined in a base class and + // the actual call uses a derived class. For example, + // + // protocol P { func +(lhs: Self, rhs: Self) } + // class Base : P { func +(lhs: Base, rhs: Base) {} } + // class Derived : Base {} + // + // If we enter this code path with two operands of type Derived, + // we know we're calling the protocol requirement P.+, with a + // substituted type of (Derived, Derived) -> (). But the type of + // the witness is (Base, Base) -> (). Just bail out and make a + // witness method call in this rare case; SIL mandatory optimizations + // will likely devirtualize it anyway. + if (!substType) + return None; + + return SubstitutionMap::get(sig, + QueryTypeSubstitutionMap{subs}, + TypeChecker::LookUpConformance(cs.DC)); + } + public: /// Build a reference to the given declaration. Expr *buildDeclRef(SelectedOverload overload, DeclNameLoc loc, @@ -400,56 +461,53 @@ namespace { // Handle operator requirements found in protocols. if (auto proto = dyn_cast(decl->getDeclContext())) { - // If we don't have an archetype or existential, we have to call the - // witness. + // If we have a concrete conformance, build a call to the witness. + // // FIXME: This is awful. We should be able to handle this as a call to // the protocol requirement with Self == the concrete type, and SILGen // (or later) can devirtualize as appropriate. - if (!baseTy->is() && !baseTy->isAnyExistentialType()) { - auto conformance = - TypeChecker::conformsToProtocol( - baseTy, proto, cs.DC, - ConformanceCheckFlags::InExpression); - if (conformance.isConcrete()) { - if (auto witness = - conformance.getConcrete()->getWitnessDecl(decl)) { - // Hack up an AST that we can type-check (independently) to get - // it into the right form. - // FIXME: the hop through 'getDecl()' is because - // SpecializedProtocolConformance doesn't substitute into - // witnesses' ConcreteDeclRefs. - Type expectedFnType = simplifyType(overload.openedType); - assert(expectedFnType->isEqual( - fullType->castTo()->getResult()) && - "Cannot handle adjustments made to the opened type"); + auto conformance = + TypeChecker::conformsToProtocol( + baseTy, proto, cs.DC, + ConformanceCheckFlags::InExpression); + if (conformance.isConcrete()) { + if (auto witness = conformance.getConcrete()->getWitnessDecl(decl)) { + // The fullType was computed by substituting the protocol + // requirement so it always has a (Self) -> ... curried + // application. Strip it off if the witness was a top-level + // function. + Type refType; + if (witness->getDeclContext()->isTypeContext()) + refType = fullType; + else + refType = fullType->castTo()->getResult(); + + // Build the AST for the call to the witness. + auto subMap = getOperatorSubstitutions(witness, refType); + if (subMap) { + ConcreteDeclRef witnessRef(witness, *subMap); + auto declRefExpr = new (ctx) DeclRefExpr(witnessRef, loc, + /*Implicit=*/false); + declRefExpr->setFunctionRefKind(choice.getFunctionRefKind()); + cs.setType(declRefExpr, refType); + Expr *refExpr; if (witness->getDeclContext()->isTypeContext()) { + // If the operator is a type member, add the implicit + // (Self) -> ... call. Expr *base = TypeExpr::createImplicitHack(loc.getBaseNameLoc(), baseTy, ctx); - refExpr = new (ctx) MemberRefExpr(base, SourceLoc(), witness, - loc, /*Implicit=*/true); + cs.setType(base, MetatypeType::get(baseTy)); + + refExpr = new (ctx) DotSyntaxCallExpr(declRefExpr, + SourceLoc(), base); + auto refType = fullType->castTo()->getResult(); + cs.setType(refExpr, refType); } else { - auto declRefExpr = new (ctx) DeclRefExpr(witness, loc, - /*Implicit=*/false); - declRefExpr->setFunctionRefKind(choice.getFunctionRefKind()); refExpr = declRefExpr; } - auto resultTy = TypeChecker::typeCheckExpression( - refExpr, cs.DC, TypeLoc::withoutLoc(expectedFnType), - CTP_CannotFail); - if (!resultTy) - return nullptr; - - cs.cacheExprTypes(refExpr); - - // Remove an outer function-conversion expression. This - // happens when we end up referring to a witness for a - // superclass conformance, and 'Self' differs. - if (auto fnConv = dyn_cast(refExpr)) - refExpr = fnConv->getSubExpr(); - return forceUnwrapIfExpected(refExpr, choice, locator); } } diff --git a/test/SILGen/protocol_operators.swift b/test/SILGen/protocol_operators.swift new file mode 100644 index 0000000000000..fd11a62cecf8f --- /dev/null +++ b/test/SILGen/protocol_operators.swift @@ -0,0 +1,60 @@ +// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s + +infix operator +++ + +protocol Twig { + static func +++(lhs: Self, rhs: Self) +} + +struct Branch : Twig { + @_implements(Twig, +++(_:_:)) + static func doIt(_: Branch, _: Branch) {} +} + +// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators9useBranchyyAA0D0VF : $@convention(thin) (Branch) -> () { +// CHECK: function_ref @$s18protocol_operators6BranchV4doItyyAC_ACtFZ : $@convention(method) (Branch, Branch, @thin Branch.Type) -> () +// CHECK: return +func useBranch(_ b: Branch) { + b +++ b +} + +class Stick : Twig { + static func +++(lhs: Stick, rhs: Stick) {} +} + +class Stuck : Stick, ExpressibleByIntegerLiteral { + typealias IntegerLiteralType = Int + + required init(integerLiteral: Int) {} +} + +// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators8useStickyyAA5StuckC_AA0D0CtF : $@convention(thin) (@guaranteed Stuck, @guaranteed Stick) -> () { +// CHECK: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> () +// CHECK: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> () +// CHECK: witness_method $Stuck, #Twig."+++"!1 : (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () +// CHECK: return +func useStick(_ a: Stuck, _ b: Stick) { + _ = a +++ b + _ = b +++ b + _ = a +++ 5 +} + +class Twine : Twig { + static func +++(lhs: Twine, rhs: Twine) {} +} + +class Rope : Twine, ExpressibleByIntegerLiteral { + typealias IntegerLiteralType = Int + + required init(integerLiteral: Int) {} +} + +// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators7useRopeyyAA0D0C_ADtF : $@convention(thin) (@guaranteed Rope, @guaranteed Rope) -> () { +// CHECK: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> () +// CHECK: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> () +// CHECK: witness_method $Rope, #Twig."+++"!1 : (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () +func useRope(_ r: Rope, _ s: Rope) { + _ = r +++ s + _ = s +++ s + _ = r +++ 5 +}