diff --git a/dpctl/_backend.pxd b/dpctl/_backend.pxd index 4ff2ed1fd6..e2d060f1c3 100644 --- a/dpctl/_backend.pxd +++ b/dpctl/_backend.pxd @@ -278,6 +278,8 @@ cdef extern from "syclinterface/dpctl_sycl_platform_interface.h": cdef const char *DPCTLPlatform_GetVendor(const DPCTLSyclPlatformRef) cdef const char *DPCTLPlatform_GetVersion(const DPCTLSyclPlatformRef) cdef DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms() + cdef DPCTLSyclContextRef DPCTLPlatform_GetDefaultContext( + const DPCTLSyclPlatformRef) cdef extern from "syclinterface/dpctl_sycl_context_interface.h": diff --git a/dpctl/_sycl_device.pyx b/dpctl/_sycl_device.pyx index 373745f22f..a7daa7c3e5 100644 --- a/dpctl/_sycl_device.pyx +++ b/dpctl/_sycl_device.pyx @@ -50,6 +50,7 @@ from ._backend cimport ( # noqa: E211 DPCTLDevice_GetMaxWriteImageArgs, DPCTLDevice_GetName, DPCTLDevice_GetParentDevice, + DPCTLDevice_GetPlatform, DPCTLDevice_GetPreferredVectorWidthChar, DPCTLDevice_GetPreferredVectorWidthDouble, DPCTLDevice_GetPreferredVectorWidthFloat, @@ -80,6 +81,7 @@ from ._backend cimport ( # noqa: E211 DPCTLSize_t_Array_Delete, DPCTLSyclDeviceRef, DPCTLSyclDeviceSelectorRef, + DPCTLSyclPlatformRef, _aspect_type, _backend_type, _device_type, @@ -91,6 +93,8 @@ from .enum_types import backend_type, device_type from libc.stdint cimport int64_t, uint32_t from libc.stdlib cimport free, malloc +from ._sycl_platform cimport SyclPlatform + import collections import warnings @@ -639,6 +643,22 @@ cdef class SyclDevice(_SyclDevice): self._device_ref ) + @property + def sycl_platform(self): + """ Returns the platform associated with this device. + + Returns: + :class:`dpctl.SyclPlatform`: The platform associated with this + device. + """ + cdef DPCTLSyclPlatformRef PRef = ( + DPCTLDevice_GetPlatform(self._device_ref) + ) + if (PRef == NULL): + raise RuntimeError("Could not get platform for device.") + else: + return SyclPlatform._create(PRef) + @property def preferred_vector_width_char(self): """ Returns the preferred native vector width size for built-in scalar diff --git a/dpctl/_sycl_platform.pyx b/dpctl/_sycl_platform.pyx index 6e3fad07c8..ce211139a7 100644 --- a/dpctl/_sycl_platform.pyx +++ b/dpctl/_sycl_platform.pyx @@ -30,6 +30,7 @@ from ._backend cimport ( # noqa: E211 DPCTLPlatform_CreateFromSelector, DPCTLPlatform_Delete, DPCTLPlatform_GetBackend, + DPCTLPlatform_GetDefaultContext, DPCTLPlatform_GetName, DPCTLPlatform_GetPlatforms, DPCTLPlatform_GetVendor, @@ -40,6 +41,7 @@ from ._backend cimport ( # noqa: E211 DPCTLPlatformVector_GetAt, DPCTLPlatformVector_Size, DPCTLPlatformVectorRef, + DPCTLSyclContextRef, DPCTLSyclDeviceSelectorRef, DPCTLSyclPlatformRef, _backend_type, @@ -47,8 +49,11 @@ from ._backend cimport ( # noqa: E211 import warnings +from ._sycl_context import SyclContextCreationError from .enum_types import backend_type +from ._sycl_context cimport SyclContext + __all__ = [ "get_platforms", "lsplatform", @@ -236,10 +241,10 @@ cdef class SyclPlatform(_SyclPlatform): @property def backend(self): - """Returns the backend_type enum value for this device + """Returns the backend_type enum value for this platform Returns: - backend_type: The backend for the device. + backend_type: The backend for the platform. """ cdef _backend_type BTy = ( DPCTLPlatform_GetBackend(self._platform_ref) @@ -255,6 +260,22 @@ cdef class SyclPlatform(_SyclPlatform): else: raise ValueError("Unknown backend type.") + @property + def default_context(self): + """Returns the default platform context for this platform + + Returns: + SyclContext: The default context for the platform. + """ + cdef DPCTLSyclContextRef CRef = ( + DPCTLPlatform_GetDefaultContext(self._platform_ref) + ) + + if (CRef == NULL): + raise + else: + return SyclContext._create(CRef) + def lsplatform(verbosity=0): """ diff --git a/dpctl/tests/test_sycl_device.py b/dpctl/tests/test_sycl_device.py index cf38f687d9..c694d33dab 100644 --- a/dpctl/tests/test_sycl_device.py +++ b/dpctl/tests/test_sycl_device.py @@ -496,6 +496,11 @@ def check_profiling_timer_resolution(device): assert isinstance(resol, int) and resol > 0 +def check_platform(device): + p = device.sycl_platform + assert isinstance(p, dpctl.SyclPlatform) + + list_of_checks = [ check_get_max_compute_units, check_get_max_work_item_dims, @@ -552,6 +557,8 @@ def check_profiling_timer_resolution(device): check_repr, check_get_global_mem_size, check_get_local_mem_size, + check_profiling_timer_resolution, + check_platform, ] diff --git a/dpctl/tests/test_sycl_platform.py b/dpctl/tests/test_sycl_platform.py index 3bc230cdb9..e846e8ecdb 100644 --- a/dpctl/tests/test_sycl_platform.py +++ b/dpctl/tests/test_sycl_platform.py @@ -87,6 +87,11 @@ def check_repr(platform): assert r != "" +def check_default_context(platform): + r = platform.default_context + assert type(r) is dpctl.SyclContext + + list_of_checks = [ check_name, check_vendor, diff --git a/libsyclinterface/include/dpctl_sycl_platform_interface.h b/libsyclinterface/include/dpctl_sycl_platform_interface.h index 1c01dcfb69..1d2238e652 100644 --- a/libsyclinterface/include/dpctl_sycl_platform_interface.h +++ b/libsyclinterface/include/dpctl_sycl_platform_interface.h @@ -142,4 +142,16 @@ DPCTLPlatform_GetVersion(__dpctl_keep const DPCTLSyclPlatformRef PRef); DPCTL_API __dpctl_give DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms(void); +/*! + * @brief Returns a DPCTLSyclContextRef for default platform context. + * + * @param PRef Opaque pointer to a sycl::platform + * @return A DPCTLSyclContextRef value for the default platform associated + * with this platform. + * @ingroup PlatformInterface + */ +DPCTL_API +__dpctl_give DPCTLSyclContextRef +DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef); + DPCTL_C_EXTERN_C_END diff --git a/libsyclinterface/source/dpctl_sycl_platform_interface.cpp b/libsyclinterface/source/dpctl_sycl_platform_interface.cpp index 36c9bd16b7..4ab9771b99 100644 --- a/libsyclinterface/source/dpctl_sycl_platform_interface.cpp +++ b/libsyclinterface/source/dpctl_sycl_platform_interface.cpp @@ -41,6 +41,7 @@ using namespace cl::sycl; namespace { DEFINE_SIMPLE_CONVERSION_FUNCTIONS(platform, DPCTLSyclPlatformRef); +DEFINE_SIMPLE_CONVERSION_FUNCTIONS(context, DPCTLSyclContextRef); DEFINE_SIMPLE_CONVERSION_FUNCTIONS(device_selector, DPCTLSyclDeviceSelectorRef); DEFINE_SIMPLE_CONVERSION_FUNCTIONS(std::vector, DPCTLPlatformVectorRef); @@ -202,3 +203,19 @@ __dpctl_give DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms() // the wrap function is defined inside dpctl_vector_templ.cpp return wrap(Platforms); } + +__dpctl_give DPCTLSyclContextRef +DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef) +{ + auto P = unwrap(PRef); + if (P) { + auto default_ctx = P->ext_oneapi_get_default_context(); + return wrap(new context(default_ctx)); + } + else { + error_handler( + "Default platform cannot be obtained up for a NULL platform.", + __FILE__, __func__, __LINE__); + return nullptr; + } +} diff --git a/libsyclinterface/tests/test_sycl_platform_interface.cpp b/libsyclinterface/tests/test_sycl_platform_interface.cpp index 594d4856e2..1fe9c80117 100644 --- a/libsyclinterface/tests/test_sycl_platform_interface.cpp +++ b/libsyclinterface/tests/test_sycl_platform_interface.cpp @@ -25,6 +25,7 @@ //===----------------------------------------------------------------------===// #include "Support/CBindingWrapping.h" +#include "dpctl_sycl_context_interface.h" #include "dpctl_sycl_device_selector_interface.h" #include "dpctl_sycl_platform_interface.h" #include "dpctl_sycl_platform_manager.h" @@ -82,6 +83,16 @@ void check_platform_backend(__dpctl_keep const DPCTLSyclPlatformRef PRef) }()); } +void check_platform_default_context( + __dpctl_keep const DPCTLSyclPlatformRef PRef) +{ + DPCTLSyclContextRef CRef = nullptr; + EXPECT_NO_FATAL_FAILURE(CRef = DPCTLPlatform_GetDefaultContext(PRef)); + EXPECT_TRUE(CRef != nullptr); + + EXPECT_NO_FATAL_FAILURE(DPCTLContext_Delete(CRef)); +} + } // namespace struct TestDPCTLSyclPlatformInterface @@ -167,6 +178,14 @@ TEST_F(TestDPCTLSyclPlatformNull, ChkGetVersion) ASSERT_TRUE(version == nullptr); } +TEST_F(TestDPCTLSyclPlatformNull, ChkGetDefaultConext) +{ + DPCTLSyclContextRef CRef = nullptr; + + EXPECT_NO_FATAL_FAILURE(CRef = DPCTLPlatform_GetDefaultContext(NullPRef)); + EXPECT_TRUE(CRef == nullptr); +} + struct TestDPCTLSyclDefaultPlatform : public ::testing::Test { DPCTLSyclPlatformRef PRef = nullptr; @@ -207,6 +226,11 @@ TEST_P(TestDPCTLSyclPlatformInterface, ChkGetBackend) check_platform_backend(PRef); } +TEST_P(TestDPCTLSyclPlatformInterface, ChkGetDefaultContext) +{ + check_platform_default_context(PRef); +} + TEST_P(TestDPCTLSyclPlatformInterface, ChkCopy) { DPCTLSyclPlatformRef Copied_PRef = nullptr; diff --git a/libsyclinterface/tests/test_sycl_queue_interface.cpp b/libsyclinterface/tests/test_sycl_queue_interface.cpp index f03a320fd3..4e09110f32 100644 --- a/libsyclinterface/tests/test_sycl_queue_interface.cpp +++ b/libsyclinterface/tests/test_sycl_queue_interface.cpp @@ -446,7 +446,6 @@ TEST_P(TestDPCTLQueueMemberFunctions, CheckMemset) ASSERT_NO_FATAL_FAILURE(DPCTLfree_with_queue(p, QRef)); - bool equal = true; for (size_t i = 0; i < nbytes; ++i) { ASSERT_TRUE(host_arr[i] == val); }