Skip to content

Commit 4dc1e6c

Browse files
zhiyuan1iggerganovslarenpminevykhrustalev
authored andcommitted
Optimize RWKV6 Operator Naming and Implement Multi-core CPU/ SYCL Acceleration (ggml-org#10133)
* rwkv6: rename to wkv6 * rwkv6: support avx2 avx512 armv8 armv9 * rwkv6: update cuda file name * rwkv6: rename params * wkv on sycl * sycl: add some ops * sycl: Enhance OP support judgment * wkv6: drop armv9 and tranfer to GGML style ggml-ci * sync : ggml * update the function to use appropriate types * fix define error * Update ggml/src/ggml-cpu.c * add appropriate asserts * move element-wise functions outside * put the declaration outside the loop * rewrite to be more inline with the common pattern for distributing threads * use recommended way GGML_TENSOR_LOCALS --------- Co-authored-by: Georgi Gerganov <[email protected]> Co-authored-by: Diego Devesa <[email protected]> Co-authored-by: Plamen Minev <[email protected]> Co-authored-by: Yuri Khrustalev <[email protected]> Co-authored-by: Meng, Hengyu <[email protected]>
1 parent 8414dfd commit 4dc1e6c

22 files changed

+2066
-1116
lines changed

docs/backend/SYCL.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ Part2:
398398

399399
|Chosen Device ID|Setting|
400400
|-|-|
401-
|0|`export ONEAPI_DEVICE_SELECTOR="level_zero:1"` or no action|
401+
|0|`export ONEAPI_DEVICE_SELECTOR="level_zero:0"` or no action|
402402
|1|`export ONEAPI_DEVICE_SELECTOR="level_zero:1"`|
403403
|0 & 1|`export ONEAPI_DEVICE_SELECTOR="level_zero:0;level_zero:1"`|
404404

ggml/include/ggml.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ extern "C" {
509509
GGML_OP_WIN_UNPART,
510510
GGML_OP_GET_REL_POS,
511511
GGML_OP_ADD_REL_POS,
512-
GGML_OP_RWKV_WKV,
512+
GGML_OP_RWKV_WKV6,
513513

514514
GGML_OP_UNARY,
515515

@@ -1819,7 +1819,7 @@ extern "C" {
18191819
struct ggml_tensor * pw,
18201820
struct ggml_tensor * ph);
18211821

1822-
GGML_API struct ggml_tensor * ggml_rwkv_wkv(
1822+
GGML_API struct ggml_tensor * ggml_rwkv_wkv6(
18231823
struct ggml_context * ctx,
18241824
struct ggml_tensor * k,
18251825
struct ggml_tensor * v,

ggml/src/ggml-cpu.c

Lines changed: 160 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -11642,79 +11642,191 @@ static void ggml_compute_forward_add_rel_pos(
1164211642
}
1164311643
}
1164411644

11645-
// ggml_compute_forward_rwkv_wkv
11645+
// ggml_compute_forward_rwkv_wkv6
1164611646

11647-
static void ggml_compute_forward_rwkv_wkv_f32(
11647+
static void ggml_compute_forward_rwkv_wkv6_f32(
1164811648
const struct ggml_compute_params * params,
1164911649
struct ggml_tensor * dst) {
11650-
const size_t T = dst->src[1]->ne[3];
11651-
const size_t C = dst->ne[0];
11652-
const size_t H = dst->src[1]->ne[2];
11653-
const size_t n_seqs = dst->src[5]->ne[1];
11650+
const int64_t T = dst->src[1]->ne[3];
11651+
const int64_t C = dst->ne[0];
11652+
const int64_t HEADS = dst->src[1]->ne[2];
11653+
const int64_t n_seqs = dst->src[5]->ne[1];
11654+
const int64_t head_size = C / HEADS;
1165411655

1165511656
float * dst_data = (float *) dst->data;
1165611657
float * state = ((float *) dst->data) + C * T;
1165711658

11658-
if (params->ith != 0) {
11659+
const int ith = params->ith;
11660+
const int nth = params->nth;
11661+
11662+
if (ith >= HEADS) {
1165911663
return;
1166011664
}
1166111665

11662-
memset(dst_data, 0, T * C * sizeof(float));
11666+
const int h_start = (HEADS * ith) / nth;
11667+
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
11668+
(HEADS * (ith + 1)) / nth : HEADS;
1166311669

1166411670
float * k = (float *) dst->src[0]->data;
1166511671
float * v = (float *) dst->src[1]->data;
1166611672
float * r = (float *) dst->src[2]->data;
1166711673
float * time_faaaa = (float *) dst->src[3]->data;
1166811674
float * time_decay = (float *) dst->src[4]->data;
1166911675

11670-
size_t t_stride = H * (C / H);
11676+
size_t t_stride = HEADS * head_size; // Same to C
1167111677

11672-
size_t h_stride = C / H;
11673-
size_t h_stride_2d = (C / H) * (C / H);
11678+
size_t h_stride = C / HEADS;
11679+
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
11680+
size_t h_stride_2d = head_size * head_size;
1167411681

11675-
// basically fused operations:
11676-
// dst = r @ (time_faaaa * (k @ v) + state),
11677-
// state = time_decay * state + (k @ v),
11678-
// recursive through each token
11679-
for (size_t t = 0; t < T; t++) {
11680-
size_t t_offset = t * t_stride;
11681-
size_t state_offset = (C / H) * C * (t / (T / n_seqs));
11682-
float * state_cur = state + state_offset;
11683-
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
11682+
if (ith == 0) {
11683+
memset(dst_data, 0, T * C * sizeof(float));
11684+
}
11685+
ggml_barrier(params->threadpool);
1168411686

11685-
for (size_t h = 0; h < H; h++) {
11686-
size_t h_offset = h * h_stride;
11687-
size_t t_h_offset = t_offset + h_offset;
11688-
size_t h_2d_offset = h * h_stride_2d;
1168911687

11690-
for (size_t i = 0; i < C / H; i++) {
11691-
size_t t_h_i_offset = t_h_offset + i;
11692-
size_t h_i_offset = h_offset + i;
11693-
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
11688+
#if defined(__AVX__) && !defined(__AVX512F__)
11689+
#define GGML_F32X GGML_F32x8
11690+
#define GGML_F32X_SET1 GGML_F32x8_SET1
11691+
#define GGML_F32X_LOAD GGML_F32x8_LOAD
11692+
#define GGML_F32X_STORE GGML_F32x8_STORE
11693+
#define GGML_F32X_MUL GGML_F32x8_MUL
11694+
#define GGML_F32X_FMA GGML_F32x8_FMA
11695+
#define WKV_VECTOR_SIZE 8
11696+
#elif defined(__AVX512F__)
11697+
#define GGML_F32X GGML_F32x16
11698+
#define GGML_F32X_SET1 GGML_F32x16_SET1
11699+
#define GGML_F32X_LOAD GGML_F32x16_LOAD
11700+
#define GGML_F32X_STORE GGML_F32x16_STORE
11701+
#define GGML_F32X_MUL GGML_F32x16_MUL
11702+
#define GGML_F32X_FMA GGML_F32x16_FMA
11703+
#define WKV_VECTOR_SIZE 16
11704+
#elif defined(__ARM_NEON) && defined(__aarch64__)
11705+
#define GGML_F32X GGML_F32x4
11706+
#define GGML_F32X_SET1 GGML_F32x4_SET1
11707+
#define GGML_F32X_LOAD GGML_F32x4_LOAD
11708+
#define GGML_F32X_STORE GGML_F32x4_STORE
11709+
#define GGML_F32X_MUL GGML_F32x4_MUL
11710+
#define GGML_F32X_FMA GGML_F32x4_FMA
11711+
#define WKV_VECTOR_SIZE 4
11712+
#endif
1169411713

11695-
float k_val = k[t_h_i_offset];
11696-
float r_val = r[t_h_i_offset];
11697-
float time_faaaa_val = time_faaaa[h_i_offset];
11698-
// RWKV v6: different time_decay for each token.
11699-
float time_decay_val = time_decay[t_h_i_offset];
11714+
#ifdef WKV_VECTOR_SIZE
11715+
const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
11716+
11717+
for (int64_t t = 0; t < T; t++) {
11718+
size_t t_offset = t * t_stride;
11719+
size_t state_offset = head_size * C * (t / (T / n_seqs));
11720+
float * state_cur = state + state_offset;
11721+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
11722+
11723+
for (int64_t h = h_start; h < h_end; h++) {
11724+
size_t h_offset = h * h_stride;
11725+
size_t t_h_offset = t_offset + h_offset;
11726+
size_t h_2d_offset = h * h_stride_2d;
11727+
11728+
for (int64_t i = 0; i < head_size; i++) {
11729+
size_t t_h_i_offset = t_h_offset + i;
11730+
size_t h_i_offset = h_offset + i;
11731+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
11732+
11733+
float k_val = k[t_h_i_offset];
11734+
float r_val = r[t_h_i_offset];
11735+
float time_faaaa_val = time_faaaa[h_i_offset];
11736+
float time_decay_val = time_decay[t_h_i_offset];
11737+
11738+
// Broadcast scalar values to vectors
11739+
GGML_F32X k_vec = GGML_F32X_SET1(k_val);
11740+
GGML_F32X r_vec = GGML_F32X_SET1(r_val);
11741+
GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
11742+
GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
11743+
11744+
for (int64_t j = 0; j < vec_count; j++) {
11745+
size_t base_j = j * WKV_VECTOR_SIZE;
11746+
size_t t_h_j_offset = t_h_offset + base_j;
11747+
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
11748+
11749+
// Load x elements at once
11750+
GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
11751+
GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
11752+
GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
11753+
11754+
// Compute kv = v * k
11755+
GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
11756+
11757+
// Compute temp = kv * time_faaaa + prev_state
11758+
GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
11759+
11760+
// Update dst: dst += temp * r
11761+
dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
11762+
GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
11763+
11764+
// Update state: state = prev_state * time_decay + kv
11765+
GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
11766+
GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
11767+
}
1170011768

11701-
for (size_t j = 0; j < C / H; j ++) {
11702-
size_t t_h_j_offset = t_h_offset + j;
11703-
size_t h_2d_i_j_offset = h_2d_i_offset + j;
11769+
// Handle remaining elements, this will not be used.
11770+
for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
11771+
size_t t_h_j_offset = t_h_offset + j;
11772+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
11773+
float v_val = v[t_h_j_offset];
11774+
float kv_val = v_val * k_val;
11775+
float prev_state_val = state_prev[h_2d_i_j_offset];
11776+
float temp_val = kv_val * time_faaaa_val + prev_state_val;
11777+
dst_data[t_h_j_offset] += temp_val * r_val;
11778+
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
11779+
}
11780+
}
11781+
}
11782+
}
1170411783

11705-
float v_val = v[t_h_j_offset];
11706-
float kv_val = v_val * k_val;
11707-
float prev_state_val = state_prev[h_2d_i_j_offset];
11708-
float temp_val = kv_val * time_faaaa_val + prev_state_val;
11709-
dst_data[t_h_j_offset] += temp_val * r_val;
11710-
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
11784+
#else
11785+
// basically fused operations:
11786+
// dst = r @ (time_faaaa * (k @ v) + state),
11787+
// state = time_decay * state + (k @ v),
11788+
// recursive through each token
11789+
for (int64_t t = 0; t < T; t++) {
11790+
size_t t_offset = t * t_stride;
11791+
size_t state_offset = head_size * C * (t / (T / n_seqs));
11792+
float * state_cur = state + state_offset;
11793+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
11794+
11795+
for (int64_t h = h_start; h < h_end; h++) {
11796+
size_t h_offset = h * h_stride;
11797+
size_t t_h_offset = t_offset + h_offset;
11798+
size_t h_2d_offset = h * h_stride_2d;
11799+
11800+
for (int64_t i = 0; i < head_size; i++) {
11801+
size_t t_h_i_offset = t_h_offset + i;
11802+
size_t h_i_offset = h_offset + i;
11803+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
11804+
11805+
float k_val = k[t_h_i_offset];
11806+
float r_val = r[t_h_i_offset];
11807+
float time_faaaa_val = time_faaaa[h_i_offset];
11808+
// RWKV v6: different time_decay for each token.
11809+
float time_decay_val = time_decay[t_h_i_offset];
11810+
11811+
for (int64_t j = 0; j < head_size; j++) {
11812+
size_t t_h_j_offset = t_h_offset + j;
11813+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
11814+
11815+
float v_val = v[t_h_j_offset];
11816+
float kv_val = v_val * k_val;
11817+
float prev_state_val = state_prev[h_2d_i_j_offset];
11818+
float temp_val = kv_val * time_faaaa_val + prev_state_val;
11819+
dst_data[t_h_j_offset] += temp_val * r_val;
11820+
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
11821+
}
1171111822
}
1171211823
}
1171311824
}
11714-
}
11825+
#endif
1171511826
}
1171611827

11717-
static void ggml_compute_forward_rwkv_wkv(
11828+
11829+
static void ggml_compute_forward_rwkv_wkv6(
1171811830
const struct ggml_compute_params * params,
1171911831
struct ggml_tensor * dst) {
1172011832

@@ -11723,7 +11835,7 @@ static void ggml_compute_forward_rwkv_wkv(
1172311835
switch (src0->type) {
1172411836
case GGML_TYPE_F32:
1172511837
{
11726-
ggml_compute_forward_rwkv_wkv_f32(params, dst);
11838+
ggml_compute_forward_rwkv_wkv6_f32(params, dst);
1172711839
} break;
1172811840
default:
1172911841
{
@@ -12475,9 +12587,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1247512587
{
1247612588
ggml_compute_forward_add_rel_pos(params, tensor);
1247712589
} break;
12478-
case GGML_OP_RWKV_WKV:
12590+
case GGML_OP_RWKV_WKV6:
1247912591
{
12480-
ggml_compute_forward_rwkv_wkv(params, tensor);
12592+
ggml_compute_forward_rwkv_wkv6(params, tensor);
1248112593
} break;
1248212594
case GGML_OP_MAP_UNARY:
1248312595
{
@@ -12775,7 +12887,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1277512887
case GGML_OP_WIN_PART:
1277612888
case GGML_OP_WIN_UNPART:
1277712889
case GGML_OP_GET_REL_POS:
12778-
case GGML_OP_RWKV_WKV:
12890+
case GGML_OP_RWKV_WKV6:
1277912891
case GGML_OP_MAP_UNARY:
1278012892
case GGML_OP_MAP_BINARY:
1278112893
case GGML_OP_MAP_CUSTOM1_F32:

ggml/src/ggml-cuda.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
#include "ggml-cuda/tsembd.cuh"
3737
#include "ggml-cuda/unary.cuh"
3838
#include "ggml-cuda/upscale.cuh"
39-
#include "ggml-cuda/rwkv-wkv.cuh"
39+
#include "ggml-cuda/wkv6.cuh"
4040

4141
#include <algorithm>
4242
#include <array>
@@ -2319,8 +2319,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23192319
case GGML_OP_CROSS_ENTROPY_LOSS:
23202320
ggml_cuda_cross_entropy_loss(ctx, dst);
23212321
break;
2322-
case GGML_OP_RWKV_WKV:
2323-
ggml_cuda_op_rwkv_wkv(ctx, dst);
2322+
case GGML_OP_RWKV_WKV6:
2323+
ggml_cuda_op_rwkv_wkv6(ctx, dst);
23242324
break;
23252325
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
23262326
ggml_cuda_cross_entropy_loss_back(ctx, dst);
@@ -3153,7 +3153,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31533153
case GGML_OP_ARANGE:
31543154
case GGML_OP_TIMESTEP_EMBEDDING:
31553155
case GGML_OP_LEAKY_RELU:
3156-
case GGML_OP_RWKV_WKV:
3156+
case GGML_OP_RWKV_WKV6:
31573157
return true;
31583158
case GGML_OP_FLASH_ATTN_EXT: {
31593159
#ifndef FLASH_ATTN_AVAILABLE

ggml/src/ggml-cuda/rwkv-wkv.cuh

Lines changed: 0 additions & 5 deletions
This file was deleted.

ggml/src/ggml-cuda/rwkv-wkv.cu renamed to ggml/src/ggml-cuda/wkv6.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "common.cuh"
2-
#include "rwkv-wkv.cuh"
2+
#include "wkv6.cuh"
33

44
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
55
const int tid = threadIdx.x;
@@ -64,7 +64,7 @@ static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const
6464
}
6565
}
6666

67-
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
67+
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
6868
const float * k_d = (const float *)dst->src[0]->data;
6969
const float * v_d = (const float *)dst->src[1]->data;
7070
const float * r_d = (const float *)dst->src[2]->data;
@@ -83,7 +83,7 @@ void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
8383

8484
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
8585
GGML_ASSERT(C % H == 0);
86-
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
86+
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
8787

8888
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
8989
}

ggml/src/ggml-cuda/wkv6.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_WKV_BLOCK_SIZE 64
4+
5+
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)