Skip to content

Commit a23b6bf

Browse files
Merge master into add_dpctl.tensor.isdtype
2 parents 2c7c097 + 1a6bba0 commit a23b6bf

File tree

8 files changed

+142
-55
lines changed

8 files changed

+142
-55
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ set(CMAKE_CXX_STANDARD 17)
1313
set(CMAKE_CXX_STANDARD_REQUIRED True)
1414
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH)
1515
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
16+
# set_property(GLOBAL PROPERTY GLOBAL_DEPENDS_DEBUG_MODE 1)
1617

1718
# Option to generate code coverage report using llvm-cov and lcov.
1819
option(DPCTL_GENERATE_COVERAGE

dpctl/CMakeLists.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ if(WIN32)
1313
"-Wunused-function "
1414
"-Wuninitialized "
1515
"-Wmissing-declarations "
16+
"-Wstrict-prototypes "
1617
"-Wno-unused-parameter "
1718
)
1819
string(CONCAT SDL_FLAGS
@@ -36,6 +37,8 @@ elseif(UNIX)
3637
"-Wunused-function "
3738
"-Wuninitialized "
3839
"-Wmissing-declarations "
40+
"-Wstrict-prototypes "
41+
"-Wno-unused-parameter "
3942
"-fdiagnostics-color=auto "
4043
)
4144
string(CONCAT SDL_FLAGS
@@ -46,7 +49,7 @@ elseif(UNIX)
4649
"-D_FORTIFY_SOURCE=2 "
4750
"-Wformat "
4851
"-Wformat-security "
49-
"-fno-strict-overflow "
52+
# "-fno-strict-overflow " # no-strict-overflow is implied by -fwrapv
5053
"-fno-delete-null-pointer-checks "
5154
"-fwrapv "
5255
)
@@ -137,9 +140,11 @@ set(CMAKE_INSTALL_RPATH "$ORIGIN")
137140

138141
function(build_dpctl_ext _trgt _src _dest)
139142
add_cython_target(${_trgt} ${_src} CXX OUTPUT_VAR _generated_src)
143+
set(_cythonize_trgt "${_trgt}_cythonize_pyx")
144+
add_custom_target(${_cythonize_trgt} DEPENDS ${_src})
140145
add_library(${_trgt} MODULE ${_generated_src})
141146
target_include_directories(${_trgt} PRIVATE ${NumPy_INCLUDE_DIR} ${DPCTL_INCLUDE_DIR})
142-
add_dependencies(${_trgt} _build_time_create_dpctl_include_copy)
147+
add_dependencies(${_trgt} _build_time_create_dpctl_include_copy ${_cythonize_trgt})
143148
if (DPCTL_GENERATE_COVERAGE)
144149
target_compile_definitions(${_trgt} PRIVATE CYTHON_TRACE=1 CYTHON_TRACE_NOGIL=1)
145150
target_compile_options(${_trgt} PRIVATE -fno-sycl-use-footer)

dpctl/_sycl_device.pyx

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ from libc.stdlib cimport free, malloc
110110
from ._sycl_platform cimport SyclPlatform
111111

112112
import collections
113+
import functools
113114
import warnings
114115

