Skip to content

Sema: Clean up handling of protocol operators with concrete operands #28788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 96 additions & 38 deletions lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "swift/AST/ASTWalker.h"
#include "swift/AST/ExistentialLayout.h"
#include "swift/AST/Initializer.h"
#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/GenericSignature.h"
#include "swift/AST/ParameterList.h"
#include "swift/AST/ProtocolConformance.h"
Expand Down Expand Up @@ -378,6 +379,66 @@ namespace {
return base.getOldType();
}

// Returns None if the AST does not contain enough information to recover
// substitutions; this is different from an Optional(SubstitutionMap()),
// indicating a valid call to a non-generic operator.
Optional<SubstitutionMap>
getOperatorSubstitutions(ValueDecl *witness, Type refType) {
// We have to recover substitutions in this hacky way because
// the AST does not retain enough information to devirtualize
// calls like this.
auto witnessType = witness->getInterfaceType();

// Compute the substitutions.
auto *gft = witnessType->getAs<GenericFunctionType>();
if (gft == nullptr) {
if (refType->isEqual(witnessType))
return SubstitutionMap();
return None;
}

auto sig = gft->getGenericSignature();
auto *env = sig->getGenericEnvironment();

witnessType = FunctionType::get(gft->getParams(),
gft->getResult(),
gft->getExtInfo());
witnessType = env->mapTypeIntoContext(witnessType);

TypeSubstitutionMap subs;
auto substType = witnessType->substituteBindingsTo(
refType,
[&](ArchetypeType *origType, CanType substType) -> CanType {
if (auto gpType = dyn_cast<GenericTypeParamType>(
origType->getInterfaceType()->getCanonicalType()))
subs[gpType] = substType;

return substType;
});

// If substitution failed, it means that the protocol requirement type
// and the witness type did not match up. The only time that this
// should happen is when the witness is defined in a base class and
// the actual call uses a derived class. For example,
//
// protocol P { func +(lhs: Self, rhs: Self) }
// class Base : P { func +(lhs: Base, rhs: Base) {} }
// class Derived : Base {}
//
// If we enter this code path with two operands of type Derived,
// we know we're calling the protocol requirement P.+, with a
// substituted type of (Derived, Derived) -> (). But the type of
// the witness is (Base, Base) -> (). Just bail out and make a
// witness method call in this rare case; SIL mandatory optimizations
// will likely devirtualize it anyway.
if (!substType)
return None;

return SubstitutionMap::get(sig,
QueryTypeSubstitutionMap{subs},
TypeChecker::LookUpConformance(cs.DC));
}

