Skip to content

Commit c39915f

Browse files
authored
[mlir][NFC] Simplify constant checks with isOneInteger and renamed isZeroInteger. (#139340)
The revision adds isOneInteger helper, and simplifies the existing code with the two methods. It removes some lambda, which makes code cleaner. For downstream users, you can update the code with the below script. ```bash sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp ``` --------- Signed-off-by: hanhanW <[email protected]>
1 parent de3e8ff commit c39915f

22 files changed

+54
-76
lines changed

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424

2525
namespace mlir {
2626

27-
/// Return true if `v` is an IntegerAttr with value `0` of a ConstantIndexOp
28-
/// with attribute with value `0`.
29-
bool isZeroIndex(OpFoldResult v);
27+
/// Return true if `v` is an IntegerAttr with value `0`.
28+
bool isZeroInteger(OpFoldResult v);
29+
30+
/// Return true if `v` is an IntegerAttr with value `1`.
31+
bool isOneInteger(OpFoldResult v);
3032

3133
/// Represents a range (offset, size, and stride) where each element of the
3234
/// triple may be dynamic or static.

mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
897897
OpFoldResult offset =
898898
getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
899899
.front();
900-
if (isConstantIntValue(offset, 0)) {
900+
if (isZeroInteger(offset)) {
901901
rewriter.replaceOp(op, src);
902902
return success();
903903
}

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4488,8 +4488,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
44884488

44894489
// Return true if we have a zero-value tile.
44904490
auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
4491-
return llvm::any_of(
4492-
tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
4491+
return llvm::any_of(tiles, isZeroInteger);
44934492
};
44944493

44954494
// Verify tiles. Do not allow zero tiles.

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3401,10 +3401,7 @@ static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
34013401
SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
34023402
SmallVector<OpFoldResult> steps = loop.getMixedStep();
34033403

3404-
if (llvm::all_of(
3405-
lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
3406-
llvm::all_of(
3407-
steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
3404+
if (llvm::all_of(lbs, isZeroInteger) && llvm::all_of(steps, isOneInteger)) {
34083405
return loop;
34093406
}
34103407

mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,8 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
441441
// If the `padOp` has a nofold attribute and all paddings are known to be 0,
442442
// explicitly insert a `linalg.copy`.
443443
if (padOp.getNofoldAttr() &&
444-
llvm::all_of(padOp.getMixedLowPad(), isZeroIndex) &&
445-
llvm::all_of(padOp.getMixedHighPad(), isZeroIndex)) {
444+
llvm::all_of(padOp.getMixedLowPad(), isZeroInteger) &&
445+
llvm::all_of(padOp.getMixedHighPad(), isZeroInteger)) {
446446
using bufferization::AllocTensorOp;
447447
Value allocated =
448448
rewriter.create<AllocTensorOp>(loc, resultType, dynamicSizes);

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
2424
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2525
#include "mlir/Dialect/Utils/IndexingUtils.h"
26+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2627
#include "mlir/IR/AffineExpr.h"
2728
#include "mlir/IR/AffineMap.h"
2829
#include "mlir/IR/BuiltinOps.h"
@@ -376,13 +377,13 @@ static void calculateTileOffsetsAndSizes(
376377

377378
SmallVector<Value> threadIds = forallOp.getInductionVars();
378379
SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
379-
numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
380+
numThreads, [](OpFoldResult ofr) { return !isZeroInteger(ofr); });
380381
int64_t nLoops = loopRanges.size();
381382
tiledOffsets.reserve(nLoops);
382383
tiledSizes.reserve(nLoops);
383384
for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
384385
bool overflow = loopIdx >= numThreads.size();
385-
bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
386+
bool isZero = !overflow && isZeroInteger(numThreads[loopIdx]);
386387
// Degenerate case: take the whole domain.
387388
if (overflow || isZero) {
388389
tiledOffsets.push_back(loopRanges[loopIdx].offset);
@@ -413,7 +414,7 @@ static void calculateTileOffsetsAndSizes(
413414
OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
414415
b, loc, i + j * m - n,
415416
{offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
416-
if (!isConstantIntValue(residualTileSize, 0)) {
417+
if (!isZeroInteger(residualTileSize)) {
417418
OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
418419
b, loc, -i + m, {offsetPerThread, size});
419420
tileSizePerThread =
@@ -655,7 +656,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
655656
Operation *tiledOp = nullptr;
656657

657658
SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
658-
numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
659+
numThreads, [](OpFoldResult ofr) { return !isZeroInteger(ofr); });
659660
SmallVector<Value> materializedNonZeroNumThreads =
660661
getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);
661662

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ struct LinalgOpPartialReductionInterface
369369

370370
SmallVector<OpFoldResult> tiledShape;
371371
for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) {
372-
if (isZeroIndex(tileSize)) {
372+
if (isZeroInteger(tileSize)) {
373373
tiledShape.push_back(dimSize);
374374
} else {
375375
tiledShape.push_back(tileSize);
@@ -732,7 +732,7 @@ struct PackOpTiling
732732
// iterated or inner dims are not tiled. Otherwise, it will generate a
733733
// sequence of non-trivial ops (for partial tiles).
734734
for (auto offset : offsets.take_back(numTiles))
735-
if (!isConstantIntValue(offset, 0))
735+
if (!isZeroInteger(offset))
736736
return failure();
737737

738738
for (auto iter :

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
5959
TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
6060

6161
void visitDimExpr(AffineDimExpr expr) {
62-
isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]);
62+
isTiled |= !isZeroInteger(tileSizes[expr.getPosition()]);
6363
}
6464
void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
6565
visit(expr.getLHS());
@@ -741,7 +741,7 @@ SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
741741
SmallVector<OpFoldResult> offsets;
742742
for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
743743
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
744-
bool isTiled = !isZeroIndex(tileSizes[idx]);
744+
bool isTiled = !isZeroInteger(tileSizes[idx]);
745745
offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0));
746746
LLVM_DEBUG(llvm::dbgs()
747747
<< "computeTileOffsets: " << offsets.back() << "\n");
@@ -754,7 +754,7 @@ SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
754754
ArrayRef<OpFoldResult> sizeBounds) {
755755
SmallVector<OpFoldResult> sizes;
756756
for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
757-
bool isTiled = !isZeroIndex(tileSizes[idx]);
757+
bool isTiled = !isZeroInteger(tileSizes[idx]);
758758
// Before composing, we need to make range a closed interval.
759759
OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx];
760760
AffineExpr d0 = getAffineDimExpr(0, b.getContext());
@@ -810,7 +810,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
810810
bool omitPartialTileCheck) {
811811
assert(ivs.size() == static_cast<size_t>(llvm::count_if(
812812
llvm::make_range(tileSizes.begin(), tileSizes.end()),
813-
[](OpFoldResult v) { return !isZeroIndex(v); })) &&
813+
[](OpFoldResult v) { return !isZeroInteger(v); })) &&
814814
"expected as many ivs as non-zero sizes");
815815

816816
// Construct (potentially temporary) mins and maxes on which to apply maps

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1894,9 +1894,7 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
18941894
// reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
18951895
// are 0.
18961896
if (auto prev = src.getDefiningOp<SubViewOp>())
1897-
if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1898-
return isConstantIntValue(val, 0);
1899-
}))
1897+
if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
19001898
return prev.getSource();
19011899

19021900
return nullptr;
@@ -3290,11 +3288,9 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
32903288
auto srcSizes = srcSubview.getMixedSizes();
32913289
auto sizes = getMixedSizes();
32923290
auto offsets = getMixedOffsets();
3293-
bool allOffsetsZero = llvm::all_of(
3294-
offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
3291+
bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
32953292
auto strides = getMixedStrides();
3296-
bool allStridesOne = llvm::all_of(
3297-
strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
3293+
bool allStridesOne = llvm::all_of(strides, isOneInteger);
32983294
bool allSizesSame = llvm::equal(sizes, srcSizes);
32993295
if (allOffsetsZero && allStridesOne && allSizesSame &&
33003296
resultMemrefType == sourceMemrefType)

mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,7 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
251251
// to do.
252252
SmallVector<OpFoldResult> indices =
253253
getAsOpFoldResult(loadStoreLikeOp.getIndices());
254-
if (llvm::all_of(indices, [](const OpFoldResult &opFold) {
255-
return isConstantIntValue(opFold, 0);
256-
})) {
254+
if (llvm::all_of(indices, isZeroInteger)) {
257255
return rewriter.notifyMatchFailure(
258256
loadStoreLikeOp, "no computation to extract: offsets are 0s");
259257
}

0 commit comments

Comments
 (0)