diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 4598e1b4288c..a8e293c998f5 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings -from typing import Union +from typing import Union, Optional, List import numpy as np import PIL @@ -21,7 +21,7 @@ from PIL import Image from .configuration_utils import ConfigMixin, register_to_config -from .utils import CONFIG_NAME, PIL_INTERPOLATION +from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate class VaeImageProcessor(ConfigMixin): @@ -82,7 +82,7 @@ def numpy_to_pt(images): @staticmethod def pt_to_numpy(images): """ - Convert a numpy image to a pytorch tensor + Convert a pytorch tensor to a numpy image """ images = images.cpu().permute(0, 2, 3, 1).float().numpy() return images @@ -93,6 +93,13 @@ def normalize(images): Normalize an image array to [-1,1] """ return 2.0 * images - 1.0 + + @staticmethod + def denormalize(images): + """ + Denormalize an image array to [0,1] + """ + return (images / 2 + 0.5).clamp(0, 1) def resize(self, images: PIL.Image.Image) -> PIL.Image.Image: """ @@ -165,17 +172,37 @@ def preprocess( def postprocess( self, - image, + image: torch.FloatTensor, output_type: str = "pil", - ): - if isinstance(image, torch.Tensor) and output_type == "pt": + do_normalize: Optional[Union[List[bool], bool]] = None, + ): + if not isinstance(image, torch.Tensor): + raise ValueError( + f"Input for postprocess is in incorrect format: {type(image)}. we only support pytorch tensor" + ) + if output_type not in ["latent", "pt", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " + "`pil`, `np`, `pt`, `latent`" + ) + deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) + output_type = "np" + + if output_type == "latent": + return image + + if not isinstance(do_normalize, list): + do_normalize = image.shape[0] * [do_normalize or self.config.do_normalize] + + image = torch.stack([self.denormalize(image[i]) if do_normalize[i] else image[i] for i in range(image.shape[0])]) + + if output_type == "pt": return image image = self.pt_to_numpy(image) if output_type == "np": return image - elif output_type == "pil": + + if output_type == "pil": return self.numpy_to_pil(image) - else: - raise ValueError(f"Unsupported output_type {output_type}.") diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index ff9474ffd43a..6c85fec94037 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch @@ -22,6 +23,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -174,6 +176,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_vae_slicing(self): @@ -425,17 +428,25 @@ def _encode_prompt( return prompt_embeds - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + def run_safety_checker(self, image, device, dtype, output_type="pil"): + if self.safety_checker is None or output_type == "latent": + has_nsfw_concept = False + else: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept def decode_latents(self, latents): + warnings.warn( + ( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead" + ), + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -699,24 +710,12 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - if output_type == "latent": - image = latents - has_nsfw_concept = None - elif output_type == "pil": - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - # 10. Convert to PIL - image = self.numpy_to_pil(image) - else: - # 8. Post-processing - image = self.decode_latents(latents) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type) - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index dee4a91924f7..05d948c2d134 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np @@ -202,6 +203,7 @@ def __init__( new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -212,11 +214,8 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.register_to_config( - requires_safety_checker=requires_safety_checker, - ) + self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_sequential_cpu_offload(self, gpu_id=0): r""" @@ -435,18 +434,30 @@ def _encode_prompt( return prompt_embeds - def run_safety_checker(self, image, device, dtype): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) + def run_safety_checker(self, image, device, dtype, output_type="pil"): + if self.safety_checker is None or output_type == "latent": + has_nsfw_concept = False + else: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) return image, has_nsfw_concept def decode_latents(self, latents): + warnings.warn( + ( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead" + ), + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_extra_step_kwargs(self, generator, eta): @@ -730,27 +741,12 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - if output_type not in ["latent", "pt", "np", "pil"]: - deprecation_message = ( - f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: " - "`pil`, `np`, `pt`, `latent`" - ) - deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) - output_type = "np" - - if output_type == "latent": - image = latents - has_nsfw_concept = None + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - else: - image = self.decode_latents(latents) - - if self.safety_checker is not None: - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - else: - has_nsfw_concept = False + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type) - image = self.image_processor.postprocess(image, output_type=output_type) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index ca0a90a5b5ca..d059e74676dd 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Callable, List, Optional, Union import numpy as np @@ -22,9 +23,10 @@ from diffusers.utils import is_accelerate_available +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import logging, randn_tensor +from ...utils import deprecate, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -184,6 +186,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_sequential_cpu_offload(self, gpu_id=0): @@ -225,14 +228,15 @@ def _execution_device(self): return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + def run_safety_checker(self, image, device, dtype, output_type="pil"): + if self.safety_checker is None or output_type == "latent": + has_nsfw_concept = False + else: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs @@ -255,6 +259,11 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -560,15 +569,22 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 11. Post-processing - image = self.decode_latents(latents) + if output_type not in ["latent", "pt", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: " + "`pil`, `np`, `pt`, `latent`" + ) + deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) + output_type = "np" + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - # 12. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) + image, has_nsfw_concept = self.run_safety_checker( + image, device, image_embeddings.dtype, output_type=output_type + ) - # 13. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index 3d5374875d12..7a8897207e8c 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -1,10 +1,12 @@ import inspect +import warnings from itertools import repeat from typing import Callable, List, Optional, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -129,10 +131,28 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype, output_type="pil"): + if self.safety_checker is None or output_type == "latent": + has_nsfw_concept = False + else: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -680,21 +700,14 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 8. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( - self.device - ) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) - ) - else: - has_nsfw_concept = None + image, has_nsfw_concept = self.run_safety_checker( + image, self.device, text_embeddings.dtype, output_type=output_type + ) - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index e2accb6d2d2a..17754ba37e9c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Callable, List, Optional, Union import numpy as np @@ -24,6 +25,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler @@ -220,6 +222,8 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload @@ -503,18 +507,24 @@ def prepare_extra_step_kwargs(self, generator, eta): return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + def run_safety_checker(self, image, device, dtype, output_type): + if self.safety_checker is None or output_type == "latent": + has_nsfw_concept = False + else: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -769,15 +779,20 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 9. Post-processing - image = self.decode_latents(latents) + if output_type not in ["latent", "pt", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: " + "`pil`, `np`, `pt`, `latent`" + ) + deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) + output_type = "np" + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - # 10. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type) - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py index 8db19c2b9109..fa4c1903cb2a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -1,3 +1,4 @@ +import inspect from logging import getLogger from typing import Any, Callable, List, Optional, Union @@ -6,9 +7,9 @@ import torch from ...schedulers import DDPMScheduler +from ...utils import deprecate, randn_tensor from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel -from ..pipeline_utils import ImagePipelineOutput -from . import StableDiffusionUpscalePipeline +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = getLogger(__name__) @@ -45,7 +46,7 @@ def preprocess(image): return image -class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): +class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline): def __init__( self, vae: OnnxRuntimeModel, @@ -56,7 +57,32 @@ def __init__( scheduler: Any, max_noise_level: int = 350, ): - super().__init__(vae, text_encoder, tokenizer, unet, low_res_scheduler, scheduler, max_noise_level) + if hasattr(vae, "config"): + # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate + is_vae_scaling_factor_set_to_0_08333 = ( + hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333 + ) + if not is_vae_scaling_factor_set_to_0_08333: + deprecation_message = ( + "The configuration file of the vae does not contain `scaling_factor` or it is set to" + f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned" + " version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to" + " 0.08333 Please make sure to update the config accordingly, as not doing so might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging" + " Face Hub, it would be very nice if you could open a Pull Request for the `vae/config.json` file" + ) + deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False) + vae.register_to_config(scaling_factor=0.08333) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + low_res_scheduler=low_res_scheduler, + scheduler=scheduler, + ) + self.register_to_config(max_noise_level=max_noise_level) def __call__( self, @@ -378,3 +404,132 @@ def _encode_prompt( prompt_embeds = np.concatenate([uncond_embeddings, prompt_embeds]) return prompt_embeds + + def check_inputs( + self, + prompt, + image, + noise_level, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}" + ) + + # verify batch size of prompt and image are same if image is a list or tensor + if isinstance(image, list) or isinstance(image, torch.Tensor): + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + if isinstance(image, list): + image_batch_size = len(image) + else: + image_batch_size = image.shape[0] + if batch_size != image_batch_size: + raise ValueError( + f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." + " Please make sure that passed `prompt` matches the batch size of `image`." + ) + + # check noise level + if noise_level > self.config.max_noise_level: + raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height, width) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 7347d70c4023..f8d758fc248a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch @@ -20,6 +21,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -177,6 +179,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_vae_slicing(self): @@ -429,16 +432,25 @@ def _encode_prompt( return prompt_embeds def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + if self.safety_checker is None: + has_nsfw_concept = False + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -702,24 +714,15 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - if output_type == "latent": - image = latents - has_nsfw_concept = None - elif output_type == "pil": - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # 10. Convert to PIL - image = self.numpy_to_pil(image) else: - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = latents + has_nsfw_concept = False + + do_normalize = [not has_nsfw for has_nsfw in has_nsfw_concept] if isinstance(has_nsfw_concept, list) else not has_nsfw_concept + image = self.image_processor.postprocess(image, output_type=output_type, do_normalize=do_normalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index fba2a4e32f88..54b428f1293c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -14,6 +14,7 @@ import inspect import math +import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -21,11 +22,18 @@ from torch.nn import functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention from ...schedulers import KarrasDiffusionSchedulers -from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -228,6 +236,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing @@ -441,18 +450,24 @@ def _encode_prompt( return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + def run_safety_checker(self, image, device, dtype, output_type): + if self.safety_checker is None or output_type == "latent": + has_nsfw_concept = False + else: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -971,15 +986,12 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 8. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type) - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index 322f2232fc8a..ae822e0559a2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -15,6 +15,7 @@ import inspect import os +import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -23,6 +24,7 @@ from torch import nn from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.controlnet import ControlNetOutput @@ -229,6 +231,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing @@ -465,18 +468,24 @@ def _encode_prompt( return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + def run_safety_checker(self, image, device, dtype, output_type="pil"): + if self.safety_checker is None or output_type == "latent": + has_nsfw_concept = False + else: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -1011,24 +1020,12 @@ def __call__( self.controlnet.to("cpu") torch.cuda.empty_cache() - if output_type == "latent": - image = latents - has_nsfw_concept = None - elif output_type == "pil": - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - # 10. Convert to PIL - image = self.numpy_to_pil(image) - else: - # 8. Post-processing - image = self.decode_latents(latents) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type) - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index c4f9ae59a4e9..010cd9afc4f0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -14,6 +14,7 @@ import contextlib import inspect +import warnings from typing import Callable, List, Optional, Union import numpy as np @@ -23,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -128,6 +130,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def enable_sequential_cpu_offload(self, gpu_id=0): r""" @@ -313,18 +316,24 @@ def _encode_prompt( return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + def run_safety_checker(self, image, device, dtype, output_type="pil"): + if self.safety_checker is None or output_type == "latent": + has_nsfw_concept = False + else: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -693,12 +702,10 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 10. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index d543593fdbf5..cf9d791a92a8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Callable, List, Optional, Union import PIL @@ -21,6 +22,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, logging, randn_tensor @@ -118,6 +120,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_sequential_cpu_offload(self, gpu_id=0): @@ -182,18 +185,24 @@ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free return image_embeddings # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + def run_safety_checker(self, image, device, dtype, output_type="pil"): + if self.safety_checker is None or output_type == "latent": + has_nsfw_concept = False + else: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -398,16 +407,14 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 8. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) - - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image, has_nsfw_concept = self.run_safety_checker( + image, device, image_embeddings.dtype, output_type=output_type + ) + image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index c26ddf06cadc..c9df2fffbae6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np @@ -129,6 +130,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi """ _optional_components = ["safety_checker", "feature_extractor"] + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ def __init__( self, vae: AutoencoderKL, @@ -205,6 +207,7 @@ def __init__( new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -215,11 +218,8 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.register_to_config( - requires_safety_checker=requires_safety_checker, - ) + self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): @@ -442,18 +442,33 @@ def _encode_prompt( return prompt_embeds - def run_safety_checker(self, image, device, dtype): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype, output_type="pil"): + if self.safety_checker is None: + has_nsfw_concept = False + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) return image, has_nsfw_concept + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs @@ -738,27 +753,19 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - if output_type not in ["latent", "pt", "np", "pil"]: - deprecation_message = ( - f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: " - "`pil`, `np`, `pt`, `latent`" - ) - deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) - output_type = "np" - - if output_type == "latent": - image = latents - has_nsfw_concept = None - + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: - image = self.decode_latents(latents) - - if self.safety_checker is not None: - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - else: - has_nsfw_concept = False - - image = self.image_processor.postprocess(image, output_type=output_type) + image = latents + has_nsfw_concept = False + + do_normalize = ( + [not has_nsfw for has_nsfw in has_nsfw_concept] + if isinstance(has_nsfw_concept, list) + else not has_nsfw_concept + ) + image = self.image_processor.postprocess(image, output_type=output_type, do_normalize=do_normalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index fb2e5dc424e3..ba2ee1c5b736 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Callable, List, Optional, Union import numpy as np @@ -22,6 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -270,6 +272,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload @@ -494,14 +497,15 @@ def _encode_prompt( return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + def run_safety_checker(self, image, device, dtype, output_type="pil"): + if self.safety_checker is None or output_type == "latent": + has_nsfw_concept = False + else: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs @@ -524,6 +528,11 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -894,15 +903,12 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 11. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - # 12. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type) - # 13. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 3ad1d5e92273..8c7250dc72e2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Callable, List, Optional, Union import numpy as np @@ -22,6 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -209,6 +211,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload @@ -433,18 +436,24 @@ def _encode_prompt( return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + def run_safety_checker(self, image, device, dtype, output_type="pil"): + if self.safety_checker is None or output_type == "latent": + has_nsfw_concept = False + else: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -718,15 +727,12 @@ def __call__( # use original latents corresponding to unmasked portions of the image latents = (init_latents_orig * mask) + (latents * (1 - mask)) - # 10. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - # 11. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type) - # 12. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 49944cdcd636..f65ab3f272a5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Callable, List, Optional, Union import numpy as np @@ -20,6 +21,7 @@ import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -136,6 +138,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @torch.no_grad() @@ -384,15 +387,12 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 10. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - # 11. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type) - # 12. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: @@ -625,14 +625,15 @@ def _encode_prompt( return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + def run_safety_checker(self, image, device, dtype, output_type="pil"): + if self.safety_checker is None or output_type == "latent": + has_nsfw_concept = False + else: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs @@ -655,6 +656,11 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 99aca66db809..7ef0caafdee9 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -13,12 +13,14 @@ # limitations under the License. import importlib +import warnings from typing import Callable, List, Optional, Union import torch from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser from k_diffusion.sampling import get_sigmas_karras +from ...image_processor import VaeImageProcessor from ...loaders import TextualInversionLoaderMixin from ...pipelines import DiffusionPipeline from ...schedulers import LMSDiscreteScheduler @@ -111,6 +113,7 @@ def __init__( ) self.register_to_config(requires_safety_checker=requires_safety_checker) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) model = ModelWrapper(unet, scheduler.alphas_cumprod) if scheduler.config.prediction_type == "v_prediction": @@ -345,18 +348,24 @@ def _encode_prompt( return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + def run_safety_checker(self, image, device, dtype, output_type="pil"): + if self.safety_checker is None or output_type == "latent": + has_nsfw_concept = False + else: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -590,16 +599,12 @@ def model_fn(x, t): # 8. Run k-diffusion solver latents = self.sampler(model_fn, latents, sigmas) - # 9. Post-processing - image = self.decode_latents(latents) - - # 10. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index 822bd49ce31c..0571a130dbc4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Callable, List, Optional, Union import numpy as np @@ -20,6 +21,7 @@ import torch.nn.functional as F from transformers import CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import EulerDiscreteScheduler from ...utils import is_accelerate_available, logging, randn_tensor @@ -91,6 +93,8 @@ def __init__( unet=unet, scheduler=scheduler, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def enable_sequential_cpu_offload(self, gpu_id=0): r""" @@ -220,6 +224,11 @@ def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_p # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -505,12 +514,10 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 10. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py index b7ded03d529b..abb7052563cd 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py @@ -13,11 +13,13 @@ import copy import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import PNDMScheduler @@ -129,6 +131,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) self.with_to_k = with_to_k @@ -372,18 +375,24 @@ def _encode_prompt( return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + def run_safety_checker(self, image, device, dtype, output_type="pil"): + if self.safety_checker is None or output_type == "latent": + has_nsfw_concept = False + else: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -767,24 +776,12 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - if output_type == "latent": - image = latents - has_nsfw_concept = None - elif output_type == "pil": - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - # 10. Convert to PIL - image = self.numpy_to_pil(image) - else: - # 8. Post-processing - image = self.decode_latents(latents) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type) - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index 392b2a72a76f..037936cf17ba 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -12,15 +12,23 @@ # limitations under the License. import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler, PNDMScheduler -from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -123,6 +131,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing @@ -336,18 +345,24 @@ def _encode_prompt( return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + def run_safety_checker(self, image, device, dtype, output_type="pil"): + if self.safety_checker is None or output_type == "latent": + has_nsfw_concept = False + else: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -659,15 +674,12 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 8. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type) - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 6444ec7c8506..082508019e6d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union @@ -28,6 +29,7 @@ CLIPTokenizer, ) +from ...image_processor import VaeImageProcessor from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention @@ -358,6 +360,7 @@ def __init__( inverse_scheduler=inverse_scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload @@ -577,18 +580,24 @@ def _encode_prompt( return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + def run_safety_checker(self, image, device, dtype, output_type="pil"): + if self.safety_checker is None or output_type == "latent": + has_nsfw_concept = False + else: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -1045,24 +1054,20 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - # 11. Post-process the latents. - edited_image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - # 12. Run the safety checker. - edited_image, has_nsfw_concept = self.run_safety_checker(edited_image, device, prompt_embeds.dtype) - - # 13. Convert to PIL. - if output_type == "pil": - edited_image = self.numpy_to_pil(edited_image) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: - return (edited_image, has_nsfw_concept) + return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=edited_image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) @torch.no_grad() @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) @@ -1251,16 +1256,15 @@ def invert( inverted_latents = latents.detach().clone() # 8. Post-processing - image = self.decode_latents(latents.detach()) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample + + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() - # 9. Convert to PIL. - if output_type == "pil": - image = self.numpy_to_pil(image) - if not return_dict: return (inverted_latents, image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index ebac58e18f62..17e88b2a482e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -13,16 +13,24 @@ # limitations under the License. import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers -from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -140,6 +148,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing @@ -353,18 +362,24 @@ def _encode_prompt( return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + def run_safety_checker(self, image, device, dtype, output_type="pil"): + if self.safety_checker is None or output_type == "latent": + has_nsfw_concept = False + else: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -682,15 +697,12 @@ def get_map_size(module, input, output): if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 8. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type) - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 693208b18cdd..e7663cba87a0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Callable, List, Optional, Union import numpy as np @@ -20,6 +21,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers @@ -114,6 +116,8 @@ def __init__( low_res_scheduler=low_res_scheduler, scheduler=scheduler, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(max_noise_level=max_noise_level) def enable_sequential_cpu_offload(self, gpu_id=0): @@ -345,6 +349,11 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -669,18 +678,17 @@ def __call__( callback(i, t, latents) # 10. Post-processing - # make sure the VAE is in float32 mode, as it overflows in float16 - self.vae.to(dtype=torch.float32) - image = self.decode_latents(latents.float()) + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + self.vae.to(dtype=torch.float32) + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample + image = self.image_processor.postprocess(image, output_type=output_type) + # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index fafb8d1d2800..e87d728dfde0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -13,17 +13,25 @@ # limitations under the License. import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers.models.clip.modeling_clip import CLIPTextModelOutput +from ...image_processor import VaeImageProcessor from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers -from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer @@ -136,6 +144,7 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -474,6 +483,11 @@ def _encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -916,17 +930,15 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 14. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample + + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() - # 15. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index 22b7280f3679..666cd54f50ca 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import PIL @@ -21,6 +22,7 @@ from diffusers.utils.import_utils import is_accelerate_available +from ...image_processor import VaeImageProcessor from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding @@ -138,6 +140,7 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -429,6 +432,11 @@ def _encode_image( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -813,16 +821,15 @@ def __call__( callback(i, t, latents) # 9. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample + + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index 87e7b3e6c9eb..dd5f3745ab2b 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -361,7 +361,6 @@ def run_safety_checker(self, image, device, dtype, enable_safety_guidance): flagged_images = None return image, has_nsfw_concept, flagged_images - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index 661a1bd3cf73..c9eb10703cb8 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Callable, List, Optional, Tuple, Union import numpy as np @@ -26,6 +27,7 @@ CLIPVisionModelWithProjection, ) +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor @@ -88,6 +90,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if self.text_unet is not None and ( "dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention @@ -329,6 +332,11 @@ def normalize_embeddings(encoder_output): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -572,12 +580,10 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 9. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index e3a2ee370362..e4c5b6ed7e1e 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Callable, List, Optional, Union import numpy as np @@ -21,6 +22,7 @@ import torch.utils.checkpoint from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor @@ -71,6 +73,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def enable_sequential_cpu_offload(self, gpu_id=0): r""" @@ -189,6 +192,11 @@ def normalize_embeddings(encoder_output): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -414,13 +422,10 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample + image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index 26b9be2bfa76..869a04044d29 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -13,12 +13,14 @@ # limitations under the License. import inspect +import warnings from typing import Callable, List, Optional, Union import torch import torch.utils.checkpoint from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor @@ -76,6 +78,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if self.text_unet is not None: self._swap_unet_attention_blocks() @@ -246,6 +249,11 @@ def normalize_embeddings(encoder_output): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -488,12 +496,10 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 9. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) diff --git a/tests/others/test_image_processor.py b/tests/others/test_image_processor.py index 4f0e2c5aecfd..c2cd6f4a04f4 100644 --- a/tests/others/test_image_processor.py +++ b/tests/others/test_image_processor.py @@ -42,7 +42,7 @@ def to_np(self, image): return image def test_vae_image_processor_pt(self): - image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) + image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) input_pt = self.dummy_sample input_np = self.to_np(input_pt) @@ -59,7 +59,7 @@ def test_vae_image_processor_pt(self): ), f"decoded output does not match input for output_type {output_type}" def test_vae_image_processor_np(self): - image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) + image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1) for output_type in ["pt", "np", "pil"]: @@ -72,7 +72,7 @@ def test_vae_image_processor_np(self): ), f"decoded output does not match input for output_type {output_type}" def test_vae_image_processor_pil(self): - image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) + image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1) input_pil = image_processor.numpy_to_pil(input_np) diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion.py b/tests/pipelines/altdiffusion/test_alt_diffusion.py index 4d19621f0c2c..60eb17e76c0a 100644 --- a/tests/pipelines/altdiffusion/test_alt_diffusion.py +++ b/tests/pipelines/altdiffusion/test_alt_diffusion.py @@ -28,17 +28,18 @@ from diffusers.utils import slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False -class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class AltDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = AltDiffusionPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py index 144107ec1c97..1f96d8954156 100644 --- a/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py +++ b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py @@ -123,6 +123,7 @@ def test_stable_diffusion_img2img_default_case(self): tokenizer.model_max_length = 77 init_image = self.dummy_image.to(device) + init_image = init_image / 2 + 0.5 # make sure here that pndm scheduler skips prk alt_pipe = AltDiffusionImg2ImgPipeline( @@ -134,7 +135,7 @@ def test_stable_diffusion_img2img_default_case(self): safety_checker=None, feature_extractor=self.dummy_extractor, ) - alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=False) + alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=True) alt_pipe = alt_pipe.to(device) alt_pipe.set_progress_bar_config(disable=None) diff --git a/tests/pipelines/paint_by_example/test_paint_by_example.py b/tests/pipelines/paint_by_example/test_paint_by_example.py index 17feba59e8e4..b0f5e74bd8eb 100644 --- a/tests/pipelines/paint_by_example/test_paint_by_example.py +++ b/tests/pipelines/paint_by_example/test_paint_by_example.py @@ -28,16 +28,17 @@ from diffusers.utils.testing_utils import require_torch_gpu from ..pipeline_params import IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False -class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class PaintByExamplePipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = PaintByExamplePipeline params = IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS batch_params = IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + image_params = frozenset([]) # TO_DO: update the image_prams once refactored VaeImageProcessor.preprocess def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/pipeline_params.py b/tests/pipelines/pipeline_params.py index a0ac6c641c0b..7c5ffa2ca24b 100644 --- a/tests/pipelines/pipeline_params.py +++ b/tests/pipelines/pipeline_params.py @@ -22,6 +22,10 @@ TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"]) +TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([]) + +IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"]) + IMAGE_VARIATION_PARAMS = frozenset( [ "image", diff --git a/tests/pipelines/stable_diffusion/test_cycle_diffusion.py b/tests/pipelines/stable_diffusion/test_cycle_diffusion.py index 05b72ab6a0fd..52d3b03e5220 100644 --- a/tests/pipelines/stable_diffusion/test_cycle_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_cycle_diffusion.py @@ -26,13 +26,13 @@ from diffusers.utils.testing_utils import require_torch_gpu, skip_mps from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False -class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class CycleDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = CycleDiffusionPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - { "negative_prompt", @@ -42,6 +42,7 @@ class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): } required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"source_prompt"}) + image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index fcfcd84c5d48..5897f4b6498a 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -41,17 +41,18 @@ from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu from ...models.test_models_unet_2d_condition import create_lora_layers -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False -class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py index 70b3652fce77..666b09692861 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py @@ -34,13 +34,14 @@ from diffusers.utils.testing_utils import require_torch_gpu from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin -class StableDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionControlNetPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py index 3bfa5810428a..fbdfc75faa84 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py @@ -33,16 +33,21 @@ from diffusers.utils.testing_utils import require_torch_gpu from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False -class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionImageVariationPipelineFastTests( + PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionImageVariationPipeline params = IMAGE_VARIATION_PARAMS batch_params = IMAGE_VARIATION_BATCH_PARAMS + image_params = frozenset( + [] + ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 4262114c78eb..123f5464dfaa 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -35,18 +35,23 @@ from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu, skip_mps -from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, +) +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False -class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) @@ -96,33 +101,19 @@ def get_dummy_components(self): } return components - def get_dummy_inputs(self, device, seed=0, input_image_type="pt", output_type="np"): + def get_dummy_inputs(self, device, seed=0): image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: generator = torch.Generator(device=device).manual_seed(seed) - - if input_image_type == "pt": - input_image = image - elif input_image_type == "np": - input_image = image.cpu().numpy().transpose(0, 2, 3, 1) - elif input_image_type == "pil": - input_image = image.cpu().numpy().transpose(0, 2, 3, 1) - input_image = VaeImageProcessor.numpy_to_pil(input_image) - else: - raise ValueError(f"unsupported input_image_type {input_image_type}.") - - if output_type not in ["pt", "np", "pil"]: - raise ValueError(f"unsupported output_type {output_type}") - inputs = { "prompt": "A painting of a squirrel eating a burger", - "image": input_image, + "image": image, "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, - "output_type": output_type, + "output_type": "numpy", } return inputs @@ -130,11 +121,12 @@ def test_stable_diffusion_img2img_default_case(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) + sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) + inputs["image"] = inputs["image"] / 2 + 0.5 image = sd_pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] @@ -147,11 +139,12 @@ def test_stable_diffusion_img2img_negative_prompt(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) + sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) + inputs["image"] = inputs["image"] / 2 + 0.5 negative_prompt = "french fries" output = sd_pipe(**inputs, negative_prompt=negative_prompt) image = output.images @@ -166,13 +159,14 @@ def test_stable_diffusion_img2img_multiple_init_images(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) + sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) inputs["prompt"] = [inputs["prompt"]] * 2 inputs["image"] = inputs["image"].repeat(2, 1, 1, 1) + inputs["image"] = inputs["image"] / 2 + 0.5 image = sd_pipe(**inputs).images image_slice = image[-1, -3:, -3:, -1] @@ -188,11 +182,12 @@ def test_stable_diffusion_img2img_k_lms(self): beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) + sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) + inputs["image"] = inputs["image"] / 2 + 0.5 image = sd_pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] @@ -217,36 +212,6 @@ def test_save_load_optional_components(self): def test_attention_slicing_forward_pass(self): return super().test_attention_slicing_forward_pass() - @skip_mps - def test_pt_np_pil_outputs_equivalent(self): - device = "cpu" - components = self.get_dummy_components() - sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - output_pt = sd_pipe(**self.get_dummy_inputs(device, output_type="pt"))[0] - output_np = sd_pipe(**self.get_dummy_inputs(device, output_type="np"))[0] - output_pil = sd_pipe(**self.get_dummy_inputs(device, output_type="pil"))[0] - - assert np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max() <= 1e-4 - assert np.abs(np.array(output_pil[0]) - (output_np * 255).round()).max() <= 1e-4 - - @skip_mps - def test_image_types_consistent(self): - device = "cpu" - components = self.get_dummy_components() - sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - output_pt = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pt"))[0] - output_np = sd_pipe(**self.get_dummy_inputs(device, input_image_type="np"))[0] - output_pil = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pil"))[0] - - assert np.abs(output_pt - output_np).max() <= 1e-4 - assert np.abs(output_pil - output_np).max() <= 1e-2 - @slow @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py index 8915f524d972..08dc1b2844dc 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py @@ -35,16 +35,21 @@ from diffusers.utils.testing_utils import require_torch_gpu from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False -class StableDiffusionInstructPix2PixPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionInstructPix2PixPipelineFastTests( + PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionInstructPix2PixPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"} batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + image_params = frozenset( + [] + ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py index bafad63ec2db..b1bed4b3cf25 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py @@ -31,18 +31,19 @@ from diffusers.utils import slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu, skip_mps -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False @skip_mps -class StableDiffusionModelEditingPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionModelEditingPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionModelEditingPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py index 3ead4fe55bab..82e42b095f5d 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py @@ -32,18 +32,19 @@ from diffusers.utils import slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu, skip_mps -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False @skip_mps -class StableDiffusionPanoramaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionPanoramaPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index 661926daaa3e..af64a23c4003 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -36,17 +36,20 @@ from diffusers.utils.testing_utils import load_image, load_pt, require_torch_gpu, skip_mps from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False @skip_mps -class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionPix2PixZeroPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = frozenset( + [] + ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess @classmethod def setUpClass(cls): diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py index 73859bdbf7d8..ad0d50df3ce5 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py @@ -29,17 +29,18 @@ from diffusers.utils import slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False -class StableDiffusionSAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionSAGPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS test_cpu_offload = False def get_dummy_components(self): diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py index 846e251f3ce2..60cf9c7982e9 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py @@ -29,16 +29,19 @@ from diffusers.utils import load_numpy, skip_mps, slow from diffusers.utils.testing_utils import require_torch_gpu -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin @skip_mps -class StableDiffusionAttendAndExcitePipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionAttendAndExcitePipelineFastTests( + PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionAttendAndExcitePipeline test_attention_slicing = False params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"token_indices"}) + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index 7a5e02a42af4..7b63583eef77 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -52,19 +52,22 @@ from diffusers.utils.testing_utils import require_torch_gpu, skip_mps from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False @skip_mps -class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionDepth2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionDepth2ImgPipeline test_save_load_optional_components = False params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = frozenset( + [] + ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess def get_dummy_components(self): torch.manual_seed(0) @@ -132,7 +135,7 @@ def get_dummy_components(self): backbone_config=backbone_config, backbone_featmap_shape=[1, 384, 24, 24], ) - depth_estimator = DPTForDepthEstimation(depth_estimator_config) + depth_estimator = DPTForDepthEstimation(depth_estimator_config).eval() feature_extractor = DPTFeatureExtractor.from_pretrained( "hf-internal-testing/tiny-random-DPTForDepthEstimation" ) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py index 2fa8b9045f43..843a6146dac9 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py @@ -27,16 +27,19 @@ from diffusers.utils.testing_utils import require_torch_gpu, slow from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False -class StableDiffusion2InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusion2InpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionInpaintPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + image_params = frozenset( + [] + ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py index aff1c1cdbde9..910f3de81325 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py @@ -32,13 +32,13 @@ from diffusers.utils.testing_utils import require_torch_gpu from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False -class StableDiffusionLatentUpscalePipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionLatentUpscalePipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionLatentUpscalePipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - { "height", @@ -49,6 +49,9 @@ class StableDiffusionLatentUpscalePipelineFastTests(PipelineTesterMixin, unittes } required_optional_params = PipelineTesterMixin.required_optional_params - {"num_images_per_prompt"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = frozenset( + [] + ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess test_cpu_offload = True @property diff --git a/tests/pipelines/stable_unclip/test_stable_unclip.py b/tests/pipelines/stable_unclip/test_stable_unclip.py index 891323d22fe0..b0e65692e8b5 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip.py @@ -15,14 +15,15 @@ from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer from diffusers.utils.testing_utils import load_numpy, require_torch_gpu, slow, torch_device -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference -class StableUnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableUnCLIPPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableUnCLIPPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS # TODO(will) Expected attn_bias.stride(1) == 0 to be true, but got false test_xformers_attention = False diff --git a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py index 69e3225ced52..450e0af8dcdc 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py @@ -29,15 +29,19 @@ from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..test_pipelines_common import ( + PipelineLatentTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference, ) -class StableUnCLIPImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableUnCLIPImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableUnCLIPImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = frozenset( + [] + ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess def get_dummy_components(self): embedder_hidden_size = 32 diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index d0712bdec8f6..aedda7bae026 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -12,6 +12,7 @@ import diffusers from diffusers import DiffusionPipeline +from diffusers.image_processor import VaeImageProcessor from diffusers.utils import logging from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available from diffusers.utils.testing_utils import require_torch, torch_device @@ -27,6 +28,78 @@ def to_np(tensor): return tensor +class PipelineLatentTesterMixin: + """ + This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. + It provides a set of common tests for PyTorch pipeline that has vae, e.g. + equivalence of different input and output types, etc. + """ + + @property + def image_params(self) -> frozenset: + raise NotImplementedError( + "You need to set the attribute `image_params` in the child test class. " + "`image_params` are tested for if all accepted input image types (i.e. `pt`,`pil`,`np`) are producing same results" + ) + + def get_dummy_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"): + inputs = self.get_dummy_inputs(device, seed) + + def convert_pt_to_type(image, input_image_type): + if input_image_type == "pt": + input_image = image + elif input_image_type == "np": + input_image = VaeImageProcessor.pt_to_numpy(image) + elif input_image_type == "pil": + input_image = VaeImageProcessor.pt_to_numpy(image) + input_image = VaeImageProcessor.numpy_to_pil(input_image) + else: + raise ValueError(f"unsupported input_image_type {input_image_type}.") + return input_image + + for image_param in self.image_params: + if image_param in inputs.keys(): + inputs[image_param] = convert_pt_to_type(inputs[image_param], input_image_type) + + inputs["output_type"] = output_type + + return inputs + + def test_pt_np_pil_outputs_equivalent(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + output_pt = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="pt"))[0] + output_np = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="np"))[0] + output_pil = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="pil"))[0] + + max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max() + self.assertLess(max_diff, 1e-4, "`output_type=='pt'` generate different results from `output_type=='np'`") + + max_diff = np.abs(np.array(output_pil[0]) - (output_np * 255).round()).max() + self.assertLess(max_diff, 1e-4, "`output_type=='pil'` generate different results from `output_type=='np'`") + + def test_pt_np_pil_inputs_equivalent(self): + if len(self.image_params) == 0: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + out_input_pt = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0] + out_input_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0] + out_input_pil = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pil"))[0] + + max_diff = np.abs(out_input_pt - out_input_np).max() + self.assertLess(max_diff, 1e-4, "`input_type=='pt'` generate different result from `input_type=='np'`") + max_diff = np.abs(out_input_pil - out_input_np).max() + self.assertLess(max_diff, 1e-2, "`input_type=='pt'` generate different result from `input_type=='np'`") + + @require_torch class PipelineTesterMixin: """