Skip to content

Add to support Guess Mode for StableDiffusionControlnetPipleline #2998

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 14, 2023
36 changes: 36 additions & 0 deletions docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,42 @@ image.save("./multi_controlnet_output.png")

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/controlnet/multi_controlnet_output.png" width=600/>

### Guess Mode

Guess Mode is [a ControlNet feature that was implemented](https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode) after the publication of [the paper](https://arxiv.org/abs/2302.05543). The description states:

>In this mode, the ControlNet encoder will try best to recognize the content of the input control map, like depth map, edge map, scribbles, etc, even if you remove all prompts.

#### The core implementation:

It adjusts the scale of the output residuals from ControlNet by a fixed ratio depending on the block depth. The shallowest DownBlock corresponds to `0.1`. As the blocks get deeper, the scale increases exponentially, and the scale for the output of the MidBlock becomes `1.0`.

Since the core implementation is just this, **it does not have any impact on prompt conditioning**. While it is common to use it without specifying any prompts, it is also possible to provide prompts if desired.

#### Usage:

Just specify `guess_mode=True` in the pipe() function. A `guidance_scale` between 3.0 and 5.0 is [recommended](https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode).
```py
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
import torch

controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet=controlnet).to(
"cuda"
)
image = pipe("", image=canny_image, guess_mode=True, guidance_scale=3.0).images[0]
image.save("guess_mode_generated.png")
```

#### Output image comparison:
Canny Control Example

|no guess_mode with prompt|guess_mode without prompt|
|---|---|
|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0.png"><img width="128" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0_gm.png"><img width="128" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0_gm.png"/></a>|



## Available checkpoints

ControlNet requires a *control image* in addition to the text-to-image *prompt*.
Expand Down
11 changes: 9 additions & 2 deletions src/diffusers/models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ def forward(
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
# check channel order
Expand Down Expand Up @@ -556,8 +557,14 @@ def forward(
mid_block_res_sample = self.controlnet_mid_block(sample)

# 6. scaling
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample *= conditioning_scale
if guess_mode:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

works for me!

scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0
scales *= conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample *= scales[-1] # last one
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample *= conditioning_scale

if not return_dict:
return (down_block_res_samples, mid_block_res_sample)
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks nice to me! Can we maybe add some additional code that makes sure that if guess_mode is True then we "prompt" has to be None to be sure users not accidentally use both guess_mode = True and guess_mode

Copy link
Contributor Author

@takuma104 takuma104 Apr 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guess Mode seems to have been designed to obtain good results even without prompts, but it can also be used with prompts. As mentioned in the first post, this is due to the following factors:

Since the core implementation is just this, it does not have any impact on Prompt conditioning. While it is common to use it without specifying any prompts, it is also possible to provide prompts if desired.

I provided an example of using prompts in Guess Mode. This OpenPose human image is easy to understand. Since it is up to the user how they want to use it, I think it might be fine not to display any error messages or warnings. What do you think?

Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def forward(
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
Expand All @@ -131,6 +132,7 @@ def forward(
timestep_cond,
attention_mask,
cross_attention_kwargs,
guess_mode,
return_dict,
)

Expand Down Expand Up @@ -627,7 +629,16 @@ def check_image(self, image, prompt, prompt_embeds):
)

def prepare_image(
self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance,
guess_mode,
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
Expand Down Expand Up @@ -664,7 +675,7 @@ def prepare_image(

image = image.to(device=device, dtype=dtype)

if do_classifier_free_guidance:
if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)

return image
Expand Down Expand Up @@ -747,6 +758,7 @@ def __call__(
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
guess_mode: bool = False,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -819,6 +831,10 @@ def __call__(
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
corresponding scale as a list.
guess_mode (`bool`, *optional*, defaults to `False`):
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.

Examples:

Returns:
Expand Down Expand Up @@ -883,6 +899,7 @@ def __call__(
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
elif isinstance(self.controlnet, MultiControlNetModel):
images = []
Expand All @@ -897,6 +914,7 @@ def __call__(
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)

images.append(image_)
Expand Down Expand Up @@ -934,15 +952,31 @@ def __call__(
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in terms of controlnet inputs, i.e. latents, prompt_embeds and image. is there any difference between these 2 configurations?

  1. not do_classifier_free_guidance
  2. do_classifier_free_guidance and guess_mode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 and 2 both share the commonality that they perform ControlNet only with the conditional condition of the prompt. Specifically, in the case of 2, the latents and prompt_embeds come in a concatenated state of both unconditional and conditional due to the processing above. Therefore, it is necessary to deliberately separate them and extract only the conditional conditions. Hence, this is the implemented.

As for the image, it is completely unrelated to the prompt for UNet, so there is no need to separate the processing under any conditions.

# Infer ControlNet only for the conditional batch.
controlnet_latent_model_input = latents
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
else:
controlnet_latent_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds

down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
controlnet_latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image,
conditioning_scale=controlnet_conditioning_scale,
guess_mode=guess_mode,
return_dict=False,
)

if guess_mode and do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,38 @@ def test_sequential_cpu_offloading(self):
# make sure that less than 7 GB is allocated
assert mem_bytes < 4 * 10**9

def test_canny_guess_mode(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")

pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)

generator = torch.Generator(device="cpu").manual_seed(0)
prompt = ""
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)

output = pipe(
prompt,
image,
generator=generator,
output_type="np",
num_inference_steps=3,
guidance_scale=3.0,
guess_mode=True,
)

image = output.images[0]
assert image.shape == (768, 512, 3)

image_slice = image[-3:, -3:, -1]
expected_slice = np.array([0.2724, 0.2846, 0.2724, 0.3843, 0.3682, 0.2736, 0.4675, 0.3862, 0.2887])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2


@slow
@require_torch_gpu
Expand Down