Skip to content

Commit 056be43

Browse files
MollySophiaarthw
authored andcommitted
RWKV v6: RWKV_WKV op CUDA implementation (ggml-org#9454)
* ggml: CUDA unary op EXP Signed-off-by: Molly Sophia <[email protected]> * ggml: rwkv_wkv op CUDA impl Signed-off-by: Molly Sophia <[email protected]> --------- Signed-off-by: Molly Sophia <[email protected]>
1 parent 6b0d342 commit 056be43

File tree

6 files changed

+168
-0
lines changed

6 files changed

+168
-0
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "ggml-cuda/tsembd.cuh"
3535
#include "ggml-cuda/unary.cuh"
3636
#include "ggml-cuda/upscale.cuh"
37+
#include "ggml-cuda/rwkv-wkv.cuh"
3738

3839
#include <algorithm>
3940
#include <array>
@@ -2243,6 +2244,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22432244
case GGML_UNARY_OP_HARDSWISH:
22442245
ggml_cuda_op_hardswish(ctx, dst);
22452246
break;
2247+
case GGML_UNARY_OP_EXP:
2248+
ggml_cuda_op_exp(ctx, dst);
2249+
break;
22462250
default:
22472251
return false;
22482252
}
@@ -2345,6 +2349,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23452349
case GGML_OP_CROSS_ENTROPY_LOSS:
23462350
ggml_cuda_cross_entropy_loss(ctx, dst);
23472351
break;
2352+
case GGML_OP_RWKV_WKV:
2353+
ggml_cuda_op_rwkv_wkv(ctx, dst);
23482354
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
23492355
ggml_cuda_cross_entropy_loss_back(ctx, dst);
23502356
break;
@@ -2806,6 +2812,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28062812
case GGML_UNARY_OP_HARDSWISH:
28072813
case GGML_UNARY_OP_GELU_QUICK:
28082814
case GGML_UNARY_OP_TANH:
2815+
case GGML_UNARY_OP_EXP:
28092816
return ggml_is_contiguous(op->src[0]);
28102817
default:
28112818
return false;
@@ -2967,6 +2974,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
29672974
case GGML_OP_ARANGE:
29682975
case GGML_OP_TIMESTEP_EMBEDDING:
29692976
case GGML_OP_LEAKY_RELU:
2977+
case GGML_OP_RWKV_WKV:
29702978
return true;
29712979
case GGML_OP_FLASH_ATTN_EXT:
29722980
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)

ggml/src/ggml-cuda/rwkv-wkv.cu

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#include "common.cuh"
2+
#include "rwkv-wkv.cuh"
3+
4+
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
5+
const int tid = threadIdx.x;
6+
const int bid = blockIdx.x;
7+
8+
const int head_size = CUDA_WKV_BLOCK_SIZE;
9+
const int batch_i = bid / H;
10+
const int head_i = bid % H;
11+
const int state_size = C * head_size;
12+
const int n_seq_tokens = T / B;
13+
14+
float state[head_size];
15+
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
16+
17+
#pragma unroll
18+
for (int i = 0; i < head_size; i++) {
19+
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
20+
}
21+
22+
__syncthreads();
23+
_tf[tid] = tf[head_i * head_size + tid];
24+
__syncthreads();
25+
26+
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
27+
__syncthreads();
28+
_k[tid] = k[t];
29+
_r[tid] = r[t];
30+
_td[tid] = td[t];
31+
__syncthreads();
32+
33+
const float _v = v[t];
34+
float y = 0;
35+
for (int j = 0; j < head_size; j += 4) {
36+
const float4& k = (float4&)(_k[j]);
37+
const float4& r = (float4&)(_r[j]);
38+
const float4& tf = (float4&)(_tf[j]);
39+
const float4& td = (float4&)(_td[j]);
40+
float4& s = (float4&)(state[j]);
41+
float4 kv;
42+
43+
kv.x = k.x * _v;
44+
kv.y = k.y * _v;
45+
kv.z = k.z * _v;
46+
kv.w = k.w * _v;
47+
48+
y += r.x * (tf.x * kv.x + s.x);
49+
y += r.y * (tf.y * kv.y + s.y);
50+
y += r.z * (tf.z * kv.z + s.z);
51+
y += r.w * (tf.w * kv.w + s.w);
52+
53+
s.x = s.x * td.x + kv.x;
54+
s.y = s.y * td.y + kv.y;
55+
s.z = s.z * td.z + kv.z;
56+
s.w = s.w * td.w + kv.w;
57+
}
58+
dst[t] = y;
59+
}
60+
61+
#pragma unroll
62+
for (int i = 0; i < head_size; i++) {
63+
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
64+
}
65+
}
66+
67+
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
68+
const float * k_d = (const float *)dst->src[0]->data;
69+
const float * v_d = (const float *)dst->src[1]->data;
70+
const float * r_d = (const float *)dst->src[2]->data;
71+
const float * tf_d = (const float *)dst->src[3]->data;
72+
const float * td_d = (const float *)dst->src[4]->data;
73+
const float * s_d = (const float *)dst->src[5]->data;
74+
75+
const int64_t B = dst->src[5]->ne[1];
76+
const int64_t T = dst->src[0]->ne[3];
77+
const int64_t C = dst->ne[0];
78+
const int64_t H = dst->src[0]->ne[2];
79+
80+
float * dst_d = (float *)dst->data;
81+
82+
cudaStream_t stream = ctx.stream();
83+
84+
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
85+
GGML_ASSERT(C % H == 0);
86+
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
87+
88+
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
89+
}

