From f74665b7b9d45f1cef00a2152153e37626c4c1f7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 17 May 2023 11:58:17 +0000 Subject: [PATCH 1/2] Correct from_ckpt --- src/diffusers/loaders.py | 2 +- .../stable_diffusion/convert_from_ckpt.py | 24 +++++++++++-------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index a1f0d8ec2a52..e50bc31a5c63 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1326,7 +1326,7 @@ def from_ckpt(cls, pretrained_model_link_or_path, **kwargs): file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] from_safetensors = file_extension == "safetensors" - if from_safetensors and use_safetensors is True: + if from_safetensors and use_safetensors is False: raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") # TODO: For now we only support stable diffusion diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 5961636dd197..48abbfc2c794 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -140,17 +140,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.bias", "group_norm.bias") - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) @@ -204,8 +204,12 @@ def assign_to_checkpoint( new_path = new_path.replace(replacement["old"], replacement["new"]) # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] else: checkpoint[new_path] = old_checkpoint[path["old"]] From 048a4a5d08d872ce7edf04419795fe824a9d3e0c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 17 May 2023 12:05:05 +0000 Subject: [PATCH 2/2] make style --- src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 48abbfc2c794..42e8ae7cafd2 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -207,7 +207,7 @@ def assign_to_checkpoint( is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) shape = old_checkpoint[path["old"]].shape if is_attn_weight and len(shape) == 3: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] elif is_attn_weight and len(shape) == 4: checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] else: