@@ -161,6 +161,44 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
161
161
device, nonNegativeDeviceIndex, type);
162
162
#endif
163
163
}
164
+
165
+ NppStreamContext createNppStreamContext (int deviceIndex) {
166
+ // From 12.9, NPP recommends using a user-created NppStreamContext and using
167
+ // the `_Ctx()` calls:
168
+ // https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#npp-release-12-9-update-1
169
+ // And the nppGetStreamContext() helper is deprecated. We are explicitly
170
+ // supposed to create the NppStreamContext manually from the CUDA device
171
+ // properties:
172
+ // https://github.com/NVIDIA/CUDALibrarySamples/blob/d97803a40fab83c058bb3d68b6c38bd6eebfff43/NPP/README.md?plain=1#L54-L72
173
+
174
+ NppStreamContext nppCtx{};
175
+ cudaDeviceProp prop{};
176
+ cudaError_t err = cudaGetDeviceProperties (&prop, deviceIndex);
177
+ TORCH_CHECK (
178
+ err == cudaSuccess,
179
+ " cudaGetDeviceProperties failed: " ,
180
+ cudaGetErrorString (err));
181
+
182
+ nppCtx.nCudaDeviceId = deviceIndex;
183
+ nppCtx.nMultiProcessorCount = prop.multiProcessorCount ;
184
+ nppCtx.nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor ;
185
+ nppCtx.nMaxThreadsPerBlock = prop.maxThreadsPerBlock ;
186
+ nppCtx.nSharedMemPerBlock = prop.sharedMemPerBlock ;
187
+ nppCtx.nCudaDevAttrComputeCapabilityMajor = prop.major ;
188
+ nppCtx.nCudaDevAttrComputeCapabilityMinor = prop.minor ;
189
+
190
+ // TODO when implementing the cache logic, move these out. See other TODO
191
+ // below.
192
+ nppCtx.hStream = at::cuda::getCurrentCUDAStream (deviceIndex).stream ();
193
+ err = cudaStreamGetFlags (nppCtx.hStream , &nppCtx.nStreamFlags );
194
+ TORCH_CHECK (
195
+ err == cudaSuccess,
196
+ " cudaStreamGetFlags failed: " ,
197
+ cudaGetErrorString (err));
198
+
199
+ return nppCtx;
200
+ }
201
+
164
202
} // namespace
165
203
166
204
CudaDeviceInterface::CudaDeviceInterface (const torch::Device& device)
@@ -265,37 +303,37 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
265
303
dst = allocateEmptyHWCTensor (height, width, device_);
266
304
}
267
305
268
- // Use the user-requested GPU for running the NPP kernel.
269
- c10::cuda::CUDAGuard deviceGuard (device_);
306
+ // TODO cache the NppStreamContext! It currently gets re-recated for every
307
+ // single frame. The cache should be per-device, similar to the existing
308
+ // hw_device_ctx cache. When implementing the cache logic, the
309
+ // NppStreamContext hStream and nStreamFlags should not be part of the cache
310
+ // because they may change across calls.
311
+ NppStreamContext nppCtx = createNppStreamContext (
312
+ static_cast <int >(getFFMPEGCompatibleDeviceIndex (device_)));
270
313
271
314
NppiSize oSizeROI = {width, height};
272
315
Npp8u* input[2 ] = {avFrame->data [0 ], avFrame->data [1 ]};
273
316
274
317
NppStatus status;
318
+
275
319
if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) {
276
- status = nppiNV12ToRGB_709CSC_8u_P2C3R (
320
+ status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx (
277
321
input,
278
322
avFrame->linesize [0 ],
279
323
static_cast <Npp8u*>(dst.data_ptr ()),
280
324
dst.stride (0 ),
281
- oSizeROI);
325
+ oSizeROI,
326
+ nppCtx);
282
327
} else {
283
- status = nppiNV12ToRGB_8u_P2C3R (
328
+ status = nppiNV12ToRGB_8u_P2C3R_Ctx (
284
329
input,
285
330
avFrame->linesize [0 ],
286
331
static_cast <Npp8u*>(dst.data_ptr ()),
287
332
dst.stride (0 ),
288
- oSizeROI);
333
+ oSizeROI,
334
+ nppCtx);
289
335
}
290
336
TORCH_CHECK (status == NPP_SUCCESS, " Failed to convert NV12 frame." );
291
-
292
- // Make the pytorch stream wait for the npp kernel to finish before using the
293
- // output.
294
- at::cuda::CUDAEvent nppDoneEvent;
295
- at::cuda::CUDAStream nppStreamWrapper =
296
- c10::cuda::getStreamFromExternal (nppGetStream (), device_.index ());
297
- nppDoneEvent.record (nppStreamWrapper);
298
- nppDoneEvent.block (at::cuda::getCurrentCUDAStream ());
299
337
}
300
338
301
339
// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9
0 commit comments