Skip to content

Commit cee2e0a

Browse files
Add support float and complex<float> for not_equal
1 parent b5c9438 commit cee2e0a

File tree

1 file changed

+17
-1
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+17
-1
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,19 @@ template <typename argT1, typename argT2, typename resT> struct NotEqualFunctor
6262

6363
resT operator()(const argT1 &in1, const argT2 &in2)
6464
{
65-
return (in1 != in2);
65+
if constexpr (std::is_same_v<argT1, std::complex<float>> &&
66+
std::is_same_v<argT2, float>)
67+
{
68+
return (std::real(in1) != in2 || std::imag(in1) != 0.0f);
69+
}
70+
else if constexpr (std::is_same_v<argT1, float> &&
71+
std::is_same_v<argT2, std::complex<float>>)
72+
{
73+
return (in1 != std::real(in2) || std::imag(in2) != 0.0f);
74+
}
75+
else {
76+
return (in1 != in2);
77+
}
6678
}
6779

6880
template <int vec_sz>
@@ -146,6 +158,10 @@ template <typename T1, typename T2> struct NotEqualOutputType
146158
T2,
147159
std::complex<double>,
148160
bool>,
161+
td_ns::
162+
BinaryTypeMapResultEntry<T1, float, T2, std::complex<float>, bool>,
163+
td_ns::
164+
BinaryTypeMapResultEntry<T1, std::complex<float>, T2, float, bool>,
149165
td_ns::DefaultResultEntry<void>>::result_type;
150166
};
151167

0 commit comments

Comments
 (0)