Skip to content

Commit c73d8f4

Browse files
[WIP] Add a initial Tornado inference implementation for Qwen3 with correct results
1 parent 354aad5 commit c73d8f4

File tree

12 files changed

+1175
-40
lines changed

12 files changed

+1175
-40
lines changed

src/main/java/com/example/inference/InferenceEngine.java

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,15 @@ public static List<Integer> generateTokensQwen3(Model model, State state, int st
159159
// We're still processing the prompt tokens
160160
final int token = promptTokens.get(promptIndex);
161161

162+
//System.out.println("Token: " + token);
162163
model.forward(state, token, position);
163164

165+
// System.out.println("Token = " + token + " -> state.logits = { " +
166+
// state.logits.getFloat(0) + ", " +
167+
// state.logits.getFloat(1) + ", " +
168+
// state.logits.getFloat(2) + ", " +
169+
// state.logits.getFloat(3) + " }");
170+
164171
promptIndex++;
165172
if (promptIndex < promptTokens.size()) {
166173
continue;
@@ -177,13 +184,28 @@ public static List<Integer> generateTokensQwen3(Model model, State state, int st
177184
inferenceStartNanos = System.nanoTime();
178185
}
179186

187+
//System.out.println("currentToken: " + currentToken);
180188
model.forward(state, currentToken, position);
181189

190+
// System.out.println("currentToken = " + currentToken + " -> state.logits = { " +
191+
// state.logits.getFloat(0) + ", " +
192+
// state.logits.getFloat(1) + ", " +
193+
// state.logits.getFloat(2) + ", " +
194+
// state.logits.getFloat(3) + " }");
195+
182196
}
183197

198+
// System.out.print("state.logits = { " +
199+
// state.logits.getFloat(0) + ", " +
200+
// state.logits.getFloat(1) + ", " +
201+
// state.logits.getFloat(2) + ", " +
202+
// state.logits.getFloat(3) + "}");
203+
184204
// Sample the next token
185205
nextToken = sampler.sampleToken(state.logits);
186206

207+
//System.out.println(", nextToken: " + nextToken);
208+
187209
// Output the token if echo is enabled
188210
if (echo) {
189211
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
@@ -249,6 +271,7 @@ public static List<Integer> generateTokensGPU(Model model, State state, int star
249271
// Main generation loop
250272
while (pos < actualMaxTokens) {
251273
// GPU Forward Pass - No conditional check since we know we're using GPU
274+
//System.out.println("currentToken: " + currentToken);
252275
FloatArray logits = InferenceCore.forwardTornadoVM(model, state, currentToken, pos, tornadoVMPlan);
253276

254277
// Process prompt tokens if still remaining
@@ -304,4 +327,116 @@ public static List<Integer> generateTokensGPU(Model model, State state, int star
304327

305328
return generatedTokens;
306329
}
330+
331+
// probably not needed TODO: check this when its working
332+
public static List<Integer> generateTokensGPUQwen3(Model model, State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
333+
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
334+
// Start timing the whole process
335+
long startNanos = System.nanoTime();
336+
long startGen = 0;
337+
long inferenceStartNanos = 0;
338+
339+
// Pre-validate the max tokens to avoid checking in the loop
340+
int actualMaxTokens = Math.min(maxTokens > 0 ? maxTokens : model.configuration().contextLength(), model.configuration().contextLength());
341+
342+
// Preallocate with expected capacity to avoid resizing
343+
List<Integer> generatedTokens = new ArrayList<>(Math.min(256, actualMaxTokens - promptTokens.size())); // Conservative estimate
344+
345+
// Initialize token variables
346+
int currentToken = state.latestToken; // BOS?
347+
int nextToken = 0;
348+
int promptIndex = 0;
349+
350+
// Use more efficient direct array access for prompt tokens if possible
351+
int[] promptTokenArray = null;
352+
if (promptTokens instanceof ArrayList) {
353+
// Try to extract the underlying array for faster access
354+
try {
355+
// This is a performance optimization that may not work on all JVMs
356+
promptTokenArray = promptTokens.stream().mapToInt(Integer::intValue).toArray();
357+
} catch (Exception e) {
358+
// Fall back to list access
359+
}
360+
}
361+
362+
for (int position = startPosition; position < maxTokens; ++position) {
363+
364+
// Handle token processing
365+
if (promptIndex < promptTokens.size()) {
366+
// We're still processing the prompt tokens
367+
final int token = promptTokens.get(promptIndex);
368+
369+
//System.out.println("Token: " + token);
370+
model.forward(state, token, position);
371+
372+
// System.out.println("Token = " + token + " -> state.wrapLogits = { " +
373+
// state.wrapLogits.get(0) + ", " +
374+
// state.wrapLogits.get(1) + ", " +
375+
// state.wrapLogits.get(2) + ", " +
376+
// state.wrapLogits.get(3) + " }");
377+
378+
promptIndex++;
379+
if (promptIndex < promptTokens.size()) {
380+
continue;
381+
}
382+
if (echo) {
383+
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
384+
}
385+
// We have reached the last prompt token and computed the first response-token.
386+
startGen = System.nanoTime();
387+
position++; // The current logit belongs to the next position
388+
} else {
389+
// Mark the start of actual generation (after prompt processing)
390+
if (inferenceStartNanos == 0) {
391+
inferenceStartNanos = System.nanoTime();
392+
}
393+
394+
//System.out.println("currentToken: " + currentToken);
395+
model.forward(state, currentToken, position);
396+
397+
// System.out.println("currentToken = " + currentToken + " -> state.wrapLogits = { " +
398+
// state.wrapLogits.get(0) + ", " +
399+
// state.wrapLogits.get(1) + ", " +
400+
// state.wrapLogits.get(2) + ", " +
401+
// state.wrapLogits.get(3) + " }");
402+
403+
}
404+
405+
406+
// Sample the next token
407+
nextToken = sampler.sampleToken(state.wrapLogits);
408+
409+
//System.out.println(", nextToken: "+ nextToken);
410+
411+
// Output the token if echo is enabled
412+
if (echo) {
413+
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
414+
}
415+
416+
// Track the generated token
417+
generatedTokens.add(nextToken);
418+
419+
// Notify via callback if provided
420+
if (onTokenGenerated != null) {
421+
onTokenGenerated.accept(nextToken);
422+
}
423+
424+
// Check for stop condition
425+
if (stopTokens.contains(nextToken)) {
426+
break;
427+
}
428+
429+
// Update for next iteration
430+
state.latestToken = currentToken = nextToken;
431+
}
432+
433+
// Calculate and print performance metrics
434+
long endNanos = System.nanoTime();
435+
double totalTimeSeconds = (endNanos - startNanos) / 1_000_000_000.0;
436+
int totalTokens = promptIndex + generatedTokens.size();
437+
438+
LastRunMetrics.setMetrics(totalTokens, totalTimeSeconds);
439+
440+
return generatedTokens;
441+
}
307442
}

src/main/java/com/example/inference/state/Qwen3State.java

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,38 @@ public final class Qwen3State extends State {
1414
// Qwen3-specific field
1515
public final FloatTensor kq;
1616

17+
// Qwen3 temporary buffer for intermediate calculations, size adjusted for local workgroup size.
18+
public FloatArray tempQcur;
19+
public FloatArray tempKcur;
20+
21+
// dbg buffer
22+
public FloatArray dbgQ;
23+
public FloatArray dbgKeyCache;
24+
public FloatArray dbgValueCache;
25+
public FloatArray dbgX;
26+
public FloatArray dbgXb;
27+
1728
public Qwen3State(Configuration config, int batchsize) {
1829
super(config, batchsize);
1930
// Initialize Qwen3-specific field
2031
this.kq = ArrayFloatTensor.allocate(config.numberOfHeads(), 32, 15);
32+
this.tempQcur = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
33+
this.tempKcur = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
34+
35+
// dbg buffers
36+
Qwen3Configuration qwen3config = (Qwen3Configuration) config;
37+
int nHeadKv = qwen3config.numberOfKeyValueHeads();
38+
int nEmbdHeadK = qwen3config.numberOfHeadsKey();
39+
int nEmbdKGqa = nEmbdHeadK * nHeadKv;
40+
int nEmbdHeadV = qwen3config.numberOfHeadsValue();
41+
int nEmbdVGqa = nEmbdHeadV * nHeadKv;
42+
int nEmbdGqa = nEmbdVGqa;
43+
44+
this.dbgQ = new FloatArray(nEmbdHeadK * qwen3config.numberOfHeads());
45+
this.dbgKeyCache = new FloatArray(qwen3config.contextLength() * nEmbdGqa * qwen3config.numberOfLayers());
46+
this.dbgValueCache = new FloatArray(qwen3config.contextLength() * nEmbdGqa * qwen3config.numberOfLayers());
47+
this.dbgX = new FloatArray(config.dim());
48+
this.dbgXb = new FloatArray(nEmbdHeadK * qwen3config.numberOfHeads());
2149
}
2250

2351
@Override
@@ -26,10 +54,15 @@ protected StateFields createStateFields(Configuration configuration) {
2654

2755
Qwen3Configuration config = (Qwen3Configuration) configuration;
2856

57+
//localSize = 128;
58+
2959
// Qwen3-specific calculations
3060
int nHeadKv = config.numberOfKeyValueHeads();
3161
int nEmbdHeadK = config.numberOfHeadsKey();
3262
int nEmbdKGqa = nEmbdHeadK * nHeadKv;
63+
int nEmbdHeadV = config.numberOfHeadsValue();
64+
int nEmbdVGqa = nEmbdHeadV * nHeadKv;
65+
int nEmbdGqa = nEmbdVGqa;
3366

3467
// Qwen3-specific allocation logic
3568
fields.x = ArrayFloatTensor.allocate(config.dim());
@@ -44,10 +77,9 @@ protected StateFields createStateFields(Configuration configuration) {
4477
fields.logits = ArrayFloatTensor.allocate(config.vocabularySize());
4578

4679
// Key-value cache with Qwen3 dimensions
47-
int kvDim = nEmbdKGqa;
48-
fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim))
80+
fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa))
4981
.limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
50-
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim))
82+
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa))
5183
.limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
5284

