Skip to content

Commit 28993dd

Browse files
Use oneapi extension for complexes for remaining elementwise functions
Used functions from sycl::ext::oneapi::experimental context to implement evaluation on data of complex type.
1 parent 41d940e commit 28993dd

File tree

21 files changed

+105
-33
lines changed

21 files changed

+105
-33
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cstddef>
2928
#include <cstdint>
29+
#include <sycl/sycl.hpp>
3030
#include <type_traits>
3131

3232
#include "utils/math_utils.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cstddef>
2928
#include <cstdint>
29+
#include <sycl/sycl.hpp>
3030
#include <type_traits>
3131

3232
#include "utils/math_utils.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cstddef>
2928
#include <cstdint>
29+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
30+
#include <sycl/sycl.hpp>
3031
#include <type_traits>
3132

3233
#include "utils/offset_utils.hpp"
@@ -49,6 +50,7 @@ namespace multiply
4950
namespace py = pybind11;
5051
namespace td_ns = dpctl::tensor::type_dispatch;
5152
namespace tu_ns = dpctl::tensor::type_utils;
53+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5254

5355
template <typename argT1, typename argT2, typename resT> struct MultiplyFunctor
5456
{
@@ -60,7 +62,18 @@ template <typename argT1, typename argT2, typename resT> struct MultiplyFunctor
6062

6163
resT operator()(const argT1 &in1, const argT2 &in2) const
6264
{
63-
return in1 * in2;
65+
if constexpr (tu_ns::is_complex<argT1>::value &&
66+
tu_ns::is_complex<argT2>::value)
67+
{
68+
using realT1 = typename argT1::value_type;
69+
using realT2 = typename argT2::value_type;
70+
71+
return exprm_ns::complex<realT1>(in1) *
72+
exprm_ns::complex<realT2>(in2);
73+
}
74+
else {
75+
return in1 * in2;
76+
}
6477
}
6578

6679
template <int vec_sz>

dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cmath>
2928
#include <cstddef>
3029
#include <cstdint>
30+
#include <sycl/sycl.hpp>
3131
#include <type_traits>
3232

3333
#include "kernels/elementwise_functions/common.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cstddef>
2928
#include <cstdint>
29+
#include <sycl/sycl.hpp>
3030
#include <type_traits>
3131

3232
#include "utils/offset_utils.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cmath>
2928
#include <cstddef>
3029
#include <cstdint>
30+
#include <sycl/sycl.hpp>
3131
#include <type_traits>
3232

3333
#include "kernels/elementwise_functions/common.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cstddef>
2928
#include <cstdint>
3029
#include <limits>
30+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
31+
#include <sycl/sycl.hpp>
3132
#include <type_traits>
3233

3334
#include "utils/offset_utils.hpp"
@@ -49,6 +50,7 @@ namespace pow
4950
namespace py = pybind11;
5051
namespace td_ns = dpctl::tensor::type_dispatch;
5152
namespace tu_ns = dpctl::tensor::type_utils;
53+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5254

5355
template <typename argT1, typename argT2, typename resT> struct PowFunctor
5456
{
@@ -83,6 +85,15 @@ template <typename argT1, typename argT2, typename resT> struct PowFunctor
8385
}
8486
return res;
8587
}
88+
else if constexpr (tu_ns::is_complex<argT1>::value &&
89+
tu_ns::is_complex<argT2>::value)
90+
{
91+
using realT1 = typename argT1::value_type;
92+
using realT2 = typename argT2::value_type;
93+
94+
return exprm_ns::pow(exprm_ns::complex<realT1>(in1),
95+
exprm_ns::complex<realT2>(in2));
96+
}
8697
else {
8798
return std::pow(in1, in2);
8899
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cmath>
2928
#include <complex>
3029
#include <cstddef>
3130
#include <cstdint>
3231
#include <limits>
32+
#include <sycl/sycl.hpp>
3333
#include <type_traits>
3434

3535
#include "kernels/elementwise_functions/common.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cmath>
2928
#include <complex>
3029
#include <cstddef>
3130
#include <cstdint>
31+
#include <sycl/sycl.hpp>
3232
#include <type_traits>
3333

3434
#include "kernels/elementwise_functions/common.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
//===---------------------------------------------------------------------===//
2626

2727
#pragma once
28-
#include <CL/sycl.hpp>
2928
#include <cstddef>
3029
#include <cstdint>
30+
#include <sycl/sycl.hpp>
3131
#include <type_traits>
3232

3333
#include "utils/offset_utils.hpp"

0 commit comments

Comments
 (0)