13
13
14
14
#include " ../helpers/memory_helpers.hpp"
15
15
16
- ur_mem_handle_t_::ur_mem_handle_t_ (ur_context_handle_t hContext, size_t size)
17
- : hContext(hContext), size(size) {}
16
+ static ur_mem_handle_t_::device_access_mode_t
17
+ getDeviceAccessMode (ur_mem_flags_t memFlag) {
18
+ if (memFlag & UR_MEM_FLAG_READ_WRITE) {
19
+ return ur_mem_handle_t_::device_access_mode_t ::read_write;
20
+ } else if (memFlag & UR_MEM_FLAG_READ_ONLY) {
21
+ return ur_mem_handle_t_::device_access_mode_t ::read_only;
22
+ } else if (memFlag & UR_MEM_FLAG_WRITE_ONLY) {
23
+ return ur_mem_handle_t_::device_access_mode_t ::write_only;
24
+ } else {
25
+ return ur_mem_handle_t_::device_access_mode_t ::read_write;
26
+ }
27
+ }
28
+
29
+ static bool isAccessCompatible (ur_mem_handle_t_::device_access_mode_t requested,
30
+ ur_mem_handle_t_::device_access_mode_t actual) {
31
+ return requested == actual ||
32
+ actual == ur_mem_handle_t_::device_access_mode_t ::read_write;
33
+ }
34
+
35
+ ur_mem_handle_t_::ur_mem_handle_t_ (ur_context_handle_t hContext, size_t size,
36
+ device_access_mode_t accessMode)
37
+ : accessMode(accessMode), hContext(hContext), size(size) {}
38
+
39
+ size_t ur_mem_handle_t_::getSize () const { return size; }
40
+
41
+ ur_shared_mutex &ur_mem_handle_t_::getMutex () { return Mutex; }
18
42
19
43
ur_usm_handle_t_::ur_usm_handle_t_ (ur_context_handle_t hContext, size_t size,
20
44
const void *ptr)
21
- : ur_mem_handle_t_(hContext, size), ptr(const_cast <void *>(ptr)) {}
45
+ : ur_mem_handle_t_(hContext, size, device_access_mode_t ::read_write),
46
+ ptr(const_cast <void *>(ptr)) {}
22
47
23
48
ur_usm_handle_t_::~ur_usm_handle_t_ () {}
24
49
25
50
void *ur_usm_handle_t_::getDevicePtr (
26
- ur_device_handle_t hDevice, access_mode_t access, size_t offset,
51
+ ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
27
52
size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
28
53
std::ignore = hDevice;
29
54
std::ignore = access;
@@ -34,9 +59,9 @@ void *ur_usm_handle_t_::getDevicePtr(
34
59
}
35
60
36
61
void *ur_usm_handle_t_::mapHostPtr (
37
- access_mode_t access , size_t offset, size_t size,
62
+ ur_map_flags_t flags , size_t offset, size_t size,
38
63
std::function<void (void *src, void *dst, size_t )>) {
39
- std::ignore = access ;
64
+ std::ignore = flags ;
40
65
std::ignore = offset;
41
66
std::ignore = size;
42
67
return ptr;
@@ -50,8 +75,8 @@ void ur_usm_handle_t_::unmapHostPtr(
50
75
51
76
ur_integrated_mem_handle_t ::ur_integrated_mem_handle_t (
52
77
ur_context_handle_t hContext, void *hostPtr, size_t size,
53
- host_ptr_action_t hostPtrAction)
54
- : ur_mem_handle_t_(hContext, size) {
78
+ host_ptr_action_t hostPtrAction, device_access_mode_t accessMode )
79
+ : ur_mem_handle_t_(hContext, size, accessMode ) {
55
80
bool hostPtrImported = false ;
56
81
if (hostPtrAction == host_ptr_action_t ::import ) {
57
82
hostPtrImported =
@@ -83,8 +108,9 @@ ur_integrated_mem_handle_t::ur_integrated_mem_handle_t(
83
108
}
84
109
85
110
ur_integrated_mem_handle_t ::ur_integrated_mem_handle_t (
86
- ur_context_handle_t hContext, void *hostPtr, size_t size, bool ownHostPtr)
87
- : ur_mem_handle_t_(hContext, size) {
111
+ ur_context_handle_t hContext, void *hostPtr, size_t size,
112
+ device_access_mode_t accessMode, bool ownHostPtr)
113
+ : ur_mem_handle_t_(hContext, size, accessMode) {
88
114
this ->ptr = usm_unique_ptr_t (hostPtr, [hContext, ownHostPtr](void *ptr) {
89
115
if (!ownHostPtr) {
90
116
return ;
@@ -97,7 +123,7 @@ ur_integrated_mem_handle_t::ur_integrated_mem_handle_t(
97
123
}
98
124
99
125
void *ur_integrated_mem_handle_t ::getDevicePtr(
100
- ur_device_handle_t hDevice, access_mode_t access, size_t offset,
126
+ ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
101
127
size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
102
128
std::ignore = hDevice;
103
129
std::ignore = access;
@@ -108,9 +134,9 @@ void *ur_integrated_mem_handle_t::getDevicePtr(
108
134
}
109
135
110
136
void *ur_integrated_mem_handle_t ::mapHostPtr(
111
- access_mode_t access , size_t offset, size_t size,
137
+ ur_map_flags_t flags , size_t offset, size_t size,
112
138
std::function<void (void *src, void *dst, size_t )> migrate) {
113
- std::ignore = access ;
139
+ std::ignore = flags ;
114
140
std::ignore = offset;
115
141
std::ignore = size;
116
142
std::ignore = migrate;
@@ -178,9 +204,10 @@ ur_discrete_mem_handle_t::migrateBufferTo(ur_device_handle_t hDevice, void *src,
178
204
return UR_RESULT_SUCCESS;
179
205
}
180
206
181
- ur_discrete_mem_handle_t ::ur_discrete_mem_handle_t (ur_context_handle_t hContext,
182
- void *hostPtr, size_t size)
183
- : ur_mem_handle_t_(hContext, size),
207
+ ur_discrete_mem_handle_t ::ur_discrete_mem_handle_t (
208
+ ur_context_handle_t hContext, void *hostPtr, size_t size,
209
+ device_access_mode_t accessMode)
210
+ : ur_mem_handle_t_(hContext, size, accessMode),
184
211
deviceAllocations (hContext->getPlatform ()->getNumDevices()),
185
212
activeAllocationDevice(nullptr ), hostAllocations() {
186
213
if (hostPtr) {
@@ -189,12 +216,11 @@ ur_discrete_mem_handle_t::ur_discrete_mem_handle_t(ur_context_handle_t hContext,
189
216
}
190
217
}
191
218
192
- ur_discrete_mem_handle_t ::ur_discrete_mem_handle_t (ur_context_handle_t hContext,
193
- ur_device_handle_t hDevice,
194
- void *devicePtr, size_t size,
195
- void *writeBackMemory,
196
- bool ownZePtr)
197
- : ur_mem_handle_t_(hContext, size),
219
+ ur_discrete_mem_handle_t ::ur_discrete_mem_handle_t (
220
+ ur_context_handle_t hContext, ur_device_handle_t hDevice, void *devicePtr,
221
+ size_t size, device_access_mode_t accessMode, void *writeBackMemory,
222
+ bool ownZePtr)
223
+ : ur_mem_handle_t_(hContext, size, accessMode),
198
224
deviceAllocations(hContext->getPlatform ()->getNumDevices()),
199
225
activeAllocationDevice(hDevice), writeBackPtr(writeBackMemory),
200
226
hostAllocations() {
@@ -227,7 +253,7 @@ ur_discrete_mem_handle_t::~ur_discrete_mem_handle_t() {
227
253
}
228
254
229
255
void *ur_discrete_mem_handle_t ::getDevicePtr(
230
- ur_device_handle_t hDevice, access_mode_t access, size_t offset,
256
+ ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
231
257
size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
232
258
TRACK_SCOPE_LATENCY (" ur_discrete_mem_handle_t::getDevicePtr" );
233
259
@@ -265,19 +291,18 @@ void *ur_discrete_mem_handle_t::getDevicePtr(
265
291
}
266
292
267
293
void *ur_discrete_mem_handle_t ::mapHostPtr(
268
- access_mode_t access , size_t offset, size_t size,
294
+ ur_map_flags_t flags , size_t offset, size_t size,
269
295
std::function<void (void *src, void *dst, size_t )> migrate) {
270
296
TRACK_SCOPE_LATENCY (" ur_discrete_mem_handle_t::mapHostPtr" );
271
-
272
297
// TODO: use async alloc?
273
298
274
299
void *ptr;
275
300
UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
276
301
hContext, nullptr , nullptr , UR_USM_TYPE_HOST, size, &ptr));
277
302
278
- hostAllocations.emplace_back (ptr, size, offset, access );
303
+ hostAllocations.emplace_back (ptr, size, offset, flags );
279
304
280
- if (activeAllocationDevice && access != access_mode_t ::write_only ) {
305
+ if (activeAllocationDevice && (flags & UR_MAP_FLAG_READ) ) {
281
306
auto srcPtr =
282
307
ur_cast<char *>(
283
308
deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
@@ -301,10 +326,11 @@ void ur_discrete_mem_handle_t::unmapHostPtr(
301
326
ur_cast<char *>(
302
327
deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
303
328
hostAllocation.offset ;
304
- } else if (hostAllocation.access != access_mode_t ::write_invalidate) {
305
- devicePtr = ur_cast<char *>(
306
- getDevicePtr (hContext->getDevices ()[0 ], access_mode_t ::read_only,
307
- hostAllocation.offset , hostAllocation.size , migrate));
329
+ } else if (!(hostAllocation.flags &
330
+ UR_MAP_FLAG_WRITE_INVALIDATE_REGION)) {
331
+ devicePtr = ur_cast<char *>(getDevicePtr (
332
+ hContext->getDevices ()[0 ], device_access_mode_t ::read_only,
333
+ hostAllocation.offset , hostAllocation.size , migrate));
308
334
}
309
335
310
336
if (devicePtr) {
@@ -332,6 +358,46 @@ static bool useHostBuffer(ur_context_handle_t hContext) {
332
358
ZE_DEVICE_PROPERTY_FLAG_INTEGRATED;
333
359
}
334
360
361
+ namespace ur ::level_zero {
362
+ ur_result_t urMemRetain (ur_mem_handle_t hMem);
363
+ ur_result_t urMemRelease (ur_mem_handle_t hMem);
364
+ } // namespace ur::level_zero
365
+
366
+ ur_mem_sub_buffer_t ::ur_mem_sub_buffer_t (ur_mem_handle_t hParent, size_t offset,
367
+ size_t size,
368
+ device_access_mode_t accessMode)
369
+ : ur_mem_handle_t_(hParent->getContext (), size, accessMode),
370
+ hParent(hParent), offset(offset), size(size) {
371
+ ur::level_zero::urMemRetain (hParent);
372
+ }
373
+
374
+ ur_mem_sub_buffer_t ::~ur_mem_sub_buffer_t () {
375
+ ur::level_zero::urMemRelease (hParent);
376
+ }
377
+
378
+ void *ur_mem_sub_buffer_t ::getDevicePtr(
379
+ ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
380
+ size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
381
+ return hParent->getDevicePtr (hDevice, access, offset + this ->offset , size,
382
+ migrate);
383
+ }
384
+
385
+ void *ur_mem_sub_buffer_t ::mapHostPtr(
386
+ ur_map_flags_t flags, size_t offset, size_t size,
387
+ std::function<void (void *src, void *dst, size_t )> migrate) {
388
+ return hParent->mapHostPtr (flags, offset + this ->offset , size, migrate);
389
+ }
390
+
391
+ void ur_mem_sub_buffer_t::unmapHostPtr (
392
+ void *pMappedPtr,
393
+ std::function<void (void *src, void *dst, size_t )> migrate) {
394
+ return hParent->unmapHostPtr (pMappedPtr, migrate);
395
+ }
396
+
397
+ size_t ur_mem_sub_buffer_t::getSize () const { return size; }
398
+
399
+ ur_shared_mutex &ur_mem_sub_buffer_t ::getMutex() { return hParent->getMutex (); }
400
+
335
401
namespace ur ::level_zero {
336
402
ur_result_t urMemBufferCreate (ur_context_handle_t hContext,
337
403
ur_mem_flags_t flags, size_t size,
@@ -347,6 +413,7 @@ ur_result_t urMemBufferCreate(ur_context_handle_t hContext,
347
413
}
348
414
349
415
void *hostPtr = pProperties ? pProperties->pHost : nullptr ;
416
+ auto accessMode = getDeviceAccessMode (flags);
350
417
351
418
if (useHostBuffer (hContext)) {
352
419
// TODO: assert that if hostPtr is set, either UR_MEM_FLAG_USE_HOST_POINTER
@@ -355,10 +422,11 @@ ur_result_t urMemBufferCreate(ur_context_handle_t hContext,
355
422
flags & UR_MEM_FLAG_USE_HOST_POINTER
356
423
? ur_integrated_mem_handle_t ::host_ptr_action_t ::import
357
424
: ur_integrated_mem_handle_t ::host_ptr_action_t ::copy;
358
- *phBuffer =
359
- new ur_integrated_mem_handle_t (hContext, hostPtr, size, hostPtrAction);
425
+ *phBuffer = new ur_integrated_mem_handle_t (hContext, hostPtr, size,
426
+ hostPtrAction, accessMode );
360
427
} else {
361
- *phBuffer = new ur_discrete_mem_handle_t (hContext, hostPtr, size);
428
+ *phBuffer =
429
+ new ur_discrete_mem_handle_t (hContext, hostPtr, size, accessMode);
362
430
}
363
431
364
432
return UR_RESULT_SUCCESS;
@@ -368,13 +436,21 @@ ur_result_t urMemBufferPartition(ur_mem_handle_t hBuffer, ur_mem_flags_t flags,
368
436
ur_buffer_create_type_t bufferCreateType,
369
437
const ur_buffer_region_t *pRegion,
370
438
ur_mem_handle_t *phMem) {
371
- std::ignore = hBuffer;
372
- std::ignore = flags;
373
- std::ignore = bufferCreateType;
374
- std::ignore = pRegion;
375
- std::ignore = phMem;
376
- logger::error (" {} function not implemented!" , __FUNCTION__);
377
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
439
+ UR_ASSERT (bufferCreateType == UR_BUFFER_CREATE_TYPE_REGION,
440
+ UR_RESULT_ERROR_INVALID_ENUMERATION);
441
+ UR_ASSERT ((pRegion->origin < hBuffer->getSize () &&
442
+ pRegion->size <= hBuffer->getSize ()),
443
+ UR_RESULT_ERROR_INVALID_BUFFER_SIZE);
444
+
445
+ auto accessMode = getDeviceAccessMode (flags);
446
+
447
+ UR_ASSERT (isAccessCompatible (accessMode, hBuffer->getDeviceAccessMode ()),
448
+ UR_RESULT_ERROR_INVALID_VALUE);
449
+
450
+ *phMem = new ur_mem_sub_buffer_t (hBuffer, pRegion->origin , pRegion->size ,
451
+ accessMode);
452
+
453
+ return UR_RESULT_SUCCESS;
378
454
}
379
455
380
456
ur_result_t urMemBufferCreateWithNativeHandle (
@@ -407,21 +483,24 @@ ur_result_t urMemBufferCreateWithNativeHandle(
407
483
UR_RESULT_ERROR_INVALID_CONTEXT);
408
484
}
409
485
486
+ // assume read-write
487
+ auto accessMode = ur_mem_handle_t_::device_access_mode_t ::read_write;
488
+
410
489
if (useHostBuffer (hContext) && memoryAttrs.type == ZE_MEMORY_TYPE_HOST) {
411
- *phMem =
412
- new ur_integrated_mem_handle_t (hContext, ptr, size, ownNativeHandle);
490
+ *phMem = new ur_integrated_mem_handle_t (hContext, ptr, size, accessMode,
491
+ ownNativeHandle);
413
492
// if useHostBuffer(hContext) is true but the allocation is on device, we'll
414
493
// treat it as discrete memory
415
494
} else {
416
495
if (memoryAttrs.type == ZE_MEMORY_TYPE_HOST) {
417
496
// For host allocation, we need to copy the data to a device buffer
418
497
// and then copy it back on release
419
498
*phMem = new ur_discrete_mem_handle_t (hContext, hDevice, nullptr , size,
420
- ptr, ownNativeHandle);
499
+ accessMode, ptr, ownNativeHandle);
421
500
} else {
422
501
// For device/shared allocation, we can use it directly
423
- *phMem = new ur_discrete_mem_handle_t (hContext, hDevice, ptr, size,
424
- nullptr , ownNativeHandle);
502
+ *phMem = new ur_discrete_mem_handle_t (
503
+ hContext, hDevice, ptr, size, accessMode, nullptr , ownNativeHandle);
425
504
}
426
505
}
427
506
@@ -452,12 +531,12 @@ ur_result_t urMemGetInfo(ur_mem_handle_t hMemory, ur_mem_info_t propName,
452
531
}
453
532
454
533
ur_result_t urMemRetain (ur_mem_handle_t hMem) {
455
- hMem->RefCount .increment ();
534
+ hMem->getRefCount () .increment ();
456
535
return UR_RESULT_SUCCESS;
457
536
}
458
537
459
538
ur_result_t urMemRelease (ur_mem_handle_t hMem) {
460
- if (hMem->RefCount .decrementAndTest ()) {
539
+ if (hMem->getRefCount () .decrementAndTest ()) {
461
540
delete hMem;
462
541
}
463
542
return UR_RESULT_SUCCESS;
@@ -468,11 +547,11 @@ ur_result_t urMemGetNativeHandle(ur_mem_handle_t hMem,
468
547
ur_native_handle_t *phNativeMem) {
469
548
std::ignore = hDevice;
470
549
471
- std::scoped_lock<ur_shared_mutex> lock (hMem->Mutex );
550
+ std::scoped_lock<ur_shared_mutex> lock (hMem->getMutex () );
472
551
473
- auto ptr =
474
- hMem-> getDevicePtr ( nullptr , ur_mem_handle_t_::access_mode_t ::read_write,
475
- 0 , hMem->getSize (), nullptr );
552
+ auto ptr = hMem-> getDevicePtr (
553
+ nullptr , ur_mem_handle_t_::device_access_mode_t ::read_write, 0 ,
554
+ hMem->getSize (), nullptr );
476
555
*phNativeMem = reinterpret_cast <ur_native_handle_t >(ptr);
477
556
return UR_RESULT_SUCCESS;
478
557
}
0 commit comments