Skip to content

Commit 043bf0e

Browse files
authored
webgpu: support AvgPoolGrad kernel (#7188)
1 parent ae902e5 commit 043bf0e

File tree

5 files changed

+175
-8
lines changed

5 files changed

+175
-8
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/**
2+
* @license
3+
* Copyright 2022 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {backend_util} from '@tensorflow/tfjs-core';
19+
import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
20+
import {computeDispatch, flatDispatchLayout} from './webgpu_util';
21+
22+
export class AvgPool2DBackpropProgram implements WebGPUProgram {
23+
outputShape: number[];
24+
shaderKey: string;
25+
dispatchLayout: {x: number[]};
26+
dispatch: [number, number, number];
27+
variableNames = ['dy'];
28+
uniforms =
29+
`stride : vec2<i32>, pads : vec2<i32>, dilation : vec2<i32>, filterDims : vec2<i32>,
30+
outHeight : i32, outWidth : i32, avgMultiplier : f32,`;
31+
workgroupSize: [number, number, number] = [64, 1, 1];
32+
size = true;
33+
34+
constructor(convInfo: backend_util.Conv2DInfo) {
35+
this.outputShape = convInfo.inShape;
36+
37+
this.dispatchLayout = flatDispatchLayout(this.outputShape);
38+
39+
this.dispatch = computeDispatch(
40+
this.dispatchLayout, this.outputShape, this.workgroupSize);
41+
42+
this.shaderKey = `avg_pool2d_backprop`;
43+
}
44+
45+
getUserCode(): string {
46+
const userCode = `
47+
${main('index')} {
48+
if (index < uniforms.size) {
49+
let coords = getCoordsFromIndex(index);
50+
let batch = coords[0];
51+
let d = coords[3];
52+
53+
let dyRCCorner = vec2<i32>(coords.yz) - uniforms.pads;
54+
let dyRCorner = dyRCCorner.x;
55+
let dyCCorner = dyRCCorner.y;
56+
57+
// Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).
58+
// ? = to be determined. : = across all values in that axis.
59+
var dotProd = 0.0;
60+
for (var wR = 0; wR < uniforms.filterDims[0]; wR = wR + uniforms.dilation[0]) {
61+
let dyR = f32(dyRCorner + wR) / f32(uniforms.stride[0]);
62+
63+
if (dyR < 0.0 || dyR >= f32(uniforms.outHeight) || fract(dyR) > 0.0) {
64+
continue;
65+
}
66+
let idyR = i32(dyR);
67+
68+
for (var wC = 0; wC < uniforms.filterDims[1]; wC = wC + uniforms.dilation[1]) {
69+
let dyC = f32(dyCCorner + wC) / f32(uniforms.stride[1]);
70+
71+
if (dyC < 0.0 || dyC >= f32(uniforms.outWidth) || fract(dyC) > 0.0) {
72+
continue;
73+
}
74+
let idyC = i32(dyC);
75+
76+
let dyValue = getDy(batch, idyR, idyC, d);
77+
78+
dotProd = dotProd + dyValue * uniforms.avgMultiplier;
79+
}
80+
}
81+
setOutputAtIndex(index, dotProd);
82+
}
83+
}
84+
`;
85+
return userCode;
86+
}
87+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/**
2+
* @license
3+
* Copyright 2022 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {AvgPoolGrad, AvgPoolGradAttrs, AvgPoolGradInputs, backend_util, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core';
19+
20+
import {AvgPool2DBackpropProgram} from '../avg_pool2d_backprop_webgpu';
21+
import {WebGPUBackend} from '../backend_webgpu';
22+
import {assertNotComplex} from '../webgpu_util';
23+
24+
export function avgPoolGrad(args: {
25+
inputs: AvgPoolGradInputs,
26+
backend: WebGPUBackend,
27+
attrs: AvgPoolGradAttrs
28+
}): TensorInfo {
29+
const {inputs, backend, attrs} = args;
30+
const {dy, input} = inputs;
31+
const x = input;
32+
assertNotComplex([dy, input], 'avgPoolGrad');
33+
const {filterSize, strides, pad} = attrs;
34+
35+
const convInfo = backend_util.computePool2DInfo(
36+
x.shape as [number, number, number, number], filterSize, strides,
37+
1 /* dilations */, pad);
38+
const avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo);
39+
const avgMultiplier = 1 / (convInfo.filterHeight * convInfo.filterWidth);
40+
const uniformData = [
41+
{type: 'int32', data: [convInfo.strideHeight, convInfo.strideWidth]}, {
42+
type: 'int32',
43+
data: [
44+
convInfo.effectiveFilterHeight - 1 - convInfo.padInfo.top,
45+
convInfo.effectiveFilterWidth - 1 - convInfo.padInfo.left
46+
]
47+
},
48+
{type: 'int32', data: [convInfo.dilationHeight, convInfo.dilationWidth]}, {
49+
type: 'int32',
50+
data: [convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth]
51+
},
52+
{type: 'int32', data: [convInfo.outHeight]},
53+
{type: 'int32', data: [convInfo.outWidth]},
54+
{type: 'float32', data: [avgMultiplier]}
55+
];
56+
return backend.runWebGPUProgram(
57+
avgPoolBackpropProgram, [dy], x.dtype, uniformData);
58+
}
59+
60+
export const avgPoolGradConfig: KernelConfig = {
61+
kernelName: AvgPoolGrad,
62+
backendName: 'webgpu',
63+
kernelFunc: avgPoolGrad as unknown as KernelFunc
64+
};

