Skip to content

Commit 0735629

Browse files
authored
Merge pull request #39612 from Jumhyn/keypath-function-conversion
[ConstraintSystem] Allow function-function conversions for keypath literals
2 parents 17c1f4b + 0d79f45 commit 0735629

File tree

9 files changed

+333
-72
lines changed

9 files changed

+333
-72
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,12 @@ ERROR(expr_smart_keypath_application_type_mismatch,none,
697697
"key path of type %0 cannot be applied to a base of type %1",
698698
(Type, Type))
699699
ERROR(expr_keypath_root_type_mismatch, none,
700+
"key path root type %0 cannot be converted to contextual type %1",
701+
(Type, Type))
702+
ERROR(expr_keypath_type_mismatch, none,
703+
"key path of type %0 cannot be converted to contextual type %1",
704+
(Type, Type))
705+
ERROR(expr_keypath_application_root_type_mismatch, none,
700706
"key path with root type %0 cannot be applied to a base of type %1",
701707
(Type, Type))
702708
ERROR(expr_swift_keypath_anyobject_root,none,

include/swift/Sema/ConstraintSystem.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1737,6 +1737,11 @@ class Solution {
17371737
/// Retrieve the type of the \p ComponentIndex-th component in \p KP.
17381738
Type getType(const KeyPathExpr *KP, unsigned ComponentIndex) const;
17391739

1740+
TypeVariableType *getKeyPathRootType(const KeyPathExpr *keyPath) const;
1741+
1742+
TypeVariableType *
1743+
getKeyPathRootTypeIfAvailable(const KeyPathExpr *keyPath) const;
1744+
17401745
/// Retrieve the type of the given node as recorded in this solution
17411746
/// and resolve all of the type variables in contains to form a fully
17421747
/// "resolved" concrete type.

lib/Sema/CSApply.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4971,20 +4971,18 @@ namespace {
49714971
// Resolve each of the components.
49724972
bool didOptionalChain = false;
49734973
bool isFunctionType = false;
4974-
Type baseTy, leafTy;
4974+
auto baseTy = cs.simplifyType(solution.getKeyPathRootType(E));
4975+
Type leafTy;
49754976
Type exprType = cs.getType(E);
49764977
if (auto fnTy = exprType->getAs<FunctionType>()) {
4977-
baseTy = fnTy->getParams()[0].getParameterType();
49784978
leafTy = fnTy->getResult();
49794979
isFunctionType = true;
49804980
} else if (auto *existential = exprType->getAs<ExistentialType>()) {
49814981
auto layout = existential->getExistentialLayout();
49824982
auto keyPathTy = layout.explicitSuperclass->castTo<BoundGenericType>();
4983-
baseTy = keyPathTy->getGenericArgs()[0];
49844983
leafTy = keyPathTy->getGenericArgs()[1];
49854984
} else {
49864985
auto keyPathTy = exprType->castTo<BoundGenericType>();
4987-
baseTy = keyPathTy->getGenericArgs()[0];
49884986
leafTy = keyPathTy->getGenericArgs()[1];
49894987
}
49904988

@@ -5103,13 +5101,11 @@ namespace {
51035101
assert(!resolvedComponents.empty());
51045102
componentTy = resolvedComponents.back().getComponentType();
51055103
}
5106-
5104+
51075105
// Wrap a non-optional result if there was chaining involved.
51085106
if (didOptionalChain && componentTy &&
51095107
!componentTy->hasUnresolvedType() &&
51105108
!componentTy->getWithoutSpecifierType()->isEqual(leafTy)) {
5111-
assert(leafTy->getOptionalObjectType()->isEqual(
5112-
componentTy->getWithoutSpecifierType()));
51135109
auto component = KeyPathExpr::Component::forOptionalWrap(leafTy);
51145110
resolvedComponents.push_back(component);
51155111
componentTy = leafTy;
@@ -5122,11 +5118,6 @@ namespace {
51225118
// See whether there's an equivalent ObjC key path string we can produce
51235119
// for interop purposes.
51245120
checkAndSetObjCKeyPathString(E);
5125-
5126-
// The final component type ought to line up with the leaf type of the
5127-
// key path.
5128-
assert(!componentTy || componentTy->hasUnresolvedType()
5129-
|| componentTy->getWithoutSpecifierType()->isEqual(leafTy));
51305121

51315122
if (!isFunctionType)
51325123
return E;
@@ -9848,6 +9839,21 @@ Type Solution::getType(const KeyPathExpr *KP, unsigned I) const {
98489839
return keyPathComponentTypes.find(std::make_pair(KP, I))->second;
98499840
}
98509841

9842+
TypeVariableType *
9843+
Solution::getKeyPathRootType(const KeyPathExpr *keyPath) const {
9844+
auto result = getKeyPathRootTypeIfAvailable(keyPath);
9845+
assert(result);
9846+
return result;
9847+
}
9848+
9849+
TypeVariableType *
9850+
Solution::getKeyPathRootTypeIfAvailable(const KeyPathExpr *keyPath) const {
9851+
auto result = KeyPaths.find(keyPath);
9852+
if (result != KeyPaths.end())
9853+
return std::get<0>(result->second);
9854+
return nullptr;
9855+
}
9856+
98519857
Type Solution::getResolvedType(ASTNode node) const {
98529858
return simplifyType(getType(node));
98539859
}

lib/Sema/CSDiagnostics.cpp

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2497,9 +2497,13 @@ bool ContextualFailure::diagnoseAsError() {
24972497

24982498
if (path.empty()) {
24992499
if (auto *KPE = getAsExpr<KeyPathExpr>(anchor)) {
2500-
emitDiagnosticAt(KPE->getLoc(),
2501-
diag::expr_keypath_type_covert_to_contextual_type,
2502-
getFromType(), getToType());
2500+
Diag<Type, Type> diag;
2501+
if (auto ctxDiag = getDiagnosticFor(CTP, getToType())) {
2502+
diag = *ctxDiag;
2503+
} else {
2504+
diag = diag::expr_keypath_type_mismatch;
2505+
}
2506+
emitDiagnosticAt(KPE->getLoc(), diag, getFromType(), getToType());
25032507
return true;
25042508
}
25052509

@@ -2750,9 +2754,14 @@ bool ContextualFailure::diagnoseAsError() {
27502754
break;
27512755
}
27522756

2757+
case ConstraintLocator::FunctionResult:
27532758
case ConstraintLocator::KeyPathValue: {
2754-
diagnostic = diag::expr_keypath_value_covert_to_contextual_type;
2755-
break;
2759+
if (auto *KPE = getAsExpr<KeyPathExpr>(anchor)) {
2760+
diagnostic = diag::expr_keypath_value_covert_to_contextual_type;
2761+
break;
2762+
} else {
2763+
return false;
2764+
}
27562765
}
27572766

27582767
default:
@@ -8272,13 +8281,24 @@ bool CoercionAsForceCastFailure::diagnoseAsError() {
82728281

82738282
bool KeyPathRootTypeMismatchFailure::diagnoseAsError() {
82748283
auto locator = getLocator();
8284+
auto anchor = locator->getAnchor();
82758285
assert(locator->isKeyPathRoot() && "Expected a key path root");
8276-
8277-
auto baseType = getFromType();
8278-
auto rootType = getToType();
82798286

8280-
emitDiagnostic(diag::expr_keypath_root_type_mismatch,
8281-
rootType, baseType);
8287+
8288+
8289+
if (isExpr<KeyPathApplicationExpr>(anchor) || isExpr<SubscriptExpr>(anchor)) {
8290+
auto baseType = getFromType();
8291+
auto rootType = getToType();
8292+
8293+
emitDiagnostic(diag::expr_keypath_application_root_type_mismatch,
8294+
rootType, baseType);
8295+
} else {
8296+
auto rootType = getFromType();
8297+
auto expectedType = getToType();
8298+
8299+
emitDiagnostic(diag::expr_keypath_root_type_mismatch, rootType,
8300+
expectedType);
8301+
}
82828302
return true;
82838303
}
82848304

lib/Sema/CSSimplify.cpp

Lines changed: 71 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5222,6 +5222,13 @@ bool ConstraintSystem::repairFailures(
52225222
});
52235223
};
52245224

5225+
auto hasAnyRestriction = [&]() {
5226+
return llvm::any_of(conversionsOrFixes,
5227+
[](const RestrictionOrFix &correction) {
5228+
return bool(correction.getRestriction());
5229+
});
5230+
};
5231+
52255232
// Check whether this is a tuple with a single unlabeled element
52265233
// i.e. `(_: Int)` and return type of that element if so. Note that
52275234
// if the element is pack expansion type the tuple is significant.
@@ -5247,6 +5254,40 @@ bool ConstraintSystem::repairFailures(
52475254
return true;
52485255
}
52495256

5257+
auto maybeRepairKeyPathResultFailure = [&](KeyPathExpr *kpExpr) {
5258+
if (lhs->isPlaceholder() || rhs->isPlaceholder())
5259+
return true;
5260+
if (lhs->isTypeVariableOrMember() || rhs->isTypeVariableOrMember())
5261+
return false;
5262+
5263+
if (hasConversionOrRestriction(ConversionRestrictionKind::DeepEquality) ||
5264+
hasConversionOrRestriction(ConversionRestrictionKind::ValueToOptional))
5265+
return false;
5266+
5267+
auto i = kpExpr->getComponents().size() - 1;
5268+
auto lastCompLoc =
5269+
getConstraintLocator(kpExpr, LocatorPathElt::KeyPathComponent(i));
5270+
if (hasFixFor(lastCompLoc, FixKind::AllowTypeOrInstanceMember))
5271+
return true;
5272+
5273+
auto *keyPathLoc = getConstraintLocator(anchor);
5274+
5275+
if (hasFixFor(keyPathLoc))
5276+
return true;
5277+
5278+
if (auto contextualInfo = getContextualTypeInfo(anchor)) {
5279+
if (hasFixFor(getConstraintLocator(
5280+
keyPathLoc,
5281+
LocatorPathElt::ContextualType(contextualInfo->purpose))))
5282+
return true;
5283+
}
5284+
5285+
conversionsOrFixes.push_back(IgnoreContextualType::create(
5286+
*this, lhs, rhs,
5287+
getConstraintLocator(keyPathLoc, ConstraintLocator::KeyPathValue)));
5288+
return true;
5289+
};
5290+
52505291
if (path.empty()) {
52515292
if (!anchor)
52525293
return false;
@@ -5266,9 +5307,9 @@ bool ConstraintSystem::repairFailures(
52665307
// instance fix recorded.
52675308
if (auto *kpExpr = getAsExpr<KeyPathExpr>(anchor)) {
52685309
if (isKnownKeyPathType(lhs) && isKnownKeyPathType(rhs)) {
5269-
// If we have keypath capabilities for both sides and one of the bases
5270-
// is unresolved, it is too early to record fix.
5271-
if (hasConversionOrRestriction(ConversionRestrictionKind::DeepEquality))
5310+
// If we have a conversion happening here, we should let fix happen in
5311+
// simplifyRestrictedConstraint.
5312+
if (hasAnyRestriction())
52725313
return false;
52735314
}
52745315

@@ -5668,10 +5709,7 @@ bool ConstraintSystem::repairFailures(
56685709

56695710
// If there are any restrictions here we need to wait and let
56705711
// `simplifyRestrictedConstraintImpl` handle them.
5671-
if (llvm::any_of(conversionsOrFixes,
5672-
[](const RestrictionOrFix &correction) {
5673-
return bool(correction.getRestriction());
5674-
}))
5712+
if (hasAnyRestriction())
56755713
break;
56765714

56775715
if (auto *fix = fixPropertyWrapperFailure(
@@ -6090,10 +6128,7 @@ bool ConstraintSystem::repairFailures(
60906128

60916129
// If there are any restrictions here we need to wait and let
60926130
// `simplifyRestrictedConstraintImpl` handle them.
6093-
if (llvm::any_of(conversionsOrFixes,
6094-
[](const RestrictionOrFix &correction) {
6095-
return bool(correction.getRestriction());
6096-
}))
6131+
if (hasAnyRestriction())
60976132
break;
60986133

60996134
// `lhs` - is an result type and `rhs` is a contextual type.
@@ -6112,6 +6147,10 @@ bool ConstraintSystem::repairFailures(
61126147
return true;
61136148
}
61146149

6150+
if (auto *kpExpr = getAsExpr<KeyPathExpr>(anchor)) {
6151+
return maybeRepairKeyPathResultFailure(kpExpr);
6152+
}
6153+
61156154
auto *loc = getConstraintLocator(anchor, {path.begin(), path.end() - 1});
61166155
// If this is a mismatch between contextual type and (trailing)
61176156
// closure with explicitly specified result type let's record it
@@ -6683,37 +6722,9 @@ bool ConstraintSystem::repairFailures(
66836722
return true;
66846723
}
66856724
case ConstraintLocator::KeyPathValue: {
6686-
if (lhs->isPlaceholder() || rhs->isPlaceholder())
6687-
return true;
6688-
if (lhs->isTypeVariableOrMember() || rhs->isTypeVariableOrMember())
6689-
break;
6690-
6691-
if (hasConversionOrRestriction(ConversionRestrictionKind::DeepEquality) ||
6692-
hasConversionOrRestriction(ConversionRestrictionKind::ValueToOptional))
6693-
return false;
6694-
6695-
auto kpExpr = castToExpr<KeyPathExpr>(anchor);
6696-
auto i = kpExpr->getComponents().size() - 1;
6697-
auto lastCompLoc =
6698-
getConstraintLocator(kpExpr, LocatorPathElt::KeyPathComponent(i));
6699-
if (hasFixFor(lastCompLoc, FixKind::AllowTypeOrInstanceMember))
6725+
if (maybeRepairKeyPathResultFailure(getAsExpr<KeyPathExpr>(anchor)))
67006726
return true;
67016727

6702-
auto *keyPathLoc = getConstraintLocator(anchor);
6703-
6704-
if (hasFixFor(keyPathLoc))
6705-
return true;
6706-
6707-
if (auto contextualInfo = getContextualTypeInfo(anchor)) {
6708-
if (hasFixFor(getConstraintLocator(
6709-
keyPathLoc,
6710-
LocatorPathElt::ContextualType(contextualInfo->purpose))))
6711-
return true;
6712-
}
6713-
6714-
conversionsOrFixes.push_back(IgnoreContextualType::create(
6715-
*this, lhs, rhs,
6716-
getConstraintLocator(keyPathLoc, ConstraintLocator::KeyPathValue)));
67176728
break;
67186729
}
67196730
default:
@@ -12257,12 +12268,26 @@ ConstraintSystem::simplifyKeyPathConstraint(
1225712268

1225812269
if (auto fnTy = contextualTy->getAs<FunctionType>()) {
1225912270
assert(fnTy->getParams().size() == 1);
12260-
// Match up the root and value types to the function's param and return
12261-
// types. Note that we're using the type of the parameter as referenced
12262-
// from inside the function body as we'll be transforming the code into:
12263-
// { root in root[keyPath: kp] }.
12264-
contextualRootTy = fnTy->getParams()[0].getParameterType();
12265-
contextualValueTy = fnTy->getResult();
12271+
// Key paths may be converted to a function of compatible type. We will
12272+
// later form from this key path an implicit closure of the form
12273+
// `{ root in root[keyPath: kp] }` so any conversions that are valid with
12274+
// a source type of `(Root) -> Value` should be valid here too.
12275+
auto rootParam = AnyFunctionType::Param(rootTy);
12276+
auto kpFnTy = FunctionType::get(rootParam, valueTy, fnTy->getExtInfo());
12277+
12278+
// Note: because the keypath is applied to `root` as a parameter internal
12279+
// to the closure, we use the function parameter's "parameter type" rather
12280+
// than the raw type. This enables things like:
12281+
// ```
12282+
// let countKeyPath: (String...) -> Int = \.count
12283+
// ```
12284+
auto paramTy = fnTy->getParams()[0].getParameterType();
12285+
auto paramParam = AnyFunctionType::Param(paramTy);
12286+
auto paramFnTy = FunctionType::get(paramParam, fnTy->getResult(),
12287+
fnTy->getExtInfo());
12288+
12289+
return matchTypes(kpFnTy, paramFnTy, ConstraintKind::Conversion, subflags,
12290+
locator).isSuccess();
1226612291
}
1226712292

1226812293
assert(contextualRootTy && contextualValueTy);

test/Constraints/keypath.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ func testVariadicKeypathAsFunc() {
8181

8282
// These are not okay, the KeyPath should have a base that matches the
8383
// internal parameter type of the function, i.e (S...).
84-
let _: (S...) -> Int = \S.i // expected-error {{key path with root type 'S...' cannot be applied to a base of type 'S'}}
85-
takesVariadicFnWithGenericRet(\S.i) // expected-error {{key path with root type 'S...' cannot be applied to a base of type 'S'}}
84+
let _: (S...) -> Int = \S.i // expected-error {{key path root type 'S' cannot be converted to contextual type 'S...'}}
85+
takesVariadicFnWithGenericRet(\S.i) // expected-error {{key path root type 'S' cannot be converted to contextual type 'S...'}}
8686
}
8787

8888
// rdar://problem/54322807
@@ -231,7 +231,7 @@ func issue_65965() {
231231
let refKP: ReferenceWritableKeyPath<S, String>
232232
refKP = \.s
233233
// expected-error@-1 {{cannot convert key path type 'WritableKeyPath<S, String>' to contextual type 'ReferenceWritableKeyPath<S, String>'}}
234-
234+
235235
let writeKP: WritableKeyPath<S, String>
236236
writeKP = \.v
237237
// expected-error@-1 {{cannot convert key path type 'KeyPath<S, String>' to contextual type 'WritableKeyPath<S, String>'}}

0 commit comments

Comments
 (0)