Skip to content

Commit 0edf87b

Browse files
committed
lint fix
1 parent 4a697ff commit 0edf87b

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

python/mlc_chat/compiler/compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _emit_metadata(metadata):
8383

8484

8585
def _attach_variable_bounds(mod, model_config):
86-
tir_bound_map = dict()
86+
tir_bound_map = {}
8787
tir_bound_map["seq_len"] = model_config.prefill_chunk_size
8888
if model_config.context_window_size != -1:
8989
tir_bound_map["total_seq_len"] = model_config.context_window_size

python/mlc_chat/compiler/model/mistral/mistral_model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -371,15 +371,21 @@ def _index(x: te.Tensor): # x[:-1,:]
371371
b, s, d = x.shape
372372
return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index")
373373

374-
hidden_states = self.model(inputs, total_seq_len, rolling_cache_len, kv_seq_len, attention_mask)
374+
hidden_states = self.model(
375+
inputs, total_seq_len, rolling_cache_len, kv_seq_len, attention_mask
376+
)
375377
hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states])
376378
logits = self.lm_head(hidden_states)
377379
if logits.dtype != "float32":
378380
logits = logits.astype("float32")
379381
return logits
380382

381383
def prefill(
382-
self, inputs: Tensor, total_seq_len: tir.Var, rolling_cache_len: tir.Var, kv_seq_len: tir.Var
384+
self,
385+
inputs: Tensor,
386+
total_seq_len: tir.Var,
387+
rolling_cache_len: tir.Var,
388+
kv_seq_len: tir.Var,
383389
):
384390
"""
385391
Prefilling the prompt.
@@ -428,7 +434,11 @@ def _sliding_window_attention_mask(
428434
return self.forward(inputs, total_seq_len, rolling_cache_len, kv_seq_len, attention_mask)
429435

430436
def decode(
431-
self, inputs: Tensor, total_seq_len: tir.Var, rolling_cache_len: tir.Var, kv_seq_len: tir.Var
437+
self,
438+
inputs: Tensor,
439+
total_seq_len: tir.Var,
440+
rolling_cache_len: tir.Var,
441+
kv_seq_len: tir.Var,
432442
):
433443
"""Decoding step."""
434444
batch_size, seq_len = inputs.shape

0 commit comments

Comments
 (0)