Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions tfjs-backend-webgl/src/gather_nd_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,26 @@ export class GatherNDProgram implements GPGPUProgram {
private sliceDim: number, private strides: number[], shape: number[],
private paramsShape: number[]) {
this.outputShape = shape;
const stridesType = getCoordsDataType(strides.length);
const dtype = getCoordsDataType(shape.length);
const strideString = this.sliceDim > 1 ? 'strides[j]' : 'strides';
const paramsShapeType = getCoordsDataType(paramsShape.length);
const paramsShapeString = paramsShape.length > 1 ? 'paramsShape[j]' : 'paramsShape';

let mainLoop = `
int index;`;
for (let j = 0; j < this.sliceDim; j++) {
mainLoop += `
index = round(getIndices(coords[0], ${j}));
out_of_bounds = out_of_bounds || index < 0;
out_of_bounds = out_of_bounds || index >= ${this.paramsShape[j]};
flattenIndex += index * ${this.strides[j]};`;
}

this.userCode = `
${stridesType} strides = ${stridesType}(${this.strides});
${paramsShapeType} paramsShape = ${paramsShapeType}(${this.paramsShape});
void main() {
${dtype} coords = getOutputCoords();
int flattenIndex = 0;
bool out_of_bounds = false;
for (int j = 0; j < ${this.sliceDim}; j++) {
int index = round(getIndices(coords[0], j));
out_of_bounds = out_of_bounds || index < 0;
out_of_bounds = out_of_bounds || index >= ${paramsShapeString};
flattenIndex += index * ${strideString};
}

${mainLoop}

setOutput(out_of_bounds ? 0.0 : getX(flattenIndex, coords[1]));
}
`;
Expand Down
19 changes: 19 additions & 0 deletions tfjs-backend-webgl/src/webgl_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,15 @@ describeWithFlags('gatherNd', WEBGL_ENVS, () => {
expectArraysEqual(await g.data(), [0, 1, 2, 0]);
});

it('works for out of bounds indices 1d', async () => {
const x = tf.tensor1d([...Array(4).keys()].map(e => e + 1), 'int32');
const indices = [0, 1, 2, 5];
const ind = tf.tensor2d(indices, [4, 1], 'int32');
const g = tf.gatherND(x, ind);
const expected = [1, 2, 3, 0];
expectArraysEqual(await g.data(), expected);
});

it('works for out of bounds indices 2d', async () => {
const x =
tf.tensor2d([...Array(4).keys()].map(e => e + 1), [2, 2], 'int32');
Expand Down Expand Up @@ -957,4 +966,14 @@ describeWithFlags('gatherNd', WEBGL_ENVS, () => {
];
expectArraysEqual(await g.data(), expected);
});

it('works for out of bounds indices 5d', async () => {
const x = tf.tensor5d(
[...Array(32).keys()].map(e => e + 1), [2, 2, 2, 2, 2], 'int32');
const indices = [0, 0, 0, 0, 0, 2, 1, 1, 1, 1];
const ind = tf.tensor2d(indices, [2, 5], 'int32');
const g = tf.gatherND(x, ind);
const expected = [1, 0];
expectArraysEqual(await g.data(), expected);
});
});
121 changes: 121 additions & 0 deletions tfjs-backend-webgpu/src/conv2d_naive_webgpu.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/**
* @license
* Copyright 2022 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {backend_util} from '@tensorflow/tfjs-core';

import {activationFnSnippet, biasActivationSnippet} from './activation_util';
import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
import {computeDispatch} from './webgpu_util';

export class Conv2DNaiveProgram implements WebGPUProgram {
outputShape: number[];
shaderKey: string;
dispatchLayout: {x: number[], y: number[], z: number[]};
dispatch: [number, number, number];
variableNames = ['x', 'W'];
uniforms =
'filterDims: vec2<i32>, pad: vec2<i32>, stride: vec2<i32>, dilation: vec2<i32>,';
workGroupSize: [number, number, number] = [4, 4, 8];
addBias: boolean;
activation: backend_util.Activation;
hasPreluActivationWeights: boolean;
isChannelsLast: boolean;

constructor(
convInfo: backend_util.Conv2DInfo, addBias = false,
activation: backend_util.Activation = null,
hasPreluActivationWeights = false) {
this.outputShape = convInfo.outShape;
this.isChannelsLast = convInfo.dataFormat === 'channelsLast';
this.dispatchLayout = this.isChannelsLast ? {x: [2], y: [1], z: [0, 3]} :
{x: [3], y: [2], z: [0, 1]};
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workGroupSize);
this.addBias = addBias;
this.activation = activation;
this.hasPreluActivationWeights = hasPreluActivationWeights;

if (addBias) {
this.variableNames.push('bias');
}

if (hasPreluActivationWeights) {
this.variableNames.push('preluActivationWeights');
}

this.shaderKey = `conv2dnaive_${this.activation}_${this.isChannelsLast}`;
}

getUserCode(): string {
const userCode = `
${
activationFnSnippet(
this.activation, this.hasPreluActivationWeights, false, 4)}
fn readInp(batch : i32, row : i32, col : i32, chan : i32) -> f32{
let coords = vec4<i32>(batch, row, col, chan);
if (coordsInBounds4D(coords, uniforms.xShape)) {
return getX(batch, row, col, chan);
} else {
return 0.0;
}
}
fn readFilt(row : i32, col : i32, xChannel : i32, outChannel : i32) -> f32{
let coords = vec4<i32>(row, col, xChannel, outChannel);
if(coordsInBounds4D(coords, uniforms.wShape)) {
return getW(row, col, xChannel, outChannel);
} else {
return 0.0;
}
}
fn writeResult(batch : i32, row : i32, col : i32, chan : i32, valueIn : f32) {
let coords = ${
this.isChannelsLast ? `vec4<i32>(batch, row, col, chan);` :
`vec4<i32>(batch, chan, row, col);`}
if (coordsInBounds4D(coords, uniforms.outShape)) {
var value = valueIn;
${biasActivationSnippet(this.addBias, this.activation)}
setOutputAtCoords(coords.x, coords.y, coords.z, coords.w, value);
}
}
${main('index')} {
let coords = getOutputCoords();
let batch = coords[0];
let outChannel = ${this.isChannelsLast ? `coords[3];` : `coords[1];`}
let outRow = ${this.isChannelsLast ? `coords[1];` : `coords[2];`}
let outCol = ${this.isChannelsLast ? `coords[2];` : `coords[3];`}
var acc : f32 = 0.0;
for (var row = 0; row < uniforms.filterDims[0]; row = row + 1) {
for (var col = 0; col < uniforms.filterDims[1]; col = col + 1) {
let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * row - uniforms.pad[0];
let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * col - uniforms.pad[1];
for (var xChannel = 0; xChannel < ${
this.isChannelsLast ? `uniforms.xShape[3];` :
`uniforms.xShape[1];`} xChannel = xChannel + 1) {
${
this.isChannelsLast ? `let v = readInp(batch, xRow, xCol, xChannel);` :
`let v = readInp(batch, xChannel, xRow, xCol);`}
let f = readFilt(row, col, xChannel, outChannel);
acc = acc + v * f;
}
}
}
writeResult(batch, outRow, outCol, outChannel, acc);
}
`;
return userCode;
}
}
5 changes: 5 additions & 0 deletions tfjs-backend-webgpu/src/flags_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,8 @@ ENV.registerFlag('WEBGPU_USE_PROFILE_TOOL', () => false);
* Whether to use import API.
*/
ENV.registerFlag('WEBGPU_IMPORT_EXTERNAL_TEXTURE', () => true);

