Skip to content

Commit 74f00b7

Browse files
rohansjoshifacebook-github-bot
authored andcommitted
Added mse range setting
Summary: Added option to use MSE range setting algorithm. This algorithm does a linear grid search over scales and selects those which minimizes mean squared error (see the line_search method in the class PerChannelParamsObserver). This method is applied for quantizing weights per channel. Accuracy is still poor, but somewhat better than using MinMax. On wikitext task, with grid size 200: | Model Name | max_seq_len | ptq | word_perplexity |----------|----------|----------|-----------| | Llama 3.2-1B Instruct | 128 | 16a4w | 2367107 | | Llama 3.2-1B Instruct | 128 | 16a4w_block | 5523977 | | Llama 3.2-1B Instruct | 128 | 8a8w | 501663 | Reviewed By: cccclai Differential Revision: D77055545
1 parent 18e4240 commit 74f00b7

File tree

1 file changed

+82
-6
lines changed

1 file changed

+82
-6
lines changed

examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@
2020
annotate_matmul_16a8w,
2121
)
2222

23+
from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import (
24+
PerChannelParamObserver,
25+
)
26+
from executorch.backends.qualcomm.quantizer.qconfig import (
27+
_derived_bias_quant_spec,
28+
QuantizationConfig,
29+
)
30+
2331
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
2432
from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d
2533

@@ -47,6 +55,8 @@
4755

4856
from torchao.quantization.pt2e import MinMaxObserver
4957
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
58+
from torchao.quantization.pt2e.quantizer import QuantizationSpec
59+
5060

5161
sys.setrecursionlimit(4096)
5262
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -174,6 +184,38 @@ def permute(w, heads):
174184
)
175185
quantizer.add_custom_quant_annotations(custom_annotations)
176186

187+
if args.range_setting == "mse_weight":
188+
weight_dtype = (
189+
torch.int4
190+
if quant_dtype in (QuantDtype.use_16a4w, QuantDtype.use_16a4w_block)
191+
else torch.int8
192+
)
193+
per_channel_q_config = quantizer.default_quant_config.quant_config
194+
weight_qspec = QuantizationSpec(
195+
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
196+
quant_min=(
197+
-7
198+
if weight_dtype == torch.int4
199+
else torch.iinfo(weight_dtype).min + 1
200+
),
201+
quant_max=(
202+
7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max
203+
),
204+
qscheme=torch.per_channel_symmetric,
205+
ch_axis=0,
206+
observer_or_fake_quant_ctr=PerChannelParamObserver.with_args(
207+
**{"steps": 200, "use_mse": True}
208+
),
209+
)
210+
quantizer.default_quant_config.per_channel_quant_config = (
211+
QuantizationConfig(
212+
input_activation=per_channel_q_config.input_activation,
213+
output_activation=per_channel_q_config.output_activation,
214+
weight=weight_qspec,
215+
bias=_derived_bias_quant_spec,
216+
)
217+
)
218+
177219
model.has_quant_io = True
178220

179221
with torch.no_grad():
@@ -198,6 +240,29 @@ def permute(w, heads):
198240
use_i64_token=use_i64_token,
199241
)
200242

243+
calibrate_with_wikitext = True
244+
if calibrate_with_wikitext:
245+
from datasets import load_dataset
246+
247+
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
248+
for i in range(1000):
249+
sample = dataset["train"][i]["text"]
250+
tokens = tokenizer.encode(
251+
sample, bos=True, eos=False, allowed_special="all"
252+
)
253+
if len(tokens) > 100: # assuming max_seq_len > 100
254+
prompt = tokenizer.decode(tokens[1 : args.max_seq_len - 1])
255+
calibrate(
256+
inputs,
257+
prompt,
258+
model,
259+
tokenizer=tokenizer,
260+
ar_len=args.prefill_ar_len,
261+
max_seq_len=args.max_seq_len,
262+
kv_updater=None,
263+
use_i64_token=use_i64_token,
264+
)
265+
201266
model = convert_pt2e(model)
202267

203268
model = WrappedLlamaModel(
@@ -245,6 +310,23 @@ def main() -> None:
245310
torch.manual_seed(seed)
246311
modelname = "llama2"
247312
parser = build_args_parser()
313+
parser.add_argument(
314+
"-P",
315+
"--ptq",
316+
help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w, 16a4w and 16a4w_block.",
317+
type=str,
318+
)
319+
parser.add_argument(
320+
"--range_setting",
321+
help="Choose which range setting method (e.g. mse_weight). If not specified, will do minmax for weights and activations",
322+
type=str,
323+
)
324+
parser.add_argument(
325+
"--limit",
326+
help="the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples",
327+
type=str,
328+
)
329+
248330
args = parser.parse_args()
249331
args.llama_model = "llama3_2"
250332
# Overrides this arg, because evaluation requires full logits.
@@ -257,15 +339,9 @@ def main() -> None:
257339
args.use_kv_cache = False
258340
args.prefill_ar_len = args.max_seq_length
259341

260-
# To do fewer samples for faster evaluation
261-
args.limit = 0.1
262-
# args.samples = {'wikitext': list(range(1))}
263-
264342
args.device = "cuda" if torch.cuda.is_available() else "cpu"
265343
torch.set_default_device(args.device)
266344

267-
args.ptq = "8a8w"
268-
269345
eval_llama(modelname, args)
270346

271347

0 commit comments

Comments
 (0)