Skip to content

Commit cf328d3

Browse files
authored
webgpu: Optimize AvgPool when filter size = input size (#6762)
* webgpu: Optimize AvgPool when filter size = input size AvgPool is very pool in cityscapes architecture in DeepLabV3. With this change, AvgPool becomes 3.07 ms from 24.77 ms.
1 parent 0557e78 commit cf328d3

File tree

4 files changed

+113
-55
lines changed

4 files changed

+113
-55
lines changed

tfjs-backend-webgpu/src/kernels/AvgPool.ts

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,10 @@
1414
* limitations under the License.
1515
* =============================================================================
1616
*/
17-
import {AvgPool, AvgPoolAttrs, AvgPoolInputs, backend_util, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core';
17+
import {AvgPool, AvgPoolAttrs, AvgPoolInputs, backend_util, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core';
1818

1919
import {WebGPUBackend} from '../backend_webgpu';
20-
21-
import {identity} from './Identity';
22-
import {Pool2DProgram} from '../pool2d_webgpu';
23-
import {PoolWithFilterSizeEqualsOneProgram} from '../pool_filtersizeone_webgpu';
20+
import {poolImpl} from './Pool_impl';
2421

2522
export function avgPool(
2623
args: {inputs: AvgPoolInputs, backend: WebGPUBackend, attrs: AvgPoolAttrs}):
@@ -32,30 +29,8 @@ export function avgPool(
3229
const convInfo = backend_util.computePool2DInfo(
3330
x.shape as [number, number, number, number], filterSize, strides,
3431
dilations, pad, dimRoundingMode);
35-
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
36-
util.arraysEqual(convInfo.inShape, convInfo.outShape)) {
37-
return identity({inputs: {x}, backend});
38-
}
39-
40-
let program: Pool2DProgram|PoolWithFilterSizeEqualsOneProgram;
41-
const dimensions =
42-
[{type: 'int32', data: [convInfo.strideHeight, convInfo.strideWidth]}];
43-
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1) {
44-
program = new PoolWithFilterSizeEqualsOneProgram(convInfo);
45-
} else {
46-
program = new Pool2DProgram(convInfo, 'avg');
47-
dimensions.push(
48-
{type: 'int32', data: [convInfo.padInfo.top, convInfo.padInfo.left]}, {
49-
type: 'int32',
50-
data: [convInfo.dilationHeight, convInfo.dilationWidth]
51-
},
52-
{type: 'int32', data: [convInfo.inHeight, convInfo.inWidth]}, {
53-
type: 'int32',
54-
data: [convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth]
55-
});
56-
}
5732

58-
return backend.runWebGPUProgram(program, [x], x.dtype, dimensions);
33+
return poolImpl(x, convInfo, 'avg', backend);
5934
}
6035

6136
export const avgPoolConfig: KernelConfig = {

tfjs-backend-webgpu/src/kernels/MaxPool.ts

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,10 @@
1414
* limitations under the License.
1515
* =============================================================================
1616
*/
17-
import {backend_util, KernelConfig, KernelFunc, MaxPool, MaxPoolAttrs, MaxPoolInputs, TensorInfo, util} from '@tensorflow/tfjs-core';
17+
import {backend_util, KernelConfig, KernelFunc, MaxPool, MaxPoolAttrs, MaxPoolInputs, TensorInfo} from '@tensorflow/tfjs-core';
1818

1919
import {WebGPUBackend} from '../backend_webgpu';
20-
import {identity} from './Identity';
21-
import {Pool2DProgram} from '../pool2d_webgpu';
22-
import {PoolWithFilterSizeEqualsOneProgram} from '../pool_filtersizeone_webgpu';
20+
import {poolImpl} from './Pool_impl';
2321

2422
export function maxPool(
2523
args: {inputs: MaxPoolInputs, backend: WebGPUBackend, attrs: MaxPoolAttrs}):
@@ -31,30 +29,8 @@ export function maxPool(
3129
const convInfo = backend_util.computePool2DInfo(
3230
x.shape as [number, number, number, number], filterSize, strides,
3331
dilations, pad, dimRoundingMode);
34-
let program: Pool2DProgram|PoolWithFilterSizeEqualsOneProgram;
35-
const dimensions = [];
36-
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1) {
37-
if (util.arraysEqual(convInfo.inShape, convInfo.outShape)) {
38-
return identity({inputs: {x}, backend});
39-
}
40-
program = new PoolWithFilterSizeEqualsOneProgram(convInfo);
41-
dimensions.push(
42-
{type: 'int32', data: [convInfo.strideHeight, convInfo.strideWidth]});
43-
} else {
44-
program = new Pool2DProgram(convInfo, 'max');
45-
dimensions.push(
46-
{type: 'int32', data: [convInfo.strideHeight, convInfo.strideWidth]},
47-
{type: 'int32', data: [convInfo.padInfo.top, convInfo.padInfo.left]}, {
48-
type: 'int32',
49-
data: [convInfo.dilationHeight, convInfo.dilationWidth]
50-
},
51-
{type: 'int32', data: [convInfo.inHeight, convInfo.inWidth]}, {
52-
type: 'int32',
53-
data: [convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth]
54-
});
55-
}
5632

57-
return backend.runWebGPUProgram(program, [x], x.dtype, dimensions);
33+
return poolImpl(x, convInfo, 'max', backend);
5834
}
5935

6036
export const maxPoolConfig: KernelConfig = {
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/**
2+
* @license
3+
* Copyright 2022 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
import {backend_util, TensorInfo, util} from '@tensorflow/tfjs-core';
18+
19+
import {WebGPUBackend} from '../backend_webgpu';
20+
import {Pool2DProgram} from '../pool2d_webgpu';
21+
import {PoolWithFilterSizeEqualsOneProgram} from '../pool_filtersizeone_webgpu';
22+
23+
import {identity} from './Identity';
24+
import {max} from './Max';
25+
import {mean} from './Mean';
26+
import {reshape} from './Reshape';
27+
28+
type PoolType = 'max'|'avg';
29+
export function poolImpl(
30+
x: TensorInfo, convInfo: backend_util.Conv2DInfo, poolType: PoolType,
31+
backend: WebGPUBackend): TensorInfo {
32+
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
33+
util.arraysEqual(convInfo.inShape, convInfo.outShape)) {
34+
return identity({inputs: {x}, backend});
35+
}
36+
37+
if (convInfo.filterWidth === convInfo.inWidth &&
38+
convInfo.filterHeight === convInfo.inHeight && convInfo.batchSize === 1 &&
39+
convInfo.padInfo.type === 'VALID') {
40+
const length = x.shape.length;
41+
const reshapeX = reshape({
42+
inputs: {x},
43+
backend,
44+
attrs: {
45+
shape: [
46+
x.shape[length - 3] * x.shape[length - 2] /* height * width */,
47+
x.shape[length - 1] /* channel */
48+
]
49+
}
50+
});
51+
let reduceX;
52+
if (poolType === 'avg') {
53+
reduceX = mean(
54+
{inputs: {x: reshapeX}, backend, attrs: {axis: 0, keepDims: false}});
55+
} else {
56+
util.assert(poolType === 'max', () => `Invalid pool type ${poolType}`);
57+
reduceX = max({
58+
inputs: {x: reshapeX},
59+
backend,
60+
attrs: {reductionIndices: 0, keepDims: false}
61+
});
62+
}
63+
64+
const result = reshape(
65+
{inputs: {x: reduceX}, backend, attrs: {shape: convInfo.outShape}});
66+
backend.disposeData(reshapeX.dataId);
67+
backend.disposeData(reduceX.dataId);
68+
return result;
69+
}
70+
71+
let program: Pool2DProgram|PoolWithFilterSizeEqualsOneProgram;
72+
const dimensions =
73+
[{type: 'int32', data: [convInfo.strideHeight, convInfo.strideWidth]}];
74+
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1) {
75+
program = new PoolWithFilterSizeEqualsOneProgram(convInfo);
76+
} else {
77+
if (poolType === 'avg') {
78+
program = new Pool2DProgram(convInfo, 'avg');
79+
} else {
80+
util.assert(poolType === 'max', () => `Invalid pool type ${poolType}`);
81+
program = new Pool2DProgram(convInfo, 'max');
82+
}
83+
84+
dimensions.push(
85+
{type: 'int32', data: [convInfo.padInfo.top, convInfo.padInfo.left]}, {
86+
type: 'int32',
87+
data: [convInfo.dilationHeight, convInfo.dilationWidth]
88+
},
89+
{type: 'int32', data: [convInfo.inHeight, convInfo.inWidth]}, {
90+
type: 'int32',
91+
data: [convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth]
92+
});
93+
}
94+
95+
return backend.runWebGPUProgram(program, [x], x.dtype, dimensions);
96+
}

tfjs-core/src/ops/avg_pool_test.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,17 @@ describeWithFlags('avgPool', ALL_ENVS, () => {
9393
expectArraysClose(await result.data(), [2.5, 3, 3.5, 4]);
9494
});
9595

96+
it('x=[2,2,3] f=[2,2] s=1 p=valid', async () => {
97+
// Feed forward.
98+
const a = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]);
99+
const fSize = 2;
100+
const strides = 1;
101+
const result = tf.avgPool(a, fSize, strides, 'valid');
102+
103+
expect(result.shape).toEqual([1, 1, 3]);
104+
expectArraysClose(await result.data(), [5.5, 6.5, 7.5]);
105+
});
106+
96107
it('x=[3,3,1] f=[3,3] s=1 p=explicit', async () => {
97108
// Feed forward.
98109
const x = tf.tensor3d([0, 1, 2, 3, 4, 5, 6, 7, 8], [3, 3, 1]);

0 commit comments

Comments
 (0)