@@ -42,6 +42,18 @@ def check_nightly_binaries_date(package: str) -> None:
42
42
f"Expected torchaudio, torchvision to be less then { NIGHTLY_ALLOWED_DELTA } days. But they are from { date_ta_str } , { date_tv_str } respectively"
43
43
)
44
44
45
+ def check_cuda_version (version : str , dlibary : str ):
46
+ version = torch .ops .torchaudio .cuda_version ()
47
+ if version is not None and torch .version .cuda is not None :
48
+ version_str = str (version )
49
+ ta_version = f"{ version_str [:- 3 ]} .{ version_str [- 2 ]} "
50
+ t_version = torch .version .cuda .split ("." )
51
+ t_version = f"{ t_version [0 ]} .{ t_version [1 ]} "
52
+ if ta_version != t_version :
53
+ raise RuntimeError (
54
+ "Detected that PyTorch and {dlibary} were compiled with different CUDA versions. "
55
+ f"PyTorch has CUDA version { t_version } whereas { dlibary } has CUDA version { ta_version } . "
56
+ )
45
57
46
58
def smoke_test_cuda (package : str ) -> None :
47
59
if not torch .cuda .is_available () and is_cuda_system :
@@ -56,20 +68,14 @@ def smoke_test_cuda(package: str) -> None:
56
68
print (f"torch cudnn: { torch .backends .cudnn .version ()} " )
57
69
print (f"cuDNN enabled? { torch .backends .cudnn .enabled } " )
58
70
59
- if (package == 'all' ):
71
+ if (package == 'all' and is_cuda_system ):
60
72
import torchaudio
61
73
import torchvision
62
-
63
74
print (f"torchvision cuda: { torch .ops .torchvision ._cuda_version ()} " )
64
75
print (f"torchaudio cuda: { torch .ops .torchaudio .cuda_version ()} " )
65
- if (
66
- gpu_arch_ver != torch .ops .torchvision ._cuda_version () or
67
- gpu_arch_ver != torch .ops .torchaudio .cuda_version ()
68
- ):
69
- raise RuntimeError (
70
- f"Wrong CUDA version. Vision: { torch .ops .torchvision ._cuda_version ()} \
71
- Audio: { ttorch .ops .torchaudio .cuda_version ()} Expected: { gpu_arch_ver } "
72
- )
76
+ check_cuda_version (torch .ops .torchvision ._cuda_version (), "TorchVision" )
77
+ check_cuda_version (torch .ops .torchaudio .cuda_version (), "TorchAudio" )
78
+
73
79
74
80
def smoke_test_conv2d () -> None :
75
81
import torch .nn as nn
0 commit comments