2
2
//
3
3
// Data Parallel Control (dpctl)
4
4
//
5
- // Copyright 2020-2022 Intel Corporation
5
+ // Copyright 2020-2023 Intel Corporation
6
6
//
7
7
// Licensed under the Apache License, Version 2.0 (the "License");
8
8
// you may not use this file except in compliance with the License.
@@ -57,31 +57,24 @@ class WhereContigFunctor
57
57
{
58
58
private:
59
59
size_t nelems = 0 ;
60
- const char *x1_cp = nullptr ;
61
- const char *x2_cp = nullptr ;
62
- char *dst_cp = nullptr ;
63
- const char *cond_cp = nullptr ;
60
+ const condT *cond_p = nullptr ;
61
+ const T *x1_p = nullptr ;
62
+ const T *x2_p = nullptr ;
63
+ T *dst_p = nullptr ;
64
64
65
65
public:
66
66
WhereContigFunctor (size_t nelems_,
67
- const char *cond_data_p ,
68
- const char *x1_data_p ,
69
- const char *x2_data_p ,
70
- char *dst_data_p )
71
- : nelems(nelems_), x1_cp(x1_data_p ), x2_cp(x2_data_p ),
72
- dst_cp (dst_data_p), cond_cp(cond_data_p )
67
+ const condT *cond_p_ ,
68
+ const T *x1_p_ ,
69
+ const T *x2_p_ ,
70
+ T *dst_p_ )
71
+ : nelems(nelems_), cond_p(cond_p_ ), x1_p(x1_p_), x2_p(x2_p_ ),
72
+ dst_p (dst_p_ )
73
73
{
74
74
}
75
75
76
76
void operator ()(sycl::nd_item<1 > ndit) const
77
77
{
78
- const T *x1_data = reinterpret_cast <const T *>(x1_cp);
79
- const T *x2_data = reinterpret_cast <const T *>(x2_cp);
80
- T *dst_data = reinterpret_cast <T *>(dst_cp);
81
- const condT *cond_data = reinterpret_cast <const condT *>(cond_cp);
82
-
83
- using dpctl::tensor::type_utils::convert_impl;
84
-
85
78
using dpctl::tensor::type_utils::is_complex;
86
79
if constexpr (is_complex<condT>::value || is_complex<T>::value) {
87
80
std::uint8_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
@@ -92,8 +85,9 @@ class WhereContigFunctor
92
85
offset < std::min (nelems, base + sgSize * (n_vecs * vec_sz));
93
86
offset += sgSize)
94
87
{
95
- bool check = convert_impl<bool , condT>(cond_data[offset]);
96
- dst_data[offset] = check ? x1_data[offset] : x2_data[offset];
88
+ using dpctl::tensor::type_utils::convert_impl;
89
+ bool check = convert_impl<bool , condT>(cond_p[offset]);
90
+ dst_p[offset] = check ? x1_p[offset] : x2_p[offset];
97
91
}
98
92
}
99
93
else {
@@ -115,7 +109,6 @@ class WhereContigFunctor
115
109
using cond_ptrT =
116
110
sycl::multi_ptr<const condT,
117
111
sycl::access::address_space::global_space>;
118
-
119
112
sycl::vec<T, vec_sz> dst_vec;
120
113
sycl::vec<T, vec_sz> x1_vec;
121
114
sycl::vec<T, vec_sz> x2_vec;
@@ -124,23 +117,20 @@ class WhereContigFunctor
124
117
#pragma unroll
125
118
for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
126
119
auto idx = base + it * sgSize;
127
- x1_vec = sg.load <vec_sz>(x_ptrT (&x1_data[idx]));
128
- x2_vec = sg.load <vec_sz>(x_ptrT (&x2_data[idx]));
129
- cond_vec = sg.load <vec_sz>(cond_ptrT (&cond_data[idx]));
130
-
120
+ x1_vec = sg.load <vec_sz>(x_ptrT (&x1_p[idx]));
121
+ x2_vec = sg.load <vec_sz>(x_ptrT (&x2_p[idx]));
122
+ cond_vec = sg.load <vec_sz>(cond_ptrT (&cond_p[idx]));
131
123
#pragma unroll
132
124
for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
133
- bool check = convert_impl<bool , condT>(cond_vec[k]);
134
- dst_vec[k] = check ? x1_vec[k] : x2_vec[k];
125
+ dst_vec[k] = cond_vec[k] ? x1_vec[k] : x2_vec[k];
135
126
}
136
- sg.store <vec_sz>(dst_ptrT (&dst_data [idx]), dst_vec);
127
+ sg.store <vec_sz>(dst_ptrT (&dst_p [idx]), dst_vec);
137
128
}
138
129
}
139
130
else {
140
131
for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems;
141
132
k += sgSize) {
142
- bool check = convert_impl<bool , condT>(cond_data[k]);
143
- dst_data[k] = check ? x1_data[k] : x2_data[k];
133
+ dst_p[k] = cond_p[k] ? x1_p[k] : x2_p[k];
144
134
}
145
135
}
146
136
}
@@ -159,12 +149,17 @@ typedef sycl::event (*where_contig_impl_fn_ptr_t)(
159
149
template <typename T, typename condT>
160
150
sycl::event where_contig_impl (sycl::queue q,
161
151
size_t nelems,
162
- const char *cond_p ,
163
- const char *x1_p ,
164
- const char *x2_p ,
165
- char *dst_p ,
152
+ const char *cond_cp ,
153
+ const char *x1_cp ,
154
+ const char *x2_cp ,
155
+ char *dst_cp ,
166
156
const std::vector<sycl::event> &depends)
167
157
{
158
+ const condT *cond_tp = reinterpret_cast <const condT *>(cond_cp);
159
+ const T *x1_tp = reinterpret_cast <const T *>(x1_cp);
160
+ const T *x2_tp = reinterpret_cast <const T *>(x2_cp);
161
+ T *dst_tp = reinterpret_cast <T *>(dst_cp);
162
+
168
163
sycl::event where_ev = q.submit ([&](sycl::handler &cgh) {
169
164
cgh.depends_on (depends);
170
165
@@ -178,8 +173,8 @@ sycl::event where_contig_impl(sycl::queue q,
178
173
179
174
cgh.parallel_for <where_contig_kernel<T, condT, vec_sz, n_vecs>>(
180
175
sycl::nd_range<1 >(gws_range, lws_range),
181
- WhereContigFunctor<T, condT, vec_sz, n_vecs>(nelems, cond_p, x1_p ,
182
- x2_p, dst_p ));
176
+ WhereContigFunctor<T, condT, vec_sz, n_vecs>(nelems, cond_tp, x1_tp ,
177
+ x2_tp, dst_tp ));
183
178
});
184
179
185
180
return where_ev;
@@ -189,39 +184,34 @@ template <typename T, typename condT, typename IndexerT>
189
184
class WhereStridedFunctor
190
185
{
191
186
private:
192
- const char *x1_cp = nullptr ;
193
- const char *x2_cp = nullptr ;
194
- char *dst_cp = nullptr ;
195
- const char *cond_cp = nullptr ;
187
+ const T *x1_p = nullptr ;
188
+ const T *x2_p = nullptr ;
189
+ T *dst_p = nullptr ;
190
+ const condT *cond_p = nullptr ;
196
191
IndexerT indexer;
197
192
198
193
public:
199
- WhereStridedFunctor (const char *cond_data_p ,
200
- const char *x1_data_p ,
201
- const char *x2_data_p ,
202
- char *dst_data_p ,
194
+ WhereStridedFunctor (const condT *cond_p_ ,
195
+ const T *x1_p_ ,
196
+ const T *x2_p_ ,
197
+ T *dst_p_ ,
203
198
IndexerT indexer_)
204
- : x1_cp(x1_data_p ), x2_cp(x2_data_p ), dst_cp(dst_data_p ),
205
- cond_cp (cond_data_p), indexer(indexer_)
199
+ : x1_p(x1_p_ ), x2_p(x2_p_ ), dst_p(dst_p_), cond_p(cond_p_ ),
200
+ indexer (indexer_)
206
201
{
207
202
}
208
203
209
204
void operator ()(sycl::id<1 > id) const
210
205
{
211
- const T *x1_data = reinterpret_cast <const T *>(x1_cp);
212
- const T *x2_data = reinterpret_cast <const T *>(x2_cp);
213
- T *dst_data = reinterpret_cast <T *>(dst_cp);
214
- const condT *cond_data = reinterpret_cast <const condT *>(cond_cp);
215
-
216
206
size_t gid = id[0 ];
217
207
auto offsets = indexer (static_cast <py::ssize_t >(gid));
218
208
219
209
using dpctl::tensor::type_utils::convert_impl;
220
210
bool check =
221
- convert_impl<bool , condT>(cond_data [offsets.get_first_offset ()]);
211
+ convert_impl<bool , condT>(cond_p [offsets.get_first_offset ()]);
222
212
223
- dst_data [gid] = check ? x1_data [offsets.get_second_offset ()]
224
- : x2_data [offsets.get_third_offset ()];
213
+ dst_p [gid] = check ? x1_p [offsets.get_second_offset ()]
214
+ : x2_p [offsets.get_third_offset ()];
225
215
}
226
216
};
227
217
@@ -243,16 +233,21 @@ template <typename T, typename condT>
243
233
sycl::event where_strided_impl (sycl::queue q,
244
234
size_t nelems,
245
235
int nd,
246
- const char *cond_p ,
247
- const char *x1_p ,
248
- const char *x2_p ,
249
- char *dst_p ,
236
+ const char *cond_cp ,
237
+ const char *x1_cp ,
238
+ const char *x2_cp ,
239
+ char *dst_cp ,
250
240
const py::ssize_t *shape_strides,
251
241
py::ssize_t x1_offset,
252
242
py::ssize_t x2_offset,
253
243
py::ssize_t cond_offset,
254
244
const std::vector<sycl::event> &depends)
255
245
{
246
+ const condT *cond_tp = reinterpret_cast <const condT *>(cond_cp);
247
+ const T *x1_tp = reinterpret_cast <const T *>(x1_cp);
248
+ const T *x2_tp = reinterpret_cast <const T *>(x2_cp);
249
+ T *dst_tp = reinterpret_cast <T *>(dst_cp);
250
+
256
251
sycl::event where_ev = q.submit ([&](sycl::handler &cgh) {
257
252
cgh.depends_on (depends);
258
253
@@ -263,7 +258,7 @@ sycl::event where_strided_impl(sycl::queue q,
263
258
where_strided_kernel<T, condT, ThreeOffsets_StridedIndexer>>(
264
259
sycl::range<1 >(nelems),
265
260
WhereStridedFunctor<T, condT, ThreeOffsets_StridedIndexer>(
266
- cond_p, x1_p, x2_p, dst_p , indexer));
261
+ cond_tp, x1_tp, x2_tp, dst_tp , indexer));
267
262
});
268
263
269
264
return where_ev;
0 commit comments