From f15d21e9f351df56fb69c39d1231bfad4a3c9929 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Tue, 29 Oct 2024 22:20:07 +0000 Subject: [PATCH 1/3] [mlir][VectorToSPIRV] Add conversion for vector.extract with dynamic indices --- .../VectorToSPIRV/VectorToSPIRV.cpp | 48 ++++++++++--------- .../VectorToSPIRV/vector-to-spirv.mlir | 42 ++++++++++++++++ 2 files changed, 68 insertions(+), 22 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 6184225cb6285..ee8dccf025a0c 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" @@ -40,22 +41,9 @@ using namespace mlir; /// Returns the integer value from the first valid input element, assuming Value /// inputs are defined by a constant index ops and Attribute inputs are integer /// attributes. -static uint64_t getFirstIntValue(ValueRange values) { - return values[0].getDefiningOp().value(); -} -static uint64_t getFirstIntValue(ArrayRef attr) { - return cast(attr[0]).getInt(); -} static uint64_t getFirstIntValue(ArrayAttr attr) { return (*attr.getAsValueRange().begin()).getZExtValue(); } -static uint64_t getFirstIntValue(ArrayRef foldResults) { - auto attr = foldResults[0].dyn_cast(); - if (attr) - return getFirstIntValue(attr); - - return getFirstIntValue(ValueRange{foldResults[0].get()}); -} /// Returns the number of bits for the given scalar/vector type. static int getNumBits(Type type) { @@ -157,9 +145,6 @@ struct VectorExtractOpConvert final LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (extractOp.hasDynamicPosition()) - return failure(); - Type dstType = getTypeConverter()->convertType(extractOp.getType()); if (!dstType) return failure(); @@ -169,9 +154,17 @@ struct VectorExtractOpConvert final return success(); } - int32_t id = getFirstIntValue(extractOp.getMixedPosition()); - rewriter.replaceOpWithNewOp( - extractOp, adaptor.getVector(), id); + std::optional id = + getConstantIntValue(extractOp.getMixedPosition()[0]); + + if (id.has_value()) + rewriter.replaceOpWithNewOp( + extractOp, dstType, adaptor.getVector(), + rewriter.getI32ArrayAttr(id.value())); + else + rewriter.replaceOpWithNewOp( + extractOp, dstType, adaptor.getVector(), + adaptor.getDynamicPosition()[0]); return success(); } }; @@ -249,9 +242,20 @@ struct VectorInsertOpConvert final return success(); } - int32_t id = getFirstIntValue(insertOp.getMixedPosition()); - rewriter.replaceOpWithNewOp( - insertOp, adaptor.getSource(), adaptor.getDest(), id); + std::optional id = + getConstantIntValue(insertOp.getMixedPosition()[0]); + + // rewriter.replaceOpWithNewOp( + // insertOp, adaptor.getSource(), adaptor.getDest(), id); + // return success(); + + if (id.has_value()) + rewriter.replaceOpWithNewOp( + insertOp, adaptor.getSource(), adaptor.getDest(), id.value()); + else + rewriter.replaceOpWithNewOp( + insertOp, insertOp.getDest(), adaptor.getSource(), + adaptor.getDynamicPosition()[0]); return success(); } }; diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 25ec5d0159bd5..62210108aa73c 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -186,6 +186,26 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 { // ----- +// CHECK-LABEL: @extract_dynamic +// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index +// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32 +// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 +func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 { + %0 = vector.extract %arg0[%id] : f32 from vector<4xf32> + return %0: f32 +} + +// CHECK-LABEL: @extract_dynamic_cst +// CHECK-SAME: %[[V:.*]]: vector<4xf32> +// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> +func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 { + %idx = arith.constant 1 : index + %0 = vector.extract %arg0[%idx] : f32 from vector<4xf32> + return %0: f32 +} + +// ----- + // CHECK-LABEL: @insert // CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32 // CHECK: spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32> @@ -216,6 +236,28 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3 // ----- +// CHECK-LABEL: @insert_dynamic +// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index +// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32 +// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 +func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> { + %0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32> + return %0: vector<4xf32> +} + +// ----- + +// CHECK-LABEL: @insert_dynamic_cst +// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32> +// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32> +func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> { + %idx = arith.constant 2 : index + %0 = vector.insert %val, %arg0[%idx] : f32 into vector<4xf32> + return %0: vector<4xf32> +} + +// ----- + // CHECK-LABEL: @extract_element // CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 // CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 From 7e7add4244da637808c781991e8df3f474c9f7d4 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Tue, 5 Nov 2024 09:01:28 +0000 Subject: [PATCH 2/3] Address comments --- .../VectorToSPIRV/VectorToSPIRV.cpp | 4 ---- .../VectorToSPIRV/vector-to-spirv.mlir | 22 +++++++++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index ee8dccf025a0c..b6b5a1cf939e4 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -245,10 +245,6 @@ struct VectorInsertOpConvert final std::optional id = getConstantIntValue(insertOp.getMixedPosition()[0]); - // rewriter.replaceOpWithNewOp( - // insertOp, adaptor.getSource(), adaptor.getDest(), id); - // return success(); - if (id.has_value()) rewriter.replaceOpWithNewOp( insertOp, adaptor.getSource(), adaptor.getDest(), id.value()); diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 62210108aa73c..8796f153c4911 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -186,6 +186,17 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 { // ----- +// CHECK-LABEL: @extract_size1_vector_dynamic +// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32> +// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] +// CHECK: return %[[R]] +func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f32 { + %0 = vector.extract %arg0[%id] : f32 from vector<1xf32> + return %0: f32 +} + +// ----- + // CHECK-LABEL: @extract_dynamic // CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index // CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32 @@ -236,6 +247,17 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3 // ----- +// CHECK-LABEL: @insert_size1_vector_dynamic +// CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32 +// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]] +// CHECK: return %[[R]] +func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id : index) -> vector<1xf32> { + %1 = vector.insert %arg1, %arg0[%id] : f32 into vector<1xf32> + return %1 : vector<1xf32> +} + +// ----- + // CHECK-LABEL: @insert_dynamic // CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index // CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32 From 46795c8ca6674c4022861958a3541f30e91ff120 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Tue, 5 Nov 2024 23:09:03 +0000 Subject: [PATCH 3/3] address more comments --- mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index b6b5a1cf939e4..656b1cb3e99a1 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -154,10 +154,8 @@ struct VectorExtractOpConvert final return success(); } - std::optional id = - getConstantIntValue(extractOp.getMixedPosition()[0]); - - if (id.has_value()) + if (std::optional id = + getConstantIntValue(extractOp.getMixedPosition()[0])) rewriter.replaceOpWithNewOp( extractOp, dstType, adaptor.getVector(), rewriter.getI32ArrayAttr(id.value())); @@ -242,10 +240,8 @@ struct VectorInsertOpConvert final return success(); } - std::optional id = - getConstantIntValue(insertOp.getMixedPosition()[0]); - - if (id.has_value()) + if (std::optional id = + getConstantIntValue(insertOp.getMixedPosition()[0])) rewriter.replaceOpWithNewOp( insertOp, adaptor.getSource(), adaptor.getDest(), id.value()); else