Skip to content

[LoRA training] update metadata use for lora alpha + README #11723

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions examples/advanced_diffusion_training/README_flux.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,24 @@ This command will prompt you for a token. Copy-paste yours from your [settings/t
> `pip install wandb`
> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.

### LoRA Rank and Alpha
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
- lora_alpha vs. rank:
This ratio dictates the LoRA's effective strength:
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)

> [!TIP]
> A common starting point is to set `lora_alpha` equal to `rank`.
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
> to give the LoRA updates more influence without increasing parameter count.
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.


### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import os
import sys
import tempfile

import safetensors

from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY


sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
Expand Down Expand Up @@ -281,3 +284,45 @@ def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_mult
run_command(self._launch_args + resume_run_args)

self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})

def test_dreambooth_lora_with_metadata(self):
# Use a `lora_alpha` that is different from `rank`.
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_alpha={lora_alpha}
--rank={rank}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))

# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}

metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)

loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_collate_lora_metadata,
_set_state_dict_into_text_encoder,
cast_training_params,
compute_density_for_timestep_sampling,
Expand Down Expand Up @@ -431,6 +432,13 @@ def parse_args(input_args=None):
help=("The dimension of the LoRA update matrices."),
)

parser.add_argument(
"--lora_alpha",
type=int,
default=4,
help="LoRA alpha to be used for additional scaling.",
)

parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")

parser.add_argument(
Expand Down Expand Up @@ -1556,7 +1564,7 @@ def main(args):
# now we will add new LoRA weights to the attention layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
Expand All @@ -1565,7 +1573,7 @@ def main(args):
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
Expand All @@ -1582,13 +1590,15 @@ def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
text_encoder_one_lora_layers_to_save = None

modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
elif isinstance(model, type(unwrap_model(text_encoder_one))):
if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["text_encoder"] = model
elif isinstance(model, type(unwrap_model(text_encoder_two))):
pass # when --train_text_encoder_ti and --enable_t5_ti we don't save the layers
else:
Expand All @@ -1601,6 +1611,7 @@ def save_model_hook(models, weights, output_dir):
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
**_collate_lora_metadata(modules_to_save),
)
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors")
Expand Down Expand Up @@ -2359,16 +2370,19 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer)
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer

if args.train_text_encoder:
text_encoder_one = unwrap_model(text_encoder_one)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
modules_to_save["text_encoder"] = text_encoder_one
else:
text_encoder_lora_layers = None

Expand All @@ -2377,6 +2391,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
save_directory=args.output_dir,
transformer_lora_layers=transformer_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
**_collate_lora_metadata(modules_to_save),
)

if args.train_text_encoder_ti:
Expand Down
17 changes: 17 additions & 0 deletions examples/dreambooth/README_flux.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,23 @@ accelerate launch train_dreambooth_lora_flux.py \
--push_to_hub
```

### LoRA Rank and Alpha
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
- lora_alpha vs. rank:
This ratio dictates the LoRA's effective strength:
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)

> [!TIP]
> A common starting point is to set `lora_alpha` equal to `rank`.
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
> to give the LoRA updates more influence without increasing parameter count.
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.

### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
Expand Down