Skip to content

Commit c10cad8

Browse files
Merge pull request #1567 from IntelPython/fix-gemm-wg-size-computation
Fix RuntimeError when running strided gemm on CUDA devices
2 parents 4c69ea1 + 9373733 commit c10cad8

File tree

1 file changed

+5
-2
lines changed
  • dpctl/tensor/libtensor/include/kernels/linalg_functions

1 file changed

+5
-2
lines changed

dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,10 +1272,13 @@ sycl::event _gemm_batch_nm_impl(sycl::queue &exec_q,
12721272
const std::uint32_t max_sg_size = krn.template get_info<
12731273
sycl::info::kernel_device_specific::max_sub_group_size>(dev);
12741274

1275+
const size_t k_wg_sz = krn.template get_info<
1276+
sycl::info::kernel_device_specific::work_group_size>(dev);
1277+
12751278
// Limit work-group size
12761279
constexpr size_t wg_sz_limit(2048);
1277-
const size_t max_wg_sz = std::min<size_t>(
1278-
dev.get_info<sycl::info::device::max_work_group_size>(), wg_sz_limit);
1280+
const size_t max_wg_sz = std::min(wg_sz_limit, k_wg_sz);
1281+
12791282
const std::uint32_t max_subgroups_per_wg =
12801283
static_cast<std::uint32_t>(max_wg_sz / max_sg_size);
12811284

0 commit comments

Comments
 (0)