diff --git a/sycl/plugins/unified_runtime/CMakeLists.txt b/sycl/plugins/unified_runtime/CMakeLists.txt index 167d5d8c7769..73d9d5a950da 100644 --- a/sycl/plugins/unified_runtime/CMakeLists.txt +++ b/sycl/plugins/unified_runtime/CMakeLists.txt @@ -57,13 +57,13 @@ if(SYCL_PI_UR_USE_FETCH_CONTENT) include(FetchContent) set(UNIFIED_RUNTIME_REPO "https://github.com/oneapi-src/unified-runtime.git") - # commit e69ed21468e04ed6e832accf162422ed11736446 - # Merge: 20fa0b5f 7fd9dafd + # commit 69a56ea6d1369a6bde5fce97c85fc7dbda49252f + # Merge: b25bb64d b78f541d # Author: Kenneth Benzie (Benie) - # Date: Fri Dec 8 12:18:51 2023 +0000 - # Merge pull request #962 from jandres742/fixwaitbarrierwithevent - # [UR][L0] Correctly wait on barrier on urEnqueueEventsWaitWithBarrier - set(UNIFIED_RUNTIME_TAG e69ed21468e04ed6e832accf162422ed11736446) + # Date: Mon Dec 11 12:30:24 2023 +0000 + # Merge pull request #1123 from aarongreig/aaron/usmLocationProps + # [OpenCL] Add ur_usm_alloc_location_desc struct and handle it in the CL adapter. + set(UNIFIED_RUNTIME_TAG 69a56ea6d1369a6bde5fce97c85fc7dbda49252f) if(SYCL_PI_UR_OVERRIDE_FETCH_CONTENT_REPO) set(UNIFIED_RUNTIME_REPO "${SYCL_PI_UR_OVERRIDE_FETCH_CONTENT_REPO}") diff --git a/sycl/plugins/unified_runtime/pi2ur.hpp b/sycl/plugins/unified_runtime/pi2ur.hpp index 01da9137e440..6e36b21c31b8 100644 --- a/sycl/plugins/unified_runtime/pi2ur.hpp +++ b/sycl/plugins/unified_runtime/pi2ur.hpp @@ -2697,12 +2697,28 @@ inline pi_result piMemBufferCreate(pi_context Context, pi_mem_flags Flags, inline pi_result piextUSMHostAlloc(void **ResultPtr, pi_context Context, pi_usm_mem_properties *Properties, size_t Size, pi_uint32 Alignment) { + ur_usm_desc_t USMDesc{}; + USMDesc.align = Alignment; + + ur_usm_alloc_location_desc_t UsmLocationDesc{}; + UsmLocationDesc.stype = UR_STRUCTURE_TYPE_USM_ALLOC_LOCATION_DESC; + + if (Properties) { + uint32_t Next = 0; + while (Properties[Next]) { + if (Properties[Next] == PI_MEM_USM_ALLOC_BUFFER_LOCATION) { + UsmLocationDesc.location = static_cast(Properties[Next + 1]); + USMDesc.pNext = &UsmLocationDesc; + } else { + return PI_ERROR_INVALID_VALUE; + } + Next += 2; + } + } - std::ignore = Properties; ur_context_handle_t UrContext = reinterpret_cast(Context); - ur_usm_desc_t USMDesc{}; - USMDesc.align = Alignment; + ur_usm_pool_handle_t Pool{}; HANDLE_ERRORS(urUSMHostAlloc(UrContext, &USMDesc, Pool, Size, ResultPtr)); return PI_SUCCESS; @@ -3131,14 +3147,29 @@ inline pi_result piextUSMDeviceAlloc(void **ResultPtr, pi_context Context, pi_device Device, pi_usm_mem_properties *Properties, size_t Size, pi_uint32 Alignment) { - - std::ignore = Properties; ur_context_handle_t UrContext = reinterpret_cast(Context); auto UrDevice = reinterpret_cast(Device); ur_usm_desc_t USMDesc{}; USMDesc.align = Alignment; + + ur_usm_alloc_location_desc_t UsmLocDesc{}; + UsmLocDesc.stype = UR_STRUCTURE_TYPE_USM_ALLOC_LOCATION_DESC; + + if (Properties) { + uint32_t Next = 0; + while (Properties[Next]) { + if (Properties[Next] == PI_MEM_USM_ALLOC_BUFFER_LOCATION) { + UsmLocDesc.location = static_cast(Properties[Next + 1]); + USMDesc.pNext = &UsmLocDesc; + } else { + return PI_ERROR_INVALID_VALUE; + } + Next += 2; + } + } + ur_usm_pool_handle_t Pool{}; HANDLE_ERRORS( urUSMDeviceAlloc(UrContext, UrDevice, &USMDesc, Pool, Size, ResultPtr)); @@ -3171,42 +3202,58 @@ inline pi_result piextUSMSharedAlloc(void **ResultPtr, pi_context Context, pi_device Device, pi_usm_mem_properties *Properties, size_t Size, pi_uint32 Alignment) { - - std::ignore = Properties; - if (Properties && *Properties != 0) { - PI_ASSERT(*(Properties) == PI_MEM_ALLOC_FLAGS && *(Properties + 2) == 0, - PI_ERROR_INVALID_VALUE); - } - ur_context_handle_t UrContext = reinterpret_cast(Context); auto UrDevice = reinterpret_cast(Device); ur_usm_desc_t USMDesc{}; + USMDesc.align = Alignment; ur_usm_device_desc_t UsmDeviceDesc{}; UsmDeviceDesc.stype = UR_STRUCTURE_TYPE_USM_DEVICE_DESC; ur_usm_host_desc_t UsmHostDesc{}; UsmHostDesc.stype = UR_STRUCTURE_TYPE_USM_HOST_DESC; + ur_usm_alloc_location_desc_t UsmLocationDesc{}; + UsmLocationDesc.stype = UR_STRUCTURE_TYPE_USM_ALLOC_LOCATION_DESC; + + // One properties bitfield can correspond to a host_desc and a device_desc + // struct, since having `0` values in these is harmless we can set up this + // pNext chain in advance. + USMDesc.pNext = &UsmDeviceDesc; + UsmDeviceDesc.pNext = &UsmHostDesc; + if (Properties) { - if (Properties[0] == PI_MEM_ALLOC_FLAGS) { - if (Properties[1] == PI_MEM_ALLOC_WRTITE_COMBINED) { - UsmDeviceDesc.flags |= UR_USM_DEVICE_MEM_FLAG_WRITE_COMBINED; - } - if (Properties[1] == PI_MEM_ALLOC_INITIAL_PLACEMENT_DEVICE) { - UsmDeviceDesc.flags |= UR_USM_DEVICE_MEM_FLAG_INITIAL_PLACEMENT; + uint32_t Next = 0; + while (Properties[Next]) { + switch (Properties[Next]) { + case PI_MEM_ALLOC_FLAGS: { + if (Properties[Next + 1] & PI_MEM_ALLOC_WRTITE_COMBINED) { + UsmDeviceDesc.flags |= UR_USM_DEVICE_MEM_FLAG_WRITE_COMBINED; + } + if (Properties[Next + 1] & PI_MEM_ALLOC_INITIAL_PLACEMENT_DEVICE) { + UsmDeviceDesc.flags |= UR_USM_DEVICE_MEM_FLAG_INITIAL_PLACEMENT; + } + if (Properties[Next + 1] & PI_MEM_ALLOC_INITIAL_PLACEMENT_HOST) { + UsmHostDesc.flags |= UR_USM_HOST_MEM_FLAG_INITIAL_PLACEMENT; + } + if (Properties[Next + 1] & PI_MEM_ALLOC_DEVICE_READ_ONLY) { + UsmDeviceDesc.flags |= UR_USM_DEVICE_MEM_FLAG_DEVICE_READ_ONLY; + } + break; } - if (Properties[1] == PI_MEM_ALLOC_INITIAL_PLACEMENT_HOST) { - UsmHostDesc.flags |= UR_USM_HOST_MEM_FLAG_INITIAL_PLACEMENT; + case PI_MEM_USM_ALLOC_BUFFER_LOCATION: { + UsmLocationDesc.location = static_cast(Properties[Next + 1]); + // We wait until we've seen a BUFFER_LOCATION property to tack this + // onto the end of the chain, a `0` here might be valid as far as we + // know so we must exclude it unless we've been given a value. + UsmHostDesc.pNext = &UsmLocationDesc; + break; } - if (Properties[1] == PI_MEM_ALLOC_DEVICE_READ_ONLY) { - UsmDeviceDesc.flags |= UR_USM_DEVICE_MEM_FLAG_DEVICE_READ_ONLY; + default: + return PI_ERROR_INVALID_VALUE; } + Next += 2; } } - UsmDeviceDesc.pNext = &UsmHostDesc; - USMDesc.pNext = &UsmDeviceDesc; - - USMDesc.align = Alignment; ur_usm_pool_handle_t Pool{}; HANDLE_ERRORS(