From 6185e424dc0708fe8ff74aafd1a1a68ca35291e5 Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Wed, 2 Apr 2025 06:31:16 -0700 Subject: [PATCH] (WIP) Make Android Module thread-safe and prevent destruction during inference (#9833) Summary: While the Android Module interface was originally not designed to be thread safe, we've seen a sizable number of issues pop up due to users not fully meeting the thread safety requirements that we impose on the caller. Empirically, this is not always obvious when writing app code and can sneak in in subtle ways. Common issues are calling forward from a different thread while one inference is already in progress and not synchronizing module cleanup with inference. Both have caused crashes that are sometimes difficult for users to debug. This PR attempts to mitigate these issues by adding explicit synchronization in the Java Module class. Both method load and execution are behind a lock, and destroy will warn and avoid immediate destruction if an inference is in progress. I'm hesitant to directly acquire the lock in destroy, since it can get called in certain cleanup paths. Instead, I'm just warning and setting the native peer to null so it should get GC'd once out of use. Differential Revision: D72273052 --- .../executorch/ModuleInstrumentationTest.java | 62 +++++++++++++++++++ .../java/org/pytorch/executorch/Module.java | 54 ++++++++++++++-- 2 files changed, 112 insertions(+), 4 deletions(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java index a25c0bf6343..f71351ae6ae 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java @@ -25,6 +25,8 @@ import java.io.InputStream; import java.net.URI; import java.net.URISyntaxException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicInteger; import java.io.IOException; import java.io.File; import java.io.FileOutputStream; @@ -42,6 +44,7 @@ public class ModuleInstrumentationTest { private static String FORWARD_METHOD = "forward"; private static String NONE_METHOD = "none"; private static int OK = 0x00; + private static int INVALID_STATE = 0x2; private static int INVALID_ARGUMENT = 0x12; private static int ACCESS_FAILED = 0x22; @@ -124,4 +127,63 @@ public void testNonPteFile() throws IOException{ int loadMethod = module.loadMethod(FORWARD_METHOD); assertEquals(loadMethod, INVALID_ARGUMENT); } + + @Test + public void testLoadOnDestroyedModule() throws IOException{ + Module module = Module.load(getTestFilePath(TEST_FILE_NAME)); + + module.destroy(); + + int loadMethod = module.loadMethod(FORWARD_METHOD); + assertEquals(loadMethod, INVALID_STATE); + } + + @Test + public void testForwardOnDestroyedModule() throws IOException{ + Module module = Module.load(getTestFilePath(TEST_FILE_NAME)); + + int loadMethod = module.loadMethod(FORWARD_METHOD); + assertEquals(loadMethod, OK); + + module.destroy(); + + EValue[] results = module.forward(); + assertEquals(0, results.length); + } + + @Test + public void testForwardFromMultipleThreads() throws InterruptedException, IOException { + Module module = Module.load(getTestFilePath(TEST_FILE_NAME)); + + int numThreads = 100; + CountDownLatch latch = new CountDownLatch(numThreads); + AtomicInteger completed = new AtomicInteger(0); + + Runnable runnable = new Runnable() { + @Override + public void run() { + try { + latch.countDown(); + latch.await(5000, java.util.concurrent.TimeUnit.MILLISECONDS); + EValue[] results = module.forward(); + assertTrue(results[0].isTensor()); + completed.incrementAndGet(); + } catch (InterruptedException e) { + + } + } + }; + + Thread[] threads = new Thread[numThreads]; + for (int i = 0; i < numThreads; i++) { + threads[i] = new Thread(runnable); + threads[i].start(); + } + + for (int i = 0; i < numThreads; i++) { + threads[i].join(); + } + + assertEquals(numThreads, completed.get()); + } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index 879b88c5f2f..f3f543dc2a8 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -8,8 +8,11 @@ package org.pytorch.executorch; +import android.util.Log; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import org.pytorch.executorch.annotations.Experimental; /** @@ -35,6 +38,9 @@ public class Module { /** Reference to the NativePeer object of this module. */ private NativePeer mNativePeer; + /** Lock protecting the non-thread safe methods in NativePeer. */ + private Lock mLock = new ReentrantLock(); + /** * Loads a serialized ExecuTorch module from the specified path on the disk. * @@ -72,7 +78,16 @@ public static Module load(final String modelPath) { * @return return value from the 'forward' method. */ public EValue[] forward(EValue... inputs) { - return mNativePeer.forward(inputs); + try { + mLock.lock(); + if (mNativePeer == null) { + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return new EValue[0]; + } + return mNativePeer.forward(inputs); + } finally { + mLock.unlock(); + } } /** @@ -83,7 +98,16 @@ public EValue[] forward(EValue... inputs) { * @return return value from the method. */ public EValue[] execute(String methodName, EValue... inputs) { - return mNativePeer.execute(methodName, inputs); + try { + mLock.lock(); + if (mNativePeer == null) { + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return new EValue[0]; + } + return mNativePeer.execute(methodName, inputs); + } finally { + mLock.unlock(); + } } /** @@ -96,7 +120,16 @@ public EValue[] execute(String methodName, EValue... inputs) { * @return the Error code if there was an error loading the method */ public int loadMethod(String methodName) { - return mNativePeer.loadMethod(methodName); + try { + mLock.lock(); + if (mNativePeer == null) { + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return 0x2; // InvalidState + } + return mNativePeer.loadMethod(methodName); + } finally { + mLock.unlock(); + } } /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ @@ -111,6 +144,19 @@ public String[] readLogBuffer() { * more quickly. See {@link com.facebook.jni.HybridData#resetNative}. */ public void destroy() { - mNativePeer.resetNative(); + if (mLock.tryLock()) { + try { + mNativePeer.resetNative(); + } finally { + mNativePeer = null; + mLock.unlock(); + } + } else { + mNativePeer = null; + Log.w( + "ExecuTorch", + "Destroy was called while the module was in use. Resources will not be immediately" + + " released."); + } } }