|
9 | 9 |
|
10 | 10 | #include <float.h> |
11 | 11 | #include <algorithm> |
12 | | -#include <vector> |
13 | 12 |
|
14 | 13 | // ggml_compute_forward_dup |
15 | 14 |
|
@@ -7971,126 +7970,6 @@ void ggml_compute_forward_argsort( |
7971 | 7970 | } |
7972 | 7971 | } |
7973 | 7972 |
|
7974 | | -//------------------------------------------------------------------------------ |
7975 | | -// SparseK Attention (CPU) |
7976 | | -//------------------------------------------------------------------------------ |
7977 | | - |
7978 | | -static void ggml_compute_forward_sparsek_attn_f32( |
7979 | | - const struct ggml_compute_params * params, |
7980 | | - struct ggml_tensor * dst) { |
7981 | | - |
7982 | | - // Single-threaded baseline version |
7983 | | - if (params->ith != 0) return; |
7984 | | - |
7985 | | - const struct ggml_tensor * Q = dst->src[0]; |
7986 | | - const struct ggml_tensor * K = dst->src[1]; |
7987 | | - const struct ggml_tensor * V = dst->src[2]; |
7988 | | - |
7989 | | - GGML_ASSERT(Q && K && V); |
7990 | | - GGML_ASSERT(Q->type == GGML_TYPE_F32); |
7991 | | - GGML_ASSERT(K->type == GGML_TYPE_F32); |
7992 | | - GGML_ASSERT(V->type == GGML_TYPE_F32); |
7993 | | - GGML_ASSERT(dst->type == GGML_TYPE_F32); |
7994 | | - |
7995 | | - const int32_t k_top = ggml_get_op_params_i32(dst, 0); |
7996 | | - const int32_t win_local = ggml_get_op_params_i32(dst, 1); |
7997 | | - const int32_t stride_glb = ggml_get_op_params_i32(dst, 2); |
7998 | | - GGML_UNUSED(win_local); |
7999 | | - GGML_UNUSED(stride_glb); |
8000 | | - |
8001 | | - // Tensor dimensions according to GGML layout: ne[0]=d, ne[1]=seq, ne[2]=head, ne[3]=batch |
8002 | | - const int64_t D = Q->ne[0]; |
8003 | | - const int64_t T = Q->ne[1]; |
8004 | | - const int64_t H = Q->ne[2]; |
8005 | | - const int64_t B = Q->ne[3]; |
8006 | | - |
8007 | | - // Temporary buffer for attention scores for one query row |
8008 | | - std::vector<float> attn_row(T, 0.0f); |
8009 | | - |
8010 | | - const float scale = 1.0f / sqrtf((float) D); |
8011 | | - |
8012 | | - // Loops over batch, head, and query token |
8013 | | - for (int64_t b = 0; b < B; ++b) { |
8014 | | - for (int64_t h = 0; h < H; ++h) { |
8015 | | - for (int64_t iq = 0; iq < T; ++iq) { |
8016 | | - |
8017 | | - // (1) Compute dot products Q·K within same (b,h) |
8018 | | - const char * qbase = (const char *) Q->data + b*Q->nb[3] + h*Q->nb[2] + iq*Q->nb[1]; |
8019 | | - const float * qv = (const float *) qbase; |
8020 | | - |
8021 | | - for (int64_t j = 0; j < T; ++j) { |
8022 | | - const char * kbase = (const char *) K->data + b*K->nb[3] + h*K->nb[2] + j*K->nb[1]; |
8023 | | - const float * kv = (const float *) kbase; |
8024 | | - |
8025 | | - float dot = 0.0f; |
8026 | | - for (int64_t d = 0; d < D; ++d) { |
8027 | | - dot += qv[d] * kv[d]; |
8028 | | - } |
8029 | | - attn_row[j] = dot * scale; |
8030 | | - } |
8031 | | - |
8032 | | - // (2) Select top-k threshold using nth_element |
8033 | | - const int kk = std::max<int>(1, std::min<int>((int)T, k_top)); |
8034 | | - std::vector<float> tmp(attn_row.begin(), attn_row.end()); |
8035 | | - std::nth_element(tmp.begin(), tmp.begin() + (kk - 1), tmp.end(), std::greater<float>()); |
8036 | | - const float thr = tmp[kk - 1]; |
8037 | | - |
8038 | | - for (int64_t j = 0; j < T; ++j) { |
8039 | | - if (attn_row[j] < thr) attn_row[j] = -INFINITY; |
8040 | | - } |
8041 | | - |
8042 | | - // (3) Numerically stable softmax on the masked row |
8043 | | - float maxv = -INFINITY; |
8044 | | - for (int64_t j = 0; j < T; ++j) { |
8045 | | - maxv = std::max(maxv, attn_row[j]); |
8046 | | - } |
8047 | | - float sum = 0.0f; |
8048 | | - for (int64_t j = 0; j < T; ++j) { |
8049 | | - float v = attn_row[j] - maxv; |
8050 | | - float e = expf(v); |
8051 | | - attn_row[j] = e; |
8052 | | - sum += e; |
8053 | | - } |
8054 | | - const float inv_sum = sum > 0.0f ? 1.0f / sum : 0.0f; |
8055 | | - for (int64_t j = 0; j < T; ++j) { |
8056 | | - attn_row[j] *= inv_sum; |
8057 | | - } |
8058 | | - |
8059 | | - // (4) Compute output = A·V (weighted sum) |
8060 | | - float * y = (float *) ((char *) dst->data + b*dst->nb[3] + h*dst->nb[2] + iq*dst->nb[1]); |
8061 | | - |
8062 | | - for (int64_t d = 0; d < D; ++d) { |
8063 | | - float acc = 0.0f; |
8064 | | - for (int64_t j = 0; j < T; ++j) { |
8065 | | - const float aij = attn_row[j]; |
8066 | | - if (aij == 0.0f) continue; // skip masked entries |
8067 | | - const char * vbase = (const char *) V->data + b*V->nb[3] + h*V->nb[2] + j*V->nb[1]; |
8068 | | - const float * vv = (const float *) vbase; |
8069 | | - acc += aij * vv[d]; |
8070 | | - } |
8071 | | - y[d] = acc; |
8072 | | - } |
8073 | | - } |
8074 | | - } |
8075 | | - } |
8076 | | - |
8077 | | - GGML_PRINT_DEBUG("[SPARSEK CPU] k_top=%d win_local=%d stride=%d\n", |
8078 | | - k_top, win_local, stride_glb); |
8079 | | -} |
8080 | | - |
8081 | | -void ggml_compute_forward_sparsek_attn( |
8082 | | - const struct ggml_compute_params * params, |
8083 | | - struct ggml_tensor * dst) { |
8084 | | - switch (dst->type) { |
8085 | | - case GGML_TYPE_F32: |
8086 | | - ggml_compute_forward_sparsek_attn_f32(params, dst); |
8087 | | - break; |
8088 | | - default: |
8089 | | - GGML_ASSERT(false && "sparsek_attn: unsupported dst type"); |
8090 | | - } |
8091 | | -} |
8092 | | - |
8093 | | - |
8094 | 7973 | // ggml_compute_forward_flash_attn_ext |
8095 | 7974 |
|
8096 | 7975 | static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( |
|
0 commit comments