From fe5e7dbbe918369772677b9e355f649bb1616e8d Mon Sep 17 00:00:00 2001 From: xhcao Date: Wed, 21 Sep 2022 17:25:54 +0800 Subject: [PATCH 1/5] webgpu: add conv2d naive version for debugging (#6837) --- .../src/conv2d_naive_webgpu.ts | 121 ++++++++++++++++++ tfjs-backend-webgpu/src/flags_webgpu.ts | 5 + .../src/kernels/Conv2D_impl.ts | 50 +++++--- 3 files changed, 157 insertions(+), 19 deletions(-) create mode 100644 tfjs-backend-webgpu/src/conv2d_naive_webgpu.ts diff --git a/tfjs-backend-webgpu/src/conv2d_naive_webgpu.ts b/tfjs-backend-webgpu/src/conv2d_naive_webgpu.ts new file mode 100644 index 00000000000..834e1c5ead6 --- /dev/null +++ b/tfjs-backend-webgpu/src/conv2d_naive_webgpu.ts @@ -0,0 +1,121 @@ +/** + * @license + * Copyright 2022 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util} from '@tensorflow/tfjs-core'; + +import {activationFnSnippet, biasActivationSnippet} from './activation_util'; +import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program'; +import {computeDispatch} from './webgpu_util'; + +export class Conv2DNaiveProgram implements WebGPUProgram { + outputShape: number[]; + shaderKey: string; + dispatchLayout: {x: number[], y: number[], z: number[]}; + dispatch: [number, number, number]; + variableNames = ['x', 'W']; + uniforms = + 'filterDims: vec2, pad: vec2, stride: vec2, dilation: vec2,'; + workGroupSize: [number, number, number] = [4, 4, 8]; + addBias: boolean; + activation: backend_util.Activation; + hasPreluActivationWeights: boolean; + isChannelsLast: boolean; + + constructor( + convInfo: backend_util.Conv2DInfo, addBias = false, + activation: backend_util.Activation = null, + hasPreluActivationWeights = false) { + this.outputShape = convInfo.outShape; + this.isChannelsLast = convInfo.dataFormat === 'channelsLast'; + this.dispatchLayout = this.isChannelsLast ? {x: [2], y: [1], z: [0, 3]} : + {x: [3], y: [2], z: [0, 1]}; + this.dispatch = computeDispatch( + this.dispatchLayout, this.outputShape, this.workGroupSize); + this.addBias = addBias; + this.activation = activation; + this.hasPreluActivationWeights = hasPreluActivationWeights; + + if (addBias) { + this.variableNames.push('bias'); + } + + if (hasPreluActivationWeights) { + this.variableNames.push('preluActivationWeights'); + } + + this.shaderKey = `conv2dnaive_${this.activation}_${this.isChannelsLast}`; + } + + getUserCode(): string { + const userCode = ` + ${ + activationFnSnippet( + this.activation, this.hasPreluActivationWeights, false, 4)} + fn readInp(batch : i32, row : i32, col : i32, chan : i32) -> f32{ + let coords = vec4(batch, row, col, chan); + if (coordsInBounds4D(coords, uniforms.xShape)) { + return getX(batch, row, col, chan); + } else { + return 0.0; + } + } + fn readFilt(row : i32, col : i32, xChannel : i32, outChannel : i32) -> f32{ + let coords = vec4(row, col, xChannel, outChannel); + if(coordsInBounds4D(coords, uniforms.wShape)) { + return getW(row, col, xChannel, outChannel); + } else { + return 0.0; + } + } + fn writeResult(batch : i32, row : i32, col : i32, chan : i32, valueIn : f32) { + let coords = ${ + this.isChannelsLast ? `vec4(batch, row, col, chan);` : + `vec4(batch, chan, row, col);`} + if (coordsInBounds4D(coords, uniforms.outShape)) { + var value = valueIn; + ${biasActivationSnippet(this.addBias, this.activation)} + setOutputAtCoords(coords.x, coords.y, coords.z, coords.w, value); + } + } + ${main('index')} { + let coords = getOutputCoords(); + let batch = coords[0]; + let outChannel = ${this.isChannelsLast ? `coords[3];` : `coords[1];`} + let outRow = ${this.isChannelsLast ? `coords[1];` : `coords[2];`} + let outCol = ${this.isChannelsLast ? `coords[2];` : `coords[3];`} + var acc : f32 = 0.0; + for (var row = 0; row < uniforms.filterDims[0]; row = row + 1) { + for (var col = 0; col < uniforms.filterDims[1]; col = col + 1) { + let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * row - uniforms.pad[0]; + let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * col - uniforms.pad[1]; + for (var xChannel = 0; xChannel < ${ + this.isChannelsLast ? `uniforms.xShape[3];` : + `uniforms.xShape[1];`} xChannel = xChannel + 1) { + ${ + this.isChannelsLast ? `let v = readInp(batch, xRow, xCol, xChannel);` : + `let v = readInp(batch, xChannel, xRow, xCol);`} + let f = readFilt(row, col, xChannel, outChannel); + acc = acc + v * f; + } + } + } + writeResult(batch, outRow, outCol, outChannel, acc); + } + `; + return userCode; + } +} diff --git a/tfjs-backend-webgpu/src/flags_webgpu.ts b/tfjs-backend-webgpu/src/flags_webgpu.ts index 2df5d8cd5d8..a5313d44c88 100644 --- a/tfjs-backend-webgpu/src/flags_webgpu.ts +++ b/tfjs-backend-webgpu/src/flags_webgpu.ts @@ -65,3 +65,8 @@ ENV.registerFlag('WEBGPU_USE_PROFILE_TOOL', () => false); * Whether to use import API. */ ENV.registerFlag('WEBGPU_IMPORT_EXTERNAL_TEXTURE', () => true); + +/** + * Whether to use conv2dNaive for debugging. + */ +ENV.registerFlag('WEBGPU_USE_NAIVE_CONV2D_DEBUG', () => false); diff --git a/tfjs-backend-webgpu/src/kernels/Conv2D_impl.ts b/tfjs-backend-webgpu/src/kernels/Conv2D_impl.ts index e46cc8558c4..cb6552abff2 100644 --- a/tfjs-backend-webgpu/src/kernels/Conv2D_impl.ts +++ b/tfjs-backend-webgpu/src/kernels/Conv2D_impl.ts @@ -15,10 +15,12 @@ * ============================================================================= */ -import {backend_util, TensorInfo} from '@tensorflow/tfjs-core'; +import {backend_util, env, TensorInfo} from '@tensorflow/tfjs-core'; import {WebGPUBackend} from '../backend_webgpu'; import {Conv2DMMProgram} from '../conv2d_mm_webgpu'; +import {Conv2DNaiveProgram} from '../conv2d_naive_webgpu'; +import {WebGPUProgram} from '../webgpu_program'; import {batchMatMulImpl} from './BatchMatMul_impl'; import {reshape} from './Reshape'; @@ -184,12 +186,15 @@ export function conv2DImpl({ convInfo.filterHeight === convInfo.inHeight && convInfo.filterWidth === convInfo.inWidth && convInfo.padInfo.type === 'VALID'; - if (sameSize || - (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && - convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && - convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && - (convInfo.padInfo.type === 'SAME' || - convInfo.padInfo.type === 'VALID'))) { + const useNaiveConv2d = env().getBool('WEBGPU_USE_NAIVE_CONV2D_DEBUG'); + + if (!useNaiveConv2d && + (sameSize || + (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && + convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && + convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && + (convInfo.padInfo.type === 'SAME' || + convInfo.padInfo.type === 'VALID')))) { return conv2dByMatMul({ x, filter, @@ -202,25 +207,32 @@ export function conv2DImpl({ }); } - const dimAOuter = isChannelsLast ? convInfo.outHeight * convInfo.outWidth : - convInfo.outChannels; - const dimBOuter = isChannelsLast ? convInfo.outChannels : - convInfo.outHeight * convInfo.outWidth; - const dimInner = - convInfo.filterHeight * convInfo.filterWidth * convInfo.inChannels; + let program: WebGPUProgram; const padInfo = [convInfo.padInfo.top, convInfo.padInfo.left]; const dimensions = [ {type: 'int32', data: [convInfo.filterHeight, convInfo.filterWidth]}, {type: 'int32', data: [...padInfo]}, {type: 'int32', data: [convInfo.strideHeight, convInfo.strideWidth]}, - {type: 'int32', data: [convInfo.dilationHeight, convInfo.dilationWidth]}, - {type: 'int32', data: [dimAOuter]}, {type: 'int32', data: [dimBOuter]}, - {type: 'int32', data: [dimInner]} + {type: 'int32', data: [convInfo.dilationHeight, convInfo.dilationWidth]} ]; + if (useNaiveConv2d) { + program = new Conv2DNaiveProgram( + convInfo, hasBias, activation, hasPreluActivationWeights); + } else { + const dimAOuter = isChannelsLast ? convInfo.outHeight * convInfo.outWidth : + convInfo.outChannels; + const dimBOuter = isChannelsLast ? convInfo.outChannels : + convInfo.outHeight * convInfo.outWidth; + const dimInner = + convInfo.filterHeight * convInfo.filterWidth * convInfo.inChannels; + dimensions.push( + {type: 'int32', data: [dimAOuter]}, {type: 'int32', data: [dimBOuter]}, + {type: 'int32', data: [dimInner]}); - const program = new Conv2DMMProgram( - convInfo, dimAOuter, dimBOuter, dimInner, hasBias, activation, - hasPreluActivationWeights); + program = new Conv2DMMProgram( + convInfo, dimAOuter, dimBOuter, dimInner, hasBias, activation, + hasPreluActivationWeights); + } const intermediates: TensorInfo[] = []; const inputVar: TensorInfo[] = [x, filter]; From cff0c5c37aac52ffa166fbe99aa2b6fbf7ddee97 Mon Sep 17 00:00:00 2001 From: Linchenn <40653845+Linchenn@users.noreply.github.com> Date: Wed, 21 Sep 2022 13:28:38 -0700 Subject: [PATCH 2/5] Fix gatherND for 5D inputs (#6832) WebGL backend is supposed to support 5D or 6D inputs for gatherND op. However, WebGL backend throws an error for such inputs because of the following codes in gatherND. Take an input tensor with (1,48,48,17,2) shape for example: ivec5 paramsShape = ivec5(1,48,48,17,2); ... paramsShape[j]; This is problematic because ivec5 is a strcuture defined by us, which could be accessed through [] and could be accessed only through .x/.y/.z/.w/.u. This PR fixes this problem by applying paramsShape's value directly into the code, instead of using an intermediate ivec5. --- tfjs-backend-webgl/src/gather_nd_gpu.ts | 26 +++++++++++++----------- tfjs-backend-webgl/src/webgl_ops_test.ts | 19 +++++++++++++++++ 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/tfjs-backend-webgl/src/gather_nd_gpu.ts b/tfjs-backend-webgl/src/gather_nd_gpu.ts index 52cba42bf6f..507f9f743b6 100644 --- a/tfjs-backend-webgl/src/gather_nd_gpu.ts +++ b/tfjs-backend-webgl/src/gather_nd_gpu.ts @@ -25,24 +25,26 @@ export class GatherNDProgram implements GPGPUProgram { private sliceDim: number, private strides: number[], shape: number[], private paramsShape: number[]) { this.outputShape = shape; - const stridesType = getCoordsDataType(strides.length); const dtype = getCoordsDataType(shape.length); - const strideString = this.sliceDim > 1 ? 'strides[j]' : 'strides'; - const paramsShapeType = getCoordsDataType(paramsShape.length); - const paramsShapeString = paramsShape.length > 1 ? 'paramsShape[j]' : 'paramsShape'; + + let mainLoop = ` + int index;`; + for (let j = 0; j < this.sliceDim; j++) { + mainLoop += ` + index = round(getIndices(coords[0], ${j})); + out_of_bounds = out_of_bounds || index < 0; + out_of_bounds = out_of_bounds || index >= ${this.paramsShape[j]}; + flattenIndex += index * ${this.strides[j]};`; + } + this.userCode = ` - ${stridesType} strides = ${stridesType}(${this.strides}); - ${paramsShapeType} paramsShape = ${paramsShapeType}(${this.paramsShape}); void main() { ${dtype} coords = getOutputCoords(); int flattenIndex = 0; bool out_of_bounds = false; - for (int j = 0; j < ${this.sliceDim}; j++) { - int index = round(getIndices(coords[0], j)); - out_of_bounds = out_of_bounds || index < 0; - out_of_bounds = out_of_bounds || index >= ${paramsShapeString}; - flattenIndex += index * ${strideString}; - } + + ${mainLoop} + setOutput(out_of_bounds ? 0.0 : getX(flattenIndex, coords[1])); } `; diff --git a/tfjs-backend-webgl/src/webgl_ops_test.ts b/tfjs-backend-webgl/src/webgl_ops_test.ts index 531314679aa..22d99ea7364 100644 --- a/tfjs-backend-webgl/src/webgl_ops_test.ts +++ b/tfjs-backend-webgl/src/webgl_ops_test.ts @@ -918,6 +918,15 @@ describeWithFlags('gatherNd', WEBGL_ENVS, () => { expectArraysEqual(await g.data(), [0, 1, 2, 0]); }); + it('works for out of bounds indices 1d', async () => { + const x = tf.tensor1d([...Array(4).keys()].map(e => e + 1), 'int32'); + const indices = [0, 1, 2, 5]; + const ind = tf.tensor2d(indices, [4, 1], 'int32'); + const g = tf.gatherND(x, ind); + const expected = [1, 2, 3, 0]; + expectArraysEqual(await g.data(), expected); + }); + it('works for out of bounds indices 2d', async () => { const x = tf.tensor2d([...Array(4).keys()].map(e => e + 1), [2, 2], 'int32'); @@ -957,4 +966,14 @@ describeWithFlags('gatherNd', WEBGL_ENVS, () => { ]; expectArraysEqual(await g.data(), expected); }); + + it('works for out of bounds indices 5d', async () => { + const x = tf.tensor5d( + [...Array(32).keys()].map(e => e + 1), [2, 2, 2, 2, 2], 'int32'); + const indices = [0, 0, 0, 0, 0, 2, 1, 1, 1, 1]; + const ind = tf.tensor2d(indices, [2, 5], 'int32'); + const g = tf.gatherND(x, ind); + const expected = [1, 0]; + expectArraysEqual(await g.data(), expected); + }); }); From eaf8f05b2462d203c669ed675251f4d179390d1d Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Thu, 22 Sep 2022 10:08:56 -0700 Subject: [PATCH 3/5] Test running tfjs-tflite in a webworker (#6844) Add a new test //tfjs-tflite/src/worker_test to test running tflite in a webworker. The test runs a simple model that just adds the two inputs together. --- tfjs-tflite/BUILD.bazel | 1 + tfjs-tflite/src/BUILD.bazel | 49 ++++++++++++++++++++++- tfjs-tflite/src/worker_test.ts | 60 +++++++++++++++++++++++++++++ tfjs-tflite/test_files/BUILD.bazel | 21 ++++++++++ tfjs-tflite/test_files/add4.tflite | Bin 0 -> 952 bytes 5 files changed, 130 insertions(+), 1 deletion(-) create mode 100644 tfjs-tflite/src/worker_test.ts create mode 100644 tfjs-tflite/test_files/BUILD.bazel create mode 100644 tfjs-tflite/test_files/add4.tflite diff --git a/tfjs-tflite/BUILD.bazel b/tfjs-tflite/BUILD.bazel index b20a3264c63..a30686ec811 100644 --- a/tfjs-tflite/BUILD.bazel +++ b/tfjs-tflite/BUILD.bazel @@ -103,5 +103,6 @@ test_suite( name = "tests", tests = [ ":tfjs-tflite_test", + "//tfjs-tflite/src:worker_test", ], ) diff --git a/tfjs-tflite/src/BUILD.bazel b/tfjs-tflite/src/BUILD.bazel index 07176ac7fd4..2171aeda521 100644 --- a/tfjs-tflite/src/BUILD.bazel +++ b/tfjs-tflite/src/BUILD.bazel @@ -15,6 +15,7 @@ load("@build_bazel_rules_nodejs//:index.bzl", "copy_to_bin") load("//tools:defaults.bzl", "esbuild", "ts_library") +load("//tools:tfjs_web_test.bzl", "tfjs_web_test") package(default_visibility = ["//visibility:public"]) @@ -55,7 +56,10 @@ copy_to_bin( ts_library( name = "tfjs-tflite_test_lib", - srcs = glob(TEST_SRCS), + srcs = glob( + TEST_SRCS, + exclude = ["worker_test.ts"], + ), module_name = "@tensorflow/tfjs-tflite/dist", deps = [ ":tfjs-tflite_lib", @@ -83,3 +87,46 @@ esbuild( "//tfjs-tflite/wasm:wasm_files", ], ) + +ts_library( + name = "worker_test_lib", + srcs = [ + "worker_test.ts", + ], + deps = [ + "//tfjs-backend-cpu/src:tfjs-backend-cpu_lib", + "//tfjs-core/src:tfjs-core_lib", + "//tfjs-core/src:tfjs-core_src_lib", + ], +) + +tfjs_web_test( + name = "worker_test", + browsers = [ + "bs_chrome_mac", + "bs_firefox_mac", + "bs_safari_mac", + # Temporarily disabled because BrowserStack does not support loading + # absolute paths in iOS, which is required for loading the worker. + # https://www.browserstack.com/question/39573 + # "bs_ios_12", + "bs_android_9", + "win_10_chrome", + ], + static_files = [ + # For the webworker + "//tfjs-core:tf-core.min.js", + "//tfjs-core:tf-core.min.js.map", + "//tfjs-backend-cpu:tf-backend-cpu.min.js", + "//tfjs-backend-cpu:tf-backend-cpu.min.js.map", + "//tfjs-tflite:tf-tflite.min.js", + "//tfjs-tflite:tf-tflite.min.js.map", + "//tfjs-tflite/wasm:wasm_files", + "//tfjs-tflite/test_files:add4.tflite", + ], + deps = [ + ":worker_test_lib", + "@npm//long:long__umd", + "@npm//seedrandom:seedrandom__umd", + ], +) diff --git a/tfjs-tflite/src/worker_test.ts b/tfjs-tflite/src/worker_test.ts new file mode 100644 index 00000000000..4865470c2df --- /dev/null +++ b/tfjs-tflite/src/worker_test.ts @@ -0,0 +1,60 @@ +/** + * @license + * Copyright 2022 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import '@tensorflow/tfjs-backend-cpu'; + +const str2workerURL = (str: string): string => { + const blob = + new Blob([str], {type: 'application/javascript'}); + return URL.createObjectURL(blob); +}; + +// The source code of a web worker. +const workerTest = ` +importScripts(location.origin + '/base/tfjs/tfjs-core/tf-core.min.js'); +importScripts(location.origin + + '/base/tfjs/tfjs-backend-cpu/tf-backend-cpu.min.js'); +// Import order matters. TFLite must be imported after tfjs core. +importScripts(location.origin + '/base/tfjs/tfjs-tflite/tf-tflite.min.js'); + +// Setting wasm path is required. It can be set to CDN if needed, +// but that's not a good idea for a test. +tflite.setWasmPath('/base/tfjs/tfjs-tflite/wasm/'); +async function main() { + // This is a test model that adds two tensors of shape [1, 4]. + const model = await tflite.loadTFLiteModel(location.origin + '/base/tfjs/tfjs-tflite/test_files/add4.tflite'); + + const a = tf.tensor2d([[1, 2, 3, 4]]); + const b = tf.tensor2d([[5, 6, 7, 8]]); + const output = model.predict([a, b]); + + self.postMessage({data: output.dataSync()}); +} + +main(); +`; + +describe('tflite in worker', () => { + it('runs a model', (done) => { + const worker = new Worker(str2workerURL(workerTest)); + worker.onmessage = (msg) => { + const data = msg.data.data; + expect([...data]).toEqual([6, 8, 10, 12]); + done(); + }; + }, 15_000); +}); diff --git a/tfjs-tflite/test_files/BUILD.bazel b/tfjs-tflite/test_files/BUILD.bazel new file mode 100644 index 00000000000..ba0b9d6c0eb --- /dev/null +++ b/tfjs-tflite/test_files/BUILD.bazel @@ -0,0 +1,21 @@ +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +package(default_visibility = ["//visibility:public"]) + +exports_files([ + # add4.tflite adds two tensors of shape [1,4] + "add4.tflite", +]) diff --git a/tfjs-tflite/test_files/add4.tflite b/tfjs-tflite/test_files/add4.tflite new file mode 100644 index 0000000000000000000000000000000000000000..6ee2860e16b4c4a7afe80792817b34e821e35960 GIT binary patch literal 952 zcmaJ=%SyvQ6us4|#g_U+Ds+*BO9@R?5O<|22v(&AaUsM&q6XS1F?H=HxO62hL|pg@ zet-)^RT@(z3e z;kXDss=%g9D=XSeY&XmXYAmTb z3kfgfc0Jz;dVc72UF*&bZoO+?Cc$AJo`FZ;0k{EZ^OWRbVjuUe-1p1=tN;QJ><|6W zCYEg9SZCS~^6I)#(`#gAZRyj{H{v56Ln=Q}R0T9(1z_J>`+IGzdF`Wy&~C1Xo2r=*zJG9O5hwR8W1_-QSyhr<6+I?mA-a0)Qq z33`bom*&GcB4-t#FC43TFJPU+|B)SpUWglVou=LCY_VPBF&6#3j_}9xb>OeRQ}C9d t8< Date: Thu, 22 Sep 2022 14:57:12 -0700 Subject: [PATCH 4/5] Rescaling Preprocessing Layer (#6840) * Rescaling Preprocessing Layer Co-authored-by: David Kim (@koyykdy) Brian Zheng (@brianzheng123) * PR issues resolved * linting and PR issues resolved Co-authored-by: Adam Lang (@AdamLang96) Co-authored-by: (@Brianzheng123) Co-authored-by: David Don Young Kim Co-authored-by: David Don Young Kim <36175976+koyykdy@users.noreply.github.com> --- tfjs-layers/src/exports_layers.ts | 32 +++++++++ .../preprocessing/image_preprocessing.ts | 66 +++++++++++++++++++ .../preprocessing/image_preprocessing_test.ts | 46 +++++++++++++ 3 files changed, 144 insertions(+) create mode 100644 tfjs-layers/src/layers/preprocessing/image_preprocessing.ts create mode 100644 tfjs-layers/src/layers/preprocessing/image_preprocessing_test.ts diff --git a/tfjs-layers/src/exports_layers.ts b/tfjs-layers/src/exports_layers.ts index 0904e02d01c..0e7b9da6ca4 100755 --- a/tfjs-layers/src/exports_layers.ts +++ b/tfjs-layers/src/exports_layers.ts @@ -24,6 +24,7 @@ import {ZeroPadding2D, ZeroPadding2DLayerArgs} from './layers/padding'; import {AveragePooling1D, AveragePooling2D, AveragePooling3D, GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalMaxPooling1D, GlobalMaxPooling2D, GlobalPooling2DLayerArgs, MaxPooling1D, MaxPooling2D, MaxPooling3D, Pooling1DLayerArgs, Pooling2DLayerArgs, Pooling3DLayerArgs} from './layers/pooling'; import {GRU, GRUCell, GRUCellLayerArgs, GRULayerArgs, LSTM, LSTMCell, LSTMCellLayerArgs, LSTMLayerArgs, RNN, RNNCell, RNNLayerArgs, SimpleRNN, SimpleRNNCell, SimpleRNNCellLayerArgs, SimpleRNNLayerArgs, StackedRNNCells, StackedRNNCellsArgs} from './layers/recurrent'; import {Bidirectional, BidirectionalLayerArgs, TimeDistributed, WrapperLayerArgs} from './layers/wrappers'; +import { Rescaling, RescalingArgs } from './layers/preprocessing/image_preprocessing'; // TODO(cais): Add doc string to all the public static functions in this // class; include exectuable JavaScript code snippets where applicable @@ -1697,3 +1698,34 @@ export function alphaDropout(args: AlphaDropoutArgs) { export function masking(args?: MaskingArgs) { return new Masking(args); } + +/** + * A preprocessing layer which rescales input values to a new range. + * + * This layer rescales every value of an input (often an image) by multiplying + * by `scale` and adding `offset`. + * + * For instance: + * 1. To rescale an input in the ``[0, 255]`` range + * to be in the `[0, 1]` range, you would pass `scale=1/255`. + * 2. To rescale an input in the ``[0, 255]`` range to be in the `[-1, 1]` + * range, you would pass `scale=1./127.5, offset=-1`. + * The rescaling is applied both during training and inference. Inputs can be + * of integer or floating point dtype, and by default the layer will output + * floats. + * + * Arguments: + * - `scale`: Float, the scale to apply to the inputs. + * - `offset`: Float, the offset to apply to the inputs. + * + * Input shape: + * Arbitrary. + * + * Output shape: + * Same as input. + * + * @doc {heading: 'Layers', subheading: 'Rescaling', namespace: 'layers'} + */ +export function rescaling(args?: RescalingArgs) { + return new Rescaling(args); +} diff --git a/tfjs-layers/src/layers/preprocessing/image_preprocessing.ts b/tfjs-layers/src/layers/preprocessing/image_preprocessing.ts new file mode 100644 index 00000000000..184c244c540 --- /dev/null +++ b/tfjs-layers/src/layers/preprocessing/image_preprocessing.ts @@ -0,0 +1,66 @@ +/** + * @license + * Copyright 2022 CodeSmith LLC + * + * Use of this source code is governed by an MIT-style + * license that can be found in the LICENSE file or at + * https://opensource.org/licenses/MIT. + * ============================================================================= + */ + +import {LayerArgs, Layer} from '../../engine/topology'; +import { serialization, Tensor, mul, add, tidy } from '@tensorflow/tfjs-core'; +import { getExactlyOneTensor } from '../../utils/types_utils'; +import * as K from '../../backend/tfjs_backend'; +import { Kwargs } from '../../types'; + +export declare interface RescalingArgs extends LayerArgs { + scale: number; + offset?: number; +} + +/** + * Preprocessing Rescaling Layer + * + * This rescales images by a scaling and offset factor + */ +export class Rescaling extends Layer { + /** @nocollapse */ + static className = 'Rescaling'; + private readonly scale: number; + private readonly offset: number; + constructor(args: RescalingArgs) { + super(args); + + this.scale = args.scale; + + if(args.offset) { + this.offset = args.offset; + } else { + this.offset = 0; + } + } + + getConfig(): serialization.ConfigDict { + const config: serialization.ConfigDict = { + 'scale': this.scale, + 'offset': this.offset + }; + + const baseConfig = super.getConfig(); + Object.assign(config, baseConfig); + return config; + } + + call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor[]|Tensor { + return tidy(() => { + inputs = getExactlyOneTensor(inputs); + if(inputs.dtype !== 'float32') { + inputs = K.cast(inputs, 'float32'); + } + return add(mul(inputs, this.scale), this.offset); + }); + } +} + +serialization.registerClass(Rescaling); diff --git a/tfjs-layers/src/layers/preprocessing/image_preprocessing_test.ts b/tfjs-layers/src/layers/preprocessing/image_preprocessing_test.ts new file mode 100644 index 00000000000..d40e2fec879 --- /dev/null +++ b/tfjs-layers/src/layers/preprocessing/image_preprocessing_test.ts @@ -0,0 +1,46 @@ +import { Tensor, randomNormal, mul, add} from '@tensorflow/tfjs-core'; +import { Rescaling } from './image_preprocessing'; +import { describeMathCPUAndGPU, expectTensorsClose } from '../../utils/test_utils'; + +describeMathCPUAndGPU('Rescaling Layer', () => { + + it('Check if input shape matches output shape', () => { + const scale = 1.0 / 127.5; + const offset = 0; + const input = randomNormal([2, 4, 5, 3]); + const expectedOutputTensor = add(mul(input, scale), offset); + const scalingLayer = new Rescaling({scale, offset}); + const layerOutputTensor = scalingLayer.apply(input) as Tensor; + expect(expectedOutputTensor.shape).toEqual(layerOutputTensor.shape); + }); + + it('Rescales input layer based on scaling factor and offset', () => { + const scale = 1.0 / 127.5; + const offset = -1.0; + const input = randomNormal([2, 4, 5, 3]); + const expectedOutputTensor = add(mul(input, scale), offset); + const scalingLayer = new Rescaling({scale, offset}); + const layerOutputTensor = scalingLayer.apply(input) as Tensor; + expectTensorsClose(layerOutputTensor, expectedOutputTensor); + }); + + it('Recasts dtype to float32', () => { + const scale = 1.0 / 127.5; + const offset = -1.0; + const intTensor = randomNormal([2, 4, 5, 3], 7, 2, 'int32'); + const expectedOutputTensor = add(mul(intTensor, scale), offset); + const scalingLayer = new Rescaling({scale, offset}); + const outputTensor = scalingLayer.apply(intTensor) as Tensor; + expect(outputTensor.dtype).toBe('float32'); + expectTensorsClose(outputTensor, expectedOutputTensor); + }); + + it('Config holds correct name', () => { + const scale = 1.0 / 127.5; + const offset = -1.0; + const scalingLayer = new Rescaling({scale, offset, name: 'Rescaling'}); + const config = scalingLayer.getConfig(); + expect(config.name).toEqual('Rescaling'); + }); + +}); From 278eaca0c2a56ba308bf8004de2feaab1bee10e1 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Fri, 23 Sep 2022 10:46:20 -0700 Subject: [PATCH 5/5] Avoid testing webgl1 in webgl2 layers test (#6849) Prevent the webgl1 test env from being used in tfjs-layers unless it is explicitly requested (such as by running //tfjs-layers:tfjs-layers_webgl1_test). --- tfjs-layers/src/BUILD.bazel | 4 +--- tfjs-layers/src/setup_tests.ts | 19 ++++++++++++++++++- tfjs-layers/src/utils/test_utils.ts | 14 +------------- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/tfjs-layers/src/BUILD.bazel b/tfjs-layers/src/BUILD.bazel index de5ee80349b..7909b9af0a1 100644 --- a/tfjs-layers/src/BUILD.bazel +++ b/tfjs-layers/src/BUILD.bazel @@ -78,9 +78,7 @@ ts_library( ts_library( name = "tfjs-layers_test_lib", - # disable testonly for the issue in the nodejs build target. - # https://github.com/bazelbuild/rules_nodejs/pull/2984 - #testonly = True, + testonly = True, srcs = glob(TEST_SRCS) + [":tests"], module_name = "@tensorflow/tfjs-layers/dist", deps = [ diff --git a/tfjs-layers/src/setup_tests.ts b/tfjs-layers/src/setup_tests.ts index 1b57c6f82b1..b9ef137e2d4 100644 --- a/tfjs-layers/src/setup_tests.ts +++ b/tfjs-layers/src/setup_tests.ts @@ -28,7 +28,8 @@ import '@tensorflow/tfjs-backend-webgl'; // tslint:disable-next-line: no-imports-from-dist import {parseTestEnvFromKarmaFlags, registerTestEnv, setTestEnvs, TEST_ENVS} from '@tensorflow/tfjs-core/dist/jasmine_util'; -registerTestEnv({ +// Register test environments. +const webgl1TestEnv = { name: 'webgl1', backendName: 'webgl', flags: { @@ -37,6 +38,17 @@ registerTestEnv({ 'WEBGL_SIZE_UPLOAD_UNIFORM': 0 }, isDataSync: true +}; +registerTestEnv(webgl1TestEnv); +registerTestEnv({name: 'cpu', backendName: 'cpu'}); +registerTestEnv({ + name: 'webgl2', + backendName: 'webgl', + flags: { + 'WEBGL_VERSION': 2, + 'WEBGL_CPU_FORWARD': false, + 'WEBGL_SIZE_UPLOAD_UNIFORM': 0 + } }); // Allow flags to override test envs @@ -44,8 +56,13 @@ registerTestEnv({ declare let __karma__: any; if (typeof __karma__ !== 'undefined') { const testEnv = parseTestEnvFromKarmaFlags(__karma__.config.args, TEST_ENVS); + if (testEnv != null) { setTestEnvs([testEnv]); + } else { + // Exclude webgl1 unless it is specifically requested because it causes + // test flakiness when switching between webgl1 and webgl2. + setTestEnvs(TEST_ENVS.filter(env => env !== webgl1TestEnv)); } } diff --git a/tfjs-layers/src/utils/test_utils.ts b/tfjs-layers/src/utils/test_utils.ts index 98289759b12..c1413d0d6fb 100644 --- a/tfjs-layers/src/utils/test_utils.ts +++ b/tfjs-layers/src/utils/test_utils.ts @@ -14,22 +14,10 @@ import {memory, Tensor, test_util, util} from '@tensorflow/tfjs-core'; // tslint:disable-next-line: no-imports-from-dist -import {ALL_ENVS, describeWithFlags, registerTestEnv} from '@tensorflow/tfjs-core/dist/jasmine_util'; +import {ALL_ENVS, describeWithFlags} from '@tensorflow/tfjs-core/dist/jasmine_util'; import {ValueError} from '../errors'; -// Register backends. -registerTestEnv({name: 'cpu', backendName: 'cpu'}); -registerTestEnv({ - name: 'webgl2', - backendName: 'webgl', - flags: { - 'WEBGL_VERSION': 2, - 'WEBGL_CPU_FORWARD': false, - 'WEBGL_SIZE_UPLOAD_UNIFORM': 0 - } -}); - /** * Expect values are close between a Tensor or number array. * @param actual