Skip to content

Commit 428a472

Browse files
committed
Minibench refactor
Allow running generic model benchmark before LLM
1 parent fa52f0c commit 428a472

12 files changed

+449
-409
lines changed

extension/benchmark/android/benchmark/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ adb push tokenizer.bin /data/local/tmp/minibench
4343

4444
### Generic model
4545
```
46-
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \
46+
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \
4747
--es model_dir /data/local/tmp/minibench
4848
```
4949

5050
### LLM
5151
```
52-
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \
52+
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \
5353
--es model_dir /data/local/tmp/minibench --es tokenizer_path /data/local/tmp/minibench/tokenizer.bin
5454
```
5555

extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,11 @@ phases:
114114
adb -s $DEVICEFARM_DEVICE_UDID shell sleep 180
115115

116116
if [ -n "$BIN_FOUND" ]; then
117-
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \
117+
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \
118118
--es "model_dir" "/data/local/tmp/minibench" \
119119
--es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.bin"
120120
elif [ -n "$MODEL_FOUND" ]; then
121-
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \
121+
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \
122122
--es "model_dir" "/data/local/tmp/minibench" \
123123
--es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.model"
124124
else

extension/benchmark/android/benchmark/app/build.gradle.kts

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
plugins { id("com.android.application") }
9+
plugins { id("com.android.application")
10+
id("org.jetbrains.kotlin.android")
11+
}
1012

1113
android {
1214
namespace = "org.pytorch.minibench"
@@ -29,8 +31,11 @@ android {
2931
}
3032
}
3133
compileOptions {
32-
sourceCompatibility = JavaVersion.VERSION_1_8
33-
targetCompatibility = JavaVersion.VERSION_1_8
34+
sourceCompatibility = JavaVersion.VERSION_17
35+
targetCompatibility = JavaVersion.VERSION_17
36+
}
37+
kotlinOptions {
38+
jvmTarget = "17"
3439
}
3540
}
3641

@@ -40,6 +45,7 @@ dependencies {
4045
implementation("com.facebook.fbjni:fbjni:0.5.1")
4146
implementation("com.google.code.gson:gson:2.8.6")
4247
implementation("org.json:json:20250107")
48+
implementation("androidx.core:core-ktx:1.13.1")
4349
testImplementation("junit:junit:4.13.2")
4450
androidTestImplementation("androidx.test.ext:junit:1.2.1")
4551
androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1")

extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,6 @@
2121
</intent-filter>
2222
</activity>
2323

24-
<activity
25-
android:name=".LlmBenchmarkActivity"
26-
android:exported="true">
27-
<intent-filter>
28-
<action android:name="org.pytorch.minibench.BENCHMARK" />
29-
</intent-filter>
30-
</activity>
31-
3224
</application>
3325

3426
</manifest>

extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java

Lines changed: 87 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -10,132 +10,118 @@
1010

1111
import android.app.Activity;
1212
import android.content.Intent;
13-
import android.os.AsyncTask;
1413
import android.os.Bundle;
15-
import android.os.Debug;
14+
import android.os.Handler;
15+
import android.os.HandlerThread;
16+
import android.os.Looper;
1617
import android.system.ErrnoException;
1718
import android.system.Os;
19+
1820
import com.google.gson.Gson;
21+
1922
import java.io.File;
2023
import java.io.FileWriter;
2124
import java.io.IOException;
2225
import java.util.ArrayList;
2326
import java.util.Arrays;
24-
import java.util.Collections;
2527
import java.util.List;
26-
import java.util.stream.Collectors;
27-
import org.pytorch.executorch.Module;
2828

2929
public class BenchmarkActivity extends Activity {
30-
@Override
31-
protected void onCreate(Bundle savedInstanceState) {
32-
super.onCreate(savedInstanceState);
33-
34-
try {
35-
Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true);
36-
} catch (ErrnoException e) {
37-
finish();
38-
}
39-
40-
Intent intent = getIntent();
41-
File modelDir = new File(intent.getStringExtra("model_dir"));
42-
File model =
43-
Arrays.stream(modelDir.listFiles())
44-
.filter(file -> file.getName().endsWith(".pte"))
45-
.findFirst()
46-
.get();
4730

48-
int numIter = intent.getIntExtra("num_iter", 50);
49-
int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10);
31+
File mModel;
32+
int mNumIter;
33+
int mNumWarmupIter;
34+
String mTokenizerPath;
35+
float mTemperature;
36+
String mPrompt;
5037

51-
long pssIdle = Debug.getPss();
38+
HandlerThread mHandlerThread;
39+
BenchmarkHandler mHandler;
5240

53-
// TODO: Format the string with a parsable format
54-
Stats stats = new Stats();
41+
List<BenchmarkMetric> mResult;
5542

56-
new AsyncTask<Void, Void, Void>() {
57-
@Override
58-
protected Void doInBackground(Void... voids) {
43+
@Override
44+
protected void onCreate(Bundle savedInstanceState) {
45+
super.onCreate(savedInstanceState);
5946

60-
// Record the time it takes to load the model and the forward method
61-
stats.loadStart = System.nanoTime();
62-
Module module = Module.load(model.getPath());
63-
stats.errorCode = module.loadMethod("forward");
64-
stats.loadEnd = System.nanoTime();
65-
66-
for (int i = 0; i < numWarmupIter; i++) {
67-
module.forward();
47+
try {
48+
Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true);
49+
} catch (ErrnoException e) {
50+
finish();
6851
}
6952

70-
for (int i = 0; i < numIter; i++) {
71-
long start = System.nanoTime();
72-
module.forward();
73-
double forwardMs = (System.nanoTime() - start) * 1e-6;
74-
stats.latency.add(forwardMs);
53+
Intent intent = getIntent();
54+
File modelDir = new File(intent.getStringExtra("model_dir"));
55+
File model =
56+
Arrays.stream(modelDir.listFiles())
57+
.filter(file -> file.getName().endsWith(".pte"))
58+
.findFirst()
59+
.get();
60+
61+
int numIter = intent.getIntExtra("num_iter", 50);
62+
int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10);
63+
String tokenizerPath = intent.getStringExtra("tokenizer_path");
64+
float temperature = intent.getFloatExtra("temperature", 0.8f);
65+
String prompt = intent.getStringExtra("prompt");
66+
67+
mModel = model;
68+
mNumIter = numIter;
69+
mNumWarmupIter = numWarmupIter;
70+
mTokenizerPath = tokenizerPath;
71+
mTemperature = temperature;
72+
mPrompt = prompt;
73+
if (mPrompt == null) {
74+
mPrompt = "The ultimate answer";
7575
}
76-
return null;
77-
}
78-
79-
@Override
80-
protected void onPostExecute(Void aVoid) {
81-
82-
final BenchmarkMetric.BenchmarkModel benchmarkModel =
83-
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
84-
final List<BenchmarkMetric> results = new ArrayList<>();
85-
// The list of metrics we have atm includes:
86-
// Avg inference latency after N iterations
87-
// Currently the result has large variance from outliers, so only use
88-
// 80% samples in the middle (trimmean 0.2)
89-
Collections.sort(stats.latency);
90-
int resultSize = stats.latency.size();
91-
List<Double> usedLatencyResults =
92-
stats.latency.subList(resultSize / 10, resultSize * 9 / 10);
93-
94-
results.add(
95-
new BenchmarkMetric(
96-
benchmarkModel,
97-
"avg_inference_latency(ms)",
98-
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
99-
0.0f));
100-
results.add(
101-
new BenchmarkMetric(
102-
benchmarkModel,
103-
"trimmean_inference_latency(ms)",
104-
usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f),
105-
0.0f));
106-
// Model load time
107-
results.add(
108-
new BenchmarkMetric(
109-
benchmarkModel,
110-
"model_load_time(ms)",
111-
(stats.loadEnd - stats.loadStart) * 1e-6,
112-
0.0f));
113-
// Load status
114-
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));
115-
// RAM PSS usage
116-
results.add(
117-
new BenchmarkMetric(
118-
benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0));
76+
mResult = new ArrayList<>();
77+
78+
mHandlerThread = new HandlerThread("ModelRunner");
79+
mHandlerThread.start();
80+
mHandler = new BenchmarkHandler(mHandlerThread.getLooper(), this);
81+
82+
mHandler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK);
83+
}
11984

