Skip to content

Commit 39b46bc

Browse files
committed
[ExecuTorch][Llama] Change runner to enable chunked prefill
Pull Request resolved: #9785 This diff adds code to chunk prompt longer than max_seq_len to enable prefill of larger context ghstack-source-id: 275283535 Differential Revision: [D71833061](https://our.internmc.facebook.com/intern/diff/D71833061/)
1 parent 2aa7748 commit 39b46bc

File tree

4 files changed

+68
-9
lines changed

4 files changed

+68
-9
lines changed

examples/models/llama/runner/runner.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include <executorch/examples/models/llama/runner/runner.h>
1313

14+
#include <algorithm>
1415
#include <ctime>
1516

1617
#include <executorch/extension/llm/runner/util.h>
@@ -140,7 +141,8 @@ Error Runner::load() {
140141
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
141142
text_decoder_runner_.get(),
142143
metadata_.at(kUseKVCache),
143-
metadata_.at(kEnableDynamicShape));
144+
metadata_.at(kEnableDynamicShape),
145+
metadata_.at(kMaxSeqLen));
144146

145147
text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
146148
tokenizer_.get(),
@@ -221,11 +223,11 @@ Error Runner::generate(
221223

222224
ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
223225
ET_CHECK_MSG(
224-
num_prompt_tokens < metadata_.at(kMaxSeqLen),
226+
num_prompt_tokens < metadata_.at(kMaxContextLen),
225227
"num_prompt_tokens %d >= max_seq_len_ %" PRId64
226228
", Max seq length exceeded - please increase max seq len value in .../llama2/model.py",
227229
num_prompt_tokens,
228-
metadata_.at(kMaxSeqLen));
230+
metadata_.at(kMaxContextLen));
229231
ET_CHECK_MSG(
230232
num_prompt_tokens < seq_len,
231233
"num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()",
@@ -242,10 +244,10 @@ Error Runner::generate(
242244
}
243245
int64_t pos = 0;
244246
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
245-
stats_.first_token_ms = llm::time_in_ms();
246-
stats_.prompt_eval_end_ms = llm::time_in_ms();
247247
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
248248
uint64_t cur_token = prefill_res.get();
249+
stats_.first_token_ms = llm::time_in_ms();
250+
stats_.prompt_eval_end_ms = llm::time_in_ms();
249251

250252
// print the first token from prefill. No prev_token so use cur_token for it.
251253
wrapped_callback(

examples/models/llava/runner/llava_runner.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ Error LlavaRunner::load() {
5555
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
5656
text_decoder_runner_.get(),
5757
/*use_kv_cache=*/true,
58-
/*enable_parallel_prefill=*/true);
58+
/*enable_parallel_prefill=*/true,
59+
/*max_seq_len=*/128);
5960

6061
// Load the image prefiller
6162
image_prefiller_ = std::make_unique<LlavaImagePrefiller>(module_.get());

extension/llm/runner/text_prefiller.cpp

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// LLM.
1111

1212
#include <executorch/extension/llm/runner/text_prefiller.h>
13+
#include <algorithm>
1314

1415
namespace executorch {
1516
namespace extension {
@@ -18,10 +19,13 @@ namespace llm {
1819
TextPrefiller::TextPrefiller(
1920
TextDecoderRunner* text_decoder_runner,
2021
bool use_kv_cache,
21-
bool enable_parallel_prefill)
22+
bool enable_parallel_prefill,
23+
int64_t max_seq_len)
2224
: text_decoder_runner_(text_decoder_runner),
2325
use_kv_cache_(use_kv_cache),
24-
enable_parallel_prefill_(enable_parallel_prefill) {}
26+
enable_parallel_prefill_(enable_parallel_prefill),
27+
max_seq_len_(max_seq_len > 0 ? max_seq_len - 1 : 127) {
28+
} // -1 because for some reason tracing results in this upperbound
2529

2630
::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
2731
std::vector<uint64_t>& prompt_tokens,
@@ -30,6 +34,45 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
3034
if (!text_decoder_runner_->is_method_loaded()) {
3135
ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
3236
}
37+
38+
// Check if we need to chunk the prompt tokens
39+
int32_t num_prompt_tokens = prompt_tokens.size();
40+
41+
// If prompt tokens exceed max_seq_len_, we need to chunk them
42+
if (num_prompt_tokens > max_seq_len_) {
43+
uint64_t cur_token = 0;
44+
int num_tokens_to_process = 0;
45+
46+
while (num_tokens_to_process < num_prompt_tokens) {
47+
auto num_tokens_to_prefill_with = std::min<int>(
48+
num_prompt_tokens - num_tokens_to_process, max_seq_len_);
49+
50+
std::vector<uint64_t> prompt_tokens_to_process(
51+
num_tokens_to_prefill_with);
52+
std::copy(
53+
prompt_tokens.begin() + num_tokens_to_process,
54+
prompt_tokens.begin() + num_tokens_to_process +
55+
num_tokens_to_prefill_with,
56+
prompt_tokens_to_process.begin());
57+
58+
// Process this chunk
59+
auto chunk_result = prefillChunk(prompt_tokens_to_process, start_pos);
60+
ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error());
61+
cur_token = chunk_result.get();
62+
63+
num_tokens_to_process += num_tokens_to_prefill_with;
64+
}
65+
66+
return cur_token;
67+
} else {
68+
// If prompt tokens don't exceed max_seq_len_, process them directly
69+
return prefillChunk(prompt_tokens, start_pos);
70+
}
71+
}
72+
73+
::executorch::runtime::Result<uint64_t> TextPrefiller::prefillChunk(
74+
std::vector<uint64_t>& prompt_tokens,
75+
int64_t& start_pos) {
3376
// enable_parallel_prefill_ maybe set even when not using kv cache
3477
// When kv cache is not used, start pos is ignored
3578
int32_t num_prompt_tokens = prompt_tokens.size();

extension/llm/runner/text_prefiller.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ class ET_EXPERIMENTAL TextPrefiller {
2222
TextPrefiller(
2323
TextDecoderRunner* text_decoder_runner,
2424
bool use_kv_cache_,
25-
bool enable_parallel_prefill);
25+
bool enable_parallel_prefill,
26+
int64_t max_seq_len = 128);
2627
/**
2728
* Prefill an LLM Module with the given text input.
2829
* @param prompt_tokens The text prompt tokens to the LLM Module. Encoded by
@@ -35,10 +36,22 @@ class ET_EXPERIMENTAL TextPrefiller {
3536
std::vector<uint64_t>& prompt_tokens,
3637
int64_t& start_pos);
3738

39+
/**
40+
* Helper method to prefill a chunk of tokens.
41+
* @param prompt_tokens The chunk of text prompt tokens to process.
42+
* @param start_pos The starting position in KV cache of the input in the LLM
43+
* Module.
44+
* @return The next token of the LLM Module after prefilling this chunk.
45+
*/
46+
::executorch::runtime::Result<uint64_t> prefillChunk(
47+
std::vector<uint64_t>& prompt_tokens,
48+
int64_t& start_pos);
49+
3850
private:
3951
TextDecoderRunner* text_decoder_runner_;
4052
bool use_kv_cache_;
4153
bool enable_parallel_prefill_;
54+
int64_t max_seq_len_;
4255
};
4356

4457
} // namespace llm

0 commit comments

Comments
 (0)