Skip to content

Commit ba67835

Browse files
authored
Update attention layer (mlc-ai#1153)
Existing dlight optimization only works for NT matmul, but not NN. As a result, the new `nn.Module`-based implementation, which uses NN matmul, fails compilation at HEAD for now. This PR fixes this issue by tweaking `k` to the preferred layout. The following commands now work with the new compilation pipeline: ```bash python -m mlc_chat.cli.compile --config llama2_7b --quantization q4f16_1 -o /tmp/1.so python -m mlc_chat.cli.compile --config llama2_13b --quantization q4f16_1 -o /tmp/1.so python -m mlc_chat.cli.compile --config llama2_70b --quantization q4f16_1 -o /tmp/1.so ``` Note that the quantization algorithm per se, `q4f16_1`, has not been implemented yet, meaning this code path is not yet ready for use so far.
1 parent 1a79a53 commit ba67835

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

python/mlc_chat/compiler/model/llama_model.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,16 @@ def forward( # pylint: disable=too-many-locals
9595

9696
self.k_cache.append(op.squeeze(k, axis=0))
9797
self.v_cache.append(op.squeeze(v, axis=0))
98-
k = op.reshape(self.k_cache.view(total_seq_len), (t, b, h_kv, d))
99-
v = op.reshape(self.v_cache.view(total_seq_len), (t, b, h_kv, d))
98+
k = op.reshape(self.k_cache.view(total_seq_len), (b, t, h_kv, d))
99+
v = op.reshape(self.v_cache.view(total_seq_len), (b, t, h_kv, d))
100100
if h_kv != h_q:
101101
k = k.repeat(h_q // h_kv, axis=2)
102102
v = v.repeat(h_q // h_kv, axis=2)
103-
attn_weights = op.matmul( # [b, h, s, t]
104-
q.permute_dims([0, 2, 1, 3]), # [b, h, s, d]
105-
k.permute_dims([1, 2, 3, 0]), # [b, h, d, t]
103+
q = q.permute_dims([0, 2, 1, 3]) # [b, h, s, d]
104+
k = k.permute_dims([0, 2, 1, 3]) # [b, h, t, d]
105+
v = v.permute_dims([0, 2, 1, 3]) # [b, h, t, d]
106+
attn_weights = op.matmul(
107+
q, k.permute_dims([0, 1, 3, 2]) # [b, h, s, d] x [b, h, d, t] = [b, h, s, t]
106108
) / math.sqrt(d)
107109
dtype = attn_weights.dtype
108110
attn_weights = attn_weights.maximum(tir.min_value(dtype)).minimum(attention_mask)
@@ -111,10 +113,7 @@ def forward( # pylint: disable=too-many-locals
111113
else:
112114
attn_weights = op.softmax(attn_weights.astype("float32"), axis=-1).astype(dtype)
113115
return self.o_proj(
114-
op.matmul( # [b, h, s, d]
115-
attn_weights, # [b, h, s, t]
116-
v.permute_dims([1, 2, 0, 3]), # [b, h, t, d]
117-
)
116+
op.matmul(attn_weights, v) # [b, h, s, t] x [b, h, t, d] = [b, h, s, d]
118117
.permute_dims([0, 2, 1, 3]) # [b, s, h, d]
119118
.reshape((b, s, h_q * d))
120119
)

python/mlc_chat/support/auto_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,13 @@ def detect_config(config: Union[str, Path]) -> Path:
3535
)
3636

3737
if isinstance(config, str) and config in MODEL_PRESETS:
38+
logger.info("%s preset model: %s", FOUND, config)
3839
content = MODEL_PRESETS[config]
3940
temp_file = tempfile.NamedTemporaryFile( # pylint: disable=consider-using-with
4041
suffix=".json",
4142
delete=False,
4243
)
43-
logger.info("%s preset model configuration: %s", FOUND, temp_file.name)
44+
logger.info("Dumping config to: %s", temp_file.name)
4445
config_path = Path(temp_file.name)
4546
with config_path.open("w", encoding="utf-8") as config_file:
4647
json.dump(content, config_file, indent=2)

0 commit comments

Comments
 (0)