|
16 | 16 | */ |
17 | 17 |
|
18 | 18 | import {backend_util, util} from '@tensorflow/tfjs-core'; |
19 | | - |
| 19 | +import {typeSnippet} from './activation_util'; |
| 20 | +import {makeMatMulPackedVec4Source} from './matmul_packed_vec4_webgpu'; |
20 | 21 | import {makeMatMulPackedSource} from './matmul_packed_webgpu'; |
21 | 22 | import {WebGPUProgram} from './webgpu_program'; |
22 | 23 | import {computeDispatch, computeWorkGroupSizeForConv2d, computeWorkPerThreadForConv2d} from './webgpu_util'; |
23 | 24 |
|
| 25 | +function conv2dTransposeCommonSnippet(innerElementSize = 4) { |
| 26 | + const getWSnippet = (innerElementSize: number) => { |
| 27 | + switch (innerElementSize) { |
| 28 | + case 1: |
| 29 | + return 'return W[getIndexFromCoords4D(coord, uniforms.wShape)];'; |
| 30 | + case 4: |
| 31 | + return ` |
| 32 | + let coord1 = vec4<i32>(coordX, coordY, col + 1, rowInner); |
| 33 | + let coord2 = vec4<i32>(coordX, coordY, col + 2, rowInner); |
| 34 | + let coord3 = vec4<i32>(coordX, coordY, col + 3, rowInner); |
| 35 | + let v0 = W[getIndexFromCoords4D(coord, uniforms.wShape)]; |
| 36 | + let v1 = W[getIndexFromCoords4D(coord1, uniforms.wShape)]; |
| 37 | + let v2 = W[getIndexFromCoords4D(coord2, uniforms.wShape)]; |
| 38 | + let v3 = W[getIndexFromCoords4D(coord3, uniforms.wShape)]; |
| 39 | + return vec4<f32>(v0, v1, v2, v3); |
| 40 | + `; |
| 41 | + default: |
| 42 | + throw new Error( |
| 43 | + `innerElementSize ${innerElementSize} is not supported.`); |
| 44 | + } |
| 45 | + }; |
| 46 | + |
| 47 | + const readASnippet = ` |
| 48 | + let outRow = row / uniforms.outShape[2]; |
| 49 | + let outCol = row % uniforms.outShape[2]; |
| 50 | +
|
| 51 | + let WRow = col / (uniforms.filterDims[1] * uniforms.outBackprop[3]); |
| 52 | + let WCol = col / uniforms.outBackprop[3] % uniforms.filterDims[1]; |
| 53 | + let xR = f32(outRow - uniforms.pads[0] + WRow) / f32(uniforms.stride[0]); |
| 54 | + let xC = f32(outCol - uniforms.pads[1] + WCol) / f32(uniforms.stride[1]); |
| 55 | + if (xR < 0.0 || xR >= f32(uniforms.outBackprop[1]) || fract(xR) > 0.0) { |
| 56 | + return ${typeSnippet(innerElementSize)}(0.0); |
| 57 | + } |
| 58 | + if (xC < 0.0 || xC >= f32(uniforms.outBackprop[2]) || fract(xC) > 0.0) { |
| 59 | + return ${typeSnippet(innerElementSize)}(0.0); |
| 60 | + } |
| 61 | + let coord = vec4<i32>( |
| 62 | + batch, |
| 63 | + i32(xR), |
| 64 | + i32(xC), |
| 65 | + col % uniforms.outBackprop[3]); |
| 66 | + return x[getIndexFromCoords4D(coord, uniforms.xShape)/${ |
| 67 | + innerElementSize}];`; |
| 68 | + |
| 69 | + const sampleA = `if (row < uniforms.dimAOuter && col < uniforms.dimInner) { |
| 70 | + ${readASnippet} |
| 71 | + } |
| 72 | + return ${typeSnippet(innerElementSize)}(0.0);`; |
| 73 | + |
| 74 | + const userCode = ` |
| 75 | + fn mm_readA(row : i32, colIn : i32, globalId : vec3<u32>) -> ${ |
| 76 | + typeSnippet(innerElementSize)} { |
| 77 | + let col = colIn * ${innerElementSize}; |
| 78 | + var batch = i32(globalId.z); |
| 79 | + ${sampleA} |
| 80 | + } |
| 81 | +
|
| 82 | + fn mm_readB(row : i32, colIn : i32, globalId : vec3<u32>) -> ${ |
| 83 | + typeSnippet(innerElementSize)} { |
| 84 | + let col = colIn * ${innerElementSize}; |
| 85 | + let coordX = uniforms.filterDims.x - 1 - |
| 86 | + row / (uniforms.filterDims[1] * uniforms.outBackprop[3]); |
| 87 | + let coordY = uniforms.filterDims.y - 1 - |
| 88 | + (row / uniforms.outBackprop[3]) % uniforms.filterDims[1]; |
| 89 | + if (row < uniforms.dimInner && col < uniforms.dimBOuter && |
| 90 | + coordX >= 0 && coordY >= 0) { |
| 91 | + let rowInner = row % uniforms.outBackprop[3]; |
| 92 | + let coord = vec4<i32>(coordX, coordY, col, rowInner); |
| 93 | + ${getWSnippet(innerElementSize)} |
| 94 | + } |
| 95 | + return ${typeSnippet(innerElementSize)}(0.0); |
| 96 | + } |
| 97 | +
|
| 98 | + fn mm_write(row : i32, colIn : i32, valueInput : ${ |
| 99 | + typeSnippet(innerElementSize)}, globalId : vec3<u32>) { |
| 100 | + let col = colIn * ${innerElementSize}; |
| 101 | + if (row < uniforms.dimAOuter && (col + ${ |
| 102 | + innerElementSize - 1}) < uniforms.dimBOuter) { |
| 103 | + var batch = i32(globalId.z); |
| 104 | + var value = valueInput; |
| 105 | + let outCoord = vec4<i32>( |
| 106 | + batch, |
| 107 | + row / uniforms.outShape[2], |
| 108 | + row % uniforms.outShape[2], |
| 109 | + col); |
| 110 | + result[getIndexFromCoords4D(outCoord, uniforms.outShape)/${ |
| 111 | + innerElementSize}] = value; |
| 112 | + } |
| 113 | + }`; |
| 114 | + return userCode; |
| 115 | +} |
| 116 | + |
24 | 117 | export class Conv2DDerInputMMProgram implements WebGPUProgram { |
25 | 118 | outputShape: number[]; |
26 | 119 | shaderKey: string; |
27 | 120 | dispatchLayout: {x: number[], y: number[], z: number[]}; |
28 | 121 | dispatch: [number, number, number]; |
29 | 122 | variableNames = ['x', 'W']; |
| 123 | + variableTypes: string[]; |
30 | 124 | uniforms = |
31 | 125 | 'filterDims : vec2<i32>, pads : vec2<i32>, stride : vec2<i32>, outBackprop : vec4<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32,'; |
32 | 126 | workGroupSize: [number, number, number]; |
33 | 127 | elementsPerThread: [number, number, number]; |
| 128 | + tileAOuter: number; |
| 129 | + tileBOuter: number; |
| 130 | + tileInner: number; |
| 131 | + innerElementSize: number; |
| 132 | + isVec4?: boolean; |
34 | 133 |
|
35 | 134 | constructor(convInfo: backend_util.Conv2DInfo) { |
36 | 135 | this.outputShape = convInfo.inShape; |
37 | 136 |
|
38 | 137 | util.assert( |
39 | 138 | convInfo.dataFormat === 'channelsLast', |
40 | 139 | () => 'TODO: NCHW is unimplemented'); |
| 140 | + this.isVec4 = |
| 141 | + convInfo.inChannels % 4 === 0 && convInfo.outChannels % 4 === 0; |
41 | 142 | this.dispatchLayout = {x: [3], y: [1, 2], z: [0]}; |
42 | | - this.workGroupSize = |
43 | | - computeWorkGroupSizeForConv2d(this.dispatchLayout, this.outputShape); |
44 | | - this.elementsPerThread = |
45 | | - computeWorkPerThreadForConv2d(this.dispatchLayout, this.outputShape); |
| 143 | + this.workGroupSize = computeWorkGroupSizeForConv2d( |
| 144 | + this.dispatchLayout, this.outputShape, this.isVec4); |
| 145 | + this.elementsPerThread = computeWorkPerThreadForConv2d( |
| 146 | + this.dispatchLayout, this.outputShape, this.isVec4); |
46 | 147 |
|
47 | 148 | this.dispatch = computeDispatch( |
48 | 149 | this.dispatchLayout, this.outputShape, this.workGroupSize, |
49 | 150 | this.elementsPerThread); |
50 | 151 |
|
51 | | - this.shaderKey = `conv2DDerInputMM_${this.elementsPerThread}`; |
| 152 | + if (this.isVec4) { |
| 153 | + this.innerElementSize = 4; |
| 154 | + this.variableTypes = ['vec4<f32>', 'f32']; |
| 155 | + } else { |
| 156 | + this.innerElementSize = this.elementsPerThread[0]; |
| 157 | + } |
| 158 | + this.tileAOuter = this.workGroupSize[1] * this.elementsPerThread[1]; |
| 159 | + this.tileBOuter = this.workGroupSize[0] * this.elementsPerThread[0]; |
| 160 | + this.tileInner = Math.max( |
| 161 | + this.workGroupSize[0] * this.innerElementSize, this.workGroupSize[1]); |
| 162 | + this.shaderKey = `conv2DDerInputMM_${this.isVec4}_${ |
| 163 | + this.elementsPerThread}_${this.innerElementSize}`; |
52 | 164 | } |
53 | 165 |
|
54 | 166 | getUserCode(): string { |
55 | | - const matMulSource = |
| 167 | + const matMulSource = this.isVec4 ? |
| 168 | + makeMatMulPackedVec4Source( |
| 169 | + this.elementsPerThread, this.tileAOuter, this.tileBOuter, |
| 170 | + this.tileInner, this.innerElementSize) : |
56 | 171 | makeMatMulPackedSource(this.elementsPerThread, this.workGroupSize); |
57 | | - |
58 | | - const readASnippet = ` |
59 | | - let outRow = row / uniforms.outShape[2]; |
60 | | - let outCol = row % uniforms.outShape[2]; |
61 | | -
|
62 | | - let WRow = col / (uniforms.filterDims[1] * uniforms.outBackprop[3]); |
63 | | - let WCol = col / uniforms.outBackprop[3] % uniforms.filterDims[1]; |
64 | | - let xR = f32(outRow - uniforms.pads[0] + WRow) / f32(uniforms.stride[0]); |
65 | | - let xC = f32(outCol - uniforms.pads[1] + WCol) / f32(uniforms.stride[1]); |
66 | | - if (xR < 0.0 || xR >= f32(uniforms.outBackprop[1]) || fract(xR) > 0.0) { |
67 | | - return 0.0; |
68 | | - } |
69 | | - if (xC < 0.0 || xC >= f32(uniforms.outBackprop[2]) || fract(xC) > 0.0) { |
70 | | - return 0.0; |
71 | | - } |
72 | | - let coord = vec4<i32>( |
73 | | - batch, |
74 | | - i32(xR), |
75 | | - i32(xC), |
76 | | - col % uniforms.outBackprop[3]); |
77 | | - return x[getIndexFromCoords4D(coord, uniforms.xShape)];`; |
78 | | - |
79 | | - const sampleA = `if (row < uniforms.dimAOuter && col < uniforms.dimInner) { |
80 | | - ${readASnippet} |
81 | | - } |
82 | | - return 0.0;`; |
83 | | - |
84 | 172 | const userCode = ` |
85 | | - fn mm_readA(row : i32, col : i32, globalId : vec3<u32>) -> f32 { |
86 | | - var batch = i32(globalId.z); |
87 | | - ${sampleA} |
88 | | - } |
89 | | -
|
90 | | - fn mm_readB(row : i32, col : i32, globalId : vec3<u32>) -> f32 { |
91 | | - let coordX = uniforms.filterDims.x - 1 - |
92 | | - row / (uniforms.filterDims[1] * uniforms.outBackprop[3]); |
93 | | - let coordY = uniforms.filterDims.y - 1 - |
94 | | - (row / uniforms.outBackprop[3]) % uniforms.filterDims[1]; |
95 | | - if (row < uniforms.dimInner && col < uniforms.dimBOuter && |
96 | | - coordX >= 0 && coordY >= 0) { |
97 | | - let coord = vec4<i32>(coordX, coordY, col, |
98 | | - row % uniforms.outBackprop[3]); |
99 | | - return W[getIndexFromCoords4D(coord, uniforms.wShape)]; |
100 | | - } |
101 | | - return 0.0; |
102 | | - } |
103 | | -
|
104 | | - fn mm_write(row : i32, col : i32, valueInput : f32, globalId : vec3<u32>) { |
105 | | - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) |
106 | | - { |
107 | | - var batch = i32(globalId.z); |
108 | | - var value = valueInput; |
109 | | - let outCoord = vec4<i32>( |
110 | | - batch, |
111 | | - row / uniforms.outShape[2], |
112 | | - row % uniforms.outShape[2], |
113 | | - col); |
114 | | - result[getIndexFromCoords4D(outCoord, uniforms.outShape)] = value; |
115 | | - } |
116 | | - } |
117 | | -
|
| 173 | + ${conv2dTransposeCommonSnippet(this.isVec4 ? 4 : 1)} |
118 | 174 | ${matMulSource} |
119 | | - `; |
| 175 | + `; |
120 | 176 | return userCode; |
121 | 177 | } |
122 | 178 | } |
0 commit comments