Skip to content

Commit 47daec1

Browse files
committed
[L0 v2] implement urMemBufferPartition
Change access_mode_t to device_access_mode_t. Previously it was use to specify both host and device access modes which was confusing. Now, it only specifies device access mode.
1 parent d8b57aa commit 47daec1

File tree

6 files changed

+213
-123
lines changed

6 files changed

+213
-123
lines changed

source/adapters/level_zero/v2/kernel.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -394,21 +394,21 @@ ur_result_t urKernelSetArgPointer(
394394
return hKernel->setArgPointer(argIndex, pProperties, pArgValue);
395395
}
396396

397-
static ur_mem_handle_t_::access_mode_t memAccessFromKernelProperties(
397+
static ur_mem_handle_t_::device_access_mode_t memAccessFromKernelProperties(
398398
const ur_kernel_arg_mem_obj_properties_t *pProperties) {
399399
if (pProperties) {
400400
switch (pProperties->memoryAccess) {
401401
case UR_MEM_FLAG_READ_WRITE:
402-
return ur_mem_handle_t_::access_mode_t::read_write;
402+
return ur_mem_handle_t_::device_access_mode_t::read_write;
403403
case UR_MEM_FLAG_WRITE_ONLY:
404-
return ur_mem_handle_t_::access_mode_t::write_only;
404+
return ur_mem_handle_t_::device_access_mode_t::write_only;
405405
case UR_MEM_FLAG_READ_ONLY:
406-
return ur_mem_handle_t_::access_mode_t::read_only;
406+
return ur_mem_handle_t_::device_access_mode_t::read_only;
407407
default:
408-
return ur_mem_handle_t_::access_mode_t::read_write;
408+
return ur_mem_handle_t_::device_access_mode_t::read_write;
409409
}
410410
}
411-
return ur_mem_handle_t_::access_mode_t::read_write;
411+
return ur_mem_handle_t_::device_access_mode_t::read_write;
412412
}
413413

414414
ur_result_t

source/adapters/level_zero/v2/kernel.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ struct ur_kernel_handle_t_ : _ur_object {
105105

106106
struct pending_memory_allocation_t {
107107
ur_mem_handle_t hMem;
108-
ur_mem_handle_t_::access_mode_t mode;
108+
ur_mem_handle_t_::device_access_mode_t mode;
109109
uint32_t argIndex;
110110
};
111111

source/adapters/level_zero/v2/memory.cpp

Lines changed: 131 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,42 @@
1313

1414
#include "../helpers/memory_helpers.hpp"
1515

16-
ur_mem_handle_t_::ur_mem_handle_t_(ur_context_handle_t hContext, size_t size)
17-
: hContext(hContext), size(size) {}
16+
static ur_mem_handle_t_::device_access_mode_t
17+
getDeviceAccessMode(ur_mem_flags_t memFlag) {
18+
if (memFlag & UR_MEM_FLAG_READ_WRITE) {
19+
return ur_mem_handle_t_::device_access_mode_t::read_write;
20+
} else if (memFlag & UR_MEM_FLAG_READ_ONLY) {
21+
return ur_mem_handle_t_::device_access_mode_t::read_only;
22+
} else if (memFlag & UR_MEM_FLAG_WRITE_ONLY) {
23+
return ur_mem_handle_t_::device_access_mode_t::write_only;
24+
} else {
25+
return ur_mem_handle_t_::device_access_mode_t::read_write;
26+
}
27+
}
28+
29+
static bool isAccessCompatible(ur_mem_handle_t_::device_access_mode_t requested,
30+
ur_mem_handle_t_::device_access_mode_t actual) {
31+
return requested == actual ||
32+
actual == ur_mem_handle_t_::device_access_mode_t::read_write;
33+
}
34+
35+
ur_mem_handle_t_::ur_mem_handle_t_(ur_context_handle_t hContext, size_t size,
36+
device_access_mode_t accessMode)
37+
: accessMode(accessMode), hContext(hContext), size(size) {}
38+
39+
size_t ur_mem_handle_t_::getSize() const { return size; }
40+
41+
ur_shared_mutex &ur_mem_handle_t_::getMutex() { return Mutex; }
1842

1943
ur_usm_handle_t_::ur_usm_handle_t_(ur_context_handle_t hContext, size_t size,
2044
const void *ptr)
21-
: ur_mem_handle_t_(hContext, size), ptr(const_cast<void *>(ptr)) {}
45+
: ur_mem_handle_t_(hContext, size, device_access_mode_t::read_write),
46+
ptr(const_cast<void *>(ptr)) {}
2247

2348
ur_usm_handle_t_::~ur_usm_handle_t_() {}
2449

2550
void *ur_usm_handle_t_::getDevicePtr(
26-
ur_device_handle_t hDevice, access_mode_t access, size_t offset,
51+
ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
2752
size_t size, std::function<void(void *src, void *dst, size_t)> migrate) {
2853
std::ignore = hDevice;
2954
std::ignore = access;
@@ -34,9 +59,9 @@ void *ur_usm_handle_t_::getDevicePtr(
3459
}
3560

3661
void *ur_usm_handle_t_::mapHostPtr(
37-
access_mode_t access, size_t offset, size_t size,
62+
ur_map_flags_t flags, size_t offset, size_t size,
3863
std::function<void(void *src, void *dst, size_t)>) {
39-
std::ignore = access;
64+
std::ignore = flags;
4065
std::ignore = offset;
4166
std::ignore = size;
4267
return ptr;
@@ -50,8 +75,8 @@ void ur_usm_handle_t_::unmapHostPtr(
5075

5176
ur_integrated_mem_handle_t::ur_integrated_mem_handle_t(
5277
ur_context_handle_t hContext, void *hostPtr, size_t size,
53-
host_ptr_action_t hostPtrAction)
54-
: ur_mem_handle_t_(hContext, size) {
78+
host_ptr_action_t hostPtrAction, device_access_mode_t accessMode)
79+
: ur_mem_handle_t_(hContext, size, accessMode) {
5580
bool hostPtrImported = false;
5681
if (hostPtrAction == host_ptr_action_t::import) {
5782
hostPtrImported =
@@ -83,8 +108,9 @@ ur_integrated_mem_handle_t::ur_integrated_mem_handle_t(
83108
}
84109

85110
ur_integrated_mem_handle_t::ur_integrated_mem_handle_t(
86-
ur_context_handle_t hContext, void *hostPtr, size_t size, bool ownHostPtr)
87-
: ur_mem_handle_t_(hContext, size) {
111+
ur_context_handle_t hContext, void *hostPtr, size_t size,
112+
device_access_mode_t accessMode, bool ownHostPtr)
113+
: ur_mem_handle_t_(hContext, size, accessMode) {
88114
this->ptr = usm_unique_ptr_t(hostPtr, [hContext, ownHostPtr](void *ptr) {
89115
if (!ownHostPtr) {
90116
return;
@@ -97,7 +123,7 @@ ur_integrated_mem_handle_t::ur_integrated_mem_handle_t(
97123
}
98124

99125
void *ur_integrated_mem_handle_t::getDevicePtr(
100-
ur_device_handle_t hDevice, access_mode_t access, size_t offset,
126+
ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
101127
size_t size, std::function<void(void *src, void *dst, size_t)> migrate) {
102128
std::ignore = hDevice;
103129
std::ignore = access;
@@ -108,9 +134,9 @@ void *ur_integrated_mem_handle_t::getDevicePtr(
108134
}
109135

110136
void *ur_integrated_mem_handle_t::mapHostPtr(
111-
access_mode_t access, size_t offset, size_t size,
137+
ur_map_flags_t flags, size_t offset, size_t size,
112138
std::function<void(void *src, void *dst, size_t)> migrate) {
113-
std::ignore = access;
139+
std::ignore = flags;
114140
std::ignore = offset;
115141
std::ignore = size;
116142
std::ignore = migrate;
@@ -178,9 +204,10 @@ ur_discrete_mem_handle_t::migrateBufferTo(ur_device_handle_t hDevice, void *src,
178204
return UR_RESULT_SUCCESS;
179205
}
180206

181-
ur_discrete_mem_handle_t::ur_discrete_mem_handle_t(ur_context_handle_t hContext,
182-
void *hostPtr, size_t size)
183-
: ur_mem_handle_t_(hContext, size),
207+
ur_discrete_mem_handle_t::ur_discrete_mem_handle_t(
208+
ur_context_handle_t hContext, void *hostPtr, size_t size,
209+
device_access_mode_t accessMode)
210+
: ur_mem_handle_t_(hContext, size, accessMode),
184211
deviceAllocations(hContext->getPlatform()->getNumDevices()),
185212
activeAllocationDevice(nullptr), hostAllocations() {
186213
if (hostPtr) {
@@ -189,12 +216,11 @@ ur_discrete_mem_handle_t::ur_discrete_mem_handle_t(ur_context_handle_t hContext,
189216
}
190217
}
191218

192-
ur_discrete_mem_handle_t::ur_discrete_mem_handle_t(ur_context_handle_t hContext,
193-
ur_device_handle_t hDevice,
194-
void *devicePtr, size_t size,
195-
void *writeBackMemory,
196-
bool ownZePtr)
197-
: ur_mem_handle_t_(hContext, size),
219+
ur_discrete_mem_handle_t::ur_discrete_mem_handle_t(
220+
ur_context_handle_t hContext, ur_device_handle_t hDevice, void *devicePtr,
221+
size_t size, device_access_mode_t accessMode, void *writeBackMemory,
222+
bool ownZePtr)
223+
: ur_mem_handle_t_(hContext, size, accessMode),
198224
deviceAllocations(hContext->getPlatform()->getNumDevices()),
199225
activeAllocationDevice(hDevice), writeBackPtr(writeBackMemory),
200226
hostAllocations() {
@@ -227,7 +253,7 @@ ur_discrete_mem_handle_t::~ur_discrete_mem_handle_t() {
227253
}
228254

229255
void *ur_discrete_mem_handle_t::getDevicePtr(
230-
ur_device_handle_t hDevice, access_mode_t access, size_t offset,
256+
ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
231257
size_t size, std::function<void(void *src, void *dst, size_t)> migrate) {
232258
TRACK_SCOPE_LATENCY("ur_discrete_mem_handle_t::getDevicePtr");
233259

@@ -265,19 +291,18 @@ void *ur_discrete_mem_handle_t::getDevicePtr(
265291
}
266292

267293
void *ur_discrete_mem_handle_t::mapHostPtr(
268-
access_mode_t access, size_t offset, size_t size,
294+
ur_map_flags_t flags, size_t offset, size_t size,
269295
std::function<void(void *src, void *dst, size_t)> migrate) {
270296
TRACK_SCOPE_LATENCY("ur_discrete_mem_handle_t::mapHostPtr");
271-
272297
// TODO: use async alloc?
273298

274299
void *ptr;
275300
UR_CALL_THROWS(hContext->getDefaultUSMPool()->allocate(
276301
hContext, nullptr, nullptr, UR_USM_TYPE_HOST, size, &ptr));
277302

278-
hostAllocations.emplace_back(ptr, size, offset, access);
303+
hostAllocations.emplace_back(ptr, size, offset, flags);
279304

280-
if (activeAllocationDevice && access != access_mode_t::write_only) {
305+
if (activeAllocationDevice && (flags & UR_MAP_FLAG_READ)) {
281306
auto srcPtr =
282307
ur_cast<char *>(
283308
deviceAllocations[activeAllocationDevice->Id.value()].get()) +
@@ -301,10 +326,11 @@ void ur_discrete_mem_handle_t::unmapHostPtr(
301326
ur_cast<char *>(
302327
deviceAllocations[activeAllocationDevice->Id.value()].get()) +
303328
hostAllocation.offset;
304-
} else if (hostAllocation.access != access_mode_t::write_invalidate) {
305-
devicePtr = ur_cast<char *>(
306-
getDevicePtr(hContext->getDevices()[0], access_mode_t::read_only,
307-
hostAllocation.offset, hostAllocation.size, migrate));
329+
} else if (!(hostAllocation.flags &
330+
UR_MAP_FLAG_WRITE_INVALIDATE_REGION)) {
331+
devicePtr = ur_cast<char *>(getDevicePtr(
332+
hContext->getDevices()[0], device_access_mode_t::read_only,
333+
hostAllocation.offset, hostAllocation.size, migrate));
308334
}
309335

310336
if (devicePtr) {
@@ -332,6 +358,46 @@ static bool useHostBuffer(ur_context_handle_t hContext) {
332358
ZE_DEVICE_PROPERTY_FLAG_INTEGRATED;
333359
}
334360

361+
namespace ur::level_zero {
362+
ur_result_t urMemRetain(ur_mem_handle_t hMem);
363+
ur_result_t urMemRelease(ur_mem_handle_t hMem);
364+
} // namespace ur::level_zero
365+
366+
ur_mem_sub_buffer_t::ur_mem_sub_buffer_t(ur_mem_handle_t hParent, size_t offset,
367+
size_t size,
368+
device_access_mode_t accessMode)
369+
: ur_mem_handle_t_(hParent->getContext(), size, accessMode),
370+
hParent(hParent), offset(offset), size(size) {
371+
ur::level_zero::urMemRetain(hParent);
372+
}
373+
374+
ur_mem_sub_buffer_t::~ur_mem_sub_buffer_t() {
375+
ur::level_zero::urMemRelease(hParent);
376+
}
377+
378+
void *ur_mem_sub_buffer_t::getDevicePtr(
379+
ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
380+
size_t size, std::function<void(void *src, void *dst, size_t)> migrate) {
381+
return hParent->getDevicePtr(hDevice, access, offset + this->offset, size,
382+
migrate);
383+
}
384+
385+
void *ur_mem_sub_buffer_t::mapHostPtr(
386+
ur_map_flags_t flags, size_t offset, size_t size,
387+
std::function<void(void *src, void *dst, size_t)> migrate) {
388+
return hParent->mapHostPtr(flags, offset + this->offset, size, migrate);
389+
}
390+
391+
void ur_mem_sub_buffer_t::unmapHostPtr(
392+
void *pMappedPtr,
393+
std::function<void(void *src, void *dst, size_t)> migrate) {
394+
return hParent->unmapHostPtr(pMappedPtr, migrate);
395+
}
396+
397+
size_t ur_mem_sub_buffer_t::getSize() const { return size; }
398+
399+
ur_shared_mutex &ur_mem_sub_buffer_t::getMutex() { return hParent->getMutex(); }
400+
335401
namespace ur::level_zero {
336402
ur_result_t urMemBufferCreate(ur_context_handle_t hContext,
337403
ur_mem_flags_t flags, size_t size,
@@ -347,6 +413,7 @@ ur_result_t urMemBufferCreate(ur_context_handle_t hContext,
347413
}
348414

349415
void *hostPtr = pProperties ? pProperties->pHost : nullptr;
416+
auto accessMode = getDeviceAccessMode(flags);
350417

351418
if (useHostBuffer(hContext)) {
352419
// TODO: assert that if hostPtr is set, either UR_MEM_FLAG_USE_HOST_POINTER
@@ -355,10 +422,11 @@ ur_result_t urMemBufferCreate(ur_context_handle_t hContext,
355422
flags & UR_MEM_FLAG_USE_HOST_POINTER
356423
? ur_integrated_mem_handle_t::host_ptr_action_t::import
357424
: ur_integrated_mem_handle_t::host_ptr_action_t::copy;
358-
*phBuffer =
359-
new ur_integrated_mem_handle_t(hContext, hostPtr, size, hostPtrAction);
425+
*phBuffer = new ur_integrated_mem_handle_t(hContext, hostPtr, size,
426+
hostPtrAction, accessMode);
360427
} else {
361-
*phBuffer = new ur_discrete_mem_handle_t(hContext, hostPtr, size);
428+
*phBuffer =
429+
new ur_discrete_mem_handle_t(hContext, hostPtr, size, accessMode);
362430
}
363431

364432
return UR_RESULT_SUCCESS;
@@ -368,13 +436,21 @@ ur_result_t urMemBufferPartition(ur_mem_handle_t hBuffer, ur_mem_flags_t flags,
368436
ur_buffer_create_type_t bufferCreateType,
369437
const ur_buffer_region_t *pRegion,
370438
ur_mem_handle_t *phMem) {
371-
std::ignore = hBuffer;
372-
std::ignore = flags;
373-
std::ignore = bufferCreateType;
374-
std::ignore = pRegion;
375-
std::ignore = phMem;
376-
logger::error("{} function not implemented!", __FUNCTION__);
377-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
439+
UR_ASSERT(bufferCreateType == UR_BUFFER_CREATE_TYPE_REGION,
440+
UR_RESULT_ERROR_INVALID_ENUMERATION);
441+
UR_ASSERT((pRegion->origin < hBuffer->getSize() &&
442+
pRegion->size <= hBuffer->getSize()),
443+
UR_RESULT_ERROR_INVALID_BUFFER_SIZE);
444+
445+
auto accessMode = getDeviceAccessMode(flags);
446+
447+
UR_ASSERT(isAccessCompatible(accessMode, hBuffer->getDeviceAccessMode()),
448+
UR_RESULT_ERROR_INVALID_VALUE);
449+
450+
*phMem = new ur_mem_sub_buffer_t(hBuffer, pRegion->origin, pRegion->size,
451+
accessMode);
452+
453+
return UR_RESULT_SUCCESS;
378454
}
379455

380456
ur_result_t urMemBufferCreateWithNativeHandle(
@@ -407,21 +483,24 @@ ur_result_t urMemBufferCreateWithNativeHandle(
407483
UR_RESULT_ERROR_INVALID_CONTEXT);
408484
}
409485

486+
// assume read-write
487+
auto accessMode = ur_mem_handle_t_::device_access_mode_t::read_write;
488+
410489
if (useHostBuffer(hContext) && memoryAttrs.type == ZE_MEMORY_TYPE_HOST) {
411-
*phMem =
412-
new ur_integrated_mem_handle_t(hContext, ptr, size, ownNativeHandle);
490+
*phMem = new ur_integrated_mem_handle_t(hContext, ptr, size, accessMode,
491+
ownNativeHandle);
413492
// if useHostBuffer(hContext) is true but the allocation is on device, we'll
414493
// treat it as discrete memory
415494
} else {
416495
if (memoryAttrs.type == ZE_MEMORY_TYPE_HOST) {
417496
// For host allocation, we need to copy the data to a device buffer
418497
// and then copy it back on release
419498
*phMem = new ur_discrete_mem_handle_t(hContext, hDevice, nullptr, size,
420-
ptr, ownNativeHandle);
499+
accessMode, ptr, ownNativeHandle);
421500
} else {
422501
// For device/shared allocation, we can use it directly
423-
*phMem = new ur_discrete_mem_handle_t(hContext, hDevice, ptr, size,
424-
nullptr, ownNativeHandle);
502+
*phMem = new ur_discrete_mem_handle_t(
503+
hContext, hDevice, ptr, size, accessMode, nullptr, ownNativeHandle);
425504
}
426505
}
427506

@@ -452,12 +531,12 @@ ur_result_t urMemGetInfo(ur_mem_handle_t hMemory, ur_mem_info_t propName,
452531
}
453532

454533
ur_result_t urMemRetain(ur_mem_handle_t hMem) {
455-
hMem->RefCount.increment();
534+
hMem->getRefCount().increment();
456535
return UR_RESULT_SUCCESS;
457536
}
458537

459538
ur_result_t urMemRelease(ur_mem_handle_t hMem) {
460-
if (hMem->RefCount.decrementAndTest()) {
539+
if (hMem->getRefCount().decrementAndTest()) {
461540
delete hMem;
462541
}
463542
return UR_RESULT_SUCCESS;
@@ -468,11 +547,11 @@ ur_result_t urMemGetNativeHandle(ur_mem_handle_t hMem,
468547
ur_native_handle_t *phNativeMem) {
469548
std::ignore = hDevice;
470549

471-
std::scoped_lock<ur_shared_mutex> lock(hMem->Mutex);
550+
std::scoped_lock<ur_shared_mutex> lock(hMem->getMutex());
472551

473-
auto ptr =
474-
hMem->getDevicePtr(nullptr, ur_mem_handle_t_::access_mode_t::read_write,
475-
0, hMem->getSize(), nullptr);
552+
auto ptr = hMem->getDevicePtr(
553+
nullptr, ur_mem_handle_t_::device_access_mode_t::read_write, 0,
554+
hMem->getSize(), nullptr);
476555
*phNativeMem = reinterpret_cast<ur_native_handle_t>(ptr);
477556
return UR_RESULT_SUCCESS;
478557
}

0 commit comments

Comments
 (0)