Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit cbd08fe

Browse files
authored
Change conv2d to match TF API (#67)
* update the readme and npm package * update readme * change conv2d api to take padding same|valid * Merge branch 'master' into conv * move logic to math * some progress * Merge remote-tracking branch 'origin/master' into conv * Migrate all conv-related ops * rename conv2dTranspose to conv2dDerInput * switch shader indexing from float to int * revert graph_runner_test * self review * Merge remote-tracking branch 'origin/master' into conv * merge with the branch int_indexing * fix typos in shaders * merge with master * update pool ops * remove commented out code * simplify api * added unit tests for conv/pool * add doc * Merge master into conv
1 parent 19c6541 commit cbd08fe

26 files changed

+843
-618
lines changed

demos/benchmarks/conv_gpu_benchmark.ts

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,22 @@ export const BENCHMARK_TEST: BenchmarkTest = (size: number) => {
2929
const texManager = new TextureManager(gpgpu);
3030
initializeGPU(gpgpu, texManager);
3131

32-
const inputDepth = 1;
33-
const inputShape: [number, number, number] = [size, size, inputDepth];
34-
const outputDepth = 1;
35-
const fieldSize = 11;
32+
const inDepth = 1;
33+
const inShape: [number, number, number] = [size, size, inDepth];
34+
const outDepth = 1;
35+
const filterSize = 11;
3636
const stride = 1;
37-
const zeroPad = conv_util.computeDefaultPad(inputShape, fieldSize, stride);
38-
3937
const hasBias = true;
40-
const program = new Conv2DProgram(
41-
inputShape, fieldSize, outputDepth, stride, zeroPad, hasBias);
38+
const convInfo = conv_util.computeConvInfo(
39+
inShape, filterSize, filterSize, outDepth, stride, stride, 'same');
40+
const program = new Conv2DProgram(convInfo, hasBias);
4241
const outputShape = program.outputShape as [number, number, number];
4342
const out = Array3D.zeros(outputShape);
44-
const x = Array3D.randUniform(inputShape, -1, 1);
45-
const wShape = conv_util.computeWeightsShape4D(1, outputDepth, fieldSize);
43+
const x = Array3D.randUniform(inShape, -1, 1);
44+
const wShape =
45+
conv_util.computeWeightsShape4D(1, outDepth, filterSize, filterSize);
4646
const W = Array4D.randUniform(wShape, -1, 1);
47-
const b = Array1D.randUniform([outputDepth], -1, 1);
47+
const b = Array1D.randUniform([outDepth], -1, 1);
4848
const inputs = [x, W, b];
4949
const binary = gpgpu_math.compileProgram(gpgpu, program, inputs, out);
5050

demos/benchmarks/conv_transpose_gpu_benchmark.ts

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ limitations under the License.
1515

1616
import * as conv_util from '../../src/math/conv_util';
1717
import {Array3D, Array4D, initializeGPU} from '../../src/math/ndarray';
18-
import {Conv2DTransposeProgram} from '../../src/math/webgl/conv_backprop_gpu';
18+
import {Conv2DDerInputProgram} from '../../src/math/webgl/conv_backprop_gpu';
1919
import {GPGPUContext} from '../../src/math/webgl/gpgpu_context';
2020
import * as gpgpu_math from '../../src/math/webgl/gpgpu_math';
2121
import {TextureManager} from '../../src/math/webgl/texture_manager';
@@ -25,8 +25,8 @@ const OP_RUNS = 40;
2525

2626
export const BENCHMARK_TEST: BenchmarkTest = (size: number) => {
2727
const origInputDepth = 1;
28-
const origOutputDepth = 2;
29-
const xShape: [number, number, number] = [size, size, 1];
28+
const origOutputDepth = 1;
29+
const xShape: [number, number, number] = [size, size, origOutputDepth];
3030
const fieldSize = 11;
3131
const origStride = 1;
3232
const origPad = 1;
@@ -36,14 +36,15 @@ export const BENCHMARK_TEST: BenchmarkTest = (size: number) => {
3636
initializeGPU(gpgpu, texManager);
3737
gpgpu.enableAutomaticDebugValidation(true);
3838

39-
const hasBias = false;
40-
const program = new Conv2DTransposeProgram(
41-
xShape, fieldSize, origInputDepth, origStride, origPad, hasBias);
39+
const convInfo = conv_util.computeConvInfo(
40+
xShape, fieldSize, fieldSize, origOutputDepth, origStride, origStride,
41+
origPad);
42+
const program = new Conv2DDerInputProgram(convInfo);
4243
const outputShape = program.outputShape as [number, number, number];
4344
const out = Array3D.zeros(outputShape);
4445
const x = Array3D.randUniform(xShape, -1, 1);
4546
const wShape = conv_util.computeWeightsShape4D(
46-
origInputDepth, origOutputDepth, fieldSize);
47+
origInputDepth, origOutputDepth, fieldSize, fieldSize);
4748
const W = Array4D.randUniform(wShape, -1, 1);
4849
const inputs = [x, W];
4950
const binary = gpgpu_math.compileProgram(gpgpu, program, inputs, out);

demos/benchmarks/max_pool_backprop_gpu_benchmark.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@ export const BENCHMARK_TEST: BenchmarkTest = (size: number) => {
2929
const texManager = new TextureManager(gpgpu);
3030
initializeGPU(gpgpu, texManager);
3131

32-
const outputDepth = 1;
33-
const dyShape: [number, number, number] = [size, size, outputDepth];
32+
const depth = 1;
33+
const dyShape: [number, number, number] = [size, size, depth];
34+
const xShape: [number, number, number] = [size, size, depth];
3435
const fSize = 11;
3536
const stride = 1;
36-
const zeroPad = conv_util.computeDefaultPad(dyShape, fSize, stride);
37-
const program = new MaxPool2DBackpropProgram(dyShape, fSize, stride, zeroPad);
37+
const convInfo = conv_util.computeConvInfo(
38+
xShape, fSize, fSize, depth, stride, stride, 'same');
39+
const program = new MaxPool2DBackpropProgram(convInfo);
3840
const res = NDArray.zeros(program.outputShape);
3941
const dy = Array3D.randUniform(dyShape, -1, 1);
4042
const positionsData = new Float32Array(dy.size);

demos/benchmarks/max_pool_gpu_benchmark.ts

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,9 @@ function testMaxPool(size: number, positions: boolean): number {
4343
const xShape: [number, number, number] = [size, size, outputDepth];
4444
const fieldSize = 11;
4545
const stride = 1;
46-
const zeroPad = conv_util.computeDefaultPad(xShape, fieldSize, stride);
47-
48-
const program =
49-
new Pool2DProgram(xShape, fieldSize, stride, zeroPad, 'max', positions);
46+
const convInfo = conv_util.computeConvInfo(
47+
xShape, fieldSize, fieldSize, outputDepth, stride, stride, 'same');
48+
const program = new Pool2DProgram(convInfo, 'max', positions);
5049
const res = NDArray.zeros(program.outputShape);
5150
const x = Array3D.randUniform(xShape, -1, 1);
5251
const binary = gpgpu_math.compileProgram(gpgpu, program, [x], res);

demos/model-builder/layer_builder.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ export class Convolution2DLayerBuilder implements LayerBuilder {
207207
{
208208
label: 'Output depth',
209209
initialValue: (inputShape: number[]) =>
210-
this.outputDepth != null ? this.outputDepth : 1,
210+
this.outputDepth != null ? this.outputDepth : 1,
211211
type: 'number',
212212
min: 1,
213213
max: 1000,
@@ -319,7 +319,7 @@ export class ReshapeLayerBuilder implements LayerBuilder {
319319
initialValue: (inputShape: number[]) => inputShape.join(', '),
320320
type: 'text' as 'text',
321321
setValue: (value: string) => this.outputShape =
322-
value.split(',').map((value) => +value),
322+
value.split(',').map((value) => +value),
323323
getValue: () => this.outputShape.join(', ')
324324
}];
325325
}

src/graph.ts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -694,10 +694,9 @@ export class MaxPoolNode extends Node {
694694
graph: Graph, private x: Tensor, public fieldSize: number,
695695
public stride = 1, public zeroPad?: number) {
696696
super(
697-
graph, 'Max pool', {x},
698-
new Tensor(conv_util.computeOutputShape3D(
699-
x.shape as [number, number, number], fieldSize, x.shape[2], stride,
700-
zeroPad)));
697+
graph, 'Max pool', {x}, new Tensor(conv_util.computeOutputShape3D(
698+
x.shape as [number, number, number],
699+
fieldSize, x.shape[2], stride, zeroPad)));
701700
}
702701
validate() {
703702
util.assert(
@@ -875,4 +874,4 @@ export class ArgMaxEqualsNode extends Node {
875874
* @hidden
876875
*/
877876
export type ArrayData =
878-
NDArray|number|number[]|number[][]|number[][][]|number[][][][];
877+
NDArray | number | number[] | number[][] | number[][][] | number[][][][];

src/math/conv_util.ts

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,88 @@ limitations under the License.
1515

1616
import * as util from '../util';
1717

18+
/**
19+
* Information about the forward pass of a convolution/pooling operation.
20+
* It includes input and output shape, strides, filter size and padding
21+
* information.
22+
*/
23+
export type ConvInfo = {
24+
inShape: [number, number, number],
25+
outShape: [number, number, number],
26+
strideHeight: number,
27+
strideWidth: number,
28+
filterHeight: number,
29+
filterWidth: number,
30+
padInfo: {top: number, left: number, right: number, bottom: number}
31+
};
32+
33+
/**
34+
* Computes the information about a forward pass of a convolution/pooling
35+
* operation.
36+
*/
37+
export function computeConvInfo(
38+
inShape: [number, number, number], filterHeight: number,
39+
filterWidth: number, outDepth: number, strideHeight: number,
40+
strideWidth: number, pad: 'same'|'valid'|number): ConvInfo {
41+
if (typeof pad === 'number') {
42+
const outShape = computeOutputShape3D(
43+
inShape, filterHeight, outDepth, strideHeight, pad);
44+
return {
45+
inShape,
46+
outShape,
47+
padInfo: {top: pad, bottom: pad, left: pad, right: pad},
48+
strideHeight,
49+
strideWidth,
50+
filterHeight,
51+
filterWidth
52+
};
53+
}
54+
const inHeight = inShape[0];
55+
const inWidth = inShape[1];
56+
let outShape: [number, number, number];
57+
let padInfo: {left: number, top: number, bottom: number, right: number};
58+
if (pad === 'same') {
59+
const outHeight = Math.ceil(inHeight / strideHeight);
60+
const outWidth = Math.ceil(inWidth / strideWidth);
61+
outShape = [outHeight, outWidth, outDepth];
62+
const padAlongHeight =
63+
(outHeight - 1) * strideHeight + filterHeight - inHeight;
64+
const padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth;
65+
const top = Math.floor(padAlongHeight / 2);
66+
const bottom = padAlongHeight - top;
67+
const left = Math.floor(padAlongWidth / 2);
68+
const right = padAlongWidth - left;
69+
padInfo = {top, bottom, left, right};
70+
} else if (pad === 'valid') {
71+
const outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
72+
const outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
73+
outShape = [outHeight, outWidth, outDepth];
74+
padInfo = {top: 0, bottom: 0, left: 0, right: 0};
75+
} else {
76+
throw Error(`Unknown padding parameter: ${pad}`);
77+
}
78+
return {
79+
inShape,
80+
outShape,
81+
padInfo,
82+
strideHeight,
83+
strideWidth,
84+
filterHeight,
85+
filterWidth
86+
};
87+
}
88+
89+
/**
90+
* @deprecated Use `conv_util.computeConvInfo` instead.
91+
*/
1892
export function computeOutputShape3D(
19-
inputShapeRowColDepth: [number, number, number], fieldSize: number,
20-
depth: number, stride: number, zeroPad?: number): [number, number, number] {
93+
inShape: [number, number, number], fieldSize: number, outDepth: number,
94+
stride: number, zeroPad?: number): [number, number, number] {
2195
if (zeroPad == null) {
22-
zeroPad = computeDefaultPad(inputShapeRowColDepth, fieldSize, stride);
96+
zeroPad = computeDefaultPad(inShape, fieldSize, stride);
2397
}
24-
const inputRows = inputShapeRowColDepth[0];
25-
const inputCols = inputShapeRowColDepth[1];
98+
const inputRows = inShape[0];
99+
const inputCols = inShape[1];
26100
const outputRows = (inputRows - fieldSize + 2 * zeroPad) / stride + 1;
27101
util.assert(
28102
util.isInt(outputRows),
@@ -35,7 +109,7 @@ export function computeOutputShape3D(
35109
`The output # of columns (${outputCols}) must be an integer. Change ` +
36110
`the stride and/or zero pad parameters`);
37111

38-
return [outputRows, outputCols, depth];
112+
return [outputRows, outputCols, outDepth];
39113
}
40114

41115
export function computeDefaultPad(
@@ -50,9 +124,9 @@ export function computeTexShapeFrom3D(
50124
}
51125

52126
export function computeWeightsShape4D(
53-
inputDepth: number, outputDepth: number,
54-
fSize: number): [number, number, number, number] {
55-
return [fSize, fSize, inputDepth, outputDepth];
127+
inputDepth: number, outputDepth: number, filterHeight: number,
128+
filterWidth: number): [number, number, number, number] {
129+
return [filterHeight, filterWidth, inputDepth, outputDepth];
56130
}
57131

58132
export function computeDilatedRC(

src/math/conv_util_test.ts

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/* Copyright 2017 Google Inc. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
import * as conv_util from './conv_util';
17+
18+
describe('conv_util computeConvInfo', () => {
19+
it('1x1 conv over 1x1 array with same pad', () => {
20+
const inShape: [number, number, number] = [1, 1, 1];
21+
const convInfo = conv_util.computeConvInfo(inShape, 1, 1, 1, 1, 1, 'same');
22+
expect(convInfo.outShape).toEqual([1, 1, 1]);
23+
});
24+
25+
it('2x2 conv over 3x3 array with same pad', () => {
26+
const inShape: [number, number, number] = [3, 3, 1];
27+
const convInfo = conv_util.computeConvInfo(inShape, 2, 2, 1, 1, 1, 'same');
28+
expect(convInfo.outShape).toEqual([3, 3, 1]);
29+
// Should produce non-even padding with extra pixel at the right/bottom.
30+
expect(convInfo.padInfo.left).toBe(0);
31+
expect(convInfo.padInfo.right).toBe(1);
32+
expect(convInfo.padInfo.top).toBe(0);
33+
expect(convInfo.padInfo.bottom).toBe(1);
34+
});
35+
36+
it('2x2 conv over 3x3 array with same pad', () => {
37+
const inShape: [number, number, number] = [3, 3, 1];
38+
const convInfo = conv_util.computeConvInfo(inShape, 2, 2, 1, 1, 1, 'same');
39+
expect(convInfo.outShape).toEqual([3, 3, 1]);
40+
});
41+
42+
it('2x2 conv over 3x3 array with valid pad', () => {
43+
const inShape: [number, number, number] = [3, 3, 1];
44+
const convInfo = conv_util.computeConvInfo(inShape, 2, 2, 1, 1, 1, 'valid');
45+
expect(convInfo.outShape).toEqual([2, 2, 1]);
46+
});
47+
48+
it('2x2 conv over 3x3 array with valid pad with stride 2', () => {
49+
const inShape: [number, number, number] = [3, 3, 1];
50+
const convInfo = conv_util.computeConvInfo(inShape, 2, 2, 1, 2, 2, 'valid');
51+
expect(convInfo.outShape).toEqual([1, 1, 1]);
52+
});
53+
54+
it('2x2 conv over 3x3 array with valid pad with stride 2', () => {
55+
const inShape: [number, number, number] = [3, 3, 1];
56+
const convInfo = conv_util.computeConvInfo(inShape, 2, 2, 1, 2, 2, 'valid');
57+
expect(convInfo.outShape).toEqual([1, 1, 1]);
58+
});
59+
60+
it('2x1 conv over 3x3 array with valid pad with stride 1', () => {
61+
const inShape: [number, number, number] = [3, 3, 1];
62+
const convInfo = conv_util.computeConvInfo(inShape, 2, 1, 1, 1, 1, 'valid');
63+
expect(convInfo.outShape).toEqual([2, 3, 1]);
64+
});
65+
66+
it('2x1 conv over 3x3 array with valid pad with strides h=2, w=1', () => {
67+
const inShape: [number, number, number] = [3, 3, 1];
68+
const convInfo = conv_util.computeConvInfo(inShape, 2, 1, 1, 2, 1, 'valid');
69+
expect(convInfo.outShape).toEqual([1, 3, 1]);
70+
});
71+
72+
it('1x2 conv over 3x3 array with valid pad with stride 1', () => {
73+
const inShape: [number, number, number] = [3, 3, 1];
74+
const convInfo = conv_util.computeConvInfo(inShape, 1, 2, 1, 1, 1, 'valid');
75+
expect(convInfo.outShape).toEqual([3, 2, 1]);
76+
});
77+
});

0 commit comments

Comments
 (0)