diff --git a/dpctl-capi/source/dpctl_sycl_device_interface.cpp b/dpctl-capi/source/dpctl_sycl_device_interface.cpp index d6c1af390d..80a0f983e0 100644 --- a/dpctl-capi/source/dpctl_sycl_device_interface.cpp +++ b/dpctl-capi/source/dpctl_sycl_device_interface.cpp @@ -29,6 +29,7 @@ #include "Support/CBindingWrapping.h" #include "dpctl_sycl_device_manager.h" #include /* SYCL headers */ +#include #include using namespace cl::sycl; @@ -577,8 +578,13 @@ DPCTLDevice_CreateSubDevicesEqually(__dpctl_keep const DPCTLSyclDeviceRef DRef, size_t count) { vector_class *Devices = nullptr; - auto D = unwrap(DRef); - if (D) { + if (DRef) { + if (count == 0) { + std::cerr << "Can not create sub-devices with zero compute units" + << '\n'; + return nullptr; + } + auto D = unwrap(DRef); try { auto subDevices = D->create_sub_devices< info::partition_property::partition_equally>(count); @@ -610,13 +616,29 @@ DPCTLDevice_CreateSubDevicesByCounts(__dpctl_keep const DPCTLSyclDeviceRef DRef, size_t ncounts) { vector_class *Devices = nullptr; - std::vector vcounts; + std::vector vcounts(ncounts); vcounts.assign(counts, counts + ncounts); - auto D = unwrap(DRef); - if (D) { + size_t min_elem = *std::min_element(vcounts.begin(), vcounts.end()); + if (min_elem == 0) { + std::cerr << "Can not create sub-devices with zero compute units" + << '\n'; + return nullptr; + } + if (DRef) { + auto D = unwrap(DRef); + vector_class::type> subDevices; try { - auto subDevices = D->create_sub_devices< + subDevices = D->create_sub_devices< info::partition_property::partition_by_counts>(vcounts); + } catch (feature_not_supported const &fnse) { + std::cerr << fnse.what() << '\n'; + return nullptr; + } catch (runtime_error const &re) { + // \todo log error + std::cerr << re.what() << '\n'; + return nullptr; + } + try { Devices = new vector_class(); for (const auto &sd : subDevices) { Devices->emplace_back(wrap(new device(sd))); @@ -625,10 +647,6 @@ DPCTLDevice_CreateSubDevicesByCounts(__dpctl_keep const DPCTLSyclDeviceRef DRef, delete Devices; std::cerr << ba.what() << '\n'; return nullptr; - } catch (feature_not_supported const &fnse) { - delete Devices; - std::cerr << fnse.what() << '\n'; - return nullptr; } catch (runtime_error const &re) { delete Devices; // \todo log error diff --git a/dpctl-capi/tests/test_sycl_device_subdevices.cpp b/dpctl-capi/tests/test_sycl_device_subdevices.cpp index 444f8f10dd..1979826bfd 100644 --- a/dpctl-capi/tests/test_sycl_device_subdevices.cpp +++ b/dpctl-capi/tests/test_sycl_device_subdevices.cpp @@ -92,6 +92,9 @@ TEST_P(TestDPCTLSyclDeviceInterface, ChkCreateSubDevicesEqually) EXPECT_NO_FATAL_FAILURE(DPCTLDevice_Delete(pDRef)); EXPECT_NO_FATAL_FAILURE(DPCTLDeviceVector_Delete(DVRef)); } + EXPECT_NO_FATAL_FAILURE( + DVRef = DPCTLDevice_CreateSubDevicesEqually(DRef, 0)); + EXPECT_TRUE(DVRef == nullptr); } } @@ -114,7 +117,12 @@ TEST_P(TestDPCTLSyclDeviceInterface, ChkCreateSubDevicesByCounts) if (DVRef) { EXPECT_TRUE(DPCTLDeviceVector_Size(DVRef) > 0); EXPECT_NO_FATAL_FAILURE(DPCTLDeviceVector_Delete(DVRef)); + DVRef = nullptr; } + counts[n - 1] = 0; + EXPECT_NO_FATAL_FAILURE( + DVRef = DPCTLDevice_CreateSubDevicesByCounts(DRef, counts, n)); + EXPECT_TRUE(DVRef == nullptr); } } diff --git a/dpctl/_sycl_device.pyx b/dpctl/_sycl_device.pyx index 90882c8bfd..8be701c20e 100644 --- a/dpctl/_sycl_device.pyx +++ b/dpctl/_sycl_device.pyx @@ -706,7 +706,12 @@ cdef class SyclDevice(_SyclDevice): the sub-devices. """ cdef DPCTLDeviceVectorRef DVRef = NULL - DVRef = DPCTLDevice_CreateSubDevicesEqually(self._device_ref, count) + if count > 0: + DVRef = DPCTLDevice_CreateSubDevicesEqually(self._device_ref, count) + else: + raise ValueError( + "Creating sub-devices with zero compute units is requested" + ) if DVRef is NULL: raise SubDeviceCreationError("Sub-devices were not created.") return _get_devices(DVRef) @@ -720,6 +725,7 @@ cdef class SyclDevice(_SyclDevice): """ cdef int ncounts = len(counts) cdef size_t *counts_buff = NULL + cdef size_t min_count = 1 cdef DPCTLDeviceVectorRef DVRef = NULL cdef int i @@ -734,10 +740,17 @@ cdef class SyclDevice(_SyclDevice): ) for i in range(ncounts): counts_buff[i] = counts[i] - DVRef = DPCTLDevice_CreateSubDevicesByCounts( - self._device_ref, counts_buff, ncounts - ) + if counts_buff[i] == 0: + min_count = 0 + if min_count: + DVRef = DPCTLDevice_CreateSubDevicesByCounts( + self._device_ref, counts_buff, ncounts + ) free(counts_buff) + if min_count == 0: + raise ValueError( + "Targeted sub-device execution units must positive" + ) if DVRef is NULL: raise SubDeviceCreationError("Sub-devices were not created.") return _get_devices(DVRef) diff --git a/dpctl/tests/test_sycl_device.py b/dpctl/tests/test_sycl_device.py index e55a5ee225..882bcebc6b 100644 --- a/dpctl/tests/test_sycl_device.py +++ b/dpctl/tests/test_sycl_device.py @@ -360,6 +360,13 @@ def check_create_sub_devices_equally(device): pytest.fail("create_sub_devices failed") +def check_create_sub_devices_equally_zeros(device): + try: + device.create_sub_devices(partition=0) + except TypeError: + pass + + def check_create_sub_devices_by_counts(device): try: n = device.max_compute_units / 2 @@ -372,6 +379,13 @@ def check_create_sub_devices_by_counts(device): pytest.fail("create_sub_devices failed") +def check_create_sub_devices_by_counts_zeros(device): + try: + device.create_sub_devices(partition=(0, 1)) + except TypeError: + pass + + def check_create_sub_devices_by_affinity_not_applicable(device): try: device.create_sub_devices(partition="not_applicable")