Skip to content

Commit 37c7b2a

Browse files
committed
address reviewer's comments
1 parent a53090a commit 37c7b2a

File tree

2 files changed

+23
-33
lines changed

2 files changed

+23
-33
lines changed

dpctl/tensor/_elementwise_funcs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,8 @@
959959
_round_docstring = """
960960
round(x, out=None, order='K')
961961
962-
Computes cosine for each element `x_i` for input array `x`.
962+
Rounds each element `x_i` of the input array `x` to
963+
the nearest integer-valued number.
963964
964965
Args:
965966
x (usm_ndarray):

dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -65,46 +65,35 @@ template <typename argT, typename resT> struct RoundFunctor
6565

6666
resT operator()(const argT &in)
6767
{
68+
6869
if constexpr (std::is_integral_v<argT>) {
6970
return in;
7071
}
7172
else if constexpr (is_complex<argT>::value) {
7273
using realT = typename argT::value_type;
74+
return resT{round_func<realT>(std::real(in)),
75+
round_func<realT>(std::imag(in))};
76+
}
77+
else {
78+
return round_func<argT>(in);
79+
}
80+
}
7381

74-
const realT x = std::real(in);
75-
const realT y = std::imag(in);
76-
realT x_round, y_round;
77-
if (std::abs(x - std::floor(x)) == std::abs(x - std::ceil(x))) {
78-
x_round = static_cast<int>(std::ceil(x)) % 2 == 0
79-
? std::ceil(x)
80-
: std::floor(x);
81-
}
82-
else {
83-
x_round = std::round(x);
84-
}
85-
if (std::abs(y - std::floor(y)) == std::abs(y - std::ceil(y))) {
86-
y_round = static_cast<int>(std::ceil(y)) % 2 == 0
87-
? std::ceil(y)
88-
: std::floor(y);
89-
}
90-
else {
91-
y_round = std::round(y);
92-
}
93-
return resT{x_round, y_round};
82+
private:
83+
template <typename T> T round_func(T input) const
84+
{
85+
if (input == 0) {
86+
return input;
87+
}
88+
else if (std::abs(input - std::floor(input)) ==
89+
std::abs(input - std::ceil(input)))
90+
{
91+
return static_cast<int>(std::ceil(input)) % 2 == 0
92+
? std::ceil(input)
93+
: std::floor(input);
9494
}
9595
else {
96-
if (in == 0) {
97-
return in;
98-
}
99-
else if (std::abs(in - std::floor(in)) ==
100-
std::abs(in - std::ceil(in))) {
101-
return static_cast<int>(std::ceil(in)) % 2 == 0
102-
? std::ceil(in)
103-
: std::floor(in);
104-
}
105-
else {
106-
return std::round(in);
107-
}
96+
return std::round(input);
10897
}
10998
}
11099
};

0 commit comments

Comments
 (0)