Skip to content

Commit 938943c

Browse files
authored
llama : move vocab, grammar and sampling into separate files (#8508)
* llama : move sampling code into llama-sampling ggml-ci * llama : move grammar code into llama-grammar ggml-ci * cont ggml-ci * cont : pre-fetch rules * cont ggml-ci * llama : deprecate llama_sample_grammar * llama : move tokenizers into llama-vocab ggml-ci * make : update llama.cpp deps [no ci] * llama : redirect external API to internal APIs ggml-ci * llama : suffix the internal APIs with "_impl" ggml-ci * llama : clean-up
1 parent 751fcfc commit 938943c

18 files changed

+3656
-3103
lines changed

Makefile

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,9 @@ OBJ_GGML += \
876876

877877
OBJ_LLAMA = \
878878
src/llama.o \
879+
src/llama-vocab.o \
880+
src/llama-grammar.o \
881+
src/llama-sampling.o \
879882
src/unicode.o \
880883
src/unicode-data.o
881884

@@ -1055,6 +1058,10 @@ src/unicode-data.o: \
10551058

10561059
src/llama.o: \
10571060
src/llama.cpp \
1061+
src/llama-impl.h \
1062+
src/llama-vocab.h \
1063+
src/llama-grammar.h \
1064+
src/llama-sampling.h \
10581065
src/unicode.h \
10591066
include/llama.h \
10601067
ggml/include/ggml-cuda.h \
@@ -1064,6 +1071,29 @@ src/llama.o: \
10641071
ggml/include/ggml-backend.h
10651072
$(CXX) $(CXXFLAGS) -c $< -o $@
10661073

1074+
src/llama-vocab.o: \
1075+
src/llama-vocab.cpp \
1076+
src/llama-vocab.h \
1077+
src/llama-impl.h \
1078+
include/llama.h
1079+
$(CXX) $(CXXFLAGS) -c $< -o $@
1080+
1081+
src/llama-grammar.o: \
1082+
src/llama-grammar.cpp \
1083+
src/llama-grammar.h \
1084+
src/llama-impl.h \
1085+
src/llama-vocab.h \
1086+
src/llama-sampling.h \
1087+
include/llama.h
1088+
$(CXX) $(CXXFLAGS) -c $< -o $@
1089+
1090+
src/llama-sampling.o: \
1091+
src/llama-sampling.cpp \
1092+
src/llama-sampling.h \
1093+
src/llama-impl.h \
1094+
include/llama.h
1095+
$(CXX) $(CXXFLAGS) -c $< -o $@
1096+
10671097
$(LIB_LLAMA): \
10681098
$(OBJ_LLAMA) \
10691099
$(LIB_GGML)
@@ -1439,7 +1469,7 @@ run-benchmark-matmult: llama-benchmark-matmult
14391469
.PHONY: run-benchmark-matmult swift
14401470

14411471
tests/test-llama-grammar: tests/test-llama-grammar.cpp \
1442-
$(OBJ_GGML) $(OBJ_COMMON) src/unicode.o src/unicode-data.o
1472+
$(OBJ_ALL)
14431473
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
14441474
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
14451475

Package.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ import PackageDescription
44

55
var sources = [
66
"src/llama.cpp",
7+
"src/llama-vocab.cpp",
8+
"src/llama-grammar.cpp",
9+
"src/llama-sampling.cpp",
710
"src/unicode.cpp",
811
"src/unicode-data.cpp",
912
"ggml/src/ggml.c",

common/sampling.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ static llama_token llama_sampling_sample_impl(
330330
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
331331

332332
// Apply grammar constraints to the single token
333-
llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);
333+
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);
334334

335335
// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
336336
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
@@ -421,7 +421,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
421421

422422
// apply grammar checks before sampling logic
423423
if (apply_grammar && ctx_sampling->grammar != NULL) {
424-
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
424+
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
425425
}
426426

427427
return cur_p;
@@ -455,6 +455,6 @@ void llama_sampling_accept(
455455
ctx_sampling->prev.push_back(id);
456456

457457
if (ctx_sampling->grammar != NULL && apply_grammar) {
458-
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
458+
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
459459
}
460460
}

examples/gbnf-validator/gbnf-validator.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,25 @@ static bool llama_sample_grammar_string(struct llama_grammar * grammar, const st
1616
auto decoded = decode_utf8(input_str, {});
1717
const auto & code_points = decoded.first;
1818

19+
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
20+
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
21+
1922
size_t pos = 0;
2023
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
21-
auto prev_stacks = grammar->stacks;
22-
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
23-
if (grammar->stacks.empty()) {
24+
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
25+
26+
llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
27+
28+
if (cur_stacks.empty()) {
2429
error_pos = pos;
2530
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
26-
grammar->stacks = prev_stacks;
31+
cur_stacks = prev_stacks;
2732
return false;
2833
}
2934
++pos;
3035
}
3136

32-
for (const auto & stack : grammar->stacks) {
37+
for (const auto & stack : cur_stacks) {
3338
if (stack.empty()) {
3439
return true;
3540
}

include/llama.h

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -906,10 +906,10 @@ extern "C" {
906906
LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding
907907

908908
// Returns -1 if unknown, 1 for true or 0 for false.
909-
LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model);
909+
LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model);
910910

911911
// Returns -1 if unknown, 1 for true or 0 for false.
912-
LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
912+
LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
913913

914914
// Codellama infill tokens
915915
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
@@ -965,6 +965,10 @@ extern "C" {
965965
bool remove_special,
966966
bool unparse_special);
967967

968+
//
969+
// Chat templates
970+
//
971+
968972
/// Apply chat template. Inspired by hf apply_chat_template() on python.
969973
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
970974
/// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
@@ -1003,6 +1007,23 @@ extern "C" {
10031007

10041008
LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
10051009

1010+
/// @details Apply constraints from grammar
1011+
LLAMA_API void llama_grammar_sample(
1012+
const struct llama_grammar * grammar,
1013+
const struct llama_context * ctx,
1014+
llama_token_data_array * candidates);
1015+
LLAMA_API DEPRECATED(void llama_sample_grammar(
1016+
struct llama_context * ctx,
1017+
llama_token_data_array * candidates,
1018+
const struct llama_grammar * grammar),
1019+
"use llama_grammar_sample instead");
1020+
1021+
/// @details Accepts the sampled token into the grammar
1022+
LLAMA_API void llama_grammar_accept_token(
1023+
struct llama_grammar * grammar,
1024+
struct llama_context * ctx,
1025+
llama_token token);
1026+
10061027
//
10071028
// Sampling functions
10081029
//
@@ -1084,12 +1105,6 @@ extern "C" {
10841105
llama_token_data_array * candidates,
10851106
float temp);
10861107

1087-
/// @details Apply constraints from grammar
1088-
LLAMA_API void llama_sample_grammar(
1089-
struct llama_context * ctx,
1090-
llama_token_data_array * candidates,
1091-
const struct llama_grammar * grammar);
1092-
10931108
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
10941109
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
10951110
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
@@ -1127,12 +1142,6 @@ extern "C" {
11271142
struct llama_context * ctx,
11281143
llama_token_data_array * candidates);
11291144

1130-
/// @details Accepts the sampled token into the grammar
1131-
LLAMA_API void llama_grammar_accept_token(
1132-
struct llama_context * ctx,
1133-
struct llama_grammar * grammar,
1134-
llama_token token);
1135-
11361145
//
11371146
// Model split
11381147
//
@@ -1175,38 +1184,45 @@ extern "C" {
11751184

11761185
struct ggml_tensor;
11771186

1187+
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
1188+
struct llama_context * ctx
1189+
);
1190+
11781191
struct llama_partial_utf8 {
11791192
uint32_t value; // bit value so far (unshifted)
11801193
int n_remain; // num bytes remaining; -1 indicates invalid sequence
11811194
};
11821195

1183-
struct llama_grammar {
1184-
const std::vector<std::vector<llama_grammar_element>> rules;
1185-
std::vector<std::vector<const llama_grammar_element *>> stacks;
1186-
1187-
// buffer for partially generated UTF-8 sequence from accepted tokens
1188-
llama_partial_utf8 partial_utf8;
1189-
};
1190-
11911196
struct llama_grammar_candidate {
11921197
size_t index;
11931198
const uint32_t * code_points;
11941199
llama_partial_utf8 partial_utf8;
11951200
};
11961201

1197-
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
1198-
struct llama_context * ctx
1199-
);
1202+
using llama_grammar_rule = std::vector< llama_grammar_element>;
1203+
using llama_grammar_stack = std::vector<const llama_grammar_element *>;
1204+
1205+
using llama_grammar_rules = std::vector<llama_grammar_rule>;
1206+
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
1207+
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
1208+
1209+
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
1210+
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
12001211

12011212
void llama_grammar_accept(
1202-
const std::vector<std::vector<llama_grammar_element>> & rules,
1203-
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
1204-
const uint32_t chr,
1205-
std::vector<std::vector<const llama_grammar_element *>> & new_stacks);
1213+
const llama_grammar_rules & rules,
1214+
const llama_grammar_stacks & stacks,
1215+
const uint32_t chr,
1216+
llama_grammar_stacks & new_stacks);
1217+
1218+
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
1219+
const llama_grammar_rules & rules,
1220+
const llama_grammar_stack & stack,
1221+
const llama_grammar_candidates & candidates);
12061222

12071223
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
12081224
const std::string & src,
1209-
llama_partial_utf8 partial_start);
1225+
llama_partial_utf8 partial_start);
12101226

12111227
// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
12121228
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.

src/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ endif()
1414
add_library(llama
1515
../include/llama.h
1616
llama.cpp
17+
llama-vocab.cpp
18+
llama-grammar.cpp
19+
llama-sampling.cpp
1720
unicode.h
1821
unicode.cpp
1922
unicode-data.cpp

0 commit comments

Comments
 (0)