Skip to content

Commit 40733d0

Browse files
Optimize Qcur rmsnorm
1 parent 7d4b30f commit 40733d0

File tree

3 files changed

+138
-43
lines changed

3 files changed

+138
-43
lines changed

src/main/java/com/example/inference/state/Qwen3State.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@ public final class Qwen3State extends State {
2828
public Qwen3State(Configuration config, int batchsize) {
2929
super(config, batchsize);
3030
// Initialize Qwen3-specific field
31+
Qwen3Configuration qwen3config = (Qwen3Configuration) config;
32+
int nEmbdHead = qwen3config.numberOfHeads();
3133
this.kq = ArrayFloatTensor.allocate(config.numberOfHeads(), 32, 15);
32-
this.tempQcur = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
34+
this.tempQcur = new FloatArray(nEmbdHead);
3335
this.tempKcur = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
3436

3537
// dbg buffers
36-
Qwen3Configuration qwen3config = (Qwen3Configuration) config;
3738
int nHeadKv = qwen3config.numberOfKeyValueHeads();
3839
int nEmbdHeadK = qwen3config.numberOfHeadsKey();
3940
int nEmbdKGqa = nEmbdHeadK * nHeadKv;

src/main/java/com/example/tornadovm/Qwen3Kernels.java

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,106 @@ public static void dbgCopy(FloatArray srcBuffer, FloatArray dstBuffer, IntArray
1919
//}
2020
}
2121

22+
public static void rmsnormReductionWithOffset(
23+
KernelContext context,
24+
FloatArray output,
25+
FloatArray x,
26+
int localMemSize) {
27+
28+
// global size: 0 - (config.numberOfHeads() * nEmbdHead)
29+
// local size : 0 - nEmbdHead
30+
int gid = context.globalIdx;
31+
int lid = context.localIdx;
32+
int groupId = context.groupIdx;
33+
int groupSize = context.localGroupSizeX;
34+
35+
// Allocate local memory with the provided size
36+
float[] localX = context.allocateFloatLocalArray(localMemSize);
37+
38+
// Load input value and compute square
39+
//int globalReadIndex = gid + offset;
40+
//if (gid < size && globalReadIndex < x.getSize()) {
41+
localX[lid] = x.get(gid);
42+
localX[lid] = localX[lid] * localX[lid];
43+
//} else {
44+
// localX[lid] = 0.0f;
45+
//}
46+
47+
// Perform parallel reduction within the work group
48+
for (int stride = (groupSize / 2); stride > 0; stride /= 2) {
49+
context.localBarrier();
50+
if (lid < stride) {
51+
localX[lid] += localX[lid + stride];
52+
}
53+
}
54+
55+
// Each workgroup stores its partial sum in a different location
56+
if (lid == 0) {
57+
// Store the partial sum from each workgroup
58+
output.set(groupId, localX[0]);
59+
}
60+
}
61+
62+
// Second kernel - Combines partial sums and computes final normalization
63+
public static void rmsnormFinalNormalizationWithParallelOffset(
64+
KernelContext context,
65+
FloatArray output, // size should be related to offsetIndex
66+
int offsetIndex, // = config.numberOfHeads()
67+
int size,
68+
float ermsNorm) {
69+
70+
int gid = context.globalIdx;
71+
72+
// Only the index threads need to perform this calculation
73+
if (gid < offsetIndex) {
74+
// Combine partial sums from all workgroups
75+
float ss = 0.0f;
76+
//for (int i = 1; i < output.getSize(); i++) { // Fixed bounds to avoid out of bounds
77+
// for (int i = 1; i < output.getSize(); i++) { // Fixed bounds to avoid out of bounds
78+
// ss += output.get(i);
79+
// }
80+
ss = output.get(gid);
81+
82+
ss /= size;
83+
ss += ermsNorm;
84+
ss = 1.0f / TornadoMath.sqrt(ss);
85+
// in place
86+
output.set(gid, ss); // Store the final scale factor
87+
}
88+
}
89+
90+
public static void rmsnormMapIndexInPlaceWithParallelOffset(
91+
KernelContext context,
92+
FloatArray out, // Q
93+
FloatArray weights,
94+
int size,
95+
FloatArray ss // tempQcur1
96+
) {
97+
98+
int gid = context.globalIdx; // 0 - size
99+
//int index = offset + gid;
100+
int groupId = context.groupIdx;
101+
102+
float finalss = ss.get(groupId);
103+
//out.set(index, weights.get(index % size) * (finalss * x.get(index)));
104+
//out.set(index, weights.get(index) * (finalss * x.get(index)));
105+
//if (index < offset + size) {
106+
if (gid < out.getSize()) { // TODO: check if redundant
107+
float a = weights.get(gid % size);
108+
float b = finalss * out.get(gid);
109+
out.set(gid, a * b);
110+
}
111+
112+
//old gid, index:
113+
// int gid = context.globalIdx; // 0 - size
114+
// int index = offset + gid;
115+
// context.globalBarrier();
116+
// // reset ss
117+
// if (gid < ss.getSize()) {
118+
// ss.set(gid, 0.0f);
119+
// }
120+
}
121+
22122
public static void reductionOneBlockWithLayerWithOffset(
23123
KernelContext context,
24124
FloatArray output,

src/main/java/com/example/tornadovm/Qwen3TornadoVMLayerPlanner.java

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
7575
state.tempFFN.init(0.0f);
7676
state.tempLogits.init(0.0f);
7777
state.wrapLogits.init(0.0f);
78-
state.tempQcur.init(0.0f);
7978
state.tempKcur.init(0.0f);
8079

8180
// state.dbgQ.init(0.0f);
@@ -185,28 +184,28 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
185184
// unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapV);
186185

187186
// Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
188-
for (int i = 0; i < config.numberOfHeads(); i++) {
189-
//rmsnorm(state.q, state.q, weights.attnQNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
190-
int offset = i * nEmbdHead;
191-
unifiedLayer.task("reductionsOneBlock" + "_Qcur_" + i ,
192-
Qwen3Kernels::reductionOneBlockWithLayerWithOffset,
193-
context,
194-
state.tempQcur, // output
195-
state.wrapQ, // input
196-
offset, nEmbdHead,
197-
config.rmsNormEps(), state.localSize)
198-
.task("reductionFinalNormalization" + "_Qcur_" + i ,
199-
TransformerComputeKernelsLayered::reductionFinalNormalization, context,
200-
state.tempQcur, // output
201-
nEmbdHead,
202-
config.rmsNormEps())
203-
.task("mapContext" + "_Qcur_" + i,
204-
Qwen3Kernels::mapIndexInPlace, context,
205-
state.wrapQ, // output
206-
weights.rms_att_QNormLayered[layerIndex],
207-
offset, nEmbdHead,
208-
state.tempQcur);
209-
}
187+
//rmsnorm(state.q, state.q, weights.attnQNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
188+
unifiedLayer
189+
.task("reductionsOneBlock_Qcur",
190+
Qwen3Kernels::rmsnormReductionWithOffset,
191+
context,
192+
state.tempQcur, // output
193+
state.wrapQ, // input
194+
state.localSize) // currently 128, should be variable of global nEmbHead
195+
.task("reductionFinalNormalization_Qcur",
196+
Qwen3Kernels::rmsnormFinalNormalizationWithParallelOffset,
197+
context,
198+
state.tempQcur, // output
199+
config.numberOfHeads(),
200+
nEmbdHead,
201+
config.rmsNormEps())
202+
.task("mapContext_Qcur",
203+
Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset,
204+
context,
205+
state.wrapQ, // output
206+
weights.rms_att_QNormLayered[layerIndex],
207+
nEmbdHead,
208+
state.tempQcur);
210209

211210
// unifiedLayer.task("dbg_copy_out_wrapQ",
212211
// Qwen3Kernels::dbgCopy,
@@ -446,6 +445,15 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
446445
curWorker.setGlobalWork(nEmbdHead, 1, 1); // Set global work size to total dimension
447446
curWorker.setLocalWork(128, 1, 1); // Set local work size to 256 (standard efficient size)
448447

448+
// config.numberOfHeads() = 16
449+
// nEmbdHead = 128
450+
// total = 2048
451+
WorkerGrid qCurWorker = new WorkerGrid1D(config.numberOfHeads() * nEmbdHead);
452+
qCurWorker.setLocalWork(nEmbdHead, 1, 1);
453+
454+
WorkerGrid qCurWorker2 = new WorkerGrid1D(config.numberOfHeads());
455+
qCurWorker2.setLocalWork(1, 1, 1);
456+
449457
int h = config.numberOfHeads();
450458
int ic = nEmbdHead / 2;
451459
WorkerGrid ropeWorker = new WorkerGrid2D(h, ic);
@@ -485,24 +493,10 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
485493
gridScheduler.addWorkerGrid("layer_" + i + ".kmatmul", matmulKVRowMajorWorker);
486494
gridScheduler.addWorkerGrid("layer_" + i + ".vmatmul", matmulKVRowMajorWorker);
487495

488-
// //int size = nEmbdHead;
489-
// for (int j = 0; j < config.numberOfHeads(); j++) {
490-
//// int offset = j * nEmbdHead;
491-
//// WorkerGrid qRmsReductionWorker = new WorkerGrid1D(size);
492-
//// qRmsReductionWorker.setLocalWork(state.localSize, 1, 1);
493-
// gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock" + "_Qcur_" + j, curWorker);
494-
// //gridScheduler.addWorkerGrid("layer_" + i + ".reductionFinalNormalization" + "_Qcur_" + j, curWorker);
495-
// gridScheduler.addWorkerGrid("layer_" + i + ".mapContext" + "_Qcur_" + j, curWorker);
496-
// }
497-
// Create separate WorkerGrid for each head
498-
for (int j = 0; j < config.numberOfHeads(); j++) {
499-
WorkerGrid headWorker = new WorkerGrid1D(nEmbdHead); // nEmbdHead = 128
500-
headWorker.setGlobalWork(nEmbdHead, 1, 1); // Set global work size to total dimension
501-
headWorker.setLocalWork(128, 1, 1);
502-
503-
gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock" + "_Qcur_" + j, headWorker);
504-
gridScheduler.addWorkerGrid("layer_" + i + ".mapContext" + "_Qcur_" + j, headWorker);
505-
}
496+
gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock_Qcur", qCurWorker);
497+
gridScheduler.addWorkerGrid("layer_" + i + ".reductionFinalNormalization_Qcur", qCurWorker2);
498+
gridScheduler.addWorkerGrid("layer_" + i + ".mapContext_Qcur", qCurWorker);
499+
506500
for (int j = 0; j < config.numberOfKeyValueHeads(); j++) {
507501
gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock" + "_Kcur_" + j, curWorker);
508502
//gridScheduler.addWorkerGrid("layer_" + i + ".reductionFinalNormalization" + "_Kcur_" + j, curWorker);

0 commit comments

Comments
 (0)