From f54daa07358bc55f1d5022df2ae62c9fa8b804cd Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Wed, 23 Aug 2023 15:42:19 -0600 Subject: [PATCH 1/4] Allow convert.py to convert to q8_0 Fix issue with bounded_parallel_map and greedy consuming iterator Display elapsed time during conversion --- convert.py | 117 +++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 91 insertions(+), 26 deletions(-) diff --git a/convert.py b/convert.py index b7c626d8473c5..8a888849f35ad 100755 --- a/convert.py +++ b/convert.py @@ -3,6 +3,7 @@ import gguf import argparse import concurrent.futures +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor import copy import enum import faulthandler @@ -17,6 +18,7 @@ import signal import struct import sys +import time import zipfile import numpy as np @@ -50,7 +52,13 @@ class UnquantizedDataType: DT_I32 = UnquantizedDataType('I32') DT_BF16 = UnquantizedDataType('BF16') -DataType = Union[UnquantizedDataType] +@dataclass(frozen=True) +class QuantizedDataType: + name: str + +DT_Q8_0 = QuantizedDataType('Q8_0') + +DataType = Union[UnquantizedDataType, QuantizedDataType] DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = { DT_BF16: np.dtype(np.uint16), @@ -73,8 +81,9 @@ class UnquantizedDataType: # TODO: rename to LLAMAFileType # TODO: move to `gguf.py` class GGMLFileType(enum.IntEnum): - AllF32 = 0 - MostlyF16 = 1 # except 1d tensors + AllF32 = 0 + MostlyF16 = 1 # except 1d tensors + MostlyQ8_0 = 7 # except 1d tensors def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType: if len(tensor.shape) == 1: @@ -84,6 +93,8 @@ def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType: return DT_F32 elif self == GGMLFileType.MostlyF16: return DT_F16 + elif self == GGMLFileType.MostlyQ8_0: + return DT_Q8_0 else: raise ValueError(self) @@ -391,7 +402,10 @@ def __init__(self, ndarray: NDArray) -> None: self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype] def astype(self, data_type: DataType) -> Tensor: - dtype = DATA_TYPE_TO_NUMPY[data_type] + if data_type == DT_Q8_0: + dtype = DATA_TYPE_TO_NUMPY[DT_F32] + else: + dtype = DATA_TYPE_TO_NUMPY[data_type] if self.data_type == DT_BF16: self.ndarray = bf16_to_fp32(self.ndarray) return UnquantizedTensor(self.ndarray.astype(dtype)) @@ -455,7 +469,7 @@ class LazyTensor: def load(self) -> Tensor: ret = self._load() - assert ret.data_type == self.data_type, (self.data_type, ret.data_type, self.description) + assert ret.data_type == self.data_type or (self.data_type is DT_Q8_0 and ret.data_type is DT_F32), (self.data_type, ret.data_type, self.description) return ret def astype(self, data_type: DataType) -> 'LazyTensor': @@ -699,23 +713,32 @@ def lazy_load_file(path: Path) -> ModelPlus: In = TypeVar('In') Out = TypeVar('Out') -def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int) -> Iterable[Out]: +def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, factory: Callable = ThreadPoolExecutor) -> Iterable[Out]: '''Parallel map, but with backpressure. If the caller doesn't call `next` fast enough, this will stop calling `func` at some point rather than letting results pile up in memory. Specifically, there is a max of one output value buffered per thread.''' - with concurrent.futures.ThreadPoolExecutor() as executor: + iterable = iter(iterable) + with factory(max_workers = max_workers) as executor: futures: List[concurrent.futures.Future[Out]] = [] - items_rev = list(iterable)[::-1] - for i in range(min(concurrency, len(items_rev))): - futures.append(executor.submit(func, items_rev.pop())) - while futures: + done = False + for i in range(concurrency): + try: + nexti = next(iterable) + except StopIteration: + break + futures.append(executor.submit(func, nexti)) + while not done or futures: result = futures.pop(0).result() - if items_rev: - futures.append(executor.submit(func, items_rev.pop())) + while len(futures) < concurrency: + try: + nexti = next(iterable) + except StopIteration: + done = True + break + futures.append(executor.submit(func, nexti)) yield result - def check_vocab_size(params: Params, vocab: Vocab) -> None: if params.n_vocab != vocab.vocab_size: assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab) @@ -732,6 +755,22 @@ def check_vocab_size(params: Params, vocab: Vocab) -> None: msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})." raise Exception(msg) +#### Mini Q8_0 quantization in Python +QK8_0 = 32 +BLOCK_Q8_0 = np.dtype([('d', ' None: @@ -777,9 +816,16 @@ def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: n_elements = 1 for dim in tensor.shape: n_elements *= dim - data_type = DATA_TYPE_TO_NUMPY[tensor.data_type] - data_nbytes = n_elements * data_type.itemsize - self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes) + if tensor.data_type == DT_Q8_0: + assert n_elements > 0 and n_elements % QK8_0 == 0, f'Cannot quantize as Q8_0, {n_elements} not a multiple of block size {QK8_0}' + data_type= BLOCK_Q8_0 + raw_dtype = gguf.GGMLQuantizationType.Q8_0 + data_nbytes = n_elements + (n_elements // QK8_0) * 2 + else: + data_type = DATA_TYPE_TO_NUMPY[tensor.data_type] + data_nbytes = n_elements * data_type.itemsize + raw_dtype = None + self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype = raw_dtype) def write_meta(self) -> None: self.gguf.write_header_to_file() @@ -805,7 +851,19 @@ def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None: of.close() @staticmethod - def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None: + def do_item(item: Tuple[str, LazyTensor]) -> (DataType, NDArray): + name, lazy_tensor = item + tensor = lazy_tensor.load().to_ggml() + return (lazy_tensor.data_type, tensor.ndarray) + + @staticmethod + def maybe_do_quant(item: Tuple[DataType, NDArray]) -> NDArray: + if item[0] == DT_Q8_0: + return quantize_array_q8_0(item[1]) + return item[1] + + @staticmethod + def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab) -> None: check_vocab_size(params, vocab) of = OutputFile(fname_out) @@ -821,16 +879,19 @@ def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) - of.write_meta() of.write_tensor_info() - def do_item(item: Tuple[str, LazyTensor]) -> NDArray: - name, lazy_tensor = item - return lazy_tensor.load().to_ggml().ndarray - # tensor data - ndarrays = bounded_parallel_map(do_item, model.items(), concurrency=8) + ndarrays = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = 8) + if ftype == GGMLFileType.MostlyQ8_0: + ndarrays = bounded_parallel_map(OutputFile.maybe_do_quant, ndarrays, concurrency = 8, max_workers = 8, factory = ProcessPoolExecutor) + else: + ndarrays = map(OutputFile.maybe_do_quant, ndarrays) + + start = time.time() for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): + elapsed = time.time() - start size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) padi = len(str(len(model))) - print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type}") + print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:6} | T+{int(elapsed):4}") of.gguf.write_tensor_data(ndarray) of.close() @@ -842,6 +903,8 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi return GGMLFileType.AllF32 if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16, DT_BF16)): return GGMLFileType.MostlyF16 + if output_type_str == "q8_0": + return GGMLFileType.MostlyQ8_0 name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()} @@ -993,6 +1056,7 @@ def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path: namestr = { GGMLFileType.AllF32: "f32", GGMLFileType.MostlyF16: "f16", + GGMLFileType.MostlyQ8_0:"q8_0", }[file_type] ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf" if ret in model_paths: @@ -1016,7 +1080,7 @@ def main(args_in: Optional[List[str]] = None) -> None: parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") - parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)") + parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format (default: based on input)") parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") @@ -1043,6 +1107,7 @@ def main(args_in: Optional[List[str]] = None) -> None: params.ftype = { "f32": GGMLFileType.AllF32, "f16": GGMLFileType.MostlyF16, + "q8_0": GGMLFileType.MostlyQ8_0, }[args.outtype] print(f"params = {params}") @@ -1074,7 +1139,7 @@ def main(args_in: Optional[List[str]] = None) -> None: params.ftype = ftype print(f"Writing {outfile}, format {ftype}") - OutputFile.write_all(outfile, params, model, vocab) + OutputFile.write_all(outfile, ftype, params, model, vocab) print(f"Wrote {outfile}") From 3efcbb8f59af2dd4212d63d2e0f64b05a6bcdd28 Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Wed, 23 Aug 2023 19:36:55 -0600 Subject: [PATCH 2/4] Add --concurrency option Minor improvements to help text Clean up bounded_parallel_map function a bit --- convert.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/convert.py b/convert.py index 8a888849f35ad..794f28e41f0bb 100755 --- a/convert.py +++ b/convert.py @@ -39,6 +39,7 @@ ARCH=gguf.MODEL_ARCH.LLAMA NAMES=gguf.MODEL_TENSOR_NAMES[ARCH] +DEFAULT_CONCURRENCY = 8 # # data types # @@ -722,21 +723,21 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc with factory(max_workers = max_workers) as executor: futures: List[concurrent.futures.Future[Out]] = [] done = False - for i in range(concurrency): + for _ in range(concurrency): try: - nexti = next(iterable) + futures.append(executor.submit(func, next(iterable))) except StopIteration: + done = True break - futures.append(executor.submit(func, nexti)) - while not done or futures: + + while futures: result = futures.pop(0).result() - while len(futures) < concurrency: + while not done and len(futures) < concurrency: try: - nexti = next(iterable) + futures.append(executor.submit(func, next(iterable))) except StopIteration: done = True break - futures.append(executor.submit(func, nexti)) yield result def check_vocab_size(params: Params, vocab: Vocab) -> None: @@ -857,13 +858,13 @@ def do_item(item: Tuple[str, LazyTensor]) -> (DataType, NDArray): return (lazy_tensor.data_type, tensor.ndarray) @staticmethod - def maybe_do_quant(item: Tuple[DataType, NDArray]) -> NDArray: + def maybe_do_quantize(item: Tuple[DataType, NDArray]) -> NDArray: if item[0] == DT_Q8_0: return quantize_array_q8_0(item[1]) return item[1] @staticmethod - def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab) -> None: + def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None: check_vocab_size(params, vocab) of = OutputFile(fname_out) @@ -880,11 +881,11 @@ def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyM of.write_tensor_info() # tensor data - ndarrays = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = 8) + ndarrays = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency) if ftype == GGMLFileType.MostlyQ8_0: - ndarrays = bounded_parallel_map(OutputFile.maybe_do_quant, ndarrays, concurrency = 8, max_workers = 8, factory = ProcessPoolExecutor) + ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor) else: - ndarrays = map(OutputFile.maybe_do_quant, ndarrays) + ndarrays = map(OutputFile.maybe_do_quantize, ndarrays) start = time.time() for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): @@ -1080,12 +1081,13 @@ def main(args_in: Optional[List[str]] = None) -> None: parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") - parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format (default: based on input)") + parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)") parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm") parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") + parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY) args = parser.parse_args(args_in) if args.dump_single: @@ -1139,7 +1141,7 @@ def main(args_in: Optional[List[str]] = None) -> None: params.ftype = ftype print(f"Writing {outfile}, format {ftype}") - OutputFile.write_all(outfile, ftype, params, model, vocab) + OutputFile.write_all(outfile, ftype, params, model, vocab, concurrency = args.concurrency) print(f"Wrote {outfile}") From 8ee186c0aa6919f6ea48c70e386cdbe0cbc9e7ab Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Thu, 24 Aug 2023 12:13:00 -0600 Subject: [PATCH 3/4] Massive speed improvement thanks to Cebtenzzre --- convert.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/convert.py b/convert.py index 794f28e41f0bb..de650dd367a9a 100755 --- a/convert.py +++ b/convert.py @@ -764,13 +764,15 @@ def quantize_array_q8_0(arr): assert arr.dtype == np.float32, f'Bad array type {arr.dtype}' n_blocks = arr.size // QK8_0 blocks = arr.reshape((n_blocks, QK8_0)) - return np.fromiter(map(quantize_block_q8_0, blocks), count = n_blocks, dtype = BLOCK_Q8_0) + return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = BLOCK_Q8_0) -def quantize_block_q8_0(blk, zero = np.float32(0), one = np.float32(1), onetwentyseven = np.float32(127), zero_chunk = (np.int8(0),) * QK8_0): - d = abs(blk).max() / onetwentyseven - if d == zero: - return (np.float16(d), zero_chunk) - return (np.float16(d), (blk * (one / d)).round()) +# Much faster implementation of block quantization contributed by @Cebtenzzre +def quantize_blocks_q8_0(blocks): + d = abs(blocks).max(axis = 1) / np.float32(127) + with np.errstate(divide = 'ignore'): + qs = (blocks / d[:, None]).round() + qs[d == 0] = 0 + yield from zip(np.float16(d), qs) class OutputFile: @@ -892,7 +894,7 @@ def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyM elapsed = time.time() - start size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) padi = len(str(len(model))) - print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:6} | T+{int(elapsed):4}") + print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}") of.gguf.write_tensor_data(ndarray) of.close() From 5f23d41faaf35ed652fb9a3f6b5c22ef4c457940 Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Sat, 26 Aug 2023 09:23:06 -0600 Subject: [PATCH 4/4] Refactor types --- convert.py | 177 +++++++++++++++++++++++++---------------------------- 1 file changed, 84 insertions(+), 93 deletions(-) diff --git a/convert.py b/convert.py index de650dd367a9a..e66b233db1427 100755 --- a/convert.py +++ b/convert.py @@ -25,7 +25,7 @@ from abc import ABCMeta, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, Union) +from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, TypeVar, Union) from sentencepiece import SentencePieceProcessor # type: ignore if TYPE_CHECKING: @@ -45,31 +45,64 @@ # @dataclass(frozen=True) -class UnquantizedDataType: +class DataType: name: str + dtype: 'np.dtype[Any]' + valid_conversions: List[str] -DT_F16 = UnquantizedDataType('F16') -DT_F32 = UnquantizedDataType('F32') -DT_I32 = UnquantizedDataType('I32') -DT_BF16 = UnquantizedDataType('BF16') + def elements_to_bytes(self, n_elements: int) -> int: + return n_elements * self.dtype.itemsize @dataclass(frozen=True) -class QuantizedDataType: - name: str +class UnquantizedDataType(DataType): + pass -DT_Q8_0 = QuantizedDataType('Q8_0') +DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0']) +DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0']) +DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = []) +DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0']) -DataType = Union[UnquantizedDataType, QuantizedDataType] +@dataclass(frozen=True) +class QuantizedDataType(DataType): + block_size: int + quantized_dtype: 'np.dtype[Any]' + ggml_type: gguf.GGMLQuantizationType -DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = { - DT_BF16: np.dtype(np.uint16), - DT_F16: np.dtype(np.float16), - DT_F32: np.dtype(np.float32), - DT_I32: np.dtype(np.int32), -} + def quantize(self, arr: NDArray) -> NDArray: + raise NotImplementedError(f'Quantization for {self.name} not implemented') -NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \ - {dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()} + def elements_to_bytes(self, n_elements: int) -> int: + assert n_elements % self.block_size == 0, f'Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}' + return self.quantized_dtype.itemsize * (n_elements // self.block_size) + +@dataclass(frozen=True) +class Q8_0QuantizedDataType(QuantizedDataType): + # Mini Q8_0 quantization in Python! + def quantize(self, arr: NDArray) -> NDArray: + assert arr.size % self.block_size == 0 and arr.size != 0, f'Bad array size {arr.size}' + assert arr.dtype == np.float32, f'Bad array type {arr.dtype}' + n_blocks = arr.size // self.block_size + blocks = arr.reshape((n_blocks, self.block_size)) + # Much faster implementation of block quantization contributed by @Cebtenzzre + def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[Tuple[Any, Any]]: + d = abs(blocks).max(axis = 1) / np.float32(127) + with np.errstate(divide = 'ignore'): + qs = (blocks / d[:, None]).round() + qs[d == 0] = 0 + yield from zip(d, qs) + return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = self.quantized_dtype) + +DT_Q8_0 = Q8_0QuantizedDataType('Q8_0', + dtype = np.dtype(np.float32), valid_conversions = [], + ggml_type = gguf.GGMLQuantizationType.Q8_0, block_size = 32, + quantized_dtype = np.dtype([('d', ' DataType: - if len(tensor.shape) == 1: - # 1D tensors are always F32. - return DT_F32 - elif self == GGMLFileType.AllF32: - return DT_F32 - elif self == GGMLFileType.MostlyF16: - return DT_F16 - elif self == GGMLFileType.MostlyQ8_0: - return DT_Q8_0 - else: + dt = GGML_FILE_TYPE_TO_DATA_TYPE.get(self) + if dt is None: raise ValueError(self) + # 1D tensors are always F32. + return dt if len(tensor.shape) > 1 else DT_F32 +GGML_FILE_TYPE_TO_DATA_TYPE: Dict[GGMLFileType, DataType] = { + GGMLFileType.AllF32 : DT_F32, + GGMLFileType.MostlyF16 : DT_F16, + GGMLFileType.MostlyQ8_0: DT_Q8_0, +} # # hparams loading @@ -403,10 +435,7 @@ def __init__(self, ndarray: NDArray) -> None: self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype] def astype(self, data_type: DataType) -> Tensor: - if data_type == DT_Q8_0: - dtype = DATA_TYPE_TO_NUMPY[DT_F32] - else: - dtype = DATA_TYPE_TO_NUMPY[data_type] + dtype = data_type.dtype if self.data_type == DT_BF16: self.ndarray = bf16_to_fp32(self.ndarray) return UnquantizedTensor(self.ndarray.astype(dtype)) @@ -445,22 +474,6 @@ def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, conv GGMLCompatibleTensor = Union[UnquantizedTensor] -class DeferredPermutedTensor(Tensor): - def __init__(self, base: Tensor, n_head: int, n_head_kv: int) -> None: - self.base = base - self.n_head = n_head - self.data_type = self.base.data_type - - def astype(self, data_type: DataType) -> Tensor: - return self.base.astype(data_type).permute(self.n_head, self.n_head_kv) - - def to_ggml(self) -> GGMLCompatibleTensor: - return self.base.to_ggml().permute(self.n_head, self.n_head_kv) - - def permute(self, n_head: int, n_head_kv: int) -> Tensor: - raise Exception("shouldn't permute twice") - - @dataclass class LazyTensor: _load: Callable[[], Tensor] @@ -470,7 +483,9 @@ class LazyTensor: def load(self) -> Tensor: ret = self._load() - assert ret.data_type == self.data_type or (self.data_type is DT_Q8_0 and ret.data_type is DT_F32), (self.data_type, ret.data_type, self.description) + # Should be okay if it maps to the same numpy type? + assert ret.data_type == self.data_type or (self.data_type.dtype == ret.data_type.dtype), \ + (self.data_type, ret.data_type, self.description) return ret def astype(self, data_type: DataType) -> 'LazyTensor': @@ -481,8 +496,8 @@ def load() -> Tensor: return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}') def validate_conversion_to(self, data_type: DataType) -> None: - if data_type == self.data_type: - return + if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions: + raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.') LazyModel = Dict[str, LazyTensor] @@ -608,9 +623,7 @@ def persistent_load(self, pid: Any) -> Any: info = self.zip_file.getinfo(filename) def load(offset: int, elm_count: int) -> NDArray: - dtype = DATA_TYPE_TO_NUMPY.get(data_type) - if dtype is None: - raise Exception("tensor stored in unsupported format") + dtype = data_type.dtype fp = self.zip_file.open(info) fp.seek(offset * dtype.itemsize) size = elm_count * dtype.itemsize @@ -674,7 +687,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: def convert(info: Dict[str, Any]) -> LazyTensor: data_type = SAFETENSORS_DATA_TYPES[info['dtype']] - numpy_dtype = DATA_TYPE_TO_NUMPY[data_type] + numpy_dtype = data_type.dtype shape: List[int] = info['shape'] begin, end = info['data_offsets'] assert 0 <= begin <= end <= len(byte_buf) @@ -719,6 +732,9 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc fast enough, this will stop calling `func` at some point rather than letting results pile up in memory. Specifically, there is a max of one output value buffered per thread.''' + if concurrency < 2: + yield from map(func, iterable) + # Not reached. iterable = iter(iterable) with factory(max_workers = max_workers) as executor: futures: List[concurrent.futures.Future[Out]] = [] @@ -756,24 +772,6 @@ def check_vocab_size(params: Params, vocab: Vocab) -> None: msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})." raise Exception(msg) -#### Mini Q8_0 quantization in Python -QK8_0 = 32 -BLOCK_Q8_0 = np.dtype([('d', ' None: @@ -816,18 +814,10 @@ def add_meta_vocab(self, vocab: Vocab) -> None: self.gguf.add_token_types(toktypes) def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: - n_elements = 1 - for dim in tensor.shape: - n_elements *= dim - if tensor.data_type == DT_Q8_0: - assert n_elements > 0 and n_elements % QK8_0 == 0, f'Cannot quantize as Q8_0, {n_elements} not a multiple of block size {QK8_0}' - data_type= BLOCK_Q8_0 - raw_dtype = gguf.GGMLQuantizationType.Q8_0 - data_nbytes = n_elements + (n_elements // QK8_0) * 2 - else: - data_type = DATA_TYPE_TO_NUMPY[tensor.data_type] - data_nbytes = n_elements * data_type.itemsize - raw_dtype = None + n_elements = int(np.prod(tensor.shape)) + raw_dtype = getattr(tensor.data_type, 'ggml_type', None) + data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype + data_nbytes = tensor.data_type.elements_to_bytes(n_elements) self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype = raw_dtype) def write_meta(self) -> None: @@ -854,16 +844,17 @@ def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None: of.close() @staticmethod - def do_item(item: Tuple[str, LazyTensor]) -> (DataType, NDArray): + def do_item(item: Tuple[str, LazyTensor]) -> Tuple[DataType, NDArray]: name, lazy_tensor = item tensor = lazy_tensor.load().to_ggml() return (lazy_tensor.data_type, tensor.ndarray) @staticmethod def maybe_do_quantize(item: Tuple[DataType, NDArray]) -> NDArray: - if item[0] == DT_Q8_0: - return quantize_array_q8_0(item[1]) - return item[1] + dt, arr = item + if not isinstance(dt, QuantizedDataType): + return arr + return dt.quantize(arr) @staticmethod def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None: @@ -883,11 +874,11 @@ def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyM of.write_tensor_info() # tensor data - ndarrays = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency) + ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency) if ftype == GGMLFileType.MostlyQ8_0: - ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor) + ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor) else: - ndarrays = map(OutputFile.maybe_do_quantize, ndarrays) + ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner) start = time.time() for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): @@ -954,7 +945,7 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel: print(f"skipping tensor {name_new}") continue else: - print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}") + print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}") out[name_new] = lazy_tensor return out