Skip to content

Commit 3f9a5bf

Browse files
Introduce ti._zeros_usm_ndarray(dst, sycl_queue)
This is akin to _full_usm_ndarray, but does not take fill_value, hence does not require castings. It dispatches straight to handler::memset.
1 parent 4297fef commit 3f9a5bf

File tree

4 files changed

+216
-0
lines changed

4 files changed

+216
-0
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ set(_tensor_impl_sources
129129
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_advanced_indexing.cpp
130130
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/eye_ctor.cpp
131131
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
132+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/zeros_ctor.cpp
132133
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp
133134
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp
134135
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp

dpctl/tensor/libtensor/source/tensor_ctors.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
#include "utils/memory_overlap.hpp"
5555
#include "utils/strided_iters.hpp"
5656
#include "where.hpp"
57+
#include "zeros_ctor.hpp"
5758

5859
namespace py = pybind11;
5960

@@ -92,6 +93,10 @@ using dpctl::tensor::py_internal::usm_ndarray_linear_sequence_step;
9293

9394
using dpctl::tensor::py_internal::usm_ndarray_full;
9495

96+
/* ================ Zeros ================== */
97+
98+
using dpctl::tensor::py_internal::usm_ndarray_zeros;
99+
95100
/* ============== Advanced Indexing ============= */
96101
using dpctl::tensor::py_internal::usm_ndarray_put;
97102
using dpctl::tensor::py_internal::usm_ndarray_take;
@@ -142,6 +147,7 @@ void init_dispatch_vectors(void)
142147
init_copy_for_roll_dispatch_vectors();
143148
init_linear_sequences_dispatch_vectors();
144149
init_full_ctor_dispatch_vectors();
150+
init_zeros_ctor_dispatch_vectors();
145151
init_eye_ctor_dispatch_vectors();
146152
init_triul_ctor_dispatch_vectors();
147153

@@ -291,6 +297,10 @@ PYBIND11_MODULE(_tensor_impl, m)
291297
py::arg("src"), py::arg("dst"), py::arg("sycl_queue"),
292298
py::arg("depends") = py::list());
293299

