36
36
#include " utils/type_utils.hpp"
37
37
38
38
#include " full_ctor.hpp"
39
+ #include " simplify_iteration_space.hpp"
40
+
39
41
40
42
namespace py = pybind11;
41
43
namespace td_ns = dpctl::tensor::type_dispatch;
@@ -61,9 +63,9 @@ typedef sycl::event (*full_contig_fn_ptr_t)(sycl::queue &,
61
63
*
62
64
* @param exec_q Sycl queue to which kernel is submitted for execution.
63
65
* @param nelems Length of the sequence
64
- * @param py_value Python object representing the value to fill the array with.
66
+ * @param py_value Python object representing the value to fill the array with.
65
67
* Must be convertible to `dstTy`.
66
- * @param dst_p Kernel accessible USM pointer to the start of array to be
68
+ * @param dst_p Kernel accessible USM pointer to the start of array to be
67
69
* populated.
68
70
* @param depends List of events to wait for before starting computations, if
69
71
* any.
@@ -152,7 +154,66 @@ template <typename fnT, typename Ty> struct FullContigFactory
152
154
}
153
155
};
154
156
157
+ typedef sycl::event (*full_strided_fn_ptr_t )(sycl::queue &,
158
+ int ,
159
+ size_t ,
160
+ py::ssize_t *,
161
+ py::ssize_t ,
162
+ const py::object &,
163
+ char *,
164
+ const std::vector<sycl::event> &);
165
+
166
+ /* !
167
+ * @brief Function to submit kernel to fill given strided memory allocation
168
+ * with specified value.
169
+ *
170
+ * @param exec_q Sycl queue to which kernel is submitted for execution.
171
+ * @param nd Array dimensionality
172
+ * @param nelems Length of the sequence
173
+ * @param shape_strides Kernel accessible USM pointer to packed shape and
174
+ * strides of array.
175
+ * @param dst_offset Displacement of first element of dst relative dst_p in
176
+ * elements
177
+ * @param py_value Python object representing the value to fill the array with.
178
+ * Must be convertible to `dstTy`.
179
+ * @param dst_p Kernel accessible USM pointer to the start of array to be
180
+ * populated.
181
+ * @param depends List of events to wait for before starting computations, if
182
+ * any.
183
+ *
184
+ * @return Event to wait on to ensure that computation completes.
185
+ * @defgroup CtorKernels
186
+ */
187
+ template <typename dstTy>
188
+ sycl::event full_strided_impl (sycl::queue &exec_q,
189
+ int nd,
190
+ size_t nelems,
191
+ py::ssize_t *shape_strides,
192
+ py::ssize_t dst_offset,
193
+ const py::object &py_value,
194
+ char *dst_p,
195
+ const std::vector<sycl::event> &depends)
196
+ {
197
+ dstTy fill_v = py::cast<dstTy>(py_value);
198
+
199
+ using dpctl::tensor::kernels::constructors::full_strided_impl;
200
+ sycl::event fill_ev = full_strided_impl<dstTy>(
201
+ exec_q, nd, nelems, shape_strides, dst_offset, fill_v, dst_p, depends);
202
+
203
+ return fill_ev;
204
+ }
205
+
206
+ template <typename fnT, typename Ty> struct FullStridedFactory
207
+ {
208
+ fnT get ()
209
+ {
210
+ fnT f = full_strided_impl<Ty>;
211
+ return f;
212
+ }
213
+ };
214
+
155
215
static full_contig_fn_ptr_t full_contig_dispatch_vector[td_ns::num_types];
216
+ static full_strided_fn_ptr_t full_strided_dispatch_vector[td_ns::num_types];
156
217
157
218
std::pair<sycl::event, sycl::event>
158
219
usm_ndarray_full (const py::object &py_value,
@@ -194,8 +255,70 @@ usm_ndarray_full(const py::object &py_value,
194
255
full_contig_event);
195
256
}
196
257
else {
197
- throw std::runtime_error (
198
- " Only population of contiguous usm_ndarray objects is supported." );
258
+ using dpctl::tensor::py_internal::simplify_iteration_space_1;
259
+
260
+ int nd = dst.get_ndim ();
261
+ const py::ssize_t *dst_shape_ptr = dst.get_shape_raw ();
262
+ auto const &dst_strides = dst.get_strides_vector ();
263
+
264
+ using shT = std::vector<py::ssize_t >;
265
+ shT simplified_dst_shape;
266
+ shT simplified_dst_strides;
267
+ py::ssize_t dst_offset (0 );
268
+
269
+ simplify_iteration_space_1 (nd, dst_shape_ptr, dst_strides,
270
+ // output
271
+ simplified_dst_shape, simplified_dst_strides,
272
+ dst_offset);
273
+
274
+ // it's possible that this branch will never be taken
275
+ // need to look carefully at `simplify_iteration_space_1`
276
+ // to find cases
277
+ if (nd == 1 && simplified_dst_strides[0 ] == 1 ) {
278
+ auto fn = full_contig_dispatch_vector[dst_typeid];
279
+
280
+ const sycl::event &full_contig_event =
281
+ fn (exec_q, static_cast <size_t >(dst_nelems), py_value,
282
+ dst_data + dst_offset, depends);
283
+
284
+ return std::make_pair (
285
+ keep_args_alive (exec_q, {dst}, {full_contig_event}),
286
+ full_contig_event);
287
+ }
288
+
289
+ auto fn = full_strided_dispatch_vector[dst_typeid];
290
+
291
+ std::vector<sycl::event> host_task_events;
292
+ host_task_events.reserve (2 );
293
+ using dpctl::tensor::offset_utils::device_allocate_and_pack;
294
+ const auto &ptr_size_event_tuple =
295
+ device_allocate_and_pack<py::ssize_t >(exec_q, host_task_events,
296
+ simplified_dst_shape,
297
+ simplified_dst_strides);
298
+ py::ssize_t *shape_strides = std::get<0 >(ptr_size_event_tuple);
299
+ if (shape_strides == nullptr ) {
300
+ throw std::runtime_error (" Unable to allocate device memory" );
301
+ }
302
+ const sycl::event ©_shape_ev = std::get<2 >(ptr_size_event_tuple);
303
+
304
+ const sycl::event &full_strided_ev =
305
+ fn (exec_q, nd, dst_nelems, shape_strides, dst_offset, py_value,
306
+ dst_data, {copy_shape_ev});
307
+
308
+ // free shape_strides
309
+ const auto &ctx = exec_q.get_context ();
310
+ const auto &temporaries_cleanup_ev =
311
+ exec_q.submit ([&](sycl::handler &cgh) {
312
+ cgh.depends_on (full_strided_ev);
313
+ using dpctl::tensor::alloc_utils::sycl_free_noexcept;
314
+ cgh.host_task ([ctx, shape_strides]() {
315
+ sycl_free_noexcept (shape_strides, ctx);
316
+ });
317
+ });
318
+ host_task_events.push_back (temporaries_cleanup_ev);
319
+
320
+ return std::make_pair (keep_args_alive (exec_q, {dst}, host_task_events),
321
+ full_strided_ev);
199
322
}
200
323
}
201
324
@@ -204,10 +327,12 @@ void init_full_ctor_dispatch_vectors(void)
204
327
using namespace td_ns ;
205
328
206
329
DispatchVectorBuilder<full_contig_fn_ptr_t , FullContigFactory, num_types>
207
- dvb ;
208
- dvb .populate_dispatch_vector (full_contig_dispatch_vector);
330
+ dvb1 ;
331
+ dvb1 .populate_dispatch_vector (full_contig_dispatch_vector);
209
332
210
- return ;
333
+ DispatchVectorBuilder<full_strided_fn_ptr_t , FullStridedFactory, num_types>
334
+ dvb2;
335
+ dvb2.populate_dispatch_vector (full_strided_dispatch_vector);
211
336
}
212
337
213
338
} // namespace py_internal
0 commit comments