Skip to content

Commit 0a89228

Browse files
authored
Merge pull request #28788 from slavapestov/csapply-protocol-operator-cleanup
Sema: Clean up handling of protocol operators with concrete operands
2 parents 448a14e + 594044a commit 0a89228

File tree

2 files changed

+156
-38
lines changed

2 files changed

+156
-38
lines changed

lib/Sema/CSApply.cpp

Lines changed: 96 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "swift/AST/ASTWalker.h"
2727
#include "swift/AST/ExistentialLayout.h"
2828
#include "swift/AST/Initializer.h"
29+
#include "swift/AST/GenericEnvironment.h"
2930
#include "swift/AST/GenericSignature.h"
3031
#include "swift/AST/ParameterList.h"
3132
#include "swift/AST/ProtocolConformance.h"
@@ -378,6 +379,66 @@ namespace {
378379
return base.getOldType();
379380
}
380381

382+
// Returns None if the AST does not contain enough information to recover
383+
// substitutions; this is different from an Optional(SubstitutionMap()),
384+
// indicating a valid call to a non-generic operator.
385+
Optional<SubstitutionMap>
386+
getOperatorSubstitutions(ValueDecl *witness, Type refType) {
387+
// We have to recover substitutions in this hacky way because
388+
// the AST does not retain enough information to devirtualize
389+
// calls like this.
390+
auto witnessType = witness->getInterfaceType();
391+
392+
// Compute the substitutions.
393+
auto *gft = witnessType->getAs<GenericFunctionType>();
394+
if (gft == nullptr) {
395+
if (refType->isEqual(witnessType))
396+
return SubstitutionMap();
397+
return None;
398+
}
399+
400+
auto sig = gft->getGenericSignature();
401+
auto *env = sig->getGenericEnvironment();
402+
403+
witnessType = FunctionType::get(gft->getParams(),
404+
gft->getResult(),
405+
gft->getExtInfo());
406+
witnessType = env->mapTypeIntoContext(witnessType);
407+
408+
TypeSubstitutionMap subs;
409+
auto substType = witnessType->substituteBindingsTo(
410+
refType,
411+
[&](ArchetypeType *origType, CanType substType) -> CanType {
412+
if (auto gpType = dyn_cast<GenericTypeParamType>(
413+
origType->getInterfaceType()->getCanonicalType()))
414+
subs[gpType] = substType;
415+
416+
return substType;
417+
});
418+
419+
// If substitution failed, it means that the protocol requirement type
420+
// and the witness type did not match up. The only time that this
421+
// should happen is when the witness is defined in a base class and
422+
// the actual call uses a derived class. For example,
423+
//
424+
// protocol P { func +(lhs: Self, rhs: Self) }
425+
// class Base : P { func +(lhs: Base, rhs: Base) {} }
426+
// class Derived : Base {}
427+
//
428+
// If we enter this code path with two operands of type Derived,
429+
// we know we're calling the protocol requirement P.+, with a
430+
// substituted type of (Derived, Derived) -> (). But the type of
431+
// the witness is (Base, Base) -> (). Just bail out and make a
432+
// witness method call in this rare case; SIL mandatory optimizations
433+
// will likely devirtualize it anyway.
434+
if (!substType)
435+
return None;
436+
437+
return SubstitutionMap::get(sig,
438+
QueryTypeSubstitutionMap{subs},
439+
TypeChecker::LookUpConformance(cs.DC));
440+
}
441+
381442
public:
382443
/// Build a reference to the given declaration.
383444
Expr *buildDeclRef(SelectedOverload overload, DeclNameLoc loc,
@@ -400,56 +461,53 @@ namespace {
400461

401462
// Handle operator requirements found in protocols.
402463
if (auto proto = dyn_cast<ProtocolDecl>(decl->getDeclContext())) {
403-
// If we don't have an archetype or existential, we have to call the
404-
// witness.
464+
// If we have a concrete conformance, build a call to the witness.
465+
//
405466
// FIXME: This is awful. We should be able to handle this as a call to
406467
// the protocol requirement with Self == the concrete type, and SILGen
407468
// (or later) can devirtualize as appropriate.
408-
if (!baseTy->is<ArchetypeType>() && !baseTy->isAnyExistentialType()) {
409-
auto conformance =
410-
TypeChecker::conformsToProtocol(
411-
baseTy, proto, cs.DC,
412-
ConformanceCheckFlags::InExpression);
413-
if (conformance.isConcrete()) {
414-
if (auto witness =
415-
conformance.getConcrete()->getWitnessDecl(decl)) {
416-
// Hack up an AST that we can type-check (independently) to get
417-
// it into the right form.
418-
// FIXME: the hop through 'getDecl()' is because
419-
// SpecializedProtocolConformance doesn't substitute into
420-
// witnesses' ConcreteDeclRefs.
421-
Type expectedFnType = simplifyType(overload.openedType);
422-
assert(expectedFnType->isEqual(
423-
fullType->castTo<AnyFunctionType>()->getResult()) &&
424-
"Cannot handle adjustments made to the opened type");
469+
auto conformance =
470+
TypeChecker::conformsToProtocol(
471+
baseTy, proto, cs.DC,
472+
ConformanceCheckFlags::InExpression);
473+
if (conformance.isConcrete()) {
474+
if (auto witness = conformance.getConcrete()->getWitnessDecl(decl)) {
475+
// The fullType was computed by substituting the protocol
476+
// requirement so it always has a (Self) -> ... curried
477+
// application. Strip it off if the witness was a top-level
478+
// function.
479+
Type refType;
480+
if (witness->getDeclContext()->isTypeContext())
481+
refType = fullType;
482+
else
483+
refType = fullType->castTo<AnyFunctionType>()->getResult();
484+
485+
// Build the AST for the call to the witness.
486+
auto subMap = getOperatorSubstitutions(witness, refType);
487+
if (subMap) {
488+
ConcreteDeclRef witnessRef(witness, *subMap);
489+
auto declRefExpr = new (ctx) DeclRefExpr(witnessRef, loc,
490+
/*Implicit=*/false);
491+
declRefExpr->setFunctionRefKind(choice.getFunctionRefKind());
492+
cs.setType(declRefExpr, refType);
493+
425494
Expr *refExpr;
426495
if (witness->getDeclContext()->isTypeContext()) {
496+
// If the operator is a type member, add the implicit
497+
// (Self) -> ... call.
427498
Expr *base =
428499
TypeExpr::createImplicitHack(loc.getBaseNameLoc(), baseTy,
429500
ctx);
430-
refExpr = new (ctx) MemberRefExpr(base, SourceLoc(), witness,
431-
loc, /*Implicit=*/true);
501+
cs.setType(base, MetatypeType::get(baseTy));
502+
503+
refExpr = new (ctx) DotSyntaxCallExpr(declRefExpr,
504+
SourceLoc(), base);
505+
auto refType = fullType->castTo<FunctionType>()->getResult();
506+
cs.setType(refExpr, refType);
432507
} else {
433-
auto declRefExpr = new (ctx) DeclRefExpr(witness, loc,
434-
/*Implicit=*/false);
435-
declRefExpr->setFunctionRefKind(choice.getFunctionRefKind());
436508
refExpr = declRefExpr;
437509
}
438510

439-
auto resultTy = TypeChecker::typeCheckExpression(
440-
refExpr, cs.DC, TypeLoc::withoutLoc(expectedFnType),
441-
CTP_CannotFail);
442-
if (!resultTy)
443-
return nullptr;
444-
445-
cs.cacheExprTypes(refExpr);
446-
447-
// Remove an outer function-conversion expression. This
448-
// happens when we end up referring to a witness for a
449-
// superclass conformance, and 'Self' differs.
450-
if (auto fnConv = dyn_cast<FunctionConversionExpr>(refExpr))
451-
refExpr = fnConv->getSubExpr();
452-
453511
return forceUnwrapIfExpected(refExpr, choice, locator);
454512
}
455513
}

test/SILGen/protocol_operators.swift

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s
2+
3+
infix operator +++
4+
5+
protocol Twig {
6+
static func +++(lhs: Self, rhs: Self)
7+
}
8+
9+
struct Branch : Twig {
10+
@_implements(Twig, +++(_:_:))
11+
static func doIt(_: Branch, _: Branch) {}
12+
}
13+
14+
// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators9useBranchyyAA0D0VF : $@convention(thin) (Branch) -> () {
15+
// CHECK: function_ref @$s18protocol_operators6BranchV4doItyyAC_ACtFZ : $@convention(method) (Branch, Branch, @thin Branch.Type) -> ()
16+
// CHECK: return
17+
func useBranch(_ b: Branch) {
18+
b +++ b
19+
}
20+
21+
class Stick : Twig {
22+
static func +++(lhs: Stick, rhs: Stick) {}
23+
}
24+
25+
class Stuck : Stick, ExpressibleByIntegerLiteral {
26+
typealias IntegerLiteralType = Int
27+
28+
required init(integerLiteral: Int) {}
29+
}
30+
31+
// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators8useStickyyAA5StuckC_AA0D0CtF : $@convention(thin) (@guaranteed Stuck, @guaranteed Stick) -> () {
32+
// CHECK: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
33+
// CHECK: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
34+
// CHECK: witness_method $Stuck, #Twig."+++"!1 : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
35+
// CHECK: return
36+
func useStick(_ a: Stuck, _ b: Stick) {
37+
_ = a +++ b
38+
_ = b +++ b
39+
_ = a +++ 5
40+
}
41+
42+
class Twine<X> : Twig {
43+
static func +++(lhs: Twine, rhs: Twine) {}
44+
}
45+
46+
class Rope : Twine<Int>, ExpressibleByIntegerLiteral {
47+
typealias IntegerLiteralType = Int
48+
49+
required init(integerLiteral: Int) {}
50+
}
51+
52+
// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators7useRopeyyAA0D0C_ADtF : $@convention(thin) (@guaranteed Rope, @guaranteed Rope) -> () {
53+
// CHECK: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
54+
// CHECK: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
55+
// CHECK: witness_method $Rope, #Twig."+++"!1 : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
56+
func useRope(_ r: Rope, _ s: Rope) {
57+
_ = r +++ s
58+
_ = s +++ s
59+
_ = r +++ 5
60+
}

0 commit comments

Comments
 (0)