tfjs-backend-webgpu/src/register_all_kernels.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import {atanConfig} from './kernels/Atan';
3232
import {atan2Config} from './kernels/Atan2';
3333
import {atanhConfig} from './kernels/Atanh';
3434
import {avgPoolConfig} from './kernels/AvgPool';
35+
import {avgPoolGradConfig} from './kernels/AvgPoolGrad';
3536
import {batchMatMulConfig} from './kernels/BatchMatMul';
3637
import {batchToSpaceNDConfig} from './kernels/BatchToSpaceND';
3738
import {bincountConfig} from './kernels/Bincount';
@@ -170,6 +171,7 @@ const kernelConfigs: KernelConfig[] = [
170171
atan2Config,
171172
atanhConfig,
172173
avgPoolConfig,
174+
avgPoolGradConfig,
173175
batchMatMulConfig,
174176
batchToSpaceNDConfig,
175177
bincountConfig,

tfjs-backend-webgpu/src/setup_test.ts

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,6 @@ const TEST_FILTERS: TestFilter[] = [
4545
'gradient', // Not yet implemented.
4646
]
4747
},
48-
{
49-
startsWith: 'avgPool ',
50-
excludes: [
51-
'gradient', // Not yet implemented.
52-
]
53-
},
5448
{
5549
startsWith: 'batchToSpaceND ',
5650
excludes: [
@@ -186,6 +180,12 @@ const TEST_FILTERS: TestFilter[] = [
186180
'poolBackprop', // maxPoolBackprop not yet implemented.
187181
]
188182
},
183+
{
184+
startsWith: 'poolBackprop ',
185+
excludes: [
186+
'max', // maxPoolBackprop not yet implemented.
187+
]
188+
},
189189
{
190190
startsWith: 'prod ',
191191
excludes: [
@@ -264,7 +264,6 @@ const TEST_FILTERS: TestFilter[] = [
264264
'maxPoolBackprop ',
265265
'maxPoolWithArgmax ',
266266
'multinomial ',
267-
'poolBackprop ',
268267
'raggedGather ',
269268
'raggedRange ',
270269
'raggedTensorToTensor ',

tfjs-backend-webgpu/src/webgpu_util.ts

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
* =============================================================================
1616
*/
17-
import {DataType} from '@tensorflow/tfjs-core';
17+
import {DataType, TensorInfo, util} from '@tensorflow/tfjs-core';
1818

1919
const arrayProduct = (arr: number[]) => {
2020
let product = 1;
@@ -161,6 +161,21 @@ export function isWebGPUSupported(): boolean {
161161
!!navigator.gpu;
162162
}
163163

164+
export function assertNotComplex(
165+
tensor: TensorInfo|TensorInfo[], opName: string): void {
166+
if (!Array.isArray(tensor)) {
167+
tensor = [tensor];
168+
}
169+
tensor.forEach(t => {
170+
if (t != null) {
171+
util.assert(
172+
t.dtype !== 'complex64',
173+
() => `${opName} does not support complex64 tensors ` +
174+
'in the WebGPU backend.');
175+
}
176+
});
177+
}
178+
164179
export enum MatMulProgramType {
165180
MatMulReduceProgram,
166181
MatMulSplitKProgram,

0 commit comments

Comments
 (0)