diff --git a/tfjs-backend-webgpu/src/kernels/AvgPool.ts b/tfjs-backend-webgpu/src/kernels/AvgPool.ts index fe3598560b2..716d5b97182 100644 --- a/tfjs-backend-webgpu/src/kernels/AvgPool.ts +++ b/tfjs-backend-webgpu/src/kernels/AvgPool.ts @@ -14,13 +14,10 @@ * limitations under the License. * ============================================================================= */ -import {AvgPool, AvgPoolAttrs, AvgPoolInputs, backend_util, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {AvgPool, AvgPoolAttrs, AvgPoolInputs, backend_util, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core'; import {WebGPUBackend} from '../backend_webgpu'; - -import {identity} from './Identity'; -import {Pool2DProgram} from '../pool2d_webgpu'; -import {PoolWithFilterSizeEqualsOneProgram} from '../pool_filtersizeone_webgpu'; +import {poolImpl} from './Pool_impl'; export function avgPool( args: {inputs: AvgPoolInputs, backend: WebGPUBackend, attrs: AvgPoolAttrs}): @@ -32,30 +29,8 @@ export function avgPool( const convInfo = backend_util.computePool2DInfo( x.shape as [number, number, number, number], filterSize, strides, dilations, pad, dimRoundingMode); - if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && - util.arraysEqual(convInfo.inShape, convInfo.outShape)) { - return identity({inputs: {x}, backend}); - } - - let program: Pool2DProgram|PoolWithFilterSizeEqualsOneProgram; - const dimensions = - [{type: 'int32', data: [convInfo.strideHeight, convInfo.strideWidth]}]; - if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1) { - program = new PoolWithFilterSizeEqualsOneProgram(convInfo); - } else { - program = new Pool2DProgram(convInfo, 'avg'); - dimensions.push( - {type: 'int32', data: [convInfo.padInfo.top, convInfo.padInfo.left]}, { - type: 'int32', - data: [convInfo.dilationHeight, convInfo.dilationWidth] - }, - {type: 'int32', data: [convInfo.inHeight, convInfo.inWidth]}, { - type: 'int32', - data: [convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth] - }); - } - return backend.runWebGPUProgram(program, [x], x.dtype, dimensions); + return poolImpl(x, convInfo, 'avg', backend); } export const avgPoolConfig: KernelConfig = { diff --git a/tfjs-backend-webgpu/src/kernels/MaxPool.ts b/tfjs-backend-webgpu/src/kernels/MaxPool.ts index 16f1808ccf2..cc42e7cbede 100644 --- a/tfjs-backend-webgpu/src/kernels/MaxPool.ts +++ b/tfjs-backend-webgpu/src/kernels/MaxPool.ts @@ -14,12 +14,10 @@ * limitations under the License. * ============================================================================= */ -import {backend_util, KernelConfig, KernelFunc, MaxPool, MaxPoolAttrs, MaxPoolInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {backend_util, KernelConfig, KernelFunc, MaxPool, MaxPoolAttrs, MaxPoolInputs, TensorInfo} from '@tensorflow/tfjs-core'; import {WebGPUBackend} from '../backend_webgpu'; -import {identity} from './Identity'; -import {Pool2DProgram} from '../pool2d_webgpu'; -import {PoolWithFilterSizeEqualsOneProgram} from '../pool_filtersizeone_webgpu'; +import {poolImpl} from './Pool_impl'; export function maxPool( args: {inputs: MaxPoolInputs, backend: WebGPUBackend, attrs: MaxPoolAttrs}): @@ -31,30 +29,8 @@ export function maxPool( const convInfo = backend_util.computePool2DInfo( x.shape as [number, number, number, number], filterSize, strides, dilations, pad, dimRoundingMode); - let program: Pool2DProgram|PoolWithFilterSizeEqualsOneProgram; - const dimensions = []; - if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1) { - if (util.arraysEqual(convInfo.inShape, convInfo.outShape)) { - return identity({inputs: {x}, backend}); - } - program = new PoolWithFilterSizeEqualsOneProgram(convInfo); - dimensions.push( - {type: 'int32', data: [convInfo.strideHeight, convInfo.strideWidth]}); - } else { - program = new Pool2DProgram(convInfo, 'max'); - dimensions.push( - {type: 'int32', data: [convInfo.strideHeight, convInfo.strideWidth]}, - {type: 'int32', data: [convInfo.padInfo.top, convInfo.padInfo.left]}, { - type: 'int32', - data: [convInfo.dilationHeight, convInfo.dilationWidth] - }, - {type: 'int32', data: [convInfo.inHeight, convInfo.inWidth]}, { - type: 'int32', - data: [convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth] - }); - } - return backend.runWebGPUProgram(program, [x], x.dtype, dimensions); + return poolImpl(x, convInfo, 'max', backend); } export const maxPoolConfig: KernelConfig = { diff --git a/tfjs-backend-webgpu/src/kernels/Pool_impl.ts b/tfjs-backend-webgpu/src/kernels/Pool_impl.ts new file mode 100644 index 00000000000..c853c5aad42 --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/Pool_impl.ts @@ -0,0 +1,96 @@ +/** + * @license + * Copyright 2022 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {backend_util, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {WebGPUBackend} from '../backend_webgpu'; +import {Pool2DProgram} from '../pool2d_webgpu'; +import {PoolWithFilterSizeEqualsOneProgram} from '../pool_filtersizeone_webgpu'; + +import {identity} from './Identity'; +import {max} from './Max'; +import {mean} from './Mean'; +import {reshape} from './Reshape'; + +type PoolType = 'max'|'avg'; +export function poolImpl( + x: TensorInfo, convInfo: backend_util.Conv2DInfo, poolType: PoolType, + backend: WebGPUBackend): TensorInfo { + if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && + util.arraysEqual(convInfo.inShape, convInfo.outShape)) { + return identity({inputs: {x}, backend}); + } + + if (convInfo.filterWidth === convInfo.inWidth && + convInfo.filterHeight === convInfo.inHeight && convInfo.batchSize === 1 && + convInfo.padInfo.type === 'VALID') { + const length = x.shape.length; + const reshapeX = reshape({ + inputs: {x}, + backend, + attrs: { + shape: [ + x.shape[length - 3] * x.shape[length - 2] /* height * width */, + x.shape[length - 1] /* channel */ + ] + } + }); + let reduceX; + if (poolType === 'avg') { + reduceX = mean( + {inputs: {x: reshapeX}, backend, attrs: {axis: 0, keepDims: false}}); + } else { + util.assert(poolType === 'max', () => `Invalid pool type ${poolType}`); + reduceX = max({ + inputs: {x: reshapeX}, + backend, + attrs: {reductionIndices: 0, keepDims: false} + }); + } + + const result = reshape( + {inputs: {x: reduceX}, backend, attrs: {shape: convInfo.outShape}}); + backend.disposeData(reshapeX.dataId); + backend.disposeData(reduceX.dataId); + return result; + } + + let program: Pool2DProgram|PoolWithFilterSizeEqualsOneProgram; + const dimensions = + [{type: 'int32', data: [convInfo.strideHeight, convInfo.strideWidth]}]; + if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1) { + program = new PoolWithFilterSizeEqualsOneProgram(convInfo); + } else { + if (poolType === 'avg') { + program = new Pool2DProgram(convInfo, 'avg'); + } else { + util.assert(poolType === 'max', () => `Invalid pool type ${poolType}`); + program = new Pool2DProgram(convInfo, 'max'); + } + + dimensions.push( + {type: 'int32', data: [convInfo.padInfo.top, convInfo.padInfo.left]}, { + type: 'int32', + data: [convInfo.dilationHeight, convInfo.dilationWidth] + }, + {type: 'int32', data: [convInfo.inHeight, convInfo.inWidth]}, { + type: 'int32', + data: [convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth] + }); + } + + return backend.runWebGPUProgram(program, [x], x.dtype, dimensions); +} diff --git a/tfjs-core/src/ops/avg_pool_test.ts b/tfjs-core/src/ops/avg_pool_test.ts index a4c0a21fb66..aea4b3590d0 100644 --- a/tfjs-core/src/ops/avg_pool_test.ts +++ b/tfjs-core/src/ops/avg_pool_test.ts @@ -93,6 +93,17 @@ describeWithFlags('avgPool', ALL_ENVS, () => { expectArraysClose(await result.data(), [2.5, 3, 3.5, 4]); }); + it('x=[2,2,3] f=[2,2] s=1 p=valid', async () => { + // Feed forward. + const a = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]); + const fSize = 2; + const strides = 1; + const result = tf.avgPool(a, fSize, strides, 'valid'); + + expect(result.shape).toEqual([1, 1, 3]); + expectArraysClose(await result.data(), [5.5, 6.5, 7.5]); + }); + it('x=[3,3,1] f=[3,3] s=1 p=explicit', async () => { // Feed forward. const x = tf.tensor3d([0, 1, 2, 3, 4, 5, 6, 7, 8], [3, 3, 1]);