diff --git a/include/swift/Sema/CSBindings.h b/include/swift/Sema/CSBindings.h index e67ad8602ee8b..b47bb84294f2e 100644 --- a/include/swift/Sema/CSBindings.h +++ b/include/swift/Sema/CSBindings.h @@ -527,6 +527,12 @@ class BindingSet { void forEachLiteralRequirement( llvm::function_ref callback) const; + void forEachAdjacentVariable( + llvm::function_ref callback) const { + for (auto *typeVar : AdjacentVars) + callback(typeVar); + } + /// Return a literal requirement that has the most impact on the binding /// score. LiteralBindingKind getLiteralForScore() const; diff --git a/lib/Sema/CSOptimizer.cpp b/lib/Sema/CSOptimizer.cpp index a93ee36bf7423..ec805d72ecfc8 100644 --- a/lib/Sema/CSOptimizer.cpp +++ b/lib/Sema/CSOptimizer.cpp @@ -23,6 +23,7 @@ #include "swift/Sema/ConstraintSystem.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TinyPtrVector.h" #include "llvm/Support/SaveAndRestore.h" @@ -46,8 +47,97 @@ struct DisjunctionInfo { DisjunctionInfo() = default; DisjunctionInfo(double score, ArrayRef favoredChoices = {}) : Score(score), FavoredChoices(favoredChoices) {} + + bool allGeneric() const { + if (FavoredChoices.empty()) + return false; + + return llvm::all_of(FavoredChoices, [](Constraint *choice) { + if (auto *decl = getOverloadChoiceDecl(choice)) + return decl->getInterfaceType()->is(); + return false; + }); + } }; +static DeclContext *getDisjunctionDC(Constraint *disjunction) { + auto *choice = disjunction->getNestedConstraints()[0]; + switch (choice->getKind()) { + case ConstraintKind::BindOverload: + return choice->getOverloadUseDC(); + case ConstraintKind::ValueMember: + case ConstraintKind::UnresolvedValueMember: + case ConstraintKind::ValueWitness: + return choice->getMemberUseDC(); + default: + return nullptr; + } +} + +static std::optional isBeforeInBuffer(ASTContext &ctx, + Constraint *disjunctionA, + Constraint *disjunctionB) { + auto *anchorA = getAsExpr(disjunctionA->getLocator()->getAnchor()); + auto *anchorB = getAsExpr(disjunctionB->getLocator()->getAnchor()); + + auto locA = anchorA ? anchorA->getLoc() : SourceLoc(); + auto locB = anchorB ? anchorB->getLoc() : SourceLoc(); + + if (!locA || !locB) + return std::nullopt; + + return ctx.SourceMgr.isBeforeInBuffer(locA, locB); +} + +/// Determine whether the given disjunction appears in a context +/// transformed by a result builder. +static bool isInResultBuilderContext(ConstraintSystem &cs, + Constraint *disjunction) { + auto *DC = getDisjunctionDC(disjunction); + if (!DC) + return false; + + do { + auto fnContext = AnyFunctionRef::fromDeclContext(DC); + if (!fnContext) + return false; + + if (cs.getAppliedResultBuilderTransform(*fnContext)) + return true; + + } while ((DC = DC->getParent())); + + return false; +} + +/// Check whether this given disjunction appears in an operator context. +/// It could be a binary operator chain i.e. `1 + Double(x)`, an unary +/// operator i.e. `-Test(x)`, or a ternary `cond ? 1 + 2 : 3` +static bool isInOperatorContext(ConstraintSystem &cs, Constraint *disjunction) { + auto *curr = castToExpr(disjunction->getLocator()->getAnchor()); + + do { + switch (curr->getKind()) { + /// Something like `1 + arr.map { <> }` + case ExprKind::Closure: + return false; + + case ExprKind::Binary: + case ExprKind::PrefixUnary: + case ExprKind::PostfixUnary: + return true; + + case ExprKind::Ternary: + return true; + + default: + break; + } + } while ((curr = cs.getParentExpr(curr))); + + return false; +} + // TODO: both `isIntegerType` and `isFloatType` should be available on Type // as `isStdlib{Integer, Float}Type`. @@ -68,6 +158,12 @@ static bool isUnboundArrayType(Type type) { return false; } +static bool isUnboundDictionaryType(Type type) { + if (auto *UGT = type->getAs()) + return UGT->getDecl() == type->getASTContext().getDictionaryDecl(); + return false; +} + static bool isSupportedOperator(Constraint *disjunction) { if (!isOperatorDisjunction(disjunction)) return false; @@ -77,7 +173,7 @@ static bool isSupportedOperator(Constraint *disjunction) { auto name = decl->getBaseIdentifier(); if (name.isArithmeticOperator() || name.isStandardComparisonOperator() || - name.isBitwiseOperator()) { + name.isBitwiseOperator() || name.isNilCoalescingOperator()) { return true; } @@ -106,15 +202,58 @@ static bool isStandardComparisonOperator(ValueDecl *decl) { decl->getBaseIdentifier().isStandardComparisonOperator(); } +static bool isStandardComparisonOperatorDisjunction(Constraint *disjunction) { + auto *choice = disjunction->getNestedConstraints()[0]; + if (auto *decl = getOverloadChoiceDecl(choice)) + return isStandardComparisonOperator(decl); + return false; +} + static bool isArithmeticOperator(ValueDecl *decl) { return decl->isOperator() && decl->getBaseIdentifier().isArithmeticOperator(); } +/// Generic choices are supported only if they are not complex enough +/// that would they'd require solving to figure out whether they are a +/// potential match or not. +static bool isSupportedGenericOverloadChoice(ValueDecl *decl, + GenericFunctionType *choiceType) { + // Same type requirements cannot be handled because each + // candidate-parameter pair is (currently) considered in isolation. + if (llvm::any_of(choiceType->getRequirements(), [](const Requirement &req) { + switch (req.getKind()) { + case RequirementKind::SameType: + case RequirementKind::SameShape: + return true; + + case RequirementKind::Conformance: + case RequirementKind::Superclass: + case RequirementKind::Layout: + return false; + } + })) + return false; + + // If there are no same-type requirements, allow signatures + // that use only concrete types or generic parameters directly + // in their parameter positions i.e. `(T, Int)`. + + auto *paramList = getParameterList(decl); + if (!paramList) + return false; + + return llvm::all_of(paramList->getArray(), [](const ParamDecl *P) { + auto paramType = P->getInterfaceType(); + return paramType->is() || + !paramType->hasTypeParameter(); + }); +} + static bool isSupportedDisjunction(Constraint *disjunction) { auto choices = disjunction->getNestedConstraints(); - if (isSupportedOperator(disjunction)) - return true; + if (isOperatorDisjunction(disjunction)) + return isSupportedOperator(disjunction); if (auto *ctor = dyn_cast_or_null( getOverloadChoiceDecl(choices.front()))) { @@ -123,7 +262,7 @@ static bool isSupportedDisjunction(Constraint *disjunction) { } // Non-operator disjunctions are supported only if they don't - // have any generic choices. + // have any complex generic choices. return llvm::all_of(choices, [&](Constraint *choice) { if (choice->isDisabled()) return true; @@ -137,13 +276,81 @@ static bool isSupportedDisjunction(Constraint *disjunction) { if (decl->isImplicitlyUnwrappedOptional()) return false; - return decl->getInterfaceType()->is(); + auto choiceType = decl->getInterfaceType()->getAs(); + if (!choiceType || choiceType->hasError()) + return false; + + // Non-generic choices are always supported. + if (choiceType->is()) + return true; + + if (auto *genericFn = choiceType->getAs()) + return isSupportedGenericOverloadChoice(decl, genericFn); + + return false; } return false; }); } +/// Given the type variable that represents a result type of a +/// function call, check whether that call is to an initializer +/// and based on that deduce possible type for the result. +/// +/// @return A type and a flag that indicates whether there +/// are any viable failable overloads and empty pair if the +/// type variable isn't a result of an initializer call. +static llvm::PointerIntPair +inferTypeFromInitializerResultType(ConstraintSystem &cs, + TypeVariableType *typeVar, + ArrayRef disjunctions) { + assert(typeVar->getImpl().isFunctionResult()); + + auto *resultLoc = typeVar->getImpl().getLocator(); + auto *call = getAsExpr(resultLoc->getAnchor()); + if (!call) + return {}; + + auto *fn = call->getFn()->getSemanticsProvidingExpr(); + + Type instanceTy; + ConstraintLocator *ctorLocator = nullptr; + if (auto *typeExpr = getAsExpr(fn)) { + instanceTy = cs.getType(typeExpr)->getMetatypeInstanceType(); + ctorLocator = + cs.getConstraintLocator(call, {LocatorPathElt::ApplyFunction(), + LocatorPathElt::ConstructorMember()}); + } else if (auto *UDE = getAsExpr(fn)) { + if (!UDE->getName().getBaseName().isConstructor()) + return {}; + instanceTy = cs.getType(UDE->getBase())->getMetatypeInstanceType(); + ctorLocator = cs.getConstraintLocator(UDE, LocatorPathElt::Member()); + } + + if (!instanceTy || !ctorLocator) + return {}; + + auto initRef = + llvm::find_if(disjunctions, [&ctorLocator](Constraint *disjunction) { + return disjunction->getLocator() == ctorLocator; + }); + + if (initRef == disjunctions.end()) + return {}; + + bool hasFailable = + llvm::any_of((*initRef)->getNestedConstraints(), [](Constraint *choice) { + if (choice->isDisabled()) + return false; + auto *decl = + dyn_cast_or_null(getOverloadChoiceDecl(choice)); + return decl && decl->isFailable(); + }); + + return {instanceTy, hasFailable}; +} + NullablePtr getApplicableFnConstraint(ConstraintGraph &CG, Constraint *disjunction) { auto *boundVar = disjunction->getNestedConstraints()[0] @@ -281,6 +488,52 @@ static void findFavoredChoicesBasedOnArity( favoredChoice(choice); } +/// Preserves old behavior where, for unary calls, the solver would not previously +/// consider choices that didn't match on the number of parameters (regardless of +/// defaults and variadics) and only exact matches were favored. +static std::optional preserveFavoringOfUnlabeledUnaryArgument( + ConstraintSystem &cs, Constraint *disjunction, ArgumentList *argumentList) { + if (!argumentList->isUnlabeledUnary()) + return std::nullopt; + + auto ODRE = isOverloadedDeclRef(disjunction); + bool preserveFavoringOfUnlabeledUnaryArgument = + !ODRE || numOverloadChoicesMatchingOnArity(ODRE, argumentList) < 2; + + if (!preserveFavoringOfUnlabeledUnaryArgument) + return std::nullopt; + + auto *argument = + argumentList->getUnlabeledUnaryExpr()->getSemanticsProvidingExpr(); + // The hack operated on "favored" types and only declaration references, + // applications, and (dynamic) subscripts had them if they managed to + // get an overload choice selected during constraint generation. + if (!(isExpr(argument) || isExpr(argument) || + isExpr(argument) || + isExpr(argument))) + return {/*score=*/0}; + + auto argumentType = cs.getType(argument); + if (argumentType->hasTypeVariable() || argumentType->hasDependentMember()) + return {/*score=*/0}; + + SmallVector favoredChoices; + forEachDisjunctionChoice( + cs, disjunction, + [&argumentType, &favoredChoices](Constraint *choice, ValueDecl *decl, + FunctionType *overloadType) { + if (overloadType->getNumParams() != 1) + return; + + auto paramType = overloadType->getParams()[0].getPlainType(); + if (paramType->isEqual(argumentType)) + favoredChoices.push_back(choice); + }); + + return DisjunctionInfo(/*score=*/favoredChoices.empty() ? 0 : 1, + favoredChoices); +} + } // end anonymous namespace /// Given a set of disjunctions, attempt to determine @@ -387,6 +640,16 @@ static void determineBestChoicesInContext( } } + // Preserves old behavior where, for unary calls, the solver + // would not consider choices that didn't match on the number + // of parameters (regardless of defaults) and only exact + // matches were favored. + if (auto info = preserveFavoringOfUnlabeledUnaryArgument(cs, disjunction, + argumentList)) { + recordResult(disjunction, std::move(info.value())); + continue; + } + if (!isSupportedDisjunction(disjunction)) continue; @@ -417,41 +680,64 @@ static void determineBestChoicesInContext( llvm::TinyPtrVector resultTypes; + bool hasArgumentCandidates = false; for (unsigned i = 0, n = argFuncType->getNumParams(); i != n; ++i) { const auto ¶m = argFuncType->getParams()[i]; auto argType = cs.simplifyType(param.getPlainType()); + SmallVector optionals; + // i.e. `??` operator could produce an optional type + // so `test(<> ?? 0) could result in an optional + // argument that wraps a type variable. It should be possible + // to infer bindings from underlying type variable and restore + // optionality. + if (argType->hasTypeVariable()) { + if (auto *typeVar = argType->lookThroughAllOptionalTypes(optionals) + ->getAs()) + argType = typeVar; + } + SmallVector types; if (auto *typeVar = argType->getAs()) { auto bindingSet = cs.getBindingsFor(typeVar); + auto restoreOptionality = [](Type type, unsigned numOptionals) { + for (unsigned i = 0; i != numOptionals; ++i) + type = type->wrapInOptionalType(); + return type; + }; + for (const auto &binding : bindingSet.Bindings) { - types.push_back({binding.BindingType}); + auto type = restoreOptionality(binding.BindingType, optionals.size()); + types.push_back({type}); } for (const auto &literal : bindingSet.Literals) { if (literal.second.hasDefaultType()) { // Add primary default type - types.push_back( - {literal.second.getDefaultType(), /*fromLiteral=*/true}); + auto type = restoreOptionality(literal.second.getDefaultType(), + optionals.size()); + types.push_back({type, + /*fromLiteral=*/true}); } } - // Helps situations like `1 + {Double, CGFloat}(...)` by inferring - // a type for the second operand of `+` based on a type being constructed. - // - // Currently limited to Double and CGFloat only since we need to - // support implicit `Double<->CGFloat` conversion. - if (typeVar->getImpl().isFunctionResult() && - isOperatorDisjunction(disjunction)) { - auto resultLoc = typeVar->getImpl().getLocator(); - if (auto *call = getAsExpr(resultLoc->getAnchor())) { - if (auto *typeExpr = dyn_cast(call->getFn())) { - auto instanceTy = cs.getType(typeExpr)->getMetatypeInstanceType(); - if (instanceTy->isDouble() || instanceTy->isCGFloat()) - types.push_back({instanceTy, /*fromLiteral=*/false, - /*fromInitializerCall=*/true}); - } + // Help situations like `1 + {Double, CGFloat}(...)` by inferring + // a type for the second operand of `+` based on a type being + // constructed. + if (typeVar->getImpl().isFunctionResult()) { + auto binding = + inferTypeFromInitializerResultType(cs, typeVar, disjunctions); + + if (auto instanceTy = binding.getPointer()) { + types.push_back({instanceTy, + /*fromLiteral=*/false, + /*fromInitializerCall=*/true}); + + if (binding.getInt()) + types.push_back({instanceTy->wrapInOptionalType(), + /*fromLiteral=*/false, + /*fromInitializerCall=*/true}); } } } else { @@ -459,6 +745,7 @@ static void determineBestChoicesInContext( } argumentCandidates[i].append(types); + hasArgumentCandidates |= !types.empty(); } auto resultType = cs.simplifyType(argFuncType->getResult()); @@ -468,6 +755,19 @@ static void determineBestChoicesInContext( for (const auto &binding : bindingSet.Bindings) { resultTypes.push_back(binding.BindingType); } + + // Infer bindings for each side of a ternary condition. + bindingSet.forEachAdjacentVariable( + [&cs, &resultTypes](TypeVariableType *adjacentVar) { + auto *adjacentLoc = adjacentVar->getImpl().getLocator(); + // This is one of the sides of a ternary operator. + if (adjacentLoc->directlyAt()) { + auto adjacentBindings = cs.getBindingsFor(adjacentVar); + + for (const auto &binding : adjacentBindings.Bindings) + resultTypes.push_back(binding.BindingType); + } + }); } else { resultTypes.push_back(resultType); } @@ -476,7 +776,7 @@ static void determineBestChoicesInContext( // This information is going to be used later on when we need to decide how to // score a matching choice. bool onlyLiteralCandidates = - argFuncType->getNumParams() > 0 && + hasArgumentCandidates && llvm::none_of( indices(argFuncType->getParams()), [&](const unsigned argIdx) { auto &candidates = argumentCandidates[argIdx]; @@ -598,32 +898,21 @@ static void determineBestChoicesInContext( } } - // Match `[...]` to Array<...> and/or `ExpressibleByArrayLiteral` - // conforming types. - if (options.contains(MatchFlag::OnParam) && - options.contains(MatchFlag::Literal) && - isUnboundArrayType(candidateType)) { + if (options.contains(MatchFlag::ExactOnly)) { // If an exact match is requested favor only `[...]` to `Array<...>` // since everything else is going to increase to score. - if (options.contains(MatchFlag::ExactOnly)) - return paramType->isArrayType() ? 1 : 0; - - // Otherwise, check if the other side conforms to - // `ExpressibleByArrayLiteral` protocol (in some way). - // We want an overly optimistic result here to avoid - // under-favoring. - auto &ctx = cs.getASTContext(); - return checkConformanceWithoutContext( - paramType, - ctx.getProtocol( - KnownProtocolKind::ExpressibleByArrayLiteral), - /*allowMissing=*/true) - ? 0.3 - : 0; - } + if (options.contains(MatchFlag::Literal)) { + if (isUnboundArrayType(candidateType)) + return paramType->isArrayType() ? 0.3 : 0; + + if (isUnboundDictionaryType(candidateType)) + return cs.isDictionaryType(paramType) ? 0.3 : 0; + } - if (options.contains(MatchFlag::ExactOnly)) - return areEqual(candidateType, paramType) ? 1 : 0; + if (!areEqual(candidateType, paramType)) + return 0; + return options.contains(MatchFlag::Literal) ? 0.3 : 1; + } // Exact match between candidate and parameter types. if (areEqual(candidateType, paramType)) { @@ -631,20 +920,73 @@ static void determineBestChoicesInContext( } if (options.contains(MatchFlag::Literal)) { - // Integer and floating-point literals can match any parameter - // type that conforms to `ExpressibleBy{Integer, Float}Literal` - // protocol but since that would constitute a non-default binding - // the score has to be slightly lowered. - if (!paramType->hasTypeParameter()) { + if (paramType->hasTypeParameter() || + paramType->isAnyExistentialType()) { + // Attempt to match literal default to generic parameter. + // This helps to determine whether there are any generic + // overloads that are a possible match. + auto score = + scoreCandidateMatch(genericSig, choice, candidateType, + paramType, options - MatchFlag::Literal); + if (score == 0) + return 0; + + // Optional injection lowers the score for operators to match + // pre-optimizer behavior. + return choice->isOperator() && paramType->getOptionalObjectType() + ? 0.2 + : 0.3; + } else { + // Integer and floating-point literals can match any parameter + // type that conforms to `ExpressibleBy{Integer, Float}Literal` + // protocol. Since this assessment is done in isolation we don't + // lower the score even though this would be a non-default binding + // for a literal. if (candidateType->isInt() && TypeChecker::conformsToKnownProtocol( paramType, KnownProtocolKind::ExpressibleByIntegerLiteral)) - return paramType->isDouble() ? 0.2 : 0.3; + return 0.3; if (candidateType->isDouble() && TypeChecker::conformsToKnownProtocol( paramType, KnownProtocolKind::ExpressibleByFloatLiteral)) return 0.3; + + if (candidateType->isBool() && + TypeChecker::conformsToKnownProtocol( + paramType, KnownProtocolKind::ExpressibleByBooleanLiteral)) + return 0.3; + + if (candidateType->isString() && + (TypeChecker::conformsToKnownProtocol( + paramType, KnownProtocolKind::ExpressibleByStringLiteral) || + TypeChecker::conformsToKnownProtocol( + paramType, + KnownProtocolKind::ExpressibleByStringInterpolation))) + return 0.3; + + auto &ctx = cs.getASTContext(); + + // Check if the other side conforms to `ExpressibleByArrayLiteral` + // protocol (in some way). We want an overly optimistic result + // here to avoid under-favoring. + if (candidateType->isArray() && + checkConformanceWithoutContext( + paramType, + ctx.getProtocol(KnownProtocolKind::ExpressibleByArrayLiteral), + /*allowMissing=*/true)) + return 0.3; + + // Check if the other side conforms to + // `ExpressibleByDictionaryLiteral` protocol (in some way). + // We want an overly optimistic result here to avoid under-favoring. + if (candidateType->isDictionary() && + checkConformanceWithoutContext( + paramType, + ctx.getProtocol( + KnownProtocolKind::ExpressibleByDictionaryLiteral), + /*allowMissing=*/true)) + return 0.3; } return 0; @@ -666,7 +1008,8 @@ static void determineBestChoicesInContext( // Injection lowers the score slightly to comply with // old behavior where exact matches on operator parameter // types were always preferred. - return score == 1 && choice->isOperator() ? 0.9 : score; + return score > 0 && choice->isOperator() ? score.value() - 0.1 + : score; } // Optionality mismatch. @@ -722,6 +1065,7 @@ static void determineBestChoicesInContext( // dependent member type (i.e. `Self.T`), let's check conformances // only and lower the score. if (candidateType->hasTypeVariable() || + candidateType->hasUnboundGenericType() || paramType->is()) { return checkProtocolRequirementsOnly(); } @@ -813,17 +1157,6 @@ static void determineBestChoicesInContext( double bestScore = 0.0; SmallVector, 2> favoredChoices; - // Preserves old behavior where, for unary calls, the solver - // would not consider choices that didn't match on the number - // of parameters (regardless of defaults) and only exact - // matches were favored. - bool preserveFavoringOfUnlabeledUnaryArgument = false; - if (argumentList->isUnlabeledUnary()) { - auto ODRE = isOverloadedDeclRef(disjunction); - preserveFavoringOfUnlabeledUnaryArgument = - !ODRE || numOverloadChoicesMatchingOnArity(ODRE, argumentList) < 2; - } - forEachDisjunctionChoice( cs, disjunction, [&](Constraint *choice, ValueDecl *decl, FunctionType *overloadType) { @@ -841,19 +1174,16 @@ static void determineBestChoicesInContext( if (!matchings) return; - // If all of the arguments are literals, let's prioritize exact - // matches to filter out non-default literal bindings which otherwise - // could cause "over-favoring". - bool favorExactMatchesOnly = onlyLiteralCandidates; - - if (preserveFavoringOfUnlabeledUnaryArgument) { - // Old behavior completely disregarded the fact that some of - // the parameters could be defaulted. - if (overloadType->getNumParams() != 1) - return; + auto canUseContextualResultTypes = [&decl]() { + return decl->isOperator() && !isStandardComparisonOperator(decl); + }; - favorExactMatchesOnly = true; - } + // Require exact matches only if all of the arguments + // are literals and there are no usable contextual result + // types that could help narrow favored choices. + bool favorExactMatchesOnly = + onlyLiteralCandidates && + (!canUseContextualResultTypes() || resultTypes.empty()); // This is important for SIMD operators in particular because // a lot of their overloads have same-type requires to a concrete @@ -900,6 +1230,9 @@ static void determineBestChoicesInContext( auto paramType = param.getPlainType(); + if (paramFlags.isAutoClosure()) + paramType = paramType->castTo()->getResult(); + // FIXME: Let's skip matching function types for now // because they have special rules for e.g. Concurrency // (around @Sendable) and @convention(c). @@ -968,10 +1301,7 @@ static void determineBestChoicesInContext( continue; } - // Only established arguments could be considered mismatches, - // literal default types should be regarded as holes if they - // didn't match. - if (!candidate.fromLiteral && !candidate.type->hasTypeVariable()) + if (!candidate.type->hasTypeVariable()) mismatches.set(candidateIdx); } @@ -992,41 +1322,16 @@ static void determineBestChoicesInContext( // parameters. score /= (overloadType->getNumParams() - numDefaulted); - // Make sure that the score is uniform for all disjunction - // choices that match on literals only, this would make sure that - // in operator chains that consist purely of literals we'd - // always prefer outermost disjunction instead of innermost - // one. - // - // Preferring outer disjunction first works better in situations - // when contextual type for the whole chain becomes available at - // some point during solving at it would allow for faster pruning. - if (score > 0 && onlyLiteralCandidates && decl->isOperator()) - score = 0.1; - // If one of the result types matches exactly, that's a good // indication that overload choice should be favored. // - // If nothing is known about the arguments it's only safe to - // check result for operators (except to standard comparison - // ones that all have the same result type), regular - // functions/methods and especially initializers could end up - // with a lot of favored overloads because on the result type alone. - if (decl->isOperator() && !isStandardComparisonOperator(decl)) { + // It's only safe to match result types of operators + // because regular functions/methods/subscripts and + // especially initializers could end up with a lot of + // favored overloads because on the result type alone. + if (canUseContextualResultTypes() && + (score > 0 || !hasArgumentCandidates)) { if (llvm::any_of(resultTypes, [&](const Type candidateResultTy) { - // Avoid increasing weight based on CGFloat result type - // match because that could require narrowing conversion - // in the arguments and that is always detrimental. - // - // For example, `has_CGFloat_param(1.0 + 2.0)` should use - // `+(_: Double, _: Double) -> Double` instead of - // `+(_: CGFloat, _: CGFloat) -> CGFloat` which would match - // parameter of `has_CGFloat_param` exactly but use a - // narrowing conversion for both literals. - if (candidateResultTy->lookThroughAllOptionalTypes() - ->isCGFloat()) - return false; - return scoreCandidateMatch(genericSig, decl, overloadType->getResult(), candidateResultTy, @@ -1051,7 +1356,7 @@ static void determineBestChoicesInContext( [&resultTy](const auto ¶m) { return param.getPlainType()->isEqual(resultTy); })) - score += 0.1; + score += 0.01; } favoredChoices.push_back({choice, score}); @@ -1191,77 +1496,60 @@ selectBestBindingDisjunction(ConstraintSystem &cs, return firstBindDisjunction; } -/// Prioritize `build{Block, Expression, ...}` and any chained -/// members that are connected to individual builder elements -/// i.e. `ForEach(...) { ... }.padding(...)`, once `ForEach` -/// is resolved, `padding` should be prioritized because its -/// requirements can help prune the solution space before the -/// body is checked. -static Constraint * -selectDisjunctionInResultBuilderContext(ConstraintSystem &cs, - ArrayRef disjunctions) { - auto context = AnyFunctionRef::fromDeclContext(cs.DC); - if (!context) - return nullptr; - - if (!cs.getAppliedResultBuilderTransform(context.value())) - return nullptr; - - std::pair best{nullptr, 0}; - for (auto *disjunction : disjunctions) { - auto *member = - getAsExpr(disjunction->getLocator()->getAnchor()); - if (!member) - continue; +static std::optional isPreferable(ConstraintSystem &cs, + Constraint *disjunctionA, + const DisjunctionInfo &infoA, + Constraint *disjunctionB, + const DisjunctionInfo &infoB) { + // Disfavor standard comparison operators when other wise has some + // non-generic favored choices. Operators like `==` and `!=` + // tend to get multiple generic matches through conformance to `Equatable` + // and it's better to attempt them as a last resort. + { + if (isStandardComparisonOperatorDisjunction(disjunctionA) && + isOperatorDisjunction(disjunctionB)) { + if (infoA.allGeneric() && (infoB.Score > 0 && !infoB.allGeneric())) + return false; + } - // Attempt `build{Block, Expression, ...} first because they - // provide contextual information for the inner calls. - if (isResultBuilderMethodReference(cs.getASTContext(), member)) - return disjunction; - - Expr *curr = member; - bool disqualified = false; - // Walk up the parent expression chain and check whether this - // disjunction represents one of the members in a chain that - // leads up to `buildExpression` (if defined by the builder) - // or to a pattern binding for `$__builderN` (the walk won't - // find any argument position locations in that case). - while (auto parent = cs.getParentExpr(curr)) { - if (!(isExpr(parent) || isExpr(parent))) { - disqualified = true; - break; - } + if (isStandardComparisonOperatorDisjunction(disjunctionB) && + isOperatorDisjunction(disjunctionA)) { + if (infoB.allGeneric() && (infoA.Score > 0 && !infoA.allGeneric())) + return true; + } + } - if (auto *call = getAsExpr(parent)) { - // The current parent appears in an argument position. - if (call->getFn() != curr) { - // Allow expressions that appear in a argument position to - // `build{Expression, Block, ...} methods. - if (auto *UDE = getAsExpr(call->getFn())) { - disqualified = - !isResultBuilderMethodReference(cs.getASTContext(), UDE); - } else { - disqualified = true; - } - } - } + // If both sides are either operators or non-operators, there is + // no preference. + if (isOperatorDisjunction(disjunctionA) == + isOperatorDisjunction(disjunctionB)) + return std::nullopt; - if (disqualified) - break; + // Prefer outer disjunctions to inner ones. This would make sure that + // i.e. we don't select operators inside of a closure before outer + // members that precede it are selected. + { + auto *dcA = getDisjunctionDC(disjunctionA); + auto *dcB = getDisjunctionDC(disjunctionB); - curr = parent; + if (dcA && dcB && dcA != dcB) { + return isBeforeInBuffer(cs.getASTContext(), disjunctionA, disjunctionB); } + } - if (disqualified) - continue; - - if (auto depth = cs.getExprDepth(member)) { - if (!best.first || best.second > depth) - best = std::make_pair(disjunction, depth.value()); + // If disjunctions appear in the same declaration context and it + // happens to be a result builder, we need to prefer members + // over operators if member doesn't appear in an operator context + // itself. + if (isInResultBuilderContext(cs, disjunctionA)) { + if (isOperatorDisjunction(disjunctionA)) { + return isInOperatorContext(cs, disjunctionB); + } else { + return !isInOperatorContext(cs, disjunctionA); } } - return best.first; + return std::nullopt; } std::optional>> @@ -1278,11 +1566,6 @@ ConstraintSystem::selectDisjunction() { llvm::DenseMap favorings; determineBestChoicesInContext(*this, disjunctions, favorings); - if (auto *disjunction = - selectDisjunctionInResultBuilderContext(*this, disjunctions)) { - return std::make_pair(disjunction, favorings[disjunction].FavoredChoices); - } - // Pick the disjunction with the smallest number of favored, then active // choices. auto bestDisjunction = std::min_element( @@ -1291,19 +1574,28 @@ ConstraintSystem::selectDisjunction() { unsigned firstActive = first->countActiveNestedConstraints(); unsigned secondActive = second->countActiveNestedConstraints(); - auto &[firstScore, firstFavoredChoices] = favorings[first]; - auto &[secondScore, secondFavoredChoices] = favorings[second]; + if (firstActive == 1 || secondActive == 1) + return secondActive != 1; + + auto &firstInfo = favorings[first]; + auto &secondInfo = favorings[second]; + + // Determine whether `first` is better based on a non-score + // preference rule first. + if (auto preference = + isPreferable(*this, first, firstInfo, second, secondInfo)) + return preference.value(); // Rank based on scores only if both disjunctions are supported. - if (firstScore && secondScore) { + if (firstInfo.Score && secondInfo.Score) { // If both disjunctions have the same score they should be ranked // based on number of favored/active choices. - if (*firstScore != *secondScore) - return *firstScore > *secondScore; + if (*firstInfo.Score != *secondInfo.Score) + return *firstInfo.Score > *secondInfo.Score; } - unsigned numFirstFavored = firstFavoredChoices.size(); - unsigned numSecondFavored = secondFavoredChoices.size(); + unsigned numFirstFavored = firstInfo.FavoredChoices.size(); + unsigned numSecondFavored = secondInfo.FavoredChoices.size(); if (numFirstFavored == numSecondFavored) { if (firstActive != secondActive) diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp index 1e9d4e9bb4a25..1c73e827b6099 100644 --- a/lib/Sema/CSSolver.cpp +++ b/lib/Sema/CSSolver.cpp @@ -1687,6 +1687,8 @@ void DisjunctionChoiceProducer::partitionDisjunction( // end of the partitioning. SmallVector favored; SmallVector everythingElse; + // Disfavored choices are part of `everythingElse` but introduced at the end. + SmallVector disfavored; SmallVector simdOperators; SmallVector disabled; SmallVector unavailable; @@ -1719,6 +1721,11 @@ void DisjunctionChoiceProducer::partitionDisjunction( everythingElse.push_back(index); return true; } + + if (decl->getAttrs().hasAttribute()) { + disfavored.push_back(index); + return true; + } } return false; @@ -1762,6 +1769,9 @@ void DisjunctionChoiceProducer::partitionDisjunction( return true; }); + // Introduce disfavored choices at the end. + everythingElse.append(disfavored); + // Local function to create the next partition based on the options // passed in. PartitionAppendCallback appendPartition = diff --git a/test/Constraints/implicit_double_cgfloat_conversion.swift b/test/Constraints/implicit_double_cgfloat_conversion.swift index b34d8a88e566d..6ebeb313fe30f 100644 --- a/test/Constraints/implicit_double_cgfloat_conversion.swift +++ b/test/Constraints/implicit_double_cgfloat_conversion.swift @@ -363,6 +363,10 @@ do { } func test_cgfloat_operator_is_attempted_with_literal_arguments(v: CGFloat?) { + // Make sure that @autoclosure thunk calls CGFloat./ and not Double./ + // CHECK-LABEL: sil private [transparent] [ossa] @$s34implicit_double_cgfloat_conversion05test_C45_operator_is_attempted_with_literal_arguments1vy12CoreGraphics7CGFloatVSg_tFAFyKXEfu_ + // CHECK: [[CGFLOAT_DIV_OP:%.*]] = function_ref @$s12CoreGraphics7CGFloatV34implicit_double_cgfloat_conversionE1doiyA2C_ACtFZ : $@convention(method) (CGFloat, CGFloat, @thin CGFloat.Type) -> CGFloat + // CHECK-NEXT: {{.*}} = apply [[CGFLOAT_DIV_OP]]({{.*}}, %2) : $@convention(method) (CGFloat, CGFloat, @thin CGFloat.Type) -> CGFloat let ratio = v ?? (2.0 / 16.0) let _: CGFloat = ratio // Ok } diff --git a/test/Constraints/nil-coalescing-favoring.swift b/test/Constraints/nil-coalescing-favoring.swift index 1723900d5a58d..5dbdeb1f69b10 100644 --- a/test/Constraints/nil-coalescing-favoring.swift +++ b/test/Constraints/nil-coalescing-favoring.swift @@ -12,3 +12,19 @@ struct A { self.init(other ?? ._none) } } + +do { + class Super {} + class Sub: Super {} + + func flatMap(_: (Int) -> R?) -> R? {} + + func test() { + let dict: Dictionary + let sup: Super + + // CHECK: declref_expr type="(consuming Super?, @autoclosure () throws -> Super) throws -> Super" {{.*}} decl="Swift.(file).?? + let x = flatMap { dict[$0] } ?? sup // Ok + let _: Super = x + } +} diff --git a/test/Constraints/old_hack_related_ambiguities.swift b/test/Constraints/old_hack_related_ambiguities.swift index da8da45cc76dc..c9e268948d440 100644 --- a/test/Constraints/old_hack_related_ambiguities.swift +++ b/test/Constraints/old_hack_related_ambiguities.swift @@ -6,14 +6,18 @@ func entity(_: Int) -> Int { struct Test { func test(_ v: Int) -> Int { v } + // expected-note@-1 {{found this candidate}} func test(_ v: Int?) -> Int? { v } + // expected-note@-1 {{found this candidate}} } func test_ternary_literal(v: Test) -> Int? { - true ? v.test(0) : nil // Ok + // Literals don't have a favored type + true ? v.test(0) : nil // expected-error {{ambiguous use of 'test'}} } func test_ternary(v: Test) -> Int? { + // Because calls had favored types set if they were resolved during constraint generation. true ? v.test(entity(0)) : nil // Ok } @@ -159,12 +163,14 @@ do { var p: UnsafeMutableRawPointer { get { fatalError() } } func f(_ p: UnsafeMutableRawPointer) { + // The old hack (which is now removed) couldn't handle member references, only direct declaration references. guard let x = UnsafeMutablePointer(OpaquePointer(self.p)) else { return } _ = x guard let x = UnsafeMutablePointer(OpaquePointer(p)) else { + // expected-error@-1 {{initializer for conditional binding must have Optional type, not 'UnsafeMutablePointer'}} return } _ = x @@ -257,3 +263,15 @@ func test_non_default_literal_use(arg: Float) { let v = arg * 2.0 // shouldn't use `(Float, Double) -> Double` overload let _: Float = v // Ok } + +// This should be ambiguous without contextual type but was accepted before during to +// unlabeled unary argument favoring. +func test_variadic_static_member_is_preferred_over_partially_applied_instance_overload() { + struct Test { + func fn() {} + static func fn(_: Test...) {} + } + + let t: Test + Test.fn(t) // Ok +} diff --git a/test/Constraints/overload.swift b/test/Constraints/overload.swift index 59f918dddc257..0f2c73ff2e13e 100644 --- a/test/Constraints/overload.swift +++ b/test/Constraints/overload.swift @@ -349,3 +349,25 @@ do { } } } + +// Make sure that the solver properly handles mix of non-default integer and floating-point literals +do { + func test( + withInitialValue initialValue: Float, + increment: Float, + count: Int) -> [Float] {} + + func test( + withInitialValue initialValue: Double, + increment: Double, + count: Int) -> [Double] {} + + + func testDoubleVsFloat(count: Int) { + let returnedResult = test(withInitialValue: 0, + increment: 0.1, + count: count) + + let _: [Double] = returnedResult // Ok + } +} diff --git a/validation-test/Sema/SwiftUI/swiftui_multiple_chained_members_in_inner_closure.swift b/validation-test/Sema/SwiftUI/swiftui_multiple_chained_members_in_inner_closure.swift new file mode 100644 index 0000000000000..dbfc309569c85 --- /dev/null +++ b/validation-test/Sema/SwiftUI/swiftui_multiple_chained_members_in_inner_closure.swift @@ -0,0 +1,34 @@ +// RUN: %target-typecheck-verify-swift -target %target-cpu-apple-macosx12 -solver-scope-threshold=10000 + +// REQUIRES: OS=macosx +// REQUIRES: objc_interop + +import Foundation +import SwiftUI + +struct MyView: View { + public enum Style { + case focusRing(platterSize: CGSize, stroke: CGFloat, offset: CGFloat) + } + + var style: Style + var isFocused: Bool + var focusColor: Color + + var body: some View { + Group { + switch style { + case let .focusRing(platterSize: platterSize, stroke: focusRingStroke, offset: focusRingOffset): + Circle() + .overlay { + Circle() + .stroke(isFocused ? focusColor : Color.clear, lineWidth: focusRingStroke) + .frame( + width: platterSize.width + (2 * focusRingOffset) + focusRingStroke, + height: platterSize.height + (2 * focusRingOffset) + focusRingStroke + ) + } + } + } + } +} diff --git a/validation-test/Sema/implicit_cgfloat_double_conversion_correctness.swift b/validation-test/Sema/implicit_cgfloat_double_conversion_correctness.swift index 121812a279c67..4aeb205e7f129 100644 --- a/validation-test/Sema/implicit_cgfloat_double_conversion_correctness.swift +++ b/validation-test/Sema/implicit_cgfloat_double_conversion_correctness.swift @@ -47,3 +47,8 @@ func test_atan_ambiguity(points: (CGPoint, CGPoint)) { test = atan((points.1.y - points.0.y) / (points.1.x - points.0.x)) // Ok _ = test } + +func test_ambigity_with_generic_funcs(a: CGFloat, b: CGFloat) -> [CGFloat] { + let result = [round(abs(a - b) * 100) / 100.0] + return result +} diff --git a/validation-test/Sema/issue78371.swift b/validation-test/Sema/issue78371.swift index e986ad173e291..aaf39a8fb4fa2 100644 --- a/validation-test/Sema/issue78371.swift +++ b/validation-test/Sema/issue78371.swift @@ -16,7 +16,12 @@ extension Optional where Wrapped == Scalar { static func ==(_: Wrapped?, _: Wrapped?) -> Wrapped { } } +// FIXME: There is currently no way to fix this because even +// if the new overload of `==` is scored the same as concrete +// one it would be skipped because it's generic _if_ operator +// is attempted before initializer. + func test(a: Scalar) { let result = a == Scalar(0x07FD) - let _: Scalar = result // Ok + let _: Scalar = result // expected-error {{cannot convert value of type 'Bool' to specified type 'Scalar'}} } diff --git a/validation-test/Sema/type_checker_perf/fast/complex_swiftui_padding_conditions.swift b/validation-test/Sema/type_checker_perf/fast/complex_swiftui_padding_conditions.swift index df038a16ff4cc..4b96ac54851ea 100644 --- a/validation-test/Sema/type_checker_perf/fast/complex_swiftui_padding_conditions.swift +++ b/validation-test/Sema/type_checker_perf/fast/complex_swiftui_padding_conditions.swift @@ -1,4 +1,4 @@ -// RUN: %target-typecheck-verify-swift -target %target-cpu-apple-macosx10.15 -swift-version 5 +// RUN: %target-typecheck-verify-swift -target %target-cpu-apple-macosx10.15 -solver-scope-threshold=1000 // REQUIRES: OS=macosx import SwiftUI diff --git a/validation-test/Sema/type_checker_perf/fast/contextual_cgfloat_type_with_operator_chain_arguments.swift b/validation-test/Sema/type_checker_perf/fast/contextual_cgfloat_type_with_operator_chain_arguments.swift new file mode 100644 index 0000000000000..72fb47673c4ea --- /dev/null +++ b/validation-test/Sema/type_checker_perf/fast/contextual_cgfloat_type_with_operator_chain_arguments.swift @@ -0,0 +1,17 @@ +// RUN: %target-typecheck-verify-swift -solver-scope-threshold=50 + +// REQUIRES: OS=macosx,no_asan +// REQUIRES: objc_interop + +import Foundation + +struct Size { + var width: CGFloat + var height: CGFloat +} + +func frame(width: CGFloat?, height: CGFloat?) {} + +func test(size: Size?) { + frame(width: ((size?.width ?? 0) * 1) + 1.0, height: ((size?.height ?? 0) * 1) + 1.0) +} diff --git a/validation-test/Sema/type_checker_perf/fast/nil_coalescing_dictionary_values.swift.gyb b/validation-test/Sema/type_checker_perf/fast/nil_coalescing_dictionary_values.swift.gyb new file mode 100644 index 0000000000000..4c15caad1276c --- /dev/null +++ b/validation-test/Sema/type_checker_perf/fast/nil_coalescing_dictionary_values.swift.gyb @@ -0,0 +1,10 @@ +// RUN: %scale-test --begin 1 --end 10 --step 1 --select NumLeafScopes %s -Xfrontend=-typecheck +// REQUIRES: asserts, no_asan + +let x: Int? + +let _ = [ +%for i in range(0, N): + "%{i}" : x ?? 0, +%end +] diff --git a/validation-test/Sema/type_checker_perf/fast/operator_chains_separated_by_ternary.swift.gyb b/validation-test/Sema/type_checker_perf/fast/operator_chains_separated_by_ternary.swift.gyb new file mode 100644 index 0000000000000..9e3a298c66295 --- /dev/null +++ b/validation-test/Sema/type_checker_perf/fast/operator_chains_separated_by_ternary.swift.gyb @@ -0,0 +1,18 @@ +// RUN: %scale-test --begin 1 --end 12 --step 1 --select NumLeafScopes %s +// REQUIRES: asserts,no_asan + +func compute(_: UInt32) { +} + +func test(cond: Bool) { + compute(cond + ? 1 +%for i in range(1, N): + + 1 +%end + : 1 +%for i in range(1, N): + * 1 +%end + ) +} diff --git a/validation-test/Sema/type_checker_perf/fast/operators_inside_closure.swift b/validation-test/Sema/type_checker_perf/fast/operators_inside_closure.swift new file mode 100644 index 0000000000000..28ea48f2c6bcf --- /dev/null +++ b/validation-test/Sema/type_checker_perf/fast/operators_inside_closure.swift @@ -0,0 +1,9 @@ +// RUN: %target-typecheck-verify-swift -solver-scope-threshold=5000 + +// REQUIRES: tools-release,no_asan + +func test(_ ids: [UInt64]) { + _ = zip(ids[ids.indices.dropLast()], ids[ids.indices.dropFirst()]).map { pair in + ((pair.0 % 2 == 0) && (pair.1 % 2 == 1)) ? UInt64(pair.1 - pair.0) : 42 + } +} diff --git a/validation-test/Sema/type_checker_perf/fast/rdar144100160.swift b/validation-test/Sema/type_checker_perf/fast/rdar144100160.swift new file mode 100644 index 0000000000000..7e80fd5c3d6c8 --- /dev/null +++ b/validation-test/Sema/type_checker_perf/fast/rdar144100160.swift @@ -0,0 +1,12 @@ +// RUN: %target-typecheck-verify-swift -solver-scope-threshold=100 +// REQUIRES: asserts,no_asan + +typealias TimeInterval = Double + +struct Date { + func addingTimeInterval(_: TimeInterval) -> Date { Date() } +} + +func test(date: Date) { + _ = date.addingTimeInterval(TimeInterval(60 * 60 * 24 * 6 + 12 * 60 + 12 + 1)) +} diff --git a/validation-test/Sema/type_checker_perf/fast/rdar47492691.swift b/validation-test/Sema/type_checker_perf/fast/rdar47492691.swift index 47fcdcb0771af..7a25cca734f79 100644 --- a/validation-test/Sema/type_checker_perf/fast/rdar47492691.swift +++ b/validation-test/Sema/type_checker_perf/fast/rdar47492691.swift @@ -7,3 +7,7 @@ import simd func test(foo: CGFloat, bar: CGFloat) { _ = CGRect(x: 0.0 + 1.0, y: 0.0 + foo, width: 3.0 - 1 - 1 - 1.0, height: bar) } + +func test_with_generic_func_and_literals(bounds: CGRect) { + _ = CGRect(x: 0, y: 0, width: 1, height: bounds.height - 2 + bounds.height / 2 + max(bounds.height / 2, bounds.height / 2)) +} diff --git a/validation-test/Sema/type_checker_perf/slow/nil_coalescing_dictionary_values.swift.gyb b/validation-test/Sema/type_checker_perf/slow/nil_coalescing_dictionary_values.swift.gyb deleted file mode 100644 index 878c0a9dc78d1..0000000000000 --- a/validation-test/Sema/type_checker_perf/slow/nil_coalescing_dictionary_values.swift.gyb +++ /dev/null @@ -1,10 +0,0 @@ -// RUN: %scale-test --invert-result --begin 1 --end 8 --step 1 --select NumLeafScopes %s -Xfrontend=-typecheck -// REQUIRES: asserts, no_asan - -let x: Int? - -let _ = [ -%for i in range(0, N): - "k" : x ?? 0, -%end -]