Skip to content

Commit e8ac887

Browse files
Fznamznonvladimirlaz
authored andcommitted
[SYCL] Implement broadcasting vec::operator=
Signed-off-by: Mariya Podchishchaeva <[email protected]>
1 parent 3e11f59 commit e8ac887

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

sycl/include/CL/sycl/types.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,12 +372,34 @@ template <typename Type, int NumElements> class vec {
372372
return *this;
373373
}
374374

375+
#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
376+
explicit vec(const DataT &arg) {
377+
m_Data = (DataType)arg;
378+
}
379+
380+
template <typename Ty = DataT>
381+
typename std::enable_if<std::is_fundamental<Ty>::value, vec &>::type
382+
operator=(const DataT &Rhs) {
383+
m_Data = (DataType)Rhs;
384+
return *this;
385+
}
386+
#else
375387
explicit vec(const DataT &arg) {
376388
for (int i = 0; i < NumElements; ++i) {
377389
setValue(i, arg);
378390
}
379391
}
380392

393+
template <typename Ty = DataT>
394+
typename std::enable_if<std::is_fundamental<Ty>::value, vec &>::type
395+
operator=(const DataT &Rhs) {
396+
for (int i = 0; i < NumElements; ++i) {
397+
setValue(i, Rhs);
398+
}
399+
return *this;
400+
}
401+
#endif
402+
381403
#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
382404
// Optimized naive constructors with NumElements of DataT values.
383405
// We don't expect compilers to optimize vararg recursive functions well.

sycl/test/basic_tests/vectors.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,13 @@ int main() {
5050
int64_t(vec_2.x());
5151
cl::sycl::int4(vec_2.x());
5252

53+
// Check broadcasting operator=
54+
cl::sycl::vec<float, 4> b_vec(1.0);
55+
b_vec = 0.5;
56+
assert(static_cast<float>(b_vec.x()) == static_cast<float>(0.5));
57+
assert(static_cast<float>(b_vec.y()) == static_cast<float>(0.5));
58+
assert(static_cast<float>(b_vec.z()) == static_cast<float>(0.5));
59+
assert(static_cast<float>(b_vec.w()) == static_cast<float>(0.5));
60+
5361
return 0;
5462
}

0 commit comments

Comments
 (0)