From 0d38a993c4edbce8da958942360a0e03dcd4d0fa Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Wed, 21 May 2025 18:03:08 -0700 Subject: [PATCH 1/6] input and output tag API --- .../org/pytorch/executorch/ModuleE2ETest.kt | 3 +- .../executorch/ModuleInstrumentationTest.kt | 2 + .../pytorch/executorch/MethodMetadata.java | 27 ++++++++++++- .../java/org/pytorch/executorch/Module.java | 14 ++++++- extension/android/jni/jni_layer.cpp | 39 +++++++++++++++++-- 5 files changed, 78 insertions(+), 7 deletions(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt index 381b1bd99d1..1ef9e9e035c 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt @@ -68,9 +68,8 @@ class ModuleE2ETest { inputStream.close() val module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte")) - val expectedBackends = arrayOf("XnnpackBackend") Assert.assertArrayEquals( - expectedBackends, + arrayOf("XnnpackBackend"), module.getMethodMetadata("forward").getBackends(), ) } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index 84dc064bf26..50d469e606d 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -62,6 +62,8 @@ class ModuleInstrumentationTest { Assert.assertArrayEquals(arrayOf("forward"), module.getMethods()) Assert.assertTrue(module.getMethodMetadata("forward").backends.isEmpty()) + Assert.assertArrayEquals(intArrayOf(1, 1, 3), module.getMethodMetadata("forward").inputTags) + Assert.assertArrayEquals(intArrayOf(1), module.getMethodMetadata("forward").outputTags) } @Test diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java index b2dde35a2d8..de2ae14fb19 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java @@ -11,8 +11,9 @@ /** Helper class to access the metadata for a method from a Module */ public class MethodMetadata { private String mName; - private String[] mBackends; + private int[] mInputTags; + private int[] mOutputTags; MethodMetadata setName(String name) { mName = name; @@ -37,4 +38,28 @@ MethodMetadata setBackends(String[] backends) { public String[] getBackends() { return mBackends; } + + /** + * @return Output tags + */ + public int[] getOutputTags() { + return mOutputTags; + } + + MethodMetadata setOutputTags(int[] outputTags) { + mOutputTags = outputTags; + return this; + } + + /** + * @return Input tags + */ + public int[] getInputTags() { + return mInputTags; + } + + MethodMetadata setInputTags(int[] inputTags) { + mInputTags = inputTags; + return this; + } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index 4e06d3cbc79..73dd64b3a74 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -59,7 +59,13 @@ Map populateMethodMeta() { Map metadata = new HashMap(); for (int i = 0; i < methods.length; i++) { String name = methods[i]; - metadata.put(name, new MethodMetadata().setName(name).setBackends(getUsedBackends(name))); + metadata.put( + name, + new MethodMetadata() + .setName(name) + .setBackends(getUsedBackends(name)) + .setInputTags(getInputTags(name)) + .setOutputTags(getOutputTags(name))); } return metadata; @@ -204,6 +210,12 @@ public String[] readLogBuffer() { @DoNotStrip private native String[] readLogBufferNative(); + @DoNotStrip + private native int[] getInputTags(String method); + + @DoNotStrip + private native int[] getOutputTags(String method); + /** * Dump the ExecuTorch ETRecord file to /data/local/tmp/result.etdump. * diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index bbe47e98a06..515acf16caa 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -453,10 +453,11 @@ class ExecuTorchJni : public facebook::jni::HybridClass { facebook::jni::local_ref> getUsedBackends( facebook::jni::alias_ref methodName) { - auto methodMeta = module_->method_meta(methodName->toStdString()).get(); + auto method_meta = + module_->method_meta(methodName->toStdString()).get(); std::unordered_set backends; - for (auto i = 0; i < methodMeta.num_backends(); i++) { - backends.insert(methodMeta.get_backend_name(i).get()); + for (auto i = 0; i < method_meta.num_backends(); i++) { + backends.insert(method_meta.get_backend_name(i).get()); } facebook::jni::local_ref> ret = @@ -471,6 +472,36 @@ class ExecuTorchJni : public facebook::jni::HybridClass { return ret; } + facebook::jni::local_ref getInputTags( + facebook::jni::alias_ref methodName) { + auto method_meta = + module_->method_meta(methodName->toStdString()).get(); + auto num_inputs = method_meta.num_inputs(); + facebook::jni::local_ref ret = + facebook::jni::JArrayInt::newArray(num_inputs); + + int i = 0; + for (int i = 0; i < num_inputs; i++) { + ret->pin()[i] = static_cast(method_meta.input_tag(i).get()); + } + return ret; + } + + facebook::jni::local_ref getOutputTags( + facebook::jni::alias_ref methodName) { + auto method_meta = + module_->method_meta(methodName->toStdString()).get(); + auto num_outputs = method_meta.num_outputs(); + facebook::jni::local_ref ret = + facebook::jni::JArrayInt::newArray(num_outputs); + + int i = 0; + for (int i = 0; i < num_outputs; i++) { + ret->pin()[i] = static_cast(method_meta.output_tag(i).get()); + } + return ret; + } + static void registerNatives() { registerHybrid({ makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid), @@ -480,6 +511,8 @@ class ExecuTorchJni : public facebook::jni::HybridClass { makeNativeMethod("etdump", ExecuTorchJni::etdump), makeNativeMethod("getMethods", ExecuTorchJni::getMethods), makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends), + makeNativeMethod("getInputTags", ExecuTorchJni::getInputTags), + makeNativeMethod("getOutputTags", ExecuTorchJni::getOutputTags), }); } }; From d9563922ecc3d8e4930f2ec460f0458e90a9fac6 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Wed, 21 May 2025 18:05:11 -0700 Subject: [PATCH 2/6] Fix --- .../executorch/LlmModuleInstrumentationTest.kt | 9 +++++---- .../java/org/pytorch/executorch/ModuleE2ETest.kt | 6 +++--- .../executorch/ModuleInstrumentationTest.kt | 15 ++++++++------- .../org/pytorch/executorch/TensorImageUtils.kt | 12 +++++++----- extension/android/jni/jni_layer.cpp | 2 -- 5 files changed, 23 insertions(+), 21 deletions(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt index 43ce302a7a6..0baa1d5191c 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt @@ -11,9 +11,6 @@ import android.Manifest import androidx.test.InstrumentationRegistry import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.rule.GrantPermissionRule -import java.io.File -import java.io.IOException -import java.net.URISyntaxException import org.apache.commons.io.FileUtils import org.json.JSONException import org.json.JSONObject @@ -24,6 +21,9 @@ import org.junit.Test import org.junit.runner.RunWith import org.pytorch.executorch.extension.llm.LlmCallback import org.pytorch.executorch.extension.llm.LlmModule +import java.io.File +import java.io.IOException +import java.net.URISyntaxException /** Unit tests for [org.pytorch.executorch.extension.llm.LlmModule]. */ @RunWith(AndroidJUnit4::class) @@ -101,7 +101,8 @@ class LlmModuleInstrumentationTest : LlmCallback { val promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms") tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000 tokensPerSecond.add(tps) - } catch (_: JSONException) {} + } catch (_: JSONException) { + } } companion object { diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt index 1ef9e9e035c..9ba97bdc3ac 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt @@ -13,15 +13,15 @@ import android.graphics.BitmapFactory import androidx.test.InstrumentationRegistry import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.rule.GrantPermissionRule -import java.io.File -import java.io.IOException -import java.net.URISyntaxException import org.apache.commons.io.FileUtils import org.junit.Assert import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith import org.pytorch.executorch.TensorImageUtils.bitmapToFloat32Tensor +import java.io.File +import java.io.IOException +import java.net.URISyntaxException /** Unit tests for [Module]. */ @RunWith(AndroidJUnit4::class) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index 50d469e606d..2468a9d2ff2 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -11,18 +11,18 @@ import android.Manifest import androidx.test.InstrumentationRegistry import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.rule.GrantPermissionRule -import java.io.File -import java.io.IOException -import java.net.URISyntaxException -import java.util.concurrent.CountDownLatch -import java.util.concurrent.TimeUnit -import java.util.concurrent.atomic.AtomicInteger import org.apache.commons.io.FileUtils import org.junit.Assert import org.junit.Before import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith +import java.io.File +import java.io.IOException +import java.net.URISyntaxException +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicInteger /** Unit tests for [Module]. */ @RunWith(AndroidJUnit4::class) @@ -152,7 +152,8 @@ class ModuleInstrumentationTest { val results = module.forward() Assert.assertTrue(results[0].isTensor) completed.incrementAndGet() - } catch (_: InterruptedException) {} + } catch (_: InterruptedException) { + } } val threads = arrayOfNulls(numThreads) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt index cb2e365a4c5..dd3f5a1dfd0 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt @@ -16,9 +16,11 @@ import java.nio.FloatBuffer * [android.media.Image] source. */ object TensorImageUtils { - @JvmField var TORCHVISION_NORM_MEAN_RGB: FloatArray = floatArrayOf(0.485f, 0.456f, 0.406f) + @JvmField + var TORCHVISION_NORM_MEAN_RGB: FloatArray = floatArrayOf(0.485f, 0.456f, 0.406f) - @JvmField var TORCHVISION_NORM_STD_RGB: FloatArray = floatArrayOf(0.229f, 0.224f, 0.225f) + @JvmField + var TORCHVISION_NORM_STD_RGB: FloatArray = floatArrayOf(0.229f, 0.224f, 0.225f) /** * Creates new [Tensor] from full [android.graphics.Bitmap], normalized with specified in @@ -144,9 +146,9 @@ object TensorImageUtils { private fun checkRotateCWDegrees(rotateCWDegrees: Int) { require( !(rotateCWDegrees != 0 && - rotateCWDegrees != 90 && - rotateCWDegrees != 180 && - rotateCWDegrees != 270) + rotateCWDegrees != 90 && + rotateCWDegrees != 180 && + rotateCWDegrees != 270) ) { "rotateCWDegrees must be one of 0, 90, 180, 270" } diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 515acf16caa..5b6c570cad6 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -480,7 +480,6 @@ class ExecuTorchJni : public facebook::jni::HybridClass { facebook::jni::local_ref ret = facebook::jni::JArrayInt::newArray(num_inputs); - int i = 0; for (int i = 0; i < num_inputs; i++) { ret->pin()[i] = static_cast(method_meta.input_tag(i).get()); } @@ -495,7 +494,6 @@ class ExecuTorchJni : public facebook::jni::HybridClass { facebook::jni::local_ref ret = facebook::jni::JArrayInt::newArray(num_outputs); - int i = 0; for (int i = 0; i < num_outputs; i++) { ret->pin()[i] = static_cast(method_meta.output_tag(i).get()); } From 0dc2ba5cc5f6374c087fb553cd6a9ec091a64eea Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 22 May 2025 14:00:21 -0700 Subject: [PATCH 3/6] Linter --- extension/android/jni/jni_layer.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 2ecd1ad2ed9..8f136902da4 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -453,8 +453,7 @@ class ExecuTorchJni : public facebook::jni::HybridClass { facebook::jni::local_ref> getUsedBackends( facebook::jni::alias_ref methodName) { - auto method_meta = - module_->method_meta(methodName->toStdString()).get(); + auto method_meta = module_->method_meta(methodName->toStdString()).get(); std::unordered_set backends; for (auto i = 0; i < method_meta.num_backends(); i++) { backends.insert(method_meta.get_backend_name(i).get()); @@ -474,8 +473,7 @@ class ExecuTorchJni : public facebook::jni::HybridClass { facebook::jni::local_ref getInputTags( facebook::jni::alias_ref methodName) { - auto method_meta = - module_->method_meta(methodName->toStdString()).get(); + auto method_meta = module_->method_meta(methodName->toStdString()).get(); auto num_inputs = method_meta.num_inputs(); facebook::jni::local_ref ret = facebook::jni::JArrayInt::newArray(num_inputs); @@ -488,8 +486,7 @@ class ExecuTorchJni : public facebook::jni::HybridClass { facebook::jni::local_ref getOutputTags( facebook::jni::alias_ref methodName) { - auto method_meta = - module_->method_meta(methodName->toStdString()).get(); + auto method_meta = module_->method_meta(methodName->toStdString()).get(); auto num_outputs = method_meta.num_outputs(); facebook::jni::local_ref ret = facebook::jni::JArrayInt::newArray(num_outputs); From 66602763a4940126f6f455a198c15e3d79a53423 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 22 May 2025 16:08:33 -0700 Subject: [PATCH 4/6] fmt --- .../LlmModuleInstrumentationTest.kt | 159 +++++---- .../org/pytorch/executorch/ModuleE2ETest.kt | 150 ++++----- .../executorch/ModuleInstrumentationTest.kt | 313 +++++++++--------- .../pytorch/executorch/TensorImageUtils.kt | 257 +++++++------- 4 files changed, 437 insertions(+), 442 deletions(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt index 0baa1d5191c..5af37d09ffb 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt @@ -11,6 +11,9 @@ import android.Manifest import androidx.test.InstrumentationRegistry import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.rule.GrantPermissionRule +import java.io.File +import java.io.IOException +import java.net.URISyntaxException import org.apache.commons.io.FileUtils import org.json.JSONException import org.json.JSONObject @@ -21,102 +24,98 @@ import org.junit.Test import org.junit.runner.RunWith import org.pytorch.executorch.extension.llm.LlmCallback import org.pytorch.executorch.extension.llm.LlmModule -import java.io.File -import java.io.IOException -import java.net.URISyntaxException /** Unit tests for [org.pytorch.executorch.extension.llm.LlmModule]. */ @RunWith(AndroidJUnit4::class) class LlmModuleInstrumentationTest : LlmCallback { - private val results: MutableList = ArrayList() - private val tokensPerSecond: MutableList = ArrayList() - private var llmModule: LlmModule? = null + private val results: MutableList = ArrayList() + private val tokensPerSecond: MutableList = ArrayList() + private var llmModule: LlmModule? = null - @Before - @Throws(IOException::class) - fun setUp() { - // copy zipped test resources to local device - val addPteFile = File(getTestFilePath(TEST_FILE_NAME)) - var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME) - FileUtils.copyInputStreamToFile(inputStream, addPteFile) - inputStream.close() + @Before + @Throws(IOException::class) + fun setUp() { + // copy zipped test resources to local device + val addPteFile = File(getTestFilePath(TEST_FILE_NAME)) + var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME) + FileUtils.copyInputStreamToFile(inputStream, addPteFile) + inputStream.close() - val tokenizerFile = File(getTestFilePath(TOKENIZER_FILE_NAME)) - inputStream = javaClass.getResourceAsStream(TOKENIZER_FILE_NAME) - FileUtils.copyInputStreamToFile(inputStream, tokenizerFile) - inputStream.close() + val tokenizerFile = File(getTestFilePath(TOKENIZER_FILE_NAME)) + inputStream = javaClass.getResourceAsStream(TOKENIZER_FILE_NAME) + FileUtils.copyInputStreamToFile(inputStream, tokenizerFile) + inputStream.close() - llmModule = - LlmModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f) - } + llmModule = + LlmModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f) + } - @get:Rule - var runtimePermissionRule: GrantPermissionRule = - GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) + @get:Rule + var runtimePermissionRule: GrantPermissionRule = + GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testGenerate() { - val loadResult = llmModule!!.load() - // Check that the model can be load successfully - Assert.assertEquals(OK.toLong(), loadResult.toLong()) + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testGenerate() { + val loadResult = llmModule!!.load() + // Check that the model can be load successfully + Assert.assertEquals(OK.toLong(), loadResult.toLong()) - llmModule!!.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest) - Assert.assertEquals(results.size.toLong(), SEQ_LEN.toLong()) - Assert.assertTrue(tokensPerSecond[tokensPerSecond.size - 1] > 0) - } + llmModule!!.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest) + Assert.assertEquals(results.size.toLong(), SEQ_LEN.toLong()) + Assert.assertTrue(tokensPerSecond[tokensPerSecond.size - 1] > 0) + } - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testGenerateAndStop() { - llmModule!!.generate( - TEST_PROMPT, - SEQ_LEN, - object : LlmCallback { - override fun onResult(result: String) { - this@LlmModuleInstrumentationTest.onResult(result) - llmModule!!.stop() - } + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testGenerateAndStop() { + llmModule!!.generate( + TEST_PROMPT, + SEQ_LEN, + object : LlmCallback { + override fun onResult(result: String) { + this@LlmModuleInstrumentationTest.onResult(result) + llmModule!!.stop() + } - override fun onStats(stats: String) { - this@LlmModuleInstrumentationTest.onStats(stats) - } - }, - ) + override fun onStats(stats: String) { + this@LlmModuleInstrumentationTest.onStats(stats) + } + }, + ) - val stoppedResultSize = results.size - Assert.assertTrue(stoppedResultSize < SEQ_LEN) - } + val stoppedResultSize = results.size + Assert.assertTrue(stoppedResultSize < SEQ_LEN) + } - override fun onResult(result: String) { - results.add(result) - } + override fun onResult(result: String) { + results.add(result) + } - override fun onStats(stats: String) { - var tps = 0f - try { - val jsonObject = JSONObject(stats) - val numGeneratedTokens = jsonObject.getInt("generated_tokens") - val inferenceEndMs = jsonObject.getInt("inference_end_ms") - val promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms") - tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000 - tokensPerSecond.add(tps) - } catch (_: JSONException) { - } - } + override fun onStats(stats: String) { + var tps = 0f + try { + val jsonObject = JSONObject(stats) + val numGeneratedTokens = jsonObject.getInt("generated_tokens") + val inferenceEndMs = jsonObject.getInt("inference_end_ms") + val promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms") + tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000 + tokensPerSecond.add(tps) + } catch (_: JSONException) {} + } - companion object { - private const val TEST_FILE_NAME = "/stories.pte" - private const val TOKENIZER_FILE_NAME = "/tokenizer.bin" - private const val TEST_PROMPT = "Hello" - private const val OK = 0x00 - private const val SEQ_LEN = 32 + companion object { + private const val TEST_FILE_NAME = "/stories.pte" + private const val TOKENIZER_FILE_NAME = "/tokenizer.bin" + private const val TEST_PROMPT = "Hello" + private const val OK = 0x00 + private const val SEQ_LEN = 32 - private fun getTestFilePath(fileName: String): String { - return InstrumentationRegistry.getInstrumentation() - .targetContext - .externalCacheDir - .toString() + fileName - } + private fun getTestFilePath(fileName: String): String { + return InstrumentationRegistry.getInstrumentation() + .targetContext + .externalCacheDir + .toString() + fileName } + } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt index 9ba97bdc3ac..8366849a3a5 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt @@ -13,104 +13,104 @@ import android.graphics.BitmapFactory import androidx.test.InstrumentationRegistry import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.rule.GrantPermissionRule +import java.io.File +import java.io.IOException +import java.net.URISyntaxException import org.apache.commons.io.FileUtils import org.junit.Assert import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith import org.pytorch.executorch.TensorImageUtils.bitmapToFloat32Tensor -import java.io.File -import java.io.IOException -import java.net.URISyntaxException /** Unit tests for [Module]. */ @RunWith(AndroidJUnit4::class) class ModuleE2ETest { - @get:Rule - var runtimePermissionRule: GrantPermissionRule = - GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) + @get:Rule + var runtimePermissionRule: GrantPermissionRule = + GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) - @Throws(IOException::class, URISyntaxException::class) - fun testClassification(filePath: String) { - val pteFile = File(getTestFilePath(filePath)) - val inputStream = javaClass.getResourceAsStream(filePath) - FileUtils.copyInputStreamToFile(inputStream, pteFile) - inputStream.close() + @Throws(IOException::class, URISyntaxException::class) + fun testClassification(filePath: String) { + val pteFile = File(getTestFilePath(filePath)) + val inputStream = javaClass.getResourceAsStream(filePath) + FileUtils.copyInputStreamToFile(inputStream, pteFile) + inputStream.close() - val imgInputStream = javaClass.getResourceAsStream("/banana.jpeg") - var bitmap = BitmapFactory.decodeStream(imgInputStream) - bitmap = Bitmap.createScaledBitmap(bitmap!!, 224, 224, true) - imgInputStream.close() + val imgInputStream = javaClass.getResourceAsStream("/banana.jpeg") + var bitmap = BitmapFactory.decodeStream(imgInputStream) + bitmap = Bitmap.createScaledBitmap(bitmap!!, 224, 224, true) + imgInputStream.close() - val inputTensor = - bitmapToFloat32Tensor( - bitmap, - TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, - TensorImageUtils.TORCHVISION_NORM_STD_RGB, - ) + val inputTensor = + bitmapToFloat32Tensor( + bitmap, + TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, + TensorImageUtils.TORCHVISION_NORM_STD_RGB, + ) - val module = Module.load(getTestFilePath(filePath)) + val module = Module.load(getTestFilePath(filePath)) - val results = module.forward(EValue.from(inputTensor)) - Assert.assertTrue(results[0].isTensor) - val scores = results[0].toTensor().dataAsFloatArray + val results = module.forward(EValue.from(inputTensor)) + Assert.assertTrue(results[0].isTensor) + val scores = results[0].toTensor().dataAsFloatArray - val bananaClass = 954 // From ImageNet 1K - Assert.assertEquals(bananaClass.toLong(), argmax(scores).toLong()) - } + val bananaClass = 954 // From ImageNet 1K + Assert.assertEquals(bananaClass.toLong(), argmax(scores).toLong()) + } - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testXnnpackBackendRequired() { - val pteFile = File(getTestFilePath("/mv3_xnnpack_fp32.pte")) - val inputStream = javaClass.getResourceAsStream("/mv3_xnnpack_fp32.pte") - FileUtils.copyInputStreamToFile(inputStream, pteFile) - inputStream.close() + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testXnnpackBackendRequired() { + val pteFile = File(getTestFilePath("/mv3_xnnpack_fp32.pte")) + val inputStream = javaClass.getResourceAsStream("/mv3_xnnpack_fp32.pte") + FileUtils.copyInputStreamToFile(inputStream, pteFile) + inputStream.close() - val module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte")) - Assert.assertArrayEquals( - arrayOf("XnnpackBackend"), - module.getMethodMetadata("forward").getBackends(), - ) - } + val module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte")) + Assert.assertArrayEquals( + arrayOf("XnnpackBackend"), + module.getMethodMetadata("forward").getBackends(), + ) + } - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testMv2Fp32() { - testClassification("/mv2_xnnpack_fp32.pte") - } + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testMv2Fp32() { + testClassification("/mv2_xnnpack_fp32.pte") + } - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testMv3Fp32() { - testClassification("/mv3_xnnpack_fp32.pte") - } + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testMv3Fp32() { + testClassification("/mv3_xnnpack_fp32.pte") + } - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testResnet50() { - testClassification("/resnet50_xnnpack_q8.pte") - } + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testResnet50() { + testClassification("/resnet50_xnnpack_q8.pte") + } - companion object { - private fun getTestFilePath(fileName: String): String { - return InstrumentationRegistry.getInstrumentation() - .targetContext - .externalCacheDir - .toString() + fileName - } + companion object { + private fun getTestFilePath(fileName: String): String { + return InstrumentationRegistry.getInstrumentation() + .targetContext + .externalCacheDir + .toString() + fileName + } - fun argmax(array: FloatArray): Int { - require(array.isNotEmpty()) { "Array cannot be empty" } - var maxIndex = 0 - var maxValue = array[0] - for (i in 1 until array.size) { - if (array[i] > maxValue) { - maxValue = array[i] - maxIndex = i - } - } - return maxIndex + fun argmax(array: FloatArray): Int { + require(array.isNotEmpty()) { "Array cannot be empty" } + var maxIndex = 0 + var maxValue = array[0] + for (i in 1 until array.size) { + if (array[i] > maxValue) { + maxValue = array[i] + maxIndex = i } + } + return maxIndex } + } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index 2468a9d2ff2..350a73dad1b 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -11,180 +11,179 @@ import android.Manifest import androidx.test.InstrumentationRegistry import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.rule.GrantPermissionRule -import org.apache.commons.io.FileUtils -import org.junit.Assert -import org.junit.Before -import org.junit.Rule -import org.junit.Test -import org.junit.runner.RunWith import java.io.File import java.io.IOException import java.net.URISyntaxException import java.util.concurrent.CountDownLatch import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger +import org.apache.commons.io.FileUtils +import org.junit.Assert +import org.junit.Before +import org.junit.Rule +import org.junit.Test +import org.junit.runner.RunWith /** Unit tests for [Module]. */ @RunWith(AndroidJUnit4::class) class ModuleInstrumentationTest { - @Before - @Throws(IOException::class) - fun setUp() { - // copy zipped test resources to local device - val addPteFile = File(getTestFilePath(TEST_FILE_NAME)) - var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME) - FileUtils.copyInputStreamToFile(inputStream, addPteFile) - inputStream.close() - - val nonPteFile = File(getTestFilePath(NON_PTE_FILE_NAME)) - inputStream = javaClass.getResourceAsStream(NON_PTE_FILE_NAME) - FileUtils.copyInputStreamToFile(inputStream, nonPteFile) - inputStream.close() - } - - @get:Rule - var runtimePermissionRule: GrantPermissionRule = - GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) - - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testModuleLoadAndForward() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - + @Before + @Throws(IOException::class) + fun setUp() { + // copy zipped test resources to local device + val addPteFile = File(getTestFilePath(TEST_FILE_NAME)) + var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME) + FileUtils.copyInputStreamToFile(inputStream, addPteFile) + inputStream.close() + + val nonPteFile = File(getTestFilePath(NON_PTE_FILE_NAME)) + inputStream = javaClass.getResourceAsStream(NON_PTE_FILE_NAME) + FileUtils.copyInputStreamToFile(inputStream, nonPteFile) + inputStream.close() + } + + @get:Rule + var runtimePermissionRule: GrantPermissionRule = + GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) + + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testModuleLoadAndForward() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + val results = module.forward() + Assert.assertTrue(results[0].isTensor) + } + + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testMethodMetadata() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + Assert.assertArrayEquals(arrayOf("forward"), module.getMethods()) + Assert.assertTrue(module.getMethodMetadata("forward").backends.isEmpty()) + Assert.assertArrayEquals(intArrayOf(1, 1, 3), module.getMethodMetadata("forward").inputTags) + Assert.assertArrayEquals(intArrayOf(1), module.getMethodMetadata("forward").outputTags) + } + + @Test + @Throws(IOException::class) + fun testModuleLoadMethodAndForward() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + val loadMethod = module.loadMethod(FORWARD_METHOD) + Assert.assertEquals(loadMethod.toLong(), OK.toLong()) + + val results = module.forward() + Assert.assertTrue(results[0].isTensor) + } + + @Test + @Throws(IOException::class) + fun testModuleLoadForwardExplicit() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + val results = module.execute(FORWARD_METHOD) + Assert.assertTrue(results[0].isTensor) + } + + @Test(expected = RuntimeException::class) + @Throws(IOException::class) + fun testModuleLoadNonExistantFile() { + val module = Module.load(getTestFilePath(MISSING_FILE_NAME)) + } + + @Test + @Throws(IOException::class) + fun testModuleLoadMethodNonExistantMethod() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + val loadMethod = module.loadMethod(NONE_METHOD) + Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) + } + + @Test(expected = RuntimeException::class) + @Throws(IOException::class) + fun testNonPteFile() { + val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME)) + + val loadMethod = module.loadMethod(FORWARD_METHOD) + Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) + } + + @Test + @Throws(IOException::class) + fun testLoadOnDestroyedModule() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + module.destroy() + + val loadMethod = module.loadMethod(FORWARD_METHOD) + Assert.assertEquals(loadMethod.toLong(), INVALID_STATE.toLong()) + } + + @Test + @Throws(IOException::class) + fun testForwardOnDestroyedModule() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + val loadMethod = module.loadMethod(FORWARD_METHOD) + Assert.assertEquals(loadMethod.toLong(), OK.toLong()) + + module.destroy() + + val results = module.forward() + Assert.assertEquals(0, results.size.toLong()) + } + + @Test + @Throws(InterruptedException::class, IOException::class) + fun testForwardFromMultipleThreads() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + val numThreads = 100 + val latch = CountDownLatch(numThreads) + val completed = AtomicInteger(0) + + val runnable = Runnable { + try { + latch.countDown() + latch.await(5000, TimeUnit.MILLISECONDS) val results = module.forward() Assert.assertTrue(results[0].isTensor) + completed.incrementAndGet() + } catch (_: InterruptedException) {} } - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testMethodMetadata() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - Assert.assertArrayEquals(arrayOf("forward"), module.getMethods()) - Assert.assertTrue(module.getMethodMetadata("forward").backends.isEmpty()) - Assert.assertArrayEquals(intArrayOf(1, 1, 3), module.getMethodMetadata("forward").inputTags) - Assert.assertArrayEquals(intArrayOf(1), module.getMethodMetadata("forward").outputTags) - } - - @Test - @Throws(IOException::class) - fun testModuleLoadMethodAndForward() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), OK.toLong()) - - val results = module.forward() - Assert.assertTrue(results[0].isTensor) - } - - @Test - @Throws(IOException::class) - fun testModuleLoadForwardExplicit() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val results = module.execute(FORWARD_METHOD) - Assert.assertTrue(results[0].isTensor) - } - - @Test(expected = RuntimeException::class) - @Throws(IOException::class) - fun testModuleLoadNonExistantFile() { - val module = Module.load(getTestFilePath(MISSING_FILE_NAME)) - } - - @Test - @Throws(IOException::class) - fun testModuleLoadMethodNonExistantMethod() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val loadMethod = module.loadMethod(NONE_METHOD) - Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) - } - - @Test(expected = RuntimeException::class) - @Throws(IOException::class) - fun testNonPteFile() { - val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME)) - - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) - } - - @Test - @Throws(IOException::class) - fun testLoadOnDestroyedModule() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - module.destroy() - - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), INVALID_STATE.toLong()) - } - - @Test - @Throws(IOException::class) - fun testForwardOnDestroyedModule() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), OK.toLong()) - - module.destroy() - - val results = module.forward() - Assert.assertEquals(0, results.size.toLong()) + val threads = arrayOfNulls(numThreads) + for (i in 0 until numThreads) { + threads[i] = Thread(runnable) + threads[i]!!.start() } - @Test - @Throws(InterruptedException::class, IOException::class) - fun testForwardFromMultipleThreads() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val numThreads = 100 - val latch = CountDownLatch(numThreads) - val completed = AtomicInteger(0) - - val runnable = Runnable { - try { - latch.countDown() - latch.await(5000, TimeUnit.MILLISECONDS) - val results = module.forward() - Assert.assertTrue(results[0].isTensor) - completed.incrementAndGet() - } catch (_: InterruptedException) { - } - } - - val threads = arrayOfNulls(numThreads) - for (i in 0 until numThreads) { - threads[i] = Thread(runnable) - threads[i]!!.start() - } - - for (i in 0 until numThreads) { - threads[i]!!.join() - } - - Assert.assertEquals(numThreads.toLong(), completed.get().toLong()) + for (i in 0 until numThreads) { + threads[i]!!.join() } - companion object { - private const val TEST_FILE_NAME = "/ModuleAdd.pte" - private const val MISSING_FILE_NAME = "/missing.pte" - private const val NON_PTE_FILE_NAME = "/test.txt" - private const val FORWARD_METHOD = "forward" - private const val NONE_METHOD = "none" - private const val OK = 0x00 - private const val INVALID_STATE = 0x2 - private const val INVALID_ARGUMENT = 0x12 - private const val ACCESS_FAILED = 0x22 - - private fun getTestFilePath(fileName: String): String { - return InstrumentationRegistry.getInstrumentation() - .targetContext - .externalCacheDir - .toString() + fileName - } + Assert.assertEquals(numThreads.toLong(), completed.get().toLong()) + } + + companion object { + private const val TEST_FILE_NAME = "/ModuleAdd.pte" + private const val MISSING_FILE_NAME = "/missing.pte" + private const val NON_PTE_FILE_NAME = "/test.txt" + private const val FORWARD_METHOD = "forward" + private const val NONE_METHOD = "none" + private const val OK = 0x00 + private const val INVALID_STATE = 0x2 + private const val INVALID_ARGUMENT = 0x12 + private const val ACCESS_FAILED = 0x22 + + private fun getTestFilePath(fileName: String): String { + return InstrumentationRegistry.getInstrumentation() + .targetContext + .externalCacheDir + .toString() + fileName } + } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt index dd3f5a1dfd0..d4c796ea6cb 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt @@ -16,149 +16,146 @@ import java.nio.FloatBuffer * [android.media.Image] source. */ object TensorImageUtils { - @JvmField - var TORCHVISION_NORM_MEAN_RGB: FloatArray = floatArrayOf(0.485f, 0.456f, 0.406f) + @JvmField var TORCHVISION_NORM_MEAN_RGB: FloatArray = floatArrayOf(0.485f, 0.456f, 0.406f) - @JvmField - var TORCHVISION_NORM_STD_RGB: FloatArray = floatArrayOf(0.229f, 0.224f, 0.225f) + @JvmField var TORCHVISION_NORM_STD_RGB: FloatArray = floatArrayOf(0.229f, 0.224f, 0.225f) - /** - * Creates new [Tensor] from full [android.graphics.Bitmap], normalized with specified in - * parameters mean and std. - * - * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order - * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB - * order - */ - @JvmStatic - fun bitmapToFloat32Tensor( - bitmap: Bitmap, - normMeanRGB: FloatArray, - normStdRGB: FloatArray, - ): Tensor { - checkNormMeanArg(normMeanRGB) - checkNormStdArg(normStdRGB) + /** + * Creates new [Tensor] from full [android.graphics.Bitmap], normalized with specified in + * parameters mean and std. + * + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB + * order + */ + @JvmStatic + fun bitmapToFloat32Tensor( + bitmap: Bitmap, + normMeanRGB: FloatArray, + normStdRGB: FloatArray, + ): Tensor { + checkNormMeanArg(normMeanRGB) + checkNormStdArg(normStdRGB) - return bitmapToFloat32Tensor( - bitmap, - 0, - 0, - bitmap.width, - bitmap.height, - normMeanRGB, - normStdRGB, - ) - } + return bitmapToFloat32Tensor( + bitmap, + 0, + 0, + bitmap.width, + bitmap.height, + normMeanRGB, + normStdRGB, + ) + } - /** - * Writes tensor content from specified [android.graphics.Bitmap], normalized with specified in - * parameters mean and std to specified [java.nio.FloatBuffer] with specified offset. - * - * @param bitmap [android.graphics.Bitmap] as a source for Tensor data - * @param x - x coordinate of top left corner of bitmap's area - * @param y - y coordinate of top left corner of bitmap's area - * @param width - width of bitmap's area - * @param height - height of bitmap's area - * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order - * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB - * order - */ - fun bitmapToFloatBuffer( - bitmap: Bitmap, - x: Int, - y: Int, - width: Int, - height: Int, - normMeanRGB: FloatArray, - normStdRGB: FloatArray, - outBuffer: FloatBuffer, - outBufferOffset: Int, - ) { - checkOutBufferCapacity(outBuffer, outBufferOffset, width, height) - checkNormMeanArg(normMeanRGB) - checkNormStdArg(normStdRGB) - val pixelsCount = height * width - val pixels = IntArray(pixelsCount) - bitmap.getPixels(pixels, 0, width, x, y, width, height) - val offsetB = 2 * pixelsCount - for (i in 0..99) { - val c = pixels[i] - Log.i("Image", ": " + i + " " + ((c shr 16) and 0xff)) - } - for (i in 0 until pixelsCount) { - val c = pixels[i] - val r = ((c shr 16) and 0xff) / 255.0f - val g = ((c shr 8) and 0xff) / 255.0f - val b = ((c) and 0xff) / 255.0f - outBuffer.put(outBufferOffset + i, (r - normMeanRGB[0]) / normStdRGB[0]) - outBuffer.put(outBufferOffset + pixelsCount + i, (g - normMeanRGB[1]) / normStdRGB[1]) - outBuffer.put(outBufferOffset + offsetB + i, (b - normMeanRGB[2]) / normStdRGB[2]) - } + /** + * Writes tensor content from specified [android.graphics.Bitmap], normalized with specified in + * parameters mean and std to specified [java.nio.FloatBuffer] with specified offset. + * + * @param bitmap [android.graphics.Bitmap] as a source for Tensor data + * @param x - x coordinate of top left corner of bitmap's area + * @param y - y coordinate of top left corner of bitmap's area + * @param width - width of bitmap's area + * @param height - height of bitmap's area + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB + * order + */ + fun bitmapToFloatBuffer( + bitmap: Bitmap, + x: Int, + y: Int, + width: Int, + height: Int, + normMeanRGB: FloatArray, + normStdRGB: FloatArray, + outBuffer: FloatBuffer, + outBufferOffset: Int, + ) { + checkOutBufferCapacity(outBuffer, outBufferOffset, width, height) + checkNormMeanArg(normMeanRGB) + checkNormStdArg(normStdRGB) + val pixelsCount = height * width + val pixels = IntArray(pixelsCount) + bitmap.getPixels(pixels, 0, width, x, y, width, height) + val offsetB = 2 * pixelsCount + for (i in 0..99) { + val c = pixels[i] + Log.i("Image", ": " + i + " " + ((c shr 16) and 0xff)) } + for (i in 0 until pixelsCount) { + val c = pixels[i] + val r = ((c shr 16) and 0xff) / 255.0f + val g = ((c shr 8) and 0xff) / 255.0f + val b = ((c) and 0xff) / 255.0f + outBuffer.put(outBufferOffset + i, (r - normMeanRGB[0]) / normStdRGB[0]) + outBuffer.put(outBufferOffset + pixelsCount + i, (g - normMeanRGB[1]) / normStdRGB[1]) + outBuffer.put(outBufferOffset + offsetB + i, (b - normMeanRGB[2]) / normStdRGB[2]) + } + } - /** - * Creates new [Tensor] from specified area of [android.graphics.Bitmap], normalized with - * specified in parameters mean and std. - * - * @param bitmap [android.graphics.Bitmap] as a source for Tensor data - * @param x - x coordinate of top left corner of bitmap's area - * @param y - y coordinate of top left corner of bitmap's area - * @param width - width of bitmap's area - * @param height - height of bitmap's area - * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order - * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB - * order - */ - fun bitmapToFloat32Tensor( - bitmap: Bitmap, - x: Int, - y: Int, - width: Int, - height: Int, - normMeanRGB: FloatArray, - normStdRGB: FloatArray, - ): Tensor { - checkNormMeanArg(normMeanRGB) - checkNormStdArg(normStdRGB) + /** + * Creates new [Tensor] from specified area of [android.graphics.Bitmap], normalized with + * specified in parameters mean and std. + * + * @param bitmap [android.graphics.Bitmap] as a source for Tensor data + * @param x - x coordinate of top left corner of bitmap's area + * @param y - y coordinate of top left corner of bitmap's area + * @param width - width of bitmap's area + * @param height - height of bitmap's area + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB + * order + */ + fun bitmapToFloat32Tensor( + bitmap: Bitmap, + x: Int, + y: Int, + width: Int, + height: Int, + normMeanRGB: FloatArray, + normStdRGB: FloatArray, + ): Tensor { + checkNormMeanArg(normMeanRGB) + checkNormStdArg(normStdRGB) - val floatBuffer = Tensor.allocateFloatBuffer(3 * width * height) - bitmapToFloatBuffer(bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0) - return Tensor.fromBlob(floatBuffer, longArrayOf(1, 3, height.toLong(), width.toLong())) - } + val floatBuffer = Tensor.allocateFloatBuffer(3 * width * height) + bitmapToFloatBuffer(bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0) + return Tensor.fromBlob(floatBuffer, longArrayOf(1, 3, height.toLong(), width.toLong())) + } - private fun checkOutBufferCapacity( - outBuffer: FloatBuffer, - outBufferOffset: Int, - tensorWidth: Int, - tensorHeight: Int, - ) { - check(outBufferOffset + 3 * tensorWidth * tensorHeight <= outBuffer.capacity()) { - "Buffer underflow" - } + private fun checkOutBufferCapacity( + outBuffer: FloatBuffer, + outBufferOffset: Int, + tensorWidth: Int, + tensorHeight: Int, + ) { + check(outBufferOffset + 3 * tensorWidth * tensorHeight <= outBuffer.capacity()) { + "Buffer underflow" } + } - private fun checkTensorSize(tensorWidth: Int, tensorHeight: Int) { - require(!(tensorHeight <= 0 || tensorWidth <= 0)) { - "tensorHeight and tensorWidth must be positive" - } + private fun checkTensorSize(tensorWidth: Int, tensorHeight: Int) { + require(!(tensorHeight <= 0 || tensorWidth <= 0)) { + "tensorHeight and tensorWidth must be positive" } + } - private fun checkRotateCWDegrees(rotateCWDegrees: Int) { - require( - !(rotateCWDegrees != 0 && - rotateCWDegrees != 90 && - rotateCWDegrees != 180 && - rotateCWDegrees != 270) - ) { - "rotateCWDegrees must be one of 0, 90, 180, 270" + private fun checkRotateCWDegrees(rotateCWDegrees: Int) { + require( + !(rotateCWDegrees != 0 && + rotateCWDegrees != 90 && + rotateCWDegrees != 180 && + rotateCWDegrees != 270)) { + "rotateCWDegrees must be one of 0, 90, 180, 270" } - } + } - private fun checkNormStdArg(normStdRGB: FloatArray) { - require(normStdRGB.size == 3) { "normStdRGB length must be 3" } - } + private fun checkNormStdArg(normStdRGB: FloatArray) { + require(normStdRGB.size == 3) { "normStdRGB length must be 3" } + } - private fun checkNormMeanArg(normMeanRGB: FloatArray) { - require(normMeanRGB.size == 3) { "normMeanRGB length must be 3" } - } + private fun checkNormMeanArg(normMeanRGB: FloatArray) { + require(normMeanRGB.size == 3) { "normMeanRGB length must be 3" } + } } From a3a02ef4ecf94e17cb55a9cdeb6f9095a6081ebd Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 22 May 2025 16:10:14 -0700 Subject: [PATCH 5/6] format --- .../LlmModuleInstrumentationTest.kt | 159 ++++----- .../org/pytorch/executorch/ModuleE2ETest.kt | 150 ++++----- .../executorch/ModuleInstrumentationTest.kt | 313 +++++++++--------- .../RuntimeInstrumentationTest.java | 7 +- .../pytorch/executorch/TensorImageUtils.kt | 257 +++++++------- 5 files changed, 447 insertions(+), 439 deletions(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt index 5af37d09ffb..0baa1d5191c 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt @@ -11,9 +11,6 @@ import android.Manifest import androidx.test.InstrumentationRegistry import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.rule.GrantPermissionRule -import java.io.File -import java.io.IOException -import java.net.URISyntaxException import org.apache.commons.io.FileUtils import org.json.JSONException import org.json.JSONObject @@ -24,98 +21,102 @@ import org.junit.Test import org.junit.runner.RunWith import org.pytorch.executorch.extension.llm.LlmCallback import org.pytorch.executorch.extension.llm.LlmModule +import java.io.File +import java.io.IOException +import java.net.URISyntaxException /** Unit tests for [org.pytorch.executorch.extension.llm.LlmModule]. */ @RunWith(AndroidJUnit4::class) class LlmModuleInstrumentationTest : LlmCallback { - private val results: MutableList = ArrayList() - private val tokensPerSecond: MutableList = ArrayList() - private var llmModule: LlmModule? = null + private val results: MutableList = ArrayList() + private val tokensPerSecond: MutableList = ArrayList() + private var llmModule: LlmModule? = null - @Before - @Throws(IOException::class) - fun setUp() { - // copy zipped test resources to local device - val addPteFile = File(getTestFilePath(TEST_FILE_NAME)) - var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME) - FileUtils.copyInputStreamToFile(inputStream, addPteFile) - inputStream.close() + @Before + @Throws(IOException::class) + fun setUp() { + // copy zipped test resources to local device + val addPteFile = File(getTestFilePath(TEST_FILE_NAME)) + var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME) + FileUtils.copyInputStreamToFile(inputStream, addPteFile) + inputStream.close() - val tokenizerFile = File(getTestFilePath(TOKENIZER_FILE_NAME)) - inputStream = javaClass.getResourceAsStream(TOKENIZER_FILE_NAME) - FileUtils.copyInputStreamToFile(inputStream, tokenizerFile) - inputStream.close() + val tokenizerFile = File(getTestFilePath(TOKENIZER_FILE_NAME)) + inputStream = javaClass.getResourceAsStream(TOKENIZER_FILE_NAME) + FileUtils.copyInputStreamToFile(inputStream, tokenizerFile) + inputStream.close() - llmModule = - LlmModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f) - } + llmModule = + LlmModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f) + } - @get:Rule - var runtimePermissionRule: GrantPermissionRule = - GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) + @get:Rule + var runtimePermissionRule: GrantPermissionRule = + GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testGenerate() { - val loadResult = llmModule!!.load() - // Check that the model can be load successfully - Assert.assertEquals(OK.toLong(), loadResult.toLong()) + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testGenerate() { + val loadResult = llmModule!!.load() + // Check that the model can be load successfully + Assert.assertEquals(OK.toLong(), loadResult.toLong()) - llmModule!!.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest) - Assert.assertEquals(results.size.toLong(), SEQ_LEN.toLong()) - Assert.assertTrue(tokensPerSecond[tokensPerSecond.size - 1] > 0) - } + llmModule!!.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest) + Assert.assertEquals(results.size.toLong(), SEQ_LEN.toLong()) + Assert.assertTrue(tokensPerSecond[tokensPerSecond.size - 1] > 0) + } - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testGenerateAndStop() { - llmModule!!.generate( - TEST_PROMPT, - SEQ_LEN, - object : LlmCallback { - override fun onResult(result: String) { - this@LlmModuleInstrumentationTest.onResult(result) - llmModule!!.stop() - } + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testGenerateAndStop() { + llmModule!!.generate( + TEST_PROMPT, + SEQ_LEN, + object : LlmCallback { + override fun onResult(result: String) { + this@LlmModuleInstrumentationTest.onResult(result) + llmModule!!.stop() + } - override fun onStats(stats: String) { - this@LlmModuleInstrumentationTest.onStats(stats) - } - }, - ) + override fun onStats(stats: String) { + this@LlmModuleInstrumentationTest.onStats(stats) + } + }, + ) - val stoppedResultSize = results.size - Assert.assertTrue(stoppedResultSize < SEQ_LEN) - } + val stoppedResultSize = results.size + Assert.assertTrue(stoppedResultSize < SEQ_LEN) + } - override fun onResult(result: String) { - results.add(result) - } + override fun onResult(result: String) { + results.add(result) + } - override fun onStats(stats: String) { - var tps = 0f - try { - val jsonObject = JSONObject(stats) - val numGeneratedTokens = jsonObject.getInt("generated_tokens") - val inferenceEndMs = jsonObject.getInt("inference_end_ms") - val promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms") - tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000 - tokensPerSecond.add(tps) - } catch (_: JSONException) {} - } + override fun onStats(stats: String) { + var tps = 0f + try { + val jsonObject = JSONObject(stats) + val numGeneratedTokens = jsonObject.getInt("generated_tokens") + val inferenceEndMs = jsonObject.getInt("inference_end_ms") + val promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms") + tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000 + tokensPerSecond.add(tps) + } catch (_: JSONException) { + } + } - companion object { - private const val TEST_FILE_NAME = "/stories.pte" - private const val TOKENIZER_FILE_NAME = "/tokenizer.bin" - private const val TEST_PROMPT = "Hello" - private const val OK = 0x00 - private const val SEQ_LEN = 32 + companion object { + private const val TEST_FILE_NAME = "/stories.pte" + private const val TOKENIZER_FILE_NAME = "/tokenizer.bin" + private const val TEST_PROMPT = "Hello" + private const val OK = 0x00 + private const val SEQ_LEN = 32 - private fun getTestFilePath(fileName: String): String { - return InstrumentationRegistry.getInstrumentation() - .targetContext - .externalCacheDir - .toString() + fileName + private fun getTestFilePath(fileName: String): String { + return InstrumentationRegistry.getInstrumentation() + .targetContext + .externalCacheDir + .toString() + fileName + } } - } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt index 8366849a3a5..9ba97bdc3ac 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt @@ -13,104 +13,104 @@ import android.graphics.BitmapFactory import androidx.test.InstrumentationRegistry import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.rule.GrantPermissionRule -import java.io.File -import java.io.IOException -import java.net.URISyntaxException import org.apache.commons.io.FileUtils import org.junit.Assert import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith import org.pytorch.executorch.TensorImageUtils.bitmapToFloat32Tensor +import java.io.File +import java.io.IOException +import java.net.URISyntaxException /** Unit tests for [Module]. */ @RunWith(AndroidJUnit4::class) class ModuleE2ETest { - @get:Rule - var runtimePermissionRule: GrantPermissionRule = - GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) + @get:Rule + var runtimePermissionRule: GrantPermissionRule = + GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) - @Throws(IOException::class, URISyntaxException::class) - fun testClassification(filePath: String) { - val pteFile = File(getTestFilePath(filePath)) - val inputStream = javaClass.getResourceAsStream(filePath) - FileUtils.copyInputStreamToFile(inputStream, pteFile) - inputStream.close() + @Throws(IOException::class, URISyntaxException::class) + fun testClassification(filePath: String) { + val pteFile = File(getTestFilePath(filePath)) + val inputStream = javaClass.getResourceAsStream(filePath) + FileUtils.copyInputStreamToFile(inputStream, pteFile) + inputStream.close() - val imgInputStream = javaClass.getResourceAsStream("/banana.jpeg") - var bitmap = BitmapFactory.decodeStream(imgInputStream) - bitmap = Bitmap.createScaledBitmap(bitmap!!, 224, 224, true) - imgInputStream.close() - - val inputTensor = - bitmapToFloat32Tensor( - bitmap, - TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, - TensorImageUtils.TORCHVISION_NORM_STD_RGB, - ) + val imgInputStream = javaClass.getResourceAsStream("/banana.jpeg") + var bitmap = BitmapFactory.decodeStream(imgInputStream) + bitmap = Bitmap.createScaledBitmap(bitmap!!, 224, 224, true) + imgInputStream.close() - val module = Module.load(getTestFilePath(filePath)) + val inputTensor = + bitmapToFloat32Tensor( + bitmap, + TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, + TensorImageUtils.TORCHVISION_NORM_STD_RGB, + ) - val results = module.forward(EValue.from(inputTensor)) - Assert.assertTrue(results[0].isTensor) - val scores = results[0].toTensor().dataAsFloatArray + val module = Module.load(getTestFilePath(filePath)) - val bananaClass = 954 // From ImageNet 1K - Assert.assertEquals(bananaClass.toLong(), argmax(scores).toLong()) - } + val results = module.forward(EValue.from(inputTensor)) + Assert.assertTrue(results[0].isTensor) + val scores = results[0].toTensor().dataAsFloatArray - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testXnnpackBackendRequired() { - val pteFile = File(getTestFilePath("/mv3_xnnpack_fp32.pte")) - val inputStream = javaClass.getResourceAsStream("/mv3_xnnpack_fp32.pte") - FileUtils.copyInputStreamToFile(inputStream, pteFile) - inputStream.close() + val bananaClass = 954 // From ImageNet 1K + Assert.assertEquals(bananaClass.toLong(), argmax(scores).toLong()) + } - val module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte")) - Assert.assertArrayEquals( - arrayOf("XnnpackBackend"), - module.getMethodMetadata("forward").getBackends(), - ) - } + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testXnnpackBackendRequired() { + val pteFile = File(getTestFilePath("/mv3_xnnpack_fp32.pte")) + val inputStream = javaClass.getResourceAsStream("/mv3_xnnpack_fp32.pte") + FileUtils.copyInputStreamToFile(inputStream, pteFile) + inputStream.close() - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testMv2Fp32() { - testClassification("/mv2_xnnpack_fp32.pte") - } + val module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte")) + Assert.assertArrayEquals( + arrayOf("XnnpackBackend"), + module.getMethodMetadata("forward").getBackends(), + ) + } - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testMv3Fp32() { - testClassification("/mv3_xnnpack_fp32.pte") - } + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testMv2Fp32() { + testClassification("/mv2_xnnpack_fp32.pte") + } - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testResnet50() { - testClassification("/resnet50_xnnpack_q8.pte") - } + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testMv3Fp32() { + testClassification("/mv3_xnnpack_fp32.pte") + } - companion object { - private fun getTestFilePath(fileName: String): String { - return InstrumentationRegistry.getInstrumentation() - .targetContext - .externalCacheDir - .toString() + fileName + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testResnet50() { + testClassification("/resnet50_xnnpack_q8.pte") } - fun argmax(array: FloatArray): Int { - require(array.isNotEmpty()) { "Array cannot be empty" } - var maxIndex = 0 - var maxValue = array[0] - for (i in 1 until array.size) { - if (array[i] > maxValue) { - maxValue = array[i] - maxIndex = i + companion object { + private fun getTestFilePath(fileName: String): String { + return InstrumentationRegistry.getInstrumentation() + .targetContext + .externalCacheDir + .toString() + fileName + } + + fun argmax(array: FloatArray): Int { + require(array.isNotEmpty()) { "Array cannot be empty" } + var maxIndex = 0 + var maxValue = array[0] + for (i in 1 until array.size) { + if (array[i] > maxValue) { + maxValue = array[i] + maxIndex = i + } + } + return maxIndex } - } - return maxIndex } - } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index 350a73dad1b..2468a9d2ff2 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -11,179 +11,180 @@ import android.Manifest import androidx.test.InstrumentationRegistry import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.rule.GrantPermissionRule -import java.io.File -import java.io.IOException -import java.net.URISyntaxException -import java.util.concurrent.CountDownLatch -import java.util.concurrent.TimeUnit -import java.util.concurrent.atomic.AtomicInteger import org.apache.commons.io.FileUtils import org.junit.Assert import org.junit.Before import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith +import java.io.File +import java.io.IOException +import java.net.URISyntaxException +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicInteger /** Unit tests for [Module]. */ @RunWith(AndroidJUnit4::class) class ModuleInstrumentationTest { - @Before - @Throws(IOException::class) - fun setUp() { - // copy zipped test resources to local device - val addPteFile = File(getTestFilePath(TEST_FILE_NAME)) - var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME) - FileUtils.copyInputStreamToFile(inputStream, addPteFile) - inputStream.close() - - val nonPteFile = File(getTestFilePath(NON_PTE_FILE_NAME)) - inputStream = javaClass.getResourceAsStream(NON_PTE_FILE_NAME) - FileUtils.copyInputStreamToFile(inputStream, nonPteFile) - inputStream.close() - } - - @get:Rule - var runtimePermissionRule: GrantPermissionRule = - GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) - - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testModuleLoadAndForward() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val results = module.forward() - Assert.assertTrue(results[0].isTensor) - } - - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testMethodMetadata() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - Assert.assertArrayEquals(arrayOf("forward"), module.getMethods()) - Assert.assertTrue(module.getMethodMetadata("forward").backends.isEmpty()) - Assert.assertArrayEquals(intArrayOf(1, 1, 3), module.getMethodMetadata("forward").inputTags) - Assert.assertArrayEquals(intArrayOf(1), module.getMethodMetadata("forward").outputTags) - } - - @Test - @Throws(IOException::class) - fun testModuleLoadMethodAndForward() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), OK.toLong()) - - val results = module.forward() - Assert.assertTrue(results[0].isTensor) - } - - @Test - @Throws(IOException::class) - fun testModuleLoadForwardExplicit() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val results = module.execute(FORWARD_METHOD) - Assert.assertTrue(results[0].isTensor) - } - - @Test(expected = RuntimeException::class) - @Throws(IOException::class) - fun testModuleLoadNonExistantFile() { - val module = Module.load(getTestFilePath(MISSING_FILE_NAME)) - } - - @Test - @Throws(IOException::class) - fun testModuleLoadMethodNonExistantMethod() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val loadMethod = module.loadMethod(NONE_METHOD) - Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) - } - - @Test(expected = RuntimeException::class) - @Throws(IOException::class) - fun testNonPteFile() { - val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME)) - - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) - } - - @Test - @Throws(IOException::class) - fun testLoadOnDestroyedModule() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - module.destroy() - - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), INVALID_STATE.toLong()) - } - - @Test - @Throws(IOException::class) - fun testForwardOnDestroyedModule() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), OK.toLong()) - - module.destroy() - - val results = module.forward() - Assert.assertEquals(0, results.size.toLong()) - } - - @Test - @Throws(InterruptedException::class, IOException::class) - fun testForwardFromMultipleThreads() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val numThreads = 100 - val latch = CountDownLatch(numThreads) - val completed = AtomicInteger(0) - - val runnable = Runnable { - try { - latch.countDown() - latch.await(5000, TimeUnit.MILLISECONDS) + @Before + @Throws(IOException::class) + fun setUp() { + // copy zipped test resources to local device + val addPteFile = File(getTestFilePath(TEST_FILE_NAME)) + var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME) + FileUtils.copyInputStreamToFile(inputStream, addPteFile) + inputStream.close() + + val nonPteFile = File(getTestFilePath(NON_PTE_FILE_NAME)) + inputStream = javaClass.getResourceAsStream(NON_PTE_FILE_NAME) + FileUtils.copyInputStreamToFile(inputStream, nonPteFile) + inputStream.close() + } + + @get:Rule + var runtimePermissionRule: GrantPermissionRule = + GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) + + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testModuleLoadAndForward() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + val results = module.forward() Assert.assertTrue(results[0].isTensor) - completed.incrementAndGet() - } catch (_: InterruptedException) {} } - val threads = arrayOfNulls(numThreads) - for (i in 0 until numThreads) { - threads[i] = Thread(runnable) - threads[i]!!.start() + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testMethodMetadata() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + Assert.assertArrayEquals(arrayOf("forward"), module.getMethods()) + Assert.assertTrue(module.getMethodMetadata("forward").backends.isEmpty()) + Assert.assertArrayEquals(intArrayOf(1, 1, 3), module.getMethodMetadata("forward").inputTags) + Assert.assertArrayEquals(intArrayOf(1), module.getMethodMetadata("forward").outputTags) + } + + @Test + @Throws(IOException::class) + fun testModuleLoadMethodAndForward() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + val loadMethod = module.loadMethod(FORWARD_METHOD) + Assert.assertEquals(loadMethod.toLong(), OK.toLong()) + + val results = module.forward() + Assert.assertTrue(results[0].isTensor) + } + + @Test + @Throws(IOException::class) + fun testModuleLoadForwardExplicit() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + val results = module.execute(FORWARD_METHOD) + Assert.assertTrue(results[0].isTensor) + } + + @Test(expected = RuntimeException::class) + @Throws(IOException::class) + fun testModuleLoadNonExistantFile() { + val module = Module.load(getTestFilePath(MISSING_FILE_NAME)) + } + + @Test + @Throws(IOException::class) + fun testModuleLoadMethodNonExistantMethod() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + val loadMethod = module.loadMethod(NONE_METHOD) + Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) + } + + @Test(expected = RuntimeException::class) + @Throws(IOException::class) + fun testNonPteFile() { + val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME)) + + val loadMethod = module.loadMethod(FORWARD_METHOD) + Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) + } + + @Test + @Throws(IOException::class) + fun testLoadOnDestroyedModule() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + module.destroy() + + val loadMethod = module.loadMethod(FORWARD_METHOD) + Assert.assertEquals(loadMethod.toLong(), INVALID_STATE.toLong()) + } + + @Test + @Throws(IOException::class) + fun testForwardOnDestroyedModule() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + val loadMethod = module.loadMethod(FORWARD_METHOD) + Assert.assertEquals(loadMethod.toLong(), OK.toLong()) + + module.destroy() + + val results = module.forward() + Assert.assertEquals(0, results.size.toLong()) } - for (i in 0 until numThreads) { - threads[i]!!.join() + @Test + @Throws(InterruptedException::class, IOException::class) + fun testForwardFromMultipleThreads() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + val numThreads = 100 + val latch = CountDownLatch(numThreads) + val completed = AtomicInteger(0) + + val runnable = Runnable { + try { + latch.countDown() + latch.await(5000, TimeUnit.MILLISECONDS) + val results = module.forward() + Assert.assertTrue(results[0].isTensor) + completed.incrementAndGet() + } catch (_: InterruptedException) { + } + } + + val threads = arrayOfNulls(numThreads) + for (i in 0 until numThreads) { + threads[i] = Thread(runnable) + threads[i]!!.start() + } + + for (i in 0 until numThreads) { + threads[i]!!.join() + } + + Assert.assertEquals(numThreads.toLong(), completed.get().toLong()) } - Assert.assertEquals(numThreads.toLong(), completed.get().toLong()) - } - - companion object { - private const val TEST_FILE_NAME = "/ModuleAdd.pte" - private const val MISSING_FILE_NAME = "/missing.pte" - private const val NON_PTE_FILE_NAME = "/test.txt" - private const val FORWARD_METHOD = "forward" - private const val NONE_METHOD = "none" - private const val OK = 0x00 - private const val INVALID_STATE = 0x2 - private const val INVALID_ARGUMENT = 0x12 - private const val ACCESS_FAILED = 0x22 - - private fun getTestFilePath(fileName: String): String { - return InstrumentationRegistry.getInstrumentation() - .targetContext - .externalCacheDir - .toString() + fileName + companion object { + private const val TEST_FILE_NAME = "/ModuleAdd.pte" + private const val MISSING_FILE_NAME = "/missing.pte" + private const val NON_PTE_FILE_NAME = "/test.txt" + private const val FORWARD_METHOD = "forward" + private const val NONE_METHOD = "none" + private const val OK = 0x00 + private const val INVALID_STATE = 0x2 + private const val INVALID_ARGUMENT = 0x12 + private const val ACCESS_FAILED = 0x22 + + private fun getTestFilePath(fileName: String): String { + return InstrumentationRegistry.getInstrumentation() + .targetContext + .externalCacheDir + .toString() + fileName + } } - } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/RuntimeInstrumentationTest.java b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/RuntimeInstrumentationTest.java index 27114b4cc77..574aca9723c 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/RuntimeInstrumentationTest.java +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/RuntimeInstrumentationTest.java @@ -11,10 +11,13 @@ import static org.junit.Assert.assertNotNull; import androidx.test.ext.junit.runners.AndroidJUnit4; -import org.junit.runner.RunWith; + import org.junit.Test; +import org.junit.runner.RunWith; -/** Unit tests for {@link ExecuTorchRuntime}. */ +/** + * Unit tests for {@link ExecuTorchRuntime}. + */ @RunWith(AndroidJUnit4.class) public class RuntimeInstrumentationTest { diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt index d4c796ea6cb..dd3f5a1dfd0 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt @@ -16,146 +16,149 @@ import java.nio.FloatBuffer * [android.media.Image] source. */ object TensorImageUtils { - @JvmField var TORCHVISION_NORM_MEAN_RGB: FloatArray = floatArrayOf(0.485f, 0.456f, 0.406f) + @JvmField + var TORCHVISION_NORM_MEAN_RGB: FloatArray = floatArrayOf(0.485f, 0.456f, 0.406f) - @JvmField var TORCHVISION_NORM_STD_RGB: FloatArray = floatArrayOf(0.229f, 0.224f, 0.225f) + @JvmField + var TORCHVISION_NORM_STD_RGB: FloatArray = floatArrayOf(0.229f, 0.224f, 0.225f) - /** - * Creates new [Tensor] from full [android.graphics.Bitmap], normalized with specified in - * parameters mean and std. - * - * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order - * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB - * order - */ - @JvmStatic - fun bitmapToFloat32Tensor( - bitmap: Bitmap, - normMeanRGB: FloatArray, - normStdRGB: FloatArray, - ): Tensor { - checkNormMeanArg(normMeanRGB) - checkNormStdArg(normStdRGB) + /** + * Creates new [Tensor] from full [android.graphics.Bitmap], normalized with specified in + * parameters mean and std. + * + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB + * order + */ + @JvmStatic + fun bitmapToFloat32Tensor( + bitmap: Bitmap, + normMeanRGB: FloatArray, + normStdRGB: FloatArray, + ): Tensor { + checkNormMeanArg(normMeanRGB) + checkNormStdArg(normStdRGB) - return bitmapToFloat32Tensor( - bitmap, - 0, - 0, - bitmap.width, - bitmap.height, - normMeanRGB, - normStdRGB, - ) - } - - /** - * Writes tensor content from specified [android.graphics.Bitmap], normalized with specified in - * parameters mean and std to specified [java.nio.FloatBuffer] with specified offset. - * - * @param bitmap [android.graphics.Bitmap] as a source for Tensor data - * @param x - x coordinate of top left corner of bitmap's area - * @param y - y coordinate of top left corner of bitmap's area - * @param width - width of bitmap's area - * @param height - height of bitmap's area - * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order - * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB - * order - */ - fun bitmapToFloatBuffer( - bitmap: Bitmap, - x: Int, - y: Int, - width: Int, - height: Int, - normMeanRGB: FloatArray, - normStdRGB: FloatArray, - outBuffer: FloatBuffer, - outBufferOffset: Int, - ) { - checkOutBufferCapacity(outBuffer, outBufferOffset, width, height) - checkNormMeanArg(normMeanRGB) - checkNormStdArg(normStdRGB) - val pixelsCount = height * width - val pixels = IntArray(pixelsCount) - bitmap.getPixels(pixels, 0, width, x, y, width, height) - val offsetB = 2 * pixelsCount - for (i in 0..99) { - val c = pixels[i] - Log.i("Image", ": " + i + " " + ((c shr 16) and 0xff)) + return bitmapToFloat32Tensor( + bitmap, + 0, + 0, + bitmap.width, + bitmap.height, + normMeanRGB, + normStdRGB, + ) } - for (i in 0 until pixelsCount) { - val c = pixels[i] - val r = ((c shr 16) and 0xff) / 255.0f - val g = ((c shr 8) and 0xff) / 255.0f - val b = ((c) and 0xff) / 255.0f - outBuffer.put(outBufferOffset + i, (r - normMeanRGB[0]) / normStdRGB[0]) - outBuffer.put(outBufferOffset + pixelsCount + i, (g - normMeanRGB[1]) / normStdRGB[1]) - outBuffer.put(outBufferOffset + offsetB + i, (b - normMeanRGB[2]) / normStdRGB[2]) + + /** + * Writes tensor content from specified [android.graphics.Bitmap], normalized with specified in + * parameters mean and std to specified [java.nio.FloatBuffer] with specified offset. + * + * @param bitmap [android.graphics.Bitmap] as a source for Tensor data + * @param x - x coordinate of top left corner of bitmap's area + * @param y - y coordinate of top left corner of bitmap's area + * @param width - width of bitmap's area + * @param height - height of bitmap's area + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB + * order + */ + fun bitmapToFloatBuffer( + bitmap: Bitmap, + x: Int, + y: Int, + width: Int, + height: Int, + normMeanRGB: FloatArray, + normStdRGB: FloatArray, + outBuffer: FloatBuffer, + outBufferOffset: Int, + ) { + checkOutBufferCapacity(outBuffer, outBufferOffset, width, height) + checkNormMeanArg(normMeanRGB) + checkNormStdArg(normStdRGB) + val pixelsCount = height * width + val pixels = IntArray(pixelsCount) + bitmap.getPixels(pixels, 0, width, x, y, width, height) + val offsetB = 2 * pixelsCount + for (i in 0..99) { + val c = pixels[i] + Log.i("Image", ": " + i + " " + ((c shr 16) and 0xff)) + } + for (i in 0 until pixelsCount) { + val c = pixels[i] + val r = ((c shr 16) and 0xff) / 255.0f + val g = ((c shr 8) and 0xff) / 255.0f + val b = ((c) and 0xff) / 255.0f + outBuffer.put(outBufferOffset + i, (r - normMeanRGB[0]) / normStdRGB[0]) + outBuffer.put(outBufferOffset + pixelsCount + i, (g - normMeanRGB[1]) / normStdRGB[1]) + outBuffer.put(outBufferOffset + offsetB + i, (b - normMeanRGB[2]) / normStdRGB[2]) + } } - } - /** - * Creates new [Tensor] from specified area of [android.graphics.Bitmap], normalized with - * specified in parameters mean and std. - * - * @param bitmap [android.graphics.Bitmap] as a source for Tensor data - * @param x - x coordinate of top left corner of bitmap's area - * @param y - y coordinate of top left corner of bitmap's area - * @param width - width of bitmap's area - * @param height - height of bitmap's area - * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order - * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB - * order - */ - fun bitmapToFloat32Tensor( - bitmap: Bitmap, - x: Int, - y: Int, - width: Int, - height: Int, - normMeanRGB: FloatArray, - normStdRGB: FloatArray, - ): Tensor { - checkNormMeanArg(normMeanRGB) - checkNormStdArg(normStdRGB) + /** + * Creates new [Tensor] from specified area of [android.graphics.Bitmap], normalized with + * specified in parameters mean and std. + * + * @param bitmap [android.graphics.Bitmap] as a source for Tensor data + * @param x - x coordinate of top left corner of bitmap's area + * @param y - y coordinate of top left corner of bitmap's area + * @param width - width of bitmap's area + * @param height - height of bitmap's area + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB + * order + */ + fun bitmapToFloat32Tensor( + bitmap: Bitmap, + x: Int, + y: Int, + width: Int, + height: Int, + normMeanRGB: FloatArray, + normStdRGB: FloatArray, + ): Tensor { + checkNormMeanArg(normMeanRGB) + checkNormStdArg(normStdRGB) - val floatBuffer = Tensor.allocateFloatBuffer(3 * width * height) - bitmapToFloatBuffer(bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0) - return Tensor.fromBlob(floatBuffer, longArrayOf(1, 3, height.toLong(), width.toLong())) - } + val floatBuffer = Tensor.allocateFloatBuffer(3 * width * height) + bitmapToFloatBuffer(bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0) + return Tensor.fromBlob(floatBuffer, longArrayOf(1, 3, height.toLong(), width.toLong())) + } - private fun checkOutBufferCapacity( - outBuffer: FloatBuffer, - outBufferOffset: Int, - tensorWidth: Int, - tensorHeight: Int, - ) { - check(outBufferOffset + 3 * tensorWidth * tensorHeight <= outBuffer.capacity()) { - "Buffer underflow" + private fun checkOutBufferCapacity( + outBuffer: FloatBuffer, + outBufferOffset: Int, + tensorWidth: Int, + tensorHeight: Int, + ) { + check(outBufferOffset + 3 * tensorWidth * tensorHeight <= outBuffer.capacity()) { + "Buffer underflow" + } } - } - private fun checkTensorSize(tensorWidth: Int, tensorHeight: Int) { - require(!(tensorHeight <= 0 || tensorWidth <= 0)) { - "tensorHeight and tensorWidth must be positive" + private fun checkTensorSize(tensorWidth: Int, tensorHeight: Int) { + require(!(tensorHeight <= 0 || tensorWidth <= 0)) { + "tensorHeight and tensorWidth must be positive" + } } - } - private fun checkRotateCWDegrees(rotateCWDegrees: Int) { - require( - !(rotateCWDegrees != 0 && - rotateCWDegrees != 90 && - rotateCWDegrees != 180 && - rotateCWDegrees != 270)) { - "rotateCWDegrees must be one of 0, 90, 180, 270" + private fun checkRotateCWDegrees(rotateCWDegrees: Int) { + require( + !(rotateCWDegrees != 0 && + rotateCWDegrees != 90 && + rotateCWDegrees != 180 && + rotateCWDegrees != 270) + ) { + "rotateCWDegrees must be one of 0, 90, 180, 270" } - } + } - private fun checkNormStdArg(normStdRGB: FloatArray) { - require(normStdRGB.size == 3) { "normStdRGB length must be 3" } - } + private fun checkNormStdArg(normStdRGB: FloatArray) { + require(normStdRGB.size == 3) { "normStdRGB length must be 3" } + } - private fun checkNormMeanArg(normMeanRGB: FloatArray) { - require(normMeanRGB.size == 3) { "normMeanRGB length must be 3" } - } + private fun checkNormMeanArg(normMeanRGB: FloatArray) { + require(normMeanRGB.size == 3) { "normMeanRGB length must be 3" } + } } From 06f8e7b871760fe6ed45f6c04456e96afb5895e9 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 22 May 2025 16:12:08 -0700 Subject: [PATCH 6/6] Update --- .../executorch/LlmModuleInstrumentationTest.kt | 9 ++++----- .../java/org/pytorch/executorch/ModuleE2ETest.kt | 9 +++++---- .../executorch/ModuleInstrumentationTest.kt | 15 +++++++-------- .../executorch/RuntimeInstrumentationTest.java | 7 ++----- .../org/pytorch/executorch/TensorImageUtils.kt | 12 +++++------- 5 files changed, 23 insertions(+), 29 deletions(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt index 0baa1d5191c..43ce302a7a6 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt @@ -11,6 +11,9 @@ import android.Manifest import androidx.test.InstrumentationRegistry import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.rule.GrantPermissionRule +import java.io.File +import java.io.IOException +import java.net.URISyntaxException import org.apache.commons.io.FileUtils import org.json.JSONException import org.json.JSONObject @@ -21,9 +24,6 @@ import org.junit.Test import org.junit.runner.RunWith import org.pytorch.executorch.extension.llm.LlmCallback import org.pytorch.executorch.extension.llm.LlmModule -import java.io.File -import java.io.IOException -import java.net.URISyntaxException /** Unit tests for [org.pytorch.executorch.extension.llm.LlmModule]. */ @RunWith(AndroidJUnit4::class) @@ -101,8 +101,7 @@ class LlmModuleInstrumentationTest : LlmCallback { val promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms") tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000 tokensPerSecond.add(tps) - } catch (_: JSONException) { - } + } catch (_: JSONException) {} } companion object { diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt index 9ba97bdc3ac..381b1bd99d1 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt @@ -13,15 +13,15 @@ import android.graphics.BitmapFactory import androidx.test.InstrumentationRegistry import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.rule.GrantPermissionRule +import java.io.File +import java.io.IOException +import java.net.URISyntaxException import org.apache.commons.io.FileUtils import org.junit.Assert import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith import org.pytorch.executorch.TensorImageUtils.bitmapToFloat32Tensor -import java.io.File -import java.io.IOException -import java.net.URISyntaxException /** Unit tests for [Module]. */ @RunWith(AndroidJUnit4::class) @@ -68,8 +68,9 @@ class ModuleE2ETest { inputStream.close() val module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte")) + val expectedBackends = arrayOf("XnnpackBackend") Assert.assertArrayEquals( - arrayOf("XnnpackBackend"), + expectedBackends, module.getMethodMetadata("forward").getBackends(), ) } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index 2468a9d2ff2..50d469e606d 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -11,18 +11,18 @@ import android.Manifest import androidx.test.InstrumentationRegistry import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.rule.GrantPermissionRule -import org.apache.commons.io.FileUtils -import org.junit.Assert -import org.junit.Before -import org.junit.Rule -import org.junit.Test -import org.junit.runner.RunWith import java.io.File import java.io.IOException import java.net.URISyntaxException import java.util.concurrent.CountDownLatch import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger +import org.apache.commons.io.FileUtils +import org.junit.Assert +import org.junit.Before +import org.junit.Rule +import org.junit.Test +import org.junit.runner.RunWith /** Unit tests for [Module]. */ @RunWith(AndroidJUnit4::class) @@ -152,8 +152,7 @@ class ModuleInstrumentationTest { val results = module.forward() Assert.assertTrue(results[0].isTensor) completed.incrementAndGet() - } catch (_: InterruptedException) { - } + } catch (_: InterruptedException) {} } val threads = arrayOfNulls(numThreads) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/RuntimeInstrumentationTest.java b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/RuntimeInstrumentationTest.java index 574aca9723c..27114b4cc77 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/RuntimeInstrumentationTest.java +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/RuntimeInstrumentationTest.java @@ -11,13 +11,10 @@ import static org.junit.Assert.assertNotNull; import androidx.test.ext.junit.runners.AndroidJUnit4; - -import org.junit.Test; import org.junit.runner.RunWith; +import org.junit.Test; -/** - * Unit tests for {@link ExecuTorchRuntime}. - */ +/** Unit tests for {@link ExecuTorchRuntime}. */ @RunWith(AndroidJUnit4.class) public class RuntimeInstrumentationTest { diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt index dd3f5a1dfd0..cb2e365a4c5 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt @@ -16,11 +16,9 @@ import java.nio.FloatBuffer * [android.media.Image] source. */ object TensorImageUtils { - @JvmField - var TORCHVISION_NORM_MEAN_RGB: FloatArray = floatArrayOf(0.485f, 0.456f, 0.406f) + @JvmField var TORCHVISION_NORM_MEAN_RGB: FloatArray = floatArrayOf(0.485f, 0.456f, 0.406f) - @JvmField - var TORCHVISION_NORM_STD_RGB: FloatArray = floatArrayOf(0.229f, 0.224f, 0.225f) + @JvmField var TORCHVISION_NORM_STD_RGB: FloatArray = floatArrayOf(0.229f, 0.224f, 0.225f) /** * Creates new [Tensor] from full [android.graphics.Bitmap], normalized with specified in @@ -146,9 +144,9 @@ object TensorImageUtils { private fun checkRotateCWDegrees(rotateCWDegrees: Int) { require( !(rotateCWDegrees != 0 && - rotateCWDegrees != 90 && - rotateCWDegrees != 180 && - rotateCWDegrees != 270) + rotateCWDegrees != 90 && + rotateCWDegrees != 180 && + rotateCWDegrees != 270) ) { "rotateCWDegrees must be one of 0, 90, 180, 270" }