diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 529dd4094507f..da0222bc94237 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -3042,13 +3042,13 @@ func.func @vector_store_index_scalable(%memref : memref<200x100xindex>, %i : ind // ----- -func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : index) { +func.func @vector_store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) { %val = arith.constant dense<11.0> : vector vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector return } -// CHECK-LABEL: func @vector_store_op_0d +// CHECK-LABEL: func @vector_store_0d // CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector to vector<1xf32> // CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64 @@ -3057,13 +3057,13 @@ func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : in // ----- -func.func @masked_load_op(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> { +func.func @masked_load(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> { %c0 = arith.constant 0: index %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> return %0 : vector<16xf32> } -// CHECK-LABEL: func @masked_load_op +// CHECK-LABEL: func @masked_load // CHECK: %[[CO:.*]] = arith.constant 0 : index // CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64 // CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 @@ -3072,23 +3072,48 @@ func.func @masked_load_op(%arg0: memref, %arg1: vector<16xi1>, %arg2: vec // ----- -func.func @masked_load_op_index(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xindex>) -> vector<16xindex> { +func.func @masked_load_scalable(%arg0: memref, %arg1: vector<[16]xi1>, %arg2: vector<[16]xf32>) -> vector<[16]xf32> { + %c0 = arith.constant 0: index + %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref, vector<[16]xi1>, vector<[16]xf32> into vector<[16]xf32> + return %0 : vector<[16]xf32> +} + +// CHECK-LABEL: func @masked_load_scalable +// CHECK: %[[CO:.*]] = arith.constant 0 : index +// CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64 +// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr, vector<[16]xi1>, vector<[16]xf32>) -> vector<[16]xf32> +// CHECK: return %[[L]] : vector<[16]xf32> + +// ----- + +func.func @masked_load_index(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xindex>) -> vector<16xindex> { %c0 = arith.constant 0: index %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref, vector<16xi1>, vector<16xindex> into vector<16xindex> return %0 : vector<16xindex> } -// CHECK-LABEL: func @masked_load_op_index +// CHECK-LABEL: func @masked_load_index // CHECK: %{{.*}} = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xi64>) -> vector<16xi64> // ----- -func.func @masked_store_op(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xf32>) { +func.func @masked_load_index_scalable(%arg0: memref, %arg1: vector<[16]xi1>, %arg2: vector<[16]xindex>) -> vector<[16]xindex> { + %c0 = arith.constant 0: index + %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref, vector<[16]xi1>, vector<[16]xindex> into vector<[16]xindex> + return %0 : vector<[16]xindex> +} +// CHECK-LABEL: func @masked_load_index_scalable +// CHECK: %{{.*}} = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.ptr, vector<[16]xi1>, vector<[16]xi64>) -> vector<[16]xi64> + +// ----- + +func.func @masked_store(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xf32>) { %c0 = arith.constant 0: index vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref, vector<16xi1>, vector<16xf32> return } -// CHECK-LABEL: func @masked_store_op +// CHECK-LABEL: func @masked_store // CHECK: %[[CO:.*]] = arith.constant 0 : index // CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64 // CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 @@ -3096,76 +3121,126 @@ func.func @masked_store_op(%arg0: memref, %arg1: vector<16xi1>, %arg2: ve // ----- -func.func @masked_store_op_index(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xindex>) { +func.func @masked_store_scalable(%arg0: memref, %arg1: vector<[16]xi1>, %arg2: vector<[16]xf32>) { + %c0 = arith.constant 0: index + vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref, vector<[16]xi1>, vector<[16]xf32> + return +} + +// CHECK-LABEL: func @masked_store_scalable +// CHECK: %[[CO:.*]] = arith.constant 0 : index +// CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64 +// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[16]xf32>, vector<[16]xi1> into !llvm.ptr + +// ----- + +func.func @masked_store_index(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xindex>) { %c0 = arith.constant 0: index vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref, vector<16xi1>, vector<16xindex> return } -// CHECK-LABEL: func @masked_store_op_index +// CHECK-LABEL: func @masked_store_index // CHECK: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<16xi64>, vector<16xi1> into !llvm.ptr // ----- -func.func @gather_op(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> { +func.func @masked_store_index_scalable(%arg0: memref, %arg1: vector<[16]xi1>, %arg2: vector<[16]xindex>) { + %c0 = arith.constant 0: index + vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref, vector<[16]xi1>, vector<[16]xindex> + return +} +// CHECK-LABEL: func @masked_store_index_scalable +// CHECK: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<[16]xi64>, vector<[16]xi1> into !llvm.ptr + +// ----- + +func.func @gather(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> { %0 = arith.constant 0: index %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32> return %1 : vector<3xf32> } -// CHECK-LABEL: func @gather_op +// CHECK-LABEL: func @gather // CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32 // CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> // CHECK: return %[[G]] : vector<3xf32> // ----- -func.func @gather_op_scalable(%arg0: memref, %arg1: vector<[3]xi32>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xf32>) -> vector<[3]xf32> { +func.func @gather_scalable(%arg0: memref, %arg1: vector<[3]xi32>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xf32>) -> vector<[3]xf32> { %0 = arith.constant 0: index %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<[3]xi32>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32> return %1 : vector<[3]xf32> } -// CHECK-LABEL: func @gather_op_scalable +// CHECK-LABEL: func @gather_scalable // CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec, f32 // CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32> // CHECK: return %[[G]] : vector<[3]xf32> // ----- -func.func @gather_op_global_memory(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> { +func.func @gather_global_memory(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> { %0 = arith.constant 0: index %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32> return %1 : vector<3xf32> } -// CHECK-LABEL: func @gather_op +// CHECK-LABEL: func @gather_global_memory // CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> !llvm.vec<3 x ptr<1>>, f32 // CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> // CHECK: return %[[G]] : vector<3xf32> // ----- +func.func @gather_global_memory_scalable(%arg0: memref, %arg1: vector<[3]xi32>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xf32>) -> vector<[3]xf32> { + %0 = arith.constant 0: index + %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<[3]xi32>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32> + return %1 : vector<[3]xf32> +} + +// CHECK-LABEL: func @gather_global_memory_scalable +// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<[3]xi32>) -> !llvm.vec>, f32 +// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32> +// CHECK: return %[[G]] : vector<[3]xf32> + +// ----- + -func.func @gather_op_index(%arg0: memref, %arg1: vector<3xindex>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) -> vector<3xindex> { +func.func @gather_index(%arg0: memref, %arg1: vector<3xindex>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) -> vector<3xindex> { %0 = arith.constant 0: index %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<3xindex>, vector<3xi1>, vector<3xindex> into vector<3xindex> return %1 : vector<3xindex> } -// CHECK-LABEL: func @gather_op_index +// CHECK-LABEL: func @gather_index // CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> !llvm.vec<3 x ptr>, i64 // CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xi64>) -> vector<3xi64> // CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[G]] : vector<3xi64> to vector<3xindex> // ----- -func.func @gather_op_multi_dims(%arg0: memref, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> { +func.func @gather_index_scalable(%arg0: memref, %arg1: vector<[3]xindex>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xindex>) -> vector<[3]xindex> { + %0 = arith.constant 0: index + %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<[3]xindex>, vector<[3]xi1>, vector<[3]xindex> into vector<[3]xindex> + return %1 : vector<[3]xindex> +} + +// CHECK-LABEL: func @gather_index_scalable +// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> !llvm.vec, i64 +// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.vec, vector<[3]xi1>, vector<[3]xi64>) -> vector<[3]xi64> +// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[G]] : vector<[3]xi64> to vector<[3]xindex> + +// ----- + +func.func @gather_2d_from_1d(%arg0: memref, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> { %0 = arith.constant 0: index %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32> return %1 : vector<2x3xf32> } -// CHECK-LABEL: func @gather_op_multi_dims +// CHECK-LABEL: func @gather_2d_from_1d // CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>> // CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>> @@ -3182,40 +3257,94 @@ func.func @gather_op_multi_dims(%arg0: memref, %arg1: vector<2x3xi32>, %a // ----- -func.func @gather_op_with_mask(%arg0: memref, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> { +func.func @gather_2d_from_1d_scalable(%arg0: memref, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xi1>, %arg3: vector<2x[3]xf32>) -> vector<2x[3]xf32> { + %0 = arith.constant 0: index + %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32> + return %1 : vector<2x[3]xf32> +} + +// CHECK-LABEL: func @gather_2d_from_1d_scalable +// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>> +// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>> +// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>> +// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec, f32 +// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32> +// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>> +// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>> +// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>> +// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>> +// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec, f32 +// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32> +// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>> + +// ----- + +func.func @gather_with_mask(%arg0: memref, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> { %0 = arith.constant 0: index %1 = vector.constant_mask [1, 2] : vector<2x3xi1> %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32> return %2 : vector<2x3xf32> } -// CHECK-LABEL: func @gather_op_with_mask +// CHECK-LABEL: func @gather_with_mask // CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> // CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> // ----- -func.func @gather_op_with_zero_mask(%arg0: memref, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> { +func.func @gather_with_mask_scalable(%arg0: memref, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> { + %0 = arith.constant 0: index + // vector.constant_mask only supports 'none set' or 'all set' scalable + // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed + // width vectors above. + %1 = vector.constant_mask [1, 3] : vector<2x[3]xi1> + %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32> + return %2 : vector<2x[3]xf32> +} + +// CHECK-LABEL: func @gather_with_mask_scalable +// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32> +// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32> + + +// ----- + +func.func @gather_with_zero_mask(%arg0: memref, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> { %0 = arith.constant 0: index %1 = vector.constant_mask [0, 0] : vector<2x3xi1> %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32> return %2 : vector<2x3xf32> } -// CHECK-LABEL: func @gather_op_with_zero_mask +// CHECK-LABEL: func @gather_with_zero_mask // CHECK-SAME: (%{{.*}}: memref, %{{.*}}: vector<2x3xi32>, %[[S:.*]]: vector<2x3xf32>) // CHECK-NOT: %{{.*}} = llvm.intr.masked.gather // CHECK: return %[[S]] : vector<2x3xf32> // ----- -func.func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> { +func.func @gather_with_zero_mask_scalable(%arg0: memref, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> { + %0 = arith.constant 0: index + %1 = vector.constant_mask [0, 0] : vector<2x[3]xi1> + %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32> + return %2 : vector<2x[3]xf32> +} + +// CHECK-LABEL: func @gather_with_zero_mask_scalable +// CHECK-SAME: (%{{.*}}: memref, %{{.*}}: vector<2x[3]xi32>, %[[S:.*]]: vector<2x[3]xf32>) +// CHECK-NOT: %{{.*}} = llvm.intr.masked.gather +// CHECK: return %[[S]] : vector<2x[3]xf32> + +// ----- + +func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> { %0 = arith.constant 3 : index %1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32> return %1 : vector<4xf32> } -// CHECK-LABEL: func @gather_2d_op +// CHECK-LABEL: func @gather_1d_from_2d // CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> !llvm.vec<4 x ptr>, f32 // CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<4 x ptr>, vector<4xi1>, vector<4xf32>) -> vector<4xf32> @@ -3223,55 +3352,94 @@ func.func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vec // ----- -func.func @scatter_op(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) { +func.func @gather_1d_from_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]xi32>, %arg2: vector<[4]xi1>, %arg3: vector<[4]xf32>) -> vector<[4]xf32> { + %0 = arith.constant 3 : index + %1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x?xf32>, vector<[4]xi32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> + return %1 : vector<[4]xf32> +} + +// CHECK-LABEL: func @gather_1d_from_2d_scalable +// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> !llvm.vec, f32 +// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32> +// CHECK: return %[[G]] : vector<[4]xf32> + +// ----- + +func.func @scatter(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) { %0 = arith.constant 0: index vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> return } -// CHECK-LABEL: func @scatter_op +// CHECK-LABEL: func @scatter // CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32 // CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr> // ----- -func.func @scatter_op_scalable(%arg0: memref, %arg1: vector<[3]xi32>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xf32>) { +func.func @scatter_scalable(%arg0: memref, %arg1: vector<[3]xi32>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xf32>) { %0 = arith.constant 0: index vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<[3]xi32>, vector<[3]xi1>, vector<[3]xf32> return } -// CHECK-LABEL: func @scatter_op_scalable +// CHECK-LABEL: func @scatter_scalable // CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec, f32 // CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec // ----- -func.func @scatter_op_index(%arg0: memref, %arg1: vector<3xindex>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) { +func.func @scatter_index(%arg0: memref, %arg1: vector<3xindex>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) { %0 = arith.constant 0: index vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<3xindex>, vector<3xi1>, vector<3xindex> return } -// CHECK-LABEL: func @scatter_op_index +// CHECK-LABEL: func @scatter_index // CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> !llvm.vec<3 x ptr>, i64 // CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<3xi64>, vector<3xi1> into !llvm.vec<3 x ptr> // ----- -func.func @scatter_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) { +func.func @scatter_index_scalable(%arg0: memref, %arg1: vector<[3]xindex>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xindex>) { + %0 = arith.constant 0: index + vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<[3]xindex>, vector<[3]xi1>, vector<[3]xindex> + return +} + +// CHECK-LABEL: func @scatter_index_scalable +// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> !llvm.vec, i64 +// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<[3]xi64>, vector<[3]xi1> into !llvm.vec + +// ----- + +func.func @scatter_1d_into_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) { %0 = arith.constant 3 : index vector.scatter %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> return } -// CHECK-LABEL: func @scatter_2d_op +// CHECK-LABEL: func @scatter_1d_into_2d // CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> !llvm.vec<4 x ptr>, f32 // CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<4xf32>, vector<4xi1> into !llvm.vec<4 x ptr> // ----- +func.func @scatter_1d_into_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]xi32>, %arg2: vector<[4]xi1>, %arg3: vector<[4]xf32>) { + %0 = arith.constant 3 : index + vector.scatter %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x?xf32>, vector<[4]xi32>, vector<[4]xi1>, vector<[4]xf32> + return +} + +// CHECK-LABEL: func @scatter_1d_into_2d_scalable +// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> !llvm.vec, f32 +// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into !llvm.vec + +// ----- + func.func @expand_load_op(%arg0: memref, %arg1: vector<11xi1>, %arg2: vector<11xf32>) -> vector<11xf32> { %c0 = arith.constant 0: index %0 = vector.expandload %arg0[%c0], %arg1, %arg2 : memref, vector<11xi1>, vector<11xf32> into vector<11xf32>