Skip to content

Commit 03cc9bc

Browse files
use simplification from ggml-org#7827
1 parent 666bb09 commit 03cc9bc

File tree

1 file changed

+16
-29
lines changed

1 file changed

+16
-29
lines changed

gguf-py/gguf/gguf_writer_split.py

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class GGUFWriterSplit(GGUFWriter):
6161
kv_data: KVTempData
6262
split_arguments: SplitArguments
6363
shards: list[Shard]
64-
shard_writers: list[GGUFWriter]
64+
shard_writers: list[tuple[GGUFWriter, os.PathLike[str]]]
6565

6666
def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: SplitArguments,
6767
use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE
@@ -115,17 +115,15 @@ def init_shards(self) -> None:
115115
logger.info("Dry run, not writing files")
116116
exit()
117117

118-
# we don't want to initialize GGUFWriters until now because they create files
119118
for i, shard in enumerate(self.shards):
120119
# add_architecture is used for consistency - examples/gguf_split doesn't add arch to all shards
121-
writer = GGUFWriter(shard.path, self.arch, use_temp_file=self.use_temp_file,
120+
writer = GGUFWriter(None, self.arch, use_temp_file=self.use_temp_file,
122121
endianess=self.endianess, add_architecture=(i == 0))
123122

124123
# only the first shard needs all the KV data
125124
if i == 0:
126125
for key, (value, etype) in self.kv_data.items():
127-
writer.add_key(key)
128-
writer.add_val(value, etype)
126+
writer.add_key_value(key, value, etype)
129127

130128
# add split metadata unless it's one file - small first shard splits even with SplitStyle.NONE
131129
if self.split_arguments.split_style != SplitStyle.NONE or self.split_arguments.small_first_shard:
@@ -141,22 +139,22 @@ def init_shards(self) -> None:
141139
except IndexError:
142140
break
143141

144-
self.shard_writers.append(writer)
142+
self.shard_writers.append((writer, shard.path))
145143

146-
def write_header_to_file(self) -> None:
144+
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
147145
if self.state is not WriterState.EMPTY:
148146
raise ValueError(f'Expected GGUFWriterSplit state to be EMPTY, got {self.state}')
149147

150-
for writer in self.shard_writers:
151-
writer.write_header_to_file()
148+
for (writer, path) in self.shard_writers:
149+
writer.write_header_to_file(path)
152150

153151
self.state = WriterState.HEADER
154152

155153
def write_kv_data_to_file(self) -> None:
156154
if self.state is not WriterState.HEADER:
157155
raise ValueError(f'Expected GGUFWriterSplit state to be HEADER, got {self.state}')
158156

159-
for writer in self.shard_writers:
157+
for (writer, _) in self.shard_writers:
160158
writer.write_kv_data_to_file()
161159

162160
self.state = WriterState.KV_DATA
@@ -167,32 +165,21 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
167165

168166
running_total = self.total_tensors
169167
for i in range(len(self.shard_writers)):
170-
writer = self.shard_writers[i]
171-
is_metadata = writer.ti_data_count == 0
168+
writer = self.shard_writers[i][0]
169+
is_metadata = len(writer.tensors) == 0
172170
if is_metadata:
173171
logger.info(f"Writing to shard {i + 1}/{len(self.shards)} with metadata only")
174172
else:
175-
logger.info(f"Writing to shard {i + 1}/{len(self.shards)} with {writer.ti_data_count}/{running_total} remaining tensors (of {self.total_tensors} total)")
176-
running_total -= writer.ti_data_count
173+
logger.info(f"Writing to shard {i + 1}/{len(self.shards)} with {len(writer.tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)")
174+
running_total -= len(writer.tensors)
177175
writer.write_tensors_to_file(progress=(progress and not is_metadata))
178176
del writer
179177

180178
self.state = WriterState.TI_DATA
181179

182-
# override add_key, add_val to handle kv data separately
183-
def add_key(self, key: str) -> None:
184-
self.recent_key = key
185-
186-
def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None:
187-
if self.recent_key is None:
188-
raise ValueError("No key set for value")
189-
self.kv_data[self.recent_key] = (val, vtype)
190-
191-
# need to handle arrays separately
192-
def add_array(self, key: str, val: Sequence[Any]) -> None:
193-
if not isinstance(val, Sequence):
194-
raise ValueError(f'Expected a sequence for {key}, got {type(val)}')
195-
self.kv_data[key] = (val, GGUFValueType.ARRAY)
180+
# override add_key_value to handle kv data separately
181+
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
182+
self.kv_data[key] = (val, vtype)
196183

197184
def add_tensor(
198185
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
@@ -218,7 +205,7 @@ def add_tensor(
218205
self.shards[-1].tensors.append((name, tensor, raw_dtype))
219206

220207
def close(self) -> None:
221-
for writer in self.shard_writers:
208+
for (writer, _) in self.shard_writers:
222209
writer.close()
223210

224211
@staticmethod

0 commit comments

Comments
 (0)