Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions tfjs-backend-webgl/src/gather_nd_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,26 @@ export class GatherNDProgram implements GPGPUProgram {
private sliceDim: number, private strides: number[], shape: number[],
private paramsShape: number[]) {
this.outputShape = shape;
const stridesType = getCoordsDataType(strides.length);
const dtype = getCoordsDataType(shape.length);
const strideString = this.sliceDim > 1 ? 'strides[j]' : 'strides';
const paramsShapeType = getCoordsDataType(paramsShape.length);
const paramsShapeString = paramsShape.length > 1 ? 'paramsShape[j]' : 'paramsShape';

let mainLoop = `
int index;`;
for (let j = 0; j < this.sliceDim; j++) {
mainLoop += `
index = round(getIndices(coords[0], ${j}));
out_of_bounds = out_of_bounds || index < 0;
out_of_bounds = out_of_bounds || index >= ${this.paramsShape[j]};
flattenIndex += index * ${this.strides[j]};`;
}

this.userCode = `
${stridesType} strides = ${stridesType}(${this.strides});
${paramsShapeType} paramsShape = ${paramsShapeType}(${this.paramsShape});
void main() {
${dtype} coords = getOutputCoords();
int flattenIndex = 0;
bool out_of_bounds = false;
for (int j = 0; j < ${this.sliceDim}; j++) {
int index = round(getIndices(coords[0], j));
out_of_bounds = out_of_bounds || index < 0;
out_of_bounds = out_of_bounds || index >= ${paramsShapeString};
flattenIndex += index * ${strideString};
}

${mainLoop}

setOutput(out_of_bounds ? 0.0 : getX(flattenIndex, coords[1]));
}
`;
Expand Down
19 changes: 19 additions & 0 deletions tfjs-backend-webgl/src/webgl_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,15 @@ describeWithFlags('gatherNd', WEBGL_ENVS, () => {
expectArraysEqual(await g.data(), [0, 1, 2, 0]);
});

it('works for out of bounds indices 1d', async () => {
const x = tf.tensor1d([...Array(4).keys()].map(e => e + 1), 'int32');
const indices = [0, 1, 2, 5];
const ind = tf.tensor2d(indices, [4, 1], 'int32');
const g = tf.gatherND(x, ind);
const expected = [1, 2, 3, 0];
expectArraysEqual(await g.data(), expected);
});

it('works for out of bounds indices 2d', async () => {
const x =
tf.tensor2d([...Array(4).keys()].map(e => e + 1), [2, 2], 'int32');
Expand Down Expand Up @@ -957,4 +966,14 @@ describeWithFlags('gatherNd', WEBGL_ENVS, () => {
];
expectArraysEqual(await g.data(), expected);
});

it('works for out of bounds indices 5d', async () => {
const x = tf.tensor5d(
[...Array(32).keys()].map(e => e + 1), [2, 2, 2, 2, 2], 'int32');
const indices = [0, 0, 0, 0, 0, 2, 1, 1, 1, 1];
const ind = tf.tensor2d(indices, [2, 5], 'int32');
const g = tf.gatherND(x, ind);
const expected = [1, 0];
expectArraysEqual(await g.data(), expected);
});
});
121 changes: 121 additions & 0 deletions tfjs-backend-webgpu/src/conv2d_naive_webgpu.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/**
* @license
* Copyright 2022 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {backend_util} from '@tensorflow/tfjs-core';

import {activationFnSnippet, biasActivationSnippet} from './activation_util';
import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
import {computeDispatch} from './webgpu_util';

export class Conv2DNaiveProgram implements WebGPUProgram {
outputShape: number[];
shaderKey: string;
dispatchLayout: {x: number[], y: number[], z: number[]};
dispatch: [number, number, number];
variableNames = ['x', 'W'];
uniforms =
'filterDims: vec2<i32>, pad: vec2<i32>, stride: vec2<i32>, dilation: vec2<i32>,';
workGroupSize: [number, number, number] = [4, 4, 8];
addBias: boolean;
activation: backend_util.Activation;
hasPreluActivationWeights: boolean;
isChannelsLast: boolean;

constructor(
convInfo: backend_util.Conv2DInfo, addBias = false,
activation: backend_util.Activation = null,
hasPreluActivationWeights = false) {
this.outputShape = convInfo.outShape;
this.isChannelsLast = convInfo.dataFormat === 'channelsLast';
this.dispatchLayout = this.isChannelsLast ? {x: [2], y: [1], z: [0, 3]} :
{x: [3], y: [2], z: [0, 1]};
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workGroupSize);
this.addBias = addBias;
this.activation = activation;
this.hasPreluActivationWeights = hasPreluActivationWeights;

if (addBias) {
this.variableNames.push('bias');
}

if (hasPreluActivationWeights) {
this.variableNames.push('preluActivationWeights');
}

this.shaderKey = `conv2dnaive_${this.activation}_${this.isChannelsLast}`;
}

getUserCode(): string {
const userCode = `
${
activationFnSnippet(
this.activation, this.hasPreluActivationWeights, false, 4)}
fn readInp(batch : i32, row : i32, col : i32, chan : i32) -> f32{
let coords = vec4<i32>(batch, row, col, chan);
if (coordsInBounds4D(coords, uniforms.xShape)) {
return getX(batch, row, col, chan);
} else {
return 0.0;
}
}
fn readFilt(row : i32, col : i32, xChannel : i32, outChannel : i32) -> f32{
let coords = vec4<i32>(row, col, xChannel, outChannel);
if(coordsInBounds4D(coords, uniforms.wShape)) {
return getW(row, col, xChannel, outChannel);
} else {
return 0.0;
}
}
fn writeResult(batch : i32, row : i32, col : i32, chan : i32, valueIn : f32) {
let coords = ${
this.isChannelsLast ? `vec4<i32>(batch, row, col, chan);` :
`vec4<i32>(batch, chan, row, col);`}
if (coordsInBounds4D(coords, uniforms.outShape)) {
var value = valueIn;
${biasActivationSnippet(this.addBias, this.activation)}
setOutputAtCoords(coords.x, coords.y, coords.z, coords.w, value);
}
}
${main('index')} {
let coords = getOutputCoords();
let batch = coords[0];
let outChannel = ${this.isChannelsLast ? `coords[3];` : `coords[1];`}
let outRow = ${this.isChannelsLast ? `coords[1];` : `coords[2];`}
let outCol = ${this.isChannelsLast ? `coords[2];` : `coords[3];`}
var acc : f32 = 0.0;
for (var row = 0; row < uniforms.filterDims[0]; row = row + 1) {
for (var col = 0; col < uniforms.filterDims[1]; col = col + 1) {
let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * row - uniforms.pad[0];
let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * col - uniforms.pad[1];
for (var xChannel = 0; xChannel < ${
this.isChannelsLast ? `uniforms.xShape[3];` :
`uniforms.xShape[1];`} xChannel = xChannel + 1) {
${
this.isChannelsLast ? `let v = readInp(batch, xRow, xCol, xChannel);` :
`let v = readInp(batch, xChannel, xRow, xCol);`}
let f = readFilt(row, col, xChannel, outChannel);
acc = acc + v * f;
}
}
}
writeResult(batch, outRow, outCol, outChannel, acc);
}
`;
return userCode;
}
}
5 changes: 5 additions & 0 deletions tfjs-backend-webgpu/src/flags_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,8 @@ ENV.registerFlag('WEBGPU_USE_PROFILE_TOOL', () => false);
* Whether to use import API.
*/
ENV.registerFlag('WEBGPU_IMPORT_EXTERNAL_TEXTURE', () => true);

