Skip to content

llama-bench : add test measuring token generation rate at given prompt length #11126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 97 additions & 11 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ struct cmd_params {
std::vector<int> n_prompt;
std::vector<int> n_gen;
std::vector<std::pair<int, int>> n_pg;
std::vector<std::pair<int, int>> n_gp;
std::vector<int> n_batch;
std::vector<int> n_ubatch;
std::vector<ggml_type> type_k;
Expand Down Expand Up @@ -192,6 +193,7 @@ static const cmd_params cmd_params_defaults = {
/* n_prompt */ { 512 },
/* n_gen */ { 128 },
/* n_pg */ {},
/* n_gp */ {},
/* n_batch */ { 2048 },
/* n_ubatch */ { 512 },
/* type_k */ { GGML_TYPE_F16 },
Expand Down Expand Up @@ -230,6 +232,8 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -n, --n-gen <n> (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
printf(" -pg <pp,tg> (default: %s)\n",
join(transform_to_str(cmd_params_defaults.n_pg, pair_str), ",").c_str());
printf(" -gp <pp,tg> (default: %s)\n",
join(transform_to_str(cmd_params_defaults.n_gp, pair_str), ",").c_str());
printf(" -b, --batch-size <n> (default: %s)\n",
join(cmd_params_defaults.n_batch, ",").c_str());
printf(" -ub, --ubatch-size <n> (default: %s)\n",
Expand Down Expand Up @@ -366,6 +370,17 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
break;
}
params.n_pg.push_back({ std::stoi(p[0]), std::stoi(p[1]) });
} else if (arg == "-gp") {
if (++i >= argc) {
invalid_param = true;
break;
}
auto p = string_split<std::string>(argv[i], ',');
if (p.size() != 2) {
invalid_param = true;
break;
}
params.n_gp.push_back({ std::stoi(p[0]), std::stoi(p[1]) });
} else if (arg == "-b" || arg == "--batch-size") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -615,6 +630,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.n_pg.empty()) {
params.n_pg = cmd_params_defaults.n_pg;
}
if (params.n_gp.empty()) {
params.n_gp = cmd_params_defaults.n_gp;
}
if (params.n_batch.empty()) {
params.n_batch = cmd_params_defaults.n_batch;
}
Expand Down Expand Up @@ -670,7 +688,19 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
return params;
}

enum test_kind_type {
// measure mean prompt processing rate without token generation
TEST_KIND_PP,
// measure mean token generation rate without prompt processing
TEST_KIND_TG,
// measure mean prompt processing and token generation rate
TEST_KIND_PG,
// measure mean token generation rate after processing prompt of given length
TEST_KIND_GP,
};

struct cmd_params_instance {
test_kind_type test_kind;
std::string model;
int n_prompt;
int n_gen;
Expand Down Expand Up @@ -757,6 +787,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
continue;
}
cmd_params_instance instance = {
/* .test_kind = */ TEST_KIND_PP,
/* .model = */ m,
/* .n_prompt = */ n_prompt,
/* .n_gen = */ 0,
Expand Down Expand Up @@ -786,6 +817,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
continue;
}
cmd_params_instance instance = {
/* .test_kind = */ TEST_KIND_TG,
/* .model = */ m,
/* .n_prompt = */ 0,
/* .n_gen = */ n_gen,
Expand Down Expand Up @@ -815,6 +847,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
continue;
}
cmd_params_instance instance = {
/* .test_kind = */ TEST_KIND_PG,
/* .model = */ m,
/* .n_prompt = */ n_pg.first,
/* .n_gen = */ n_pg.second,
Expand All @@ -838,6 +871,36 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
};
instances.push_back(instance);
}

for (const auto & n_gp : params.n_gp) {
if (n_gp.first == 0 && n_gp.second == 0) {
continue;
}
cmd_params_instance instance = {
/* .test_kind = */ TEST_KIND_GP,
/* .model = */ m,
/* .n_prompt = */ n_gp.first,
/* .n_gen = */ n_gp.second,
/* .n_batch = */ nb,
/* .n_ubatch = */ nub,
/* .type_k = */ tk,
/* .type_v = */ tv,
/* .n_threads = */ nt,
/* .cpu_mask = */ cm,
/* .cpu_strict = */ cs,
/* .poll = */ pl,
/* .n_gpu_layers = */ nl,
/* .rpc_servers = */ rpc,
/* .split_mode = */ sm,
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
};
instances.push_back(instance);
}
}
// clang-format on

