Skip to content

metal : improve FA + improve MoE #12612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 28, 2025
10 changes: 5 additions & 5 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1791,11 +1791,11 @@ extern "C" {

#define GGML_KQ_MASK_PAD 64

// q: [n_embd, n_batch, n_head, 1]
// k: [n_embd, n_kv, n_head_kv, 1]
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
// q: [n_embd_k, n_batch, n_head, 1]
// k: [n_embd_k, n_kv, n_head_kv, 1]
// v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx,
struct ggml_tensor * q,
Expand Down
50 changes: 25 additions & 25 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -12238,23 +12238,23 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int ith = params->ith;
const int nth = params->nth;

const int64_t D = neq0;
const int64_t N = neq1;
const int64_t DK = nek0;
const int64_t DV = nev0;
const int64_t N = neq1;

GGML_ASSERT(ne0 == D);
GGML_ASSERT(ne0 == DV);
GGML_ASSERT(ne2 == N);

// input tensor rows must be contiguous
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
GGML_ASSERT(nbv0 == ggml_type_size(v->type));

GGML_ASSERT(neq0 == D);
GGML_ASSERT(nek0 == D);
GGML_ASSERT(nev0 == D);
GGML_ASSERT(neq0 == DK);
GGML_ASSERT(nek0 == DK);
GGML_ASSERT(nev0 == DV);

GGML_ASSERT(neq1 == N);
GGML_ASSERT(nev0 == D);

// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
Expand Down Expand Up @@ -12320,15 +12320,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value

float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16

if (v->type == GGML_TYPE_F16) {
memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
} else {
memset(VKQ32, 0, D*sizeof(float));
memset(VKQ32, 0, DV*sizeof(float));
}

const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
Expand All @@ -12342,7 +12342,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int iv2 = iq2 / rv2;

const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
q_to_vec_dot(pq, Q_q, D);
q_to_vec_dot(pq, Q_q, DK);

// online softmax / attention
// loop over n_kv and n_head_kv
Expand All @@ -12356,7 +12356,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
float s; // KQ value

const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);

s = s*scale; // scale KQ value

Expand All @@ -12380,45 +12380,45 @@ static void ggml_compute_forward_flash_attn_ext_f16(
ms = expf(Mold - M);

// V = V*expf(Mold - M)
ggml_vec_scale_f16(D, VKQ16, ms);
ggml_vec_scale_f16(DV, VKQ16, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}

// V += v*expf(s - M)
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
} else {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
ms = expf(Mold - M);

// V = V*expf(Mold - M)
ggml_vec_scale_f32(D, VKQ32, ms);
ggml_vec_scale_f32(DV, VKQ32, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}

v_to_float(v_data, V32, D);
v_to_float(v_data, V32, DV);

// V += v*expf(s - M)
ggml_vec_mad_f32(D, VKQ32, V32, vs);
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
}

S = S*ms + vs; // scale and increment sum with partial sum
}

if (v->type == GGML_TYPE_F16) {
for (int64_t d = 0; d < D; ++d) {
for (int64_t d = 0; d < DV; ++d) {
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
}
}

// V /= S
const float S_inv = 1.0f/S;
ggml_vec_scale_f32(D, VKQ32, S_inv);
ggml_vec_scale_f32(DV, VKQ32, S_inv);

// dst indices
const int i1 = iq1;
Expand Down Expand Up @@ -15277,7 +15277,6 @@ struct ggml_cplan ggml_graph_plan(
size_t cur = 0;

if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {

switch (node->op) {
case GGML_OP_CPY:
case GGML_OP_DUP:
Expand Down Expand Up @@ -15386,9 +15385,10 @@ struct ggml_cplan ggml_graph_plan(
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
const int64_t ne00 = node->src[0]->ne[0]; // D
const int64_t ne10 = node->src[1]->ne[0]; // DK
const int64_t ne20 = node->src[2]->ne[0]; // DV

cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3232,6 +3232,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
#ifndef FLASH_ATTN_AVAILABLE
return false;
#endif // FLASH_ATTN_AVAILABLE
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
return false;
}
if (op->src[0]->ne[0] == 192) {
return false;
}
if (op->src[0]->ne[3] != 1) {
return false;
}
Expand Down
9 changes: 6 additions & 3 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,12 @@ typedef struct {
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
uint64_t nb_12_1;
uint64_t nb_12_2;
uint64_t nb_12_3;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
uint64_t nb31;
int32_t ne1;
int32_t ne2;
Expand Down
Loading
Loading