Skip to content

Commit d3ce6f4

Browse files
authored
Support revision in Flax text-to-image training (#2567)
Support revision in Flax text-to-image training.
1 parent ff91f15 commit d3ce6f4

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

examples/text_to_image/train_text_to_image_flax.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ def parse_args():
4848
required=True,
4949
help="Path to pretrained model or model identifier from huggingface.co/models.",
5050
)
51+
parser.add_argument(
52+
"--revision",
53+
type=str,
54+
default=None,
55+
required=False,
56+
help="Revision of pretrained model identifier from huggingface.co/models.",
57+
)
5158
parser.add_argument(
5259
"--dataset_name",
5360
type=str,
@@ -386,15 +393,17 @@ def collate_fn(examples):
386393
weight_dtype = jnp.bfloat16
387394

388395
# Load models and create wrapper for stable diffusion
389-
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
396+
tokenizer = CLIPTokenizer.from_pretrained(
397+
args.pretrained_model_name_or_path, revision=args.revision, subfolder="tokenizer"
398+
)
390399
text_encoder = FlaxCLIPTextModel.from_pretrained(
391-
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype
400+
args.pretrained_model_name_or_path, revision=args.revision, subfolder="text_encoder", dtype=weight_dtype
392401
)
393402
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
394-
args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype
403+
args.pretrained_model_name_or_path, revision=args.revision, subfolder="vae", dtype=weight_dtype
395404
)
396405
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
397-
args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype
406+
args.pretrained_model_name_or_path, revision=args.revision, subfolder="unet", dtype=weight_dtype
398407
)
399408

400409
# Optimization

0 commit comments

Comments
 (0)