diff --git a/llama-tornado b/llama-tornado index 675fe204..112a7ca1 100755 --- a/llama-tornado +++ b/llama-tornado @@ -14,17 +14,19 @@ from pathlib import Path from typing import List, Optional, Dict, Any from enum import Enum + class Backend(Enum): OPENCL = "opencl" PTX = "ptx" + class LlamaRunner: """Main class for managing LLM execution with GPU acceleration.""" def __init__(self): - self.java_home = os.environ.get('JAVA_HOME') - self.tornado_sdk = os.environ.get('TORNADO_SDK') - self.llama_root = os.environ.get('LLAMA_ROOT') + self.java_home = os.environ.get("JAVA_HOME") + self.tornado_sdk = os.environ.get("TORNADO_SDK") + self.llama_root = os.environ.get("LLAMA_ROOT") if not all([self.java_home, self.tornado_sdk, self.llama_root]): print("Error: Required environment variables not set") @@ -35,9 +37,9 @@ class LlamaRunner: def _validate_paths(self): """Validate that required paths exist.""" paths_to_check = { - 'JAVA_HOME': self.java_home, - 'TORNADO_SDK': self.tornado_sdk, - 'LLAMA_ROOT': self.llama_root + "JAVA_HOME": self.java_home, + "TORNADO_SDK": self.tornado_sdk, + "LLAMA_ROOT": self.llama_root, } for name, path in paths_to_check.items(): @@ -62,7 +64,8 @@ class LlamaRunner: "--enable-preview", f"-Djava.library.path={self.tornado_sdk}/lib", "-Djdk.module.showModuleResolution=false", - "--module-path", self.module_path_colon_sep([".", f"{self.tornado_sdk}/share/java/tornado"]), + "--module-path", + self.module_path_colon_sep([".", f"{self.tornado_sdk}/share/java/tornado"]), ] # TornadoVM configuration @@ -86,22 +89,38 @@ class LlamaRunner: debug_config = [] if args.debug: - debug_config.extend([ - "-Dtornado.debug=true", - "-Dtornado.threadInfo=True" if args.threads else "-Dtornado.threadInfo=false" - ]) + debug_config.extend( + [ + "-Dtornado.debug=true", + "-Dtornado.threadInfo=True" + if args.threads + else "-Dtornado.threadInfo=false", + ] + ) else: - debug_config.extend([ - "-Dtornado.threadInfo=True" if args.threads else "-Dtornado.threadInfo=false", - "-Dtornado.debug=false" - ]) + debug_config.extend( + [ + "-Dtornado.threadInfo=True" + if args.threads + else "-Dtornado.threadInfo=false", + "-Dtornado.debug=false", + ] + ) # Additional debug options - debug_config.extend([ - "-Dtornado.fullDebug=True" if args.full_dump else "-Dtornado.fullDebug=false", - "-Dtornado.printKernel=True" if args.print_kernel else "-Dtornado.printKernel=false", - "-Dtornado.print.bytecodes=True" if args.print_bytecodes else "-Dtornado.print.bytecodes=false" - ]) + debug_config.extend( + [ + "-Dtornado.fullDebug=True" + if args.full_dump + else "-Dtornado.fullDebug=false", + "-Dtornado.printKernel=True" + if args.print_kernel + else "-Dtornado.printKernel=false", + "-Dtornado.print.bytecodes=True" + if args.print_bytecodes + else "-Dtornado.print.bytecodes=false", + ] + ) cmd.extend(debug_config) @@ -115,7 +134,7 @@ class LlamaRunner: "-Dtornado.enable.mathOptimizations=false", "-Dtornado.enable.nativeFunctions=true", "-Dtornado.loop.interchange=true", - f"-Dtornado.eventpool.maxwaitevents={args.max_wait_events}" + f"-Dtornado.eventpool.maxwaitevents={args.max_wait_events}", ] cmd.extend(tornado_runtime_config) @@ -126,25 +145,35 @@ class LlamaRunner: # Module configuration - varies by backend module_config = [ - f"--upgrade-module-path", f"{self.tornado_sdk}/share/java/graalJars", + f"--upgrade-module-path", + f"{self.tornado_sdk}/share/java/graalJars", f"@{self.tornado_sdk}/etc/exportLists/common-exports", ] # Add backend-specific exports and modules if args.backend == Backend.OPENCL: - module_config.extend([ - f"@{self.tornado_sdk}/etc/exportLists/opencl-exports", - "--add-modules", "ALL-SYSTEM,jdk.incubator.vector,tornado.runtime,tornado.annotation,tornado.drivers.common,tornado.drivers.opencl", - ]) + module_config.extend( + [ + f"@{self.tornado_sdk}/etc/exportLists/opencl-exports", + "--add-modules", + "ALL-SYSTEM,jdk.incubator.vector,tornado.runtime,tornado.annotation,tornado.drivers.common,tornado.drivers.opencl", + ] + ) elif args.backend == Backend.PTX: - module_config.extend([ - f"@{self.tornado_sdk}/etc/exportLists/ptx-exports", - "--add-modules", "ALL-SYSTEM,jdk.incubator.vector,tornado.runtime,tornado.annotation,tornado.drivers.common,tornado.drivers.ptx", - ]) - - module_config.extend([ - "-cp", f"{self.llama_root}/target/gpu-llama3-1.0-SNAPSHOT.jar", - "com.example.LlamaApp" - ]) + module_config.extend( + [ + f"@{self.tornado_sdk}/etc/exportLists/ptx-exports", + "--add-modules", + "ALL-SYSTEM,jdk.incubator.vector,tornado.runtime,tornado.annotation,tornado.drivers.common,tornado.drivers.ptx", + ] + ) + + module_config.extend( + [ + "-cp", + f"{self.llama_root}/target/gpu-llama3-1.0-SNAPSHOT.jar", + "com.example.LlamaApp", + ] + ) cmd.extend(module_config) return cmd @@ -152,13 +181,20 @@ class LlamaRunner: def _add_llama_args(self, cmd: List[str], args: argparse.Namespace) -> List[str]: """Add LLaMA-specific arguments to the command.""" llama_args = [ - "-m", args.model_path, - "--temperature", str(args.temperature), - "--top-p", str(args.top_p), - "--seed", str(args.seed), - "--max-tokens", str(args.max_tokens), - "--stream", str(args.stream).lower(), - "--echo", str(args.echo).lower() + "-m", + args.model_path, + "--temperature", + str(args.temperature), + "--top-p", + str(args.top_p), + "--seed", + str(args.seed), + "--max-tokens", + str(args.max_tokens), + "--stream", + str(args.stream).lower(), + "--echo", + str(args.echo).lower(), ] if args.prompt: @@ -191,20 +227,22 @@ class LlamaRunner: escaped_cmd = [] for arg in cmd: # Escape arguments that contain spaces or special characters - if ' ' in arg or '"' in arg or "'" in arg: + if " " in arg or '"' in arg or "'" in arg: escaped_cmd.append(f'"{arg}"') else: escaped_cmd.append(arg) # Print as a continuous line that can be easily copied - print(' '.join(escaped_cmd)) + print(" ".join(escaped_cmd)) print("-" * 80) print() # If user only wants to see the command without executing if not args.execute_after_show: print("Command built successfully. Exiting without execution.") - print("Use --execute-after-show to run the command after displaying it.") + print( + "Use --execute-after-show to run the command after displaying it." + ) return 0 if args.verbose: @@ -227,6 +265,7 @@ class LlamaRunner: print(f"Error: {e}") return 1 + def load_env_from_script(): system = platform.system() @@ -234,7 +273,9 @@ def load_env_from_script(): # Call set_paths.cmd and capture output as environment result = subprocess.run( ["cmd.exe", "/c", "set_paths.cmd && set"], - capture_output=True, text=True, shell=False + capture_output=True, + text=True, + shell=False, ) if result.returncode != 0: print("Failed to run set_paths.cmd") @@ -242,121 +283,192 @@ def load_env_from_script(): # Parse environment variables from output for line in result.stdout.splitlines(): - if '=' in line: - key, value = line.strip().split('=', 1) + if "=" in line: + key, value = line.strip().split("=", 1) os.environ[key] = value elif system in ("Linux", "Darwin"): # Source the set_paths file and capture env - command = ['bash', '-c', 'source ./set_paths && env'] + command = ["bash", "-c", "source ./set_paths && env"] result = subprocess.run(command, capture_output=True, text=True) if result.returncode != 0: print("Failed to source set_paths") sys.exit(1) for line in result.stdout.splitlines(): - if '=' in line: - key, value = line.strip().split('=', 1) + if "=" in line: + key, value = line.strip().split("=", 1) os.environ[key] = value else: print(f"Unsupported OS: {system}") sys.exit(1) + def create_parser() -> argparse.ArgumentParser: """Create and configure the argument parser.""" parser = argparse.ArgumentParser( prog="llama-tornado", description="GPU-accelerated LLM runner using TornadoVM", - formatter_class=argparse.ArgumentDefaultsHelpFormatter + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) # Required arguments - parser.add_argument("--model", dest="model_path", required=True, - help="Path to the LLM gguf file (e.g., Llama-3.2-1B-Instruct-Q8_0.gguf)") + parser.add_argument( + "--model", + dest="model_path", + required=True, + help="Path to the LLM gguf file (e.g., Llama-3.2-1B-Instruct-Q8_0.gguf)", + ) # LLM arguments llm_group = parser.add_argument_group("LLaMA Configuration") llm_group.add_argument("--prompt", help="Input prompt for the model") llm_group.add_argument("-sp", "--system-prompt", help="System prompt for the model") - llm_group.add_argument("--temperature", type=float, default=0.1, - help="Sampling temperature (0.0 to 2.0)") - llm_group.add_argument("--top-p", type=float, default=0.95, - help="Top-p sampling parameter") - llm_group.add_argument("--seed", type=int, default=None, - help="Random seed (default: current timestamp)") - llm_group.add_argument("-n", "--max-tokens", type=int, default=512, - help="Maximum number of tokens to generate") - llm_group.add_argument("--stream", type=bool, default=True, - help="Enable streaming output") - llm_group.add_argument("--echo", type=bool, default=False, - help="Echo the input prompt") - llm_group.add_argument("--suffix", help="Suffix for fill-in-the-middle request (Codestral)") + llm_group.add_argument( + "--temperature", + type=float, + default=0.1, + help="Sampling temperature (0.0 to 2.0)", + ) + llm_group.add_argument( + "--top-p", type=float, default=0.95, help="Top-p sampling parameter" + ) + llm_group.add_argument( + "--seed", + type=int, + default=None, + help="Random seed (default: current timestamp)", + ) + llm_group.add_argument( + "-n", + "--max-tokens", + type=int, + default=512, + help="Maximum number of tokens to generate", + ) + llm_group.add_argument( + "--stream", type=bool, default=True, help="Enable streaming output" + ) + llm_group.add_argument( + "--echo", type=bool, default=False, help="Echo the input prompt" + ) + llm_group.add_argument( + "--suffix", help="Suffix for fill-in-the-middle request (Codestral)" + ) # Mode selection mode_group = parser.add_argument_group("Mode Selection") - mode_group.add_argument("-i", "--interactive", action="store_true", - help="Run in interactive/chat mode") - mode_group.add_argument("--instruct", action="store_true", default=True, - help="Run in instruction mode (default)") + mode_group.add_argument( + "-i", "--interactive", action="store_true", help="Run in interactive/chat mode" + ) + mode_group.add_argument( + "--instruct", + action="store_true", + default=True, + help="Run in instruction mode (default)", + ) # Hardware configuration hw_group = parser.add_argument_group("Hardware Configuration") - hw_group.add_argument("--gpu", dest="use_gpu", action="store_true", - help="Enable GPU acceleration") - hw_group.add_argument("--opencl", dest="backend", action="store_const", const=Backend.OPENCL, - help="Use OpenCL backend (default)") - hw_group.add_argument("--ptx", dest="backend", action="store_const", const=Backend.PTX, - help="Use PTX/CUDA backend") - hw_group.add_argument("--gpu-memory", default="7GB", - help="GPU memory allocation") - hw_group.add_argument("--heap-min", default="20g", - help="Minimum JVM heap size") - hw_group.add_argument("--heap-max", default="20g", - help="Maximum JVM heap size") + hw_group.add_argument( + "--gpu", dest="use_gpu", action="store_true", help="Enable GPU acceleration" + ) + hw_group.add_argument( + "--opencl", + dest="backend", + action="store_const", + const=Backend.OPENCL, + help="Use OpenCL backend (default)", + ) + hw_group.add_argument( + "--ptx", + dest="backend", + action="store_const", + const=Backend.PTX, + help="Use PTX/CUDA backend", + ) + hw_group.add_argument("--gpu-memory", default="7GB", help="GPU memory allocation") + hw_group.add_argument("--heap-min", default="20g", help="Minimum JVM heap size") + hw_group.add_argument("--heap-max", default="20g", help="Maximum JVM heap size") # Debug and profiling debug_group = parser.add_argument_group("Debug and Profiling") - debug_group.add_argument("--debug", action="store_true", - help="Enable debug output") - debug_group.add_argument("--profiler", action="store_true", - help="Enable TornadoVM profiler") - debug_group.add_argument("--profiler-dump-dir", - default="/home/mikepapadim/repos/gpu-llama3.java/prof.json", - help="Directory for profiler output") + debug_group.add_argument("--debug", action="store_true", help="Enable debug output") + debug_group.add_argument( + "--profiler", action="store_true", help="Enable TornadoVM profiler" + ) + debug_group.add_argument( + "--profiler-dump-dir", + default="/home/mikepapadim/repos/gpu-llama3.java/prof.json", + help="Directory for profiler output", + ) # TornadoVM Execution Verbose options verbose_group = parser.add_argument_group("TornadoVM Execution Verbose") - verbose_group.add_argument("--print-bytecodes", dest="print_bytecodes", action="store_true", - help="Print bytecodes (tornado.print.bytecodes=true)") - verbose_group.add_argument("--print-threads", dest="threads", action="store_true", - help="Print thread information (tornado.threadInfo=true)") - verbose_group.add_argument("--print-kernel", dest="print_kernel", action="store_true", - help="Print kernel information (tornado.printKernel=true)") - verbose_group.add_argument("--full-dump", dest="full_dump", action="store_true", - help="Enable full debug dump (tornado.fullDebug=true)") - verbose_group.add_argument("--verbose-init", dest="verbose_init", action="store_true", - help="Enable timers for TornadoVM initialization (llama.EnableTimingForTornadoVMInit=true)") - + verbose_group.add_argument( + "--print-bytecodes", + dest="print_bytecodes", + action="store_true", + help="Print bytecodes (tornado.print.bytecodes=true)", + ) + verbose_group.add_argument( + "--print-threads", + dest="threads", + action="store_true", + help="Print thread information (tornado.threadInfo=true)", + ) + verbose_group.add_argument( + "--print-kernel", + dest="print_kernel", + action="store_true", + help="Print kernel information (tornado.printKernel=true)", + ) + verbose_group.add_argument( + "--full-dump", + dest="full_dump", + action="store_true", + help="Enable full debug dump (tornado.fullDebug=true)", + ) + verbose_group.add_argument( + "--verbose-init", + dest="verbose_init", + action="store_true", + help="Enable timers for TornadoVM initialization (llama.EnableTimingForTornadoVMInit=true)", + ) # Command display options command_group = parser.add_argument_group("Command Display Options") - command_group.add_argument("--show-command", action="store_true", - help="Display the full Java command that will be executed") - command_group.add_argument("--execute-after-show", action="store_true", - help="Execute the command after showing it (use with --show-command)") + command_group.add_argument( + "--show-command", + action="store_true", + help="Display the full Java command that will be executed", + ) + command_group.add_argument( + "--execute-after-show", + action="store_true", + help="Execute the command after showing it (use with --show-command)", + ) # Advanced options advanced_group = parser.add_argument_group("Advanced Options") - advanced_group.add_argument("--opencl-flags", - default="-cl-denorms-are-zero -cl-no-signed-zeros -cl-finite-math-only", - help="OpenCL compiler flags") - advanced_group.add_argument("--max-wait-events", type=int, default=32000, - help="Maximum wait events for TornadoVM event pool") - advanced_group.add_argument("--verbose", "-v", action="store_true", - help="Verbose output") + advanced_group.add_argument( + "--opencl-flags", + default="-cl-denorms-are-zero -cl-no-signed-zeros -cl-finite-math-only", + help="OpenCL compiler flags", + ) + advanced_group.add_argument( + "--max-wait-events", + type=int, + default=32000, + help="Maximum wait events for TornadoVM event pool", + ) + advanced_group.add_argument( + "--verbose", "-v", action="store_true", help="Verbose output" + ) return parser + def main(): """Main entry point.""" load_env_from_script() @@ -368,7 +480,7 @@ def main(): args.seed = int(time.time()) # Set default backend to OpenCL if not specified - if not hasattr(args, 'backend') or args.backend is None: + if not hasattr(args, "backend") or args.backend is None: args.backend = Backend.OPENCL # Handle mode selection logic @@ -379,5 +491,6 @@ def main(): runner = LlamaRunner() return runner.run(args) + if __name__ == "__main__": sys.exit(main()) diff --git a/src/main/java/com/example/LlamaApp.java b/src/main/java/com/example/LlamaApp.java index 826b35c0..5ea0cb23 100644 --- a/src/main/java/com/example/LlamaApp.java +++ b/src/main/java/com/example/LlamaApp.java @@ -5,8 +5,8 @@ import com.example.inference.sampler.CategoricalSampler; import com.example.inference.sampler.Sampler; import com.example.inference.sampler.ToppSampler; -import com.example.model.Model; import com.example.loader.weights.ModelLoader; +import com.example.model.Model; import com.example.tornadovm.FloatArrayUtils; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; @@ -106,16 +106,48 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp, return sampler; } - public static void main(String[] args) throws IOException { - Options options = Options.parseOptions(args); - Model model; + /** + * Loads the language model based on the given options. + *
+ * If Ahead-of-Time (AOT) mode is enabled, attempts to use a pre-loaded compiled model. + * Otherwise, loads the model from the specified path using the model loader. + *
+ * + * @param options the parsed CLI options containing model path and max token limit + * @return the loaded {@link Model} instance + * @throws IOException if the model fails to load + * @throws IllegalStateException if AOT loading is enabled but the preloaded model is unavailable + */ + private static Model loadModel(Options options) throws IOException { if (USE_AOT) { - model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens()); - } else { - model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true); + Model model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens()); + if (model == null) { + throw new IllegalStateException("Failed to load precompiled AOT model."); + } + return model; } - assert model != null; - Sampler sampler = selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed()); + return ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true); + } + + private static Sampler createSampler(Model model, Options options) { + return selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed()); + } + + /** + * Entry point for running the LLaMA-based model with provided command-line arguments. + * + *Initializes model options, loads the appropriate model (either AOT or on-demand), + * configures the sampler, and runs either in interactive or single-instruction mode + * based on the input options.
+ * + * @param args command-line arguments used to configure model path, temperature, seed, etc. + * @throws IOException if model loading or file operations fail. + */ + public static void main(String[] args) throws IOException { + Options options = Options.parseOptions(args); + Model model = loadModel(options); + Sampler sampler = createSampler(model, options); + if (options.interactive()) { model.runInteractive(sampler, options); } else { diff --git a/src/main/java/com/example/model/Model.java b/src/main/java/com/example/model/Model.java index 07799562..e42349b5 100644 --- a/src/main/java/com/example/model/Model.java +++ b/src/main/java/com/example/model/Model.java @@ -1,12 +1,12 @@ package com.example.model; +import com.example.Options; import com.example.auxiliary.LastRunMetrics; -import com.example.model.format.ChatFormat; import com.example.inference.InferenceEngine; import com.example.inference.sampler.Sampler; -import com.example.Options; import com.example.loader.weights.State; import com.example.loader.weights.Weights; +import com.example.model.format.ChatFormat; import com.example.tokenizer.impl.Tokenizer; import com.example.tornadovm.TornadoVMMasterPlan; diff --git a/src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java b/src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java index 77c6b56b..e1864a4b 100644 --- a/src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java +++ b/src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java @@ -41,7 +41,6 @@ * * @see TaskGraph * @see GridScheduler - * @see Llama */ // @formatter:on public class TornadoVMLayerPlanner { diff --git a/src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java b/src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java index eb194603..dc578f7a 100644 --- a/src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java @@ -1,9 +1,9 @@ package com.example.tornadovm; import com.example.auxiliary.Tuple2; +import com.example.loader.weights.State; import com.example.model.Configuration; import com.example.model.Model; -import com.example.loader.weights.State; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; @@ -24,7 +24,9 @@ public class TornadoVMMasterPlan { public TornadoVMMasterPlan(State state, Model model, boolean isNvidia) { TornadoVMLayerPlanner tornadoVMLayerPlanner = new TornadoVMLayerPlanner(state, model); - Tuple2