Skip to content

Commit 6522c4a

Browse files
authored
1 parent 76fcc39 commit 6522c4a

File tree

4 files changed

+65
-0
lines changed

4 files changed

+65
-0
lines changed

lib/SPIRV/SPIRVReader.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1961,6 +1961,12 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
19611961
IRBuilder<> Builder(BB);
19621962
auto *Scalar = transValue(MTS->getScalar(), F, BB);
19631963
auto *Matrix = transValue(MTS->getMatrix(), F, BB);
1964+
1965+
if (MTS->getMatrix()->getType()->isTypeCooperativeMatrixKHR()) {
1966+
return mapValue(BV, transSPIRVBuiltinFromInst(
1967+
static_cast<SPIRVInstruction *>(BV), BB));
1968+
}
1969+
19641970
uint64_t ColNum = Matrix->getType()->getArrayNumElements();
19651971
auto *ColType = cast<ArrayType>(Matrix->getType())->getElementType();
19661972
auto VecSize = cast<FixedVectorType>(ColType)->getNumElements();

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5735,6 +5735,11 @@ LLVMToSPIRVBase::transBuiltinToInstWithoutDecoration(Op OC, CallInst *CI,
57355735
return BM->addCompositeConstructInst(transType(CI->getType()), Operands,
57365736
BB);
57375737
}
5738+
case OpMatrixTimesScalar: {
5739+
return BM->addMatrixTimesScalarInst(
5740+
transType(CI->getType()), transValue(CI->getArgOperand(0), BB)->getId(),
5741+
transValue(CI->getArgOperand(1), BB)->getId(), BB);
5742+
}
57385743
default: {
57395744
if (isCvtOpCode(OC) && OC != OpGenericCastToPtrExplicit) {
57405745
return BM->addUnaryInst(OC, transScavengedType(CI),

lib/SPIRV/libSPIRV/SPIRVType.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ SPIRVType *SPIRVType::getScalarType() const {
142142
return getVectorComponentType();
143143
case OpTypeMatrix:
144144
return getMatrixColumnType()->getVectorComponentType();
145+
case OpTypeCooperativeMatrixKHR:
146+
return static_cast<const SPIRVTypeCooperativeMatrixKHR *>(this)
147+
->getCompType();
145148
case OpTypeInt:
146149
case OpTypeFloat:
147150
case OpTypeBool:
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix -o %t.spv
3+
; TODO: Validation is disabled till the moment the tools in CI are updated
4+
; R/UN: spirv-val %t.spv
5+
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
6+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
7+
8+
; RUN: llvm-spirv -r %t.spv -o %t.bc
9+
; RUN: llvm-dis < %t.bc | FileCheck %s --check-prefix=CHECK-LLVM
10+
11+
; CHECK-SPIRV: TypeFloat [[#TypeFloat:]] 32
12+
; CHECK-SPIRV: TypeCooperativeMatrixKHR [[#MatrixType:]]
13+
14+
; CHECK-SPIRV: CompositeConstruct [[#MatrixType]] [[#Matrix:]] [[#]] {{$}}
15+
; CHECK-SPIRV: Load [[#TypeFloat]] [[#Scalar:]]
16+
; CHECK-SPIRV: MatrixTimesScalar [[#MatrixType]] [[#]] [[#Matrix]] [[#Scalar]]
17+
18+
; CHECK-LLVM: %[[#Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
19+
; CHECK-LLVM: %[[#Scalar:]] = load float, ptr %scalar
20+
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z25__spirv_MatrixTimesScalarPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3f(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#Matrix]], float %[[#Scalar]])
21+
22+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
23+
target triple = "spir64-unknown-unknown"
24+
25+
; Function Attrs: mustprogress uwtable
26+
define dso_local void @matrix_times_scalar(ptr %scalar) local_unnamed_addr #0 {
27+
entry:
28+
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstruct(float 0.000000e+00) #4
29+
%1 = load float, ptr %scalar, align 4
30+
%call = call noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z25__spirv_MatrixTimesScalar(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, float %1)
31+
ret void
32+
}
33+
34+
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstruct(float noundef) local_unnamed_addr #2
35+
36+
declare noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z25__spirv_MatrixTimesScalar(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef, float noundef) local_unnamed_addr #2
37+
38+
attributes #0 = { mustprogress uwtable "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
39+
attributes #1 = { mustprogress nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) }
40+
attributes #2 = { "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
41+
attributes #3 = { nounwind }
42+
43+
!llvm.module.flags = !{!0, !1, !2, !3, !4}
44+
!llvm.ident = !{!5}
45+
46+
!0 = !{i32 7, !"Dwarf Version", i32 4}
47+
!1 = !{i32 1, !"wchar_size", i32 4}
48+
!2 = !{i32 8, !"PIC Level", i32 2}
49+
!3 = !{i32 7, !"PIE Level", i32 2}
50+
!4 = !{i32 7, !"uwtable", i32 2}
51+
!5 = !{!"clang version 16.0.0 (https://github.com/llvm/llvm-project.git 08d094a0e457360ad8b94b017d2dc277e697ca76)"}

0 commit comments

Comments
 (0)