Skip to content

Commit da954f9

Browse files
Added special case for _full_usm_ndarray
Bitwise zero values, and 1-byte wide types now use memset, instead of using fill. ``` In [1]: import dpctl.tensor as dpt, dpctl.tensor._tensor_impl as ti In [2]: res = dpt.empty(10**6, dtype="i8") In [3]: %timeit -n 2000 -r 11 ti._full_usm_ndarray(0, dst=res, sycl_queue=res.sycl_queue)[0].wait() 243 µs ± 22.6 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [4]: %timeit -n 2000 -r 11 ti._full_usm_ndarray(0, dst=res, sycl_queue=res.sycl_queue)[0].wait() 229 µs ± 14 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [5]: %timeit -n 2000 -r 11 ti._zeros_usm_ndarray(dst=res, sycl_queue=res.sycl_queue)[0].wait() 227 µs ± 23 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [6]: %timeit -n 2000 -r 11 ti._zeros_usm_ndarray(dst=res, sycl_queue=res.sycl_queue)[0].wait() 233 µs ± 25.9 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [7]: %timeit -n 2000 -r 11 ti._zeros_usm_ndarray(dst=res, sycl_queue=res.sycl_queue)[0].wait() 301 µs ± 54.1 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [8]: %timeit -n 2000 -r 11 ti._zeros_usm_ndarray(dst=res, sycl_queue=res.sycl_queue)[0].wait() 236 µs ± 17.2 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [9]: %timeit -n 2000 -r 11 ti._full_usm_ndarray(0, dst=res, sycl_queue=res.sycl_queue)[0].wait() 240 µs ± 35.2 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [10]: %timeit -n 2000 -r 11 ti._full_usm_ndarray(1, dst=res, sycl_queue=res.sycl_queue)[0].wait() 243 µs ± 17.6 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [11]: %timeit -n 2000 -r 11 ti._full_usm_ndarray(1, dst=res, sycl_queue=res.sycl_queue)[0].wait() 263 µs ± 39.9 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [12]: %timeit -n 2000 -r 11 ti._full_usm_ndarray(0, dst=res, sycl_queue=res.sycl_queue)[0].wait() 239 µs ± 26.4 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [13]: %timeit -n 2000 -r 11 ti._zeros_usm_ndarray(dst=res, sycl_queue=res.sycl_queue)[0].wait() 224 µs ± 18.1 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) ```
1 parent bec95f9 commit da954f9

File tree

1 file changed

+58
-4
lines changed

1 file changed

+58
-4
lines changed

dpctl/tensor/libtensor/source/full_ctor.cpp

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,65 @@ sycl::event full_contig_impl(sycl::queue &exec_q,
8080
{
8181
dstTy fill_v = py::cast<dstTy>(py_value);
8282

83-
using dpctl::tensor::kernels::constructors::full_contig_impl;
83+
sycl::event fill_ev;
8484

85-
sycl::event fill_ev =
86-
full_contig_impl<dstTy>(exec_q, nelems, fill_v, dst_p, depends);
85+
if constexpr (sizeof(dstTy) == sizeof(char)) {
86+
const auto memset_val = sycl::bit_cast<unsigned char>(fill_v);
87+
fill_ev = exec_q.submit([&](sycl::handler &cgh) {
88+
cgh.depends_on(depends);
89+
90+
cgh.memset(reinterpret_cast<void *>(dst_p), memset_val,
91+
nelems * sizeof(dstTy));
92+
});
93+
}
94+
else {
95+
bool is_zero = false;
96+
if constexpr (sizeof(dstTy) == 1) {
97+
is_zero = (std::uint8_t{0} == sycl::bit_cast<std::uint8_t>(fill_v));
98+
}
99+
else if constexpr (sizeof(dstTy) == 2) {
100+
is_zero =
101+
(std::uint16_t{0} == sycl::bit_cast<std::uint16_t>(fill_v));
102+
}
103+
else if constexpr (sizeof(dstTy) == 4) {
104+
is_zero =
105+
(std::uint32_t{0} == sycl::bit_cast<std::uint32_t>(fill_v));
106+
}
107+
else if constexpr (sizeof(dstTy) == 8) {
108+
is_zero =
109+
(std::uint64_t{0} == sycl::bit_cast<std::uint64_t>(fill_v));
110+
}
111+
else if constexpr (sizeof(dstTy) == 16) {
112+
struct UInt128
113+
{
114+
115+
constexpr UInt128() : v1{}, v2{} {}
116+
UInt128(const UInt128 &) = default;
117+
118+
operator bool() const { return bool(v1) && bool(v2); }
119+
120+
std::uint64_t v1;
121+
std::uint64_t v2;
122+
};
123+
is_zero = static_cast<bool>(sycl::bit_cast<UInt128>(fill_v));
124+
}
125+
126+
if (is_zero) {
127+
constexpr int memset_val = 0;
128+
fill_ev = exec_q.submit([&](sycl::handler &cgh) {
129+
cgh.depends_on(depends);
130+
131+
cgh.memset(reinterpret_cast<void *>(dst_p), memset_val,
132+
nelems * sizeof(dstTy));
133+
});
134+
}
135+
else {
136+
using dpctl::tensor::kernels::constructors::full_contig_impl;
137+
138+
fill_ev =
139+
full_contig_impl<dstTy>(exec_q, nelems, fill_v, dst_p, depends);
140+
}
141+
}
87142

88143
return fill_ev;
89144
}
@@ -126,7 +181,6 @@ usm_ndarray_full(const py::object &py_value,
126181
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
127182

128183
char *dst_data = dst.get_data();
129-
sycl::event full_event;
130184

131185
if (dst_nelems == 1 || dst.is_c_contiguous() || dst.is_f_contiguous()) {
132186
auto fn = full_contig_dispatch_vector[dst_typeid];

0 commit comments

Comments
 (0)