Skip to content

Commit b4967a4

Browse files
committed
Made changes as per PR review by @oleksandr-pavlyk
1 parent b4b4a7c commit b4967a4

File tree

6 files changed

+103
-77
lines changed

6 files changed

+103
-77
lines changed

dpctl/tensor/_search_functions.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Data Parallel Control (dpctl)
22
#
3-
# Copyright 2020-2022 Intel Corporation
3+
# Copyright 2020-2023 Intel Corporation
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -40,6 +40,37 @@ def _where_result_type(dt1, dt2, dev):
4040

4141

4242
def where(condition, x1, x2):
43+
"""where(condition, x1, x2)
44+
45+
Returns :class:`dpctl.tensor.usm_ndarray` with elements chosen
46+
from `x1` or `x2` depending on `condition`.
47+
48+
Args:
49+
condition (usm_ndarray): When True yields from `x1`,
50+
and otherwise yields from `x2`.
51+
Must be compatible with `x1` and `x2` according
52+
to broadcasting rules.
53+
x1 (usm_ndarray): Array from which values are chosen when
54+
`condition` is True.
55+
Must be compatible with `condition` and `x2` according
56+
to broadcasting rules.
57+
x2 (usm_ndarray): Array from which values are chosen when
58+
`condition` is not True.
59+
Must be compatible with `condition` and `x2` according
60+
to broadcasting rules.
61+
62+
Returns:
63+
usm_ndarray:
64+
An array with elements from `x1` where `condition` is True,
65+
and elements from `x2` elsewhere.
66+
67+
The data type of the returned array is determined by applying
68+
the Type Promotion Rules to `x1` and `x2`.
69+
70+
The memory layout of the returned array is
71+
F-contiguous (column-major) when all inputs are F-contiguous,
72+
and C-contiguous (row-major) otherwise.
73+
"""
4374
if not isinstance(condition, dpt.usm_ndarray):
4475
raise TypeError(
4576
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(condition)}"
@@ -89,7 +120,7 @@ def where(condition, x1, x2):
89120

90121
deps = []
91122
wait_list = []
92-
if x1_dtype is not dst_dtype:
123+
if x1_dtype != dst_dtype:
93124
_x1 = dpt.empty_like(x1, dtype=dst_dtype)
94125
ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray(
95126
src=x1, dst=_x1, sycl_queue=exec_q
@@ -98,7 +129,7 @@ def where(condition, x1, x2):
98129
deps.append(copy1_ev)
99130
wait_list.append(ht_copy1_ev)
100131

101-
if x2_dtype is not dst_dtype:
132+
if x2_dtype != dst_dtype:
102133
_x2 = dpt.empty_like(x2, dtype=dst_dtype)
103134
ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
104135
src=x2, dst=_x2, sycl_queue=exec_q
@@ -140,7 +171,7 @@ def where(condition, x1, x2):
140171
sycl_queue=exec_q,
141172
depends=deps,
142173
)
143-
wait_list.append(hev)
144174
dpctl.SyclEvent.wait_for(wait_list)
175+
hev.wait()
145176

146177
return dst

dpctl/tensor/_type_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Data Parallel Control (dpctl)
22
#
3-
# Copyright 2020-2022 Intel Corporation
3+
# Copyright 2020-2023 Intel Corporation
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -84,7 +84,7 @@ def _all_data_types(_fp16, _fp64):
8484
]
8585

8686

