Skip to content

Commit 202ad0b

Browse files
committed
export llama with lora
1 parent 753a88e commit 202ad0b

File tree

3 files changed

+100
-4
lines changed

3 files changed

+100
-4
lines changed

examples/models/llama/attention.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,48 @@ def forward(
160160
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
161161

162162

163+
class LoRALinear(nn.Module):
164+
def __init__(
165+
self,
166+
in_dim: int,
167+
out_dim: int,
168+
rank: int,
169+
alpha: float,
170+
dropout: float = 0.0,
171+
use_bias: bool = False,
172+
):
173+
super().__init__()
174+
self.in_dim = in_dim
175+
self.out_dim = out_dim
176+
self.rank = rank
177+
self.alpha = alpha
178+
self.use_bias = use_bias
179+
self.dropout = dropout
180+
181+
# Setup weight and bias
182+
# self.wq = nn.Linear(
183+
# self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
184+
# )
185+
linear_q = nn.Linear(in_dim, out_dim, bias=use_bias)
186+
weight = linear_q.weight
187+
bias = linear_q.bias if self.use_bias else None
188+
self.register_parameter("weight", nn.Parameter(weight))
189+
self.register_parameter(
190+
"bias", nn.Parameter(bias) if bias is not None else None
191+
)
192+
193+
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
194+
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
195+
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)
196+
197+
def forward(self, x: torch.Tensor) -> torch.Tensor:
198+
out = torch.nn.functional.linear(x, self.weight, self.bias)
199+
lora_out = self.lora_a(self.dropout(x))
200+
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)
201+
202+
return out + lora_out
203+
204+
163205
@register_attention("mha")
164206
class AttentionMHA(Attention):
165207
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
@@ -185,9 +227,19 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
185227
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
186228
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)
187229