/**
* Whether to use conv2dNaive for debugging.
*/
ENV.registerFlag('WEBGPU_USE_NAIVE_CONV2D_DEBUG', () => false);
50 changes: 31 additions & 19 deletions tfjs-backend-webgpu/src/kernels/Conv2D_impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
* =============================================================================
*/

import {backend_util, TensorInfo} from '@tensorflow/tfjs-core';
import {backend_util, env, TensorInfo} from '@tensorflow/tfjs-core';

import {WebGPUBackend} from '../backend_webgpu';
import {Conv2DMMProgram} from '../conv2d_mm_webgpu';
import {Conv2DNaiveProgram} from '../conv2d_naive_webgpu';
import {WebGPUProgram} from '../webgpu_program';

import {batchMatMulImpl} from './BatchMatMul_impl';
import {reshape} from './Reshape';
Expand Down Expand Up @@ -184,12 +186,15 @@ export function conv2DImpl({
convInfo.filterHeight === convInfo.inHeight &&
convInfo.filterWidth === convInfo.inWidth &&
convInfo.padInfo.type === 'VALID';
if (sameSize ||
(convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
(convInfo.padInfo.type === 'SAME' ||
convInfo.padInfo.type === 'VALID'))) {
const useNaiveConv2d = env().getBool('WEBGPU_USE_NAIVE_CONV2D_DEBUG');

if (!useNaiveConv2d &&
(sameSize ||
(convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
(convInfo.padInfo.type === 'SAME' ||
convInfo.padInfo.type === 'VALID')))) {
return conv2dByMatMul({
x,
filter,
Expand All @@ -202,25 +207,32 @@ export function conv2DImpl({
});
}

const dimAOuter = isChannelsLast ? convInfo.outHeight * convInfo.outWidth :
convInfo.outChannels;
const dimBOuter = isChannelsLast ? convInfo.outChannels :
convInfo.outHeight * convInfo.outWidth;
const dimInner =
convInfo.filterHeight * convInfo.filterWidth * convInfo.inChannels;
let program: WebGPUProgram;
const padInfo = [convInfo.padInfo.top, convInfo.padInfo.left];
const dimensions = [
{type: 'int32', data: [convInfo.filterHeight, convInfo.filterWidth]},
{type: 'int32', data: [...padInfo]},
{type: 'int32', data: [convInfo.strideHeight, convInfo.strideWidth]},
{type: 'int32', data: [convInfo.dilationHeight, convInfo.dilationWidth]},
{type: 'int32', data: [dimAOuter]}, {type: 'int32', data: [dimBOuter]},
{type: 'int32', data: [dimInner]}
{type: 'int32', data: [convInfo.dilationHeight, convInfo.dilationWidth]}
];
if (useNaiveConv2d) {
program = new Conv2DNaiveProgram(
convInfo, hasBias, activation, hasPreluActivationWeights);
} else {
const dimAOuter = isChannelsLast ? convInfo.outHeight * convInfo.outWidth :
convInfo.outChannels;
const dimBOuter = isChannelsLast ? convInfo.outChannels :
convInfo.outHeight * convInfo.outWidth;
const dimInner =
convInfo.filterHeight * convInfo.filterWidth * convInfo.inChannels;
dimensions.push(
{type: 'int32', data: [dimAOuter]}, {type: 'int32', data: [dimBOuter]},
{type: 'int32', data: [dimInner]});

const program = new Conv2DMMProgram(
convInfo, dimAOuter, dimBOuter, dimInner, hasBias, activation,
hasPreluActivationWeights);
program = new Conv2DMMProgram(
convInfo, dimAOuter, dimBOuter, dimInner, hasBias, activation,
hasPreluActivationWeights);
}

const intermediates: TensorInfo[] = [];
const inputVar: TensorInfo[] = [x, filter];
Expand Down
29 changes: 28 additions & 1 deletion tfjs-layers/src/exports_layers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import {ZeroPadding2D, ZeroPadding2DLayerArgs} from './layers/padding';
import {AveragePooling1D, AveragePooling2D, AveragePooling3D, GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalMaxPooling1D, GlobalMaxPooling2D, GlobalPooling2DLayerArgs, MaxPooling1D, MaxPooling2D, MaxPooling3D, Pooling1DLayerArgs, Pooling2DLayerArgs, Pooling3DLayerArgs} from './layers/pooling';
import {GRU, GRUCell, GRUCellLayerArgs, GRULayerArgs, LSTM, LSTMCell, LSTMCellLayerArgs, LSTMLayerArgs, RNN, RNNCell, RNNLayerArgs, SimpleRNN, SimpleRNNCell, SimpleRNNCellLayerArgs, SimpleRNNLayerArgs, StackedRNNCells, StackedRNNCellsArgs} from './layers/recurrent';
import {Bidirectional, BidirectionalLayerArgs, TimeDistributed, WrapperLayerArgs} from './layers/wrappers';
import { Rescaling, RescalingArgs } from './layers/preprocessing/image_preprocessing';
import {Rescaling, RescalingArgs} from './layers/preprocessing/image_preprocessing';
import {Resizing, ResizingArgs} from './layers/preprocessing/image_resizing';

// TODO(cais): Add doc string to all the public static functions in this
// class; include exectuable JavaScript code snippets where applicable
Expand Down Expand Up @@ -1729,3 +1730,29 @@ export function masking(args?: MaskingArgs) {
export function rescaling(args?: RescalingArgs) {
return new Rescaling(args);
}

/**
* A preprocessing layer which resizes images.
* This layer resizes an image input to a target height and width. The input
* should be a 4D (batched) or 3D (unbatched) tensor in `"channels_last"`
* format. Input pixel values can be of any range (e.g. `[0., 1.)` or `[0,
* 255]`) and of interger or floating point dtype. By default, the layer will
* output floats.
*
* Arguments:
* - `height`: number, the height for the output tensor.
* - `width`: number, the width for the output tensor.
* - `interpolation`: string, the method for image resizing interpolation.
* - `cropToAspectRatio`: boolean, whether to keep image aspect ratio.
*
* Input shape:
* Arbitrary.
*
* Output shape:
* height, width, num channels.
*
* @doc {heading: 'Layers', subheading: 'Resizing', namespace: 'layers'}
*/
export function resizing(args?: ResizingArgs) {
return new Resizing(args);
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
* =============================================================================
*/

import {LayerArgs, Layer} from '../../engine/topology';
import { LayerArgs, Layer } from '../../engine/topology';
import { serialization, Tensor, mul, add, tidy } from '@tensorflow/tfjs-core';
import { getExactlyOneTensor } from '../../utils/types_utils';
import * as K from '../../backend/tfjs_backend';
Expand Down
29 changes: 21 additions & 8 deletions tfjs-layers/src/layers/preprocessing/image_preprocessing_test.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
import { Tensor, randomNormal, mul, add} from '@tensorflow/tfjs-core';
import { Rescaling } from './image_preprocessing';
import { describeMathCPUAndGPU, expectTensorsClose } from '../../utils/test_utils';
/**
* @license
* Copyright 2022 CodeSmith LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/

describeMathCPUAndGPU('Rescaling Layer', () => {
/**
* Unit Tests for image rescaling layer.
*/

import {add, mul, randomNormal, Tensor} from '@tensorflow/tfjs-core';
import {describeMathCPUAndGPU, expectTensorsClose} from '../../utils/test_utils';

import {Rescaling} from './image_preprocessing';

describeMathCPUAndGPU('Rescaling Layer', () => {
it('Check if input shape matches output shape', () => {
const scale = 1.0 / 127.5;
const offset = 0;
Expand Down Expand Up @@ -31,16 +45,15 @@ describeMathCPUAndGPU('Rescaling Layer', () => {
const expectedOutputTensor = add(mul(intTensor, scale), offset);
const scalingLayer = new Rescaling({scale, offset});
const outputTensor = scalingLayer.apply(intTensor) as Tensor;
expect(outputTensor.dtype).toBe('float32');
expectTensorsClose(outputTensor, expectedOutputTensor);
expect(outputTensor.dtype).toBe('float32');
expectTensorsClose(outputTensor, expectedOutputTensor);
});

it('Config holds correct name', () => {
const scale = 1.0 / 127.5;
const offset = -1.0;
const scalingLayer = new Rescaling({scale, offset, name: 'Rescaling'});
const config = scalingLayer.getConfig();
expect(config.name).toEqual('Rescaling');
expect(config.name).toEqual('Rescaling');
});

});
Loading