From 1a169180d4c7efba4025ccf7d274476981bf67cd Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Thu, 4 May 2023 15:26:19 +0530 Subject: [PATCH 1/8] Add movq --- .../convert_kandinsky_movq_to_diffusers.py | 465 ++++++++++++++++++ src/diffusers/models/attention.py | 45 +- src/diffusers/models/resnet.py | 14 +- src/diffusers/models/unet_2d_blocks.py | 35 +- src/diffusers/models/vae.py | 40 +- src/diffusers/models/vq_model.py | 12 +- src/diffusers/pipelines/kandinsky/__init__.py | 2 +- 7 files changed, 585 insertions(+), 28 deletions(-) create mode 100644 scripts/convert_kandinsky_movq_to_diffusers.py diff --git a/scripts/convert_kandinsky_movq_to_diffusers.py b/scripts/convert_kandinsky_movq_to_diffusers.py new file mode 100644 index 000000000000..4a40b10ec61a --- /dev/null +++ b/scripts/convert_kandinsky_movq_to_diffusers.py @@ -0,0 +1,465 @@ +import argparse +import tempfile +import os +import torch +from accelerate import load_checkpoint_and_dispatch +from diffusers.models.vq_model import VQModel + +""" +Example - From the diffusers root directory: + +Download weights: +```sh +$ wget https://huggingface.co/ai-forever/Kandinsky_2.1/blob/main/movq_fp16.ckpt +``` + +Convert the model: +```sh +python scripts/convert_kandinsky_movq_to_diffusers.py --movq_checkpoint_path D:/Kandinsky-2/weights/2_1/movq_final.ckpt --dump_path ./kandinsky_model --debug movq +``` +"""# movq + +movq_ORIGINAL_PREFIX = "model" + +# Uses default arguments +movq_CONFIG = {} + + +def movq_model_from_original_config(): + movq = VQModel( + in_channels=3, + out_channels=3, + latent_channels=4, + use_spatial_norm=True, + down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "AttnDownEncoderBlock2D"), + up_block_types=("AttnUpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"), + num_vq_embeddings=16384, + block_out_channels=(128, 256, 256, 512), + vq_embed_dim=4, + layers_per_block=2 + ) + return movq + + +def vqvae_encoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv_in + diffusers_checkpoint.update( + { + "encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"], + "encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"], + } + ) + + # down_blocks + for down_block_idx, down_block in enumerate(model.encoder.down_blocks): + diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}" + down_block_prefix = f"encoder.down.{down_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(down_block.resnets): + diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + # downsample + + # do not include the downsample when on the last down block + # There is no downsample on the last down block + if down_block_idx != len(model.encoder.down_blocks) - 1: + # There's a single downsample in the original checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv" + downsample_prefix = f"{down_block_prefix}.downsample.conv" + diffusers_checkpoint.update( + { + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], + } + ) + + # attentions + + if hasattr(down_block, "attentions"): + for attention_idx, _ in enumerate(down_block.attentions): + diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{down_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix, + ) + ) + + # mid block + + # mid block attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion encoder + diffusers_attention_prefix = "encoder.mid_block.attentions.0" + attention_prefix = "encoder.mid.attn_1" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # mid block resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion encoder + resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update( + { + # conv_norm_out + "encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"], + "encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"], + # conv_out + "encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"], + "encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"], + } + ) + + return diffusers_checkpoint + + +def vqvae_decoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv in + diffusers_checkpoint.update( + { + "decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"], + "decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"], + } + ) + + # up_blocks + + for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks): + # up_blocks are stored in reverse order in the VQ-diffusion checkpoint + orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx + + diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}" + up_block_prefix = f"decoder.up.{orig_up_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(up_block.resnets): + diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint_spatial_norm( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + # upsample + + # there is no up sample on the last up block + if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1: + # There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv" + downsample_prefix = f"{up_block_prefix}.upsample.conv" + diffusers_checkpoint.update( + { + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], + } + ) + + # attentions + + if hasattr(up_block, "attentions"): + for attention_idx, _ in enumerate(up_block.attentions): + diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{up_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint_spatial_norm( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix, + ) + ) + + # mid block + + # mid block attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion decoder + diffusers_attention_prefix = "decoder.mid_block.attentions.0" + attention_prefix = "decoder.mid.attn_1" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint_spatial_norm( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # mid block resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion decoder + resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint_spatial_norm( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update( + { + # conv_norm_out + "decoder.conv_norm_out.norm_layer.weight": checkpoint["decoder.norm_out.norm_layer.weight"], + "decoder.conv_norm_out.norm_layer.bias": checkpoint["decoder.norm_out.norm_layer.bias"], + "decoder.conv_norm_out.conv_y.weight": checkpoint["decoder.norm_out.conv_y.weight"], + "decoder.conv_norm_out.conv_y.bias": checkpoint["decoder.norm_out.conv_y.bias"], + "decoder.conv_norm_out.conv_b.weight": checkpoint["decoder.norm_out.conv_b.weight"], + "decoder.conv_norm_out.conv_b.bias": checkpoint["decoder.norm_out.conv_b.bias"], + # conv_out + "decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"], + "decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"], + } + ) + + return diffusers_checkpoint + + +def vqvae_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + rv = { + # norm1 + f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"], + f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"], + # conv1 + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], + # norm2 + f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"], + f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"], + # conv2 + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], + } + + if resnet.conv_shortcut is not None: + rv.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], + } + ) + + return rv + +def vqvae_resnet_to_diffusers_checkpoint_spatial_norm(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + rv = { + # norm1 + f"{diffusers_resnet_prefix}.norm1.norm_layer.weight": checkpoint[f"{resnet_prefix}.norm1.norm_layer.weight"], + f"{diffusers_resnet_prefix}.norm1.norm_layer.bias": checkpoint[f"{resnet_prefix}.norm1.norm_layer.bias"], + f"{diffusers_resnet_prefix}.norm1.conv_y.weight": checkpoint[f"{resnet_prefix}.norm1.conv_y.weight"], + f"{diffusers_resnet_prefix}.norm1.conv_y.bias": checkpoint[f"{resnet_prefix}.norm1.conv_y.bias"], + f"{diffusers_resnet_prefix}.norm1.conv_b.weight": checkpoint[f"{resnet_prefix}.norm1.conv_b.weight"], + f"{diffusers_resnet_prefix}.norm1.conv_b.bias": checkpoint[f"{resnet_prefix}.norm1.conv_b.bias"], + # conv1 + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], + # norm2 + f"{diffusers_resnet_prefix}.norm2.norm_layer.weight": checkpoint[f"{resnet_prefix}.norm2.norm_layer.weight"], + f"{diffusers_resnet_prefix}.norm2.norm_layer.bias": checkpoint[f"{resnet_prefix}.norm2.norm_layer.bias"], + f"{diffusers_resnet_prefix}.norm2.conv_y.weight": checkpoint[f"{resnet_prefix}.norm2.conv_y.weight"], + f"{diffusers_resnet_prefix}.norm2.conv_y.bias": checkpoint[f"{resnet_prefix}.norm2.conv_y.bias"], + f"{diffusers_resnet_prefix}.norm2.conv_b.weight": checkpoint[f"{resnet_prefix}.norm2.conv_b.weight"], + f"{diffusers_resnet_prefix}.norm2.conv_b.bias": checkpoint[f"{resnet_prefix}.norm2.conv_b.bias"], + # conv2 + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], + } + + if resnet.conv_shortcut is not None: + rv.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], + } + ) + + return rv + + + +def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # norm + f"{diffusers_attention_prefix}.norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], + f"{diffusers_attention_prefix}.norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], + # query + f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"], + # key + f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"], + # value + f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"], + # proj_attn + f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ + :, :, 0, 0 + ], + f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + +def vqvae_attention_to_diffusers_checkpoint_spatial_norm(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # norm + f"{diffusers_attention_prefix}.norm.norm_layer.weight": checkpoint[f"{attention_prefix}.norm.norm_layer.weight"], + f"{diffusers_attention_prefix}.norm.norm_layer.bias": checkpoint[f"{attention_prefix}.norm.norm_layer.bias"], + f"{diffusers_attention_prefix}.norm.conv_y.weight": checkpoint[f"{attention_prefix}.norm.conv_y.weight"], + f"{diffusers_attention_prefix}.norm.conv_y.bias": checkpoint[f"{attention_prefix}.norm.conv_y.bias"], + f"{diffusers_attention_prefix}.norm.conv_b.weight": checkpoint[f"{attention_prefix}.norm.conv_b.weight"], + f"{diffusers_attention_prefix}.norm.conv_b.bias": checkpoint[f"{attention_prefix}.norm.conv_b.bias"], + # query + f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"], + # key + f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"], + # value + f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"], + # proj_attn + f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ + :, :, 0, 0 + ], + f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + + + + + +def movq_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + diffusers_checkpoint.update(vqvae_encoder_to_diffusers_checkpoint(model, checkpoint)) + + + # quant_conv + + diffusers_checkpoint.update( + { + "quant_conv.weight": checkpoint["quant_conv.weight"], + "quant_conv.bias": checkpoint["quant_conv.bias"], + } + ) + + # quantize + diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding.weight"]}) + + # post_quant_conv + diffusers_checkpoint.update( + { + "post_quant_conv.weight": checkpoint["post_quant_conv.weight"], + "post_quant_conv.bias": checkpoint["post_quant_conv.bias"], + } + ) + + # decoder + diffusers_checkpoint.update(vqvae_decoder_to_diffusers_checkpoint(model, checkpoint)) + + + + for keys in diffusers_checkpoint.keys(): + print(keys) + + return diffusers_checkpoint + + + +def movq(*, args, checkpoint_map_location): + print("loading movq") + + movq_checkpoint = torch.load(args.movq_checkpoint_path, map_location=checkpoint_map_location) + movq_model = movq_model_from_original_config() + + movq_diffusers_checkpoint = movq_original_checkpoint_to_diffusers_checkpoint( + movq_model, movq_checkpoint + ) + + del movq_checkpoint + load_checkpoint_to_model(movq_diffusers_checkpoint, movq_model, strict=True) + + print("done loading movq") + + return movq_model + +def load_checkpoint_to_model(checkpoint, model, strict=False): + with tempfile.NamedTemporaryFile(delete=False) as file: + torch.save(checkpoint, file.name) + del checkpoint + if strict: + model.load_state_dict(torch.load(file.name), strict=True) + else: + load_checkpoint_and_dispatch(model, file.name, device_map="auto") + os.remove(file.name) + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--dump_path", default="./kandinsky_model", type=str, required=False, help="Path to the output model.") + + parser.add_argument( + "--movq_checkpoint_path", + default="D:/Kandinsky-2/weights/2_1/movq_final.ckpt", + type=str, + required=False, + help="Path to the movq checkpoint to convert.", + ) + parser.add_argument( + "--checkpoint_load_device", + default="cpu", + type=str, + required=False, + help="The device passed to `map_location` when loading checkpoints.", + ) + + parser.add_argument( + "--debug", + default=None, + type=str, + required=False, + help="Only run a specific stage of the convert script. Used for debugging", + ) + + args = parser.parse_args() + + print(f"loading checkpoints to {args.checkpoint_load_device}") + + checkpoint_map_location = torch.device(args.checkpoint_load_device) + + + movq_model = movq(args=args, checkpoint_map_location=checkpoint_map_location) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 8e537c6f3680..4863905b99ce 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -55,12 +55,18 @@ def __init__( norm_num_groups: int = 32, rescale_output_factor: float = 1.0, eps: float = 1e-5, + use_spatial_norm: bool = False, + temb_channels: Optional[int] = None, ): super().__init__() self.channels = channels + self.use_spatial_norm = use_spatial_norm self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 - self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) + if use_spatial_norm: + self.norm = SpatialNorm(channels, temb_channels) + else: + self.norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) # define q,k,v as linear layers self.query = nn.Linear(channels, channels) @@ -126,12 +132,15 @@ def set_use_memory_efficient_attention_xformers( self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers self._attention_op = attention_op - def forward(self, hidden_states): + def forward(self, hidden_states, zq=None): residual = hidden_states batch, channel, height, width = hidden_states.shape # norm - hidden_states = self.group_norm(hidden_states) + if self.use_spatial_norm: + hidden_states = self.norm(hidden_states, zq=zq) + else: + hidden_states = self.norm(hidden_states) hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) @@ -539,3 +548,33 @@ def forward(self, x, emb): x = F.group_norm(x, self.num_groups, eps=self.eps) x = x * (1 + scale) + shift return x + + +class SpatialNorm(nn.Module): + def __init__( + self, + f_channels, + zq_channels, + norm_layer=nn.GroupNorm, + freeze_norm_layer=False, + add_conv=False, + ): + super().__init__() + self.norm_layer = norm_layer(num_channels=f_channels,num_groups=32,eps=1e-6,affine=True) + if freeze_norm_layer: + for p in self.norm_layer.parameters: + p.requires_grad = False + self.add_conv = add_conv + if self.add_conv: + self.conv = nn.Conv2d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1) + self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f, zq): + f_size = f.shape[-2:] + zq = F.interpolate(zq, size=f_size, mode="nearest") + if self.add_conv: + zq = self.conv(zq) + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index d9d539959c09..83bec9a52593 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -20,7 +20,7 @@ import torch.nn as nn import torch.nn.functional as F -from .attention import AdaGroupNorm +from .attention import AdaGroupNorm, SpatialNorm class Upsample1D(nn.Module): @@ -460,7 +460,7 @@ def __init__( eps=1e-6, non_linearity="swish", skip_time_act=False, - time_embedding_norm="default", # default, scale_shift, ada_group + time_embedding_norm="default", # default, scale_shift, ada_group, spatial kernel=None, output_scale_factor=1.0, use_in_shortcut=None, @@ -487,6 +487,8 @@ def __init__( if self.time_embedding_norm == "ada_group": self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) + elif self.time_embedding_norm == "spatial": + self.norm1 = SpatialNorm(in_channels, temb_channels) else: self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) @@ -497,7 +499,7 @@ def __init__( self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) - elif self.time_embedding_norm == "ada_group": + elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": self.time_emb_proj = None else: raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") @@ -506,6 +508,8 @@ def __init__( if self.time_embedding_norm == "ada_group": self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) + elif self.time_embedding_norm == "spatial": + self.norm2 = SpatialNorm(out_channels, temb_channels) else: self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) @@ -551,7 +555,7 @@ def __init__( def forward(self, input_tensor, temb): hidden_states = input_tensor - if self.time_embedding_norm == "ada_group": + if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm1(hidden_states, temb) else: hidden_states = self.norm1(hidden_states) @@ -579,7 +583,7 @@ def forward(self, input_tensor, temb): if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb - if self.time_embedding_norm == "ada_group": + if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm2(hidden_states, temb) else: hidden_states = self.norm2(hidden_states) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 439c5c34b601..4c3254eaae46 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -348,6 +348,7 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels ) elif up_block_type == "AttnUpDecoderBlock2D": return AttnUpDecoderBlock2D( @@ -360,6 +361,7 @@ def get_up_block( resnet_groups=resnet_groups, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels ) elif up_block_type == "KUpBlock2D": return KUpBlock2D( @@ -402,10 +404,12 @@ def __init__( add_attention: bool = True, attn_num_head_channels=1, output_scale_factor=1.0, + use_spatial_norm=False, ): super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.add_attention = add_attention + self.use_spatial_norm = use_spatial_norm # there is always at least one resnet resnets = [ @@ -433,6 +437,8 @@ def __init__( rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, + use_spatial_norm=use_spatial_norm, + temb_channels=temb_channels ) ) else: @@ -460,7 +466,10 @@ def forward(self, hidden_states, temb=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: - hidden_states = attn(hidden_states) + if self.use_spatial_norm: + hidden_states = attn(hidden_states, temb) + else: + hidden_states = attn(hidden_states) hidden_states = resnet(hidden_states, temb) return hidden_states @@ -1956,6 +1965,7 @@ def __init__( resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, + temb_channels=None ): super().__init__() resnets = [] @@ -1967,7 +1977,7 @@ def __init__( ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, - temb_channels=None, + temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, @@ -1985,9 +1995,9 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states): + def forward(self, hidden_states, temb=None): for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None) + hidden_states = resnet(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -2011,10 +2021,12 @@ def __init__( attn_num_head_channels=1, output_scale_factor=1.0, add_upsample=True, + temb_channels=None ): super().__init__() resnets = [] attentions = [] + self.use_spatial_norm = resnet_time_scale_shift == "spatial" for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels @@ -2023,7 +2035,7 @@ def __init__( ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, - temb_channels=None, + temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, @@ -2040,6 +2052,8 @@ def __init__( rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, + use_spatial_norm=self.use_spatial_norm, + temb_channels=temb_channels ) ) @@ -2051,10 +2065,15 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states): + def forward(self, hidden_states, zq): for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=None) - hidden_states = attn(hidden_states) + if self.use_spatial_norm: + hidden_states = resnet(hidden_states, temb=zq) + hidden_states = attn(hidden_states, zq) + else: + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states, zq) + if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 400c3030af90..5e7353736b53 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -20,7 +20,7 @@ from ..utils import BaseOutput, randn_tensor from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block - +from .attention import SpatialNorm @dataclass class DecoderOutput(BaseOutput): @@ -149,10 +149,14 @@ def __init__( layers_per_block=2, norm_num_groups=32, act_fn="silu", + use_spatial_norm=False, + temb_channels=None ): super().__init__() self.layers_per_block = layers_per_block + self.use_spatial_norm = use_spatial_norm + self.conv_in = nn.Conv2d( in_channels, block_out_channels[-1], @@ -163,6 +167,10 @@ def __init__( self.mid_block = None self.up_blocks = nn.ModuleList([]) + resnet_time_scale_shift = "default" + + if self.use_spatial_norm: + resnet_time_scale_shift = "spatial" # mid self.mid_block = UNetMidBlock2D( @@ -170,10 +178,11 @@ def __init__( resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, - resnet_time_scale_shift="default", + resnet_time_scale_shift=resnet_time_scale_shift, attn_num_head_channels=None, resnet_groups=norm_num_groups, - temb_channels=None, + temb_channels=temb_channels, + use_spatial_norm=use_spatial_norm, ) # up @@ -196,19 +205,23 @@ def __init__( resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attn_num_head_channels=None, - temb_channels=None, + temb_channels=temb_channels, + resnet_time_scale_shift=resnet_time_scale_shift, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + if use_spatial_norm: + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) self.gradient_checkpointing = False - def forward(self, z): + def forward(self, z, zq=None): sample = z sample = self.conv_in(sample) @@ -230,15 +243,24 @@ def custom_forward(*inputs): sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) else: # middle - sample = self.mid_block(sample) + if self.use_spatial_norm: + sample = self.mid_block(sample, zq) + else: + sample = self.mid_block(sample) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: - sample = up_block(sample) + if self.use_spatial_norm: + sample = up_block(sample, zq) + else: + sample = up_block(sample) # post-process - sample = self.conv_norm_out(sample) + if self.use_spatial_norm: + sample = self.conv_norm_out(sample, zq) + else: + sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 65f734dccb2d..ee2a7d203bcf 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -82,9 +82,12 @@ def __init__( norm_num_groups: int = 32, vq_embed_dim: Optional[int] = None, scaling_factor: float = 0.18215, + use_spatial_norm: bool = False ): super().__init__() + self.use_spatial_norm = use_spatial_norm + # pass init params to Encoder self.encoder = Encoder( in_channels=in_channels, @@ -112,6 +115,8 @@ def __init__( layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, + use_spatial_norm=use_spatial_norm, + temb_channels=latent_channels, ) def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: @@ -131,8 +136,11 @@ def decode( quant, emb_loss, info = self.quantize(h) else: quant = h - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) + quant2 = self.post_quant_conv(quant) + if self.use_spatial_norm: + dec = self.decoder(quant2, quant) + else : + dec = self.decoder(quant2) if not return_dict: return (dec,) diff --git a/src/diffusers/pipelines/kandinsky/__init__.py b/src/diffusers/pipelines/kandinsky/__init__.py index 7996ed2d581f..fb4746bd6087 100644 --- a/src/diffusers/pipelines/kandinsky/__init__.py +++ b/src/diffusers/pipelines/kandinsky/__init__.py @@ -13,4 +13,4 @@ print("to-do") # from ...utils.dummy_torch_and_transformers_objects import UnCLIPImageVariationPipeline, UnCLIPPipeline else: - from .pipeline_kandinsky_prior import KandinskyPipeline + from .pipeline_kandinsky import KandinskyPipeline From e0582c18c6280384bc1fe9bc76e1952b6f59c119 Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Thu, 4 May 2023 15:54:23 +0530 Subject: [PATCH 2/8] Merge decoder conversion script with others --- .../convert_kandinsky_movq_to_diffusers.py | 465 ------------------ scripts/convert_kandinsky_to_diffusers.py | 418 +++++++++++++++- 2 files changed, 414 insertions(+), 469 deletions(-) delete mode 100644 scripts/convert_kandinsky_movq_to_diffusers.py diff --git a/scripts/convert_kandinsky_movq_to_diffusers.py b/scripts/convert_kandinsky_movq_to_diffusers.py deleted file mode 100644 index 4a40b10ec61a..000000000000 --- a/scripts/convert_kandinsky_movq_to_diffusers.py +++ /dev/null @@ -1,465 +0,0 @@ -import argparse -import tempfile -import os -import torch -from accelerate import load_checkpoint_and_dispatch -from diffusers.models.vq_model import VQModel - -""" -Example - From the diffusers root directory: - -Download weights: -```sh -$ wget https://huggingface.co/ai-forever/Kandinsky_2.1/blob/main/movq_fp16.ckpt -``` - -Convert the model: -```sh -python scripts/convert_kandinsky_movq_to_diffusers.py --movq_checkpoint_path D:/Kandinsky-2/weights/2_1/movq_final.ckpt --dump_path ./kandinsky_model --debug movq -``` -"""# movq - -movq_ORIGINAL_PREFIX = "model" - -# Uses default arguments -movq_CONFIG = {} - - -def movq_model_from_original_config(): - movq = VQModel( - in_channels=3, - out_channels=3, - latent_channels=4, - use_spatial_norm=True, - down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "AttnDownEncoderBlock2D"), - up_block_types=("AttnUpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"), - num_vq_embeddings=16384, - block_out_channels=(128, 256, 256, 512), - vq_embed_dim=4, - layers_per_block=2 - ) - return movq - - -def vqvae_encoder_to_diffusers_checkpoint(model, checkpoint): - diffusers_checkpoint = {} - - # conv_in - diffusers_checkpoint.update( - { - "encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"], - "encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"], - } - ) - - # down_blocks - for down_block_idx, down_block in enumerate(model.encoder.down_blocks): - diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}" - down_block_prefix = f"encoder.down.{down_block_idx}" - - # resnets - for resnet_idx, resnet in enumerate(down_block.resnets): - diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}" - resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}" - - diffusers_checkpoint.update( - vqvae_resnet_to_diffusers_checkpoint( - resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix - ) - ) - - # downsample - - # do not include the downsample when on the last down block - # There is no downsample on the last down block - if down_block_idx != len(model.encoder.down_blocks) - 1: - # There's a single downsample in the original checkpoint but a list of downsamples - # in the diffusers model. - diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv" - downsample_prefix = f"{down_block_prefix}.downsample.conv" - diffusers_checkpoint.update( - { - f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], - f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], - } - ) - - # attentions - - if hasattr(down_block, "attentions"): - for attention_idx, _ in enumerate(down_block.attentions): - diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}" - attention_prefix = f"{down_block_prefix}.attn.{attention_idx}" - diffusers_checkpoint.update( - vqvae_attention_to_diffusers_checkpoint( - checkpoint, - diffusers_attention_prefix=diffusers_attention_prefix, - attention_prefix=attention_prefix, - ) - ) - - # mid block - - # mid block attentions - - # There is a single hardcoded attention block in the middle of the VQ-diffusion encoder - diffusers_attention_prefix = "encoder.mid_block.attentions.0" - attention_prefix = "encoder.mid.attn_1" - diffusers_checkpoint.update( - vqvae_attention_to_diffusers_checkpoint( - checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix - ) - ) - - # mid block resnets - - for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): - diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}" - - # the hardcoded prefixes to `block_` are 1 and 2 - orig_resnet_idx = diffusers_resnet_idx + 1 - # There are two hardcoded resnets in the middle of the VQ-diffusion encoder - resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}" - - diffusers_checkpoint.update( - vqvae_resnet_to_diffusers_checkpoint( - resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix - ) - ) - - diffusers_checkpoint.update( - { - # conv_norm_out - "encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"], - "encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"], - # conv_out - "encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"], - "encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"], - } - ) - - return diffusers_checkpoint - - -def vqvae_decoder_to_diffusers_checkpoint(model, checkpoint): - diffusers_checkpoint = {} - - # conv in - diffusers_checkpoint.update( - { - "decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"], - "decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"], - } - ) - - # up_blocks - - for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks): - # up_blocks are stored in reverse order in the VQ-diffusion checkpoint - orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx - - diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}" - up_block_prefix = f"decoder.up.{orig_up_block_idx}" - - # resnets - for resnet_idx, resnet in enumerate(up_block.resnets): - diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}" - resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}" - - diffusers_checkpoint.update( - vqvae_resnet_to_diffusers_checkpoint_spatial_norm( - resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix - ) - ) - - # upsample - - # there is no up sample on the last up block - if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1: - # There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples - # in the diffusers model. - diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv" - downsample_prefix = f"{up_block_prefix}.upsample.conv" - diffusers_checkpoint.update( - { - f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], - f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], - } - ) - - # attentions - - if hasattr(up_block, "attentions"): - for attention_idx, _ in enumerate(up_block.attentions): - diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}" - attention_prefix = f"{up_block_prefix}.attn.{attention_idx}" - diffusers_checkpoint.update( - vqvae_attention_to_diffusers_checkpoint_spatial_norm( - checkpoint, - diffusers_attention_prefix=diffusers_attention_prefix, - attention_prefix=attention_prefix, - ) - ) - - # mid block - - # mid block attentions - - # There is a single hardcoded attention block in the middle of the VQ-diffusion decoder - diffusers_attention_prefix = "decoder.mid_block.attentions.0" - attention_prefix = "decoder.mid.attn_1" - diffusers_checkpoint.update( - vqvae_attention_to_diffusers_checkpoint_spatial_norm( - checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix - ) - ) - - # mid block resnets - - for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): - diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}" - - # the hardcoded prefixes to `block_` are 1 and 2 - orig_resnet_idx = diffusers_resnet_idx + 1 - # There are two hardcoded resnets in the middle of the VQ-diffusion decoder - resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}" - - diffusers_checkpoint.update( - vqvae_resnet_to_diffusers_checkpoint_spatial_norm( - resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix - ) - ) - - diffusers_checkpoint.update( - { - # conv_norm_out - "decoder.conv_norm_out.norm_layer.weight": checkpoint["decoder.norm_out.norm_layer.weight"], - "decoder.conv_norm_out.norm_layer.bias": checkpoint["decoder.norm_out.norm_layer.bias"], - "decoder.conv_norm_out.conv_y.weight": checkpoint["decoder.norm_out.conv_y.weight"], - "decoder.conv_norm_out.conv_y.bias": checkpoint["decoder.norm_out.conv_y.bias"], - "decoder.conv_norm_out.conv_b.weight": checkpoint["decoder.norm_out.conv_b.weight"], - "decoder.conv_norm_out.conv_b.bias": checkpoint["decoder.norm_out.conv_b.bias"], - # conv_out - "decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"], - "decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"], - } - ) - - return diffusers_checkpoint - - -def vqvae_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): - rv = { - # norm1 - f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"], - f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"], - # conv1 - f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], - f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], - # norm2 - f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"], - f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"], - # conv2 - f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], - f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], - } - - if resnet.conv_shortcut is not None: - rv.update( - { - f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], - f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], - } - ) - - return rv - -def vqvae_resnet_to_diffusers_checkpoint_spatial_norm(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): - rv = { - # norm1 - f"{diffusers_resnet_prefix}.norm1.norm_layer.weight": checkpoint[f"{resnet_prefix}.norm1.norm_layer.weight"], - f"{diffusers_resnet_prefix}.norm1.norm_layer.bias": checkpoint[f"{resnet_prefix}.norm1.norm_layer.bias"], - f"{diffusers_resnet_prefix}.norm1.conv_y.weight": checkpoint[f"{resnet_prefix}.norm1.conv_y.weight"], - f"{diffusers_resnet_prefix}.norm1.conv_y.bias": checkpoint[f"{resnet_prefix}.norm1.conv_y.bias"], - f"{diffusers_resnet_prefix}.norm1.conv_b.weight": checkpoint[f"{resnet_prefix}.norm1.conv_b.weight"], - f"{diffusers_resnet_prefix}.norm1.conv_b.bias": checkpoint[f"{resnet_prefix}.norm1.conv_b.bias"], - # conv1 - f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], - f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], - # norm2 - f"{diffusers_resnet_prefix}.norm2.norm_layer.weight": checkpoint[f"{resnet_prefix}.norm2.norm_layer.weight"], - f"{diffusers_resnet_prefix}.norm2.norm_layer.bias": checkpoint[f"{resnet_prefix}.norm2.norm_layer.bias"], - f"{diffusers_resnet_prefix}.norm2.conv_y.weight": checkpoint[f"{resnet_prefix}.norm2.conv_y.weight"], - f"{diffusers_resnet_prefix}.norm2.conv_y.bias": checkpoint[f"{resnet_prefix}.norm2.conv_y.bias"], - f"{diffusers_resnet_prefix}.norm2.conv_b.weight": checkpoint[f"{resnet_prefix}.norm2.conv_b.weight"], - f"{diffusers_resnet_prefix}.norm2.conv_b.bias": checkpoint[f"{resnet_prefix}.norm2.conv_b.bias"], - # conv2 - f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], - f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], - } - - if resnet.conv_shortcut is not None: - rv.update( - { - f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], - f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], - } - ) - - return rv - - - -def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): - return { - # norm - f"{diffusers_attention_prefix}.norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], - f"{diffusers_attention_prefix}.norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], - # query - f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], - f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"], - # key - f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], - f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"], - # value - f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], - f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"], - # proj_attn - f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ - :, :, 0, 0 - ], - f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], - } - -def vqvae_attention_to_diffusers_checkpoint_spatial_norm(checkpoint, *, diffusers_attention_prefix, attention_prefix): - return { - # norm - f"{diffusers_attention_prefix}.norm.norm_layer.weight": checkpoint[f"{attention_prefix}.norm.norm_layer.weight"], - f"{diffusers_attention_prefix}.norm.norm_layer.bias": checkpoint[f"{attention_prefix}.norm.norm_layer.bias"], - f"{diffusers_attention_prefix}.norm.conv_y.weight": checkpoint[f"{attention_prefix}.norm.conv_y.weight"], - f"{diffusers_attention_prefix}.norm.conv_y.bias": checkpoint[f"{attention_prefix}.norm.conv_y.bias"], - f"{diffusers_attention_prefix}.norm.conv_b.weight": checkpoint[f"{attention_prefix}.norm.conv_b.weight"], - f"{diffusers_attention_prefix}.norm.conv_b.bias": checkpoint[f"{attention_prefix}.norm.conv_b.bias"], - # query - f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], - f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"], - # key - f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], - f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"], - # value - f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], - f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"], - # proj_attn - f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ - :, :, 0, 0 - ], - f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], - } - - - - - -def movq_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): - diffusers_checkpoint = {} - diffusers_checkpoint.update(vqvae_encoder_to_diffusers_checkpoint(model, checkpoint)) - - - # quant_conv - - diffusers_checkpoint.update( - { - "quant_conv.weight": checkpoint["quant_conv.weight"], - "quant_conv.bias": checkpoint["quant_conv.bias"], - } - ) - - # quantize - diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding.weight"]}) - - # post_quant_conv - diffusers_checkpoint.update( - { - "post_quant_conv.weight": checkpoint["post_quant_conv.weight"], - "post_quant_conv.bias": checkpoint["post_quant_conv.bias"], - } - ) - - # decoder - diffusers_checkpoint.update(vqvae_decoder_to_diffusers_checkpoint(model, checkpoint)) - - - - for keys in diffusers_checkpoint.keys(): - print(keys) - - return diffusers_checkpoint - - - -def movq(*, args, checkpoint_map_location): - print("loading movq") - - movq_checkpoint = torch.load(args.movq_checkpoint_path, map_location=checkpoint_map_location) - movq_model = movq_model_from_original_config() - - movq_diffusers_checkpoint = movq_original_checkpoint_to_diffusers_checkpoint( - movq_model, movq_checkpoint - ) - - del movq_checkpoint - load_checkpoint_to_model(movq_diffusers_checkpoint, movq_model, strict=True) - - print("done loading movq") - - return movq_model - -def load_checkpoint_to_model(checkpoint, model, strict=False): - with tempfile.NamedTemporaryFile(delete=False) as file: - torch.save(checkpoint, file.name) - del checkpoint - if strict: - model.load_state_dict(torch.load(file.name), strict=True) - else: - load_checkpoint_and_dispatch(model, file.name, device_map="auto") - os.remove(file.name) - - - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - parser.add_argument("--dump_path", default="./kandinsky_model", type=str, required=False, help="Path to the output model.") - - parser.add_argument( - "--movq_checkpoint_path", - default="D:/Kandinsky-2/weights/2_1/movq_final.ckpt", - type=str, - required=False, - help="Path to the movq checkpoint to convert.", - ) - parser.add_argument( - "--checkpoint_load_device", - default="cpu", - type=str, - required=False, - help="The device passed to `map_location` when loading checkpoints.", - ) - - parser.add_argument( - "--debug", - default=None, - type=str, - required=False, - help="Only run a specific stage of the convert script. Used for debugging", - ) - - args = parser.parse_args() - - print(f"loading checkpoints to {args.checkpoint_load_device}") - - checkpoint_map_location = torch.device(args.checkpoint_load_device) - - - movq_model = movq(args=args, checkpoint_map_location=checkpoint_map_location) diff --git a/scripts/convert_kandinsky_to_diffusers.py b/scripts/convert_kandinsky_to_diffusers.py index 00941d3d2a3b..560fc316040a 100644 --- a/scripts/convert_kandinsky_to_diffusers.py +++ b/scripts/convert_kandinsky_to_diffusers.py @@ -1,9 +1,11 @@ import argparse import tempfile +import os import torch from accelerate import load_checkpoint_and_dispatch from diffusers.models.prior_transformer import PriorTransformer +from diffusers.models.vq_model import VQModel from diffusers.pipelines.kandinsky.text_proj import KandinskyTextProjModel from diffusers import UNet2DConditionModel @@ -753,15 +755,413 @@ def text2img(*, args, checkpoint_map_location): return unet_model, text_proj_model +# movq + +MOVQ_CONFIG ={ + "in_channels":3, + "out_channels":3, + "latent_channels":4, + "use_spatial_norm":True, + "down_block_types":("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "AttnDownEncoderBlock2D"), + "up_block_types":("AttnUpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"), + "num_vq_embeddings":16384, + "block_out_channels":(128, 256, 256, 512), + "vq_embed_dim":4, + "layers_per_block":2 + } + + +def movq_model_from_original_config(): + movq = VQModel(**MOVQ_CONFIG ) + return movq + +def movq_encoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv_in + diffusers_checkpoint.update( + { + "encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"], + "encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"], + } + ) + + # down_blocks + for down_block_idx, down_block in enumerate(model.encoder.down_blocks): + diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}" + down_block_prefix = f"encoder.down.{down_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(down_block.resnets): + diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + movq_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + # downsample + + # do not include the downsample when on the last down block + # There is no downsample on the last down block + if down_block_idx != len(model.encoder.down_blocks) - 1: + # There's a single downsample in the original checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv" + downsample_prefix = f"{down_block_prefix}.downsample.conv" + diffusers_checkpoint.update( + { + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], + } + ) + + # attentions + + if hasattr(down_block, "attentions"): + for attention_idx, _ in enumerate(down_block.attentions): + diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{down_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + movq_attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix, + ) + ) + + # mid block + + # mid block attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion encoder + diffusers_attention_prefix = "encoder.mid_block.attentions.0" + attention_prefix = "encoder.mid.attn_1" + diffusers_checkpoint.update( + movq_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # mid block resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion encoder + resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + movq_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update( + { + # conv_norm_out + "encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"], + "encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"], + # conv_out + "encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"], + "encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"], + } + ) + + return diffusers_checkpoint + + +def movq_decoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv in + diffusers_checkpoint.update( + { + "decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"], + "decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"], + } + ) + + # up_blocks + + for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks): + # up_blocks are stored in reverse order in the VQ-diffusion checkpoint + orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx + + diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}" + up_block_prefix = f"decoder.up.{orig_up_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(up_block.resnets): + diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + movq_resnet_to_diffusers_checkpoint_spatial_norm( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + # upsample + + # there is no up sample on the last up block + if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1: + # There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv" + downsample_prefix = f"{up_block_prefix}.upsample.conv" + diffusers_checkpoint.update( + { + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], + } + ) + + # attentions + + if hasattr(up_block, "attentions"): + for attention_idx, _ in enumerate(up_block.attentions): + diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{up_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + movq_attention_to_diffusers_checkpoint_spatial_norm( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix, + ) + ) + + # mid block + + # mid block attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion decoder + diffusers_attention_prefix = "decoder.mid_block.attentions.0" + attention_prefix = "decoder.mid.attn_1" + diffusers_checkpoint.update( + movq_attention_to_diffusers_checkpoint_spatial_norm( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # mid block resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion decoder + resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + movq_resnet_to_diffusers_checkpoint_spatial_norm( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update( + { + # conv_norm_out + "decoder.conv_norm_out.norm_layer.weight": checkpoint["decoder.norm_out.norm_layer.weight"], + "decoder.conv_norm_out.norm_layer.bias": checkpoint["decoder.norm_out.norm_layer.bias"], + "decoder.conv_norm_out.conv_y.weight": checkpoint["decoder.norm_out.conv_y.weight"], + "decoder.conv_norm_out.conv_y.bias": checkpoint["decoder.norm_out.conv_y.bias"], + "decoder.conv_norm_out.conv_b.weight": checkpoint["decoder.norm_out.conv_b.weight"], + "decoder.conv_norm_out.conv_b.bias": checkpoint["decoder.norm_out.conv_b.bias"], + # conv_out + "decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"], + "decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"], + } + ) + + return diffusers_checkpoint + + +def movq_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + rv = { + # norm1 + f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"], + f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"], + # conv1 + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], + # norm2 + f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"], + f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"], + # conv2 + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], + } + + if resnet.conv_shortcut is not None: + rv.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], + } + ) + + return rv + +def movq_resnet_to_diffusers_checkpoint_spatial_norm(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + rv = { + # norm1 + f"{diffusers_resnet_prefix}.norm1.norm_layer.weight": checkpoint[f"{resnet_prefix}.norm1.norm_layer.weight"], + f"{diffusers_resnet_prefix}.norm1.norm_layer.bias": checkpoint[f"{resnet_prefix}.norm1.norm_layer.bias"], + f"{diffusers_resnet_prefix}.norm1.conv_y.weight": checkpoint[f"{resnet_prefix}.norm1.conv_y.weight"], + f"{diffusers_resnet_prefix}.norm1.conv_y.bias": checkpoint[f"{resnet_prefix}.norm1.conv_y.bias"], + f"{diffusers_resnet_prefix}.norm1.conv_b.weight": checkpoint[f"{resnet_prefix}.norm1.conv_b.weight"], + f"{diffusers_resnet_prefix}.norm1.conv_b.bias": checkpoint[f"{resnet_prefix}.norm1.conv_b.bias"], + # conv1 + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], + # norm2 + f"{diffusers_resnet_prefix}.norm2.norm_layer.weight": checkpoint[f"{resnet_prefix}.norm2.norm_layer.weight"], + f"{diffusers_resnet_prefix}.norm2.norm_layer.bias": checkpoint[f"{resnet_prefix}.norm2.norm_layer.bias"], + f"{diffusers_resnet_prefix}.norm2.conv_y.weight": checkpoint[f"{resnet_prefix}.norm2.conv_y.weight"], + f"{diffusers_resnet_prefix}.norm2.conv_y.bias": checkpoint[f"{resnet_prefix}.norm2.conv_y.bias"], + f"{diffusers_resnet_prefix}.norm2.conv_b.weight": checkpoint[f"{resnet_prefix}.norm2.conv_b.weight"], + f"{diffusers_resnet_prefix}.norm2.conv_b.bias": checkpoint[f"{resnet_prefix}.norm2.conv_b.bias"], + # conv2 + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], + } + + if resnet.conv_shortcut is not None: + rv.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], + } + ) + + return rv + + + +def movq_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # norm + f"{diffusers_attention_prefix}.norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], + f"{diffusers_attention_prefix}.norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], + # query + f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"], + # key + f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"], + # value + f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"], + # proj_attn + f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ + :, :, 0, 0 + ], + f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + +def movq_attention_to_diffusers_checkpoint_spatial_norm(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # norm + f"{diffusers_attention_prefix}.norm.norm_layer.weight": checkpoint[f"{attention_prefix}.norm.norm_layer.weight"], + f"{diffusers_attention_prefix}.norm.norm_layer.bias": checkpoint[f"{attention_prefix}.norm.norm_layer.bias"], + f"{diffusers_attention_prefix}.norm.conv_y.weight": checkpoint[f"{attention_prefix}.norm.conv_y.weight"], + f"{diffusers_attention_prefix}.norm.conv_y.bias": checkpoint[f"{attention_prefix}.norm.conv_y.bias"], + f"{diffusers_attention_prefix}.norm.conv_b.weight": checkpoint[f"{attention_prefix}.norm.conv_b.weight"], + f"{diffusers_attention_prefix}.norm.conv_b.bias": checkpoint[f"{attention_prefix}.norm.conv_b.bias"], + # query + f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"], + # key + f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"], + # value + f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"], + # proj_attn + f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ + :, :, 0, 0 + ], + f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + + + + + +def movq_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + diffusers_checkpoint.update(movq_encoder_to_diffusers_checkpoint(model, checkpoint)) + + + # quant_conv + + diffusers_checkpoint.update( + { + "quant_conv.weight": checkpoint["quant_conv.weight"], + "quant_conv.bias": checkpoint["quant_conv.bias"], + } + ) + + # quantize + diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding.weight"]}) + + # post_quant_conv + diffusers_checkpoint.update( + { + "post_quant_conv.weight": checkpoint["post_quant_conv.weight"], + "post_quant_conv.bias": checkpoint["post_quant_conv.bias"], + } + ) + + # decoder + diffusers_checkpoint.update(movq_decoder_to_diffusers_checkpoint(model, checkpoint)) + + + + for keys in diffusers_checkpoint.keys(): + print(keys) + + return diffusers_checkpoint + + + + + +def movq(*, args, checkpoint_map_location): + print("loading movq") + + movq_checkpoint = torch.load(args.movq_checkpoint_path, map_location=checkpoint_map_location) + + movq_model = movq_model_from_original_config() + + movq_diffusers_checkpoint = movq_original_checkpoint_to_diffusers_checkpoint( + movq_model, movq_checkpoint + ) + + del movq_checkpoint + + load_checkpoint_to_model(movq_diffusers_checkpoint, movq_model, strict=True) + + print("done loading movq") + + return movq_model + def load_checkpoint_to_model(checkpoint, model, strict=False): - with tempfile.NamedTemporaryFile() as file: + with tempfile.NamedTemporaryFile(delete=False) as file: torch.save(checkpoint, file.name) del checkpoint if strict: model.load_state_dict(torch.load(file.name), strict=True) else: load_checkpoint_and_dispatch(model, file.name, device_map="auto") + os.remove(file.name) @@ -774,17 +1174,24 @@ def load_checkpoint_to_model(checkpoint, model, strict=False): "--prior_checkpoint_path", default=None, type=str, - required=True, + required=False, help="Path to the prior checkpoint to convert.", ) parser.add_argument( - "--clip_stat_path", default=None, type=str, required=True, help="Path to the clip stats checkpoint to convert." + "--clip_stat_path", default=None, type=str, required=False, help="Path to the clip stats checkpoint to convert." ) parser.add_argument( "--text2img_checkpoint_path", default=None, type=str, - required=True, + required=False, + help="Path to the text2img checkpoint to convert.", + ) + parser.add_argument( + "--movq_checkpoint_path", + default=None, + type=str, + required=False, help="Path to the text2img checkpoint to convert.", ) parser.add_argument( @@ -821,5 +1228,8 @@ def load_checkpoint_to_model(checkpoint, model, strict=False): unet_model, text_proj_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location) unet_model.save_pretrained(f"{args.dump_path}/unet") text_proj_model.save_pretrained(f"{args.dump_path}/text_proj") + elif args.debug == 'decoder': + decoder = movq(args=args, checkpoint_map_location=checkpoint_map_location) + decoder.save_pretrained(f"{args.dump_path}/decoder") else: raise ValueError(f"unknown debug value : {args.debug}") \ No newline at end of file From 86f21457c6550959a3d28365b3177e7da0b0e98d Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Sun, 7 May 2023 00:11:07 +0530 Subject: [PATCH 3/8] Refactoring --- scripts/convert_kandinsky_to_diffusers.py | 4 +-- src/diffusers/models/attention.py | 27 +++++++------------ src/diffusers/models/unet_2d_blocks.py | 27 ++++++------------- src/diffusers/models/vae.py | 33 ++++++++--------------- src/diffusers/models/vq_model.py | 13 +++------ 5 files changed, 34 insertions(+), 70 deletions(-) diff --git a/scripts/convert_kandinsky_to_diffusers.py b/scripts/convert_kandinsky_to_diffusers.py index 560fc316040a..1904015a2527 100644 --- a/scripts/convert_kandinsky_to_diffusers.py +++ b/scripts/convert_kandinsky_to_diffusers.py @@ -761,13 +761,13 @@ def text2img(*, args, checkpoint_map_location): "in_channels":3, "out_channels":3, "latent_channels":4, - "use_spatial_norm":True, "down_block_types":("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "AttnDownEncoderBlock2D"), "up_block_types":("AttnUpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"), "num_vq_embeddings":16384, "block_out_channels":(128, 256, 256, 512), "vq_embed_dim":4, - "layers_per_block":2 + "layers_per_block":2, + "norm_type":"spatial" } diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 4863905b99ce..3e631a63601e 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -55,15 +55,14 @@ def __init__( norm_num_groups: int = 32, rescale_output_factor: float = 1.0, eps: float = 1e-5, - use_spatial_norm: bool = False, + norm_type: str = "default", # default, spatial temb_channels: Optional[int] = None, ): super().__init__() self.channels = channels - self.use_spatial_norm = use_spatial_norm self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 - if use_spatial_norm: + if norm_type == "spatial": self.norm = SpatialNorm(channels, temb_channels) else: self.norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) @@ -137,10 +136,10 @@ def forward(self, hidden_states, zq=None): batch, channel, height, width = hidden_states.shape # norm - if self.use_spatial_norm: - hidden_states = self.norm(hidden_states, zq=zq) - else: + if zq is None: hidden_states = self.norm(hidden_states) + else: + hidden_states = self.norm(hidden_states, zq=zq) hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) @@ -551,30 +550,22 @@ def forward(self, x, emb): class SpatialNorm(nn.Module): + """ + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002 + """ def __init__( self, f_channels, zq_channels, - norm_layer=nn.GroupNorm, - freeze_norm_layer=False, - add_conv=False, ): super().__init__() - self.norm_layer = norm_layer(num_channels=f_channels,num_groups=32,eps=1e-6,affine=True) - if freeze_norm_layer: - for p in self.norm_layer.parameters: - p.requires_grad = False - self.add_conv = add_conv - if self.add_conv: - self.conv = nn.Conv2d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1) + self.norm_layer = nn.GroupNorm(num_channels=f_channels,num_groups=32,eps=1e-6,affine=True) self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) def forward(self, f, zq): f_size = f.shape[-2:] zq = F.interpolate(zq, size=f_size, mode="nearest") - if self.add_conv: - zq = self.conv(zq) norm_f = self.norm_layer(f) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 4c3254eaae46..540f6c380dc2 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -404,13 +404,10 @@ def __init__( add_attention: bool = True, attn_num_head_channels=1, output_scale_factor=1.0, - use_spatial_norm=False, ): super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.add_attention = add_attention - self.use_spatial_norm = use_spatial_norm - # there is always at least one resnet resnets = [ ResnetBlock2D( @@ -437,7 +434,7 @@ def __init__( rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, - use_spatial_norm=use_spatial_norm, + norm_type=resnet_time_scale_shift, temb_channels=temb_channels ) ) @@ -466,10 +463,8 @@ def forward(self, hidden_states, temb=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: - if self.use_spatial_norm: - hidden_states = attn(hidden_states, temb) - else: - hidden_states = attn(hidden_states) + hidden_states = attn(hidden_states, temb) + hidden_states = resnet(hidden_states, temb) return hidden_states @@ -2026,7 +2021,6 @@ def __init__( super().__init__() resnets = [] attentions = [] - self.use_spatial_norm = resnet_time_scale_shift == "spatial" for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels @@ -2052,8 +2046,8 @@ def __init__( rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, - use_spatial_norm=self.use_spatial_norm, - temb_channels=temb_channels + temb_channels=temb_channels, + norm_type=resnet_time_scale_shift ) ) @@ -2065,15 +2059,10 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states, zq): + def forward(self, hidden_states, temb=None): for resnet, attn in zip(self.resnets, self.attentions): - if self.use_spatial_norm: - hidden_states = resnet(hidden_states, temb=zq) - hidden_states = attn(hidden_states, zq) - else: - hidden_states = resnet(hidden_states, temb=None) - hidden_states = attn(hidden_states, zq) - + hidden_states = resnet(hidden_states, temb=temb) + hidden_states = attn(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 5e7353736b53..776203042e9b 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -149,14 +149,11 @@ def __init__( layers_per_block=2, norm_num_groups=32, act_fn="silu", - use_spatial_norm=False, - temb_channels=None + norm_type="default", # default, spatial ): super().__init__() self.layers_per_block = layers_per_block - self.use_spatial_norm = use_spatial_norm - self.conv_in = nn.Conv2d( in_channels, block_out_channels[-1], @@ -167,10 +164,9 @@ def __init__( self.mid_block = None self.up_blocks = nn.ModuleList([]) - resnet_time_scale_shift = "default" - if self.use_spatial_norm: - resnet_time_scale_shift = "spatial" + + temb_channels = in_channels if norm_type == "spatial" else None # mid self.mid_block = UNetMidBlock2D( @@ -178,11 +174,10 @@ def __init__( resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, - resnet_time_scale_shift=resnet_time_scale_shift, + resnet_time_scale_shift=norm_type, attn_num_head_channels=None, resnet_groups=norm_num_groups, temb_channels=temb_channels, - use_spatial_norm=use_spatial_norm, ) # up @@ -206,13 +201,13 @@ def __init__( resnet_groups=norm_num_groups, attn_num_head_channels=None, temb_channels=temb_channels, - resnet_time_scale_shift=resnet_time_scale_shift, + resnet_time_scale_shift=norm_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out - if use_spatial_norm: + if norm_type == "spatial": self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) else: self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) @@ -243,24 +238,18 @@ def custom_forward(*inputs): sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) else: # middle - if self.use_spatial_norm: - sample = self.mid_block(sample, zq) - else: - sample = self.mid_block(sample) + sample = self.mid_block(sample, zq) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: - if self.use_spatial_norm: - sample = up_block(sample, zq) - else: - sample = up_block(sample) + sample = up_block(sample, zq) # post-process - if self.use_spatial_norm: - sample = self.conv_norm_out(sample, zq) - else: + if zq is None: sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, zq) sample = self.conv_act(sample) sample = self.conv_out(sample) diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index ee2a7d203bcf..040447ba82c8 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -82,11 +82,10 @@ def __init__( norm_num_groups: int = 32, vq_embed_dim: Optional[int] = None, scaling_factor: float = 0.18215, - use_spatial_norm: bool = False + norm_type: str = "default" ): super().__init__() - self.use_spatial_norm = use_spatial_norm # pass init params to Encoder self.encoder = Encoder( @@ -115,8 +114,7 @@ def __init__( layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, - use_spatial_norm=use_spatial_norm, - temb_channels=latent_channels, + norm_type=norm_type, ) def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: @@ -136,11 +134,8 @@ def decode( quant, emb_loss, info = self.quantize(h) else: quant = h - quant2 = self.post_quant_conv(quant) - if self.use_spatial_norm: - dec = self.decoder(quant2, quant) - else : - dec = self.decoder(quant2) + quant2 = self.post_quant_conv(quant) + dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) if not return_dict: return (dec,) From d4859aad748c6ccb40558768ea7c8e2022d8659f Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Wed, 10 May 2023 15:26:58 +0530 Subject: [PATCH 4/8] Use new attention processor --- scripts/convert_kandinsky_to_diffusers.py | 16 ++-- src/diffusers/models/attention.py | 38 +------- src/diffusers/models/unet_2d_blocks.py | 102 +++++++++++++++++----- 3 files changed, 94 insertions(+), 62 deletions(-) diff --git a/scripts/convert_kandinsky_to_diffusers.py b/scripts/convert_kandinsky_to_diffusers.py index 1904015a2527..56b23316e9b6 100644 --- a/scripts/convert_kandinsky_to_diffusers.py +++ b/scripts/convert_kandinsky_to_diffusers.py @@ -1075,19 +1075,19 @@ def movq_attention_to_diffusers_checkpoint_spatial_norm(checkpoint, *, diffusers f"{diffusers_attention_prefix}.norm.conv_b.weight": checkpoint[f"{attention_prefix}.norm.conv_b.weight"], f"{diffusers_attention_prefix}.norm.conv_b.bias": checkpoint[f"{attention_prefix}.norm.conv_b.bias"], # query - f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], - f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"], + f"{diffusers_attention_prefix}.attention.to_q.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.attention.to_q.bias": checkpoint[f"{attention_prefix}.q.bias"], # key - f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], - f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"], + f"{diffusers_attention_prefix}.attention.to_k.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.attention.to_k.bias": checkpoint[f"{attention_prefix}.k.bias"], # value - f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], - f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"], + f"{diffusers_attention_prefix}.attention.to_v.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.attention.to_v.bias": checkpoint[f"{attention_prefix}.v.bias"], # proj_attn - f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ + f"{diffusers_attention_prefix}.attention.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ :, :, 0, 0 ], - f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + f"{diffusers_attention_prefix}.attention.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], } diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 3e631a63601e..7de00e4f0045 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -55,17 +55,12 @@ def __init__( norm_num_groups: int = 32, rescale_output_factor: float = 1.0, eps: float = 1e-5, - norm_type: str = "default", # default, spatial - temb_channels: Optional[int] = None, ): super().__init__() self.channels = channels self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 - if norm_type == "spatial": - self.norm = SpatialNorm(channels, temb_channels) - else: - self.norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) + self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) # define q,k,v as linear layers self.query = nn.Linear(channels, channels) @@ -131,15 +126,12 @@ def set_use_memory_efficient_attention_xformers( self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers self._attention_op = attention_op - def forward(self, hidden_states, zq=None): + def forward(self, hidden_states): residual = hidden_states batch, channel, height, width = hidden_states.shape # norm - if zq is None: - hidden_states = self.norm(hidden_states) - else: - hidden_states = self.norm(hidden_states, zq=zq) + hidden_states = self.group_norm(hidden_states) hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) @@ -546,26 +538,4 @@ def forward(self, x, emb): x = F.group_norm(x, self.num_groups, eps=self.eps) x = x * (1 + scale) + shift - return x - - -class SpatialNorm(nn.Module): - """ - Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002 - """ - def __init__( - self, - f_channels, - zq_channels, - ): - super().__init__() - self.norm_layer = nn.GroupNorm(num_channels=f_channels,num_groups=32,eps=1e-6,affine=True) - self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, f, zq): - f_size = f.shape[-2:] - zq = F.interpolate(zq, size=f_size, mode="nearest") - norm_f = self.norm_layer(f) - new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) - return new_f + return x \ No newline at end of file diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 540f6c380dc2..91ba2eaaedc0 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -427,17 +427,25 @@ def __init__( for _ in range(num_layers): if self.add_attention: - attentions.append( - AttentionBlock( - in_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - norm_type=resnet_time_scale_shift, - temb_channels=temb_channels + if resnet_time_scale_shift == "spatial": + attentions.append( + MOVQAttention( + in_channels, + temb_channels, + attn_num_head_channels + )) + else: + attentions.append( + AttentionBlock( + in_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + norm_type=resnet_time_scale_shift, + temb_channels=temb_channels + ) ) - ) else: attentions.append(None) @@ -1946,6 +1954,30 @@ def custom_forward(*inputs): return hidden_states +class MOVQAttention(nn.Module): + def __init__(self, query_dim, temb_channels, attn_num_head_channels): + super().__init__() + + self.norm = SpatialNorm(query_dim, temb_channels) + num_heads = query_dim // attn_num_head_channels if attn_num_head_channels is not None else 1 + dim_head = attn_num_head_channels if attn_num_head_channels is not None else query_dim + self.attention = Attention( + query_dim=query_dim, + heads=num_heads, + dim_head=dim_head, + bias=True + ) + + def forward(self, hidden_states, temb): + residual = hidden_states + hidden_states = self.norm(hidden_states, temb).view(hidden_states.shape[0], hidden_states.shape[1], -1) + hidden_states = self.attention(hidden_states.transpose(1, 2), None, None).transpose(1, 2) + hidden_states = hidden_states.view(residual.shape) + hidden_states = hidden_states + residual + return hidden_states + + + class UpDecoderBlock2D(nn.Module): def __init__( self, @@ -2039,17 +2071,26 @@ def __init__( pre_norm=resnet_pre_norm, ) ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - temb_channels=temb_channels, - norm_type=resnet_time_scale_shift + if resnet_time_scale_shift == "spatial": + attentions.append( + MOVQAttention( + out_channels, + temb_channels=temb_channels, + attn_num_head_channels=attn_num_head_channels, + ) + ) + else: + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + temb_channels=temb_channels, + norm_type=resnet_time_scale_shift + ) ) - ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) @@ -2815,3 +2856,24 @@ def forward( hidden_states = attn_output + hidden_states return hidden_states + +class SpatialNorm(nn.Module): + """ + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002 + """ + def __init__( + self, + f_channels, + zq_channels, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels,num_groups=32,eps=1e-6,affine=True) + self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f, zq): + f_size = f.shape[-2:] + zq = F.interpolate(zq, size=f_size, mode="nearest") + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f From b6e986121dac9284dae724ae91eac318e8ea3735 Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Thu, 11 May 2023 00:58:10 +0530 Subject: [PATCH 5/8] Fix minor bug in Attention Block --- src/diffusers/models/unet_2d_blocks.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 91ba2eaaedc0..434b65ff77a5 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -442,8 +442,6 @@ def __init__( rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, - norm_type=resnet_time_scale_shift, - temb_channels=temb_channels ) ) else: @@ -2087,8 +2085,6 @@ def __init__( rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, - temb_channels=temb_channels, - norm_type=resnet_time_scale_shift ) ) From 7de123706730c94a9b42fbd98c946d8b6f4337a0 Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Fri, 12 May 2023 15:38:33 +0530 Subject: [PATCH 6/8] Reposition Spatial Norm --- src/diffusers/models/attention.py | 23 ++++++++++++++++++++++- src/diffusers/models/unet_2d_blocks.py | 22 +--------------------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 4c72aa955bdd..fc988ffabe77 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -540,4 +540,25 @@ def forward(self, x, emb): x = F.group_norm(x, self.num_groups, eps=self.eps) x = x * (1 + scale) + shift - return x \ No newline at end of file + return x + +class SpatialNorm(nn.Module): + """ + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002 + """ + def __init__( + self, + f_channels, + zq_channels, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels,num_groups=32,eps=1e-6,affine=True) + self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f, zq): + f_size = f.shape[-2:] + zq = F.interpolate(zq, size=f_size, mode="nearest") + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f \ No newline at end of file diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 025761ddcddb..5b560f82b81f 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from torch import nn -from .attention import AdaGroupNorm, AttentionBlock +from .attention import AdaGroupNorm, AttentionBlock, SpatialNorm from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .dual_transformer_2d import DualTransformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D @@ -2856,23 +2856,3 @@ def forward( return hidden_states -class SpatialNorm(nn.Module): - """ - Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002 - """ - def __init__( - self, - f_channels, - zq_channels, - ): - super().__init__() - self.norm_layer = nn.GroupNorm(num_channels=f_channels,num_groups=32,eps=1e-6,affine=True) - self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, f, zq): - f_size = f.shape[-2:] - zq = F.interpolate(zq, size=f_size, mode="nearest") - norm_f = self.norm_layer(f) - new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) - return new_f From cf3cbcba84936784ef887a92ab67986863ed8d3c Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Fri, 12 May 2023 16:58:18 +0530 Subject: [PATCH 7/8] Add decoder to text2img pipeline --- .../pipelines/kandinsky/pipeline_kandinsky.py | 31 +++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py index d988f38506ea..984cf74cd07b 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py @@ -19,7 +19,7 @@ XLMRobertaTokenizerFast, ) -from ...models import UNet2DConditionModel +from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...schedulers import UnCLIPScheduler from ...utils import ( @@ -30,6 +30,7 @@ ) from .text_encoder import MultilingualCLIP from .text_proj import KandinskyTextProjModel +from PIL import Image logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -44,6 +45,21 @@ def get_new_h_w(h, w): new_w += 1 return new_h * 8, new_w * 8 +def process_images(batch): + scaled = ( + ((batch + 1) * 127.5) + .round() + .clamp(0, 255) + .to(torch.uint8) + .to("cpu") + .permute(0, 2, 3, 1) + .numpy() + ) + images = [] + for i in range(scaled.shape[0]): + images.append(Image.fromarray(scaled[i])) + return images + class KandinskyPipeline(DiffusionPipeline): """ @@ -63,6 +79,8 @@ class KandinskyPipeline(DiffusionPipeline): Conditional U-Net architecture to denoise the image embedding. text_proj ([`KandinskyTextProjModel`]): Utility class to prepare and combine the embeddings before they are passed to the decoder. + decoder ([`VQModel`]): + Decoder to generate the image from the latents. """ def __init__( @@ -72,6 +90,7 @@ def __init__( text_proj: KandinskyTextProjModel, unet: UNet2DConditionModel, scheduler: UnCLIPScheduler, + decoder: VQModel ): super().__init__() @@ -94,6 +113,13 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): latents = latents * scheduler.init_noise_sigma return latents + def get_image(self, latents): + images = self.decoder.decode(latents, force_not_quantize=True)["sample"] + images = process_images(images) + return images + + + def _encode_prompt( self, prompt, @@ -371,4 +397,5 @@ def __call__( _, latents = latents.chunk(2) - return latents + images = self.get_image(latents) + return images From 9b557659ab88ec71ce76c69276ab0d7401bcd8e3 Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Fri, 12 May 2023 18:04:58 +0530 Subject: [PATCH 8/8] Register decoder --- src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py index 984cf74cd07b..b0d8b4b429a1 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py @@ -100,6 +100,7 @@ def __init__( text_proj=text_proj, unet=unet, scheduler=scheduler, + decoder=decoder, ) def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):