5385
// TornadoVM wrappers with Qwen3-specific sizes
@@ -61,8 +93,8 @@ protected StateFields createStateFields(Configuration configuration) {
6193
fields.wrapK = new FloatArray(nEmbdKGqa); // Different from Llama!
6294
fields.wrapV = new FloatArray(nEmbdKGqa); // Different from Llama!
6395

64-
fields.wrapKeyCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers());
65-
fields.wrapValueCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers());
96+
fields.wrapKeyCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
97+
fields.wrapValueCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
6698
fields.wrapValueCache.init(0.f);
6799
fields.wrapKeyCache.init(0.f);
68100
fields.wrapAtt = new FloatArray(config.numberOfHeads() * config.contextLength());
@@ -73,6 +105,15 @@ protected StateFields createStateFields(Configuration configuration) {
73105
fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
74106
fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
75107

108+
System.out.println("nEmbdHeadK: " + nEmbdHeadK);
109+
System.out.println("nEmbdHeadV: " + nEmbdHeadV);
110+
System.out.println("nEmbdKGqa: " + nEmbdKGqa);
111+
System.out.println("nEmbdVGqa: " + nEmbdVGqa);
112+
System.out.println("nEmbdGqa: " + nEmbdGqa);
113+
System.out.println("wrapK.getSize(): " + fields.wrapK.getSize());
114+
System.out.println("wrapV.getSize(): " + fields.wrapV.getSize());
115+
System.out.println("wrapV.getSize(): " + fields.wrapV.getSize());
116+
76117
return fields;
77118
}
78119
}

