Skip to content

Commit 1306591

Browse files
Apply optimizations to Kcur rmsnorm and rename some Qcur fields
1 parent 40733d0 commit 1306591

File tree

2 files changed

+45
-40
lines changed

2 files changed

+45
-40
lines changed

src/main/java/com/example/inference/state/Qwen3State.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public Qwen3State(Configuration config, int batchsize) {
3232
int nEmbdHead = qwen3config.numberOfHeads();
3333
this.kq = ArrayFloatTensor.allocate(config.numberOfHeads(), 32, 15);
3434
this.tempQcur = new FloatArray(nEmbdHead);
35-
this.tempKcur = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
35+
this.tempKcur = new FloatArray(nEmbdHead);
3636

3737
// dbg buffers
3838
int nHeadKv = qwen3config.numberOfKeyValueHeads();

src/main/java/com/example/tornadovm/Qwen3TornadoVMLayerPlanner.java

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -186,20 +186,20 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
186186
// Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
187187
//rmsnorm(state.q, state.q, weights.attnQNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
188188
unifiedLayer
189-
.task("reductionsOneBlock_Qcur",
189+
.task("rmsnormReduction_Qcur",
190190
Qwen3Kernels::rmsnormReductionWithOffset,
191191
context,
192192
state.tempQcur, // output
193193
state.wrapQ, // input
194194
state.localSize) // currently 128, should be variable of global nEmbHead
195-
.task("reductionFinalNormalization_Qcur",
195+
.task("rmsnormFinalNormalization_Qcur",
196196
Qwen3Kernels::rmsnormFinalNormalizationWithParallelOffset,
197197
context,
198198
state.tempQcur, // output
199199
config.numberOfHeads(),
200200
nEmbdHead,
201201
config.rmsNormEps())
202-
.task("mapContext_Qcur",
202+
.task("rmsnormMapIndexInPlace_Qcur",
203203
Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset,
204204
context,
205205
state.wrapQ, // output
@@ -217,34 +217,28 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
217217
// unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
218218
//
219219
// Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
220-
for (int i = 0; i < config.numberOfKeyValueHeads(); i++) {
221-
//rmsnorm(state.k, state.k, weights.attnKNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
222-
int offset = i * nEmbdHead;
223-
unifiedLayer.task("reductionsOneBlock" + "_Kcur_" + i ,
224-
Qwen3Kernels::reductionOneBlockWithLayerWithOffset,
225-
context,
226-
state.tempKcur,
227-
state.wrapK,
228-
offset,
229-
nEmbdHead,
230-
config.rmsNormEps(),
231-
state.localSize)
232-
// maybe in CPU ?
233-
.task("reductionFinalNormalization" + "_Kcur_" + i ,
234-
TransformerComputeKernelsLayered::reductionFinalNormalization,
235-
context,
236-
state.tempKcur,
237-
nEmbdHead,
238-
config.rmsNormEps())
239-
.task("mapContext" + "_Kcur_" + i,
240-
Qwen3Kernels::mapIndexInPlace,
241-
context,
242-
state.wrapK,
243-
weights.rms_att_KNormLayered[layerIndex],
244-
offset,
245-
nEmbdHead,
246-
state.tempKcur);
247-
}
220+
//rmsnorm(state.k, state.k, weights.attnKNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
221+
unifiedLayer
222+
.task("rmsnormReduction_Kcur",
223+
Qwen3Kernels::rmsnormReductionWithOffset,
224+
context,
225+
state.tempKcur, // output
226+
state.wrapK, // input
227+
state.localSize) // currently 128, should be variable of global nEmbHead
228+
.task("rmsnormFinalNormalization_Kcur",
229+
Qwen3Kernels::rmsnormFinalNormalizationWithParallelOffset,
230+
context,
231+
state.tempKcur, // output
232+
config.numberOfKeyValueHeads(),
233+
nEmbdHead,
234+
config.rmsNormEps())
235+
.task("rmsnormMapIndexInPlace_Kcur",
236+
Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset,
237+
context,
238+
state.wrapK, // output
239+
weights.rms_att_KNormLayered[layerIndex],
240+
nEmbdHead,
241+
state.tempKcur);
248242
// dbg copy out
249243
//unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapQ);
250244
//unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
@@ -445,6 +439,7 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
445439
curWorker.setGlobalWork(nEmbdHead, 1, 1); // Set global work size to total dimension
446440
curWorker.setLocalWork(128, 1, 1); // Set local work size to 256 (standard efficient size)
447441

442+
// Qcur
448443
// config.numberOfHeads() = 16
449444
// nEmbdHead = 128
450445
// total = 2048
@@ -454,6 +449,16 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
454449
WorkerGrid qCurWorker2 = new WorkerGrid1D(config.numberOfHeads());
455450
qCurWorker2.setLocalWork(1, 1, 1);
456451

452+
// Kcur
453+
// config.numberOfKeyValueHeads() = 8
454+
// nEmbdHead = 128
455+
// total = 1024
456+
WorkerGrid kCurWorker = new WorkerGrid1D(config.numberOfKeyValueHeads() * nEmbdHead);
457+
kCurWorker.setLocalWork(nEmbdHead, 1, 1);
458+
459+
WorkerGrid kCurWorker2 = new WorkerGrid1D(config.numberOfKeyValueHeads());
460+
kCurWorker2.setLocalWork(1, 1, 1);
461+
457462
int h = config.numberOfHeads();
458463
int ic = nEmbdHead / 2;
459464
WorkerGrid ropeWorker = new WorkerGrid2D(h, ic);
@@ -493,15 +498,15 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
493498
gridScheduler.addWorkerGrid("layer_" + i + ".kmatmul", matmulKVRowMajorWorker);
494499
gridScheduler.addWorkerGrid("layer_" + i + ".vmatmul", matmulKVRowMajorWorker);
495500

496-
gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock_Qcur", qCurWorker);
497-
gridScheduler.addWorkerGrid("layer_" + i + ".reductionFinalNormalization_Qcur", qCurWorker2);
498-
gridScheduler.addWorkerGrid("layer_" + i + ".mapContext_Qcur", qCurWorker);
501+
// Qcur
502+
gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Qcur", qCurWorker);
503+
gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormFinalNormalization_Qcur", qCurWorker2);
504+
gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Qcur", qCurWorker);
499505

500-
for (int j = 0; j < config.numberOfKeyValueHeads(); j++) {
501-
gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock" + "_Kcur_" + j, curWorker);
502-
//gridScheduler.addWorkerGrid("layer_" + i + ".reductionFinalNormalization" + "_Kcur_" + j, curWorker);
503-
gridScheduler.addWorkerGrid("layer_" + i + ".mapContext" + "_Kcur_" + j, curWorker);
504-
}
506+
// Kcur
507+
gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Kcur", kCurWorker);
508+
gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormFinalNormalization_Kcur", kCurWorker2);
509+
gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Kcur", kCurWorker);
505510

506511
gridScheduler.addWorkerGrid("layer_" + i + ".ropeRotation", ropeWorker);
507512
gridScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker);

0 commit comments

Comments
 (0)