diff --git a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx b/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx
index 5a4cfa41ca43..af859177c002 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx
+++ b/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx
@@ -242,6 +242,42 @@ image.save("./multi_controlnet_output.png")
+### Guess Mode
+
+Guess Mode is [a ControlNet feature that was implemented](https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode) after the publication of [the paper](https://arxiv.org/abs/2302.05543). The description states:
+
+>In this mode, the ControlNet encoder will try best to recognize the content of the input control map, like depth map, edge map, scribbles, etc, even if you remove all prompts.
+
+#### The core implementation:
+
+It adjusts the scale of the output residuals from ControlNet by a fixed ratio depending on the block depth. The shallowest DownBlock corresponds to `0.1`. As the blocks get deeper, the scale increases exponentially, and the scale for the output of the MidBlock becomes `1.0`.
+
+Since the core implementation is just this, **it does not have any impact on prompt conditioning**. While it is common to use it without specifying any prompts, it is also possible to provide prompts if desired.
+
+#### Usage:
+
+Just specify `guess_mode=True` in the pipe() function. A `guidance_scale` between 3.0 and 5.0 is [recommended](https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode).
+```py
+from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
+import torch
+
+controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
+pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet=controlnet).to(
+ "cuda"
+)
+image = pipe("", image=canny_image, guess_mode=True, guidance_scale=3.0).images[0]
+image.save("guess_mode_generated.png")
+```
+
+#### Output image comparison:
+Canny Control Example
+
+|no guess_mode with prompt|guess_mode without prompt|
+|---|---|
+|
|
|
+
+
+
## Available checkpoints
ControlNet requires a *control image* in addition to the text-to-image *prompt*.
diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py
index bb608ad82a7a..4f1ffe604578 100644
--- a/src/diffusers/models/controlnet.py
+++ b/src/diffusers/models/controlnet.py
@@ -456,6 +456,7 @@ def forward(
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
# check channel order
@@ -556,8 +557,14 @@ def forward(
mid_block_res_sample = self.controlnet_mid_block(sample)
# 6. scaling
- down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
- mid_block_res_sample *= conditioning_scale
+ if guess_mode:
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0
+ scales *= conditioning_scale
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
+ mid_block_res_sample *= scales[-1] # last one
+ else:
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
+ mid_block_res_sample *= conditioning_scale
if not return_dict:
return (down_block_res_samples, mid_block_res_sample)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py
index 12d21afbfeda..1ebd469f76b3 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py
@@ -118,6 +118,7 @@ def forward(
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
@@ -131,6 +132,7 @@ def forward(
timestep_cond,
attention_mask,
cross_attention_kwargs,
+ guess_mode,
return_dict,
)
@@ -627,7 +629,16 @@ def check_image(self, image, prompt, prompt_embeds):
)
def prepare_image(
- self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance,
+ guess_mode,
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
@@ -664,7 +675,7 @@ def prepare_image(
image = image.to(device=device, dtype=dtype)
- if do_classifier_free_guidance:
+ if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)
return image
@@ -747,6 +758,7 @@ def __call__(
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ guess_mode: bool = False,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -819,6 +831,10 @@ def __call__(
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
+
Examples:
Returns:
@@ -883,6 +899,7 @@ def __call__(
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
)
elif isinstance(self.controlnet, MultiControlNetModel):
images = []
@@ -897,6 +914,7 @@ def __call__(
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
)
images.append(image_)
@@ -934,15 +952,31 @@ def __call__(
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# controlnet(s) inference
+ if guess_mode and do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ controlnet_latent_model_input = latents
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ else:
+ controlnet_latent_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
- latent_model_input,
+ controlnet_latent_model_input,
t,
- encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image,
conditioning_scale=controlnet_conditioning_scale,
+ guess_mode=guess_mode,
return_dict=False,
)
+ if guess_mode and do_classifier_free_guidance:
+ # Infered ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py
index d556e6318f43..5e73692c8d87 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py
@@ -553,6 +553,38 @@ def test_sequential_cpu_offloading(self):
# make sure that less than 7 GB is allocated
assert mem_bytes < 4 * 10**9
+ def test_canny_guess_mode(self):
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
+
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
+ )
+ pipe.enable_model_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.Generator(device="cpu").manual_seed(0)
+ prompt = ""
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
+ )
+
+ output = pipe(
+ prompt,
+ image,
+ generator=generator,
+ output_type="np",
+ num_inference_steps=3,
+ guidance_scale=3.0,
+ guess_mode=True,
+ )
+
+ image = output.images[0]
+ assert image.shape == (768, 512, 3)
+
+ image_slice = image[-3:, -3:, -1]
+ expected_slice = np.array([0.2724, 0.2846, 0.2724, 0.3843, 0.3682, 0.2736, 0.4675, 0.3862, 0.2887])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
@slow
@require_torch_gpu