Skip to content

Commit 2ce17a2

Browse files
Merge pull request #1605 from IntelPython/backport-gh-1567
Backport gh-1567 to 0.16.x maintenance branch
2 parents d08793d + 45c1841 commit 2ce17a2

File tree

1 file changed

+5
-2
lines changed
  • dpctl/tensor/libtensor/include/kernels/linalg_functions

1 file changed

+5
-2
lines changed

dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,10 +1365,13 @@ sycl::event _gemm_batch_nm_impl(sycl::queue &exec_q,
13651365
const std::uint32_t max_sg_size = krn.template get_info<
13661366
sycl::info::kernel_device_specific::max_sub_group_size>(dev);
13671367

1368+
const size_t k_wg_sz = krn.template get_info<
1369+
sycl::info::kernel_device_specific::work_group_size>(dev);
1370+
13681371
// Limit work-group size
13691372
constexpr size_t wg_sz_limit(2048);
1370-
const size_t max_wg_sz = std::min<size_t>(
1371-
dev.get_info<sycl::info::device::max_work_group_size>(), wg_sz_limit);
1373+
const size_t max_wg_sz = std::min(wg_sz_limit, k_wg_sz);
1374+
13721375
const std::uint32_t max_subgroups_per_wg =
13731376
static_cast<std::uint32_t>(max_wg_sz / max_sg_size);
13741377

0 commit comments

Comments
 (0)