File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed
dpctl/tensor/libtensor/include/kernels/linalg_functions Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -1272,10 +1272,13 @@ sycl::event _gemm_batch_nm_impl(sycl::queue &exec_q,
1272
1272
const std::uint32_t max_sg_size = krn.template get_info <
1273
1273
sycl::info::kernel_device_specific::max_sub_group_size>(dev);
1274
1274
1275
+ const size_t k_wg_sz = krn.template get_info <
1276
+ sycl::info::kernel_device_specific::work_group_size>(dev);
1277
+
1275
1278
// Limit work-group size
1276
1279
constexpr size_t wg_sz_limit (2048 );
1277
- const size_t max_wg_sz = std::min< size_t >(
1278
- dev. get_info <sycl::info::device::max_work_group_size>(), wg_sz_limit);
1280
+ const size_t max_wg_sz = std::min (wg_sz_limit, k_wg_sz);
1281
+
1279
1282
const std::uint32_t max_subgroups_per_wg =
1280
1283
static_cast <std::uint32_t >(max_wg_sz / max_sg_size);
1281
1284
You can’t perform that action at this time.
0 commit comments