Skip to content

Commit 7d4b30f

Browse files
Use optimized tornado kernel for Attention
1 parent f059f51 commit 7d4b30f

File tree

1 file changed

+14
-16
lines changed

1 file changed

+14
-16
lines changed

src/main/java/com/example/tornadovm/Qwen3TornadoVMLayerPlanner.java

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -304,22 +304,20 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
304304
// state.positionHolder,
305305
// layerIndex);
306306

307-
unifiedLayer.task("parallel-attention", Qwen3Kernels::processHeadsParallel,
308-
state.wrapQ,
309-
state.wrapKeyCache,
310-
state.wrapValueCache,
311-
state.wrapXb, // out
312-
config.numberOfHeads(),
313-
nEmbdHead,
314-
nEmbdHeadK,
315-
nEmbdHeadV,
316-
nEmbdGqa,
317-
gqa,
318-
config.vocabularySize(),
319-
state.positionHolder,
320-
state.wrapAtt, // out
321-
layerIndex,
322-
config.contextLength());
307+
unifiedLayer.task("parallel-attention",
308+
TransformerComputeKernelsLayered::processHeadsFlashAttention,
309+
context,
310+
state.wrapQ,
311+
state.wrapKeyCache,
312+
state.wrapValueCache,
313+
state.wrapXb, // out
314+
config.numberOfHeads(),
315+
nEmbdHead,
316+
nEmbdGqa,
317+
gqa,
318+
state.positionHolder,
319+
layerIndex,
320+
config.contextLength());
323321

324322
// unifiedLayer.task("dbg_copy_out_x",
325323
// Qwen3Kernels::dbgCopy,

0 commit comments

Comments
 (0)