Skip to content

Commit 62a3dd8

Browse files
Refactor and improve code formatting across multiple files
Applied consistent formatting using @Formatter directives to enhance readability. Improved class documentation with detailed JavaDoc comments for methods and constructors, clarifying their purpose and parameters. Adjusted code style for multiline constructs and added missing comments where necessary.
1 parent dabbdfb commit 62a3dd8

25 files changed

+316
-120
lines changed

src/main/java/com/example/aot/AOT.java

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ public final class AOT {
3232

3333
static LlamaModelLoader modelLoader;
3434

35-
36-
record PartialModel(String modelFileName, Llama model, long tensorDataOffset, Map<String, GGUF.GGUFTensorInfo> tensorInfos) {}
35+
record PartialModel(String modelFileName, Llama model, long tensorDataOffset, Map<String, GGUF.GGUFTensorInfo> tensorInfos) {
36+
}
3737

3838
private static final PartialModel PRELOADED_GGUF = preLoadGGUF(System.getProperty("llama.PreloadGGUF"));
3939

@@ -49,12 +49,8 @@ private static PartialModel preLoadGGUF(String modelPath) {
4949
GGUF gguf = GGUF.loadModel(path);
5050
try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) {
5151
modelLoader = new LlamaModelLoader(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false);
52-
return new PartialModel(
53-
path.getFileName().toString(),
54-
modelLoader.loadModel(), // TODO: needs proper handling for AOT
55-
gguf.getTensorDataOffset(),
56-
gguf.getTensorInfos()
57-
);
52+
return new PartialModel(path.getFileName().toString(), modelLoader.loadModel(), // TODO: needs proper handling for AOT
53+
gguf.getTensorDataOffset(), gguf.getTensorInfos());
5854
}
5955
} catch (IOException e) {
6056
throw new RuntimeException(e);
@@ -78,8 +74,7 @@ public static Model tryUsePreLoaded(Path modelPath, int contextLength) throws IO
7874
return null;
7975
}
8076
Llama baseModel = preLoaded.model();
81-
try (var timer = Timer.log("Load tensors from pre-loaded model");
82-
var fileChannel = FileChannel.open(modelPath, StandardOpenOption.READ)) {
77+
try (var timer = Timer.log("Load tensors from pre-loaded model"); var fileChannel = FileChannel.open(modelPath, StandardOpenOption.READ)) {
8378
// Load only the tensors (mmap slices).
8479
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, preLoaded.tensorDataOffset(), preLoaded.tensorInfos());
8580
Weights weights = modelLoader.loadWeights(tensorEntries, baseModel.configuration());

src/main/java/com/example/auxiliary/Utf8Mask.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
/** mask of a byte-sequence in UTF-8 encoding */
44
public record Utf8Mask(int mask, int pattern, int len) {
5+
//@formatter:off
56
public static final Utf8Mask[] MASKS = {
67
new Utf8Mask(0b11100000, 0b11000000, 2),
78
new Utf8Mask(0b11110000, 0b11100000, 3),
89
new Utf8Mask(0b11111000, 0b11110000, 4)
910
};
11+
//@formatter:on
1012
}

src/main/java/com/example/inference/state/LlamaState.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ protected StateFields createStateFields(Configuration config) {
5656
fields.positionHolder = new IntArray(1);
5757

5858
// Temporary arrays
59-
fields.temp = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
60-
fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
61-
fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
59+
fields.temp = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
60+
fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
61+
fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
6262

6363
return fields;
6464
}

src/main/java/com/example/inference/state/Qwen3State.java

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,8 @@ protected StateFields createStateFields(Configuration configuration) {
5252
fields.logits = ArrayFloatTensor.allocate(config.vocabularySize());
5353

5454
// Key-value cache with Qwen3 dimensions
55-
fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa))
56-
.limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
57-
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa))
58-
.limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
55+
fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
56+
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
5957

