@@ -61,7 +61,7 @@ class GGUFWriterSplit(GGUFWriter):
61
61
kv_data : KVTempData
62
62
split_arguments : SplitArguments
63
63
shards : list [Shard ]
64
- shard_writers : list [GGUFWriter ]
64
+ shard_writers : list [tuple [ GGUFWriter , os . PathLike [ str ]] ]
65
65
66
66
def __init__ (self , path : os .PathLike [str ] | str , arch : str , split_arguments : SplitArguments ,
67
67
use_temp_file : bool = True , endianess : GGUFEndian = GGUFEndian .LITTLE
@@ -115,17 +115,15 @@ def init_shards(self) -> None:
115
115
logger .info ("Dry run, not writing files" )
116
116
exit ()
117
117
118
- # we don't want to initialize GGUFWriters until now because they create files
119
118
for i , shard in enumerate (self .shards ):
120
119
# add_architecture is used for consistency - examples/gguf_split doesn't add arch to all shards
121
- writer = GGUFWriter (shard . path , self .arch , use_temp_file = self .use_temp_file ,
120
+ writer = GGUFWriter (None , self .arch , use_temp_file = self .use_temp_file ,
122
121
endianess = self .endianess , add_architecture = (i == 0 ))
123
122
124
123
# only the first shard needs all the KV data
125
124
if i == 0 :
126
125
for key , (value , etype ) in self .kv_data .items ():
127
- writer .add_key (key )
128
- writer .add_val (value , etype )
126
+ writer .add_key_value (key , value , etype )
129
127
130
128
# add split metadata unless it's one file - small first shard splits even with SplitStyle.NONE
131
129
if self .split_arguments .split_style != SplitStyle .NONE or self .split_arguments .small_first_shard :
@@ -141,22 +139,22 @@ def init_shards(self) -> None:
141
139
except IndexError :
142
140
break
143
141
144
- self .shard_writers .append (writer )
142
+ self .shard_writers .append (( writer , shard . path ) )
145
143
146
- def write_header_to_file (self ) -> None :
144
+ def write_header_to_file (self , path : os . PathLike [ str ] | str | None = None ) -> None :
147
145
if self .state is not WriterState .EMPTY :
148
146
raise ValueError (f'Expected GGUFWriterSplit state to be EMPTY, got { self .state } ' )
149
147
150
- for writer in self .shard_writers :
151
- writer .write_header_to_file ()
148
+ for ( writer , path ) in self .shard_writers :
149
+ writer .write_header_to_file (path )
152
150
153
151
self .state = WriterState .HEADER
154
152
155
153
def write_kv_data_to_file (self ) -> None :
156
154
if self .state is not WriterState .HEADER :
157
155
raise ValueError (f'Expected GGUFWriterSplit state to be HEADER, got { self .state } ' )
158
156
159
- for writer in self .shard_writers :
157
+ for ( writer , _ ) in self .shard_writers :
160
158
writer .write_kv_data_to_file ()
161
159
162
160
self .state = WriterState .KV_DATA
@@ -167,32 +165,21 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
167
165
168
166
running_total = self .total_tensors
169
167
for i in range (len (self .shard_writers )):
170
- writer = self .shard_writers [i ]
171
- is_metadata = writer .ti_data_count == 0
168
+ writer = self .shard_writers [i ][ 0 ]
169
+ is_metadata = len ( writer .tensors ) == 0
172
170
if is_metadata :
173
171
logger .info (f"Writing to shard { i + 1 } /{ len (self .shards )} with metadata only" )
174
172
else :
175
- logger .info (f"Writing to shard { i + 1 } /{ len (self .shards )} with { writer .ti_data_count } /{ running_total } remaining tensors (of { self .total_tensors } total)" )
176
- running_total -= writer .ti_data_count
173
+ logger .info (f"Writing to shard { i + 1 } /{ len (self .shards )} with { len ( writer .tensors ) } /{ running_total } remaining tensors (of { self .total_tensors } total)" )
174
+ running_total -= len ( writer .tensors )
177
175
writer .write_tensors_to_file (progress = (progress and not is_metadata ))
178
176
del writer
179
177
180
178
self .state = WriterState .TI_DATA
181
179
182
- # override add_key, add_val to handle kv data separately
183
- def add_key (self , key : str ) -> None :
184
- self .recent_key = key
185
-
186
- def add_val (self , val : Any , vtype : GGUFValueType | None = None , add_vtype : bool = True ) -> None :
187
- if self .recent_key is None :
188
- raise ValueError ("No key set for value" )
189
- self .kv_data [self .recent_key ] = (val , vtype )
190
-
191
- # need to handle arrays separately
192
- def add_array (self , key : str , val : Sequence [Any ]) -> None :
193
- if not isinstance (val , Sequence ):
194
- raise ValueError (f'Expected a sequence for { key } , got { type (val )} ' )
195
- self .kv_data [key ] = (val , GGUFValueType .ARRAY )
180
+ # override add_key_value to handle kv data separately
181
+ def add_key_value (self , key : str , val : Any , vtype : GGUFValueType ) -> None :
182
+ self .kv_data [key ] = (val , vtype )
196
183
197
184
def add_tensor (
198
185
self , name : str , tensor : np .ndarray [Any , Any ], raw_shape : Sequence [int ] | None = None ,
@@ -218,7 +205,7 @@ def add_tensor(
218
205
self .shards [- 1 ].tensors .append ((name , tensor , raw_dtype ))
219
206
220
207
def close (self ) -> None :
221
- for writer in self .shard_writers :
208
+ for ( writer , _ ) in self .shard_writers :
222
209
writer .close ()
223
210
224
211
@staticmethod
0 commit comments