From d0118d7c1f794a34181566c40c5cf77f82b2f10f Mon Sep 17 00:00:00 2001 From: lucylq Date: Mon, 7 Apr 2025 15:10:38 -0700 Subject: [PATCH] refactor-attention --- examples/models/llama/llama_transformer.py | 23 +++++------------- examples/models/llama/model.py | 27 +++++++++++++++++++--- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 5c8db7f208d..380bb5910db 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -12,10 +12,7 @@ import torch import torch.nn.functional as F -from executorch.examples.models.llama.attention import ( - ATTENTION_REGISTRY, - ForwardOptions, -) +from executorch.examples.models.llama.attention import Attention, ForwardOptions from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.norm import RMSNorm @@ -83,19 +80,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock(nn.Module): - def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): + def __init__(self, args: ModelArgs, attention: Attention): super().__init__() self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.head_dim - if args.attention_type not in ATTENTION_REGISTRY: - raise ValueError( - f"Unknown attention type: {args.attention_type}. " - f"Available: {list(ATTENTION_REGISTRY.keys())}" - ) - cls = ATTENTION_REGISTRY[args.attention_type] - self.attention = cls(args, layer_id, rope) + self.attention = attention if args.moe: self.block_sparse_moe = MOEFeedForward(args) else: @@ -117,7 +108,7 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: class Transformer(nn.Module): - def __init__(self, params: ModelArgs): + def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope): super().__init__() self.params = params self.vocab_size = params.vocab_size @@ -130,10 +121,8 @@ def __init__(self, params: ModelArgs): if self.apply_embedding else None ) - self.rope = Rope(params) - self.layers = torch.nn.ModuleList() - for layer_id in range(params.n_layers): - self.layers.append(TransformerBlock(layer_id, params, self.rope)) + self.layers = layers + self.rope = rope self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = ( nn.Linear(params.dim, params.vocab_size, bias=False) diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 2c82841c573..884b0132213 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -15,9 +15,13 @@ get_checkpoint_dtype, get_default_model_resource_dir, ) -from executorch.examples.models.llama.llama_transformer import Transformer - +from executorch.examples.models.llama.attention import ATTENTION_REGISTRY +from executorch.examples.models.llama.llama_transformer import ( + Transformer, + TransformerBlock, +) from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import Rope from torchao.utils import TorchAOBaseTensor try: @@ -174,7 +178,24 @@ def __init__(self, **kwargs): # They possess all other metadata a tensor carries such as size, stride, requires_grad. with torch.device("meta"): # Model itself is loaded in default dtype, fp32. - self.model_ = Transformer(model_args) + + # Construct attention layers. + rope = Rope(model_args) + if model_args.attention_type not in ATTENTION_REGISTRY: + raise ValueError( + f"Unknown attention type: {model_args.attention_type}. " + f"Available: {list(ATTENTION_REGISTRY.keys())}" + ) + layers = torch.nn.ModuleList() + cls = ATTENTION_REGISTRY[model_args.attention_type] + for layer_id in range(model_args.n_layers): + attention = cls(model_args, layer_id, rope) + transformer_block = TransformerBlock(model_args, attention) + layers.append(transformer_block) + + # Construct transformer model. + self.model_ = Transformer(model_args, layers, rope) + # Get checkpoint dtype. if checkpoint: self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)