Expand All @@ -853,6 +916,7 @@ struct test {
std::string model_type;
uint64_t model_size;
uint64_t model_n_params;
test_kind_type test_kind;
int n_batch;
int n_ubatch;
int n_threads;
Expand All @@ -872,6 +936,7 @@ struct test {
int n_prompt;
int n_gen;
std::string test_time;
std::string test_label;
std::vector<uint64_t> samples_ns;

test(const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) {
Expand All @@ -881,6 +946,7 @@ struct test {
model_type = buf;
model_size = llama_model_size(lmodel);
model_n_params = llama_model_n_params(lmodel);
test_kind = inst.test_kind;
n_batch = inst.n_batch;
n_ubatch = inst.n_ubatch;
n_threads = inst.n_threads;
Expand All @@ -904,6 +970,26 @@ struct test {
std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t));
test_time = buf;

// prepare test label for printing
switch (test_kind) {
case TEST_KIND_PP:
snprintf(buf, sizeof(buf), "pp%d", n_prompt);
break;
case TEST_KIND_TG:
snprintf(buf, sizeof(buf), "tg%d", n_gen);
break;
case TEST_KIND_PG:
snprintf(buf, sizeof(buf), "pp%d+tg%d", n_prompt, n_gen);
break;
case TEST_KIND_GP:
snprintf(buf, sizeof(buf), "tg%d@pp%d", n_gen, n_prompt);
break;
default:
snprintf(buf, sizeof(buf), "unknown");
break;
}
Comment on lines +974 to +990
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This formatting should only be applied to the markdown printer. The other printers are intended to be used programmatically, so it should be a simple enum that can be parsed easily, without the token counts. The token counts can be obtained from the n_prompt and n_gen parameters already.

test_label = buf;

(void) ctx;
}

Expand All @@ -912,7 +998,7 @@ struct test {
uint64_t stdev_ns() const { return ::stdev(samples_ns); }

std::vector<double> get_ts() const {
int n_tokens = n_prompt + n_gen;
int n_tokens = (test_kind == TEST_KIND_GP ? 0 : n_prompt) + n_gen;
std::vector<double> ts;
std::transform(samples_ns.begin(), samples_ns.end(), std::back_inserter(ts),
[n_tokens](uint64_t t) { return 1e9 * n_tokens / t; });
Expand Down Expand Up @@ -942,7 +1028,7 @@ struct test {
"cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers",
"split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "use_mmap",
"embeddings", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns",
"avg_ts", "stddev_ts",
"avg_ts", "stddev_ts", "test",
};
return fields;
}
Expand Down Expand Up @@ -1013,7 +1099,8 @@ struct test {
std::to_string(avg_ns()),
std::to_string(stdev_ns()),
std::to_string(avg_ts()),
std::to_string(stdev_ts()) };
std::to_string(stdev_ts()),
test_label };
return values;
}

Expand Down Expand Up @@ -1325,14 +1412,7 @@ struct markdown_printer : public printer {
} else if (field == "backend") {
value = test::get_backend();
} else if (field == "test") {
if (t.n_prompt > 0 && t.n_gen == 0) {
snprintf(buf, sizeof(buf), "pp%d", t.n_prompt);
} else if (t.n_gen > 0 && t.n_prompt == 0) {
snprintf(buf, sizeof(buf), "tg%d", t.n_gen);
} else {
snprintf(buf, sizeof(buf), "pp%d+tg%d", t.n_prompt, t.n_gen);
}
value = buf;
value = t.test_label;
} else if (field == "t/s") {
snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_ts(), t.stdev_ts());
value = buf;
Expand Down Expand Up @@ -1597,6 +1677,12 @@ int main(int argc, char ** argv) {
}
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
}

// we are not interested in prompt processing time in g@p test
if (t.test_kind == TEST_KIND_GP) {
t_start = get_time_ns();
}

if (t.n_gen > 0) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count,
Expand Down
Loading