Skip to content

ControlNet LoRA support more broadly across models #11733

Open
@bghira

Description

@bghira

Is your feature request related to a problem? Please describe.

In SimpleTuner, I'm currently extending ControlNet training to all implemented model architectures.

So far, Control LoRA seems to be the preferred way to handle these in Diffusers, but in my limited experimentation across architectures, ControlNet LoRA is easier to train in fewer steps.

The problem however, the LoraLoaderMixin and related variants do not have controlnet_lora_layers input, or handling for the controlnet prefix removal at load time.

Describe the solution you'd like.

I would like for the default implementation for new LoRA loader mixins to include ControlNet layer support.

Additionally, it would be excellent if the current LoRA loaders are all extended to allow saving ControlNet layers.

Describe alternatives you've considered.

Currently, we're just forking the LoRA loader mixins and adjusting them to include ControlNet model support.

I suppose alternatively, we could have a ControlNet LoRA loader mixin that provides this functionality more generally since it hardly changes between models.

Additional context.

class PixArtSigmaControlNetLoraLoaderMixin(LoraBaseMixin):
    """
    Load LoRA layers into PixArt Sigma ControlNet models.
    """

    _lora_loadable_modules = ["transformer", "controlnet"]
    transformer_name = "transformer"
    controlnet_name = "controlnet"

    @classmethod
    def save_lora_weights(
        cls,
        save_directory: Union[str, os.PathLike],
        transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
        controlnet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
        is_main_process: bool = True,
        weight_name: str = None,
        save_function: Callable = None,
        safe_serialization: bool = True,
    ):
        """Save LoRA weights for both transformer and controlnet."""
        state_dict = {}

        # Pack transformer weights (only the non-replaced blocks)
        if transformer_lora_layers:
            transformer_state = cls.pack_weights(
                transformer_lora_layers, cls.transformer_name
            )
            state_dict.update(transformer_state)

        # Pack controlnet weights
        if controlnet_lora_layers:
            state_dict.update(controlnet_lora_layers)  # they're already packed

        # Save the model
        cls.write_lora_layers(
            state_dict=state_dict,
            save_directory=save_directory,
            is_main_process=is_main_process,
            weight_name=weight_name,
            save_function=save_function,
            safe_serialization=safe_serialization,
        )

    def load_lora_weights(
        self,
        pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
        adapter_name=None,
        **kwargs,
    ):
        """Load LoRA weights into transformer and controlnet."""
        if not USE_PEFT_BACKEND:
            raise ValueError("PEFT backend is required for this method.")

        state_dict = self.lora_state_dict(
            pretrained_model_name_or_path_or_dict, **kwargs
        )

        # Separate transformer and controlnet weights
        transformer_state_dict = {}
        controlnet_state_dict = {}

        for key, value in state_dict.items():
            if key.startswith("controlnet."):
                # Remove the "controlnet." prefix for loading into controlnet
                new_key = key[len("controlnet.") :]
                controlnet_state_dict[new_key] = value
            elif key.startswith("transformer."):
                # Remove the "transformer." prefix
                new_key = key[len("transformer.") :]
                transformer_state_dict[new_key] = value
            else:
                # Handle unprefixed keys based on content
                if "controlnet" in key:
                    controlnet_state_dict[key] = value
                else:
                    transformer_state_dict[key] = value

        # Load into transformer if there are transformer weights
        if transformer_state_dict:
            self.load_lora_into_transformer(
                transformer_state_dict,
                transformer=self.transformer.transformer,  # Access the base transformer
                adapter_name=adapter_name,
                _pipeline=self,
            )

        # Load into controlnet if there are controlnet weights
        if controlnet_state_dict:
            self.load_lora_into_controlnet(
                controlnet_state_dict,
                controlnet=self.transformer.controlnet,  # Access the controlnet through wrapper
                adapter_name=adapter_name,
                _pipeline=self,
            )

    @classmethod
    def load_lora_into_controlnet(
        cls,
        state_dict,
        controlnet,
        adapter_name=None,
        _pipeline=None,
        low_cpu_mem_usage=False,
    ):
        """Load LoRA layers into the controlnet adapter."""
        logger.info("Loading controlnet LoRA layers.")

        # The controlnet should have a load_lora_adapter method similar to transformer
        if hasattr(controlnet, "load_lora_adapter"):
            out = controlnet.load_lora_adapter(
                state_dict,
                adapter_name=adapter_name,
                _pipeline=_pipeline,
                low_cpu_mem_usage=low_cpu_mem_usage,
            )
            print(f"output of loading: {out}")
        else:
            # Fallback: manually inject LoRA weights
            print(
                f"[WARNING] Fallback to manual PEFT injection for loading. This is bad!"
            )
            from peft import inject_adapter_in_model, LoraConfig

            # Infer LoRA config from state dict
            lora_config = LoraConfig(
                r=16,  # You might want to infer this from the state dict
                lora_alpha=16,
                target_modules=[
                    "to_k",
                    "to_q",
                    "to_v",
                    "to_out.0",
                    "before_proj",
                    "after_proj",
                ],
            )

            inject_adapter_in_model(lora_config, controlnet, adapter_name=adapter_name)
            incompatible_keys = set()

            # Load the weights
            for key in state_dict.keys():
                controlnet.load_state_dict({key: state_dict[key]}, strict=True)

    @classmethod
    @validate_hf_hub_args
    def lora_state_dict(
        cls,
        pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
        **kwargs,
    ):
        """
        Return state dict for lora weights and the network alphas.

        Parameters:
            pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
                Can be either:
                    - A string, the *model id* of a pretrained model hosted on the Hub.
                    - A path to a *directory* containing the model weights.
                    - A torch state dict.
        """
        # Load the main state dict first
        cache_dir = kwargs.pop("cache_dir", None)
        force_download = kwargs.pop("force_download", False)
        proxies = kwargs.pop("proxies", None)
        local_files_only = kwargs.pop("local_files_only", None)
        token = kwargs.pop("token", None)
        revision = kwargs.pop("revision", None)
        subfolder = kwargs.pop("subfolder", None)
        weight_name = kwargs.pop("weight_name", None)
        use_safetensors = kwargs.pop("use_safetensors", None)

        allow_pickle = False
        if use_safetensors is None:
            use_safetensors = True
            allow_pickle = True

        user_agent = {
            "file_type": "attn_procs_weights",
            "framework": "pytorch",
        }

        state_dict = _fetch_state_dict(
            pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
            weight_name=weight_name,
            use_safetensors=use_safetensors,
            local_files_only=local_files_only,
            cache_dir=cache_dir,
            force_download=force_download,
            proxies=proxies,
            token=token,
            revision=revision,
            subfolder=subfolder,
            user_agent=user_agent,
            allow_pickle=allow_pickle,
        )

        return state_dict

this isn't production-ready but it demonstrates some of the needed changes for controlnet model.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions