@@ -185,8 +185,15 @@ static ur_result_t USMDeviceAllocImpl(void **ResultPtr,
185
185
ZeDesc.pNext = &RelaxedDesc;
186
186
}
187
187
188
- ZE2UR_CALL (zeMemAllocDevice, (Context->ZeContext , &ZeDesc, Size, Alignment,
189
- Device->ZeDevice , ResultPtr));
188
+ ze_result_t ZeResult =
189
+ zeMemAllocDevice (Context->ZeContext , &ZeDesc, Size, Alignment,
190
+ Device->ZeDevice , ResultPtr);
191
+ if (ZeResult != ZE_RESULT_SUCCESS) {
192
+ if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_SIZE) {
193
+ return UR_RESULT_ERROR_INVALID_USM_SIZE;
194
+ }
195
+ return ze2urResult (ZeResult);
196
+ }
190
197
191
198
UR_ASSERT (Alignment == 0 ||
192
199
reinterpret_cast <std::uintptr_t >(*ResultPtr) % Alignment == 0 ,
@@ -224,8 +231,15 @@ static ur_result_t USMSharedAllocImpl(void **ResultPtr,
224
231
ZeDevDesc.pNext = &RelaxedDesc;
225
232
}
226
233
227
- ZE2UR_CALL (zeMemAllocShared, (Context->ZeContext , &ZeDevDesc, &ZeHostDesc,
228
- Size, Alignment, Device->ZeDevice , ResultPtr));
234
+ ze_result_t ZeResult =
235
+ zeMemAllocShared (Context->ZeContext , &ZeDevDesc, &ZeHostDesc, Size,
236
+ Alignment, Device->ZeDevice , ResultPtr);
237
+ if (ZeResult != ZE_RESULT_SUCCESS) {
238
+ if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_SIZE) {
239
+ return UR_RESULT_ERROR_INVALID_USM_SIZE;
240
+ }
241
+ return ze2urResult (ZeResult);
242
+ }
229
243
230
244
UR_ASSERT (Alignment == 0 ||
231
245
reinterpret_cast <std::uintptr_t >(*ResultPtr) % Alignment == 0 ,
@@ -252,8 +266,14 @@ static ur_result_t USMHostAllocImpl(void **ResultPtr,
252
266
// TODO: translate PI properties to Level Zero flags
253
267
ZeStruct<ze_host_mem_alloc_desc_t > ZeHostDesc;
254
268
ZeHostDesc.flags = 0 ;
255
- ZE2UR_CALL (zeMemAllocHost,
256
- (Context->ZeContext , &ZeHostDesc, Size, Alignment, ResultPtr));
269
+ ze_result_t ZeResult = zeMemAllocHost (Context->ZeContext , &ZeHostDesc, Size,
270
+ Alignment, ResultPtr);
271
+ if (ZeResult != ZE_RESULT_SUCCESS) {
272
+ if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_SIZE) {
273
+ return UR_RESULT_ERROR_INVALID_USM_SIZE;
274
+ }
275
+ return ze2urResult (ZeResult);
276
+ }
257
277
258
278
UR_ASSERT (Alignment == 0 ||
259
279
reinterpret_cast <std::uintptr_t >(*ResultPtr) % Alignment == 0 ,
@@ -597,6 +617,40 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMGetMemAllocInfo(
597
617
ZE2UR_CALL (zeMemGetAddressRange, (Context->ZeContext , Ptr, nullptr , &Size));
598
618
return ReturnValue (Size);
599
619
}
620
+ case UR_USM_ALLOC_INFO_POOL: {
621
+ auto UMFPool = umfPoolByPtr (Ptr);
622
+ if (!UMFPool) {
623
+ return UR_RESULT_ERROR_INVALID_VALUE;
624
+ }
625
+
626
+ std::shared_lock<ur_shared_mutex> ContextLock (Context->Mutex );
627
+
628
+ auto SearchMatchingPool =
629
+ [](std::unordered_map<ur_device_handle_t , umf::pool_unique_handle_t >
630
+ &PoolMap,
631
+ umf_memory_pool_handle_t UMFPool) {
632
+ for (auto &PoolPair : PoolMap) {
633
+ if (PoolPair.second .get () == UMFPool) {
634
+ return true ;
635
+ }
636
+ }
637
+ return false ;
638
+ };
639
+
640
+ for (auto &Pool : Context->UsmPoolHandles ) {
641
+ if (SearchMatchingPool (Pool->DeviceMemPools , UMFPool)) {
642
+ return ReturnValue (Pool);
643
+ }
644
+ if (SearchMatchingPool (Pool->SharedMemPools , UMFPool)) {
645
+ return ReturnValue (Pool);
646
+ }
647
+ if (Pool->HostMemPool .get () == UMFPool) {
648
+ return ReturnValue (Pool);
649
+ }
650
+ }
651
+
652
+ return UR_RESULT_ERROR_INVALID_VALUE;
653
+ }
600
654
default :
601
655
urPrint (" urUSMGetMemAllocInfo: unsupported ParamName\n " );
602
656
return UR_RESULT_ERROR_INVALID_VALUE;
@@ -746,6 +800,7 @@ ur_result_t L0HostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size,
746
800
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_ (ur_context_handle_t Context,
747
801
ur_usm_pool_desc_t *PoolDesc) {
748
802
803
+ this ->Context = Context;
749
804
zeroInit = static_cast <uint32_t >(PoolDesc->flags &
750
805
UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK);
751
806
@@ -829,6 +884,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreate(
829
884
try {
830
885
*Pool = reinterpret_cast <ur_usm_pool_handle_t >(
831
886
new ur_usm_pool_handle_t_ (Context, PoolDesc));
887
+
888
+ std::shared_lock<ur_shared_mutex> ContextLock (Context->Mutex );
889
+ Context->UsmPoolHandles .insert (Context->UsmPoolHandles .cend (), *Pool);
890
+
832
891
} catch (const UsmAllocationException &Ex) {
833
892
return Ex.getError ();
834
893
}
@@ -846,6 +905,8 @@ ur_result_t
846
905
urUSMPoolRelease (ur_usm_pool_handle_t Pool // /< [in] pointer to USM memory pool
847
906
) {
848
907
if (Pool->RefCount .decrementAndTest ()) {
908
+ std::shared_lock<ur_shared_mutex> ContextLock (Pool->Context ->Mutex );
909
+ Pool->Context ->UsmPoolHandles .remove (Pool);
849
910
delete Pool;
850
911
}
851
912
return UR_RESULT_SUCCESS;
@@ -859,13 +920,19 @@ ur_result_t urUSMPoolGetInfo(
859
920
// /< property
860
921
size_t *PropSizeRet // /< [out] size in bytes returned in pool property value
861
922
) {
862
- std::ignore = Pool;
863
- std::ignore = PropName;
864
- std::ignore = PropSize;
865
- std::ignore = PropValue;
866
- std::ignore = PropSizeRet;
867
- urPrint (" [UR][L0] %s function not implemented!\n " , __FUNCTION__);
868
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
923
+ UrReturnHelper ReturnValue (PropSize, PropValue, PropSizeRet);
924
+
925
+ switch (PropName) {
926
+ case UR_USM_POOL_INFO_REFERENCE_COUNT: {
927
+ return ReturnValue (Pool->RefCount .load ());
928
+ }
929
+ case UR_USM_POOL_INFO_CONTEXT: {
930
+ return ReturnValue (Pool->Context );
931
+ }
932
+ default : {
933
+ return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
934
+ }
935
+ }
869
936
}
870
937
871
938
// If indirect access tracking is not enabled then this functions just performs
0 commit comments