Skip to content

Commit b1f8af1

Browse files
authored
convert.py: Outfile default name change and additional metadata support (#4858)
* convert.py: Outfile default name change and additional metadata support * convert.py: don't stringify Metadata load method output * convert.py: typo fix * convert.py: fix metadata format to sync with LLM_KV_NAMES in llama.cpp
1 parent e586ee4 commit b1f8af1

File tree

1 file changed

+155
-25
lines changed

1 file changed

+155
-25
lines changed

convert.py

Lines changed: 155 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
2525
from dataclasses import dataclass
2626
from pathlib import Path
27-
from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable
27+
from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable, Optional
2828

2929
import numpy as np
3030
from sentencepiece import SentencePieceProcessor
@@ -344,10 +344,47 @@ def load(model_plus: ModelPlus) -> Params:
344344
return params
345345

346346

347+
@dataclass
348+
class Metadata:
349+
name: Optional[str] = None
350+
author: Optional[str] = None
351+
version: Optional[str] = None
352+
url: Optional[str] = None
353+
description: Optional[str] = None
354+
licence: Optional[str] = None
355+
source_url: Optional[str] = None
356+
source_hf_repo: Optional[str] = None
357+
358+
@staticmethod
359+
def load(metadata_path: Path) -> Metadata:
360+
if metadata_path is None or not metadata_path.exists():
361+
return Metadata()
362+
363+
with open(metadata_path, 'r') as file:
364+
data = json.load(file)
365+
366+
# Create a new Metadata instance
367+
metadata = Metadata()
368+
369+
# Assigning values to Metadata attributes if they exist in the JSON file
370+
# This is based on LLM_KV_NAMES mapping in llama.cpp
371+
metadata.name = data.get("general.name")
372+
metadata.author = data.get("general.author")
373+
metadata.version = data.get("general.version")
374+
metadata.url = data.get("general.url")
375+
metadata.description = data.get("general.description")
376+
metadata.license = data.get("general.license")
377+
metadata.source_url = data.get("general.source.url")
378+
metadata.source_hf_repo = data.get("general.source.huggingface.repository")
379+
380+
return metadata
381+
382+
347383
#
348384
# vocab
349385
#
350386

387+
351388
@runtime_checkable
352389
class BaseVocab(Protocol):
353390
tokenizer_model: ClassVar[str]
@@ -1066,21 +1103,42 @@ class OutputFile:
10661103
def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
10671104
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
10681105

1069-
def add_meta_arch(self, params: Params) -> None:
1106+
def add_meta_model(self, params: Params, metadata: Metadata) -> None:
1107+
# Metadata About The Model And Its Provenence
10701108
name = "LLaMA"
1071-
1072-
# TODO: better logic to determine model name
1073-
if params.n_ctx == 4096:
1074-
name = "LLaMA v2"
1109+
if metadata is not None and metadata.name is not None:
1110+
name = metadata.name
10751111
elif params.path_model is not None:
1076-
name = str(params.path_model.parent).split('/')[-1]
1077-
1078-
self.gguf.add_name (name)
1079-
self.gguf.add_vocab_size (params.n_vocab)
1080-
self.gguf.add_context_length (params.n_ctx)
1081-
self.gguf.add_embedding_length (params.n_embd)
1082-
self.gguf.add_block_count (params.n_layer)
1083-
self.gguf.add_feed_forward_length (params.n_ff)
1112+
name = str(params.path_model.parent).split("/")[-1]
1113+
elif params.n_ctx == 4096:
1114+
# Heuristic detection of LLaMA v2 model
1115+
name = "LLaMA v2"
1116+
1117+
self.gguf.add_name(name)
1118+
1119+
if metadata is not None:
1120+
if metadata.author is not None:
1121+
self.gguf.add_author(metadata.author)
1122+
if metadata.version is not None:
1123+
self.gguf.add_version(metadata.version)
1124+
if metadata.url is not None:
1125+
self.gguf.add_url(metadata.url)
1126+
if metadata.description is not None:
1127+
self.gguf.add_description(metadata.description)
1128+
if metadata.licence is not None:
1129+
self.gguf.add_licence(metadata.licence)
1130+
if metadata.source_url is not None:
1131+
self.gguf.add_source_url(metadata.source_url)
1132+
if metadata.source_hf_repo is not None:
1133+
self.gguf.add_source_hf_repo(metadata.source_hf_repo)
1134+
1135+
def add_meta_arch(self, params: Params) -> None:
1136+
# Metadata About The Neural Architecture Itself
1137+
self.gguf.add_vocab_size(params.n_vocab)
1138+
self.gguf.add_context_length(params.n_ctx)
1139+
self.gguf.add_embedding_length(params.n_embd)
1140+
self.gguf.add_block_count(params.n_layer)
1141+
self.gguf.add_feed_forward_length(params.n_ff)
10841142
self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
10851143
self.gguf.add_head_count (params.n_head)
10861144
self.gguf.add_head_count_kv (params.n_head_kv)
@@ -1183,13 +1241,14 @@ def close(self) -> None:
11831241
@staticmethod
11841242
def write_vocab_only(
11851243
fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
1186-
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False,
1244+
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata = None,
11871245
) -> None:
11881246
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
11891247

11901248
of = OutputFile(fname_out, endianess=endianess)
11911249

11921250
# meta data
1251+
of.add_meta_model(params, metadata)
11931252
of.add_meta_arch(params)
11941253
of.add_meta_vocab(vocab)
11951254
of.add_meta_special_vocab(svocab)
@@ -1216,12 +1275,14 @@ def write_all(
12161275
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
12171276
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
12181277
pad_vocab: bool = False,
1278+
metadata: Metadata = None,
12191279
) -> None:
12201280
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
12211281

12221282
of = OutputFile(fname_out, endianess=endianess)
12231283

12241284
# meta data
1285+
of.add_meta_model(params, metadata)
12251286
of.add_meta_arch(params)
12261287
if isinstance(vocab, Vocab):
12271288
of.add_meta_vocab(vocab)
@@ -1257,6 +1318,37 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT
12571318
raise ValueError(f"Unexpected combination of types: {name_to_type}")
12581319

12591320

1321+
def model_parameter_count(model: LazyModel) -> int:
1322+
total_model_parameters = 0
1323+
for i, (name, lazy_tensor) in enumerate(model.items()):
1324+
sum_weights_in_tensor = 1
1325+
for dim in lazy_tensor.shape:
1326+
sum_weights_in_tensor *= dim
1327+
total_model_parameters += sum_weights_in_tensor
1328+
return total_model_parameters
1329+
1330+
1331+
def model_parameter_count_rounded_notation(model_params_count: int) -> str:
1332+
if model_params_count > 1e12 :
1333+
# Trillions Of Parameters
1334+
scaled_model_params = model_params_count * 1e-12
1335+
scale_suffix = "T"
1336+
elif model_params_count > 1e9 :
1337+
# Billions Of Parameters
1338+
scaled_model_params = model_params_count * 1e-9
1339+
scale_suffix = "B"
1340+
elif model_params_count > 1e6 :
1341+
# Millions Of Parameters
1342+
scaled_model_params = model_params_count * 1e-6
1343+
scale_suffix = "M"
1344+
else:
1345+
# Thousands Of Parameters
1346+
scaled_model_params = model_params_count * 1e-3
1347+
scale_suffix = "K"
1348+
1349+
return f"{round(scaled_model_params)}{scale_suffix}"
1350+
1351+
12601352
def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
12611353
return {name: tensor.astype(output_type.type_for_tensor(name, tensor))
12621354
for (name, tensor) in model.items()}
@@ -1436,13 +1528,35 @@ def load_vocab(self, vocab_types: list[str] | None, model_parent_path: Path) ->
14361528
return vocab, special_vocab
14371529

14381530

1439-
def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path:
1440-
namestr = {
1441-
GGMLFileType.AllF32: "f32",
1442-
GGMLFileType.MostlyF16: "f16",
1443-
GGMLFileType.MostlyQ8_0:"q8_0",
1531+
def default_convention_outfile(file_type: GGMLFileType, params: Params, model_params_count: int, metadata: Metadata) -> str:
1532+
quantization = {
1533+
GGMLFileType.AllF32: "F32",
1534+
GGMLFileType.MostlyF16: "F16",
1535+
GGMLFileType.MostlyQ8_0: "Q8_0",
14441536
}[file_type]
1445-
ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf"
1537+
1538+
parameters = model_parameter_count_rounded_notation(model_params_count)
1539+
1540+
expert_count = ""
1541+
if params.n_experts is not None:
1542+
expert_count = f"{params.n_experts}x"
1543+
1544+
version = ""
1545+
if metadata is not None and metadata.version is not None:
1546+
version = f"-{metadata.version}"
1547+
1548+
name = "ggml-model"
1549+
if metadata is not None and metadata.name is not None:
1550+
name = metadata.name
1551+
elif params.path_model is not None:
1552+
name = params.path_model.name
1553+
1554+
return f"{name}{version}-{expert_count}{parameters}-{quantization}"
1555+
1556+
1557+
def default_outfile(model_paths: list[Path], file_type: GGMLFileType, params: Params, model_params_count: int, metadata: Metadata) -> Path:
1558+
default_filename = default_convention_outfile(file_type, params, model_params_count, metadata)
1559+
ret = model_paths[0].parent / f"{default_filename}.gguf"
14461560
if ret in model_paths:
14471561
logger.error(
14481562
f"Error: Default output path ({ret}) would overwrite the input. "
@@ -1480,17 +1594,30 @@ def main(args_in: list[str] | None = None) -> None:
14801594
parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
14811595
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
14821596
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
1597+
parser.add_argument("--metadata", type=Path, help="Specify the path for a metadata file")
1598+
parser.add_argument("--get-outfile", action="store_true", help="get calculated default outfile name")
14831599

14841600
args = parser.parse_args(args_in)
14851601

14861602
if args.verbose:
14871603
logging.basicConfig(level=logging.DEBUG)
1488-
elif args.dump_single or args.dump:
1604+
elif args.dump_single or args.dump or args.get_outfile:
14891605
# Avoid printing anything besides the dump output
14901606
logging.basicConfig(level=logging.WARNING)
14911607
else:
14921608
logging.basicConfig(level=logging.INFO)
14931609

1610+
metadata = Metadata.load(args.metadata)
1611+
1612+
if args.get_outfile:
1613+
model_plus = load_some_model(args.model)
1614+
params = Params.load(model_plus)
1615+
model = convert_model_names(model_plus.model, params, args.skip_unknown)
1616+
model_params_count = model_parameter_count(model_plus.model)
1617+
ftype = pick_output_type(model, args.outtype)
1618+
print(f"{default_convention_outfile(ftype, params, model_params_count, metadata)}") # noqa: NP100
1619+
return
1620+
14941621
if args.no_vocab and args.vocab_only:
14951622
raise ValueError("--vocab-only does not make sense with --no-vocab")
14961623

@@ -1504,6 +1631,9 @@ def main(args_in: list[str] | None = None) -> None:
15041631
else:
15051632
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
15061633

1634+
model_params_count = model_parameter_count(model_plus.model)
1635+
logger.info(f"model parameters count : {model_params_count} ({model_parameter_count_rounded_notation(model_params_count)})")
1636+
15071637
if args.dump:
15081638
do_dump_model(model_plus)
15091639
return
@@ -1557,7 +1687,7 @@ def main(args_in: list[str] | None = None) -> None:
15571687
f_norm_eps = 1e-5,
15581688
)
15591689
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab,
1560-
endianess=endianess, pad_vocab=args.pad_vocab)
1690+
endianess=endianess, pad_vocab=args.pad_vocab, metadata=metadata)
15611691
logger.info(f"Wrote {outfile}")
15621692
return
15631693

@@ -1570,13 +1700,13 @@ def main(args_in: list[str] | None = None) -> None:
15701700
model = convert_model_names(model, params, args.skip_unknown)
15711701
ftype = pick_output_type(model, args.outtype)
15721702
model = convert_to_output_type(model, ftype)
1573-
outfile = args.outfile or default_outfile(model_plus.paths, ftype)
1703+
outfile = args.outfile or default_outfile(model_plus.paths, ftype, params, model_params_count, metadata)
15741704

15751705
params.ftype = ftype
15761706
logger.info(f"Writing {outfile}, format {ftype}")
15771707

15781708
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
1579-
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab)
1709+
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab, metadata=metadata)
15801710
logger.info(f"Wrote {outfile}")
15811711

15821712

0 commit comments

Comments
 (0)