@@ -2839,8 +2839,8 @@ static void llm_load_tensors(
2839
2839
auto & layer = model.layers [i];
2840
2840
2841
2841
layer.attn_norm = ml.create_tensor (ctx, tn (LLM_TENSOR_ATTN_NORM, " weight" , i), {n_embd}, backend);
2842
- layer.wqkv = ml.create_tensor (ctx, tn (LLM_TENSOR_ATTN_QKV, " weight" , i), {n_embd, 3 * n_embd}, backend_split);
2843
- layer.wo = ml.create_tensor (ctx, tn (LLM_TENSOR_ATTN_OUT, " weight" , i), {n_embd, n_embd}, backend_split);
2842
+ layer.wqkv = ml.create_tensor (ctx, tn (LLM_TENSOR_ATTN_QKV, " weight" , i), {n_embd, n_embd + 2 *n_embd_gqa }, backend_split);
2843
+ layer.wo = ml.create_tensor (ctx, tn (LLM_TENSOR_ATTN_OUT, " weight" , i), {n_embd, n_embd}, backend_split);
2844
2844
2845
2845
layer.ffn_norm = ml.create_tensor (ctx, tn (LLM_TENSOR_FFN_NORM, " weight" , i), {n_embd}, backend);
2846
2846
@@ -5368,7 +5368,7 @@ static struct ggml_cgraph * llm_build_mpt(
5368
5368
const int64_t n_layer = hparams.n_layer ;
5369
5369
const int64_t n_ctx = cparams.n_ctx ;
5370
5370
const int64_t n_head = hparams.n_head ;
5371
- const int64_t n_head_kv = hparams.n_head_kv ; // == n_head for MPT, as there's no MQA/GQA
5371
+ const int64_t n_head_kv = hparams.n_head_kv ;
5372
5372
const int64_t n_embd_head = hparams.n_embd_head ();
5373
5373
const int64_t n_embd_gqa = hparams.n_embd_gqa ();
5374
5374
0 commit comments