Skip to content

Commit d8fa6c4

Browse files
committed
[llvm][RISCV] Add RISCV vector tuple type to value types(MVT)
This patch handles the types(MVT) in `selectionDAG` for RISCV vector tuples. As described in previous patch handling llvm types, the MVTs also have 32 variants: ``` riscv_mf8x2, riscv_mf8x3, riscv_mf8x4, riscv_mf8x5, riscv_mf8x6, riscv_mf8x7, riscv_mf8x8, riscv_mf4x2, riscv_mf4x3, riscv_mf4x4, riscv_mf4x5, riscv_mf4x6, riscv_mf4x7, riscv_mf4x8, riscv_mf2x2, riscv_mf2x3, riscv_mf2x4, riscv_mf2x5, riscv_mf2x6, riscv_mf2x7, riscv_mf2x8, riscv_m1x2, riscv_m1x3, riscv_m1x4, riscv_m1x5, riscv_m1x6, riscv_m1x7, riscv_m1x8, riscv_m2x2, riscv_m2x3, riscv_m2x4, riscv_m4x2. ``` An intuitive way to model vector tuple type is using nested scalable vector, e.g. `nElts=NF, EltTy=nxv2i32`. However it's not compatible to what we've done to handle scalable vector in TargetLowering, so it would need more effort to change the code to handle this concept. Another approach is encoding the `LMUL` info in `sz` of `MVT`, e.g. `nElts=NF, sz=(LMUL*NF*RVVBitsPerBlock)`, this makes it much easier to handle and changes less code. This patch adopts the latter approach.
1 parent 600f240 commit d8fa6c4

File tree

6 files changed

+207
-27
lines changed

6 files changed

+207
-27
lines changed

llvm/include/llvm/CodeGen/ValueTypes.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@ namespace llvm {
8787
return getExtendedVectorVT(Context, VT, EC);
8888
}
8989

90+
/// Returns the EVT that represents a vector tuple type.
91+
static EVT getRISCVVectorTupleVT(int Log2LMUL, unsigned NFields) {
92+
// Sz = NF * LMUL * BitsPerBlock
93+
unsigned Sz =
94+
NFields * (Log2LMUL < 0 ? (64 >> -Log2LMUL) : (64 << Log2LMUL));
95+
return MVT::getRISCVVectorTupleVT(Sz, NFields);
96+
}
97+
9098
/// Return a vector with the same number of elements as this vector, but
9199
/// with the element type converted to an integer type with the same
92100
/// bitwidth.
@@ -174,6 +182,9 @@ namespace llvm {
174182
return isSimple() ? V.isScalableVector() : isExtendedScalableVector();
175183
}
176184

