Skip to content

Commit 10ba7e8

Browse files
Refactor Model design. Abandon Records, adopt interface with abstract base class hierarchy
1 parent f1a0089 commit 10ba7e8

File tree

5 files changed

+102
-11
lines changed

5 files changed

+102
-11
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package com.example.model;
2+
3+
import com.example.inference.weights.Weights;
4+
import com.example.model.format.ChatFormat;
5+
import com.example.tokenizer.impl.Tokenizer;
6+
import com.example.tornadovm.TornadoVMMasterPlan;
7+
8+
public abstract class AbstractModel implements Model {
9+
10+
protected Tokenizer tokenizer;
11+
protected Weights weights;
12+
protected ChatFormat chatFormat;
13+
/**
14+
* Represents the master plan for the TornadoVM execution in the context of the model.
15+
* This variable is used to manage the execution flow or strategy within the TornadoVM environment.
16+
* <p>
17+
* Initialized *only* when Tornado is used, via {@link TornadoVMMasterPlan#initializeTornadoVMPlan}.
18+
* </p>
19+
*/
20+
protected TornadoVMMasterPlan plan;
21+
22+
protected AbstractModel(Tokenizer tokenizer, Weights weights, ChatFormat chatFormat, TornadoVMMasterPlan plan) {
23+
this.tokenizer = tokenizer;
24+
this.weights = weights;
25+
this.chatFormat = chatFormat;
26+
this.plan = plan;
27+
}
28+
29+
// Common methods across models
30+
31+
public Weights weights() {
32+
return weights;
33+
}
34+
35+
public ChatFormat chatFormat() {
36+
return chatFormat;
37+
}
38+
39+
public TornadoVMMasterPlan tornadoVMPlan() {
40+
return plan;
41+
}
42+
43+
public void setTornadoVMPlan(TornadoVMMasterPlan plan) {
44+
this.plan = plan;
45+
}
46+
47+
}

src/main/java/com/example/model/Model.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ public interface Model {
2929

3030
ChatFormat chatFormat();
3131

32+
TornadoVMMasterPlan tornadoVMPlan();
33+
34+
void setTornadoVMPlan(TornadoVMMasterPlan plan);
35+
3236
ModelType getModelType();
3337

3438
State createNewState();

src/main/java/com/example/model/llama/Llama.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import com.example.inference.state.LlamaState;
77
import com.example.inference.state.State;
88
import com.example.inference.weights.Weights;
9-
import com.example.model.Model;
9+
import com.example.model.AbstractModel;
1010
import com.example.model.ModelType;
1111
import com.example.model.format.ChatFormat;
1212
import com.example.tokenizer.impl.LlamaTokenizer;
@@ -16,10 +16,22 @@
1616
import java.util.Set;
1717
import java.util.function.IntConsumer;
1818

19-
public record Llama(LlamaConfiguration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) implements Model {
19+
public class Llama extends AbstractModel {
2020

21-
/* For explicit use */
22-
private LlamaTokenizer getAsLlamaTokenizer() {
21+
LlamaConfiguration configuration;
22+
23+
public Llama(LlamaConfiguration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) {
24+
super(tokenizer, weights, chatFormat, null);
25+
this.configuration = configuration;
26+
}
27+
28+
@Override
29+
public LlamaConfiguration configuration() {
30+
return configuration;
31+
}
32+
33+
@Override
34+
public LlamaTokenizer tokenizer() {
2335
return (LlamaTokenizer) tokenizer;
2436
}
2537

src/main/java/com/example/model/mistral/Mistral.java

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import com.example.inference.state.LlamaState;
77
import com.example.inference.state.State;
88
import com.example.inference.weights.Weights;
9-
import com.example.model.Model;
9+
import com.example.model.AbstractModel;
1010
import com.example.model.ModelType;
1111
import com.example.model.format.ChatFormat;
1212
import com.example.tokenizer.impl.MistralTokenizer;
@@ -16,11 +16,23 @@
1616
import java.util.Set;
1717
import java.util.function.IntConsumer;
1818

19-
public record Mistral(MistralConfiguration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) implements Model {
19+
public class Mistral extends AbstractModel {
2020

21-
/* For explicit use */
22-
private MistralTokenizer getAsMistralTokenizer() {
23-
return (MistralTokenizer) tokenizer;
21+
MistralConfiguration configuration;
22+
23+
public Mistral(MistralConfiguration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) {
24+
super(tokenizer, weights, chatFormat, null);
25+
this.configuration = configuration;
26+
}
27+
28+
@Override
29+
public MistralConfiguration configuration() {
30+
return configuration;
31+
}
32+
33+
@Override
34+
public MistralTokenizer tokenizer() {
35+
return (MistralTokenizer)tokenizer;
2436
}
2537

2638
@Override

src/main/java/com/example/model/qwen3/Qwen3.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,38 @@
66
import com.example.inference.state.Qwen3State;
77
import com.example.inference.state.State;
88
import com.example.inference.weights.Weights;
9-
import com.example.model.Model;
9+
import com.example.model.AbstractModel;
1010
import com.example.model.ModelType;
1111
import com.example.model.format.ChatFormat;
12+
import com.example.tokenizer.impl.Qwen3Tokenizer;
1213
import com.example.tokenizer.impl.Tokenizer;
1314

1415
import java.util.List;
1516
import java.util.Set;
1617
import java.util.function.IntConsumer;
1718

18-
public record Qwen3(Qwen3Configuration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) implements Model {
19+
public class Qwen3 extends AbstractModel {
20+
21+
Qwen3Configuration configuration;
22+
23+
public Qwen3(Qwen3Configuration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) {
24+
super(tokenizer, weights, chatFormat, null);
25+
this.configuration = configuration;
26+
}
27+
28+
public Qwen3Configuration configuration() {
29+
return configuration;
30+
}
1931

2032
@Override
2133
public ModelType getModelType() {
2234
return ModelType.QWEN_3;
2335
}
2436

37+
public Qwen3Tokenizer tokenizer() {
38+
return (Qwen3Tokenizer) tokenizer;
39+
}
40+
2541
@Override
2642
public State createNewState() {
2743
State state = new Qwen3State(configuration(), -1);

0 commit comments

Comments
 (0)