diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 54f738ba737..a36305359f3 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -185,6 +185,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps) self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps) + # self.wq = nn.Linear( + # self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias + # ) self.wq = nn.Linear( self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias ) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index dc7f763fade..c6890c218e3 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -101,7 +101,7 @@ "phi_4_mini", "smollm2", ] -TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"] +TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision", "llama3_2_lora"] HUGGING_FACE_REPO_IDS = { "qwen2_5": "Qwen/Qwen2.5-1.5B", "phi_4_mini": "microsoft/Phi-4-mini-instruct", @@ -209,6 +209,12 @@ def build_args_parser() -> argparse.ArgumentParser: help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.", ) + parser.add_argument( + "--adapter", + default=None, + help="Adapter path", + ) + parser.add_argument( "--use_qnn_sha", action="store_true", @@ -585,6 +591,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: checkpoint_dir = ( canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None ) + adapter_path = canonical_path(args.adapter) if args.adapter else None params_path = canonical_path(args.params) if args.params else None output_dir_path = canonical_path(args.output_dir, dir=True) weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA @@ -592,10 +599,12 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: # Convert dtype override string arg to actual type. dtype_override = DType[args.dtype_override] + # breakpoint() # 1, OK. edge_manager = _load_llama_model( args.model, checkpoint=checkpoint_path, checkpoint_dir=checkpoint_dir, + adapter=adapter_path, params_path=params_path, use_kv_cache=args.use_kv_cache, use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, @@ -616,10 +625,16 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: dtype_override=dtype_override, args=args, ) - # At this point, the model is loaded in the default fp32. # Checkpoint dtype should be lower or equal precision to the dtype override. + eg = torch.tensor([[2, 3, 4]], dtype=torch.int64) + ip = torch.tensor([[0, 1, 2]], dtype=torch.long) + + em1 = edge_manager.model.forward(eg, input_pos=ip) + eager = torch.load("/data/users/lfq/executorch/eager_res.pt") + torch.allclose(eager, em1) + # breakpoint() # 4, OK. checkpoint_dtype = edge_manager.model.checkpoint_dtype if not ( checkpoint_dtype == dtype_override.to_torch_dtype() @@ -637,6 +652,10 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: ) edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype()) + # edge_manager.model = edge_manager.model.to(dtype=torch.float32) + em2 = edge_manager.model.forward(eg, input_pos=ip) + torch.allclose(em2, eager) + # breakpoint() # 5, not OK, gets converted to bf16. OK if dtype is consistent. # We want to quantize (in the source transforms) the weights of the model # in the checkpoint dtype. @@ -649,7 +668,9 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: args=args, ) ) - + # torch.allclose here as well. + em3 = edge_manager.model.forward(eg, input_pos=ip) + torch.allclose(em3, eager) return edge_manager @@ -777,6 +798,9 @@ def _to_edge_and_lower_llama( # noqa: C901 builder_exported_to_edge = builder_exported.pt2e_quantize( quantizers ).export_to_edge() + breakpoint() + # ^to_edge_res.pt + # allclose 1e-1 compared to pre-auto. # to_backend partitioners = [] @@ -911,7 +935,16 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 # export_to_edge builder_exported = _prepare_for_llama_export(args).export() + eg = torch.tensor([[2, 3, 4]], dtype=torch.int64) + ip = torch.tensor([[0, 1, 2]], dtype=torch.long) + b_e = builder_exported.model.forward(eg, input_pos=ip) + eager = torch.load("/data/users/lfq/executorch/eager_res.pt") + torch.allclose(b_e, eager) + # breakpoint() + builder_exported.run_canonical_optimizations() + b_e2 = builder_exported.model.forward(eg, input_pos=ip) + torch.allclose(b_e2, eager) modelname = builder_exported.modelname if args.export_only: @@ -932,6 +965,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 args, ) else: + # breakpoint() + b_e3 = builder_exported.model.forward(eg, input_pos=ip) + torch.allclose(b_e3, eager) builder = _to_edge_and_lower_llama( builder_exported, modelname, @@ -941,6 +977,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 quant_dtype, args, ) + breakpoint() if args.profile_memory: generate_memory_trace(builder.export_program, "memory_profile.json") @@ -1004,6 +1041,7 @@ def _load_llama_model( *, checkpoint: Optional[str] = None, checkpoint_dir: Optional[str] = None, + adapter: Optional[str] = None, params_path: Optional[str] = None, use_kv_cache: bool = False, use_sdpa_with_kv_cache: bool = False, @@ -1038,6 +1076,9 @@ def _load_llama_model( if modelname == "llama3_2_vision": module_name = "llama3_2_vision" model_class_name = "Llama3_2Decoder" + if modelname == "llama3_2_lora": + module_name = "llama3_2_lora" + model_class_name = "Llama3_2_Lora" else: raise ValueError(f"{modelname} is not a valid Llama model.") else: @@ -1051,6 +1092,7 @@ def _load_llama_model( model_class_name, checkpoint=checkpoint, checkpoint_dir=checkpoint_dir, + adapter=adapter, params=params_path, use_kv_cache=use_kv_cache, use_sdpa_with_kv_cache=use_sdpa_with_kv_cache, @@ -1066,6 +1108,7 @@ def _load_llama_model( ) ) + # breakpoint() # 3. OK. return LLMEdgeManager( model=model, modelname=modelname, @@ -1093,7 +1136,7 @@ def _load_llama_model( model.max_seq_len, # pyre-fixme[6]: For 6th argument expected `ModelArgs` but got # `Union[Tensor, Module]`. - model.max_context_len, + max_context_len, # pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor, # Module]`. model.n_layers, @@ -1244,6 +1287,9 @@ def _get_source_transforms( # noqa if args.vulkan: transforms.append(replace_with_vulkan_rotary_emb) + # transforms.append( + # replace_rope_with_inference_rope() + # ) return transforms diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 1bb7d277545..03048d393a5 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -9,12 +9,16 @@ # Example script for exporting Llama2 to flatbuffer import math -from typing import Tuple +from typing import Optional, Tuple import torch from executorch.examples.models.llama.attention import KVCache, SDPA +# from executorch.extension.llm.modules.attention import SDPA as TTSDPA + +from torchtune.modules.attention_utils import _MaskType + class SDPACustom(torch.nn.Module): def __init__( @@ -49,7 +53,7 @@ def forward( q, k, v, - input_pos[0].item(), + input_pos.item(), None, # Attention mask 0, # dropout probability. Ignored by the code True, # is_causal @@ -60,11 +64,19 @@ def forward( def _replace_sdpa_with_custom_op(module: torch.nn.Module): for name, child in module.named_children(): if isinstance(child, SDPA): + breakpoint() setattr( module, name, SDPACustom(child.dim), ) + # elif isinstance(child, TTSDPA): + # # breakpoint() + # setattr( + # module, + # name, + # SDPAConverter(child.num_heads * child.head_dim), + # ) else: _replace_sdpa_with_custom_op(child) @@ -76,6 +88,63 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: return module +# Convert from torchtune SDPA to SDPACustom. +class SDPAConverter(torch.nn.Module): + def __init__( + self, + dim: int, + ): + super().__init__() + self.dim = dim + self.SDPA = SDPACustom(dim) + + def forward( + self, + q: torch.Tensor, # [b, s, n_h, h_d] + k: torch.Tensor, # [b, s, n_kv, h_d] + v: torch.Tensor, # [b, s, n_kv, h_d] + bsz: int, + seq_len: int, + mask: Optional[_MaskType] = None, + ): + # input_pos = 0 + # Mask isn't used in SDPA? + + # Make sure mask isn't None + # take the first row of the mask, number of 0s/Trues. Index of the first non-zero. + # assert mask is not None + if mask is not None: + attention_mask = mask.reshape(-1, max_seq_len) + first_row = attention_mask[0, :] + start_pos = torch.argmin(first_row).item() - 1 + else: + start_pos = 0 + + ## + q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Custom op only supports float32 currently. Converting to/from float32 is + # faster than not having the op. + input_dtype = q.dtype + q = q.to(dtype=torch.float) + k = k.to(dtype=torch.float) + v = v.to(dtype=torch.float) + + output = torch.ops.llama.custom_sdpa( + q, + k, + v, + start_pos, + mask, # Attention mask + 0, # dropout probability. Ignored by the code + True, # is_causal + ) + return output.view(bsz, seq_len, self.dim).to(dtype=input_dtype) + # return self.SDPA(start_pos, q, k, v, bsz, seq_len, mask) + + class SDPASimple(torch.nn.Module): def __init__( self, diff --git a/examples/models/llama3_2_lora/__init__.py b/examples/models/llama3_2_lora/__init__.py new file mode 100644 index 00000000000..c0dea76977b --- /dev/null +++ b/examples/models/llama3_2_lora/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .model import Llama3_2_Lora + +__all__ = [ + "Llama3_2_Lora", +] diff --git a/examples/models/llama3_2_lora/model.py b/examples/models/llama3_2_lora/model.py new file mode 100644 index 00000000000..b1657e39e10 --- /dev/null +++ b/examples/models/llama3_2_lora/model.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import json +import os +from typing import Any, Dict + +import torch + +from executorch.examples.models.checkpoint import get_checkpoint_dtype +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import Rope, RotaryEmbedding +from executorch.examples.models.model_base import EagerModelBase +from executorch.extension.llm.modules.attention import ( + replace_mha_with_inference_mha, + replace_rope_with_inference_rope, +) + +from torchtune.models import convert_weights + +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE + +from torchtune.models.llama3_2._component_builders import lora_llama3_2 + + +class Llama3_2_Lora(EagerModelBase): + def __init__(self, **kwargs): + # Set member vars from kwargs. + self.max_seq_len = kwargs.get( + "max_seq_len", 8192 + ) # Trained to be a lot larger, but this value is kept small because of static kv cache at the moment. + # self.encoder_max_seq_len = kwargs.get( + # "encoder_max_seq_len", int(4 * (448 / 14) ** 2 + 1) + # ) # Same as above. + self.generate_full_logits = kwargs.get("generate_full_logits", False) + self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", True) + self.output_prune_map_path = kwargs.get("output_prune_map_path", None) + self.use_kv_cache = kwargs.get("use_kv_cache", False) + self.verbose = kwargs.get("verbose", False) + self.args = kwargs.get("args", None) + self.dtype = kwargs.get("dtype", torch.float16) + self.use_checkpoint = False + self.max_context_len = kwargs.get("max_context_len", 8192) + + # Single checkpoint file. + checkpoint_path = kwargs.get("checkpoint") + + if os.path.isfile(checkpoint_path): + self.use_checkpoint = True + + params_path = kwargs.get("params") + adapter_path = kwargs.get("adapter") + + # self.input_pos = torch.arange(self.max_seq_len, dtype=torch.int64) + # Load checkpoint and params. + device = "cpu" + if self.use_checkpoint: + checkpoint = torch.load( + checkpoint_path, map_location=device, weights_only=False, mmap=True + ) + checkpoint = convert_weights.meta_to_tune(checkpoint) + self.dtype = get_checkpoint_dtype(checkpoint) + + adapter = torch.load( + adapter_path, map_location="cpu", mmap=True, weights_only=False + ) + + checkpoint.update(adapter) + + with open(params_path, "r") as f: + params = json.loads(f.read()) + + # Load model. + # Cannot use "with torch.device("meta"):" because it causes some exceptions during export, + # i.e. the model isn't fully initialized or something. + self.model_ = lora_llama3_2( + lora_attn_modules=[ + "q_proj", + ], + apply_lora_to_mlp=False, + apply_lora_to_output=False, + # llama3_2 args + vocab_size=params["vocab_size"], + num_layers=params["n_layers"], + num_heads=params["n_heads"], + num_kv_heads=params["n_kv_heads"], + embed_dim=params["dim"], + max_seq_len=self.max_seq_len, # 131072 + # intermediate_dim=params["intermediate_dim"], # 8192, calc is 4096 + # LoRA args. TODO take in the adapter config. + lora_rank=8, + lora_alpha=16, + ) + self.model_.requires_grad_(False) + for param_name, param_val in params.items(): + setattr(self.model_, param_name, param_val) + + setattr(self.model_, "enable_dynamic_shape", self.enable_dynamic_shape) + # Source transformation for MultiHeadAttention + self.model_ = replace_mha_with_inference_mha(self.model_) + + model_args: ModelArgs = ModelArgs( + max_seq_len=self.max_seq_len, + max_context_len=self.max_context_len, + use_kv_cache=self.use_kv_cache, + generate_full_logits=self.generate_full_logits, + enable_dynamic_shape=self.enable_dynamic_shape, + **params, + ) + # Source transformation for RoPE + # self.model_ = replace_rope_with_inference_rope(self.model_, model_args) + + setattr(self.model_, "checkpoint_dtype", self.dtype) + if self.use_checkpoint: + # Load checkpoint. + missing, unexpected = self.model_.load_state_dict( + checkpoint, + strict=False, + assign=True, + ) + if kwargs.get("verbose", False): + print("============= missing keys ================") + print(missing) + print("============= /missing ================") + print("============= unexpected keys ================") + print(unexpected) + print("============= /unexpected ================") + + self.model_.to(self.dtype) + eg = torch.tensor([[2, 3, 4]], dtype=torch.int64) + ip = torch.tensor([[0, 1, 2]], dtype=torch.long) + # self.model_.forward(eg, input_pos=ip) + # breakpoint() # 2, OK. + self.model_.forward(eg, input_pos=ip) + + def get_eager_model(self) -> torch.nn.Module: + return self.model_ + + def get_example_inputs(self): + return (torch.tensor([[2, 3, 4]], dtype=torch.int64),) + # return ( + # torch.tensor([[2, 3, 4]], dtype=torch.long), + # {"input_pos": torch.tensor([0], dtype=torch.long)}, + # ) + # return (torch.ones(1, self.n_tokens, dtype=torch.int64),) + + # eg=torch.tensor([[2, 3, 4]], dtype=torch.int64) + # ip=torch.tensor([[0, 1, 2]], dtype=torch.long) + def get_example_kwarg_inputs(self): + return {"input_pos": torch.tensor([[0, 1, 2]], dtype=torch.long)} + + def get_dynamic_shapes(self): + dim = torch.export.Dim("token_dim", min=1, max=self.max_seq_len - 1) + return ({1: dim}, {1: dim}) diff --git a/exir/program/_program.py b/exir/program/_program.py index 7a2120f9e9b..c0a1b9c061a 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1314,10 +1314,38 @@ def to_edge( edge_programs: Dict[str, ExportedProgram] = {} + eg = torch.tensor([[2, 3, 4]], dtype=torch.int64) + ip = torch.tensor([[0, 1, 2]], dtype=torch.long) + eager = torch.load("/data/users/lfq/executorch/eager_res.pt") + for name, program in aten_programs.items(): # Decompose to Core ATen - program = program.run_decompositions(_default_decomposition_table()) + x1 = program.module()(eg, input_pos=ip) + preserve_ops = [ + # torch.ops.aten.type_as.default, + # torch.ops.aten.to.dtype, + # torch.ops.aten.linear.default, + # torch.ops.aten.rms_norm.default, + # torch.ops.aten.transpose.int, + # torch.ops.aten.flatten.using_ints, + # torch.ops.aten.silu.default, + # torch.ops.aten.reshape.default, + # torch.ops.aten.stack.default, + torch.ops.aten.scaled_dot_product_attention.default, + ] + + # breakpoint() + table = _default_decomposition_table() + for op in preserve_ops: + table.pop(op, None) + + # breakpoint() + program = program.run_decompositions(table) + # x2 = program.module()(eg, input_pos=ip) + breakpoint() + # assert torch.allclose(x1, x2) edge_programs[name] = _generate_edge_program(name, config, program) + # also this does not allclose to run_decomp. return EdgeProgramManager(edge_programs, constant_methods, config) diff --git a/export.py b/export.py new file mode 100644 index 00000000000..f67a7dbdfb8 --- /dev/null +++ b/export.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +import torch +from executorch.exir import to_edge +from executorch.extension.llm.modules.attention import replace_mha_with_inference_mha +from torch import int64, long, no_grad, randint, Tensor, zeros +from torch.export import export, ExportedProgram +from torch.nn.attention import sdpa_kernel, SDPBackend +from torchtune.models import convert_weights + +from torchtune.models.llama3_2._model_builders import lora_llama3_2_3b + +# from torchtune.modules.peft import get_adapter_params, set_trainable_params + + +def export_llama3_lora(model, checkpoint_path, adapter_path) -> None: + example_args = ( + torch.tensor( + [[1]], dtype=torch.long + ), # tokens, with kv cache our input token length is always just 1 token. + ) + example_kwargs = { + "input_pos": torch.tensor( + [0], dtype=torch.long + ) # start_pos, what token of output are we on. + } + breakpoint() + model.requires_grad_(False) + model = replace_mha_with_inference_mha(model) + + # print("Loading checkpoint") + checkpoint = torch.load( + checkpoint_path, map_location="cpu", mmap=True, weights_only=False + ) + state_dict = convert_weights.meta_to_tune(checkpoint) + breakpoint() + state_dict.pop("output.weight") + adapter = torch.load( + adapter_path, map_location="cpu", mmap=True, weights_only=False + ) + + state_dict.update(adapter) + + breakpoint() + # missing, unexpected = model.load_state_dict( + # state_dict, + # # strict=False, + # # assign=True, + # ) + + # print("Missing: ", missing) + # print("Unexpected: ", unexpected) + + eager_result = model.forward( + example_args[0], + ) + + breakpoint() + with sdpa_kernel([SDPBackend.MATH]): + print("Exporting to aten dialect") + aten_dialect: ExportedProgram = export( + model, args=example_args, kwargs=example_kwargs, strict=True + ) + + exported_result = aten_dialect.module()( + example_args[0], input_pos=example_kwargs["input_pos"] + ) + + print("Checking eager and exported results are close") + print(torch.allclose(eager_result, exported_result)) + + # print("EXPORTED_MODEL") + # with open("/data/users/lfq/executorch/exported_model.txt", "w") as f: + # f.write(aten_dialect.graph_module.print_readable()) + + # print("EXPORTED_MODEL No decomp") + # with open("/data/users/lfq/executorch/exported_decomp.txt", "w") as f: + # run_decomp = aten_dialect.run_decompositions() + # f.write(run_decomp.graph_module.print_readable()) + + # 2. to_edge: Make optimizations for Edge devices. + print("Lowering to edge dialect") + edge_program = to_edge(aten_dialect) + + edge_result = edge_program._edge_programs["forward"].module()( + example_args[0], input_pos=example_kwargs["input_pos"] + ) + + print("Checking eager and edge results are close") + print(torch.allclose(eager_result, edge_result)) + breakpoint() + + # 3. to_executorch: Convert the graph to an ExecuTorch program. + print("Exporting to executorch") + executorch_program = edge_program.to_executorch() + + # 4. Save the compiled .pte program. + print("Saving to llama3_2_lora.pte") + with open("llama3_2_lora.pte", "wb") as file: + file.write(executorch_program.buffer) + + print("Done.") + + +def export_llama3(model, checkpoint_path) -> None: + pass + + +def main() -> None: + print("Main") + + parser = argparse.ArgumentParser() + parser.add_argument( + "-c", + "--checkpoint", + help="checkpoint path", + ) + + parser.add_argument( + "-a", + "--adapter", + help="adapter path", + ) + + args = parser.parse_args() + + lora_model = lora_llama3_2_3b( + lora_attn_modules=[ + "q_proj", + # "v_proj", + # "o_proj", + # "gate_proj", + # "down_proj", + # "up_proj", + ], + # apply_lora_to_mlp=False, + # lora_rank=8, + ) + lora_model.eval() + + # Export for inference. + export_llama3_lora(lora_model, args.checkpoint, args.adapter) + + # Is this something to do with the torchtune checkpoint? + # export_llama3(llama3_model, args.checkpoint) + + +if __name__ == "__main__": + main() diff --git a/extension/export_util/utils.py b/extension/export_util/utils.py index 2679930178a..9f2c536acd5 100644 --- a/extension/export_util/utils.py +++ b/extension/export_util/utils.py @@ -66,6 +66,14 @@ def _core_aten_to_edge( constant_methods=edge_constant_methods, compile_config=edge_compile_config, ) + eg = torch.tensor([[2, 3, 4]], dtype=torch.int64) + ip = torch.tensor([[0, 1, 2]], dtype=torch.long) + em1 = edge_manager._edge_programs["forward"].module()(eg, input_pos=ip) + eager = torch.load("/data/users/lfq/executorch/eager_res.pt") + # breakpoint() + assert torch.allclose( + eager, em1 + ) # Fails to 100 lol. Passes if we do not decompose sdpa. if verbose: logging.info(f"Exported graph:\n{edge_manager.exported_program()}") return edge_manager @@ -82,6 +90,12 @@ def export_to_edge( strict=True, verbose=True, ) -> EdgeProgramManager: + eg = torch.tensor([[2, 3, 4]], dtype=torch.int64) + ip = torch.tensor([[0, 1, 2]], dtype=torch.long) + x = model.forward(eg, input_pos=ip) + eager = torch.load("/data/users/lfq/executorch/eager_res.pt") + torch.allclose(x, eager) + # breakpoint() # same core_aten_ep = _to_core_aten( model, example_inputs, @@ -90,6 +104,9 @@ def export_to_edge( strict=strict, verbose=verbose, ) + # breakpoint() # same + y = core_aten_ep.module()(eg, input_pos=ip) + torch.allclose(y, eager) return _core_aten_to_edge( core_aten_ep, edge_constant_methods, edge_compile_config, verbose=verbose ) @@ -108,8 +125,10 @@ def export_to_exec_prog( ) -> ExecutorchProgramManager: m = model.eval() # pre-autograd export. eventually this will become torch.export + breakpoint() m = export_for_training(m, example_inputs).module() + breakpoint() core_aten_ep = _to_core_aten( m, example_inputs, @@ -118,11 +137,15 @@ def export_to_exec_prog( strict=strict, ) + breakpoint() edge_m = _core_aten_to_edge( core_aten_ep, edge_constant_methods, edge_compile_config ) + breakpoint() exec_prog = edge_m.to_executorch(backend_config) + + breakpoint() return exec_prog diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 751e2d16175..536fbd00a3a 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -29,6 +29,7 @@ from executorch.exir.pass_base import ExportPass from executorch.exir.passes import MemoryPlanningPass +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass @@ -164,6 +165,10 @@ def source_transform( transforms (List[Callable[[torch.nn.Module], torch.nn.Module]]): A list of source transforms. """ + eg = torch.tensor([[2, 3, 4]], dtype=torch.int64) + ip = torch.tensor([[0, 1, 2]], dtype=torch.long) + eager = torch.load("/data/users/lfq/executorch/eager_res.pt") + breakpoint() # OK. for transform in transforms: self.model = transform(self.model) self.applied_source_transforms.extend(transforms) @@ -171,6 +176,10 @@ def source_transform( if self.verbose: logging.info(f"Applied source transforms: {self.applied_source_transforms}") logging.info(f"Model after source transforms: {self.model}") + + breakpoint() + x = self.model.forward(eg, input_pos=ip) + assert torch.allclose(x, eager) return self def _get_dynamic_shape(self) -> Any: @@ -184,7 +193,8 @@ def _get_dynamic_shape(self) -> Any: self.dynamic_shapes = ({1: dim},) elif self.enable_dynamic_shape: # Two input arguments: tokens and input_pos but input_pos is static shape - self.dynamic_shapes = ({1: dim}, {"input_pos": {0: 1}}) + # self.dynamic_shapes = ({1: dim}, {"input_pos": {0: 1}}) + self.dynamic_shapes = ({1: dim}, {0: 1}) else: # Two input arguments: tokens and input_pos but both are of static shape self.dynamic_shapes = None @@ -195,6 +205,7 @@ def _get_edge_config(self) -> EdgeCompileConfig: _check_ir_validity=False, _skip_type_promotion=bool(self.dtype == DType.fp16), _skip_dim_order=True, + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], ) return edge_config @@ -202,7 +213,8 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram: dynamic_shape = self._get_dynamic_shape() # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) - with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + # with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + with torch.no_grad(): if hasattr(self.args, "qnn") and self.args.qnn: # TODO: this is temporary, as qnn flow does not work with new, non-functional export IR. # See issue: https://github.com/pytorch/executorch/issues/7373 @@ -229,12 +241,24 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram: logging.info(f"inputs: {self.example_inputs}") logging.info(f"kwargs: {self.example_kwarg_inputs}") logging.info(f"dynamic shapes: {dynamic_shape}") + eg = torch.tensor([[2, 3, 4]], dtype=torch.int64) + ip = torch.tensor([[0, 1, 2]], dtype=torch.long) + + pre_ep = self.model.forward(eg, input_pos=ip) + eager = torch.load("/data/users/lfq/executorch/eager_res.pt") + assert torch.allclose(eager, pre_ep) + + # breakpoint() # Bad here. OK without sdpa context. exported_module = export_for_training( self.model if not module else module, self.example_inputs, kwargs=self.example_kwarg_inputs, dynamic_shapes=dynamic_shape, ) + ep = self.model.forward(eg, input_pos=ip) + torch.allclose(eager, ep) + # breakpoint() # Same as above. + return exported_module def export(self) -> "LLMEdgeManager": @@ -244,7 +268,16 @@ def export(self) -> "LLMEdgeManager": The full torch.export() if called later on during to_edge() or to_edge_transform_and_lower(). """ + eg = torch.tensor([[2, 3, 4]], dtype=torch.int64) + ip = torch.tensor([[0, 1, 2]], dtype=torch.long) + x = self.model.forward(eg, input_pos=ip) + eager = torch.load("/data/users/lfq/executorch/eager_res.pt") + torch.allclose(x, eager) exported_module = self._export() + y = exported_module.module()(eg, input_pos=ip) + torch.allclose(y, eager) + # breakpoint() + # Need to store the graph module to record transformation passes. # Persisting those changes back to an ExportedProgram will require # an additional export(). @@ -405,10 +438,10 @@ def export_to_edge(self) -> "LLMEdgeManager": """ dynamic_shape = self._get_dynamic_shape() edge_config = self._get_edge_config() - # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) - with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + # with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + with torch.no_grad(): if self.pre_autograd_graph_module is None: # Run export() if it didn't run self.export() @@ -422,6 +455,12 @@ def export_to_edge(self) -> "LLMEdgeManager": ) with override_export_behaviour: + eg = torch.tensor([[2, 3, 4]], dtype=torch.int64) + ip = torch.tensor([[0, 1, 2]], dtype=torch.long) + x = self.model.forward(eg, input_pos=ip) + eager = torch.load("/data/users/lfq/executorch/eager_res.pt") + # breakpoint() + assert torch.allclose(x, eager) self.edge_manager = export_to_edge( self.pre_autograd_graph_module, # pyre-fixme[6] self.example_inputs, @@ -431,6 +470,9 @@ def export_to_edge(self) -> "LLMEdgeManager": edge_compile_config=edge_config, verbose=self.verbose, ) + y = self.edge_manager.exported_program().module()(eg, input_pos=ip) + assert torch.allclose(y, eager) + # breakpoint() return self def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager": @@ -495,6 +537,7 @@ def to_executorch( # transpose followed by a regular op_mm. ConvertToLinearPass(), QuantFusionPass(), + InitializedMutableBufferPass(["kv_cache_pos"]), ] if passes: # pyre-fixme[6]: In call `list.extend`, for 1st positional argument, diff --git a/extension/llm/modules/__init__.py b/extension/llm/modules/__init__.py index 02e3c389f67..1ff085749f5 100644 --- a/extension/llm/modules/__init__.py +++ b/extension/llm/modules/__init__.py @@ -20,5 +20,6 @@ "replace_tiled_token_positional_embedding", "MultiHeadAttention", "replace_mha_with_inference_mha", + "replace_rope_with_inference_rope", "KVCache", ] diff --git a/extension/llm/modules/attention.py b/extension/llm/modules/attention.py index bb688c2b8c1..2098ddb610f 100644 --- a/extension/llm/modules/attention.py +++ b/extension/llm/modules/attention.py @@ -11,6 +11,8 @@ import torch import torchtune.modules.attention as TorchTuneAttention + +from executorch.examples.models.llama.model_args import ModelArgs from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache from torch import nn from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention @@ -18,6 +20,12 @@ logger = logging.getLogger(__name__) +from executorch.examples.models.llama.rope import Rope, RotaryEmbedding + +from executorch.examples.models.llama.source_transformation.sdpa import SDPACustom +from executorch.extension.llm.custom_ops import custom_ops +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE + class MultiHeadAttention(nn.Module): """ @@ -147,15 +155,17 @@ def __init__( # Use flex attention if supported and we are sample packing self._attention_call = _sdpa_or_flex_attention() - self._sdpa = SDPA( - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - head_dim=self.head_dim, - attn_dropout=self.attn_dropout if self.training else 0.0, - is_causal=self.is_causal, - attention_fn=self._attention_call, - kv_cache=self.kv_cache, - ) + # self._sdpa = SDPA( + # num_kv_heads=self.num_kv_heads, + # num_heads=self.num_heads, + # head_dim=self.head_dim, + # attn_dropout=self.attn_dropout if self.training else 0.0, + # is_causal=self.is_causal, + # attention_fn=self._attention_call, + # kv_cache=self.kv_cache, + # ) + dim = self.head_dim * self.num_heads + self._sdpa = SDPACustom(dim=dim) # this flag indicates whether to update the kv-cache during forward # passes. when disabled, we can have the cache setup but still @@ -249,6 +259,7 @@ def forward( # y has shape [b, s_y, d] b, s_x, _ = x.shape + # breakpoint() # q has shape [b, s_x, num_heads * head_dim] q = self.q_proj(x) @@ -312,7 +323,9 @@ def false_fn(y): self.kv_cache.v_cache.copy_(v) self.kv_cache.kv_cache_pos.copy_(cache_pos) - output = self._sdpa(q, k, v, b, s_x, mask=mask) + # output = self._sdpa(q, k, v, b, s_x, mask=mask) + # breakpoint() + output = self._sdpa(input_pos[0][2], q, k, v, b, s_x, mask=mask) return self.output_proj(output) @@ -413,3 +426,19 @@ def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module: """ _replace_mha_with_inference_mha(module) return module + + +def replace_rope_with_inference_rope( + module: torch.nn.Module, params: ModelArgs +) -> None: + # Create a mapping of Llama3ScaledRoPE instances to their replacement Rope instances + rope_replacements = {} + + for name, child in module.named_children(): + print(name, child) + if isinstance(child, Llama3ScaledRoPE): + # breakpoint() + setattr(module, name, Rope(params)) + else: + replace_rope_with_inference_rope(child, params) + return module diff --git a/install_script.sh b/install_script.sh new file mode 100644 index 00000000000..b2d01efdfd4 --- /dev/null +++ b/install_script.sh @@ -0,0 +1,58 @@ +./install_executorch.sh --pybind xnnpack + +examples/models/llama/install_requirements.sh + +cmake -DPYTHON_EXECUTABLE=python \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DEXECUTORCH_ENABLE_LOGGING=1 \ + -DCMAKE_BUILD_TYPE=Debug \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_BUILD_XNNPACK=ON \ + -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ + -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ + -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ + -Bcmake-out . + +cmake --build cmake-out -j16 --target install --config Debug + +cmake -DPYTHON_EXECUTABLE=python \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DCMAKE_BUILD_TYPE=Debug \ + -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ + -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ + -DEXECUTORCH_BUILD_XNNPACK=ON \ + -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ + -Bcmake-out/examples/models/llama \ + examples/models/llama + +cmake --build cmake-out/examples/models/llama -j16 --config Debug + +cmake-out/examples/models/llama/llama_main --model_path=llama3_2_plain.pte --tokenizer_path=../llama3b/tokenizer.model --prompt="Hi" +cmake-out/examples/models/llama/llama_main --model_path=llama3_2_lora.pte --tokenizer_path=../llama3b/tokenizer.model --prompt="Hi" + +# Other prompts +# What happens if you eat watermelon seeds? + +-- export +# No quantization +# Set these paths to point to the downloaded files +LLAMA_CHECKPOINT=../llama3b/consolidated.00.pth +LLAMA_PARAMS=../llama3b/params.json + +python -m examples.models.llama.export_llama \ + --model "llama3_2" \ + --checkpoint "${LLAMA_CHECKPOINT:?}" \ + --params "${LLAMA_PARAMS:?}" \ + -kv \ + --use_sdpa_with_kv_cache \ + -d bf16 \ + --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \ + --output_name="llama3_2.pte" + + +Notes: +- We can't use llama_transformer, unless we update the model definition with additional layers. +- When using custom export script, we need to take in the checkpoint+adapter files. +- Check that eager works with consolidated.00.pth and adapter.pt?