diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index e2eb5e9eecc18..cfaecc9be7188 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -2052,7 +2052,7 @@ bool TypeVarBindingProducer::computeNext() { } } - if (NumTries == 0) { + if (newBindings.empty()) { // Add defaultable constraints (if any). for (auto *constraint : DelayedDefaults) { if (constraint->getKind() == ConstraintKind::FallbackType) { @@ -2065,6 +2065,9 @@ bool TypeVarBindingProducer::computeNext() { addNewBinding(getDefaultBinding(constraint)); } + + // Drop all of the default since we have converted them into bindings. + DelayedDefaults.clear(); } if (newBindings.empty()) diff --git a/unittests/Sema/BindingInferenceTests.cpp b/unittests/Sema/BindingInferenceTests.cpp index 6eb17238b7b63..9ff2046aa8e26 100644 --- a/unittests/Sema/BindingInferenceTests.cpp +++ b/unittests/Sema/BindingInferenceTests.cpp @@ -395,3 +395,43 @@ TEST_F(SemaTest, TestNoDoubleVoidClosureResultInference) { verifyInference(closureResultWithoutVoid, 3); } + +TEST_F(SemaTest, TestSupertypeInferenceWithDefaults) { + ConstraintSystemOptions options; + ConstraintSystem cs(DC, options); + + auto *genericArg = cs.createTypeVariable( + cs.getConstraintLocator({}, ConstraintLocator::GenericArgument), + /*options=*/0); + + // KeyPath i.e. \.utf8.count or something similar + auto keyPath = + BoundGenericType::get(Context.getKeyPathDecl(), /*parent=*/Type(), + {getStdlibType("String"), getStdlibType("Int")}); + + cs.addConstraint(ConstraintKind::Conversion, keyPath, genericArg, + cs.getConstraintLocator({})); + + cs.addConstraint(ConstraintKind::Defaultable, genericArg, Context.TheAnyType, + cs.getConstraintLocator({})); + + auto bindings = cs.getBindingsFor(genericArg); + TypeVarBindingProducer producer(bindings); + + llvm::SmallVector inferredTypes; + while (auto binding = producer()) { + ASSERT_TRUE(binding.has_value()); + inferredTypes.push_back(binding->getType()); + } + + // The inference should produce 4 types: KeyPath, + // PartialKeyPath, AnyKeyPath and Any - in that order. + + ASSERT_EQ(inferredTypes.size(), 4); + ASSERT_TRUE(inferredTypes[0]->isEqual(keyPath)); + ASSERT_TRUE(inferredTypes[1]->isEqual( + BoundGenericType::get(Context.getPartialKeyPathDecl(), + /*parent=*/Type(), {getStdlibType("String")}))); + ASSERT_TRUE(inferredTypes[2]->isEqual(getStdlibType("AnyKeyPath"))); + ASSERT_TRUE(inferredTypes[3]->isEqual(Context.TheAnyType)); +}