Skip to content

Commit 429c368

Browse files
committed
Implements in-place addition
1 parent 69d420e commit 429c368

File tree

8 files changed

+794
-8
lines changed

8 files changed

+794
-8
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
_empty_like_pair_orderK,
3232
_find_buf_dtype,
3333
_find_buf_dtype2,
34+
_find_inplace_dtype,
3435
_to_device_supported_dtype,
3536
)
3637

@@ -331,11 +332,19 @@ class BinaryElementwiseFunc:
331332
Class that implements binary element-wise functions.
332333
"""
333334

334-
def __init__(self, name, result_type_resolver_fn, binary_dp_impl_fn, docs):
335+
def __init__(
336+
self,
337+
name,
338+
result_type_resolver_fn,
339+
binary_dp_impl_fn,
340+
docs,
341+
binary_inplace_fn=None,
342+
):
335343
self.__name__ = "BinaryElementwiseFunc"
336344
self.name_ = name
337345
self.result_type_resolver_fn_ = result_type_resolver_fn
338346
self.binary_fn_ = binary_dp_impl_fn
347+
self.binary_inplace_fn_ = binary_inplace_fn
339348
self.__doc__ = docs
340349

341350
def __str__(self):
@@ -631,3 +640,113 @@ def __call__(self, o1, o2, out=None, order="K"):
631640
)
632641
dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_])
633642
return out
643+
644+
def inplace(self, lhs, val):
645+
if self.binary_inplace_fn_ is None:
646+
raise ValueError(
647+
f"In-place operation not supported for ufunc '{self.name_}'"
648+
)
649+
if not isinstance(lhs, dpt.usm_ndarray):
650+
raise TypeError(
651+
f"Expected dpctl.tensor.usm_ndarray, got {type(lhs)}"
652+
)
653+
q1, lhs_usm_type = _get_queue_usm_type(lhs)
654+
q2, val_usm_type = _get_queue_usm_type(val)
655+
if q2 is None:
656+
exec_q = q1
657+
usm_type = lhs_usm_type
658+
else:
659+
exec_q = dpctl.utils.get_execution_queue((q1, q2))
660+
if exec_q is None:
661+
raise ExecutionPlacementError(
662+
"Execution placement can not be unambiguously inferred "
663+
"from input arguments."
664+
)
665+
usm_type = dpctl.utils.get_coerced_usm_type(
666+
(
667+
lhs_usm_type,
668+
val_usm_type,
669+
)
670+
)
671+
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
672+
lhs_shape = _get_shape(lhs)
673+
val_shape = _get_shape(val)
674+
if not all(
675+
isinstance(s, (tuple, list))
676+
for s in (
677+
lhs_shape,
678+
val_shape,
679+
)
680+
):
681+
raise TypeError(
682+
"Shape of arguments can not be inferred. "
683+
"Arguments are expected to be "
684+
)
685+
try:
686+
res_shape = _broadcast_shape_impl(
687+
[
688+
lhs_shape,
689+
val_shape,
690+
]
691+
)
692+
except ValueError:
693+
raise ValueError(
694+
"operands could not be broadcast together with shapes "
695+
f"{lhs_shape} and {val_shape}"
696+
)
697+
if res_shape != lhs_shape:
698+
raise ValueError(
699+
f"output shape {lhs_shape} does not match "
700+
f"broadcast shape {res_shape}"
701+
)
702+
sycl_dev = exec_q.sycl_device
703+
lhs_dtype = lhs.dtype
704+
val_dtype = _get_dtype(val, sycl_dev)
705+
if not _validate_dtype(val_dtype):
706+
raise ValueError("Input operand of unsupported type")
707+
708+
lhs_dtype, val_dtype = _resolve_weak_types(
709+
lhs_dtype, val_dtype, sycl_dev
710+
)
711+
712+
buf_dt = _find_inplace_dtype(
713+
lhs_dtype, val_dtype, self.result_type_resolver_fn_, sycl_dev
714+
)
715+
716+
if buf_dt is None:
717+
raise TypeError(
718+
f"function '{self.name_}' does not support input types "
719+
f"({lhs_dtype}, {val_dtype}), "
720+
"and the inputs could not be safely coerced to any "
721+
"supported types according to the casting rule ''safe''."
722+
)
723+
724+
if isinstance(val, dpt.usm_ndarray):
725+
rhs = val
726+
else:
727+
rhs = dpt.asarray(val, dtype=val_dtype, sycl_queue=exec_q)
728+
729+
if buf_dt == val_dtype:
730+
rhs = dpt.broadcast_to(rhs, res_shape)
731+
ht_, _ = self.binary_inplace_fn_(
732+
lhs=lhs, rhs=rhs, sycl_queue=exec_q
733+
)
734+
ht_.wait()
735+
736+
else:
737+
buf = dpt.empty_like(rhs, dtype=buf_dt)
738+
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
739+
src=rhs, dst=buf, sycl_queue=exec_q
740+
)
741+
742+
buf = dpt.broadcast_to(buf, res_shape)
743+
ht_, _ = self.binary_inplace_fn_(
744+
lhs=lhs,
745+
rhs=buf,
746+
sycl_queue=exec_q,
747+
depends=[copy_ev],
748+
)
749+
ht_copy_ev.wait()
750+
ht_.wait()
751+
752+
return lhs

dpctl/tensor/_elementwise_funcs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@
7373
returned array is determined by the Type Promotion Rules.
7474
"""
7575
add = BinaryElementwiseFunc(
76-
"add", ti._add_result_type, ti._add, _add_docstring_
76+
"add",
77+
ti._add_result_type,
78+
ti._add,
79+
_add_docstring_,
80+
binary_inplace_fn=ti._add_inplace,
7781
)
7882

