Skip to content

Commit 585f621

Browse files
[Stable Diffusion] Allow users to disable Safety checker if loading model from checkpoint (#2768)
* Allow user to disable SafetyChecker and enable dtypes if loading models from .ckpt or .safetensors * Fix Import sorting (Ruff error) * Get rid of the dtype convert method as it was implemented all along * Fix the docstring * Fix ruff formatting --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent c0afca2 commit 585f621

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,7 @@ def download_from_original_stable_diffusion_ckpt(
989989
stable_unclip_prior: Optional[str] = None,
990990
clip_stats_path: Optional[str] = None,
991991
controlnet: Optional[bool] = None,
992+
load_safety_checker: bool = True,
992993
) -> StableDiffusionPipeline:
993994
"""
994995
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
@@ -1028,6 +1029,8 @@ def download_from_original_stable_diffusion_ckpt(
10281029
The device to use. Pass `None` to determine automatically. :param from_safetensors: If `checkpoint_path` is
10291030
in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A
10301031
StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
1032+
load_safety_checker (`bool`, *optional*, defaults to `True`):
1033+
Whether to load the safety checker or not. Defaults to `True`.
10311034
"""
10321035
if prediction_type == "v-prediction":
10331036
prediction_type = "v_prediction"
@@ -1270,8 +1273,13 @@ def download_from_original_stable_diffusion_ckpt(
12701273
elif model_type == "FrozenCLIPEmbedder":
12711274
text_model = convert_ldm_clip_checkpoint(checkpoint)
12721275
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
1273-
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
1274-
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
1276+
1277+
if load_safety_checker:
1278+
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
1279+
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
1280+
else:
1281+
safety_checker = None
1282+
feature_extractor = None
12751283

12761284
if controlnet:
12771285
pipe = StableDiffusionControlNetPipeline(

0 commit comments

Comments
 (0)