File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -239,7 +239,10 @@ def _sample_draft_tokens(
239
239
x = x - x .amax (dim = - 1 , keepdim = True )
240
240
241
241
# --- temperature for drafter q ---
242
- tau_d = float (getattr (self .opt_config , "draft_temperature" , 1.0 ) or 1.0 )
242
+ # Read from TARGET temperature (not draft_temperature)
243
+ tau_d = 1.0
244
+ if hasattr (self , '_current_sampling_metadata' ) and self ._current_sampling_metadata is not None :
245
+ tau_d = float (getattr (self ._current_sampling_metadata , 'temperature' , 1.0 ))
243
246
244
247
tau_q = tau_d + float (getattr (self .opt_config , "draft_q_temp_offset" , 0.0 ))
245
248
tau_max = float (getattr (self .opt_config , "draft_tau_max" , 0.0 ))
@@ -324,6 +327,9 @@ def propose(
324
327
sampling_metadata : SamplingMetadata ,
325
328
mm_embeds : Optional [list [torch .Tensor ]] = None ,
326
329
) -> torch .Tensor :
330
+ # Store sampling_metadata so _sample_draft_tokens() can access target temperature
331
+ self ._current_sampling_metadata = sampling_metadata
332
+
327
333
num_tokens = target_token_ids .shape [0 ]
328
334
batch_size = next_token_ids .shape [0 ]
329
335
You can’t perform that action at this time.
0 commit comments