87-
def is_maximal_inexact_type(dt: dpt.dtype, _fp16: bool, _fp64: bool):
87+
def _is_maximal_inexact_type(dt: dpt.dtype, _fp16: bool, _fp64: bool):
8888
"""
8989
Return True if data type `dt` is the
9090
maximal size inexact data type
@@ -106,7 +106,7 @@ def _can_cast(from_: dpt.dtype, to_: dpt.dtype, _fp16: bool, _fp64: bool):
106106
if (
107107
from_.kind in "biu"
108108
and to_.kind in "fc"
109-
and is_maximal_inexact_type(to_, _fp16, _fp64)
109+
and _is_maximal_inexact_type(to_, _fp16, _fp64)
110110
):
111111
return True
112112

dpctl/tensor/libtensor/include/kernels/where.hpp

Lines changed: 54 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// Data Parallel Control (dpctl)
44
//
5-
// Copyright 2020-2022 Intel Corporation
5+
// Copyright 2020-2023 Intel Corporation
66
//
77
// Licensed under the Apache License, Version 2.0 (the "License");
88
// you may not use this file except in compliance with the License.
@@ -57,31 +57,24 @@ class WhereContigFunctor
5757
{
5858
private:
5959
size_t nelems = 0;
60-
const char *x1_cp = nullptr;
61-
const char *x2_cp = nullptr;
62-
char *dst_cp = nullptr;
63-
const char *cond_cp = nullptr;
60+
const condT *cond_p = nullptr;
61+
const T *x1_p = nullptr;
62+
const T *x2_p = nullptr;
63+
T *dst_p = nullptr;
6464

6565
public:
6666
WhereContigFunctor(size_t nelems_,
67-
const char *cond_data_p,
68-
const char *x1_data_p,
69-
const char *x2_data_p,
70-
char *dst_data_p)
71-
: nelems(nelems_), x1_cp(x1_data_p), x2_cp(x2_data_p),
72-
dst_cp(dst_data_p), cond_cp(cond_data_p)
67+
const condT *cond_p_,
68+
const T *x1_p_,
69+
const T *x2_p_,
70+
T *dst_p_)
71+
: nelems(nelems_), cond_p(cond_p_), x1_p(x1_p_), x2_p(x2_p_),
72+
dst_p(dst_p_)
7373
{
7474
}
7575

7676
void operator()(sycl::nd_item<1> ndit) const
7777
{
78-
const T *x1_data = reinterpret_cast<const T *>(x1_cp);
79-
const T *x2_data = reinterpret_cast<const T *>(x2_cp);
80-
T *dst_data = reinterpret_cast<T *>(dst_cp);
81-
const condT *cond_data = reinterpret_cast<const condT *>(cond_cp);
82-
83-
using dpctl::tensor::type_utils::convert_impl;
84-
8578
using dpctl::tensor::type_utils::is_complex;
8679
if constexpr (is_complex<condT>::value || is_complex<T>::value) {
8780
std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0];
@@ -92,8 +85,9 @@ class WhereContigFunctor
9285
offset < std::min(nelems, base + sgSize * (n_vecs * vec_sz));
9386
offset += sgSize)
9487
{
95-
bool check = convert_impl<bool, condT>(cond_data[offset]);
96-
dst_data[offset] = check ? x1_data[offset] : x2_data[offset];
88+
using dpctl::tensor::type_utils::convert_impl;
89+
bool check = convert_impl<bool, condT>(cond_p[offset]);
90+
dst_p[offset] = check ? x1_p[offset] : x2_p[offset];
9791
}
9892
}
9993
else {
@@ -115,7 +109,6 @@ class WhereContigFunctor
115109
using cond_ptrT =
116110
sycl::multi_ptr<const condT,
117111
sycl::access::address_space::global_space>;
118-
119112
sycl::vec<T, vec_sz> dst_vec;
120113
sycl::vec<T, vec_sz> x1_vec;
121114
sycl::vec<T, vec_sz> x2_vec;
@@ -124,23 +117,20 @@ class WhereContigFunctor
124117
#pragma unroll
125118
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
126119
auto idx = base + it * sgSize;
127-
x1_vec = sg.load<vec_sz>(x_ptrT(&x1_data[idx]));
128-
x2_vec = sg.load<vec_sz>(x_ptrT(&x2_data[idx]));
129-
cond_vec = sg.load<vec_sz>(cond_ptrT(&cond_data[idx]));
130-
120+
x1_vec = sg.load<vec_sz>(x_ptrT(&x1_p[idx]));
121+
x2_vec = sg.load<vec_sz>(x_ptrT(&x2_p[idx]));
122+
cond_vec = sg.load<vec_sz>(cond_ptrT(&cond_p[idx]));
131123
#pragma unroll
132124
for (std::uint8_t k = 0; k < vec_sz; ++k) {
133-
bool check = convert_impl<bool, condT>(cond_vec[k]);
134-
dst_vec[k] = check ? x1_vec[k] : x2_vec[k];
125+
dst_vec[k] = cond_vec[k] ? x1_vec[k] : x2_vec[k];
135126
}
136-
sg.store<vec_sz>(dst_ptrT(&dst_data[idx]), dst_vec);
127+
sg.store<vec_sz>(dst_ptrT(&dst_p[idx]), dst_vec);
137128
}
138129
}
139130
else {
140131
for (size_t k = base + sg.get_local_id()[0]; k < nelems;
141132
k += sgSize) {
142-
bool check = convert_impl<bool, condT>(cond_data[k]);
143-
dst_data[k] = check ? x1_data[k] : x2_data[k];
133+
dst_p[k] = cond_p[k] ? x1_p[k] : x2_p[k];
144134
}
145135
}
146136
}
@@ -159,12 +149,17 @@ typedef sycl::event (*where_contig_impl_fn_ptr_t)(
159149
template <typename T, typename condT>
160150
sycl::event where_contig_impl(sycl::queue q,
161151
size_t nelems,
162-
const char *cond_p,
163-
const char *x1_p,
164-
const char *x2_p,
165-
char *dst_p,
152+
const char *cond_cp,
153+
const char *x1_cp,
154+
const char *x2_cp,
155+
char *dst_cp,
166156
const std::vector<sycl::event> &depends)
167157
{
158+
const condT *cond_tp = reinterpret_cast<const condT *>(cond_cp);
159+
const T *x1_tp = reinterpret_cast<const T *>(x1_cp);
160+
const T *x2_tp = reinterpret_cast<const T *>(x2_cp);
161+
T *dst_tp = reinterpret_cast<T *>(dst_cp);
162+
168163
sycl::event where_ev = q.submit([&](sycl::handler &cgh) {
169164
cgh.depends_on(depends);
170165

@@ -178,8 +173,8 @@ sycl::event where_contig_impl(sycl::queue q,
178173

179174
cgh.parallel_for<where_contig_kernel<T, condT, vec_sz, n_vecs>>(
180175
sycl::nd_range<1>(gws_range, lws_range),
181-
WhereContigFunctor<T, condT, vec_sz, n_vecs>(nelems, cond_p, x1_p,
182-
x2_p, dst_p));
176+
WhereContigFunctor<T, condT, vec_sz, n_vecs>(nelems, cond_tp, x1_tp,
177+
x2_tp, dst_tp));
183178
});
184179

185180
return where_ev;
@@ -189,39 +184,34 @@ template <typename T, typename condT, typename IndexerT>
189184
class WhereStridedFunctor
190185
{
191186
private:
192-
const char *x1_cp = nullptr;
193-
const char *x2_cp = nullptr;
194-
char *dst_cp = nullptr;
195-
const char *cond_cp = nullptr;
187+
const T *x1_p = nullptr;
188+
const T *x2_p = nullptr;
189+
T *dst_p = nullptr;
190+
const condT *cond_p = nullptr;
196191
IndexerT indexer;
197192

198193
public:
199-
WhereStridedFunctor(const char *cond_data_p,
200-
const char *x1_data_p,
201-
const char *x2_data_p,
202-
char *dst_data_p,
194+
WhereStridedFunctor(const condT *cond_p_,
195+
const T *x1_p_,
196+
const T *x2_p_,
197+
T *dst_p_,
203198
IndexerT indexer_)
204-
: x1_cp(x1_data_p), x2_cp(x2_data_p), dst_cp(dst_data_p),
205-
cond_cp(cond_data_p), indexer(indexer_)
199+
: x1_p(x1_p_), x2_p(x2_p_), dst_p(dst_p_), cond_p(cond_p_),
200+
indexer(indexer_)
206201
{
207202
}
208203

209204
void operator()(sycl::id<1> id) const
210205
{
211-
const T *x1_data = reinterpret_cast<const T *>(x1_cp);
212-
const T *x2_data = reinterpret_cast<const T *>(x2_cp);
213-
T *dst_data = reinterpret_cast<T *>(dst_cp);
214-
const condT *cond_data = reinterpret_cast<const condT *>(cond_cp);
215-
216206
size_t gid = id[0];
217207
auto offsets = indexer(static_cast<py::ssize_t>(gid));
218208

219209
using dpctl::tensor::type_utils::convert_impl;
220210
bool check =
221-
convert_impl<bool, condT>(cond_data[offsets.get_first_offset()]);
211+
convert_impl<bool, condT>(cond_p[offsets.get_first_offset()]);
222212

223-
dst_data[gid] = check ? x1_data[offsets.get_second_offset()]
224-
: x2_data[offsets.get_third_offset()];
213+
dst_p[gid] = check ? x1_p[offsets.get_second_offset()]
214+
: x2_p[offsets.get_third_offset()];
225215
}
226216
};
227217

@@ -243,16 +233,21 @@ template <typename T, typename condT>
243233
sycl::event where_strided_impl(sycl::queue q,
244234
size_t nelems,
245235
int nd,
246-
const char *cond_p,
247-
const char *x1_p,
248-
const char *x2_p,
249-
char *dst_p,
236+
const char *cond_cp,
237+
const char *x1_cp,
238+
const char *x2_cp,
239+
char *dst_cp,
250240
const py::ssize_t *shape_strides,
251241
py::ssize_t x1_offset,
252242
py::ssize_t x2_offset,
253243
py::ssize_t cond_offset,
254244
const std::vector<sycl::event> &depends)
255245
{
246+
const condT *cond_tp = reinterpret_cast<const condT *>(cond_cp);
247+
const T *x1_tp = reinterpret_cast<const T *>(x1_cp);
248+
const T *x2_tp = reinterpret_cast<const T *>(x2_cp);
249+
T *dst_tp = reinterpret_cast<T *>(dst_cp);
250+
256251
sycl::event where_ev = q.submit([&](sycl::handler &cgh) {
257252
cgh.depends_on(depends);
258253

@@ -263,7 +258,7 @@ sycl::event where_strided_impl(sycl::queue q,
263258
where_strided_kernel<T, condT, ThreeOffsets_StridedIndexer>>(
264259
sycl::range<1>(nelems),
265260
WhereStridedFunctor<T, condT, ThreeOffsets_StridedIndexer>(
266-
cond_p, x1_p, x2_p, dst_p, indexer));
261+
cond_tp, x1_tp, x2_tp, dst_tp, indexer));
267262
});
268263

269264
return where_ev;

dpctl/tensor/libtensor/source/where.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// Data Parallel Control (dpctl)
44
//
5-
// Copyright 2020-2022 Intel Corporation
5+
// Copyright 2020-2023 Intel Corporation
66
//
77
// Licensed under the Apache License, Version 2.0 (the "License");
88
// you may not use this file except in compliance with the License.
@@ -96,10 +96,10 @@ py_where(dpctl::tensor::usm_ndarray condition,
9696
bool shapes_equal(true);
9797
size_t nelems(1);
9898
for (int i = 0; i < nd; ++i) {
99-
nelems *= static_cast<size_t>(dst_shape[i]);
100-
shapes_equal = shapes_equal && (x1_shape[i] == dst_shape[i]) &&
101-
(x2_shape[i] == dst_shape[i]) &&
102-
(cond_shape[i] == dst_shape[i]);
99+
const auto &sh_i = dst_shape[i];
100+
nelems *= static_cast<size_t>(sh_i);
101+
shapes_equal = shapes_equal && (x1_shape[i] == sh_i) &&
102+
(x2_shape[i] == sh_i) && (cond_shape[i] == sh_i);
103103
}
104104

105105
if (!shapes_equal) {
@@ -127,7 +127,7 @@ py_where(dpctl::tensor::usm_ndarray condition,
127127
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
128128

129129
if (x1_typeid != x2_typeid || x1_typeid != dst_typeid) {
130-
throw py::value_error("Non-condition are not of same type.");
130+
throw py::value_error("Value arrays must have the same data type");
131131
}
132132

133133
// ensure that dst is sufficiently ample
@@ -166,8 +166,8 @@ py_where(dpctl::tensor::usm_ndarray condition,
166166

167167
auto where_ev = contig_fn(exec_q, nelems, cond_data, x1_data, x2_data,
168168
dst_data, depends);
169-
sycl::event ht_ev = dpctl::utils::keep_args_alive(
170-
exec_q, {x1, x2, dst, condition}, {where_ev});
169+
sycl::event ht_ev =
170+
keep_args_alive(exec_q, {x1, x2, dst, condition}, {where_ev});
171171

172172
return std::make_pair(ht_ev, where_ev);
173173
}

dpctl/tensor/libtensor/source/where.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// Data Parallel Control (dpctl)
44
//
5-
// Copyright 2020-2022 Intel Corporation
5+
// Copyright 2020-2023 Intel Corporation
66
//
77
// Licensed under the Apache License, Version 2.0 (the "License");
88
// you may not use this file except in compliance with the License.

dpctl/tests/test_usm_ndarray_search_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def test_where_all_dtypes(dt):
121121

122122
# mask dtype changes
123123
cond = dpt.asarray([0, 1, 3, 0, 10], dtype=dt, sycl_queue=q)
124-
x1 = dpt.asarray(0, dtype="f", sycl_queue=q)
125-
x2 = dpt.asarray(1, dtype="f", sycl_queue=q)
124+
x1 = dpt.asarray(0, dtype="f4", sycl_queue=q)
125+
x2 = dpt.asarray(1, dtype="f4", sycl_queue=q)
126126
res = dpt.where(cond, x1, x2)
127127

128128
res_check = np.asarray([1, 0, 0, 1, 0], dtype=res.dtype)

0 commit comments

Comments
 (0)