@@ -171,7 +171,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step(
171171 }
172172
173173 // TODO: support data parallel case
174- if (input.input_params .q_seq_lens_vec [ 0 ] > 1 ) {
174+ if (check_is_prefill ( input.input_params .q_seq_lens_vec ) ) {
175175 return step_prefill (input);
176176 } else {
177177 return step_decode (input);
@@ -180,7 +180,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step(
180180
181181std::optional<ForwardOutput> SpeculativeWorkerImpl::step_empty (
182182 const ForwardInput& input) {
183- if (input.input_params .q_seq_lens_vec [ 0 ] > 1 ) {
183+ if (check_is_prefill ( input.input_params .q_seq_lens_vec ) ) {
184184 auto output = impl_->step (input);
185185 auto draft_output = draft_impl_->step (input);
186186 return output;
@@ -224,9 +224,10 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_prefill(
224224 auto offset = input.input_params .num_sequences ;
225225 auto token_offset = prefill_input.token_ids .size (0 );
226226 if (token_offset > 0 ) {
227- prefill_input.input_params .mm_data = MMData (
228- MMType::EMBEDDING,
229- {{" embedding" , embeddings.narrow (0 , token_start_idx, token_offset)}});
227+ prefill_input.input_params .mm_data =
228+ MMData (MMType::EMBEDDING,
229+ {{" embedding" ,
230+ embeddings.narrow (0 , token_start_idx, token_offset).clone ()}});
230231 }
231232 if (next_tokens.defined ()) {
232233 auto & token_ids = prefill_input.token_ids ;
@@ -329,7 +330,11 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
329330 // final step
330331 prepare_validate_inputs (input, validate_input, true );
331332 } else {
332- prepare_draft_inputs (draft_input, next_step_input, 1 , device_);
333+ if (i == 0 ) {
334+ prepare_draft_inputs (input, next_step_input, 1 , device_);
335+ } else {
336+ prepare_draft_inputs (draft_input, next_step_input, 1 , device_);
337+ }
333338 }
334339 draft_outputs.push_back (std::move (future).get ().value ());
335340 // update input of next step
@@ -759,7 +764,7 @@ void SpeculativeWorkerImpl::update_sampling_params(
759764void SpeculativeWorkerImpl::prepare_work_before_execute (
760765 const ForwardInput& input,
761766 ForwardInput& processed_input) {
762- if (input.input_params .q_seq_lens_vec [ 0 ] > 1 ) {
767+ if (check_is_prefill ( input.input_params .q_seq_lens_vec ) ) {
763768 WorkerImpl::prepare_work_before_execute (input, processed_input);
764769 } else {
765770 if (enable_schedule_overlap ()) {
0 commit comments