diff --git a/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp b/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp index 3ff837beabd16..6d076fed8b87f 100644 --- a/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp +++ b/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp @@ -100,7 +100,7 @@ struct sub_group_mask { uint32_t mask = 0; if (pos.get(0) + insert_size < size()) mask |= (0xffffffff << (pos.get(0) + insert_size)); - if (pos.get(0) < size()) + if (pos.get(0) < size() && pos.get(0)) mask |= (0xffffffff >> (size() - pos.get(0))); Bits &= mask; Bits += insert_data; diff --git a/sycl/test/extensions/sub_group_mask.cpp b/sycl/test/extensions/sub_group_mask.cpp index cc2cdfa17c43c..1021c3a4a65a4 100644 --- a/sycl/test/extensions/sub_group_mask.cpp +++ b/sycl/test/extensions/sub_group_mask.cpp @@ -87,4 +87,26 @@ int main() { sycl::marray r3{-1}; b.extract_bits(r3, 14); assert(r3[0] == 1 && r3[1] == 2 && r3[2] == 2 && !r3[3] && !r3[4] && !r3[5]); + int ibits = 0b1010101010101010101010101010101; + b.insert_bits(ibits); + for (size_t i = 0; i < 32; i++) { + assert(b[i] != (bool)(i % 2)); + } + short sbits = 0b0111011101110111; + b.insert_bits(sbits, 7); + b.extract_bits(ibits); + assert(ibits == 0b1010101001110111011101111010101); + sbits = 0b1100001111000011; + b.insert_bits(sbits, 23); + b.extract_bits(ibits); + assert(ibits == 0b11100001101110111011101111010101); + int64_t lbits = -1; + b.extract_bits(lbits, 33); + assert(lbits == 0); + lbits = -1; + b.extract_bits(lbits, 5); + assert(lbits == 0b111000011011101110111011110); + lbits = -1; + b.insert_bits(lbits); + assert(b.all()); }