@@ -159,8 +159,15 @@ public static List<Integer> generateTokensQwen3(Model model, State state, int st
159159 // We're still processing the prompt tokens
160160 final int token = promptTokens .get (promptIndex );
161161
162+ //System.out.println("Token: " + token);
162163 model .forward (state , token , position );
163164
165+ // System.out.println("Token = " + token + " -> state.logits = { " +
166+ // state.logits.getFloat(0) + ", " +
167+ // state.logits.getFloat(1) + ", " +
168+ // state.logits.getFloat(2) + ", " +
169+ // state.logits.getFloat(3) + " }");
170+
164171 promptIndex ++;
165172 if (promptIndex < promptTokens .size ()) {
166173 continue ;
@@ -177,13 +184,28 @@ public static List<Integer> generateTokensQwen3(Model model, State state, int st
177184 inferenceStartNanos = System .nanoTime ();
178185 }
179186
187+ //System.out.println("currentToken: " + currentToken);
180188 model .forward (state , currentToken , position );
181189
190+ // System.out.println("currentToken = " + currentToken + " -> state.logits = { " +
191+ // state.logits.getFloat(0) + ", " +
192+ // state.logits.getFloat(1) + ", " +
193+ // state.logits.getFloat(2) + ", " +
194+ // state.logits.getFloat(3) + " }");
195+
182196 }
183197
198+ // System.out.print("state.logits = { " +
199+ // state.logits.getFloat(0) + ", " +
200+ // state.logits.getFloat(1) + ", " +
201+ // state.logits.getFloat(2) + ", " +
202+ // state.logits.getFloat(3) + "}");
203+
184204 // Sample the next token
185205 nextToken = sampler .sampleToken (state .logits );
186206
207+ //System.out.println(", nextToken: " + nextToken);
208+
187209 // Output the token if echo is enabled
188210 if (echo ) {
189211 System .err .print (Tokenizer .replaceControlCharacters (model .tokenizer ().decode (List .of (nextToken ))));
@@ -249,6 +271,7 @@ public static List<Integer> generateTokensGPU(Model model, State state, int star
249271 // Main generation loop
250272 while (pos < actualMaxTokens ) {
251273 // GPU Forward Pass - No conditional check since we know we're using GPU
274+ //System.out.println("currentToken: " + currentToken);
252275 FloatArray logits = InferenceCore .forwardTornadoVM (model , state , currentToken , pos , tornadoVMPlan );
253276
254277 // Process prompt tokens if still remaining
@@ -304,4 +327,116 @@ public static List<Integer> generateTokensGPU(Model model, State state, int star
304327
305328 return generatedTokens ;
306329 }
330+
331+ // probably not needed TODO: check this when its working
332+ public static List <Integer > generateTokensGPUQwen3 (Model model , State state , int startPosition , List <Integer > promptTokens , Set <Integer > stopTokens , int maxTokens , Sampler sampler , boolean echo ,
333+ IntConsumer onTokenGenerated , TornadoVMMasterPlan tornadoVMPlan ) {
334+ // Start timing the whole process
335+ long startNanos = System .nanoTime ();
336+ long startGen = 0 ;
337+ long inferenceStartNanos = 0 ;
338+
339+ // Pre-validate the max tokens to avoid checking in the loop
340+ int actualMaxTokens = Math .min (maxTokens > 0 ? maxTokens : model .configuration ().contextLength (), model .configuration ().contextLength ());
341+
342+ // Preallocate with expected capacity to avoid resizing
343+ List <Integer > generatedTokens = new ArrayList <>(Math .min (256 , actualMaxTokens - promptTokens .size ())); // Conservative estimate
344+
345+ // Initialize token variables
346+ int currentToken = state .latestToken ; // BOS?
347+ int nextToken = 0 ;
348+ int promptIndex = 0 ;
349+
350+ // Use more efficient direct array access for prompt tokens if possible
351+ int [] promptTokenArray = null ;
352+ if (promptTokens instanceof ArrayList ) {
353+ // Try to extract the underlying array for faster access
354+ try {
355+ // This is a performance optimization that may not work on all JVMs
356+ promptTokenArray = promptTokens .stream ().mapToInt (Integer ::intValue ).toArray ();
357+ } catch (Exception e ) {
358+ // Fall back to list access
359+ }
360+ }
361+
362+ for (int position = startPosition ; position < maxTokens ; ++position ) {
363+
364+ // Handle token processing
365+ if (promptIndex < promptTokens .size ()) {
366+ // We're still processing the prompt tokens
367+ final int token = promptTokens .get (promptIndex );
368+
369+ //System.out.println("Token: " + token);
370+ model .forward (state , token , position );
371+
372+ // System.out.println("Token = " + token + " -> state.wrapLogits = { " +
373+ // state.wrapLogits.get(0) + ", " +
374+ // state.wrapLogits.get(1) + ", " +
375+ // state.wrapLogits.get(2) + ", " +
376+ // state.wrapLogits.get(3) + " }");
377+
378+ promptIndex ++;
379+ if (promptIndex < promptTokens .size ()) {
380+ continue ;
381+ }
382+ if (echo ) {
383+ System .err .print (Tokenizer .replaceControlCharacters (model .tokenizer ().decode (List .of (nextToken ))));
384+ }
385+ // We have reached the last prompt token and computed the first response-token.
386+ startGen = System .nanoTime ();
387+ position ++; // The current logit belongs to the next position
388+ } else {
389+ // Mark the start of actual generation (after prompt processing)
390+ if (inferenceStartNanos == 0 ) {
391+ inferenceStartNanos = System .nanoTime ();
392+ }
393+
394+ //System.out.println("currentToken: " + currentToken);
395+ model .forward (state , currentToken , position );
396+
397+ // System.out.println("currentToken = " + currentToken + " -> state.wrapLogits = { " +
398+ // state.wrapLogits.get(0) + ", " +
399+ // state.wrapLogits.get(1) + ", " +
400+ // state.wrapLogits.get(2) + ", " +
401+ // state.wrapLogits.get(3) + " }");
402+
403+ }
404+
405+
406+ // Sample the next token
407+ nextToken = sampler .sampleToken (state .wrapLogits );
408+
409+ //System.out.println(", nextToken: "+ nextToken);
410+
411+ // Output the token if echo is enabled
412+ if (echo ) {
413+ System .err .print (Tokenizer .replaceControlCharacters (model .tokenizer ().decode (List .of (nextToken ))));
414+ }
415+
416+ // Track the generated token
417+ generatedTokens .add (nextToken );
418+
419+ // Notify via callback if provided
420+ if (onTokenGenerated != null ) {
421+ onTokenGenerated .accept (nextToken );
422+ }
423+
424+ // Check for stop condition
425+ if (stopTokens .contains (nextToken )) {
426+ break ;
427+ }
428+
429+ // Update for next iteration
430+ state .latestToken = currentToken = nextToken ;
431+ }
432+
433+ // Calculate and print performance metrics
434+ long endNanos = System .nanoTime ();
435+ double totalTimeSeconds = (endNanos - startNanos ) / 1_000_000_000.0 ;
436+ int totalTokens = promptIndex + generatedTokens .size ();
437+
438+ LastRunMetrics .setMetrics (totalTokens , totalTimeSeconds );
439+
440+ return generatedTokens ;
441+ }
307442}
0 commit comments