Skip to content

Commit 71e605a

Browse files
khaledDiptorup Deb
authored andcommitted
Fix in usm_ndarray_types.py
1 parent a8b10a3 commit 71e605a

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

numba_dpex/core/types/usm_ndarray_type.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from numba.core.types.npytypes import Array
1313
from numba.np.numpy_support import from_dtype
1414

15+
from numba_dpex.core.types.dpctl_types import DpctlSyclQueue
1516
from numba_dpex.utils import address_space
1617

1718

@@ -31,15 +32,10 @@ def __init__(
3132
aligned=True,
3233
addrspace=address_space.GLOBAL,
3334
):
34-
if not isinstance(device, str):
35+
if queue is not None and device != "unknown":
3536
raise TypeError(
36-
"The device keyword arg should be a str object specifying "
37-
"a SYCL filter selector"
38-
)
39-
40-
if not isinstance(queue, dpctl.SyclQueue) and queue is not None:
41-
raise TypeError(
42-
"The queue keyword arg should be a dpctl.SyclQueue object or None"
37+
"numba_dpex.core.types.usm_ndarray_type.USMNdArray.__init__(): "
38+
"`device` and `sycl_queue` are exclusive keywords, i.e. use one or other."
4339
)
4440

4541
self.usm_type = usm_type
@@ -48,18 +44,28 @@ def __init__(
4844
if device == "unknown":
4945
device = None
5046

51-
if queue is not None and device is not None:
52-
raise TypeError(
53-
"'queue' and 'device' keywords can not be both specified"
54-
)
55-
5647
if queue is not None:
48+
if not isinstance(queue, dpctl.SyclQueue):
49+
raise TypeError(
50+
"numba_dpex.core.types.usm_ndarray_type.USMNdArray.__init__(): "
51+
"The queue keyword arg should be a dpctl.SyclQueue object or None."
52+
)
5753
self.queue = queue
5854
else:
5955
if device is None:
60-
device = dpctl.SyclDevice()
61-
62-
self.queue = dpctl.get_device_cached_queue(device)
56+
sycl_device = dpctl.SyclDevice()
57+
else:
58+
if not isinstance(device, str):
59+
raise TypeError(
60+
"numba_dpex.core.types.usm_ndarray_type.USMNdArray.__init__(): "
61+
"The device keyword arg should be a str object specifying "
62+
"a SYCL filter selector."
63+
)
64+
sycl_device = dpctl.SyclDevice(device)
65+
66+
self.queue = dpctl._sycl_queue_manager.get_device_cached_queue(
67+
sycl_device
68+
)
6369

6470
self.device = self.queue.sycl_device.filter_string
6571

0 commit comments

Comments
 (0)