Skip to content

Commit 5edb812

Browse files
authored
[WebGL] Fix NHWC packed depthwise conv2d for dilation=3 (#6662)
FIX * Update conv_packed_gpu_depthwise.ts * add test
1 parent 17c0b71 commit 5edb812

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

tfjs-backend-webgl/src/conv_packed_gpu_depthwise.ts

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,18 +193,21 @@ export class DepthwiseConvPacked2DProgram implements GPGPUProgram {
193193
if (dilationWidth > 1) {
194194
mainLoop += `
195195
xCOffset -= 2;
196-
if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${
197-
colIndex}Ready == 0) {
198-
xTexelC${colIndex} = getX(batch, xR, xCOffset, d1);
199-
xTexelC${colIndex}Ready = 1;
196+
if (xCOffset >= 0 && xCOffset < inDims[1]) {
197+
previous = getX(batch, xR, xCOffset, d1);
198+
xC${colIndex + 1} = vec4(previous.zw, xTexelC${
199+
colIndex + 1}.xy);
200+
} else {
201+
xC${colIndex + 1} = vec4(0.0, 0.0, xTexelC${
202+
colIndex + 1}.xy);
200203
}
201204
`;
205+
} else {
206+
mainLoop += `
207+
xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${
208+
colIndex + 1}.xy);
209+
`;
202210
}
203-
204-
mainLoop += `
205-
xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${
206-
colIndex + 1}.xy);
207-
`;
208211
} else {
209212
// If dilation is 1 and padding is odd, we have already read the
210213
// texel when constructing the previous x value. Here we can

tfjs-core/src/ops/depthwise_conv2d_test.ts

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,37 @@ describeWithFlags('depthwiseConv2D', ALL_ENVS, () => {
717717
]);
718718
});
719719

720+
it('input=1x8x8x2,f=3,s=1,d=3,p=valid,chMul=1', async () => {
721+
const fSize = 3;
722+
const pad = 'valid';
723+
const stride = 1;
724+
const inDepth = 2;
725+
const dilation = 3;
726+
727+
const x = tf.tensor4d(
728+
[
729+
1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8,
730+
9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16,
731+
17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24,
732+
25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30, 31, 31, 32, 32,
733+
33, 33, 34, 34, 35, 35, 36, 36, 37, 37, 38, 38, 39, 39, 40, 40,
734+
41, 41, 42, 42, 43, 43, 44, 44, 45, 45, 46, 46, 47, 47, 48, 48,
735+
49, 49, 50, 50, 51, 51, 52, 52, 53, 53, 54, 54, 55, 55, 56, 56,
736+
57, 57, 58, 58, 59, 59, 60, 60, 61, 61, 62, 62, 63, 63, 64, 64
737+
],
738+
[1, 8, 8, inDepth]);
739+
740+
const w = tf.tensor4d(
741+
[9, 1, 8, 2, 7, 3, 6, 4, 5, 5, 4, 6, 3, 7, 2, 8, 1, 9],
742+
[fSize, fSize, inDepth, 1],
743+
);
744+
const result = tf.depthwiseConv2d(x, w, stride, pad, 'NHWC', dilation);
745+
746+
expect(result.shape).toEqual([1, 2, 2, 2]);
747+
expectArraysClose(
748+
await result.data(), [810, 1710, 855, 1755, 1170, 2070, 1215, 2115]);
749+
});
750+
720751
it('Tensor3D is allowed', async () => {
721752
const fSize = 2;
722753
const pad = 'same';

0 commit comments

Comments
 (0)