Skip to content

Commit 224455b

Browse files
ndgrigorianoleksandr-pavlyk
authored andcommitted
Implements logaddexp and hypot
1 parent 73a2b68 commit 224455b

File tree

7 files changed

+1140
-5
lines changed

7 files changed

+1140
-5
lines changed

dpctl/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
floor_divide,
106106
greater,
107107
greater_equal,
108+
hypot,
108109
imag,
109110
isfinite,
110111
isinf,
@@ -115,6 +116,7 @@
115116
log1p,
116117
log2,
117118
log10,
119+
logaddexp,
118120
logical_and,
119121
logical_not,
120122
logical_or,
@@ -222,6 +224,7 @@
222224
"floor_divide",
223225
"greater",
224226
"greater_equal",
227+
"hypot",
225228
"imag",
226229
"isfinite",
227230
"isinf",
@@ -241,6 +244,7 @@
241244
"not_equal",
242245
"positive",
243246
"pow",
247+
"logaddexp",
244248
"proj",
245249
"real",
246250
"sin",

dpctl/tensor/_elementwise_funcs.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,32 @@
661661
)
662662

663663
# B15: ==== LOGADDEXP (x1, x2)
664-
# FIXME: implement B15
664+
_logaddexp_docstring_ = """
665+
logaddexp(x1, x2, out=None, order='K')
666+
667+
Calculates the ratio for each element `x1_i` of the input array `x1` with
668+
the respective element `x2_i` of the input array `x2`.
669+
670+
Args:
671+
x1 (usm_ndarray):
672+
First input array, expected to have numeric data type.
673+
x2 (usm_ndarray):
674+
Second input array, also expected to have numeric data type.
675+
out ({None, usm_ndarray}, optional):
676+
Output array to populate.
677+
Array have the correct shape and the expected data type.
678+
order ("C","F","A","K", optional):
679+
Memory layout of the newly output array, if parameter `out` is `None`.
680+
Default: "K".
681+
Returns:
682+
usm_narray:
683+
An array containing the result of element-wise division. The data type
684+
of the returned array is determined by the Type Promotion Rules.
685+
"""
686+
687+
logaddexp = BinaryElementwiseFunc(
688+
"logaddexp", ti._logaddexp_result_type, ti._logaddexp, _logaddexp_docstring_
689+
)
665690

666691
# B16: ==== LOGICAL_AND (x1, x2)
667692
_logical_and_docstring_ = """
@@ -1094,12 +1119,40 @@
10941119
order ("C","F","A","K", optional):
10951120
Memory layout of the newly output array, if parameter `out` is `None`.
10961121
Default: "K".
1122+
Returns:
1123+
usm_narray:
1124+
An array containing the result of element-wise division. The data type
1125+
of the returned array is determined by the Type Promotion Rules.
1126+
"""
1127+
trunc = UnaryElementwiseFunc(
1128+
"trunc", ti._trunc_result_type, ti._trunc, _trunc_docstring
1129+
)
1130+
1131+
1132+
# B24: ==== HYPOT (x1, x2)
1133+
_hypot_docstring_ = """
1134+
hypot(x1, x2, out=None, order='K')
1135+
1136+
Calculates the ratio for each element `x1_i` of the input array `x1` with
1137+
the respective element `x2_i` of the input array `x2`.
1138+
1139+
Args:
1140+
x1 (usm_ndarray):
1141+
First input array, expected to have numeric data type.
1142+
x2 (usm_ndarray):
1143+
Second input array, also expected to have numeric data type.
1144+
out ({None, usm_ndarray}, optional):
1145+
Output array to populate.
1146+
Array have the correct shape and the expected data type.
1147+
order ("C","F","A","K", optional):
1148+
Memory layout of the newly output array, if parameter `out` is `None`.
1149+
Default: "K".
10971150
Returns:
10981151
usm_narray:
10991152
An array containing the element-wise truncated value of input array.
11001153
The returned array has the same data type as `x`.
11011154
"""
11021155

