@@ -34,10 +34,10 @@ public TransformerComputeKernelsLayered() {
3434 * @param localMemSize Size of local memory allocation (must match work group size)
3535 */
3636 public static void reductionOneBlockWithLayer (KernelContext context , FloatArray output , FloatArray x , int size , float ermsNorm , int localMemSize ) {
37- int gid = context .globalIdx ; // 0-1024
38- int lid = context .localIdx ; // 0-256
39- int groupId = context .groupIdx ; // 0-4
40- int groupSize = context .localGroupSizeX ; // 256
37+ int gid = context .globalIdx ;
38+ int lid = context .localIdx ;
39+ int groupId = context .groupIdx ;
40+ int groupSize = context .localGroupSizeX ;
4141
4242 // Allocate local memory with the provided size
4343 float [] localX = context .allocateFloatLocalArray (localMemSize );
@@ -115,8 +115,7 @@ public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray
115115 * @param layer Current transformer layer index
116116 * @param contextLength Maximum sequence length
117117 */
118- public static void copyToCache (FloatArray destKeyCache , FloatArray srcKey , FloatArray destValueCache , FloatArray srcValue ,
119- IntArray positioNlayer , int kvDim , int layer , int contextLength ) {
118+ public static void copyToCache (FloatArray destKeyCache , FloatArray srcKey , FloatArray destValueCache , FloatArray srcValue , IntArray positioNlayer , int kvDim , int layer , int contextLength ) {
120119
121120 int position = positioNlayer .get (0 );
122121 int loff = layer * contextLength * kvDim ;
@@ -195,14 +194,8 @@ public static void ropeRotation(KernelContext context, IntArray positionHolder,
195194 * @param layer Current transformer layer
196195 * @param contextLength Maximum context length
197196 */
198- public static void processHeadsParallel (
199- FloatArray q ,
200- FloatArray key_cache ,
201- FloatArray value_cache ,
202- FloatArray xb ,
203- int nHeads , int headSize , int kvDim , int kvMul , int seqLen ,
204- IntArray positionHolder ,
205- FloatArray wrapAtt , int layer , int contextLength ) {
197+ public static void processHeadsParallel (FloatArray q , FloatArray key_cache , FloatArray value_cache , FloatArray xb , int nHeads , int headSize , int kvDim , int kvMul , int seqLen ,
198+ IntArray positionHolder , FloatArray wrapAtt , int layer , int contextLength ) {
206199
207200 int pos = positionHolder .get (0 );
208201 int loff = layer * contextLength * kvDim ;
@@ -663,8 +656,7 @@ public static void matrixVectorGeneric(
663656 * @param d Output dimension
664657 * @param localWorkGroupSize Work group size
665658 */
666- public static void matrixVectorGenericWithResidual (KernelContext context , FloatArray x , FloatArray hb , HalfFloatArray w ,
667- int n , int d , int localWorkGroupSize ) {
659+ public static void matrixVectorGenericWithResidual (KernelContext context , FloatArray x , FloatArray hb , HalfFloatArray w , int n , int d , int localWorkGroupSize ) {
668660 // One row per workgroup (not per thread)
669661 int rowId = context .groupIdx ;
670662 int localId = context .localIdx ;
@@ -794,8 +786,8 @@ public static float matrixVectorRowMajorOptimized(KernelContext context, int loc
794786 }
795787
796788 public static float matrixVectorRowMajorOptimized (KernelContext context , int localSize , FloatArray x , HalfFloatArray w , int n ) {
797- int rowId = context .groupIdx ; // 0-dim
798- int localId = context .localIdx ; // 0-32
789+ int rowId = context .groupIdx ;
790+ int localId = context .localIdx ;
799791
800792 // Allocate local memory for reduction
801793 float [] localSum = context .allocateFloatLocalArray (localSize );
0 commit comments