@@ -11642,79 +11642,191 @@ static void ggml_compute_forward_add_rel_pos(
11642
11642
}
11643
11643
}
11644
11644
11645
- // ggml_compute_forward_rwkv_wkv
11645
+ // ggml_compute_forward_rwkv_wkv6
11646
11646
11647
- static void ggml_compute_forward_rwkv_wkv_f32 (
11647
+ static void ggml_compute_forward_rwkv_wkv6_f32 (
11648
11648
const struct ggml_compute_params * params ,
11649
11649
struct ggml_tensor * dst ) {
11650
- const size_t T = dst -> src [1 ]-> ne [3 ];
11651
- const size_t C = dst -> ne [0 ];
11652
- const size_t H = dst -> src [1 ]-> ne [2 ];
11653
- const size_t n_seqs = dst -> src [5 ]-> ne [1 ];
11650
+ const int64_t T = dst -> src [1 ]-> ne [3 ];
11651
+ const int64_t C = dst -> ne [0 ];
11652
+ const int64_t HEADS = dst -> src [1 ]-> ne [2 ];
11653
+ const int64_t n_seqs = dst -> src [5 ]-> ne [1 ];
11654
+ const int64_t head_size = C / HEADS ;
11654
11655
11655
11656
float * dst_data = (float * ) dst -> data ;
11656
11657
float * state = ((float * ) dst -> data ) + C * T ;
11657
11658
11658
- if (params -> ith != 0 ) {
11659
+ const int ith = params -> ith ;
11660
+ const int nth = params -> nth ;
11661
+
11662
+ if (ith >= HEADS ) {
11659
11663
return ;
11660
11664
}
11661
11665
11662
- memset (dst_data , 0 , T * C * sizeof (float ));
11666
+ const int h_start = (HEADS * ith ) / nth ;
11667
+ const int h_end = ((HEADS * (ith + 1 )) / nth < HEADS ) ?
11668
+ (HEADS * (ith + 1 )) / nth : HEADS ;
11663
11669
11664
11670
float * k = (float * ) dst -> src [0 ]-> data ;
11665
11671
float * v = (float * ) dst -> src [1 ]-> data ;
11666
11672
float * r = (float * ) dst -> src [2 ]-> data ;
11667
11673
float * time_faaaa = (float * ) dst -> src [3 ]-> data ;
11668
11674
float * time_decay = (float * ) dst -> src [4 ]-> data ;
11669
11675
11670
- size_t t_stride = H * ( C / H );
11676
+ size_t t_stride = HEADS * head_size ; // Same to C
11671
11677
11672
- size_t h_stride = C / H ;
11673
- size_t h_stride_2d = (C / H ) * (C / H );
11678
+ size_t h_stride = C / HEADS ;
11679
+ GGML_ASSERT (C % HEADS == 0 ); // C must be divisible by HEADS
11680
+ size_t h_stride_2d = head_size * head_size ;
11674
11681
11675
- // basically fused operations:
11676
- // dst = r @ (time_faaaa * (k @ v) + state),
11677
- // state = time_decay * state + (k @ v),
11678
- // recursive through each token
11679
- for (size_t t = 0 ; t < T ; t ++ ) {
11680
- size_t t_offset = t * t_stride ;
11681
- size_t state_offset = (C / H ) * C * (t / (T / n_seqs ));
11682
- float * state_cur = state + state_offset ;
11683
- float * state_prev = t % (T / n_seqs ) ? state_cur : (float * )dst -> src [5 ]-> data + state_offset ;
11682
+ if (ith == 0 ) {
11683
+ memset (dst_data , 0 , T * C * sizeof (float ));
11684
+ }
11685
+ ggml_barrier (params -> threadpool );
11684
11686
11685
- for (size_t h = 0 ; h < H ; h ++ ) {
11686
- size_t h_offset = h * h_stride ;
11687
- size_t t_h_offset = t_offset + h_offset ;
11688
- size_t h_2d_offset = h * h_stride_2d ;
11689
11687
11690
- for (size_t i = 0 ; i < C / H ; i ++ ) {
11691
- size_t t_h_i_offset = t_h_offset + i ;
11692
- size_t h_i_offset = h_offset + i ;
11693
- size_t h_2d_i_offset = h_2d_offset + i * h_stride ;
11688
+ #if defined(__AVX__ ) && !defined(__AVX512F__ )
11689
+ #define GGML_F32X GGML_F32x8
11690
+ #define GGML_F32X_SET1 GGML_F32x8_SET1
11691
+ #define GGML_F32X_LOAD GGML_F32x8_LOAD
11692
+ #define GGML_F32X_STORE GGML_F32x8_STORE
11693
+ #define GGML_F32X_MUL GGML_F32x8_MUL
11694
+ #define GGML_F32X_FMA GGML_F32x8_FMA
11695
+ #define WKV_VECTOR_SIZE 8
11696
+ #elif defined(__AVX512F__ )
11697
+ #define GGML_F32X GGML_F32x16
11698
+ #define GGML_F32X_SET1 GGML_F32x16_SET1
11699
+ #define GGML_F32X_LOAD GGML_F32x16_LOAD
11700
+ #define GGML_F32X_STORE GGML_F32x16_STORE
11701
+ #define GGML_F32X_MUL GGML_F32x16_MUL
11702
+ #define GGML_F32X_FMA GGML_F32x16_FMA
11703
+ #define WKV_VECTOR_SIZE 16
11704
+ #elif defined(__ARM_NEON ) && defined(__aarch64__ )
11705
+ #define GGML_F32X GGML_F32x4
11706
+ #define GGML_F32X_SET1 GGML_F32x4_SET1
11707
+ #define GGML_F32X_LOAD GGML_F32x4_LOAD
11708
+ #define GGML_F32X_STORE GGML_F32x4_STORE
11709
+ #define GGML_F32X_MUL GGML_F32x4_MUL
11710
+ #define GGML_F32X_FMA GGML_F32x4_FMA
11711
+ #define WKV_VECTOR_SIZE 4
11712
+ #endif
11694
11713
11695
- float k_val = k [t_h_i_offset ];
11696
- float r_val = r [t_h_i_offset ];
11697
- float time_faaaa_val = time_faaaa [h_i_offset ];
11698
- // RWKV v6: different time_decay for each token.
11699
- float time_decay_val = time_decay [t_h_i_offset ];
11714
+ #ifdef WKV_VECTOR_SIZE
11715
+ const int64_t vec_count = head_size / WKV_VECTOR_SIZE ;
11716
+
11717
+ for (int64_t t = 0 ; t < T ; t ++ ) {
11718
+ size_t t_offset = t * t_stride ;
11719
+ size_t state_offset = head_size * C * (t / (T / n_seqs ));
11720
+ float * state_cur = state + state_offset ;
11721
+ float * state_prev = t % (T / n_seqs ) ? state_cur : (float * )dst -> src [5 ]-> data + state_offset ;
11722
+
11723
+ for (int64_t h = h_start ; h < h_end ; h ++ ) {
11724
+ size_t h_offset = h * h_stride ;
11725
+ size_t t_h_offset = t_offset + h_offset ;
11726
+ size_t h_2d_offset = h * h_stride_2d ;
11727
+
11728
+ for (int64_t i = 0 ; i < head_size ; i ++ ) {
11729
+ size_t t_h_i_offset = t_h_offset + i ;
11730
+ size_t h_i_offset = h_offset + i ;
11731
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride ;
11732
+
11733
+ float k_val = k [t_h_i_offset ];
11734
+ float r_val = r [t_h_i_offset ];
11735
+ float time_faaaa_val = time_faaaa [h_i_offset ];
11736
+ float time_decay_val = time_decay [t_h_i_offset ];
11737
+
11738
+ // Broadcast scalar values to vectors
11739
+ GGML_F32X k_vec = GGML_F32X_SET1 (k_val );
11740
+ GGML_F32X r_vec = GGML_F32X_SET1 (r_val );
11741
+ GGML_F32X time_faaaa_vec = GGML_F32X_SET1 (time_faaaa_val );
11742
+ GGML_F32X time_decay_vec = GGML_F32X_SET1 (time_decay_val );
11743
+
11744
+ for (int64_t j = 0 ; j < vec_count ; j ++ ) {
11745
+ size_t base_j = j * WKV_VECTOR_SIZE ;
11746
+ size_t t_h_j_offset = t_h_offset + base_j ;
11747
+ size_t h_2d_i_j_offset = h_2d_i_offset + base_j ;
11748
+
11749
+ // Load x elements at once
11750
+ GGML_F32X v_vec = GGML_F32X_LOAD (& v [t_h_j_offset ]);
11751
+ GGML_F32X prev_state_vec = GGML_F32X_LOAD (& state_prev [h_2d_i_j_offset ]);
11752
+ GGML_F32X dst_vec = GGML_F32X_LOAD (& dst_data [t_h_j_offset ]);
11753
+
11754
+ // Compute kv = v * k
11755
+ GGML_F32X kv_vec = GGML_F32X_MUL (v_vec , k_vec );
11756
+
11757
+ // Compute temp = kv * time_faaaa + prev_state
11758
+ GGML_F32X temp_vec = GGML_F32X_FMA (prev_state_vec , kv_vec , time_faaaa_vec );
11759
+
11760
+ // Update dst: dst += temp * r
11761
+ dst_vec = GGML_F32X_FMA (dst_vec , temp_vec , r_vec );
11762
+ GGML_F32X_STORE (& dst_data [t_h_j_offset ], dst_vec );
11763
+
11764
+ // Update state: state = prev_state * time_decay + kv
11765
+ GGML_F32X new_state_vec = GGML_F32X_FMA (kv_vec , prev_state_vec , time_decay_vec );
11766
+ GGML_F32X_STORE (& state_cur [h_2d_i_j_offset ], new_state_vec );
11767
+ }
11700
11768
11701
- for (size_t j = 0 ; j < C / H ; j ++ ) {
11702
- size_t t_h_j_offset = t_h_offset + j ;
11703
- size_t h_2d_i_j_offset = h_2d_i_offset + j ;
11769
+ // Handle remaining elements, this will not be used.
11770
+ for (int64_t j = vec_count * WKV_VECTOR_SIZE ; j < head_size ; j ++ ) {
11771
+ size_t t_h_j_offset = t_h_offset + j ;
11772
+ size_t h_2d_i_j_offset = h_2d_i_offset + j ;
11773
+ float v_val = v [t_h_j_offset ];
11774
+ float kv_val = v_val * k_val ;
11775
+ float prev_state_val = state_prev [h_2d_i_j_offset ];
11776
+ float temp_val = kv_val * time_faaaa_val + prev_state_val ;
11777
+ dst_data [t_h_j_offset ] += temp_val * r_val ;
11778
+ state_cur [h_2d_i_j_offset ] = prev_state_val * time_decay_val + kv_val ;
11779
+ }
11780
+ }
11781
+ }
11782
+ }
11704
11783
11705
- float v_val = v [t_h_j_offset ];
11706
- float kv_val = v_val * k_val ;
11707
- float prev_state_val = state_prev [h_2d_i_j_offset ];
11708
- float temp_val = kv_val * time_faaaa_val + prev_state_val ;
11709
- dst_data [t_h_j_offset ] += temp_val * r_val ;
11710
- state_cur [h_2d_i_j_offset ] = prev_state_val * time_decay_val + kv_val ;
11784
+ #else
11785
+ // basically fused operations:
11786
+ // dst = r @ (time_faaaa * (k @ v) + state),
11787
+ // state = time_decay * state + (k @ v),
11788
+ // recursive through each token
11789
+ for (int64_t t = 0 ; t < T ; t ++ ) {
11790
+ size_t t_offset = t * t_stride ;
11791
+ size_t state_offset = head_size * C * (t / (T / n_seqs ));
11792
+ float * state_cur = state + state_offset ;
11793
+ float * state_prev = t % (T / n_seqs ) ? state_cur : (float * )dst -> src [5 ]-> data + state_offset ;
11794
+
11795
+ for (int64_t h = h_start ; h < h_end ; h ++ ) {
11796
+ size_t h_offset = h * h_stride ;
11797
+ size_t t_h_offset = t_offset + h_offset ;
11798
+ size_t h_2d_offset = h * h_stride_2d ;
11799
+
11800
+ for (int64_t i = 0 ; i < head_size ; i ++ ) {
11801
+ size_t t_h_i_offset = t_h_offset + i ;
11802
+ size_t h_i_offset = h_offset + i ;
11803
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride ;
11804
+
11805
+ float k_val = k [t_h_i_offset ];
11806
+ float r_val = r [t_h_i_offset ];
11807
+ float time_faaaa_val = time_faaaa [h_i_offset ];
11808
+ // RWKV v6: different time_decay for each token.
11809
+ float time_decay_val = time_decay [t_h_i_offset ];
11810
+
11811
+ for (int64_t j = 0 ; j < head_size ; j ++ ) {
11812
+ size_t t_h_j_offset = t_h_offset + j ;
11813
+ size_t h_2d_i_j_offset = h_2d_i_offset + j ;
11814
+
11815
+ float v_val = v [t_h_j_offset ];
11816
+ float kv_val = v_val * k_val ;
11817
+ float prev_state_val = state_prev [h_2d_i_j_offset ];
11818
+ float temp_val = kv_val * time_faaaa_val + prev_state_val ;
11819
+ dst_data [t_h_j_offset ] += temp_val * r_val ;
11820
+ state_cur [h_2d_i_j_offset ] = prev_state_val * time_decay_val + kv_val ;
11821
+ }
11711
11822
}
11712
11823
}
11713
11824
}
11714
- }
11825
+ #endif
11715
11826
}
11716
11827
11717
- static void ggml_compute_forward_rwkv_wkv (
11828
+
11829
+ static void ggml_compute_forward_rwkv_wkv6 (
11718
11830
const struct ggml_compute_params * params ,
11719
11831
struct ggml_tensor * dst ) {
11720
11832
@@ -11723,7 +11835,7 @@ static void ggml_compute_forward_rwkv_wkv(
11723
11835
switch (src0 -> type ) {
11724
11836
case GGML_TYPE_F32 :
11725
11837
{
11726
- ggml_compute_forward_rwkv_wkv_f32 (params , dst );
11838
+ ggml_compute_forward_rwkv_wkv6_f32 (params , dst );
11727
11839
} break ;
11728
11840
default :
11729
11841
{
@@ -12475,9 +12587,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
12475
12587
{
12476
12588
ggml_compute_forward_add_rel_pos (params , tensor );
12477
12589
} break ;
12478
- case GGML_OP_RWKV_WKV :
12590
+ case GGML_OP_RWKV_WKV6 :
12479
12591
{
12480
- ggml_compute_forward_rwkv_wkv (params , tensor );
12592
+ ggml_compute_forward_rwkv_wkv6 (params , tensor );
12481
12593
} break ;
12482
12594
case GGML_OP_MAP_UNARY :
12483
12595
{
@@ -12775,7 +12887,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
12775
12887
case GGML_OP_WIN_PART :
12776
12888
case GGML_OP_WIN_UNPART :
12777
12889
case GGML_OP_GET_REL_POS :
12778
- case GGML_OP_RWKV_WKV :
12890
+ case GGML_OP_RWKV_WKV6 :
12779
12891
case GGML_OP_MAP_UNARY :
12780
12892
case GGML_OP_MAP_BINARY :
12781
12893
case GGML_OP_MAP_CUSTOM1_F32 :
0 commit comments