@@ -16702,30 +16702,30 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
16702
16702
struct ggml_tensor * dst) {
16703
16703
const size_t T = dst->src[1]->ne[3];
16704
16704
const size_t C = dst->ne[0];
16705
- const size_t H = dst->src[1]->ne[2];
16705
+ const size_t HEADS = dst->src[1]->ne[2];
16706
16706
const size_t n_seqs = dst->src[5]->ne[1];
16707
- const size_t head_size = C / H ;
16707
+ const size_t head_size = C / HEADS ;
16708
16708
16709
16709
float * dst_data = (float *) dst->data;
16710
16710
float * state = ((float *) dst->data) + C * T;
16711
16711
16712
- if ((size_t)params->ith >= H ) {
16712
+ if ((size_t)params->ith >= HEADS ) {
16713
16713
return;
16714
16714
}
16715
16715
16716
- size_t h_start = (H * params->ith) / params->nth;
16717
- size_t h_end = ((H * (size_t)(params->ith + 1)) / (size_t)params->nth < H ) ?
16718
- (H * (size_t)(params->ith + 1)) / (size_t)params->nth : H ;
16716
+ size_t h_start = (HEADS * params->ith) / params->nth;
16717
+ size_t h_end = ((HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth < HEADS ) ?
16718
+ (HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth : HEADS ;
16719
16719
16720
16720
float * k = (float *) dst->src[0]->data;
16721
16721
float * v = (float *) dst->src[1]->data;
16722
16722
float * r = (float *) dst->src[2]->data;
16723
16723
float * time_faaaa = (float *) dst->src[3]->data;
16724
16724
float * time_decay = (float *) dst->src[4]->data;
16725
16725
16726
- size_t t_stride = H * head_size;
16726
+ size_t t_stride = HEADS * head_size;
16727
16727
16728
- size_t h_stride = C / H ;
16728
+ size_t h_stride = C / HEADS ;
16729
16729
size_t h_stride_2d = head_size * head_size;
16730
16730
16731
16731
if (params->ith == 0) {
0 commit comments