22
33import com .example .auxiliary .Tuple2 ;
44import com .example .inference .state .Qwen3State ;
5- import com .example .inference .state .State ;
65import com .example .inference .weights .tornado .Qwen3TornadoWeights ;
76import com .example .model .Model ;
87import com .example .model .qwen3 .Qwen3Configuration ;
@@ -109,8 +108,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
109108 config .dim (),
110109 config .rmsNormEps (),
111110 state .localSize )
112- //.task("reductionFinalNormalization" , TransformerComputeKernelsLayered::reductionFinalNormalization, context,
113- //state.temp, config.dim(), config.rmsNormEps())
114111 .task ("mapContext" ,
115112 TransformerComputeKernelsLayered ::reductionOneBlock2WithLayer ,
116113 context ,
@@ -119,16 +116,9 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
119116 weights .rms_att_weightLayered [layerIndex ],
120117 state .temp );
121118
122- //unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapXb);
123-
124- // // dbg copy out
125- // unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.temp);
126- // unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapXb);
127-
128119 int qDim0 = nEmbdHeadK * config .numberOfHeads ();
129120 int kvDim0 = nEmbdGqa ;
130121 int qkvDim1 = config .dim ();
131- //qkvMatmuls = new TaskGraph("qkvMatmuls_layer_" + layerIndex);
132122 unifiedLayer .task ("qmatmul" ,
133123 TransformerComputeKernelsLayered ::matrixVectorGeneric ,
134124 context ,
@@ -157,11 +147,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
157147 kvDim0 ,
158148 LOCAL_WORK_GROUP_SIZE_ALLOC );
159149
160- // dbg copy out
161- // unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapQ);
162- // unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
163- // unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapV);
164-
165150 // Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
166151 //rmsnorm(state.q, state.q, weights.attnQNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
167152 unifiedLayer
@@ -173,23 +158,14 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
173158 state .localSize , // currently 128, should be variable of global nEmbHead
174159 nEmbdHead , // for normalization
175160 config .rmsNormEps ()) // for normalization
176- // .task("rmsnormFinalNormalization_Qcur",
177- // Qwen3Kernels::rmsnormFinalNormalizationWithParallelOffset,
178- // context,
179- // state.tempQcur, // output
180- // config.numberOfHeads(),
181- // nEmbdHead,
182- // config.rmsNormEps())
183161 .task ("rmsnormMapIndexInPlace_Qcur" ,
184162 Qwen3Kernels ::rmsnormMapIndexInPlaceWithParallelOffset ,
185163 context ,
186164 state .wrapQ , // output
187165 weights .rms_att_QNormLayered [layerIndex ],
188166 nEmbdHead ,
189167 state .tempQcur );
190- // unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapQ);
191- // unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
192- //
168+
193169 // Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
194170 //rmsnorm(state.k, state.k, weights.attnKNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
195171 unifiedLayer
@@ -201,24 +177,13 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
201177 state .localSize , // currently 128, should be variable of global nEmbHead
202178 nEmbdHead , // for normalization
203179 config .rmsNormEps ()) // for normalization
204- // .task("rmsnormFinalNormalization_Kcur",
205- // Qwen3Kernels::rmsnormFinalNormalizationWithParallelOffset,
206- // context,
207- // state.tempKcur, // output
208- // config.numberOfKeyValueHeads(),
209- // nEmbdHead,
210- // config.rmsNormEps())
211180 .task ("rmsnormMapIndexInPlace_Kcur" ,
212181 Qwen3Kernels ::rmsnormMapIndexInPlaceWithParallelOffset ,
213182 context ,
214183 state .wrapK , // output
215184 weights .rms_att_KNormLayered [layerIndex ],
216185 nEmbdHead ,
217186 state .tempKcur );
218- // dbg copy out
219- //unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapQ);
220- //unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
221- //unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapV);
222187
223188 // rope rotation task graph
224189 unifiedLayer .task ("ropeRotation" ,
@@ -230,10 +195,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
230195 config .numberOfKeyValueHeads (),
231196 nEmbdHead );
232197
233- // dbg copy out
234- //unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapQ);
235- //unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
236-
237198 unifiedLayer .task ("copyToCaches" ,
238199 TransformerComputeKernelsLayered ::copyToCache ,
239200 state .wrapKeyCache , // out
@@ -245,7 +206,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
245206 layerIndex ,
246207 config .contextLength ());
247208
248- // global size = numberOfHeads * 8 = 16 * 8 = 128
249209 unifiedLayer .task ("parallel-attention" ,
250210 TransformerComputeKernelsLayered ::processHeadsFlashAttentionOpt ,
251211 context ,
@@ -261,7 +221,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
261221 layerIndex ,
262222 config .contextLength ());
263223
264- //unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapXb);
265224 unifiedLayer .task ("matmul1" , Qwen3Kernels ::matrixVectorGenericWithResidual ,
266225 context ,
267226 state .wrapXb , // vector
@@ -271,7 +230,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
271230 config .dim (), // dim0 = 1024
272231 LOCAL_WORK_GROUP_SIZE_ALLOC );
273232
274- //unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapX);
275233 unifiedLayer .task ("reductionsOneBlockFFN" , TransformerComputeKernelsLayered ::reductionOneBlockWithLayer ,
276234 context , state .tempFFN , state .wrapX , config .dim (), config .rmsNormEps (), state .localSize )
277235 .task ("reductionFinalNormalizationFFN" , TransformerComputeKernelsLayered ::reductionFinalNormalization , context , state .tempFFN ,
@@ -283,7 +241,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
283241 state .wrapXb , state .wrapHb , weights .w1Layered [layerIndex ], weights .w3Layered [layerIndex ], config .dim (), config .hiddenDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
284242 .task ("projectionTwo" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidual , context ,
285243 state .wrapHb , state .wrapX , weights .w2Layered [layerIndex ], config .hiddenDim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
286- //.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapX)
287244 .persistOnDevice (
288245 state .wrapX
289246 );
@@ -295,14 +252,12 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
295252 .consumeFromDevice (lastUnifiedLayer .getTaskGraphName (),
296253 state .wrapX
297254 )
298- //.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapX)
299255 .transferToDevice (DataTransferMode .EVERY_EXECUTION ,
300256 state .tempLogits ,
301257 state .wrapLogits
302258 )
303259 .transferToDevice (DataTransferMode .FIRST_EXECUTION ,
304260 context ,
305- //state.wrapLogits,
306261 weights .wclsHalfFloat ,
307262 weights .rms_final_weight_as_floatArray
308263 )
@@ -313,13 +268,8 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
313268 config .dim (),
314269 config .rmsNormEps (),
315270 state .localSize )
316- // .transferToHost(DataTransferMode.EVERY_EXECUTION, state.tempLogits)
317- // .transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapX)
318- // .task("reductionFinalNormalizationLogits" , TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits,
319- // config.dim(), config.rmsNormEps())
320271 .task ("mapContextLogits" , TransformerComputeKernels ::reductionOneBlock2WithLogits , context , state .wrapX ,
321272 weights .rms_final_weight_as_floatArray , state .tempLogits );
322- //.transferToHost(DataTransferMode.EVERY_EXECUTION, state.tempLogits);
323273 logits = configureQuantizedMatrixVectorFinalWeight (logits );
324274 logits .transferToHost (DataTransferMode .EVERY_EXECUTION , state .wrapLogits );
325275 taskGraphs .add (logits .snapshot ());
@@ -357,25 +307,13 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
357307 curWorker .setLocalWork (128 , 1 , 1 ); // Set local work size to 256 (standard efficient size)
358308
359309 // Qcur
360- // config.numberOfHeads() = 16
361- // nEmbdHead = 128
362- // total = 2048
363310 WorkerGrid qCurWorker = new WorkerGrid1D (config .numberOfHeads () * nEmbdHead );
364311 qCurWorker .setLocalWork (nEmbdHead , 1 , 1 );
365312
366- // WorkerGrid qCurWorker2 = new WorkerGrid1D(config.numberOfHeads());
367- // qCurWorker2.setLocalWork(1, 1, 1);
368-
369313 // Kcur
370- // config.numberOfKeyValueHeads() = 8
371- // nEmbdHead = 128
372- // total = 1024
373314 WorkerGrid kCurWorker = new WorkerGrid1D (config .numberOfKeyValueHeads () * nEmbdHead );
374315 kCurWorker .setLocalWork (nEmbdHead , 1 , 1 );
375316
376- // WorkerGrid kCurWorker2 = new WorkerGrid1D(config.numberOfKeyValueHeads());
377- // kCurWorker2.setLocalWork(1, 1, 1);
378-
379317 int h = config .numberOfHeads ();
380318 int ic = nEmbdHead / 2 ;
381319 WorkerGrid ropeWorker = new WorkerGrid2D (h , ic );
@@ -384,13 +322,12 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
384322
385323 WorkerGrid copyToCachesWorker = new WorkerGrid1D (nEmbdGqa );
386324 copyToCachesWorker .setGlobalWork (nEmbdGqa , 1 , 1 );
387- copyToCachesWorker .setLocalWork (128 , 1 , 1 ); // Set local work size to 32 (for copying to caches)
325+ copyToCachesWorker .setLocalWork (128 , 1 , 1 );
388326
389327 // Parallel attention worker configuration
390- WorkerGrid parallelAttentionWorker = new WorkerGrid1D (config .numberOfHeads ()); // qwen ok
391- // the global group work size is numberOfHeads * localWorkGroupSize, where the localWorkGroupSize is currently 4
328+ WorkerGrid parallelAttentionWorker = new WorkerGrid1D (config .numberOfHeads ());
392329 parallelAttentionWorker .setGlobalWork (config .numberOfHeads () * 32 , 1 , 1 );
393- parallelAttentionWorker .setLocalWork (32 , 1 , 1 ); // Set local work size to 4 (for parallel attention)
330+ parallelAttentionWorker .setLocalWork (32 , 1 , 1 );
394331
395332 int matmul1Global = config .dim () * LOCAL_WORK_GROUP_SIZE_ALLOC ;
396333 WorkerGrid matmul1Worker = new WorkerGrid1D (matmul1Global );
@@ -408,7 +345,6 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
408345 gridScheduler .addWorkerGrid ("activationUpdate.updateX" , singleWorker );
409346 for (int i = 0 ; i < config .numberOfLayers (); i ++) {
410347 gridScheduler .addWorkerGrid ("layer_" + i + ".reductionsOneBlock" , rmsNormWorker );
411- //gridScheduler.addWorkerGrid("layer_" + i + ".reductionFinalNormalization", rmsNormWorker);
412348 gridScheduler .addWorkerGrid ("layer_" + i + ".mapContext" , rmsNormWorker );
413349
414350 gridScheduler .addWorkerGrid ("layer_" + i + ".qmatmul" , matmulQRowMajorWorker );
@@ -417,20 +353,17 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
417353
418354 // Qcur
419355 gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormReduction_Qcur" , qCurWorker );
420- //gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormFinalNormalization_Qcur", qCurWorker2);
421356 gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormMapIndexInPlace_Qcur" , qCurWorker );
422357
423358 // Kcur
424359 gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormReduction_Kcur" , kCurWorker );
425- //gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormFinalNormalization_Kcur", kCurWorker2);
426360 gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormMapIndexInPlace_Kcur" , kCurWorker );
427361
428362 gridScheduler .addWorkerGrid ("layer_" + i + ".ropeRotation" , ropeWorker );
429363 gridScheduler .addWorkerGrid ("layer_" + i + ".copyToCaches" , copyToCachesWorker );
430364 gridScheduler .addWorkerGrid ("layer_" + i + ".parallel-attention" , parallelAttentionWorker );
431365 gridScheduler .addWorkerGrid ("layer_" + i + ".matmul1" , matmul1Worker );
432366 gridScheduler .addWorkerGrid ("layer_" + i + ".reductionsOneBlockFFN" , rmsNormWorker );
433- //gridScheduler.addWorkerGrid("layer_" + i + ".reductionFinalNormalizationFFN", rmsNormWorker);
434367 gridScheduler .addWorkerGrid ("layer_" + i + ".mapContextFFN" , rmsNormWorker );
435368 gridScheduler .addWorkerGrid ("layer_" + i + ".fused_ffn_w1_w3" , fusedFFNW1W3Worker );
436369 gridScheduler .addWorkerGrid ("layer_" + i + ".projectionTwo" , projectionTwoWorker );
0 commit comments