diff --git a/README.md b/README.md index b77322aa..f4c1274e 100644 --- a/README.md +++ b/README.md @@ -164,11 +164,14 @@ Check models below. ## Download Model Files -Download `FP16` quantized .gguf files from: +Download `FP16` quantized `Llama-3` .gguf files from: - https://huggingface.co/beehive-lab/Llama-3.2-1B-Instruct-GGUF-FP16 - https://huggingface.co/beehive-lab/Llama-3.2-3B-Instruct-GGUF-FP16 - https://huggingface.co/beehive-lab/Llama-3.2-8B-Instruct-GGUF-FP16 +Download `FP16` quantized `Mistral` .gguf files from: +- https://huggingface.co/collections/beehive-lab/mistral-gpullama3java-684afabb206136d2e9cd47e0 + Please be gentle with [huggingface.co](https://huggingface.co) servers: **Note** FP16 models are first-class citizens for the current version. @@ -181,6 +184,9 @@ wget https://huggingface.co/beehive-lab/Llama-3.2-3B-Instruct-GGUF-FP16/resolve/ # Llama 3 (8B) - FP16 wget https://huggingface.co/beehive-lab/Llama-3.2-8B-Instruct-GGUF-FP16/resolve/main/beehive-llama-3.2-8b-instruct-fp16.gguf + +# Mistral (7B) - FP16 +wget https://huggingface.co/MaziyarPanahi/Mistral-7B-Instruct-v0.3-GGUF/resolve/main/Mistral-7B-Instruct-v0.3.fp16.gguf ``` **[Experimental]** you can download the Q8 and Q4 used in the original implementation of Llama3.java, but for now are going to be dequanted to FP16 for TornadoVM support: @@ -201,7 +207,7 @@ curl -L -O https://huggingface.co/mukel/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/ ## Running `llama-tornado` -To execute Llama3 models with TornadoVM on GPUs use the `llama-tornado` script with the `--gpu` flag. +To execute Llama3, or Mistral models with TornadoVM on GPUs use the `llama-tornado` script with the `--gpu` flag. ### Usage Examples @@ -246,11 +252,11 @@ First, check your GPU specifications. If your GPU has high memory capacity, you ### GPU Memory Requirements by Model Size -| Model Size | Recommended GPU Memory | -|------------|------------------------| -| 1B models | 7GB (default) | -| 3B models | 15GB | -| 8B models | 20GB+ | +| Model Size | Recommended GPU Memory | +|-------------|------------------------| +| 1B models | 7GB (default) | +| 3-7B models | 15GB | +| 8B models | 20GB+ | **Note**: If you still encounter memory issues, try: @@ -288,6 +294,7 @@ LLaMA Configuration: Maximum number of tokens to generate (default: 512) --stream STREAM Enable streaming output (default: True) --echo ECHO Echo the input prompt (default: False) + --suffix SUFFIX Suffix for fill-in-the-middle request (Codestral) (default: None) Mode Selection: -i, --interactive Run in interactive/chat mode (default: False) diff --git a/llama-tornado b/llama-tornado index cf9c13c6..675fe204 100755 --- a/llama-tornado +++ b/llama-tornado @@ -1,7 +1,7 @@ #!/usr/bin/env python3 """ -llama-tornado: GPU-accelerated LLaMA.java runner with TornadoVM -Run LLaMA models using either OpenCL or PTX backends. +llama-tornado: GPU-accelerated Java LLM runner with TornadoVM +Run LLM models using either OpenCL or PTX backends. """ import argparse @@ -19,7 +19,7 @@ class Backend(Enum): PTX = "ptx" class LlamaRunner: - """Main class for managing LLaMA model execution with GPU acceleration.""" + """Main class for managing LLM execution with GPU acceleration.""" def __init__(self): self.java_home = os.environ.get('JAVA_HOME') @@ -266,30 +266,31 @@ def create_parser() -> argparse.ArgumentParser: """Create and configure the argument parser.""" parser = argparse.ArgumentParser( prog="llama-tornado", - description="GPU-accelerated LLaMA.java model runner using TornadoVM", + description="GPU-accelerated LLM runner using TornadoVM", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) # Required arguments parser.add_argument("--model", dest="model_path", required=True, - help="Path to the LLaMA model file (e.g., Llama-3.2-1B-Instruct-Q8_0.gguf)") + help="Path to the LLM gguf file (e.g., Llama-3.2-1B-Instruct-Q8_0.gguf)") - # LLaMA arguments - llama_group = parser.add_argument_group("LLaMA Configuration") - llama_group.add_argument("--prompt", help="Input prompt for the model") - llama_group.add_argument("-sp", "--system-prompt", help="System prompt for the model") - llama_group.add_argument("--temperature", type=float, default=0.1, + # LLM arguments + llm_group = parser.add_argument_group("LLaMA Configuration") + llm_group.add_argument("--prompt", help="Input prompt for the model") + llm_group.add_argument("-sp", "--system-prompt", help="System prompt for the model") + llm_group.add_argument("--temperature", type=float, default=0.1, help="Sampling temperature (0.0 to 2.0)") - llama_group.add_argument("--top-p", type=float, default=0.95, + llm_group.add_argument("--top-p", type=float, default=0.95, help="Top-p sampling parameter") - llama_group.add_argument("--seed", type=int, default=None, + llm_group.add_argument("--seed", type=int, default=None, help="Random seed (default: current timestamp)") - llama_group.add_argument("-n", "--max-tokens", type=int, default=512, + llm_group.add_argument("-n", "--max-tokens", type=int, default=512, help="Maximum number of tokens to generate") - llama_group.add_argument("--stream", type=bool, default=True, + llm_group.add_argument("--stream", type=bool, default=True, help="Enable streaming output") - llama_group.add_argument("--echo", type=bool, default=False, + llm_group.add_argument("--echo", type=bool, default=False, help="Echo the input prompt") + llm_group.add_argument("--suffix", help="Suffix for fill-in-the-middle request (Codestral)") # Mode selection mode_group = parser.add_argument_group("Mode Selection") diff --git a/src/main/java/com/example/LlamaApp.java b/src/main/java/com/example/LlamaApp.java index 8ae277f5..826b35c0 100644 --- a/src/main/java/com/example/LlamaApp.java +++ b/src/main/java/com/example/LlamaApp.java @@ -1,25 +1,16 @@ package com.example; import com.example.aot.AOT; -import com.example.auxiliary.ChatFormat; import com.example.core.model.tensor.FloatTensor; -import com.example.inference.CategoricalSampler; -import com.example.inference.Sampler; -import com.example.inference.ToppSampler; -import com.example.inference.engine.impl.Llama; -import com.example.inference.engine.impl.Options; +import com.example.inference.sampler.CategoricalSampler; +import com.example.inference.sampler.Sampler; +import com.example.inference.sampler.ToppSampler; +import com.example.model.Model; import com.example.loader.weights.ModelLoader; -import com.example.loader.weights.State; import com.example.tornadovm.FloatArrayUtils; -import com.example.tornadovm.TornadoVMMasterPlan; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Scanner; -import java.util.Set; -import java.util.function.IntConsumer; import java.util.random.RandomGenerator; import java.util.random.RandomGeneratorFactory; @@ -115,156 +106,20 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp, return sampler; } - static void runInteractive(Llama model, Sampler sampler, Options options) { - State state = null; - List conversationTokens = new ArrayList<>(); - ChatFormat chatFormat = new ChatFormat(model.tokenizer()); - conversationTokens.add(chatFormat.beginOfText); - if (options.systemPrompt() != null) { - conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); - } - int startPosition = 0; - Scanner in = new Scanner(System.in); - - // Initialize TornadoVM plan once at the beginning if GPU path is enabled - TornadoVMMasterPlan tornadoVMPlan = null; - - try { - while (true) { - System.out.print("> "); - System.out.flush(); - String userText = in.nextLine(); - if (List.of("quit", "exit").contains(userText)) { - break; - } - if (state == null) { - state = model.createNewState(); - } - - if (USE_TORNADOVM && tornadoVMPlan == null) { - tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model); - } - - conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText))); - conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); - Set stopTokens = chatFormat.getStopTokens(); - - List responseTokens; - IntConsumer tokenConsumer = token -> { - if (options.stream()) { - if (!model.tokenizer().isSpecialToken(token)) { - System.out.print(model.tokenizer().decode(List.of(token))); - } - } - }; - - // Choose between GPU and CPU path based on configuration - if (USE_TORNADOVM) { - // GPU path using TornadoVM - responseTokens = Llama.generateTokensGPU(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), - sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan); - } else { - // CPU path - responseTokens = Llama.generateTokens(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler, - options.echo(), tokenConsumer); - } - - // Include stop token in the prompt history, but not in the response displayed to the user. - conversationTokens.addAll(responseTokens); - startPosition = conversationTokens.size(); - Integer stopToken = null; - if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { - stopToken = responseTokens.getLast(); - responseTokens.removeLast(); - } - if (!options.stream()) { - String responseText = model.tokenizer().decode(responseTokens); - System.out.println(responseText); - } - if (stopToken == null) { - System.err.println("\n Ran out of context length...\n Increase context length with by passing to llama-tornado --max-tokens XXX"); - break; - } - System.out.print("\n"); - - // Optionally print performance metrics after each response - if (SHOW_PERF_INTERACTIVE) { - Llama.LastRunMetrics.printMetrics(); - } - } - } finally { - // Clean up TornadoVM resources when exiting the chat loop - if (USE_TORNADOVM && tornadoVMPlan != null) { - try { - tornadoVMPlan.freeTornadoExecutionPlan(); - } catch (Exception e) { - System.err.println("Error while cleaning up TornadoVM resources: " + e.getMessage()); - } - } - } - } - - static void runInstructOnce(Llama model, Sampler sampler, Options options) { - State state = model.createNewState(); - ChatFormat chatFormat = new ChatFormat(model.tokenizer()); - TornadoVMMasterPlan tornadoVMPlan = null; - - List promptTokens = new ArrayList<>(); - promptTokens.add(chatFormat.beginOfText); - if (options.systemPrompt() != null) { - promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); - } - promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt()))); - promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); - List responseTokens; - - // Define the token consumer - IntConsumer tokenConsumer = token -> { - if (options.stream()) { - if (!model.tokenizer().isSpecialToken(token)) { - System.out.print(model.tokenizer().decode(List.of(token))); - } - } - }; - - Set stopTokens = chatFormat.getStopTokens(); - if (USE_TORNADOVM) { - tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model); - // Call generateTokensGPU without the token consumer parameter - responseTokens = Llama.generateTokensGPU(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan); - } else { - // CPU path still uses the token consumer - responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer); - } - - if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { - responseTokens.removeLast(); - } - if (!options.stream()) { - String responseText = model.tokenizer().decode(responseTokens); - System.out.println(responseText); - } - - Llama.LastRunMetrics.printMetrics(); - - if (tornadoVMPlan != null) { - tornadoVMPlan.freeTornadoExecutionPlan(); - } - } - public static void main(String[] args) throws IOException { Options options = Options.parseOptions(args); - Llama model; + Model model; if (USE_AOT) { model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens()); } else { model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true); } - Sampler sampler = selectSampler(model.configuration().vocabularySize, options.temperature(), options.topp(), options.seed()); + assert model != null; + Sampler sampler = selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed()); if (options.interactive()) { - runInteractive(model, sampler, options); + model.runInteractive(sampler, options); } else { - runInstructOnce(model, sampler, options); + model.runInstructOnce(sampler, options); } } } diff --git a/src/main/java/com/example/inference/engine/impl/Options.java b/src/main/java/com/example/Options.java similarity index 91% rename from src/main/java/com/example/inference/engine/impl/Options.java rename to src/main/java/com/example/Options.java index 7ee13194..284e754a 100644 --- a/src/main/java/com/example/inference/engine/impl/Options.java +++ b/src/main/java/com/example/Options.java @@ -1,17 +1,17 @@ -package com.example.inference.engine.impl; +package com.example; import java.io.PrintStream; import java.nio.file.Path; import java.nio.file.Paths; -public record Options(Path modelPath, String prompt, String systemPrompt, boolean interactive, +public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive, float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo) { public static final int DEFAULT_MAX_TOKENS = 1024; public Options { require(modelPath != null, "Missing argument: --model is required"); - require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"" ); + require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\""); require(0 <= temperature, "Invalid argument: --temperature must be non-negative"); require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]"); } @@ -33,7 +33,8 @@ static void printUsage(PrintStream out) { out.println(" --interactive, --chat, -i run in chat mode"); out.println(" --instruct run in instruct (once) mode, default mode"); out.println(" --prompt, -p input prompt"); - out.println(" --system-prompt, -sp (optional) system prompt"); + out.println(" --system-prompt, -sp (optional) system prompt (Llama models)"); + out.println(" --suffix suffix for fill-in-the-middle request (Codestral)"); out.println(" --temperature, -temp temperature in [0,inf], default 0.1"); out.println(" --top-p p value in top-p (nucleus) sampling in [0,1] default 0.95"); out.println(" --seed random seed, default System.nanoTime()"); @@ -46,6 +47,7 @@ static void printUsage(PrintStream out) { public static Options parseOptions(String[] args) { String prompt = "Tell me a story with Java"; // Hardcoded for testing String systemPrompt = null; + String suffix = null; float temperature = 0.1f; float topp = 0.95f; Path modelPath = null; @@ -80,6 +82,7 @@ public static Options parseOptions(String[] args) { switch (optionName) { case "--prompt", "-p" -> prompt = nextArg; case "--system-prompt", "-sp" -> systemPrompt = nextArg; + case "--suffix" -> suffix = nextArg; case "--temperature", "--temp" -> temperature = Float.parseFloat(nextArg); case "--top-p" -> topp = Float.parseFloat(nextArg); case "--model", "-m" -> modelPath = Paths.get(nextArg); @@ -92,6 +95,6 @@ public static Options parseOptions(String[] args) { } } } - return new Options(modelPath, prompt, systemPrompt, interactive, temperature, topp, seed, maxTokens, stream, echo); + return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo); } } diff --git a/src/main/java/com/example/aot/AOT.java b/src/main/java/com/example/aot/AOT.java index 11713a8a..837a56b4 100644 --- a/src/main/java/com/example/aot/AOT.java +++ b/src/main/java/com/example/aot/AOT.java @@ -3,8 +3,9 @@ import com.example.auxiliary.Timer; import com.example.core.model.GGUF; import com.example.core.model.tensor.GGMLTensorEntry; -import com.example.inference.engine.impl.Llama; -import com.example.inference.engine.impl.Options; +import com.example.model.Model; +import com.example.Options; +import com.example.model.llama.Llama; import com.example.loader.weights.ModelLoader; import com.example.loader.weights.Weights; @@ -45,7 +46,7 @@ private static PartialModel preLoadGGUF(String modelPath) { try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) { return new PartialModel( path.getFileName().toString(), - ModelLoader.loadModel(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false), + Llama.loadModel(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false), // TODO: needs proper handling for AOT gguf.getTensorDataOffset(), gguf.getTensorInfos() ); @@ -60,7 +61,7 @@ private static PartialModel preLoadGGUF(String modelPath) { * The file name (base name) must match with the preloaded file name. * No checksum/hash is checked for performance reasons. */ - public static com.example.inference.engine.impl.Llama tryUsePreLoaded(Path modelPath, int contextLength) throws IOException { + public static Model tryUsePreLoaded(Path modelPath, int contextLength) throws IOException { AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF; if (preLoaded == null) { return null; // no pre-loaded model stored diff --git a/src/main/java/com/example/auxiliary/ChatFormat.java b/src/main/java/com/example/auxiliary/ChatFormat.java deleted file mode 100644 index dd4aa825..00000000 --- a/src/main/java/com/example/auxiliary/ChatFormat.java +++ /dev/null @@ -1,83 +0,0 @@ -package com.example.auxiliary; - -import com.example.tokenizer.impl.Tokenizer; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Set; - -public class ChatFormat { - - final Tokenizer tokenizer; - public final int beginOfText; - final int endHeader; - final int startHeader; - final int endOfTurn; - final int endOfText; - final int endOfMessage; - final Set stopTokens; - - public ChatFormat(Tokenizer tokenizer) { - this.tokenizer = tokenizer; - Map specialTokens = this.tokenizer.getSpecialTokens(); - this.beginOfText = specialTokens.get("<|begin_of_text|>"); - this.startHeader = specialTokens.get("<|start_header_id|>"); - this.endHeader = specialTokens.get("<|end_header_id|>"); - this.endOfTurn = specialTokens.get("<|eot_id|>"); - this.endOfText = specialTokens.get("<|end_of_text|>"); - this.endOfMessage = specialTokens.getOrDefault("<|eom_id|>", -1); // only in 3.1 - this.stopTokens = Set.of(endOfText, endOfTurn); - } - - public Tokenizer getTokenizer() { - return tokenizer; - } - - public Set getStopTokens() { - return stopTokens; - } - - public List encodeHeader(ChatFormat.Message message) { - List tokens = new ArrayList<>(); - tokens.add(startHeader); - tokens.addAll(this.tokenizer.encodeAsList(message.role().name())); - tokens.add(endHeader); - tokens.addAll(this.tokenizer.encodeAsList("\n")); - return tokens; - } - - public List encodeMessage(ChatFormat.Message message) { - List tokens = this.encodeHeader(message); - tokens.addAll(this.tokenizer.encodeAsList(message.content().strip())); - tokens.add(endOfTurn); - return tokens; - } - - public List encodeDialogPrompt(boolean appendAssistantTurn, List dialog) { - List tokens = new ArrayList<>(); - tokens.add(beginOfText); - for (ChatFormat.Message message : dialog) { - tokens.addAll(this.encodeMessage(message)); - } - if (appendAssistantTurn) { - // Add the start of an assistant message for the model to complete. - tokens.addAll(this.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); - } - return tokens; - } - - public record Message(ChatFormat.Role role, String content) { - } - - public record Role(String name) { - public static ChatFormat.Role SYSTEM = new ChatFormat.Role("system"); - public static ChatFormat.Role USER = new ChatFormat.Role("user"); - public static ChatFormat.Role ASSISTANT = new ChatFormat.Role("assistant"); - - @Override - public String toString() { - return name; - } - } -} \ No newline at end of file diff --git a/src/main/java/com/example/auxiliary/LastRunMetrics.java b/src/main/java/com/example/auxiliary/LastRunMetrics.java new file mode 100644 index 00000000..f74e554f --- /dev/null +++ b/src/main/java/com/example/auxiliary/LastRunMetrics.java @@ -0,0 +1,33 @@ +package com.example.auxiliary; + +/** + * Record to store metrics from the last model run. + * @param totalTokens The total number of tokens processed + * @param totalSeconds The total time in seconds + */ +public record LastRunMetrics(int totalTokens, double totalSeconds) { + /** + * Singleton instance to store the latest metrics + */ + private static LastRunMetrics latestMetrics; + + /** + * Sets the metrics for the latest run + * + * @param tokens The total number of tokens processed + * @param seconds The total time in seconds + */ + public static void setMetrics(int tokens, double seconds) { + latestMetrics = new LastRunMetrics(tokens, seconds); + } + + /** + * Prints the metrics from the latest run to stderr + */ + public static void printMetrics() { + if (latestMetrics != null) { + double tokensPerSecond = latestMetrics.totalTokens() / latestMetrics.totalSeconds(); + System.err.printf("\n\nachieved tok/s: %.2f. Tokens: %d, seconds: %.2f\n", tokensPerSecond, latestMetrics.totalTokens(), latestMetrics.totalSeconds()); + } + } +} diff --git a/src/main/java/com/example/inference/InferenceCore.java b/src/main/java/com/example/inference/InferenceCore.java new file mode 100644 index 00000000..81e432c2 --- /dev/null +++ b/src/main/java/com/example/inference/InferenceCore.java @@ -0,0 +1,198 @@ +package com.example.inference; + +import com.example.auxiliary.Parallel; +import com.example.core.model.tensor.FloatTensor; +import com.example.loader.weights.State; +import com.example.loader.weights.Weights; +import com.example.model.Configuration; +import com.example.model.Model; +import com.example.tornadovm.TornadoVMMasterPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +import java.lang.foreign.MemorySegment; +import java.nio.FloatBuffer; + +/** + * Low-level operations for model inference. + * + *

