Skip to content

Commit cfba263

Browse files
Merge pull request #1806 from IntelPython/fixme-async-memset
2 parents 4297fef + 640e706 commit cfba263

File tree

7 files changed

+298
-6
lines changed

7 files changed

+298
-6
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ set(_tensor_impl_sources
129129
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_advanced_indexing.cpp
130130
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/eye_ctor.cpp
131131
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
132+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/zeros_ctor.cpp
132133
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp
133134
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp
134135
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp

dpctl/tensor/_ctors.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -945,8 +945,11 @@ def zeros(
945945
order=order,
946946
buffer_ctor_kwargs={"queue": sycl_queue},
947947
)
948-
# FIXME: replace with asynchronous call to ti
949-
res.usm_data.memset()
948+
_manager = dpctl.utils.SequentialOrderManager[sycl_queue]
949+
# populating new allocation, no dependent events
950+
hev, zeros_ev = ti._zeros_usm_ndarray(res, sycl_queue)
951+
_manager.add_event_pair(hev, zeros_ev)
952+
950953
return res
951954

952955

dpctl/tensor/libtensor/source/full_ctor.cpp

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,65 @@ sycl::event full_contig_impl(sycl::queue &exec_q,
8080
{
8181
dstTy fill_v = py::cast<dstTy>(py_value);
8282

83-
using dpctl::tensor::kernels::constructors::full_contig_impl;
83+
sycl::event fill_ev;
8484

85-
sycl::event fill_ev =
86-
full_contig_impl<dstTy>(exec_q, nelems, fill_v, dst_p, depends);
85+
if constexpr (sizeof(dstTy) == sizeof(char)) {
86+
const auto memset_val = sycl::bit_cast<unsigned char>(fill_v);
87+
fill_ev = exec_q.submit([&](sycl::handler &cgh) {
88+
cgh.depends_on(depends);
89+
90+
cgh.memset(reinterpret_cast<void *>(dst_p), memset_val,
91+
nelems * sizeof(dstTy));
92+
});
93+
}
94+
else {
95+
bool is_zero = false;
96+
if constexpr (sizeof(dstTy) == 1) {
97+
is_zero = (std::uint8_t{0} == sycl::bit_cast<std::uint8_t>(fill_v));
98+
}
99+
else if constexpr (sizeof(dstTy) == 2) {
100+
is_zero =
101+
(std::uint16_t{0} == sycl::bit_cast<std::uint16_t>(fill_v));
102+
}
103+
else if constexpr (sizeof(dstTy) == 4) {
104+
is_zero =
105+
(std::uint32_t{0} == sycl::bit_cast<std::uint32_t>(fill_v));
106+
}
107+
else if constexpr (sizeof(dstTy) == 8) {
108+
is_zero =
109+
(std::uint64_t{0} == sycl::bit_cast<std::uint64_t>(fill_v));
110+
}
111+
else if constexpr (sizeof(dstTy) == 16) {
112+
struct UInt128
113+
{
114+
115+
constexpr UInt128() : v1{}, v2{} {}
116+
UInt128(const UInt128 &) = default;
117+
118+
operator bool() const { return bool(!v1) && bool(!v2); }
119+
120+
std::uint64_t v1;
121+
std::uint64_t v2;
122+
};
123+
is_zero = static_cast<bool>(sycl::bit_cast<UInt128>(fill_v));
124+
}
125+
126+
if (is_zero) {
127+
constexpr int memset_val = 0;
128+
fill_ev = exec_q.submit([&](sycl::handler &cgh) {
129+
cgh.depends_on(depends);
130+
131+
cgh.memset(reinterpret_cast<void *>(dst_p), memset_val,
132+
nelems * sizeof(dstTy));
133+
});
134+
}
135+
else {
136+
using dpctl::tensor::kernels::constructors::full_contig_impl;
137+
138+
fill_ev =
139+
full_contig_impl<dstTy>(exec_q, nelems, fill_v, dst_p, depends);
140+
}
141+
}
87142

88143
return fill_ev;
89144
}
@@ -126,7 +181,6 @@ usm_ndarray_full(const py::object &py_value,
126181
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
127182

128183
char *dst_data = dst.get_data();
129-
sycl::event full_event;
130184

131185
if (dst_nelems == 1 || dst.is_c_contiguous() || dst.is_f_contiguous()) {
132186
auto fn = full_contig_dispatch_vector[dst_typeid];

dpctl/tensor/libtensor/source/tensor_ctors.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
#include "utils/memory_overlap.hpp"
5555
#include "utils/strided_iters.hpp"
5656
#include "where.hpp"
57+
#include "zeros_ctor.hpp"
5758

5859
namespace py = pybind11;
5960

@@ -92,6 +93,10 @@ using dpctl::tensor::py_internal::usm_ndarray_linear_sequence_step;
9293

9394
using dpctl::tensor::py_internal::usm_ndarray_full;
9495

96+
/* ================ Zeros ================== */
97+
98+
using dpctl::tensor::py_internal::usm_ndarray_zeros;
99+
95100
/* ============== Advanced Indexing ============= */
96101
using dpctl::tensor::py_internal::usm_ndarray_put;
97102
using dpctl::tensor::py_internal::usm_ndarray_take;
@@ -142,6 +147,7 @@ void init_dispatch_vectors(void)
142147
init_copy_for_roll_dispatch_vectors();
143148
init_linear_sequences_dispatch_vectors();
144149
init_full_ctor_dispatch_vectors();
150+
init_zeros_ctor_dispatch_vectors();
145151
init_eye_ctor_dispatch_vectors();
146152
init_triul_ctor_dispatch_vectors();
147153

@@ -291,6 +297,10 @@ PYBIND11_MODULE(_tensor_impl, m)
291297
py::arg("src"), py::arg("dst"), py::arg("sycl_queue"),
292298
py::arg("depends") = py::list());
293299

300+
m.def("_zeros_usm_ndarray", &usm_ndarray_zeros,
301+
"Populate usm_ndarray `dst` with zeros.", py::arg("dst"),
302+
py::arg("sycl_queue"), py::arg("depends") = py::list());
303+
294304
m.def("_full_usm_ndarray", &usm_ndarray_full,
295305
"Populate usm_ndarray `dst` with given fill_value.",
296306
py::arg("fill_value"), py::arg("dst"), py::arg("sycl_queue"),
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2024 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===--------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
23+
//===--------------------------------------------------------------------===//
24+
25+
#include "dpctl4pybind11.hpp"
26+
#include <complex>
27+
#include <pybind11/complex.h>
28+
#include <pybind11/pybind11.h>
29+
#include <sycl/sycl.hpp>
30+
#include <utility>
31+
#include <vector>
32+
33+
#include "kernels/constructors.hpp"
34+
#include "utils/output_validation.hpp"
35+
#include "utils/type_dispatch.hpp"
36+
#include "utils/type_utils.hpp"
37+
38+
#include "zeros_ctor.hpp"
39+
40+
namespace py = pybind11;
41+
namespace td_ns = dpctl::tensor::type_dispatch;
42+
43+
namespace dpctl
44+
{
45+
namespace tensor
46+
{
47+
namespace py_internal
48+
{
49+
50+
using dpctl::utils::keep_args_alive;
51+
52+
typedef sycl::event (*zeros_contig_fn_ptr_t)(sycl::queue &,
53+
size_t,
54+
char *,
55+
const std::vector<sycl::event> &);
56+
57+
/*!
58+
* @brief Function to submit kernel to fill given contiguous memory allocation
59+
* with zeros.
60+
*
61+
* @param exec_q Sycl queue to which kernel is submitted for execution.
62+
* @param nelems Length of the sequence
63+
* @param dst_p Kernel accessible USM pointer to the start of array to be
64+
* populated.
65+
* @param depends List of events to wait for before starting computations, if
66+
* any.
67+
*
68+
* @return Event to wait on to ensure that computation completes.
69+
* @defgroup CtorKernels
70+
*/
71+
template <typename dstTy>
72+
sycl::event zeros_contig_impl(sycl::queue &exec_q,
73+
size_t nelems,
74+
char *dst_p,
75+
const std::vector<sycl::event> &depends)
76+
{
77+
78+
constexpr int memset_val(0);
79+
sycl::event fill_ev = exec_q.submit([&](sycl::handler &cgh) {
80+
cgh.depends_on(depends);
81+
82+
cgh.memset(reinterpret_cast<void *>(dst_p), memset_val,
83+
nelems * sizeof(dstTy));
84+
});
85+
86+
return fill_ev;
87+
}
88+
89+
template <typename fnT, typename Ty> struct ZerosContigFactory
90+
{
91+
fnT get()
92+
{
93+
fnT f = zeros_contig_impl<Ty>;
94+
return f;
95+
}
96+
};
97+
98+
static zeros_contig_fn_ptr_t zeros_contig_dispatch_vector[td_ns::num_types];
99+
100+
std::pair<sycl::event, sycl::event>
101+
usm_ndarray_zeros(const dpctl::tensor::usm_ndarray &dst,
102+
sycl::queue &exec_q,
103+
const std::vector<sycl::event> &depends)
104+
{
105+
py::ssize_t dst_nelems = dst.get_size();
106+
107+
if (dst_nelems == 0) {
108+
// nothing to do
109+
return std::make_pair(sycl::event(), sycl::event());
110+
}
111+
112+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst})) {
113+
throw py::value_error(
114+
"Execution queue is not compatible with the allocation queue");
115+
}
116+
117+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
118+
119+
auto array_types = td_ns::usm_ndarray_types();
120+
int dst_typenum = dst.get_typenum();
121+
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
122+
123+
char *dst_data = dst.get_data();
124+
125+
if (dst_nelems == 1 || dst.is_c_contiguous() || dst.is_f_contiguous()) {
126+
auto fn = zeros_contig_dispatch_vector[dst_typeid];
127+
128+
sycl::event zeros_contig_event =
129+
fn(exec_q, static_cast<size_t>(dst_nelems), dst_data, depends);
130+
131+
return std::make_pair(
132+
keep_args_alive(exec_q, {dst}, {zeros_contig_event}),
133+
zeros_contig_event);
134+
}
135+
else {
136+
throw std::runtime_error(
137+
"Only population of contiguous usm_ndarray objects is supported.");
138+
}
139+
}
140+
141+
void init_zeros_ctor_dispatch_vectors(void)
142+
{
143+
using namespace td_ns;
144+
145+
DispatchVectorBuilder<zeros_contig_fn_ptr_t, ZerosContigFactory, num_types>
146+
dvb;
147+
dvb.populate_dispatch_vector(zeros_contig_dispatch_vector);
148+
149+
return;
150+
}
151+
152+
} // namespace py_internal
153+
} // namespace tensor
154+
} // namespace dpctl
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2024 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===--------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
23+
//===--------------------------------------------------------------------===//
24+
25+
#pragma once
26+
#include <sycl/sycl.hpp>
27+
#include <utility>
28+
#include <vector>
29+
30+
#include "dpctl4pybind11.hpp"
31+
#include <pybind11/pybind11.h>
32+
33+
namespace dpctl
34+
{
35+
namespace tensor
36+
{
37+
namespace py_internal
38+
{
39+
40+
extern std::pair<sycl::event, sycl::event>
41+
usm_ndarray_zeros(const dpctl::tensor::usm_ndarray &dst,
42+
sycl::queue &exec_q,
43+
const std::vector<sycl::event> &depends = {});
44+
45+
extern void init_zeros_ctor_dispatch_vectors(void);
46+
47+
} // namespace py_internal
48+
} // namespace tensor
49+
} // namespace dpctl

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,6 +1682,27 @@ def test_full(dtype):
16821682
assert np.array_equal(dpt.asnumpy(X), np.full(10, 4, dtype=dtype))
16831683

16841684

1685+
def test_full_cmplx128():
1686+
q = get_queue_or_skip()
1687+
dtype = "c16"
1688+
skip_if_dtype_not_supported(dtype, q)
1689+
fill_v = 1 + 1j
1690+
X = dpt.full(tuple(), fill_value=fill_v, dtype=dtype, sycl_queue=q)
1691+
assert np.array_equal(
1692+
dpt.asnumpy(X), np.full(tuple(), fill_value=fill_v, dtype=dtype)
1693+
)
1694+
fill_v = 0 + 1j
1695+
X = dpt.full(tuple(), fill_value=fill_v, dtype=dtype, sycl_queue=q)
1696+
assert np.array_equal(
1697+
dpt.asnumpy(X), np.full(tuple(), fill_value=fill_v, dtype=dtype)
1698+
)
1699+
fill_v = 0 + 0j
1700+
X = dpt.full(tuple(), fill_value=fill_v, dtype=dtype, sycl_queue=q)
1701+
assert np.array_equal(
1702+
dpt.asnumpy(X), np.full(tuple(), fill_value=fill_v, dtype=dtype)
1703+
)
1704+
1705+
16851706
def test_full_dtype_inference():
16861707
try:
16871708
X = dpt.full(10, 4)

0 commit comments

Comments
 (0)