diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 7b36d2eed96a..0b0ce0be547f 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -558,7 +558,7 @@ def forward( mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling - if guess_mode: + if guess_mode and not self.config.global_pool_conditions: scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 scales = scales * conditioning_scale diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index 5e8e68823b34..f8c3856c015e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -918,6 +918,13 @@ def __call__( if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets) + global_pool_conditions = ( + self.controlnet.config.global_pool_conditions + if isinstance(self.controlnet, ControlNetModel) + else self.controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py index 279df4a32b29..c7bb49771e98 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py @@ -622,6 +622,37 @@ def test_stable_diffusion_compile(self): assert np.abs(expected_image - image).max() < 1e-1 + def test_v11_shuffle_global_pool_conditions(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11e_sd15_shuffle") + + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "New York" + image = load_image( + "https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/main/images/control.png" + ) + + output = pipe( + prompt, + image, + generator=generator, + output_type="np", + num_inference_steps=3, + guidance_scale=7.0, + ) + + image = output.images[0] + assert image.shape == (512, 640, 3) + + image_slice = image[-3:, -3:, -1] + expected_slice = np.array([0.1338, 0.1597, 0.1202, 0.1687, 0.1377, 0.1017, 0.2070, 0.1574, 0.1348]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + @slow @require_torch_gpu