1103-
trunc = UnaryElementwiseFunc(
1104-
"trunc", ti._trunc_result_type, ti._trunc, _trunc_docstring
1156+
hypot = BinaryElementwiseFunc(
1157+
"hypot", ti._hypot_result_type, ti._hypot, _hypot_docstring_
11051158
)
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
//=== HYPOT.hpp - Binary function HYPOT ------ *-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===---------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for elementwise evaluation of HYPOT(x1, x2)
23+
/// function.
24+
//===---------------------------------------------------------------------===//
25+
26+
#pragma once
27+
#include <CL/sycl.hpp>
28+
#include <cstddef>
29+
#include <cstdint>
30+
#include <type_traits>
31+
32+
#include "utils/offset_utils.hpp"
33+
#include "utils/type_dispatch.hpp"
34+
#include "utils/type_utils.hpp"
35+
36+
#include "kernels/elementwise_functions/common.hpp"
37+
#include <pybind11/pybind11.h>
38+
39+
namespace dpctl
40+
{
41+
namespace tensor
42+
{
43+
namespace kernels
44+
{
45+
namespace hypot
46+
{
47+
48+
namespace py = pybind11;
49+
namespace td_ns = dpctl::tensor::type_dispatch;
50+
namespace tu_ns = dpctl::tensor::type_utils;
51+
52+
template <typename argT1, typename argT2, typename resT> struct HypotFunctor
53+
{
54+
55+
using supports_sg_loadstore = std::negation<
56+
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
57+
using supports_vec = std::negation<
58+
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
59+
60+
resT operator()(const argT1 &in1, const argT2 &in2)
61+
{
62+
return std::hypot(in1, in2);
63+
}
64+
65+
template <int vec_sz>
66+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
67+
const sycl::vec<argT2, vec_sz> &in2)
68+
{
69+
auto res = sycl::hypot(in1, in2);
70+
if constexpr (std::is_same_v<resT,
71+
typename decltype(res)::element_type>) {
72+
return res;
73+
}
74+
else {
75+
using dpctl::tensor::type_utils::vec_cast;
76+
77+
return vec_cast<resT, typename decltype(res)::element_type, vec_sz>(
78+
res);
79+
}
80+
}
81+
};
82+
83+
template <typename argT1,
84+
typename argT2,
85+
typename resT,
86+
unsigned int vec_sz = 4,
87+
unsigned int n_vecs = 2>
88+
using HypotContigFunctor =
89+
elementwise_common::BinaryContigFunctor<argT1,
90+
argT2,
91+
resT,
92+
HypotFunctor<argT1, argT2, resT>,
93+
vec_sz,
94+
n_vecs>;
95+
96+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
97+
using HypotStridedFunctor =
98+
elementwise_common::BinaryStridedFunctor<argT1,
99+
argT2,
100+
resT,
101+
IndexerT,
102+
HypotFunctor<argT1, argT2, resT>>;
103+
104+
template <typename T1, typename T2> struct HypotOutputType
105+
{
106+
using value_type = typename std::disjunction< // disjunction is C++17
107+
// feature, supported by DPC++
108+
td_ns::BinaryTypeMapResultEntry<T1,
109+
sycl::half,
110+
T2,
111+
sycl::half,
112+
sycl::half>,
113+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, float>,
114+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, double>,
115+
td_ns::DefaultResultEntry<void>>::result_type;
116+
};
117+
118+
template <typename argT1,
119+
typename argT2,
120+
typename resT,
121+
unsigned int vec_sz,
122+
unsigned int n_vecs>
123+
class hypot_contig_kernel;
124+
125+
template <typename argTy1, typename argTy2>
126+
sycl::event hypot_contig_impl(sycl::queue exec_q,
127+
size_t nelems,
128+
const char *arg1_p,
129+
py::ssize_t arg1_offset,
130+
const char *arg2_p,
131+
py::ssize_t arg2_offset,
132+
char *res_p,
133+
py::ssize_t res_offset,
134+
const std::vector<sycl::event> &depends = {})
135+
{
136+
return elementwise_common::binary_contig_impl<
137+
argTy1, argTy2, HypotOutputType, HypotContigFunctor,
138+
hypot_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p,
139+
arg2_offset, res_p, res_offset, depends);
140+
}
141+
142+
template <typename fnT, typename T1, typename T2> struct HypotContigFactory
143+
{
144+
fnT get()
145+
{
146+
if constexpr (std::is_same_v<
147+
typename HypotOutputType<T1, T2>::value_type, void>)
148+
{
149+
fnT fn = nullptr;
150+
return fn;
151+
}
152+
else {
153+
fnT fn = hypot_contig_impl<T1, T2>;
154+
return fn;
155+
}
156+
}
157+
};
158+
159+
template <typename fnT, typename T1, typename T2> struct HypotTypeMapFactory
160+
{
161+
/*! @brief get typeid for output type of std::hypot(T1 x, T2 y) */
162+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
163+
{
164+
using rT = typename HypotOutputType<T1, T2>::value_type;
165+
;
166+
return td_ns::GetTypeid<rT>{}.get();
167+
}
168+
};
169+
170+
template <typename T1, typename T2, typename resT, typename IndexerT>
171+
class hypot_strided_strided_kernel;
172+
173+
template <typename argTy1, typename argTy2>
174+
sycl::event
175+
hypot_strided_impl(sycl::queue exec_q,
176+
size_t nelems,
177+
int nd,
178+
const py::ssize_t *shape_and_strides,
179+
const char *arg1_p,
180+
py::ssize_t arg1_offset,
181+
const char *arg2_p,
182+
py::ssize_t arg2_offset,
183+
char *res_p,
184+
py::ssize_t res_offset,
185+
const std::vector<sycl::event> &depends,
186+
const std::vector<sycl::event> &additional_depends)
187+
{
188+
return elementwise_common::binary_strided_impl<
189+
argTy1, argTy2, HypotOutputType, HypotStridedFunctor,
190+
hypot_strided_strided_kernel>(
191+
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
192+
arg2_offset, res_p, res_offset, depends, additional_depends);
193+
}
194+
195+
template <typename fnT, typename T1, typename T2> struct HypotStridedFactory
196+
{
197+
fnT get()
198+
{
199+
if constexpr (std::is_same_v<
200+
typename HypotOutputType<T1, T2>::value_type, void>)
201+
{
202+
fnT fn = nullptr;
203+
return fn;
204+
}
205+
else {
206+
fnT fn = hypot_strided_impl<T1, T2>;
207+
return fn;
208+
}
209+
}
210+
};
211+
212+
} // namespace hypot
213+
} // namespace kernels
214+
} // namespace tensor
215+
} // namespace dpctl

0 commit comments

Comments
 (0)