Skip to content

Commit ee42162

Browse files
Kh4LNicolasHug
andauthored
Add support for CUDA >= 12.9 (#757)
Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 8ab84dd commit ee42162

File tree

3 files changed

+56
-15
lines changed

3 files changed

+56
-15
lines changed

.github/workflows/linux_cuda_wheel.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ jobs:
6767
# For the actual release we should add that label and change this to
6868
# include more python versions.
6969
python-version: ['3.9']
70-
cuda-version: ['12.6', '12.8']
70+
# We test against 12.6 and 12.9 to avoid having too big of a CI matrix,
71+
# but for releases we should add 12.8.
72+
cuda-version: ['12.6', '12.9']
7173
# TODO: put back ffmpeg 5 https://github.com/pytorch/torchcodec/issues/325
7274
ffmpeg-version-for-tests: ['4.4.2', '6', '7']
7375

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,44 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
161161
device, nonNegativeDeviceIndex, type);
162162
#endif
163163
}
164+
165+
NppStreamContext createNppStreamContext(int deviceIndex) {
166+
// From 12.9, NPP recommends using a user-created NppStreamContext and using
167+
// the `_Ctx()` calls:
168+
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#npp-release-12-9-update-1
169+
// And the nppGetStreamContext() helper is deprecated. We are explicitly
170+
// supposed to create the NppStreamContext manually from the CUDA device
171+
// properties:
172+
// https://github.com/NVIDIA/CUDALibrarySamples/blob/d97803a40fab83c058bb3d68b6c38bd6eebfff43/NPP/README.md?plain=1#L54-L72
173+
174+
NppStreamContext nppCtx{};
175+
cudaDeviceProp prop{};
176+
cudaError_t err = cudaGetDeviceProperties(&prop, deviceIndex);
177+
TORCH_CHECK(
178+
err == cudaSuccess,
179+
"cudaGetDeviceProperties failed: ",
180+
cudaGetErrorString(err));
181+
182+
nppCtx.nCudaDeviceId = deviceIndex;
183+
nppCtx.nMultiProcessorCount = prop.multiProcessorCount;
184+
nppCtx.nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor;
185+
nppCtx.nMaxThreadsPerBlock = prop.maxThreadsPerBlock;
186+
nppCtx.nSharedMemPerBlock = prop.sharedMemPerBlock;
187+
nppCtx.nCudaDevAttrComputeCapabilityMajor = prop.major;
188+
nppCtx.nCudaDevAttrComputeCapabilityMinor = prop.minor;
189+
190+
// TODO when implementing the cache logic, move these out. See other TODO
191+
// below.
192+
nppCtx.hStream = at::cuda::getCurrentCUDAStream(deviceIndex).stream();
193+
err = cudaStreamGetFlags(nppCtx.hStream, &nppCtx.nStreamFlags);
194+
TORCH_CHECK(
195+
err == cudaSuccess,
196+
"cudaStreamGetFlags failed: ",
197+
cudaGetErrorString(err));
198+
199+
return nppCtx;
200+
}
201+
164202
} // namespace
165203

166204
CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
@@ -265,37 +303,37 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
265303
dst = allocateEmptyHWCTensor(height, width, device_);
266304
}
267305

268-
// Use the user-requested GPU for running the NPP kernel.
269-
c10::cuda::CUDAGuard deviceGuard(device_);
306+
// TODO cache the NppStreamContext! It currently gets re-recated for every
307+
// single frame. The cache should be per-device, similar to the existing
308+
// hw_device_ctx cache. When implementing the cache logic, the
309+
// NppStreamContext hStream and nStreamFlags should not be part of the cache
310+
// because they may change across calls.
311+
NppStreamContext nppCtx = createNppStreamContext(
312+
static_cast<int>(getFFMPEGCompatibleDeviceIndex(device_)));
270313

271314
NppiSize oSizeROI = {width, height};
272315
Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]};
273316

274317
NppStatus status;
318+
275319
if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) {
276-
status = nppiNV12ToRGB_709CSC_8u_P2C3R(
320+
status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx(
277321
input,
278322
avFrame->linesize[0],
279323
static_cast<Npp8u*>(dst.data_ptr()),
280324
dst.stride(0),
281-
oSizeROI);
325+
oSizeROI,
326+
nppCtx);
282327
} else {
283-
status = nppiNV12ToRGB_8u_P2C3R(
328+
status = nppiNV12ToRGB_8u_P2C3R_Ctx(
284329
input,
285330
avFrame->linesize[0],
286331
static_cast<Npp8u*>(dst.data_ptr()),
287332
dst.stride(0),
288-
oSizeROI);
333+
oSizeROI,
334+
nppCtx);
289335
}
290336
TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame.");
291-
292-
// Make the pytorch stream wait for the npp kernel to finish before using the
293-
// output.
294-
at::cuda::CUDAEvent nppDoneEvent;
295-
at::cuda::CUDAStream nppStreamWrapper =
296-
c10::cuda::getStreamFromExternal(nppGetStream(), device_.index());
297-
nppDoneEvent.record(nppStreamWrapper);
298-
nppDoneEvent.block(at::cuda::getCurrentCUDAStream());
299337
}
300338

301339
// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#pragma once
88

9+
#include <npp.h>
910
#include "src/torchcodec/_core/DeviceInterface.h"
1011

1112
namespace facebook::torchcodec {

0 commit comments

Comments
 (0)