300+
m.def("_zeros_usm_ndarray", &usm_ndarray_zeros,
301+
"Populate usm_ndarray `dst` with zeros.", py::arg("dst"),
302+
py::arg("sycl_queue"), py::arg("depends") = py::list());
303+
294304
m.def("_full_usm_ndarray", &usm_ndarray_full,
295305
"Populate usm_ndarray `dst` with given fill_value.",
296306
py::arg("fill_value"), py::arg("dst"), py::arg("sycl_queue"),
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2024 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===--------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
23+
//===--------------------------------------------------------------------===//
24+
25+
#include "dpctl4pybind11.hpp"
26+
#include <complex>
27+
#include <pybind11/complex.h>
28+
#include <pybind11/pybind11.h>
29+
#include <sycl/sycl.hpp>
30+
#include <utility>
31+
#include <vector>
32+
33+
#include "kernels/constructors.hpp"
34+
#include "utils/output_validation.hpp"
35+
#include "utils/type_dispatch.hpp"
36+
#include "utils/type_utils.hpp"
37+
38+
#include "zeros_ctor.hpp"
39+
40+
namespace py = pybind11;
41+
namespace td_ns = dpctl::tensor::type_dispatch;
42+
43+
namespace dpctl
44+
{
45+
namespace tensor
46+
{
47+
namespace py_internal
48+
{
49+
50+
using dpctl::utils::keep_args_alive;
51+
52+
typedef sycl::event (*zeros_contig_fn_ptr_t)(sycl::queue &,
53+
size_t,
54+
char *,
55+
const std::vector<sycl::event> &);
56+
57+
/*!
58+
* @brief Function to submit kernel to fill given contiguous memory allocation
59+
* with zeros.
60+
*
61+
* @param exec_q Sycl queue to which kernel is submitted for execution.
62+
* @param nelems Length of the sequence
63+
* @param dst_p Kernel accessible USM pointer to the start of array to be
64+
* populated.
65+
* @param depends List of events to wait for before starting computations, if
66+
* any.
67+
*
68+
* @return Event to wait on to ensure that computation completes.
69+
* @defgroup CtorKernels
70+
*/
71+
template <typename dstTy>
72+
sycl::event zeros_contig_impl(sycl::queue &exec_q,
73+
size_t nelems,
74+
char *dst_p,
75+
const std::vector<sycl::event> &depends)
76+
{
77+
78+
constexpr int memset_val(0);
79+
sycl::event fill_ev = exec_q.submit([&](sycl::handler &cgh) {
80+
cgh.depends_on(depends);
81+
82+
cgh.memset(reinterpret_cast<void *>(dst_p), memset_val,
83+
nelems * sizeof(dstTy));
84+
});
85+
86+
return fill_ev;
87+
}
88+
89+
template <typename fnT, typename Ty> struct ZerosContigFactory
90+
{
91+
fnT get()
92+
{
93+
fnT f = zeros_contig_impl<Ty>;
94+
return f;
95+
}
96+
};
97+
98+
static zeros_contig_fn_ptr_t zeros_contig_dispatch_vector[td_ns::num_types];
99+
100+
std::pair<sycl::event, sycl::event>
101+
usm_ndarray_zeros(const dpctl::tensor::usm_ndarray &dst,
102+
sycl::queue &exec_q,
103+
const std::vector<sycl::event> &depends)
104+
{
105+
// start, end should be coercible into data type of dst
106+
107+
py::ssize_t dst_nelems = dst.get_size();
108+
109+
if (dst_nelems == 0) {
110+
// nothing to do
111+
return std::make_pair(sycl::event(), sycl::event());
112+
}
113+
114+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst})) {
115+
throw py::value_error(
116+
"Execution queue is not compatible with the allocation queue");
117+
}
118+
119+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
120+
121+
auto array_types = td_ns::usm_ndarray_types();
122+
int dst_typenum = dst.get_typenum();
123+
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
124+
125+
char *dst_data = dst.get_data();
126+
127+
if (dst_nelems == 1 || dst.is_c_contiguous() || dst.is_f_contiguous()) {
128+
auto fn = zeros_contig_dispatch_vector[dst_typeid];
129+
130+
sycl::event zeros_contig_event =
131+
fn(exec_q, static_cast<size_t>(dst_nelems), dst_data, depends);
132+
133+
return std::make_pair(
134+
keep_args_alive(exec_q, {dst}, {zeros_contig_event}),
135+
zeros_contig_event);
136+
}
137+
else {
138+
throw std::runtime_error(
139+
"Only population of contiguous usm_ndarray objects is supported.");
140+
}
141+
}
142+
143+
void init_zeros_ctor_dispatch_vectors(void)
144+
{
145+
using namespace td_ns;
146+
147+
DispatchVectorBuilder<zeros_contig_fn_ptr_t, ZerosContigFactory, num_types>
148+
dvb;
149+
dvb.populate_dispatch_vector(zeros_contig_dispatch_vector);
150+
151+
return;
152+
}
153+
154+
} // namespace py_internal
155+
} // namespace tensor
156+
} // namespace dpctl
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2024 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===--------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
23+
//===--------------------------------------------------------------------===//
24+
25+
#pragma once
26+
#include <sycl/sycl.hpp>
27+
#include <utility>
28+
#include <vector>
29+
30+
#include "dpctl4pybind11.hpp"
31+
#include <pybind11/pybind11.h>
32+
33+
namespace dpctl
34+
{
35+
namespace tensor
36+
{
37+
namespace py_internal
38+
{
39+
40+
extern std::pair<sycl::event, sycl::event>
41+
usm_ndarray_zeros(const dpctl::tensor::usm_ndarray &dst,
42+
sycl::queue &exec_q,
43+
const std::vector<sycl::event> &depends = {});
44+
45+
extern void init_zeros_ctor_dispatch_vectors(void);
46+
47+
} // namespace py_internal
48+
} // namespace tensor
49+
} // namespace dpctl

0 commit comments

Comments
 (0)