diff --git a/dppl/_memory.pyx b/dppl/_memory.pyx index 22b9329073..6b14164f43 100644 --- a/dppl/_memory.pyx +++ b/dppl/_memory.pyx @@ -8,14 +8,14 @@ from cpython cimport Py_buffer cdef class Memory: cdef DPPLSyclUSMRef memory_ptr cdef Py_ssize_t nbytes - cdef SyclContext context + cdef SyclQueue queue cdef _cinit(self, Py_ssize_t nbytes, ptr_type, SyclQueue queue): cdef DPPLSyclUSMRef p self.memory_ptr = NULL self.nbytes = 0 - self.context = None + self.queue = None if (nbytes > 0): if queue is None: @@ -34,7 +34,7 @@ cdef class Memory: if (p): self.memory_ptr = p self.nbytes = nbytes - self.context = queue.get_sycl_context() + self.queue = queue else: raise RuntimeError("Null memory pointer returned") else: @@ -42,11 +42,11 @@ cdef class Memory: def __dealloc__(self): if (self.memory_ptr): - DPPLfree_with_context(self.memory_ptr, - self.context.get_context_ref()) + DPPLfree_with_queue(self.memory_ptr, + self.queue.get_queue_ref()) self.memory_ptr = NULL self.nbytes = 0 - self.context = None + self.queue = None cdef _getbuffer(self, Py_buffer *buffer, int flags): # memory_ptr is Ref which is pointer to SYCL type. For USM it is void*. @@ -68,16 +68,36 @@ cdef class Memory: property _context: def __get__(self): - return self.context + return self.queue.get_sycl_context() + + property _queue: + def __get__(self): + return self.queue def __repr__(self): return "" \ .format(self.nbytes, hex((self.memory_ptr))) - def _usm_type(self): + def _usm_type(self, context=None): cdef const char* kind - kind = DPPLUSM_GetPointerType(self.memory_ptr, - self.context.get_context_ref()) + cdef SyclContext ctx + cdef SyclQueue q + if context is None: + ctx = self._context + kind = DPPLUSM_GetPointerType(self.memory_ptr, + ctx.get_context_ref()) + elif isinstance(context, SyclContext): + ctx = (context) + kind = DPPLUSM_GetPointerType(self.memory_ptr, + ctx.get_context_ref()) + elif isinstance(context, SyclQueue): + q = (context) + ctx = q.get_sycl_context() + kind = DPPLUSM_GetPointerType(self.memory_ptr, + ctx.get_context_ref()) + else: + raise ValueError("sycl_context keyword can be either None, " + "or an instance of dppl.SyclConext") return kind.decode('UTF-8') diff --git a/dppl/tests/dppl_tests/test_sycl_usm.py b/dppl/tests/dppl_tests/test_sycl_usm.py index 15bb397a9f..94b14c50fb 100644 --- a/dppl/tests/dppl_tests/test_sycl_usm.py +++ b/dppl/tests/dppl_tests/test_sycl_usm.py @@ -48,14 +48,28 @@ def test_memory_cpu_context (self): # CPU context with dppl.device_context(dppl.device_type.cpu): - self.assertEqual(mobj._usm_type(), 'shared') + # type respective to the context in which + # memory was created + usm_type = mobj._usm_type() + self.assertEqual(usm_type, 'shared') + + current_queue = dppl.get_current_queue() + # type as view from current queue + usm_type = mobj._usm_type(context=current_queue) + # type can be unknown if current queue is + # not in the same SYCL context + self.assertTrue(usm_type in ['unknown', 'shared']) def test_memory_gpu_context (self): mobj = self._create_memory() # GPU context with dppl.device_context(dppl.device_type.gpu): - self.assertEqual(mobj._usm_type(), 'shared') + usm_type = mobj._usm_type() + self.assertEqual(usm_type, 'shared') + current_queue = dppl.get_current_queue() + usm_type = mobj._usm_type(context=current_queue) + self.assertTrue(usm_type in ['unknown', 'shared']) class TestMemoryUSMBase: