diff --git a/tfjs-backend-webgl/src/gather_nd_gpu.ts b/tfjs-backend-webgl/src/gather_nd_gpu.ts index 52cba42bf6f..507f9f743b6 100644 --- a/tfjs-backend-webgl/src/gather_nd_gpu.ts +++ b/tfjs-backend-webgl/src/gather_nd_gpu.ts @@ -25,24 +25,26 @@ export class GatherNDProgram implements GPGPUProgram { private sliceDim: number, private strides: number[], shape: number[], private paramsShape: number[]) { this.outputShape = shape; - const stridesType = getCoordsDataType(strides.length); const dtype = getCoordsDataType(shape.length); - const strideString = this.sliceDim > 1 ? 'strides[j]' : 'strides'; - const paramsShapeType = getCoordsDataType(paramsShape.length); - const paramsShapeString = paramsShape.length > 1 ? 'paramsShape[j]' : 'paramsShape'; + + let mainLoop = ` + int index;`; + for (let j = 0; j < this.sliceDim; j++) { + mainLoop += ` + index = round(getIndices(coords[0], ${j})); + out_of_bounds = out_of_bounds || index < 0; + out_of_bounds = out_of_bounds || index >= ${this.paramsShape[j]}; + flattenIndex += index * ${this.strides[j]};`; + } + this.userCode = ` - ${stridesType} strides = ${stridesType}(${this.strides}); - ${paramsShapeType} paramsShape = ${paramsShapeType}(${this.paramsShape}); void main() { ${dtype} coords = getOutputCoords(); int flattenIndex = 0; bool out_of_bounds = false; - for (int j = 0; j < ${this.sliceDim}; j++) { - int index = round(getIndices(coords[0], j)); - out_of_bounds = out_of_bounds || index < 0; - out_of_bounds = out_of_bounds || index >= ${paramsShapeString}; - flattenIndex += index * ${strideString}; - } + + ${mainLoop} + setOutput(out_of_bounds ? 0.0 : getX(flattenIndex, coords[1])); } `; diff --git a/tfjs-backend-webgl/src/webgl_ops_test.ts b/tfjs-backend-webgl/src/webgl_ops_test.ts index 531314679aa..22d99ea7364 100644 --- a/tfjs-backend-webgl/src/webgl_ops_test.ts +++ b/tfjs-backend-webgl/src/webgl_ops_test.ts @@ -918,6 +918,15 @@ describeWithFlags('gatherNd', WEBGL_ENVS, () => { expectArraysEqual(await g.data(), [0, 1, 2, 0]); }); + it('works for out of bounds indices 1d', async () => { + const x = tf.tensor1d([...Array(4).keys()].map(e => e + 1), 'int32'); + const indices = [0, 1, 2, 5]; + const ind = tf.tensor2d(indices, [4, 1], 'int32'); + const g = tf.gatherND(x, ind); + const expected = [1, 2, 3, 0]; + expectArraysEqual(await g.data(), expected); + }); + it('works for out of bounds indices 2d', async () => { const x = tf.tensor2d([...Array(4).keys()].map(e => e + 1), [2, 2], 'int32'); @@ -957,4 +966,14 @@ describeWithFlags('gatherNd', WEBGL_ENVS, () => { ]; expectArraysEqual(await g.data(), expected); }); + + it('works for out of bounds indices 5d', async () => { + const x = tf.tensor5d( + [...Array(32).keys()].map(e => e + 1), [2, 2, 2, 2, 2], 'int32'); + const indices = [0, 0, 0, 0, 0, 2, 1, 1, 1, 1]; + const ind = tf.tensor2d(indices, [2, 5], 'int32'); + const g = tf.gatherND(x, ind); + const expected = [1, 0]; + expectArraysEqual(await g.data(), expected); + }); }); diff --git a/tfjs-backend-webgpu/src/conv2d_naive_webgpu.ts b/tfjs-backend-webgpu/src/conv2d_naive_webgpu.ts new file mode 100644 index 00000000000..834e1c5ead6 --- /dev/null +++ b/tfjs-backend-webgpu/src/conv2d_naive_webgpu.ts @@ -0,0 +1,121 @@ +/** + * @license + * Copyright 2022 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util} from '@tensorflow/tfjs-core'; + +import {activationFnSnippet, biasActivationSnippet} from './activation_util'; +import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program'; +import {computeDispatch} from './webgpu_util'; + +export class Conv2DNaiveProgram implements WebGPUProgram { + outputShape: number[]; + shaderKey: string; + dispatchLayout: {x: number[], y: number[], z: number[]}; + dispatch: [number, number, number]; + variableNames = ['x', 'W']; + uniforms = + 'filterDims: vec2, pad: vec2, stride: vec2, dilation: vec2,'; + workGroupSize: [number, number, number] = [4, 4, 8]; + addBias: boolean; + activation: backend_util.Activation; + hasPreluActivationWeights: boolean; + isChannelsLast: boolean; + + constructor( + convInfo: backend_util.Conv2DInfo, addBias = false, + activation: backend_util.Activation = null, + hasPreluActivationWeights = false) { + this.outputShape = convInfo.outShape; + this.isChannelsLast = convInfo.dataFormat === 'channelsLast'; + this.dispatchLayout = this.isChannelsLast ? {x: [2], y: [1], z: [0, 3]} : + {x: [3], y: [2], z: [0, 1]}; + this.dispatch = computeDispatch( + this.dispatchLayout, this.outputShape, this.workGroupSize); + this.addBias = addBias; + this.activation = activation; + this.hasPreluActivationWeights = hasPreluActivationWeights; + + if (addBias) { + this.variableNames.push('bias'); + } + + if (hasPreluActivationWeights) { + this.variableNames.push('preluActivationWeights'); + } + + this.shaderKey = `conv2dnaive_${this.activation}_${this.isChannelsLast}`; + } + + getUserCode(): string { + const userCode = ` + ${ + activationFnSnippet( + this.activation, this.hasPreluActivationWeights, false, 4)} + fn readInp(batch : i32, row : i32, col : i32, chan : i32) -> f32{ + let coords = vec4(batch, row, col, chan); + if (coordsInBounds4D(coords, uniforms.xShape)) { + return getX(batch, row, col, chan); + } else { + return 0.0; + } + } + fn readFilt(row : i32, col : i32, xChannel : i32, outChannel : i32) -> f32{ + let coords = vec4(row, col, xChannel, outChannel); + if(coordsInBounds4D(coords, uniforms.wShape)) { + return getW(row, col, xChannel, outChannel); + } else { + return 0.0; + } + } + fn writeResult(batch : i32, row : i32, col : i32, chan : i32, valueIn : f32) { + let coords = ${ + this.isChannelsLast ? `vec4(batch, row, col, chan);` : + `vec4(batch, chan, row, col);`} + if (coordsInBounds4D(coords, uniforms.outShape)) { + var value = valueIn; + ${biasActivationSnippet(this.addBias, this.activation)} + setOutputAtCoords(coords.x, coords.y, coords.z, coords.w, value); + } + } + ${main('index')} { + let coords = getOutputCoords(); + let batch = coords[0]; + let outChannel = ${this.isChannelsLast ? `coords[3];` : `coords[1];`} + let outRow = ${this.isChannelsLast ? `coords[1];` : `coords[2];`} + let outCol = ${this.isChannelsLast ? `coords[2];` : `coords[3];`} + var acc : f32 = 0.0; + for (var row = 0; row < uniforms.filterDims[0]; row = row + 1) { + for (var col = 0; col < uniforms.filterDims[1]; col = col + 1) { + let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * row - uniforms.pad[0]; + let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * col - uniforms.pad[1]; + for (var xChannel = 0; xChannel < ${ + this.isChannelsLast ? `uniforms.xShape[3];` : + `uniforms.xShape[1];`} xChannel = xChannel + 1) { + ${ + this.isChannelsLast ? `let v = readInp(batch, xRow, xCol, xChannel);` : + `let v = readInp(batch, xChannel, xRow, xCol);`} + let f = readFilt(row, col, xChannel, outChannel); + acc = acc + v * f; + } + } + } + writeResult(batch, outRow, outCol, outChannel, acc); + } + `; + return userCode; + } +} diff --git a/tfjs-backend-webgpu/src/flags_webgpu.ts b/tfjs-backend-webgpu/src/flags_webgpu.ts index 2df5d8cd5d8..a5313d44c88 100644 --- a/tfjs-backend-webgpu/src/flags_webgpu.ts +++ b/tfjs-backend-webgpu/src/flags_webgpu.ts @@ -65,3 +65,8 @@ ENV.registerFlag('WEBGPU_USE_PROFILE_TOOL', () => false); * Whether to use import API. */ ENV.registerFlag('WEBGPU_IMPORT_EXTERNAL_TEXTURE', () => true); + +/** + * Whether to use conv2dNaive for debugging. + */ +ENV.registerFlag('WEBGPU_USE_NAIVE_CONV2D_DEBUG', () => false); diff --git a/tfjs-backend-webgpu/src/kernels/Conv2D_impl.ts b/tfjs-backend-webgpu/src/kernels/Conv2D_impl.ts index e46cc8558c4..cb6552abff2 100644 --- a/tfjs-backend-webgpu/src/kernels/Conv2D_impl.ts +++ b/tfjs-backend-webgpu/src/kernels/Conv2D_impl.ts @@ -15,10 +15,12 @@ * ============================================================================= */ -import {backend_util, TensorInfo} from '@tensorflow/tfjs-core'; +import {backend_util, env, TensorInfo} from '@tensorflow/tfjs-core'; import {WebGPUBackend} from '../backend_webgpu'; import {Conv2DMMProgram} from '../conv2d_mm_webgpu'; +import {Conv2DNaiveProgram} from '../conv2d_naive_webgpu'; +import {WebGPUProgram} from '../webgpu_program'; import {batchMatMulImpl} from './BatchMatMul_impl'; import {reshape} from './Reshape'; @@ -184,12 +186,15 @@ export function conv2DImpl({ convInfo.filterHeight === convInfo.inHeight && convInfo.filterWidth === convInfo.inWidth && convInfo.padInfo.type === 'VALID'; - if (sameSize || - (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && - convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && - convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && - (convInfo.padInfo.type === 'SAME' || - convInfo.padInfo.type === 'VALID'))) { + const useNaiveConv2d = env().getBool('WEBGPU_USE_NAIVE_CONV2D_DEBUG'); + + if (!useNaiveConv2d && + (sameSize || + (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && + convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && + convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && + (convInfo.padInfo.type === 'SAME' || + convInfo.padInfo.type === 'VALID')))) { return conv2dByMatMul({ x, filter, @@ -202,25 +207,32 @@ export function conv2DImpl({ }); } - const dimAOuter = isChannelsLast ? convInfo.outHeight * convInfo.outWidth : - convInfo.outChannels; - const dimBOuter = isChannelsLast ? convInfo.outChannels : - convInfo.outHeight * convInfo.outWidth; - const dimInner = - convInfo.filterHeight * convInfo.filterWidth * convInfo.inChannels; + let program: WebGPUProgram; const padInfo = [convInfo.padInfo.top, convInfo.padInfo.left]; const dimensions = [ {type: 'int32', data: [convInfo.filterHeight, convInfo.filterWidth]}, {type: 'int32', data: [...padInfo]}, {type: 'int32', data: [convInfo.strideHeight, convInfo.strideWidth]}, - {type: 'int32', data: [convInfo.dilationHeight, convInfo.dilationWidth]}, - {type: 'int32', data: [dimAOuter]}, {type: 'int32', data: [dimBOuter]}, - {type: 'int32', data: [dimInner]} + {type: 'int32', data: [convInfo.dilationHeight, convInfo.dilationWidth]} ]; + if (useNaiveConv2d) { + program = new Conv2DNaiveProgram( + convInfo, hasBias, activation, hasPreluActivationWeights); + } else { + const dimAOuter = isChannelsLast ? convInfo.outHeight * convInfo.outWidth : + convInfo.outChannels; + const dimBOuter = isChannelsLast ? convInfo.outChannels : + convInfo.outHeight * convInfo.outWidth; + const dimInner = + convInfo.filterHeight * convInfo.filterWidth * convInfo.inChannels; + dimensions.push( + {type: 'int32', data: [dimAOuter]}, {type: 'int32', data: [dimBOuter]}, + {type: 'int32', data: [dimInner]}); - const program = new Conv2DMMProgram( - convInfo, dimAOuter, dimBOuter, dimInner, hasBias, activation, - hasPreluActivationWeights); + program = new Conv2DMMProgram( + convInfo, dimAOuter, dimBOuter, dimInner, hasBias, activation, + hasPreluActivationWeights); + } const intermediates: TensorInfo[] = []; const inputVar: TensorInfo[] = [x, filter]; diff --git a/tfjs-layers/src/BUILD.bazel b/tfjs-layers/src/BUILD.bazel index de5ee80349b..7909b9af0a1 100644 --- a/tfjs-layers/src/BUILD.bazel +++ b/tfjs-layers/src/BUILD.bazel @@ -78,9 +78,7 @@ ts_library( ts_library( name = "tfjs-layers_test_lib", - # disable testonly for the issue in the nodejs build target. - # https://github.com/bazelbuild/rules_nodejs/pull/2984 - #testonly = True, + testonly = True, srcs = glob(TEST_SRCS) + [":tests"], module_name = "@tensorflow/tfjs-layers/dist", deps = [ diff --git a/tfjs-layers/src/exports_layers.ts b/tfjs-layers/src/exports_layers.ts index 0904e02d01c..0e7b9da6ca4 100755 --- a/tfjs-layers/src/exports_layers.ts +++ b/tfjs-layers/src/exports_layers.ts @@ -24,6 +24,7 @@ import {ZeroPadding2D, ZeroPadding2DLayerArgs} from './layers/padding'; import {AveragePooling1D, AveragePooling2D, AveragePooling3D, GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalMaxPooling1D, GlobalMaxPooling2D, GlobalPooling2DLayerArgs, MaxPooling1D, MaxPooling2D, MaxPooling3D, Pooling1DLayerArgs, Pooling2DLayerArgs, Pooling3DLayerArgs} from './layers/pooling'; import {GRU, GRUCell, GRUCellLayerArgs, GRULayerArgs, LSTM, LSTMCell, LSTMCellLayerArgs, LSTMLayerArgs, RNN, RNNCell, RNNLayerArgs, SimpleRNN, SimpleRNNCell, SimpleRNNCellLayerArgs, SimpleRNNLayerArgs, StackedRNNCells, StackedRNNCellsArgs} from './layers/recurrent'; import {Bidirectional, BidirectionalLayerArgs, TimeDistributed, WrapperLayerArgs} from './layers/wrappers'; +import { Rescaling, RescalingArgs } from './layers/preprocessing/image_preprocessing'; // TODO(cais): Add doc string to all the public static functions in this // class; include exectuable JavaScript code snippets where applicable @@ -1697,3 +1698,34 @@ export function alphaDropout(args: AlphaDropoutArgs) { export function masking(args?: MaskingArgs) { return new Masking(args); } + +/** + * A preprocessing layer which rescales input values to a new range. + * + * This layer rescales every value of an input (often an image) by multiplying + * by `scale` and adding `offset`. + * + * For instance: + * 1. To rescale an input in the ``[0, 255]`` range + * to be in the `[0, 1]` range, you would pass `scale=1/255`. + * 2. To rescale an input in the ``[0, 255]`` range to be in the `[-1, 1]` + * range, you would pass `scale=1./127.5, offset=-1`. + * The rescaling is applied both during training and inference. Inputs can be + * of integer or floating point dtype, and by default the layer will output + * floats. + * + * Arguments: + * - `scale`: Float, the scale to apply to the inputs. + * - `offset`: Float, the offset to apply to the inputs. + * + * Input shape: + * Arbitrary. + * + * Output shape: + * Same as input. + * + * @doc {heading: 'Layers', subheading: 'Rescaling', namespace: 'layers'} + */ +export function rescaling(args?: RescalingArgs) { + return new Rescaling(args); +} diff --git a/tfjs-layers/src/layers/preprocessing/image_preprocessing.ts b/tfjs-layers/src/layers/preprocessing/image_preprocessing.ts new file mode 100644 index 00000000000..184c244c540 --- /dev/null +++ b/tfjs-layers/src/layers/preprocessing/image_preprocessing.ts @@ -0,0 +1,66 @@ +/** + * @license + * Copyright 2022 CodeSmith LLC + * + * Use of this source code is governed by an MIT-style + * license that can be found in the LICENSE file or at + * https://opensource.org/licenses/MIT. + * ============================================================================= + */ + +import {LayerArgs, Layer} from '../../engine/topology'; +import { serialization, Tensor, mul, add, tidy } from '@tensorflow/tfjs-core'; +import { getExactlyOneTensor } from '../../utils/types_utils'; +import * as K from '../../backend/tfjs_backend'; +import { Kwargs } from '../../types'; + +export declare interface RescalingArgs extends LayerArgs { + scale: number; + offset?: number; +} + +/** + * Preprocessing Rescaling Layer + * + * This rescales images by a scaling and offset factor + */ +export class Rescaling extends Layer { + /** @nocollapse */ + static className = 'Rescaling'; + private readonly scale: number; + private readonly offset: number; + constructor(args: RescalingArgs) { + super(args); + + this.scale = args.scale; + + if(args.offset) { + this.offset = args.offset; + } else { + this.offset = 0; + } + } + + getConfig(): serialization.ConfigDict { + const config: serialization.ConfigDict = { + 'scale': this.scale, + 'offset': this.offset + }; + + const baseConfig = super.getConfig(); + Object.assign(config, baseConfig); + return config; + } + + call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor[]|Tensor { + return tidy(() => { + inputs = getExactlyOneTensor(inputs); + if(inputs.dtype !== 'float32') { + inputs = K.cast(inputs, 'float32'); + } + return add(mul(inputs, this.scale), this.offset); + }); + } +} + +serialization.registerClass(Rescaling); diff --git a/tfjs-layers/src/layers/preprocessing/image_preprocessing_test.ts b/tfjs-layers/src/layers/preprocessing/image_preprocessing_test.ts new file mode 100644 index 00000000000..d40e2fec879 --- /dev/null +++ b/tfjs-layers/src/layers/preprocessing/image_preprocessing_test.ts @@ -0,0 +1,46 @@ +import { Tensor, randomNormal, mul, add} from '@tensorflow/tfjs-core'; +import { Rescaling } from './image_preprocessing'; +import { describeMathCPUAndGPU, expectTensorsClose } from '../../utils/test_utils'; + +describeMathCPUAndGPU('Rescaling Layer', () => { + + it('Check if input shape matches output shape', () => { + const scale = 1.0 / 127.5; + const offset = 0; + const input = randomNormal([2, 4, 5, 3]); + const expectedOutputTensor = add(mul(input, scale), offset); + const scalingLayer = new Rescaling({scale, offset}); + const layerOutputTensor = scalingLayer.apply(input) as Tensor; + expect(expectedOutputTensor.shape).toEqual(layerOutputTensor.shape); + }); + + it('Rescales input layer based on scaling factor and offset', () => { + const scale = 1.0 / 127.5; + const offset = -1.0; + const input = randomNormal([2, 4, 5, 3]); + const expectedOutputTensor = add(mul(input, scale), offset); + const scalingLayer = new Rescaling({scale, offset}); + const layerOutputTensor = scalingLayer.apply(input) as Tensor; + expectTensorsClose(layerOutputTensor, expectedOutputTensor); + }); + + it('Recasts dtype to float32', () => { + const scale = 1.0 / 127.5; + const offset = -1.0; + const intTensor = randomNormal([2, 4, 5, 3], 7, 2, 'int32'); + const expectedOutputTensor = add(mul(intTensor, scale), offset); + const scalingLayer = new Rescaling({scale, offset}); + const outputTensor = scalingLayer.apply(intTensor) as Tensor; + expect(outputTensor.dtype).toBe('float32'); + expectTensorsClose(outputTensor, expectedOutputTensor); + }); + + it('Config holds correct name', () => { + const scale = 1.0 / 127.5; + const offset = -1.0; + const scalingLayer = new Rescaling({scale, offset, name: 'Rescaling'}); + const config = scalingLayer.getConfig(); + expect(config.name).toEqual('Rescaling'); + }); + +}); diff --git a/tfjs-layers/src/setup_tests.ts b/tfjs-layers/src/setup_tests.ts index 1b57c6f82b1..b9ef137e2d4 100644 --- a/tfjs-layers/src/setup_tests.ts +++ b/tfjs-layers/src/setup_tests.ts @@ -28,7 +28,8 @@ import '@tensorflow/tfjs-backend-webgl'; // tslint:disable-next-line: no-imports-from-dist import {parseTestEnvFromKarmaFlags, registerTestEnv, setTestEnvs, TEST_ENVS} from '@tensorflow/tfjs-core/dist/jasmine_util'; -registerTestEnv({ +// Register test environments. +const webgl1TestEnv = { name: 'webgl1', backendName: 'webgl', flags: { @@ -37,6 +38,17 @@ registerTestEnv({ 'WEBGL_SIZE_UPLOAD_UNIFORM': 0 }, isDataSync: true +}; +registerTestEnv(webgl1TestEnv); +registerTestEnv({name: 'cpu', backendName: 'cpu'}); +registerTestEnv({ + name: 'webgl2', + backendName: 'webgl', + flags: { + 'WEBGL_VERSION': 2, + 'WEBGL_CPU_FORWARD': false, + 'WEBGL_SIZE_UPLOAD_UNIFORM': 0 + } }); // Allow flags to override test envs @@ -44,8 +56,13 @@ registerTestEnv({ declare let __karma__: any; if (typeof __karma__ !== 'undefined') { const testEnv = parseTestEnvFromKarmaFlags(__karma__.config.args, TEST_ENVS); + if (testEnv != null) { setTestEnvs([testEnv]); + } else { + // Exclude webgl1 unless it is specifically requested because it causes + // test flakiness when switching between webgl1 and webgl2. + setTestEnvs(TEST_ENVS.filter(env => env !== webgl1TestEnv)); } } diff --git a/tfjs-layers/src/utils/test_utils.ts b/tfjs-layers/src/utils/test_utils.ts index 98289759b12..c1413d0d6fb 100644 --- a/tfjs-layers/src/utils/test_utils.ts +++ b/tfjs-layers/src/utils/test_utils.ts @@ -14,22 +14,10 @@ import {memory, Tensor, test_util, util} from '@tensorflow/tfjs-core'; // tslint:disable-next-line: no-imports-from-dist -import {ALL_ENVS, describeWithFlags, registerTestEnv} from '@tensorflow/tfjs-core/dist/jasmine_util'; +import {ALL_ENVS, describeWithFlags} from '@tensorflow/tfjs-core/dist/jasmine_util'; import {ValueError} from '../errors'; -// Register backends. -registerTestEnv({name: 'cpu', backendName: 'cpu'}); -registerTestEnv({ - name: 'webgl2', - backendName: 'webgl', - flags: { - 'WEBGL_VERSION': 2, - 'WEBGL_CPU_FORWARD': false, - 'WEBGL_SIZE_UPLOAD_UNIFORM': 0 - } -}); - /** * Expect values are close between a Tensor or number array. * @param actual diff --git a/tfjs-tflite/BUILD.bazel b/tfjs-tflite/BUILD.bazel index b20a3264c63..a30686ec811 100644 --- a/tfjs-tflite/BUILD.bazel +++ b/tfjs-tflite/BUILD.bazel @@ -103,5 +103,6 @@ test_suite( name = "tests", tests = [ ":tfjs-tflite_test", + "//tfjs-tflite/src:worker_test", ], ) diff --git a/tfjs-tflite/src/BUILD.bazel b/tfjs-tflite/src/BUILD.bazel index 07176ac7fd4..2171aeda521 100644 --- a/tfjs-tflite/src/BUILD.bazel +++ b/tfjs-tflite/src/BUILD.bazel @@ -15,6 +15,7 @@ load("@build_bazel_rules_nodejs//:index.bzl", "copy_to_bin") load("//tools:defaults.bzl", "esbuild", "ts_library") +load("//tools:tfjs_web_test.bzl", "tfjs_web_test") package(default_visibility = ["//visibility:public"]) @@ -55,7 +56,10 @@ copy_to_bin( ts_library( name = "tfjs-tflite_test_lib", - srcs = glob(TEST_SRCS), + srcs = glob( + TEST_SRCS, + exclude = ["worker_test.ts"], + ), module_name = "@tensorflow/tfjs-tflite/dist", deps = [ ":tfjs-tflite_lib", @@ -83,3 +87,46 @@ esbuild( "//tfjs-tflite/wasm:wasm_files", ], ) + +ts_library( + name = "worker_test_lib", + srcs = [ + "worker_test.ts", + ], + deps = [ + "//tfjs-backend-cpu/src:tfjs-backend-cpu_lib", + "//tfjs-core/src:tfjs-core_lib", + "//tfjs-core/src:tfjs-core_src_lib", + ], +) + +tfjs_web_test( + name = "worker_test", + browsers = [ + "bs_chrome_mac", + "bs_firefox_mac", + "bs_safari_mac", + # Temporarily disabled because BrowserStack does not support loading + # absolute paths in iOS, which is required for loading the worker. + # https://www.browserstack.com/question/39573 + # "bs_ios_12", + "bs_android_9", + "win_10_chrome", + ], + static_files = [ + # For the webworker + "//tfjs-core:tf-core.min.js", + "//tfjs-core:tf-core.min.js.map", + "//tfjs-backend-cpu:tf-backend-cpu.min.js", + "//tfjs-backend-cpu:tf-backend-cpu.min.js.map", + "//tfjs-tflite:tf-tflite.min.js", + "//tfjs-tflite:tf-tflite.min.js.map", + "//tfjs-tflite/wasm:wasm_files", + "//tfjs-tflite/test_files:add4.tflite", + ], + deps = [ + ":worker_test_lib", + "@npm//long:long__umd", + "@npm//seedrandom:seedrandom__umd", + ], +) diff --git a/tfjs-tflite/src/worker_test.ts b/tfjs-tflite/src/worker_test.ts new file mode 100644 index 00000000000..4865470c2df --- /dev/null +++ b/tfjs-tflite/src/worker_test.ts @@ -0,0 +1,60 @@ +/** + * @license + * Copyright 2022 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import '@tensorflow/tfjs-backend-cpu'; + +const str2workerURL = (str: string): string => { + const blob = + new Blob([str], {type: 'application/javascript'}); + return URL.createObjectURL(blob); +}; + +// The source code of a web worker. +const workerTest = ` +importScripts(location.origin + '/base/tfjs/tfjs-core/tf-core.min.js'); +importScripts(location.origin + + '/base/tfjs/tfjs-backend-cpu/tf-backend-cpu.min.js'); +// Import order matters. TFLite must be imported after tfjs core. +importScripts(location.origin + '/base/tfjs/tfjs-tflite/tf-tflite.min.js'); + +// Setting wasm path is required. It can be set to CDN if needed, +// but that's not a good idea for a test. +tflite.setWasmPath('/base/tfjs/tfjs-tflite/wasm/'); +async function main() { + // This is a test model that adds two tensors of shape [1, 4]. + const model = await tflite.loadTFLiteModel(location.origin + '/base/tfjs/tfjs-tflite/test_files/add4.tflite'); + + const a = tf.tensor2d([[1, 2, 3, 4]]); + const b = tf.tensor2d([[5, 6, 7, 8]]); + const output = model.predict([a, b]); + + self.postMessage({data: output.dataSync()}); +} + +main(); +`; + +describe('tflite in worker', () => { + it('runs a model', (done) => { + const worker = new Worker(str2workerURL(workerTest)); + worker.onmessage = (msg) => { + const data = msg.data.data; + expect([...data]).toEqual([6, 8, 10, 12]); + done(); + }; + }, 15_000); +}); diff --git a/tfjs-tflite/test_files/BUILD.bazel b/tfjs-tflite/test_files/BUILD.bazel new file mode 100644 index 00000000000..ba0b9d6c0eb --- /dev/null +++ b/tfjs-tflite/test_files/BUILD.bazel @@ -0,0 +1,21 @@ +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +package(default_visibility = ["//visibility:public"]) + +exports_files([ + # add4.tflite adds two tensors of shape [1,4] + "add4.tflite", +]) diff --git a/tfjs-tflite/test_files/add4.tflite b/tfjs-tflite/test_files/add4.tflite new file mode 100644 index 00000000000..6ee2860e16b Binary files /dev/null and b/tfjs-tflite/test_files/add4.tflite differ