Skip to content

Commit 98c2c53

Browse files
authored
Revert "Minibench refactor (#10376)" (#10405)
This reverts commit ebab7bb. Broke llama benchmarking.
1 parent 80fc3fc commit 98c2c53

10 files changed

+240
-283
lines changed

extension/benchmark/android/benchmark/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ adb push tokenizer.bin /data/local/tmp/minibench
4343

4444
### Generic model
4545
```
46-
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \
46+
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \
4747
--es model_dir /data/local/tmp/minibench
4848
```
4949

5050
### LLM
5151
```
52-
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \
52+
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \
5353
--es model_dir /data/local/tmp/minibench --es tokenizer_path /data/local/tmp/minibench/tokenizer.bin
5454
```
5555

extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,11 @@ phases:
114114
adb -s $DEVICEFARM_DEVICE_UDID shell sleep 180
115115

116116
if [ -n "$BIN_FOUND" ]; then
117-
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \
117+
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \
118118
--es "model_dir" "/data/local/tmp/minibench" \
119119
--es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.bin"
120120
elif [ -n "$MODEL_FOUND" ]; then
121-
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \
121+
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \
122122
--es "model_dir" "/data/local/tmp/minibench" \
123123
--es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.model"
124124
else

extension/benchmark/android/benchmark/app/build.gradle.kts

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
plugins { id("com.android.application")
10-
id("org.jetbrains.kotlin.android")
11-
}
9+
plugins { id("com.android.application") }
1210

1311
android {
1412
namespace = "org.pytorch.minibench"
@@ -31,11 +29,8 @@ android {
3129
}
3230
}
3331
compileOptions {
34-
sourceCompatibility = JavaVersion.VERSION_17
35-
targetCompatibility = JavaVersion.VERSION_17
36-
}
37-
kotlinOptions {
38-
jvmTarget = "17"
32+
sourceCompatibility = JavaVersion.VERSION_1_8
33+
targetCompatibility = JavaVersion.VERSION_1_8
3934
}
4035
}
4136

