From 9eca5bcc6e2f4d1373e45c449eccb5ad6b7aff48 Mon Sep 17 00:00:00 2001 From: jle-quel Date: Tue, 14 Mar 2023 15:14:47 +0100 Subject: [PATCH 01/39] [SYCL][MARRAY] introduce sycl complex marray specialization --- .../ext/oneapi/experimental/sycl_complex.hpp | 424 +++++++++++++++++- 1 file changed, 421 insertions(+), 3 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/sycl_complex.hpp b/sycl/include/sycl/ext/oneapi/experimental/sycl_complex.hpp index 0f8223286b852..a6c509b6d2d83 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/sycl_complex.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/sycl_complex.hpp @@ -61,9 +61,7 @@ template struct __numeric_type { static const bool value = _IsNotSame::value; }; -template <> struct __numeric_type { - static const bool value = true; -}; +template <> struct __numeric_type { static const bool value = true; }; template ::value &&__numeric_type<_A2>::value @@ -983,6 +981,426 @@ tan(const complex<_Tp> &__x) { } // namespace oneapi } // namespace ext +template +class marray, NumElements> { +private: + using ComplexDataT = sycl::ext::oneapi::experimental::complex; + +public: + using value_type = ComplexDataT; + using reference = ComplexDataT &; + using const_reference = const ComplexDataT &; + using iterator = ComplexDataT *; + using const_iterator = const ComplexDataT *; + +private: + value_type MData[NumElements]; + +public: + constexpr marray() : MData{} {}; + + explicit constexpr marray(const ComplexDataT &arg) { + for (size_t i = 0; i < NumElements; ++i) { + MData[i] = arg; + } + } + + template + constexpr marray(const ArgTN &... args) : MData{args...} {}; + + constexpr marray(const marray &rhs) = default; + constexpr marray(marray &&rhs) = default; + + // Available only when: NumElements == 1 + template > + operator ComplexDataT() const { + return MData[0]; + } + + static constexpr std::size_t size() noexcept { return NumElements; } + + marray real() const { + sycl::marray rtn; + for (std::size_t i = 0; i < NumElements; ++i) { + rtn[i] = MData[i].real(); + } + return rtn; + } + + marray imag() const { + sycl::marray rtn; + for (std::size_t i = 0; i < NumElements; ++i) { + rtn[i] = MData[i].imag(); + } + return rtn; + } + + // subscript operator + reference operator[](std::size_t i) { return MData[i]; } + const_reference operator[](std::size_t i) const { return MData[i]; } + + marray &operator=(const marray &rhs) = default; + marray &operator=(const ComplexDataT &rhs) { + for (std::size_t i = 0; i < NumElements; ++i) { + MData[i] = rhs; + } + return *this; + } + + // iterator functions + iterator begin() { return MData; } + const_iterator begin() const { return MData; } + + iterator end() { return MData + NumElements; } + const_iterator end() const { return MData + NumElements; } + + // OP is: +, -, *, / +#define OP(op) \ + friend marray operator op(const marray &lhs, const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = lhs[i] op rhs[i]; \ + } \ + return rtn; \ + } \ + \ + friend marray operator op(const marray &lhs, const ComplexDataT &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = lhs[i] op rhs; \ + } \ + return rtn; \ + } \ + \ + friend marray operator op(const ComplexDataT &lhs, const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = lhs op rhs[i]; \ + } \ + return rtn; \ + } + + OP(+) + OP(-) + OP(*) + OP(/) + +#undef OP + + // OP is: % + friend marray operator%(const marray &lhs, const marray &rhs) = delete; + friend marray operator%(const marray &lhs, const ComplexDataT &rhs) = delete; + friend marray operator%(const ComplexDataT &lhs, const marray &rhs) = delete; + + // OP is: +=, -=, *=, /= +#define OP(op) \ + friend marray &operator op(marray &lhs, const marray &rhs) { \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + lhs[i] op rhs[i]; \ + } \ + return lhs; \ + } \ + \ + friend marray &operator op(marray &lhs, const ComplexDataT &rhs) { \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + lhs[i] op rhs; \ + } \ + return lhs; \ + } \ + friend marray &operator op(ComplexDataT &lhs, const marray &rhs) { \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + lhs[i] op rhs; \ + } \ + return lhs; \ + } + + OP(+=) + OP(-=) + OP(*=) + OP(/=) + +#undef OP + + // OP is: %= + friend marray &operator%=(marray &lhs, const marray &rhs) = delete; + friend marray &operator%=(marray &lhs, const ComplexDataT &rhs) = delete; + friend marray &operator%=(ComplexDataT &lhs, const marray &rhs) = delete; + +// OP is: ++, -- +#define OP(op) \ + friend marray operator op(marray &lhs, int) = delete; \ + friend marray &operator op(marray &rhs) = delete; + + OP(++) + OP(--) + +#undef OP + +// OP is: unary +, unary - +#define OP(op) \ + friend marray operator op( \ + const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = op rhs[i]; \ + } \ + return rtn; \ + } + + OP(+) + OP(-) + +#undef OP + +// OP is: &, |, ^ +#define OP(op) \ + friend marray operator op(const marray &lhs, const marray &rhs) = delete; \ + friend marray operator op(const marray &lhs, const ComplexDataT &rhs) = \ + delete; + + OP(&) + OP(|) + OP(^) + +#undef OP + +// OP is: &=, |=, ^= +#define OP(op) \ + friend marray &operator op(marray &lhs, const marray &rhs) = delete; \ + friend marray &operator op(marray &lhs, const ComplexDataT &rhs) = delete; \ + friend marray &operator op(ComplexDataT &lhs, const marray &rhs) = delete; + + OP(&=) + OP(|=) + OP(^=) + +#undef OP + +// OP is: &&, || +#define OP(op) \ + friend marray operator op(const marray &lhs, \ + const marray &rhs) = delete; \ + friend marray operator op( \ + const marray &lhs, const ComplexDataT &rhs) = delete; \ + friend marray operator op(const ComplexDataT &lhs, \ + const marray &rhs) = delete; + + OP(&&) + OP(||) + +#undef OP + +// OP is: <<, >> +#define OP(op) \ + friend marray operator op(const marray &lhs, const marray &rhs) = delete; \ + friend marray operator op(const marray &lhs, const ComplexDataT &rhs) = \ + delete; \ + friend marray operator op(const ComplexDataT &lhs, const marray &rhs) = \ + delete; + + OP(<<) + OP(>>) + +#undef OP + +// OP is: <<=, >>= +#define OP(op) \ + friend marray &operator op(marray &lhs, const marray &rhs) = delete; \ + friend marray &operator op(marray &lhs, const ComplexDataT &rhs) = delete; + + OP(<<=) + OP(>>=) + +#undef OP + + // OP is: ==, != +#define OP(op) \ + friend marray operator op(const marray &lhs, \ + const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = lhs[i] op rhs[i]; \ + } \ + return rtn; \ + } \ + \ + friend marray operator op(const marray &lhs, \ + const ComplexDataT &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = lhs[i] op rhs; \ + } \ + return rtn; \ + } \ + \ + friend marray operator op(const ComplexDataT &lhs, \ + const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = lhs op rhs[i]; \ + } \ + return rtn; \ + } + + OP(==) + OP(!=) + +#undef OP + + // OP is: <, >, <=, >= +#define OP(op) \ + friend marray operator op(const marray &lhs, \ + const marray &rhs) = delete; \ + friend marray operator op( \ + const marray &lhs, const ComplexDataT &rhs) = delete; \ + friend marray operator op(const ComplexDataT &lhs, \ + const marray &rhs) = delete; + + OP(<); + OP(>); + OP(<=); + OP(>=); + +#undef OP + + friend marray operator~(const marray &v) = delete; + + friend marray operator!(const marray &v) = delete; +}; + +namespace ext { +namespace oneapi { +namespace experimental { + +// Math marray overloads + +#define MATH_OP_ONE_PARAM(math_func, rtn_type, arg_type) \ + template ::value || \ + is_gencomplex::value>> \ + _SYCL_EXT_CPLX_INLINE_VISIBILITY sycl::marray \ + math_func(const sycl::marray &x) { \ + sycl::marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = sycl::ext::oneapi::experimental::math_func(x[i]); \ + } \ + return rtn; \ + } + +MATH_OP_ONE_PARAM(abs, T, complex); +MATH_OP_ONE_PARAM(acos, complex, complex); +MATH_OP_ONE_PARAM(asin, complex, complex); +MATH_OP_ONE_PARAM(atan, complex, complex); +MATH_OP_ONE_PARAM(acosh, complex, complex); +MATH_OP_ONE_PARAM(asinh, complex, complex); +MATH_OP_ONE_PARAM(atanh, complex, complex); +MATH_OP_ONE_PARAM(arg, T, complex); +MATH_OP_ONE_PARAM(conj, complex, complex); +MATH_OP_ONE_PARAM(cos, complex, complex); +MATH_OP_ONE_PARAM(cosh, complex, complex); +MATH_OP_ONE_PARAM(exp, complex, complex); +MATH_OP_ONE_PARAM(log, complex, complex); +MATH_OP_ONE_PARAM(log10, complex, complex); +MATH_OP_ONE_PARAM(norm, T, complex); +MATH_OP_ONE_PARAM(proj, complex, complex); +MATH_OP_ONE_PARAM(proj, complex, T); +MATH_OP_ONE_PARAM(sin, complex, complex); +MATH_OP_ONE_PARAM(sinh, complex, complex); +MATH_OP_ONE_PARAM(sqrt, complex, complex); +MATH_OP_ONE_PARAM(tan, complex, complex); +MATH_OP_ONE_PARAM(tanh, complex, complex); + +#undef MATH_OP_ONE_PARAM + +#define MATH_OP_TWO_PARAM(math_func, rtn_type, arg_type1, arg_type2) \ + template ::value || \ + is_gencomplex::value>> \ + _SYCL_EXT_CPLX_INLINE_VISIBILITY sycl::marray \ + math_func(const sycl::marray &x, \ + const sycl::marray &y) { \ + sycl::marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = sycl::ext::oneapi::experimental::math_func(x[i], y[i]); \ + } \ + return rtn; \ + } \ + \ + template ::value || \ + is_gencomplex::value>> \ + _SYCL_EXT_CPLX_INLINE_VISIBILITY sycl::marray \ + math_func(const sycl::marray &x, \ + const arg_type2 &y) { \ + sycl::marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = sycl::ext::oneapi::experimental::math_func(x[i], y); \ + } \ + return rtn; \ + } \ + \ + template ::value || \ + is_gencomplex::value>> \ + _SYCL_EXT_CPLX_INLINE_VISIBILITY sycl::marray \ + math_func(const arg_type1 &x, \ + const sycl::marray &y) { \ + sycl::marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = math_func(x, y[i]); \ + } \ + return rtn; \ + } + +MATH_OP_TWO_PARAM(pow, complex, complex, T); +MATH_OP_TWO_PARAM(pow, complex, complex, complex); +MATH_OP_TWO_PARAM(pow, complex, T, complex); + +#undef MATH_OP_TWO_PARAM + +// Special definition as polar requires default argument + +template ::value>> +_SYCL_EXT_CPLX_INLINE_VISIBILITY + sycl::marray, NumElements> + polar(const sycl::marray &rho, + const sycl::marray &theta) { + sycl::marray, NumElements> rtn; + for (std::size_t i = 0; i < NumElements; ++i) { + rtn[i] = sycl::ext::oneapi::experimental::polar(rho[i], theta[i]); + } + return rtn; +} + +template ::value>> +_SYCL_EXT_CPLX_INLINE_VISIBILITY + sycl::marray, NumElements> + polar(const sycl::marray &rho, const T &theta = 0) { + sycl::marray, NumElements> rtn; + for (std::size_t i = 0; i < NumElements; ++i) { + rtn[i] = sycl::ext::oneapi::experimental::polar(rho[i], theta); + } + return rtn; +} + +template ::value>> +_SYCL_EXT_CPLX_INLINE_VISIBILITY + sycl::marray, NumElements> + polar(const T &rho, const sycl::marray &theta) { + sycl::marray, NumElements> rtn; + for (std::size_t i = 0; i < NumElements; ++i) { + rtn[i] = sycl::ext::oneapi::experimental::polar(rho, theta[i]); + } + return rtn; +} + +} // namespace experimental +} // namespace oneapi +} // namespace ext + } // __SYCL_INLINE_VER_NAMESPACE(_V1) } // namespace sycl From c5ba5af6447f1019dda1e5223840ebf2a2080493 Mon Sep 17 00:00:00 2001 From: jle-quel Date: Tue, 14 Mar 2023 15:15:21 +0100 Subject: [PATCH 02/39] [SYCL][MARRAY] add marray specialization's tests --- sycl/test/extensions/test_marray_complex.cpp | 253 +++++++++++++++++++ 1 file changed, 253 insertions(+) create mode 100644 sycl/test/extensions/test_marray_complex.cpp diff --git a/sycl/test/extensions/test_marray_complex.cpp b/sycl/test/extensions/test_marray_complex.cpp new file mode 100644 index 0000000000000..0a4118bdd56f2 --- /dev/null +++ b/sycl/test/extensions/test_marray_complex.cpp @@ -0,0 +1,253 @@ +// RUN: %clangxx -fsycl -fsyntax-only %s + +#define SYCL_EXT_ONEAPI_COMPLEX + +#include +#include + +using namespace sycl::ext::oneapi::experimental; + +// Helper to test each complex specilization +template