Skip to content

Commit d7925b7

Browse files
[mlir][PartialReductionTilingInterface] Generalize implementation of tileUsingSCF for ReductionTilingStrategy::PartialOuterReduction.
This is a precursor to generalizing the `tileUsingSCF` to handle `ReductionTilingStrategy::PartialOuterParallel` strategy. This change itself is generalizing/refactoring the current implementation that supports only `ReductionTilingStrategy::PartialOuterReduction`. Changes in this PR - Move the `ReductionTilingStrategy` enum out of `scf::SCFTilingOptions` and make them visible to `TilingInterface`. - `PartialTilingInterface` changes - Pass the `tilingStrategy` used for partial reduction to `tileToPartialReduction`. - Pass the reduction dimension along as `const llvm::SetVector<unsigned> &`. - Allow `scf::SCFTilingOptions` to set the reduction dimensions that are to be tiled. - Change `structured.tiled_reduction_using_for` to allow specification of the reduction dimensions to be partially tiled. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent f280d3b commit d7925b7

File tree

9 files changed

+433
-251
lines changed

9 files changed

+433
-251
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1850,6 +1850,10 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
18501850
- the result-combining op,
18511851
- the parent `for` op.
18521852

1853+
The `reduction_dims` can be used to specify the subset of reduction dimensions
1854+
of the operation to tile. If left unspecified, all reduction dimensions are
1855+
tiled.
1856+
18531857
#### Example:
18541858

