Skip to content

Commit f154685

Browse files
committed
Merge branch 'concedo_experimentalMAIN'
2 parents cbdc1f3 + 94e0a06 commit f154685

21 files changed

+1642
-654
lines changed

convert.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -234,14 +234,21 @@ def load(model_plus: 'ModelPlus') -> 'Params':
234234

235235

236236
class SentencePieceVocab:
237-
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
238-
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
237+
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], vocabtype: Optional[str]) -> None:
238+
self.vocabtype = vocabtype
239+
if self.vocabtype == "bpe":
240+
self.sentencepiece_tokenizer = json.loads(open(str(fname_tokenizer)).read())
241+
else:
242+
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
239243
added_tokens: Dict[str, int]
240244
if fname_added_tokens is not None:
241245
added_tokens = json.load(open(fname_added_tokens))
242246
else:
243247
added_tokens = {}
244-
vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
248+
if self.vocabtype == "bpe":
249+
vocab_size: int = len(self.sentencepiece_tokenizer)
250+
else:
251+
vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
245252
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
246253
actual_ids = sorted(added_tokens.values())
247254
if expected_ids != actual_ids:
@@ -255,22 +262,32 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) ->
255262

256263
def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]:
257264
tokenizer = self.sentencepiece_tokenizer
258-
for i in range(tokenizer.vocab_size()):
265+
if self.vocabtype == "bpe":
266+
from transformers.models.gpt2 import tokenization_gpt2
267+
byte_encoder = tokenization_gpt2.bytes_to_unicode()
268+
byte_decoder = {v: k for k, v in byte_encoder.items()}
269+
for i, item in enumerate(tokenizer):
259270
text: bytes
260-
if tokenizer.is_unknown(i):
261-
text = " \u2047 ".encode("utf-8")
262-
elif tokenizer.is_control(i):
263-
text = b""
264-
elif tokenizer.is_byte(i):
265-
piece = tokenizer.id_to_piece(i)
266-
if len(piece) != 6:
267-
raise Exception(f"Invalid token: {piece}")
268-
byte_value = int(piece[3:-1], 16)
269-
text = struct.pack("B", byte_value)
270-
else:
271-
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
272-
score: float = tokenizer.get_score(i)
271+
text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]])
272+
score: float = -i
273273
yield text, score
274+
else:
275+
for i in range(tokenizer.vocab_size()):
276+
text: bytes
277+
if tokenizer.is_unknown(i):
278+
text = " \u2047 ".encode("utf-8")
279+
elif tokenizer.is_control(i):
280+
text = b""
281+
elif tokenizer.is_byte(i):
282+
piece = tokenizer.id_to_piece(i)
283+
if len(piece) != 6:
284+
raise Exception(f"Invalid token: {piece}")
285+
byte_value = int(piece[3:-1], 16)
286+
text = struct.pack("B", byte_value)
287+
else:
288+
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
289+
score: float = tokenizer.get_score(i)
290+
yield text, score
274291

275292
def added_tokens(self) -> Iterable[Tuple[bytes, float]]:
276293
for text in self.added_tokens_list:
@@ -1196,14 +1213,18 @@ def filter_and_sort_tensors(model: LazyModel) -> LazyModel:
11961213
return {name: model[name] for name in TENSORS_LIST if name in model}
11971214

11981215

1199-
def load_vocab(path: Path) -> SentencePieceVocab:
1216+
def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab:
1217+
print(f"vocabtype: {vocabtype}")
12001218
# Be extra-friendly and accept either a file or a directory. Also, if it's
12011219
# a directory, it might be the model directory, and tokenizer.model might
12021220
# be in the parent of that.
12031221
if path.is_dir():
1204-
path2 = path / "tokenizer.model"
1222+
vocab_file = "tokenizer.model"
1223+
if vocabtype == 'bpe':
1224+
vocab_file = "vocab.json"
1225+
path2 = path / vocab_file
12051226
# Use `.parent` instead of /.. to handle the symlink case better.
1206-
path3 = path.parent / "tokenizer.model"
1227+
path3 = path.parent / vocab_file
12071228
if path2.exists():
12081229
path = path2
12091230
elif path3.exists():
@@ -1214,7 +1235,8 @@ def load_vocab(path: Path) -> SentencePieceVocab:
12141235
"if it's in another directory, pass the directory as --vocab-dir")
12151236
added_tokens_path = path.parent / "added_tokens.json"
12161237
print(f"Loading vocab file {path}")
1217-
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
1238+
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None,
1239+
vocabtype)
12181240

12191241

