diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 38c8734c47381..9d6ce653e285c 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2019,6 +2019,7 @@ def TileReductionUsingForallOp : // TODO: support mixed static-dynamic (see TileUsingForallOp). let arguments = (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$reduction_dims, DefaultValuedAttr:$num_threads, DefaultValuedAttr:$tile_sizes, OptionalAttr:$mapping); @@ -2036,10 +2037,11 @@ def TileReductionUsingForallOp : let assemblyFormat = [{ $target + (`reduction_dims` `=` $reduction_dims^)? `by` (`num_threads` `=` $num_threads^)? - (`,` `tile_sizes` `=` $tile_sizes^)? - (`,` `mapping` `=` $mapping^)? + (`tile_sizes` `=` $tile_sizes^)? + (`mapping` `=` $mapping^)? attr-dict `:` functional-type(operands, results) }]; diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index b37fb55b67931..77c376fb9973a 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -156,7 +156,7 @@ SmallVector getMixedValues(ArrayRef staticValues, /// corresponding pair of arrays. This is the inverse function of /// `getMixedValues`. std::pair, SmallVector> -decomposeMixedValues(const SmallVectorImpl &mixedValues); +decomposeMixedValues(ArrayRef mixedValues); /// Helper to sort `values` according to matching `keys`. SmallVector diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td index 43a27e1cb6cdf..0de37338c95e4 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -367,15 +367,20 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface", [TilingInterface]> { let description = [{ Interface for allowing operations to expose information needed to - tile reductions using partial reduction followed by merge. This is - complementary to TilingInterface to tile reductions. + tile reductions using partial reduction followed by merge. This + extends the `TilingInterface` to allow splitting a reduction + dimension into a parallel dimension and reduction dimension. + The materialized inter-tile loop could either be the reduction dimension + (i.e. `ReductionTilingStrategy::PartialReductionOuterReduction`) or + the parallel dimension (i.e + `ReductionTilingStrategy::PartialReductionOuterReduction`). }]; let cppNamespace = "::mlir"; let methods = [ InterfaceMethod< /*desc=*/[{ Method to generate a tensor initalized with the identity value of the - operation reduction. The tensor shape is equal to operation result + reduction operator. The tensor shape is equal to operation result shape with new dimension for each non zero tile size. }], /*retType=*/"::mlir::FailureOr>", @@ -383,7 +388,7 @@ def PartialReductionOpInterface : /*args=*/(ins "::mlir::OpBuilder &":$b, "Location":$loc, - "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes, + "::mlir::ArrayRef<::mlir::OpFoldResult>":$tileSizes, "const ::mlir::SetVector &":$reductionDims), /*methodBody=*/"", /*defaultImplementation=*/[{ @@ -396,6 +401,11 @@ def PartialReductionOpInterface : reduction dimension are converted to parallel dimensions with a size less or equal to the tile size. This is meant to be used with `mergeReductions` method which will combine the partial reductions. + The method recieves the `offset` and `sizes` for all iteration space + dimensions, as well as the iteration number of the tiled reduction + dimensions (which is the induction variable of the inter-tile loop + for the reduction dimension divided by the step of the loop) in + `splitReductionIvs`. }], /*retType=*/"::mlir::FailureOr", /*methodName=*/"tileToPartialReduction", @@ -406,7 +416,8 @@ def PartialReductionOpInterface : "ValueRange":$init, "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets, "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes, - "const ::llvm::SetVector &":$reductionDims), + "const ::llvm::SetVector &":$reductionDims, + "::mlir::ArrayRef<::mlir::OpFoldResult>":$splitReductionIvs), /*methodBody=*/"", /*defaultImplementation=*/[{ return failure(); @@ -436,15 +447,22 @@ def PartialReductionOpInterface : the tiled operation. This is same as TilingInterface:::getResultTilePosition, but determines the result tile position for partial reduction. + The method recieves the `offset` and `sizes` for all iteration space + dimensions, as well as the iteration number of the tiled reduction + dimensions (which is the induction variable of the inter-tile loop + for the reduction dimension divided by the tile size specified) in + `splitReductionIvs`. }], /*retType=*/"::llvm::LogicalResult", /*methodName=*/"getPartialResultTilePosition", /*args=*/(ins "::mlir::OpBuilder &":$b, "unsigned":$resultNumber, + "ReductionTilingStrategy":$tilingStrategy, "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets, "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes, "const ::mlir::SetVector &":$reductionDims, + "::mlir::ArrayRef<::mlir::OpFoldResult>":$splitReductionIvs, "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets, "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes), /*methodBody=*/"", diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index f2b7b34256847..2355edea2df6c 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3022,6 +3022,7 @@ void transform::TileReductionUsingForallOp::build( build(builder, result, /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy}, /*target=*/target, + /*reduction_dims=*/{}, /*num_threads=*/staticNumThreadsAttr, /*tile_sizes=*/staticTileSizesAttr, /*mapping=*/mapping); @@ -3036,23 +3037,45 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne( getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads())); SmallVector tileSizes = getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())); - FailureOr result = - linalg::tileReductionUsingForall( - rewriter, cast(target.getOperation()), - numThreads, tileSizes, getMapping()); + + scf::SCFTilingOptions options; + options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); + options.setReductionTilingStrategy( + ReductionTilingStrategy::PartialReductionOuterParallel); + if (!getNumThreads().empty()) { + options.setNumThreads(numThreads); + } else { + options.setTileSizes(tileSizes); + } + if (auto mapping = getMapping()) { + options.setMapping(mapping.value().getValue()); + } + SmallVector reductionDims = + extractFromIntegerArrayAttr(getReductionDims()); + if (reductionDims.empty()) { + for (auto [idx, iteratorType] : + llvm::enumerate(target.getIteratorTypesArray())) { + if (iteratorType == utils::IteratorType::reduction) + reductionDims.push_back(idx); + } + } + options.setReductionDims(reductionDims); + FailureOr result = scf::tileUsingSCF( + rewriter, cast(target.getOperation()), options); if (failed(result)) { auto diag = emitSilenceableError() << "could not tile reduction"; - diag.attachNote(target.getLoc()) << "target operation"; return diag; } + rewriter.replaceOp(target, result->replacements); + for (Value initValue : result->initialValues) results.push_back(initValue.getDefiningOp()); - for (auto parallelTiledOp : result->parallelTiledOps) + for (auto parallelTiledOp : result->tiledOps) results.push_back(parallelTiledOp); for (auto mergeOp : result->mergeOps) results.push_back(mergeOp); - results.push_back(result->loops); + results.push_back(result->loops.front()); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index f649bc49a8fbd..19d484a3bb701 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -328,6 +328,17 @@ struct LinalgOpTilingInterface // External Model for implementing `PartialReductionInterface` for `LinalgOp`s. //===----------------------------------------------------------------------===// +/// In a given set vector, get the position of a particular element. +std::optional getPositionIn(const llvm::SetVector &reductionDims, + unsigned value) { + for (auto [index, reductionDim] : llvm::enumerate(reductionDims)) { + if (reductionDim == value) { + return index; + } + } + return std::nullopt; +} + /// Return an AffineMaps to use for the `outs` operands of the linalg op /// generated for partial results. The new AffineMap is the AffineMap of the /// untiled op with reduction dimensions appended at end in order in which they @@ -348,28 +359,86 @@ getPartialResultAffineMaps(LinalgOp linalgOp, return partialReductionMaps; } -/// Return the slice of the `initValue` to use as input to the partial reduction -/// op generated. -static Operation *getInitSliceForOuterReduction( - OpBuilder &b, Location loc, Value initValue, ArrayRef offsets, +struct InitSliceInfo { + SmallVector resultShape; + SmallVector offsets; + SmallVector sizes; + SmallVector strides; +}; + +/// Return the result shape, offsets, sizes and strides of the slice of the +/// `initValue` to use as the destination of the partial reduction op generated +/// with outer reduction strategy. +static InitSliceInfo getInitSliceInfoForOuterReduction( + MLIRContext *context, ArrayRef offsets, ArrayRef sizes, const SetVector &reductionDims, - AffineMap partialReductionMap) { + ArrayRef splitReductionIvs, AffineMap partialReductionMap) { int64_t initRank = partialReductionMap.getNumResults(); SmallVector initOffsets, initSizes; - SmallVector initStrides(initRank, b.getIndexAttr(1)); + Attribute zero = IntegerAttr::get(IndexType::get(context), 0); + Attribute one = IntegerAttr::get(IndexType::get(context), 1); + SmallVector initStrides(initRank, one); for (AffineExpr dimExpr : partialReductionMap.getResults()) { unsigned dim = cast(dimExpr).getPosition(); if (reductionDims.contains(dim)) { - initOffsets.push_back(b.getIndexAttr(0)); + initOffsets.push_back(zero); } else { initOffsets.push_back(offsets[dim]); } initSizes.push_back(sizes[dim]); } - // TODO: Use SubsetExtractOpInterface here once available. - auto extractSlice = b.create( - loc, initValue, initOffsets, initSizes, initStrides); - return extractSlice; + SmallVector resultShape; + std::tie(resultShape, std::ignore) = decomposeMixedValues(initSizes); + return {resultShape, initOffsets, initSizes, initStrides}; +} + +/// Return the result shape, offsets, sizes and strides of the slice of the +/// `initValue` to use as destination of the partial reduction op generated with +/// outer parallel strategy. +static InitSliceInfo getInitSliceInfoForOuterParallel( + MLIRContext *context, ArrayRef offsets, + ArrayRef sizes, const SetVector &reductionDims, + ArrayRef splitReductionIvs, AffineMap partialReductionMap) { + int64_t initRank = partialReductionMap.getNumResults(); + SmallVector initOffsets, initSizes; + Attribute one = IntegerAttr::get(IndexType::get(context), 1); + SmallVector initStrides(initRank, one); + SmallVector resultShape; + for (AffineExpr dimExpr : partialReductionMap.getResults()) { + unsigned dim = cast(dimExpr).getPosition(); + if (std::optional dimPos = getPositionIn(reductionDims, dim)) { + initOffsets.push_back(splitReductionIvs[dimPos.value()]); + initSizes.push_back(one); + } else { + initOffsets.push_back(offsets[dim]); + initSizes.push_back(sizes[dim]); + resultShape.push_back(sizes[dim]); + } + } + SmallVector staticShapes; + std::tie(staticShapes, std::ignore) = decomposeMixedValues(resultShape); + return {staticShapes, initOffsets, initSizes, initStrides}; +} + +/// Return the result shape, offsets, sizes and strides of the slice of the +/// `initValue` to use as destination of the partial reduction op. +static InitSliceInfo getInitSliceInfo(MLIRContext *context, + ReductionTilingStrategy strategy, + ArrayRef offsets, + ArrayRef sizes, + const SetVector &reductionDims, + ArrayRef splitReductionIvs, + AffineMap partialReductionMap) { + if (strategy == ReductionTilingStrategy::PartialReductionOuterReduction) { + return getInitSliceInfoForOuterReduction(context, offsets, sizes, + reductionDims, splitReductionIvs, + partialReductionMap); + } + assert(strategy == ReductionTilingStrategy::PartialReductionOuterParallel && + "unexpected ReductionTilingStrategy"); + return getInitSliceInfoForOuterParallel(context, offsets, sizes, + reductionDims, splitReductionIvs, + partialReductionMap); } /// External model implementation of PartialReductionInterface for @@ -390,21 +459,6 @@ struct LinalgOpPartialReductionInterface SmallVector partialResultMaps = getPartialResultAffineMaps(linalgOp, reductionDims); - // LinalgOp implements TilingInterface. - auto tilingInterfaceOp = cast(linalgOp.getOperation()); - SmallVector shape = - llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b), - [](Range x) { return x.size; }); - - SmallVector tiledShape; - for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) { - if (isZeroInteger(tileSize)) { - tiledShape.push_back(dimSize); - } else { - tiledShape.push_back(tileSize); - } - } - SmallVector inits; for (auto [initIdx, result, partialMap] : llvm::enumerate(linalgOp->getResults(), partialResultMaps)) { @@ -424,7 +478,7 @@ struct LinalgOpPartialReductionInterface SmallVector partialResultShape; for (AffineExpr dimExpr : partialMap.getResults()) { auto dim = cast(dimExpr); - partialResultShape.push_back(tiledShape[dim.getPosition()]); + partialResultShape.push_back(sizes[dim.getPosition()]); } Type elType = getElementTypeOrSelf(result.getType()); @@ -444,13 +498,8 @@ struct LinalgOpPartialReductionInterface ReductionTilingStrategy tilingStrategy, ValueRange init, ArrayRef offsets, ArrayRef sizes, - const SetVector &reductionDims) const { - if (tilingStrategy != - ReductionTilingStrategy::PartialReductionOuterReduction) { - // TODO: Add support for `PartialReductionOuterParallel` strategy. - return op->emitOpError("unsupported partial reduction tiling with " - "`PartialReductionOuterParallel` strategy"); - } + const SetVector &reductionDims, + ArrayRef splitReductionIvs) const { OpBuilder::InsertionGuard guard(b); auto linalgOp = cast(op); @@ -459,7 +508,16 @@ struct LinalgOpPartialReductionInterface // Step 1. Extend init maps to have reduction dimension dims, since we // are converting them to parallel dimensions. - SmallVector newInitMaps = partialReductionMaps; + SmallVector newInitMaps; + if (tilingStrategy == + ReductionTilingStrategy::PartialReductionOuterReduction) { + newInitMaps = llvm::to_vector(partialReductionMaps); + } else { + newInitMaps = llvm::map_to_vector( + linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) { + return linalgOp.getMatchingIndexingMap(&opOperand); + }); + } // Step 2a: Extract a slice of the input operands. SmallVector tiledInputs = makeTiledShapes( @@ -473,10 +531,17 @@ struct LinalgOpPartialReductionInterface SmallVector tiledInits; for (auto [partialReductionMap, valueToTile] : llvm::zip_equal(partialReductionMaps, init)) { - Operation *sliceOp = - getInitSliceForOuterReduction(b, loc, valueToTile, offsets, sizes, - reductionDims, partialReductionMap); - tiledInits.push_back(sliceOp->getResult(0)); + InitSliceInfo sliceInfo = getInitSliceInfo( + b.getContext(), tilingStrategy, offsets, sizes, reductionDims, + splitReductionIvs, partialReductionMap); + auto valueToTileType = cast(valueToTile.getType()); + RankedTensorType sliceResultType = RankedTensorType::get( + sliceInfo.resultShape, valueToTileType.getElementType(), + valueToTileType.getEncoding()); + auto sliceOp = b.create( + loc, sliceResultType, valueToTile, sliceInfo.offsets, sliceInfo.sizes, + sliceInfo.strides); + tiledInits.push_back(sliceOp.getResult()); generatedSlices.push_back(sliceOp); } @@ -491,19 +556,31 @@ struct LinalgOpPartialReductionInterface // Step 3. Change the reduction dim iterator types. SmallVector newIteratorTypes = linalgOp.getIteratorTypesArray(); - for (int dim : reductionDims) - newIteratorTypes[dim] = utils::IteratorType::parallel; + if (tilingStrategy == + ReductionTilingStrategy::PartialReductionOuterReduction) { + for (int dim : reductionDims) + newIteratorTypes[dim] = utils::IteratorType::parallel; + } // Step 4. Create the new generic op. + Operation *partialReductionOp; auto resultTypes = ValueRange(tiledInits).getTypes(); - auto genericOp = b.create(loc, resultTypes, tiledInputs, - tiledInits, newMaps, newIteratorTypes); - IRMapping mapping; - op->getRegion(0).cloneInto(&genericOp.getRegion(), - genericOp.getRegion().begin(), mapping); + if (tilingStrategy == + ReductionTilingStrategy::PartialReductionOuterReduction) { + auto genericOp = b.create( + loc, resultTypes, tiledInputs, tiledInits, newMaps, newIteratorTypes); + IRMapping mapping; + op->getRegion(0).cloneInto(&genericOp.getRegion(), + genericOp.getRegion().begin(), mapping); + partialReductionOp = genericOp.getOperation(); + } else { + SmallVector operands = std::move(tiledInputs); + llvm::append_range(operands, tiledInits); + partialReductionOp = mlir::clone(b, op, resultTypes, operands); + } return TilingResult{ - {genericOp.getOperation()}, - llvm::map_to_vector(genericOp->getResults(), + {partialReductionOp}, + llvm::map_to_vector(partialReductionOp->getResults(), [](OpResult r) -> Value { return r; }), generatedSlices}; } @@ -558,26 +635,19 @@ struct LinalgOpPartialReductionInterface LogicalResult getPartialResultTilePosition( Operation *op, OpBuilder &b, unsigned resultNumber, - ArrayRef offsets, ArrayRef sizes, - const SetVector &reductionDims, + ReductionTilingStrategy tilingStrategy, ArrayRef offsets, + ArrayRef sizes, const SetVector &reductionDims, + ArrayRef splitReductionIvs, SmallVector &resultOffsets, SmallVector &resultSizes) const { auto linalgOp = cast(op); SmallVector partialReductionMaps = getPartialResultAffineMaps(linalgOp, reductionDims); - - for (AffineExpr dimExpr : partialReductionMaps[resultNumber].getResults()) { - unsigned dim = cast(dimExpr).getPosition(); - resultSizes.push_back(sizes[dim]); - - if (llvm::is_contained(reductionDims, dim)) { - // Reduction dims are reduced, and are always outputed in the same - // place. So use offset 0 for them. - resultOffsets.push_back(b.getIndexAttr(0)); - } else { - resultOffsets.push_back(offsets[dim]); - } - } + InitSliceInfo sliceInfo = getInitSliceInfo( + b.getContext(), tilingStrategy, offsets, sizes, reductionDims, + splitReductionIvs, partialReductionMaps[resultNumber]); + std::swap(resultOffsets, sliceInfo.offsets); + std::swap(resultSizes, sliceInfo.sizes); return success(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index e7c076024e67b..ddcae8481a5b4 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -166,12 +166,11 @@ static LogicalResult checkTileSizes(TilingInterface op, assert((numThreads.empty() || (numThreads.size() == iterators.size())) && "when specified, expected number of threads to use for each loop"); - bool isParallelTiling = false, isReductionTiling = false; + bool isParallelTiling = false; for (auto [index, iterator, tileSize] : llvm::enumerate(iterators, tileSizes)) { if (!isConstantIntValue(tileSize, 0)) { isParallelTiling |= iterator == utils::IteratorType::parallel; - isReductionTiling |= iterator == utils::IteratorType::reduction; } if (loopType == scf::SCFTilingOptions::LoopType::ForallOp && @@ -199,15 +198,29 @@ static LogicalResult checkTileSizes(TilingInterface op, } } - if (isParallelTiling && isReductionTiling && - reductionStrategy != ReductionTilingStrategy::FullReduction) { - return op->emitOpError( - "combined parallel and reduction tiling is not supported with partial " - "reduction tiling strategies"); + if (reductionStrategy != ReductionTilingStrategy::FullReduction) { + if (isParallelTiling) { + return op->emitOpError("tiling parallel dimensions is not supported with " + "partial reduction tiling strategies"); + } } return success(); } +/// Get the reduction dims that are tiled. This accounts for reduction dims +/// that are specified as tiled, but the tile size is 0. +static SetVector +getSanitizedReductionDims(ArrayRef tileSizes, + const scf::SCFTilingOptions &options) { + SetVector reductionDims; + for (auto dim : options.reductionDims) { + if (isConstantIntValue(tileSizes[dim], 0)) + continue; + reductionDims.insert(dim); + } + return reductionDims; +} + /// Check if `stride` evenly divides the trip count `size - offset`. static bool tileDividesIterationDomain(Range loopRange) { std::optional offsetAsInt = getConstantIntValue(loopRange.offset); @@ -264,10 +277,12 @@ static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, /// `offset`s and `size`s of the tile of the iteration space that the /// innermost loop body of the generated tiled loops corresponds to. static std::tuple, SmallVector> -getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, +getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, + ReductionTilingStrategy strategy, ValueRange ivs, ArrayRef iterationDomain, ArrayRef tileSizes, - ArrayRef numThreads) { + ArrayRef numThreads, + const llvm::SetVector &reductionDims) { SmallVector offsets, sizes; int materializedLoopNum = 0; @@ -279,8 +294,8 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, offsetExpr = d0 + d1 * s0; residualTileSizeExpr = s1 - (d0 + d1 * s0); - for (auto [nt, tileSize, loopRange] : - llvm::zip_equal(numThreads, tileSizes, iterationDomain)) { + for (auto [index, nt, tileSize, loopRange] : + llvm::enumerate(numThreads, tileSizes, iterationDomain)) { // Non-tiled cases, set the offset and size to the // `loopRange.offset/size`. @@ -564,9 +579,10 @@ static LogicalResult generateLoopNestUsingForallOp( /// - `loops` is an in-out parameter into which the generated loops are /// populated. static LogicalResult generateLoopNest( - RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, - ArrayRef loopRanges, ArrayRef tileSizes, - ArrayRef numThreads, ValueRange destinationTensors, + RewriterBase &rewriter, Location loc, + scf::SCFTilingOptions::LoopType loopType, ArrayRef loopRanges, + ArrayRef tileSizes, ArrayRef numThreads, + ValueRange destinationTensors, ArrayRef mappingVector, YieldTiledValuesFn tiledBodyFn, SmallVector &loops) { // If the tile sizes are all zero, no loops are generated. Just call the // callback function to handle untiled case. @@ -576,25 +592,26 @@ static LogicalResult generateLoopNest( return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors, tiledResults, resultOffsets, resultSizes); } - if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) { + if (loopType == scf::SCFTilingOptions::LoopType::ForOp) { return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes, destinationTensors, tiledBodyFn, loops); } - if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { + if (loopType == scf::SCFTilingOptions::LoopType::ForallOp) { return generateLoopNestUsingForallOp( - rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector, + rewriter, loc, loopRanges, tileSizes, numThreads, mappingVector, destinationTensors, tiledBodyFn, loops); } return rewriter.notifyMatchFailure(loc, "unhandled loop type"); } -static FailureOr> -createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, - ArrayRef tileSizes, - const scf::SCFTilingOptions &options) { +static FailureOr> createInitialTensorsForTiling( + RewriterBase &rewriter, TilingInterface op, + ReductionTilingStrategy reductionStrategy, ArrayRef iterationDomain, + ArrayRef numThreads, ArrayRef tileSizes, + const SetVector &reductionDims) { SmallVector initTensors; Location loc = op->getLoc(); - if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) { + if (reductionStrategy == ReductionTilingStrategy::FullReduction) { if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors))) return failure(); return initTensors; @@ -602,20 +619,94 @@ createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, auto redOp = dyn_cast(op.getOperation()); if (!redOp) { - return rewriter.notifyMatchFailure( - op, "PartialReductionOuterReduction tiling strategy is only supported" - "for operations implementing PartialReductionOpInterface"); + return op->emitOpError( + "PartialReductionOuterReduction tiling strategy is only supported for " + "operations implementing PartialReductionOpInterface"); + } + SmallVector sizes(iterationDomain.size()); + AffineExpr s0, s1, s2; + bindSymbols(rewriter.getContext(), s0, s1, s2); + AffineExpr sizeExpr = ((s0 - s1).ceilDiv(s2)); + AffineExpr divExpr = s0.ceilDiv(s1); + for (auto [index, domain, tileSize] : + llvm::enumerate(iterationDomain, tileSizes)) { + if (!numThreads.empty()) { + // Untiled case. + if (isConstantIntValue(numThreads[index], 0)) { + sizes[index] = affine::makeComposedFoldedAffineApply( + rewriter, op.getLoc(), sizeExpr, + {domain.size, domain.offset, domain.stride}); + continue; + } + sizes[index] = numThreads[index]; + continue; + } + + // Non reduction dimensions/non-tiled dimensions. + if (!reductionDims.contains(index) || isConstantIntValue(tileSize, 0)) { + sizes[index] = affine::makeComposedFoldedAffineApply( + rewriter, op.getLoc(), sizeExpr, + {domain.size, domain.offset, domain.stride}); + continue; + } + + if (reductionStrategy == + ReductionTilingStrategy::PartialReductionOuterReduction) { + sizes[index] = tileSize; + continue; + } + + assert(reductionStrategy == + ReductionTilingStrategy::PartialReductionOuterParallel); + OpFoldResult normalizedRange = affine::makeComposedFoldedAffineApply( + rewriter, op.getLoc(), sizeExpr, + {domain.size, domain.offset, domain.stride}); + sizes[index] = affine::makeComposedFoldedAffineApply( + rewriter, op.getLoc(), divExpr, {normalizedRange, tileSize}); + } + return redOp.generateInitialTensorForPartialReduction(rewriter, loc, sizes, + reductionDims); +} + +/// For the case of `ReductionTilingStrategy::PartialReductionOuterParallel` +/// the `PartialReductionOpInterface` methods need the index of the parallel +/// split reduction being executed. +static SmallVector +getSplitReductionIvs(RewriterBase &rewriter, Location loc, + ReductionTilingStrategy reductionStrategy, ValueRange ivs, + ArrayRef numThreads, + ArrayRef tileSizes, + const SetVector &reductionDims) { + SmallVector splitReductionIvs; + splitReductionIvs.resize(reductionDims.size(), rewriter.getIndexAttr(0)); + AffineExpr s0, s1; + bindSymbols(rewriter.getContext(), s0, s1); + AffineExpr divExpr = s0.ceilDiv(s1); + int ivIndex = 0; + if (reductionStrategy == + ReductionTilingStrategy::PartialReductionOuterParallel) { + for (auto [index, reductionDim] : llvm::enumerate(reductionDims)) { + if (!numThreads.empty()) { + splitReductionIvs[index] = ivs[ivIndex++]; + continue; + } + splitReductionIvs[index] = affine::makeComposedFoldedAffineApply( + rewriter, loc, divExpr, + ArrayRef{ivs[ivIndex++], tileSizes[reductionDim]}); + } } - return redOp.generateInitialTensorForPartialReduction( - rewriter, loc, tileSizes, options.reductionDims); + return splitReductionIvs; } static FailureOr getTiledImplementation(RewriterBase &rewriter, TilingInterface op, + ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef offsets, - ArrayRef sizes, - const scf::SCFTilingOptions &options) { - if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) { + ArrayRef sizes, ValueRange ivs, + ArrayRef numThreads, + ArrayRef tileSizes, + const SetVector &reductionDims) { + if (reductionStrategy == ReductionTilingStrategy::FullReduction) { return op.getTiledImplementation(rewriter, offsets, sizes); } @@ -626,20 +717,25 @@ getTiledImplementation(RewriterBase &rewriter, TilingInterface op, "supported for operations " "implementing PartialReductionOpInterface"); } - return redOp.tileToPartialReduction(rewriter, op.getLoc(), - options.reductionStrategy, regionIterArg, - offsets, sizes, options.reductionDims); + + SmallVector splitReductionIvs = + getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs, + numThreads, tileSizes, reductionDims); + return redOp.tileToPartialReduction(rewriter, op.getLoc(), reductionStrategy, + regionIterArg, offsets, sizes, + reductionDims, splitReductionIvs); } -static LogicalResult -getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, - TilingInterface op, ArrayRef offsets, - ArrayRef sizes, - SmallVector &resultOffset, - SmallVector &resultSize, - const scf::SCFTilingOptions &options) { +static LogicalResult getResultTilePosition( + RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, + int64_t index, Value tiledResult, TilingInterface op, + ArrayRef offsets, ArrayRef sizes, + ValueRange ivs, ArrayRef numThreads, + ArrayRef tileSizes, const SetVector &reductionDims, + SmallVector &resultOffset, + SmallVector &resultSize) { - if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) { + if (reductionStrategy == ReductionTilingStrategy::FullReduction) { return op.getResultTilePosition(rewriter, index, offsets, sizes, resultOffset, resultSize); } @@ -649,16 +745,20 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, op, "PartialReductionOuterReduction tiling strategy is only supported" "for operations implementing PartialReductionOpInterface"); } - return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes, - options.reductionDims, resultOffset, - resultSize); + SmallVector splitReductionIvs = + getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs, + numThreads, tileSizes, reductionDims); + return redOp.getPartialResultTilePosition( + rewriter, index, reductionStrategy, offsets, sizes, reductionDims, + splitReductionIvs, resultOffset, resultSize); } static FailureOr mergeTilingResults(RewriterBase &rewriter, TilingInterface op, - ValueRange partialResults, - const scf::SCFTilingOptions &options) { - assert(options.reductionStrategy != ReductionTilingStrategy::FullReduction && + ReductionTilingStrategy reductionStrategy, + const SetVector &reductionDims, + ValueRange partialResults) { + assert(reductionStrategy != ReductionTilingStrategy::FullReduction && "expected merge to be called for only partial reduction cases"); auto redOp = dyn_cast(op.getOperation()); @@ -669,7 +769,7 @@ mergeTilingResults(RewriterBase &rewriter, TilingInterface op, "implementing PartialReductionOpInterface"); } return redOp.mergeReductions(rewriter, op.getLoc(), partialResults, - options.reductionDims); + reductionDims); } /// Append the specified additional `newInitOperands` operands to the @@ -911,6 +1011,10 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, return failure(); } + // Get the reduction dims + SetVector reductionDims = + getSanitizedReductionDims(tileSizes, options); + // 3. If there is an interchange specified, permute the iteration domain and // the tile sizes. SmallVector interchangeVector; @@ -938,7 +1042,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, // 4a. Compute the `offsets` and `sizes` to use for tiling. SmallVector offsets, sizes; std::tie(offsets, sizes) = getTileOffsetAndSizes( - rewriter, loc, ivs, iterationDomain, tileSizes, numThreads); + rewriter, loc, options.reductionStrategy, ivs, iterationDomain, + tileSizes, numThreads, reductionDims); // 4b. If interchange was provided, apply inverse of the interchange // to get back the offsets/sizes in the order to be specified. @@ -966,8 +1071,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, } // 5c. Tile the cloned operation. - tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs, - offsets, sizes, options); + tilingResult = getTiledImplementation( + rewriter, clonedOp, options.reductionStrategy, regionIterArgs, offsets, + sizes, ivs, numThreads, tileSizes, reductionDims); if (failed(tilingResult)) { rewriter.eraseOp(clonedOp); return op.emitOpError("faild to tile operation"); @@ -982,9 +1088,10 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, llvm::enumerate(tilingResult->tiledValues)) { tiledResults.push_back(tiledValue); SmallVector resultOffset, resultSize; - if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets, - sizes, resultOffset, resultSize, - options))) { + if (failed(getResultTilePosition( + rewriter, options.reductionStrategy, index, tiledValue, op, + offsets, sizes, ivs, numThreads, tileSizes, reductionDims, + resultOffset, resultSize))) { for (auto op : tilingResult->tiledOps) { rewriter.eraseOp(op); } @@ -999,8 +1106,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, }; // 6. Find the destination tensors to use for the operation. - FailureOr> maybeInits = - createInitialTensorsForTiling(rewriter, op, tileSizes, options); + FailureOr> maybeInits = createInitialTensorsForTiling( + rewriter, op, options.reductionStrategy, iterationDomain, numThreads, + tileSizes, reductionDims); if (failed(maybeInits)) { return rewriter.notifyMatchFailure( op, "unable to create initial tensors for tiling"); @@ -1009,8 +1117,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, // 7. Generate the tiled loops nest using the callback defined above. SmallVector loops; - if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain, - tileSizes, numThreads, initTensors, + if (failed(generateLoopNest(rewriter, op.getLoc(), options.loopType, + iterationDomain, tileSizes, numThreads, + initTensors, options.mappingVector, innerYieldTiledValuesFn, loops))) return op.emitOpError("failed to generate tiling loops"); assert(succeeded(tilingResult) && @@ -1038,8 +1147,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, } // The results of the loop needs to be merged. - FailureOr mergeResult = - mergeTilingResults(rewriter, op, loopResults, options); + FailureOr mergeResult = mergeTilingResults( + rewriter, op, options.reductionStrategy, reductionDims, loopResults); if (failed(mergeResult)) { return rewriter.notifyMatchFailure( op, "Failed to merge partial results from tiling"); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 04242cad9ecb6..72144ec71c5d2 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2315,13 +2315,13 @@ RankedTensorType ExtractSliceOp::inferResultType( RankedTensorType ExtractSliceOp::inferResultType( RankedTensorType sourceTensorType, ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { - SmallVector staticOffsets, staticSizes, staticStrides; - SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; - dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); - dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); - dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets, - staticSizes, staticStrides); + SmallVector staticSizes; + std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes); + assert(static_cast(staticSizes.size()) == + sourceTensorType.getRank() && + "unexpected staticSizes not equal to rank of source"); + return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(), + sourceTensorType.getEncoding()); } /// If the rank is reduced (i.e. the desiredResultRank is smaller than the diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 8e3f796af54df..be01ff2fa3781 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -208,7 +208,7 @@ SmallVector getMixedValues(ArrayRef staticValues, /// Decompose a vector of mixed static or dynamic values into the corresponding /// pair of arrays. This is the inverse function of `getMixedValues`. std::pair, SmallVector> -decomposeMixedValues(const SmallVectorImpl &mixedValues) { +decomposeMixedValues(ArrayRef mixedValues) { SmallVector staticValues; SmallVector dynamicValues; for (const auto &it : mixedValues) { diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir index 009ab17786696..075d02ab75ad1 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir @@ -112,7 +112,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0 - by num_threads = [0, 5], tile_sizes = [] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + by num_threads = [0, 5] tile_sizes = [] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) transform.yield } } @@ -134,10 +134,9 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]] // CHECK-DAG: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]]) // CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor to tensor -// CHECK: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]] -// CHECK: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor to tensor -// CHECK: %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0] [%[[D0]]] [1] : tensor to tensor -// CHECK: %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor) outs(%[[TEMPEXT]] : tensor) { +// CHECK-DAG: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]] +// CHECK-DAG: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor to tensor +// CHECK: %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor) outs(%[[ET]] : tensor) { // CHECK: arith.mulf // CHECK: arith.addf // CHECK: linalg.yield @@ -166,7 +165,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0 - by num_threads = [0, 0, 5], tile_sizes = [] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + by num_threads = [0, 0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) transform.yield } } @@ -187,11 +186,10 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]] // CHECK-DAG: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]]) // CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor to tensor -// CHECK: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]] -// CHECK: %[[INCHUNKA:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor to tensor -// CHECK: %[[INCHUNKB:.+]] = tensor.extract_slice %[[ARG1]][%[[TINDEX]], 0] [%[[TS1]], %[[D2]]] [1, 1] : tensor to tensor -// CHECK: %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0, 0] [%[[D0]], %[[D2]]] [1, 1] : tensor to tensor -// CHECK: %[[PARTIAL:.+]] = linalg.matmul ins(%[[INCHUNKA]], %[[INCHUNKB]] : tensor, tensor) outs(%[[TEMPEXT]] : tensor) -> tensor +// CHECK-DAG: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]] +// CHECK-DAG: %[[INCHUNKA:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor to tensor +// CHECK-DAG: %[[INCHUNKB:.+]] = tensor.extract_slice %[[ARG1]][%[[TINDEX]], 0] [%[[TS1]], %[[D2]]] [1, 1] : tensor to tensor +// CHECK: %[[PARTIAL:.+]] = linalg.matmul ins(%[[INCHUNKA]], %[[INCHUNKB]] : tensor, tensor) outs(%[[ET]] : tensor) -> tensor // CHECK: scf.forall.in_parallel { // CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor into tensor // CHECK: } @@ -204,113 +202,9 @@ module attributes {transform.with_named_sequence} { // ----- -func.func @reduction_tile_parallel_cyclic_dist( - %arg0: tensor, %out: tensor) -> tensor { - %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%arg0 : tensor) - outs(%out : tensor) { - ^bb0(%arg7: f32, %arg9: f32): - %1 = arith.mulf %arg7, %arg7 : f32 - %2 = arith.addf %1, %arg9 : f32 - linalg.yield %2 : f32 - } -> tensor - return %red : tensor -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0 - by num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - transform.yield - } -} - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 3)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)> -// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0)> - -// CHECK: func @reduction_tile_parallel_cyclic_dist(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor -// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index -// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[E:.*]] = tensor.empty(%[[D0]]) : tensor -// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor -// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { -// CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor to tensor -// CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[LB:.+]] = affine.apply #[[MAP0]]()[%[[IV]]] -// CHECK: %[[CARRY:.+]] = scf.for %[[IV1:.+]] = %[[LB]] to %[[D1]] step %[[C15]] iter_args(%[[ACC:.+]] = %[[ET]]) -> (tensor) { -// CHECK: %[[TS0:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[D1]]] -// CHECK: %[[D3:.+]] = tensor.dim %[[ACC]], %[[C0]] : tensor -// CHECK: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV1]]] [%[[D0]], %[[TS0]]] [1, 1] : tensor to tensor -// CHECK: %[[TEMPEXT:.+]] = tensor.extract_slice %[[ACC]][0] [%[[D3]]] [1] : tensor to tensor -// CHECK: %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor) outs(%[[TEMPEXT]] : tensor) { -// CHECK: arith.mulf -// CHECK: arith.addf -// CHECK: linalg.yield -// CHECK: } -> tensor -// CHECK: %[[INS:.+]] = tensor.insert_slice %[[PARTIAL]] into %[[ACC]][0] [%[[D3]]] [1] : tensor into tensor -// CHECK: scf.yield %[[INS]] : tensor -// CHECK: } -// CHECK: scf.forall.in_parallel { -// CHECK: tensor.parallel_insert_slice %[[CARRY]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor into tensor -// CHECK: } -// CHECK: } -// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor) outs(%[[ARG1]] : tensor) dimensions = [1] -// CHECK: arith.addf -// CHECK: linalg.yield -// CHECK: } -// CHECK: return %[[R]] : tensor - -// ----- - -func.func @reduction_tile_parallel_cyclic_dist( - %arg0: tensor, %out: tensor) -> tensor { - %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%arg0 : tensor) - outs(%out : tensor) { - ^bb0(%arg7: f32, %arg9: f32): - %1 = arith.mulf %arg7, %arg7 : f32 - %2 = arith.addf %1, %arg9 : f32 - linalg.yield %2 : f32 - } -> tensor - return %red : tensor -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0 - by num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - - // CHECK: expecting fill - // CHECK-NEXT: linalg.fill - transform.print %1 {name = "expecting fill"} : !transform.any_op - // CHECK: expecting parallel reduction - // CHECK-NEXT: linalg.generic - // CHECK: iterator_types = ["parallel", "reduction"] - transform.print %2 {name = "expecting parallel reduction"} : !transform.any_op - // CHECK: expecting parallel reduction - // CHECK-NEXT: linalg.reduce - // CHECK: iterator_types = ["parallel", "reduction"] - transform.print %3 {name = "expecting parallel reduction"} : !transform.any_op - transform.yield - } -} - -// ----- - func.func @reduction_untiled_forall( %arg0: tensor, %out: tensor) -> tensor { - // expected-note @below {{target operation}} + // expected-error @below {{tiling parallel dimensions is not supported with partial reduction tiling strategies}} %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} @@ -329,9 +223,8 @@ module attributes {transform.with_named_sequence} { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op // expected-error @below {{could not tile reduction}} %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0 - by num_threads = [5], tile_sizes = [3], mapping = [#gpu.thread] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - - transform.yield + by num_threads = [5] tile_sizes = [3] mapping = [#gpu.thread] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield } } @@ -643,3 +536,158 @@ module { // CHECK-SAME: outs(%[[INIT]] : // CHECK-SAME: dimensions = [1, 2] // CHECK: return %[[REDUCE]] + +// ----- + +func.func @reduction_tile_parallel_using_tile_sizes( + %arg0: tensor, %out: tensor) -> tensor { + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor) + outs(%out : tensor) { + ^bb0(%arg7: f32, %arg9: f32): + %1 = arith.mulf %arg7, %arg7 : f32 + %2 = arith.addf %1, %arg9 : f32 + linalg.yield %2 : f32 + } -> tensor + return %red : tensor +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 ceildiv 5)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)> +// CHECK: func @reduction_tile_parallel_using_tile_sizes(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK-DAG: %[[PARALLEL_DIM:.+]] = affine.apply #[[MAP0]]()[%[[D1]]] +// CHECK: %[[E:.*]] = tensor.empty(%[[D0]], %[[PARALLEL_DIM]]) : tensor +// CHECK: %[[F:.*]] = linalg.fill +// CHECK-SAME: outs(%[[E]] : +// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (%[[D1]]) step (5) shared_outs(%[[ARG3:.+]] = %[[F]]) +// CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP1]](%[[IV]])[%[[D1]]] +// CHECK-DAG: %[[INIT_OFFSET:.+]] = affine.apply #[[MAP0]]()[%[[IV]]] +// CHECK-DAG: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [%[[D0]], %[[TS0]]] [1, 1] +// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3]][0, %[[INIT_OFFSET]]] [%[[D0]], 1] [1, 1] +// CHECK: %[[PARTIAL:.+]] = linalg.generic +// CHECK-SAME: ins(%[[INCHUNK]] : +// CHECK-SAME: outs(%[[ET]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[INIT_OFFSET]]] [%[[D0]], 1] [1, 1] +// CHECK: } +// CHECK: } +// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] +// CHECK-SAME: outs(%[[ARG1]] : +// CHECK: return %[[R]] : tensor +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0 + by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// Check that only one of the reduction dimension can be tiled (in this case inner). + +#map = affine_map<(d0, d1, d2) -> (d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0)> +module { + func.func @reduction_using_forall_tile_single_of_multiple_reduction_inner( + %arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> { + %0 = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "reduction"]} + ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %1 = arith.mulf %in, %in_0 : f32 + %2 = arith.addf %1, %out : f32 + linalg.yield %2 : f32 + } -> tensor<4096xf32> + return %0 : tensor<4096xf32> + } + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %fill_op, %split_linalg_op, %combining_linalg_op, %for_op = + transform.structured.tile_reduction_using_forall %0 reduction_dims = [2] by tile_sizes = [0, 0, 64] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } + } +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 ceildiv 64)> +// CHECK: func @reduction_using_forall_tile_single_of_multiple_reduction_inner(%[[ARG0:.+]]: tensor<86x128xf32>, %[[ARG1:.+]]: tensor<4096x86x128xf32>, %[[ARG2:.+]]: tensor<4096xf32>) +// CHECK: %[[E:.*]] = tensor.empty() : tensor<4096x2xf32> +// CHECK: %[[F:.*]] = linalg.fill +// CHECK-SAME: outs(%[[E]] : +// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (128) step (64) shared_outs(%[[ARG3:.+]] = %[[F]]) +// CHECK-DAG: %[[INIT_OFFSET:.+]] = affine.apply #[[MAP0]]()[%[[IV]]] +// CHECK-DAG: %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [86, 64] [1, 1] +// CHECK-DAG: %[[ARG1_SLICE:.+]] = tensor.extract_slice %[[ARG1]][0, 0, %[[IV]]] [4096, 86, 64] [1, 1, 1] +// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3]][0, %[[INIT_OFFSET]]] [4096, 1] [1, 1] +// CHECK: %[[PARTIAL:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[ARG1_SLICE]] : +// CHECK-SAME: outs(%[[ET]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[INIT_OFFSET]]] [4096, 1] [1, 1] +// CHECK: } +// CHECK: } +// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] +// CHECK-SAME: outs(%[[ARG2]] : +// CHECK: return %[[R]] + +// ----- + +// Check that specifying both reduction dimensions, but setting tile size to 0 for one of them behaves consistent with specifying single reduction dimension. + +#map = affine_map<(d0, d1, d2) -> (d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0)> +module { + func.func @reduction_using_forall_tilesize_0_of_multiple_reduction_inner( + %arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> { + %0 = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "reduction"]} + ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %1 = arith.mulf %in, %in_0 : f32 + %2 = arith.addf %1, %out : f32 + linalg.yield %2 : f32 + } -> tensor<4096xf32> + return %0 : tensor<4096xf32> + } + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %fill_op, %split_linalg_op, %combining_linalg_op, %for_op = + transform.structured.tile_reduction_using_forall %0 reduction_dims = [1, 2] by tile_sizes = [0, 0, 64] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } + } +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 ceildiv 64)> +// CHECK: func @reduction_using_forall_tilesize_0_of_multiple_reduction_inner(%[[ARG0:.+]]: tensor<86x128xf32>, %[[ARG1:.+]]: tensor<4096x86x128xf32>, %[[ARG2:.+]]: tensor<4096xf32>) +// CHECK: %[[E:.*]] = tensor.empty() : tensor<4096x2xf32> +// CHECK: %[[F:.*]] = linalg.fill +// CHECK-SAME: outs(%[[E]] : +// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (128) step (64) shared_outs(%[[ARG3:.+]] = %[[F]]) +// CHECK-DAG: %[[INIT_OFFSET:.+]] = affine.apply #[[MAP0]]()[%[[IV]]] +// CHECK-DAG: %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [86, 64] [1, 1] +// CHECK-DAG: %[[ARG1_SLICE:.+]] = tensor.extract_slice %[[ARG1]][0, 0, %[[IV]]] [4096, 86, 64] [1, 1, 1] +// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3]][0, %[[INIT_OFFSET]]] [4096, 1] [1, 1] +// CHECK: %[[PARTIAL:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[ARG1_SLICE]] : +// CHECK-SAME: outs(%[[ET]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[INIT_OFFSET]]] [4096, 1] [1, 1] +// CHECK: } +// CHECK: } +// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] +// CHECK-SAME: outs(%[[ARG2]] : +// CHECK: return %[[R]]