diff --git a/sycl/plugins/level_zero/pi_level0.cpp b/sycl/plugins/level_zero/pi_level0.cpp index dc4ccfd988ca9..8f726e4951dcd 100644 --- a/sycl/plugins/level_zero/pi_level0.cpp +++ b/sycl/plugins/level_zero/pi_level0.cpp @@ -2142,7 +2142,11 @@ piEnqueueKernelLaunch(pi_queue Queue, pi_kernel Kernel, pi_uint32 WorkDim, assert(GlobalWorkSize[1] == (ZeThreadGroupDimensions.groupCountY * WG[1])); assert(GlobalWorkSize[2] == (ZeThreadGroupDimensions.groupCountZ * WG[2])); - ZE_CALL(zeKernelSetGroupSize(Kernel->ZeKernel, WG[0], WG[1], WG[2])); + ze_result_t res = ZE_CALL_NOCHECK(zeKernelSetGroupSize(Kernel->ZeKernel, WG[0], WG[1], WG[2])); + + if (res == ZE_RESULT_ERROR_INVALID_GROUP_SIZE_DIMENSION) { + return PI_INVALID_WORK_GROUP_SIZE; + } // Get a new command list to be used on this call ze_command_list_handle_t ZeCommandList = nullptr; diff --git a/sycl/source/detail/error_handling/enqueue_kernel.cpp b/sycl/source/detail/error_handling/enqueue_kernel.cpp index 1438e66a1e80b..cfa73bc6f0226 100644 --- a/sycl/source/detail/error_handling/enqueue_kernel.cpp +++ b/sycl/source/detail/error_handling/enqueue_kernel.cpp @@ -21,6 +21,54 @@ namespace detail { namespace enqueue_kernel_launch { +bool l0HandleInvalidWorkGroupSize(const device_impl &DeviceImpl, + pi_kernel Kernel, const NDRDescT &NDRDesc) { + const bool HasLocalSize = (NDRDesc.LocalSize[0] != 0); + + const plugin &Plugin = DeviceImpl.getPlugin(); + RT::PiDevice Device = DeviceImpl.getHandleRef(); + + size_t CompileWGSize[3] = {0}; + Plugin.call( + Kernel, Device, PI_KERNEL_GROUP_INFO_COMPILE_WORK_GROUP_SIZE, + sizeof(size_t) * 3, CompileWGSize, nullptr); + + if (CompileWGSize[0] != 0) { + // PI_INVALID_WORK_GROUP_SIZE if local_work_size is specified and does not + // match the required work-group size for kernel in the program source. + if (NDRDesc.LocalSize[0] != CompileWGSize[0] || + NDRDesc.LocalSize[1] != CompileWGSize[1] || + NDRDesc.LocalSize[2] != CompileWGSize[2]) + throw sycl::nd_range_error( + "Specified local size doesn't match the required work-group size " + "specified in the program source", + PI_INVALID_WORK_GROUP_SIZE); + } + + if (HasLocalSize) { + const bool NonUniformWGs = + (NDRDesc.LocalSize[0] != 0 && + NDRDesc.GlobalSize[0] % NDRDesc.LocalSize[0] != 0) || + (NDRDesc.LocalSize[1] != 0 && + NDRDesc.GlobalSize[1] % NDRDesc.LocalSize[1] != 0) || + (NDRDesc.LocalSize[2] != 0 && + NDRDesc.GlobalSize[2] % NDRDesc.LocalSize[2] != 0); + + // PI_INVALID_WORK_GROUP_SIZE if local_work_size is specified and + // number of workitems specified by global_work_size is not evenly + // divisible by size of work-group given by local_work_size + if (NonUniformWGs) + throw sycl::nd_range_error( + "Non-uniform work-groups are not supported by the target device", + PI_INVALID_WORK_GROUP_SIZE); + } + + // Fallback + constexpr pi_result Error = PI_INVALID_WORK_GROUP_SIZE; + throw sycl::runtime_error( + "Level0 API failed. Level0 API returns: " + codeToString(Error), Error); +} + bool oclHandleInvalidWorkGroupSize(const device_impl &DeviceImpl, pi_kernel Kernel, const NDRDescT &NDRDesc) { const bool HasLocalSize = (NDRDesc.LocalSize[0] != 0); @@ -230,6 +278,10 @@ bool handleInvalidWorkGroupSize(const device_impl &DeviceImpl, pi_kernel Kernel, return oclHandleInvalidWorkGroupSize(DeviceImpl, Kernel, NDRDesc); } + if (PlatformName.find("Level-Zero") != std::string::npos) { + return l0HandleInvalidWorkGroupSize(DeviceImpl, Kernel, NDRDesc); + } + // Fallback constexpr pi_result Error = PI_INVALID_WORK_GROUP_SIZE; throw runtime_error( diff --git a/sycl/test/basic_tests/parallel_for_range_level0.cpp b/sycl/test/basic_tests/parallel_for_range_level0.cpp new file mode 100644 index 0000000000000..ed797cfdddbee --- /dev/null +++ b/sycl/test/basic_tests/parallel_for_range_level0.cpp @@ -0,0 +1,91 @@ +// XFAIL: cuda +// CUDA exposes broken hierarchical parallelism. + +// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out +// RUN: %ACC_RUN_PLACEHOLDER %t.out + +#include + +#include + +using namespace cl::sycl; + +[[cl::reqd_work_group_size(4, 4, 4)]] void reqd_wg_size_helper() { + // do nothing +} + +int main() { + auto AsyncHandler = [](exception_list ES) { + for (auto &E : ES) { + std::rethrow_exception(E); + } + }; + + queue Q(AsyncHandler); + device D(Q.get_device()); + + string_class DeviceVendorName = D.get_info(); + auto DeviceType = D.get_info(); + + // parallel_for, (16, 16, 16) global, (8, 8, 8) local, reqd_wg_size(4, 4, 4) + // -> fail + try { + Q.submit([&](handler &CGH) { + CGH.parallel_for( + nd_range<3>(range<3>(16, 16, 16), range<3>(8, 8, 8)), + [=](nd_item<3>) { reqd_wg_size_helper(); }); + }); + Q.wait_and_throw(); + std::cerr << "Test case ReqdWGSizeNegativeA failed: no exception has been " + "thrown\n"; + return 1; // We shouldn't be here, exception is expected + } catch (nd_range_error &E) { + if (string_class(E.what()).find( + "Specified local size doesn't match the required work-group size " + "specified in the program source") == string_class::npos) { + std::cerr + << "Test case ReqdWGSizeNegativeA failed: unexpected exception: " + << E.what() << std::endl; + return 1; + } + } catch (runtime_error &E) { + std::cerr << "Test case ReqdWGSizeNegativeA failed: unexpected exception: " + << E.what() << std::endl; + return 1; + } catch (...) { + std::cerr << "Test case ReqdWGSizeNegativeA failed: something unexpected " + "has been caught" + << std::endl; + return 1; + } + + // Positive test-cases that should pass on any underlying OpenCL runtime + + // parallel_for, (8, 8, 8) global, (4, 4, 4) local, reqd_wg_size(4, 4, 4) -> + // pass + try { + Q.submit([&](handler &CGH) { + CGH.parallel_for( + nd_range<3>(range<3>(8, 8, 8), range<3>(4, 4, 4)), + [=](nd_item<3>) { reqd_wg_size_helper(); }); + }); + Q.wait_and_throw(); + } catch (nd_range_error &E) { + std::cerr << "Test case ReqdWGSizePositiveA failed: unexpected exception: " + << E.what() << std::endl; + return 1; + } catch (runtime_error &E) { + std::cerr << "Test case ReqdWGSizePositiveA failed: unexpected exception: " + << E.what() << std::endl; + return 1; + } catch (...) { + std::cerr << "Test case ReqdWGSizePositiveA failed: something unexpected " + "has been caught" + << std::endl; + return 1; + } + + return 0; +}