Skip to content

Commit 24e0c89

Browse files
committed
Implement dedicated strided full kernel
1 parent cfba263 commit 24e0c89

File tree

2 files changed

+201
-7
lines changed

2 files changed

+201
-7
lines changed

dpctl/tensor/libtensor/include/kernels/constructors.hpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ namespace constructors
4646

4747
template <typename Ty> class linear_sequence_step_kernel;
4848
template <typename Ty, typename wTy> class linear_sequence_affine_kernel;
49+
template <typename Ty> class full_strided_kernel;
4950
template <typename Ty> class eye_kernel;
5051

5152
using namespace dpctl::tensor::offset_utils;
@@ -252,6 +253,74 @@ sycl::event full_contig_impl(sycl::queue &q,
252253
return fill_ev;
253254
}
254255

256+
template <typename Ty, typename IndexerT> class FullStridedFunctor
257+
{
258+
private:
259+
Ty *p = nullptr;
260+
const Ty fill_v;
261+
const IndexerT indexer;
262+
263+
public:
264+
FullStridedFunctor(Ty *p_, const Ty &fill_v_, const IndexerT &indexer_)
265+
: p(p_), fill_v(fill_v_), indexer(indexer_)
266+
{
267+
}
268+
269+
void operator()(sycl::id<1> id) const
270+
{
271+
auto offset = indexer(id.get(0));
272+
p[offset] = fill_v;
273+
}
274+
};
275+
276+
/*!
277+
* @brief Function to submit kernel to fill given contiguous memory allocation
278+
* with specified value.
279+
*
280+
* @param exec_q Sycl queue to which kernel is submitted for execution.
281+
* @param nd Array dimensionality
282+
* @param nelems Length of the sequence
283+
* @param shape_strides Kernel accessible USM pointer to packed shape and
284+
* strides of array.
285+
* @param offset Displacement of first element of dst relative dst_p in
286+
* elements
287+
* @param fill_v Value to fill the array with
288+
* @param dst_p Kernel accessible USM pointer to the start of array to be
289+
* populated.
290+
* @param depends List of events to wait for before starting computations, if
291+
* any.
292+
*
293+
* @return Event to wait on to ensure that computation completes.
294+
* @defgroup CtorKernels
295+
*/
296+
template <typename dstTy>
297+
sycl::event full_strided_impl(sycl::queue &q,
298+
int nd,
299+
size_t nelems,
300+
const ssize_t *shape_strides,
301+
const ssize_t offset,
302+
dstTy fill_v,
303+
char *dst_p,
304+
const std::vector<sycl::event> &depends)
305+
{
306+
dpctl::tensor::type_utils::validate_type_for_device<dstTy>(q);
307+
308+
dstTy *dst_tp = reinterpret_cast<dstTy *>(dst_p);
309+
310+
using dpctl::tensor::offset_utils::StridedIndexer;
311+
const StridedIndexer strided_indexer(nd, offset, shape_strides);
312+
313+
sycl::event fill_ev = q.submit([&](sycl::handler &cgh) {
314+
cgh.depends_on(depends);
315+
cgh.parallel_for<full_strided_kernel<dstTy>>(
316+
sycl::range<1>{nelems},
317+
FullStridedFunctor<dstTy, decltype(strided_indexer)>(
318+
dst_tp, fill_v, strided_indexer));
319+
});
320+
321+
return fill_ev;
322+
}
323+
255324
/* ================ Eye ================== */
256325

257326
typedef sycl::event (*eye_fn_ptr_t)(sycl::queue &,

dpctl/tensor/libtensor/source/full_ctor.cpp

Lines changed: 132 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
#include "utils/type_utils.hpp"
3737

3838
#include "full_ctor.hpp"
39+
#include "simplify_iteration_space.hpp"
40+
3941

4042
namespace py = pybind11;
4143
namespace td_ns = dpctl::tensor::type_dispatch;
@@ -61,9 +63,9 @@ typedef sycl::event (*full_contig_fn_ptr_t)(sycl::queue &,
6163
*
6264
* @param exec_q Sycl queue to which kernel is submitted for execution.
6365
* @param nelems Length of the sequence
64-
* @param py_value Python object representing the value to fill the array with.
66+
* @param py_value Python object representing the value to fill the array with.
6567
* Must be convertible to `dstTy`.
66-
* @param dst_p Kernel accessible USM pointer to the start of array to be
68+
* @param dst_p Kernel accessible USM pointer to the start of array to be
6769
* populated.
6870
* @param depends List of events to wait for before starting computations, if
6971
* any.
@@ -152,7 +154,66 @@ template <typename fnT, typename Ty> struct FullContigFactory
152154
}
153155
};
154156

