Skip to content

Commit cdcb0af

Browse files
committed
[mlir][NFC] Simplify constant checks with isZeroIndex and isOneIndex.
The revision adds isOneIndex helper, and simplifies the existing code with the two methods. It removes some lambda, which makes code cleaner. Signed-off-by: hanhanW <[email protected]>
1 parent ce9a898 commit cdcb0af

File tree

18 files changed

+45
-57
lines changed

18 files changed

+45
-57
lines changed

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ namespace mlir {
2828
/// with attribute with value `0`.
2929
bool isZeroIndex(OpFoldResult v);
3030

31+
/// Return true if `v` is an IntegerAttr with value `1` of a ConstantIndexOp
32+
/// with attribute with value `1`.
33+
bool isOneIndex(OpFoldResult v);
34+
3135
/// Represents a range (offset, size, and stride) where each element of the
3236
/// triple may be dynamic or static.
3337
struct Range {

mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
897897
OpFoldResult offset =
898898
getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
899899
.front();
900-
if (isConstantIntValue(offset, 0)) {
900+
if (isZeroIndex(offset)) {
901901
rewriter.replaceOp(op, src);
902902
return success();
903903
}

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4426,8 +4426,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
44264426

44274427
// Return true if we have a zero-value tile.
44284428
auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
4429-
return llvm::any_of(
4430-
tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
4429+
return llvm::any_of(tiles, isZeroIndex);
44314430
};
44324431

44334432
// Verify tiles. Do not allow zero tiles.

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3315,10 +3315,7 @@ static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
33153315
SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
33163316
SmallVector<OpFoldResult> steps = loop.getMixedStep();
33173317

3318-
if (llvm::all_of(
3319-
lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
3320-
llvm::all_of(
3321-
steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
3318+
if (llvm::all_of(lbs, isZeroIndex) && llvm::all_of(steps, isOneIndex)) {
33223319
return loop;
33233320
}
33243321

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
2424
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2525
#include "mlir/Dialect/Utils/IndexingUtils.h"
26+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2627
#include "mlir/IR/AffineExpr.h"
2728
#include "mlir/IR/AffineMap.h"
2829
#include "mlir/IR/BuiltinOps.h"
@@ -376,13 +377,13 @@ static void calculateTileOffsetsAndSizes(
376377

377378
SmallVector<Value> threadIds = forallOp.getInductionVars();
378379
SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
379-
numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
380+
numThreads, [](OpFoldResult ofr) { return !isZeroIndex(ofr); });
380381
int64_t nLoops = loopRanges.size();
381382
tiledOffsets.reserve(nLoops);
382383
tiledSizes.reserve(nLoops);
383384
for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
384385
bool overflow = loopIdx >= numThreads.size();
385-
bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
386+
bool isZero = !overflow && isZeroIndex(numThreads[loopIdx]);
386387
// Degenerate case: take the whole domain.
387388
if (overflow || isZero) {
388389
tiledOffsets.push_back(loopRanges[loopIdx].offset);
@@ -413,7 +414,7 @@ static void calculateTileOffsetsAndSizes(
413414
OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
414415
b, loc, i + j * m - n,
415416
{offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
416-
if (!isConstantIntValue(residualTileSize, 0)) {
417+
if (!isZeroIndex(residualTileSize)) {
417418
OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
418419
b, loc, -i + m, {offsetPerThread, size});
419420
tileSizePerThread =
@@ -655,7 +656,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
655656
Operation *tiledOp = nullptr;
656657

657658
SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
658-
numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
659+
numThreads, [](OpFoldResult ofr) { return !isZeroIndex(ofr); });
659660
SmallVector<Value> materializedNonZeroNumThreads =
660661
getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);
661662

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,7 @@ struct PackOpTiling
732732
// iterated or inner dims are not tiled. Otherwise, it will generate a
733733
// sequence of non-trivial ops (for partial tiles).
734734
for (auto offset : offsets.take_back(numTiles))
735-
if (!isConstantIntValue(offset, 0))
735+
if (!isZeroIndex(offset))
736736
return failure();
737737

738738
for (auto iter :

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,9 +1889,7 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
18891889
// reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
18901890
// are 0.
18911891
if (auto prev = src.getDefiningOp<SubViewOp>())
1892-
if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1893-
return isConstantIntValue(val, 0);
1894-
}))
1892+
if (llvm::all_of(prev.getMixedOffsets(), isZeroIndex))
18951893
return prev.getSource();
18961894

18971895
return nullptr;
@@ -3285,11 +3283,9 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
32853283
auto srcSizes = srcSubview.getMixedSizes();
32863284
auto sizes = getMixedSizes();
32873285
auto offsets = getMixedOffsets();
3288-
bool allOffsetsZero = llvm::all_of(
3289-
offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
3286+
bool allOffsetsZero = llvm::all_of(offsets, isZeroIndex);
32903287
auto strides = getMixedStrides();
3291-
bool allStridesOne = llvm::all_of(
3292-
strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
3288+
bool allStridesOne = llvm::all_of(strides, isOneIndex);
32933289
bool allSizesSame = llvm::equal(sizes, srcSizes);
32943290
if (allOffsetsZero && allStridesOne && allSizesSame &&
32953291
resultMemrefType == sourceMemrefType)

mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,7 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
251251
// to do.
252252
SmallVector<OpFoldResult> indices =
253253
getAsOpFoldResult(loadStoreLikeOp.getIndices());
254-
if (std::all_of(indices.begin(), indices.end(),
255-
[](const OpFoldResult &opFold) {
256-
return isConstantIntValue(opFold, 0);
257-
})) {
254+
if (std::all_of(indices.begin(), indices.end(), isZeroIndex)) {
258255
return rewriter.notifyMatchFailure(
259256
loadStoreLikeOp, "no computation to extract: offsets are 0s");
260257
}

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
133133
tileSizes.resize(numLoops, zero);
134134
for (auto [index, range, nt] :
135135
llvm::enumerate(iterationDomain, numThreads)) {
136-
if (isConstantIntValue(nt, 0))
136+
if (isZeroIndex(nt))
137137
continue;
138138

139139
tileSizes[index] = affine::makeComposedFoldedAffineApply(
@@ -265,7 +265,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
265265

266266
// Non-tiled cases, set the offset and size to the
267267
// `loopRange.offset/size`.
268-
if (isConstantIntValue(nt, 0)) {
268+
if (isZeroIndex(nt)) {
269269
offsets.push_back(loopRange.offset);
270270
sizes.push_back(loopRange.size);
271271
continue;
@@ -280,7 +280,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
280280
{loopRange.offset, nt, tileSize, loopRange.size});
281281

282282
OpFoldResult size = tileSize;
283-
if (!isConstantIntValue(residualTileSize, 0)) {
283+
if (!isZeroIndex(residualTileSize)) {
284284
OpFoldResult sizeMinusOffsetPerThread =
285285
affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
286286
{offset, loopRange.size});
@@ -316,7 +316,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
316316

317317
// Non-tiled cases, set the offset and size to the
318318
// `loopRange.offset/size`.
319-
if (isConstantIntValue(tileSize, 0)) {
319+
if (isZeroIndex(tileSize)) {
320320
offsets.push_back(loopRange.offset);
321321
sizes.push_back(loopRange.size);
322322
continue;
@@ -341,7 +341,7 @@ getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
341341
SmallVector<OpFoldResult> lbs, ubs, steps;
342342
for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
343343
// No loop if the tile size is 0.
344-
if (isConstantIntValue(tileSize, 0))
344+
if (isZeroIndex(tileSize))
345345
continue;
346346
lbs.push_back(loopRange.offset);
347347
ubs.push_back(loopRange.size);
@@ -495,7 +495,7 @@ static LogicalResult generateLoopNestUsingForallOp(
495495
// Prune the zero numthreads.
496496
SmallVector<OpFoldResult> nonZeroNumThreads;
497497
for (auto nt : numThreads) {
498-
if (isConstantIntValue(nt, 0))
498+
if (isZeroIndex(nt))
499499
continue;
500500
nonZeroNumThreads.push_back(nt);
501501
}
@@ -1290,9 +1290,7 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
12901290
sliceSizes = sliceOp.getMixedSizes();
12911291

12921292
// expect all strides of sliceOp being 1
1293-
if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
1294-
return !isConstantIntValue(ofr, 1);
1295-
}))
1293+
if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
12961294
return failure();
12971295

12981296
unsigned sliceResultNumber =
@@ -2114,9 +2112,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(
21142112
SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
21152113

21162114
// 9. Check all insert stride is 1.
2117-
if (llvm::any_of(strides, [](OpFoldResult stride) {
2118-
return !isConstantIntValue(stride, 1);
2119-
})) {
2115+
if (!llvm::all_of(strides, isOneIndex)) {
21202116
return rewriter.notifyMatchFailure(
21212117
candidateSliceOp, "containingOp's result yield with stride");
21222118
}

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
768768
// If an `affine.apply` operation is generated for denormalization, the use
769769
// of `origLb` in those ops must not be replaced. These arent not generated
770770
// when `origLb == 0` and `origStep == 1`.
771-
if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) {
771+
if (!isZeroIndex(origLb) || !isOneIndex(origStep)) {
772772
if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
773773
preservedUses.insert(preservedUse);
774774
}
@@ -785,8 +785,8 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
785785
}
786786
Value denormalizedIv;
787787
SmallPtrSet<Operation *, 2> preserve;
788-
bool isStepOne = isConstantIntValue(origStep, 1);
789-
bool isZeroBased = isConstantIntValue(origLb, 0);
788+
bool isStepOne = isOneIndex(origStep);
789+
bool isZeroBased = isZeroIndex(origLb);
790790

791791
Value scaled = normalizedIv;
792792
if (!isStepOne) {

0 commit comments

Comments
 (0)