@@ -61,9 +61,9 @@ typedef sycl::event (*full_contig_fn_ptr_t)(sycl::queue &,
61
61
*
62
62
* @param exec_q Sycl queue to which kernel is submitted for execution.
63
63
* @param nelems Length of the sequence
64
- * @param py_value Python object representing the value to fill the array with.
64
+ * @param py_value Python object representing the value to fill the array with.
65
65
* Must be convertible to `dstTy`.
66
- * @param dst_p Kernel accessible USM pointer to the start of array to be
66
+ * @param dst_p Kernel accessible USM pointer to the start of array to be
67
67
* populated.
68
68
* @param depends List of events to wait for before starting computations, if
69
69
* any.
@@ -152,7 +152,62 @@ template <typename fnT, typename Ty> struct FullContigFactory
152
152
}
153
153
};
154
154
155
+ typedef sycl::event (*full_strided_fn_ptr_t )(sycl::queue &,
156
+ int ,
157
+ size_t ,
158
+ py::ssize_t *,
159
+ const py::object &,
160
+ char *,
161
+ const std::vector<sycl::event> &);
162
+
163
+ /* !
164
+ * @brief Function to submit kernel to fill given strided memory allocation
165
+ * with specified value.
166
+ *
167
+ * @param exec_q Sycl queue to which kernel is submitted for execution.
168
+ * @param nd Array dimensionality
169
+ * @param nelems Length of the sequence
170
+ * @param shape_strides Kernel accessible USM pointer to packed shape and
171
+ * strides of array.
172
+ * @param py_value Python object representing the value to fill the array with.
173
+ * Must be convertible to `dstTy`.
174
+ * @param dst_p Kernel accessible USM pointer to the start of array to be
175
+ * populated.
176
+ * @param depends List of events to wait for before starting computations, if
177
+ * any.
178
+ *
179
+ * @return Event to wait on to ensure that computation completes.
180
+ * @defgroup CtorKernels
181
+ */
182
+ template <typename dstTy>
183
+ sycl::event full_strided_impl (sycl::queue &exec_q,
184
+ int nd,
185
+ size_t nelems,
186
+ py::ssize_t *shape_strides,
187
+ const py::object &py_value,
188
+ char *dst_p,
189
+ const std::vector<sycl::event> &depends)
190
+ {
191
+ dstTy fill_v = py::cast<dstTy>(py_value);
192
+
193
+ using dpctl::tensor::kernels::constructors::full_strided_impl;
194
+ sycl::event fill_ev = full_strided_impl<dstTy>(
195
+ exec_q, nd, nelems, shape_strides, fill_v, dst_p, depends);
196
+
197
+ return fill_ev;
198
+ }
199
+
200
+ template <typename fnT, typename Ty> struct FullStridedFactory
201
+ {
202
+ fnT get ()
203
+ {
204
+ fnT f = full_strided_impl<Ty>;
205
+ return f;
206
+ }
207
+ };
208
+
155
209
static full_contig_fn_ptr_t full_contig_dispatch_vector[td_ns::num_types];
210
+ static full_strided_fn_ptr_t full_strided_dispatch_vector[td_ns::num_types];
156
211
157
212
std::pair<sycl::event, sycl::event>
158
213
usm_ndarray_full (const py::object &py_value,
@@ -194,8 +249,42 @@ usm_ndarray_full(const py::object &py_value,
194
249
full_contig_event);
195
250
}
196
251
else {
197
- throw std::runtime_error (
198
- " Only population of contiguous usm_ndarray objects is supported." );
252
+ int nd = dst.get_ndim ();
253
+ auto const &dst_shape = dst.get_shape_vector ();
254
+ auto const &dst_strides = dst.get_strides_vector ();
255
+
256
+ auto fn = full_strided_dispatch_vector[dst_typeid];
257
+
258
+ std::vector<sycl::event> host_task_events;
259
+ host_task_events.reserve (2 );
260
+ using dpctl::tensor::offset_utils::device_allocate_and_pack;
261
+ const auto &ptr_size_event_tuple =
262
+ device_allocate_and_pack<py::ssize_t >(exec_q, host_task_events,
263
+ dst_shape, dst_strides);
264
+ py::ssize_t *shape_strides = std::get<0 >(ptr_size_event_tuple);
265
+ if (shape_strides == nullptr ) {
266
+ throw std::runtime_error (" Unable to allocate device memory" );
267
+ }
268
+ const sycl::event ©_shape_ev = std::get<2 >(ptr_size_event_tuple);
269
+
270
+ const sycl::event &full_strided_ev =
271
+ fn (exec_q, nd, dst_nelems, shape_strides, py_value, dst_data,
272
+ {copy_shape_ev});
273
+
274
+ // free shape_strides
275
+ const auto &ctx = exec_q.get_context ();
276
+ const auto &temporaries_cleanup_ev =
277
+ exec_q.submit ([&](sycl::handler &cgh) {
278
+ cgh.depends_on (full_strided_ev);
279
+ using dpctl::tensor::alloc_utils::sycl_free_noexcept;
280
+ cgh.host_task ([ctx, shape_strides]() {
281
+ sycl_free_noexcept (shape_strides, ctx);
282
+ });
283
+ });
284
+ host_task_events.push_back (temporaries_cleanup_ev);
285
+
286
+ return std::make_pair (keep_args_alive (exec_q, {dst}, host_task_events),
287
+ full_strided_ev);
199
288
}
200
289
}
201
290
@@ -204,10 +293,12 @@ void init_full_ctor_dispatch_vectors(void)
204
293
using namespace td_ns ;
205
294
206
295
DispatchVectorBuilder<full_contig_fn_ptr_t , FullContigFactory, num_types>
207
- dvb ;
208
- dvb .populate_dispatch_vector (full_contig_dispatch_vector);
296
+ dvb1 ;
297
+ dvb1 .populate_dispatch_vector (full_contig_dispatch_vector);
209
298
210
- return ;
299
+ DispatchVectorBuilder<full_strided_fn_ptr_t , FullStridedFactory, num_types>
300
+ dvb2;
301
+ dvb2.populate_dispatch_vector (full_strided_dispatch_vector);
211
302
}
212
303
213
304
} // namespace py_internal
0 commit comments