Skip to content

Commit 36083dc

Browse files
authored
Use Longest Common Prefix (LCP) instead of LCS
1 parent f116411 commit 36083dc

File tree

4 files changed

+25
-67
lines changed

4 files changed

+25
-67
lines changed

common/common.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,12 +1460,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
14601460
params.chat_template = argv[i];
14611461
return true;
14621462
}
1463-
if (arg == "--lcs-similarity") {
1463+
if (arg == "--lcp-similarity") {
14641464
if (++i >= argc) {
14651465
invalid_param = true;
14661466
return true;
14671467
}
1468-
params.lcs_similarity = std::stof(argv[i]);
1468+
params.lcp_similarity = std::stof(argv[i]);
14691469
return true;
14701470
}
14711471
if (arg == "-pps") {
@@ -1839,8 +1839,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
18391839
"set custom jinja chat template (default: template taken from model's metadata)\n"
18401840
"only commonly used templates are accepted:\n"
18411841
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
1842-
options.push_back({ "server", " --lcs-similarity SIMILARITY",
1843-
"how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f)\n", params.lcs_similarity });
1842+
options.push_back({ "server", " --lcp-similarity SIMILARITY",
1843+
"how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f)\n", params.lcp_similarity });
18441844

18451845
#ifndef LOG_DISABLE_LOGS
18461846
options.push_back({ "logging" });

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ struct gpt_params {
202202

203203
std::string slot_save_path;
204204

205-
float lcs_similarity = 0.0f;
205+
float lcp_similarity = 0.0f;
206206

207207
// batched-bench params
208208
bool is_pp_shared = false;

examples/server/server.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -647,8 +647,8 @@ struct server_context {
647647

648648
server_metrics metrics;
649649

650-
// Longest Common Substring similarity for slot selection
651-
float lcs_similarity = 0.0f;
650+
// Longest Common Prefix similarity for slot selection
651+
float lcp_similarity = 0.0f;
652652

653653
~server_context() {
654654
if (ctx) {
@@ -812,8 +812,8 @@ struct server_context {
812812
server_slot * ret = nullptr;
813813

814814
// find the slot that has at least n% prompt similarity
815-
if (ret == nullptr && lcs_similarity != 0.0f && !prompt.empty()) {
816-
int max_lcs_len = 0;
815+
if (ret == nullptr && lcp_similarity != 0.0f && !prompt.empty()) {
816+
int max_lcp_len = 0;
817817
float similarity = 0;
818818

819819
for (server_slot & slot : slots) {
@@ -833,23 +833,23 @@ struct server_context {
833833
// length of the current slot's prompt
834834
int slot_prompt_len = slot_prompt.size();
835835

836-
// length of the longest common substring between the current slot's prompt and the input prompt
837-
int lcs_len = lcs_length(slot_prompt, prompt);
836+
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
837+
int lcp_len = common_part(slot_prompt, prompt);
838838

839839
// fraction of the common substring length compared to the current slot's prompt length
840-
similarity = static_cast<float>(lcs_len) / slot_prompt_len;
840+
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
841841

842842
// select the current slot if the criteria match
843-
if (lcs_len > max_lcs_len && similarity > lcs_similarity) {
844-
max_lcs_len = lcs_len;
843+
if (lcp_len > max_lcp_len && similarity > lcp_similarity) {
844+
max_lcp_len = lcp_len;
845845
ret = &slot;
846846
}
847847
}
848848

849849
if (ret != nullptr) {
850-
LOG_VERBOSE("selected slot by lcs similarity", {
850+
LOG_VERBOSE("selected slot by lcp similarity", {
851851
{"id_slot", ret->id},
852-
{"max_lcs_len", max_lcs_len},
852+
{"max_lcp_len", max_lcp_len},
853853
{"similarity", similarity},
854854
});
855855
}
@@ -2568,8 +2568,8 @@ int main(int argc, char ** argv) {
25682568
log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded";
25692569
}
25702570

2571-
// Longest Common Substring similarity for slot selection
2572-
ctx_server.lcs_similarity = params.lcs_similarity;
2571+
// Longest Common Prefix similarity for slot selection
2572+
ctx_server.lcp_similarity = params.lcp_similarity;
25732573

25742574
// load the model
25752575
if (!ctx_server.load_model(params)) {

examples/server/utils.hpp

Lines changed: 7 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,13 @@ static size_t common_part(const std::vector<llama_token> & a, const std::vector<
253253
return i;
254254
}
255255

256+
static size_t common_part(const std::string & a, const std::string & b) {
257+
size_t i;
258+
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
259+
260+
return i;
261+
}
262+
256263
static bool ends_with(const std::string & str, const std::string & suffix) {
257264
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
258265
}
@@ -646,52 +653,3 @@ static json format_error_response(const std::string & message, const enum error_
646653
{"type", type_str},
647654
};
648655
}
649-
650-
static int lcs_length(const std::string & str1, const std::string & str2) {
651-
// check for empty strings
652-
if (str1.empty() || str2.empty()) {
653-
return 0;
654-
}
655-
656-
// get the lengths of the input strings
657-
int str1_len = str1.size();
658-
int str2_len = str2.size();
659-
660-
// initialize the maximum length of the longest common subsequence (LCS)
661-
int max_length = 0;
662-
663-
// use two rows instead of a 2D matrix to optimize space
664-
std::vector<int> prev_row(str2_len + 1, 0);
665-
std::vector<int> curr_row(str2_len + 1, 0);
666-
667-
// iterate through the characters of str1
668-
for (int i = 1; i <= str1_len; i++) {
669-
// iterate through the characters of str2
670-
for (int j = 1; j <= str2_len; j++) {
671-
// if characters at the current positions match
672-
if (str1[i - 1] == str2[j - 1]) {
673-
// if it's the first character of either string, set LCS length to 1
674-
if (i == 1 || j == 1) {
675-
curr_row[j] = 1;
676-
} else {
677-
// increment LCS length by 1 compared to the previous character
678-
curr_row[j] = prev_row[j - 1] + 1;
679-
}
680-
681-
// update max_length if necessary
682-
if (curr_row[j] > max_length) {
683-
max_length = curr_row[j];
684-
}
685-
} else {
686-
// reset LCS length if characters don't match
687-
curr_row[j] = 0;
688-
}
689-
}
690-
691-
// update the previous row for the next iteration
692-
prev_row = curr_row;
693-
}
694-
695-
// return the maximum length of the LCS
696-
return max_length;
697-
}

0 commit comments

Comments
 (0)