Skip to content

Commit ea8b42e

Browse files
DongheJinyq33victor
authored andcommitted
bugfix: fix coredump issue when both prefixcache and mtp are enabled. (#377)
* bugfix: fix coredump issue when both prefixcache and mtp are enabled. * bugfix: fix coredump caused by incorrect token replacement.
1 parent d7ec230 commit ea8b42e

File tree

7 files changed

+27
-10
lines changed

7 files changed

+27
-10
lines changed

xllm/core/framework/model/model_args.h

100755100644
File mode changed.

xllm/core/framework/request/sequence_kv_state.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ void KVCacheState::add_shared_kv_blocks(std::vector<Block>&& blocks,
5858
if (blocks.empty()) {
5959
return;
6060
}
61-
6261
// The number of matched blocks may be fewer than the number of blocks held by
6362
// the sequence itself. In this case, try to replace the blocks computed by
6463
// the sequence with blocks from the prefix_cache and release the computed
@@ -86,6 +85,10 @@ void KVCacheState::add_shared_kv_blocks(std::vector<Block>&& blocks,
8685
CHECK_GT(block_size, 0);
8786
num_shared_tokens =
8887
((current_total_num_tokens - 1) / block_size) * block_size;
88+
if (num_owned_shared_blocks_ > 0) {
89+
num_owned_shared_blocks_--;
90+
blocks_.pop_back();
91+
}
8992
}
9093
CHECK_LT(num_shared_tokens, current_total_num_tokens);
9194
// update the kv cache position

xllm/core/runtime/llm_worker_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ std::optional<ForwardOutput> LLMWorkerImpl::step(const ForwardInput& input) {
174174
// should be in same prefill stage, so, to judge empty_kv_cache,
175175
// just use micro batch 0 here
176176
if (options_.enable_speculative_decode() && !is_spec_draft_) {
177-
if (input.input_params.q_seq_lens_vec[0] > 1) {
177+
if (check_is_prefill(input.input_params.q_seq_lens_vec)) {
178178
output.sample_output.embeddings = hidden_states;
179179
} else if (sampling_params.sample_idxes.defined()) {
180180
// auto sample_idxes =

xllm/core/runtime/speculative_worker_impl.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step(
171171
}
172172

173173
// TODO: support data parallel case
174-
if (input.input_params.q_seq_lens_vec[0] > 1) {
174+
if (check_is_prefill(input.input_params.q_seq_lens_vec)) {
175175
return step_prefill(input);
176176
} else {
177177
return step_decode(input);
@@ -180,7 +180,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step(
180180

181181
std::optional<ForwardOutput> SpeculativeWorkerImpl::step_empty(
182182
const ForwardInput& input) {
183-
if (input.input_params.q_seq_lens_vec[0] > 1) {
183+
if (check_is_prefill(input.input_params.q_seq_lens_vec)) {
184184
auto output = impl_->step(input);
185185
auto draft_output = draft_impl_->step(input);
186186
return output;
@@ -224,9 +224,10 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_prefill(
224224
auto offset = input.input_params.num_sequences;
225225
auto token_offset = prefill_input.token_ids.size(0);
226226
if (token_offset > 0) {
227-
prefill_input.input_params.mm_data = MMData(
228-
MMType::EMBEDDING,
229-
{{"embedding", embeddings.narrow(0, token_start_idx, token_offset)}});
227+
prefill_input.input_params.mm_data =
228+
MMData(MMType::EMBEDDING,
229+
{{"embedding",
230+
embeddings.narrow(0, token_start_idx, token_offset).clone()}});
230231
}
231232
if (next_tokens.defined()) {
232233
auto& token_ids = prefill_input.token_ids;
@@ -329,7 +330,11 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
329330
// final step
330331
prepare_validate_inputs(input, validate_input, true);
331332
} else {
332-
prepare_draft_inputs(draft_input, next_step_input, 1, device_);
333+
if (i == 0) {
334+
prepare_draft_inputs(input, next_step_input, 1, device_);
335+
} else {
336+
prepare_draft_inputs(draft_input, next_step_input, 1, device_);
337+
}
333338
}
334339
draft_outputs.push_back(std::move(future).get().value());
335340
// update input of next step
@@ -759,7 +764,7 @@ void SpeculativeWorkerImpl::update_sampling_params(
759764
void SpeculativeWorkerImpl::prepare_work_before_execute(
760765
const ForwardInput& input,
761766
ForwardInput& processed_input) {
762-
if (input.input_params.q_seq_lens_vec[0] > 1) {
767+
if (check_is_prefill(input.input_params.q_seq_lens_vec)) {
763768
WorkerImpl::prepare_work_before_execute(input, processed_input);
764769
} else {
765770
if (enable_schedule_overlap()) {

xllm/core/runtime/worker_impl.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,5 +1039,13 @@ AlignedTensorCreater::AlignedTensorCreater(
10391039
LOG(INFO) << "Page aligned: "
10401040
<< ((uintptr_t)base_ptr_ % page_size == 0 ? "YES" : "NO");
10411041
}
1042+
bool WorkerImpl::check_is_prefill(const std::vector<int>& q_seq_lens_vec) {
1043+
for (auto q_len : q_seq_lens_vec) {
1044+
if (q_len > 1) {
1045+
return true;
1046+
}
1047+
}
1048+
return false;
1049+
}
10421050

10431051
} // namespace xllm

xllm/core/runtime/worker_impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ class WorkerImpl {
165165

166166
torch::ScalarType dtype() const { return dtype_; }
167167

168+
bool check_is_prefill(const std::vector<int>& q_seq_lens_vec);
169+
168170
int32_t hidden_size() const {
169171
return context_.get_model_args().hidden_size();
170172
}

xllm/core/scheduler/continuous_scheduler.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ ContinuousScheduler::ContinuousScheduler(Engine* engine, const Options& options)
9393
} else {
9494
min_speculative_tokens_required_ = options_.num_speculative_tokens();
9595
}
96-
9796
}
9897

9998
ContinuousScheduler::~ContinuousScheduler() { running_requests_.clear(); }

0 commit comments

Comments
 (0)