Skip to content

Commit 7ffc449

Browse files
committed
Implement dedicated strided full kernel
1 parent cfba263 commit 7ffc449

File tree

2 files changed

+164
-7
lines changed

2 files changed

+164
-7
lines changed

dpctl/tensor/libtensor/include/kernels/constructors.hpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ namespace constructors
4646

4747
template <typename Ty> class linear_sequence_step_kernel;
4848
template <typename Ty, typename wTy> class linear_sequence_affine_kernel;
49+
template <typename Ty> class full_strided_kernel;
4950
template <typename Ty> class eye_kernel;
5051

5152
using namespace dpctl::tensor::offset_utils;
@@ -252,6 +253,71 @@ sycl::event full_contig_impl(sycl::queue &q,
252253
return fill_ev;
253254
}
254255

256+
template <typename Ty, typename IndexerT> class FullStridedFunctor
257+
{
258+
private:
259+
Ty *p = nullptr;
260+
const Ty fill_v;
261+
const IndexerT indexer;
262+
263+
public:
264+
FullStridedFunctor(Ty *p_, const Ty &fill_v_, const IndexerT &indexer_)
265+
: p(p_), fill_v(fill_v_), indexer(indexer_)
266+
{
267+
}
268+
269+
void operator()(sycl::id<1> id) const
270+
{
271+
auto offset = indexer(id.get(0));
272+
p[offset] = fill_v;
273+
}
274+
};
275+
276+
/*!
277+
* @brief Function to submit kernel to fill given contiguous memory allocation
278+
* with specified value.
279+
*
280+
* @param exec_q Sycl queue to which kernel is submitted for execution.
281+
* @param nd Array dimensionality
282+
* @param nelems Length of the sequence
283+
* @param shape_strides Kernel accessible USM pointer to packed shape and
284+
* strides of array.
285+
* @param fill_v Value to fill the array with
286+
* @param dst_p Kernel accessible USM pointer to the start of array to be
287+
* populated.
288+
* @param depends List of events to wait for before starting computations, if
289+
* any.
290+
*
291+
* @return Event to wait on to ensure that computation completes.
292+
* @defgroup CtorKernels
293+
*/
294+
template <typename dstTy>
295+
sycl::event full_strided_impl(sycl::queue &q,
296+
int nd,
297+
size_t nelems,
298+
const ssize_t *shape_strides,
299+
dstTy fill_v,
300+
char *dst_p,
301+
const std::vector<sycl::event> &depends)
302+
{
303+
dpctl::tensor::type_utils::validate_type_for_device<dstTy>(q);
304+
305+
dstTy *dst_tp = reinterpret_cast<dstTy *>(dst_p);
306+
307+
using dpctl::tensor::offset_utils::StridedIndexer;
308+
const StridedIndexer strided_indexer(nd, 0, shape_strides);
309+
310+
sycl::event fill_ev = q.submit([&](sycl::handler &cgh) {
311+
cgh.depends_on(depends);
312+
cgh.parallel_for<full_strided_kernel<dstTy>>(
313+
sycl::range<1>{nelems},
314+
FullStridedFunctor<dstTy, decltype(strided_indexer)>(
315+
dst_tp, fill_v, strided_indexer));
316+
});
317+
318+
return fill_ev;
319+
}
320+
255321
/* ================ Eye ================== */
256322

257323
typedef sycl::event (*eye_fn_ptr_t)(sycl::queue &,

dpctl/tensor/libtensor/source/full_ctor.cpp

Lines changed: 98 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ typedef sycl::event (*full_contig_fn_ptr_t)(sycl::queue &,
6161
*
6262
* @param exec_q Sycl queue to which kernel is submitted for execution.
6363
* @param nelems Length of the sequence
64-
* @param py_value Python object representing the value to fill the array with.
64+
* @param py_value Python object representing the value to fill the array with.
6565
* Must be convertible to `dstTy`.
66-
* @param dst_p Kernel accessible USM pointer to the start of array to be
66+
* @param dst_p Kernel accessible USM pointer to the start of array to be
6767
* populated.
6868
* @param depends List of events to wait for before starting computations, if
6969
* any.
@@ -152,7 +152,62 @@ template <typename fnT, typename Ty> struct FullContigFactory
152152
}
153153
};
154154

