Skip to content

Commit 2c04e58

Browse files
Multi Vector Textual Inversion (#3144)
* Multi Vector * Improve * fix multi token * improve test * make style * Update examples/test_examples.py * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * update * Finish * Apply suggestions from code review --------- Co-authored-by: Suraj Patil <[email protected]>
1 parent 391cfcd commit 2c04e58

File tree

5 files changed

+107
-13
lines changed

5 files changed

+107
-13
lines changed

docs/source/en/training/text_inversion.mdx

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,18 @@ accelerate launch textual_inversion.py \
122122
--lr_warmup_steps=0 \
123123
--output_dir="textual_inversion_cat"
124124
```
125+
126+
<Tip>
127+
128+
💡 If you want to increase the trainable capacity, you can associate your placeholder token, *e.g.* `<cat-toy>` to
129+
multiple embedding vectors. This can help the model to better capture the style of more (complex) images.
130+
To enable training multiple embedding vectors, simply pass:
131+
132+
```bash
133+
--num_vectors=5
134+
```
135+
136+
</Tip>
125137
</pt>
126138
<jax>
127139
If you have access to TPUs, try out the [Flax training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py) to train even faster (this'll also work for GPUs). With the same configuration settings, the Flax training script should be at least 70% faster than the PyTorch training script! ⚡️

examples/research_projects/mulit_token_textual_inversion/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
## Multi Token Textual Inversion
1+
## [Deprecated] Multi Token Textual Inversion
2+
3+
**IMPORTART: This research project is deprecated. Multi Token Textual Inversion is now supported natively in [the officail textual inversion example](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion#running-locally-with-pytorch).**
4+
25
The author of this project is [Isamu Isozaki](https://github.com/isamu-isozaki) - please make sure to tag the author for issue and PRs as well as @patrickvonplaten.
36

47
We add multi token support to textual inversion. I added

examples/test_examples.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ def test_textual_inversion(self):
105105
--learnable_property object
106106
--placeholder_token <cat-toy>
107107
--initializer_token a
108+
--validation_prompt <cat-toy>
109+
--validation_steps 1
110+
--save_steps 1
111+
--num_vectors 2
108112
--resolution 64
109113
--train_batch_size 1
110114
--gradient_accumulation_steps 1

examples/textual_inversion/README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) e
3636
accelerate config
3737
```
3838

39-
4039
### Cat toy example
4140

4241
First, let's login so that we can upload the checkpoint to the Hub during training:
@@ -83,6 +82,18 @@ accelerate launch textual_inversion.py \
8382

8483
A full training run takes ~1 hour on one V100 GPU.
8584

85+
**Note**: As described in [the official paper](https://arxiv.org/abs/2208.01618)
86+
only one embedding vector is used for the placeholder token, *e.g.* `"<cat-toy>"`.
87+
However, one can also add multiple embedding vectors for the placeholder token
88+
to inclease the number of fine-tuneable parameters. This can help the model to learn
89+
more complex details. To use multiple embedding vectors, you can should define `--num_vectors`
90+
to a number larger than one, *e.g.*:
91+
```
92+
--num_vectors 5
93+
```
94+
95+
The saved textual inversion vectors will then be larger in size compared to the default case.
96+
8697
### Inference
8798

8899
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.

examples/textual_inversion/textual_inversion.py

Lines changed: 75 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,34 @@
8282
logger = get_logger(__name__)
8383

8484

85+
def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None):
86+
img_str = ""
87+
for i, image in enumerate(images):
88+
image.save(os.path.join(repo_folder, f"image_{i}.png"))
89+
img_str += f"![img_{i}](./image_{i}.png)\n"
90+
91+
yaml = f"""
92+
---
93+
license: creativeml-openrail-m
94+
base_model: {base_model}
95+
tags:
96+
- stable-diffusion
97+
- stable-diffusion-diffusers
98+
- text-to-image
99+
- diffusers
100+
- textual_inversion
101+
inference: true
102+
---
103+
"""
104+
model_card = f"""
105+
# Textual inversion text2image fine-tuning - {repo_id}
106+
These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n
107+
{img_str}
108+
"""
109+
with open(os.path.join(repo_folder, "README.md"), "w") as f:
110+
f.write(yaml + model_card)
111+
112+
85113
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
86114
logger.info(
87115
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
@@ -94,6 +122,7 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
94122
tokenizer=tokenizer,
95123
unet=unet,
96124
vae=vae,
125+
safety_checker=None,
97126
revision=args.revision,
98127
torch_dtype=weight_dtype,
99128
)
@@ -124,11 +153,16 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
124153

125154
del pipeline
126155
torch.cuda.empty_cache()
156+
return images
127157

128158

129-
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
159+
def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path):
130160
logger.info("Saving embeddings")
131-
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
161+
learned_embeds = (
162+
accelerator.unwrap_model(text_encoder)
163+
.get_input_embeddings()
164+
.weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
165+
)
132166
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
133167
torch.save(learned_embeds_dict, save_path)
134168

