-
Notifications
You must be signed in to change notification settings - Fork 22
[model] Add support for Mistral models #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Introduced `kvDim` and `kvMul` methods in `Configuration` and `MistralConfiguration` to enhance model configuration flexibility. Refactored TornadoVM classes to generalize handling of different models by replacing `Llama`-specific types with `Model` interface. Streamlined token generation logic to support conditional GPU execution with TornadoVM.
Move format classes from auxiliary.format to model.format to fix dependency direction. These classes are only used by model classes, so co-locating them improves package cohesion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR integrates the Mistral LLM into the GPULlama3 codebase by introducing model and tokenizer abstractions, extending chat formatting logic, and refactoring inference loading.
- Refactor
Llamainto a genericModelinterface and addMistralimplementation - Introduce
Tokenizerinterface withLlamaTokenizerandMistralTokenizer - Abstract
ChatFormatfor both Llama and Mistral and enhanceModelLoaderto detect and load each
Reviewed Changes
Copilot reviewed 29 out of 29 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| model/llama/LlamaConfiguration.java | New record implementing Configuration for Llama |
| model/llama/Llama.java | Updated to implement generic Model interface |
| model/format/ChatFormat.java | Interface and factory to select Llama/Mistral format |
| model/format/MistralChatFormat.java | New chat formatter for Mistral-specific tokens |
| model/format/LlamaChatFormat.java | Refined Llama chat formatting under ChatFormat |
| loader/weights/ModelLoader.java | Detects GGUF metadata and loads the appropriate model |
| Model.java | Unified inference entry points under the Model interface |
Comments suppressed due to low confidence (2)
src/main/java/com/example/model/format/LlamaChatFormat.java:60
- This loop refers to
LlamaChatFormat.Message, butMessageis defined in theChatFormatinterface. Change tofor (ChatFormat.Message message : dialog)to match the interface type.
for (LlamaChatFormat.Message message : dialog) {
src/main/java/com/example/Model.java:103
List<Integer>does not havegetLast()orremoveLast()methods. UseresponseTokens.get(responseTokens.size() - 1)andresponseTokens.remove(responseTokens.size() - 1)instead.
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
|
|
||
| public static Llama loadModel(Path ggufPath, int contextLength, boolean loadWeights) throws IOException { | ||
| // Check by vocabulary size as fallback | ||
| if (vocabSize != null) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any other way to detect the model here instead of checking sizes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed checking the vocab size seems odd, but, it should be noted that it's the third check in line and just acts as a fallback in case the model name (1st check) and the tokenizer metadata (2nd check) are not enough. imho, we can keep only the model name check.
|
Fixes #18 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more thing! Can you update the README with the link and instruction to download the mistral model ->https://huggingface.co/beehive-lab
done |
|
@orionpapadakis thanks LGTM, let me test it and add the changes for the normalization layers, then we can merge. |
Summary
This PR integrates the Mistral LLM into the
GPULlama3repository.To support this integration, several architectural changes and refactorings were made to promote component abstractions and a modular, extensible design.
Key Changes
1. Model Abstraction
Llamaclass has been refactored into aModelinterface (under a newmodelpackage).Llama-specific functionality is moved tomodel.llama.Llama.model.mistral.Mistral.2. Tokenizer Abstraction
Introduced a
Tokenizerinterface under thetokenizer.implpackage.Implemented two tokenizers:
LlamaTokenizerfor GPT-2-style BPE.MistralTokenizerfor TikToken-style BPE.The
Vocabularyclass has been relocated to thetokenizerpackage.3. ChatFormat Abstraction
ChatFormatfunctionality has been refactored into an abstract form to support model-specific formatting and enable future extensions.4. Inference Refactoring
Inference logic is decoupled from the model classes.
Introduced a dedicated
inferencepackage with:InferenceEngine: Entry point for token generation (generateToken,generateTokenGPUmethods).InferenceCore: Contains reusable core operations (e.g.,rmsnorm,forward, etc.).