Description
Is your feature request related to a problem? Please describe.
In SimpleTuner, I'm currently extending ControlNet training to all implemented model architectures.
So far, Control LoRA seems to be the preferred way to handle these in Diffusers, but in my limited experimentation across architectures, ControlNet LoRA is easier to train in fewer steps.
The problem however, the LoraLoaderMixin and related variants do not have controlnet_lora_layers
input, or handling for the controlnet
prefix removal at load time.
Describe the solution you'd like.
I would like for the default implementation for new LoRA loader mixins to include ControlNet layer support.
Additionally, it would be excellent if the current LoRA loaders are all extended to allow saving ControlNet layers.
Describe alternatives you've considered.
Currently, we're just forking the LoRA loader mixins and adjusting them to include ControlNet model support.
I suppose alternatively, we could have a ControlNet LoRA loader mixin that provides this functionality more generally since it hardly changes between models.
Additional context.
class PixArtSigmaControlNetLoraLoaderMixin(LoraBaseMixin):
"""
Load LoRA layers into PixArt Sigma ControlNet models.
"""
_lora_loadable_modules = ["transformer", "controlnet"]
transformer_name = "transformer"
controlnet_name = "controlnet"
@classmethod
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
controlnet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
"""Save LoRA weights for both transformer and controlnet."""
state_dict = {}
# Pack transformer weights (only the non-replaced blocks)
if transformer_lora_layers:
transformer_state = cls.pack_weights(
transformer_lora_layers, cls.transformer_name
)
state_dict.update(transformer_state)
# Pack controlnet weights
if controlnet_lora_layers:
state_dict.update(controlnet_lora_layers) # they're already packed
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name=None,
**kwargs,
):
"""Load LoRA weights into transformer and controlnet."""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
state_dict = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, **kwargs
)
# Separate transformer and controlnet weights
transformer_state_dict = {}
controlnet_state_dict = {}
for key, value in state_dict.items():
if key.startswith("controlnet."):
# Remove the "controlnet." prefix for loading into controlnet
new_key = key[len("controlnet.") :]
controlnet_state_dict[new_key] = value
elif key.startswith("transformer."):
# Remove the "transformer." prefix
new_key = key[len("transformer.") :]
transformer_state_dict[new_key] = value
else:
# Handle unprefixed keys based on content
if "controlnet" in key:
controlnet_state_dict[key] = value
else:
transformer_state_dict[key] = value
# Load into transformer if there are transformer weights
if transformer_state_dict:
self.load_lora_into_transformer(
transformer_state_dict,
transformer=self.transformer.transformer, # Access the base transformer
adapter_name=adapter_name,
_pipeline=self,
)
# Load into controlnet if there are controlnet weights
if controlnet_state_dict:
self.load_lora_into_controlnet(
controlnet_state_dict,
controlnet=self.transformer.controlnet, # Access the controlnet through wrapper
adapter_name=adapter_name,
_pipeline=self,
)
@classmethod
def load_lora_into_controlnet(
cls,
state_dict,
controlnet,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
):
"""Load LoRA layers into the controlnet adapter."""
logger.info("Loading controlnet LoRA layers.")
# The controlnet should have a load_lora_adapter method similar to transformer
if hasattr(controlnet, "load_lora_adapter"):
out = controlnet.load_lora_adapter(
state_dict,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
print(f"output of loading: {out}")
else:
# Fallback: manually inject LoRA weights
print(
f"[WARNING] Fallback to manual PEFT injection for loading. This is bad!"
)
from peft import inject_adapter_in_model, LoraConfig
# Infer LoRA config from state dict
lora_config = LoraConfig(
r=16, # You might want to infer this from the state dict
lora_alpha=16,
target_modules=[
"to_k",
"to_q",
"to_v",
"to_out.0",
"before_proj",
"after_proj",
],
)
inject_adapter_in_model(lora_config, controlnet, adapter_name=adapter_name)
incompatible_keys = set()
# Load the weights
for key in state_dict.keys():
controlnet.load_state_dict({key: state_dict[key]}, strict=True)
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
"""
Return state dict for lora weights and the network alphas.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* of a pretrained model hosted on the Hub.
- A path to a *directory* containing the model weights.
- A torch state dict.
"""
# Load the main state dict first
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
return state_dict
this isn't production-ready but it demonstrates some of the needed changes for controlnet model.