@@ -48,6 +48,13 @@ def parse_args():
48
48
required = True ,
49
49
help = "Path to pretrained model or model identifier from huggingface.co/models." ,
50
50
)
51
+ parser .add_argument (
52
+ "--revision" ,
53
+ type = str ,
54
+ default = None ,
55
+ required = False ,
56
+ help = "Revision of pretrained model identifier from huggingface.co/models." ,
57
+ )
51
58
parser .add_argument (
52
59
"--dataset_name" ,
53
60
type = str ,
@@ -386,15 +393,17 @@ def collate_fn(examples):
386
393
weight_dtype = jnp .bfloat16
387
394
388
395
# Load models and create wrapper for stable diffusion
389
- tokenizer = CLIPTokenizer .from_pretrained (args .pretrained_model_name_or_path , subfolder = "tokenizer" )
396
+ tokenizer = CLIPTokenizer .from_pretrained (
397
+ args .pretrained_model_name_or_path , revision = args .revision , subfolder = "tokenizer"
398
+ )
390
399
text_encoder = FlaxCLIPTextModel .from_pretrained (
391
- args .pretrained_model_name_or_path , subfolder = "text_encoder" , dtype = weight_dtype
400
+ args .pretrained_model_name_or_path , revision = args . revision , subfolder = "text_encoder" , dtype = weight_dtype
392
401
)
393
402
vae , vae_params = FlaxAutoencoderKL .from_pretrained (
394
- args .pretrained_model_name_or_path , subfolder = "vae" , dtype = weight_dtype
403
+ args .pretrained_model_name_or_path , revision = args . revision , subfolder = "vae" , dtype = weight_dtype
395
404
)
396
405
unet , unet_params = FlaxUNet2DConditionModel .from_pretrained (
397
- args .pretrained_model_name_or_path , subfolder = "unet" , dtype = weight_dtype
406
+ args .pretrained_model_name_or_path , revision = args . revision , subfolder = "unet" , dtype = weight_dtype
398
407
)
399
408
400
409
# Optimization
0 commit comments