diff --git a/sycl/include/CL/sycl/access/access.hpp b/sycl/include/CL/sycl/access/access.hpp index 10101d02435f5..9187a972bd6fb 100644 --- a/sycl/include/CL/sycl/access/access.hpp +++ b/sycl/include/CL/sycl/access/access.hpp @@ -45,7 +45,9 @@ enum class address_space : int { private_space = 0, global_space, constant_space, - local_space + local_space, + global_device_space, + global_host_space }; } // namespace access @@ -103,11 +105,15 @@ constexpr bool modeWritesNewData(access::mode m) { #ifdef __SYCL_DEVICE_ONLY__ #define __OPENCL_GLOBAL_AS__ __attribute__((opencl_global)) +#define __OPENCL_GLOBAL_DEVICE_AS__ __attribute__((opencl_global_device)) +#define __OPENCL_GLOBAL_HOST_AS__ __attribute__((opencl_global_host)) #define __OPENCL_LOCAL_AS__ __attribute__((opencl_local)) #define __OPENCL_CONSTANT_AS__ __attribute__((opencl_constant)) #define __OPENCL_PRIVATE_AS__ __attribute__((opencl_private)) #else #define __OPENCL_GLOBAL_AS__ +#define __OPENCL_GLOBAL_DEVICE_AS__ +#define __OPENCL_GLOBAL_HOST_AS__ #define __OPENCL_LOCAL_AS__ #define __OPENCL_CONSTANT_AS__ #define __OPENCL_PRIVATE_AS__ @@ -141,6 +147,16 @@ struct PtrValueType { using type = __OPENCL_GLOBAL_AS__ ElementType; }; +template +struct PtrValueType { + using type = __OPENCL_GLOBAL_DEVICE_AS__ ElementType; +}; + +template +struct PtrValueType { + using type = __OPENCL_GLOBAL_HOST_AS__ ElementType; +}; + template struct PtrValueType { // Current implementation of address spaces handling leads to possibility @@ -171,6 +187,14 @@ struct remove_AS<__OPENCL_GLOBAL_AS__ T> { typedef T type; }; +template struct remove_AS<__OPENCL_GLOBAL_DEVICE_AS__ T> { + typedef T type; +}; + +template struct remove_AS<__OPENCL_GLOBAL_HOST_AS__ T> { + typedef T type; +}; + template struct remove_AS<__OPENCL_PRIVATE_AS__ T> { typedef T type; @@ -188,6 +212,8 @@ struct remove_AS<__OPENCL_CONSTANT_AS__ T> { #endif #undef __OPENCL_GLOBAL_AS__ +#undef __OPENCL_GLOBAL_DEVICE_AS__ +#undef __OPENCL_GLOBAL_HOST_AS__ #undef __OPENCL_LOCAL_AS__ #undef __OPENCL_CONSTANT_AS__ #undef __OPENCL_PRIVATE_AS__ diff --git a/sycl/include/CL/sycl/atomic.hpp b/sycl/include/CL/sycl/atomic.hpp index da9daa465efdd..6c0be13b2c523 100644 --- a/sycl/include/CL/sycl/atomic.hpp +++ b/sycl/include/CL/sycl/atomic.hpp @@ -46,8 +46,10 @@ template struct IsValidAtomicType { }; template struct IsValidAtomicAddressSpace { - static constexpr bool value = (AS == access::address_space::global_space || - AS == access::address_space::local_space); + static constexpr bool value = + (AS == access::address_space::global_space || + AS == access::address_space::local_space || + AS == access::address_space::global_device_space); }; // Type trait to translate a cl::sycl::access::address_space to @@ -56,6 +58,10 @@ template struct GetSpirvMemoryScope {}; template <> struct GetSpirvMemoryScope { static constexpr auto scope = __spv::Scope::Device; }; +template <> +struct GetSpirvMemoryScope { + static constexpr auto scope = __spv::Scope::Device; +}; template <> struct GetSpirvMemoryScope { static constexpr auto scope = __spv::Scope::Workgroup; }; @@ -168,12 +174,12 @@ template class atomic { static_assert(detail::IsValidAtomicType::value, - "Invalid SYCL atomic type. Valid types are: int, " - "unsigned int, long, unsigned long, long long, unsigned " + "Invalid SYCL atomic type. Valid types are: int, " + "unsigned int, long, unsigned long, long long, unsigned " "long long, float"); static_assert(detail::IsValidAtomicAddressSpace::value, - "Invalid SYCL atomic address_space. Valid address spaces are: " - "global_space, local_space"); + "Invalid SYCL atomic address_space. Valid address spaces are: " + "global_space, local_space, global_device_space"); static constexpr auto SpirvScope = detail::GetSpirvMemoryScope::scope; diff --git a/sycl/include/CL/sycl/detail/generic_type_lists.hpp b/sycl/include/CL/sycl/detail/generic_type_lists.hpp index 191b52765c524..9965ea66eee9a 100644 --- a/sycl/include/CL/sycl/detail/generic_type_lists.hpp +++ b/sycl/include/CL/sycl/detail/generic_type_lists.hpp @@ -361,21 +361,25 @@ using nan_list = type_list; +using all_address_space_list = address_space_list< + access::address_space::local_space, access::address_space::global_space, + access::address_space::private_space, access::address_space::constant_space, + access::address_space::global_device_space, + access::address_space::global_host_space>; using nonconst_address_space_list = address_space_list; + access::address_space::private_space, + access::address_space::global_device_space, + access::address_space::global_host_space>; using nonlocal_address_space_list = address_space_list; + access::address_space::constant_space, + access::address_space::global_device_space, + access::address_space::global_host_space>; } // namespace gvl } // namespace detail } // namespace sycl diff --git a/sycl/include/CL/sycl/multi_ptr.hpp b/sycl/include/CL/sycl/multi_ptr.hpp index 4495c654ecb3a..1a59113d9fc18 100644 --- a/sycl/include/CL/sycl/multi_ptr.hpp +++ b/sycl/include/CL/sycl/multi_ptr.hpp @@ -108,17 +108,18 @@ template class multi_ptr { return reinterpret_cast(m_Pointer)[index]; } - // Only if Space == global_space + // Only if Space == global_space || global_device_space template ::type> + (Space == access::address_space::global_space || + Space == access::address_space::global_device_space)>::type> multi_ptr(accessor Accessor) { - m_Pointer = (pointer_t)(Accessor.get_pointer().m_Pointer); + m_Pointer = (pointer_t)(Accessor.get_pointer().get()); } // Only if Space == local_space @@ -152,14 +153,17 @@ template class multi_ptr { // 2. from multi_ptr to multi_ptr - // Only if Space == global_space and element type is const - template < - int dimensions, access::mode Mode, access::placeholder isPlaceholder, - access::address_space _Space = Space, typename ET = ElementType, - typename = typename std::enable_if< - _Space == Space && Space == access::address_space::global_space && - std::is_const::value && - std::is_same::value>::type> + // Only if Space == global_space || global_device_space and element type is + // const + template ::value && + std::is_same::value>::type> multi_ptr(accessor::type, dimensions, Mode, access::target::global_buffer, isPlaceholder> Accessor) @@ -345,12 +349,13 @@ template class multi_ptr { return *this; } - // Only if Space == global_space + // Only if Space == global_space || global_device_space template ::type> + (Space == access::address_space::global_space || + Space == access::address_space::global_device_space)>::type> multi_ptr( accessor @@ -466,12 +471,13 @@ class multi_ptr { return *this; } - // Only if Space == global_space + // Only if Space == global_space || global_device_space template ::type> + (Space == access::address_space::global_space || + Space == access::address_space::global_device_space)>::type> multi_ptr( accessor diff --git a/sycl/include/CL/sycl/pointers.hpp b/sycl/include/CL/sycl/pointers.hpp index 9f91ba70ee6b7..efec74e0fd3a6 100644 --- a/sycl/include/CL/sycl/pointers.hpp +++ b/sycl/include/CL/sycl/pointers.hpp @@ -19,6 +19,14 @@ template class multi_ptr; template using global_ptr = multi_ptr; +template +using device_ptr = + multi_ptr; + +template +using host_ptr = + multi_ptr; + template using local_ptr = multi_ptr; diff --git a/sycl/test/check_device_code/usm_pointers.cpp b/sycl/test/check_device_code/usm_pointers.cpp new file mode 100644 index 0000000000000..aa0a0ed58045d --- /dev/null +++ b/sycl/test/check_device_code/usm_pointers.cpp @@ -0,0 +1,41 @@ +// RUN: %clangxx -fsycl-device-only -Xclang -fsycl-is-device -emit-llvm %s -S -o %t.ll -I %sycl_include -Wno-sycl-strict -Xclang -verify-ignore-unexpected=note,warning +// RUN: FileCheck %s --input-file %t.ll +// +// Check the address space of the pointer in multi_ptr class +// +// CHECK: %[[DEVPTR_T:.*]] = type { i8 addrspace(5)* } +// CHECK: %[[HOSTPTR_T:.*]] = type { i8 addrspace(6)* } +// +// CHECK-LABEL: define {{.*}} spir_func i8 addrspace(4)* @{{.*}}multi_ptr{{.*}} +// CHECK: %m_Pointer = getelementptr inbounds %[[DEVPTR_T]] +// CHECK-NEXT: %[[DEVLOAD:[0-9]+]] = load i8 addrspace(5)*, i8 addrspace(5)* addrspace(4)* %m_Pointer +// CHECK-NEXT: %[[DEVCAST:[0-9]+]] = addrspacecast i8 addrspace(5)* %[[DEVLOAD]] to i8 addrspace(4)* +// ret i8 addrspace(4)* %[[DEVCAST]] +// +// CHECK-LABEL: define {{.*}} spir_func i8 addrspace(4)* @{{.*}}multi_ptr{{.*}} +// CHECK: %m_Pointer = getelementptr inbounds %[[HOSTPTR_T]] +// CHECK-NEXT: %[[HOSTLOAD:[0-9]+]] = load i8 addrspace(6)*, i8 addrspace(6)* addrspace(4)* %m_Pointer +// CHECK-NEXT: %[[HOSTCAST:[0-9]+]] = addrspacecast i8 addrspace(6)* %[[HOSTLOAD]] to i8 addrspace(4)* +// ret i8 addrspace(4)* %[[HOSTCAST]] + +#include + +using namespace cl::sycl; + +int main() { + cl::sycl::queue queue; + { + queue.submit([&](cl::sycl::handler &cgh) { + cgh.single_task([=]() { + void *Ptr = nullptr; + device_ptr DevPtr(Ptr); + host_ptr HostPtr(Ptr); + global_ptr GlobPtr = global_ptr(DevPtr); + GlobPtr = global_ptr(HostPtr); + }); + }); + queue.wait(); + } + + return 0; +} diff --git a/sycl/test/multi_ptr/multi_ptr.cpp b/sycl/test/multi_ptr/multi_ptr.cpp index c2e44f461e1b7..9ebb33046a459 100644 --- a/sycl/test/multi_ptr/multi_ptr.cpp +++ b/sycl/test/multi_ptr/multi_ptr.cpp @@ -82,6 +82,7 @@ template void testMultPtr() { auto local_ptr = make_ptr( localAccessor.get_pointer()); + // General conversions in multi_ptr class T *RawPtr = nullptr; global_ptr ptr_4(RawPtr); ptr_4 = RawPtr; @@ -92,6 +93,12 @@ template void testMultPtr() { ptr_6 = (void *)RawPtr; + // Explicit conversions for device_ptr/host_ptr to global_ptr + device_ptr ptr_7((void *)RawPtr); + global_ptr ptr_8 = global_ptr(ptr_7); + host_ptr ptr_9((void *)RawPtr); + global_ptr ptr_10 = global_ptr(ptr_9); + innerFunc(wiID.get(0), ptr_1, ptr_2, local_ptr); }); }); @@ -109,12 +116,14 @@ void testMultPtrArrowOperator() { point data_1[1] = {1}; point data_2[1] = {2}; point data_3[1] = {3}; + point data_4[1] = {4}; { range<1> numOfItems{1}; buffer, 1> bufferData_1(data_1, numOfItems); buffer, 1> bufferData_2(data_2, numOfItems); buffer, 1> bufferData_3(data_3, numOfItems); + buffer, 1> bufferData_4(data_4, numOfItems); queue myQueue; myQueue.submit([&](handler &cgh) { accessor, 1, access::mode::read, access::target::global_buffer, @@ -126,6 +135,9 @@ void testMultPtrArrowOperator() { accessor, 1, access::mode::read_write, access::target::local, access::placeholder::false_t> accessorData_3(1, cgh); + accessor, 1, access::mode::read, access::target::global_buffer, + access::placeholder::false_t> + accessorData_4(bufferData_4, cgh); cgh.single_task>([=]() { auto ptr_1 = make_ptr, access::address_space::global_space>( @@ -134,10 +146,13 @@ void testMultPtrArrowOperator() { accessorData_2.get_pointer()); auto ptr_3 = make_ptr, access::address_space::local_space>( accessorData_3.get_pointer()); + auto ptr_4 = make_ptr, access::address_space::global_device_space>( + accessorData_4.get_pointer()); auto x1 = ptr_1->x; auto x2 = ptr_2->x; auto x3 = ptr_3->x; + auto x4 = ptr_4->x; static_assert(std::is_same::value, "Expected decltype(ptr_1->x) == T"); @@ -145,6 +160,8 @@ void testMultPtrArrowOperator() { "Expected decltype(ptr_2->x) == T"); static_assert(std::is_same::value, "Expected decltype(ptr_3->x) == T"); + static_assert(std::is_same::value, + "Expected decltype(ptr_4->x) == T"); }); }); }