File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed
dpctl/tensor/libtensor/include/kernels/linalg_functions Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -1365,10 +1365,13 @@ sycl::event _gemm_batch_nm_impl(sycl::queue &exec_q,
1365
1365
const std::uint32_t max_sg_size = krn.template get_info <
1366
1366
sycl::info::kernel_device_specific::max_sub_group_size>(dev);
1367
1367
1368
+ const size_t k_wg_sz = krn.template get_info <
1369
+ sycl::info::kernel_device_specific::work_group_size>(dev);
1370
+
1368
1371
// Limit work-group size
1369
1372
constexpr size_t wg_sz_limit (2048 );
1370
- const size_t max_wg_sz = std::min< size_t >(
1371
- dev. get_info <sycl::info::device::max_work_group_size>(), wg_sz_limit);
1373
+ const size_t max_wg_sz = std::min (wg_sz_limit, k_wg_sz);
1374
+
1372
1375
const std::uint32_t max_subgroups_per_wg =
1373
1376
static_cast <std::uint32_t >(max_wg_sz / max_sg_size);
1374
1377
You can’t perform that action at this time.
0 commit comments