diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 097caf23edfa5..12bd02050be03 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -836,7 +836,11 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. Type elementPtrType = this->getElementPtrType(memRefType); - auto stream = adaptor.getAsyncDependencies().front(); + + auto nullPtr = rewriter.create(loc, llvmPointerType); + Value stream = adaptor.getAsyncDependencies().empty() + ? nullPtr + : adaptor.getAsyncDependencies().front(); auto isHostShared = rewriter.create( loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared)); @@ -855,7 +859,12 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( auto memRefDescriptor = this->createMemRefDescriptor( loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter); - rewriter.replaceOp(allocOp, {memRefDescriptor, stream}); + if (allocOp.getAsyncToken()) { + // Async alloc: make dependent ops use the same stream. + rewriter.replaceOp(allocOp, {memRefDescriptor, stream}); + } else { + rewriter.replaceOp(allocOp, {memRefDescriptor}); + } return success(); } diff --git a/mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir index f365dcb02daf4..70450656b9df6 100644 --- a/mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir +++ b/mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir @@ -19,4 +19,22 @@ module attributes {gpu.container_module} { gpu.wait [%3] return } + + // CHECK-LABEL: llvm.func @alloc_sync + // CHECK-SAME: %[[size:.*]]: i64 + func.func @alloc_sync(%size : index) { + // CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}}[%[[size]]] + // CHECK: %[[size_bytes:.*]] = llvm.ptrtoint %[[gep]] + // CHECK: %[[nullptr:.*]] = llvm.mlir.zero + // CHECK: %[[isHostShared:.*]] = llvm.mlir.constant + // CHECK: llvm.call @mgpuMemAlloc(%[[size_bytes]], %[[nullptr]], %[[isHostShared]]) + %0 = gpu.alloc host_shared (%size) : memref + // CHECK: %[[stream:.*]] = llvm.call @mgpuStreamCreate() + %1 = gpu.wait async + %2 = gpu.dealloc async [%1] %0 : memref + // CHECK: llvm.call @mgpuStreamSynchronize(%[[stream]]) + // CHECK: llvm.call @mgpuStreamDestroy(%[[stream]]) + gpu.wait [%2] + return + } }