@@ -371,15 +371,21 @@ def _index(x: te.Tensor): # x[:-1,:]
371
371
b , s , d = x .shape
372
372
return te .compute ((b , 1 , d ), lambda i , _ , k : x [i , s - 1 , k ], name = "index" )
373
373
374
- hidden_states = self .model (inputs , total_seq_len , rolling_cache_len , kv_seq_len , attention_mask )
374
+ hidden_states = self .model (
375
+ inputs , total_seq_len , rolling_cache_len , kv_seq_len , attention_mask
376
+ )
375
377
hidden_states = op .tensor_expr_op (_index , name_hint = "index" , args = [hidden_states ])
376
378
logits = self .lm_head (hidden_states )
377
379
if logits .dtype != "float32" :
378
380
logits = logits .astype ("float32" )
379
381
return logits
380
382
381
383
def prefill (
382
- self , inputs : Tensor , total_seq_len : tir .Var , rolling_cache_len : tir .Var , kv_seq_len : tir .Var
384
+ self ,
385
+ inputs : Tensor ,
386
+ total_seq_len : tir .Var ,
387
+ rolling_cache_len : tir .Var ,
388
+ kv_seq_len : tir .Var ,
383
389
):
384
390
"""
385
391
Prefilling the prompt.
@@ -428,7 +434,11 @@ def _sliding_window_attention_mask(
428
434
return self .forward (inputs , total_seq_len , rolling_cache_len , kv_seq_len , attention_mask )
429
435
430
436
def decode (
431
- self , inputs : Tensor , total_seq_len : tir .Var , rolling_cache_len : tir .Var , kv_seq_len : tir .Var
437
+ self ,
438
+ inputs : Tensor ,
439
+ total_seq_len : tir .Var ,
440
+ rolling_cache_len : tir .Var ,
441
+ kv_seq_len : tir .Var ,
432
442
):
433
443
"""Decoding step."""
434
444
batch_size , seq_len = inputs .shape
0 commit comments