Skip to content

Commit 03bd85a

Browse files
Kush Rastogifacebook-github-bot
authored andcommitted
Adding KV to Prefill IO (pytorch#9466)
Summary: Pull Request resolved: pytorch#9466 Differential Revision: D71567692
1 parent 6f6fa6a commit 03bd85a

File tree

5 files changed

+105
-36
lines changed

5 files changed

+105
-36
lines changed

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,11 @@ int main(int argc, char** argv) {
7676
std::vector<char> buf;
7777
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
7878
std::ofstream fout(FLAGS_output_path.c_str());
79-
auto callback = [&](const std::string& piece) {
79+
80+
int32_t num_total_tokens = 0;
81+
82+
auto callback = [&](const std::string& piece, int32_t tokens_generated) {
83+
num_total_tokens += tokens_generated;
8084
for (const char c : piece) {
8185
buf.push_back(c);
8286
}
@@ -85,6 +89,7 @@ int main(int argc, char** argv) {
8589
for (int i = 0; i < FLAGS_num_iters; i++) {
8690
runner.generate(
8791
FLAGS_seq_len,
92+
num_total_tokens,
8893
FLAGS_prompt.c_str(),
8994
FLAGS_system_prompt.c_str(),
9095
callback);

examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,45 @@ void ShiftPointerIoMgr::prepare_prefill_io(
494494
}
495495
}
496496

497+
void ShiftPointerIoMgr::update_kv_to_prefill_io(
498+
int64_t pos,
499+
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) {
500+
// update v_cache
501+
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& v_cache_in =
502+
v_cache_in_[prefill_forward_name_];
503+
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& v_cache_out =
504+
v_cache_out_[prefill_forward_name_];
505+
for (int i = 0, v_cache_stride = head_dim_ * pos; i < v_cache_in.size();
506+
i++) {
507+
v_cache_in[i]->set_data(
508+
v_cache_in[i]->mutable_data<uint8_t>() + v_cache_stride);
509+
v_cache_out[i]->set_data(
510+
v_cache_out[i]->mutable_data<uint8_t>() + v_cache_stride);
511+
}
512+
513+
// update k_cache
514+
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& k_cache_in =
515+
k_cache_in_[prefill_forward_name_];
516+
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& k_cache_out =
517+
k_cache_out_[prefill_forward_name_];
518+
519+
for (int i = 0, k_cache_stride = pos * sizeof(uint8_t); i < k_cache_in_.size();
520+
i++) {
521+
k_cache_in[i]->set_data(
522+
k_cache_in[i]->mutable_data<uint8_t>() + k_cache_stride);
523+
k_cache_out[i]->set_data(
524+
k_cache_out[i]->mutable_data<uint8_t>() + k_cache_stride);
525+
}
526+
527+
// Setting attention mask from context_len - prefill_ar_len - i to context_len
528+
IO* ptr = static_cast<IO*>(data_ptr_.get());
529+
for (int i = prefill_ar_len_; i < pos; i++) {
530+
for (int j = 0; j < prefill_ar_len_; j++) {
531+
ptr->prefill_attention_mask[j * context_len_ + context_len_ - prefill_ar_len_ - i] = 65535;
532+
}
533+
}
534+
}
535+
497536
void ShiftPointerIoMgr::update_prefill_to_kv_io(
498537
int64_t cur_token,
499538
int64_t pos,
@@ -664,33 +703,32 @@ void ShiftPointerIoMgr::update_prefill_io(
664703
}
665704

666705
void ShiftPointerIoMgr::fill_prefill_toks(
667-
int64_t start_pos,
706+
int64_t num_prev_tokens,
707+
int64_t prompt_pos,
668708
std::vector<uint64_t>& prompt_tokens) {
669709
IO* ptr = static_cast<IO*>(get_mutable_ptr());
670710
for (int i = 0; i < prefill_ar_len_; i++) {
671711
if (!is_bert_) {
672-
ptr->prefill_input_pos[i] = start_pos + i;
712+
ptr->prefill_input_pos[i] = num_prev_tokens + prompt_pos + i;
673713
}
674714

675-
if (start_pos + i < prompt_tokens.size()) {
715+
if (prompt_pos + i < prompt_tokens.size()) {
676716
// Support CPU 4-bit embedding, which requires int64 input.
677717
// However, for QNN embedding, only int32 input is needed.
678718
// Therefore, we need to cast to the correct type to write the data.
679719
if (use_int64_token_) {
680-
ptr->prefill_input_toks[i] = prompt_tokens[start_pos + i];
720+
ptr->prefill_input_toks[i] = prompt_tokens[prompt_pos + i];
681721
} else {
682722
int32_t* prefill_input_toks_ptr =
683723
reinterpret_cast<int32_t*>(ptr->prefill_input_toks.data());
684724
prefill_input_toks_ptr[i] =
685-
static_cast<int32_t>(prompt_tokens[start_pos + i]);
725+
static_cast<int32_t>(prompt_tokens[prompt_pos + i]);
686726
}
687727
}
688-
if (start_pos >= prefill_ar_len_) {
689-
for (int j = 0,
690-
offset = i * context_len_ +
691-
(context_len_ - prefill_ar_len_ - start_pos);
692-
j < prefill_ar_len_;
693-
++j) {
728+
if (num_prev_tokens + prompt_pos >= prefill_ar_len_) {
729+
int64_t start_offset = i * context_len_ +
730+
(context_len_ - num_prev_tokens - prompt_pos - prefill_ar_len_);
731+
for (int j = 0, offset = start_offset; j < prefill_ar_len_; ++j) {
694732
ptr->prefill_attention_mask[offset + j] = 65535;
695733
}
696734
}
@@ -1305,6 +1343,12 @@ void SmartMaskIoMgr::prepare_prefill_io(
13051343
}
13061344
}
13071345

1346+
void SmartMaskIoMgr::update_kv_to_prefill_io(
1347+
int64_t pos,
1348+
std::vector<std::vector<Tensor>>& output_tensors) {
1349+
//TODO: Fill In
1350+
}
1351+
13081352
void SmartMaskIoMgr::update_prefill_to_kv_io(
13091353
int64_t cur_token,
13101354
int64_t pos,
@@ -1396,29 +1440,30 @@ void SmartMaskIoMgr::update_prefill_io(
13961440
}
13971441

13981442
void SmartMaskIoMgr::fill_prefill_toks(
1399-
int64_t start_pos,
1443+
int64_t num_prev_tokens,
1444+
int64_t prompt_pos,
14001445
std::vector<uint64_t>& prompt_tokens) {
14011446
IO* ptr = static_cast<IO*>(get_mutable_ptr());
14021447
for (int i = 0; i < prefill_ar_len_; i++) {
14031448
if (!is_bert_) {
1404-
ptr->prefill_input_pos[i] = start_pos + i;
1449+
ptr->prefill_input_pos[i] = prompt_pos + i;
14051450
}
14061451

1407-
if (start_pos + i < prompt_tokens.size()) {
1452+
if (prompt_pos + i < prompt_tokens.size()) {
14081453
// Support CPU 4-bit embedding, which requires int64 input.
14091454
// However, for QNN embedding, only int32 input is needed.
14101455
// Therefore, we need to cast to the correct type to write the data.
14111456
if (use_int64_token_) {
1412-
ptr->prefill_input_toks[i] = prompt_tokens[start_pos + i];
1457+
ptr->prefill_input_toks[i] = prompt_tokens[prompt_pos + i];
14131458
} else {
14141459
int32_t* prefill_input_toks_ptr =
14151460
reinterpret_cast<int32_t*>(ptr->prefill_input_toks);
14161461
prefill_input_toks_ptr[i] =
1417-
static_cast<int32_t>(prompt_tokens[start_pos + i]);
1462+
static_cast<int32_t>(prompt_tokens[prompt_pos + i]);
14181463
}
14191464
}
1420-
if (start_pos >= prefill_ar_len_) {
1421-
for (int j = 0, offset = i * context_len_ + (start_pos - prefill_ar_len_);
1465+
if (prompt_pos >= prefill_ar_len_) {
1466+
for (int j = 0, offset = i * context_len_ + (prompt_pos - prefill_ar_len_);
14221467
j < prefill_ar_len_;
14231468
++j) {
14241469
ptr->prefill_attention_mask[offset + j] = 65535;

examples/qualcomm/oss_scripts/llama/runner/io_manager.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,13 @@ class IoMgrBase {
4848
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
4949
methods_meta) = 0;
5050
virtual void fill_prefill_toks(
51-
int64_t start_pos,
51+
int64_t num_prev_tokens,
52+
int64_t prompt_pos,
5253
std::vector<uint64_t>& prompt_tokens) = 0;
5354
virtual void fill_kv_tok_mask(int64_t pos, int64_t cur_token) = 0;
55+
virtual void update_kv_to_prefill_io(
56+
int64_t pos,
57+
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) = 0;
5458
virtual void update_prefill_to_kv_io(
5559
int64_t cur_token,
5660
int64_t pos,
@@ -118,9 +122,13 @@ class ShiftPointerIoMgr : public IoMgrBase {
118122
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
119123
methods_meta) override;
120124
void fill_prefill_toks(
121-
int64_t start_pos,
125+
int64_t num_prev_tokens,
126+
int64_t prompt_pos,
122127
std::vector<uint64_t>& prompt_tokens) override;
123128
void fill_kv_tok_mask(int64_t pos, int64_t cur_token) override;
129+
void update_kv_to_prefill_io(
130+
int64_t pos,
131+
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) override;
124132
void update_prefill_to_kv_io(
125133
int64_t cur_token,
126134
int64_t pos,
@@ -226,9 +234,13 @@ class SmartMaskIoMgr : public IoMgrBase {
226234
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
227235
methods_meta) override;
228236
void fill_prefill_toks(
229-
int64_t start_pos,
237+
int64_t num_prev_tokens,
238+
int64_t prompt_pos,
230239
std::vector<uint64_t>& prompt_tokens) override;
231240
void fill_kv_tok_mask(int64_t pos, int64_t cur_token) override;
241+
void update_kv_to_prefill_io(
242+
int64_t pos,
243+
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) override;
232244
void update_prefill_to_kv_io(
233245
int64_t cur_token,
234246
int64_t pos,

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,10 @@ void Runner::run_model_step(
276276

277277
Error Runner::generate(
278278
int32_t seq_len,
279+
int32_t num_prev_tokens,
279280
const std::string& prompt,
280281
const std::string& system_prompt,
281-
std::function<void(const std::string&)> token_callback,
282+
std::function<void(const std::string&, int32_t)> token_callback,
282283
std::function<void(const Stats&)> stats_callback) {
283284
std::unordered_map<std::string, std::vector<std::vector<Tensor>>>
284285
input_tensors, output_tensors;
@@ -327,14 +328,16 @@ Error Runner::generate(
327328
prompt_.append(
328329
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
329330
if (token_callback) {
330-
token_callback("<|begin_of_text|>");
331+
token_callback("<|begin_of_text|>", 0);
331332
}
332333
break;
333334
default:
334335
ET_CHECK_MSG(false, "unsupported llama version");
335336
break;
336337
}
337338

339+
ET_LOG(Info, "Number of Previous Tokens Prefill + Decode: %d", num_prev_tokens);
340+
338341
seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_;
339342
tokenizers::Result<std::vector<uint64_t>> encode_res =
340343
tokenizer_->encode(prompt_, n_bos_, 0);
@@ -349,7 +352,7 @@ Error Runner::generate(
349352

350353
int64_t pos = 0, prev_token, cur_token = prompt_tokens[0];
351354
if (token_callback) {
352-
token_callback(prompt_);
355+
token_callback(prompt_, num_prompt_tokens);
353356
}
354357
auto prefill_execute = [&](const std::string& method_name) {
355358
int num_iters = 1 + ((num_prompt_tokens - 1) / prefill_ar_len_);
@@ -361,7 +364,7 @@ Error Runner::generate(
361364
num_iters);
362365

363366
for (int i = 0; i < num_iters; i++) {
364-
io_mgr_->fill_prefill_toks(pos, prompt_tokens);
367+
io_mgr_->fill_prefill_toks(num_prev_tokens, pos, prompt_tokens);
365368
run_model_step(method_name, inputs[method_name]);
366369
io_mgr_->update_prefill_io(cur_token, pos, output_tensors[method_name]);
367370
pos += prefill_ar_len_;
@@ -377,10 +380,12 @@ Error Runner::generate(
377380
auto piece_res = tokenizer_->decode(prev_token, cur_token);
378381
ET_CHECK(piece_res.ok());
379382
if (token_callback) {
380-
token_callback(piece_res.get().c_str());
383+
ET_LOG(Info, "Prefill: %s", piece_res.get().c_str());
384+
token_callback(piece_res.get().c_str(), 1);
381385
}
382386

383-
pos = num_prompt_tokens;
387+
pos = num_prev_tokens + num_prompt_tokens;
388+
ET_LOG(Info, "Pos: %ld, Prompt Tokens: %ld", pos, num_prompt_tokens);
384389
stats_.first_token_ms = time_in_ms();
385390
stats_.prompt_eval_end_ms = time_in_ms();
386391
};
@@ -394,9 +399,9 @@ Error Runner::generate(
394399

395400
// hybrid mode will check these stats_ at prefill(prefill)
396401
if (eval_mode_ == EvalMode::kKVCached) {
397-
if (pos == num_prompt_tokens) {
402+
if (pos == num_prev_tokens + num_prompt_tokens) {
398403
stats_.first_token_ms = time_in_ms();
399-
} else if (pos == num_prompt_tokens - 1) {
404+
} else if (pos == num_prev_tokens + num_prompt_tokens - 1) {
400405
stats_.prompt_eval_end_ms = time_in_ms();
401406
}
402407
}
@@ -405,15 +410,15 @@ Error Runner::generate(
405410
cur_token = logitsToToken(logits_tensor, pos);
406411
stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms;
407412

408-
if (pos < num_prompt_tokens - 1) {
413+
if (pos < num_prev_tokens + num_prompt_tokens - 1) {
409414
cur_token = prompt_tokens[pos + 1];
410415
}
411416
io_mgr_->update_kv_io(cur_token, ++pos, output_tensors[method_name]);
412417
auto piece_res = tokenizer_->decode(prev_token, cur_token);
413418
ET_CHECK(piece_res.ok());
414419

415420
if (token_callback && pos >= num_prompt_tokens) {
416-
token_callback(piece_res.get().c_str());
421+
token_callback(piece_res.get().c_str(), 1);
417422
}
418423

419424
if (pos >= num_prompt_tokens && eos_id_.count(cur_token) > 0) {
@@ -432,6 +437,7 @@ Error Runner::generate(
432437
io_mgr_->update_prefill_to_kv_io(
433438
cur_token, pos, output_tensors[kv_forward_name_]);
434439
kv_execute(kv_forward_name_);
440+
io_mgr_->update_kv_to_prefill_io(pos, output_tensors[prefill_forward_name_]);
435441
break;
436442
default:
437443
ET_CHECK_MSG(false, "Unsupported eval mode");
@@ -448,9 +454,9 @@ Error Runner::generate(
448454
if (stats_callback) {
449455
stats_callback(stats_);
450456
}
451-
io_mgr_->reset_io(
452-
get_methods_meta(prefill_forward_name_),
453-
get_methods_meta(kv_forward_name_));
457+
// io_mgr_->reset_io(
458+
// get_methods_meta(prefill_forward_name_),
459+
// get_methods_meta(kv_forward_name_));
454460
prompt_.clear();
455461
return Error::Ok;
456462
}

examples/qualcomm/oss_scripts/llama/runner/runner.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,10 @@ class Runner {
6767
executorch::runtime::Error load();
6868
executorch::runtime::Error generate(
6969
int32_t seq_len,
70+
int32_t num_prev_tokens,
7071
const std::string& prompt,
7172
const std::string& system_prompt,
72-
std::function<void(const std::string&)> token_callback = {},
73+
std::function<void(const std::string&, int32_t)> token_callback = {},
7374
std::function<void(const Stats&)> stats_callback = {});
7475
void stop();
7576
std::vector<executorch::runtime::Result<executorch::runtime::MethodMeta>>

0 commit comments

Comments
 (0)