Skip to content

Commit 4939453

Browse files
authored
[webgpu] Support convTranspose vec4 (#6603)
1 parent 36a0548 commit 4939453

File tree

4 files changed

+178
-75
lines changed

4 files changed

+178
-75
lines changed

tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts

Lines changed: 124 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -16,107 +16,163 @@
1616
*/
1717

1818
import {backend_util, util} from '@tensorflow/tfjs-core';
19-
19+
import {typeSnippet} from './activation_util';
20+
import {makeMatMulPackedVec4Source} from './matmul_packed_vec4_webgpu';
2021
import {makeMatMulPackedSource} from './matmul_packed_webgpu';
2122
import {WebGPUProgram} from './webgpu_program';
2223
import {computeDispatch, computeWorkGroupSizeForConv2d, computeWorkPerThreadForConv2d} from './webgpu_util';
2324

25+
function conv2dTransposeCommonSnippet(innerElementSize = 4) {
26+
const getWSnippet = (innerElementSize: number) => {
27+
switch (innerElementSize) {
28+
case 1:
29+
return 'return W[getIndexFromCoords4D(coord, uniforms.wShape)];';
30+
case 4:
31+
return `
32+
let coord1 = vec4<i32>(coordX, coordY, col + 1, rowInner);
33+
let coord2 = vec4<i32>(coordX, coordY, col + 2, rowInner);
34+
let coord3 = vec4<i32>(coordX, coordY, col + 3, rowInner);
35+
let v0 = W[getIndexFromCoords4D(coord, uniforms.wShape)];
36+
let v1 = W[getIndexFromCoords4D(coord1, uniforms.wShape)];
37+
let v2 = W[getIndexFromCoords4D(coord2, uniforms.wShape)];
38+
let v3 = W[getIndexFromCoords4D(coord3, uniforms.wShape)];
39+
return vec4<f32>(v0, v1, v2, v3);
40+
`;
41+
default:
42+
throw new Error(
43+
`innerElementSize ${innerElementSize} is not supported.`);
44+
}
45+
};
46+
47+
const readASnippet = `
48+
let outRow = row / uniforms.outShape[2];
49+
let outCol = row % uniforms.outShape[2];
50+
51+
let WRow = col / (uniforms.filterDims[1] * uniforms.outBackprop[3]);
52+
let WCol = col / uniforms.outBackprop[3] % uniforms.filterDims[1];
53+
let xR = f32(outRow - uniforms.pads[0] + WRow) / f32(uniforms.stride[0]);
54+
let xC = f32(outCol - uniforms.pads[1] + WCol) / f32(uniforms.stride[1]);
55+
if (xR < 0.0 || xR >= f32(uniforms.outBackprop[1]) || fract(xR) > 0.0) {
56+
return ${typeSnippet(innerElementSize)}(0.0);
57+
}
58+
if (xC < 0.0 || xC >= f32(uniforms.outBackprop[2]) || fract(xC) > 0.0) {
59+
return ${typeSnippet(innerElementSize)}(0.0);
60+
}
61+
let coord = vec4<i32>(
62+
batch,
63+
i32(xR),
64+
i32(xC),
65+
col % uniforms.outBackprop[3]);
66+
return x[getIndexFromCoords4D(coord, uniforms.xShape)/${
67+
innerElementSize}];`;
68+
69+
const sampleA = `if (row < uniforms.dimAOuter && col < uniforms.dimInner) {
70+
${readASnippet}
71+
}
72+
return ${typeSnippet(innerElementSize)}(0.0);`;
73+
74+
const userCode = `
75+
fn mm_readA(row : i32, colIn : i32, globalId : vec3<u32>) -> ${
76+
typeSnippet(innerElementSize)} {
77+
let col = colIn * ${innerElementSize};
78+
var batch = i32(globalId.z);
79+
${sampleA}
80+
}
81+
82+
fn mm_readB(row : i32, colIn : i32, globalId : vec3<u32>) -> ${
83+
typeSnippet(innerElementSize)} {
84+
let col = colIn * ${innerElementSize};
85+
let coordX = uniforms.filterDims.x - 1 -
86+
row / (uniforms.filterDims[1] * uniforms.outBackprop[3]);
87+
let coordY = uniforms.filterDims.y - 1 -
88+
(row / uniforms.outBackprop[3]) % uniforms.filterDims[1];
89+
if (row < uniforms.dimInner && col < uniforms.dimBOuter &&
90+
coordX >= 0 && coordY >= 0) {
91+
let rowInner = row % uniforms.outBackprop[3];
92+
let coord = vec4<i32>(coordX, coordY, col, rowInner);
93+
${getWSnippet(innerElementSize)}
94+
}
95+
return ${typeSnippet(innerElementSize)}(0.0);
96+
}
97+
98+
fn mm_write(row : i32, colIn : i32, valueInput : ${
99+
typeSnippet(innerElementSize)}, globalId : vec3<u32>) {
100+
let col = colIn * ${innerElementSize};
101+
if (row < uniforms.dimAOuter && (col + ${
102+
innerElementSize - 1}) < uniforms.dimBOuter) {
103+
var batch = i32(globalId.z);
104+
var value = valueInput;
105+
let outCoord = vec4<i32>(
106+
batch,
107+
row / uniforms.outShape[2],
108+
row % uniforms.outShape[2],
109+
col);
110+
result[getIndexFromCoords4D(outCoord, uniforms.outShape)/${
111+
innerElementSize}] = value;
112+
}
113+
}`;
114+
return userCode;
115+
}
116+
24117
export class Conv2DDerInputMMProgram implements WebGPUProgram {
25118
outputShape: number[];
26119
shaderKey: string;
27120
dispatchLayout: {x: number[], y: number[], z: number[]};
28121
dispatch: [number, number, number];
29122
variableNames = ['x', 'W'];
123+
variableTypes: string[];
30124
uniforms =
31125
'filterDims : vec2<i32>, pads : vec2<i32>, stride : vec2<i32>, outBackprop : vec4<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32,';
32126
workGroupSize: [number, number, number];
33127
elementsPerThread: [number, number, number];
128+
tileAOuter: number;
129+
tileBOuter: number;
130+
tileInner: number;
131+
innerElementSize: number;
132+
isVec4?: boolean;
34133

35134
constructor(convInfo: backend_util.Conv2DInfo) {
36135
this.outputShape = convInfo.inShape;
37136

38137
util.assert(
39138
convInfo.dataFormat === 'channelsLast',
40139
() => 'TODO: NCHW is unimplemented');
140+
this.isVec4 =
141+
convInfo.inChannels % 4 === 0 && convInfo.outChannels % 4 === 0;
41142
this.dispatchLayout = {x: [3], y: [1, 2], z: [0]};
42-
this.workGroupSize =
43-
computeWorkGroupSizeForConv2d(this.dispatchLayout, this.outputShape);
44-
this.elementsPerThread =
45-
computeWorkPerThreadForConv2d(this.dispatchLayout, this.outputShape);
143+
this.workGroupSize = computeWorkGroupSizeForConv2d(
144+
this.dispatchLayout, this.outputShape, this.isVec4);
145+
this.elementsPerThread = computeWorkPerThreadForConv2d(
146+
this.dispatchLayout, this.outputShape, this.isVec4);
46147

47148
this.dispatch = computeDispatch(
48149
this.dispatchLayout, this.outputShape, this.workGroupSize,
49150
this.elementsPerThread);
50151

51-
this.shaderKey = `conv2DDerInputMM_${this.elementsPerThread}`;
152+
if (this.isVec4) {
153+
this.innerElementSize = 4;
154+
this.variableTypes = ['vec4<f32>', 'f32'];
155+
} else {
156+
this.innerElementSize = this.elementsPerThread[0];
157+
}
158+
this.tileAOuter = this.workGroupSize[1] * this.elementsPerThread[1];
159+
this.tileBOuter = this.workGroupSize[0] * this.elementsPerThread[0];
160+
this.tileInner = Math.max(
161+
this.workGroupSize[0] * this.innerElementSize, this.workGroupSize[1]);
162+
this.shaderKey = `conv2DDerInputMM_${this.isVec4}_${
163+
this.elementsPerThread}_${this.innerElementSize}`;
52164
}
53165

54166
getUserCode(): string {
55-
const matMulSource =
167+
const matMulSource = this.isVec4 ?
168+
makeMatMulPackedVec4Source(
169+
this.elementsPerThread, this.tileAOuter, this.tileBOuter,
170+
this.tileInner, this.innerElementSize) :
56171
makeMatMulPackedSource(this.elementsPerThread, this.workGroupSize);
57-
58-
const readASnippet = `
59-
let outRow = row / uniforms.outShape[2];
60-
let outCol = row % uniforms.outShape[2];
61-
62-
let WRow = col / (uniforms.filterDims[1] * uniforms.outBackprop[3]);
63-
let WCol = col / uniforms.outBackprop[3] % uniforms.filterDims[1];
64-
let xR = f32(outRow - uniforms.pads[0] + WRow) / f32(uniforms.stride[0]);
65-
let xC = f32(outCol - uniforms.pads[1] + WCol) / f32(uniforms.stride[1]);
66-
if (xR < 0.0 || xR >= f32(uniforms.outBackprop[1]) || fract(xR) > 0.0) {
67-
return 0.0;
68-
}
69-
if (xC < 0.0 || xC >= f32(uniforms.outBackprop[2]) || fract(xC) > 0.0) {
70-
return 0.0;
71-
}
72-
let coord = vec4<i32>(
73-
batch,
74-
i32(xR),
75-
i32(xC),
76-
col % uniforms.outBackprop[3]);
77-
return x[getIndexFromCoords4D(coord, uniforms.xShape)];`;
78-
79-
const sampleA = `if (row < uniforms.dimAOuter && col < uniforms.dimInner) {
80-
${readASnippet}
81-
}
82-
return 0.0;`;
83-
84172
const userCode = `
85-
fn mm_readA(row : i32, col : i32, globalId : vec3<u32>) -> f32 {
86-
var batch = i32(globalId.z);
87-
${sampleA}
88-
}
89-
90-
fn mm_readB(row : i32, col : i32, globalId : vec3<u32>) -> f32 {
91-
let coordX = uniforms.filterDims.x - 1 -
92-
row / (uniforms.filterDims[1] * uniforms.outBackprop[3]);
93-
let coordY = uniforms.filterDims.y - 1 -
94-
(row / uniforms.outBackprop[3]) % uniforms.filterDims[1];
95-
if (row < uniforms.dimInner && col < uniforms.dimBOuter &&
96-
coordX >= 0 && coordY >= 0) {
97-
let coord = vec4<i32>(coordX, coordY, col,
98-
row % uniforms.outBackprop[3]);
99-
return W[getIndexFromCoords4D(coord, uniforms.wShape)];
100-
}
101-
return 0.0;
102-
}
103-
104-
fn mm_write(row : i32, col : i32, valueInput : f32, globalId : vec3<u32>) {
105-
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter)
106-
{
107-
var batch = i32(globalId.z);
108-
var value = valueInput;
109-
let outCoord = vec4<i32>(
110-
batch,
111-
row / uniforms.outShape[2],
112-
row % uniforms.outShape[2],
113-
col);
114-
result[getIndexFromCoords4D(outCoord, uniforms.outShape)] = value;
115-
}
116-
}
117-
173+
${conv2dTransposeCommonSnippet(this.isVec4 ? 4 : 1)}
118174
${matMulSource}
119-
`;
175+
`;
120176
return userCode;
121177
}
122178
}

tfjs-backend-webgpu/src/setup_test.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ const TEST_FILTERS: TestFilter[] = [
6060
'gradient', // gradient function not found.
6161
]
6262
},
63+
{
64+
startsWith: 'conv2dTranspose ',
65+
excludes: [
66+
'gradient', // gradient function not found.
67+
]
68+
},
6369
{
6470
startsWith: 'cumprod ',
6571
excludes: [
@@ -283,7 +289,6 @@ const TEST_FILTERS: TestFilter[] = [
283289
'avgPool3dBackprop ',
284290
'bincount ',
285291
'broadcastArgs ',
286-
'conv2dTranspose ',
287292
'conv2DBackpropFilter ',
288293
'gradient with clones, input=2x2x1,d2=1,f=1,s=1,d=1,p=same', // Conv2DBackpropFilter
289294
'conv1d gradients', // Conv2DBackpropFilter

tfjs-core/src/ops/conv2d_transpose_test.ts

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,14 +207,14 @@ describeWithFlags('conv2dTranspose', ALL_ENVS, () => {
207207
inputShape);
208208
const w = tf.tensor4d(
209209
[
210-
0., 1., 2., 3., 4., 5., 6., 7., 8.,
211-
9., 10., 11., 12., 13., 14., 15.
210+
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
211+
15.
212212
],
213213
[fSize, fSize, origInputDepth, origOutputDepth]);
214214

215215
expect(
216216
() => tf.conv2dTranspose(
217-
x, w, [1, 3, 3, 1], origStride, origPad, dimRoundingMode))
217+
x, w, [1, 3, 3, 1], origStride, origPad, dimRoundingMode))
218218
.toThrowError();
219219
});
220220

@@ -239,14 +239,14 @@ describeWithFlags('conv2dTranspose', ALL_ENVS, () => {
239239
inputShape);
240240
const w = tf.tensor4d(
241241
[
242-
0., 1., 2., 3., 4., 5., 6., 7., 8.,
243-
9., 10., 11., 12., 13., 14., 15.
242+
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
243+
15.
244244
],
245245
[fSize, fSize, origInputDepth, origOutputDepth]);
246246

247247
expect(
248248
() => tf.conv2dTranspose(
249-
x, w, [1, 3, 3, 1], origStride, origPad, dimRoundingMode))
249+
x, w, [1, 3, 3, 1], origStride, origPad, dimRoundingMode))
250250
.toThrowError();
251251
});
252252

@@ -666,4 +666,44 @@ describeWithFlags('conv2dTranspose', ALL_ENVS, () => {
666666
expect(result.shape).toEqual([2, 2, 1]);
667667
expectArraysClose(await result.data(), expected);
668668
});
669+
670+
it('input=8x8x8,output=4x4x8,f=8,s=1,inDepth=8,p=same vec4', async () => {
671+
const origInputDepth = 8;
672+
const origOutputDepth = 8;
673+
const inputShape: [number, number, number, number] =
674+
[1, 8, 8, origOutputDepth];
675+
const fSize = 8;
676+
const origPad = 'same';
677+
const origStride: [number, number] = [1, 1];
678+
const wShape: [number, number, number, number] =
679+
[fSize, fSize, origInputDepth, origOutputDepth];
680+
681+
const inputData = [];
682+
for (let i = 0; i < fSize * fSize * origInputDepth; i++) {
683+
inputData.push(i % 5);
684+
}
685+
const wData = [];
686+
for (let i = 0; i < fSize * fSize * origInputDepth * origOutputDepth; i++) {
687+
wData.push(i % 5);
688+
}
689+
690+
const x = tf.tensor4d(inputData, inputShape);
691+
const w = tf.tensor4d(wData, wShape);
692+
const result = tf.conv2dTranspose(
693+
x, w, [1, 4, 4, origInputDepth], origStride, origPad);
694+
expect(result.shape).toEqual([1, 4, 4, 8]);
695+
696+
const expected = [
697+
512, 533, 469, 550, 506, 512, 533, 469, 550, 506, 512, 533, 469, 550, 506,
698+
512, 533, 469, 550, 506, 512, 533, 469, 550, 506, 512, 533, 469, 550, 506,
699+
512, 533, 506, 512, 533, 469, 550, 506, 512, 533, 469, 550, 506, 512, 533,
700+
469, 550, 506, 512, 533, 469, 550, 506, 512, 533, 469, 550, 506, 512, 533,
701+
469, 550, 506, 512, 550, 506, 512, 533, 469, 550, 506, 512, 533, 469, 550,
702+
506, 512, 533, 469, 550, 506, 512, 533, 469, 550, 506, 512, 533, 469, 550,
703+
506, 512, 533, 469, 550, 506, 469, 550, 506, 512, 533, 469, 550, 506, 512,
704+
533, 469, 550, 506, 512, 533, 469, 550, 506, 512, 533, 469, 550, 506, 512,
705+
533, 469, 550, 506, 512, 533, 469, 550
706+
];
707+
expectArraysClose(await result.data(), expected);
708+
});
669709
});

tfjs-node/src/run_tests.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ const IGNORE_LIST: string[] = [
104104
// Node backend which uses TF 2.4.0 doesn't support explicit padding
105105
'conv2dTranspose test-tensorflow {} input=3x3x1,d2=1,f=2,s=2,p=explicit',
106106
// tslint:disable-next-line:max-line-length
107+
'conv2dTranspose test-tensorflow {} input=8x8x8,output=4x4x8,f=8,s=1,inDepth=8,p=same vec4',
108+
// tslint:disable-next-line:max-line-length
107109
'conv2dTranspose test-tensorflow {} gradient input=[1,3,3,1] f=[2,2,2,1] s=[1,1] p=explicit',
108110
'fused conv2d test-tensorflow {} basic in NCHW',
109111
'fused conv2d test-tensorflow {} im2row in NCHW',

0 commit comments

Comments
 (0)