Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
9b68bf7
Initial commit of Mistral port
orionpapadakis May 27, 2025
89d5aa3
Refactor TornadoVM integration and extend Mistral configuration.
mikepapadim May 29, 2025
a2ec8cc
Refactor for Mistral integration with abstractions
orionpapadakis May 30, 2025
115f25a
Move sampler classes to dedicated package
orionpapadakis Jun 2, 2025
726eec2
Decouple LastRunMetrics class from Llama and reuse it for Mistral
orionpapadakis Jun 2, 2025
41a1733
Add comments for tokenizers
orionpapadakis Jun 2, 2025
1a18ad4
Decouple inference implementation from Model
orionpapadakis Jun 2, 2025
b6b693f
Fully integrate TornadoVM for Mistral
orionpapadakis Jun 2, 2025
420a119
Generalize interactive mode implementation for Llama and Mistral
orionpapadakis Jun 5, 2025
613062c
Generalize instruct mode implementation for Llama and Mistral
orionpapadakis Jun 5, 2025
2407938
Relocate ChatFormat classes to model.format package
orionpapadakis Jun 6, 2025
1c514ad
Clean up
orionpapadakis Jun 11, 2025
1640e90
Remove redundant
orionpapadakis Jun 12, 2025
733815b
Move ModelType enum to dedicated file
orionpapadakis Jun 12, 2025
340b35e
Merge createTokenizer methods into Tokenizer constructors
orionpapadakis Jun 12, 2025
8e63862
Move loadModel methods to dedicated model classes
orionpapadakis Jun 12, 2025
548b55b
Add support for --suffix option in llama-tornado python script
orionpapadakis Jun 12, 2025
72a2b8b
Generalize names and comments in llama-tornado python script
orionpapadakis Jun 12, 2025
3b4bb62
Apply a formatter pass
orionpapadakis Jun 12, 2025
2371a7c
Update README
orionpapadakis Jun 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,14 @@ Check models below.

## Download Model Files

Download `FP16` quantized .gguf files from:
Download `FP16` quantized `Llama-3` .gguf files from:
- https://huggingface.co/beehive-lab/Llama-3.2-1B-Instruct-GGUF-FP16
- https://huggingface.co/beehive-lab/Llama-3.2-3B-Instruct-GGUF-FP16
- https://huggingface.co/beehive-lab/Llama-3.2-8B-Instruct-GGUF-FP16

Download `FP16` quantized `Mistral` .gguf files from:
- https://huggingface.co/collections/beehive-lab/mistral-gpullama3java-684afabb206136d2e9cd47e0

