From 01d4fef842270145ae1afdb9f5b0cf9c072f3748 Mon Sep 17 00:00:00 2001 From: Stax124 Date: Tue, 21 Mar 2023 17:01:37 +0100 Subject: [PATCH 1/5] Allow user to disable SafetyChecker and enable dtypes if loading models from .ckpt or .safetensors --- .../stable_diffusion/convert_from_ckpt.py | 62 +++++++++++++++---- 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index ef4598433f82..bd377f19cc2f 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -18,19 +18,9 @@ from io import BytesIO from typing import Optional -import requests import torch -from transformers import ( - AutoFeatureExtractor, - BertTokenizerFast, - CLIPImageProcessor, - CLIPTextModel, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionConfig, - CLIPVisionModelWithProjection, -) +import requests from diffusers import ( AutoencoderKL, ControlNetModel, @@ -55,6 +45,16 @@ from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) from ...utils import is_omegaconf_available, is_safetensors_available, logging from ...utils.import_utils import BACKENDS_MAPPING @@ -954,6 +954,23 @@ def stable_unclip_image_noising_components( return image_normalizer, image_noising_scheduler +def cast_to_dtype(module, dtype: torch.dtype): + """ + Converts the module to the given dtype + + Args: + module: + The module to convert + dtype: (`torch.dtype`) + The dtype to convert to (e.g. `torch.float16`) + """ + + if dtype == torch.float16: + module.half() + elif dtype == torch.bfloat16: + module.bfloat16() + + def convert_controlnet_checkpoint( checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema ): @@ -989,6 +1006,8 @@ def download_from_original_stable_diffusion_ckpt( stable_unclip_prior: Optional[str] = None, clip_stats_path: Optional[str] = None, controlnet: Optional[bool] = None, + torch_dtype: torch.dtype = torch.float32, + load_safety_checker: bool = True, ) -> StableDiffusionPipeline: """ Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` @@ -1028,6 +1047,10 @@ def download_from_original_stable_diffusion_ckpt( The device to use. Pass `None` to determine automatically. :param from_safetensors: If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. + torch_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + The data type to use for the model. Defaults to `torch.float32`. + load_safety_checker (`bool`, *optional*, defaults to `True`): + Whether to load the safety checker or not. Defaults to `True`. """ if prediction_type == "v-prediction": prediction_type = "v_prediction" @@ -1159,6 +1182,7 @@ def download_from_original_stable_diffusion_ckpt( ) unet.load_state_dict(converted_unet_checkpoint) + cast_to_dtype(unet, dtype=torch_dtype) # Convert the VAE model. vae_config = create_vae_diffusers_config(original_config, image_size=image_size) @@ -1166,6 +1190,7 @@ def download_from_original_stable_diffusion_ckpt( vae = AutoencoderKL(**vae_config) vae.load_state_dict(converted_vae_checkpoint) + cast_to_dtype(vae, dtype=torch_dtype) # Convert the text model. if model_type is None: @@ -1174,6 +1199,8 @@ def download_from_original_stable_diffusion_ckpt( if model_type == "FrozenOpenCLIPEmbedder": text_model = convert_open_clip_checkpoint(checkpoint) + cast_to_dtype(text_model, dtype=torch_dtype) + tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer") if stable_unclip is None: @@ -1269,9 +1296,18 @@ def download_from_original_stable_diffusion_ckpt( ) elif model_type == "FrozenCLIPEmbedder": text_model = convert_ldm_clip_checkpoint(checkpoint) + cast_to_dtype(text_model, dtype=torch_dtype) + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") - feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") + + if load_safety_checker: + safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") + cast_to_dtype(safety_checker, dtype=torch_dtype) + + feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") + else: + safety_checker = None + feature_extractor = None if controlnet: pipe = StableDiffusionControlNetPipeline( From 5d41b5383c6a0421a24b17f4520d71a1834c1dac Mon Sep 17 00:00:00 2001 From: Stax124 Date: Tue, 21 Mar 2023 17:49:37 +0100 Subject: [PATCH 2/5] Fix Import sorting (Ruff error) --- .../stable_diffusion/convert_from_ckpt.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index bd377f19cc2f..7c15a85f98ce 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -18,9 +18,19 @@ from io import BytesIO from typing import Optional +import requests import torch +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) -import requests from diffusers import ( AutoencoderKL, ControlNetModel, @@ -45,16 +55,6 @@ from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer -from transformers import ( - AutoFeatureExtractor, - BertTokenizerFast, - CLIPImageProcessor, - CLIPTextModel, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionConfig, - CLIPVisionModelWithProjection, -) from ...utils import is_omegaconf_available, is_safetensors_available, logging from ...utils.import_utils import BACKENDS_MAPPING From 280cb2ae1f3e7bfea1c53634457d879399ee3d2d Mon Sep 17 00:00:00 2001 From: Stax124 Date: Thu, 23 Mar 2023 17:30:24 +0100 Subject: [PATCH 3/5] Get rid of the dtype convert method as it was implemented all along --- .../stable_diffusion/convert_from_ckpt.py | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 7c15a85f98ce..e8cfce23ac2a 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -954,23 +954,6 @@ def stable_unclip_image_noising_components( return image_normalizer, image_noising_scheduler -def cast_to_dtype(module, dtype: torch.dtype): - """ - Converts the module to the given dtype - - Args: - module: - The module to convert - dtype: (`torch.dtype`) - The dtype to convert to (e.g. `torch.float16`) - """ - - if dtype == torch.float16: - module.half() - elif dtype == torch.bfloat16: - module.bfloat16() - - def convert_controlnet_checkpoint( checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema ): @@ -1006,7 +989,6 @@ def download_from_original_stable_diffusion_ckpt( stable_unclip_prior: Optional[str] = None, clip_stats_path: Optional[str] = None, controlnet: Optional[bool] = None, - torch_dtype: torch.dtype = torch.float32, load_safety_checker: bool = True, ) -> StableDiffusionPipeline: """ @@ -1182,7 +1164,6 @@ def download_from_original_stable_diffusion_ckpt( ) unet.load_state_dict(converted_unet_checkpoint) - cast_to_dtype(unet, dtype=torch_dtype) # Convert the VAE model. vae_config = create_vae_diffusers_config(original_config, image_size=image_size) @@ -1190,7 +1171,6 @@ def download_from_original_stable_diffusion_ckpt( vae = AutoencoderKL(**vae_config) vae.load_state_dict(converted_vae_checkpoint) - cast_to_dtype(vae, dtype=torch_dtype) # Convert the text model. if model_type is None: @@ -1199,8 +1179,6 @@ def download_from_original_stable_diffusion_ckpt( if model_type == "FrozenOpenCLIPEmbedder": text_model = convert_open_clip_checkpoint(checkpoint) - cast_to_dtype(text_model, dtype=torch_dtype) - tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer") if stable_unclip is None: @@ -1296,14 +1274,10 @@ def download_from_original_stable_diffusion_ckpt( ) elif model_type == "FrozenCLIPEmbedder": text_model = convert_ldm_clip_checkpoint(checkpoint) - cast_to_dtype(text_model, dtype=torch_dtype) - tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") if load_safety_checker: safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") - cast_to_dtype(safety_checker, dtype=torch_dtype) - feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") else: safety_checker = None From fcf1963ab6f0551bb8a31c67e95a44eb221a1d6b Mon Sep 17 00:00:00 2001 From: Stax124 Date: Thu, 23 Mar 2023 17:37:30 +0100 Subject: [PATCH 4/5] Fix the docstring --- .../stable_diffusion/convert_from_ckpt.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index e8cfce23ac2a..b23c97ff9290 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -18,19 +18,9 @@ from io import BytesIO from typing import Optional -import requests import torch -from transformers import ( - AutoFeatureExtractor, - BertTokenizerFast, - CLIPImageProcessor, - CLIPTextModel, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionConfig, - CLIPVisionModelWithProjection, -) +import requests from diffusers import ( AutoencoderKL, ControlNetModel, @@ -55,6 +45,16 @@ from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) from ...utils import is_omegaconf_available, is_safetensors_available, logging from ...utils.import_utils import BACKENDS_MAPPING @@ -1029,8 +1029,6 @@ def download_from_original_stable_diffusion_ckpt( The device to use. Pass `None` to determine automatically. :param from_safetensors: If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. - torch_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - The data type to use for the model. Defaults to `torch.float32`. load_safety_checker (`bool`, *optional*, defaults to `True`): Whether to load the safety checker or not. Defaults to `True`. """ From 75fd2d6236f4763f98fd66e06937a3f76d2f87d0 Mon Sep 17 00:00:00 2001 From: Stax124 Date: Thu, 23 Mar 2023 17:39:43 +0100 Subject: [PATCH 5/5] Fix ruff formatting --- .../stable_diffusion/convert_from_ckpt.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index b23c97ff9290..ad0dea383402 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -18,9 +18,19 @@ from io import BytesIO from typing import Optional +import requests import torch +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) -import requests from diffusers import ( AutoencoderKL, ControlNetModel, @@ -45,16 +55,6 @@ from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer -from transformers import ( - AutoFeatureExtractor, - BertTokenizerFast, - CLIPImageProcessor, - CLIPTextModel, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionConfig, - CLIPVisionModelWithProjection, -) from ...utils import is_omegaconf_available, is_safetensors_available, logging from ...utils.import_utils import BACKENDS_MAPPING