@@ -30,11 +30,8 @@ class OperatingSystem(Enum):
30
30
ENABLE = "enable"
31
31
DISABLE = "disable"
32
32
33
- # Mapping json to release matrix is here for now
34
- # TBD drive the mapping via:
35
- # 1. Scanning release matrix and picking 2 latest cuda versions and 1 latest rocm
36
- # 2. Possibility to override the scanning algorithm with arguments passed from workflow
37
- acc_arch_ver_map = {
33
+ # Mapping json to release matrix default values
34
+ acc_arch_ver_default = {
38
35
"nightly" : {
39
36
"accnone" : ("cpu" , "" ),
40
37
"cuda.x" : ("cuda" , "11.6" ),
@@ -49,6 +46,11 @@ class OperatingSystem(Enum):
49
46
}
50
47
}
51
48
49
+ # Initialize arch version to default values
50
+ # these default values will be overwritten by
51
+ # extracted values from the release marix
52
+ acc_arch_ver_map = acc_arch_ver_default
53
+
52
54
LIBTORCH_DWNL_INSTR = {
53
55
PRE_CXX11_ABI : "Download here (Pre-cxx11 ABI):" ,
54
56
CXX11_ABI : "Download here (cxx11 ABI):" ,
@@ -163,6 +165,26 @@ def gen_install_matrix(versions) -> Dict[str, str]:
163
165
result [key ] = "<br />" .join (lines )
164
166
return result
165
167
168
+ # This method is used for extracting two latest verisons of cuda and
169
+ # last verion of rocm. It will modify the acc_arch_ver_map object used
170
+ # to update getting started page.
171
+ def extract_arch_ver_map (release_matrix ):
172
+ def gen_ver_list (chan , gpu_arch_type ):
173
+ return {
174
+ x ["desired_cuda" ]: x ["gpu_arch_version" ]
175
+ for x in release_matrix [chan ]["linux" ]
176
+ if x ["gpu_arch_type" ] == gpu_arch_type
177
+ }
178
+
179
+ for chan in ("nightly" , "release" ):
180
+ cuda_ver_list = gen_ver_list (chan , "cuda" )
181
+ rocm_ver_list = gen_ver_list (chan , "rocm" )
182
+ cuda_list = sorted (cuda_ver_list .values ())[- 2 :]
183
+ acc_arch_ver_map [chan ]["rocm5.x" ] = ("rocm" , max (rocm_ver_list .values ()))
184
+ for cuda_ver , label in zip (cuda_list , ["cuda.x" , "cuda.y" ]):
185
+ acc_arch_ver_map [chan ][label ] = ("cuda" , cuda_ver )
186
+
187
+
166
188
def main ():
167
189
parser = argparse .ArgumentParser ()
168
190
parser .add_argument ('--autogenerate' , dest = 'autogenerate' , action = 'store_true' )
@@ -178,6 +200,7 @@ def main():
178
200
for osys in OperatingSystem :
179
201
release_matrix [val ][osys .value ] = read_matrix_for_os (osys , val )
180
202
203
+ extract_arch_ver_map (release_matrix )
181
204
for val in ("nightly" , "release" ):
182
205
update_versions (versions , release_matrix [val ], val )
183
206
0 commit comments