Skip to content

convert_hf : faster lazy safetensors #8482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,16 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
tensor_names_from_parts.update(model_part.keys())

for name in model_part.keys():
data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
if self.lazy:
data = LazyTorchTensor.from_eager(data)
if self.is_safetensors:
if self.lazy:
data = model_part.get_slice(name)
data = LazyTorchTensor.from_safetensors_slice(data)
else:
data = model_part.get_tensor(name)
else:
data = model_part[name]
if self.lazy:
data = LazyTorchTensor.from_eager(data)
yield name, data

# only verify tensor name presence; it doesn't matter if they are not in the right files
Expand Down Expand Up @@ -3424,6 +3431,27 @@ class LazyTorchTensor(gguf.LazyBase):
torch.float32: np.float32,
}

# used for safetensors slices
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
_dtype_str_map: dict[str, torch.dtype] = {
"F64": torch.float64,
"F32": torch.float32,
"BF16": torch.bfloat16,
"F16": torch.float16,
# "U64": torch.uint64,
"I64": torch.int64,
# "U32": torch.uint32,
"I32": torch.int32,
# "U16": torch.uint16,
"I16": torch.int16,
"U8": torch.uint8,
"I8": torch.int8,
"BOOL": torch.bool,
"F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2,
}

def numpy(self) -> gguf.LazyNumpyTensor:
dtype = self._dtype_map[self.dtype]
return gguf.LazyNumpyTensor(
Expand All @@ -3437,6 +3465,13 @@ def numpy(self) -> gguf.LazyNumpyTensor:
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: torch.Size) -> Tensor:
return torch.empty(size=shape, dtype=dtype, device="meta")

@classmethod
def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
dtype = cls._dtype_str_map[st_slice.get_dtype()]
shape = st_slice.get_shape()
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[0][:])
return cast(torch.Tensor, lazy)

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
del types # unused
Expand Down
14 changes: 6 additions & 8 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,14 +602,12 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int):
for tensor, keys in self.block_mappings_cfg.items():
if tensor not in MODEL_TENSORS[arch]:
continue
# TODO: make this configurable
n_experts = 160
for xid in range(n_experts):
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
self.mapping[tensor_name] = (tensor, tensor_name)
for key in keys:
key = key.format(bid = bid, xid = xid)
self.mapping[key] = (tensor, tensor_name)

tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
self.mapping[tensor_name] = (tensor, tensor_name)
for key in keys:
key = key.format(bid = bid)
self.mapping[key] = (tensor, tensor_name)

def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
result = self.mapping.get(key)
Expand Down
Loading