@@ -186,20 +186,20 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
186186 // Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
187187 //rmsnorm(state.q, state.q, weights.attnQNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
188188 unifiedLayer
189- .task ("reductionsOneBlock_Qcur " ,
189+ .task ("rmsnormReduction_Qcur " ,
190190 Qwen3Kernels ::rmsnormReductionWithOffset ,
191191 context ,
192192 state .tempQcur , // output
193193 state .wrapQ , // input
194194 state .localSize ) // currently 128, should be variable of global nEmbHead
195- .task ("reductionFinalNormalization_Qcur " ,
195+ .task ("rmsnormFinalNormalization_Qcur " ,
196196 Qwen3Kernels ::rmsnormFinalNormalizationWithParallelOffset ,
197197 context ,
198198 state .tempQcur , // output
199199 config .numberOfHeads (),
200200 nEmbdHead ,
201201 config .rmsNormEps ())
202- .task ("mapContext_Qcur " ,
202+ .task ("rmsnormMapIndexInPlace_Qcur " ,
203203 Qwen3Kernels ::rmsnormMapIndexInPlaceWithParallelOffset ,
204204 context ,
205205 state .wrapQ , // output
@@ -217,34 +217,28 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
217217// unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
218218//
219219 // Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
220- for (int i = 0 ; i < config .numberOfKeyValueHeads (); i ++) {
221- //rmsnorm(state.k, state.k, weights.attnKNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
222- int offset = i * nEmbdHead ;
223- unifiedLayer .task ("reductionsOneBlock" + "_Kcur_" + i ,
224- Qwen3Kernels ::reductionOneBlockWithLayerWithOffset ,
225- context ,
226- state .tempKcur ,
227- state .wrapK ,
228- offset ,
229- nEmbdHead ,
230- config .rmsNormEps (),
231- state .localSize )
232- // maybe in CPU ?
233- .task ("reductionFinalNormalization" + "_Kcur_" + i ,
234- TransformerComputeKernelsLayered ::reductionFinalNormalization ,
235- context ,
236- state .tempKcur ,
237- nEmbdHead ,
238- config .rmsNormEps ())
239- .task ("mapContext" + "_Kcur_" + i ,
240- Qwen3Kernels ::mapIndexInPlace ,
241- context ,
242- state .wrapK ,
243- weights .rms_att_KNormLayered [layerIndex ],
244- offset ,
245- nEmbdHead ,
246- state .tempKcur );
247- }
220+ //rmsnorm(state.k, state.k, weights.attnKNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
221+ unifiedLayer
222+ .task ("rmsnormReduction_Kcur" ,
223+ Qwen3Kernels ::rmsnormReductionWithOffset ,
224+ context ,
225+ state .tempKcur , // output
226+ state .wrapK , // input
227+ state .localSize ) // currently 128, should be variable of global nEmbHead
228+ .task ("rmsnormFinalNormalization_Kcur" ,
229+ Qwen3Kernels ::rmsnormFinalNormalizationWithParallelOffset ,
230+ context ,
231+ state .tempKcur , // output
232+ config .numberOfKeyValueHeads (),
233+ nEmbdHead ,
234+ config .rmsNormEps ())
235+ .task ("rmsnormMapIndexInPlace_Kcur" ,
236+ Qwen3Kernels ::rmsnormMapIndexInPlaceWithParallelOffset ,
237+ context ,
238+ state .wrapK , // output
239+ weights .rms_att_KNormLayered [layerIndex ],
240+ nEmbdHead ,
241+ state .tempKcur );
248242 // dbg copy out
249243 //unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapQ);
250244 //unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
@@ -445,6 +439,7 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
445439 curWorker .setGlobalWork (nEmbdHead , 1 , 1 ); // Set global work size to total dimension
446440 curWorker .setLocalWork (128 , 1 , 1 ); // Set local work size to 256 (standard efficient size)
447441
442+ // Qcur
448443 // config.numberOfHeads() = 16
449444 // nEmbdHead = 128
450445 // total = 2048
@@ -454,6 +449,16 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
454449 WorkerGrid qCurWorker2 = new WorkerGrid1D (config .numberOfHeads ());
455450 qCurWorker2 .setLocalWork (1 , 1 , 1 );
456451
452+ // Kcur
453+ // config.numberOfKeyValueHeads() = 8
454+ // nEmbdHead = 128
455+ // total = 1024
456+ WorkerGrid kCurWorker = new WorkerGrid1D (config .numberOfKeyValueHeads () * nEmbdHead );
457+ kCurWorker .setLocalWork (nEmbdHead , 1 , 1 );
458+
459+ WorkerGrid kCurWorker2 = new WorkerGrid1D (config .numberOfKeyValueHeads ());
460+ kCurWorker2 .setLocalWork (1 , 1 , 1 );
461+
457462 int h = config .numberOfHeads ();
458463 int ic = nEmbdHead / 2 ;
459464 WorkerGrid ropeWorker = new WorkerGrid2D (h , ic );
@@ -493,15 +498,15 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
493498 gridScheduler .addWorkerGrid ("layer_" + i + ".kmatmul" , matmulKVRowMajorWorker );
494499 gridScheduler .addWorkerGrid ("layer_" + i + ".vmatmul" , matmulKVRowMajorWorker );
495500
496- gridScheduler .addWorkerGrid ("layer_" + i + ".reductionsOneBlock_Qcur" , qCurWorker );
497- gridScheduler .addWorkerGrid ("layer_" + i + ".reductionFinalNormalization_Qcur" , qCurWorker2 );
498- gridScheduler .addWorkerGrid ("layer_" + i + ".mapContext_Qcur" , qCurWorker );
501+ // Qcur
502+ gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormReduction_Qcur" , qCurWorker );
503+ gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormFinalNormalization_Qcur" , qCurWorker2 );
504+ gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormMapIndexInPlace_Qcur" , qCurWorker );
499505
500- for (int j = 0 ; j < config .numberOfKeyValueHeads (); j ++) {
501- gridScheduler .addWorkerGrid ("layer_" + i + ".reductionsOneBlock" + "_Kcur_" + j , curWorker );
502- //gridScheduler.addWorkerGrid("layer_" + i + ".reductionFinalNormalization" + "_Kcur_" + j, curWorker);
503- gridScheduler .addWorkerGrid ("layer_" + i + ".mapContext" + "_Kcur_" + j , curWorker );
504- }
506+ // Kcur
507+ gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormReduction_Kcur" , kCurWorker );
508+ gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormFinalNormalization_Kcur" , kCurWorker2 );
509+ gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormMapIndexInPlace_Kcur" , kCurWorker );
505510
506511 gridScheduler .addWorkerGrid ("layer_" + i + ".ropeRotation" , ropeWorker );
507512 gridScheduler .addWorkerGrid ("layer_" + i + ".copyToCaches" , copyToCachesWorker );
0 commit comments