@@ -217,12 +217,93 @@ static int accelerator_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type,
217
217
return 0 ;
218
218
}
219
219
220
+ static int accelerator_cuda_check_mpool (CUdeviceptr dbuf , CUmemorytype * mem_type ,
221
+ int * dev_id )
222
+ {
223
+ #if OPAL_CUDA_VMM_SUPPORT
224
+ static int device_count = -1 ;
225
+ static int mpool_supported = -1 ;
226
+ CUresult result ;
227
+ CUmemoryPool mpool ;
228
+ CUmemAccess_flags flags ;
229
+ CUmemLocation location ;
230
+
231
+ if (device_count == -1 ) {
232
+ result = cuDeviceGetCount (& device_count );
233
+ if (result != CUDA_SUCCESS ) {
234
+ return 0 ;
235
+ }
236
+ }
237
+
238
+ if (mpool_supported == -1 ) {
239
+ /* assume uniformity of devices */
240
+ result = cuDeviceGetAttribute (& mpool_supported ,
241
+ CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED , 0 );
242
+ if (result != CUDA_SUCCESS ) {
243
+ return 0 ;
244
+ }
245
+ }
246
+
247
+ if (mpool_supported == 0 ) {
248
+ return 0 ;
249
+ }
250
+
251
+ result = cuPointerGetAttribute (& mpool , CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE ,
252
+ dbuf );
253
+ if (CUDA_SUCCESS != result ) {
254
+ return 0 ;
255
+ }
256
+
257
+ /* check if device has access */
258
+ for (int i = 0 ; i < device_count ; i ++ ) {
259
+ location .type = CU_MEM_LOCATION_TYPE_DEVICE ;
260
+ location .id = i ;
261
+ result = cuMemPoolGetAccess (& flags , mpool , & location );
262
+ if ((CUDA_SUCCESS == result ) &&
263
+ (CU_MEM_ACCESS_FLAGS_PROT_READWRITE == flags )) {
264
+ * mem_type = CU_MEMORYTYPE_DEVICE ;
265
+ * dev_id = i ;
266
+ return 1 ;
267
+ }
268
+ }
269
+
270
+ /* host must have access as device access possibility is exhausted */
271
+ * mem_type = CU_MEMORYTYPE_HOST ;
272
+ * dev_id = MCA_ACCELERATOR_NO_DEVICE_ID ;
273
+ return 0 ;
274
+ #endif
275
+
276
+ return 0 ;
277
+ }
278
+
279
+ static int accelerator_cuda_get_primary_context (CUdevice dev_id , CUcontext * pctx )
280
+ {
281
+ CUresult result ;
282
+ unsigned int flags ;
283
+ int active ;
284
+
285
+ result = cuDevicePrimaryCtxGetState (dev_id , & flags , & active );
286
+ if (CUDA_SUCCESS != result ) {
287
+ return OPAL_ERROR ;
288
+ }
289
+
290
+ if (active ) {
291
+ result = cuDevicePrimaryCtxRetain (pctx , dev_id );
292
+ return OPAL_SUCCESS ;
293
+ }
294
+
295
+ return OPAL_ERROR ;
296
+ }
297
+
220
298
static int accelerator_cuda_check_addr (const void * addr , int * dev_id , uint64_t * flags )
221
299
{
222
300
CUresult result ;
223
301
int is_vmm = 0 ;
302
+ int is_mpool_ptr = 0 ;
224
303
int vmm_dev_id = MCA_ACCELERATOR_NO_DEVICE_ID ;
304
+ int mpool_dev_id = MCA_ACCELERATOR_NO_DEVICE_ID ;
225
305
CUmemorytype vmm_mem_type = 0 ;
306
+ CUmemorytype mpool_mem_type = 0 ;
226
307
CUmemorytype mem_type = 0 ;
227
308
CUdeviceptr dbuf = (CUdeviceptr ) addr ;
228
309
CUcontext ctx = NULL , mem_ctx = NULL ;
@@ -235,6 +316,7 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
235
316
* flags = 0 ;
236
317
237
318
is_vmm = accelerator_cuda_check_vmm (dbuf , & vmm_mem_type , & vmm_dev_id );
319
+ is_mpool_ptr = accelerator_cuda_check_mpool (dbuf , & mpool_mem_type , & mpool_dev_id );
238
320
239
321
#if OPAL_CUDA_GET_ATTRIBUTES
240
322
uint32_t is_managed = 0 ;
@@ -268,6 +350,9 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
268
350
if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE )) {
269
351
mem_type = CU_MEMORYTYPE_DEVICE ;
270
352
* dev_id = vmm_dev_id ;
353
+ } else if (is_mpool_ptr && (mpool_mem_type == CU_MEMORYTYPE_DEVICE )) {
354
+ mem_type = CU_MEMORYTYPE_DEVICE ;
355
+ * dev_id = mpool_dev_id ;
271
356
} else {
272
357
/* Host memory, nothing to do here */
273
358
return 0 ;
@@ -278,6 +363,8 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
278
363
} else {
279
364
if (is_vmm ) {
280
365
* dev_id = vmm_dev_id ;
366
+ } else if (is_mpool_ptr ) {
367
+ * dev_id = mpool_dev_id ;
281
368
} else {
282
369
/* query the device from the context */
283
370
* dev_id = accelerator_cuda_get_device_id (mem_ctx );
@@ -296,13 +383,18 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
296
383
if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE )) {
297
384
mem_type = CU_MEMORYTYPE_DEVICE ;
298
385
* dev_id = vmm_dev_id ;
386
+ } else if (is_mpool_ptr && (mpool_mem_type == CU_MEMORYTYPE_DEVICE )) {
387
+ mem_type = CU_MEMORYTYPE_DEVICE ;
388
+ * dev_id = mpool_dev_id ;
299
389
} else {
300
390
/* Host memory, nothing to do here */
301
391
return 0 ;
302
392
}
303
393
} else {
304
394
if (is_vmm ) {
305
395
* dev_id = vmm_dev_id ;
396
+ } else if (is_mpool_ptr ) {
397
+ * dev_id = mpool_dev_id ;
306
398
} else {
307
399
result = cuPointerGetAttribute (& mem_ctx ,
308
400
CU_POINTER_ATTRIBUTE_CONTEXT , dbuf );
@@ -336,14 +428,18 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
336
428
return OPAL_ERROR ;
337
429
}
338
430
#endif /* OPAL_CUDA_GET_ATTRIBUTES */
339
- if (is_vmm ) {
340
- /* This function is expected to set context if pointer is device
341
- * accessible but VMM allocations have NULL context associated
342
- * which cannot be set against the calling thread */
343
- opal_output (0 ,
344
- "CUDA: unable to set context with the given pointer"
345
- "ptr=%p aborting..." , addr );
346
- return OPAL_ERROR ;
431
+ if (is_vmm || is_mpool_ptr ) {
432
+ if (OPAL_SUCCESS ==
433
+ accelerator_cuda_get_primary_context (
434
+ is_vmm ? vmm_dev_id : mpool_dev_id , & mem_ctx )) {
435
+ /* As VMM/mempool allocations have no context associated
436
+ * with them, check if device primary context can be set */
437
+ } else {
438
+ opal_output (0 ,
439
+ "CUDA: unable to set ctx with the given pointer"
440
+ "ptr=%p aborting..." , addr );
441
+ return OPAL_ERROR ;
442
+ }
347
443
}
348
444
349
445
result = cuCtxSetCurrent (mem_ctx );
0 commit comments