6058
// TornadoVM wrappers with Qwen3-specific sizes
6159
fields.wrapX = new FloatArray(config.dim());
@@ -76,9 +74,9 @@ protected StateFields createStateFields(Configuration configuration) {
7674
fields.positionHolder = new IntArray(1);
7775

7876
// Temporary arrays
79-
fields.temp = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
80-
fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
81-
fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
77+
fields.temp = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
78+
fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
79+
fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
8280

8381
return fields;
8482
}

src/main/java/com/example/inference/state/State.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
/**
99
* Base class for State
1010
*/
11-
public abstract class State{
11+
public abstract class State {
1212

1313
// current wave of activations
1414
public final FloatTensor x; // activation at current time stamp (dim,)

src/main/java/com/example/inference/weights/Weights.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22

33
import com.example.core.model.GGMLType;
44

5+
/**
6+
* The GPULlama3.java utilizes two distinct weight types:
7+
* <ul>
8+
* <li><b>StandardWeights:</b> Designed for standard Java-based inference on the CPU.</li>
9+
* <li><b>TornadoWeights:</b> Optimized for GPU-accelerated inference using TornadoVM.</li>
10+
* </ul>
11+
*
12+
* The packages <code>weights.standard</code> and <code>weights.tornado</code> define
13+
* base classes and model-specific implementations for weights in their respective formats.
14+
*/
515
public interface Weights {
616

717
GGMLType getWeightType();

src/main/java/com/example/inference/weights/standard/LlamaStandardWeights.java

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,69 @@
33
import com.example.core.model.GGMLType;
44
import com.example.core.model.tensor.FloatTensor;
55

6+
/**
7+
* A model-specific implementation of {@link StandardWeights} for the Llama model.
8+
* This class encapsulates the weights required for performing inference
9+
* using the Llama model in the standard CPU-based format.
10+
*
11+
* <p><b>Note:</b> This weight format is also used for the Mistral model.</p>
12+
*/
613
public class LlamaStandardWeights extends StandardWeights {
714

8-
public LlamaStandardWeights(FloatTensor token_embedding_table, FloatTensor[] rms_att_weight, FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo, FloatTensor[] rms_ffn_weight,
9-
FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3, FloatTensor rms_final_weight, FloatTensor freq_cis_real, FloatTensor freq_cis_imag, FloatTensor wcls, GGMLType weightType) {
10-
super(token_embedding_table, rms_att_weight, wq, wk, wv, wo, rms_ffn_weight, w1, w2, w3, rms_final_weight, freq_cis_real, freq_cis_imag, wcls, weightType);
15+
// @formatter:off
16+
/**
17+
* Constructor for LlamaStandardWeights.
18+
*
19+
* @param token_embedding_table The token embedding table tensor.
20+
* @param rms_att_weight Array of RMS attention weights tensors.
21+
* @param wq Array of query weight tensors.
22+
* @param wk Array of key weight tensors.
23+
* @param wv Array of value weight tensors.
24+
* @param wo Array of output weight tensors.
25+
* @param rms_ffn_weight Array of RMS feed-forward network weights.
26+
* @param w1 Array of first feed-forward layer weights.
27+
* @param w2 Array of second feed-forward layer weights.
28+
* @param w3 Array of third feed-forward layer weights.
29+
* @param rms_final_weight Final RMS weight tensor.
30+
* @param freq_cis_real Real part of frequency cis tensor.
31+
* @param freq_cis_imag Imaginary part of frequency cis tensor.
32+
* @param wcls Class token weight tensor.
33+
* @param weightType The GGML weight type.
34+
*/
35+
public LlamaStandardWeights(
36+
FloatTensor token_embedding_table,
37+
FloatTensor[] rms_att_weight,
38+
FloatTensor[] wq,
39+
FloatTensor[] wk,
40+
FloatTensor[] wv,
41+
FloatTensor[] wo,
42+
FloatTensor[] rms_ffn_weight,
43+
FloatTensor[] w1,
44+
FloatTensor[] w2,
45+
FloatTensor[] w3,
46+
FloatTensor rms_final_weight,
47+
FloatTensor freq_cis_real,
48+
FloatTensor freq_cis_imag,
49+
FloatTensor wcls,
50+
GGMLType weightType) {
51+
// call to StandardWeights constructor
52+
super(token_embedding_table,
53+
rms_att_weight,
54+
wq,
55+
wk,
56+
wv,
57+
wo,
58+
rms_ffn_weight,
59+
w1,
60+
w2,
61+
w3,
62+
rms_final_weight,
63+
freq_cis_real,
64+
freq_cis_imag,
65+
wcls,
66+
weightType);
1167
}
68+
// @formatter:on
1269

1370
@Override
1471
public GGMLType getWeightType() {

src/main/java/com/example/inference/weights/standard/Qwen3StandardWeights.java

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,75 @@
33
import com.example.core.model.GGMLType;
44
import com.example.core.model.tensor.FloatTensor;
55

6+
/**
7+
* A model-specific implementation of {@link StandardWeights} for the Qwen-3 model.
8+
* This class defines the weights required for performing inference
9+
* using the Qwen-3 model in the standard CPU-based format.
10+
*/
611
public class Qwen3StandardWeights extends StandardWeights {
712
public final FloatTensor[] attnKNorm, attnQNorm;
813

9-
public Qwen3StandardWeights(FloatTensor token_embedding_table, FloatTensor[] rms_att_weight,
10-
FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo,
11-
FloatTensor[] attnKNorm, FloatTensor[] attnQNorm,
14+
// @formatter:off
15+
/**
16+
* Constructor for {@code Qwen3StandardWeights}.
17+
*
18+
* @param token_embedding_table The token embedding table, used to map tokens to embeddings.
19+
* @param rms_att_weight The array of Root Mean Square (RMS) attention weights.
20+
* @param wq The array of query weight tensors for attention layers.
21+
* @param wk The array of key weight tensors for attention layers.
22+
* @param wv The array of value weight tensors for attention layers.
23+
* @param wo The array of output weight tensors for attention layers.
24+
* @param attnKNorm The array of normalization tensors for attention keys.
25+
* @param attnQNorm The array of normalization tensors for attention queries.
26+
* @param rms_ffn_weight The array of RMS weights for feed-forward neural network layers.
27+
* @param w1 The array of first weight tensors for feed-forward layers.
28+
* @param w2 The array of second weight tensors for feed-forward layers.
29+
* @param w3 The array of third weight tensors for feed-forward layers.
30+
* @param rms_final_weight The RMS weight used for final output normalization.
31+
* @param freq_cis_real The real part of the frequency position encodings.
32+
* @param freq_cis_imag The imaginary part of the frequency position encodings.
33+
* @param wcls The weight tensor for the classification head.
34+
* @param weightType The type of the weights, defined as {@link GGMLType}.
35+
*/
36+
public Qwen3StandardWeights(
37+
FloatTensor token_embedding_table,
38+
FloatTensor[] rms_att_weight,
39+
FloatTensor[] wq,
40+
FloatTensor[] wk,
41+
FloatTensor[] wv,
42+
FloatTensor[] wo,
43+
FloatTensor[] attnKNorm,
44+
FloatTensor[] attnQNorm,
1245
FloatTensor[] rms_ffn_weight,
13-
FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3,
14-
FloatTensor rms_final_weight, FloatTensor freq_cis_real, FloatTensor freq_cis_imag, FloatTensor wcls, GGMLType weightType) {
46+
FloatTensor[] w1,
47+
FloatTensor[] w2,
48+
FloatTensor[] w3,
49+
FloatTensor rms_final_weight,
50+
FloatTensor freq_cis_real,
51+
FloatTensor freq_cis_imag,
52+
FloatTensor wcls,
53+
GGMLType weightType) {
1554
// call to StandardWeights constructor
16-
super(token_embedding_table, rms_att_weight, wq, wk, wv, wo, rms_ffn_weight, w1, w2, w3, rms_final_weight, freq_cis_real, freq_cis_imag, wcls, weightType);
55+
super(token_embedding_table,
56+
rms_att_weight,
57+
wq,
58+
wk,
59+
wv,
60+
wo,
61+
rms_ffn_weight,
62+
w1,
63+
w2,
64+
w3,
65+
rms_final_weight,
66+
freq_cis_real,
67+
freq_cis_imag,
68+
wcls,
69+
weightType);
70+
// init Qwen3-specific fields
1771
this.attnKNorm = attnKNorm;
1872
this.attnQNorm = attnQNorm;
1973
}
74+
// @formatter:on
2075

2176
@Override
2277
public GGMLType getWeightType() {

src/main/java/com/example/inference/weights/standard/StandardWeights.java

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
import com.example.core.model.tensor.FloatTensor;
55
import com.example.inference.weights.Weights;
66

7+
/**
8+
* Base class that represents the standard weight format used for Java-based CPU inference.
9+
* This abstract class provides the foundation for defining model-specific
10+
* weights in the StandardWeights format.
11+
*/
712
public abstract class StandardWeights implements Weights {
813
// token embedding table
914
public final FloatTensor token_embedding_table; // (vocab_size, dim)
@@ -14,8 +19,6 @@ public abstract class StandardWeights implements Weights {
1419
public final FloatTensor[] wk; // (layer, n_kv_heads, head_size)
1520
public final FloatTensor[] wv; // (layer, n_kv_heads * head_size)
1621
public final FloatTensor[] wo; // (layer, n_heads * head_size, dim)
17-
//public final FloatTensor[] attnKNorm; // qwen3
18-
//public final FloatTensor[] attnQNorm; // qwen3
1922
public final FloatTensor[] rms_ffn_weight; // (layer, dim)
2023

2124
// weights for ffn
@@ -33,41 +36,27 @@ public abstract class StandardWeights implements Weights {
3336
// (optional) classifier weights for the logits, on the last layer
3437
protected final GGMLType weightType;
3538

39+
//@formatter:off
3640
/**
3741
* Constructor for standard (non-TornadoVM) mode
3842
*
39-
* @param token_embedding_table
40-
* Token embeddings matrix
41-
* @param rms_att_weight
42-
* RMSNorm weights for attention layers
43-
* @param wq
44-
* Query weight matrices
45-
* @param wk
46-
* Key weight matrices
47-
* @param wv
48-
* Value weight matrices
49-
* @param wo
50-
* Output projection matrices
51-
* @param rms_ffn_weight
52-
* RMSNorm weights for FFN layers
53-
* @param w1
54-
* First FFN weight matrices
55-
* @param w2
56-
* Second FFN weight matrices
57-
* @param w3
58-
* Third FFN weight matrices (gate)
59-
* @param rms_final_weight
60-
* Final layer normalization weights
61-
* @param freq_cis_real
62-
* RoPE cosine components
63-
* @param freq_cis_imag
64-
* RoPE sine components
65-
* @param wcls
66-
* Classifier weights for output logits
43+
* @param token_embedding_table Token embeddings matrix
44+
* @param rms_att_weight RMSNorm weights for attention layers
45+
* @param wq Query weight matrices
46+
* @param wk Key weight matrices
47+
* @param wv Value weight matrices
48+
* @param wo Output projection matrices
49+
* @param rms_ffn_weight RMSNorm weights for FFN layers
50+
* @param w1 First FFN weight matrices
51+
* @param w2 Second FFN weight matrices
52+
* @param w3 Third FFN weight matrices (gate)
53+
* @param rms_final_weight Final layer normalization weights
54+
* @param freq_cis_real RoPE cosine components
55+
* @param freq_cis_imag RoPE sine components
56+
* @param wcls Classifier weights for output logits
6757
*/
6858
protected StandardWeights(FloatTensor token_embedding_table, FloatTensor[] rms_att_weight,
6959
FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo,
70-
//FloatTensor[] attnKNorm, FloatTensor[] attnQNorm,
7160
FloatTensor[] rms_ffn_weight,
7261
FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3,
7362
FloatTensor rms_final_weight,
@@ -92,4 +81,5 @@ protected StandardWeights(FloatTensor token_embedding_table, FloatTensor[] rms_a
9281
this.freq_cis_imag = freq_cis_imag;
9382
this.weightType = weightType;
9483
}
84+
//@formatter:on
9585
}

0 commit comments

Comments
 (0)