diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 4f1ffe604578..3ffbb04eb222 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -119,6 +119,7 @@ def __init__( projection_class_embeddings_input_dim: Optional[int] = None, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, ): super().__init__() @@ -566,6 +567,12 @@ def forward( down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] mid_block_res_sample *= conditioning_scale + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) + if not return_dict: return (down_block_res_samples, mid_block_res_sample)