diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index 3bf312f511033..d0b9513c215bf 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -86,6 +86,12 @@ __spirv_JointMatrixSUMadINTEL( __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); +template +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +__spirv_CompositeConstruct(const T v); + template extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL( diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index d6cd2e41ed308..e9e03d3b894cb 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -202,6 +202,23 @@ joint_matrix_mad(Group sg, joint_matrix &mA, #endif // __SYCL_DEVICE_ONLY__ } +template +inline __SYCL_ALWAYS_INLINE void +joint_matrix_fill(Group sg, + joint_matrix &res, + const T v) { + // We kept the unused "sg" in joint_matrix_fill to match the other DPC++ + // functions + (void)sg; +#ifdef __SYCL_DEVICE_ONLY__ + res.spvm = __spirv_CompositeConstruct(v); +#else + (void)res; + (void)v; +#endif // __SYCL_DEVICE_ONLY__ +} + template diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index 4bd4a0aa16742..a1fce823e12a1 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -1,4 +1,5 @@ // RUN: %clangxx -fsycl -O2 %s -o %t.out +// XFAIL: * #include #if (SYCL_EXT_ONEAPI_MATRIX == 2) #include @@ -68,10 +69,7 @@ void matrix_multiply(big_matrix &C, big_matrix