/**
* Whether to use conv2dNaive for debugging.
*/
ENV.registerFlag('WEBGPU_USE_NAIVE_CONV2D_DEBUG', () => false);
50 changes: 31 additions & 19 deletions tfjs-backend-webgpu/src/kernels/Conv2D_impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
* =============================================================================
*/

import {backend_util, TensorInfo} from '@tensorflow/tfjs-core';
import {backend_util, env, TensorInfo} from '@tensorflow/tfjs-core';

import {WebGPUBackend} from '../backend_webgpu';
import {Conv2DMMProgram} from '../conv2d_mm_webgpu';
import {Conv2DNaiveProgram} from '../conv2d_naive_webgpu';
import {WebGPUProgram} from '../webgpu_program';

import {batchMatMulImpl} from './BatchMatMul_impl';
import {reshape} from './Reshape';
Expand Down Expand Up @@ -184,12 +186,15 @@ export function conv2DImpl({
convInfo.filterHeight === convInfo.inHeight &&
convInfo.filterWidth === convInfo.inWidth &&
convInfo.padInfo.type === 'VALID';
if (sameSize ||
(convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
(convInfo.padInfo.type === 'SAME' ||
convInfo.padInfo.type === 'VALID'))) {
const useNaiveConv2d = env().getBool('WEBGPU_USE_NAIVE_CONV2D_DEBUG');

if (!useNaiveConv2d &&
(sameSize ||
(convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
(convInfo.padInfo.type === 'SAME' ||
convInfo.padInfo.type === 'VALID')))) {
return conv2dByMatMul({
x,
filter,
Expand All @@ -202,25 +207,32 @@ export function conv2DImpl({
});
}

const dimAOuter = isChannelsLast ? convInfo.outHeight * convInfo.outWidth :
convInfo.outChannels;
const dimBOuter = isChannelsLast ? convInfo.outChannels :
convInfo.outHeight * convInfo.outWidth;
const dimInner =
convInfo.filterHeight * convInfo.filterWidth * convInfo.inChannels;
let program: WebGPUProgram;
const padInfo = [convInfo.padInfo.top, convInfo.padInfo.left];
const dimensions = [
{type: 'int32', data: [convInfo.filterHeight, convInfo.filterWidth]},
{type: 'int32', data: [...padInfo]},
{type: 'int32', data: [convInfo.strideHeight, convInfo.strideWidth]},
{type: 'int32', data: [convInfo.dilationHeight, convInfo.dilationWidth]},
{type: 'int32', data: [dimAOuter]}, {type: 'int32', data: [dimBOuter]},
{type: 'int32', data: [dimInner]}
{type: 'int32', data: [convInfo.dilationHeight, convInfo.dilationWidth]}
];
if (useNaiveConv2d) {
program = new Conv2DNaiveProgram(
convInfo, hasBias, activation, hasPreluActivationWeights);
} else {
const dimAOuter = isChannelsLast ? convInfo.outHeight * convInfo.outWidth :
convInfo.outChannels;
const dimBOuter = isChannelsLast ? convInfo.outChannels :
convInfo.outHeight * convInfo.outWidth;
const dimInner =
convInfo.filterHeight * convInfo.filterWidth * convInfo.inChannels;
dimensions.push(
{type: 'int32', data: [dimAOuter]}, {type: 'int32', data: [dimBOuter]},
{type: 'int32', data: [dimInner]});

const program = new Conv2DMMProgram(
convInfo, dimAOuter, dimBOuter, dimInner, hasBias, activation,
hasPreluActivationWeights);
program = new Conv2DMMProgram(
convInfo, dimAOuter, dimBOuter, dimInner, hasBias, activation,
hasPreluActivationWeights);
}

const intermediates: TensorInfo[] = [];
const inputVar: TensorInfo[] = [x, filter];
Expand Down
4 changes: 1 addition & 3 deletions tfjs-layers/src/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ ts_library(

ts_library(
name = "tfjs-layers_test_lib",
# disable testonly for the issue in the nodejs build target.
# https://github.com/bazelbuild/rules_nodejs/pull/2984
#testonly = True,
testonly = True,
srcs = glob(TEST_SRCS) + [":tests"],
module_name = "@tensorflow/tfjs-layers/dist",
deps = [
Expand Down
32 changes: 32 additions & 0 deletions tfjs-layers/src/exports_layers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {ZeroPadding2D, ZeroPadding2DLayerArgs} from './layers/padding';
import {AveragePooling1D, AveragePooling2D, AveragePooling3D, GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalMaxPooling1D, GlobalMaxPooling2D, GlobalPooling2DLayerArgs, MaxPooling1D, MaxPooling2D, MaxPooling3D, Pooling1DLayerArgs, Pooling2DLayerArgs, Pooling3DLayerArgs} from './layers/pooling';
import {GRU, GRUCell, GRUCellLayerArgs, GRULayerArgs, LSTM, LSTMCell, LSTMCellLayerArgs, LSTMLayerArgs, RNN, RNNCell, RNNLayerArgs, SimpleRNN, SimpleRNNCell, SimpleRNNCellLayerArgs, SimpleRNNLayerArgs, StackedRNNCells, StackedRNNCellsArgs} from './layers/recurrent';
import {Bidirectional, BidirectionalLayerArgs, TimeDistributed, WrapperLayerArgs} from './layers/wrappers';
import { Rescaling, RescalingArgs } from './layers/preprocessing/image_preprocessing';

// TODO(cais): Add doc string to all the public static functions in this
// class; include exectuable JavaScript code snippets where applicable
Expand Down Expand Up @@ -1697,3 +1698,34 @@ export function alphaDropout(args: AlphaDropoutArgs) {
export function masking(args?: MaskingArgs) {
return new Masking(args);
}

/**
* A preprocessing layer which rescales input values to a new range.
*
* This layer rescales every value of an input (often an image) by multiplying
* by `scale` and adding `offset`.
*
* For instance:
* 1. To rescale an input in the ``[0, 255]`` range
* to be in the `[0, 1]` range, you would pass `scale=1/255`.
* 2. To rescale an input in the ``[0, 255]`` range to be in the `[-1, 1]`
* range, you would pass `scale=1./127.5, offset=-1`.
* The rescaling is applied both during training and inference. Inputs can be
* of integer or floating point dtype, and by default the layer will output
* floats.
*
* Arguments:
* - `scale`: Float, the scale to apply to the inputs.
* - `offset`: Float, the offset to apply to the inputs.
*
* Input shape:
* Arbitrary.
*
* Output shape:
* Same as input.
*
* @doc {heading: 'Layers', subheading: 'Rescaling', namespace: 'layers'}
*/
export function rescaling(args?: RescalingArgs) {
return new Rescaling(args);
}
Loading