diff --git a/tfjs-tflite/BUILD.bazel b/tfjs-tflite/BUILD.bazel index b20a3264c63..a30686ec811 100644 --- a/tfjs-tflite/BUILD.bazel +++ b/tfjs-tflite/BUILD.bazel @@ -103,5 +103,6 @@ test_suite( name = "tests", tests = [ ":tfjs-tflite_test", + "//tfjs-tflite/src:worker_test", ], ) diff --git a/tfjs-tflite/src/BUILD.bazel b/tfjs-tflite/src/BUILD.bazel index 07176ac7fd4..2171aeda521 100644 --- a/tfjs-tflite/src/BUILD.bazel +++ b/tfjs-tflite/src/BUILD.bazel @@ -15,6 +15,7 @@ load("@build_bazel_rules_nodejs//:index.bzl", "copy_to_bin") load("//tools:defaults.bzl", "esbuild", "ts_library") +load("//tools:tfjs_web_test.bzl", "tfjs_web_test") package(default_visibility = ["//visibility:public"]) @@ -55,7 +56,10 @@ copy_to_bin( ts_library( name = "tfjs-tflite_test_lib", - srcs = glob(TEST_SRCS), + srcs = glob( + TEST_SRCS, + exclude = ["worker_test.ts"], + ), module_name = "@tensorflow/tfjs-tflite/dist", deps = [ ":tfjs-tflite_lib", @@ -83,3 +87,46 @@ esbuild( "//tfjs-tflite/wasm:wasm_files", ], ) + +ts_library( + name = "worker_test_lib", + srcs = [ + "worker_test.ts", + ], + deps = [ + "//tfjs-backend-cpu/src:tfjs-backend-cpu_lib", + "//tfjs-core/src:tfjs-core_lib", + "//tfjs-core/src:tfjs-core_src_lib", + ], +) + +tfjs_web_test( + name = "worker_test", + browsers = [ + "bs_chrome_mac", + "bs_firefox_mac", + "bs_safari_mac", + # Temporarily disabled because BrowserStack does not support loading + # absolute paths in iOS, which is required for loading the worker. + # https://www.browserstack.com/question/39573 + # "bs_ios_12", + "bs_android_9", + "win_10_chrome", + ], + static_files = [ + # For the webworker + "//tfjs-core:tf-core.min.js", + "//tfjs-core:tf-core.min.js.map", + "//tfjs-backend-cpu:tf-backend-cpu.min.js", + "//tfjs-backend-cpu:tf-backend-cpu.min.js.map", + "//tfjs-tflite:tf-tflite.min.js", + "//tfjs-tflite:tf-tflite.min.js.map", + "//tfjs-tflite/wasm:wasm_files", + "//tfjs-tflite/test_files:add4.tflite", + ], + deps = [ + ":worker_test_lib", + "@npm//long:long__umd", + "@npm//seedrandom:seedrandom__umd", + ], +) diff --git a/tfjs-tflite/src/worker_test.ts b/tfjs-tflite/src/worker_test.ts new file mode 100644 index 00000000000..4865470c2df --- /dev/null +++ b/tfjs-tflite/src/worker_test.ts @@ -0,0 +1,60 @@ +/** + * @license + * Copyright 2022 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import '@tensorflow/tfjs-backend-cpu'; + +const str2workerURL = (str: string): string => { + const blob = + new Blob([str], {type: 'application/javascript'}); + return URL.createObjectURL(blob); +}; + +// The source code of a web worker. +const workerTest = ` +importScripts(location.origin + '/base/tfjs/tfjs-core/tf-core.min.js'); +importScripts(location.origin + + '/base/tfjs/tfjs-backend-cpu/tf-backend-cpu.min.js'); +// Import order matters. TFLite must be imported after tfjs core. +importScripts(location.origin + '/base/tfjs/tfjs-tflite/tf-tflite.min.js'); + +// Setting wasm path is required. It can be set to CDN if needed, +// but that's not a good idea for a test. +tflite.setWasmPath('/base/tfjs/tfjs-tflite/wasm/'); +async function main() { + // This is a test model that adds two tensors of shape [1, 4]. + const model = await tflite.loadTFLiteModel(location.origin + '/base/tfjs/tfjs-tflite/test_files/add4.tflite'); + + const a = tf.tensor2d([[1, 2, 3, 4]]); + const b = tf.tensor2d([[5, 6, 7, 8]]); + const output = model.predict([a, b]); + + self.postMessage({data: output.dataSync()}); +} + +main(); +`; + +describe('tflite in worker', () => { + it('runs a model', (done) => { + const worker = new Worker(str2workerURL(workerTest)); + worker.onmessage = (msg) => { + const data = msg.data.data; + expect([...data]).toEqual([6, 8, 10, 12]); + done(); + }; + }, 15_000); +}); diff --git a/tfjs-tflite/test_files/BUILD.bazel b/tfjs-tflite/test_files/BUILD.bazel new file mode 100644 index 00000000000..ba0b9d6c0eb --- /dev/null +++ b/tfjs-tflite/test_files/BUILD.bazel @@ -0,0 +1,21 @@ +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +package(default_visibility = ["//visibility:public"]) + +exports_files([ + # add4.tflite adds two tensors of shape [1,4] + "add4.tflite", +]) diff --git a/tfjs-tflite/test_files/add4.tflite b/tfjs-tflite/test_files/add4.tflite new file mode 100644 index 00000000000..6ee2860e16b Binary files /dev/null and b/tfjs-tflite/test_files/add4.tflite differ