85+
void writeResult() {
12086
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
121-
Gson gson = new Gson();
122-
writer.write(gson.toJson(results));
87+
Gson gson = new Gson();
88+
writer.write(gson.toJson(mResult));
12389
} catch (IOException e) {
124-
e.printStackTrace();
90+
e.printStackTrace();
91+
} finally {
92+
finish();
12593
}
126-
}
127-
}.execute();
128-
}
94+
}
12995
}
13096

131-
class Stats {
132-
long loadStart;
133-
long loadEnd;
134-
List<Double> latency = new ArrayList<>();
135-
int errorCode = 0;
97+
class BenchmarkHandler extends Handler {
98+
public static int MESSAGE_RUN_BENCHMARK = 1;
99+
public static int MESSAGE_LLM_RUN_BENCHMARK = 2;
100+
101+
ModelRunner mModelRunner;
102+
BenchmarkActivity mBenchmarkActivity;
136103

137-
@Override
138-
public String toString() {
139-
return "latency: " + latency.stream().map(Object::toString).collect(Collectors.joining(""));
140-
}
104+
LlmModelRunner mLlmModelRunner;
105+
LlmBenchmark mLlmBenchmark;
106+
107+
public BenchmarkHandler(Looper looper, BenchmarkActivity benchmarkActivity) {
108+
super(looper);
109+
mModelRunner = new ModelRunner();
110+
mBenchmarkActivity = benchmarkActivity;
111+
}
112+
113+
@Override
114+
public void handleMessage(android.os.Message msg) {
115+
if (msg.what == MESSAGE_RUN_BENCHMARK) {
116+
mModelRunner.runBenchmark(mBenchmarkActivity.mModel, mBenchmarkActivity.mNumWarmupIter, mBenchmarkActivity.mNumIter, mBenchmarkActivity.mResult);
117+
118+
if (mBenchmarkActivity.mTokenizerPath == null) {
119+
mBenchmarkActivity.writeResult();
120+
} else {
121+
this.sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK);
122+
}
123+
} else if (msg.what == MESSAGE_LLM_RUN_BENCHMARK) {
124+
mLlmBenchmark = new LlmBenchmark(mBenchmarkActivity, mBenchmarkActivity.mModel.getPath(), mBenchmarkActivity.mTokenizerPath, mBenchmarkActivity.mPrompt, mBenchmarkActivity.mTemperature, mBenchmarkActivity.mResult);
125+
}
126+
}
141127
}

0 commit comments

Comments
 (0)