Skip to content

convert_hf : faster lazy safetensors #8482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,16 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
tensor_names_from_parts.update(model_part.keys())

for name in model_part.keys():
data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
if self.lazy:
data = LazyTorchTensor.from_eager(data)
if self.is_safetensors:
if self.lazy:
data = model_part.get_slice(name)
data = LazyTorchTensor.from_safetensors_slice(data)
else:
data = model_part.get_tensor(name)
else:
data = model_part[name]
if self.lazy:
data = LazyTorchTensor.from_eager(data)
yield name, data

# only verify tensor name presence; it doesn't matter if they are not in the right files
Expand Down Expand Up @@ -3424,19 +3431,46 @@ class LazyTorchTensor(gguf.LazyBase):
torch.float32: np.float32,
}

# used for safetensors slices
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
_dtype_str_map: dict[str, torch.dtype] = {
"F64": torch.float64,
"F32": torch.float32,
"BF16": torch.bfloat16,
"F16": torch.float16,
# "U64": torch.uint64,
"I64": torch.int64,
# "U32": torch.uint32,
"I32": torch.int32,
# "U16": torch.uint16,
"I16": torch.int16,
"U8": torch.uint8,
"I8": torch.int8,
"BOOL": torch.bool,
"F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2,
}

def numpy(self) -> gguf.LazyNumpyTensor:
dtype = self._dtype_map[self.dtype]
return gguf.LazyNumpyTensor(
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
lazy=self._lazy,
args=(self,),
func=(lambda s: s[0].numpy())
func=(lambda s: s.numpy())
)

@classmethod
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: torch.Size) -> Tensor:
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor:
return torch.empty(size=shape, dtype=dtype, device="meta")

@classmethod
def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
dtype = cls._dtype_str_map[st_slice.get_dtype()]
shape: tuple[int, ...] = tuple(st_slice.get_shape())
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
return cast(torch.Tensor, lazy)

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
del types # unused
Expand All @@ -3447,7 +3481,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
if func is torch.Tensor.numpy:
return args[0].numpy()

return LazyTorchTensor._wrap_fn(func)(*args, **kwargs)
return cls._wrap_fn(func)(*args, **kwargs)


def parse_args() -> argparse.Namespace:
Expand Down
63 changes: 18 additions & 45 deletions gguf-py/gguf/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import logging
from typing import Any, Callable
from collections import deque

import numpy as np
from numpy.typing import DTypeLike
Expand Down Expand Up @@ -74,20 +73,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
_tensor_type: type
_meta: Any
_data: Any | None
_lazy: deque[LazyBase] # shared within a graph, to avoid deep recursion when making eager
_args: tuple
_func: Callable[[tuple], Any] | None
_kwargs: dict[str, Any]
_func: Callable[[Any], Any] | None

def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None):
def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
super().__init__()
self._meta = meta
self._data = data
self._lazy = lazy if lazy is not None else deque()
self._args = args
self._kwargs = kwargs if kwargs is not None else {}
self._func = func
assert self._func is not None or self._data is not None
if self._data is None:
self._lazy.append(self)

def __init_subclass__(cls) -> None:
if "_tensor_type" not in cls.__dict__:
Expand Down Expand Up @@ -117,6 +114,7 @@ def wrapped_fn(*args, **kwargs):
args = ((use_self,) if use_self is not None else ()) + args

meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
# TODO: maybe handle tensors in kwargs too

if isinstance(meta_noop, bool) and not meta_noop:
try:
Expand All @@ -140,23 +138,7 @@ def wrapped_fn(*args, **kwargs):
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)

if isinstance(res, cls._tensor_type):
class CollectSharedLazy:
# emulating a static variable
shared_lazy: None | deque[LazyBase] = None

@staticmethod
def collect_replace(t: LazyBase):
if CollectSharedLazy.shared_lazy is None:
CollectSharedLazy.shared_lazy = t._lazy
else:
CollectSharedLazy.shared_lazy.extend(t._lazy)
t._lazy = CollectSharedLazy.shared_lazy

LazyBase._recurse_apply(args, CollectSharedLazy.collect_replace)

shared_lazy = CollectSharedLazy.shared_lazy

return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
else:
del res # not needed
# non-tensor return likely relies on the contents of the args
Expand All @@ -168,26 +150,18 @@ def collect_replace(t: LazyBase):
@classmethod
def to_eager(cls, t: Any) -> Any:
def simple_to_eager(_t: LazyBase) -> Any:
def already_eager_to_eager(_t: LazyBase) -> Any:
assert _t._data is not None
if _t._data is not None:
return _t._data

while _t._data is None:
lt = _t._lazy.popleft()
if lt._data is not None:
# Lazy tensor did not belong in the lazy queue.
# Weirdly only happens with Bloom models...
# likely because tensors aren't unique in the queue.
# The final output is still the same as in eager mode,
# so it's safe to ignore this.
continue
assert lt._func is not None
lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
lt._data = lt._func(lt._args)
# sanity check
assert lt._data is not None
assert lt._data.dtype == lt._meta.dtype
assert lt._data.shape == lt._meta.shape
# NOTE: there's a recursion limit in Python (usually 1000)

assert _t._func is not None
_t._args = cls._recurse_apply(_t._args, simple_to_eager)
_t._data = _t._func(*_t._args, **_t._kwargs)
# sanity check
assert _t._data is not None
assert _t._data.dtype == _t._meta.dtype
assert _t._data.shape == _t._meta.shape

return _t._data

Expand All @@ -206,7 +180,7 @@ def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: pass
@classmethod
def from_eager(cls, t: Any) -> Any:
if type(t) is cls:
# already eager
# already lazy
return t
elif isinstance(t, cls._tensor_type):
return cls(meta=cls.eager_to_meta(t), data=t)
Expand All @@ -228,8 +202,7 @@ def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) ->
def astype(self, dtype, *args, **kwargs):
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
full_args = (self, dtype,) + args
# very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))

def tofile(self, *args, **kwargs):
eager = LazyNumpyTensor.to_eager(self)
Expand Down
14 changes: 6 additions & 8 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,14 +602,12 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int):
for tensor, keys in self.block_mappings_cfg.items():
if tensor not in MODEL_TENSORS[arch]:
continue
# TODO: make this configurable
n_experts = 160
for xid in range(n_experts):
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
self.mapping[tensor_name] = (tensor, tensor_name)
for key in keys:
key = key.format(bid = bid, xid = xid)
self.mapping[key] = (tensor, tensor_name)

tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
self.mapping[tensor_name] = (tensor, tensor_name)
for key in keys:
key = key.format(bid = bid)
self.mapping[key] = (tensor, tensor_name)

def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
result = self.mapping.get(key)
Expand Down
Loading