-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Add to support Guess Mode for StableDiffusionControlnetPipleline #2998
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4cb28c9
ab298ca
04a375b
0b377e7
a2b7e42
3350754
1828cdf
3a0898d
1819689
c703294
81fe000
e2c9a03
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks nice to me! Can we maybe add some additional code that makes sure that if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guess Mode seems to have been designed to obtain good results even without prompts, but it can also be used with prompts. As mentioned in the first post, this is due to the following factors:
I provided an example of using prompts in Guess Mode. This OpenPose human image is easy to understand. Since it is up to the user how they want to use it, I think it might be fine not to display any error messages or warnings. What do you think? |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -118,6 +118,7 @@ def forward( | |
timestep_cond: Optional[torch.Tensor] = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||
guess_mode: bool = False, | ||
return_dict: bool = True, | ||
) -> Union[ControlNetOutput, Tuple]: | ||
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): | ||
|
@@ -131,6 +132,7 @@ def forward( | |
timestep_cond, | ||
attention_mask, | ||
cross_attention_kwargs, | ||
guess_mode, | ||
return_dict, | ||
) | ||
|
||
|
@@ -627,7 +629,16 @@ def check_image(self, image, prompt, prompt_embeds): | |
) | ||
|
||
def prepare_image( | ||
self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance | ||
self, | ||
image, | ||
width, | ||
height, | ||
batch_size, | ||
num_images_per_prompt, | ||
device, | ||
dtype, | ||
do_classifier_free_guidance, | ||
guess_mode, | ||
): | ||
if not isinstance(image, torch.Tensor): | ||
if isinstance(image, PIL.Image.Image): | ||
|
@@ -664,7 +675,7 @@ def prepare_image( | |
|
||
image = image.to(device=device, dtype=dtype) | ||
|
||
if do_classifier_free_guidance: | ||
if do_classifier_free_guidance and not guess_mode: | ||
image = torch.cat([image] * 2) | ||
|
||
return image | ||
|
@@ -747,6 +758,7 @@ def __call__( | |
callback_steps: int = 1, | ||
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0, | ||
guess_mode: bool = False, | ||
): | ||
r""" | ||
Function invoked when calling the pipeline for generation. | ||
|
@@ -819,6 +831,10 @@ def __call__( | |
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added | ||
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the | ||
corresponding scale as a list. | ||
guess_mode (`bool`, *optional*, defaults to `False`): | ||
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if | ||
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. | ||
|
||
Examples: | ||
|
||
Returns: | ||
|
@@ -883,6 +899,7 @@ def __call__( | |
device=device, | ||
dtype=self.controlnet.dtype, | ||
do_classifier_free_guidance=do_classifier_free_guidance, | ||
guess_mode=guess_mode, | ||
) | ||
elif isinstance(self.controlnet, MultiControlNetModel): | ||
images = [] | ||
|
@@ -897,6 +914,7 @@ def __call__( | |
device=device, | ||
dtype=self.controlnet.dtype, | ||
do_classifier_free_guidance=do_classifier_free_guidance, | ||
guess_mode=guess_mode, | ||
) | ||
|
||
images.append(image_) | ||
|
@@ -934,15 +952,31 @@ def __call__( | |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | ||
|
||
# controlnet(s) inference | ||
if guess_mode and do_classifier_free_guidance: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in terms of controlnet inputs, i.e.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 1 and 2 both share the commonality that they perform ControlNet only with the As for the |
||
# Infer ControlNet only for the conditional batch. | ||
controlnet_latent_model_input = latents | ||
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] | ||
else: | ||
controlnet_latent_model_input = latent_model_input | ||
controlnet_prompt_embeds = prompt_embeds | ||
|
||
down_block_res_samples, mid_block_res_sample = self.controlnet( | ||
latent_model_input, | ||
controlnet_latent_model_input, | ||
t, | ||
encoder_hidden_states=prompt_embeds, | ||
encoder_hidden_states=controlnet_prompt_embeds, | ||
controlnet_cond=image, | ||
conditioning_scale=controlnet_conditioning_scale, | ||
guess_mode=guess_mode, | ||
return_dict=False, | ||
) | ||
|
||
if guess_mode and do_classifier_free_guidance: | ||
# Infered ControlNet only for the conditional batch. | ||
# To apply the output of ControlNet to both the unconditional and conditional batches, | ||
# add 0 to the unconditional batch to keep it unchanged. | ||
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] | ||
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) | ||
|
||
# predict the noise residual | ||
noise_pred = self.unet( | ||
latent_model_input, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
works for me!