@@ -20,14 +20,13 @@ import {mapActivationToShaderProgram} from './activation_util';
2020import { getWorkGroupSizeString , WebGPUProgram } from './webgpu_program' ;
2121import { computeDispatch } from './webgpu_util' ;
2222
23- export class DepthwiseConv2D3x3Program implements WebGPUProgram {
23+ export class DepthwiseConv2DVec4Program implements WebGPUProgram {
2424 outputShape : number [ ] ;
2525 shaderKey : string ;
2626 dispatchLayout : { x : number [ ] , y : number [ ] , z : number [ ] } ;
2727 dispatch : [ number , number , number ] ;
2828 variableNames = [ 'x' , 'W' ] ;
29- uniforms =
30- 'pad : vec2<i32>, stride : vec2<i32>, dilation : vec2<i32>, inDims : vec2<i32>,' ;
29+ uniforms = 'pad : vec2<i32>, inDims : vec2<i32>,' ;
3130 workGroupSize : [ number , number , number ] = [ 4 , 4 , 4 ] ;
3231 convInfo : backend_util . Conv2DInfo ;
3332 addBias : boolean ;
@@ -39,9 +38,9 @@ export class DepthwiseConv2D3x3Program implements WebGPUProgram {
3938 convInfo : backend_util . Conv2DInfo , addBias = false ,
4039 activation : backend_util . Activation = null , hasPreluActivation = false ) {
4140 this . outputShape = convInfo . outShape ;
42- this . dispatchLayout = { x : [ 0 , 1 ] , y : [ 2 ] , z : [ 3 ] } ;
41+ this . dispatchLayout = { x : [ 3 ] , y : [ 2 ] , z : [ 0 , 1 ] } ;
4342 this . dispatch = computeDispatch (
44- this . dispatchLayout , this . outputShape , this . workGroupSize , [ 1 , 4 , 4 ] ) ;
43+ this . dispatchLayout , this . outputShape , this . workGroupSize , [ 4 , 4 , 1 ] ) ;
4544
4645 util . assert (
4746 convInfo . dataFormat === 'channelsLast' ,
@@ -59,7 +58,8 @@ export class DepthwiseConv2D3x3Program implements WebGPUProgram {
5958 this . activation = activation ;
6059 this . hasPreluActivation = hasPreluActivation ;
6160
62- this . shaderKey = `depthwise3x3_${ activation } ` ;
61+ this . shaderKey = `depthwiseVec4_${ activation } _${
62+ this . convInfo . filterHeight } _${ this . convInfo . filterWidth } `;
6363 }
6464
6565 getUserCode ( ) : string {
@@ -87,65 +87,53 @@ export class DepthwiseConv2D3x3Program implements WebGPUProgram {
8787 const addBiasSnippet = this . addBias ?
8888 'dotProd[i] = dotProd[i] + getBiasByOutputCoords(coords);' :
8989 '' ;
90-
90+ // Here 4 is the work per thread in X dimension.
91+ const xNumber = 4 + this . convInfo . filterWidth - 1 ;
9192 const userCode = `
9293 ${ activationSnippet }
93-
94+ fn readX(batch : i32, row : i32, col : i32, channel : i32) -> vec4<f32> {
95+ var value = vec4<f32>(0.0);
96+ if (row >=0 && row < uniforms.inDims[0] && col >=0 && col < uniforms.inDims[1])
97+ {
98+ value = getX(batch, row, col, channel);
99+ }
100+ return value;
101+ }
94102 ${ getWorkGroupSizeString ( ) }
95103 fn main(@builtin(global_invocation_id) globalId: vec3<u32>) {
96- let batch = 0 ;
97- let r = i32(globalId.x) ;
104+ let batch = i32(globalId.z) / uniforms.outShape[1] ;
105+ let r = i32(globalId.z) % uniforms.outShape[1] ;
98106 let c = i32(globalId.y) * 4;
99- let d2 = i32(globalId.z) * 4;
100- let xRCCorner = vec2<i32>(r, c) * uniforms.stride - uniforms.pad;
101- let d1 = d2;
102- let q = 0;
107+ let d1 = i32(globalId.x) * 4;
108+ let xRCCorner = vec2<i32>(r, c) - uniforms.pad;
103109
104110 let xRCorner = xRCCorner.x;
105111 let xCCorner = xRCCorner.y;
106-
107- var wVals : array<vec4<f32>, 9>;
108- wVals[0] = getW(0, 0, d1, q);
109- wVals[1] = getW(0, 1, d1, q);
110- wVals[2] = getW(0, 2, d1, q);
111- wVals[3] = getW(1, 0, d1, q);
112- wVals[4] = getW(1, 1, d1, q);
113- wVals[5] = getW(1, 2, d1, q);
114- wVals[6] = getW(2, 0, d1, q);
115- wVals[7] = getW(2, 1, d1, q);
116- wVals[8] = getW(2, 2, d1, q);
117-
118- var xVals : array<array<vec4<f32>, 6>, 3>;
119- for (var wR = 0; wR < 3; wR = wR + 1) {
120- let xR = xRCorner + wR * uniforms.dilation[0];
121- for (var wC = 0; wC < 6; wC = wC + 1) {
122- let xC = xCCorner + wC * uniforms.dilation[1];
123- if (xR < 0 || xR >= uniforms.inDims[0] || xC < 0 || xC >= uniforms.inDims[1]) {
124- xVals[wR][wC] = vec4<f32>(0.0);
125- } else {
126- xVals[wR][wC] = getX(batch, xR, xC, d1);
127- }
128- }
129- }
130-
112+ var xVals : array<vec4<f32>, ${ xNumber } >;
131113 var dotProd : array<vec4<f32>, 4>;
132114 dotProd[0] = vec4<f32>(0.0);
133115 dotProd[1] = vec4<f32>(0.0);
134116 dotProd[2] = vec4<f32>(0.0);
135117 dotProd[3] = vec4<f32>(0.0);
136118
137- for (var wR = 0; wR < 3; wR = wR + 1) {
138- for (var wC = 0; wC < 3; wC = wC + 1) {
139- let indexW = wR * 3 + wC;
140- dotProd[0] = dotProd[0] + xVals[wR][0 + wC] * wVals[indexW];
141- dotProd[1] = dotProd[1] + xVals[wR][1 + wC] * wVals[indexW];
142- dotProd[2] = dotProd[2] + xVals[wR][2 + wC] * wVals[indexW];
143- dotProd[3] = dotProd[3] + xVals[wR][3 + wC] * wVals[indexW];
119+ // Use constant instead of uniform can give better performance.
120+ for (var wR = 0; wR < ${ this . convInfo . filterHeight } ; wR = wR + 1) {
121+ let xR = xRCorner + wR;
122+ for (var i = 0; i < ${ xNumber } ; i++)
123+ {
124+ xVals[i] = readX(batch, xR, xCCorner + i, d1);
125+ }
126+ for (var wC = 0; wC < ${ this . convInfo . filterWidth } ; wC = wC + 1) {
127+ let wValue = getW(wR, wC, d1, 0);
128+ dotProd[0] = dotProd[0] + xVals[0 + wC] * wValue;
129+ dotProd[1] = dotProd[1] + xVals[1 + wC] * wValue;
130+ dotProd[2] = dotProd[2] + xVals[2 + wC] * wValue;
131+ dotProd[3] = dotProd[3] + xVals[3 + wC] * wValue;
144132 }
145133 }
146134
147135 for (var i = 0; i < 4; i = i + 1) {
148- let coords = vec4<i32>(batch, r, c + i, d2 );
136+ let coords = vec4<i32>(batch, r, c + i, d1 );
149137 if (coordsInBounds4D(coords, uniforms.outShape)) {
150138 ${ addBiasSnippet }
151139 ${ applyActivationSnippet }
0 commit comments