- 
                Notifications
    You must be signed in to change notification settings 
- Fork 22
[models] Support for Qwen3 models #37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
1306591    to
    0369d50      
    Compare
  
    … base class hierarchy
2343297    to
    7dc5056      
    Compare
  
    | ready for review | 
        
          
                src/main/java/com/example/inference/weights/standard/LlamaStandardWeights.java
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
                src/main/java/com/example/tornadovm/Qwen3TornadoVMLayerPlanner.java
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
                src/main/java/com/example/tornadovm/Qwen3TornadoVMLayerPlanner.java
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
                src/main/java/com/example/tornadovm/Qwen3TornadoVMLayerPlanner.java
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
                src/main/java/com/example/tornadovm/TransformerComputeKernelsLayered.java
              
                Outdated
          
            Show resolved
            Hide resolved
        
      There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for Qwen3 models to the codebase, implementing a modular architecture that refactors model loading and inference engines to support multiple model types. The implementation includes both CPU and GPU inference paths through TornadoVM for Qwen3 models, alongside architectural improvements to the existing LLaMA and Mistral model support.
Key changes include:
- Adding Qwen3 model support with specialized tokenization, configuration, and inference logic
- Refactoring the model loading system to use a modular pattern with abstract base classes
- Implementing separate state management and weight handling for different model architectures
Reviewed Changes
Copilot reviewed 44 out of 44 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description | 
|---|---|
| TornadoVMLayerPlanner.java | Refactored to support generic model types with parameterized base class | 
| Qwen3TornadoVMLayerPlanner.java | New Qwen3-specific GPU execution planner with custom kernel configurations | 
| Qwen3Kernels.java | Qwen3-specific GPU kernels including RMSNorm and RoPE rotation implementations | 
| Qwen3Tokenizer.java | Complete Qwen3 tokenizer implementation with BPE encoding/decoding | 
| Model architecture files | New Qwen3Configuration, Qwen3 model class, and supporting infrastructure | 
| Weight/State refactoring | Separated standard and TornadoVM weight classes, model-specific state classes | 
| Model loader refactoring | Abstract ModelLoader base with concrete implementations for each model type | 
Comments suppressed due to low confidence (3)
src/main/java/com/example/tornadovm/TransformerComputeKernelsLayered.java:441
- [nitpick] The variable name 'shared_tile_max_holder' is verbose and the comment suggests it's a workaround. Consider renaming to 'tileMaxBuffer' for clarity and consistency with other buffer variables.
        float[] shared_tile_max_holder = context.allocateFloatLocalArray(1); // FIX: For broadcasting tile max
src/main/java/com/example/tornadovm/TransformerComputeKernelsLayered.java:623
- [nitpick] The parameter name 'hb' is not descriptive. Consider renaming to 'output' or 'outputBuffer' to match the comment and improve readability.
            FloatArray hb,                  // output
src/main/java/com/example/inference/state/Qwen3State.java:25
- The variable 'nEmbdHead' is assigned 'numberOfHeads()' but based on context, it should likely be 'numberOfHeadsValue()' or a calculated embedding head size. This naming suggests a mismatch between the variable name and its actual value.
        int nEmbdHead = qwen3config.numberOfHeads();
| @orionpapadakis also, update the readmen with Qwen models instructiosn etc | 
Applied consistent formatting using @Formatter directives to enhance readability. Improved class documentation with detailed JavaDoc comments for methods and constructors, clarifying their purpose and parameters. Adjusted code style for multiline constructs and added missing comments where necessary.
62a3dd8    to
    2fd98ef      
    Compare
  
    
On going work for #19
Check list:
[x] CPU inference path in a working state
[x] GPU inference path in a working state