@@ -75,7 +75,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
7575 state .tempFFN .init (0.0f );
7676 state .tempLogits .init (0.0f );
7777 state .wrapLogits .init (0.0f );
78- state .tempQcur .init (0.0f );
7978 state .tempKcur .init (0.0f );
8079
8180// state.dbgQ.init(0.0f);
@@ -185,28 +184,28 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
185184// unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapV);
186185
187186 // Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
188- for ( int i = 0 ; i < config .numberOfHeads (); i ++) {
189- //rmsnorm(state.q, state.q, weights.attnQNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
190- int offset = i * nEmbdHead ;
191- unifiedLayer . task ( "reductionsOneBlock" + "_Qcur_" + i ,
192- Qwen3Kernels :: reductionOneBlockWithLayerWithOffset ,
193- context ,
194- state .tempQcur , // output
195- state .wrapQ , // input
196- offset , nEmbdHead ,
197- config . rmsNormEps (), state . localSize )
198- . task ( "reductionFinalNormalization" + "_Qcur_" + i ,
199- TransformerComputeKernelsLayered :: reductionFinalNormalization , context ,
200- state . tempQcur , // output
201- nEmbdHead ,
202- config .rmsNormEps ())
203- .task ("mapContext" + "_Qcur_" + i ,
204- Qwen3Kernels ::mapIndexInPlace , context ,
205- state . wrapQ , // output
206- weights . rms_att_QNormLayered [ layerIndex ],
207- offset , nEmbdHead ,
208- state . tempQcur );
209- }
187+ //rmsnorm(state.q, state.q, weights.attnQNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
188+ unifiedLayer
189+ . task ( "reductionsOneBlock_Qcur" ,
190+ Qwen3Kernels :: rmsnormReductionWithOffset ,
191+ context ,
192+ state . tempQcur , // output
193+ state .wrapQ , // input
194+ state .localSize ) // currently 128, should be variable of global nEmbHead
195+ . task ( "reductionFinalNormalization_Qcur" ,
196+ Qwen3Kernels :: rmsnormFinalNormalizationWithParallelOffset ,
197+ context ,
198+ state . tempQcur , // output
199+ config . numberOfHeads (),
200+ nEmbdHead ,
201+ config .rmsNormEps ())
202+ .task ("mapContext_Qcur" ,
203+ Qwen3Kernels ::rmsnormMapIndexInPlaceWithParallelOffset ,
204+ context ,
205+ state . wrapQ , // output
206+ weights . rms_att_QNormLayered [ layerIndex ] ,
207+ nEmbdHead ,
208+ state . tempQcur );
210209
211210// unifiedLayer.task("dbg_copy_out_wrapQ",
212211// Qwen3Kernels::dbgCopy,
@@ -446,6 +445,15 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
446445 curWorker .setGlobalWork (nEmbdHead , 1 , 1 ); // Set global work size to total dimension
447446 curWorker .setLocalWork (128 , 1 , 1 ); // Set local work size to 256 (standard efficient size)
448447
448+ // config.numberOfHeads() = 16
449+ // nEmbdHead = 128
450+ // total = 2048
451+ WorkerGrid qCurWorker = new WorkerGrid1D (config .numberOfHeads () * nEmbdHead );
452+ qCurWorker .setLocalWork (nEmbdHead , 1 , 1 );
453+
454+ WorkerGrid qCurWorker2 = new WorkerGrid1D (config .numberOfHeads ());
455+ qCurWorker2 .setLocalWork (1 , 1 , 1 );
456+
449457 int h = config .numberOfHeads ();
450458 int ic = nEmbdHead / 2 ;
451459 WorkerGrid ropeWorker = new WorkerGrid2D (h , ic );
@@ -485,24 +493,10 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
485493 gridScheduler .addWorkerGrid ("layer_" + i + ".kmatmul" , matmulKVRowMajorWorker );
486494 gridScheduler .addWorkerGrid ("layer_" + i + ".vmatmul" , matmulKVRowMajorWorker );
487495
488- // //int size = nEmbdHead;
489- // for (int j = 0; j < config.numberOfHeads(); j++) {
490- //// int offset = j * nEmbdHead;
491- //// WorkerGrid qRmsReductionWorker = new WorkerGrid1D(size);
492- //// qRmsReductionWorker.setLocalWork(state.localSize, 1, 1);
493- // gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock" + "_Qcur_" + j, curWorker);
494- // //gridScheduler.addWorkerGrid("layer_" + i + ".reductionFinalNormalization" + "_Qcur_" + j, curWorker);
495- // gridScheduler.addWorkerGrid("layer_" + i + ".mapContext" + "_Qcur_" + j, curWorker);
496- // }
497- // Create separate WorkerGrid for each head
498- for (int j = 0 ; j < config .numberOfHeads (); j ++) {
499- WorkerGrid headWorker = new WorkerGrid1D (nEmbdHead ); // nEmbdHead = 128
500- headWorker .setGlobalWork (nEmbdHead , 1 , 1 ); // Set global work size to total dimension
501- headWorker .setLocalWork (128 , 1 , 1 );
502-
503- gridScheduler .addWorkerGrid ("layer_" + i + ".reductionsOneBlock" + "_Qcur_" + j , headWorker );
504- gridScheduler .addWorkerGrid ("layer_" + i + ".mapContext" + "_Qcur_" + j , headWorker );
505- }
496+ gridScheduler .addWorkerGrid ("layer_" + i + ".reductionsOneBlock_Qcur" , qCurWorker );
497+ gridScheduler .addWorkerGrid ("layer_" + i + ".reductionFinalNormalization_Qcur" , qCurWorker2 );
498+ gridScheduler .addWorkerGrid ("layer_" + i + ".mapContext_Qcur" , qCurWorker );
499+
506500 for (int j = 0 ; j < config .numberOfKeyValueHeads (); j ++) {
507501 gridScheduler .addWorkerGrid ("layer_" + i + ".reductionsOneBlock" + "_Kcur_" + j , curWorker );
508502 //gridScheduler.addWorkerGrid("layer_" + i + ".reductionFinalNormalization" + "_Kcur_" + j, curWorker);
0 commit comments