Skip to content

Commit 3dfaa1f

Browse files
compiladeteleprint-me
authored andcommitted
convert-hf : support direct Q8_0 conversion (ggml-org#7234)
* convert-hf : support q8_0 conversion * convert-hf : add missing ftype This was messing with the checksums otherwise. * convert-hf : add missing ftype to Baichuan and Xverse I didn't notice these on my first pass.
1 parent 7d85ea8 commit 3dfaa1f

File tree

5 files changed

+169
-58
lines changed

5 files changed

+169
-58
lines changed

convert-hf-to-gguf.py

Lines changed: 28 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -250,23 +250,6 @@ def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: i
250250
return False
251251

252252
def write_tensors(self):
253-
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
254-
def np_fp32_to_bf16(n: np.ndarray):
255-
# force nan to quiet
256-
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
257-
# flush subnormals to zero
258-
n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
259-
# round to nearest even
260-
n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
261-
return n.astype(np.int16)
262-
263-
# Doing this row-wise is much, much faster than element-wise, hence the signature
264-
v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)")
265-
if self.lazy:
266-
# TODO: find a way to implicitly wrap np.vectorize functions
267-
# NOTE: the type is changed to reflect otypes passed to np.vectorize above
268-
v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16)
269-
270253
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
271254

272255
for name, data_torch in self.get_tensors():
@@ -319,27 +302,31 @@ def np_fp32_to_bf16(n: np.ndarray):
319302
))
320303

321304
if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
322-
if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
305+
if self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
306+
data = gguf.quantize_bf16(data)
307+
assert data.dtype == np.int16
308+
data_qtype = gguf.GGMLQuantizationType.BF16
309+
310+
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data):
311+
data = gguf.quantize_q8_0(data)
312+
assert data.dtype == np.uint8
313+
data_qtype = gguf.GGMLQuantizationType.Q8_0
314+
315+
else: # default to float16 for quantized tensors
323316
if data_dtype != np.float16:
324317
data = data.astype(np.float16)
325318
data_qtype = gguf.GGMLQuantizationType.F16
326319

327-
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
328-
if data_dtype != np.float32:
329-
data = data.astype(np.float32)
330-
data = v_fp32_to_bf16(data.view(np.int32))
331-
assert data.dtype == np.int16
332-
data_qtype = gguf.GGMLQuantizationType.BF16
333-
334-
else: # by default, convert to float32
320+
if data_qtype is None: # by default, convert to float32
335321
if data_dtype != np.float32:
336322
data = data.astype(np.float32)
337323
data_qtype = gguf.GGMLQuantizationType.F32
338324

339-
assert data_qtype is not None
340-
325+
block_size, type_size = gguf.GGML_QUANT_SIZES[data_qtype]
341326
# reverse shape to make it similar to the internal ggml dimension order
342-
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
327+
shape_str = f"""{{{', '.join(str(n) for n in reversed(
328+
(*data.shape[:-1], data.shape[-1] * data.dtype.itemsize // type_size * block_size))
329+
)}}}"""
343330

344331
# n_dims is implicit in the shape
345332
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
@@ -881,6 +868,7 @@ def set_gguf_parameters(self):
881868
self.gguf_writer.add_head_count(head_count)
882869
self.gguf_writer.add_head_count_kv(head_count_kv)
883870
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
871+
self.gguf_writer.add_file_type(self.ftype)
884872

885873
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
886874
if self.hparams["rope_scaling"].get("type") == "linear":
@@ -1003,6 +991,7 @@ def set_gguf_parameters(self):
1003991
self.gguf_writer.add_head_count(head_count)
1004992
self.gguf_writer.add_head_count_kv(head_count_kv)
1005993
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
994+
self.gguf_writer.add_file_type(self.ftype)
1006995

1007996
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
1008997
if self.hparams["rope_scaling"].get("type") == "linear":
@@ -1237,6 +1226,7 @@ def set_gguf_parameters(self):
12371226
self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])
12381227
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
12391228
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"]))
1229+
self.gguf_writer.add_file_type(self.ftype)
12401230