+ * Provides core computational operations: RMS normalization and forward passes + * through model layers. Supports both CPU and GPU implementations. + */ + +public final class InferenceCore { + + private InferenceCore() { + // prevent instantiation + } + + public static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size, float rmsNormEps) { + // calculate sum of squares + float ss = x.reduce(0, size, 0f, (acc, xi) -> acc + xi * xi); + ss /= size; + ss += rmsNormEps; + ss = (float) (1.0 / Math.sqrt(ss)); + // normalize and scale + final float finalss = ss; // for the lambda + out.mapWithIndexInPlace(0, size, (value, index) -> weight.get(index) * (finalss * x.getFloat(index))); + } + + public static FloatTensor forwardJava(Model model, State state, int token, int position) { + // a few convenience variables + final Configuration config = model.configuration(); + final Weights weights = model.weights(); + int dim = config.dim(); + int headSize = config.headSize(); + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); // integer multiplier of the kv sharing in multiquery + float sqrtHeadSize = (float) Math.sqrt(headSize); + + // copy the token embedding into x + weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim); + + // forward all the layers + for (int l = 0; l < config.numberOfLayers(); l++) { + // attention rmsnorm + rmsnorm(state.xb, state.x, weights.rms_att_weight[l], dim, config.rmsNormEps()); + + // qkv matmuls for this position + + weights.wq[l].matmul(state.xb, state.q, dim, dim); + weights.wk[l].matmul(state.xb, state.k, kvDim, dim); + weights.wv[l].matmul(state.xb, state.v, kvDim, dim); + + // RoPE relative positional encoding: complex-valued rotate q and k in each head + for (int i = 0; i < dim; i += 2) { + int head_dim = i % headSize; + float fcr = weights.freq_cis_real.get(position * (headSize / 2) + (head_dim / 2)); + float fci = weights.freq_cis_imag.get(position * (headSize / 2) + (head_dim / 2)); + int rotn = i < kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only + for (int v = 0; v < rotn; v++) { + FloatTensor vec = v == 0 ? state.q : state.k; // the vector to rotate (query or key) + float v0 = vec.getFloat(i); + float v1 = vec.getFloat(i + 1); + vec.setFloat(i, v0 * fcr - v1 * fci); + vec.setFloat(i + 1, v0 * fci + v1 * fcr); + } + } + + // save key,value at this time step (position) to our kv cache + //int loff = l * config.seq_len * kvDim; + // kv cache layer offset for convenience + state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim); + state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim); + + int curLayer = l; + + // multihead attention. iterate over all heads + Parallel.parallelFor(0, config.numberOfHeads(), h -> { + // get the query vector for this head + // float* q = s.q + h * headSize; + int qOffset = h * headSize; + + // attention scores for this head + // float* att = s.att + h * config.seq_len; + int attOffset = h * config.contextLength(); + + // iterate over all timesteps, including the current one + for (int t = 0; t <= position; t++) { + // get the key vector for this head and at this timestep + // float* k = s.key_cache + loff + t * dim + h * headSize; + int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize; + // calculate the attention score as the dot product of q and k + float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize); + score /= sqrtHeadSize; + // save the score to the attention buffer + state.att.setFloat(attOffset + t, score); + } + + // softmax the scores to get attention weights, from 0..position inclusively + state.att.softmaxInPlace(attOffset, position + 1); + + // weighted sum of the values, store back into xb + // float* xb = s.xb + h * headSize; + int xbOffset = h * headSize; + // memset(xb, 0, headSize * sizeof(float)); + state.xb.fillInPlace(xbOffset, headSize, 0f); + + for (int t = 0; t <= position; t++) { + // get the value vector for this head and at this timestep + // float* v = s.value_cache + loff + t * dim + h * headSize; + int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize; + // get the attention weight for this timestep + float a = state.att.getFloat(attOffset + t); + // accumulate the weighted value into xb + state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a); + } + }); + + // final matmul to get the output of the attention + weights.wo[l].matmul(state.xb, state.xb2, dim, dim); + + // residual connection back into x + state.x.addInPlace(state.xb2); + + // ffn rmsnorm + rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], dim, config.rmsNormEps()); + + // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) + // first calculate self.w1(x) and self.w3(x) + weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim); + weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim); + + // SwiGLU non-linearity + // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid + state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); + + // elementwise multiply with w3(x) + state.hb.multiplyInPlace(state.hb2); + + // final matmul to get the output of the ffn + weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim()); + + // residual connection + state.x.addInPlace(state.xb); + } + + rmsnorm(state.x, state.x, weights.rms_final_weight, dim, config.rmsNormEps()); + + weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim); + + return state.logits; + } + + /** + * Performs the initial embedding lookup and triggers the TornadoVM accelerated forward pass for an LLM token. + * + *

This method handles the first phase of processing a token through the transformer model: + *

    + *
  1. Copies the token embedding from the model's embedding table to the state's buffer
  2. + *
  3. Delegates the transformer layer processing to TornadoVM through the master plan
  4. + *
+ * + *

The token embedding lookup happens on the CPU using {@link MemorySegment} operations, + * while the subsequent transformer layers processing is offloaded to the accelerator through + * TornadoVM for improved performance. + * + * @param model + * The Llama model containing weights and configuration parameters + * @param state + * The current execution state holding input/output tensors and temporary buffers + * @param token + * The input token ID to process + * @param position + * The position of this token in the sequence context window + * @param tornadoVMMasterPlan + * The execution plan for TornadoVM acceleration + * @return FloatTensor containing the output logits for token prediction + */ + public static FloatArray forwardTornadoVM(Model model, State state, int token, int position, TornadoVMMasterPlan tornadoVMMasterPlan) { + final Configuration configuration = model.configuration(); + final Weights weights = model.weights(); + + MemorySegment.copy(weights.tokenEmbeddingTable.getSegment(), token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES); + + return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position); + } + +} diff --git a/src/main/java/com/example/inference/InferenceEngine.java b/src/main/java/com/example/inference/InferenceEngine.java new file mode 100644 index 00000000..814b1ae9 --- /dev/null +++ b/src/main/java/com/example/inference/InferenceEngine.java @@ -0,0 +1,213 @@ +package com.example.inference; + +import com.example.auxiliary.LastRunMetrics; +import com.example.inference.sampler.Sampler; +import com.example.loader.weights.State; +import com.example.model.Configuration; +import com.example.model.Model; +import com.example.tokenizer.impl.Tokenizer; +import com.example.tornadovm.TornadoVMMasterPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.function.IntConsumer; + +/** + * Main entry point for LLM token generation. + * + *

+ * Orchestrates the complete inference process: ingests prompt tokens, then generates + * new tokens until a stop condition is met. Supports both CPU and GPU execution. + */ +public final class InferenceEngine { + + private InferenceEngine() { + //prevent instantiation + } + + /** + * LLM generation entry point, ingest prompt tokens and generates new tokens. + * + *

+ * All prompt tokens are ingested first, then inference starts, until a stop token is found. + * The returned tokens only include generated/inferred tokens. + * + * @param model model to run inference (including weights, configuration, tokenizer ...) + * @param state state of the model e.g. key/value caches ... this is mutated by this call + * @param startPosition start prompt ingestion + inference at this position in the context e.g. useful if state was kept across calls (chained generation). 0 implies run with no previous context. + * @param promptTokens prompt tokens to ingest, all the prompt tokens will be ingested, given there's enough capacity left in the context + * @param stopTokens set of tokens that abort generation during inference, stop tokens do not affect prompt ingestion + * @param maxTokens maximum number of tokens (can go up to {@link Configuration#contextLength context length} + * if this value is negative or greater than {@link Configuration#contextLength context length} + * @param sampler {@link Sampler strategy} used to select tokens + * @param echo debugging flag, prints ALL, prompt and inferred tokens, to {@link System#err stderr} + * @param onTokenGenerated callback, if non-null, it's called every time a token is inferred e.g. it's not called when ingesting prompt tokens + * @return list of generated/inferred tokens, including the stop token, if any e.g. does not include any token from the prompt + */ + public static List generateTokens(Model model, State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated) { + // Start timing the whole process + long startNanos = System.nanoTime(); + long inferenceStartNanos = 0; + + Object logits; + // Validate and adjust maxTokens if necessary + if (maxTokens < 0 || model.configuration().contextLength() < maxTokens) { + maxTokens = model.configuration().contextLength(); + } + + // Storage for generated tokens + List generatedTokens = new ArrayList<>(); + + // Initialize token variables + int currentToken = state.latestToken; + int nextToken; + int promptIndex = 0; + int pos = startPosition; + + while (pos < maxTokens) { + + logits = InferenceCore.forwardJava(model, state, currentToken, pos); + + // Handle token processing + if (promptIndex < promptTokens.size()) { + // We're still processing the prompt tokens + nextToken = promptTokens.get(promptIndex++); + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + } + } else { + // Mark the start of actual generation (after prompt processing) + if (inferenceStartNanos == 0) { + inferenceStartNanos = System.nanoTime(); + } + + // Sample the next token + nextToken = sampler.sampleToken(logits); + + // Output the token if echo is enabled + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + } + + // Track the generated token + generatedTokens.add(nextToken); + + // Notify via callback if provided + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + // Check for stop condition + if (stopTokens.contains(nextToken)) { + break; + } + } + + // Update for next iteration + currentToken = nextToken; + state.latestToken = currentToken; + pos++; + } + + // Calculate and print performance metrics + long endNanos = System.nanoTime(); + double totalTimeSeconds = (endNanos - startNanos) / 1_000_000_000.0; + int totalTokens = promptIndex + generatedTokens.size(); + + LastRunMetrics.setMetrics(totalTokens, totalTimeSeconds); + + return generatedTokens; + } + + public static List generateTokensGPU(Model model, State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + // === Setup and Initialization === + long startNanos = System.nanoTime(); + long inferenceStartNanos = 0; + + // Pre-validate the max tokens to avoid checking in the loop + int actualMaxTokens = Math.min(maxTokens > 0 ? maxTokens : model.configuration().contextLength(), model.configuration().contextLength()); + + // Preallocate with expected capacity to avoid resizing + List generatedTokens = new ArrayList<>(Math.min(256, actualMaxTokens - promptTokens.size())); // Conservative estimate + + // === Token Generation Loop === + int currentToken = state.latestToken; + int nextToken; + int promptIndex = 0; + int pos = startPosition; + + // Use more efficient direct array access for prompt tokens if possible + int[] promptTokenArray = null; + if (promptTokens instanceof ArrayList) { + // Try to extract the underlying array for faster access + try { + // This is a performance optimization that may not work on all JVMs + promptTokenArray = promptTokens.stream().mapToInt(Integer::intValue).toArray(); + } catch (Exception e) { + // Fall back to list access + } + } + + // Main generation loop + while (pos < actualMaxTokens) { + // GPU Forward Pass - No conditional check since we know we're using GPU + FloatArray logits = InferenceCore.forwardTornadoVM(model, state, currentToken, pos, tornadoVMPlan); + + // Process prompt tokens if still remaining + if (promptIndex < promptTokens.size()) { + // Get next prompt token (using array access if available) + nextToken = promptTokenArray != null ? promptTokenArray[promptIndex++] : promptTokens.get(promptIndex++); + + if (echo) { + // Decode and output token + System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + } + } else { + // Mark first inference token + if (inferenceStartNanos == 0) { + inferenceStartNanos = System.nanoTime(); + } + + // Sample next token - use GPU sampling if available + nextToken = sampler.sampleToken(logits); + + // Add token consumer support + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + // Output if needed + if (echo && onTokenGenerated == null) { + System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + } + + // Store token + generatedTokens.add(nextToken); + + // Check stop condition + if (stopTokens.contains(nextToken)) { + break; + } + } + + // Update for next iteration + currentToken = nextToken; + state.latestToken = currentToken; + pos++; + } + + // === Performance Metrics === + long endNanos = System.nanoTime(); + double totalSeconds = (endNanos - startNanos) / 1_000_000_000.0; + int totalTokens = promptIndex + generatedTokens.size(); + + // Set metrics for tokens achieved + LastRunMetrics.setMetrics(totalTokens, totalSeconds); + + return generatedTokens; + } +} diff --git a/src/main/java/com/example/inference/engine/impl/Configuration.java b/src/main/java/com/example/inference/engine/impl/Configuration.java deleted file mode 100644 index 31042f38..00000000 --- a/src/main/java/com/example/inference/engine/impl/Configuration.java +++ /dev/null @@ -1,94 +0,0 @@ -package com.example.inference.engine.impl; - -public final class Configuration { - /** Transformer embedding dimension */ - public final int dim; - - /** Hidden dimension size for feed-forward network layers */ - public final int hiddenDim; - - /** Number of transformer layers in the model */ - public final int numberOfLayers; - - /** Number of attention heads for queries */ - public final int numberOfHeads; - - /** Number of key/value heads (can be fewer than query heads in multi-query attention) */ - public final int numberOfKeyValueHeads; - - /** Size of the vocabulary (token set) */ - public final int vocabularySize; - - /** Maximum sequence length the model can process */ - public final int contextLength; - - /** Epsilon value for RMSNorm layers (stabilizes normalization) */ - public final float rmsNormEps; - - /** Base value for RoPE (Rotary Position Embedding) calculations */ - public final float ropeTheta; - - /** Size of each attention head (derived from dim / numberOfHeads) */ - public final int headSize; - - /** Key/value dimension (derived from dim * numberOfKeyValueHeads / numberOfHeads) */ - public final int kvDim; - - /** Multiplier for key/value sharing in multi-query attention */ - public final int kvMul; - - /** - - /** - * Constructs a new Configuration with the specified parameters. - * - * @param dim Transformer embedding dimension - * @param hiddenDim Hidden dimension for feed-forward layers - * @param numberOfLayers Number of transformer layers - * @param numberOfHeads Number of attention heads - * @param numberOfKeyValueHeads Number of key/value heads - * @param vocabularySize Size of the vocabulary - * @param contextLength Maximum sequence length - * @param rmsNormEps Epsilon for RMSNorm - * @param ropeTheta Base value for RoPE calculations - */ - public Configuration(int dim, int hiddenDim, int numberOfLayers, int numberOfHeads, int numberOfKeyValueHeads, int vocabularySize, int contextLength, float rmsNormEps, float ropeTheta) { - this.dim = dim; - this.hiddenDim = hiddenDim; - this.numberOfLayers = numberOfLayers; - this.numberOfHeads = numberOfHeads; - this.numberOfKeyValueHeads = numberOfKeyValueHeads; - this.vocabularySize = vocabularySize; - this.contextLength = contextLength; - this.rmsNormEps = rmsNormEps; - this.ropeTheta = ropeTheta; - this.headSize = dim / numberOfHeads; - this.kvDim = dim * numberOfKeyValueHeads / numberOfHeads; - this.kvMul = numberOfHeads / numberOfKeyValueHeads; - } - - /** - * Creates a new Configuration with a different context length. - * - * @param newContextLength The new context length to use - * @return A new Configuration instance with updated context length, - * or the current instance if newContextLength is negative - */ - public Configuration withContextLength(int newContextLength) { - if (newContextLength < 0) { - return this; // no change - } - return new Configuration( - this.dim, - this.hiddenDim, - this.numberOfLayers, - this.numberOfHeads, - this.numberOfKeyValueHeads, - this.vocabularySize, - newContextLength, - this.rmsNormEps, - this.ropeTheta - ); - } -} - diff --git a/src/main/java/com/example/inference/engine/impl/Llama.java b/src/main/java/com/example/inference/engine/impl/Llama.java deleted file mode 100644 index c0ab30bb..00000000 --- a/src/main/java/com/example/inference/engine/impl/Llama.java +++ /dev/null @@ -1,409 +0,0 @@ -package com.example.inference.engine.impl; - -import com.example.auxiliary.Parallel; -import com.example.core.model.tensor.FloatTensor; -import com.example.inference.Sampler; -import com.example.loader.weights.State; -import com.example.loader.weights.Weights; -import com.example.tokenizer.impl.Tokenizer; -import com.example.tornadovm.TornadoVMMasterPlan; -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; - -import java.lang.foreign.MemorySegment; -import java.nio.FloatBuffer; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; -import java.util.function.IntConsumer; - -public record Llama(Configuration configuration, Tokenizer tokenizer, Weights weights) { - private static final int BATCH_SIZE = Integer.getInteger("llama.BatchSize", 16); - - public static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size, float rmsNormEps) { - // calculate sum of squares - float ss = x.reduce(0, size, 0f, (acc, xi) -> acc + xi * xi); - ss /= size; - ss += rmsNormEps; - ss = (float) (1.0 / Math.sqrt(ss)); - // normalize and scale - final float finalss = ss; // for the lambda - out.mapWithIndexInPlace(0, size, (value, index) -> weight.get(index) * (finalss * x.getFloat(index))); - } - - public static FloatTensor forwardJava(Llama model, State state, int token, int position) { - // a few convenience variables - Configuration config = model.configuration(); - Weights weights = model.weights(); - int dim = config.dim; - int headSize = config.headSize; - int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads; - int kvMul = config.numberOfHeads / config.numberOfKeyValueHeads; // integer multiplier of the kv sharing in multiquery - float sqrtHeadSize = (float) Math.sqrt(headSize); - - // copy the token embedding into x - weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim); - - // forward all the layers - for (int l = 0; l < config.numberOfLayers; l++) { - // attention rmsnorm - rmsnorm(state.xb, state.x, weights.rms_att_weight[l], dim, config.rmsNormEps); - - // qkv matmuls for this position - - weights.wq[l].matmul(state.xb, state.q, dim, dim); - weights.wk[l].matmul(state.xb, state.k, kvDim, dim); - weights.wv[l].matmul(state.xb, state.v, kvDim, dim); - - // RoPE relative positional encoding: complex-valued rotate q and k in each head - for (int i = 0; i < dim; i += 2) { - int head_dim = i % headSize; - float fcr = weights.freq_cis_real.get(position * (headSize / 2) + (head_dim / 2)); - float fci = weights.freq_cis_imag.get(position * (headSize / 2) + (head_dim / 2)); - int rotn = i < kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only - for (int v = 0; v < rotn; v++) { - FloatTensor vec = v == 0 ? state.q : state.k; // the vector to rotate (query or key) - float v0 = vec.getFloat(i); - float v1 = vec.getFloat(i + 1); - vec.setFloat(i, v0 * fcr - v1 * fci); - vec.setFloat(i + 1, v0 * fci + v1 * fcr); - } - } - - // save key,value at this time step (position) to our kv cache - //int loff = l * config.seq_len * kvDim; - // kv cache layer offset for convenience - state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim); - state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim); - - int curLayer = l; - - // multihead attention. iterate over all heads - Parallel.parallelFor(0, config.numberOfHeads, h -> { - // get the query vector for this head - // float* q = s.q + h * headSize; - int qOffset = h * headSize; - - // attention scores for this head - // float* att = s.att + h * config.seq_len; - int attOffset = h * config.contextLength; - - // iterate over all timesteps, including the current one - for (int t = 0; t <= position; t++) { - // get the key vector for this head and at this timestep - // float* k = s.key_cache + loff + t * dim + h * headSize; - int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize; - // calculate the attention score as the dot product of q and k - float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize); - score /= sqrtHeadSize; - // save the score to the attention buffer - state.att.setFloat(attOffset + t, score); - } - - // softmax the scores to get attention weights, from 0..position inclusively - state.att.softmaxInPlace(attOffset, position + 1); - - // weighted sum of the values, store back into xb - // float* xb = s.xb + h * headSize; - int xbOffset = h * headSize; - // memset(xb, 0, headSize * sizeof(float)); - state.xb.fillInPlace(xbOffset, headSize, 0f); - - for (int t = 0; t <= position; t++) { - // get the value vector for this head and at this timestep - // float* v = s.value_cache + loff + t * dim + h * headSize; - int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize; - // get the attention weight for this timestep - float a = state.att.getFloat(attOffset + t); - // accumulate the weighted value into xb - state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a); - } - }); - - // final matmul to get the output of the attention - weights.wo[l].matmul(state.xb, state.xb2, dim, dim); - - // residual connection back into x - state.x.addInPlace(state.xb2); - - // ffn rmsnorm - rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], dim, config.rmsNormEps); - - // System.out.println("x " + weights.w1.toString() + " " + weights.w2.toString() + " " + weights.w3.toString()); - // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) - // first calculate self.w1(x) and self.w3(x) - weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim, dim); - weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim, dim); - - // SwiGLU non-linearity - // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid - state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); - - // elementwise multiply with w3(x) - state.hb.multiplyInPlace(state.hb2); - - // final matmul to get the output of the ffn - weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim); - - // residual connection - state.x.addInPlace(state.xb); - } - - rmsnorm(state.x, state.x, weights.rms_final_weight, dim, config.rmsNormEps); - - weights.wcls.matmul(state.x, state.logits, config.vocabularySize, dim); - - return state.logits; - } - - /** - * Performs the initial embedding lookup and triggers the TornadoVM accelerated forward pass for an LLM token. - * - *

This method handles the first phase of processing a token through the transformer model: - *

    - *
  1. Copies the token embedding from the model's embedding table to the state's buffer
  2. - *
  3. Delegates the transformer layer processing to TornadoVM through the master plan
  4. - *
- * - *

The token embedding lookup happens on the CPU using {@link MemorySegment} operations, - * while the subsequent transformer layers processing is offloaded to the accelerator through - * TornadoVM for improved performance. - * - * @param model - * The Llama model containing weights and configuration parameters - * @param state - * The current execution state holding input/output tensors and temporary buffers - * @param token - * The input token ID to process - * @param position - * The position of this token in the sequence context window - * @param tornadoVMMasterPlan - * The execution plan for TornadoVM acceleration - * @return FloatTensor containing the output logits for token prediction - */ - public static FloatArray forwardTornadoVM( // - Llama model, // - State state, // - int token, // - int position, // - TornadoVMMasterPlan tornadoVMMasterPlan) { // - - MemorySegment.copy(model.weights.tokenEmbeddingTable.getSegment(), token * model.configuration.dim * Float.BYTES, state.wrapX.getSegment(), 0, model.configuration.dim * Float.BYTES); - - return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position); - } - - public static List generateTokensGPU(Llama model, State state, - int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, - TornadoVMMasterPlan tornadoVMPlan) { - // === Setup and Initialization === - long startNanos = System.nanoTime(); - long inferenceStartNanos = 0; - - // Pre-validate the max tokens to avoid checking in the loop - int actualMaxTokens = Math.min(maxTokens > 0 ? maxTokens : model.configuration().contextLength, model.configuration().contextLength); - - // Preallocate with expected capacity to avoid resizing - List generatedTokens = new ArrayList<>(Math.min(256, actualMaxTokens - promptTokens.size())); // Conservative estimate - - // === Token Generation Loop === - int currentToken = state.latestToken; - int nextToken; - int promptIndex = 0; - int pos = startPosition; - - // Use more efficient direct array access for prompt tokens if possible - int[] promptTokenArray = null; - if (promptTokens instanceof ArrayList) { - // Try to extract the underlying array for faster access - try { - // This is a performance optimization that may not work on all JVMs - promptTokenArray = promptTokens.stream().mapToInt(Integer::intValue).toArray(); - } catch (Exception e) { - // Fall back to list access - } - } - - // Main generation loop - while (pos < actualMaxTokens) { - // GPU Forward Pass - No conditional check since we know we're using GPU - FloatArray logits = forwardTornadoVM(model, state, currentToken, pos, tornadoVMPlan); - - // Process prompt tokens if still remaining - if (promptIndex < promptTokens.size()) { - // Get next prompt token (using array access if available) - nextToken = promptTokenArray != null ? promptTokenArray[promptIndex++] : promptTokens.get(promptIndex++); - - if (echo) { - // Decode and output token - System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); - } - } else { - // Mark first inference token - if (inferenceStartNanos == 0) { - inferenceStartNanos = System.nanoTime(); - } - - // Sample next token - use GPU sampling if available - nextToken = sampler.sampleToken(logits); - - // Add token consumer support - if (onTokenGenerated != null) { - onTokenGenerated.accept(nextToken); - } - - // Output if needed - if (echo && onTokenGenerated == null) { - System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); - } - - // Store token - generatedTokens.add(nextToken); - - // Check stop condition - if (stopTokens.contains(nextToken)) { - break; - } - } - - // Update for next iteration - currentToken = nextToken; - state.latestToken = currentToken; - pos++; - } - - // === Performance Metrics === - long endNanos = System.nanoTime(); - double totalSeconds = (endNanos - startNanos) / 1_000_000_000.0; - int totalTokens = promptIndex + generatedTokens.size(); - - // Set metrics for tokens achieved - LastRunMetrics.setMetrics(totalTokens, totalSeconds); - - return generatedTokens; - } - - public static List generateTokens(Llama model, State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, - IntConsumer onTokenGenerated) { - // Initialize TornadoVM plan if enabled - - // Start timing the whole process - long startNanos = System.nanoTime(); - long inferenceStartNanos = 0; - - Object logits; - // Validate and adjust maxTokens if necessary - if (maxTokens < 0 || model.configuration().contextLength < maxTokens) { - maxTokens = model.configuration().contextLength; - } - - // Storage for generated tokens - List generatedTokens = new ArrayList<>(); - - // Initialize token variables - int currentToken = state.latestToken; - int nextToken; - int promptIndex = 0; - int pos = startPosition; - - while (pos < maxTokens) { - - logits = forwardJava(model, state, currentToken, pos); - - // Handle token processing - if (promptIndex < promptTokens.size()) { - // We're still processing the prompt tokens - nextToken = promptTokens.get(promptIndex++); - if (echo) { - System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); - } - } else { - // Mark the start of actual generation (after prompt processing) - if (inferenceStartNanos == 0) { - inferenceStartNanos = System.nanoTime(); - } - - // Sample the next token - nextToken = sampler.sampleToken(logits); - - // Output the token if echo is enabled - if (echo) { - System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); - } - - // Track the generated token - generatedTokens.add(nextToken); - - // Notify via callback if provided - if (onTokenGenerated != null) { - onTokenGenerated.accept(nextToken); - } - - // Check for stop condition - if (stopTokens.contains(nextToken)) { - break; - } - } - - // Update for next iteration - currentToken = nextToken; - state.latestToken = currentToken; - pos++; - } - - // Calculate and print performance metrics - long endNanos = System.nanoTime(); - double totalTimeSeconds = (endNanos - startNanos) / 1_000_000_000.0; - int totalTokens = promptIndex + generatedTokens.size(); - - LastRunMetrics.setMetrics(totalTokens, totalTimeSeconds); - - return generatedTokens; - } - - - public State createNewState() { - State state = new State(configuration(), -1); - state.latestToken = tokenizer.getSpecialTokens().get("<|begin_of_text|>"); - return state; - } - - public State createNewState(int batchsize) { - State state = new State(configuration(), batchsize); - state.latestToken = tokenizer.getSpecialTokens().get("<|begin_of_text|>"); - return state; - } - - /** - * Record to store metrics from the last model run. - * @param totalTokens The total number of tokens processed - * @param totalSeconds The total time in seconds - */ - public record LastRunMetrics(int totalTokens, double totalSeconds) { - /** - * Singleton instance to store the latest metrics - */ - private static LastRunMetrics latestMetrics; - - /** - * Sets the metrics for the latest run - * - * @param tokens The total number of tokens processed - * @param seconds The total time in seconds - */ - public static void setMetrics(int tokens, double seconds) { - latestMetrics = new LastRunMetrics(tokens, seconds); - } - - /** - * Prints the metrics from the latest run to stderr - */ - public static void printMetrics() { - if (latestMetrics != null) { - double tokensPerSecond = latestMetrics.totalTokens() / latestMetrics.totalSeconds(); - System.err.printf("\n\nachieved tok/s: %.2f. Tokens: %d, seconds: %.2f\n", tokensPerSecond, latestMetrics.totalTokens(), latestMetrics.totalSeconds()); - } - } - } - -} - diff --git a/src/main/java/com/example/inference/CategoricalSampler.java b/src/main/java/com/example/inference/sampler/CategoricalSampler.java similarity index 98% rename from src/main/java/com/example/inference/CategoricalSampler.java rename to src/main/java/com/example/inference/sampler/CategoricalSampler.java index 83d90ed5..acbc6a47 100644 --- a/src/main/java/com/example/inference/CategoricalSampler.java +++ b/src/main/java/com/example/inference/sampler/CategoricalSampler.java @@ -1,4 +1,4 @@ -package com.example.inference; +package com.example.inference.sampler; import com.example.core.model.tensor.FloatTensor; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; diff --git a/src/main/java/com/example/inference/Sampler.java b/src/main/java/com/example/inference/sampler/Sampler.java similarity index 97% rename from src/main/java/com/example/inference/Sampler.java rename to src/main/java/com/example/inference/sampler/Sampler.java index 5f83656c..4cb7c42e 100644 --- a/src/main/java/com/example/inference/Sampler.java +++ b/src/main/java/com/example/inference/sampler/Sampler.java @@ -1,4 +1,4 @@ -package com.example.inference; +package com.example.inference.sampler; import com.example.core.model.tensor.FloatTensor; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; diff --git a/src/main/java/com/example/inference/ToppSampler.java b/src/main/java/com/example/inference/sampler/ToppSampler.java similarity index 99% rename from src/main/java/com/example/inference/ToppSampler.java rename to src/main/java/com/example/inference/sampler/ToppSampler.java index a7ee8a1e..d08ded48 100644 --- a/src/main/java/com/example/inference/ToppSampler.java +++ b/src/main/java/com/example/inference/sampler/ToppSampler.java @@ -1,4 +1,4 @@ -package com.example.inference; +package com.example.inference.sampler; import com.example.core.model.tensor.FloatTensor; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; diff --git a/src/main/java/com/example/loader/weights/ModelLoader.java b/src/main/java/com/example/loader/weights/ModelLoader.java index 8b98e184..353600e3 100644 --- a/src/main/java/com/example/loader/weights/ModelLoader.java +++ b/src/main/java/com/example/loader/weights/ModelLoader.java @@ -1,7 +1,6 @@ package com.example.loader.weights; import com.example.LlamaApp; -import com.example.auxiliary.Timer; import com.example.core.model.GGMLType; import com.example.core.model.GGUF; import com.example.core.model.tensor.F16FloatTensor; @@ -10,11 +9,10 @@ import com.example.core.model.tensor.Q4_0FloatTensor; import com.example.core.model.tensor.Q8_0FloatTensor; import com.example.core.types.Pair; -import com.example.inference.engine.impl.Configuration; -import com.example.inference.engine.impl.Llama; +import com.example.model.Configuration; +import com.example.model.Model; +import com.example.model.ModelType; import com.example.inference.operation.RoPE; -import com.example.tokenizer.impl.Tokenizer; -import com.example.tokenizer.vocabulary.Vocabulary; import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; @@ -26,45 +24,55 @@ import java.nio.channels.FileChannel; import java.nio.file.Path; import java.nio.file.StandardOpenOption; -import java.util.Arrays; -import java.util.List; import java.util.Map; import java.util.function.IntFunction; -import java.util.stream.Collectors; -import java.util.stream.IntStream; public final class ModelLoader { private static final String TOKENIZER_LLAMA_3_MODEL = "gpt2"; + private static final String TOKENIZER_MISTRAL_MODEL = "llama"; + + private static ModelType detectModelType(Map metadata) { + String name = (String) metadata.get("general.name"); + String tokenizerModel = (String) metadata.get("tokenizer.ggml.model"); + Integer vocabSize = (Integer) metadata.get("llama.vocab_size"); + + // Check by name first + if (name != null) { + String lowerName = name.toLowerCase(); + if (lowerName.contains("mistral")) { + return ModelType.MISTRAL; + } else if (lowerName.contains("llama")) { + return ModelType.LLAMA_3; + } + } - private static final String LLAMA_3_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; - - public static Llama loadModel(Path ggufPath, int contextLength, boolean loadWeights) throws IOException { - GGUF gguf = GGUF.loadModel(ggufPath); - FileChannel fileChannel = FileChannel.open(ggufPath, StandardOpenOption.READ); - return loadModel(fileChannel, gguf, contextLength, loadWeights); - } - - public static Llama loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) throws IOException { - try (var ignored = Timer.log("Load LlaMa model")) { - Map metadata = gguf.getMetadata(); - Vocabulary vocabulary = Vocabulary.loadVocabulary(metadata); - Tokenizer tokenizer = createTokenizer(metadata, vocabulary); - - Configuration config = new Configuration((int) metadata.get("llama.embedding_length"), (int) metadata.get("llama.feed_forward_length"), (int) metadata.get("llama.block_count"), - (int) metadata.get("llama.attention.head_count"), - - metadata.containsKey("llama.attention.head_count_kv") ? (int) metadata.get("llama.attention.head_count_kv") : (int) metadata.get("llama.attention.head_count"), - - vocabulary.size(), (int) metadata.get("llama.context_length"), (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), - (float) metadata.getOrDefault("llama.rope.freq_base", 10000f)).withContextLength(contextLength); + // Check by tokenizer model + if (TOKENIZER_MISTRAL_MODEL.equals(tokenizerModel)) { + return ModelType.MISTRAL; + } else if (TOKENIZER_LLAMA_3_MODEL.equals(tokenizerModel)) { + return ModelType.LLAMA_3; + } - Weights weights = null; - if (loadWeights) { - Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); - weights = loadWeights(tensorEntries, config); + // Check by vocabulary size as fallback + if (vocabSize != null) { + if (vocabSize == 32768) { + return ModelType.MISTRAL; + } else if (vocabSize == 128256) { + return ModelType.LLAMA_3; } - return new Llama(config, tokenizer, weights); } + + return ModelType.UNKNOWN; + } + + public static Model loadModel(Path ggufPath, int contextLength, boolean loadWeights) throws IOException { + // initial load of metadata from gguf file + GGUF gguf = GGUF.loadModel(ggufPath); + FileChannel fileChannel = FileChannel.open(ggufPath, StandardOpenOption.READ); + // detect model type + ModelType modelType = detectModelType(gguf.getMetadata()); + // model type-specific load + return modelType.loadModel(fileChannel, gguf, contextLength, loadWeights); } public static Weights loadWeights(Map tensorEntries, Configuration config) { @@ -75,9 +83,9 @@ public static Weights loadWeights(Map tensorEntries, Co 8192 // oldContextLength ); - Pair ropeFreqs = RoPE.precomputeFreqsCis(config.contextLength, // Maximum sequence length the model can process - config.headSize, // Dimension of each attention head - config.ropeTheta, // Base frequency parameter (typically 10000.0) + Pair ropeFreqs = RoPE.precomputeFreqsCis(config.contextLength(), // Maximum sequence length the model can process + config.headSize(), // Dimension of each attention head + config.ropeTheta(), // Base frequency parameter (typically 10000.0) ropeScaling, // Whether to apply frequency scaling (determined by model type) ropeConfig.scaleFactor, // Scale factor for extending context length (NTK-aware scaling) ropeConfig.loFreqFactor, // Low frequency scaling factor for better long-range dependencies @@ -100,15 +108,15 @@ private static Weights createTornadoVMWeights(Map tenso GGMLTensorEntry outputWeight) { return new Weights( // Load directly to TornadoVM format - loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), - loadArrayAsFloatArrayFromBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), + loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), FloatArray.fromArray(ropeFreqs.first()), FloatArray.fromArray(ropeFreqs.second()), loadTensorAsHalfFloatArray(outputWeight), outputWeight.ggmlType()); } @@ -117,35 +125,18 @@ private static Weights createTornadoVMWeights(Map tenso */ private static Weights createStandardWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { - return new Weights(loadQuantized(tokenEmbeddings), loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), - loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), - loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), - loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), - loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), - loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), - loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), - loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), - loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), toFloatBuffer(tensorEntries.get("output_norm.weight")), + return new Weights(loadQuantized(tokenEmbeddings), loadArrayOfFloatBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfFloatBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), toFloatBuffer(tensorEntries.get("output_norm.weight")), FloatBuffer.wrap(ropeFreqs.first()), FloatBuffer.wrap(ropeFreqs.second()), loadQuantized(outputWeight), outputWeight.ggmlType()); } - private static Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) { - String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges"); - List> merges = Arrays.stream(mergeLines).map(line -> line.split(" ")) - .map(parts -> new Pair<>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).toList(); - - int allTokens = vocabulary.size(); - int baseTokens = 128000; // assume all tokens after the base ones are special. - int reservedSpecialTokens = allTokens - baseTokens; - List specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList(); - - assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent()); - - Map specialTokens = IntStream.range(0, specialTokensList.size()).boxed().collect(Collectors.toMap(i -> specialTokensList.get(i), i -> baseTokens + i)); - - return new Tokenizer(vocabulary, merges, LLAMA_3_PATTERN, specialTokens); - } - public static FloatTensor loadQuantized(GGMLTensorEntry entry) { GGMLType ggmlType = entry.ggmlType(); return switch (ggmlType) { diff --git a/src/main/java/com/example/loader/weights/State.java b/src/main/java/com/example/loader/weights/State.java index 29c2f510..12f968a3 100644 --- a/src/main/java/com/example/loader/weights/State.java +++ b/src/main/java/com/example/loader/weights/State.java @@ -2,7 +2,7 @@ import com.example.core.model.tensor.ArrayFloatTensor; import com.example.core.model.tensor.FloatTensor; -import com.example.inference.engine.impl.Configuration; +import com.example.model.Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; @@ -58,46 +58,46 @@ public final class State { public State(Configuration config, int batchsize) { this.batchsize = -1; - this.x = ArrayFloatTensor.allocate(config.dim); - this.xb = ArrayFloatTensor.allocate(config.dim); - this.xb2 = ArrayFloatTensor.allocate(config.dim); - this.hb = ArrayFloatTensor.allocate(config.hiddenDim); - this.hb2 = ArrayFloatTensor.allocate(config.hiddenDim); - this.q = ArrayFloatTensor.allocate(config.dim); - this.k = ArrayFloatTensor.allocate(config.dim); - this.v = ArrayFloatTensor.allocate(config.dim); - this.att = ArrayFloatTensor.allocate(config.numberOfHeads, config.contextLength); - this.logits = ArrayFloatTensor.allocate(config.vocabularySize); - int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads; - this.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength, kvDim)).limit(config.numberOfLayers).toArray(FloatTensor[]::new); - this.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength, kvDim)).limit(config.numberOfLayers).toArray(FloatTensor[]::new); - - this.wrapX = new FloatArray(config.dim); - this.wrapXb = new FloatArray(config.dim); - this.wrapXb2 = new FloatArray(config.dim); - this.wrapHb = new FloatArray(config.hiddenDim); - this.wrapHb2 = new FloatArray(config.hiddenDim); - - this.wrapLogits = new FloatArray(config.vocabularySize); - this.wrapQ = new FloatArray(config.dim); - this.wrapK = new FloatArray(config.dim); - this.wrapV = new FloatArray(config.dim); + this.x = ArrayFloatTensor.allocate(config.dim()); + this.xb = ArrayFloatTensor.allocate(config.dim()); + this.xb2 = ArrayFloatTensor.allocate(config.dim()); + this.hb = ArrayFloatTensor.allocate(config.hiddenDim()); + this.hb2 = ArrayFloatTensor.allocate(config.hiddenDim()); + this.q = ArrayFloatTensor.allocate(config.dim()); + this.k = ArrayFloatTensor.allocate(config.dim()); + this.v = ArrayFloatTensor.allocate(config.dim()); + this.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength()); + this.logits = ArrayFloatTensor.allocate(config.vocabularySize()); + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + this.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); + this.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); + + this.wrapX = new FloatArray(config.dim()); + this.wrapXb = new FloatArray(config.dim()); + this.wrapXb2 = new FloatArray(config.dim()); + this.wrapHb = new FloatArray(config.hiddenDim()); + this.wrapHb2 = new FloatArray(config.hiddenDim()); + + this.wrapLogits = new FloatArray(config.vocabularySize()); + this.wrapQ = new FloatArray(config.dim()); + this.wrapK = new FloatArray(config.dim()); + this.wrapV = new FloatArray(config.dim()); // dim vs kvdim - this.wrapKeyCache = new FloatArray(config.contextLength * kvDim * config.numberOfLayers); - this.wrapValueCache = new FloatArray(config.contextLength * kvDim * config.numberOfLayers); + this.wrapKeyCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers()); + this.wrapValueCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers()); this.wrapValueCache.init(0.f); this.wrapKeyCache.init(0.f); - this.wrapAtt = new FloatArray(config.numberOfHeads * config.contextLength); + this.wrapAtt = new FloatArray(config.numberOfHeads() * config.contextLength()); this.positionHolder = new IntArray(1); this.latestToken = -1; // this.localSize = 256; // You need at least 9 elements: 1 for the final result + 8 for the workgroup partial sums - this.temp = new FloatArray(1 + ((config.dim + localSize-1) / localSize)); - this.tempFFN = new FloatArray(1 + ((config.dim + localSize-1) / localSize)); - this.tempLogits = new FloatArray(1 + ((config.dim + localSize-1) / localSize)); + this.temp = new FloatArray(1 + ((config.dim() + localSize-1) / localSize)); + this.tempFFN = new FloatArray(1 + ((config.dim() + localSize-1) / localSize)); + this.tempLogits = new FloatArray(1 + ((config.dim() + localSize-1) / localSize)); } @Override diff --git a/src/main/java/com/example/model/Configuration.java b/src/main/java/com/example/model/Configuration.java new file mode 100644 index 00000000..6ff558c1 --- /dev/null +++ b/src/main/java/com/example/model/Configuration.java @@ -0,0 +1,38 @@ +package com.example.model; + +public interface Configuration { + + /** Transformer embedding dimension */ + int dim(); + + /** Hidden dimension size for feed-forward network layers */ + int hiddenDim(); + + /** Number of transformer layers in the model */ + int numberOfLayers(); + + /** Number of attention heads for queries */ + int numberOfHeads(); + + /** Number of key/value heads (can be fewer than query heads in multi-query attention) */ + int numberOfKeyValueHeads(); + + /** Size of the vocabulary (token set) */ + int vocabularySize(); + + /** Maximum sequence length the model can process */ + int contextLength(); + + /** Epsilon value for RMSNorm layers (stabilizes normalization) */ + float rmsNormEps(); + + /** Base value for RoPE (Rotary Position Embedding) calculations */ + float ropeTheta(); + + /** Size of each attention head (derived from dim / numberOfHeads) */ + int headSize(); + + int kvDim(); + + int kvMul(); +} diff --git a/src/main/java/com/example/model/Model.java b/src/main/java/com/example/model/Model.java new file mode 100644 index 00000000..07799562 --- /dev/null +++ b/src/main/java/com/example/model/Model.java @@ -0,0 +1,188 @@ +package com.example.model; + +import com.example.auxiliary.LastRunMetrics; +import com.example.model.format.ChatFormat; +import com.example.inference.InferenceEngine; +import com.example.inference.sampler.Sampler; +import com.example.Options; +import com.example.loader.weights.State; +import com.example.loader.weights.Weights; +import com.example.tokenizer.impl.Tokenizer; +import com.example.tornadovm.TornadoVMMasterPlan; + +import java.util.ArrayList; +import java.util.List; +import java.util.Scanner; +import java.util.Set; +import java.util.function.IntConsumer; + +import static com.example.LlamaApp.SHOW_PERF_INTERACTIVE; +import static com.example.LlamaApp.USE_TORNADOVM; + +public interface Model { + Configuration configuration(); + + Tokenizer tokenizer(); + + Weights weights(); + + ModelType getModelType(); + + State createNewState(); + + State createNewState(int batchsize); + + /** + * Model agnostic default implementation for interactive mode. + * @param sampler + * @param options + */ + default void runInteractive(Sampler sampler, Options options) { + State state = null; + List conversationTokens = new ArrayList<>(); + + ChatFormat chatFormat = ChatFormat.create(tokenizer()); + conversationTokens.add(chatFormat.getBeginOfText()); + + if (options.systemPrompt() != null) { + conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); + } + + int startPosition = 0; + Scanner in = new Scanner(System.in); + + // Initialize TornadoVM plan once at the beginning if GPU path is enabled + TornadoVMMasterPlan tornadoVMPlan = null; + + try { + while (true) { + System.out.print("> "); + System.out.flush(); + String userText = in.nextLine(); + if (List.of("quit", "exit").contains(userText)) { + break; + } + if (state == null) { + // State allocation can take some time for large context sizes, + // allocate the model state only after printing the user '>' prompt. + state = createNewState(); + } + + if (USE_TORNADOVM && tornadoVMPlan == null) { + tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this); + } + + conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText))); + conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + Set stopTokens = chatFormat.getStopTokens(); + + List responseTokens; + IntConsumer tokenConsumer = token -> { + if (options.stream()) { + if (tokenizer().shouldDisplayToken(token)) { + System.out.print(tokenizer().decode(List.of(token))); + } + } + }; + + // Choose between GPU and CPU path based on configuration + if (USE_TORNADOVM) { + // GPU path using TornadoVM + responseTokens = InferenceEngine.generateTokensGPU(this, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, + options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan); + } else { + // CPU path + responseTokens = InferenceEngine.generateTokens(this, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), + sampler, options.echo(), tokenConsumer); + } + + // Include stop token in the prompt history, but not in the response displayed to the user. + conversationTokens.addAll(responseTokens); + startPosition = conversationTokens.size(); + Integer stopToken = null; + if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { + stopToken = responseTokens.getLast(); + responseTokens.removeLast(); + } + if (!options.stream()) { + String responseText = tokenizer().decode(responseTokens); + System.out.println(responseText); + } + if (stopToken == null) { + System.err.println("\n Ran out of context length...\n Increase context length with by passing to llama-tornado --max-tokens XXX"); + break; + } + System.out.print("\n"); + + // Optionally print performance metrics after each response + if (SHOW_PERF_INTERACTIVE) { + LastRunMetrics.printMetrics(); + } + } + } finally { + // Clean up TornadoVM resources when exiting the chat loop + if (USE_TORNADOVM && tornadoVMPlan != null) { + try { + tornadoVMPlan.freeTornadoExecutionPlan(); + } catch (Exception e) { + System.err.println("Error while cleaning up TornadoVM resources: " + e.getMessage()); + } + } + } + } + + /** + * Model agnostic default implementation for instruct mode. + * @param sampler + * @param options + */ + default void runInstructOnce(Sampler sampler, Options options) { + State state = createNewState(); + ChatFormat chatFormat = ChatFormat.create(tokenizer()); + TornadoVMMasterPlan tornadoVMPlan = null; + + List promptTokens = new ArrayList<>(); + promptTokens.add(chatFormat.getBeginOfText()); + + if (options.systemPrompt() != null) { + promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); + } + promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt()))); + promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + + List responseTokens; + + IntConsumer tokenConsumer = token -> { + if (options.stream()) { + if (tokenizer().shouldDisplayToken(token)) { + System.out.print(tokenizer().decode(List.of(token))); + } + } + }; + + Set stopTokens = chatFormat.getStopTokens(); + + if (USE_TORNADOVM) { + tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this); + // Call generateTokensGPU without the token consumer parameter + responseTokens = InferenceEngine.generateTokensGPU(this, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, + tornadoVMPlan); + } else { + responseTokens = InferenceEngine.generateTokens(this, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer); + } + + if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { + responseTokens.removeLast(); + } + if (!options.stream()) { + String responseText = tokenizer().decode(responseTokens); + System.out.println(responseText); + } + + LastRunMetrics.printMetrics(); + + if (tornadoVMPlan != null) { + tornadoVMPlan.freeTornadoExecutionPlan(); + } + } +} diff --git a/src/main/java/com/example/model/ModelType.java b/src/main/java/com/example/model/ModelType.java new file mode 100644 index 00000000..04268243 --- /dev/null +++ b/src/main/java/com/example/model/ModelType.java @@ -0,0 +1,33 @@ +package com.example.model; + +import com.example.core.model.GGUF; +import com.example.model.llama.Llama; +import com.example.model.mistral.Mistral; + +import java.nio.channels.FileChannel; + +public enum ModelType { + LLAMA_3 { + @Override + public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { + return Llama.loadModel(fileChannel, gguf, contextLength, loadWeights); + } + }, + + MISTRAL { + @Override + public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { + return Mistral.loadModel(fileChannel, gguf, contextLength, loadWeights); + } + }, + + UNKNOWN { + @Override + public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { + throw new UnsupportedOperationException("Cannot load unknown model type"); + } + }; + + // Abstract method that each enum constant must implement + public abstract Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights); +} diff --git a/src/main/java/com/example/model/format/ChatFormat.java b/src/main/java/com/example/model/format/ChatFormat.java new file mode 100644 index 00000000..b114c7dd --- /dev/null +++ b/src/main/java/com/example/model/format/ChatFormat.java @@ -0,0 +1,64 @@ +package com.example.model.format; + +import com.example.tokenizer.impl.LlamaTokenizer; +import com.example.tokenizer.impl.MistralTokenizer; + +import java.util.List; +import java.util.Set; + +public interface ChatFormat { + + static ChatFormat create(Object tokenizer) { + if (tokenizer instanceof LlamaTokenizer llamaTokenizer) { + return new LlamaChatFormat(llamaTokenizer); + } else if (tokenizer instanceof MistralTokenizer mistralTokenizer) { + return new MistralChatFormat(mistralTokenizer); + } else { + throw new IllegalArgumentException("Unsupported tokenizer type: " + tokenizer.getClass().getName()); + } + } + + List encodeHeader(Message message); + + List encodeMessage(Message message); + + int getBeginOfText(); + + Set getStopTokens(); + + /** + * Represents a single message in a LLM chat session. + * + * Each message is associated with a specific role (system, user, or assistant) + * and contains the textual content of that message. + * + * @param role the participant who issued the message (SYSTEM, USER, or ASSISTANT). + * @param content the textual content of the message + */ + record Message(Role role, String content) { + } + + /** + * Represents the role of a participant in a LLM chat conversation + * + * There are three standard roles: + *

    + *
  • SYSTEM - sets the behavior and context of the assistant at the start of the conversation.
  • + *
  • USER - represents input from the human user.
  • + *
  • ASSISTANT - represents output from the AI assistant.
  • + *
+ * + * @param name the string representation of the role + */ + record Role(String name) { + public static Role SYSTEM = new Role("system"); + public static Role USER = new Role("user"); + public static Role ASSISTANT = new Role("assistant"); + + @Override + public String toString() { + return name; + } + } + +} \ No newline at end of file diff --git a/src/main/java/com/example/model/format/LlamaChatFormat.java b/src/main/java/com/example/model/format/LlamaChatFormat.java new file mode 100644 index 00000000..03914b44 --- /dev/null +++ b/src/main/java/com/example/model/format/LlamaChatFormat.java @@ -0,0 +1,73 @@ +package com.example.model.format; + +import com.example.tokenizer.impl.LlamaTokenizer; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class LlamaChatFormat implements ChatFormat { + + protected final LlamaTokenizer tokenizer; + protected final int beginOfText; + protected final int endHeader; + protected final int startHeader; + protected final int endOfTurn; + protected final int endOfText; + protected final int endOfMessage; + protected final Set stopTokens; + + public LlamaChatFormat(LlamaTokenizer tokenizer) { + this.tokenizer = tokenizer; + Map specialTokens = tokenizer.getSpecialTokens(); + this.beginOfText = specialTokens.get("<|begin_of_text|>"); + this.startHeader = specialTokens.get("<|start_header_id|>"); + this.endHeader = specialTokens.get("<|end_header_id|>"); + this.endOfTurn = specialTokens.get("<|eot_id|>"); + this.endOfText = specialTokens.get("<|end_of_text|>"); + this.endOfMessage = specialTokens.getOrDefault("<|eom_id|>", -1); // only in 3.1 + this.stopTokens = Set.of(endOfText, endOfTurn); + } + + @Override + public int getBeginOfText() { + return beginOfText; + } + + @Override + public Set getStopTokens() { + return stopTokens; + } + + @Override + public List encodeHeader(Message message) { + List tokens = new ArrayList<>(); + tokens.add(startHeader); + tokens.addAll(tokenizer.encodeAsList(message.role().name())); + tokens.add(endHeader); + tokens.addAll(tokenizer.encodeAsList("\n")); + return tokens; + } + + @Override + public List encodeMessage(Message message) { + List tokens = encodeHeader(message); + tokens.addAll(tokenizer.encodeAsList(message.content().strip())); + tokens.add(endOfTurn); + return tokens; + } + + public List encodeDialogPrompt(boolean appendAssistantTurn, List dialog) { + List tokens = new ArrayList<>(); + tokens.add(beginOfText); + for (Message message : dialog) { + tokens.addAll(encodeMessage(message)); + } + if (appendAssistantTurn) { + // Add the start of an assistant message for the model to complete. + tokens.addAll(encodeHeader(new Message(ChatFormat.Role.ASSISTANT, ""))); + } + return tokens; + } +} \ No newline at end of file diff --git a/src/main/java/com/example/model/format/MistralChatFormat.java b/src/main/java/com/example/model/format/MistralChatFormat.java new file mode 100644 index 00000000..d0d57a3a --- /dev/null +++ b/src/main/java/com/example/model/format/MistralChatFormat.java @@ -0,0 +1,80 @@ +package com.example.model.format; + +import com.example.tokenizer.impl.MistralTokenizer; + +import java.util.*; + +public class MistralChatFormat implements ChatFormat { + + protected final MistralTokenizer tokenizer; + protected final int unknownToken; + protected final int beginOfText; + protected final int endOfText; + protected final int beginOfInstruction; + protected final int endOfInstruction; + protected final int toolCalls; + protected final int beginOfAvailableTools; + protected final int endOfAvailableTools; + protected final int beginOfToolResults; + protected final int endOfToolResults; + protected final int prefix; + protected final int middle; + protected final int suffix; + + public MistralChatFormat(MistralTokenizer tokenizer) { + this.tokenizer = tokenizer; + Map specialTokens = tokenizer.getSpecialTokens(); + this.unknownToken = specialTokens.get(""); + this.beginOfText = specialTokens.get(""); + this.endOfText = specialTokens.get(""); + this.beginOfInstruction = specialTokens.get("[INST]"); + this.endOfInstruction = specialTokens.get("[/INST]"); + this.toolCalls = specialTokens.get("[TOOL_CALLS]"); + this.beginOfAvailableTools = specialTokens.get("[AVAILABLE_TOOLS]"); + this.endOfAvailableTools = specialTokens.get("[/AVAILABLE_TOOLS]"); + this.beginOfToolResults = specialTokens.get("[TOOL_RESULTS]"); + this.endOfToolResults = specialTokens.get("[/TOOL_RESULTS]"); + // Only Codestral supports FIM tokens. + this.prefix = specialTokens.getOrDefault("[PREFIX]", unknownToken); + this.suffix = specialTokens.getOrDefault("[SUFFIX]", unknownToken); + this.middle = specialTokens.getOrDefault("[MIDDLE]", unknownToken); + } + + @Override + public int getBeginOfText() { + return beginOfText; + } + + @Override + public Set getStopTokens() { + return Set.of(endOfText); + } + + @Override + public List encodeHeader(Message message) { + List tokens = new ArrayList<>(); + tokens.add(beginOfInstruction); + tokens.addAll(tokenizer.encodeAsList(message.role().name())); + tokens.add(endOfInstruction); + return tokens; + } + + @Override + public List encodeMessage(Message message) { + List tokens = encodeHeader(message); + tokens.addAll(tokenizer.encodeAsList(message.content().strip())); + tokens.add(endOfInstruction); + return tokens; + } + + public List encodeFillInTheMiddle(String prefix, String suffix) { + List tokens = new ArrayList<>(); + // dummy - empty string set to comply with encode method signature. + final Set EMPTY_STRING_SET = Collections.emptySet(); + tokens.add(this.suffix); + tokens.addAll(tokenizer.encode(suffix, EMPTY_STRING_SET)); + tokens.add(this.prefix); + tokens.addAll(tokenizer.encode(prefix, EMPTY_STRING_SET)); + return tokens; + } +} diff --git a/src/main/java/com/example/model/llama/Llama.java b/src/main/java/com/example/model/llama/Llama.java new file mode 100644 index 00000000..c8326a1f --- /dev/null +++ b/src/main/java/com/example/model/llama/Llama.java @@ -0,0 +1,84 @@ +package com.example.model.llama; + +import com.example.auxiliary.Timer; +import com.example.core.model.GGUF; +import com.example.core.model.tensor.GGMLTensorEntry; +import com.example.model.Model; +import com.example.loader.weights.State; +import com.example.loader.weights.Weights; +import com.example.model.ModelType; +import com.example.tokenizer.impl.LlamaTokenizer; +import com.example.tokenizer.impl.Tokenizer; +import com.example.tokenizer.vocabulary.Vocabulary; + +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.util.Map; + +import static com.example.loader.weights.ModelLoader.loadWeights; + +public record Llama(LlamaConfiguration configuration, Tokenizer tokenizer, Weights weights) implements Model { + private static final int BATCH_SIZE = Integer.getInteger("llama.BatchSize", 16); + + /* For explicit use */ + private LlamaTokenizer getAsLlamaTokenizer() { + return (LlamaTokenizer) tokenizer; + } + + @Override + public ModelType getModelType() { + return ModelType.LLAMA_3; + } + + @Override + public State createNewState() { + State state = new State(configuration(), -1); + state.latestToken = tokenizer.getSpecialTokens().get("<|begin_of_text|>"); + return state; + } + + @Override + public State createNewState(int batchsize) { + State state = new State(configuration(), batchsize); + state.latestToken = tokenizer.getSpecialTokens().get("<|begin_of_text|>"); + return state; + } + + // @formatter:off + public static Llama loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { + try (var ignored = Timer.log("Load LlaMa model")) { + Map metadata = gguf.getMetadata(); + + Vocabulary vocabulary = Vocabulary.loadLlamaVocabulary(metadata); + Tokenizer tokenizer = new LlamaTokenizer(metadata, vocabulary); + + LlamaConfiguration config = new LlamaConfiguration( + (int) metadata.get("llama.embedding_length"), + (int) metadata.get("llama.feed_forward_length"), + (int) metadata.get("llama.block_count"), + (int) metadata.get("llama.attention.head_count"), + + metadata.containsKey("llama.attention.head_count_kv") ? + (int) metadata.get("llama.attention.head_count_kv") : + (int) metadata.get("llama.attention.head_count"), + + vocabulary.size(), + (int) metadata.get("llama.context_length"), + (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), + (float) metadata.getOrDefault("llama.rope.freq_base", 10000f) + ).withContextLength(contextLength); + + Weights weights = null; + if (loadWeights) { + Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); + weights = loadWeights(tensorEntries, config); + } + return new Llama(config, tokenizer, weights); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + // @formatter:on + +} + diff --git a/src/main/java/com/example/model/llama/LlamaConfiguration.java b/src/main/java/com/example/model/llama/LlamaConfiguration.java new file mode 100644 index 00000000..53195d11 --- /dev/null +++ b/src/main/java/com/example/model/llama/LlamaConfiguration.java @@ -0,0 +1,48 @@ +package com.example.model.llama; + +import com.example.model.Configuration; + +public record LlamaConfiguration(int dim, int hiddenDim, int numberOfLayers, int numberOfHeads, int numberOfKeyValueHeads, int vocabularySize, int contextLength, float rmsNormEps, float ropeTheta) + implements Configuration { + + public int headSize() { + return dim / numberOfHeads; + } + + /** Key/value dimension (derived from dim * numberOfKeyValueHeads / numberOfHeads) */ + public int kvDim() { + return dim * numberOfKeyValueHeads / numberOfHeads; + } + + /** Multiplier for key/value sharing in multi-query attention */ + public int kvMul() { + return numberOfHeads / numberOfKeyValueHeads; + } + + /** + * Creates a new Configuration with a different context length. + * + * @param newContextLength The new context length to use + * @return A new Configuration instance with updated context length, + * or the current instance if newContextLength is negative + */ + // @formatter:off + public LlamaConfiguration withContextLength(int newContextLength) { + if (newContextLength < 0) { + return this; // no change + } + return new LlamaConfiguration( + this.dim, + this.hiddenDim, + this.numberOfLayers, + this.numberOfHeads, + this.numberOfKeyValueHeads, + this.vocabularySize, + newContextLength, + this.rmsNormEps, + this.ropeTheta + ); + } + // @formatter:on +} + diff --git a/src/main/java/com/example/model/mistral/Mistral.java b/src/main/java/com/example/model/mistral/Mistral.java new file mode 100644 index 00000000..c1f9d09c --- /dev/null +++ b/src/main/java/com/example/model/mistral/Mistral.java @@ -0,0 +1,86 @@ +package com.example.model.mistral; + +import com.example.auxiliary.Timer; +import com.example.core.model.GGUF; +import com.example.core.model.tensor.GGMLTensorEntry; +import com.example.model.Model; +import com.example.loader.weights.State; +import com.example.loader.weights.Weights; +import com.example.model.ModelType; +import com.example.tokenizer.impl.MistralTokenizer; +import com.example.tokenizer.impl.Tokenizer; +import com.example.tokenizer.vocabulary.Vocabulary; + +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.util.Map; + +import static com.example.loader.weights.ModelLoader.loadWeights; + +public record Mistral(MistralConfiguration configuration, Tokenizer tokenizer, Weights weights) implements Model { + + /* For explicit use */ + private MistralTokenizer getAsMistralTokenizer() { + return (MistralTokenizer) tokenizer; + } + + @Override + public ModelType getModelType() { + return ModelType.MISTRAL; + } + + public State createNewState() { + State state = new State(configuration(), -1); + state.latestToken = tokenizer.getSpecialTokens().get(""); + return state; + } + + public State createNewState(int batchsize) { + State state = new State(configuration(), batchsize); + state.latestToken = tokenizer.getSpecialTokens().get(""); + return state; + } + + // @formatter:off + public static Mistral loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { + try (var ignored = Timer.log("Load Mistral model")) { + Map metadata = gguf.getMetadata(); + + Vocabulary vocabulary = Vocabulary.loadMistralVocabulary(metadata); + Tokenizer tokenizer = new MistralTokenizer(metadata, vocabulary); + + int modelContextLength = (int) metadata.get("llama.context_length"); + if (contextLength < 0 || modelContextLength < contextLength) { + contextLength = modelContextLength; + } + + MistralConfiguration config = new MistralConfiguration( + (int) metadata.get("llama.embedding_length"), + (int) metadata.get("llama.feed_forward_length"), + (int) metadata.get("llama.block_count"), + (int) metadata.get("llama.attention.head_count"), + + metadata.containsKey("llama.attention.head_count_kv") + ? (int) metadata.get("llama.attention.head_count_kv") + : (int) metadata.get("llama.attention.head_count"), + + vocabulary.size(), + contextLength, + false, + (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), + (float) metadata.getOrDefault("llama.rope.freq_base", 10000f) + ); + + Weights weights = null; + if (loadWeights) { + Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); + weights = loadWeights(tensorEntries, config); + } + return new Mistral(config, tokenizer, weights); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + // @formatter:on + +} diff --git a/src/main/java/com/example/model/mistral/MistralConfiguration.java b/src/main/java/com/example/model/mistral/MistralConfiguration.java new file mode 100644 index 00000000..d2c14d42 --- /dev/null +++ b/src/main/java/com/example/model/mistral/MistralConfiguration.java @@ -0,0 +1,20 @@ +package com.example.model.mistral; + +import com.example.model.Configuration; + +public record MistralConfiguration(int dim, int hiddenDim, int numberOfLayers, int numberOfHeads, int numberOfKeyValueHeads, int vocabularySize, int contextLength, boolean sharedWeights, + float rmsNormEps, float ropeTheta) implements Configuration { + + public int kvDim() { + return dim * numberOfKeyValueHeads / numberOfHeads; + } + + public int kvMul() { + return numberOfHeads / numberOfKeyValueHeads; + } + + public int headSize() { + return dim / numberOfHeads; + } +} + diff --git a/src/main/java/com/example/tokenizer/impl/LlamaTokenizer.java b/src/main/java/com/example/tokenizer/impl/LlamaTokenizer.java new file mode 100644 index 00000000..e429bb6b --- /dev/null +++ b/src/main/java/com/example/tokenizer/impl/LlamaTokenizer.java @@ -0,0 +1,278 @@ +package com.example.tokenizer.impl; + +import com.example.core.types.Pair; +import com.example.tokenizer.vocabulary.Vocabulary; + +import java.nio.charset.StandardCharsets; +import java.util.*; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** + * GPT-2-style BPE tokenizer (even though it's called "llama") with an explicit merges list. + *

+ * BPE (Byte Pair Encoding): + * A sub-word tokenization algorithm that iteratively merges the most frequent pairs of symbols in a corpus to build a vocabulary of common character sequences. + *

+ * GPT-2-style tokenization: + * Applies BPE at the byte level, ensuring all UTF-8 inputs are representable and using tokens that preserve leading spaces (e.g., 'Ġthe'). + *

+ * Explicit merges list: + * A fixed sequence of learned merge rules that deterministically reconstructs the tokenizer’s vocabulary during inference without retraining. + *

+ * Based on minbpe, algorithmically follows along the + * GPT 2 tokenizer + */ +public class LlamaTokenizer implements Tokenizer { + private static final String LLAMA_3_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; + // general fields + private final Pattern compiledPattern; + private final Vocabulary vocabulary; + // model-specific fields + private final Map, Integer> merges; + private final Map specialTokens; + + public String regexPattern() { + if (compiledPattern == null) { + return null; + } + return compiledPattern.pattern(); + } + + @Override + public Map getSpecialTokens() { + return specialTokens; + } + + @Override + public boolean isSpecialToken(int tokenIndex) { + return specialTokens.containsValue(tokenIndex); + } + + @Override + public boolean shouldDisplayToken(int token) { + return !isSpecialToken(token); + } + + public LlamaTokenizer(Map metadata, Vocabulary vocabulary) { + // load from metadata + String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges"); + List> merges = Arrays.stream(mergeLines).map(line -> line.split(" ")) + .map(parts -> new Pair<>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).toList(); + int allTokens = vocabulary.size(); + int baseTokens = 128000; // assume all tokens after the base ones are special. + int reservedSpecialTokens = allTokens - baseTokens; + List specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList(); + + assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent()); + + Map specialTokens = IntStream.range(0, specialTokensList.size()).boxed().collect(Collectors.toMap(i -> specialTokensList.get(i), i -> baseTokens + i)); + + // init tokenizer object fields + this.vocabulary = vocabulary; + this.compiledPattern = Pattern.compile(LLAMA_3_PATTERN); + this.specialTokens = new HashMap<>(specialTokens); + this.merges = new HashMap<>(); + for (Pair pair : merges) { + int firstIndex = pair.first(); + int secondIndex = pair.second(); + int mergeIndex = vocabulary.getIndex(vocabulary.get(firstIndex) + vocabulary.get(secondIndex)).orElseThrow(); + this.merges.put(pair, mergeIndex); + } + } + + private int[] encodeImpl(String text) { + return encode(text, Set.of()).stream().mapToInt(i -> i).toArray(); + } + + /** + * Unlike {@link #encodeOrdinary(String)}, this function handles special tokens. + * allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens + * if none_raise, then an error is raised if any special token is encountered in text + * this is the default tiktoken behavior right now as well + * any other behavior is either annoying, or a major footgun. + */ + public List encode(String text, Set allowedSpecial) { + // decode the user desire w.r.t. handling of special tokens + Set special = allowedSpecial; + assert getSpecialTokens().keySet().containsAll(special); + if (special.isEmpty()) { + // shortcut: if no special tokens, just use the ordinary encoding + return encodeOrdinary(text); + } + + // otherwise, we have to be careful with potential special tokens in text + // we handle special tokens by splitting the text + // based on the occurrence of any exact match with any of the special tokens + // we can use re.split for this. note that surrounding the pattern with () + // makes it into a capturing group, so the special tokens will be included + String specialPattern = special + .stream() + .map(Pattern::quote) + .collect(Collectors.joining("|", "(", ")")); + + String[] specialChunks = text.split(specialPattern); + // now all the special characters are separated from the rest of the text + // all chunks of text are encoded separately, then results are joined + List ids = new ArrayList<>(); + for (String part : specialChunks) { + if (special.contains(part)) { + // this is a special token, encode it separately as a special case + ids.add(getSpecialTokens().get(part)); + } else { + // this is an ordinary sequence, encode it normally + ids.addAll(encodeOrdinary(part)); + } + } + return ids; + } + + private static List findAll(Pattern pattern, String text) { + List allMatches = new ArrayList<>(); + Matcher matcher = pattern.matcher(text); + while (matcher.find()) { + allMatches.add(matcher.group()); + } + return allMatches; + } + + /** + * Encoding that ignores any special tokens. + */ + public List encodeOrdinary(String text) { + // split text into chunks of text by categories defined in regex pattern + List textChunks = findAll(compiledPattern, text); + // all chunks of text are encoded separately, then results are joined + List ids = new ArrayList<>(); + for (String chunk : textChunks) { + List chunkIds = encodeChunk(chunk); + ids.addAll(chunkIds); + } + return ids; + } + + private Map, Integer> getStats(List ids) { + Map, Integer> map = new HashMap<>(); + for (int i = 0; i + 1 < ids.size(); i++) { + Pair key = new Pair<>(ids.get(i), ids.get(i + 1)); + map.put(key, map.getOrDefault(key, 0) + 1); + } + return map; + } + + private List encodeChunk(String chunk) { + // return the token ids + // let's begin. first, convert all bytes to integers in range 0..255 + List ids = new ArrayList<>(); + for (int b : chunk.toCharArray()) { + int tokenIndex = this.vocabulary.getIndex(String.valueOf((char) b)).orElseThrow(); + ids.add(tokenIndex); + } + + while (ids.size() >= 2) { + // find the pair with the lowest merge index + Map, Integer> stats = getStats(ids); + Pair pair = stats.keySet().stream().min(Comparator.comparingInt(key -> this.merges.getOrDefault(key, Integer.MAX_VALUE))).orElseThrow(); + // subtle: if there are no more merges available, the key will + // result in an inf for every single pair, and the min will be + // just the first pair in the list, arbitrarily + // we can detect this terminating case by a membership check + if (!this.merges.containsKey(pair)) { + break; // nothing else can be merged anymore + } + // otherwise let's merge the best pair (lowest merge index) + int idx = this.merges.get(pair); + ids = merge(ids, pair, idx); + } + return ids; + } + + private static List merge(List ids, Pair pair, int idx) { + List newids = new ArrayList<>(); + int i = 0; + while (i < ids.size()) { + // if not at the very last position AND the pair matches, replace it + if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) { + newids.add(idx); + i += 2; + } else { + newids.add(ids.get(i)); + i += 1; + } + } + return newids; + } + + public String decodeImpl(List tokens) { + StringBuilder sb = new StringBuilder(); + for (int token : tokens) { + String tokenString = vocabulary.get(token); + sb.append(tokenString); + } + return sb.toString(); + } + + /** + * Returns list of utf-8 byte and a corresponding list of unicode strings. + * The reversible bpe codes work on unicode strings. + * This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + * When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + * This is a significant percentage of your normal, say, 32K bpe vocab. + * To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + * And avoids mapping to whitespace/control characters the bpe code barfs on. + */ + private static Map bytesToUnicode() { + List bs = new ArrayList<>(); + IntStream.rangeClosed('!', '~').forEach(bs::add); + IntStream.rangeClosed('¡', '¬').forEach(bs::add); + IntStream.rangeClosed('®', 'ÿ').forEach(bs::add); + + List cs = new ArrayList<>(bs); + int n = 0; + for (int b = 0; b < 256; ++b) { + if (!bs.contains(b)) { + bs.add(b); + cs.add(256 + n); + n += 1; + } + } + + // return dict(zip(bs, cs)) + return IntStream.range(0, bs.size()).boxed().collect(Collectors.toMap(bs::get, cs::get)); + } + + static final Map BYTE_ENCODER = bytesToUnicode(); + static final Map BYTE_DECODER = BYTE_ENCODER.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); + + public int[] encode(String text) { + StringBuilder sb = new StringBuilder(); + byte[] bytes = text.getBytes(StandardCharsets.UTF_8); + for (byte b : bytes) { + sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b))); + } + return encodeImpl(sb.toString()); + } + + @Override + public List encodeAsList(String text) { + StringBuilder sb = new StringBuilder(); + byte[] bytes = text.getBytes(StandardCharsets.UTF_8); + for (byte b : bytes) { + sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b))); + } + return Arrays.stream(encodeImpl(sb.toString())).boxed().toList(); + } + + @Override + public String decode(List tokens) { + String decoded = decodeImpl(tokens); + int[] decodedBytesAsInts = decoded.codePoints().map(BYTE_DECODER::get).toArray(); + byte[] rawBytes = new byte[decodedBytesAsInts.length]; + for (int i = 0; i < decoded.length(); i++) { + rawBytes[i] = (byte) decodedBytesAsInts[i]; + } + return new String(rawBytes, StandardCharsets.UTF_8); + } +} diff --git a/src/main/java/com/example/tokenizer/impl/MistralTokenizer.java b/src/main/java/com/example/tokenizer/impl/MistralTokenizer.java new file mode 100644 index 00000000..9fba5b5d --- /dev/null +++ b/src/main/java/com/example/tokenizer/impl/MistralTokenizer.java @@ -0,0 +1,175 @@ +package com.example.tokenizer.impl; + +import com.example.tokenizer.vocabulary.Vocabulary; + +import java.nio.charset.StandardCharsets; +import java.util.*; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** + * TikToken-style BPE tokenizer with byte fallback. + *

+ * TikToken-style: + * A Byte Pair Encoding (BPE) strategy that converts text to UTF-8 bytes. + * Frequent pairs of bytes (or tokens) are merged according to a learned vocabulary. + * This reduces long words into common subwords or whole-word tokens. + * If a word or character isn't found, it falls back to byte-level tokens. + *

+ * Byte fallback: + * A fail-safe mechanism. + * It ensures every byte has a token, so any input (even unknown words, misspellings, foreign languages, emojis, or binary) can be tokenized. + * If a token is not found in the merges or vocabulary, it will fall back to the individual byte. + * Each byte is wrapped as a special token like <0xF0> — these are part of the tokenizer’s extended vocabulary. + * This guarantees reversibility: every string can be tokenized and decoded back exactly. + */ +public class MistralTokenizer implements Tokenizer { + private static final String MISTRAL_PATTERN = "\\S+|\\s+"; + // general fields + private final Pattern compiledPattern; + private final Vocabulary vocabulary; + // model-specific fields + private final Map specialTokens; + private final int[] tokenType; + private final int byte0; + + public String regexPattern() { + if (compiledPattern == null) { + return null; + } + return compiledPattern.pattern(); + } + + @Override + public Map getSpecialTokens() { + return specialTokens; + } + + @Override + public boolean isSpecialToken(int tokenIndex) { + return getTokenType(tokenIndex) != 1; + } + + @Override + public boolean shouldDisplayToken(int token) { + int tokenType = getTokenType(token); + return tokenType == 1 || tokenType == 6; + } + + public int getTokenType(int tokenIndex) { + return tokenType[tokenIndex]; + } + + // @formatter:off + public MistralTokenizer(Map metadata, Vocabulary vocabulary) { + // load from metadata + int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type"); + List specialTokensList = IntStream.range(0, vocabulary.size()).filter(t -> tokenTypes[t] != 1 && tokenTypes[t] != 6).boxed().toList(); + Map specialTokens = + IntStream.range(0, specialTokensList.size()) + .boxed() + .collect(Collectors.toMap( + t -> vocabulary.get(t), + t -> t) + ); + // init tokenizer object fields + this.vocabulary = vocabulary; + this.compiledPattern = null; + this.specialTokens = new HashMap<>(specialTokens); + this.tokenType = tokenTypes; + this.byte0 = vocabulary.getIndex("<0x00>").orElseThrow(); + } + // @formatter:on + + private List encodeImpl(String text) { + + List tokens = new ArrayList<>(); + + // first encode every individual codepoint in the input string + for (int i = 0, cpi; i < text.length(); i += Character.charCount(cpi)) { + cpi = text.codePointAt(i); + + String singleCodepoint = Character.toString(cpi); + int id = vocabulary.getIndex(singleCodepoint).orElse(-1); + + if (id != -1) { + // we found this codepoint in vocab, add it as a token + tokens.add(id); + } else { + // byte_fallback encoding: just encode each byte as a token + // +byte0 here to skip all the control and special tokens e.g. , , + // so the individual bytes only start at token <0x00> + for (byte b : singleCodepoint.getBytes(StandardCharsets.UTF_8)) { + tokens.add(Byte.toUnsignedInt(b) + byte0); + } + } + } + + // merge the best consecutive pair each iteration, according the scores in vocab_scores + while (true) { + float best_score = -1e10f; + int best_id = -1; + int best_idx = -1; + + for (int i = 0; i < tokens.size() - 1; ++i) { + // check if we can merge the pair (tokens[i], tokens[i+1]) + String str_buffer = vocabulary.get(tokens.get(i)) + vocabulary.get(tokens.get(i + 1)); + int id = vocabulary.getIndex(str_buffer).orElse(-1); + if (id != -1 && vocabulary.getScore(id) > best_score) { + // this merge pair exists in vocab! record its score and position + best_score = vocabulary.getScore(id); + best_id = id; + best_idx = i; + } + } + + if (best_idx == -1) { + break; // we couldn't find any more pairs to merge, so we're done + } + + // merge the consecutive pair (best_idx, best_idx+1) into new token best_id + tokens.set(best_idx, best_id); + tokens.remove(best_idx + 1); + } + + return tokens; + } + + /** + * Modified original signature from mistral.java: List encode(String text); + */ + @Override + public List encode(String text, Set allowedSpecial) { + return encodeImpl(text.replace(' ', '▁')); + } + + @Override + public List encodeAsList(String text) { + // pass an empty set to comply with method signature. + return encode(text, Collections.emptySet()); + } + + @Override + public String decode(List tokens) { + StringBuilder sb = new StringBuilder(); + for (int token : tokens) { + String tokenString = vocabulary.get(token); + if (isSpecialToken(token)) { + // some tokens designate raw bytes e.g. '<0x10>' + String prefix = "<0x"; + String suffix = ">"; + if (tokenString.length() == 6 && tokenString.startsWith(prefix) && tokenString.endsWith(suffix)) { + String code = tokenString.substring(prefix.length(), tokenString.length() - suffix.length()); + int cp = Integer.parseInt(code, 16); + tokenString = Character.toString(cp); + } + } else { + tokenString = tokenString.replace('▁', ' '); + + } + sb.append(tokenString); + } + return sb.toString(); + } +} diff --git a/src/main/java/com/example/tokenizer/impl/Tokenizer.java b/src/main/java/com/example/tokenizer/impl/Tokenizer.java index 4cfd8b8c..21915197 100644 --- a/src/main/java/com/example/tokenizer/impl/Tokenizer.java +++ b/src/main/java/com/example/tokenizer/impl/Tokenizer.java @@ -1,233 +1,34 @@ package com.example.tokenizer.impl; -import com.example.core.types.Pair; -import com.example.tokenizer.vocabulary.Vocabulary; - -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Comparator; -import java.util.HashMap; import java.util.HexFormat; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -public class Tokenizer { - private final Pattern compiledPattern; - private final Vocabulary vocabulary; - private final Map, Integer> merges; - private final Map specialTokens; - - public String regexPattern() { - if (compiledPattern == null) { - return null; - } - return compiledPattern.pattern(); - } - - public Map getSpecialTokens() { - return specialTokens; - } - - public boolean isSpecialToken(int tokenIndex) { - return specialTokens.containsValue(tokenIndex); - } - - public Tokenizer(Vocabulary vocabulary, List> merges, String regexPattern, Map specialTokens) { - this.vocabulary = vocabulary; - this.compiledPattern = regexPattern != null ? Pattern.compile(regexPattern) : null; - this.specialTokens = new HashMap<>(specialTokens); - this.merges = new HashMap<>(); - for (Pair pair : merges) { - int firstIndex = pair.first(); - int secondIndex = pair.second(); - int mergeIndex = vocabulary.getIndex(vocabulary.get(firstIndex) + vocabulary.get(secondIndex)).orElseThrow(); - this.merges.put(pair, mergeIndex); - } - } - - private int[] encodeImpl(String text) { - return encode(text, Set.of()).stream().mapToInt(i -> i).toArray(); - } - - /** - * Unlike {@link #encodeOrdinary(String)}, this function handles special tokens. - * allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens - * if none_raise, then an error is raised if any special token is encountered in text - * this is the default tiktoken behavior right now as well - * any other behavior is either annoying, or a major footgun. - */ - List encode(String text, Set allowedSpecial) { - // decode the user desire w.r.t. handling of special tokens - Set special = allowedSpecial; - assert getSpecialTokens().keySet().containsAll(special); - if (special.isEmpty()) { - // shortcut: if no special tokens, just use the ordinary encoding - return encodeOrdinary(text); - } - - // otherwise, we have to be careful with potential special tokens in text - // we handle special tokens by splitting the text - // based on the occurrence of any exact match with any of the special tokens - // we can use re.split for this. note that surrounding the pattern with () - // makes it into a capturing group, so the special tokens will be included - String specialPattern = special - .stream() - .map(Pattern::quote) - .collect(Collectors.joining("|", "(", ")")); - - String[] specialChunks = text.split(specialPattern); - // now all the special characters are separated from the rest of the text - // all chunks of text are encoded separately, then results are joined - List ids = new ArrayList<>(); - for (String part : specialChunks) { - if (special.contains(part)) { - // this is a special token, encode it separately as a special case - ids.add(getSpecialTokens().get(part)); - } else { - // this is an ordinary sequence, encode it normally - ids.addAll(encodeOrdinary(part)); - } - } - return ids; - } - - private static List findAll(Pattern pattern, String text) { - List allMatches = new ArrayList<>(); - Matcher matcher = pattern.matcher(text); - while (matcher.find()) { - allMatches.add(matcher.group()); - } - return allMatches; - } - - /** - * Encoding that ignores any special tokens. - */ - public List encodeOrdinary(String text) { - // split text into chunks of text by categories defined in regex pattern - List textChunks = findAll(compiledPattern, text); - // all chunks of text are encoded separately, then results are joined - List ids = new ArrayList<>(); - for (String chunk : textChunks) { - List chunkIds = encodeChunk(chunk); - ids.addAll(chunkIds); - } - return ids; - } - - private Map, Integer> getStats(List ids) { - Map, Integer> map = new HashMap<>(); - for (int i = 0; i + 1 < ids.size(); i++) { - Pair key = new Pair<>(ids.get(i), ids.get(i + 1)); - map.put(key, map.getOrDefault(key, 0) + 1); - } - return map; - } - private List encodeChunk(String chunk) { - // return the token ids - // let's begin. first, convert all bytes to integers in range 0..255 - List ids = new ArrayList<>(); - for (int b : chunk.toCharArray()) { - int tokenIndex = this.vocabulary.getIndex(String.valueOf((char) b)).orElseThrow(); - ids.add(tokenIndex); - } +public interface Tokenizer { + String regexPattern(); - while (ids.size() >= 2) { - // find the pair with the lowest merge index - Map, Integer> stats = getStats(ids); - Pair pair = stats.keySet().stream().min(Comparator.comparingInt(key -> this.merges.getOrDefault(key, Integer.MAX_VALUE))).orElseThrow(); - // subtle: if there are no more merges available, the key will - // result in an inf for every single pair, and the min will be - // just the first pair in the list, arbitrarily - // we can detect this terminating case by a membership check - if (!this.merges.containsKey(pair)) { - break; // nothing else can be merged anymore - } - // otherwise let's merge the best pair (lowest merge index) - int idx = this.merges.get(pair); - ids = merge(ids, pair, idx); - } - return ids; - } - - private static List merge(List ids, Pair pair, int idx) { - List newids = new ArrayList<>(); - int i = 0; - while (i < ids.size()) { - // if not at the very last position AND the pair matches, replace it - if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) { - newids.add(idx); - i += 2; - } else { - newids.add(ids.get(i)); - i += 1; - } - } - return newids; - } + Map getSpecialTokens(); - public String decodeImpl(List tokens) { - StringBuilder sb = new StringBuilder(); - for (int token : tokens) { - String tokenString = vocabulary.get(token); - sb.append(tokenString); - } - return sb.toString(); - } + boolean isSpecialToken(int tokenIndex); /** - * Returns list of utf-8 byte and a corresponding list of unicode strings. - * The reversible bpe codes work on unicode strings. - * This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - * When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - * This is a significant percentage of your normal, say, 32K bpe vocab. - * To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - * And avoids mapping to whitespace/control characters the bpe code barfs on. + * Determines if a token should be displayed during streaming output. + * This filters out special tokens, control characters, or other non-displayable content. + * + * @param token the token to check + * @return true if the token should be displayed to the user, false otherwise */ - private static Map bytesToUnicode() { - List bs = new ArrayList<>(); - IntStream.rangeClosed('!', '~').forEach(bs::add); - IntStream.rangeClosed('¡', '¬').forEach(bs::add); - IntStream.rangeClosed('®', 'ÿ').forEach(bs::add); - - List cs = new ArrayList<>(bs); - int n = 0; - for (int b = 0; b < 256; ++b) { - if (!bs.contains(b)) { - bs.add(b); - cs.add(256 + n); - n += 1; - } - } + boolean shouldDisplayToken(int token); - // return dict(zip(bs, cs)) - return IntStream.range(0, bs.size()) - .boxed() - .collect(Collectors.toMap(bs::get, cs::get)); - } + List encode(String text, Set allowedSpecial); - static final Map BYTE_ENCODER = bytesToUnicode(); - static final Map BYTE_DECODER = BYTE_ENCODER.entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); + List encodeAsList(String text); - public int[] encode(String text) { - StringBuilder sb = new StringBuilder(); - byte[] bytes = text.getBytes(StandardCharsets.UTF_8); - for (byte b : bytes) { - sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b))); - } - return encodeImpl(sb.toString()); - } + String decode(List tokens); - public static String replaceControlCharacters(int[] codePoints) { + // Utility method for all tokenizers, implemented as static. + static String replaceControlCharacters(int[] codePoints) { // we don't want to print control characters // which distort the output (e.g. \n or much worse) // https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117 @@ -243,22 +44,10 @@ public static String replaceControlCharacters(int[] codePoints) { return chars.toString(); } - public static String replaceControlCharacters(String str) { + // Utility method for all tokenizers, implemented as static. + static String replaceControlCharacters(String str) { return replaceControlCharacters(str.codePoints().toArray()); } - public List encodeAsList(String text) { - return Arrays.stream(encode(text)).boxed().toList(); - } - - public String decode(List tokens) { - String decoded = decodeImpl(tokens); - int[] decodedBytesAsInts = decoded.codePoints().map(BYTE_DECODER::get).toArray(); - byte[] rawBytes = new byte[decodedBytesAsInts.length]; - for (int i = 0; i < decoded.length(); i++) { - rawBytes[i] = (byte) decodedBytesAsInts[i]; - } - return new String(rawBytes, StandardCharsets.UTF_8); - } } diff --git a/src/main/java/com/example/tokenizer/vocabulary/Vocabulary.java b/src/main/java/com/example/tokenizer/vocabulary/Vocabulary.java index 77c793ba..a45c0737 100644 --- a/src/main/java/com/example/tokenizer/vocabulary/Vocabulary.java +++ b/src/main/java/com/example/tokenizer/vocabulary/Vocabulary.java @@ -1,14 +1,15 @@ package com.example.tokenizer.vocabulary; +import java.util.Arrays; import java.util.Map; import java.util.OptionalInt; import java.util.stream.Collectors; import java.util.stream.IntStream; - public record Vocabulary(String[] tokens, float[] scores, Map tokenToIndex) { private static final String TOKENIZER_LLAMA_3_MODEL = "gpt2"; + // @formatter:off public Vocabulary(String[] vocabulary, float[] scores) { this(vocabulary, scores, IntStream.range(0, vocabulary.length) @@ -16,6 +17,7 @@ public Vocabulary(String[] vocabulary, float[] scores) { .collect(Collectors.toMap(i -> vocabulary[i], i -> i)) ); } + // @formatter:on public String get(int tokenIndex) { return tokens[tokenIndex]; @@ -26,16 +28,37 @@ public OptionalInt getIndex(String token) { return value != null ? OptionalInt.of(value) : OptionalInt.empty(); } - public static Vocabulary loadVocabulary(Map metadata) { - String model = (String) metadata.get("tokenizer.ggml.model"); - if (!TOKENIZER_LLAMA_3_MODEL.equals(model)) { - throw new IllegalArgumentException("expected " + TOKENIZER_LLAMA_3_MODEL + " but found " + model); - } + public static Vocabulary loadLlamaVocabulary(Map metadata) { String[] tokens = (String[]) metadata.get("tokenizer.ggml.tokens"); return new Vocabulary(tokens, null); } + public static Vocabulary loadMistralVocabulary(Map metadata) { + String[] tokens = (String[]) metadata.get("tokenizer.ggml.tokens"); + float[] scores = (float[]) metadata.get("tokenizer.ggml.scores"); + Vocabulary v = new Vocabulary(tokens, scores); + return v; + } + public int size() { return tokens.length; } + + /** + * Only for Mistral. + */ + public float getScore(int tokenIndex) { + return scores[tokenIndex]; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("Vocabulary:\n"); + sb.append("Tokens: ").append(Arrays.toString(tokens)).append("\n"); + sb.append("Scores: ").append(Arrays.toString(scores)).append("\n"); + sb.append("Token to Index Map:\n"); + tokenToIndex.forEach((token, index) -> sb.append(" ").append(token).append(" -> ").append(index).append("\n")); + return sb.toString(); + } } \ No newline at end of file diff --git a/src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java b/src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java index cfa19806..77c6b56b 100644 --- a/src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java +++ b/src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java @@ -1,8 +1,8 @@ package com.example.tornadovm; import com.example.auxiliary.Tuple2; -import com.example.inference.engine.impl.Configuration; -import com.example.inference.engine.impl.Llama; +import com.example.model.Configuration; +import com.example.model.Model; import com.example.loader.weights.State; import com.example.loader.weights.Weights; import uk.ac.manchester.tornado.api.GridScheduler; @@ -61,7 +61,7 @@ public class TornadoVMLayerPlanner { * @param model * The Llama model instance containing configuration and weights */ - public TornadoVMLayerPlanner(State state, Llama model) { + public TornadoVMLayerPlanner(State state, Model model) { this.state = state; this.config = model.configuration(); this.weights = model.weights(); @@ -82,80 +82,80 @@ public Tuple2, GridScheduler> setupTornadoForwardPlanLa .persistOnDevice(state.wrapX); taskGraphs.add(activationUpdate.snapshot()); - TaskGraph unifiedLayer = null; - for (int layerIndex =0; layerIndex < config.numberOfLayers; layerIndex++) { - unifiedLayer = new TaskGraph("layer_" + layerIndex); - unifiedLayer.consumeFromDevice(state.wrapX); - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, - //Copy-in weights per layer for batched-layered layout - weights.rms_att_weightLayered[layerIndex], - weights.wqLayered[layerIndex], - weights.wkLayered[layerIndex], - weights.wvLayered[layerIndex], - weights.woLayered[layerIndex], - weights.rms_ffn_weightLayered[layerIndex], - weights.w1Layered[layerIndex], - weights.w2Layered[layerIndex], - weights.w3Layered[layerIndex] + TaskGraph unifiedLayer = null; + for (int layerIndex =0; layerIndex < config.numberOfLayers(); layerIndex++) { + unifiedLayer = new TaskGraph("layer_" + layerIndex); + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + //Copy-in weights per layer for batched-layered layout + weights.rms_att_weightLayered[layerIndex], + weights.wqLayered[layerIndex], + weights.wkLayered[layerIndex], + weights.wvLayered[layerIndex], + weights.woLayered[layerIndex], + weights.rms_ffn_weightLayered[layerIndex], + weights.w1Layered[layerIndex], + weights.w2Layered[layerIndex], + weights.w3Layered[layerIndex] + ); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + unifiedLayer.task("reductionsOneBlock" , TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, + state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) + .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, + state.wrapX, weights.rms_att_weightLayered[layerIndex], state.temp) + .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, + state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, + state.wrapXb, state.wrapK, weights.wkLayered[layerIndex], config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, + state.wrapXb, state.wrapV, weights.wvLayered[layerIndex], config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("rope", TransformerComputeKernelsLayered::ropeRotation,context, + state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), + config.headSize()) + .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, + state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), layerIndex, config.contextLength()) + .task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, context, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), + state.positionHolder, layerIndex, config.contextLength()) + .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, + state.wrapXb, state.wrapX, weights.woLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, + state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) + .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, + state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN) + .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, + state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex], weights.w3Layered[layerIndex], config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, + state.wrapHb, state.wrapX, weights.w2Layered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .persistOnDevice( + state.wrapX ); - unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - unifiedLayer.task("reductionsOneBlock" , TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, - state.wrapX, config.dim, config.rmsNormEps, state.localSize) - .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, - state.wrapX, weights.rms_att_weightLayered[layerIndex], state.temp) - .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, - state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex], config.dim, config.dim, LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, - state.wrapXb, state.wrapK, weights.wkLayered[layerIndex], config.dim, config.kvDim, LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, - state.wrapXb, state.wrapV, weights.wvLayered[layerIndex], config.dim, config.kvDim, LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("rope", TransformerComputeKernelsLayered::ropeRotation,context, - state.positionHolder, state.wrapQ, state.wrapK, config.kvDim, - config.headSize) - .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, - state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim, layerIndex, config.contextLength) - .task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, context, - state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, - config.numberOfHeads, config.headSize, config.kvDim, config.kvMul, - state.positionHolder, layerIndex, config.contextLength) - .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, - state.wrapXb, state.wrapX, weights.woLayered[layerIndex], config.dim, config.dim, LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, - state.wrapX, config.dim, config.rmsNormEps, state.localSize) - .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, - state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN) - .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, - state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex], weights.w3Layered[layerIndex], config.dim, config.hiddenDim, LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, - state.wrapHb, state.wrapX, weights.w2Layered[layerIndex], config.hiddenDim, config.dim, LOCAL_WORK_GROUP_SIZE_ALLOC) - .persistOnDevice( - state.wrapX - ); - taskGraphs.add(unifiedLayer.snapshot()); - } + taskGraphs.add(unifiedLayer.snapshot()); + } - TaskGraph lastUnifiedLayer = unifiedLayer; - TaskGraph logits = new TaskGraph("logits") - .consumeFromDevice(lastUnifiedLayer.getTaskGraphName(), - state.wrapX - ) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, - state.tempLogits - ) - .transferToDevice(DataTransferMode.FIRST_EXECUTION, - context, - state.wrapLogits, - weights.wclsHalfFloat, - weights.rms_final_weight_as_floatArray - ) - .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, - state.wrapX, config.dim, config.rmsNormEps, state.localSize) - .task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, - weights.rms_final_weight_as_floatArray, state.tempLogits); - logits = configureQuantizedMatrixVectorFinalWeight(logits); - logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); - taskGraphs.add(logits.snapshot()); - // @formatter:on + TaskGraph lastUnifiedLayer = unifiedLayer; + TaskGraph logits = new TaskGraph("logits") + .consumeFromDevice(lastUnifiedLayer.getTaskGraphName(), + state.wrapX + ) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.tempLogits + ) + .transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.wrapLogits, + weights.wclsHalfFloat, + weights.rms_final_weight_as_floatArray + ) + .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, + state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) + .task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, + weights.rms_final_weight_as_floatArray, state.tempLogits); + logits = configureQuantizedMatrixVectorFinalWeight(logits); + logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); + taskGraphs.add(logits.snapshot()); + // @formatter:on return new Tuple2<>(taskGraphs, setupGridSchedulersLayered()); } @@ -189,7 +189,7 @@ private TaskGraph configureQuantizedMatrixVectorFinalWeight(TaskGraph logits) { case Q4_0: logits.task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, // context, state.wrapX, state.wrapLogits, weights.wclsHalfFloat, // - config.dim, config.vocabularySize, LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); // + config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); // break; default: throw new UnsupportedOperationException("Unsupported weight quantization type: " + weights.weightType + ". Only Q8_0 and Q4_0 are supported."); @@ -272,56 +272,56 @@ private GridScheduler setupGridSchedulersLayered() { // config.dim / 2 Worker for RoPE // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim/2,1,1], localWorkSize=[128,1,1]) // CUDA equivalent: kernel<<>> - WorkerGrid ropeWorker = new WorkerGrid1D(config.dim / 2); - ropeWorker.setGlobalWork(config.dim / 2, 1, 1); + WorkerGrid ropeWorker = new WorkerGrid1D(config.dim() / 2); + ropeWorker.setGlobalWork(config.dim() / 2, 1, 1); ropeWorker.setLocalWork(128, 1, 1); // config.dim Worker for Row major access // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1]) // CUDA equivalent: kernel<<>> - int configDimRowMajorGlobal = config.dim * LOCAL_WORK_GROUP_SIZE_ALLOC; + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid configDimRowMajorGlobalWorker = new WorkerGrid1D(configDimRowMajorGlobal); configDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); // config.kvDim Worker for Row major access // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.kvDim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1]) // CUDA equivalent: kernel<<>> - int configKvDimRowMajorGlobal = config.kvDim * LOCAL_WORK_GROUP_SIZE_ALLOC; + int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid configKvDimRowMajorGlobalWorker = new WorkerGrid1D(configKvDimRowMajorGlobal); configKvDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); // config.hiddenDim * 32 Worker for Row major access // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.hiddenDim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1]) // CUDA equivalent: kernel<<>> - int configHiddenDimRowMajor = config.hiddenDim * LOCAL_WORK_GROUP_SIZE_ALLOC; + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid configHiddenDimRowMajorWorker = new WorkerGrid1D(configHiddenDimRowMajor); configHiddenDimRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); // RMSNorm worker configuration // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[256,1,1]) // CUDA equivalent: kernel<<>> - WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim); - rmsNormWorker.setGlobalWork(config.dim, 1, 1); // Set global work size to total dimension + WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim()); + rmsNormWorker.setGlobalWork(config.dim(), 1, 1); // Set global work size to total dimension rmsNormWorker.setLocalWork(256, 1, 1); // Set local work size to 256 (standard efficient size) // Parallel attention worker configuration // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.numberOfHeads,1,1], localWorkSize=[4,1,1]) // CUDA equivalent: kernel<<>> - WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads); + WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads()); // the global group work size is numberOfHeads * localWorkGroupSize, where the localWorkGroupSize is currently 4 - parallelAttentionWorker.setGlobalWork(config.numberOfHeads * 8, 1, 1); + parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * 8, 1, 1); parallelAttentionWorker.setLocalWork(8, 1, 1); // Set local work size to 4 (for parallel attention) // Copy to caches worker configuration // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[128,1,1]) // CUDA equivalent: kernel<<>> - WorkerGrid copyToCachesWorker = new WorkerGrid1D(config.kvDim); - copyToCachesWorker.setGlobalWork(config.dim, 1, 1); + WorkerGrid copyToCachesWorker = new WorkerGrid1D(config.kvDim()); + copyToCachesWorker.setGlobalWork(config.dim(), 1, 1); copyToCachesWorker.setLocalWork(128, 1, 1); // Set local work size to 32 (for copying to caches) // Map workers to tasks tornadoForwardScheduler.addWorkerGrid("activationUpdate.updateX", singleWorker); - for (int i = 0; i < config.numberOfLayers; i++) { + for (int i = 0; i < config.numberOfLayers(); i++) { tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); @@ -340,7 +340,7 @@ private GridScheduler setupGridSchedulersLayered() { // Vocabulary worker configuration // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.vocabularySize,1,1], localWorkSize=[16,1,1]) // CUDA equivalent: kernel<<>> - int vocabSizeRowMajor = config.vocabularySize * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; + int vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); @@ -364,56 +364,56 @@ private GridScheduler setupGridSchedulersLayeredNonNvidia() { // config.dim / 2 Worker for RoPE // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim/2,1,1], localWorkSize=[128,1,1]) // CUDA equivalent: kernel<<>> - WorkerGrid ropeWorker = new WorkerGrid1D(config.dim / 2); - ropeWorker.setGlobalWork(config.dim / 2, 1, 1); + WorkerGrid ropeWorker = new WorkerGrid1D(config.dim() / 2); + ropeWorker.setGlobalWork(config.dim() / 2, 1, 1); ropeWorker.setLocalWork(128, 1, 1); // config.dim Worker for Row major access // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1]) // CUDA equivalent: kernel<<>> - int configDimRowMajorGlobal = config.dim * LOCAL_WORK_GROUP_SIZE_ALLOC; + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid configDimRowMajorGlobalWorker = new WorkerGrid1D(configDimRowMajorGlobal); configDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); // config.kvDim Worker for Row major access // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.kvDim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1]) // CUDA equivalent: kernel<<>> - int configKvDimRowMajorGlobal = config.kvDim * LOCAL_WORK_GROUP_SIZE_ALLOC; + int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid configKvDimRowMajorGlobalWorker = new WorkerGrid1D(configKvDimRowMajorGlobal); configKvDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); // config.hiddenDim * 32 Worker for Row major access // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.hiddenDim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1]) // CUDA equivalent: kernel<<>> - int configHiddenDimRowMajor = config.hiddenDim * LOCAL_WORK_GROUP_SIZE_ALLOC; + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid configHiddenDimRowMajorWorker = new WorkerGrid1D(configHiddenDimRowMajor); configHiddenDimRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); // RMSNorm worker configuration // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[256,1,1]) // CUDA equivalent: kernel<<>> - WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim); - rmsNormWorker.setGlobalWork(config.dim, 1, 1); // Set global work size to total dimension + WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim()); + rmsNormWorker.setGlobalWork(config.dim(), 1, 1); // Set global work size to total dimension rmsNormWorker.setLocalWork(256, 1, 1); // Set local work size to 256 (standard efficient size) // Parallel attention worker configuration // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.numberOfHeads,1,1], localWorkSize=[4,1,1]) // CUDA equivalent: kernel<<>> - WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads); + WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads()); // the global group work size is numberOfHeads * localWorkGroupSize, where the localWorkGroupSize is currently 4 - parallelAttentionWorker.setGlobalWork(config.numberOfHeads * 8, 1, 1); + parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * 8, 1, 1); parallelAttentionWorker.setLocalWork(8, 1, 1); // Set local work size to 4 (for parallel attention) // Copy to caches worker configuration // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[128,1,1]) // CUDA equivalent: kernel<<>> - WorkerGrid copyToCachesWorker = new WorkerGrid1D(config.kvDim); - copyToCachesWorker.setGlobalWork(config.dim, 1, 1); + WorkerGrid copyToCachesWorker = new WorkerGrid1D(config.kvDim()); + copyToCachesWorker.setGlobalWork(config.dim(), 1, 1); copyToCachesWorker.setLocalWork(128, 1, 1); // Set local work size to 32 (for copying to caches) // Map workers to tasks tornadoForwardScheduler.addWorkerGrid("activationUpdate.updateX", singleWorker); - for (int i = 0; i < config.numberOfLayers; i++) { + for (int i = 0; i < config.numberOfLayers(); i++) { tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); @@ -432,7 +432,7 @@ private GridScheduler setupGridSchedulersLayeredNonNvidia() { // Vocabulary worker configuration // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.vocabularySize,1,1], localWorkSize=[16,1,1]) // CUDA equivalent: kernel<<>> - int vocabSizeRowMajor = config.vocabularySize * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; + int vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); @@ -458,7 +458,7 @@ public Tuple2, GridScheduler> setupTornadoForwardPlanLa taskGraphs.add(activationUpdate.snapshot()); TaskGraph unifiedLayer = null; - for (int layerIndex =0; layerIndex < config.numberOfLayers; layerIndex++) { + for (int layerIndex =0; layerIndex < config.numberOfLayers(); layerIndex++) { unifiedLayer = new TaskGraph("layer_" + layerIndex); unifiedLayer.consumeFromDevice(state.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, @@ -475,38 +475,38 @@ public Tuple2, GridScheduler> setupTornadoForwardPlanLa ); unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); unifiedLayer.task("reductionsOneBlock" , TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, - state.wrapX, config.dim, config.rmsNormEps, state.localSize) + state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) .task("reductionFinalNormalization" , TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, - config.dim, config.rmsNormEps) + config.dim(), config.rmsNormEps()) .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex], state.temp) .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, - state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex], config.dim, config.dim, LOCAL_WORK_GROUP_SIZE_ALLOC) + state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, - state.wrapXb, state.wrapK, weights.wkLayered[layerIndex], config.dim, config.kvDim, LOCAL_WORK_GROUP_SIZE_ALLOC) + state.wrapXb, state.wrapK, weights.wkLayered[layerIndex], config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, - state.wrapXb, state.wrapV, weights.wvLayered[layerIndex], config.dim, config.kvDim, LOCAL_WORK_GROUP_SIZE_ALLOC) + state.wrapXb, state.wrapV, weights.wvLayered[layerIndex], config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("rope", TransformerComputeKernelsLayered::ropeRotation,context, - state.positionHolder, state.wrapQ, state.wrapK, config.kvDim, - config.headSize) + state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), + config.headSize()) .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, - state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim, layerIndex, config.contextLength) + state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), layerIndex, config.contextLength()) .task("parallel-attention", TransformerComputeKernelsLayered::processHeadsParallel, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, - config.numberOfHeads, config.headSize, config.kvDim, config.kvMul, config.vocabularySize, - state.positionHolder, state.wrapAtt, layerIndex, config.contextLength) + config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), config.vocabularySize(), + state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()) .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, - state.wrapXb, state.wrapX, weights.woLayered[layerIndex], config.dim, config.dim, LOCAL_WORK_GROUP_SIZE_ALLOC) + state.wrapXb, state.wrapX, weights.woLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, - state.wrapX, config.dim, config.rmsNormEps, state.localSize) + state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) .task("reductionFinalNormalizationFFN" , TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, - config.dim, config.rmsNormEps) + config.dim(), config.rmsNormEps()) .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN) .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, - state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex], weights.w3Layered[layerIndex], config.dim, config.hiddenDim, LOCAL_WORK_GROUP_SIZE_ALLOC) + state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex], weights.w3Layered[layerIndex], config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, - state.wrapHb, state.wrapX, weights.w2Layered[layerIndex], config.hiddenDim, config.dim, LOCAL_WORK_GROUP_SIZE_ALLOC) + state.wrapHb, state.wrapX, weights.w2Layered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .persistOnDevice( state.wrapX ); @@ -528,9 +528,9 @@ public Tuple2, GridScheduler> setupTornadoForwardPlanLa weights.rms_final_weight_as_floatArray ) .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, - state.wrapX, config.dim, config.rmsNormEps, state.localSize) + state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) .task("reductionFinalNormalizationLogits" , TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, - config.dim, config.rmsNormEps) + config.dim(), config.rmsNormEps()) .task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray, state.tempLogits); logits = configureQuantizedMatrixVectorFinalWeight(logits); diff --git a/src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java b/src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java index 17ba6fac..eb194603 100644 --- a/src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java @@ -1,8 +1,8 @@ package com.example.tornadovm; import com.example.auxiliary.Tuple2; -import com.example.inference.engine.impl.Configuration; -import com.example.inference.engine.impl.Llama; +import com.example.model.Configuration; +import com.example.model.Model; import com.example.loader.weights.State; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; @@ -22,7 +22,7 @@ public class TornadoVMMasterPlan { public TornadoExecutionPlan executionPlan; List taskGraphs; - public TornadoVMMasterPlan(State state, Llama model, boolean isNvidia) { + public TornadoVMMasterPlan(State state, Model model, boolean isNvidia) { TornadoVMLayerPlanner tornadoVMLayerPlanner = new TornadoVMLayerPlanner(state, model); Tuple2, GridScheduler> tornadoVMPlan = isNvidia ? tornadoVMLayerPlanner.setupTornadoForwardPlanLayered() : tornadoVMLayerPlanner.setupTornadoForwardPlanLayeredNonNvidia(); this.taskGraphs = tornadoVMPlan.getFirst(); @@ -43,7 +43,7 @@ public TornadoVMMasterPlan(State state, Llama model, boolean isNvidia) { * @param model The Llama model instance * @return The initialized TornadoVMMasterPlan ready for inference */ - public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Llama model) { + public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) { // Initialize timing variables outside conditional blocks to avoid scope issues long startTime = System.nanoTime(); long planCreationTime = 0; @@ -117,7 +117,7 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) { // 2. Execute each transformer layer graph sequentially // Each graph computes attention and feed-forward transformations for one layer - for (int layer = 0; layer < config.numberOfLayers; layer++) { + for (int layer = 0; layer < config.numberOfLayers(); layer++) { executionPlan.withGraph(getLayerGraphIndex(layer)) .withGridScheduler(scheduler) .execute(); @@ -166,12 +166,12 @@ public void forceCopyInReadOnlyDataLayered() { executionPlan.withGraph(0).withGridScheduler(scheduler).execute(); // Execute layer processing graphs - for (int layer = 0; layer < config.numberOfLayers; layer++) { + for (int layer = 0; layer < config.numberOfLayers(); layer++) { executionPlan.withGraph(layer + 1).withGridScheduler(scheduler).execute(); } // Execute logits graph - executionPlan.withGraph(config.numberOfLayers + 1).withGridScheduler(scheduler).execute(); + executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(scheduler).execute(); } /**