ggml/src/ggml-cuda/rwkv-wkv.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_WKV_BLOCK_SIZE 64
4+
5+
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/unary.cu

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,15 @@ static __global__ void hardswish_f32(const float * x, float * dst, const int k)
9595
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
9696
}
9797

98+
static __global__ void exp_f32(const float * x, float * dst, const int k) {
99+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
100+
101+
if (i >= k) {
102+
return;
103+
}
104+
dst[i] = expf(x[i]);
105+
}
106+
98107
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
99108
const int i = blockDim.x*blockIdx.x + threadIdx.x;
100109
if (i >= k) {
@@ -189,6 +198,11 @@ static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaSt
189198
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
190199
}
191200

201+
static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
202+
const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
203+
exp_f32<<<num_blocks, CUDA_EXP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
204+
}
205+
192206
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
193207
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
194208
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
@@ -354,6 +368,20 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
354368
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
355369
}
356370

371+
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
372+
const ggml_tensor * src0 = dst->src[0];
373+
const float * src0_d = (const float *)src0->data;
374+
float * dst_d = (float *)dst->data;
375+
cudaStream_t stream = ctx.stream();
376+
377+
GGML_ASSERT(ggml_is_contiguous(src0));
378+
379+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
380+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
381+
382+
exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
383+
}
384+
357385
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
358386
const ggml_tensor * src0 = dst->src[0];
359387
const float * src0_d = (const float *)src0->data;

ggml/src/ggml-cuda/unary.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#define CUDA_RELU_BLOCK_SIZE 256
99
#define CUDA_SIGMOID_BLOCK_SIZE 256
1010
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
11+
#define CUDA_EXP_BLOCK_SIZE 256
1112
#define CUDA_HARDSWISH_BLOCK_SIZE 256
1213
#define CUDA_SQR_BLOCK_SIZE 256
1314
#define CUDA_SQRT_BLOCK_SIZE 256
@@ -32,6 +33,8 @@ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
3233

3334
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
3435

36+
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
37+
3538
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
3639

3740
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

tests/test-backend-ops.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1543,6 +1543,36 @@ struct test_ssm_scan : public test_case {
15431543
}
15441544
};
15451545

1546+
// GGML_OP_RWKV_WKV
1547+
struct test_rwkv_wkv : public test_case {
1548+
const ggml_type type;
1549+
1550+
const int64_t head_count;
1551+
const int64_t head_size;
1552+
const int64_t n_seq_tokens;
1553+
const int64_t n_seqs;
1554+
1555+
std::string vars() override {
1556+
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
1557+
}
1558+
1559+
test_rwkv_wkv(ggml_type type = GGML_TYPE_F32,
1560+
int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
1561+
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
1562+
1563+
ggml_tensor * build_graph(ggml_context * ctx) override {
1564+
const int64_t n_tokens = n_seq_tokens * n_seqs;
1565+
ggml_tensor * r = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
1566+
ggml_tensor * k = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_size, 1, head_count, n_tokens }.data());
1567+
ggml_tensor * v = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
1568+
ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
1569+
ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
1570+
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
1571+
ggml_tensor * out = ggml_rwkv_wkv(ctx, k, v, r, tf, td, s);
1572+
return out;
1573+
}
1574+
};
1575+
15461576
// GGML_OP_MUL_MAT
15471577
struct test_mul_mat : public test_case {
15481578
const ggml_type type_a;
@@ -3337,6 +3367,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
33373367

33383368
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
33393369

3370+
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 1, 1));
3371+
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 1));
3372+
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 4));
3373+
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 128, 4));
3374+
33403375
#if 1
33413376
for (ggml_type type_a : base_types) {
33423377
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {

0 commit comments

Comments
 (0)