12201242
def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
@@ -1252,14 +1274,15 @@ def main(args_in: Optional[List[str]] = None) -> None:
12521274
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
12531275
parser.add_argument("model", type=Path,
12541276
help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
1277+
parser.add_argument("--vocabtype", default='spm', choices=["spm", "bpe"], help="vocab format (default: spm)")
12551278
args = parser.parse_args(args_in)
12561279

12571280
vocab: Vocab
12581281
if args.dump_single:
12591282
model_plus = lazy_load_file(args.model)
12601283
do_dump_model(model_plus)
12611284
elif args.vocab_only:
1262-
vocab = load_vocab(args.vocab_dir or args.model)
1285+
vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
12631286
assert args.outfile, "need --outfile if using --vocab-only"
12641287
outfile = args.outfile
12651288
OutputFile.write_vocab_only(outfile, vocab)
@@ -1273,7 +1296,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
12731296
vocab = model_plus.vocab
12741297
else:
12751298
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
1276-
vocab = load_vocab(vocab_dir)
1299+
vocab = load_vocab(vocab_dir, args.vocabtype)
12771300
params = Params.load(model_plus)
12781301
model = model_plus.model
12791302
model = do_necessary_conversions(model, params)

examples/baby-llama/baby-llama.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
#pragma warning(disable: 4244 4267) // possible loss of data
99
#endif
1010

11+
#ifdef LLAMA_DEFAULT_RMS_EPS
12+
static const float rms_norm_eps = LLAMA_DEFAULT_RMS_EPS;
13+
#else
14+
static const float rms_norm_eps = 5e-6f;
15+
#endif
16+
1117
float frand() {
1218
return (float)rand()/(float)RAND_MAX;
1319
}
@@ -562,7 +568,7 @@ struct ggml_tensor * forward(
562568
// norm
563569
{
564570
// cur shape [n_embd,N,1,1]
565-
cur = ggml_rms_norm(ctx0, inpL);
571+
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
566572

567573
// cur = attention_norm*cur
568574
cur = ggml_mul(ctx0,
@@ -685,7 +691,7 @@ struct ggml_tensor * forward(
685691
// norm
686692
{
687693
// cur shape [n_embd,N,1,1]
688-
cur = ggml_rms_norm(ctx0, inpFF);
694+
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
689695

690696
// cur = ffn_norm*cur
691697
// cur shape [n_embd,N,1,1]
@@ -729,7 +735,7 @@ struct ggml_tensor * forward(
729735
{
730736

731737
// inpL shape [n_embd,N,1,1]
732-
inpL = ggml_rms_norm(ctx0, inpL);
738+
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
733739

734740
// inpL = norm*inpL
735741
// inpL shape [n_embd,N,1,1]
@@ -817,7 +823,7 @@ struct ggml_tensor * forward_batch(
817823
// norm
818824
{
819825
// cur shape [n_embd,N*n_batch,1,1]
820-
cur = ggml_rms_norm(ctx0, inpL);
826+
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
821827
assert_shape_2d(cur, n_embd, N*n_batch);
822828

823829
// cur = attention_norm*cur
@@ -981,7 +987,7 @@ struct ggml_tensor * forward_batch(
981987
// norm
982988
{
983989
// cur shape [n_embd,N*n_batch,1,1]
984-
cur = ggml_rms_norm(ctx0, inpFF);
990+
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
985991
assert_shape_2d(cur, n_embd, N*n_batch);
986992

987993
// cur = ffn_norm*cur
@@ -1034,7 +1040,7 @@ struct ggml_tensor * forward_batch(
10341040
{
10351041

10361042
// inpL shape [n_embd,N*n_batch,1,1]
1037-
inpL = ggml_rms_norm(ctx0, inpL);
1043+
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
10381044
assert_shape_2d(inpL, n_embd, N*n_batch);
10391045

10401046
// inpL = norm*inpL
@@ -1104,7 +1110,7 @@ struct ggml_tensor * forward_lora(
11041110
// norm
11051111
{
11061112
// cur shape [n_embd,N,1,1]
1107-
cur = ggml_rms_norm(ctx0, inpL);
1113+
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
11081114

11091115
// cur = attention_norm*cur
11101116
cur = ggml_mul(ctx0,
@@ -1251,7 +1257,7 @@ struct ggml_tensor * forward_lora(
12511257
// norm
12521258
{
12531259
// cur shape [n_embd,N,1,1]
1254-
cur = ggml_rms_norm(ctx0, inpFF);
1260+
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
12551261

12561262
// cur = ffn_norm*cur
12571263
// cur shape [n_embd,N,1,1]
@@ -1295,7 +1301,7 @@ struct ggml_tensor * forward_lora(
12951301
{
12961302

12971303
// inpL shape [n_embd,N,1,1]
1298-
inpL = ggml_rms_norm(ctx0, inpL);
1304+
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
12991305

13001306
// inpL = norm*inpL
13011307
// inpL shape [n_embd,N,1,1]

examples/common.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
177177
break;
178178
}
179179
params.n_gqa = std::stoi(argv[i]);
180+
} else if (arg == "-eps" || arg == "--rms-norm-eps") {
181+
if (++i >= argc) {
182+
invalid_param = true;
183+
break;
184+
}
185+
params.rms_norm_eps = std::stof(argv[i]);
180186
} else if (arg == "--rope-freq-base") {
181187
if (++i >= argc) {
182188
invalid_param = true;
@@ -426,6 +432,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
426432
exit(0);
427433
} else if (arg == "--random-prompt") {
428434
params.random_prompt = true;
435+
} else if (arg == "--in-prefix-bos") {
436+
params.input_prefix_bos = true;
429437
} else if (arg == "--in-prefix") {
430438
if (++i >= argc) {
431439
invalid_param = true;
@@ -511,6 +519,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
511519
fprintf(stdout, " not supported with --interactive or other interactive options\n");
512520
fprintf(stdout, " --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n");
513521
fprintf(stdout, " --random-prompt start with a randomized prompt.\n");
522+
fprintf(stdout, " --in-prefix-bos prefix BOS to user inputs, preceding the `--in-prefix` string\n");
514523
fprintf(stdout, " --in-prefix STRING string to prefix user inputs with (default: empty)\n");
515524
fprintf(stdout, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n");
516525
fprintf(stdout, " -f FNAME, --file FNAME\n");
@@ -519,6 +528,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
519528
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
520529
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
521530
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
531+
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
522532
fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
523533
fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
524534
fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
@@ -615,6 +625,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
615625
lparams.n_ctx = params.n_ctx;
616626
lparams.n_batch = params.n_batch;
617627
lparams.n_gqa = params.n_gqa;
628+
lparams.rms_norm_eps = params.rms_norm_eps;
618629
lparams.n_gpu_layers = params.n_gpu_layers;
619630
lparams.main_gpu = params.main_gpu;
620631
lparams.tensor_split = params.tensor_split;

examples/common.h

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,19 @@
2222
int32_t get_num_physical_cores();
2323

2424
struct gpt_params {
25-
uint32_t seed = -1; // RNG seed
25+
uint32_t seed = -1; // RNG seed
2626
int32_t n_threads = get_num_physical_cores();
27-
int32_t n_predict = -1; // new tokens to predict
28-
int32_t n_ctx = 512; // context size
29-
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
30-
int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
31-
int32_t n_keep = 0; // number of tokens to keep from initial prompt
32-
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
33-
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
34-
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
35-
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
36-
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
27+
int32_t n_predict = -1; // new tokens to predict
28+
int32_t n_ctx = 512; // context size
29+
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
30+
int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
31+
int32_t n_keep = 0; // number of tokens to keep from initial prompt
32+
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
33+
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
34+
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
35+
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
36+
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
37+
float rms_norm_eps = LLAMA_DEFAULT_RMS_EPS; // rms norm epsilon
3738
float rope_freq_base = 10000.0f; // RoPE base frequency
3839
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
3940

@@ -81,6 +82,7 @@ struct gpt_params {
8182
bool interactive_first = false; // wait for user input immediately
8283
bool multiline_input = false; // reverse the usage of `\`
8384

85+
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
8486
bool instruct = false; // instruction mode (used for Alpaca models)
8587
bool penalize_nl = true; // consider newlines as a repeatable token
8688
bool perplexity = false; // compute perplexity over the prompt

examples/main/main.cpp

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,10 @@ int main(int argc, char ** argv) {
325325
}
326326
}
327327

328+
if (params.input_prefix_bos) {
329+
fprintf(stderr, "Input prefix with BOS\n");
330+
}
331+
328332
if (!params.input_prefix.empty()) {
329333
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
330334
}
@@ -633,16 +637,6 @@ int main(int argc, char ** argv) {
633637
last_n_tokens.push_back(id);
634638
}
635639

636-
// replace end of text token with newline token when in interactive mode
637-
if (id == llama_token_eos() && params.interactive && !params.instruct) {
638-
id = llama_token_newline.front();
639-
if (params.antiprompt.size() != 0) {
640-
// tokenize and inject first reverse prompt
641-
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
642-
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
643-
}
644-
}
645-
646640
// add it to the context
647641
embd.push_back(id);
648642

@@ -708,11 +702,34 @@ int main(int argc, char ** argv) {
708702
}
709703
}
710704

705+
// deal with end of text token in interactive mode
706+
if (last_n_tokens.back() == llama_token_eos()) {
707+
if (params.interactive) {
708+
if (params.antiprompt.size() != 0) {
709+
// tokenize and inject first reverse prompt
710+
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
711+
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
712+
is_antiprompt = true;
713+
}
714+
715+
is_interacting = true;
716+
printf("\n");
717+
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
718+
fflush(stdout);
719+
} else if (params.instruct) {
720+
is_interacting = true;
721+
}
722+
}
723+
711724
if (n_past > 0 && is_interacting) {
712725
if (params.instruct) {
713726
printf("\n> ");
714727
}
715728

729+
if (params.input_prefix_bos) {
730+
embd_inp.push_back(llama_token_bos());
731+
}
732+
716733
std::string buffer;
717734
if (!params.input_prefix.empty()) {
718735
buffer += params.input_prefix;
@@ -776,13 +793,9 @@ int main(int argc, char ** argv) {
776793
}
777794

778795
// end of text token
779-
if (!embd.empty() && embd.back() == llama_token_eos()) {
780-
if (params.instruct) {
781-
is_interacting = true;
782-
} else {
783-
fprintf(stderr, " [end of text]\n");
784-
break;
785-
}
796+
if (!embd.empty() && embd.back() == llama_token_eos() && !(params.instruct || params.interactive)) {
797+
fprintf(stderr, " [end of text]\n");
798+
break;
786799
}
787800

788801
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.

0 commit comments

Comments
 (0)