Skip to content

Commit 0dc80bd

Browse files
author
Andrew Savonichev
committed
[SYCL] Return __constant pointers from 'constant' multi_ptr
We cannot return a generic pointer like in 'global', 'local' and 'private' cases, because 'constant' and 'generic' address spaces do not overlap. This patch returns __constant pointer as is, and assumes that a user have to deal with __constant pointer limitations (ie. such pointers cannot be casted to plain (default) pointers). Signed-off-by: Andrew Savonichev <[email protected]>
1 parent 9be9768 commit 0dc80bd

File tree

1 file changed

+29
-11
lines changed

1 file changed

+29
-11
lines changed

sycl/include/CL/sycl/multi_ptr.hpp

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,37 @@ template <typename ElementType, access::address_space Space> class multi_ptr {
6969
m_Pointer = nullptr;
7070
return *this;
7171
}
72-
ElementType &operator*() const {
73-
return *(reinterpret_cast<ElementType *>(m_Pointer));
72+
73+
#ifdef __SYCL_ENABLE_INFER_AS__
74+
using ReturnPtr =
75+
typename std::conditional<Space == access::address_space::constant_space,
76+
pointer_t, ElementType *>::type;
77+
using ReturnRef =
78+
typename std::conditional<Space == access::address_space::constant_space,
79+
reference_t, ElementType &>::type;
80+
using ReturnConstRef =
81+
typename std::conditional<Space == access::address_space::constant_space,
82+
const_reference_t, const ElementType &>::type;
83+
#else
84+
using ReturnPtr = ElementType *;
85+
using ReturnRef = ElementType &;
86+
using ReturnConstRef = const ElementType &;
87+
#endif
88+
89+
ReturnRef operator*() const {
90+
return *reinterpret_cast<ReturnPtr>(m_Pointer);
7491
}
75-
ElementType *operator->() const {
76-
return reinterpret_cast<ElementType *>(m_Pointer);
92+
93+
ReturnPtr operator->() const {
94+
return reinterpret_cast<ReturnPtr>(m_Pointer);
7795
}
78-
ElementType &operator[](difference_type index) {
79-
return *(reinterpret_cast<ElementType *>(m_Pointer + index));
96+
97+
ReturnRef operator[](difference_type index) {
98+
return reinterpret_cast<ReturnPtr>(m_Pointer)[index];
8099
}
81-
ElementType operator[](difference_type index) const {
82-
return *(reinterpret_cast<ElementType *>(m_Pointer + index));
100+
101+
ReturnConstRef operator[](difference_type index) const {
102+
return reinterpret_cast<ReturnPtr>(m_Pointer)[index];
83103
}
84104

85105
// Only if Space == global_space
@@ -181,9 +201,7 @@ template <typename ElementType, access::address_space Space> class multi_ptr {
181201
pointer_t get() const { return m_Pointer; }
182202

183203
// Implicit conversion to the underlying pointer type
184-
operator ElementType *() const {
185-
return reinterpret_cast<ElementType *>(m_Pointer);
186-
}
204+
operator ReturnPtr() const { return reinterpret_cast<ReturnPtr>(m_Pointer); }
187205

188206
// Implicit conversion to a multi_ptr<void>
189207
// Only available when ElementType is not const-qualified

0 commit comments

Comments
 (0)