18551859
```
@@ -1900,7 +1904,8 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
19001904

19011905
// TODO: support mixed static-dynamic (see TileUsingForallOp).
19021906
let arguments = (ins TransformHandleTypeInterface:$target,
1903-
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
1907+
DefaultValuedAttr<I64ArrayAttr, "{}">:$reduction_dims,
1908+
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
19041909
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
19051910
TransformHandleTypeInterface:$split_op,
19061911
TransformHandleTypeInterface:$combining_op,
@@ -1913,6 +1918,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
19131918

19141919
let assemblyFormat = [{
19151920
$target
1921+
(`reduction_dims` `=` $reduction_dims^)?
19161922
`by` `tile_sizes` `=` $tile_sizes
19171923
attr-dict
19181924
`:` functional-type(operands, results)

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -85,28 +85,21 @@ struct SCFTilingOptions {
8585
return *this;
8686
}
8787

88+
/// Specify mapping of loops to devices. This is only respected when the loop
89+
/// constructs support such a mapping (like `scf.forall`). Will be ignored
90+
/// when using loop constructs that dont support such a mapping (like
91+
/// `scf.for`)
92+
SmallVector<Attribute> mappingVector = {};
93+
SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
94+
mappingVector = llvm::to_vector(mapping);
95+
return *this;
96+
}
97+
98+
//-------------------------------------------------------------------------//
99+
// Options related reduction tiling
100+
//-------------------------------------------------------------------------//
101+
88102
/// Specify how reduction dimensions should be tiled.
89-
///
90-
/// Tiling can be thought of as splitting a dimension into 2 and materializing
91-
/// the outer dimension as a loop:
92-
///
93-
/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
94-
///
95-
/// For parallel dimensions, the split can only happen in one way, with both
96-
/// dimensions being parallel. For reduction dimensions however, there is a
97-
/// choice in how we split the reduction dimension. This enum exposes this
98-
/// choice.
99-
enum class ReductionTilingStrategy {
100-
// [reduction] -> [reduction1, reduction2]
101-
// -> loop[reduction1] { [reduction2] }
102-
FullReduction,
103-
// [reduction] -> [reduction1, parallel2]
104-
// -> loop[reduction1] { [parallel2] }; merge[reduction1]
105-
PartialReductionOuterReduction,
106-
// [reduction] -> [parallel1, reduction2]
107-
// -> loop[parallel1] { [reduction2] }; merge[parallel1]
108-
PartialReductionOuterParallel
109-
};
110103
ReductionTilingStrategy reductionStrategy =
111104
ReductionTilingStrategy::FullReduction;
112105
SCFTilingOptions &
@@ -115,13 +108,13 @@ struct SCFTilingOptions {
115108
return *this;
116109
}
117110

118-
/// Specify mapping of loops to devices. This is only respected when the loop
119-
/// constructs support such a mapping (like `scf.forall`). Will be ignored
120-
/// when using loop constructs that dont support such a mapping (like
121-
/// `scf.for`)
122-
SmallVector<Attribute> mappingVector = {};
123-
SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
124-
mappingVector = llvm::to_vector(mapping);
111+
/// Specify the reduction dimensions to be tiled. Note that this needs to be
112+
/// specified. If left unspecified, then none of the reduction dimensions are
113+
/// tiled.
114+
SetVector<unsigned> reductionDims;
115+
SCFTilingOptions &setReductionDims(ArrayRef<unsigned> dims) {
116+
reductionDims.clear();
117+
reductionDims.insert(dims.begin(), dims.end());
125118
return *this;
126119
}
127120
};

mlir/include/mlir/Interfaces/TilingInterface.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,27 @@ struct TilingResult {
3636
SmallVector<Operation *> generatedSlices;
3737
};
3838

39+
/// Tiling can be thought of as splitting a dimension into 2 and
40+
/// materializing the outer dimension as a loop:
41+
///
42+
/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
43+
///
44+
/// For parallel dimensions, the split can only happen in one way, with both
45+
/// dimensions being parallel. For reduction dimensions however, there is a
46+
/// choice in how we split the reduction dimension. This enum exposes this
47+
/// choice.
48+
enum class ReductionTilingStrategy {
49+
// [reduction] -> [reduction1, reduction2]
50+
// -> loop[reduction1] { [reduction2] }
51+
FullReduction,
52+
// [reduction] -> [reduction1, parallel2]
53+
// -> loop[reduction1] { [parallel2] }; merge[reduction1]
54+
PartialReductionOuterReduction,
55+
// [reduction] -> [parallel1, reduction2]
56+
// -> loop[parallel1] { [reduction2] }; merge[parallel1]
57+
PartialReductionOuterParallel
58+
};
59+
3960
/// Container for the result of merge operation of tiling.
4061
/// - `mergeOps` contains operations created during the merge.
4162
/// - `replacements` contains the values that represents the result of the

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def PartialReductionOpInterface :
384384
"::mlir::OpBuilder &":$b,
385385
"Location":$loc,
386386
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
387-
"::mlir::ArrayRef<int>":$reductionDim),
387+
"const ::mlir::SetVector<unsigned> &":$reductionDims),
388388
/*methodBody=*/"",
389389
/*defaultImplementation=*/[{
390390
return failure();
@@ -402,10 +402,11 @@ def PartialReductionOpInterface :
402402
/*args=*/(ins
403403
"::mlir::OpBuilder &":$b,
404404
"Location ":$loc,
405+
"::mlir::ReductionTilingStrategy":$tilingStrategy,
405406
"ValueRange":$init,
406407
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
407408
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
408-
"::mlir::ArrayRef<int>":$reductionDims),
409+
"const ::llvm::SetVector<unsigned> &":$reductionDims),
409410
/*methodBody=*/"",
410411
/*defaultImplementation=*/[{
411412
return failure();
@@ -423,7 +424,7 @@ def PartialReductionOpInterface :
423424
"::mlir::OpBuilder &":$b,
424425
"Location ":$loc,
425426
"ValueRange":$partialReduce,
426-
"::mlir::ArrayRef<int>":$reductionDim),
427+
"const ::mlir::SetVector<unsigned> &":$reductionDims),
427428
/*methodBody=*/"",
428429
/*defaultImplementation=*/[{
429430
return failure();
@@ -443,9 +444,9 @@ def PartialReductionOpInterface :
443444
"unsigned":$resultNumber,
444445
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
445446
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
447+
"const ::mlir::SetVector<unsigned> &":$reductionDims,
446448
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
447-
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
448-
"::mlir::ArrayRef<int>":$reductionDims),
449+
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes),
449450
/*methodBody=*/"",
450451
/*defaultImplementation=*/[{
451452
return failure();

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2966,10 +2966,11 @@ void transform::TileReductionUsingForOp::build(
29662966
// TODO: support mixed static-dynamic (see TileUsingForallOp).
29672967
MLIRContext *ctx = builder.getContext();
29682968
auto opTy = transform::AnyOpType::get(ctx);
2969-
auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2969+
auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
29702970
build(builder, result,
29712971
/*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
29722972
/*target=*/target,
2973+
/*reduction_dims=*/nullptr,
29732974
/*tile_sizes=*/staticTileSizesAttr);
29742975
}
29752976

@@ -2985,12 +2986,30 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
29852986
target->getLoc(),
29862987
"Operation should implement PartialReductionOpInterface");
29872988
}
2988-
FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
2989-
rewriter, partialReductionOp,
2990-
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
29912989

2992-
if (failed(result))
2993-
return emitDefaultSilenceableFailure(target);
2990+
SmallVector<unsigned> reductionDims =
2991+
extractFromIntegerArrayAttr<unsigned>(getReductionDims());
2992+
if (reductionDims.empty()) {
2993+
for (auto [idx, iteratorType] :
2994+
llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
2995+
if (iteratorType == utils::IteratorType::reduction)
2996+
reductionDims.push_back(idx);
2997+
}
2998+
}
2999+
3000+
scf::SCFTilingOptions options;
3001+
options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
3002+
options.setReductionTilingStrategy(
3003+
ReductionTilingStrategy::PartialReductionOuterReduction);
3004+
options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
3005+
options.setReductionDims(reductionDims);
3006+
FailureOr<scf::SCFTilingResult> result =
3007+
scf::tileUsingSCF(rewriter, partialReductionOp, options);
3008+
3009+
if (failed(result)) {
3010+
return emitSilenceableFailure(getLoc(),
3011+
"failed to tile using partial reduction");
3012+
}
29943013
rewriter.replaceOp(target, result->replacements);
29953014
for (Value initValue : result->initialValues)
29963015
results.push_back(initValue.getDefiningOp());

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
109109
}
110110

111111
FailureOr<StaticContinuousTileSizeSpecification>
112-
mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
113-
unsigned dimension,
112+
mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
114113
unsigned targetSize) {
115114

116115
assert(!op.hasDynamicShape() &&
@@ -183,8 +182,8 @@ mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
183182

184183
// Find the trip count of the iteration space dimension for which the tile
185184
// sizes are computed.
186-
Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
187-
loopRanges[dimension].size);
185+
Value loopRange =
186+
getValueOrCreateConstantIndexOp(b, loc, loopRanges[dimension].size);
188187
ContinuousTileSizeSpecification spec;
189188

190189
// Compute the tile sizes and the respective numbers of tiles.
@@ -633,16 +632,18 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
633632
if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
634633
return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
635634
"many elements as number of threads");
636-
int reductionDim = static_cast<int>(redDims.front());
637635

638636
if (redDims.front() >= numThreads.size())
639637
return b.notifyMatchFailure(
640638
op, "reduction dimension must be mapped to threads");
641639

642640
// 1. Create the inital tensor value.
641+
unsigned reductionDim = redDims.front();
642+
SetVector<unsigned> reductionDims;
643+
reductionDims.insert(reductionDim);
643644
FailureOr<SmallVector<Value>> maybeInitTensors =
644645
op.generateInitialTensorForPartialReduction(b, loc, numThreads,
645-
reductionDim);
646+
reductionDims);
646647
if (failed(maybeInitTensors))
647648
return b.notifyMatchFailure(
648649
op, "Failed to create inital tensors for partial reduction");
@@ -780,7 +781,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
780781
// 7. Merge the partial reductions.
781782
b.setInsertionPointAfter(forallOp);
782783
FailureOr<MergeResult> mergeResult =
783-
op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
784+
op.mergeReductions(b, loc, forallOp->getResults(), reductionDims);
784785
if (failed(mergeResult)) {
785786
return failure();
786787
}

0 commit comments

Comments
 (0)