185+
/// Return true if this is a vector value type.
186+
bool isRISCVVectorTuple() const { return V.isRISCVVectorTuple(); }
187+
177188
bool isFixedLengthVector() const {
178189
return isSimple() ? V.isFixedLengthVector()
179190
: isExtendedFixedLengthVector();

llvm/include/llvm/CodeGen/ValueTypes.td

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class ValueType<int size, int value> {
2323
bit isFP = false;
2424
bit isVector = false;
2525
bit isScalable = false;
26+
bit isRISCVVecTuple = false;
2627
// Indicates this VT should be included in the
2728
// [FIRST_VALUETYPE,LAST_VALUETYPE] range.
2829
bit isNormalValueType = true;
@@ -56,6 +57,14 @@ class VTScalableVec<int nelem, ValueType elt, int value>
5657
let isScalable = true;
5758
}
5859

60+
class VTVecTup<int size, int nf, ValueType dummy_elt, int value>
61+
: ValueType<size, value> {
62+
let nElem = nf;
63+
let isVector = true;
64+
let ElementType = dummy_elt;
65+
let isRISCVVecTuple = true;
66+
}
67+
5968
defset list<ValueType> ValueTypes = {
6069

6170
def OtherVT : ValueType<0, 1> { // "Other" value
@@ -273,20 +282,54 @@ def nxv2f64 : VTScalableVec<2, f64, 187>; // n x 2 x f64 vector value
273282
def nxv4f64 : VTScalableVec<4, f64, 188>; // n x 4 x f64 vector value
274283
def nxv8f64 : VTScalableVec<8, f64, 189>; // n x 8 x f64 vector value
275284

276-
def x86mmx : ValueType<64, 190>; // X86 MMX value
277-
def Glue : ValueType<0, 191>; // Pre-RA sched glue
278-
def isVoid : ValueType<0, 192>; // Produces no value
279-
def untyped : ValueType<8, 193> { // Produces an untyped value
285+
// Sz = NF * LMUL * BitsPerBlock
286+
def riscv_mf8x2 : VTVecTup<16, 2, i8, 190>; // RISCV vector tuple(lmul=1/8, nf=2)
287+
def riscv_mf8x3 : VTVecTup<24, 3, i8, 191>; // RISCV vector tuple(lmul=1/8, nf=3)
288+
def riscv_mf8x4 : VTVecTup<32, 4, i8, 192>; // RISCV vector tuple(lmul=1/8, nf=4)
289+
def riscv_mf8x5 : VTVecTup<40, 5, i8, 193>; // RISCV vector tuple(lmul=1/8, nf=5)
290+
def riscv_mf8x6 : VTVecTup<48, 6, i8, 194>; // RISCV vector tuple(lmul=1/8, nf=6)
291+
def riscv_mf8x7 : VTVecTup<56, 7, i8, 195>; // RISCV vector tuple(lmul=1/8, nf=7)
292+
def riscv_mf8x8 : VTVecTup<64, 8, i8, 196>; // RISCV vector tuple(lmul=1/8, nf=8)
293+
def riscv_mf4x2 : VTVecTup<32, 2, i8, 197>; // RISCV vector tuple(lmul=1/4, nf=2)
294+
def riscv_mf4x3 : VTVecTup<48, 3, i8, 198>; // RISCV vector tuple(lmul=1/4, nf=3)
295+
def riscv_mf4x4 : VTVecTup<64, 4, i8, 199>; // RISCV vector tuple(lmul=1/4, nf=4)
296+
def riscv_mf4x5 : VTVecTup<80, 5, i8, 200>; // RISCV vector tuple(lmul=1/4, nf=5)
297+
def riscv_mf4x6 : VTVecTup<96, 6, i8, 201>; // RISCV vector tuple(lmul=1/4, nf=6)
298+
def riscv_mf4x7 : VTVecTup<112, 7, i8, 202>; // RISCV vector tuple(lmul=1/4, nf=7)
299+
def riscv_mf4x8 : VTVecTup<128, 8, i8, 203>; // RISCV vector tuple(lmul=1/4, nf=8)
300+
def riscv_mf2x2 : VTVecTup<64, 2, i8, 204>; // RISCV vector tuple(lmul=1/2, nf=2)
301+
def riscv_mf2x3 : VTVecTup<96, 3, i8, 205>; // RISCV vector tuple(lmul=1/2, nf=3)
302+
def riscv_mf2x4 : VTVecTup<128, 4, i8, 206>; // RISCV vector tuple(lmul=1/2, nf=4)
303+
def riscv_mf2x5 : VTVecTup<160, 5, i8, 207>; // RISCV vector tuple(lmul=1/2, nf=5)
304+
def riscv_mf2x6 : VTVecTup<192, 6, i8, 208>; // RISCV vector tuple(lmul=1/2, nf=6)
305+
def riscv_mf2x7 : VTVecTup<224, 7, i8, 209>; // RISCV vector tuple(lmul=1/2, nf=7)
306+
def riscv_mf2x8 : VTVecTup<256, 8, i8, 210>; // RISCV vector tuple(lmul=1/2, nf=8)
307+
def riscv_m1x2 : VTVecTup<128, 2, i8, 211>; // RISCV vector tuple(lmul=1, nf=2)
308+
def riscv_m1x3 : VTVecTup<192, 3, i8, 212>; // RISCV vector tuple(lmul=1, nf=3)
309+
def riscv_m1x4 : VTVecTup<256, 4, i8, 213>; // RISCV vector tuple(lmul=1, nf=4)
310+
def riscv_m1x5 : VTVecTup<320, 5, i8, 214>; // RISCV vector tuple(lmul=1, nf=5)
311+
def riscv_m1x6 : VTVecTup<384, 6, i8, 215>; // RISCV vector tuple(lmul=1, nf=6)
312+
def riscv_m1x7 : VTVecTup<448, 7, i8, 216>; // RISCV vector tuple(lmul=1, nf=7)
313+
def riscv_m1x8 : VTVecTup<512, 8, i8, 217>; // RISCV vector tuple(lmul=1, nf=8)
314+
def riscv_m2x2 : VTVecTup<256, 2, i8, 218>; // RISCV vector tuple(lmul=2, nf=2)
315+
def riscv_m2x3 : VTVecTup<384, 3, i8, 219>; // RISCV vector tuple(lmul=2, nf=3)
316+
def riscv_m2x4 : VTVecTup<512, 4, i8, 220>; // RISCV vector tuple(lmul=2, nf=4)
317+
def riscv_m4x2 : VTVecTup<512, 2, i8, 221>; // RISCV vector tuple(lmul=4, nf=2)
318+
319+
def x86mmx : ValueType<64, 222>; // X86 MMX value
320+
def Glue : ValueType<0, 223>; // Pre-RA sched glue
321+
def isVoid : ValueType<0, 224>; // Produces no value
322+
def untyped : ValueType<8, 225> { // Produces an untyped value
280323
let LLVMName = "Untyped";
281324
}
282-
def funcref : ValueType<0, 194>; // WebAssembly's funcref type
283-
def externref : ValueType<0, 195>; // WebAssembly's externref type
284-
def exnref : ValueType<0, 196>; // WebAssembly's exnref type
285-
def x86amx : ValueType<8192, 197>; // X86 AMX value
286-
def i64x8 : ValueType<512, 198>; // 8 Consecutive GPRs (AArch64)
325+
def funcref : ValueType<0, 226>; // WebAssembly's funcref type
326+
def externref : ValueType<0, 227>; // WebAssembly's externref type
327+
def exnref : ValueType<0, 228>; // WebAssembly's exnref type
328+
def x86amx : ValueType<8192, 229>; // X86 AMX value
329+
def i64x8 : ValueType<512, 230>; // 8 Consecutive GPRs (AArch64)
287330
def aarch64svcount
288-
: ValueType<16, 199>; // AArch64 predicate-as-counter
289-
def spirvbuiltin : ValueType<0, 200>; // SPIR-V's builtin type
331+
: ValueType<16, 231>; // AArch64 predicate-as-counter
332+
def spirvbuiltin : ValueType<0, 232>; // SPIR-V's builtin type
290333

291334
let isNormalValueType = false in {
292335
def token : ValueType<0, 248>; // TokenTy

llvm/include/llvm/CodeGenTypes/MachineValueType.h

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "llvm/Support/ErrorHandling.h"
2121
#include "llvm/Support/MathExtras.h"
2222
#include "llvm/Support/TypeSize.h"
23+
#include "llvm/Support/Debug.h"
2324
#include <cassert>
2425
#include <cstdint>
2526

@@ -38,7 +39,8 @@ namespace llvm {
3839
// are considered extended value types.
3940
INVALID_SIMPLE_VALUE_TYPE = 0,
4041

41-
#define GET_VT_ATTR(Ty, n, sz, Any, Int, FP, Vec, Sc, NElem, EltTy) Ty = n,
42+
#define GET_VT_ATTR(Ty, n, sz, Any, Int, FP, Vec, Sc, Tup, NElem, EltTy) \
43+
Ty = n,
4244
#define GET_VT_RANGES
4345
#include "llvm/CodeGen/GenVT.inc"
4446
#undef GET_VT_ATTR
@@ -113,6 +115,13 @@ namespace llvm {
113115
SimpleTy <= MVT::LAST_SCALABLE_VECTOR_VALUETYPE);
114116
}
115117

118+
/// Return true if this is a RISCV vector tuple type where the
119+
/// runtime length is machine dependent
120+
bool isRISCVVectorTuple() const {
121+
return (SimpleTy >= MVT::FIRST_RISCV_VECTOR_TUPLE_VALUETYPE &&
122+
SimpleTy <= MVT::LAST_RISCV_VECTOR_TUPLE_VALUETYPE);
123+
}
124+
116125
/// Return true if this is a custom target type that has a scalable size.
117126
bool isScalableTargetExtVT() const {
118127
return SimpleTy == MVT::aarch64svcount;
@@ -171,7 +180,7 @@ namespace llvm {
171180
/// Return true if this is an overloaded type for TableGen.
172181
bool isOverloaded() const {
173182
switch (SimpleTy) {
174-
#define GET_VT_ATTR(Ty, n, sz, Any, Int, FP, Vec, Sc, NElem, EltTy) \
183+
#define GET_VT_ATTR(Ty, n, sz, Any, Int, FP, Vec, Sc, Tup, NElem, EltTy) \
175184
case Ty: \
176185
return Any;
177186
#include "llvm/CodeGen/GenVT.inc"
@@ -254,7 +263,8 @@ namespace llvm {
254263
MVT getVectorElementType() const {
255264
assert(SimpleTy >= FIRST_VALUETYPE && SimpleTy <= LAST_VALUETYPE);
256265
static constexpr SimpleValueType EltTyTable[] = {
257-
#define GET_VT_ATTR(Ty, N, Sz, Any, Int, FP, Vec, Sc, NElem, EltTy) EltTy,
266+
#define GET_VT_ATTR(Ty, N, Sz, Any, Int, FP, Vec, Sc, Tup, NElem, EltTy) \
267+
EltTy,
258268
#include "llvm/CodeGen/GenVT.inc"
259269
#undef GET_VT_ATTR
260270
};
@@ -267,7 +277,8 @@ namespace llvm {
267277
unsigned getVectorMinNumElements() const {
268278
assert(SimpleTy >= FIRST_VALUETYPE && SimpleTy <= LAST_VALUETYPE);
269279
static constexpr uint16_t NElemTable[] = {
270-
#define GET_VT_ATTR(Ty, N, Sz, Any, Int, FP, Vec, Sc, NElem, EltTy) NElem,
280+
#define GET_VT_ATTR(Ty, N, Sz, Any, Int, FP, Vec, Sc, Tup, NElem, EltTy) \
281+
NElem,
271282
#include "llvm/CodeGen/GenVT.inc"
272283
#undef GET_VT_ATTR
273284
};
@@ -296,7 +307,7 @@ namespace llvm {
296307
/// base size.
297308
TypeSize getSizeInBits() const {
298309
static constexpr TypeSize SizeTable[] = {
299-
#define GET_VT_ATTR(Ty, N, Sz, Any, Int, FP, Vec, Sc, NElem, EltTy) \
310+
#define GET_VT_ATTR(Ty, N, Sz, Any, Int, FP, Vec, Sc, Tup, NElem, EltTy) \
300311
TypeSize(Sz, Sc || Ty == aarch64svcount /* FIXME: Not in the td. */),
301312
#include "llvm/CodeGen/GenVT.inc"
302313
#undef GET_VT_ATTR
@@ -418,7 +429,7 @@ namespace llvm {
418429
}
419430

420431
static MVT getFloatingPointVT(unsigned BitWidth) {
421-
#define GET_VT_ATTR(Ty, n, sz, Any, Int, FP, Vec, Sc, NElem, EltTy) \
432+
#define GET_VT_ATTR(Ty, n, sz, Any, Int, FP, Vec, Sc, Tup, NElem, EltTy) \
422433
if (FP == 3 && sz == BitWidth) \
423434
return Ty;
424435
#include "llvm/CodeGen/GenVT.inc"
@@ -428,7 +439,7 @@ namespace llvm {
428439
}
429440

430441
static MVT getIntegerVT(unsigned BitWidth) {
431-
#define GET_VT_ATTR(Ty, n, sz, Any, Int, FP, Vec, Sc, NElem, EltTy) \
442+
#define GET_VT_ATTR(Ty, n, sz, Any, Int, FP, Vec, Sc, Tup, NElem, EltTy) \
432443
if (Int == 3 && sz == BitWidth) \
433444
return Ty;
434445
#include "llvm/CodeGen/GenVT.inc"
@@ -438,8 +449,8 @@ namespace llvm {
438449
}
439450

440451
static MVT getVectorVT(MVT VT, unsigned NumElements) {
441-
#define GET_VT_VECATTR(Ty, Sc, nElem, ElTy) \
442-
if (!Sc && VT.SimpleTy == ElTy && NumElements == nElem) \
452+
#define GET_VT_VECATTR(Ty, Sc, Tup, nElem, ElTy) \
453+
if (!Sc && !Tup && VT.SimpleTy == ElTy && NumElements == nElem) \
443454
return Ty;
444455
#include "llvm/CodeGen/GenVT.inc"
445456
#undef GET_VT_VECATTR
@@ -448,7 +459,7 @@ namespace llvm {
448459
}
449460

450461
static MVT getScalableVectorVT(MVT VT, unsigned NumElements) {
451-
#define GET_VT_VECATTR(Ty, Sc, nElem, ElTy) \
462+
#define GET_VT_VECATTR(Ty, Sc, Tup, nElem, ElTy) \
452463
if (Sc && VT.SimpleTy == ElTy && NumElements == nElem) \
453464
return Ty;
454465
#include "llvm/CodeGen/GenVT.inc"
@@ -457,6 +468,17 @@ namespace llvm {
457468
return (MVT::SimpleValueType)(MVT::INVALID_SIMPLE_VALUE_TYPE);
458469
}
459470

471+
static MVT getRISCVVectorTupleVT(unsigned Sz, unsigned NFields) {
472+
#define GET_VT_ATTR(Ty, n, sz, Any, Int, FP, Vec, Sc, Tup, nElem, EltTy) \
473+
if (Tup && sz == Sz && nElem == NFields) \
474+
return Ty;
475+
#include "llvm/CodeGen/GenVT.inc"
476+
#undef GET_VT_ATTR
477+
478+
llvm::dbgs() << Sz << ' ' << NFields << '\n';
479+
llvm_unreachable("Invalid RISCV vector tuple type");
480+
}
481+
460482
static MVT getVectorVT(MVT VT, unsigned NumElements, bool IsScalable) {
461483
if (IsScalable)
462484
return getScalableVectorVT(VT, NumElements);

llvm/lib/CodeGen/ValueTypes.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88

99
#include "llvm/CodeGen/ValueTypes.h"
1010
#include "llvm/ADT/StringExtras.h"
11+
#include "llvm/ADT/StringRef.h"
1112
#include "llvm/IR/DerivedTypes.h"
1213
#include "llvm/IR/Type.h"
1314
#include "llvm/Support/Debug.h"
15+
#include "llvm/TargetParser/RISCVTargetParser.h"
1416
#include "llvm/Support/ErrorHandling.h"
1517
#include "llvm/Support/TypeSize.h"
1618
#include "llvm/Support/WithColor.h"
@@ -161,6 +163,14 @@ TypeSize EVT::getExtendedSizeInBits() const {
161163
std::string EVT::getEVTString() const {
162164
switch (V.SimpleTy) {
163165
default:
166+
if (isRISCVVectorTuple()) {
167+
unsigned Sz = getSizeInBits();
168+
unsigned NF = getVectorMinNumElements();
169+
int Log2LMUL = Log2_64(Sz / NF) - 6;
170+
return "riscv_m" +
171+
((Log2LMUL < 0 ? "f" : "") + utostr(1 << std::abs(Log2LMUL))) +
172+
"x" + utostr(getVectorMinNumElements());
173+
}
164174
if (isVector())
165175
return (isScalableVector() ? "nxv" : "v") +
166176
utostr(getVectorElementCount().getKnownMinValue()) +
@@ -214,6 +224,70 @@ Type *EVT::getTypeForEVT(LLVMContext &Context) const {
214224
case MVT::i64x8: return IntegerType::get(Context, 512);
215225
case MVT::externref: return Type::getWasm_ExternrefTy(Context);
216226
case MVT::funcref: return Type::getWasm_FuncrefTy(Context);
227+
case MVT::riscv_mf8x2:
228+
return TargetExtType::get(Context, "riscv_mf8x2", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{0, 2});
229+
case MVT::riscv_mf8x3:
230+
return TargetExtType::get(Context, "riscv_mf8x3", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{0, 3});
231+
case MVT::riscv_mf8x4:
232+
return TargetExtType::get(Context, "riscv_mf8x4", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{0, 4});
233+
case MVT::riscv_mf8x5:
234+
return TargetExtType::get(Context, "riscv_mf8x5", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{0, 5});
235+
case MVT::riscv_mf8x6:
236+
return TargetExtType::get(Context, "riscv_mf8x6", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{0, 6});
237+
case MVT::riscv_mf8x7:
238+
return TargetExtType::get(Context, "riscv_mf8x7", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{0, 7});
239+
case MVT::riscv_mf8x8:
240+
return TargetExtType::get(Context, "riscv_mf8x8", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{0, 8});
241+
case MVT::riscv_mf4x2:
242+
return TargetExtType::get(Context, "riscv_mf4x2", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{1, 2});
243+
case MVT::riscv_mf4x3:
244+
return TargetExtType::get(Context, "riscv_mf4x3", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{1, 3});
245+
case MVT::riscv_mf4x4:
246+
return TargetExtType::get(Context, "riscv_mf4x4", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{1, 4});
247+
case MVT::riscv_mf4x5:
248+
return TargetExtType::get(Context, "riscv_mf4x5", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{1, 5});
249+
case MVT::riscv_mf4x6:
250+
return TargetExtType::get(Context, "riscv_mf4x6", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{1, 6});
251+
case MVT::riscv_mf4x7:
252+
return TargetExtType::get(Context, "riscv_mf4x7", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{1, 7});
253+
case MVT::riscv_mf4x8:
254+
return TargetExtType::get(Context, "riscv_mf4x8", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{1, 8});
255+
case MVT::riscv_mf2x2:
256+
return TargetExtType::get(Context, "riscv_mf2x2", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{2, 2});
257+
case MVT::riscv_mf2x3:
258+
return TargetExtType::get(Context, "riscv_mf2x3", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{2, 3});
259+
case MVT::riscv_mf2x4:
260+
return TargetExtType::get(Context, "riscv_mf2x4", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{2, 4});
261+
case MVT::riscv_mf2x5:
262+
return TargetExtType::get(Context, "riscv_mf2x5", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{2, 5});
263+
case MVT::riscv_mf2x6:
264+
return TargetExtType::get(Context, "riscv_mf2x6", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{2, 6});
265+
case MVT::riscv_mf2x7:
266+
return TargetExtType::get(Context, "riscv_mf2x7", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{2, 7});
267+
case MVT::riscv_mf2x8:
268+
return TargetExtType::get(Context, "riscv_mf2x8", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{2, 8});
269+
case MVT::riscv_m1x2:
270+
return TargetExtType::get(Context, "riscv_m1x2", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{3, 2});
271+
case MVT::riscv_m1x3:
272+
return TargetExtType::get(Context, "riscv_m1x3", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{3, 3});
273+
case MVT::riscv_m1x4:
274+
return TargetExtType::get(Context, "riscv_m1x4", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{3, 4});
275+
case MVT::riscv_m1x5:
276+
return TargetExtType::get(Context, "riscv_m1x5", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{3, 5});
277+
case MVT::riscv_m1x6:
278+
return TargetExtType::get(Context, "riscv_m1x6", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{3, 6});
279+
case MVT::riscv_m1x7:
280+
return TargetExtType::get(Context, "riscv_m1x7", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{3, 7});
281+
case MVT::riscv_m1x8:
282+
return TargetExtType::get(Context, "riscv_m1x8", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{3, 8});
283+
case MVT::riscv_m2x2:
284+
return TargetExtType::get(Context, "riscv_m2x2", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{4, 2});
285+
case MVT::riscv_m2x3:
286+
return TargetExtType::get(Context, "riscv_m2x3", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{4, 3});
287+
case MVT::riscv_m2x4:
288+
return TargetExtType::get(Context, "riscv_m2x4", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{4, 4});
289+
case MVT::riscv_m4x2:
290+
return TargetExtType::get(Context, "riscv_m4x2", {Type::getInt8Ty(Context), Type::getInt8Ty(Context)},{5, 2});
217291
case MVT::Metadata: return Type::getMetadataTy(Context);
218292
#define GET_VT_EVT(Ty, EVT) case MVT::Ty: return EVT;
219293
#include "llvm/CodeGen/GenVT.inc"
@@ -249,6 +323,28 @@ MVT MVT::getVT(Type *Ty, bool HandleUnknown){
249323
return MVT(MVT::aarch64svcount);
250324
else if (TargetExtTy->getName().starts_with("spirv."))
251325
return MVT(MVT::spirvbuiltin);
326+
if (TargetExtTy->getName().starts_with("riscv_m")) {
327+
StringRef Name = TargetExtTy->getName();
328+
if (Name.consume_front("riscv_m")) {
329+
bool IsFracLMUL = false;
330+
unsigned LMUL, NF;
331+
if (Name.consume_front("f"))
332+
IsFracLMUL = true;
333+
334+
LMUL = Name[0] - '0';
335+
Name = Name.drop_front(2);
336+
NF = Name[0] - '0';
337+
338+
llvm::dbgs() << LMUL << ' ' << NF << '\n';
339+
unsigned Sz = NF * RISCV::RVVBitsPerBlock;
340+
if (IsFracLMUL)
341+
Sz /= LMUL;
342+
else
343+
Sz *= LMUL;
344+
345+
return MVT::getRISCVVectorTupleVT(Sz , NF);
346+
}
347+
}
252348
if (HandleUnknown)
253349
return MVT(MVT::Other);
254350
llvm_unreachable("Unknown target ext type!");

llvm/utils/TableGen/Common/CodeGenTarget.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ StringRef llvm::getName(MVT::SimpleValueType T) {
6363
StringRef llvm::getEnumName(MVT::SimpleValueType T) {
6464
// clang-format off
6565
switch (T) {
66-
#define GET_VT_ATTR(Ty, N, Sz, Any, Int, FP, Vec, Sc, NElem, EltTy) \
66+
#define GET_VT_ATTR(Ty, N, Sz, Any, Int, FP, Vec, Sc, Tup, NElem, EltTy) \
6767
case MVT::Ty: return "MVT::" # Ty;
6868
#include "llvm/CodeGen/GenVT.inc"
6969
default: llvm_unreachable("ILLEGAL VALUE TYPE!");

0 commit comments

Comments
 (0)