7983
# U04: ===== ASIN (x)

dpctl/tensor/_type_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,27 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev):
294294
return None, None, None
295295

296296

297+
def _find_inplace_dtype(lhs_dtype, rhs_dtype, query_fn, sycl_dev):
298+
res_dt = query_fn(lhs_dtype, rhs_dtype)
299+
if res_dt and res_dt == lhs_dtype:
300+
return rhs_dtype
301+
302+
_fp16 = sycl_dev.has_aspect_fp16
303+
_fp64 = sycl_dev.has_aspect_fp64
304+
all_dts = _all_data_types(_fp16, _fp64)
305+
for buf_dt in all_dts:
306+
if _can_cast(rhs_dtype, buf_dt, _fp16, _fp64):
307+
res_dt = query_fn(lhs_dtype, buf_dt)
308+
if res_dt and res_dt == lhs_dtype:
309+
return buf_dt
310+
311+
return None
312+
313+
297314
__all__ = [
298315
"_find_buf_dtype",
299316
"_find_buf_dtype2",
317+
"_find_inplace_dtype",
300318
"_empty_like_orderK",
301319
"_empty_like_pair_orderK",
302320
"_to_device_supported_dtype",

dpctl/tensor/_usmarray.pyx

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,11 +1245,8 @@ cdef class usm_ndarray:
12451245
return _dispatch_binary_elementwise2(other, "logical_xor", self)
12461246

12471247
def __iadd__(self, other):
1248-
res = self.__add__(other)
1249-
if res is NotImplemented:
1250-
return res
1251-
self.__setitem__(Ellipsis, res)
1252-
return self
1248+
from ._elementwise_funcs import add
1249+
return add.inplace(self, other)
12531250

12541251
def __iand__(self, other):
12551252
res = self.__and__(other)

dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "utils/type_utils.hpp"
3535

3636
#include "kernels/elementwise_functions/common.hpp"
37+
#include "kernels/elementwise_functions/common_inplace.hpp"
3738
#include <pybind11/pybind11.h>
3839

3940
namespace dpctl
@@ -212,7 +213,6 @@ template <typename fnT, typename T1, typename T2> struct AddTypeMapFactory
212213
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
213214
{
214215
using rT = typename AddOutputType<T1, T2>::value_type;
215-
;
216216
return td_ns::GetTypeid<rT>{}.get();
217217
}
218218
};
@@ -364,6 +364,122 @@ struct AddContigRowContigMatrixBroadcastFactory
364364
}
365365
};
366366

