diff --git a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx b/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx index 5a4cfa41ca43..af859177c002 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx @@ -242,6 +242,42 @@ image.save("./multi_controlnet_output.png") +### Guess Mode + +Guess Mode is [a ControlNet feature that was implemented](https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode) after the publication of [the paper](https://arxiv.org/abs/2302.05543). The description states: + +>In this mode, the ControlNet encoder will try best to recognize the content of the input control map, like depth map, edge map, scribbles, etc, even if you remove all prompts. + +#### The core implementation: + +It adjusts the scale of the output residuals from ControlNet by a fixed ratio depending on the block depth. The shallowest DownBlock corresponds to `0.1`. As the blocks get deeper, the scale increases exponentially, and the scale for the output of the MidBlock becomes `1.0`. + +Since the core implementation is just this, **it does not have any impact on prompt conditioning**. While it is common to use it without specifying any prompts, it is also possible to provide prompts if desired. + +#### Usage: + +Just specify `guess_mode=True` in the pipe() function. A `guidance_scale` between 3.0 and 5.0 is [recommended](https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode). +```py +from diffusers import StableDiffusionControlNetPipeline, ControlNetModel +import torch + +controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") +pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet=controlnet).to( + "cuda" +) +image = pipe("", image=canny_image, guess_mode=True, guidance_scale=3.0).images[0] +image.save("guess_mode_generated.png") +``` + +#### Output image comparison: +Canny Control Example + +|no guess_mode with prompt|guess_mode without prompt| +|---|---| +||| + + + ## Available checkpoints ControlNet requires a *control image* in addition to the text-to-image *prompt*. diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index bb608ad82a7a..4f1ffe604578 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -456,6 +456,7 @@ def forward( timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, return_dict: bool = True, ) -> Union[ControlNetOutput, Tuple]: # check channel order @@ -556,8 +557,14 @@ def forward( mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling - down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] - mid_block_res_sample *= conditioning_scale + if guess_mode: + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0 + scales *= conditioning_scale + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample *= scales[-1] # last one + else: + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample *= conditioning_scale if not return_dict: return (down_block_res_samples, mid_block_res_sample) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index 12d21afbfeda..1ebd469f76b3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -118,6 +118,7 @@ def forward( timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, return_dict: bool = True, ) -> Union[ControlNetOutput, Tuple]: for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): @@ -131,6 +132,7 @@ def forward( timestep_cond, attention_mask, cross_attention_kwargs, + guess_mode, return_dict, ) @@ -627,7 +629,16 @@ def check_image(self, image, prompt, prompt_embeds): ) def prepare_image( - self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance, + guess_mode, ): if not isinstance(image, torch.Tensor): if isinstance(image, PIL.Image.Image): @@ -664,7 +675,7 @@ def prepare_image( image = image.to(device=device, dtype=dtype) - if do_classifier_free_guidance: + if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image @@ -747,6 +758,7 @@ def __call__( callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, ): r""" Function invoked when calling the pipeline for generation. @@ -819,6 +831,10 @@ def __call__( The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original unet. If multiple ControlNets are specified in init, you can set the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + Examples: Returns: @@ -883,6 +899,7 @@ def __call__( device=device, dtype=self.controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, ) elif isinstance(self.controlnet, MultiControlNetModel): images = [] @@ -897,6 +914,7 @@ def __call__( device=device, dtype=self.controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, ) images.append(image_) @@ -934,15 +952,31 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + controlnet_latent_model_input = latents + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + controlnet_latent_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + down_block_res_samples, mid_block_res_sample = self.controlnet( - latent_model_input, + controlnet_latent_model_input, t, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, conditioning_scale=controlnet_conditioning_scale, + guess_mode=guess_mode, return_dict=False, ) + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + # predict the noise residual noise_pred = self.unet( latent_model_input, diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py index d556e6318f43..5e73692c8d87 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py @@ -553,6 +553,38 @@ def test_sequential_cpu_offloading(self): # make sure that less than 7 GB is allocated assert mem_bytes < 4 * 10**9 + def test_canny_guess_mode(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") + + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + + output = pipe( + prompt, + image, + generator=generator, + output_type="np", + num_inference_steps=3, + guidance_scale=3.0, + guess_mode=True, + ) + + image = output.images[0] + assert image.shape == (768, 512, 3) + + image_slice = image[-3:, -3:, -1] + expected_slice = np.array([0.2724, 0.2846, 0.2724, 0.3843, 0.3682, 0.2736, 0.4675, 0.3862, 0.2887]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + @slow @require_torch_gpu