Skip to content

Commit 740bbac

Browse files
committed
Linter
1 parent 428a472 commit 740bbac

File tree

5 files changed

+368
-363
lines changed

5 files changed

+368
-363
lines changed

extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java

Lines changed: 100 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@
1616
import android.os.Looper;
1717
import android.system.ErrnoException;
1818
import android.system.Os;
19-
19+
import android.util.Log;
2020
import com.google.gson.Gson;
21-
2221
import java.io.File;
2322
import java.io.FileWriter;
2423
import java.io.IOException;
@@ -28,100 +27,111 @@
2827

2928
public class BenchmarkActivity extends Activity {
3029

31-
File mModel;
32-
int mNumIter;
33-
int mNumWarmupIter;
34-
String mTokenizerPath;
35-
float mTemperature;
36-
String mPrompt;
37-
38-
HandlerThread mHandlerThread;
39-
BenchmarkHandler mHandler;
40-
41-
List<BenchmarkMetric> mResult;
42-
43-
@Override
44-
protected void onCreate(Bundle savedInstanceState) {
45-
super.onCreate(savedInstanceState);
46-
47-
try {
48-
Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true);
49-
} catch (ErrnoException e) {
50-
finish();
51-
}
52-
53-
Intent intent = getIntent();
54-
File modelDir = new File(intent.getStringExtra("model_dir"));
55-
File model =
56-
Arrays.stream(modelDir.listFiles())
57-
.filter(file -> file.getName().endsWith(".pte"))
58-
.findFirst()
59-
.get();
60-
61-
int numIter = intent.getIntExtra("num_iter", 50);
62-
int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10);
63-
String tokenizerPath = intent.getStringExtra("tokenizer_path");
64-
float temperature = intent.getFloatExtra("temperature", 0.8f);
65-
String prompt = intent.getStringExtra("prompt");
66-
67-
mModel = model;
68-
mNumIter = numIter;
69-
mNumWarmupIter = numWarmupIter;
70-
mTokenizerPath = tokenizerPath;
71-
mTemperature = temperature;
72-
mPrompt = prompt;
73-
if (mPrompt == null) {
74-
mPrompt = "The ultimate answer";
75-
}
76-
mResult = new ArrayList<>();
77-
78-
mHandlerThread = new HandlerThread("ModelRunner");
79-
mHandlerThread.start();
80-
mHandler = new BenchmarkHandler(mHandlerThread.getLooper(), this);
81-
82-
mHandler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK);
83-
}
30+
File mModel;
31+
int mNumIter;
32+
int mNumWarmupIter;
33+
String mTokenizerPath;
34+
float mTemperature;
35+
String mPrompt;
8436

85-
void writeResult() {
86-
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
87-
Gson gson = new Gson();
88-
writer.write(gson.toJson(mResult));
89-
} catch (IOException e) {
90-
e.printStackTrace();
91-
} finally {
92-
finish();
93-
}
94-
}
95-
}
37+
HandlerThread mHandlerThread;
38+
BenchmarkHandler mHandler;
9639

