Skip to content

Commit 22cd31a

Browse files
committed
[mlir][spirv] Add support for VectorAnyINTEL capability
Allow vector of any lengths between [2-2^32-1]. VectorAnyINTEL capability (part of "SPV_INTEL_vector_compute" extension) relaxes the length constraint on SPIR-V vector sizes from 2,3, and 4.
1 parent 34c6f20 commit 22cd31a

File tree

13 files changed

+124
-45
lines changed

13 files changed

+124
-45
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4146,7 +4146,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
41464146
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
41474147
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
41484148
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
4149-
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
4149+
// Remove the vector size restriction.
4150+
// Although the vector size can be upto (2^64-1), uint64,
4151+
// 2^32-1 (UNINT32_MAX>) is a more realistic number, it should serve the purpose
4152+
// for all practical cases.
4153+
// Also unsigned is used for the number elements for composite tyeps.
4154+
def SPIRV_Vector : VectorOfLengthRangeAndType<[2, 0xFFFFFFFF],
41504155
[SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
41514156
// Component type check is done in the type parser for the following SPIR-V
41524157
// dialect-specific types so we use "Any" here.
@@ -4206,10 +4211,10 @@ class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
42064211
"Joint Matrix">;
42074212

42084213
class SPIRV_ScalarOrVectorOf<Type type> :
4209-
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>;
4214+
AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>]>;
42104215

42114216
class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
4212-
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
4217+
AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>,
42134218
SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;
42144219

42154220
class SPIRV_MatrixOrCoopMatrixOf<Type type> :

mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
184184
parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
185185
return Type();
186186
}
187-
if (t.getNumElements() > 4) {
187+
// Number of elements should be between [2 - 2^32 -1],
188+
// since getNumElements() returns an unsigned, the upper limit check is
189+
// unnecessary.
190+
if (t.getNumElements() < 2) {
188191
parser.emitError(
189-
typeLoc, "vector length has to be less than or equal to 4 but found ")
192+
typeLoc, "vector length has to be between [2 - 2^32 -1] but found ")
190193
<< t.getNumElements();
191194
return Type();
192195
}

mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,11 @@ bool CompositeType::classof(Type type) {
101101
}
102102

103103
bool CompositeType::isValid(VectorType type) {
104-
return type.getRank() == 1 &&
105-
llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
106-
llvm::isa<ScalarType>(type.getElementType());
104+
// Number of elements should be between [2 - 2^32 -1],
105+
// since getNumElements() returns an unsigned, the upper limit check is
106+
// unnecessary.
107+
return type.getRank() == 1 && llvm::isa<ScalarType>(type.getElementType()) &&
108+
type.getNumElements() >= 2;
107109
}
108110

