20
20
annotate_matmul_16a8w ,
21
21
)
22
22
23
+ from executorch .backends .qualcomm .quantizer .observers .per_channel_param_observer import (
24
+ PerChannelParamObserver ,
25
+ )
26
+ from executorch .backends .qualcomm .quantizer .qconfig import (
27
+ _derived_bias_quant_spec ,
28
+ QuantizationConfig ,
29
+ )
30
+
23
31
from executorch .backends .qualcomm .quantizer .quantizer import QuantDtype
24
32
from executorch .backends .qualcomm .utils .utils import convert_linear_to_conv2d
25
33
47
55
48
56
from torchao .quantization .pt2e import MinMaxObserver
49
57
from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
58
+ from torchao .quantization .pt2e .quantizer import QuantizationSpec
59
+
50
60
51
61
sys .setrecursionlimit (4096 )
52
62
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -174,6 +184,38 @@ def permute(w, heads):
174
184
)
175
185
quantizer .add_custom_quant_annotations (custom_annotations )
176
186
187
+ if args .range_setting == "mse_weight" :
188
+ weight_dtype = (
189
+ torch .int4
190
+ if quant_dtype in (QuantDtype .use_16a4w , QuantDtype .use_16a4w_block )
191
+ else torch .int8
192
+ )
193
+ per_channel_q_config = quantizer .default_quant_config .quant_config
194
+ weight_qspec = QuantizationSpec (
195
+ dtype = torch .int8 if weight_dtype == torch .int4 else weight_dtype ,
196
+ quant_min = (
197
+ - 7
198
+ if weight_dtype == torch .int4
199
+ else torch .iinfo (weight_dtype ).min + 1
200
+ ),
201
+ quant_max = (
202
+ 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).max
203
+ ),
204
+ qscheme = torch .per_channel_symmetric ,
205
+ ch_axis = 0 ,
206
+ observer_or_fake_quant_ctr = PerChannelParamObserver .with_args (
207
+ ** {"steps" : 200 , "use_mse" : True }
208
+ ),
209
+ )
210
+ quantizer .default_quant_config .per_channel_quant_config = (
211
+ QuantizationConfig (
212
+ input_activation = per_channel_q_config .input_activation ,
213
+ output_activation = per_channel_q_config .output_activation ,
214
+ weight = weight_qspec ,
215
+ bias = _derived_bias_quant_spec ,
216
+ )
217
+ )
218
+
177
219
model .has_quant_io = True
178
220
179
221
with torch .no_grad ():
@@ -198,6 +240,29 @@ def permute(w, heads):
198
240
use_i64_token = use_i64_token ,
199
241
)
200
242
243
+ calibrate_with_wikitext = True
244
+ if calibrate_with_wikitext :
245
+ from datasets import load_dataset
246
+
247
+ dataset = load_dataset ("wikitext" , "wikitext-2-raw-v1" )
248
+ for i in range (1000 ):
249
+ sample = dataset ["train" ][i ]["text" ]
250
+ tokens = tokenizer .encode (
251
+ sample , bos = True , eos = False , allowed_special = "all"
252
+ )
253
+ if len (tokens ) > 100 : # assuming max_seq_len > 100
254
+ prompt = tokenizer .decode (tokens [1 : args .max_seq_len - 1 ])
255
+ calibrate (
256
+ inputs ,
257
+ prompt ,
258
+ model ,
259
+ tokenizer = tokenizer ,
260
+ ar_len = args .prefill_ar_len ,
261
+ max_seq_len = args .max_seq_len ,
262
+ kv_updater = None ,
263
+ use_i64_token = use_i64_token ,
264
+ )
265
+
201
266
model = convert_pt2e (model )
202
267
203
268
model = WrappedLlamaModel (
@@ -245,6 +310,23 @@ def main() -> None:
245
310
torch .manual_seed (seed )
246
311
modelname = "llama2"
247
312
parser = build_args_parser ()
313
+ parser .add_argument (
314
+ "-P" ,
315
+ "--ptq" ,
316
+ help = "If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w, 16a4w and 16a4w_block." ,
317
+ type = str ,
318
+ )
319
+ parser .add_argument (
320
+ "--range_setting" ,
321
+ help = "Choose which range setting method (e.g. mse_weight). If not specified, will do minmax for weights and activations" ,
322
+ type = str ,
323
+ )
324
+ parser .add_argument (
325
+ "--limit" ,
326
+ help = "the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples" ,
327
+ type = str ,
328
+ )
329
+
248
330
args = parser .parse_args ()
249
331
args .llama_model = "llama3_2"
250
332
# Overrides this arg, because evaluation requires full logits.
@@ -257,15 +339,9 @@ def main() -> None:
257
339
args .use_kv_cache = False
258
340
args .prefill_ar_len = args .max_seq_length
259
341
260
- # To do fewer samples for faster evaluation
261
- args .limit = 0.1
262
- # args.samples = {'wikitext': list(range(1))}
263
-
264
342
args .device = "cuda" if torch .cuda .is_available () else "cpu"
265
343
torch .set_default_device (args .device )
266
344
267
- args .ptq = "8a8w"
268
-
269
345
eval_llama (modelname , args )
270
346
271
347
0 commit comments