src/main/java/com/example/inference/weights/tornado/Qwen3TornadoWeights.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
public class Qwen3TornadoWeights extends TornadoWeights {
88

99
//attnKNorm
10-
FloatArray[] rms_att_KNormLayered;
10+
public FloatArray[] rms_att_KNormLayered;
1111
//attnQNorm
12-
FloatArray[] rms_att_QNormLayered;
12+
public FloatArray[] rms_att_QNormLayered;
1313

1414
public Qwen3TornadoWeights(FloatArray tokenEmbeddingTable, FloatArray[] rms_att_weightLayered,
1515
HalfFloatArray[] wqLayered, HalfFloatArray[] wkLayered, HalfFloatArray[] wvLayered, HalfFloatArray[] woLayered,

src/main/java/com/example/model/Model.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ public interface Model {
5454
*/
5555
List<Integer> generateTokens(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated);
5656

57+
List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
58+
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan);
59+
5760
/**
5861
* Model agnostic default implementation for interactive mode.
5962
* @param sampler
@@ -113,7 +116,7 @@ default void runInteractive(Sampler sampler, Options options) {
113116
// Choose between GPU and CPU path based on configuration
114117
if (USE_TORNADOVM) {
115118
// GPU path using TornadoVM
116-
responseTokens = InferenceEngine.generateTokensGPU(this, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens,
119+
responseTokens = generateTokensGPU(state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens,
117120
options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
118121
} else {
119122
// CPU path
@@ -193,8 +196,9 @@ default void runInstructOnce(Sampler sampler, Options options) {
193196
if (USE_TORNADOVM) {
194197
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
195198
// Call generateTokensGPU without the token consumer parameter
196-
responseTokens = InferenceEngine.generateTokensGPU(this, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null,
197-
tornadoVMPlan);
199+
//responseTokens = InferenceEngine.generateTokensGPU(this, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null,
200+
// tornadoVMPlan);
201+
responseTokens = generateTokensGPU(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
198202
} else {
199203
responseTokens = generateTokens(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer);
200204
}

src/main/java/com/example/model/llama/Llama.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import com.example.model.format.ChatFormat;
1212
import com.example.tokenizer.impl.LlamaTokenizer;
1313
import com.example.tokenizer.impl.Tokenizer;
14+
import com.example.tornadovm.TornadoVMMasterPlan;
1415

1516
import java.util.List;
1617
import java.util.Set;
@@ -64,5 +65,11 @@ public void forward(State state, int token, int position) {
6465
public List<Integer> generateTokens(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) {
6566
return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
6667
}
68+
69+
@Override
70+
public List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
71+
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
72+
return InferenceEngine.generateTokensGPU(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan);
73+
}
6774
}
6875

src/main/java/com/example/model/mistral/Mistral.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import com.example.model.format.ChatFormat;
1212
import com.example.tokenizer.impl.MistralTokenizer;
1313
import com.example.tokenizer.impl.Tokenizer;
14+
import com.example.tornadovm.TornadoVMMasterPlan;
1415

1516
import java.util.List;
1617
import java.util.Set;
@@ -62,4 +63,10 @@ public List<Integer> generateTokens(State state, int startPosition, List<Integer
6263
return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
6364
}
6465

66+
@Override
67+
public List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
68+
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
69+
return InferenceEngine.generateTokensGPU(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan);
70+
}
71+
6572
}

src/main/java/com/example/model/qwen3/Qwen3.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import com.example.model.format.ChatFormat;
1212
import com.example.tokenizer.impl.Qwen3Tokenizer;
1313
import com.example.tokenizer.impl.Tokenizer;
14+
import com.example.tornadovm.TornadoVMMasterPlan;
1415

1516
import java.util.List;
1617
import java.util.Set;
@@ -54,12 +55,22 @@ public State createNewState(int batchsize) {
5455

5556
@Override
5657
public void forward(State state, int token, int position) {
57-
InferenceCore.forwardJavaQwen3(this, state, token, position);
58+
if (plan == null) {
59+
InferenceCore.forwardJavaQwen3(this, state, token, position);
60+
} else {
61+
InferenceCore.forwardTornadoVM(this, state, token, position, tornadoVMPlan());
62+
}
5863
}
5964

6065
@Override
6166
public List<Integer> generateTokens(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) {
6267
return InferenceEngine.generateTokensQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
6368
}
6469

70+
@Override
71+
public List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
72+
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
73+
return InferenceEngine.generateTokensGPUQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan);
74+
}
75+
6576
}

0 commit comments

Comments
 (0)