12411231
_q_norms: list[dict[str, Tensor]] | None = None
12421232
_k_norms: list[dict[str, Tensor]] | None = None
@@ -1613,6 +1603,7 @@ def set_gguf_parameters(self):
16131603
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
16141604
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
16151605
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
1606+
self.gguf_writer.add_file_type(self.ftype)
16161607

16171608

16181609
@Model.register("Qwen2ForCausalLM")
@@ -1850,6 +1841,7 @@ def set_gguf_parameters(self):
18501841
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
18511842
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
18521843
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
1844+
self.gguf_writer.add_file_type(self.ftype)
18531845

18541846
def shuffle_attn_q_weight(self, data_torch):
18551847
assert data_torch.size() == (5120, 5120)
@@ -2029,6 +2021,7 @@ def set_gguf_parameters(self):
20292021
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
20302022
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
20312023
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
2024+
self.gguf_writer.add_file_type(self.ftype)
20322025

20332026
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
20342027
num_heads = self.hparams["num_attention_heads"]
@@ -2437,25 +2430,15 @@ class LazyTorchTensor(gguf.LazyBase):
24372430
def numpy(self) -> gguf.LazyNumpyTensor:
24382431
dtype = self._dtype_map[self.dtype]
24392432
return gguf.LazyNumpyTensor(
2440-
meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)),
2433+
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
24412434
lazy=self._lazy,
24422435
args=(self,),
24432436
func=(lambda s: s[0].numpy())
24442437
)
24452438

24462439
@classmethod
2447-
def eager_to_meta(cls, t: Tensor) -> Tensor:
2448-
if t.is_meta:
2449-
return t
2450-
return t.detach().to("meta")
2451-
2452-
@classmethod
2453-
def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor:
2454-
m = m.detach()
2455-
if not m.is_meta:
2456-
m = m.to("meta")
2457-
m.dtype = dtype
2458-
return m
2440+
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: torch.Size) -> Tensor:
2441+
return torch.empty(size=shape, dtype=dtype, device="meta")
24592442

24602443
@classmethod
24612444
def __torch_function__(cls, func, types, args=(), kwargs=None):
@@ -2486,8 +2469,8 @@ def parse_args() -> argparse.Namespace:
24862469
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
24872470
)
24882471
parser.add_argument(
2489-
"--outtype", type=str, choices=["f32", "f16", "bf16", "auto"], default="f16",
2490-
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
2472+
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
2473+
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
24912474
)
24922475
parser.add_argument(
24932476
"--bigendian", action="store_true",
@@ -2545,6 +2528,7 @@ def main() -> None:
25452528
"f32": gguf.LlamaFileType.ALL_F32,
25462529
"f16": gguf.LlamaFileType.MOSTLY_F16,
25472530
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
2531+
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
25482532
"auto": gguf.LlamaFileType.GUESSED,
25492533
}
25502534

gguf-py/gguf/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from .lazy import *
33
from .gguf_reader import *
44
from .gguf_writer import *
5+
from .quants import *
56
from .tensor_mapping import *
67
from .vocab import *

gguf-py/gguf/gguf_writer.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import numpy as np
1414