public:
/// Build a reference to the given declaration.
Expr *buildDeclRef(SelectedOverload overload, DeclNameLoc loc,
Expand All @@ -400,56 +461,53 @@ namespace {

// Handle operator requirements found in protocols.
if (auto proto = dyn_cast<ProtocolDecl>(decl->getDeclContext())) {
// If we don't have an archetype or existential, we have to call the
// witness.
// If we have a concrete conformance, build a call to the witness.
//
// FIXME: This is awful. We should be able to handle this as a call to
// the protocol requirement with Self == the concrete type, and SILGen
// (or later) can devirtualize as appropriate.
if (!baseTy->is<ArchetypeType>() && !baseTy->isAnyExistentialType()) {
auto conformance =
TypeChecker::conformsToProtocol(
baseTy, proto, cs.DC,
ConformanceCheckFlags::InExpression);
if (conformance.isConcrete()) {
if (auto witness =
conformance.getConcrete()->getWitnessDecl(decl)) {
// Hack up an AST that we can type-check (independently) to get
// it into the right form.
// FIXME: the hop through 'getDecl()' is because
// SpecializedProtocolConformance doesn't substitute into
// witnesses' ConcreteDeclRefs.
Type expectedFnType = simplifyType(overload.openedType);
assert(expectedFnType->isEqual(
fullType->castTo<AnyFunctionType>()->getResult()) &&
"Cannot handle adjustments made to the opened type");
auto conformance =
TypeChecker::conformsToProtocol(
baseTy, proto, cs.DC,
ConformanceCheckFlags::InExpression);
if (conformance.isConcrete()) {
if (auto witness = conformance.getConcrete()->getWitnessDecl(decl)) {
// The fullType was computed by substituting the protocol
// requirement so it always has a (Self) -> ... curried
// application. Strip it off if the witness was a top-level
// function.
Type refType;
if (witness->getDeclContext()->isTypeContext())
refType = fullType;
else
refType = fullType->castTo<AnyFunctionType>()->getResult();

// Build the AST for the call to the witness.
auto subMap = getOperatorSubstitutions(witness, refType);
if (subMap) {
ConcreteDeclRef witnessRef(witness, *subMap);
auto declRefExpr = new (ctx) DeclRefExpr(witnessRef, loc,
/*Implicit=*/false);
declRefExpr->setFunctionRefKind(choice.getFunctionRefKind());
cs.setType(declRefExpr, refType);

Expr *refExpr;
if (witness->getDeclContext()->isTypeContext()) {
// If the operator is a type member, add the implicit
// (Self) -> ... call.
Expr *base =
TypeExpr::createImplicitHack(loc.getBaseNameLoc(), baseTy,
ctx);
refExpr = new (ctx) MemberRefExpr(base, SourceLoc(), witness,
loc, /*Implicit=*/true);
cs.setType(base, MetatypeType::get(baseTy));

refExpr = new (ctx) DotSyntaxCallExpr(declRefExpr,
SourceLoc(), base);
auto refType = fullType->castTo<FunctionType>()->getResult();
cs.setType(refExpr, refType);
} else {
auto declRefExpr = new (ctx) DeclRefExpr(witness, loc,
/*Implicit=*/false);
declRefExpr->setFunctionRefKind(choice.getFunctionRefKind());
refExpr = declRefExpr;
}

auto resultTy = TypeChecker::typeCheckExpression(
refExpr, cs.DC, TypeLoc::withoutLoc(expectedFnType),
CTP_CannotFail);
if (!resultTy)
return nullptr;

cs.cacheExprTypes(refExpr);

// Remove an outer function-conversion expression. This
// happens when we end up referring to a witness for a
// superclass conformance, and 'Self' differs.
if (auto fnConv = dyn_cast<FunctionConversionExpr>(refExpr))
refExpr = fnConv->getSubExpr();

return forceUnwrapIfExpected(refExpr, choice, locator);
}
}
Expand Down
60 changes: 60 additions & 0 deletions test/SILGen/protocol_operators.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s

infix operator +++

protocol Twig {
static func +++(lhs: Self, rhs: Self)
}

struct Branch : Twig {
@_implements(Twig, +++(_:_:))
static func doIt(_: Branch, _: Branch) {}
}

// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators9useBranchyyAA0D0VF : $@convention(thin) (Branch) -> () {
// CHECK: function_ref @$s18protocol_operators6BranchV4doItyyAC_ACtFZ : $@convention(method) (Branch, Branch, @thin Branch.Type) -> ()
// CHECK: return
func useBranch(_ b: Branch) {
b +++ b
}

class Stick : Twig {
static func +++(lhs: Stick, rhs: Stick) {}
}

class Stuck : Stick, ExpressibleByIntegerLiteral {
typealias IntegerLiteralType = Int

required init(integerLiteral: Int) {}
}

// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators8useStickyyAA5StuckC_AA0D0CtF : $@convention(thin) (@guaranteed Stuck, @guaranteed Stick) -> () {
// CHECK: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
// CHECK: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
// CHECK: witness_method $Stuck, #Twig."+++"!1 : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: return
func useStick(_ a: Stuck, _ b: Stick) {
_ = a +++ b
_ = b +++ b
_ = a +++ 5
}

class Twine<X> : Twig {
static func +++(lhs: Twine, rhs: Twine) {}
}

class Rope : Twine<Int>, ExpressibleByIntegerLiteral {
typealias IntegerLiteralType = Int

required init(integerLiteral: Int) {}
}

// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators7useRopeyyAA0D0C_ADtF : $@convention(thin) (@guaranteed Rope, @guaranteed Rope) -> () {
// CHECK: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
// CHECK: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
// CHECK: witness_method $Rope, #Twig."+++"!1 : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
func useRope(_ r: Rope, _ s: Rope) {
_ = r +++ s
_ = s +++ s
_ = r +++ 5
}