@@ -45,7 +40,6 @@ dependencies {
4540
implementation("com.facebook.fbjni:fbjni:0.5.1")
4641
implementation("com.google.code.gson:gson:2.8.6")
4742
implementation("org.json:json:20250107")
48-
implementation("androidx.core:core-ktx:1.13.1")
4943
testImplementation("junit:junit:4.13.2")
5044
androidTestImplementation("androidx.test.ext:junit:1.2.1")
5145
androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1")

extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@
2121
</intent-filter>
2222
</activity>
2323

24+
<activity
25+
android:name=".LlmBenchmarkActivity"
26+
android:exported="true">
27+
<intent-filter>
28+
<action android:name="org.pytorch.minibench.BENCHMARK" />
29+
</intent-filter>
30+
</activity>
31+
2432
</application>
2533

2634
</manifest>

extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java

Lines changed: 85 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010

1111
import android.app.Activity;
1212
import android.content.Intent;
13+
import android.os.AsyncTask;
1314
import android.os.Bundle;
14-
import android.os.Handler;
15-
import android.os.HandlerThread;
16-
import android.os.Looper;
15+
import android.os.Debug;
1716
import android.system.ErrnoException;
1817
import android.system.Os;
1918
import com.google.gson.Gson;
@@ -22,22 +21,12 @@
2221
import java.io.IOException;
2322
import java.util.ArrayList;
2423
import java.util.Arrays;
24+
import java.util.Collections;
2525
import java.util.List;
26+
import java.util.stream.Collectors;
27+
import org.pytorch.executorch.Module;
2628

2729
public class BenchmarkActivity extends Activity {
28-
29-
File mModel;
30-
int mNumIter;
31-
int mNumWarmupIter;
32-
String mTokenizerPath;
33-
float mTemperature;
34-
String mPrompt;
35-
36-
HandlerThread mHandlerThread;
37-
BenchmarkHandler mHandler;
38-
39-
List<BenchmarkMetric> mResult;
40-
4130
@Override
4231
protected void onCreate(Bundle savedInstanceState) {
4332
super.onCreate(savedInstanceState);
@@ -58,79 +47,95 @@ protected void onCreate(Bundle savedInstanceState) {
5847

5948
int numIter = intent.getIntExtra("num_iter", 50);
6049
int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10);
61-
String tokenizerPath = intent.getStringExtra("tokenizer_path");
62-
float temperature = intent.getFloatExtra("temperature", 0.8f);
63-
String prompt = intent.getStringExtra("prompt");
64-
65-
mModel = model;
66-
mNumIter = numIter;
67-
mNumWarmupIter = numWarmupIter;
68-
mTokenizerPath = tokenizerPath;
69-
mTemperature = temperature;
70-
mPrompt = prompt;
71-
if (mPrompt == null) {
72-
mPrompt = "The ultimate answer";
73-
}
74-
mResult = new ArrayList<>();
7550

76-
mHandlerThread = new HandlerThread("ModelRunner");
77-
mHandlerThread.start();
78-
mHandler = new BenchmarkHandler(mHandlerThread.getLooper(), this);
51+
long pssIdle = Debug.getPss();
7952

80-
mHandler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK);
81-
}
53+
// TODO: Format the string with a parsable format
54+
Stats stats = new Stats();
8255

83-
void writeResult() {
84-
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
85-
Gson gson = new Gson();
86-
writer.write(gson.toJson(mResult));
87-
} catch (IOException e) {
88-
e.printStackTrace();
89-
} finally {
90-
finish();
91-
}
92-
}
93-
}
56+
new AsyncTask<Void, Void, Void>() {
57+
@Override
58+
protected Void doInBackground(Void... voids) {
9459

95-
class BenchmarkHandler extends Handler {
96-
public static int MESSAGE_RUN_BENCHMARK = 1;
97-
public static int MESSAGE_LLM_RUN_BENCHMARK = 2;
60+
// Record the time it takes to load the model and the forward method
61+
stats.loadStart = System.nanoTime();
62+
Module module = Module.load(model.getPath());
63+
stats.errorCode = module.loadMethod("forward");
64+
stats.loadEnd = System.nanoTime();
9865

99-
ModelRunner mModelRunner;
100-
BenchmarkActivity mBenchmarkActivity;
66+
for (int i = 0; i < numWarmupIter; i++) {
67+
module.forward();
68+
}
10169

102-
LlmModelRunner mLlmModelRunner;
103-
LlmBenchmark mLlmBenchmark;
70+
for (int i = 0; i < numIter; i++) {
71+
long start = System.nanoTime();
72+
module.forward();
73+
double forwardMs = (System.nanoTime() - start) * 1e-6;
74+
stats.latency.add(forwardMs);
75+
}
76+
return null;
77+
}
10478

105-
public BenchmarkHandler(Looper looper, BenchmarkActivity benchmarkActivity) {
106-
super(looper);
107-
mModelRunner = new ModelRunner();
108-
mBenchmarkActivity = benchmarkActivity;
79+
@Override
80+
protected void onPostExecute(Void aVoid) {
81+
82+
final BenchmarkMetric.BenchmarkModel benchmarkModel =
83+
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
84+
final List<BenchmarkMetric> results = new ArrayList<>();
85+
// The list of metrics we have atm includes:
86+
// Avg inference latency after N iterations
87+
// Currently the result has large variance from outliers, so only use
88+
// 80% samples in the middle (trimmean 0.2)
89+
Collections.sort(stats.latency);
90+
int resultSize = stats.latency.size();
91+
List<Double> usedLatencyResults =
92+
stats.latency.subList(resultSize / 10, resultSize * 9 / 10);
93+
94+
results.add(
95+
new BenchmarkMetric(
96+
benchmarkModel,
97+
"avg_inference_latency(ms)",
98+
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
99+
0.0f));
100+
results.add(
101+
new BenchmarkMetric(
102+
benchmarkModel,
103+
"trimmean_inference_latency(ms)",
104+
usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f),
105+
0.0f));
106+
// Model load time
107+
results.add(
108+
new BenchmarkMetric(
109+
benchmarkModel,
110+
"model_load_time(ms)",
111+
(stats.loadEnd - stats.loadStart) * 1e-6,
112+
0.0f));
113+
// Load status
114+
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));
115+
// RAM PSS usage
116+
results.add(
117+
new BenchmarkMetric(
118+
benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0));
119+
120+
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
121+
Gson gson = new Gson();
122+
writer.write(gson.toJson(results));
123+
} catch (IOException e) {
124+
e.printStackTrace();
125+
}
126+
}
127+
}.execute();
109128
}
129+
}
130+
131+
class Stats {
132+
long loadStart;
133+
long loadEnd;
134+
List<Double> latency = new ArrayList<>();
135+
int errorCode = 0;
110136

111137
@Override
112-
public void handleMessage(android.os.Message msg) {
113-
if (msg.what == MESSAGE_RUN_BENCHMARK) {
114-
mModelRunner.runBenchmark(
115-
mBenchmarkActivity.mModel,
116-
mBenchmarkActivity.mNumWarmupIter,
117-
mBenchmarkActivity.mNumIter,
118-
mBenchmarkActivity.mResult);
119-
120-
if (mBenchmarkActivity.mTokenizerPath == null) {
121-
mBenchmarkActivity.writeResult();
122-
} else {
123-
this.sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK);
124-
}
125-
} else if (msg.what == MESSAGE_LLM_RUN_BENCHMARK) {
126-
mLlmBenchmark =
127-
new LlmBenchmark(
128-
mBenchmarkActivity,
129-
mBenchmarkActivity.mModel.getPath(),
130-
mBenchmarkActivity.mTokenizerPath,
131-
mBenchmarkActivity.mPrompt,
132-
mBenchmarkActivity.mTemperature,
133-
mBenchmarkActivity.mResult);
134-
}
138+
public String toString() {
139+
return "latency: " + latency.stream().map(Object::toString).collect(Collectors.joining(""));
135140
}
136141
}

0 commit comments

Comments
 (0)