Skip to content

Commit 402359b

Browse files
committed
Fix 70B not working
more bandaids but it works fr this time :D
1 parent 795ae22 commit 402359b

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

convert.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
ADDED_TOKENS_FILE = 'added_tokens.json'
5151
FAST_TOKENIZER_FILE = 'tokenizer.json'
52+
is_llama3_model = True
5253

5354
#
5455
# data types
@@ -821,6 +822,9 @@ def convert(name: str) -> LazyTensor:
821822
else:
822823
# split by rows
823824
axis = 0
825+
global is_llama3_model
826+
if name.startswith('tok_embeddings.') and is_llama3_model:
827+
axis = 0
824828
concatenated_shape = list(lazy_tensors[0].shape)
825829
concatenated_shape[axis] = sum(tensor.shape[axis] for tensor in lazy_tensors)
826830

@@ -1194,6 +1198,12 @@ def add_meta_vocab(self, vocab: Vocab) -> None:
11941198
tokens, scores, toktypes = self.extract_vocabulary_from_model(vocab)
11951199

11961200
# Add extracted token information for model conversion
1201+
# Tokenizer for LLaMA 3
1202+
# Source: trust me bro
1203+
global is_llama3_model
1204+
if is_llama3_model:
1205+
self.gguf.add_tokenizer_model("gpt2")
1206+
self.gguf.add_tokenizer_pre("llama-bpe")
11971207
self.gguf.add_token_list(tokens)
11981208
self.gguf.add_token_scores(scores)
11991209
self.gguf.add_token_types(toktypes)
@@ -1662,10 +1672,14 @@ def main(args_in: list[str] | None = None) -> None:
16621672
}[args.outtype]
16631673

16641674
logger.info(f"params = {params}")
1665-
1666-
1667-
import convert_llama_weights_to_hf
1668-
convert_llama_weights_to_hf.write_tokenizer(args.model, os.path.join(args.model, "tokenizer.model"), 3)
1675+
#TODO: add more bandaids for llama 3 detection
1676+
try:
1677+
global is_llama3_model
1678+
import convert_llama_weights_to_hf
1679+
convert_llama_weights_to_hf.write_tokenizer(args.model, os.path.join(args.model, "tokenizer.model"), 3)
1680+
is_llama3_model = True
1681+
except:
1682+
pass
16691683

16701684

16711685
model_parent_path = model_plus.paths[0].parent

0 commit comments

Comments
 (0)