Please be gentle with [huggingface.co](https://huggingface.co) servers:

**Note** FP16 models are first-class citizens for the current version.
Expand All @@ -181,6 +184,9 @@ wget https://huggingface.co/beehive-lab/Llama-3.2-3B-Instruct-GGUF-FP16/resolve/

# Llama 3 (8B) - FP16
wget https://huggingface.co/beehive-lab/Llama-3.2-8B-Instruct-GGUF-FP16/resolve/main/beehive-llama-3.2-8b-instruct-fp16.gguf

# Mistral (7B) - FP16
wget https://huggingface.co/MaziyarPanahi/Mistral-7B-Instruct-v0.3-GGUF/resolve/main/Mistral-7B-Instruct-v0.3.fp16.gguf
```

**[Experimental]** you can download the Q8 and Q4 used in the original implementation of Llama3.java, but for now are going to be dequanted to FP16 for TornadoVM support:
Expand All @@ -201,7 +207,7 @@ curl -L -O https://huggingface.co/mukel/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/

## Running `llama-tornado`

To execute Llama3 models with TornadoVM on GPUs use the `llama-tornado` script with the `--gpu` flag.
To execute Llama3, or Mistral models with TornadoVM on GPUs use the `llama-tornado` script with the `--gpu` flag.

### Usage Examples

Expand Down Expand Up @@ -246,11 +252,11 @@ First, check your GPU specifications. If your GPU has high memory capacity, you

### GPU Memory Requirements by Model Size

| Model Size | Recommended GPU Memory |
|------------|------------------------|
| 1B models | 7GB (default) |
| 3B models | 15GB |
| 8B models | 20GB+ |
| Model Size | Recommended GPU Memory |
|-------------|------------------------|
| 1B models | 7GB (default) |
| 3-7B models | 15GB |
| 8B models | 20GB+ |

**Note**: If you still encounter memory issues, try:

Expand Down Expand Up @@ -288,6 +294,7 @@ LLaMA Configuration:
Maximum number of tokens to generate (default: 512)
--stream STREAM Enable streaming output (default: True)
--echo ECHO Echo the input prompt (default: False)
--suffix SUFFIX Suffix for fill-in-the-middle request (Codestral) (default: None)

Mode Selection:
-i, --interactive Run in interactive/chat mode (default: False)
Expand Down
31 changes: 16 additions & 15 deletions llama-tornado
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
"""
llama-tornado: GPU-accelerated LLaMA.java runner with TornadoVM
Run LLaMA models using either OpenCL or PTX backends.
llama-tornado: GPU-accelerated Java LLM runner with TornadoVM
Run LLM models using either OpenCL or PTX backends.
"""

import argparse
Expand All @@ -19,7 +19,7 @@ class Backend(Enum):
PTX = "ptx"

class LlamaRunner:
"""Main class for managing LLaMA model execution with GPU acceleration."""
"""Main class for managing LLM execution with GPU acceleration."""

def __init__(self):
self.java_home = os.environ.get('JAVA_HOME')
Expand Down Expand Up @@ -266,30 +266,31 @@ def create_parser() -> argparse.ArgumentParser:
"""Create and configure the argument parser."""
parser = argparse.ArgumentParser(
prog="llama-tornado",
description="GPU-accelerated LLaMA.java model runner using TornadoVM",
description="GPU-accelerated LLM runner using TornadoVM",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

# Required arguments
parser.add_argument("--model", dest="model_path", required=True,
help="Path to the LLaMA model file (e.g., Llama-3.2-1B-Instruct-Q8_0.gguf)")
help="Path to the LLM gguf file (e.g., Llama-3.2-1B-Instruct-Q8_0.gguf)")

# LLaMA arguments
llama_group = parser.add_argument_group("LLaMA Configuration")
llama_group.add_argument("--prompt", help="Input prompt for the model")
llama_group.add_argument("-sp", "--system-prompt", help="System prompt for the model")
llama_group.add_argument("--temperature", type=float, default=0.1,
# LLM arguments
llm_group = parser.add_argument_group("LLaMA Configuration")
llm_group.add_argument("--prompt", help="Input prompt for the model")
llm_group.add_argument("-sp", "--system-prompt", help="System prompt for the model")
llm_group.add_argument("--temperature", type=float, default=0.1,
help="Sampling temperature (0.0 to 2.0)")
llama_group.add_argument("--top-p", type=float, default=0.95,
llm_group.add_argument("--top-p", type=float, default=0.95,
help="Top-p sampling parameter")
llama_group.add_argument("--seed", type=int, default=None,
llm_group.add_argument("--seed", type=int, default=None,
help="Random seed (default: current timestamp)")
llama_group.add_argument("-n", "--max-tokens", type=int, default=512,
llm_group.add_argument("-n", "--max-tokens", type=int, default=512,
help="Maximum number of tokens to generate")
llama_group.add_argument("--stream", type=bool, default=True,
llm_group.add_argument("--stream", type=bool, default=True,
help="Enable streaming output")
llama_group.add_argument("--echo", type=bool, default=False,
llm_group.add_argument("--echo", type=bool, default=False,
help="Echo the input prompt")
llm_group.add_argument("--suffix", help="Suffix for fill-in-the-middle request (Codestral)")

# Mode selection
mode_group = parser.add_argument_group("Mode Selection")
Expand Down
163 changes: 9 additions & 154 deletions src/main/java/com/example/LlamaApp.java
Original file line number Diff line number Diff line change
@@ -1,25 +1,16 @@
package com.example;

import com.example.aot.AOT;
import com.example.auxiliary.ChatFormat;
import com.example.core.model.tensor.FloatTensor;
import com.example.inference.CategoricalSampler;
import com.example.inference.Sampler;
import com.example.inference.ToppSampler;
import com.example.inference.engine.impl.Llama;
import com.example.inference.engine.impl.Options;
import com.example.inference.sampler.CategoricalSampler;
import com.example.inference.sampler.Sampler;
import com.example.inference.sampler.ToppSampler;
import com.example.model.Model;
import com.example.loader.weights.ModelLoader;
import com.example.loader.weights.State;
import com.example.tornadovm.FloatArrayUtils;
import com.example.tornadovm.TornadoVMMasterPlan;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
import java.util.Set;
import java.util.function.IntConsumer;
import java.util.random.RandomGenerator;
import java.util.random.RandomGeneratorFactory;

Expand Down Expand Up @@ -115,156 +106,20 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp,
return sampler;
}

static void runInteractive(Llama model, Sampler sampler, Options options) {
State state = null;
List<Integer> conversationTokens = new ArrayList<>();
ChatFormat chatFormat = new ChatFormat(model.tokenizer());
conversationTokens.add(chatFormat.beginOfText);
if (options.systemPrompt() != null) {
conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
}
int startPosition = 0;
Scanner in = new Scanner(System.in);

// Initialize TornadoVM plan once at the beginning if GPU path is enabled
TornadoVMMasterPlan tornadoVMPlan = null;

try {
while (true) {
System.out.print("> ");
System.out.flush();
String userText = in.nextLine();
if (List.of("quit", "exit").contains(userText)) {
break;
}
if (state == null) {
state = model.createNewState();
}

if (USE_TORNADOVM && tornadoVMPlan == null) {
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model);
}

conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText)));
conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
Set<Integer> stopTokens = chatFormat.getStopTokens();

List<Integer> responseTokens;
IntConsumer tokenConsumer = token -> {
if (options.stream()) {
if (!model.tokenizer().isSpecialToken(token)) {
System.out.print(model.tokenizer().decode(List.of(token)));
}
}
};

// Choose between GPU and CPU path based on configuration
if (USE_TORNADOVM) {
// GPU path using TornadoVM
responseTokens = Llama.generateTokensGPU(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(),
sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
} else {
// CPU path
responseTokens = Llama.generateTokens(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler,
options.echo(), tokenConsumer);
}

// Include stop token in the prompt history, but not in the response displayed to the user.
conversationTokens.addAll(responseTokens);
startPosition = conversationTokens.size();
Integer stopToken = null;
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
stopToken = responseTokens.getLast();
responseTokens.removeLast();
}
if (!options.stream()) {
String responseText = model.tokenizer().decode(responseTokens);
System.out.println(responseText);
}
if (stopToken == null) {
System.err.println("\n Ran out of context length...\n Increase context length with by passing to llama-tornado --max-tokens XXX");
break;
}
System.out.print("\n");

// Optionally print performance metrics after each response
if (SHOW_PERF_INTERACTIVE) {
Llama.LastRunMetrics.printMetrics();
}
}
} finally {
// Clean up TornadoVM resources when exiting the chat loop
if (USE_TORNADOVM && tornadoVMPlan != null) {
try {
tornadoVMPlan.freeTornadoExecutionPlan();
} catch (Exception e) {
System.err.println("Error while cleaning up TornadoVM resources: " + e.getMessage());
}
}
}
}

static void runInstructOnce(Llama model, Sampler sampler, Options options) {
State state = model.createNewState();
ChatFormat chatFormat = new ChatFormat(model.tokenizer());
TornadoVMMasterPlan tornadoVMPlan = null;

List<Integer> promptTokens = new ArrayList<>();
promptTokens.add(chatFormat.beginOfText);
if (options.systemPrompt() != null) {
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
}
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt())));
promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
List<Integer> responseTokens;

// Define the token consumer
IntConsumer tokenConsumer = token -> {
if (options.stream()) {
if (!model.tokenizer().isSpecialToken(token)) {
System.out.print(model.tokenizer().decode(List.of(token)));
}
}
};

Set<Integer> stopTokens = chatFormat.getStopTokens();
if (USE_TORNADOVM) {
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model);
// Call generateTokensGPU without the token consumer parameter
responseTokens = Llama.generateTokensGPU(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
} else {
// CPU path still uses the token consumer
responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer);
}

if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
responseTokens.removeLast();
}
if (!options.stream()) {
String responseText = model.tokenizer().decode(responseTokens);
System.out.println(responseText);
}

Llama.LastRunMetrics.printMetrics();

if (tornadoVMPlan != null) {
tornadoVMPlan.freeTornadoExecutionPlan();
}
}

public static void main(String[] args) throws IOException {
Options options = Options.parseOptions(args);
Llama model;
Model model;
if (USE_AOT) {
model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens());
} else {
model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true);
}
Sampler sampler = selectSampler(model.configuration().vocabularySize, options.temperature(), options.topp(), options.seed());
assert model != null;
Sampler sampler = selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed());
if (options.interactive()) {
runInteractive(model, sampler, options);
model.runInteractive(sampler, options);
} else {
runInstructOnce(model, sampler, options);
model.runInstructOnce(sampler, options);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package com.example.inference.engine.impl;
package com.example;

import java.io.PrintStream;
import java.nio.file.Path;
import java.nio.file.Paths;

public record Options(Path modelPath, String prompt, String systemPrompt, boolean interactive,
public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive,
float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo) {

public static final int DEFAULT_MAX_TOKENS = 1024;

public Options {
require(modelPath != null, "Missing argument: --model <path> is required");
require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"" );
require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"");
require(0 <= temperature, "Invalid argument: --temperature must be non-negative");
require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]");
}
Expand All @@ -33,7 +33,8 @@ static void printUsage(PrintStream out) {
out.println(" --interactive, --chat, -i run in chat mode");
out.println(" --instruct run in instruct (once) mode, default mode");
out.println(" --prompt, -p <string> input prompt");
out.println(" --system-prompt, -sp <string> (optional) system prompt");
out.println(" --system-prompt, -sp <string> (optional) system prompt (Llama models)");
out.println(" --suffix <string> suffix for fill-in-the-middle request (Codestral)");
out.println(" --temperature, -temp <float> temperature in [0,inf], default 0.1");
out.println(" --top-p <float> p value in top-p (nucleus) sampling in [0,1] default 0.95");
out.println(" --seed <long> random seed, default System.nanoTime()");
Expand All @@ -46,6 +47,7 @@ static void printUsage(PrintStream out) {
public static Options parseOptions(String[] args) {
String prompt = "Tell me a story with Java"; // Hardcoded for testing
String systemPrompt = null;
String suffix = null;
float temperature = 0.1f;
float topp = 0.95f;
Path modelPath = null;
Expand Down Expand Up @@ -80,6 +82,7 @@ public static Options parseOptions(String[] args) {
switch (optionName) {
case "--prompt", "-p" -> prompt = nextArg;
case "--system-prompt", "-sp" -> systemPrompt = nextArg;
case "--suffix" -> suffix = nextArg;
case "--temperature", "--temp" -> temperature = Float.parseFloat(nextArg);
case "--top-p" -> topp = Float.parseFloat(nextArg);
case "--model", "-m" -> modelPath = Paths.get(nextArg);
Expand All @@ -92,6 +95,6 @@ public static Options parseOptions(String[] args) {
}
}
}
return new Options(modelPath, prompt, systemPrompt, interactive, temperature, topp, seed, maxTokens, stream, echo);
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo);
}
}
9 changes: 5 additions & 4 deletions src/main/java/com/example/aot/AOT.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import com.example.auxiliary.Timer;
import com.example.core.model.GGUF;
import com.example.core.model.tensor.GGMLTensorEntry;
import com.example.inference.engine.impl.Llama;
import com.example.inference.engine.impl.Options;
import com.example.model.Model;
import com.example.Options;
import com.example.model.llama.Llama;
import com.example.loader.weights.ModelLoader;
import com.example.loader.weights.Weights;

Expand Down Expand Up @@ -45,7 +46,7 @@ private static PartialModel preLoadGGUF(String modelPath) {
try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) {
return new PartialModel(
path.getFileName().toString(),
ModelLoader.loadModel(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false),
Llama.loadModel(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false), // TODO: needs proper handling for AOT
gguf.getTensorDataOffset(),
gguf.getTensorInfos()
);
Expand All @@ -60,7 +61,7 @@ private static PartialModel preLoadGGUF(String modelPath) {
* The file name (base name) must match with the preloaded file name.
* No checksum/hash is checked for performance reasons.
*/
public static com.example.inference.engine.impl.Llama tryUsePreLoaded(Path modelPath, int contextLength) throws IOException {
public static Model tryUsePreLoaded(Path modelPath, int contextLength) throws IOException {
AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF;
if (preLoaded == null) {
return null; // no pre-loaded model stored
Expand Down
Loading