155+
typedef sycl::event (*full_strided_fn_ptr_t)(sycl::queue &,
156+
int,
157+
size_t,
158+
py::ssize_t *,
159+
const py::object &,
160+
char *,
161+
const std::vector<sycl::event> &);
162+
163+
/*!
164+
* @brief Function to submit kernel to fill given strided memory allocation
165+
* with specified value.
166+
*
167+
* @param exec_q Sycl queue to which kernel is submitted for execution.
168+
* @param nd Array dimensionality
169+
* @param nelems Length of the sequence
170+
* @param shape_strides Kernel accessible USM pointer to packed shape and
171+
* strides of array.
172+
* @param py_value Python object representing the value to fill the array with.
173+
* Must be convertible to `dstTy`.
174+
* @param dst_p Kernel accessible USM pointer to the start of array to be
175+
* populated.
176+
* @param depends List of events to wait for before starting computations, if
177+
* any.
178+
*
179+
* @return Event to wait on to ensure that computation completes.
180+
* @defgroup CtorKernels
181+
*/
182+
template <typename dstTy>
183+
sycl::event full_strided_impl(sycl::queue &exec_q,
184+
int nd,
185+
size_t nelems,
186+
py::ssize_t *shape_strides,
187+
const py::object &py_value,
188+
char *dst_p,
189+
const std::vector<sycl::event> &depends)
190+
{
191+
dstTy fill_v = py::cast<dstTy>(py_value);
192+
193+
using dpctl::tensor::kernels::constructors::full_strided_impl;
194+
sycl::event fill_ev = full_strided_impl<dstTy>(
195+
exec_q, nd, nelems, shape_strides, fill_v, dst_p, depends);
196+
197+
return fill_ev;
198+
}
199+
200+
template <typename fnT, typename Ty> struct FullStridedFactory
201+
{
202+
fnT get()
203+
{
204+
fnT f = full_strided_impl<Ty>;
205+
return f;
206+
}
207+
};
208+
155209
static full_contig_fn_ptr_t full_contig_dispatch_vector[td_ns::num_types];
210+
static full_strided_fn_ptr_t full_strided_dispatch_vector[td_ns::num_types];
156211

157212
std::pair<sycl::event, sycl::event>
158213
usm_ndarray_full(const py::object &py_value,
@@ -194,8 +249,42 @@ usm_ndarray_full(const py::object &py_value,
194249
full_contig_event);
195250
}
196251
else {
197-
throw std::runtime_error(
198-
"Only population of contiguous usm_ndarray objects is supported.");
252+
int nd = dst.get_ndim();
253+
auto const &dst_shape = dst.get_shape_vector();
254+
auto const &dst_strides = dst.get_strides_vector();
255+
256+
auto fn = full_strided_dispatch_vector[dst_typeid];
257+
258+
std::vector<sycl::event> host_task_events;
259+
host_task_events.reserve(2);
260+
using dpctl::tensor::offset_utils::device_allocate_and_pack;
261+
const auto &ptr_size_event_tuple =
262+
device_allocate_and_pack<py::ssize_t>(exec_q, host_task_events,
263+
dst_shape, dst_strides);
264+
py::ssize_t *shape_strides = std::get<0>(ptr_size_event_tuple);
265+
if (shape_strides == nullptr) {
266+
throw std::runtime_error("Unable to allocate device memory");
267+
}
268+
const sycl::event &copy_shape_ev = std::get<2>(ptr_size_event_tuple);
269+
270+
const sycl::event &full_strided_ev =
271+
fn(exec_q, nd, dst_nelems, shape_strides, py_value, dst_data,
272+
{copy_shape_ev});
273+
274+
// free shape_strides
275+
const auto &ctx = exec_q.get_context();
276+
const auto &temporaries_cleanup_ev =
277+
exec_q.submit([&](sycl::handler &cgh) {
278+
cgh.depends_on(full_strided_ev);
279+
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
280+
cgh.host_task([ctx, shape_strides]() {
281+
sycl_free_noexcept(shape_strides, ctx);
282+
});
283+
});
284+
host_task_events.push_back(temporaries_cleanup_ev);
285+
286+
return std::make_pair(keep_args_alive(exec_q, {dst}, host_task_events),
287+
full_strided_ev);
199288
}
200289
}
201290

@@ -204,10 +293,12 @@ void init_full_ctor_dispatch_vectors(void)
204293
using namespace td_ns;
205294

206295
DispatchVectorBuilder<full_contig_fn_ptr_t, FullContigFactory, num_types>
207-
dvb;
208-
dvb.populate_dispatch_vector(full_contig_dispatch_vector);
296+
dvb1;
297+
dvb1.populate_dispatch_vector(full_contig_dispatch_vector);
209298

210-
return;
299+
DispatchVectorBuilder<full_strided_fn_ptr_t, FullStridedFactory, num_types>
300+
dvb2;
301+
dvb2.populate_dispatch_vector(full_strided_dispatch_vector);
211302
}
212303

213304
} // namespace py_internal

0 commit comments

Comments
 (0)