Skip to content

Commit fa52f0c

Browse files
authored
[Android] New config API for Llm init and generate
Differential Revision: D73409726 Pull Request resolved: #10345
1 parent 3a66b14 commit fa52f0c

File tree

5 files changed

+326
-4
lines changed

5 files changed

+326
-4
lines changed

extension/android/BUCK

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ non_fbcode_target(_kind = fb_android_library,
2525
name = "executorch_llama",
2626
srcs = [
2727
"executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java",
28+
"executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.java",
2829
"executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java",
30+
"executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.java",
2931
],
3032
autoglob = False,
3133
language = "JAVA",

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ public interface LlmCallback {
3131
/**
3232
* Called when the statistics for the generate() is available.
3333
*
34-
* The result will be a JSON string. See extension/llm/stats.h for the field
35-
* definitions.
34+
* <p>The result will be a JSON string. See extension/llm/stats.h for the field definitions.
3635
*
3736
* @param stats JSON string containing the statistics for the generate()
3837
*/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch.extension.llm;
10+
11+
/**
12+
* Configuration class for controlling text generation parameters in LLM operations.
13+
*
14+
* <p>This class provides settings for text generation behavior including output formatting,
15+
* generation limits, and sampling parameters. Instances should be created using the {@link
16+
* #create()} method and the fluent builder pattern.
17+
*/
18+
public class LlmGenerationConfig {
19+
private final boolean echo;
20+
private final int maxNewTokens;
21+
private final boolean warming;
22+
private final int seqLen;
23+
private final float temperature;
24+
25+
private LlmGenerationConfig(Builder builder) {
26+
this.echo = builder.echo;
27+
this.maxNewTokens = builder.maxNewTokens;
28+
this.warming = builder.warming;
29+
this.seqLen = builder.seqLen;
30+
this.temperature = builder.temperature;
31+
}
32+
33+
/**
34+
* Creates a new Builder instance for constructing generation configurations.
35+
*
36+
* @return a new Builder with default configuration values
37+
*/
38+
public static Builder create() {
39+
return new Builder();
40+
}
41+
42+
/**
43+
* @return true if input prompt should be included in the output
44+
*/
45+
public boolean isEcho() {
46+
return echo;
47+
}
48+
49+
/**
50+
* @return maximum number of tokens to generate (-1 for unlimited)
51+
*/
52+
public int getMaxNewTokens() {
53+
return maxNewTokens;
54+
}
55+
56+
/**
57+
* @return true if model warming is enabled
58+
*/
59+
public boolean isWarming() {
60+
return warming;
61+
}
62+
63+
/**
64+
* @return maximum sequence length for generation (-1 for default)
65+
*/
66+
public int getSeqLen() {
67+
return seqLen;
68+
}
69+
70+
/**
71+
* @return temperature value for sampling (higher = more random)
72+
*/
73+
public float getTemperature() {
74+
return temperature;
75+
}
76+
77+
/**
78+
* Builder class for constructing LlmGenerationConfig instances.
79+
*
80+
* <p>Provides a fluent interface for configuring generation parameters with sensible defaults.
81+
* All methods return the builder instance to enable method chaining.
82+
*/
83+
public static class Builder {
84+
private boolean echo = true;
85+
private int maxNewTokens = -1;
86+
private boolean warming = false;
87+
private int seqLen = -1;
88+
private float temperature = 0.8f;
89+
90+
Builder() {}
91+
92+
/**
93+
* Sets whether to include the input prompt in the generated output.
94+
*
95+
* @param echo true to include input prompt, false to return only new tokens
96+
* @return this builder instance
97+
*/
98+
public Builder echo(boolean echo) {
99+
this.echo = echo;
100+
return this;
101+
}
102+
103+
/**
104+
* Sets the maximum number of new tokens to generate.
105+
*
106+
* @param maxNewTokens the token limit (-1 for unlimited generation)
107+
* @return this builder instance
108+
*/
109+
public Builder maxNewTokens(int maxNewTokens) {
110+
this.maxNewTokens = maxNewTokens;
111+
return this;
112+
}
113+
114+
/**
115+
* Enables or disables model warming.
116+
*
117+
* @param warming true to generate initial tokens for model warmup
118+
* @return this builder instance
119+
*/
120+
public Builder warming(boolean warming) {
121+
this.warming = warming;
122+
return this;
123+
}
124+
125+
/**
126+
* Sets the maximum sequence length for generation.
127+
*
128+
* @param seqLen maximum sequence length (-1 for default behavior)
129+
* @return this builder instance
130+
*/
131+
public Builder seqLen(int seqLen) {
132+
this.seqLen = seqLen;
133+
return this;
134+
}
135+
136+
/**
137+
* Sets the temperature for random sampling.
138+
*
139+
* @param temperature sampling temperature (typical range 0.0-1.0)
140+
* @return this builder instance
141+
*/
142+
public Builder temperature(float temperature) {
143+
this.temperature = temperature;
144+
return this;
145+
}
146+
147+
/**
148+
* Constructs the LlmGenerationConfig instance with the configured parameters.
149+
*
150+
* @return new LlmGenerationConfig instance with current builder settings
151+
*/
152+
public LlmGenerationConfig build() {
153+
return new LlmGenerationConfig(this);
154+
}
155+
}
156+
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
import org.pytorch.executorch.annotations.Experimental;
1616

1717
/**
18-
* LlmModule is a wrapper around the Executorch LLM. It provides a simple interface to
19-
* generate text from the model.
18+
* LlmModule is a wrapper around the Executorch LLM. It provides a simple interface to generate text
19+
* from the model.
2020
*
2121
* <p>Warning: These APIs are experimental and subject to change without notice
2222
*/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch.extension.llm;
10+
11+
/**
12+
* Configuration class for initializing a LlmModule.
13+
*
14+
* <p>{@link #create()} method and the fluent builder pattern.
15+
*/
16+
public class LlmModuleConfig {
17+
private final String modulePath;
18+
private final String tokenizerPath;
19+
private final float temperature;
20+
private final String dataPath;
21+
private final int modelType;
22+
23+
private LlmModuleConfig(Builder builder) {
24+
this.modulePath = builder.modulePath;
25+
this.tokenizerPath = builder.tokenizerPath;
26+
this.temperature = builder.temperature;
27+
this.dataPath = builder.dataPath;
28+
this.modelType = builder.modelType;
29+
}
30+
31+
/** Model type constant for text-only models. */
32+
public static final int MODEL_TYPE_TEXT = 1;
33+
34+
/** Model type constant for text-and-vision multimodal models. */
35+
public static final int MODEL_TYPE_TEXT_VISION = 2;
36+
37+
/**
38+
* Creates a new Builder instance for constructing LlmModuleConfig objects.
39+
*
40+
* @return a new Builder instance with default configuration values
41+
*/
42+
public static Builder create() {
43+
return new Builder();
44+
}
45+
46+
// Getters with documentation
47+
/**
48+
* @return Path to the compiled model module (.pte file)
49+
*/
50+
public String getModulePath() {
51+
return modulePath;
52+
}
53+
54+
/**
55+
* @return Path to the tokenizer file or directory
56+
*/
57+
public String getTokenizerPath() {
58+
return tokenizerPath;
59+
}
60+
61+
/**
62+
* @return Temperature value for sampling (higher = more random)
63+
*/
64+
public float getTemperature() {
65+
return temperature;
66+
}
67+
68+
/**
69+
* @return Optional path to additional data files
70+
*/
71+
public String getDataPath() {
72+
return dataPath;
73+
}
74+
75+
/**
76+
* @return Type of model (text-only or text-vision)
77+
*/
78+
public int getModelType() {
79+
return modelType;
80+
}
81+
82+
/**
83+
* Builder class for constructing LlmModuleConfig instances with optional parameters.
84+
*
85+
* <p>The builder provides a fluent interface for configuring model parameters and validates
86+
* required fields before construction.
87+
*/
88+
public static class Builder {
89+
private String modulePath;
90+
private String tokenizerPath;
91+
private float temperature = 0.8f;
92+
private String dataPath = "";
93+
private int modelType = MODEL_TYPE_TEXT;
94+
95+
Builder() {}
96+
97+
/**
98+
* Sets the path to the module.
99+
*
100+
* @param modulePath Path to module
101+
* @return This builder instance for method chaining
102+
*/
103+
public Builder modulePath(String modulePath) {
104+
this.modulePath = modulePath;
105+
return this;
106+
}
107+
108+
/**
109+
* Sets the path to the tokenizer.
110+
*
111+
* @param tokenizerPath Path to tokenizer
112+
* @return This builder instance for method chaining
113+
*/
114+
public Builder tokenizerPath(String tokenizerPath) {
115+
this.tokenizerPath = tokenizerPath;
116+
return this;
117+
}
118+
119+
/**
120+
* Sets the temperature for sampling generation.
121+
*
122+
* @param temperature Temperature value (typical range 0.0-1.0)
123+
* @return This builder instance for method chaining
124+
*/
125+
public Builder temperature(float temperature) {
126+
this.temperature = temperature;
127+
return this;
128+
}
129+
130+
/**
131+
* Sets the path to optional additional data files.
132+
*
133+
* @param dataPath Path to supplementary data resources
134+
* @return This builder instance for method chaining
135+
*/
136+
public Builder dataPath(String dataPath) {
137+
this.dataPath = dataPath;
138+
return this;
139+
}
140+
141+
/**
142+
* Sets the model type (text-only or multimodal).
143+
*
144+
* @param modelType One of MODEL_TYPE_TEXT or MODEL_TYPE_TEXT_VISION
145+
* @return This builder instance for method chaining
146+
*/
147+
public Builder modelType(int modelType) {
148+
this.modelType = modelType;
149+
return this;
150+
}
151+
152+
/**
153+
* Constructs the LlmModuleConfig instance with validated parameters.
154+
*
155+
* @return New LlmModuleConfig instance with configured values
156+
* @throws IllegalArgumentException if required fields are missing
157+
*/
158+
public LlmModuleConfig build() {
159+
if (modulePath == null || tokenizerPath == null) {
160+
throw new IllegalArgumentException("Module path and tokenizer path are required");
161+
}
162+
return new LlmModuleConfig(this);
163+
}
164+
}
165+
}

0 commit comments

Comments
 (0)