Skip to content

Commit a2f79cc

Browse files
yael-worksGitty Burstein
authored andcommitted
Cleanup: remove old SparseK operator; keep only mask augmentation
1 parent 422c65a commit a2f79cc

File tree

6 files changed

+8
-257
lines changed

6 files changed

+8
-257
lines changed

ggml/include/ggml.h

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ extern "C" {
530530
GGML_OP_TIMESTEP_EMBEDDING,
531531
GGML_OP_ARGSORT,
532532
GGML_OP_LEAKY_RELU,
533-
GGML_OP_SPARSEK_ATTN,
533+
534534
GGML_OP_FLASH_ATTN_EXT,
535535
GGML_OP_FLASH_ATTN_BACK,
536536
GGML_OP_SSM_CONV,
@@ -2233,26 +2233,6 @@ extern "C" {
22332233
// n_head % ne32 == 0
22342234
// ne3 % ne33 == 0
22352235
//
2236-
2237-
GGML_API struct ggml_tensor * ggml_sparsek_attn(
2238-
struct ggml_context * ctx,
2239-
struct ggml_tensor * Q,
2240-
struct ggml_tensor * K,
2241-
struct ggml_tensor * V,
2242-
int32_t k_top,
2243-
int32_t win_local,
2244-
int32_t stride_global);
2245-
2246-
GGML_API void ggml_sparsek_attn_set_params(
2247-
struct ggml_tensor * a,
2248-
int32_t k_top,
2249-
int32_t win_local,
2250-
int32_t stride_global);
2251-
2252-
GGML_API int32_t ggml_sparsek_attn_get_param(
2253-
const struct ggml_tensor * a,
2254-
int index);
2255-
22562236
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
22572237
struct ggml_context * ctx,
22582238
struct ggml_tensor * q,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1947,10 +1947,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
19471947
{
19481948
ggml_compute_forward_flash_attn_ext(params, tensor);
19491949
} break;
1950-
case GGML_OP_SPARSEK_ATTN:
1951-
{
1952-
ggml_compute_forward_sparsek_attn(params, tensor);
1953-
} break;
19541950
case GGML_OP_FLASH_ATTN_BACK:
19551951
{
19561952
int32_t t = ggml_get_op_params_i32(tensor, 0);

ggml/src/ggml-cpu/ops.cpp

Lines changed: 0 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
#include <float.h>
1111
#include <algorithm>
12-
#include <vector>
1312

1413
// ggml_compute_forward_dup
1514

@@ -7971,126 +7970,6 @@ void ggml_compute_forward_argsort(
79717970
}
79727971
}
79737972

7974-
//------------------------------------------------------------------------------
7975-
// SparseK Attention (CPU)
7976-
//------------------------------------------------------------------------------
7977-
7978-
static void ggml_compute_forward_sparsek_attn_f32(
7979-
const struct ggml_compute_params * params,
7980-
struct ggml_tensor * dst) {
7981-
7982-
// Single-threaded baseline version
7983-
if (params->ith != 0) return;
7984-
7985-
const struct ggml_tensor * Q = dst->src[0];
7986-
const struct ggml_tensor * K = dst->src[1];
7987-
const struct ggml_tensor * V = dst->src[2];
7988-
7989-
GGML_ASSERT(Q && K && V);
7990-
GGML_ASSERT(Q->type == GGML_TYPE_F32);
7991-
GGML_ASSERT(K->type == GGML_TYPE_F32);
7992-
GGML_ASSERT(V->type == GGML_TYPE_F32);
7993-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
7994-
7995-
const int32_t k_top = ggml_get_op_params_i32(dst, 0);
7996-
const int32_t win_local = ggml_get_op_params_i32(dst, 1);
7997-
const int32_t stride_glb = ggml_get_op_params_i32(dst, 2);
7998-
GGML_UNUSED(win_local);
7999-
GGML_UNUSED(stride_glb);
8000-
8001-
// Tensor dimensions according to GGML layout: ne[0]=d, ne[1]=seq, ne[2]=head, ne[3]=batch
8002-
const int64_t D = Q->ne[0];
8003-
const int64_t T = Q->ne[1];
8004-
const int64_t H = Q->ne[2];
8005-
const int64_t B = Q->ne[3];
8006-
8007-
// Temporary buffer for attention scores for one query row
8008-
std::vector<float> attn_row(T, 0.0f);
8009-
8010-
const float scale = 1.0f / sqrtf((float) D);
8011-
8012-
// Loops over batch, head, and query token
8013-
for (int64_t b = 0; b < B; ++b) {
8014-
for (int64_t h = 0; h < H; ++h) {
8015-
for (int64_t iq = 0; iq < T; ++iq) {
8016-
8017-
// (1) Compute dot products Q·K within same (b,h)
8018-
const char * qbase = (const char *) Q->data + b*Q->nb[3] + h*Q->nb[2] + iq*Q->nb[1];
8019-
const float * qv = (const float *) qbase;
8020-
8021-
for (int64_t j = 0; j < T; ++j) {
8022-
const char * kbase = (const char *) K->data + b*K->nb[3] + h*K->nb[2] + j*K->nb[1];
8023-
const float * kv = (const float *) kbase;
8024-
8025-
float dot = 0.0f;
8026-
for (int64_t d = 0; d < D; ++d) {
8027-
dot += qv[d] * kv[d];
8028-
}
8029-
attn_row[j] = dot * scale;
8030-
}
8031-
8032-
// (2) Select top-k threshold using nth_element
8033-
const int kk = std::max<int>(1, std::min<int>((int)T, k_top));
8034-
std::vector<float> tmp(attn_row.begin(), attn_row.end());
8035-
std::nth_element(tmp.begin(), tmp.begin() + (kk - 1), tmp.end(), std::greater<float>());
8036-
const float thr = tmp[kk - 1];
8037-
8038-
for (int64_t j = 0; j < T; ++j) {
8039-
if (attn_row[j] < thr) attn_row[j] = -INFINITY;
8040-
}
8041-
8042-
// (3) Numerically stable softmax on the masked row
8043-
float maxv = -INFINITY;
8044-
for (int64_t j = 0; j < T; ++j) {
8045-
maxv = std::max(maxv, attn_row[j]);
8046-
}
8047-
float sum = 0.0f;
8048-
for (int64_t j = 0; j < T; ++j) {
8049-
float v = attn_row[j] - maxv;
8050-
float e = expf(v);
8051-
attn_row[j] = e;
8052-
sum += e;
8053-
}
8054-
const float inv_sum = sum > 0.0f ? 1.0f / sum : 0.0f;
8055-
for (int64_t j = 0; j < T; ++j) {
8056-
attn_row[j] *= inv_sum;
8057-
}
8058-
8059-
// (4) Compute output = A·V (weighted sum)
8060-
float * y = (float *) ((char *) dst->data + b*dst->nb[3] + h*dst->nb[2] + iq*dst->nb[1]);
8061-
8062-
for (int64_t d = 0; d < D; ++d) {
8063-
float acc = 0.0f;
8064-
for (int64_t j = 0; j < T; ++j) {
8065-
const float aij = attn_row[j];
8066-
if (aij == 0.0f) continue; // skip masked entries
8067-
const char * vbase = (const char *) V->data + b*V->nb[3] + h*V->nb[2] + j*V->nb[1];
8068-
const float * vv = (const float *) vbase;
8069-
acc += aij * vv[d];
8070-
}
8071-
y[d] = acc;
8072-
}
8073-
}
8074-
}
8075-
}
8076-
8077-
GGML_PRINT_DEBUG("[SPARSEK CPU] k_top=%d win_local=%d stride=%d\n",
8078-
k_top, win_local, stride_glb);
8079-
}
8080-
8081-
void ggml_compute_forward_sparsek_attn(
8082-
const struct ggml_compute_params * params,
8083-
struct ggml_tensor * dst) {
8084-
switch (dst->type) {
8085-
case GGML_TYPE_F32:
8086-
ggml_compute_forward_sparsek_attn_f32(params, dst);
8087-
break;
8088-
default:
8089-
GGML_ASSERT(false && "sparsek_attn: unsupported dst type");
8090-
}
8091-
}
8092-
8093-
80947973
// ggml_compute_forward_flash_attn_ext
80957974

80967975
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(

ggml/src/ggml-cpu/ops.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,6 @@ void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params *
8686
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
8787
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
8888
void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
89-
void ggml_compute_forward_sparsek_attn(const struct ggml_compute_params * params, struct ggml_tensor * dst);
90-
9189
void ggml_compute_forward_flash_attn_back(
9290
const struct ggml_compute_params * params,
9391
const bool masked,

ggml/src/ggml.c

Lines changed: 4 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -990,7 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
990990
"TIMESTEP_EMBEDDING",
991991
"ARGSORT",
992992
"LEAKY_RELU",
993-
"SPARSEK_ATTN",
993+
994994
"FLASH_ATTN_EXT",
995995
"FLASH_ATTN_BACK",
996996
"SSM_CONV",
@@ -1019,7 +1019,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
10191019
"GLU",
10201020
};
10211021

1022-
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");
1022+
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
10231023

10241024
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10251025
"none",
@@ -1094,7 +1094,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10941094
"timestep_embedding(timesteps, dim, max_period)",
10951095
"argsort(x)",
10961096
"leaky_relu(x)",
1097-
"sparsek_attn(x)",
1097+
10981098
"flash_attn_ext(x)",
10991099
"flash_attn_back(x)",
11001100
"ssm_conv(x)",
@@ -1123,7 +1123,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
11231123
"glu(x)",
11241124
};
11251125

1126-
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");
1126+
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
11271127

11281128
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
11291129

@@ -5063,52 +5063,6 @@ struct ggml_tensor * ggml_top_k(
50635063
return result;
50645064
}
50655065

5066-
// ggml_sparsek_attn
5067-
struct ggml_tensor * ggml_sparsek_attn(
5068-
struct ggml_context * ctx,
5069-
struct ggml_tensor * Q,
5070-
struct ggml_tensor * K,
5071-
struct ggml_tensor * V,
5072-
int32_t k_top,
5073-
int32_t win_local,
5074-
int32_t stride_global) {
5075-
5076-
GGML_ASSERT(ggml_can_mul_mat(K, Q));
5077-
GGML_ASSERT(Q->ne[3] == K->ne[3] && Q->ne[3] == V->ne[3]);
5078-
5079-
int64_t ne[4] = { V->ne[0], Q->ne[2], Q->ne[1], Q->ne[3] };
5080-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
5081-
5082-
5083-
int32_t params_i32[3] = { k_top, win_local, stride_global };
5084-
ggml_set_op_params(result, params_i32, sizeof(params_i32));
5085-
5086-
result->op = GGML_OP_SPARSEK_ATTN;
5087-
result->src[0] = Q;
5088-
result->src[1] = K;
5089-
result->src[2] = V;
5090-
5091-
return result;
5092-
}
5093-
5094-
5095-
void ggml_sparsek_attn_set_params(struct ggml_tensor * a,
5096-
int32_t k_top,
5097-
int32_t win_local,
5098-
int32_t stride_global) {
5099-
GGML_ASSERT(a->op == GGML_OP_SPARSEK_ATTN);
5100-
ggml_set_op_params_i32(a, 0, k_top);
5101-
ggml_set_op_params_i32(a, 1, win_local);
5102-
ggml_set_op_params_i32(a, 2, stride_global);
5103-
}
5104-
5105-
int32_t ggml_sparsek_attn_get_param(const struct ggml_tensor * a, int index) {
5106-
GGML_ASSERT(a->op == GGML_OP_SPARSEK_ATTN);
5107-
return ggml_get_op_params_i32(a, index);
5108-
}
5109-
5110-
5111-
51125066
// ggml_flash_attn_ext
51135067

51145068
struct ggml_tensor * ggml_flash_attn_ext(

tests/test-backend-ops.cpp

Lines changed: 3 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,7 +1780,6 @@ struct test_example : public test_case {
17801780
};
17811781

17821782

1783-
17841783
// GGML_OP_UNARY
17851784
struct test_unary : public test_case {
17861785
const ggml_unary_op op;
@@ -5558,46 +5557,7 @@ struct test_leaky_relu : public test_case {
55585557
}
55595558
};
55605559

5561-
// GGML_OP_SPARSEK_ATTN
5562-
struct test_sparsek_attn : public test_case {
5563-
const int64_t d_qk;
5564-
const int64_t d_v;
5565-
const int64_t n_head;
5566-
const int64_t n_tokens;
5567-
const int64_t batch;
5568-
const int32_t k_top;
5569-
const int32_t win_local;
5570-
const int32_t stride_global;
5571-
5572-
std::string vars() override {
5573-
return VARS_TO_STR9(d_qk, d_v, n_head, n_tokens, batch, k_top, win_local, stride_global, 0);
5574-
}
5575-
5576-
test_sparsek_attn(int64_t d_qk = 128, int64_t d_v = 128, int64_t n_head = 8,
5577-
int64_t n_tokens = 256, int64_t batch = 4,
5578-
int32_t k_top = 32, int32_t win_local = 64, int32_t stride_global = 128)
5579-
: d_qk(d_qk), d_v(d_v), n_head(n_head), n_tokens(n_tokens), batch(batch),
5580-
k_top(k_top), win_local(win_local), stride_global(stride_global) {}
5581-
5582-
ggml_tensor * build_graph(ggml_context * ctx) override {
5583-
const int64_t n_q = n_tokens;
5584-
ggml_tensor * Q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_qk, n_q, n_head, batch);
5585-
ggml_set_name(Q, "Q");
5586-
ggml_tensor * K = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_qk, n_tokens, n_head, batch);
5587-
ggml_set_name(K, "K");
5588-
ggml_tensor * V = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_v, n_tokens, n_head, batch);
5589-
ggml_set_name(V, "V");
5590-
5591-
ggml_tensor * out = ggml_sparsek_attn(ctx, Q, K, V, k_top, win_local, stride_global);
5592-
ggml_set_name(out, "SPARSEK_ATTN_out");
5593-
5594-
return out;
5595-
}
5596-
};
5597-
5598-
5599-
5600-
// GGML_OP_FLAsH_ATTN_EXT
5560+
// GGML_OP_FLASH_ATTN_EXT
56015561
struct test_flash_attn_ext : public test_case {
56025562
const int64_t hsk; // K head size
56035563
const int64_t hsv; // V head size
@@ -7381,7 +7341,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
73817341
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
73827342
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
73837343
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
7384-
7344+
73857345
for (bool mask : { true, false } ) {
73867346
for (bool sinks : { true, false } ) {
73877347
for (float max_bias : { 0.0f, 8.0f }) {
@@ -7420,23 +7380,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
74207380
}
74217381
}
74227382
}
7423-
// ---- SPARSEK_ATTN --------------------------------------------------
7424-
for (int64_t d_qk : {64, 128}) {
7425-
for (int64_t d_v : {64, 128}) {
7426-
for (int64_t n_head : {4, 8}) {
7427-
for (int64_t kv : {113, 512}) {
7428-
for (int64_t b : {1, 4}) {
7429-
for (int32_t k_top : {16, 32}) {
7430-
for (int32_t win_local : {32, 64}) {
7431-
test_cases.emplace_back(new test_sparsek_attn(
7432-
d_qk, d_v, n_head, kv, b, k_top, win_local, /*stride_global*/128));
7433-
}
7434-
}
7435-
}
7436-
}
7437-
}
7438-
}
7439-
}
74407383

74417384
test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, { 10, 5, 4, 3}));
74427385
test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, {30000, 1, 1, 1}));
@@ -7497,6 +7440,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
74977440
// Test cases for performance evaluation: should be representative of real-world use cases
74987441
static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
74997442
std::vector<std::unique_ptr<test_case>> test_cases;
7443+
75007444
// Conv2d: K=CRS=NPQ=4096 matmul performance
75017445
uint32_t iwh_idx = 0;
75027446
uint32_t kwh_idx = 1;

0 commit comments

Comments
 (0)