Skip to content

[CS] Clean up pack expansion environment handling a little #79529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/swift/Sema/CSTrail.def
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ LOCATOR_CHANGE(RecordedAppliedDisjunction, AppliedDisjunctions)
LOCATOR_CHANGE(RecordedMatchCallArgumentResult, argumentMatchingChoices)
LOCATOR_CHANGE(RecordedOpenedTypes, OpenedTypes)
LOCATOR_CHANGE(RecordedOpenedExistentialType, OpenedExistentialTypes)
LOCATOR_CHANGE(RecordedPackExpansionEnvironment, PackExpansionEnvironments)
LOCATOR_CHANGE(RecordedDefaultedConstraint, DefaultedConstraints)
LOCATOR_CHANGE(ResolvedOverload, ResolvedOverloads)
LOCATOR_CHANGE(RecordedArgumentList, ArgumentLists)
Expand Down Expand Up @@ -95,7 +94,8 @@ CHANGE(AddedConversionRestriction)
CHANGE(AddedFix)
CHANGE(AddedFixedRequirement)
CHANGE(RecordedOpenedPackExpansionType)
CHANGE(RecordedPackEnvironment)
CHANGE(RecordedPackElementExpansion)
CHANGE(RecordedPackExpansionEnvironment)
CHANGE(RecordedNodeType)
CHANGE(RecordedKeyPathComponentType)
CHANGE(RecordedResultBuilderTransform)
Expand All @@ -118,4 +118,4 @@ LAST_CHANGE(RetractedBinding)
#undef BINDING_RELATION_CHANGE
#undef SCORE_CHANGE
#undef LAST_CHANGE
#undef CHANGE
#undef CHANGE
9 changes: 7 additions & 2 deletions include/swift/Sema/CSTrail.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class SolverTrail {
ConstraintFix *TheFix;
ConstraintLocator *TheLocator;
PackExpansionType *TheExpansion;
PackExpansionExpr *TheExpansionExpr;
PackElementExpr *TheElement;
Expr *TheExpr;
Stmt *TheStmt;
Expand Down Expand Up @@ -213,7 +214,11 @@ class SolverTrail {

/// Create a change that recorded a mapping from a pack element expression
/// to its parent expansion expression.
static Change RecordedPackEnvironment(PackElementExpr *packElement);
static Change RecordedPackElementExpansion(PackElementExpr *packElement);

/// Create a change that records the GenericEnvironment for a given
/// PackExpansionExpr.
static Change RecordedPackExpansionEnvironment(PackExpansionExpr *expr);

/// Create a change that recorded an assignment of a type to an AST node.
static Change RecordedNodeType(ASTNode node, Type oldType);
Expand Down Expand Up @@ -307,4 +312,4 @@ class SolverTrail {
} // namespace constraints
} // namespace swift

#endif // SWIFT_SEMA_CSTRAIL_H
#endif // SWIFT_SEMA_CSTRAIL_H
65 changes: 44 additions & 21 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -1560,14 +1560,13 @@ class Solution {
llvm::DenseMap<PackExpansionType *, TypeVariableType *>
OpenedPackExpansionTypes;

/// The pack expansion environment that can open pack elements for
/// a given locator.
llvm::DenseMap<ConstraintLocator *, std::pair<UUID, Type>>
/// The generic environment that can open pack elements for a given
/// pack expansion.
llvm::DenseMap<PackExpansionExpr *, GenericEnvironment *>
PackExpansionEnvironments;

/// The pack expansion environment that can open a given pack element.
llvm::DenseMap<PackElementExpr *, PackExpansionExpr *>
PackEnvironments;
/// The pack expansion expression for a given pack element.
llvm::DenseMap<PackElementExpr *, PackExpansionExpr *> PackElementExpansions;

/// The locators of \c Defaultable constraints whose defaults were used.
llvm::DenseSet<ConstraintLocator *> DefaultedConstraints;
Expand Down Expand Up @@ -1811,6 +1810,11 @@ class Solution {
return Type();
}

/// Retrieve the generic environment for the opened element of a given pack
/// expansion, or \c nullptr if no environment was recorded.
GenericEnvironment *
getPackExpansionEnvironment(PackExpansionExpr *expr) const;

/// For a given locator describing a function argument conversion, or a
/// constraint within an argument conversion, returns information about the
/// application of the argument to its parameter. If the locator is not
Expand Down Expand Up @@ -2408,11 +2412,11 @@ class ConstraintSystem {
llvm::SmallDenseMap<PackExpansionType *, TypeVariableType *, 4>
OpenedPackExpansionTypes;

llvm::SmallDenseMap<ConstraintLocator *, std::pair<UUID, Type>, 4>
llvm::SmallDenseMap<PackExpansionExpr *, GenericEnvironment *, 4>
PackExpansionEnvironments;

llvm::SmallDenseMap<PackElementExpr *, PackExpansionExpr *, 2>
PackEnvironments;
PackElementExpansions;

llvm::SmallVector<GenericEnvironment *, 4> PackElementGenericEnvironments;

Expand Down Expand Up @@ -3371,25 +3375,39 @@ class ConstraintSystem {
void recordOpenedExistentialType(ConstraintLocator *locator,
OpenedArchetypeType *opened);

/// Get the opened element generic environment for the given locator.
GenericEnvironment *getPackElementEnvironment(ConstraintLocator *locator,
CanType shapeClass);
/// Retrieve the generic environment for the opened element of a given pack
/// expansion, or \c nullptr if no environment was recorded yet.
GenericEnvironment *
getPackExpansionEnvironment(PackExpansionExpr *expr) const;

/// Create a new opened element generic environment for the given pack
/// expansion.
GenericEnvironment *
createPackExpansionEnvironment(PackExpansionExpr *expr,
CanGenericTypeParamType shapeParam);

/// Update PackExpansionEnvironments and record a change in the trail.
void recordPackExpansionEnvironment(ConstraintLocator *locator,
std::pair<UUID, Type> uuidAndShape);
void recordPackExpansionEnvironment(PackExpansionExpr *expr,
GenericEnvironment *env);

/// Get the opened element generic environment for the given pack element.
PackExpansionExpr *getPackEnvironment(PackElementExpr *packElement) const;
/// Undo the above change.
void removePackExpansionEnvironment(PackExpansionExpr *expr) {
bool erased = PackExpansionEnvironments.erase(expr);
ASSERT(erased);
}

/// Associate an opened element generic environment to a pack element,
/// and record a change in the trail.
void addPackEnvironment(PackElementExpr *packElement,
PackExpansionExpr *packExpansion);
/// Get the pack expansion expr for the given pack element.
PackExpansionExpr *
getPackElementExpansion(PackElementExpr *packElement) const;

/// Associate a pack element with a given pack expansion, and record the
/// change in the trail.
void recordPackElementExpansion(PackElementExpr *packElement,
PackExpansionExpr *packExpansion);

/// Undo the above change.
void removePackEnvironment(PackElementExpr *packElement) {
bool erased = PackEnvironments.erase(packElement);
void removePackElementExpansion(PackElementExpr *packElement) {
bool erased = PackElementExpansions.erase(packElement);
ASSERT(erased);
}

Expand Down Expand Up @@ -5024,6 +5042,11 @@ class ConstraintSystem {
TypeMatchOptions flags,
ConstraintLocatorBuilder locator);

/// Attempt to match a pack element type with the fully resolved pattern type
/// for the pack expansion.
SolutionKind matchPackElementType(Type elementType, Type patternType,
ConstraintLocatorBuilder locator);

/// Attempt to simplify a PackElementOf constraint.
///
/// Solving this constraint is delayed until the element type is fully
Expand Down
10 changes: 2 additions & 8 deletions lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3942,21 +3942,15 @@ namespace {
}

Expr *visitPackExpansionExpr(PackExpansionExpr *expr) {
simplifyExprType(expr);

// Set the opened pack element environment for this pack expansion.
auto expansionTy = cs.getType(expr)->castTo<PackExpansionType>();
auto *locator = cs.getConstraintLocator(expr);
auto *environment = cs.getPackElementEnvironment(locator,
expansionTy->getCountType()->getCanonicalType());

// Assert that we have an opened element environment, otherwise we'll get
// an ASTVerifier crash when pack archetypes or element archetypes appear
// inside the pack expansion expression.
auto *environment = solution.getPackExpansionEnvironment(expr);
assert(environment);
expr->setGenericEnvironment(environment);

return expr;
return simplifyExprType(expr);
}

Expr *visitPackElementExpr(PackElementExpr *expr) {
Expand Down
21 changes: 10 additions & 11 deletions lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -902,10 +902,10 @@ TypeVarRefCollector::walkToExprPre(Expr *expr) {
}

if (auto *packElement = getAsExpr<PackElementExpr>(expr)) {
// If environment hasn't been established yet, it means that pack expansion
// If expansion hasn't been established yet, it means that pack expansion
// appears inside of this closure.
if (auto *outerEnvironment = CS.getPackEnvironment(packElement))
inferTypeVars(outerEnvironment);
if (auto *outerExpansion = CS.getPackElementExpansion(packElement))
inferTypeVars(outerExpansion);
}

return Action::Continue(expr);
Expand Down Expand Up @@ -1214,9 +1214,8 @@ namespace {
SmallVector<ASTNode, 2> expandedPacks;
collectExpandedPacks(expr, expandedPacks);
for (auto pack : expandedPacks) {
if (auto *elementExpr = getAsExpr<PackElementExpr>(pack)) {
CS.addPackEnvironment(elementExpr, expr);
}
if (auto *elementExpr = getAsExpr<PackElementExpr>(pack))
CS.recordPackElementExpansion(elementExpr, expr);
}

auto *patternLoc = CS.getConstraintLocator(
Expand Down Expand Up @@ -3239,15 +3238,15 @@ namespace {

Type visitPackElementExpr(PackElementExpr *expr) {
auto packType = CS.getType(expr->getPackRefExpr());
auto *packEnvironment = CS.getPackEnvironment(expr);
auto *packExpansion = CS.getPackElementExpansion(expr);
auto elementType = openPackElement(
packType, CS.getConstraintLocator(expr), packEnvironment);
if (packEnvironment) {
packType, CS.getConstraintLocator(expr), packExpansion);
if (packExpansion) {
auto expansionType =
CS.getType(packEnvironment)->castTo<PackExpansionType>();
CS.getType(packExpansion)->castTo<PackExpansionType>();
CS.addConstraint(ConstraintKind::ShapeOf, expansionType->getCountType(),
packType,
CS.getConstraintLocator(packEnvironment,
CS.getConstraintLocator(packExpansion,
ConstraintLocator::PackShape));
} else {
CS.recordFix(AllowInvalidPackReference::create(
Expand Down
99 changes: 59 additions & 40 deletions lib/Sema/CSSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9626,6 +9626,63 @@ ConstraintSystem::simplifyBindTupleOfFunctionParamsConstraint(
return SolutionKind::Solved;
}

ConstraintSystem::SolutionKind
ConstraintSystem::matchPackElementType(Type elementType, Type patternType,
ConstraintLocatorBuilder locator) {
auto tryFix = [&](llvm::function_ref<ConstraintFix *(void)> fix) {
if (!shouldAttemptFixes())
return SolutionKind::Error;

if (recordFix(fix()))
return SolutionKind::Error;

recordAnyTypeVarAsPotentialHole(elementType);
return SolutionKind::Solved;
};

auto *loc = getConstraintLocator(locator);
ASSERT(loc->directlyAt<PackExpansionExpr>());
auto *packExpansion = castToExpr<PackExpansionExpr>(loc->getAnchor());

ASSERT(!patternType->hasTypeVariable());
auto shapeClass = patternType->getReducedShape();

// `each` was applied to a concrete type.
if (!shapeClass->is<PackArchetypeType>()) {
return tryFix([&]() {
return AllowInvalidPackElement::create(*this, patternType, loc);
});
}

auto shapeParam = CanGenericTypeParamType(cast<GenericTypeParamType>(
shapeClass->mapTypeOutOfContext()->getCanonicalType()));

auto *genericEnv = getPackExpansionEnvironment(packExpansion);
if (genericEnv) {
if (shapeParam != genericEnv->getOpenedElementShapeClass()) {
return tryFix([&]() {
auto envShape = genericEnv->mapTypeIntoContext(
genericEnv->getOpenedElementShapeClass());
if (auto *pack = dyn_cast<PackType>(envShape))
envShape = pack->unwrapSingletonPackExpansion()->getPatternType();

return SkipSameShapeRequirement::create(
*this, envShape, shapeClass,
getConstraintLocator(loc, ConstraintLocator::PackShape));
});
}
} else {
genericEnv = createPackExpansionEnvironment(packExpansion, shapeParam);
}

auto expectedElementTy =
genericEnv->mapContextualPackTypeIntoElementContext(patternType);
assert(!expectedElementTy->is<PackType>());

addConstraint(ConstraintKind::Equal, elementType, expectedElementTy, locator);
return SolutionKind::Solved;
}

ConstraintSystem::SolutionKind
ConstraintSystem::simplifyPackElementOfConstraint(Type first, Type second,
TypeMatchOptions flags,
Expand Down Expand Up @@ -9660,46 +9717,8 @@ ConstraintSystem::simplifyPackElementOfConstraint(Type first, Type second,
}

// Let's try to resolve element type based on the pattern type.
if (!patternType->hasTypeVariable()) {
auto *loc = getConstraintLocator(locator);
auto shapeClass = patternType->getReducedShape();
auto *elementEnv = getPackElementEnvironment(loc, shapeClass);

// Without an opened element environment, we cannot derive the
// element binding.
if (!elementEnv) {
if (!shouldAttemptFixes())
return SolutionKind::Error;

// `each` was applied to a concrete type.
if (!shapeClass->is<PackArchetypeType>()) {
if (recordFix(AllowInvalidPackElement::create(*this, patternType, loc)))
return SolutionKind::Error;
} else {
auto envShape = PackExpansionEnvironments.find(loc);
if (envShape == PackExpansionEnvironments.end()) {
return SolutionKind::Error;
}
auto *fix = SkipSameShapeRequirement::create(
*this, envShape->second.second, shapeClass,
getConstraintLocator(loc, ConstraintLocator::PackShape));
if (recordFix(fix)) {
return SolutionKind::Error;
}
}

recordAnyTypeVarAsPotentialHole(elementType);
return SolutionKind::Solved;
}

auto expectedElementTy =
elementEnv->mapContextualPackTypeIntoElementContext(patternType);
assert(!expectedElementTy->is<PackType>());

addConstraint(ConstraintKind::Equal, elementType, expectedElementTy,
locator);
return SolutionKind::Solved;
}
if (!patternType->hasTypeVariable())
return matchPackElementType(elementType, patternType, locator);

// Otherwise we are inferred or checking pattern type.

Expand Down
12 changes: 6 additions & 6 deletions lib/Sema/CSSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ Solution ConstraintSystem::finalize() {
solution.PackExpansionEnvironments.insert(env);
}

for (const auto &packEnv : PackEnvironments)
solution.PackEnvironments.insert(packEnv);
for (const auto &packEnv : PackElementExpansions)
solution.PackElementExpansions.insert(packEnv);

for (const auto &synthesized : SynthesizedConformances) {
solution.SynthesizedConformances.insert(synthesized);
Expand Down Expand Up @@ -355,10 +355,10 @@ void ConstraintSystem::replaySolution(const Solution &solution,
recordPackExpansionEnvironment(expansion.first, expansion.second);
}

// Register the solutions's pack environments.
for (auto &packEnvironment : solution.PackEnvironments) {
if (PackEnvironments.count(packEnvironment.first) == 0)
addPackEnvironment(packEnvironment.first, packEnvironment.second);
// Register the solutions's pack expansions.
for (auto &packEnvironment : solution.PackElementExpansions) {
if (PackElementExpansions.count(packEnvironment.first) == 0)
recordPackElementExpansion(packEnvironment.first, packEnvironment.second);
}

// Register the defaulted type variables.
Expand Down
4 changes: 2 additions & 2 deletions lib/Sema/CSSyntacticElement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ class TypeVariableRefFinder : public ASTWalker {
// that reference pack elements have to bring expansion's shape
// type in scope to make sure that the shapes match.
if (auto *packElement = getAsExpr<PackElementExpr>(expr)) {
if (auto *outerEnvironment = CS.getPackEnvironment(packElement)) {
auto *expansionTy = CS.simplifyType(CS.getType(outerEnvironment))
if (auto *outerExpansion = CS.getPackElementExpansion(packElement)) {
auto *expansionTy = CS.simplifyType(CS.getType(outerExpansion))
->castTo<PackExpansionType>();
expansionTy->getCountType()->getTypeVariables(ReferencedVars);
}
Expand Down
Loading