109111
Type CompositeType::getElementType(unsigned index) const {
@@ -171,9 +173,17 @@ void CompositeType::getCapabilities(
171173
.Case<VectorType>([&](VectorType type) {
172174
auto vecSize = getNumElements();
173175
if (vecSize == 8 || vecSize == 16) {
174-
static const Capability caps[] = {Capability::Vector16};
175-
ArrayRef<Capability> ref(caps, std::size(caps));
176-
capabilities.push_back(ref);
176+
static constexpr Capability caps[] = {Capability::Vector16,
177+
Capability::VectorAnyINTEL};
178+
capabilities.push_back(caps);
179+
}
180+
// VectorAnyINTEL capability removes the vector size restriction and
181+
// allows the vector size to be up to (2^32-1).
182+
// Vector16 capability allows the vector size to be 8 and 16
183+
SmallVector<unsigned, 5> allowedVecRange = {2, 3, 4, 8, 16};
184+
if (vecSize >= 2 && !llvm::is_contained(allowedVecRange, vecSize)) {
185+
static constexpr Capability caps[] = {Capability::VectorAnyINTEL};
186+
capabilities.push_back(caps);
177187
}
178188
return llvm::cast<ScalarType>(type.getElementType())
179189
.getCapabilities(capabilities, storage);

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ module attributes {
1111
#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
1212
} {
1313

14-
func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) {
14+
func.func @unsupported_5elem_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) {
1515
// expected-error@+1 {{failed to legalize operation 'arith.subi'}}
16-
%1 = arith.subi %arg0, %arg0: vector<5xi32>
16+
%1 = arith.subi %arg0, %arg1: vector<5xi32>
1717
return
1818
}
1919

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,3 +1407,43 @@ func.func @float_scalar(%arg0: f16) {
14071407
}
14081408

14091409
} // end module
1410+
1411+
// -----
1412+
1413+
//===----------------------------------------------------------------------===//
1414+
// VectorAnyINTEL support
1415+
//===----------------------------------------------------------------------===//
1416+
1417+
// Check that with VectorAnyINTEL, VectorComputeINTEL capability,
1418+
// and SPV_INTEL_vector_compute extension, any sized (2-2^32 -1) vector is allowed.
1419+
module attributes {
1420+
spirv.target_env = #spirv.target_env<
1421+
#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Kernel, VectorAnyINTEL], [SPV_INTEL_vector_compute]>, #spirv.resource_limits<>>
1422+
} {
1423+
1424+
// CHECK-LABEL: @any_vector
1425+
func.func @any_vector(%arg0: vector<16xi32>, %arg1: vector<16xi32>) {
1426+
// CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<16xi32>
1427+
%0 = arith.subi %arg0, %arg1: vector<16xi32>
1428+
return
1429+
}
1430+
1431+
// CHECK-LABEL: @max_vector
1432+
func.func @max_vector(%arg0: vector<4294967295xi32>, %arg1: vector<4294967295xi32>) {
1433+
// CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<4294967295xi32>
1434+
%0 = arith.subi %arg0, %arg1: vector<4294967295xi32>
1435+
return
1436+
}
1437+
1438+
1439+
// Check float vector types of any size.
1440+
// CHECK-LABEL: @float_vector58
1441+
func.func @float_vector58(%arg0: vector<5xf16>, %arg1: vector<8xf64>) {
1442+
// CHECK: spirv.FAdd %{{.*}}, %{{.*}}: vector<5xf16>
1443+
%0 = arith.addf %arg0, %arg0: vector<5xf16>
1444+
// CHECK: spirv.FMul %{{.*}}, %{{.*}}: vector<8xf64>
1445+
%1 = arith.mulf %arg1, %arg1: vector<8xf64>
1446+
return
1447+
}
1448+
1449+
} // end module

mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,21 @@ module attributes {
351351
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
352352
} {
353353

354-
// CHECK-NOT: spirv.func @large_vector
355-
func.func @large_vector(%arg0: vector<1024xi32>) { return }
354+
// CHECK-NOT: spirv.func @large_vector_unsupported
355+
func.func @large_vector_unsupported(%arg0: vector<1024xi32>) { return }
356+
357+
} // end module
358+
359+
360+
// -----
361+
362+
// Check that large vectors are supported with VectorAnyINTEL or VectorComputeINTEL.
363+
module attributes {
364+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Float16, Kernel, VectorAnyINTEL], [SPV_INTEL_vector_compute]>, #spirv.resource_limits<>>
365+
} {
366+
367+
// CHECK: spirv.func @large_any_vector
368+
func.func @large_any_vector(%arg0: vector<1024xi32>) { return }
356369

357370
} // end module
358371

mlir/test/Dialect/SPIRV/IR/bit-ops.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ func.func @bitwise_or_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> {
137137
// -----
138138

139139
func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 {
140-
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
140+
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295}}
141141
%0 = spirv.BitwiseOr %arg0, %arg1 : f16
142142
return %0 : f16
143143
}
@@ -163,7 +163,7 @@ func.func @bitwise_xor_vector(%arg: vector<4xi32>) -> vector<4xi32> {
163163
// -----
164164

165165
func.func @bitwise_xor_float(%arg0: f16, %arg1: f16) -> f16 {
166-
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
166+
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295}}
167167
%0 = spirv.BitwiseXor %arg0, %arg1 : f16
168168
return %0 : f16
169169
}
@@ -272,7 +272,7 @@ func.func @bitwise_and_zext_vector(%arg: vector<2xi8>) -> vector<2xi32> {
272272
// -----
273273

274274
func.func @bitwise_and_float(%arg0: f16, %arg1: f16) -> f16 {
275-
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
275+
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295}}
276276
%0 = spirv.BitwiseAnd %arg0, %arg1 : f16
277277
return %0 : f16
278278
}

mlir/test/Dialect/SPIRV/IR/gl-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () {
2727
// -----
2828

2929
func.func @exp(%arg0 : vector<5xf32>) -> () {
30-
// expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}}
30+
// CHECK: spirv.GL.Exp {{%.*}} : vector<5xf32
3131
%2 = spirv.GL.Exp %arg0 : vector<5xf32>
3232
return
3333
}

mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" {
2121
// -----
2222

2323
spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" {
24-
// expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
24+
// expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2-4294967295, but got}}
2525
%0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16
2626
spirv.Return
2727
}
@@ -57,7 +57,7 @@ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" {
5757
// -----
5858

5959
spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" {
60-
// expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
60+
// expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2-4294967295, but got}}
6161
%0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f16
6262
spirv.Return
6363
}

mlir/test/Dialect/SPIRV/IR/logical-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func.func @logicalUnary(%arg0 : i1)
166166

167167
func.func @logicalUnary(%arg0 : i32)
168168
{
169-
// expected-error @+1 {{'operand' must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}}
169+
// expected-error @+1 {{'operand' must be bool or vector of bool values of length 2-4294967295, but got 'i32'}}
170170
%0 = spirv.LogicalNot %arg0 : i32
171171
return
172172
}

0 commit comments

Comments
 (0)