Skip to content

Commit e4d74cc

Browse files
committed
feat: add new GGUFValueType.OBJ virtual type(python)
The content of the OBJ type is actually a list of all key names of the object. * GGUFWriter: * Added `def add_kv(self, key: str, val: Any) -> None`: This will be added based on the val type * Added `def add_dict(self, key: str, val: dict) -> None`: add object(dict) values. It will recursively add all subkeys. * constants: * `GGUFValueType.get_type`: Added support for Numpy's integers and floating-point numbers, and selected the appropriate number of digits based on the size of the integer. * gguf_reader: * Added `ReaderField.get`: to return the value of the field * Unit test added.
1 parent a1d6df1 commit e4d74cc

File tree

5 files changed

+195
-6
lines changed

5 files changed

+195
-6
lines changed

gguf-py/gguf/constants.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
from enum import Enum, IntEnum, auto
55
from typing import Any
6+
import numpy as np
67

78
#
89
# constants
@@ -510,19 +511,60 @@ class GGUFValueType(IntEnum):
510511
UINT64 = 10
511512
INT64 = 11
512513
FLOAT64 = 12
514+
OBJ = 13
513515

514516
@staticmethod
515517
def get_type(val: Any) -> GGUFValueType:
516518
if isinstance(val, (str, bytes, bytearray)):
517519
return GGUFValueType.STRING
518520
elif isinstance(val, list):
519521
return GGUFValueType.ARRAY
522+
elif isinstance(val, np.float32):
523+
return GGUFValueType.FLOAT32
524+
elif isinstance(val, np.float64):
525+
return GGUFValueType.FLOAT64
520526
elif isinstance(val, float):
521527
return GGUFValueType.FLOAT32
522528
elif isinstance(val, bool):
523529
return GGUFValueType.BOOL
524-
elif isinstance(val, int):
530+
elif isinstance(val, np.uint8):
531+
return GGUFValueType.UINT8
532+
elif isinstance(val, np.uint16):
533+
return GGUFValueType.UINT16
534+
elif isinstance(val, np.uint32):
535+
return GGUFValueType.UINT32
536+
elif isinstance(val, np.uint64):
537+
return GGUFValueType.UINT64
538+
elif isinstance(val, np.int8):
539+
return GGUFValueType.INT8
540+
elif isinstance(val, np.int16):
541+
return GGUFValueType.INT16
542+
elif isinstance(val, np.int32):
525543
return GGUFValueType.INT32
544+
elif isinstance(val, np.int64):
545+
return GGUFValueType.INT64
546+
elif isinstance(val, int):
547+
if val >=0 and val <= np.iinfo(np.uint8).max:
548+
return GGUFValueType.UINT8
549+
elif val >=0 and val <= np.iinfo(np.uint16).max:
550+
return GGUFValueType.UINT16
551+
elif val >=0 and val <= np.iinfo(np.uint32).max:
552+
return GGUFValueType.UINT32
553+
elif val >=0 and val <= np.iinfo(np.uint64).max:
554+
return GGUFValueType.UINT64
555+
elif val >=np.iinfo(np.int8).min and val <= np.iinfo(np.int8).max:
556+
return GGUFValueType.INT8
557+
elif val >=np.iinfo(np.int16).min and val <= np.iinfo(np.int16).max:
558+
return GGUFValueType.INT16
559+
elif val >=np.iinfo(np.int32).min and val <= np.iinfo(np.int32).max:
560+
return GGUFValueType.INT32
561+
elif val >=np.iinfo(np.int64).min and val <= np.iinfo(np.int64).max:
562+
return GGUFValueType.INT64
563+
else:
564+
print("The integer exceed limit:", val)
565+
sys.exit()
566+
elif isinstance(val, dict):
567+
return GGUFValueType.OBJ
526568
# TODO: need help with 64-bit types in Python
527569
else:
528570
print("Unknown type:", type(val))

gguf-py/gguf/gguf_reader.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,22 @@ class ReaderField(NamedTuple):
4949

5050
types: list[GGUFValueType] = []
5151

52+
def get(self):
53+
result = None
54+
itype = self.types[0]
55+
if itype == GGUFValueType.ARRAY or itype == GGUFValueType.OBJ:
56+
itype = self.types[-1]
57+
if itype == GGUFValueType.STRING:
58+
result = [str(bytes(self.parts[idx]), encoding="utf-8") for idx in self.data]
59+
else:
60+
result = [pv for idx in self.data for pv in self.parts[idx].tolist()]
61+
elif itype == GGUFValueType.STRING:
62+
result = str(bytes(self.parts[-1]), encoding="utf-8")
63+
else:
64+
result = self.parts[-1].tolist()[0]
65+
66+
return result
67+
5268