97-
class BenchmarkHandler extends Handler {
98-
public static int MESSAGE_RUN_BENCHMARK = 1;
99-
public static int MESSAGE_LLM_RUN_BENCHMARK = 2;
40+
List<BenchmarkMetric> mResult;
10041

101-
ModelRunner mModelRunner;
102-
BenchmarkActivity mBenchmarkActivity;
42+
@Override
43+
protected void onCreate(Bundle savedInstanceState) {
44+
super.onCreate(savedInstanceState);
10345

104-
LlmModelRunner mLlmModelRunner;
105-
LlmBenchmark mLlmBenchmark;
46+
try {
47+
Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true);
48+
} catch (ErrnoException e) {
49+
finish();
50+
}
10651

107-
public BenchmarkHandler(Looper looper, BenchmarkActivity benchmarkActivity) {
108-
super(looper);
109-
mModelRunner = new ModelRunner();
110-
mBenchmarkActivity = benchmarkActivity;
52+
Intent intent = getIntent();
53+
File modelDir = new File(intent.getStringExtra("model_dir"));
54+
File model =
55+
Arrays.stream(modelDir.listFiles())
56+
.filter(file -> file.getName().endsWith(".pte"))
57+
.findFirst()
58+
.get();
59+
60+
int numIter = intent.getIntExtra("num_iter", 50);
61+
int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10);
62+
String tokenizerPath = intent.getStringExtra("tokenizer_path");
63+
float temperature = intent.getFloatExtra("temperature", 0.8f);
64+
String prompt = intent.getStringExtra("prompt");
65+
66+
mModel = model;
67+
mNumIter = numIter;
68+
mNumWarmupIter = numWarmupIter;
69+
mTokenizerPath = tokenizerPath;
70+
mTemperature = temperature;
71+
mPrompt = prompt;
72+
if (mPrompt == null) {
73+
mPrompt = "The ultimate answer";
74+
}
75+
mResult = new ArrayList<>();
76+
77+
mHandlerThread = new HandlerThread("ModelRunner");
78+
mHandlerThread.start();
79+
mHandler = new BenchmarkHandler(mHandlerThread.getLooper(), this);
80+
81+
mHandler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK);
82+
}
83+
84+
void writeResult() {
85+
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
86+
Gson gson = new Gson();
87+
writer.write(gson.toJson(mResult));
88+
} catch (IOException e) {
89+
e.printStackTrace();
90+
} finally {
91+
finish();
11192
}
93+
}
94+
}
11295

113-
@Override
114-
public void handleMessage(android.os.Message msg) {
115-
if (msg.what == MESSAGE_RUN_BENCHMARK) {
116-
mModelRunner.runBenchmark(mBenchmarkActivity.mModel, mBenchmarkActivity.mNumWarmupIter, mBenchmarkActivity.mNumIter, mBenchmarkActivity.mResult);
117-
118-
if (mBenchmarkActivity.mTokenizerPath == null) {
119-
mBenchmarkActivity.writeResult();
120-
} else {
121-
this.sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK);
122-
}
123-
} else if (msg.what == MESSAGE_LLM_RUN_BENCHMARK) {
124-
mLlmBenchmark = new LlmBenchmark(mBenchmarkActivity, mBenchmarkActivity.mModel.getPath(), mBenchmarkActivity.mTokenizerPath, mBenchmarkActivity.mPrompt, mBenchmarkActivity.mTemperature, mBenchmarkActivity.mResult);
125-
}
96+
class BenchmarkHandler extends Handler {
97+
public static int MESSAGE_RUN_BENCHMARK = 1;
98+
public static int MESSAGE_LLM_RUN_BENCHMARK = 2;
99+
100+
ModelRunner mModelRunner;
101+
BenchmarkActivity mBenchmarkActivity;
102+
103+
LlmModelRunner mLlmModelRunner;
104+
LlmBenchmark mLlmBenchmark;
105+
106+
public BenchmarkHandler(Looper looper, BenchmarkActivity benchmarkActivity) {
107+
super(looper);
108+
mModelRunner = new ModelRunner();
109+
mBenchmarkActivity = benchmarkActivity;
110+
}
111+
112+
@Override
113+
public void handleMessage(android.os.Message msg) {
114+
if (msg.what == MESSAGE_RUN_BENCHMARK) {
115+
mModelRunner.runBenchmark(
116+
mBenchmarkActivity.mModel,
117+
mBenchmarkActivity.mNumWarmupIter,
118+
mBenchmarkActivity.mNumIter,
119+
mBenchmarkActivity.mResult);
120+
121+
if (mBenchmarkActivity.mTokenizerPath == null) {
122+
mBenchmarkActivity.writeResult();
123+
} else {
124+
this.sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK);
125+
}
126+
} else if (msg.what == MESSAGE_LLM_RUN_BENCHMARK) {
127+
mLlmBenchmark =
128+
new LlmBenchmark(
129+
mBenchmarkActivity,
130+
mBenchmarkActivity.mModel.getName(),
131+
mBenchmarkActivity.mTokenizerPath,
132+
mBenchmarkActivity.mPrompt,
133+
mBenchmarkActivity.mTemperature,
134+
mBenchmarkActivity.mResult);
126135
}
136+
}
127137
}

extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java

Lines changed: 47 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,66 +10,65 @@
1010