1515
from .constants import (
16+
GGML_QUANT_SIZES,
1617
GGUF_DEFAULT_ALIGNMENT,
1718
GGUF_MAGIC,
1819
GGUF_VERSION,
@@ -195,7 +196,7 @@ def ggml_pad(x: int, n: int) -> int:
195196
return ((x + n - 1) // n) * n
196197

197198
def add_tensor_info(
198-
self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32],
199+
self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype,
199200
tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
200201
) -> None:
201202
if self.state is not WriterState.EMPTY:
@@ -208,10 +209,6 @@ def add_tensor_info(
208209
encoded_name = name.encode("utf-8")
209210
self.ti_data += self._pack("Q", len(encoded_name))
210211
self.ti_data += encoded_name
211-
n_dims = len(tensor_shape)
212-
self.ti_data += self._pack("I", n_dims)
213-
for i in range(n_dims):
214-
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
215212
if raw_dtype is None:
216213
if tensor_dtype == np.float16:
217214
dtype = GGMLQuantizationType.F16
@@ -231,6 +228,15 @@ def add_tensor_info(
231228
raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now")
232229
else:
233230
dtype = raw_dtype
231+
if tensor_dtype == np.uint8:
232+
block_size, type_size = GGML_QUANT_SIZES[raw_dtype]
233+
if tensor_shape[-1] % type_size != 0:
234+
raise ValueError(f"Quantized tensor row size ({tensor_shape[-1]}) is not a multiple of {dtype.name} type size ({type_size})")
235+
tensor_shape = tuple(tensor_shape[:-1]) + (tensor_shape[-1] // type_size * block_size,)
236+
n_dims = len(tensor_shape)
237+
self.ti_data += self._pack("I", n_dims)
238+
for i in range(n_dims):
239+
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
234240
self.ti_data += self._pack("I", dtype)
235241
self.ti_data += self._pack("Q", self.offset_tensor)
236242
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)

gguf-py/gguf/lazy.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections import deque
77

88
import numpy as np
9+
from numpy._typing import _Shape
910
from numpy.typing import DTypeLike
1011

1112

@@ -110,7 +111,7 @@ def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
110111
return o
111112

112113
@classmethod
113-
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike = False) -> Callable[[Any], Any]:
114+
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]:
114115
def wrapped_fn(*args, **kwargs):
115116
if kwargs is None:
116117
kwargs = {}
@@ -130,9 +131,14 @@ def wrapped_fn(*args, **kwargs):
130131
res = args[0]
131132
assert isinstance(res, cls)
132133
res = res._meta
133-
# allow operations to override the dtype
134+
# allow operations to override the dtype and shape
134135
if meta_noop is not True:
135-
res = cls.meta_with_dtype(res, meta_noop)
136+
if isinstance(meta_noop, tuple):
137+
dtype, shape = meta_noop
138+
assert callable(shape)
139+
res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape))
140+
else:
141+
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
136142

137143
if isinstance(res, cls._tensor_type):
138144
def collect_replace(t: LazyBase):
@@ -168,7 +174,12 @@ def already_eager_to_eager(_t: LazyBase) -> Any:
168174
while _t._data is None:
169175
lt = _t._lazy.popleft()
170176
if lt._data is not None:
171-
raise ValueError(f"{lt} did not belong in the lazy queue")
177+
# Lazy tensor did not belong in the lazy queue.
178+
# Weirdly only happens with Bloom models...
179+
# likely because tensors aren't unique in the queue.
180+
# The final output is still the same as in eager mode,
181+
# so it's safe to ignore this.
182+
continue
172183
assert lt._func is not None
173184
lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
174185
lt._data = lt._func(lt._args)
@@ -183,12 +194,12 @@ def already_eager_to_eager(_t: LazyBase) -> Any:
183194

184195
@classmethod
185196
def eager_to_meta(cls, t: Any) -> Any:
186-
return cls.meta_with_dtype(t, t.dtype)
197+
return cls.meta_with_dtype_and_shape(t.dtype, t.shape)
187198

188199
# must be overridden, meta tensor init is backend-specific
189200
@classmethod
190201
@abstractmethod
191-
def meta_with_dtype(cls, m: Any, dtype: Any) -> Any: pass
202+
def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: pass
192203

193204
@classmethod
194205
def from_eager(cls, t: Any) -> Any:
@@ -205,15 +216,15 @@ class LazyNumpyTensor(LazyBase):
205216
_tensor_type = np.ndarray
206217

207218
@classmethod
208-
def meta_with_dtype(cls, m: np.ndarray[Any, Any], dtype: DTypeLike) -> np.ndarray[Any, Any]:
219+
def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: _Shape) -> np.ndarray[Any, Any]:
209220
# The initial idea was to use np.nan as the fill value,
210221
# but non-float types like np.int16 can't use that.
211222
# So zero it is.
212223
cheat = np.zeros(1, dtype)
213-
return np.lib.stride_tricks.as_strided(cheat, m.shape, (0 for _ in m.shape))
224+
return np.lib.stride_tricks.as_strided(cheat, shape, (0 for _ in shape))
214225

215226
def astype(self, dtype, *args, **kwargs):
216-
meta = type(self).meta_with_dtype(self._meta, dtype)
227+
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
217228
full_args = (self, dtype,) + args
218229
# very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
219230
return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))

0 commit comments

Comments
 (0)