Skip to content

Commit 2a1db35

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
(WIP) Make Android Module thread-safe and prevent destruction during inference (#9833)
Summary: While the Android Module interface was originally not designed to be thread safe, we've seen a sizable number of issues pop up due to users not fully meeting the thread safety requirements that we impose on the caller. Empirically, this is not always obvious when writing app code and can sneak in in subtle ways. Common issues are calling forward from a different thread while one inference is already in progress and not synchronizing module cleanup with inference. Both have caused crashes that are sometimes difficult for users to debug. This PR attempts to mitigate these issues by adding explicit synchronization in the Java Module class. Both method load and execution are behind a lock, and destroy will warn and avoid immediate destruction if an inference is in progress. I'm hesitant to directly acquire the lock in destroy, since it can get called in certain cleanup paths. Instead, I'm just warning and setting the native peer to null so it should get GC'd once out of use. Differential Revision: D72273052
1 parent 376d66a commit 2a1db35

File tree

2 files changed

+108
-4
lines changed

2 files changed

+108
-4
lines changed

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import java.io.InputStream;
2626
import java.net.URI;
2727
import java.net.URISyntaxException;
28+
import java.util.concurrent.CountDownLatch;
29+
import java.util.concurrent.atomic.AtomicInteger;
2830
import java.io.IOException;
2931
import java.io.File;
3032
import java.io.FileOutputStream;
@@ -42,6 +44,7 @@ public class ModuleInstrumentationTest {
4244
private static String FORWARD_METHOD = "forward";
4345
private static String NONE_METHOD = "none";
4446
private static int OK = 0x00;
47+
private static int INVALID_STATE = 0x2;
4548
private static int INVALID_ARGUMENT = 0x12;
4649
private static int ACCESS_FAILED = 0x22;
4750

@@ -124,4 +127,59 @@ public void testNonPteFile() throws IOException{
124127
int loadMethod = module.loadMethod(FORWARD_METHOD);
125128
assertEquals(loadMethod, INVALID_ARGUMENT);
126129
}
130+
131+
@Test
132+
public void testLoadOnDestroyedModule() throws IOException{
133+
Module module = Module.load(getTestFilePath(TEST_FILE_NAME));
134+
135+
module.destroy();
136+
137+
int loadMethod = module.loadMethod(FORWARD_METHOD);
138+
assertEquals(loadMethod, INVALID_STATE);
139+
}
140+
141+
@Test
142+
public void testForwardOnDestroyedModule() throws IOException{
143+
Module module = Module.load(getTestFilePath(TEST_FILE_NAME));
144+
145+
int loadMethod = module.loadMethod(FORWARD_METHOD);
146+
assertEquals(loadMethod, OK);
147+
148+
module.destroy();
149+
150+
EValue[] results = module.forward();
151+
assertEquals(0, results.length);
152+
}
153+
154+
@Test
155+
public void testForwardFromMultipleThreads() throws IOException {
156+
Module module = Module.load(getTestFilePath(TEST_FILE_NAME));
157+
158+
int numThreads = 100;
159+
CountDownLatch latch = new CountDownLatch(numThreads);
160+
AtomicInteger completed = new AtomicInteger(0);
161+
162+
Runnable runnable = new Runnable() {
163+
@Override
164+
public void run() {
165+
latch.countDown();
166+
latch.await(5000, java.util.concurrent.TimeUnit.MILLISECONDS);
167+
EValue[] results = module.forward();
168+
assertTrue(results[0].isTensor());
169+
completed.incrementAndGet();
170+
}
171+
};
172+
173+
Thread[] threads = new Thread[numThreads];
174+
for (int i = 0; i < numThreads; i++) {
175+
threads[i] = new Thread(runnable);
176+
thrads[i].start();
177+
}
178+
179+
for (int i = 0; i < numThreads; i++) {
180+
threads[i].join();
181+
}
182+
183+
assertEquals(numThreads, completed.get());
184+
}
127185
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010

1111
import com.facebook.soloader.nativeloader.NativeLoader;
1212
import com.facebook.soloader.nativeloader.SystemDelegate;
13+
14+
import android.util.Log;
15+
16+
import java.util.concurrent.locks.Lock;
17+
import java.util.concurrent.locks.ReentrantLock;
18+
1319
import org.pytorch.executorch.annotations.Experimental;
1420

1521
/**
@@ -34,6 +40,9 @@ public class Module {
3440

3541
/** Reference to the NativePeer object of this module. */
3642
private NativePeer mNativePeer;
43+
44+
/** Lock protecting the non-thread safe methods in NativePeer. */
45+
private Lock mLock = new ReentrantLock();
3746

3847
/**
3948
* Loads a serialized ExecuTorch module from the specified path on the disk.
@@ -72,7 +81,16 @@ public static Module load(final String modelPath) {
7281
* @return return value from the 'forward' method.
7382
*/
7483
public EValue[] forward(EValue... inputs) {
75-
return mNativePeer.forward(inputs);
84+
try {
85+
mLock.lock();
86+
if (mNativePeer == null) {
87+
Log.e("ExecuTorch", "Attempt to use a destroyed module");
88+
return new EValue[0];
89+
}
90+
return mNativePeer.forward(inputs);
91+
} finally {
92+
mLock.unlock();
93+
}
7694
}
7795

7896
/**
@@ -83,7 +101,16 @@ public EValue[] forward(EValue... inputs) {
83101
* @return return value from the method.
84102
*/
85103
public EValue[] execute(String methodName, EValue... inputs) {
86-
return mNativePeer.execute(methodName, inputs);
104+
try {
105+
mLock.lock();
106+
if (mNativePeer == null) {
107+
Log.e("ExecuTorch", "Attempt to use a destroyed module");
108+
return new EValue[0];
109+
}
110+
return mNativePeer.execute(methodName, inputs);
111+
} finally {
112+
mLock.unlock();
113+
}
87114
}
88115

89116
/**
@@ -96,7 +123,16 @@ public EValue[] execute(String methodName, EValue... inputs) {
96123
* @return the Error code if there was an error loading the method
97124
*/
98125
public int loadMethod(String methodName) {
99-
return mNativePeer.loadMethod(methodName);
126+
try {
127+
mLock.lock();
128+
if (mNativePeer == null) {
129+
Log.e("ExecuTorch", "Attempt to use a destroyed module");
130+
return 0x2; // InvalidState
131+
}
132+
return mNativePeer.loadMethod(methodName);
133+
} finally {
134+
mLock.unlock();
135+
}
100136
}
101137

102138
/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
@@ -111,6 +147,16 @@ public String[] readLogBuffer() {
111147
* more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
112148
*/
113149
public void destroy() {
114-
mNativePeer.resetNative();
150+
if (mLock.tryLock()) {
151+
try {
152+
mNativePeer.resetNative();
153+
mNativePeer = null;
154+
} finally {
155+
mLock.unlock();
156+
}
157+
} else {
158+
mNativePeer = null;
159+
Log.w("ExecuTorch", "Destroy was called while the module was in use. Resources will not be immediately released.");
160+
}
115161
}
116162
}

0 commit comments

Comments
 (0)