1
+ // Copyright (C) 2024 Zhiyuan Li
2
+
3
+
4
+ #include < sycl/sycl.hpp>
5
+ #include " wkv6.hpp"
6
+
7
+ constexpr int WKV_BLOCK_SIZE = 64 ; // Matching CUDA_WKV_BLOCK_SIZE
8
+
9
+ // Helper function for the main kernel
10
+ static void rwkv_wkv_f32_kernel (
11
+ const int B, const int T, const int C, const int H,
12
+ const float * k, const float * v, const float * r,
13
+ const float * tf, const float * td, const float * s,
14
+ float * dst, const sycl::nd_item<3 >& item_ct1, float * shared_mem) {
15
+
16
+ const int tid = item_ct1.get_local_id (2 );
17
+ const int bid = item_ct1.get_group (2 );
18
+
19
+ const int head_size = WKV_BLOCK_SIZE;
20
+ const int batch_i = bid / H;
21
+ const int head_i = bid % H;
22
+ const int state_size = C * head_size;
23
+ const int n_seq_tokens = T / B;
24
+
25
+ // Set up shared memory pointers
26
+ float * _k = shared_mem;
27
+ float * _r = _k + head_size;
28
+ float * _tf = _r + head_size;
29
+ float * _td = _tf + head_size;
30
+
31
+ // Local state array
32
+ float state[WKV_BLOCK_SIZE];
33
+
34
+ // Load initial state
35
+ #pragma unroll
36
+ for (int i = 0 ; i < head_size; i++) {
37
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
38
+ }
39
+
40
+ // Sync threads before shared memory operations
41
+ item_ct1.barrier (sycl::access ::fence_space::local_space);
42
+
43
+ // Load time-mixing parameters
44
+ _tf[tid] = tf[head_i * head_size + tid];
45
+ item_ct1.barrier (sycl::access ::fence_space::local_space);
46
+
47
+ // Main sequence processing loop
48
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
49
+ t < (batch_i + 1 ) * n_seq_tokens * C + head_i * head_size + tid;
50
+ t += C) {
51
+
52
+ item_ct1.barrier (sycl::access ::fence_space::local_space);
53
+
54
+ // Load current timestep data to shared memory
55
+ _k[tid] = k[t];
56
+ _r[tid] = r[t];
57
+ _td[tid] = td[t];
58
+
59
+ item_ct1.barrier (sycl::access ::fence_space::local_space);
60
+
61
+ const float _v = v[t];
62
+ float y = 0 ;
63
+
64
+ // Process in chunks of 4 for better vectorization
65
+ #pragma unroll
66
+ for (int j = 0 ; j < head_size; j += 4 ) {
67
+ // Load data in vec4 chunks
68
+ sycl::float4 k4 (_k[j], _k[j+1 ], _k[j+2 ], _k[j+3 ]);
69
+ sycl::float4 r4 (_r[j], _r[j+1 ], _r[j+2 ], _r[j+3 ]);
70
+ sycl::float4 tf4 (_tf[j], _tf[j+1 ], _tf[j+2 ], _tf[j+3 ]);
71
+ sycl::float4 td4 (_td[j], _td[j+1 ], _td[j+2 ], _td[j+3 ]);
72
+ sycl::float4 s4 (state[j], state[j+1 ], state[j+2 ], state[j+3 ]);
73
+
74
+ // Compute key-value product
75
+ sycl::float4 kv4 = k4 * _v;
76
+
77
+ // Accumulate weighted sum
78
+ y += sycl::dot (r4, tf4 * kv4 + s4);
79
+
80
+ // Update state
81
+ s4 = s4 * td4 + kv4;
82
+
83
+ // Store updated state
84
+ state[j] = s4.x ();
85
+ state[j+1 ] = s4.y ();
86
+ state[j+2 ] = s4.z ();
87
+ state[j+3 ] = s4.w ();
88
+ }
89
+
90
+ dst[t] = y;
91
+ }
92
+
93
+ // Save final state
94
+ #pragma unroll
95
+ for (int i = 0 ; i < head_size; i++) {
96
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
97
+ }
98
+ }
99
+
100
+ void ggml_sycl_op_rwkv_wkv6 (ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
101
+ const ggml_tensor* src1, ggml_tensor* dst) {
102
+
103
+ const float * k_d = (const float *)dst->src [0 ]->data ;
104
+ const float * v_d = (const float *)dst->src [1 ]->data ;
105
+ const float * r_d = (const float *)dst->src [2 ]->data ;
106
+ const float * tf_d = (const float *)dst->src [3 ]->data ;
107
+ const float * td_d = (const float *)dst->src [4 ]->data ;
108
+ const float * s_d = (const float *)dst->src [5 ]->data ;
109
+ float * dst_d = (float *)dst->data ;
110
+
111
+ const int64_t B = dst->src [5 ]->ne [1 ];
112
+ const int64_t T = dst->src [0 ]->ne [3 ];
113
+ const int64_t C = dst->ne [0 ];
114
+ const int64_t H = dst->src [0 ]->ne [2 ];
115
+
116
+ GGML_ASSERT (dst->src [5 ]->type == GGML_TYPE_F32);
117
+ GGML_ASSERT (C % H == 0 );
118
+ GGML_ASSERT (C / H == WKV_BLOCK_SIZE);
119
+
120
+ dpct::queue_ptr stream = ctx.stream ();
121
+
122
+ // Calculate execution configuration
123
+ const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof (float ); // For k, r, tf, td
124
+ sycl::range<3 > block_dims (1 , 1 , C / H);
125
+ sycl::range<3 > grid_dims (1 , 1 , B * H);
126
+
127
+ // Submit kernel
128
+ stream->submit ([&](sycl::handler& cgh) {
129
+ sycl::local_accessor<float , 1 > shared_mem_acc (shared_mem_size, cgh);
130
+
131
+ cgh.parallel_for (
132
+ sycl::nd_range<3 >(grid_dims * block_dims, block_dims),
133
+ [=](sycl::nd_item<3 > item_ct1) {
134
+ rwkv_wkv_f32_kernel (
135
+ B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
136
+ item_ct1, shared_mem_acc.get_pointer ()
137
+ );
138
+ });
139
+ });
140
+ }
0 commit comments