diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index 18ef03cc70607..ff7c9982dd77a 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -50,6 +50,42 @@ __spirv_JointMatrixMadINTEL( __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); +template +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +__spirv_JointMatrixUUMadINTEL( + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, + __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); + +template +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +__spirv_JointMatrixUSMadINTEL( + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, + __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); + +template +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +__spirv_JointMatrixSUMadINTEL( + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, + __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); + #ifndef __SPIRV_BUILTIN_DECLARATIONS__ #error \ "SPIR-V built-ins are not available. Please set -fdeclare-spirv-builtins flag." diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index 5343c3ccb301c..e81881e52f6a7 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -160,16 +160,27 @@ joint_matrix_store(Group sg, #endif // __SYCL_DEVICE_ONLY__ } -template -inline __SYCL_ALWAYS_INLINE joint_matrix +inline __SYCL_ALWAYS_INLINE joint_matrix joint_matrix_mad(Group sg, joint_matrix &mA, - joint_matrix &mB, - joint_matrix &mC) { + joint_matrix &mB, + joint_matrix &mC) { #ifdef __SYCL_DEVICE_ONLY__ - joint_matrix res(sg); - res.spvm = __spirv_JointMatrixMadINTEL(mA.spvm, mB.spvm, mC.spvm); + joint_matrix res(sg); + if constexpr (std::is_same::value && + std::is_same::value && + std::is_same::value) + res.spvm = __spirv_JointMatrixMadINTEL(mA.spvm, mB.spvm, mC.spvm); + else if constexpr (std::is_unsigned::value && std::is_unsigned::value) + res.spvm = __spirv_JointMatrixUUMadINTEL(mA.spvm, mB.spvm, mC.spvm); + else if constexpr (std::is_signed::value && std::is_unsigned::value) + res.spvm = __spirv_JointMatrixSUMadINTEL(mA.spvm, mB.spvm, mC.spvm); + else if constexpr (std::is_unsigned::value && std::is_signed::value) + res.spvm = __spirv_JointMatrixUSMadINTEL(mA.spvm, mB.spvm, mC.spvm); + else + res.spvm = __spirv_JointMatrixMadINTEL(mA.spvm, mB.spvm, mC.spvm); return res; #else (void)sg; diff --git a/sycl/test/matrix/matrix-bf16-test-SG-16.cpp b/sycl/test/matrix/matrix-bf16-test-SG-16.cpp index 0d7d7e2ef4629..0150c420e3bd9 100644 --- a/sycl/test/matrix/matrix-bf16-test-SG-16.cpp +++ b/sycl/test/matrix/matrix-bf16-test-SG-16.cpp @@ -1,4 +1,4 @@ -// RUN: %clangxx -march=sapphirerapids -fsycl -O2 %s -o %t.out +// RUN: %clangxx -fsycl -O2 %s -o %t.out #include #if (SYCL_EXT_ONEAPI_MATRIX == 2) #include diff --git a/sycl/test/matrix/matrix-bf16-test.cpp b/sycl/test/matrix/matrix-bf16-test.cpp index 224846fd6f9e7..efa118a33bff7 100644 --- a/sycl/test/matrix/matrix-bf16-test.cpp +++ b/sycl/test/matrix/matrix-bf16-test.cpp @@ -1,4 +1,4 @@ -// RUN: %clangxx -march=sapphirerapids -fsycl -O2 %s -o %t.out +// RUN: %clangxx -fsycl -O2 %s -o %t.out #include #if (SYCL_EXT_ONEAPI_MATRIX == 2) #include diff --git a/sycl/test/matrix/matrix-int8-test-SG-16.cpp b/sycl/test/matrix/matrix-int8-test-SG-16.cpp index 6fead706ea137..0aace5f46c036 100644 --- a/sycl/test/matrix/matrix-int8-test-SG-16.cpp +++ b/sycl/test/matrix/matrix-int8-test-SG-16.cpp @@ -1,4 +1,4 @@ -// RUN: %clangxx -march=sapphirerapids -fsycl -O2 %s -o %t.out +// RUN: %clangxx -fsycl -O2 %s -o %t.out #include #if (SYCL_EXT_ONEAPI_MATRIX == 2) #include diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index 9692f686e4859..263cc75c1097f 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -1,4 +1,4 @@ -// RUN: %clangxx -march=sapphirerapids -fsycl -O2 %s -o %t.out +// RUN: %clangxx -fsycl -O2 %s -o %t.out #include #if (SYCL_EXT_ONEAPI_MATRIX == 2) #include