5369
class ReaderTensor(NamedTuple):
5470
name: str
@@ -165,7 +181,7 @@ def _get_field_parts(
165181
val = self._get(offs, nptype)
166182
return int(val.nbytes), [val], [0], types
167183
# Handle arrays.
168-
if gtype == GGUFValueType.ARRAY:
184+
if gtype == GGUFValueType.ARRAY or gtype == GGUFValueType.OBJ:
169185
raw_itype = self._get(offs, np.uint32)
170186
offs += int(raw_itype.nbytes)
171187
alen = self._get(offs, np.uint64)

gguf-py/gguf/gguf_writer.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,48 @@ def add_array(self, key: str, val: Sequence[Any]) -> None:
158158
self.add_key(key)
159159
self.add_val(val, GGUFValueType.ARRAY)
160160

161+
def add_kv(self, key: str, val: Any) -> None:
162+
vtype=GGUFValueType.get_type(val)
163+
if vtype == GGUFValueType.OBJ:
164+
self.add_dict(key, val)
165+
elif vtype == GGUFValueType.ARRAY:
166+
self.add_array(key, val)
167+
elif vtype == GGUFValueType.STRING:
168+
self.add_string(key, val)
169+
elif vtype == GGUFValueType.BOOL:
170+
self.add_bool(key, val)
171+
elif vtype == GGUFValueType.INT8:
172+
self.add_int8(key, val)
173+
elif vtype == GGUFValueType.INT16:
174+
self.add_int16(key, val)
175+
elif vtype == GGUFValueType.INT32:
176+
self.add_int32(key, val)
177+
elif vtype == GGUFValueType.INT64:
178+
self.add_int64(key, val)
179+
elif vtype == GGUFValueType.UINT8:
180+
self.add_uint8(key, val)
181+
elif vtype == GGUFValueType.UINT16:
182+
self.add_uint16(key, val)
183+
elif vtype == GGUFValueType.UINT32:
184+
self.add_uint32(key, val)
185+
elif vtype == GGUFValueType.UINT64:
186+
self.add_uint64(key, val)
187+
elif vtype == GGUFValueType.FLOAT32:
188+
self.add_float32(key, val)
189+
elif vtype == GGUFValueType.FLOAT64:
190+
self.add_float64(key, val)
191+
else:
192+
raise ValueError(f"Unsupported type: {type(val)}")
193+
194+
def add_dict(self, key: str, val: dict) -> None:
195+
if not isinstance(val, dict):
196+
raise ValueError("Value must be a dict type")
197+
198+
self.add_key(key)
199+
self.add_val(val, GGUFValueType.OBJ)
200+
for k, v in val.items():
201+
self.add_kv(key + "." + k, v)
202+
161203
def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None:
162204
if vtype is None:
163205
vtype = GGUFValueType.get_type(val)
@@ -181,6 +223,8 @@ def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool
181223
self.kv_data += self._pack("Q", len(val))
182224
for item in val:
183225
self.add_val(item, add_vtype=False)
226+
elif vtype == GGUFValueType.OBJ and isinstance(val, dict) and val:
227+
self.add_val(list(val.keys()), GGUFValueType.ARRAY, False)
184228
else:
185229
raise ValueError("Invalid GGUF metadata value type or value")
186230

gguf-py/tests/test_constants.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import sys
2+
from pathlib import Path
3+
import numpy as np
4+
import unittest
5+
6+
# Necessary to load the local gguf package
7+
sys.path.insert(0, str(Path(__file__).parent.parent))
8+
9+
from gguf.constants import GGUFValueType
10+
11+
class TestGGUFValueType(unittest.TestCase):
12+
13+
def test_get_type(self):
14+
self.assertEqual(GGUFValueType.get_type("test"), GGUFValueType.STRING)
15+
self.assertEqual(GGUFValueType.get_type([1, 2, 3]), GGUFValueType.ARRAY)
16+
self.assertEqual(GGUFValueType.get_type(1.0), GGUFValueType.FLOAT32)
17+
self.assertEqual(GGUFValueType.get_type(True), GGUFValueType.BOOL)
18+
self.assertEqual(GGUFValueType.get_type(b"test"), GGUFValueType.STRING)
19+
self.assertEqual(GGUFValueType.get_type(np.uint8(1)), GGUFValueType.UINT8)
20+
self.assertEqual(GGUFValueType.get_type(np.uint16(1)), GGUFValueType.UINT16)
21+
self.assertEqual(GGUFValueType.get_type(np.uint32(1)), GGUFValueType.UINT32)
22+
self.assertEqual(GGUFValueType.get_type(np.uint64(1)), GGUFValueType.UINT64)
23+
self.assertEqual(GGUFValueType.get_type(np.int8(-1)), GGUFValueType.INT8)
24+
self.assertEqual(GGUFValueType.get_type(np.int16(-1)), GGUFValueType.INT16)
25+
self.assertEqual(GGUFValueType.get_type(np.int32(-1)), GGUFValueType.INT32)
26+
self.assertEqual(GGUFValueType.get_type(np.int64(-1)), GGUFValueType.INT64)
27+
self.assertEqual(GGUFValueType.get_type(np.float32(1.0)), GGUFValueType.FLOAT32)
28+
self.assertEqual(GGUFValueType.get_type(np.float64(1.0)), GGUFValueType.FLOAT64)
29+
self.assertEqual(GGUFValueType.get_type({"k": 12}), GGUFValueType.OBJ)
30+
31+
if __name__ == '__main__':
32+
unittest.main()

gguf-py/tests/test_gguf.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,62 @@
1-
import gguf # noqa: F401
1+
import sys
2+
from pathlib import Path
3+
import numpy as np
4+
import unittest
25

3-
# TODO: add tests
6+
# Necessary to load the local gguf package
7+
sys.path.insert(0, str(Path(__file__).parent.parent))
48

9+
from gguf import GGUFWriter, GGUFReader, GGUFValueType
510

6-
def test_write_gguf() -> None:
7-
pass
11+
class TestGGUFReaderWriter(unittest.TestCase):
12+
13+
def test_rw(self) -> None:
14+
# Example usage with a file
15+
gguf_writer = GGUFWriter("test_writer.gguf", "llama")
16+
17+
# gguf_writer.add_architecture()
18+
gguf_writer.add_block_count(12)
19+
gguf_writer.add_uint32("answer", 42) # Write a 32-bit integer
20+
gguf_writer.add_float32("answer_in_float", 42.0) # Write a 32-bit float
21+
gguf_writer.add_kv("uint8", 1)
22+
gguf_writer.add_kv("nint8", np.int8(1))
23+
gguf_writer.add_dict("dict1", {"key1": 2, "key2": "hi", "obj": {"k": 1}})
24+
gguf_writer.add_custom_alignment(64)
25+
26+
tensor1 = np.ones((32,), dtype=np.float32) * 100.0
27+
tensor2 = np.ones((64,), dtype=np.float32) * 101.0
28+
tensor3 = np.ones((96,), dtype=np.float32) * 102.0
29+
30+
gguf_writer.add_tensor("tensor1", tensor1)
31+
gguf_writer.add_tensor("tensor2", tensor2)
32+
gguf_writer.add_tensor("tensor3", tensor3)
33+
34+
gguf_writer.write_header_to_file()
35+
gguf_writer.write_kv_data_to_file()
36+
gguf_writer.write_tensors_to_file()
37+
38+
gguf_writer.close()
39+
40+
gguf_reader = GGUFReader("test_writer.gguf")
41+
self.assertEqual(gguf_reader.alignment, 64)
42+
v = gguf_reader.get_field("uint8")
43+
self.assertEqual(v.get(), 1)
44+
self.assertEqual(v.types[0], GGUFValueType.UINT8)
45+
v = gguf_reader.get_field("nint8")
46+
self.assertEqual(v.get(), 1)
47+
self.assertEqual(v.types[0], GGUFValueType.INT8)
48+
v = gguf_reader.get_field("dict1")
49+
self.assertIsNotNone(v)
50+
self.assertListEqual(v.get(), ['key1', 'key2', 'obj'])
51+
v = gguf_reader.get_field("dict1.key1")
52+
self.assertEqual(v.get(), 2)
53+
v = gguf_reader.get_field("dict1.key2")
54+
self.assertEqual(v.get(), "hi")
55+
v = gguf_reader.get_field("dict1.obj")
56+
self.assertListEqual(v.get(), ['k'])
57+
v = gguf_reader.get_field("dict1.obj.k")
58+
self.assertEqual(v.get(), 1)
59+
60+
61+
if __name__ == '__main__':
62+
unittest.main()

0 commit comments

Comments
 (0)