From e3430216d00e9463961dd934671ba9563746c659 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sat, 19 Oct 2024 12:05:13 +0200 Subject: [PATCH 1/5] [mlir][Transforms] Merge 1:1 and 1:N type converters --- .../Dialect/SparseTensor/Transforms/Passes.h | 2 +- .../mlir/Transforms/DialectConversion.h | 56 ++++++++++++++----- .../mlir/Transforms/OneToNTypeConversion.h | 45 +-------------- .../ArmSME/Transforms/VectorLegalization.cpp | 2 +- .../Transforms/Utils/DialectConversion.cpp | 24 ++++++-- .../Transforms/Utils/OneToNTypeConversion.cpp | 44 +++++---------- .../TestOneToNTypeConversionPass.cpp | 18 ++++-- 7 files changed, 93 insertions(+), 98 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h index 6ccbc40bdd603..2e9c297f20182 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -150,7 +150,7 @@ std::unique_ptr createLowerForeachToSCFPass(); //===----------------------------------------------------------------------===// /// Type converter for iter_space and iterator. -struct SparseIterationTypeConverter : public OneToNTypeConverter { +struct SparseIterationTypeConverter : public TypeConverter { SparseIterationTypeConverter(); }; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 5ff36160dd616..37da03bbe386e 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -173,7 +173,9 @@ class TypeConverter { /// conversion has finished. /// /// Note: Target materializations may optionally accept an additional Type - /// parameter, which is the original type of the SSA value. + /// parameter, which is the original type of the SSA value. Furthermore, `T` + /// can be a TypeRange; in that case, the function must return a + /// SmallVector. /// This method registers a materialization that will be called when /// converting (potentially multiple) block arguments that were the result of @@ -210,6 +212,9 @@ class TypeConverter { /// will be invoked with: outputType = "t3", inputs = "v2", // originalType = "t1". Note that the original type "t1" cannot be recovered /// from just "t3" and "v2"; that's why the originalType parameter exists. + /// + /// Note: During a 1:N conversion, the result types can be a TypeRange. In + /// that case the materialization produces a SmallVector. template >::template arg_t<1>> void addTargetMaterialization(FnT &&callback) { @@ -316,6 +321,11 @@ class TypeConverter { Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType = {}) const; + SmallVector materializeTargetConversion(OpBuilder &builder, + Location loc, + TypeRange resultType, + ValueRange inputs, + Type originalType = {}) const; /// Convert an attribute present `attr` from within the type `type` using /// the registered conversion functions. If no applicable conversion has been @@ -340,9 +350,9 @@ class TypeConverter { /// The signature of the callback used to materialize a target conversion. /// - /// Arguments: builder, result type, inputs, location, original type - using TargetMaterializationCallbackFn = - std::function; + /// Arguments: builder, result types, inputs, location, original type + using TargetMaterializationCallbackFn = std::function( + OpBuilder &, TypeRange, ValueRange, Location, Type)>; /// The signature of the callback used to convert a type attribute. using TypeAttributeConversionCallbackFn = @@ -409,22 +419,40 @@ class TypeConverter { /// callback. /// /// With callback of form: - /// `Value(OpBuilder &, T, ValueRange, Location, Type)` + /// - Value(OpBuilder &, T, ValueRange, Location, Type) + /// - SmallVector(OpBuilder &, TypeRange, ValueRange, Location, Type) template std::enable_if_t< std::is_invocable_v, TargetMaterializationCallbackFn> wrapTargetMaterialization(FnT &&callback) const { return [callback = std::forward(callback)]( - OpBuilder &builder, Type resultType, ValueRange inputs, - Location loc, Type originalType) -> Value { - if (T derivedType = dyn_cast(resultType)) - return callback(builder, derivedType, inputs, loc, originalType); - return Value(); + OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, + Location loc, Type originalType) -> SmallVector { + SmallVector result; + if constexpr (std::is_same::value) { + // This is a 1:N target materialization. Return the produces values + // directly. + result = callback(builder, resultTypes, inputs, loc, originalType); + } else { + // This is a 1:1 target materialization. Invoke it only if the result + // type class of the callback matches the requested result type. + if (T derivedType = dyn_cast(resultTypes.front())) { + // 1:1 materializations produce single values, but we store 1:N + // target materialization functions in the type converter. Wrap the + // result value in a SmallVector. + std::optional val = + callback(builder, derivedType, inputs, loc, originalType); + if (val.has_value() && *val) + result.push_back(*val); + } + } + return result; }; } /// With callback of form: - /// `Value(OpBuilder &, T, ValueRange, Location)` + /// - Value(OpBuilder &, T, ValueRange, Location) + /// - SmallVector(OpBuilder &, TypeRange, ValueRange, Location) template std::enable_if_t< std::is_invocable_v, @@ -432,9 +460,9 @@ class TypeConverter { wrapTargetMaterialization(FnT &&callback) const { return wrapTargetMaterialization( [callback = std::forward(callback)]( - OpBuilder &builder, T resultType, ValueRange inputs, Location loc, - Type originalType) -> Value { - return callback(builder, resultType, inputs, loc); + OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc, + Type originalType) { + return callback(builder, resultTypes, inputs, loc); }); } diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h index c59a3a52f028f..7b4dd65cbff7b 100644 --- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h +++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h @@ -33,49 +33,6 @@ namespace mlir { -/// Extends `TypeConverter` with 1:N target materializations. Such -/// materializations have to provide the "reverse" of 1:N type conversions, -/// i.e., they need to materialize N values with target types into one value -/// with a source type (which isn't possible in the base class currently). -class OneToNTypeConverter : public TypeConverter { -public: - /// Callback that expresses user-provided materialization logic from the given - /// value to N values of the given types. This is useful for expressing target - /// materializations for 1:N type conversions, which materialize one value in - /// a source type as N values in target types. - using OneToNMaterializationCallbackFn = - std::function>(OpBuilder &, TypeRange, - Value, Location)>; - - /// Creates the mapping of the given range of original types to target types - /// of the conversion and stores that mapping in the given (signature) - /// conversion. This function simply calls - /// `TypeConverter::convertSignatureArgs` and exists here with a different - /// name to reflect the broader semantic. - LogicalResult computeTypeMapping(TypeRange types, - SignatureConversion &result) const { - return convertSignatureArgs(types, result); - } - - /// Applies one of the user-provided 1:N target materializations. If several - /// exists, they are tried out in the reverse order in which they have been - /// added until the first one succeeds. If none succeeds, the functions - /// returns `std::nullopt`. - std::optional> - materializeTargetConversion(OpBuilder &builder, Location loc, - TypeRange resultTypes, Value input) const; - - /// Adds a 1:N target materialization to the converter. Such materializations - /// build IR that converts N values with target types into 1 value of the - /// source type. - void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) { - oneToNTargetMaterializations.emplace_back(std::move(callback)); - } - -private: - SmallVector oneToNTargetMaterializations; -}; - /// Stores a 1:N mapping of types and provides several useful accessors. This /// class extends `SignatureConversion`, which already supports 1:N type /// mappings but lacks some accessors into the mapping as well as access to the @@ -295,7 +252,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern { /// not fail if some ops or types remain unconverted (i.e., the conversion is /// only "partial"). LogicalResult -applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, +applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, const FrozenRewritePatternSet &patterns); /// Add a pattern to the given pattern list to convert the signature of a diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp index 4968c4fc463d0..e908a536e6fb2 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp @@ -921,7 +921,7 @@ struct VectorLegalizationPass : public arm_sme::impl::VectorLegalizationBase { void runOnOperation() override { auto *context = &getContext(); - OneToNTypeConverter converter; + TypeConverter converter; RewritePatternSet patterns(context); converter.addConversion([](Type type) { return type; }); converter.addConversion( diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 3cfcaa965f354..bf969e74e8bfe 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2831,11 +2831,27 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType) const { + SmallVector result = materializeTargetConversion( + builder, loc, TypeRange(resultType), inputs, originalType); + if (result.empty()) + return nullptr; + assert(result.size() == 1 && "requested 1:1 materialization, but callback " + "produced 1:N materialization"); + return result.front(); +} + +SmallVector TypeConverter::materializeTargetConversion( + OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs, + Type originalType) const { for (const TargetMaterializationCallbackFn &fn : - llvm::reverse(targetMaterializations)) - if (Value result = fn(builder, resultType, inputs, loc, originalType)) - return result; - return nullptr; + llvm::reverse(targetMaterializations)) { + SmallVector result = + fn(builder, resultTypes, inputs, loc, originalType); + if (result.empty()) + continue; + return result; + } + return {}; } std::optional diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp index 19e29d48623e0..c208716891ef1 100644 --- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp +++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp @@ -17,20 +17,6 @@ using namespace llvm; using namespace mlir; -std::optional> -OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder, - Location loc, - TypeRange resultTypes, - Value input) const { - for (const OneToNMaterializationCallbackFn &fn : - llvm::reverse(oneToNTargetMaterializations)) { - if (std::optional> result = - fn(builder, resultTypes, input, loc)) - return *result; - } - return std::nullopt; -} - TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const { TypeRange convertedTypes = getConvertedTypes(); if (auto mapping = getInputMapping(originalTypeNo)) @@ -268,20 +254,20 @@ Block *OneToNPatternRewriter::applySignatureConversion( LogicalResult OneToNConversionPattern::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - auto *typeConverter = getTypeConverter(); + auto *typeConverter = getTypeConverter(); // Construct conversion mapping for results. Operation::result_type_range originalResultTypes = op->getResultTypes(); OneToNTypeMapping resultMapping(originalResultTypes); - if (failed(typeConverter->computeTypeMapping(originalResultTypes, - resultMapping))) + if (failed(typeConverter->convertSignatureArgs(originalResultTypes, + resultMapping))) return failure(); // Construct conversion mapping for operands. Operation::operand_type_range originalOperandTypes = op->getOperandTypes(); OneToNTypeMapping operandMapping(originalOperandTypes); - if (failed(typeConverter->computeTypeMapping(originalOperandTypes, - operandMapping))) + if (failed(typeConverter->convertSignatureArgs(originalOperandTypes, + operandMapping))) return failure(); // Cast operands to target types. @@ -318,7 +304,7 @@ namespace mlir { // inserted by this pass are annotated with a string attribute that also // documents which kind of the cast (source, argument, or target). LogicalResult -applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, +applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, const FrozenRewritePatternSet &patterns) { #ifndef NDEBUG // Remember existing unrealized casts. This data structure is only used in @@ -370,15 +356,13 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, // Target materialization. assert(!areOperandTypesLegal && areResultsTypesLegal && operands.size() == 1 && "found unexpected target cast"); - std::optional> maybeResults = - typeConverter.materializeTargetConversion( - rewriter, castOp->getLoc(), resultTypes, operands.front()); - if (!maybeResults) { + materializedResults = typeConverter.materializeTargetConversion( + rewriter, castOp->getLoc(), resultTypes, operands.front()); + if (materializedResults.empty()) { emitError(castOp->getLoc()) << "failed to create target materialization"; return failure(); } - materializedResults = maybeResults.value(); } else { // Source and argument materializations. assert(areOperandTypesLegal && !areResultsTypesLegal && @@ -427,18 +411,18 @@ class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern { const OneToNTypeMapping &resultMapping, ValueRange convertedOperands) const override { auto funcOp = cast(op); - auto *typeConverter = getTypeConverter(); + auto *typeConverter = getTypeConverter(); // Construct mapping for function arguments. OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes()); - if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(), - argumentMapping))) + if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(), + argumentMapping))) return failure(); // Construct mapping for function results. OneToNTypeMapping funcResultMapping(funcOp.getResultTypes()); - if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(), - funcResultMapping))) + if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(), + funcResultMapping))) return failure(); // Nothing to do if the op doesn't have any non-identity conversions for its diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp index 5c03ac12d1e58..b18dfd8bb22cb 100644 --- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp @@ -147,9 +147,14 @@ populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter, /// /// This function has been copied (with small adaptions) from /// TestDecomposeCallGraphTypes.cpp. -static std::optional> -buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input, - Location loc) { +static SmallVector buildGetTupleElementOps(OpBuilder &builder, + TypeRange resultTypes, + ValueRange inputs, + Location loc) { + if (inputs.size() != 1) + return {}; + Value input = inputs.front(); + TupleType inputType = dyn_cast(input.getType()); if (!inputType) return {}; @@ -222,7 +227,7 @@ void TestOneToNTypeConversionPass::runOnOperation() { auto *context = &getContext(); // Assemble type converter. - OneToNTypeConverter typeConverter; + TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion( @@ -234,6 +239,11 @@ void TestOneToNTypeConversionPass::runOnOperation() { typeConverter.addArgumentMaterialization(buildMakeTupleOp); typeConverter.addSourceMaterialization(buildMakeTupleOp); typeConverter.addTargetMaterialization(buildGetTupleElementOps); + // Test the other target materialization variant that takes the original type + // as additional argument. This materialization function always fails. + typeConverter.addTargetMaterialization( + [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, + Location loc, Type originalType) -> SmallVector { return {}; }); // Assemble patterns. RewritePatternSet patterns(context); From c20a361c062481fc6cdb6027a5f63b1af83774ed Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 23 Oct 2024 09:36:45 -0700 Subject: [PATCH 2/5] Update mlir/include/mlir/Transforms/DialectConversion.h MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Markus Böck --- mlir/include/mlir/Transforms/DialectConversion.h | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 37da03bbe386e..0638dfedd647b 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -434,18 +434,19 @@ class TypeConverter { // This is a 1:N target materialization. Return the produces values // directly. result = callback(builder, resultTypes, inputs, loc, originalType); - } else { + } else if constexpr (std::is_assignable::value) { // This is a 1:1 target materialization. Invoke it only if the result // type class of the callback matches the requested result type. if (T derivedType = dyn_cast(resultTypes.front())) { // 1:1 materializations produce single values, but we store 1:N // target materialization functions in the type converter. Wrap the // result value in a SmallVector. - std::optional val = - callback(builder, derivedType, inputs, loc, originalType); - if (val.has_value() && *val) - result.push_back(*val); + Value val = callback(builder, derivedType, inputs, loc, originalType); + if (val) + result.push_back(val); } + } else { + static_assert(false, "T must be a Type or a TypeRange"); } return result; }; From 74e3ea8cf38d1138911ccbbb038d2f898ad571fb Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 24 Oct 2024 10:39:34 -0700 Subject: [PATCH 3/5] Update mlir/include/mlir/Transforms/DialectConversion.h MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Markus Böck --- mlir/include/mlir/Transforms/DialectConversion.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 0638dfedd647b..54aa8e90135c6 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -446,7 +446,7 @@ class TypeConverter { result.push_back(val); } } else { - static_assert(false, "T must be a Type or a TypeRange"); + static_assert(sizeof(T) == 0, "T must be a Type or a TypeRange"); } return result; }; From 88bbf68b673b972709e6bf044bc786031d451022 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 24 Oct 2024 19:48:14 +0200 Subject: [PATCH 4/5] Improve asserts --- mlir/lib/Transforms/Utils/DialectConversion.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index bf969e74e8bfe..3d0c81867e0cc 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2835,8 +2835,7 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder, builder, loc, TypeRange(resultType), inputs, originalType); if (result.empty()) return nullptr; - assert(result.size() == 1 && "requested 1:1 materialization, but callback " - "produced 1:N materialization"); + assert(result.size() == 1 && "expected single result"); return result.front(); } @@ -2849,6 +2848,9 @@ SmallVector TypeConverter::materializeTargetConversion( fn(builder, resultTypes, inputs, loc, originalType); if (result.empty()) continue; + assert(TypeRange(result) == resultTypes && + "callback produced incorrect number of values or values with " + "incorrect types"); return result; } return {}; From 43edbab25ec8e2d3a4050e09f43032b7076dc002 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 25 Oct 2024 19:51:40 +0200 Subject: [PATCH 5/5] address comments --- .../mlir/Transforms/DialectConversion.h | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 54aa8e90135c6..5e5957170e646 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -435,15 +435,20 @@ class TypeConverter { // directly. result = callback(builder, resultTypes, inputs, loc, originalType); } else if constexpr (std::is_assignable::value) { - // This is a 1:1 target materialization. Invoke it only if the result - // type class of the callback matches the requested result type. - if (T derivedType = dyn_cast(resultTypes.front())) { - // 1:1 materializations produce single values, but we store 1:N - // target materialization functions in the type converter. Wrap the - // result value in a SmallVector. - Value val = callback(builder, derivedType, inputs, loc, originalType); - if (val) - result.push_back(val); + // This is a 1:1 target materialization. Invoke the callback only if a + // single SSA value is requested. + if (resultTypes.size() == 1) { + // Invoke the callback only if the type class of the callback matches + // the requested result type. + if (T derivedType = dyn_cast(resultTypes.front())) { + // 1:1 materializations produce single values, but we store 1:N + // target materialization functions in the type converter. Wrap the + // result value in a SmallVector. + Value val = + callback(builder, derivedType, inputs, loc, originalType); + if (val) + result.push_back(val); + } } } else { static_assert(sizeof(T) == 0, "T must be a Type or a TypeRange");