@@ -144,9 +178,15 @@ def parse_args():
144178
parser.add_argument(
145179
"--only_save_embeds",
146180
action="store_true",
147-
default=False,
181+
default=True,
148182
help="Save only the embeddings for the new concept.",
149183
)
184+
parser.add_argument(
185+
"--num_vectors",
186+
type=int,
187+
default=1,
188+
help="How many textual inversion vectors shall be used to learn the concept.",
189+
)
150190
parser.add_argument(
151191
"--pretrained_model_name_or_path",
152192
type=str,
@@ -581,8 +621,19 @@ def main():
581621
)
582622

583623
# Add the placeholder token in tokenizer
584-
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
585-
if num_added_tokens == 0:
624+
placeholder_tokens = [args.placeholder_token]
625+
626+
if args.num_vectors < 1:
627+
raise ValueError(f"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}")
628+
629+
# add dummy tokens for multi-vector
630+
additional_tokens = []
631+
for i in range(1, args.num_vectors):
632+
additional_tokens.append(f"{args.placeholder_token}_{i}")
633+
placeholder_tokens += additional_tokens
634+
635+
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
636+
if num_added_tokens != args.num_vectors:
586637
raise ValueError(
587638
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
588639
" `placeholder_token` that is not already in the tokenizer."
@@ -595,14 +646,16 @@ def main():
595646
raise ValueError("The initializer token must be a single token.")
596647

597648
initializer_token_id = token_ids[0]
598-
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
649+
placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)
599650

600651
# Resize the token embeddings as we are adding new special tokens to the tokenizer
601652
text_encoder.resize_token_embeddings(len(tokenizer))
602653

603654
# Initialise the newly added placeholder token with the embeddings of the initializer token
604655
token_embeds = text_encoder.get_input_embeddings().weight.data
605-
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
656+
with torch.no_grad():
657+
for token_id in placeholder_token_ids:
658+
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
606659

607660
# Freeze vae and unet
608661
vae.requires_grad_(False)
@@ -810,19 +863,22 @@ def main():
810863
optimizer.zero_grad()
811864

812865
# Let's make sure we don't update any embedding weights besides the newly added token
813-
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
866+
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
867+
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
868+
814869
with torch.no_grad():
815870
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
816871
index_no_updates
817872
] = orig_embeds_params[index_no_updates]
818873

819874
# Checks if the accelerator has performed an optimization step behind the scenes
820875
if accelerator.sync_gradients:
876+
images = []
821877
progress_bar.update(1)
822878
global_step += 1
823879
if global_step % args.save_steps == 0:
824880
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
825-
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
881+
save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
826882

827883
if accelerator.is_main_process:
828884
if global_step % args.checkpointing_steps == 0:
@@ -831,7 +887,9 @@ def main():
831887
logger.info(f"Saved state to {save_path}")
832888

833889
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
834-
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
890+
images = log_validation(
891+
text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch
892+
)
835893

836894
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
837895
progress_bar.set_postfix(**logs)
@@ -858,9 +916,15 @@ def main():
858916
pipeline.save_pretrained(args.output_dir)
859917
# Save the newly trained embeddings
860918
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
861-
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
919+
save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
862920

863921
if args.push_to_hub:
922+
save_model_card(
923+
repo_id,
924+
images=images,
925+
base_model=args.pretrained_model_name_or_path,
926+
repo_folder=args.output_dir,
927+
)
864928
upload_folder(
865929
repo_id=repo_id,
866930
folder_path=args.output_dir,

0 commit comments

Comments
 (0)