10
10
11
11
import android .app .Activity ;
12
12
import android .content .Intent ;
13
+ import android .os .AsyncTask ;
13
14
import android .os .Bundle ;
14
- import android .os .Handler ;
15
- import android .os .HandlerThread ;
16
- import android .os .Looper ;
15
+ import android .os .Debug ;
17
16
import android .system .ErrnoException ;
18
17
import android .system .Os ;
19
18
import com .google .gson .Gson ;
22
21
import java .io .IOException ;
23
22
import java .util .ArrayList ;
24
23
import java .util .Arrays ;
24
+ import java .util .Collections ;
25
25
import java .util .List ;
26
+ import java .util .stream .Collectors ;
27
+ import org .pytorch .executorch .Module ;
26
28
27
29
public class BenchmarkActivity extends Activity {
28
-
29
- File mModel ;
30
- int mNumIter ;
31
- int mNumWarmupIter ;
32
- String mTokenizerPath ;
33
- float mTemperature ;
34
- String mPrompt ;
35
-
36
- HandlerThread mHandlerThread ;
37
- BenchmarkHandler mHandler ;
38
-
39
- List <BenchmarkMetric > mResult ;
40
-
41
30
@ Override
42
31
protected void onCreate (Bundle savedInstanceState ) {
43
32
super .onCreate (savedInstanceState );
@@ -58,79 +47,95 @@ protected void onCreate(Bundle savedInstanceState) {
58
47
59
48
int numIter = intent .getIntExtra ("num_iter" , 50 );
60
49
int numWarmupIter = intent .getIntExtra ("num_warm_up_iter" , 10 );
61
- String tokenizerPath = intent .getStringExtra ("tokenizer_path" );
62
- float temperature = intent .getFloatExtra ("temperature" , 0.8f );
63
- String prompt = intent .getStringExtra ("prompt" );
64
-
65
- mModel = model ;
66
- mNumIter = numIter ;
67
- mNumWarmupIter = numWarmupIter ;
68
- mTokenizerPath = tokenizerPath ;
69
- mTemperature = temperature ;
70
- mPrompt = prompt ;
71
- if (mPrompt == null ) {
72
- mPrompt = "The ultimate answer" ;
73
- }
74
- mResult = new ArrayList <>();
75
50
76
- mHandlerThread = new HandlerThread ("ModelRunner" );
77
- mHandlerThread .start ();
78
- mHandler = new BenchmarkHandler (mHandlerThread .getLooper (), this );
51
+ long pssIdle = Debug .getPss ();
79
52
80
- mHandler . sendEmptyMessage ( BenchmarkHandler . MESSAGE_RUN_BENCHMARK );
81
- }
53
+ // TODO: Format the string with a parsable format
54
+ Stats stats = new Stats ();
82
55
83
- void writeResult () {
84
- try (FileWriter writer = new FileWriter (getFilesDir () + "/benchmark_results.json" )) {
85
- Gson gson = new Gson ();
86
- writer .write (gson .toJson (mResult ));
87
- } catch (IOException e ) {
88
- e .printStackTrace ();
89
- } finally {
90
- finish ();
91
- }
92
- }
93
- }
56
+ new AsyncTask <Void , Void , Void >() {
57
+ @ Override
58
+ protected Void doInBackground (Void ... voids ) {
94
59
95
- class BenchmarkHandler extends Handler {
96
- public static int MESSAGE_RUN_BENCHMARK = 1 ;
97
- public static int MESSAGE_LLM_RUN_BENCHMARK = 2 ;
60
+ // Record the time it takes to load the model and the forward method
61
+ stats .loadStart = System .nanoTime ();
62
+ Module module = Module .load (model .getPath ());
63
+ stats .errorCode = module .loadMethod ("forward" );
64
+ stats .loadEnd = System .nanoTime ();
98
65
99
- ModelRunner mModelRunner ;
100
- BenchmarkActivity mBenchmarkActivity ;
66
+ for (int i = 0 ; i < numWarmupIter ; i ++) {
67
+ module .forward ();
68
+ }
101
69
102
- LlmModelRunner mLlmModelRunner ;
103
- LlmBenchmark mLlmBenchmark ;
70
+ for (int i = 0 ; i < numIter ; i ++) {
71
+ long start = System .nanoTime ();
72
+ module .forward ();
73
+ double forwardMs = (System .nanoTime () - start ) * 1e-6 ;
74
+ stats .latency .add (forwardMs );
75
+ }
76
+ return null ;
77
+ }
104
78
105
- public BenchmarkHandler (Looper looper , BenchmarkActivity benchmarkActivity ) {
106
- super (looper );
107
- mModelRunner = new ModelRunner ();
108
- mBenchmarkActivity = benchmarkActivity ;
79
+ @ Override
80
+ protected void onPostExecute (Void aVoid ) {
81
+
82
+ final BenchmarkMetric .BenchmarkModel benchmarkModel =
83
+ BenchmarkMetric .extractBackendAndQuantization (model .getName ().replace (".pte" , "" ));
84
+ final List <BenchmarkMetric > results = new ArrayList <>();
85
+ // The list of metrics we have atm includes:
86
+ // Avg inference latency after N iterations
87
+ // Currently the result has large variance from outliers, so only use
88
+ // 80% samples in the middle (trimmean 0.2)
89
+ Collections .sort (stats .latency );
90
+ int resultSize = stats .latency .size ();
91
+ List <Double > usedLatencyResults =
92
+ stats .latency .subList (resultSize / 10 , resultSize * 9 / 10 );
93
+
94
+ results .add (
95
+ new BenchmarkMetric (
96
+ benchmarkModel ,
97
+ "avg_inference_latency(ms)" ,
98
+ stats .latency .stream ().mapToDouble (l -> l ).average ().orElse (0.0f ),
99
+ 0.0f ));
100
+ results .add (
101
+ new BenchmarkMetric (
102
+ benchmarkModel ,
103
+ "trimmean_inference_latency(ms)" ,
104
+ usedLatencyResults .stream ().mapToDouble (l -> l ).average ().orElse (0.0f ),
105
+ 0.0f ));
106
+ // Model load time
107
+ results .add (
108
+ new BenchmarkMetric (
109
+ benchmarkModel ,
110
+ "model_load_time(ms)" ,
111
+ (stats .loadEnd - stats .loadStart ) * 1e-6 ,
112
+ 0.0f ));
113
+ // Load status
114
+ results .add (new BenchmarkMetric (benchmarkModel , "load_status" , stats .errorCode , 0 ));
115
+ // RAM PSS usage
116
+ results .add (
117
+ new BenchmarkMetric (
118
+ benchmarkModel , "ram_pss_usage(mb)" , (Debug .getPss () - pssIdle ) / 1024 , 0 ));
119
+
120
+ try (FileWriter writer = new FileWriter (getFilesDir () + "/benchmark_results.json" )) {
121
+ Gson gson = new Gson ();
122
+ writer .write (gson .toJson (results ));
123
+ } catch (IOException e ) {
124
+ e .printStackTrace ();
125
+ }
126
+ }
127
+ }.execute ();
109
128
}
129
+ }
130
+
131
+ class Stats {
132
+ long loadStart ;
133
+ long loadEnd ;
134
+ List <Double > latency = new ArrayList <>();
135
+ int errorCode = 0 ;
110
136
111
137
@ Override
112
- public void handleMessage (android .os .Message msg ) {
113
- if (msg .what == MESSAGE_RUN_BENCHMARK ) {
114
- mModelRunner .runBenchmark (
115
- mBenchmarkActivity .mModel ,
116
- mBenchmarkActivity .mNumWarmupIter ,
117
- mBenchmarkActivity .mNumIter ,
118
- mBenchmarkActivity .mResult );
119
-
120
- if (mBenchmarkActivity .mTokenizerPath == null ) {
121
- mBenchmarkActivity .writeResult ();
122
- } else {
123
- this .sendEmptyMessage (MESSAGE_LLM_RUN_BENCHMARK );
124
- }
125
- } else if (msg .what == MESSAGE_LLM_RUN_BENCHMARK ) {
126
- mLlmBenchmark =
127
- new LlmBenchmark (
128
- mBenchmarkActivity ,
129
- mBenchmarkActivity .mModel .getPath (),
130
- mBenchmarkActivity .mTokenizerPath ,
131
- mBenchmarkActivity .mPrompt ,
132
- mBenchmarkActivity .mTemperature ,
133
- mBenchmarkActivity .mResult );
134
- }
138
+ public String toString () {
139
+ return "latency: " + latency .stream ().map (Object ::toString ).collect (Collectors .joining ("" ));
135
140
}
136
141
}
0 commit comments