99
1010public class Qwen3Kernels {
1111
12- //public static void dbgCopy(FloatArray destKeyCache, FloatArray srcKey, FloatArray destValueCache, FloatArray srcValue, IntArray positioNlayer, int kvDim, int layer, int contextLength) {
13- public static void dbgCopy (FloatArray srcBuffer , FloatArray dstBuffer , IntArray positioNlayer , int layer ) {
14- //int position = positioNlayer.get(0);
15- //if (position == 1) {
12+ /**
13+ * For explicit copy out useful in debugging.
14+ * With this kernel we can store the values of an array to a tmp buffer at a timing of interest.
15+ * In the end of the taskgraph we copy out the tmp buffer to inspect the array values at the timing of interest.
16+ * @param srcBuffer the array we want to inspect.
17+ * @param dstBuffer the tmp buffer.
18+ */
19+ public static void dbgCopy (FloatArray srcBuffer , FloatArray dstBuffer ) {
1620 for (@ Parallel int i = 0 ; i < srcBuffer .getSize (); i ++) {
1721 dstBuffer .set (i , srcBuffer .get (i ));
1822 }
19- //}
2023 }
2124
22- public static void rmsnormReductionWithOffset (
25+ /**
26+ * RmsNorm with parallel offset:
27+ * The following 3 kernels implement rmsnorm in offset range in parallel for qCur and Kcur rmsnorm calculations.
28+ *
29+ * Step 1: Reduction.
30+ * This kernel implements rmsnorm in offset range in parallel for qCur and Kcur rmsnorm calculations.
31+ */
32+ public static void rmsnormReductionWithParallelOffset (
2333 KernelContext context ,
2434 FloatArray output ,
2535 FloatArray x ,
2636 int localMemSize ) {
2737
28- // global size: 0 - (config.numberOfHeads() * nEmbdHead)
29- // local size : 0 - nEmbdHead
3038 int gid = context .globalIdx ;
3139 int lid = context .localIdx ;
3240 int groupId = context .groupIdx ;
@@ -36,13 +44,8 @@ public static void rmsnormReductionWithOffset(
3644 float [] localX = context .allocateFloatLocalArray (localMemSize );
3745
3846 // Load input value and compute square
39- //int globalReadIndex = gid + offset;
40- //if (gid < size && globalReadIndex < x.getSize()) {
41- localX [lid ] = x .get (gid );
42- localX [lid ] = localX [lid ] * localX [lid ];
43- //} else {
44- // localX[lid] = 0.0f;
45- //}
47+ localX [lid ] = x .get (gid );
48+ localX [lid ] = localX [lid ] * localX [lid ];
4649
4750 // Perform parallel reduction within the work group
4851 for (int stride = (groupSize / 2 ); stride > 0 ; stride /= 2 ) {
@@ -59,7 +62,11 @@ public static void rmsnormReductionWithOffset(
5962 }
6063 }
6164
62- // Second kernel - Combines partial sums and computes final normalization
65+ /**
66+ * RmsNorm with parallel offset:
67+ *
68+ * Step 2: Combines partial reduction outputs and computes final normalization.
69+ */
6370 public static void rmsnormFinalNormalizationWithParallelOffset (
6471 KernelContext context ,
6572 FloatArray output , // size should be related to offsetIndex
@@ -72,12 +79,7 @@ public static void rmsnormFinalNormalizationWithParallelOffset(
7279 // Only the index threads need to perform this calculation
7380 if (gid < offsetIndex ) {
7481 // Combine partial sums from all workgroups
75- float ss = 0.0f ;
76- //for (int i = 1; i < output.getSize(); i++) { // Fixed bounds to avoid out of bounds
77- // for (int i = 1; i < output.getSize(); i++) { // Fixed bounds to avoid out of bounds
78- // ss += output.get(i);
79- // }
80- ss = output .get (gid );
82+ float ss = output .get (gid );
8183
8284 ss /= size ;
8385 ss += ermsNorm ;
@@ -87,36 +89,28 @@ public static void rmsnormFinalNormalizationWithParallelOffset(
8789 }
8890 }
8991
92+ /**
93+ * RmsNorm with parallel offset:
94+ *
95+ * Step 3: perform mapIndex operation.
96+ */
9097 public static void rmsnormMapIndexInPlaceWithParallelOffset (
9198 KernelContext context ,
92- FloatArray out , // Q
99+ FloatArray out ,
93100 FloatArray weights ,
94101 int size ,
95- FloatArray ss // tempQcur1
96- ) {
102+ FloatArray ss ) {
97103
98- int gid = context .globalIdx ; // 0 - size
99- //int index = offset + gid;
104+ int gid = context .globalIdx ;
100105 int groupId = context .groupIdx ;
101106
102107 float finalss = ss .get (groupId );
103- //out.set(index, weights.get(index % size) * (finalss * x.get(index)));
104- //out.set(index, weights.get(index) * (finalss * x.get(index)));
105- //if (index < offset + size) {
108+
106109 if (gid < out .getSize ()) { // TODO: check if redundant
107110 float a = weights .get (gid % size );
108111 float b = finalss * out .get (gid );
109112 out .set (gid , a * b );
110113 }
111-
112- //old gid, index:
113- // int gid = context.globalIdx; // 0 - size
114- // int index = offset + gid;
115- // context.globalBarrier();
116- // // reset ss
117- // if (gid < ss.getSize()) {
118- // ss.set(gid, 0.0f);
119- // }
120114 }
121115
122116 /**
@@ -162,92 +156,12 @@ public static void rmsnormWithParallelOffset(
162156 }
163157 }
164158
165- public static void reductionOneBlockWithLayerWithOffset (
166- KernelContext context ,
167- FloatArray output ,
168- FloatArray x ,
169- int offset ,
170- int size ,
171- float ermsNorm ,
172- int localMemSize ) {
173-
174- int gid = context .globalIdx ; // 0 - nEmbHead = 128
175- int lid = context .localIdx ; // 0 - state.localsize [
176- int groupId = context .groupIdx ;
177- int groupSize = context .localGroupSizeX ;
178-
179- // Allocate local memory with the provided size
180- float [] localX = context .allocateFloatLocalArray (localMemSize );
181-
182- // Load input value and compute square
183- int globalReadIndex = gid + offset ;
184- if (gid < size && globalReadIndex < x .getSize ()) {
185- localX [lid ] = x .get (globalReadIndex );
186- localX [lid ] = localX [lid ] * localX [lid ];
187- } else {
188- localX [lid ] = 0.0f ;
189- }
190-
191- // Perform parallel reduction within the work group
192- for (int stride = (groupSize / 2 ); stride > 0 ; stride /= 2 ) {
193- context .localBarrier ();
194- if (lid < stride ) {
195- localX [lid ] += localX [lid + stride ];
196- }
197- }
198-
199- // Each workgroup stores its partial sum in a different location
200- if (lid == 0 ) {
201- // Store the partial sum from each workgroup
202- output .set (groupId + 1 , localX [0 ]);
203- }
204-
205- // // Only the first thread in the first workgroup computes the final normalization factor
206- // if (gid == 0) {
207- // // Combine partial sums from all workgroups
208- // float ss = 0.0f;
209- // for (int i = 1; i <= (size / localMemSize); i++) { // Assuming 8 workgroups
210- // ss += output.get(i);
211- // }
212- //
213- // ss /= size;
214- // ss += ermsNorm;
215- // ss = 1.0f / TornadoMath.sqrt(ss);
216- // output.set(0, ss); // Store the final scale factor
217- // }
218- }
219-
220- /**
221- * Normalize and scale (in-place) of rmsnorm operation.
222- */
223- public static void mapIndexInPlace (KernelContext context , FloatArray out , /*FloatArray x,*/ FloatArray weights , int offset , int size , FloatArray ss ) {
224- int gid = context .globalIdx ; // 0 - size
225- int index = offset + gid ;
226-
227- float finalss = ss .get (0 );
228- //out.set(index, weights.get(index % size) * (finalss * x.get(index)));
229- //out.set(index, weights.get(index) * (finalss * x.get(index)));
230- //if (index < offset + size) {
231- if (index < out .getSize ()) { // TODO: check if redundant
232- float a = weights .get (index % size );
233- float b = finalss * out .get (index );
234- out .set (index , a * b );
235- }
236-
237- context .globalBarrier ();
238- // reset ss
239- if (gid < ss .getSize ()) {
240- ss .set (gid , 0.0f );
241- }
242- }
243-
244159 public static void ropeRotation (KernelContext context ,
245160 IntArray position ,
246161 FloatArray q ,
247162 FloatArray k ,
248163 int numberOfKeyValueHeads ,
249164 int nEmbdHead ) {
250- //System.out.println("ropeRotationSplit");
251165 int h = context .globalIdx ;
252166 int ic = context .globalIdy ;
253167
@@ -256,7 +170,6 @@ public static void ropeRotation(KernelContext context,
256170 int nComplEmbdHead = nEmbdHead / 2 ;
257171
258172 // Compute RoPE frequencies for Qwen3
259- //float freq = 1.0f / TornadoMath.pow(10000.0f, (2.0f * ic) / (float) nEmbdHead);
260173 float theta = 1000000.0f ;
261174 int i = ic * 2 ; // match i in precompute (see RoPE.precomputeFreqsCis)
262175 float freq = 1.0f / TornadoMath .pow (theta , (float )i / (float )nEmbdHead );
@@ -290,13 +203,11 @@ public static void processHeadsParallel(
290203 int nEmbdHeadV , /* = config.numberOfHeadsValue(), replace headSize in lines: 266, 268, 273 */
291204 int nEmbdGqa , /* kvDim */
292205 int gqa , /* kvMul */
293- int seqLen ,
294206 IntArray positionHolder ,
295207 FloatArray wrapAtt ,
296208 int layer , int contextLength ) {
297209
298210 int pos = positionHolder .get (0 );
299- //int loff = layer * contextLength * kvDim;
300211 int loff = layer * contextLength * nEmbdGqa ;
301212
302213 // Parallelize computation across attention heads
@@ -332,22 +243,16 @@ private static void processHeadTornado(
332243
333244 // Base index for this head's attention weights
334245 int headOffset = h * (pos + 1 );
335- //int headOffset = h * contextLength;
336246
337247 // STEP 1: Calculate attention scores for all timesteps
338248 for (int t = 0 ; t <= pos ; t ++) {
339- //int kvHeadIdx = h / kvMul;
340249 int kvHeadIdx = h / gqa ;
341- //int keyOffset = (int) (loff + t * kvDim + kvHeadIdx * headSize);
342250 int keyOffset = (int ) (loff + t * nEmbdGqa + kvHeadIdx * nEmbdHeadK ); // line 255
343251
344252 float score = 0.0f ;
345- //for (int i = 0; i < headSize; i++) {
346253 for (int i = 0 ; i < nEmbdHeadK ; i ++) {
347- //score += allQ.get(h * headSize + i) * key_cache.get(keyOffset + i);
348254 score += allQ .get (h * nEmbdHeadK + i ) * key_cache .get (keyOffset + i ); // line 255
349255 }
350- //score = score / TornadoMath.sqrt(headSize);
351256 score = score / TornadoMath .sqrt (nEmbdHead ); // line 257
352257
353258 // Store in attention buffer
@@ -380,28 +285,24 @@ private static void processHeadTornado(
380285 }
381286
382287 // STEP 5: Compute weighted sum of values for each dimension
383- //for (int i = 0; i < headSize; i++) {
384288 for (int i = 0 ; i < nEmbdHeadV ; i ++) {
385289 float weightedSum = 0.0f ;
386290 for (int t = 0 ; t <= pos ; t ++) {
387- //int kvHeadIdx = h / kvMul;
388291 int kvHeadIdx = h / gqa ;
389- //int valueOffset = (int) (loff + t * kvDim + kvHeadIdx * headSize);
390292 int valueOffset = (int ) (loff + t * nEmbdGqa + kvHeadIdx * nEmbdHeadV ); //line 273
391293 weightedSum += wrapAtt .get (headOffset + t ) * value_cache .get (valueOffset + i );
392294 }
393- //allXb.set(h * headSize + i, weightedSum);
394295 allXb .set (h * nEmbdHeadV + i , weightedSum ); // offset from line 266
395296 }
396297 }
397298
398299 public static void matrixVectorGenericWithResidual (
399300 KernelContext context ,
400- FloatArray v , // vector = [2048]
401- FloatArray out , // out = [1024]
402- HalfFloatArray m , // matrix = [2048, 1024]
403- int dim1 , // dim1 = 2048, vectorSize
404- int dim0 , // dim0 = 1024, outputSize
301+ FloatArray v ,
302+ FloatArray out ,
303+ HalfFloatArray m ,
304+ int dim1 ,
305+ int dim0 ,
405306 int localWorkGroupSize ) {
406307
407308 // One row per workgroup (not per thread)
@@ -431,8 +332,8 @@ public static float matrixVectorRowMajorOptimized(
431332 int dim1 ,
432333 int dim0
433334 ) {
434- int rowId = context .groupIdx ; // 0-dim
435- int localId = context .localIdx ; // 0-32
335+ int rowId = context .groupIdx ;
336+ int localId = context .localIdx ;
436337
437338 // Allocate local memory for reduction
438339 float [] localSum = context .allocateFloatLocalArray (localSize );
@@ -444,48 +345,6 @@ public static float matrixVectorRowMajorOptimized(
444345 for (int j = localId ; j < dim1 ; j += localSize ) {
445346 int matrixIdx = rowOffset + j ;
446347 partialSum += m .get (matrixIdx ).getFloat32 () * v .get (j );
447- //partialSum += w.get(rowOffset + j).getFloat32() * x.get(j);
448- }
449-
450- // Store partial sum in local memory
451- localSum [localId ] = partialSum ;
452- context .localBarrier ();
453-
454- // Parallel reduction within workgroup
455- for (int stride = localSize / 2 ; stride > 0 ; stride >>= 1 ) {
456- if (localId < stride ) {
457- localSum [localId ] += localSum [localId + stride ];
458- }
459- context .localBarrier ();
460- }
461-
462- return localSum [0 ];
463- }
464-
465- public static float matrixVectorRowMajorOptimized2 (
466- KernelContext context ,
467- int localSize ,
468- FloatArray v , // input vector [2048]
469- HalfFloatArray m , // matrix [2048, 1024]
470- int vectorSize , // 2048
471- int outputSize ,
472- int rowId // which output row we're computing (0-1023)
473- ) {
474- int localId = context .localIdx ; // 0 to localSize-1
475-
476- // Allocate local memory for reduction
477- float [] localSum = context .allocateFloatLocalArray (localSize );
478-
479- // For matrix [2048, 1024], if we want row 'rowId' of the OUTPUT,
480- // we need to compute dot product of INPUT vector with COLUMN 'rowId' of the matrix
481- // Matrix element [i][j] is at index i * outputSize + j
482- // We want column 'rowId', so elements are at: 0*outputSize + rowId, 1*outputSize + rowId, etc.
483-
484- // Each thread calculates partial dot product
485- float partialSum = 0.0f ;
486- for (int i = localId ; i < vectorSize ; i += localSize ) {
487- int matrixIdx = i * outputSize + rowId ; // Column-wise access for row rowId
488- partialSum += m .get (matrixIdx ).getFloat32 () * v .get (i );
489348 }
490349
491350 // Store partial sum in local memory
0 commit comments