157+
typedef sycl::event (*full_strided_fn_ptr_t)(sycl::queue &,
158+
int,
159+
size_t,
160+
py::ssize_t *,
161+
py::ssize_t,
162+
const py::object &,
163+
char *,
164+
const std::vector<sycl::event> &);
165+
166+
/*!
167+
* @brief Function to submit kernel to fill given strided memory allocation
168+
* with specified value.
169+
*
170+
* @param exec_q Sycl queue to which kernel is submitted for execution.
171+
* @param nd Array dimensionality
172+
* @param nelems Length of the sequence
173+
* @param shape_strides Kernel accessible USM pointer to packed shape and
174+
* strides of array.
175+
* @param dst_offset Displacement of first element of dst relative dst_p in
176+
* elements
177+
* @param py_value Python object representing the value to fill the array with.
178+
* Must be convertible to `dstTy`.
179+
* @param dst_p Kernel accessible USM pointer to the start of array to be
180+
* populated.
181+
* @param depends List of events to wait for before starting computations, if
182+
* any.
183+
*
184+
* @return Event to wait on to ensure that computation completes.
185+
* @defgroup CtorKernels
186+
*/
187+
template <typename dstTy>
188+
sycl::event full_strided_impl(sycl::queue &exec_q,
189+
int nd,
190+
size_t nelems,
191+
py::ssize_t *shape_strides,
192+
py::ssize_t dst_offset,
193+
const py::object &py_value,
194+
char *dst_p,
195+
const std::vector<sycl::event> &depends)
196+
{
197+
dstTy fill_v = py::cast<dstTy>(py_value);
198+
199+
using dpctl::tensor::kernels::constructors::full_strided_impl;
200+
sycl::event fill_ev = full_strided_impl<dstTy>(
201+
exec_q, nd, nelems, shape_strides, dst_offset, fill_v, dst_p, depends);
202+
203+
return fill_ev;
204+
}
205+
206+
template <typename fnT, typename Ty> struct FullStridedFactory
207+
{
208+
fnT get()
209+
{
210+
fnT f = full_strided_impl<Ty>;
211+
return f;
212+
}
213+
};
214+
155215
static full_contig_fn_ptr_t full_contig_dispatch_vector[td_ns::num_types];
216+
static full_strided_fn_ptr_t full_strided_dispatch_vector[td_ns::num_types];
156217

157218
std::pair<sycl::event, sycl::event>
158219
usm_ndarray_full(const py::object &py_value,
@@ -194,8 +255,70 @@ usm_ndarray_full(const py::object &py_value,
194255
full_contig_event);
195256
}
196257
else {
197-
throw std::runtime_error(
198-
"Only population of contiguous usm_ndarray objects is supported.");
258+
using dpctl::tensor::py_internal::simplify_iteration_space_1;
259+
260+
int nd = dst.get_ndim();
261+
const py::ssize_t *dst_shape_ptr = dst.get_shape_raw();
262+
auto const &dst_strides = dst.get_strides_vector();
263+
264+
using shT = std::vector<py::ssize_t>;
265+
shT simplified_dst_shape;
266+
shT simplified_dst_strides;
267+
py::ssize_t dst_offset(0);
268+
269+
simplify_iteration_space_1(nd, dst_shape_ptr, dst_strides,
270+
// output
271+
simplified_dst_shape, simplified_dst_strides,
272+
dst_offset);
273+
274+
// it's possible that this branch will never be taken
275+
// need to look carefully at `simplify_iteration_space_1`
276+
// to find cases
277+
if (nd == 1 && simplified_dst_strides[0] == 1) {
278+
auto fn = full_contig_dispatch_vector[dst_typeid];
279+
280+
const sycl::event &full_contig_event =
281+
fn(exec_q, static_cast<size_t>(dst_nelems), py_value,
282+
dst_data + dst_offset, depends);
283+
284+
return std::make_pair(
285+
keep_args_alive(exec_q, {dst}, {full_contig_event}),
286+
full_contig_event);
287+
}
288+
289+
auto fn = full_strided_dispatch_vector[dst_typeid];
290+
291+
std::vector<sycl::event> host_task_events;
292+
host_task_events.reserve(2);
293+
using dpctl::tensor::offset_utils::device_allocate_and_pack;
294+
const auto &ptr_size_event_tuple =
295+
device_allocate_and_pack<py::ssize_t>(exec_q, host_task_events,
296+
simplified_dst_shape,
297+
simplified_dst_strides);
298+
py::ssize_t *shape_strides = std::get<0>(ptr_size_event_tuple);
299+
if (shape_strides == nullptr) {
300+
throw std::runtime_error("Unable to allocate device memory");
301+
}
302+
const sycl::event &copy_shape_ev = std::get<2>(ptr_size_event_tuple);
303+
304+
const sycl::event &full_strided_ev =
305+
fn(exec_q, nd, dst_nelems, shape_strides, dst_offset, py_value,
306+
dst_data, {copy_shape_ev});
307+
308+
// free shape_strides
309+
const auto &ctx = exec_q.get_context();
310+
const auto &temporaries_cleanup_ev =
311+
exec_q.submit([&](sycl::handler &cgh) {
312+
cgh.depends_on(full_strided_ev);
313+
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
314+
cgh.host_task([ctx, shape_strides]() {
315+
sycl_free_noexcept(shape_strides, ctx);
316+
});
317+
});
318+
host_task_events.push_back(temporaries_cleanup_ev);
319+
320+
return std::make_pair(keep_args_alive(exec_q, {dst}, host_task_events),
321+
full_strided_ev);
199322
}
200323
}
201324

@@ -204,10 +327,12 @@ void init_full_ctor_dispatch_vectors(void)
204327
using namespace td_ns;
205328

206329
DispatchVectorBuilder<full_contig_fn_ptr_t, FullContigFactory, num_types>
207-
dvb;
208-
dvb.populate_dispatch_vector(full_contig_dispatch_vector);
330+
dvb1;
331+
dvb1.populate_dispatch_vector(full_contig_dispatch_vector);
209332

210-
return;
333+
DispatchVectorBuilder<full_strided_fn_ptr_t, FullStridedFactory, num_types>
334+
dvb2;
335+
dvb2.populate_dispatch_vector(full_strided_dispatch_vector);
211336
}
212337

213338
} // namespace py_internal

0 commit comments

Comments
 (0)