367+
template <typename argT, typename resT> struct AddInplaceFunctor
368+
{
369+
370+
using supports_sg_loadstore = std::negation<
371+
std::disjunction<tu_ns::is_complex<argT>, tu_ns::is_complex<resT>>>;
372+
using supports_vec = std::negation<
373+
std::disjunction<tu_ns::is_complex<argT>, tu_ns::is_complex<resT>>>;
374+
375+
void operator()(resT &res, const argT &in)
376+
{
377+
res += in;
378+
}
379+
380+
template <int vec_sz>
381+
void operator()(sycl::vec<resT, vec_sz> &res,
382+
const sycl::vec<argT, vec_sz> &in)
383+
{
384+
res += in;
385+
}
386+
};
387+
388+
template <typename argT,
389+
typename resT,
390+
unsigned int vec_sz = 4,
391+
unsigned int n_vecs = 2>
392+
using AddInplaceContigFunctor = elementwise_common::BinaryInplaceContigFunctor<
393+
argT,
394+
resT,
395+
AddInplaceFunctor<argT, resT>,
396+
vec_sz,
397+
n_vecs>;
398+
399+
template <typename argT, typename resT, typename IndexerT>
400+
using AddInplaceStridedFunctor =
401+
elementwise_common::BinaryInplaceStridedFunctor<
402+
argT,
403+
resT,
404+
IndexerT,
405+
AddInplaceFunctor<argT, resT>>;
406+
407+
template <typename argT,
408+
typename resT,
409+
unsigned int vec_sz,
410+
unsigned int n_vecs>
411+
class add_inplace_contig_kernel;
412+
413+
template <typename argTy, typename resTy>
414+
sycl::event
415+
add_inplace_contig_impl(sycl::queue exec_q,
416+
size_t nelems,
417+
const char *arg_p,
418+
py::ssize_t arg_offset,
419+
char *res_p,
420+
py::ssize_t res_offset,
421+
const std::vector<sycl::event> &depends = {})
422+
{
423+
return elementwise_common::binary_inplace_contig_impl<
424+
argTy, resTy, AddInplaceContigFunctor, add_inplace_contig_kernel>(
425+
exec_q, nelems, arg_p, arg_offset, res_p, res_offset, depends);
426+
}
427+
428+
template <typename fnT, typename T1, typename T2> struct AddInplaceContigFactory
429+
{
430+
fnT get()
431+
{
432+
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
433+
void>) {
434+
fnT fn = nullptr;
435+
return fn;
436+
}
437+
else {
438+
fnT fn = add_inplace_contig_impl<T1, T2>;
439+
return fn;
440+
}
441+
}
442+
};
443+
444+
template <typename resT, typename argT, typename IndexerT>
445+
class add_inplace_strided_kernel;
446+
447+
template <typename argTy, typename resTy>
448+
sycl::event
449+
add_inplace_strided_impl(sycl::queue exec_q,
450+
size_t nelems,
451+
int nd,
452+
const py::ssize_t *shape_and_strides,
453+
const char *arg_p,
454+
py::ssize_t arg_offset,
455+
char *res_p,
456+
py::ssize_t res_offset,
457+
const std::vector<sycl::event> &depends,
458+
const std::vector<sycl::event> &additional_depends)
459+
{
460+
return elementwise_common::binary_inplace_strided_impl<
461+
argTy, resTy, AddInplaceStridedFunctor, add_inplace_strided_kernel>(
462+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
463+
res_offset, depends, additional_depends);
464+
}
465+
466+
template <typename fnT, typename T1, typename T2>
467+
struct AddInplaceStridedFactory
468+
{
469+
fnT get()
470+
{
471+
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
472+
void>) {
473+
fnT fn = nullptr;
474+
return fn;
475+
}
476+
else {
477+
fnT fn = add_inplace_strided_impl<T1, T2>;
478+
return fn;
479+
}
480+
}
481+
};
482+
367483
} // namespace add
368484
} // namespace kernels
369485
} // namespace tensor

0 commit comments

Comments
 (0)