115116
__all__ = [
@@ -198,6 +199,37 @@ cdef void _init_helper(_SyclDevice device, DPCTLSyclDeviceRef DRef):
198199
device._max_work_item_sizes = DPCTLDevice_GetMaxWorkItemSizes3d(DRef)
199200

200201

202+
@functools.lru_cache(maxsize=None)
203+
def _cached_filter_string(d : SyclDevice):
204+
"""
205+
Internal utility to compute filter_string of input SyclDevice
206+
and cached with `functools.cache`.
207+
208+
Args:
209+
d (dpctl.SyclDevice):
210+
A device for which to compute the filter string.
211+
Returns:
212+
out(str):
213+
Filter string that can be used to create input device,
214+
if the device is a root (unpartitioned) device.
215+
216+
Raises:
217+
ValueError: if the input device is a sub-device.
218+
"""
219+
cdef _backend_type BTy
220+
cdef _device_type DTy
221+
cdef int64_t relId = -1
222+
cdef SyclDevice cd = <SyclDevice> d
223+
relId = DPCTLDeviceMgr_GetRelativeId(cd._device_ref)
224+
if (relId == -1):
225+
raise ValueError("This SyclDevice is not a root device")
226+
BTy = DPCTLDevice_GetBackend(cd._device_ref)
227+
br_str = _backend_type_to_filter_string_part(BTy)
228+
DTy = DPCTLDevice_GetDeviceType(cd._device_ref)
229+
dt_str = _device_type_to_filter_string_part(DTy)
230+
return ":".join((br_str, dt_str, str(relId)))
231+
232+
201233
cdef class SyclDevice(_SyclDevice):
202234
""" SyclDevice(arg=None)
203235
A Python wrapper for the :sycl_device:`sycl::device <>` C++ class.
@@ -1360,14 +1392,14 @@ cdef class SyclDevice(_SyclDevice):
13601392

13611393
@property
13621394
def filter_string(self):
1363-
""" For a parent device, returns a fully specified filter selector
1364-
string``backend:device_type:relative_id`` selecting the device.
1395+
""" For a root device, returns a fully specified filter selector
1396+
string ``"backend:device_type:relative_id"`` selecting the device.
13651397
13661398
Returns:
13671399
str: A Python string representing a filter selector string.
13681400
13691401
Raises:
1370-
TypeError: If the device is a sub-devices.
1402+
TypeError: If the device is a sub-device.
13711403
13721404
:Example:
13731405
.. code-block:: python
@@ -1387,14 +1419,7 @@ cdef class SyclDevice(_SyclDevice):
13871419
cdef int64_t relId = -1
13881420
pDRef = DPCTLDevice_GetParentDevice(self._device_ref)
13891421
if (pDRef is NULL):
1390-
BTy = DPCTLDevice_GetBackend(self._device_ref)
1391-
DTy = DPCTLDevice_GetDeviceType(self._device_ref)
1392-
relId = DPCTLDeviceMgr_GetRelativeId(self._device_ref)
1393-
if (relId == -1):
1394-
raise TypeError("This SyclDevice is not a root device")
1395-
br_str = _backend_type_to_filter_string_part(BTy)
1396-
dt_str = _device_type_to_filter_string_part(DTy)
1397-
return ":".join((br_str, dt_str, str(relId)))
1422+
return _cached_filter_string(self)
13981423
else:
13991424
# this a sub-device, free it, and raise an exception
14001425
DPCTLDevice_Delete(pDRef)

dpctl/tensor/_indexing_functions.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,18 @@
2626
from ._copy_utils import _extract_impl, _nonzero_impl
2727

2828

29-
def take(x, indices, /, *, axis=None, mode="clip"):
30-
"""take(x, indices, axis=None, mode="clip")
29+
def _get_indexing_mode(name):
30+
modes = {"wrap": 0, "clip": 1}
31+
try:
32+
return modes[name]
33+
except KeyError:
34+
raise ValueError(
35+
"`mode` must be `wrap` or `clip`." "Got `{}`.".format(name)
36+
)
37+
38+
39+
def take(x, indices, /, *, axis=None, mode="wrap"):
40+
"""take(x, indices, axis=None, mode="wrap")
3141
3242
Takes elements from array along a given axis.
3343
@@ -42,15 +52,15 @@ def take(x, indices, /, *, axis=None, mode="clip"):
4252
Default: `None`.
4353
mode:
4454
How out-of-bounds indices will be handled.
45-
"clip" - clamps indices to (-n <= i < n), then wraps
55+
"wrap" - clamps indices to (-n <= i < n), then wraps
4656
negative indices.
47-
"wrap" - wraps both negative and positive indices.
48-
Default: `"clip"`.
57+
"clip" - clips indices to (0 <= i < n)
58+
Default: `"wrap"`.
4959
5060
Returns:
5161
out: usm_ndarray
5262
Array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:]
53-
filled with elements .
63+
filled with elements from x.
5464
"""
5565
if not isinstance(x, dpt.usm_ndarray):
5666
raise TypeError(
@@ -80,11 +90,7 @@ def take(x, indices, /, *, axis=None, mode="clip"):
8090
[x.usm_type, indices.usm_type]
8191
)
8292

83-
modes = {"clip": 0, "wrap": 1}
84-
try:
85-
mode = modes[mode]
86-
except KeyError:
87-
raise ValueError("`mode` must be `clip` or `wrap`.")
93+
mode = _get_indexing_mode(mode)
8894

8995
x_ndim = x.ndim
9096
if axis is None:
@@ -114,8 +120,8 @@ def take(x, indices, /, *, axis=None, mode="clip"):
114120
return res
115121

116122

117-
def put(x, indices, vals, /, *, axis=None, mode="clip"):
118-
"""put(x, indices, vals, axis=None, mode="clip")
123+
def put(x, indices, vals, /, *, axis=None, mode="wrap"):
124+
"""put(x, indices, vals, axis=None, mode="wrap")
119125
120126
Puts values of an array into another array
121127
along a given axis.
@@ -134,10 +140,10 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"):
134140
Default: `None`.
135141
mode:
136142
How out-of-bounds indices will be handled.
137-
"clip" - clamps indices to (-axis_size <= i < axis_size),
138-
then wraps negative indices.
139-
"wrap" - wraps both negative and positive indices.
140-
Default: `"clip"`.
143+
"wrap" - clamps indices to (-n <= i < n), then wraps
144+
negative indices.
145+
"clip" - clips indices to (0 <= i < n)
146+
Default: `"wrap"`.
141147
"""
142148
if not isinstance(x, dpt.usm_ndarray):
143149
raise TypeError(
@@ -175,11 +181,8 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"):
175181
if exec_q is None:
176182
raise dpctl.utils.ExecutionPlacementError
177183
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
178-
modes = {"clip": 0, "wrap": 1}
179-
try:
180-
mode = modes[mode]
181-
except KeyError:
182-
raise ValueError("`mode` must be `clip` or `wrap`.")
184+
185+
mode = _get_indexing_mode(mode)
183186

184187
x_ndim = x.ndim
185188
if axis is None:

dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ namespace py = pybind11;
4646
template <typename ProjectorT, typename Ty, typename indT> class take_kernel;
4747
template <typename ProjectorT, typename Ty, typename indT> class put_kernel;
4848

49-
class ClipIndex
49+
class WrapIndex
5050
{
5151
public:
52-
ClipIndex() = default;
52+
WrapIndex() = default;
5353

5454
void operator()(py::ssize_t max_item, py::ssize_t &ind) const
5555
{
@@ -60,16 +60,15 @@ class ClipIndex
6060
}
6161
};
6262

63-
class WrapIndex
63+
class ClipIndex
6464
{
6565
public:
66-
WrapIndex() = default;
66+
ClipIndex() = default;
6767

6868
void operator()(py::ssize_t max_item, py::ssize_t &ind) const
6969
{
7070
max_item = std::max<py::ssize_t>(max_item, 1);
71-
ind = (ind < 0) ? (ind + max_item * ((-ind / max_item) + 1)) % max_item
72-
: ind % max_item;
71+
ind = std::clamp<py::ssize_t>(ind, 0, max_item - 1);
7372
return;
7473
}
7574
};

dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
#include "integer_advanced_indexing.hpp"
4141

4242
#define INDEXING_MODES 2
43-
#define CLIP_MODE 0
44-
#define WRAP_MODE 1
43+
#define WRAP_MODE 0
44+
#define CLIP_MODE 1
4545

4646
namespace dpctl
4747
{
@@ -252,8 +252,8 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
252252
throw py::value_error("Axis cannot be negative.");
253253
}
254254

255-
if (mode != 0 && mode != 1) {
256-
throw py::value_error("Mode must be 0 or 1.");
255+
if (mode != 0 && mode != 1 && mode != 2) {
256+
throw py::value_error("Mode must be 0, 1, or 2.");
257257
}
258258

259259
const dpctl::tensor::usm_ndarray ind_rep = ind[0];
@@ -575,8 +575,8 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
575575
throw py::value_error("Axis cannot be negative.");
576576
}
577577

578-
if (mode != 0 && mode != 1) {
579-
throw py::value_error("Mode must be 0 or 1.");
578+
if (mode != 0 && mode != 1 && mode != 2) {
579+
throw py::value_error("Mode must be 0, 1, or 2.");
580580
}
581581

582582
if (!dst.is_writable()) {

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from helper import get_queue_or_skip, skip_if_dtype_not_supported
2121
from numpy.testing import assert_array_equal
2222

23+
import dpctl
2324
import dpctl.tensor as dpt
2425
from dpctl.utils import ExecutionPlacementError
2526

@@ -895,20 +896,21 @@ def test_integer_indexing_modes():
895896
q = get_queue_or_skip()
896897

897898
x = dpt.arange(5, sycl_queue=q)
899+
x_np = dpt.asnumpy(x)
900+
901+
# wrapping negative indices
902+
ind = dpt.asarray([-4, -3, 0, 2, 4], dtype=np.intp, sycl_queue=q)
898903

899-
# wrapping
900-
ind = dpt.asarray([-6, -3, 0, 2, 6], dtype=np.intp, sycl_queue=q)
901904
res = dpt.take(x, ind, mode="wrap")
902-
expected_arr = np.take(dpt.asnumpy(x), dpt.asnumpy(ind), mode="wrap")
905+
expected_arr = np.take(x_np, dpt.asnumpy(ind), mode="raise")
903906

904907
assert (dpt.asnumpy(res) == expected_arr).all()
905908

906-
# clipping to -n<=i<n,
907-
# where n is the axis length
908-
ind = dpt.asarray([-4, -3, 0, 2, 4], dtype=np.intp, sycl_queue=q)
909+
# clipping to 0 (disabling negative indices)
910+
ind = dpt.asarray([-6, -3, 0, 2, 6], dtype=np.intp, sycl_queue=q)
909911

910912
res = dpt.take(x, ind, mode="clip")
911-
expected_arr = np.take(dpt.asnumpy(x), dpt.asnumpy(ind), mode="raise")
913+
expected_arr = np.take(x_np, dpt.asnumpy(ind), mode="clip")
912914

913915
assert (dpt.asnumpy(res) == expected_arr).all()
914916

@@ -939,6 +941,10 @@ def test_take_arg_validation():
939941
dpt.take(dpt.reshape(x, (2, 2)), ind0, axis=None)
940942
with pytest.raises(ValueError):
941943
dpt.take(x, dpt.reshape(ind0, (2, 2)))
944+
with pytest.raises(ValueError):
945+
dpt.take(x[0], ind0, axis=2)
946+
with pytest.raises(ValueError):
947+
dpt.take(x[:, dpt.newaxis, dpt.newaxis], ind0, axis=None)
942948

943949

944950
def test_put_arg_validation():
@@ -968,6 +974,10 @@ def test_put_arg_validation():
968974
dpt.put(x, ind0, val, mode=0)
969975
with pytest.raises(ValueError):
970976
dpt.put(x, dpt.reshape(ind0, (2, 2)), val)
977+
with pytest.raises(ValueError):
978+
dpt.put(x[0], ind0, val, axis=2)
979+
with pytest.raises(ValueError):
980+
dpt.put(x[:, dpt.newaxis, dpt.newaxis], ind0, val, axis=None)
971981

972982

973983
def test_advanced_indexing_compute_follows_data():
@@ -1269,3 +1279,43 @@ def test_nonzero_large():
12691279

12701280
m = dpt.full((30, 60, 80), True)
12711281
assert m[m].size == m.size
1282+
1283+
1284+
def test_extract_arg_validation():
1285+
get_queue_or_skip()
1286+
with pytest.raises(TypeError):
1287+
dpt.extract(None, None)
1288+
cond = dpt.ones(10, dtype="?")
1289+
with pytest.raises(TypeError):
1290+
dpt.extract(cond, None)
1291+
q1 = dpctl.SyclQueue()
1292+
with pytest.raises(ExecutionPlacementError):
1293+
dpt.extract(cond.to_device(q1), dpt.zeros_like(cond, dtype="u1"))
1294+
with pytest.raises(ValueError):
1295+
dpt.extract(dpt.ones((2, 3), dtype="?"), dpt.ones((3, 2), dtype="i1"))
1296+
1297+
1298+
def test_place_arg_validation():
1299+
get_queue_or_skip()
1300+
with pytest.raises(TypeError):
1301+
dpt.place(None, None, None)
1302+
arr = dpt.zeros(8, dtype="i1")
1303+
with pytest.raises(TypeError):
1304+
dpt.place(arr, None, None)
1305+
cond = dpt.ones(8, dtype="?")
1306+
with pytest.raises(TypeError):
1307+
dpt.place(arr, cond, None)
1308+
vals = dpt.ones_like(arr)
1309+
q1 = dpctl.SyclQueue()
1310+
with pytest.raises(ExecutionPlacementError):
1311+
dpt.place(arr.to_device(q1), cond, vals)
1312+
with pytest.raises(ValueError):
1313+
dpt.place(dpt.reshape(arr, (2, 2, 2)), cond, vals)
1314+
1315+
1316+
def test_nonzero_arg_validation():
1317+
get_queue_or_skip()
1318+
with pytest.raises(TypeError):
1319+
dpt.nonzero(list())
1320+
with pytest.raises(ValueError):
1321+
dpt.nonzero(dpt.asarray(1))

0 commit comments

Comments
 (0)