Skip to content

Commit 7c39f2d

Browse files
committed
ggml: rwkv_wkv op CUDA impl
Signed-off-by: Molly Sophia <[email protected]>
1 parent 0b174ab commit 7c39f2d

File tree

4 files changed

+134
-0
lines changed

4 files changed

+134
-0
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "ggml-cuda/tsembd.cuh"
3333
#include "ggml-cuda/unary.cuh"
3434
#include "ggml-cuda/upscale.cuh"
35+
#include "ggml-cuda/rwkv-wkv.cuh"
3536

3637
#include <algorithm>
3738
#include <array>
@@ -2327,6 +2328,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23272328
case GGML_OP_CROSS_ENTROPY_LOSS:
23282329
ggml_cuda_cross_entropy_loss(ctx, dst);
23292330
break;
2331+
case GGML_OP_RWKV_WKV:
2332+
ggml_cuda_op_rwkv_wkv(ctx, dst);
2333+
break;
23302334
default:
23312335
return false;
23322336
}
@@ -2926,6 +2930,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
29262930
case GGML_OP_ARANGE:
29272931
case GGML_OP_TIMESTEP_EMBEDDING:
29282932
case GGML_OP_LEAKY_RELU:
2933+
case GGML_OP_RWKV_WKV:
29292934
return true;
29302935
case GGML_OP_FLASH_ATTN_EXT:
29312936
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)

ggml/src/ggml-cuda/rwkv-wkv.cu

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#include "common.cuh"
2+
#include "rwkv-wkv.cuh"
3+
4+
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
5+
const int tid = threadIdx.x;
6+
const int bid = blockIdx.x;
7+
8+
const int head_size = CUDA_WKV_BLOCK_SIZE;
9+
const int batch_i = bid / H;
10+
const int head_i = bid % H;
11+
const int state_size = C * head_size;
12+
const int n_seq_tokens = T / B;
13+
14+
float state[head_size];
15+
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
16+
17+
#pragma unroll
18+
for (int i = 0; i < head_size; i++) {
19+
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
20+
}
21+
22+
__syncthreads();
23+
_tf[tid] = tf[head_i * head_size + tid];
24+
__syncthreads();
25+
26+
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
27+
__syncthreads();
28+
_k[tid] = k[t];
29+
_r[tid] = r[t];
30+
_td[tid] = td[t];
31+
__syncthreads();
32+
33+
const float _v = v[t];
34+
float y = 0;
35+
for (int j = 0; j < head_size; j += 4) {
36+
const float4& k = (float4&)(_k[j]);
37+
const float4& r = (float4&)(_r[j]);
38+
const float4& tf = (float4&)(_tf[j]);
39+
const float4& td = (float4&)(_td[j]);
40+
float4& s = (float4&)(state[j]);
41+
float4 kv;
42+
43+
kv.x = k.x * _v;
44+
kv.y = k.y * _v;
45+
kv.z = k.z * _v;
46+
kv.w = k.w * _v;
47+
48+
y += r.x * (tf.x * kv.x + s.x);
49+
y += r.y * (tf.y * kv.y + s.y);
50+
y += r.z * (tf.z * kv.z + s.z);
51+
y += r.w * (tf.w * kv.w + s.w);
52+
53+
s.x = s.x * td.x + kv.x;
54+
s.y = s.y * td.y + kv.y;
55+
s.z = s.z * td.z + kv.z;
56+
s.w = s.w * td.w + kv.w;
57+
}
58+
dst[t] = y;
59+
}
60+
61+
#pragma unroll
62+
for (int i = 0; i < head_size; i++) {
63+
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
64+
}
65+
}
66+
67+
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
68+
const float * k_d = (const float *)dst->src[0]->data;
69+
const float * v_d = (const float *)dst->src[1]->data;
70+
const float * r_d = (const float *)dst->src[2]->data;
71+
const float * tf_d = (const float *)dst->src[3]->data;
72+
const float * td_d = (const float *)dst->src[4]->data;
73+
const float * s_d = (const float *)dst->src[5]->data;
74+
75+
const int64_t B = dst->src[5]->ne[1];
76+
const int64_t T = dst->src[0]->ne[3];
77+
const int64_t C = dst->ne[0];
78+
const int64_t H = dst->src[0]->ne[2];
79+
80+
float * dst_d = (float *)dst->data;
81+
82+
cudaStream_t stream = ctx.stream();
83+
84+
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
85+
GGML_ASSERT(C % H == 0);
86+
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
87+
88+
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
89+
}

ggml/src/ggml-cuda/rwkv-wkv.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_WKV_BLOCK_SIZE 64
4+
5+
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

tests/test-backend-ops.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,6 +1553,36 @@ struct test_ssm_scan : public test_case {
15531553
}
15541554
};
15551555

1556+
// GGML_OP_RWKV_WKV
1557+
struct test_rwkv_wkv : public test_case {
1558+
const ggml_type type;
1559+
1560+
const int64_t head_count;
1561+
const int64_t head_size;
1562+
const int64_t n_seq_tokens;
1563+
const int64_t n_seqs;
1564+
1565+
std::string vars() override {
1566+
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
1567+
}
1568+
1569+
test_rwkv_wkv(ggml_type type = GGML_TYPE_F32,
1570+
int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
1571+
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
1572+
1573+
ggml_tensor * build_graph(ggml_context * ctx) override {
1574+
const int64_t n_tokens = n_seq_tokens * n_seqs;
1575+
ggml_tensor * r = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
1576+
ggml_tensor * k = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_size, 1, head_count, n_tokens }.data());
1577+
ggml_tensor * v = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
1578+
ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
1579+
ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
1580+
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
1581+
ggml_tensor * out = ggml_rwkv_wkv(ctx, k, v, r, tf, td, s);
1582+
return out;
1583+
}
1584+
};
1585+
15561586
// GGML_OP_MUL_MAT
15571587
struct test_mul_mat : public test_case {
15581588
const ggml_type type_a;
@@ -3257,6 +3287,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
32573287

32583288
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
32593289

3290+
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 1, 1));
3291+
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 1));
3292+
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 4));
3293+
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 128, 4));
3294+
32603295
#if 1
32613296
for (ggml_type type_a : base_types) {
32623297
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {

0 commit comments

Comments
 (0)