diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 06cae3475e8..3b7a04c2e97 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -64,7 +64,7 @@ set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/cmake/ExecuTorch) find_package(executorch CONFIG REQUIRED) target_link_options_shared_lib(executorch) -add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp) +add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp) set(link_libraries) list( diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/RuntimeInstrumentationTest.java b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/RuntimeInstrumentationTest.java new file mode 100644 index 00000000000..27114b4cc77 --- /dev/null +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/RuntimeInstrumentationTest.java @@ -0,0 +1,37 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import static org.junit.Assert.assertNotNull; + +import androidx.test.ext.junit.runners.AndroidJUnit4; +import org.junit.runner.RunWith; +import org.junit.Test; + +/** Unit tests for {@link ExecuTorchRuntime}. */ +@RunWith(AndroidJUnit4.class) +public class RuntimeInstrumentationTest { + + @Test + public void testRuntimeApi() { + String[] ops = ExecuTorchRuntime.getRegisteredOps(); + String[] backends = ExecuTorchRuntime.getRegisteredBackends(); + + assertNotNull(ops); + assertNotNull(backends); + + for (String op : ops) { + assertNotNull(op); + } + + for (String backend : backends) { + assertNotNull(backend); + } + } +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java index 14ab77a3a70..8e2f259ef3a 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java @@ -8,6 +8,7 @@ package org.pytorch.executorch; +import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; @@ -30,4 +31,12 @@ private ExecuTorchRuntime() {} public static ExecuTorchRuntime getRuntime() { return sInstance; } + + /** Get all registered ops. */ + @DoNotStrip + public static native String[] getRegisteredOps(); + + /** Get all registered backends. */ + @DoNotStrip + public static native String[] getRegisteredBackends(); } diff --git a/extension/android/jni/BUCK b/extension/android/jni/BUCK index 6fe03a58891..9ffe0525707 100644 --- a/extension/android/jni/BUCK +++ b/extension/android/jni/BUCK @@ -28,7 +28,7 @@ non_fbcode_target(_kind = executorch_generated_lib, non_fbcode_target(_kind = fb_android_cxx_library, name = "executorch_jni", - srcs = ["jni_layer.cpp", "log.cpp"], + srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp"], allow_jni_merging = False, compiler_flags = ET_JNI_COMPILER_FLAGS, soname = "libexecutorch.$(ext)", @@ -49,7 +49,7 @@ non_fbcode_target(_kind = fb_android_cxx_library, non_fbcode_target(_kind = fb_android_cxx_library, name = "executorch_jni_full", - srcs = ["jni_layer.cpp", "log.cpp"], + srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp"], allow_jni_merging = False, compiler_flags = ET_JNI_COMPILER_FLAGS, soname = "libexecutorch.$(ext)", @@ -74,6 +74,7 @@ non_fbcode_target(_kind = fb_android_cxx_library, srcs = [ "jni_layer.cpp", "jni_layer_llama.cpp", + "jni_layer_runtime.cpp", ], allow_jni_merging = False, compiler_flags = ET_JNI_COMPILER_FLAGS + [ @@ -113,6 +114,10 @@ runtime.export_file( name = "jni_layer.cpp", ) +runtime.export_file( + name = "jni_layer_runtime.cpp", +) + runtime.cxx_library( name = "jni_headers", exported_headers = [ diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index bbe47e98a06..c3ffe77a0cb 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -491,9 +491,11 @@ extern void register_natives_for_llm(); // No op if we don't build LLM void register_natives_for_llm() {} #endif +extern void register_natives_for_runtime(); JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { return facebook::jni::initialize(vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); register_natives_for_llm(); + register_natives_for_runtime(); }); } diff --git a/extension/android/jni/jni_layer_runtime.cpp b/extension/android/jni/jni_layer_runtime.cpp new file mode 100644 index 00000000000..890e1d0fad9 --- /dev/null +++ b/extension/android/jni/jni_layer_runtime.cpp @@ -0,0 +1,72 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +namespace executorch_jni { +namespace runtime = ::executorch::ET_RUNTIME_NAMESPACE; + +class AndroidRuntimeJni : public facebook::jni::JavaClass { + public: + constexpr static const char* kJavaDescriptor = + "Lorg/pytorch/executorch/ExecuTorchRuntime;"; + + static void registerNatives() { + javaClassStatic()->registerNatives({ + makeNativeMethod( + "getRegisteredOps", AndroidRuntimeJni::getRegisteredOps), + makeNativeMethod( + "getRegisteredBackends", AndroidRuntimeJni::getRegisteredBackends), + }); + } + + // Returns a string array of all registered ops + static facebook::jni::local_ref> + getRegisteredOps(facebook::jni::alias_ref) { + auto kernels = runtime::get_registered_kernels(); + auto result = facebook::jni::JArrayClass::newArray(kernels.size()); + + for (size_t i = 0; i < kernels.size(); ++i) { + auto op = facebook::jni::make_jstring(kernels[i].name_); + result->setElement(i, op.get()); + } + + return result; + } + + // Returns a string array of all registered backends + static facebook::jni::local_ref> + getRegisteredBackends(facebook::jni::alias_ref) { + int num_backends = runtime::get_num_registered_backends(); + auto result = facebook::jni::JArrayClass::newArray(num_backends); + + for (int i = 0; i < num_backends; ++i) { + auto name_result = runtime::get_backend_name(i); + const char* name = ""; + + if (name_result.ok()) { + name = *name_result; + } + + auto backend_str = facebook::jni::make_jstring(name); + result->setElement(i, backend_str.get()); + } + + return result; + } +}; + +} // namespace executorch_jni + +void register_natives_for_runtime() { + executorch_jni::AndroidRuntimeJni::registerNatives(); +} diff --git a/extension/android/jni/selective_jni.buck.bzl b/extension/android/jni/selective_jni.buck.bzl index 1a921e6ef1e..d557606b7d1 100644 --- a/extension/android/jni/selective_jni.buck.bzl +++ b/extension/android/jni/selective_jni.buck.bzl @@ -4,10 +4,12 @@ load("@fbsource//xplat/executorch/backends/xnnpack/third-party:third_party_libs. load("@fbsource//xplat/executorch/extension/android/jni:build_defs.bzl", "ET_JNI_COMPILER_FLAGS") def selective_jni_target(name, deps, srcs = [], soname = "libexecutorch.$(ext)"): - non_fbcode_target(_kind = fb_android_cxx_library, + non_fbcode_target( + _kind = fb_android_cxx_library, name = name, srcs = [ "//xplat/executorch/extension/android/jni:jni_layer.cpp", + "//xplat/executorch/extension/android/jni:jni_layer_runtime.cpp", ] + srcs, allow_jni_merging = False, compiler_flags = ET_JNI_COMPILER_FLAGS,