Skip to content

Commit 8cc570a

Browse files
committed
Where kernel for contiguous data now uses vec_cast
1 parent c59fedb commit 8cc570a

File tree

1 file changed

+19
-7
lines changed
  • dpctl/tensor/libtensor/include/kernels

1 file changed

+19
-7
lines changed

dpctl/tensor/libtensor/include/kernels/where.hpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ class WhereContigFunctor
8080
T *dst_data = reinterpret_cast<T *>(dst_cp);
8181
const condT *cond_data = reinterpret_cast<const condT *>(cond_cp);
8282

83-
using dpctl::tensor::type_utils::convert_impl;
84-
8583
using dpctl::tensor::type_utils::is_complex;
8684
if constexpr (is_complex<condT>::value || is_complex<T>::value) {
8785
std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0];
@@ -92,6 +90,7 @@ class WhereContigFunctor
9290
offset < std::min(nelems, base + sgSize * (n_vecs * vec_sz));
9391
offset += sgSize)
9492
{
93+
using dpctl::tensor::type_utils::convert_impl;
9594
bool check = convert_impl<bool, condT>(cond_data[offset]);
9695
dst_data[offset] = check ? x1_data[offset] : x2_data[offset];
9796
}
@@ -115,7 +114,6 @@ class WhereContigFunctor
115114
using cond_ptrT =
116115
sycl::multi_ptr<const condT,
117116
sycl::access::address_space::global_space>;
118-
119117
sycl::vec<T, vec_sz> dst_vec;
120118
sycl::vec<T, vec_sz> x1_vec;
121119
sycl::vec<T, vec_sz> x2_vec;
@@ -127,18 +125,32 @@ class WhereContigFunctor
127125
x1_vec = sg.load<vec_sz>(x_ptrT(&x1_data[idx]));
128126
x2_vec = sg.load<vec_sz>(x_ptrT(&x2_data[idx]));
129127
cond_vec = sg.load<vec_sz>(cond_ptrT(&cond_data[idx]));
130-
128+
if constexpr (std::is_same_v<bool, condT>) {
129+
#pragma unroll
130+
for (std::uint8_t k = 0; k < vec_sz; ++k) {
131+
bool check = cond_vec[k];
132+
dst_vec[k] = check ? x1_vec[k] : x2_vec[k];
133+
}
134+
}
135+
else {
136+
using dpctl::tensor::type_utils::vec_cast;
137+
sycl::vec<bool, vec_sz> tmp =
138+
vec_cast<bool,
139+
typename decltype(cond_vec)::element_type,
140+
vec_sz>(cond_vec);
131141
#pragma unroll
132-
for (std::uint8_t k = 0; k < vec_sz; ++k) {
133-
bool check = convert_impl<bool, condT>(cond_vec[k]);
134-
dst_vec[k] = check ? x1_vec[k] : x2_vec[k];
142+
for (std::uint8_t k = 0; k < vec_sz; ++k) {
143+
bool check = tmp[k];
144+
dst_vec[k] = check ? x1_vec[k] : x2_vec[k];
145+
}
135146
}
136147
sg.store<vec_sz>(dst_ptrT(&dst_data[idx]), dst_vec);
137148
}
138149
}
139150
else {
140151
for (size_t k = base + sg.get_local_id()[0]; k < nelems;
141152
k += sgSize) {
153+
using dpctl::tensor::type_utils::convert_impl;
142154
bool check = convert_impl<bool, condT>(cond_data[k]);
143155
dst_data[k] = check ? x1_data[k] : x2_data[k];
144156
}

0 commit comments

Comments
 (0)