1111
import android.app.ActivityManager;
1212
import android.os.Build;
13-
1413
import java.util.regex.Matcher;
1514
import java.util.regex.Pattern;
1615

1716
class BenchmarkMetric {
18-
public static class BenchmarkModel {
19-
// The model name, i.e. stories110M
20-
String name;
21-
String backend;
22-
String quantization;
17+
public static class BenchmarkModel {
18+
// The model name, i.e. stories110M
19+
String name;
20+
String backend;
21+
String quantization;
2322

24-
public BenchmarkModel(final String name, final String backend, final String quantization) {
25-
this.name = name;
26-
this.backend = backend;
27-
this.quantization = quantization;
28-
}
23+
public BenchmarkModel(final String name, final String backend, final String quantization) {
24+
this.name = name;
25+
this.backend = backend;
26+
this.quantization = quantization;
2927
}
28+
}
3029

31-
BenchmarkModel benchmarkModel;
30+
BenchmarkModel benchmarkModel;
3231

33-
// The metric name, i.e. TPS
34-
String metric;
32+
// The metric name, i.e. TPS
33+
String metric;
3534

36-
// The actual value and the option target value
37-
double actualValue;
38-
double targetValue;
35+
// The actual value and the option target value
36+
double actualValue;
37+
double targetValue;
3938

40-
public static class DeviceInfo {
41-
// Let's see which information we want to include here
42-
final String device = Build.BRAND;
43-
// The phone model and Android release version
44-
final String arch = Build.MODEL;
45-
final String os = "Android " + Build.VERSION.RELEASE;
46-
final long totalMem = new ActivityManager.MemoryInfo().totalMem;
47-
final long availMem = new ActivityManager.MemoryInfo().availMem;
48-
}
39+
public static class DeviceInfo {
40+
// Let's see which information we want to include here
41+
final String device = Build.BRAND;
42+
// The phone model and Android release version
43+
final String arch = Build.MODEL;
44+
final String os = "Android " + Build.VERSION.RELEASE;
45+
final long totalMem = new ActivityManager.MemoryInfo().totalMem;
46+
final long availMem = new ActivityManager.MemoryInfo().availMem;
47+
}
4948

50-
DeviceInfo deviceInfo = new DeviceInfo();
49+
DeviceInfo deviceInfo = new DeviceInfo();
5150

52-
public BenchmarkMetric(
53-
final BenchmarkModel benchmarkModel,
54-
final String metric,
55-
final double actualValue,
56-
final double targetValue) {
57-
this.benchmarkModel = benchmarkModel;
58-
this.metric = metric;
59-
this.actualValue = actualValue;
60-
this.targetValue = targetValue;
61-
}
51+
public BenchmarkMetric(
52+
final BenchmarkModel benchmarkModel,
53+
final String metric,
54+
final double actualValue,
55+
final double targetValue) {
56+
this.benchmarkModel = benchmarkModel;
57+
this.metric = metric;
58+
this.actualValue = actualValue;
59+
this.targetValue = targetValue;
60+
}
6261

63-
// TODO (huydhn): Figure out a way to extract the backend and quantization information from
64-
// the .pte model itself instead of parsing its name
65-
public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) {
66-
final Matcher m =
67-
Pattern.compile("(?<name>\\w+)_(?<backend>[\\w\\+]+)_(?<quantization>\\w+)").matcher(model);
68-
if (m.matches()) {
69-
return new BenchmarkMetric.BenchmarkModel(
70-
m.group("name"), m.group("backend"), m.group("quantization"));
71-
} else {
72-
return new BenchmarkMetric.BenchmarkModel(model, "", "");
73-
}
62+
// TODO (huydhn): Figure out a way to extract the backend and quantization information from
63+
// the .pte model itself instead of parsing its name
64+
public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) {
65+
final Matcher m =
66+
Pattern.compile("(?<name>\\w+)_(?<backend>[\\w\\+]+)_(?<quantization>\\w+)").matcher(model);
67+
if (m.matches()) {
68+
return new BenchmarkMetric.BenchmarkModel(
69+
m.group("name"), m.group("backend"), m.group("quantization"));
70+
} else {
71+
return new BenchmarkMetric.BenchmarkModel(model, "", "");
7472
}
73+
}
7574
}

0 commit comments

Comments
 (0)