188-
self.wq = nn.Linear(
189-
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
230+
# self.wq = nn.Linear(
231+
# self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
232+
# )
233+
self.wq = LoRALinear(
234+
in_dim=self.dim,
235+
out_dim=self.n_heads * self.head_dim,
236+
rank=8,
237+
alpha=16.0,
238+
dropout=0.0,
239+
use_bias=self.attention_qkv_bias,
190240
)
241+
242+
# breakpoint()
191243
self.wk = nn.Linear(
192244
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
193245
)
@@ -238,6 +290,10 @@ def forward(
238290

239291
# QKV
240292
q, k, v = self.wq(x), self.wk(x), self.wv(x)
293+
294+
# q_per_kv = self.num_heads // self.num_kv_heads
295+
# q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim)
296+
241297
# We need view_copy elimination
242298
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
243299
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
@@ -268,6 +324,7 @@ def forward(
268324

269325
mask = self.mask[:seqlen, :seqlen]
270326

327+
# Somehow, kv become floats.
271328
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
272329

273330
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

examples/models/llama/export_llama_lib.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@
2626

2727
from executorch.backends.vulkan._passes.remove_asserts import remove_asserts
2828
from executorch.devtools.backend_debug import print_delegation_info
29-
3029
from executorch.devtools.etrecord import generate_etrecord
30+
31+
from executorch.examples.models.llama.attention import ForwardOptions
3132
from executorch.examples.models.llama.hf_download import (
3233
download_and_convert_hf_checkpoint,
3334
)
@@ -455,6 +456,18 @@ def build_args_parser() -> argparse.ArgumentParser:
455456
help="Whether the checkpoin is pre-quantized with QAT or not.",
456457
)
457458

459+
parser.add_argument(
460+
"--adapter",
461+
default=None,
462+
help="Adapter path",
463+
)
464+
465+
parser.add_argument(
466+
"--adapter_config",
467+
default=None,
468+
help="Adapter config path",
469+
)
470+
458471
parser.add_argument(
459472
"-lora",
460473
"--use_lora",
@@ -591,6 +604,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
591604
checkpoint_dir = (
592605
canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
593606
)
607+
adapter_path = canonical_path(args.adapter) if args.adapter else None
594608
params_path = canonical_path(args.params) if args.params else None
595609
output_dir_path = canonical_path(args.output_dir, dir=True)
596610
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
@@ -602,6 +616,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
602616
args.model,
603617
checkpoint=checkpoint_path,
604618
checkpoint_dir=checkpoint_dir,
619+
adapter=adapter_path,
605620
params_path=params_path,
606621
use_kv_cache=args.use_kv_cache,
607622
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
@@ -641,7 +656,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
641656
logging.warning(
642657
f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}."
643658
)
644-
659+
breakpoint()
645660
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
646661

647662
# We want to quantize (in the source transforms) the weights of the model
@@ -1020,6 +1035,7 @@ def _load_llama_model(
10201035
*,
10211036
checkpoint: Optional[str] = None,
10221037
checkpoint_dir: Optional[str] = None,
1038+
adapter: Optional[str] = None,
10231039
params_path: Optional[str] = None,
10241040
use_kv_cache: bool = False,
10251041
use_sdpa_with_kv_cache: bool = False,
@@ -1067,6 +1083,7 @@ def _load_llama_model(
10671083
model_class_name,
10681084
checkpoint=checkpoint,
10691085
checkpoint_dir=checkpoint_dir,
1086+
adapter=adapter,
10701087
params=params_path,
10711088
use_kv_cache=use_kv_cache,
10721089
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
@@ -1081,6 +1098,11 @@ def _load_llama_model(
10811098
args=args,
10821099
)
10831100
)
1101+
eg = torch.tensor([[13347]], dtype=torch.long)
1102+
ip = torch.tensor([0], dtype=torch.long)
1103+
fw = ForwardOptions(input_pos=ip)
1104+
# breakpoint()
1105+
# model.forward(eg, fw)
10841106

10851107
return LLMEdgeManager(
10861108
model=model,

examples/models/llama/model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from executorch.examples.models.llama.model_args import ModelArgs
2121

22+
from torchtune.models import convert_weights
23+
2224
try:
2325
from .fairseq2 import convert_to_llama_checkpoint
2426

@@ -45,6 +47,9 @@ def __init__(self, **kwargs):
4547
# Params file.
4648
params_path = kwargs.get("params", None)
4749

50+
# Adapter file.
51+
adapter_path = kwargs.get("adapter", None)
52+
4853
self.use_kv_cache = kwargs.get("use_kv_cache", False)
4954
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
5055
self.generate_full_logits = kwargs.get("generate_full_logits", False)
@@ -96,6 +101,15 @@ def __init__(self, **kwargs):
96101
elif checkpoint_path:
97102
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
98103

104+
# Load adapter.
105+
if adapter_path:
106+
print("Loading adapter from: ", adapter_path)
107+
adapter = torch.load(adapter_path, map_location=device, mmap=True)
108+
adapter = convert_weights.tune_to_meta(adapter)
109+
# Convert from tune to meta.
110+
# breakpoint()
111+
checkpoint.update(adapter)
112+
99113
# If given checkpoint is fairseq, convert to llama checkpoint.
100114
fairseq2_checkpoint = kwargs.get("fairseq2", False)
101115
if fairseq2_checkpoint:
@@ -174,8 +188,10 @@ def __init__(self, **kwargs):
174188
with torch.device("meta"):
175189
# Model itself is loaded in default dtype, fp32.
176190
self.model_ = Transformer(model_args)
191+
177192
# Get checkpoint dtype.
178193
if checkpoint:
194+
# breakpoint()
179195
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)
180196
else:
181197
self.model_.checkpoint_dtype = torch.float32
@@ -252,6 +268,7 @@ def __init__(self, **kwargs):
252268
# by default initialized to fp32. This is fine because every other supported type
253269
# losslessly converts to fp32, so we don't lose precision here.
254270
if checkpoint:
271+
# breakpoint()
255272
missing, unexpected = self.model_.load_state_dict(
256273
checkpoint,
257274
strict=False,

0 commit comments

Comments
 (0)