diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp index fffaecaae3ad2..7bf7ba0b84d20 100644 --- a/clang/lib/CodeGen/CodeGenTypes.cpp +++ b/clang/lib/CodeGen/CodeGenTypes.cpp @@ -51,6 +51,54 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD, StringRef suffix) { SmallString<256> TypeName; llvm::raw_svector_ostream OS(TypeName); + // If RD is spirv_JointMatrixINTEL type, mangle differently. + if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) { + if (RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") { + if (auto TemplateDecl = dyn_cast(RD)) { + ArrayRef TemplateArgs = + TemplateDecl->getTemplateArgs().asArray(); + OS << "spirv.JointMatrixINTEL."; + for (auto &TemplateArg : TemplateArgs) { + OS << "_"; + if (TemplateArg.getKind() == TemplateArgument::Type) { + llvm::Type *TTy = ConvertType(TemplateArg.getAsType()); + if (TTy->isIntegerTy()) { + switch (TTy->getIntegerBitWidth()) { + case 8: + OS << "char"; + break; + case 16: + OS << "short"; + break; + case 32: + OS << "int"; + break; + case 64: + OS << "long"; + break; + default: + OS << "i" << TTy->getIntegerBitWidth(); + break; + } + } else if (TTy->isBFloatTy()) + OS << "bfloat16"; + else if (TTy->isStructTy()) { + StringRef LlvmTyName = TTy->getStructName(); + // Emit half/bfloat16 for sycl[::*]::{half,bfloat16} + if (LlvmTyName.startswith("class.sycl::") || + LlvmTyName.startswith("class.__sycl_internal::")) + LlvmTyName = LlvmTyName.rsplit("::").second; + OS << LlvmTyName; + } else + TTy->print(OS, false, true); + } else if (TemplateArg.getKind() == TemplateArgument::Integral) + OS << TemplateArg.getAsIntegral(); + } + Ty->setName(OS.str()); + return; + } + } + } OS << RD->getKindName() << '.'; // FIXME: We probably want to make more tweaks to the printing policy. For diff --git a/clang/test/CodeGenSYCL/matrix.cpp b/clang/test/CodeGenSYCL/matrix.cpp new file mode 100644 index 0000000000000..a361518590519 --- /dev/null +++ b/clang/test/CodeGenSYCL/matrix.cpp @@ -0,0 +1,34 @@ +// RUN: %clang_cc1 -triple spir64-unknown-unknown -disable-llvm-passes -emit-llvm %s -o - -no-opaque-pointers | FileCheck %s +// Test that SPIR-V codegen generates the expected LLVM struct name for the +// JointMatrixINTEL type. +#include +#include + +namespace __spv { + template + struct __spirv_JointMatrixINTEL; +} + +// CHECK: @_Z2f1{{.*}}(%spirv.JointMatrixINTEL._float_5_10_0_1 +void f1(__spv::__spirv_JointMatrixINTEL *matrix) {} + +// CHECK: @_Z2f2{{.*}}(%spirv.JointMatrixINTEL._long_10_2_0_0 +void f2(__spv::__spirv_JointMatrixINTEL *matrix) {} + +// CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0 +void f3(__spv::__spirv_JointMatrixINTEL *matrix) {} + +namespace sycl { + class half {}; + class bfloat16 {}; +} +typedef sycl::half my_half; + +// CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0 +void f4(__spv::__spirv_JointMatrixINTEL *matrix) {} + +// CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0 +void f5(__spv::__spirv_JointMatrixINTEL *matrix) {} + +// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0 +void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {} diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index aa38ece1b8713..891d37aec696a 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -1,8 +1,8 @@ // RUN: %clangxx -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s -// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL" = type { [12 x [48 x [1 x [4 x i8]]]] addrspace(4)* } -// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [12 x [12 x [1 x [4 x i32]]]] addrspace(4)* } -// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [48 x [12 x [4 x [4 x i8]]]] addrspace(4)* } +// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3 = type { [12 x [48 x [1 x [4 x i8]]]] addrspace(4)* } +// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3 = type { [12 x [12 x [1 x [4 x i32]]]] addrspace(4)* } +// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3 = type { [48 x [12 x [4 x [4 x i8]]]] addrspace(4)* } #include #if (SYCL_EXT_ONEAPI_MATRIX == 2)