diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 6184225cb6285..656b1cb3e99a1 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" @@ -40,22 +41,9 @@ using namespace mlir; /// Returns the integer value from the first valid input element, assuming Value /// inputs are defined by a constant index ops and Attribute inputs are integer /// attributes. -static uint64_t getFirstIntValue(ValueRange values) { - return values[0].getDefiningOp().value(); -} -static uint64_t getFirstIntValue(ArrayRef attr) { - return cast(attr[0]).getInt(); -} static uint64_t getFirstIntValue(ArrayAttr attr) { return (*attr.getAsValueRange().begin()).getZExtValue(); } -static uint64_t getFirstIntValue(ArrayRef foldResults) { - auto attr = foldResults[0].dyn_cast(); - if (attr) - return getFirstIntValue(attr); - - return getFirstIntValue(ValueRange{foldResults[0].get()}); -} /// Returns the number of bits for the given scalar/vector type. static int getNumBits(Type type) { @@ -157,9 +145,6 @@ struct VectorExtractOpConvert final LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (extractOp.hasDynamicPosition()) - return failure(); - Type dstType = getTypeConverter()->convertType(extractOp.getType()); if (!dstType) return failure(); @@ -169,9 +154,15 @@ struct VectorExtractOpConvert final return success(); } - int32_t id = getFirstIntValue(extractOp.getMixedPosition()); - rewriter.replaceOpWithNewOp( - extractOp, adaptor.getVector(), id); + if (std::optional id = + getConstantIntValue(extractOp.getMixedPosition()[0])) + rewriter.replaceOpWithNewOp( + extractOp, dstType, adaptor.getVector(), + rewriter.getI32ArrayAttr(id.value())); + else + rewriter.replaceOpWithNewOp( + extractOp, dstType, adaptor.getVector(), + adaptor.getDynamicPosition()[0]); return success(); } }; @@ -249,9 +240,14 @@ struct VectorInsertOpConvert final return success(); } - int32_t id = getFirstIntValue(insertOp.getMixedPosition()); - rewriter.replaceOpWithNewOp( - insertOp, adaptor.getSource(), adaptor.getDest(), id); + if (std::optional id = + getConstantIntValue(insertOp.getMixedPosition()[0])) + rewriter.replaceOpWithNewOp( + insertOp, adaptor.getSource(), adaptor.getDest(), id.value()); + else + rewriter.replaceOpWithNewOp( + insertOp, insertOp.getDest(), adaptor.getSource(), + adaptor.getDynamicPosition()[0]); return success(); } }; diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 25ec5d0159bd5..8796f153c4911 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -186,6 +186,37 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 { // ----- +// CHECK-LABEL: @extract_size1_vector_dynamic +// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32> +// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] +// CHECK: return %[[R]] +func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f32 { + %0 = vector.extract %arg0[%id] : f32 from vector<1xf32> + return %0: f32 +} + +// ----- + +// CHECK-LABEL: @extract_dynamic +// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index +// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32 +// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 +func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 { + %0 = vector.extract %arg0[%id] : f32 from vector<4xf32> + return %0: f32 +} + +// CHECK-LABEL: @extract_dynamic_cst +// CHECK-SAME: %[[V:.*]]: vector<4xf32> +// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> +func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 { + %idx = arith.constant 1 : index + %0 = vector.extract %arg0[%idx] : f32 from vector<4xf32> + return %0: f32 +} + +// ----- + // CHECK-LABEL: @insert // CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32 // CHECK: spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32> @@ -216,6 +247,39 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3 // ----- +// CHECK-LABEL: @insert_size1_vector_dynamic +// CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32 +// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]] +// CHECK: return %[[R]] +func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id : index) -> vector<1xf32> { + %1 = vector.insert %arg1, %arg0[%id] : f32 into vector<1xf32> + return %1 : vector<1xf32> +} + +// ----- + +// CHECK-LABEL: @insert_dynamic +// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index +// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32 +// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 +func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> { + %0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32> + return %0: vector<4xf32> +} + +// ----- + +// CHECK-LABEL: @insert_dynamic_cst +// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32> +// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32> +func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> { + %idx = arith.constant 2 : index + %0 = vector.insert %val, %arg0[%idx] : f32 into vector<4xf32> + return %0: vector<4xf32> +} + +// ----- + // CHECK-LABEL: @extract_element // CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 // CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32