diff --git a/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h b/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h index c8167014b5300..e481ef85562b3 100644 --- a/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h +++ b/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h @@ -118,6 +118,31 @@ class FlatLinearConstraints : public presburger::IntegerPolyhedron { /// we explicitly introduce them here. using IntegerPolyhedron::addBound; + /// Returns a non-negative constant bound on the extent (upper bound - lower + /// bound) of the specified variable if it is found to be a constant; returns + /// std::nullopt if it's not a constant. This method treats symbolic + /// variables specially, i.e., it looks for constant differences between + /// affine expressions involving only the symbolic variables. 'lb', if + /// provided, is set to the lower bound map associated with the constant + /// difference, and similarly, `ub` to the upper bound. Note that 'lb', 'ub' + /// are purely symbolic and will correspond to the symbolic variables of the + /// constaint set. + // Egs: 0 <= i <= 15, return 16. + // s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol) + // s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16. + // s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb = + // ceil(s0 - 7 / 8) = floor(s0 / 8)). + /// The difference between this method and + /// IntegerRelation::getConstantBoundOnDimSize is that unlike the latter, this + /// makes use of affine expressions and maps in its inference and provides + /// output with affine maps; it thus handles local variables by detecting them + /// as affine functions of the symbols when possible. + std::optional + getConstantBoundOnDimSize(MLIRContext *context, unsigned pos, + AffineMap *lb = nullptr, AffineMap *ub = nullptr, + unsigned *minLbPos = nullptr, + unsigned *minUbPos = nullptr) const; + /// Returns the constraint system as an integer set. Returns a null integer /// set if the system has no constraints, or if an integer set couldn't be /// constructed as a result of a local variable's explicit representation not diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h index fa29ac23af607..85ff6da1b0b98 100644 --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -152,6 +152,13 @@ class IntegerRelation { /// intersection with no simplification of any sort attempted. void append(const IntegerRelation &other); + /// Finds an equality that equates the specified variable to a constant. + /// Returns the position of the equality row. If 'symbolic' is set to true, + /// symbols are also treated like a constant, i.e., an affine function of the + /// symbols is also treated like a constant. Returns -1 if such an equality + /// could not be found. + int findEqualityToConstant(unsigned pos, bool symbolic = false) const; + /// Return the intersection of the two relations. /// If there are locals, they will be merged. IntegerRelation intersect(IntegerRelation other) const; diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h index de9ed6a683c24..0dd8de4f70039 100644 --- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h @@ -500,6 +500,8 @@ struct MemRefRegion { /// to slice operands (which correspond to symbols). /// If 'addMemRefDimBounds' is true, constant upper/lower bounds /// [0, memref.getDimSize(i)) are added for each MemRef dimension 'i'. + /// If `dropLocalVars` is true, all local variables in `cst` are projected + /// out. /// /// For example, the memref region for this operation at loopDepth = 1 will /// be: @@ -513,9 +515,14 @@ struct MemRefRegion { /// {memref = %A, write = false, {%i <= m0 <= %i + 7} } /// The last field is a 2-d FlatAffineValueConstraints symbolic in %i. /// + /// If `dropOuterIVs` is true, project out any IVs other than those among + /// `loopDepth` surrounding IVs, which would be symbols. If `dropOuterIVs` + /// is false, the IVs would be turned into local variables instead of being + /// projected out. LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState = nullptr, - bool addMemRefDimBounds = true); + bool addMemRefDimBounds = true, + bool dropLocalVars = true, bool dropOuterIVs = true); FlatAffineValueConstraints *getConstraints() { return &cst; } const FlatAffineValueConstraints *getConstraints() const { return &cst; } @@ -530,31 +537,18 @@ struct MemRefRegion { /// corresponding dimension-wise bounds major to minor. The number of elements /// and all the dimension-wise bounds are guaranteed to be non-negative. We /// use int64_t instead of uint64_t since index types can be at most - /// int64_t. `lbs` are set to the lower bounds for each of the rank - /// dimensions, and lbDivisors contains the corresponding denominators for - /// floorDivs. + /// int64_t. `lbs` are set to the lower bound maps for each of the rank + /// dimensions where each of these maps is purely symbolic in the constraints + /// set's symbols. std::optional getConstantBoundingSizeAndShape( SmallVectorImpl *shape = nullptr, - std::vector> *lbs = nullptr, - SmallVectorImpl *lbDivisors = nullptr) const; + SmallVectorImpl *lbs = nullptr) const; /// Gets the lower and upper bound map for the dimensional variable at /// `pos`. void getLowerAndUpperBound(unsigned pos, AffineMap &lbMap, AffineMap &ubMap) const; - /// A wrapper around FlatAffineValueConstraints::getConstantBoundOnDimSize(). - /// 'pos' corresponds to the position of the memref shape's dimension (major - /// to minor) which matches 1:1 with the dimensional variable positions in - /// 'cst'. - std::optional - getConstantBoundOnDimSize(unsigned pos, - SmallVectorImpl *lb = nullptr, - int64_t *lbFloorDivisor = nullptr) const { - assert(pos < getRank() && "invalid position"); - return cst.getConstantBoundOnDimSize64(pos, lb); - } - /// Returns the size of this MemRefRegion in bytes. std::optional getRegionSize(); diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp index 8c179cb2a38ba..6ad39a3a91293 100644 --- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp +++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp @@ -581,6 +581,49 @@ std::pair FlatLinearConstraints::getLowerAndUpperBound( return {lbMap, ubMap}; } +/// Express the pos^th identifier of `cst` as an affine expression in +/// terms of other identifiers, if they are available in `exprs`, using the +/// equality at position `idx` in `cs`t. Populates `exprs` with such an +/// expression if possible, and return true. Returns false otherwise. +static bool detectAsExpr(const FlatLinearConstraints &cst, unsigned pos, + unsigned idx, MLIRContext *context, + SmallVectorImpl &exprs) { + // Initialize with a `0` expression. + auto expr = getAffineConstantExpr(0, context); + + // Traverse `idx`th equality and construct the possible affine expression in + // terms of known identifiers. + unsigned j, e; + for (j = 0, e = cst.getNumVars(); j < e; ++j) { + if (j == pos) + continue; + int64_t c = cst.atEq64(idx, j); + if (c == 0) + continue; + // If any of the involved IDs hasn't been found yet, we can't proceed. + if (!exprs[j]) + break; + expr = expr + exprs[j] * c; + } + if (j < e) + // Can't construct expression as it depends on a yet uncomputed + // identifier. + return false; + + // Add constant term to AffineExpr. + expr = expr + cst.atEq64(idx, cst.getNumVars()); + int64_t vPos = cst.atEq64(idx, pos); + assert(vPos != 0 && "expected non-zero here"); + if (vPos > 0) + expr = (-expr).floorDiv(vPos); + else + // vPos < 0. + expr = expr.floorDiv(-vPos); + // Successfully constructed expression. + exprs[pos] = expr; + return true; +} + /// Compute a representation of `num` identifiers starting at `offset` in `cst` /// as affine expressions involving other known identifiers. Each identifier's /// expression (in terms of known identifiers) is populated into `memo`. @@ -636,41 +679,13 @@ static void computeUnknownVars(const FlatLinearConstraints &cst, // Detect a variable as an expression of other variables. std::optional idx; - if (!(idx = cst.findConstraintWithNonZeroAt(pos, /*isEq=*/true))) { + if (!(idx = cst.findConstraintWithNonZeroAt(pos, /*isEq=*/true))) continue; - } - // Build AffineExpr solving for variable 'pos' in terms of all others. - auto expr = getAffineConstantExpr(0, context); - unsigned j, e; - for (j = 0, e = cst.getNumVars(); j < e; ++j) { - if (j == pos) - continue; - int64_t c = cst.atEq64(*idx, j); - if (c == 0) - continue; - // If any of the involved IDs hasn't been found yet, we can't proceed. - if (!memo[j]) - break; - expr = expr + memo[j] * c; - } - if (j < e) - // Can't construct expression as it depends on a yet uncomputed - // variable. + if (detectAsExpr(cst, pos, *idx, context, memo)) { + changed = true; continue; - - // Add constant term to AffineExpr. - expr = expr + cst.atEq64(*idx, cst.getNumVars()); - int64_t vPos = cst.atEq64(*idx, pos); - assert(vPos != 0 && "expected non-zero here"); - if (vPos > 0) - expr = (-expr).floorDiv(vPos); - else - // vPos < 0. - expr = expr.floorDiv(-vPos); - // Successfully constructed expression. - memo[pos] = expr; - changed = true; + } } // This loop is guaranteed to reach a fixed point - since once an // variable's explicit form is computed (in memo[pos]), it's not updated @@ -891,6 +906,185 @@ FlatLinearConstraints::computeLocalVars(SmallVectorImpl &memo, llvm::all_of(localExprs, [](AffineExpr expr) { return expr; })); } +/// Given an equality or inequality (`isEquality` used to disambiguate) of `cst` +/// at `idx`, traverse and sum up `AffineExpr`s of all known ids other than the +/// `pos`th. Known `AffineExpr`s are given in `exprs` (unknowns are null). If +/// the equality/inequality contains any unknown id, return None. Otherwise +/// return sum as `AffineExpr`. +static std::optional getAsExpr(const FlatLinearConstraints &cst, + unsigned pos, MLIRContext *context, + ArrayRef exprs, + unsigned idx, bool isEquality) { + // Initialize with a `0` expression. + auto expr = getAffineConstantExpr(0, context); + + SmallVector row = + isEquality ? cst.getEquality64(idx) : cst.getInequality64(idx); + + // Traverse `idx`th equality and construct the possible affine expression in + // terms of known identifiers. + unsigned j, e; + for (j = 0, e = cst.getNumVars(); j < e; ++j) { + if (j == pos) + continue; + int64_t c = row[j]; + if (c == 0) + continue; + // If any of the involved IDs hasn't been found yet, we can't proceed. + if (!exprs[j]) + break; + expr = expr + exprs[j] * c; + } + if (j < e) + // Can't construct expression as it depends on a yet uncomputed + // identifier. + return std::nullopt; + + // Add constant term to AffineExpr. + expr = expr + row[cst.getNumVars()]; + return expr; +} + +std::optional FlatLinearConstraints::getConstantBoundOnDimSize( + MLIRContext *context, unsigned pos, AffineMap *lb, AffineMap *ub, + unsigned *minLbPos, unsigned *minUbPos) const { + + assert(pos < getNumDimVars() && "Invalid identifier position"); + + auto freeOfUnknownLocalVars = [&](ArrayRef cst, + ArrayRef whiteListCols) { + for (int i = getNumDimAndSymbolVars(), e = cst.size() - 1; i < e; ++i) { + if (whiteListCols[i] && whiteListCols[i].isSymbolicOrConstant()) + continue; + if (cst[i] != 0) + return false; + } + return true; + }; + + // Detect the necesary local variables first. + SmallVector memo(getNumVars(), AffineExpr()); + (void)computeLocalVars(memo, context); + + // Find an equality for 'pos'^th identifier that equates it to some function + // of the symbolic identifiers (+ constant). + int eqPos = findEqualityToConstant(pos, /*symbolic=*/true); + // If the equality involves a local var that can not be expressed as a + // symbolic or constant affine expression, we bail out. + if (eqPos != -1 && freeOfUnknownLocalVars(getEquality64(eqPos), memo)) { + // This identifier can only take a single value. + if (lb && detectAsExpr(*this, pos, eqPos, context, memo)) { + AffineExpr equalityExpr = + simplifyAffineExpr(memo[pos], 0, getNumSymbolVars()); + *lb = AffineMap::get(/*dimCount=*/0, getNumSymbolVars(), equalityExpr); + if (ub) + *ub = *lb; + } + if (minLbPos) + *minLbPos = eqPos; + if (minUbPos) + *minUbPos = eqPos; + return 1; + } + + // Positions of constraints that are lower/upper bounds on the variable. + SmallVector lbIndices, ubIndices; + + // Note inequalities that give lower and upper bounds. + getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices, + /*eqIndices=*/nullptr, /*offset=*/0, + /*num=*/getNumDimVars()); + + std::optional minDiff = std::nullopt; + unsigned minLbPosition = 0, minUbPosition = 0; + AffineExpr minLbExpr, minUbExpr; + + // Traverse each lower bound and upper bound pair, to compute the difference + // between them. + for (unsigned ubPos : ubIndices) { + // Construct sum of all ids other than `pos`th in the given upper bound row. + std::optional maybeUbExpr = + getAsExpr(*this, pos, context, memo, ubPos, /*isEquality=*/false); + if (!maybeUbExpr.has_value() || !(*maybeUbExpr).isSymbolicOrConstant()) + continue; + + // Canonical form of an inequality that constrains the upper bound on + // an id `x_i` is of the form: + // `c_1*x_1 + c_2*x_2 + ... + c_0 >= 0`, where `c_i` <= -1. + // Therefore the upper bound on `x_i` will be + // `( + // sum(c_j*x_j) where j != i + // + + // c_0 + // ) + // / + // -(c_i)`. Divison here is a floorDiv. + AffineExpr ubExpr = maybeUbExpr->floorDiv(-atIneq64(ubPos, pos)); + assert(-atIneq64(ubPos, pos) > 0 && "invalid upper bound index"); + + // Go over each lower bound. + for (unsigned lbPos : lbIndices) { + // Construct sum of all ids other than `pos`th in the given lower bound + // row. + std::optional maybeLbExpr = + getAsExpr(*this, pos, context, memo, lbPos, /*isEquality=*/false); + if (!maybeLbExpr.has_value() || !(*maybeLbExpr).isSymbolicOrConstant()) + continue; + + // Canonical form of an inequality that is constraining the lower bound + // on an id `x_i is of the form: + // `c_1*x_1 + c_2*x_2 + ... + c_0 >= 0`, where `c_i` >= 1. + // Therefore upperBound on `x_i` will be + // `-( + // sum(c_j*x_j) where j != i + // + + // c_0 + // ) + // / + // c_i`. Divison here is a ceilDiv. + int64_t divisor = atIneq64(lbPos, pos); + // We convert the `ceilDiv` for floordiv with the formula: + // `expr ceildiv divisor is (expr + divisor - 1) floordiv divisor`, + // since uniformly keeping divisons as `floorDiv` helps their + // simplification. + AffineExpr lbExpr = (-(*maybeLbExpr) + divisor - 1).floorDiv(divisor); + assert(atIneq64(lbPos, pos) > 0 && "invalid lower bound index"); + + AffineExpr difference = + simplifyAffineExpr(ubExpr - lbExpr + 1, 0, getNumSymbolVars()); + // If the difference is not constant, ignore the lower bound - upper bound + // pair. + auto constantDiff = dyn_cast(difference); + if (!constantDiff) + continue; + + int64_t diffValue = constantDiff.getValue(); + // This bound is non-negative by definition. + diffValue = std::max(diffValue, 0); + if (!minDiff || diffValue < *minDiff) { + minDiff = diffValue; + minLbPosition = lbPos; + minUbPosition = ubPos; + minLbExpr = lbExpr; + minUbExpr = ubExpr; + } + } + } + + // Populate outputs where available and needed. + if (lb && minDiff) { + *lb = AffineMap::get(/*dimCount=*/0, getNumSymbolVars(), minLbExpr); + } + if (ub) + *ub = AffineMap::get(/*dimCount=*/0, getNumSymbolVars(), minUbExpr); + if (minLbPos) + *minLbPos = minLbPosition; + if (minUbPos) + *minUbPos = minUbPosition; + + return minDiff; +} + IntegerSet FlatLinearConstraints::getAsIntegerSet(MLIRContext *context) const { if (getNumConstraints() == 0) // Return universal set (always true): 0 == 0. diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 5de3fd920e4e0..097cb9c2201aa 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -1521,25 +1521,19 @@ void IntegerRelation::addLocalFloorDiv(ArrayRef dividend, getDivUpperBound(dividendCopy, divisor, dividendCopy.size() - 2)); } -/// Finds an equality that equates the specified variable to a constant. -/// Returns the position of the equality row. If 'symbolic' is set to true, -/// symbols are also treated like a constant, i.e., an affine function of the -/// symbols is also treated like a constant. Returns -1 if such an equality -/// could not be found. -static int findEqualityToConstant(const IntegerRelation &cst, unsigned pos, - bool symbolic = false) { - assert(pos < cst.getNumVars() && "invalid position"); - for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { - DynamicAPInt v = cst.atEq(r, pos); +int IntegerRelation::findEqualityToConstant(unsigned pos, bool symbolic) const { + assert(pos < getNumVars() && "invalid position"); + for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { + DynamicAPInt v = atEq(r, pos); if (v * v != 1) continue; unsigned c; - unsigned f = symbolic ? cst.getNumDimVars() : cst.getNumVars(); + unsigned f = symbolic ? getNumDimVars() : getNumVars(); // This checks for zeros in all positions other than 'pos' in [0, f) for (c = 0; c < f; c++) { if (c == pos) continue; - if (cst.atEq(r, c) != 0) { + if (atEq(r, c) != 0) { // Dependent on another variable. break; } @@ -1554,7 +1548,7 @@ static int findEqualityToConstant(const IntegerRelation &cst, unsigned pos, LogicalResult IntegerRelation::constantFoldVar(unsigned pos) { assert(pos < getNumVars() && "invalid position"); int rowIdx; - if ((rowIdx = findEqualityToConstant(*this, pos)) == -1) + if ((rowIdx = findEqualityToConstant(pos)) == -1) return failure(); // atEq(rowIdx, pos) is either -1 or 1. @@ -1593,12 +1587,13 @@ std::optional IntegerRelation::getConstantBoundOnDimSize( // Find an equality for 'pos'^th variable that equates it to some function // of the symbolic variables (+ constant). - int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true); + int eqPos = findEqualityToConstant(pos, /*symbolic=*/true); if (eqPos != -1) { auto eq = getEquality(eqPos); - // If the equality involves a local var, punt for now. - // TODO: this can be handled in the future by using the explicit - // representation of the local vars. + // If the equality involves a local var, we do not handle it. + // FlatLinearConstraints can instead be used to detect the local variable as + // an affine function (potentially div/mod) of other variables and use + // affine expressions/maps to represent output. if (!std::all_of(eq.begin() + getNumDimAndSymbolVars(), eq.end() - 1, [](const DynamicAPInt &coeff) { return coeff == 0; })) return std::nullopt; @@ -1719,7 +1714,7 @@ IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) { projectOut(0, pos); projectOut(1, getNumVars() - 1); // Check if there's an equality equating the '0'^th variable to a constant. - int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false); + int eqRowIdx = findEqualityToConstant(/*pos=*/0, /*symbolic=*/false); if (eqRowIdx != -1) // atEq(rowIdx, 0) is either -1 or 1. return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0); diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp index ba6f045cff408..69a7e0b790181 100644 --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -1069,9 +1069,9 @@ unsigned MemRefRegion::getRank() const { } std::optional MemRefRegion::getConstantBoundingSizeAndShape( - SmallVectorImpl *shape, std::vector> *lbs, - SmallVectorImpl *lbDivisors) const { + SmallVectorImpl *shape, SmallVectorImpl *lbs) const { auto memRefType = cast(memref.getType()); + MLIRContext *context = memref.getContext(); unsigned rank = memRefType.getRank(); if (shape) shape->reserve(rank); @@ -1083,7 +1083,7 @@ std::optional MemRefRegion::getConstantBoundingSizeAndShape( // over-approximation from projection or union bounding box. We may not add // this on the region itself since they might just be redundant constraints // that will need non-trivials means to eliminate. - FlatAffineValueConstraints cstWithShapeBounds(cst); + FlatLinearValueConstraints cstWithShapeBounds(cst); for (unsigned r = 0; r < rank; r++) { cstWithShapeBounds.addBound(BoundType::LB, r, 0); int64_t dimSize = memRefType.getDimSize(r); @@ -1092,39 +1092,34 @@ std::optional MemRefRegion::getConstantBoundingSizeAndShape( cstWithShapeBounds.addBound(BoundType::UB, r, dimSize - 1); } - // Find a constant upper bound on the extent of this memref region along each - // dimension. + // Find a constant upper bound on the extent of this memref region along + // each dimension. int64_t numElements = 1; int64_t diffConstant; - int64_t lbDivisor; for (unsigned d = 0; d < rank; d++) { - SmallVector lb; + AffineMap lb; std::optional diff = - cstWithShapeBounds.getConstantBoundOnDimSize64(d, &lb, &lbDivisor); + cstWithShapeBounds.getConstantBoundOnDimSize(context, d, &lb); if (diff.has_value()) { diffConstant = *diff; - assert(diffConstant >= 0 && "Dim size bound can't be negative"); - assert(lbDivisor > 0); + assert(diffConstant >= 0 && "dim size bound cannot be negative"); } else { // If no constant bound is found, then it can always be bound by the // memref's dim size if the latter has a constant size along this dim. auto dimSize = memRefType.getDimSize(d); - if (dimSize == ShapedType::kDynamic) + if (ShapedType::isDynamic(dimSize)) return std::nullopt; diffConstant = dimSize; // Lower bound becomes 0. - lb.resize(cstWithShapeBounds.getNumSymbolVars() + 1, 0); - lbDivisor = 1; + lb = AffineMap::get(/*dimCount=*/0, cstWithShapeBounds.getNumSymbolVars(), + /*result=*/getAffineConstantExpr(0, context)); } numElements *= diffConstant; - if (lbs) { + // Populate outputs if available. + if (lbs) lbs->push_back(lb); - assert(lbDivisors && "both lbs and lbDivisor or none"); - lbDivisors->push_back(lbDivisor); - } - if (shape) { + if (shape) shape->push_back(diffConstant); - } } return numElements; } @@ -1172,7 +1167,8 @@ LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) { // (dma_start, dma_wait). LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState, - bool addMemRefDimBounds) { + bool addMemRefDimBounds, bool dropLocalVars, + bool dropOuterIvs) { assert((isa(op)) && "affine read/write op expected"); @@ -1286,15 +1282,25 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, enclosingIVs.resize(loopDepth); SmallVector vars; cst.getValues(cst.getNumDimVars(), cst.getNumDimAndSymbolVars(), &vars); - for (Value var : vars) { - if ((isAffineInductionVar(var)) && !llvm::is_contained(enclosingIVs, var)) { - cst.projectOut(var); + for (auto en : llvm::enumerate(vars)) { + if ((isAffineInductionVar(en.value())) && + !llvm::is_contained(enclosingIVs, en.value())) { + if (dropOuterIvs) { + cst.projectOut(en.value()); + } else { + unsigned varPosition; + cst.findVar(en.value(), &varPosition); + auto varKind = cst.getVarKindAt(varPosition); + varPosition -= cst.getNumDimVars(); + cst.convertToLocal(varKind, varPosition, varPosition + 1); + } } } // Project out any local variables (these would have been added for any - // mod/divs). - cst.projectOut(cst.getNumDimAndSymbolVars(), cst.getNumLocalVars()); + // mod/divs) if specified. + if (dropLocalVars) + cst.projectOut(cst.getNumDimAndSymbolVars(), cst.getNumLocalVars()); // Constant fold any symbolic variables. cst.constantFoldVarRange(/*pos=*/cst.getNumDimVars(), diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp index 1c4793626a152..21cd0e2e49c2a 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -349,17 +349,19 @@ static Value createPrivateMemRef(AffineForOp forOp, // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. MemRefRegion region(srcStoreOp->getLoc()); - bool validRegion = succeeded(region.compute(srcStoreOp, dstLoopDepth)); + bool validRegion = succeeded( + region.compute(srcStoreOp, dstLoopDepth, /*sliceState=*/nullptr, + /*addMemRefDimBounds=*/true, /*dropLocalVars=*/false)); + (void)validRegion; assert(validRegion && "unexpected memref region failure"); SmallVector newShape; - std::vector> lbs; - SmallVector lbDivisors; + SmallVector lbs; lbs.reserve(rank); // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed // by 'srcStoreOpInst' at depth 'dstLoopDepth'. std::optional numElements = - region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors); + region.getConstantBoundingSizeAndShape(&newShape, &lbs); assert(numElements && "non-constant number of elts in local buffer"); const FlatAffineValueConstraints *cst = region.getConstraints(); @@ -367,22 +369,21 @@ static Value createPrivateMemRef(AffineForOp forOp, // on; this would correspond to loop IVs surrounding the level at which the // slice is being materialized. SmallVector outerIVs; - cst->getValues(rank, cst->getNumVars(), &outerIVs); + cst->getValues(rank, cst->getNumDimAndSymbolVars(), &outerIVs); // Build 'rank' AffineExprs from MemRefRegion 'lbs' SmallVector offsets; offsets.reserve(rank); - for (unsigned d = 0; d < rank; ++d) { - assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size"); - AffineExpr offset = top.getAffineConstantExpr(0); - for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) { - offset = offset + lbs[d][j] * top.getAffineDimExpr(j); - } - assert(lbDivisors[d] > 0); - offset = - (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]); - offsets.push_back(offset); + // Outer IVs are considered symbols during memref region computation. Replace + // them uniformly with dims so that valid IR is guaranteed. + SmallVector replacements; + for (unsigned j = 0, e = lbs[0].getNumSymbols(); j < e; ++j) + replacements.push_back(mlir::getAffineDimExpr(j, forOp.getContext())); + for (unsigned d = 0; d < rank; ++d) { + assert(lbs[d].getNumResults() == 1 && + "invalid private memref bound calculation"); + offsets.push_back(lbs[d].getResult(0).replaceSymbols(replacements)); } // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index c7e412b2b0fd9..5c94ec2985c3d 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -1995,16 +1995,22 @@ static LogicalResult generateCopy( SmallVector fastBufferShape; // Compute the extents of the buffer. - std::vector> lbs; - SmallVector lbDivisors; + SmallVector lbs; lbs.reserve(rank); - std::optional numElements = region.getConstantBoundingSizeAndShape( - &fastBufferShape, &lbs, &lbDivisors); + std::optional numElements = + region.getConstantBoundingSizeAndShape(&fastBufferShape, &lbs); if (!numElements) { LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n"); return failure(); } + if (llvm::any_of(lbs, [](AffineMap lb) { return lb.getNumResults() > 1; })) { + // This can be supported in the future if needed. + LLVM_DEBUG(llvm::dbgs() + << "Max lower bound for memref region start not supported\n"); + return failure(); + } + if (*numElements == 0) { LLVM_DEBUG(llvm::dbgs() << "Nothing to copy\n"); return success(); @@ -2028,7 +2034,7 @@ static LogicalResult generateCopy( SmallVector regionSymbols; cst->getValues(rank, cst->getNumVars(), ®ionSymbols); - // Construct the index expressions for the fast memory buffer. The index + // Construct the access expression for the fast memory buffer. The access // expression for a particular dimension of the fast buffer is obtained by // subtracting out the lower bound on the original memref's data region // along the corresponding dimension. @@ -2037,19 +2043,13 @@ static LogicalResult generateCopy( SmallVector fastBufOffsets; fastBufOffsets.reserve(rank); for (unsigned d = 0; d < rank; d++) { - assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size"); - - AffineExpr offset = top.getAffineConstantExpr(0); - for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) - offset = offset + lbs[d][j] * top.getAffineDimExpr(j); - assert(lbDivisors[d] > 0); - offset = - (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]); + assert(lbs[d].getNumSymbols() == cst->getNumCols() - rank - 1 && + "incorrect bound size"); // Set copy start location for this dimension in the lower memory space // memref. - if (auto caf = dyn_cast(offset)) { - auto indexVal = caf.getValue(); + if (auto caf = lbs[d].isSingleConstant()) { + auto indexVal = lbs[d].getSingleConstantResult(); if (indexVal == 0) { memIndices.push_back(zeroIndex); } else { @@ -2059,16 +2059,23 @@ static LogicalResult generateCopy( } else { // The coordinate for the start location is just the lower bound along the // corresponding dimension on the memory region (stored in 'offset'). - auto map = AffineMap::get( - cst->getNumDimVars() + cst->getNumSymbolVars() - rank, 0, offset); - memIndices.push_back(b.create(loc, map, regionSymbols)); + // Remap all inputs of the map to dimensions uniformly since in the + // generate IR we need valid affine symbols as opposed to "symbols" for + // the purpose of the memref region. + SmallVector symReplacements(lbs[d].getNumSymbols()); + for (unsigned i = 0, e = lbs[d].getNumSymbols(); i < e; ++i) + symReplacements[i] = top.getAffineDimExpr(i); + lbs[d] = lbs[d].replaceDimsAndSymbols( + /*dimReplacements=*/{}, symReplacements, lbs[d].getNumSymbols(), + /*numResultSyms=*/0); + memIndices.push_back(b.create(loc, lbs[d], regionSymbols)); } // The fast buffer is copied into at location zero; addressing is relative. bufIndices.push_back(zeroIndex); // Record the offsets since they are needed to remap the memory accesses of // the original memref further below. - fastBufOffsets.push_back(offset); + fastBufOffsets.push_back(lbs[d].getResult(0)); } // The faster memory space buffer. @@ -2596,10 +2603,11 @@ static AffineIfOp createSeparationCondition(MutableArrayRef loops, cst.setDimSymbolSeparation(/*newSymbolCount=*/cst.getNumDimAndSymbolVars() - 1); unsigned fullTileLbPos, fullTileUbPos; - if (!cst.getConstantBoundOnDimSize(0, /*lb=*/nullptr, - /*boundFloorDivisor=*/nullptr, - /*ub=*/nullptr, &fullTileLbPos, - &fullTileUbPos)) { + if (!((IntegerRelation)cst) + .getConstantBoundOnDimSize(0, /*lb=*/nullptr, + /*boundFloorDivisor=*/nullptr, + /*ub=*/nullptr, &fullTileLbPos, + &fullTileUbPos)) { LLVM_DEBUG(llvm::dbgs() << "Can't get constant diff pair for a loop\n"); return nullptr; } @@ -2669,9 +2677,10 @@ createFullTiles(MutableArrayRef inputNest, // pair of with a constant difference. cst.setDimSymbolSeparation(cst.getNumDimAndSymbolVars() - 1); unsigned lbPos, ubPos; - if (!cst.getConstantBoundOnDimSize(/*pos=*/0, /*lb=*/nullptr, - /*boundFloorDivisor=*/nullptr, - /*ub=*/nullptr, &lbPos, &ubPos) || + if (!((IntegerRelation)cst) + .getConstantBoundOnDimSize(/*pos=*/0, /*lb=*/nullptr, + /*boundFloorDivisor=*/nullptr, + /*ub=*/nullptr, &lbPos, &ubPos) || lbPos == ubPos) { LLVM_DEBUG(llvm::dbgs() << "[tile separation] Can't get constant diff / " "equalities not yet handled\n"); diff --git a/mlir/test/Dialect/Affine/dma-generate.mlir b/mlir/test/Dialect/Affine/dma-generate.mlir index b38bf896e78cf..0438499468696 100644 --- a/mlir/test/Dialect/Affine/dma-generate.mlir +++ b/mlir/test/Dialect/Affine/dma-generate.mlir @@ -485,9 +485,6 @@ func.func @test_read_write_region_union() { // This should create a buffer of size 2 affine.for %arg2. -#map_lb = affine_map<(d0) -> (d0)> -#map_ub = affine_map<(d0) -> (d0 + 3)> -#map_acc = affine_map<(d0) -> (d0 floordiv 8)> // CHECK-LABEL: func @test_analysis_util func.func @test_analysis_util(%arg0: memref<4x4x16x1xf32>, %arg1: memref<144x9xf32>, %arg2: memref<2xf32>) -> (memref<144x9xf32>, memref<2xf32>) { %c0 = arith.constant 0 : index @@ -495,13 +492,11 @@ func.func @test_analysis_util(%arg0: memref<4x4x16x1xf32>, %arg1: memref<144x9xf %1 = memref.alloc() : memref<144x4xf32> %2 = arith.constant 0.0 : f32 affine.for %i8 = 0 to 9 step 3 { - affine.for %i9 = #map_lb(%i8) to #map_ub(%i8) { + affine.for %i9 = affine_map<(d0) -> (d0)>(%i8) to affine_map<(d0) -> (d0 + 3)>(%i8) { affine.for %i17 = 0 to 64 { - %23 = affine.apply #map_acc(%i9) - %25 = affine.load %arg2[%23] : memref<2xf32> - %26 = affine.apply #map_lb(%i17) - %27 = affine.load %0[%26, %c0] : memref<64x1xf32> - affine.store %27, %arg2[%23] : memref<2xf32> + %25 = affine.load %arg2[%i9 floordiv 8] : memref<2xf32> + %27 = affine.load %0[%i17, %c0] : memref<64x1xf32> + affine.store %27, %arg2[%i17] : memref<2xf32> } } } @@ -509,7 +504,7 @@ func.func @test_analysis_util(%arg0: memref<4x4x16x1xf32>, %arg1: memref<144x9xf } // CHECK: affine.for %{{.*}} = 0 to 9 step 3 { // CHECK: [[BUF:%[0-9a-zA-Z_]+]] = memref.alloc() : memref<2xf32, 2> -// CHECK: affine.dma_start %{{.*}}[%{{.*}} floordiv 8], [[BUF]] +// CHECK: affine.dma_start %{{.*}}[%c0{{.*}}], [[BUF]] // CHECK: affine.dma_wait %{{.*}}[%{{.*}}], %{{.*}} : memref<1xi32> // CHECK: affine.for %{{.*}} = diff --git a/mlir/test/Dialect/Affine/loop-fusion-3.mlir b/mlir/test/Dialect/Affine/loop-fusion-3.mlir index 2116cfb9f884a..70d6c82105543 100644 --- a/mlir/test/Dialect/Affine/loop-fusion-3.mlir +++ b/mlir/test/Dialect/Affine/loop-fusion-3.mlir @@ -521,10 +521,7 @@ func.func @fuse_minor_affine_map(%in: memref<128xf32>, %out: memref<20x512xf32>) return } -// TODO: The size of the private memref is not properly computed in the presence -// of the 'mod' operation. It should be memref<1xf32> instead of -// memref<128xf32>: https://bugs.llvm.org/show_bug.cgi?id=46973 -// MAXIMAL: memref.alloc() : memref<128xf32> +// MAXIMAL: memref.alloc() : memref<1xf32> // MAXIMAL: affine.for // MAXIMAL-NEXT: affine.for // MAXIMAL-NOT: affine.for @@ -553,7 +550,7 @@ func.func @should_fuse_multi_store_producer_and_privatize_memfefs() { %0 = affine.load %b[%arg0] : memref<10xf32> } - // All the memrefs should be privatized except '%c', which is not involved in + // All the memrefs should be privatized except '%c', which is not involved in // the producer-consumer fusion. // CHECK: affine.for %{{.*}} = 0 to 10 { // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> @@ -584,7 +581,7 @@ func.func @should_fuse_multi_store_producer_with_escaping_memrefs_and_remove_src %0 = affine.load %b[%i2] : memref<10xf32> } - // Producer loop '%i0' should be removed after fusion since fusion is maximal. + // Producer loop '%i0' should be removed after fusion since fusion is maximal. // No memref should be privatized since they escape the function, and the // producer is removed after fusion. // CHECK: affine.for %{{.*}} = 0 to 10 { @@ -769,7 +766,7 @@ func.func @should_not_fuse_defining_node_has_transitive_dependence_from_source_l %2 = arith.divf %0, %1 : f32 } - // When loops '%i0' and '%i2' are evaluated first, they should not be + // When loops '%i0' and '%i2' are evaluated first, they should not be // fused. The defining node of '%0' in loop '%i2' has transitive dependence // from loop '%i0'. After that, loops '%i0' and '%i1' are evaluated, and they // will be fused as usual. diff --git a/mlir/test/Dialect/Affine/loop-fusion.mlir b/mlir/test/Dialect/Affine/loop-fusion.mlir index dcd2e1cdb275a..1ea42517988c3 100644 --- a/mlir/test/Dialect/Affine/loop-fusion.mlir +++ b/mlir/test/Dialect/Affine/loop-fusion.mlir @@ -748,7 +748,7 @@ func.func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // // CHECK-LABEL: func @R6_to_R2_reshape -// CHECK: memref.alloc() : memref<1x2x3x3x16x1xi32> +// CHECK: memref.alloc() : memref<1x1x1x1x1x1xi32> // CHECK: memref.alloc() : memref<1x1xi32> // CHECK: memref.alloc() : memref<64x9xi32> // CHECK-NEXT: affine.for %{{.*}} = 0 to 64 { @@ -759,7 +759,7 @@ func.func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // CHECK-NEXT: affine.apply [[$MAP3]](%{{.*}}, %{{.*}}) // CHECK-NEXT: affine.apply [[$MAP4]](%{{.*}}, %{{.*}}) // CHECK-NEXT: "foo"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (index, index, index, index, index, index) -> i32 -// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, ((%{{.*}} * 9 + %{{.*}}) mod 288) floordiv 144, ((%{{.*}} * 9 + %{{.*}}) mod 144) floordiv 48, ((%{{.*}} * 9 + %{{.*}}) mod 48) floordiv 16, (%{{.*}} * 9 + %{{.*}}) mod 16, 0] : memref<1x2x3x3x16x1xi32> +// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, 0, 0, 0, 0, 0] : memref<1x1x1x1x1x1xi32> // CHECK-NEXT: affine.apply [[$MAP11]](%{{.*}}, %{{.*}}) // CHECK-NEXT: affine.apply [[$MAP12]](%{{.*}}) // CHECK-NEXT: affine.apply [[$MAP13]](%{{.*}}) @@ -767,7 +767,7 @@ func.func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // CHECK-NEXT: affine.apply [[$MAP15]](%{{.*}}) // CHECK-NEXT: affine.apply [[$MAP16]](%{{.*}}) // CHECK-NEXT: affine.apply [[$MAP17]](%{{.*}}) -// CHECK-NEXT: affine.load %{{.*}}[0, ((%{{.*}} * 9 + %{{.*}}) mod 288) floordiv 144, ((%{{.*}} * 9 + %{{.*}}) mod 144) floordiv 48, ((%{{.*}} * 9 + %{{.*}}) mod 48) floordiv 16, (%{{.*}} * 9 + %{{.*}}) mod 16, 0] : memref<1x2x3x3x16x1xi32> +// CHECK-NEXT: affine.load %{{.*}}[0, 0, 0, 0, 0, 0] : memref<1x1x1x1x1x1xi32> // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xi32> // CHECK-NEXT: affine.load %{{.*}}[0, 0] : memref<1x1xi32> // CHECK-NEXT: arith.muli %{{.*}}, %{{.*}} : i32