Skip to content

Commit 02f9a1d

Browse files
committed
rwkv6: rename params
1 parent 01d49f3 commit 02f9a1d

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

ggml/src/ggml.c

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16702,30 +16702,30 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1670216702
struct ggml_tensor * dst) {
1670316703
const size_t T = dst->src[1]->ne[3];
1670416704
const size_t C = dst->ne[0];
16705-
const size_t H = dst->src[1]->ne[2];
16705+
const size_t HEADS = dst->src[1]->ne[2];
1670616706
const size_t n_seqs = dst->src[5]->ne[1];
16707-
const size_t head_size = C / H;
16707+
const size_t head_size = C / HEADS;
1670816708

1670916709
float * dst_data = (float *) dst->data;
1671016710
float * state = ((float *) dst->data) + C * T;
1671116711

16712-
if ((size_t)params->ith >= H) {
16712+
if ((size_t)params->ith >= HEADS) {
1671316713
return;
1671416714
}
1671516715

16716-
size_t h_start = (H * params->ith) / params->nth;
16717-
size_t h_end = ((H * (size_t)(params->ith + 1)) / (size_t)params->nth < H) ?
16718-
(H * (size_t)(params->ith + 1)) / (size_t)params->nth : H;
16716+
size_t h_start = (HEADS * params->ith) / params->nth;
16717+
size_t h_end = ((HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth < HEADS) ?
16718+
(HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth : HEADS;
1671916719

1672016720
float * k = (float *) dst->src[0]->data;
1672116721
float * v = (float *) dst->src[1]->data;
1672216722
float * r = (float *) dst->src[2]->data;
1672316723
float * time_faaaa = (float *) dst->src[3]->data;
1672416724
float * time_decay = (float *) dst->src[4]->data;
1672516725

16726-
size_t t_stride = H * head_size;
16726+
size_t t_stride = HEADS * head_size;
1672716727

16728-
size_t h_stride = C / H;
16728+
size_t h_stride = C / HEADS;
1672916729
size_t h_stride_2d = head_size * head_size;
1673016730

1673116731
if (params->ith == 0) {

0 commit comments

Comments
 (0)