@@ -102,7 +102,7 @@ def test_cuda_runtime_errors_captured() -> None:
102
102
if (cuda_exception_missed ):
103
103
raise RuntimeError ( f"Expected CUDA RuntimeError but have not received!" )
104
104
105
- def smoke_test_cuda (package : str ) -> None :
105
+ def smoke_test_cuda (package : str , runtime_error_check : str ) -> None :
106
106
if not torch .cuda .is_available () and is_cuda_system :
107
107
raise RuntimeError (f"Expected CUDA { gpu_arch_ver } . However CUDA is not loaded." )
108
108
@@ -132,7 +132,8 @@ def smoke_test_cuda(package: str) -> None:
132
132
if (sys .platform == "linux" or sys .platform == "linux2" ) and sys .version_info < (3 , 11 , 0 ):
133
133
smoke_test_compile ()
134
134
135
- test_cuda_runtime_errors_captured ()
135
+ if (runtime_error_check == "enabled" )
136
+ test_cuda_runtime_errors_captured ()
136
137
137
138
138
139
def smoke_test_conv2d () -> None :
@@ -231,6 +232,13 @@ def main() -> None:
231
232
choices = ["all" , "torchonly" ],
232
233
default = "all" ,
233
234
)
235
+ parser .add_argument (
236
+ "--runtime-error-check" ,
237
+ help = "No Runtime Error check" ,
238
+ type = str ,
239
+ choices = ["enabled" , "disabled" ],
240
+ default = "enabled" ,
241
+ )
234
242
options = parser .parse_args ()
235
243
print (f"torch: { torch .__version__ } " )
236
244
check_version (options .package )
@@ -240,7 +248,7 @@ def main() -> None:
240
248
if options .package == "all" :
241
249
smoke_test_modules ()
242
250
243
- smoke_test_cuda (options .package )
251
+ smoke_test_cuda (options .package , options . runtime_error_check )
244
252
245
253
246
254
if __name__ == "__main__" :
0 commit comments