From 02409183931cd59abd12100a10989542d3231633 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 27 Apr 2023 12:09:51 -1000 Subject: [PATCH 1/5] refactor img2img VaeImageProcessor.postprocess --- src/diffusers/image_processor.py | 45 +++++++++--- .../pipeline_stable_diffusion_img2img.py | 67 +++++++++-------- tests/others/test_image_processor.py | 6 +- .../test_stable_diffusion_img2img.py | 73 +++++-------------- 4 files changed, 95 insertions(+), 96 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 4598e1b4288c..a8e293c998f5 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings -from typing import Union +from typing import Union, Optional, List import numpy as np import PIL @@ -21,7 +21,7 @@ from PIL import Image from .configuration_utils import ConfigMixin, register_to_config -from .utils import CONFIG_NAME, PIL_INTERPOLATION +from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate class VaeImageProcessor(ConfigMixin): @@ -82,7 +82,7 @@ def numpy_to_pt(images): @staticmethod def pt_to_numpy(images): """ - Convert a numpy image to a pytorch tensor + Convert a pytorch tensor to a numpy image """ images = images.cpu().permute(0, 2, 3, 1).float().numpy() return images @@ -93,6 +93,13 @@ def normalize(images): Normalize an image array to [-1,1] """ return 2.0 * images - 1.0 + + @staticmethod + def denormalize(images): + """ + Denormalize an image array to [0,1] + """ + return (images / 2 + 0.5).clamp(0, 1) def resize(self, images: PIL.Image.Image) -> PIL.Image.Image: """ @@ -165,17 +172,37 @@ def preprocess( def postprocess( self, - image, + image: torch.FloatTensor, output_type: str = "pil", - ): - if isinstance(image, torch.Tensor) and output_type == "pt": + do_normalize: Optional[Union[List[bool], bool]] = None, + ): + if not isinstance(image, torch.Tensor): + raise ValueError( + f"Input for postprocess is in incorrect format: {type(image)}. we only support pytorch tensor" + ) + if output_type not in ["latent", "pt", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " + "`pil`, `np`, `pt`, `latent`" + ) + deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) + output_type = "np" + + if output_type == "latent": + return image + + if not isinstance(do_normalize, list): + do_normalize = image.shape[0] * [do_normalize or self.config.do_normalize] + + image = torch.stack([self.denormalize(image[i]) if do_normalize[i] else image[i] for i in range(image.shape[0])]) + + if output_type == "pt": return image image = self.pt_to_numpy(image) if output_type == "np": return image - elif output_type == "pil": + + if output_type == "pil": return self.numpy_to_pil(image) - else: - raise ValueError(f"Unsupported output_type {output_type}.") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index c26ddf06cadc..c9df2fffbae6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np @@ -129,6 +130,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi """ _optional_components = ["safety_checker", "feature_extractor"] + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ def __init__( self, vae: AutoencoderKL, @@ -205,6 +207,7 @@ def __init__( new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -215,11 +218,8 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.register_to_config( - requires_safety_checker=requires_safety_checker, - ) + self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): @@ -442,18 +442,33 @@ def _encode_prompt( return prompt_embeds - def run_safety_checker(self, image, device, dtype): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype, output_type="pil"): + if self.safety_checker is None: + has_nsfw_concept = False + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) return image, has_nsfw_concept + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs @@ -738,27 +753,19 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - if output_type not in ["latent", "pt", "np", "pil"]: - deprecation_message = ( - f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: " - "`pil`, `np`, `pt`, `latent`" - ) - deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) - output_type = "np" - - if output_type == "latent": - image = latents - has_nsfw_concept = None - + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: - image = self.decode_latents(latents) - - if self.safety_checker is not None: - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - else: - has_nsfw_concept = False - - image = self.image_processor.postprocess(image, output_type=output_type) + image = latents + has_nsfw_concept = False + + do_normalize = ( + [not has_nsfw for has_nsfw in has_nsfw_concept] + if isinstance(has_nsfw_concept, list) + else not has_nsfw_concept + ) + image = self.image_processor.postprocess(image, output_type=output_type, do_normalize=do_normalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/tests/others/test_image_processor.py b/tests/others/test_image_processor.py index 4f0e2c5aecfd..c2cd6f4a04f4 100644 --- a/tests/others/test_image_processor.py +++ b/tests/others/test_image_processor.py @@ -42,7 +42,7 @@ def to_np(self, image): return image def test_vae_image_processor_pt(self): - image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) + image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) input_pt = self.dummy_sample input_np = self.to_np(input_pt) @@ -59,7 +59,7 @@ def test_vae_image_processor_pt(self): ), f"decoded output does not match input for output_type {output_type}" def test_vae_image_processor_np(self): - image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) + image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1) for output_type in ["pt", "np", "pil"]: @@ -72,7 +72,7 @@ def test_vae_image_processor_np(self): ), f"decoded output does not match input for output_type {output_type}" def test_vae_image_processor_pil(self): - image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) + image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1) input_pil = image_processor.numpy_to_pil(input_np) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 4262114c78eb..123f5464dfaa 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -35,18 +35,23 @@ from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu, skip_mps -from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, +) +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False -class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) @@ -96,33 +101,19 @@ def get_dummy_components(self): } return components - def get_dummy_inputs(self, device, seed=0, input_image_type="pt", output_type="np"): + def get_dummy_inputs(self, device, seed=0): image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: generator = torch.Generator(device=device).manual_seed(seed) - - if input_image_type == "pt": - input_image = image - elif input_image_type == "np": - input_image = image.cpu().numpy().transpose(0, 2, 3, 1) - elif input_image_type == "pil": - input_image = image.cpu().numpy().transpose(0, 2, 3, 1) - input_image = VaeImageProcessor.numpy_to_pil(input_image) - else: - raise ValueError(f"unsupported input_image_type {input_image_type}.") - - if output_type not in ["pt", "np", "pil"]: - raise ValueError(f"unsupported output_type {output_type}") - inputs = { "prompt": "A painting of a squirrel eating a burger", - "image": input_image, + "image": image, "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, - "output_type": output_type, + "output_type": "numpy", } return inputs @@ -130,11 +121,12 @@ def test_stable_diffusion_img2img_default_case(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) + sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) + inputs["image"] = inputs["image"] / 2 + 0.5 image = sd_pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] @@ -147,11 +139,12 @@ def test_stable_diffusion_img2img_negative_prompt(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) + sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) + inputs["image"] = inputs["image"] / 2 + 0.5 negative_prompt = "french fries" output = sd_pipe(**inputs, negative_prompt=negative_prompt) image = output.images @@ -166,13 +159,14 @@ def test_stable_diffusion_img2img_multiple_init_images(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) + sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) inputs["prompt"] = [inputs["prompt"]] * 2 inputs["image"] = inputs["image"].repeat(2, 1, 1, 1) + inputs["image"] = inputs["image"] / 2 + 0.5 image = sd_pipe(**inputs).images image_slice = image[-1, -3:, -3:, -1] @@ -188,11 +182,12 @@ def test_stable_diffusion_img2img_k_lms(self): beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) + sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) + inputs["image"] = inputs["image"] / 2 + 0.5 image = sd_pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] @@ -217,36 +212,6 @@ def test_save_load_optional_components(self): def test_attention_slicing_forward_pass(self): return super().test_attention_slicing_forward_pass() - @skip_mps - def test_pt_np_pil_outputs_equivalent(self): - device = "cpu" - components = self.get_dummy_components() - sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - output_pt = sd_pipe(**self.get_dummy_inputs(device, output_type="pt"))[0] - output_np = sd_pipe(**self.get_dummy_inputs(device, output_type="np"))[0] - output_pil = sd_pipe(**self.get_dummy_inputs(device, output_type="pil"))[0] - - assert np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max() <= 1e-4 - assert np.abs(np.array(output_pil[0]) - (output_np * 255).round()).max() <= 1e-4 - - @skip_mps - def test_image_types_consistent(self): - device = "cpu" - components = self.get_dummy_components() - sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - output_pt = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pt"))[0] - output_np = sd_pipe(**self.get_dummy_inputs(device, input_image_type="np"))[0] - output_pil = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pil"))[0] - - assert np.abs(output_pt - output_np).max() <= 1e-4 - assert np.abs(output_pil - output_np).max() <= 1e-2 - @slow @require_torch_gpu From 161db124c91e054d01cd196a1cf7cb8116a61e35 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 27 Apr 2023 15:53:04 -1000 Subject: [PATCH 2/5] make style --- src/diffusers/image_processor.py | 20 +++-- .../pipeline_stable_diffusion_img2img.py | 22 +++-- tests/pipelines/pipeline_params.py | 4 + tests/pipelines/test_pipelines_common.py | 85 ++++++++++++++++--- 4 files changed, 100 insertions(+), 31 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index a8e293c998f5..46ddc6cc9728 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings -from typing import Union, Optional, List +from typing import List, Optional, Union import numpy as np import PIL @@ -93,7 +93,7 @@ def normalize(images): Normalize an image array to [-1,1] """ return 2.0 * images - 1.0 - + @staticmethod def denormalize(images): """ @@ -174,8 +174,8 @@ def postprocess( self, image: torch.FloatTensor, output_type: str = "pil", - do_normalize: Optional[Union[List[bool], bool]] = None, - ): + do_denormalize: Optional[List[bool]] = None, + ): if not isinstance(image, torch.Tensor): raise ValueError( f"Input for postprocess is in incorrect format: {type(image)}. we only support pytorch tensor" @@ -191,10 +191,12 @@ def postprocess( if output_type == "latent": return image - if not isinstance(do_normalize, list): - do_normalize = image.shape[0] * [do_normalize or self.config.do_normalize] - - image = torch.stack([self.denormalize(image[i]) if do_normalize[i] else image[i] for i in range(image.shape[0])]) + if do_denormalize is None: + do_denormalize = [self.config.do_normalize] * image.shape[0] + + image = torch.stack( + [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])] + ) if output_type == "pt": return image @@ -203,6 +205,6 @@ def postprocess( if output_type == "np": return image - + if output_type == "pil": return self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index c9df2fffbae6..5eb9458efa4f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -442,10 +442,9 @@ def _encode_prompt( return prompt_embeds - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype, output_type="pil"): + def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: - has_nsfw_concept = False + has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") @@ -457,7 +456,6 @@ def run_safety_checker(self, image, device, dtype, output_type="pil"): ) return image, has_nsfw_concept - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" @@ -758,14 +756,14 @@ def __call__( image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents - has_nsfw_concept = False - - do_normalize = ( - [not has_nsfw for has_nsfw in has_nsfw_concept] - if isinstance(has_nsfw_concept, list) - else not has_nsfw_concept - ) - image = self.image_processor.postprocess(image, output_type=output_type, do_normalize=do_normalize) + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/tests/pipelines/pipeline_params.py b/tests/pipelines/pipeline_params.py index a0ac6c641c0b..7c5ffa2ca24b 100644 --- a/tests/pipelines/pipeline_params.py +++ b/tests/pipelines/pipeline_params.py @@ -22,6 +22,10 @@ TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"]) +TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([]) + +IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"]) + IMAGE_VARIATION_PARAMS = frozenset( [ "image", diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 0278092282ba..aedda7bae026 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -12,6 +12,7 @@ import diffusers from diffusers import DiffusionPipeline +from diffusers.image_processor import VaeImageProcessor from diffusers.utils import logging from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available from diffusers.utils.testing_utils import require_torch, torch_device @@ -27,6 +28,78 @@ def to_np(tensor): return tensor +class PipelineLatentTesterMixin: + """ + This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. + It provides a set of common tests for PyTorch pipeline that has vae, e.g. + equivalence of different input and output types, etc. + """ + + @property + def image_params(self) -> frozenset: + raise NotImplementedError( + "You need to set the attribute `image_params` in the child test class. " + "`image_params` are tested for if all accepted input image types (i.e. `pt`,`pil`,`np`) are producing same results" + ) + + def get_dummy_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"): + inputs = self.get_dummy_inputs(device, seed) + + def convert_pt_to_type(image, input_image_type): + if input_image_type == "pt": + input_image = image + elif input_image_type == "np": + input_image = VaeImageProcessor.pt_to_numpy(image) + elif input_image_type == "pil": + input_image = VaeImageProcessor.pt_to_numpy(image) + input_image = VaeImageProcessor.numpy_to_pil(input_image) + else: + raise ValueError(f"unsupported input_image_type {input_image_type}.") + return input_image + + for image_param in self.image_params: + if image_param in inputs.keys(): + inputs[image_param] = convert_pt_to_type(inputs[image_param], input_image_type) + + inputs["output_type"] = output_type + + return inputs + + def test_pt_np_pil_outputs_equivalent(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + output_pt = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="pt"))[0] + output_np = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="np"))[0] + output_pil = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="pil"))[0] + + max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max() + self.assertLess(max_diff, 1e-4, "`output_type=='pt'` generate different results from `output_type=='np'`") + + max_diff = np.abs(np.array(output_pil[0]) - (output_np * 255).round()).max() + self.assertLess(max_diff, 1e-4, "`output_type=='pil'` generate different results from `output_type=='np'`") + + def test_pt_np_pil_inputs_equivalent(self): + if len(self.image_params) == 0: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + out_input_pt = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0] + out_input_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0] + out_input_pil = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pil"))[0] + + max_diff = np.abs(out_input_pt - out_input_np).max() + self.assertLess(max_diff, 1e-4, "`input_type=='pt'` generate different result from `input_type=='np'`") + max_diff = np.abs(out_input_pil - out_input_np).max() + self.assertLess(max_diff, 1e-2, "`input_type=='pt'` generate different result from `input_type=='np'`") + + @require_torch class PipelineTesterMixin: """ @@ -339,9 +412,6 @@ def test_components_function(self): @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") def test_float16_inference(self): - self._test_float16_inference() - - def _test_float16_inference(self, expected_max_diff=1e-2): components = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe.to(torch_device) @@ -355,13 +425,10 @@ def _test_float16_inference(self, expected_max_diff=1e-2): output_fp16 = pipe_fp16(**self.get_dummy_inputs(torch_device))[0] max_diff = np.abs(to_np(output) - to_np(output_fp16)).max() - self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.") + self.assertLess(max_diff, 1e-2, "The outputs of the fp16 and fp32 pipelines are too different.") @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") def test_save_load_float16(self): - self._test_save_load_float16() - - def _test_save_load_float16(self, expected_max_diff=1e-2): components = self.get_dummy_components() for name, module in components.items(): if hasattr(module, "half"): @@ -390,9 +457,7 @@ def _test_save_load_float16(self, expected_max_diff=1e-2): output_loaded = pipe_loaded(**inputs)[0] max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() - self.assertLess( - max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." - ) + self.assertLess(max_diff, 1e-2, "The output of the fp16 pipeline changed after saving and loading.") def test_save_load_optional_components(self): if not hasattr(self.pipeline_class, "_optional_components"): From 43068ef19014394970967320a334a73469f899ca Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 27 Apr 2023 15:58:11 -1000 Subject: [PATCH 3/5] remove copy from for init --- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 5eb9458efa4f..5e9a0f9e350b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -130,7 +130,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi """ _optional_components = ["safety_checker", "feature_extractor"] - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ def __init__( self, vae: AutoencoderKL, From 4cd9ecf9f6c64d4215788740497749a468742553 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 27 Apr 2023 16:01:04 -1000 Subject: [PATCH 4/5] alt --- .../pipeline_alt_diffusion_img2img.py | 56 ++++++++++--------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index dee4a91924f7..5df9bab3ae41 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np @@ -202,6 +203,7 @@ def __init__( new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -212,11 +214,8 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.register_to_config( - requires_safety_checker=requires_safety_checker, - ) + self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_sequential_cpu_offload(self, gpu_id=0): r""" @@ -436,17 +435,32 @@ def _encode_prompt( return prompt_embeds def run_safety_checker(self, image, device, dtype): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) return image, has_nsfw_concept def decode_latents(self, latents): + warnings.warn( + ( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead" + ), + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_extra_step_kwargs(self, generator, eta): @@ -730,27 +744,19 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - if output_type not in ["latent", "pt", "np", "pil"]: - deprecation_message = ( - f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: " - "`pil`, `np`, `pt`, `latent`" - ) - deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) - output_type = "np" - - if output_type == "latent": + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor).sample + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: image = latents has_nsfw_concept = None + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] else: - image = self.decode_latents(latents) - - if self.safety_checker is not None: - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - else: - has_nsfw_concept = False + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess(image, output_type=output_type) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: From 91092f06af53773c6db7c04a511e74f4b74c3c34 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sun, 30 Apr 2023 19:19:57 -1000 Subject: [PATCH 5/5] Update src/diffusers/image_processor.py Co-authored-by: Sayak Paul --- src/diffusers/image_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 46ddc6cc9728..68782d1f5f79 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -178,7 +178,7 @@ def postprocess( ): if not isinstance(image, torch.Tensor): raise ValueError( - f"Input for postprocess is in incorrect format: {type(image)}. we only support pytorch tensor" + f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" ) if output_type not in ["latent", "pt", "np", "pil"]: deprecation_message = (