|
15 | 15 | * ============================================================================= |
16 | 16 | */ |
17 | 17 |
|
18 | | -import {getCoordsDataType, getMainHeaderAndGlobalIndexString, WebGPUProgram} from './webgpu_program'; |
| 18 | +import {DataType} from '@tensorflow/tfjs-core'; |
| 19 | +import {getCoordsDataType, getMainHeaderAndGlobalIndexString, mapToWgslTypes, WebGPUProgram} from './webgpu_program'; |
19 | 20 | import {computeDispatch, flatDispatchLayout} from './webgpu_util'; |
20 | 21 |
|
21 | 22 | export class ScatterProgram implements WebGPUProgram { |
22 | | - variableNames = ['updates', 'indices', 'defaultValue']; |
| 23 | + variableNames = ['updates', 'indices']; |
23 | 24 | uniforms: string; |
24 | 25 | outputShape: number[]; |
| 26 | + sumDupeIndices: boolean; |
25 | 27 | shaderKey: string; |
26 | 28 | dispatchLayout: {x: number[]}; |
27 | 29 | dispatch: [number, number, number]; |
28 | 30 | workGroupSize: [number, number, number] = [64, 1, 1]; |
29 | | - workPerThread = 4; |
30 | | - size = true; |
31 | | - indicesSnippet: string; |
32 | | - strideString: string; |
33 | | - updatesSnippet: string; |
| 31 | + updatesRank: number; |
| 32 | + indicesRank: number; |
| 33 | + sliceDimGreaterThanOne: boolean; |
| 34 | + atomic = true; |
| 35 | + type: DataType; |
34 | 36 |
|
35 | 37 | constructor( |
36 | | - updateSize: number, sliceDim: number, indicesRank: number, |
| 38 | + flattenXShape: number[], sliceDim: number, indicesRank: number, |
37 | 39 | updatesRank: number, strides: number[], shape: number[], |
38 | | - summingDupeIndex = true) { |
| 40 | + outputDtype: DataType, sumDupeIndices = true) { |
39 | 41 | this.outputShape = shape; |
40 | | - this.dispatchLayout = flatDispatchLayout(this.outputShape); |
41 | | - this.dispatch = computeDispatch( |
42 | | - this.dispatchLayout, this.outputShape, this.workGroupSize, |
43 | | - [this.workPerThread, 1, 1]); |
44 | | - const sliceDimGreaterThanOne = sliceDim > 1; |
45 | | - this.shaderKey = |
46 | | - `scatter_${indicesRank}_${updatesRank}_${sliceDimGreaterThanOne}`; |
| 42 | + this.type = outputDtype; |
| 43 | + this.sumDupeIndices = sumDupeIndices; |
| 44 | + this.dispatchLayout = flatDispatchLayout(flattenXShape); |
| 45 | + // Dispatching based on |updates| shape instead of output shape. |
| 46 | + this.dispatch = |
| 47 | + computeDispatch(this.dispatchLayout, flattenXShape, this.workGroupSize); |
| 48 | + this.sliceDimGreaterThanOne = sliceDim > 1; |
| 49 | + this.shaderKey = `scatter_${indicesRank}_${updatesRank}_${ |
| 50 | + this.sliceDimGreaterThanOne}_${outputDtype}_${sumDupeIndices}`; |
47 | 51 | const stridesType = getCoordsDataType(strides.length); |
48 | | - this.uniforms = |
49 | | - `updateSize : i32, sliceDim : i32, strides: ${stridesType},`; |
| 52 | + this.uniforms = `sliceDim : i32, strides: ${stridesType}, size: i32,`; |
| 53 | + this.updatesRank = updatesRank; |
| 54 | + this.indicesRank = indicesRank; |
| 55 | + } |
| 56 | + |
| 57 | + getUserCode(): string { |
50 | 58 | let indicesString = ''; |
51 | | - if (indicesRank === 1) { |
52 | | - indicesString = 'i'; |
53 | | - } else if (indicesRank === 2) { |
54 | | - indicesString = 'i, j'; |
| 59 | + if (this.indicesRank === 1) { |
| 60 | + indicesString = 'coords[0]'; |
| 61 | + } else if (this.indicesRank === 2) { |
| 62 | + indicesString = 'coords[0], j'; |
55 | 63 | } |
56 | | - this.indicesSnippet = `getIndices(${indicesString})`; |
| 64 | + const indicesSnippet = `getIndices(${indicesString})`; |
| 65 | + |
| 66 | + const strideString = this.sliceDimGreaterThanOne ? 'uniforms.strides[j]' : |
| 67 | + 'uniforms.strides'; |
57 | 68 |
|
58 | | - let updatesString = ''; |
59 | | - if (updatesRank === 1) { |
60 | | - updatesString = 'i'; |
61 | | - } else if (updatesRank === 2) { |
62 | | - updatesString = 'i, coords[1]'; |
| 69 | + let outCoordsString = ''; |
| 70 | + let getUpdatesCoordsFromFlatIndex = ''; |
| 71 | + if (this.dispatchLayout.x.length === 1) { |
| 72 | + outCoordsString = 'flattenedIndex'; |
| 73 | + getUpdatesCoordsFromFlatIndex = ` |
| 74 | + fn getUpdatesCoordsFromFlatIndex(index : i32) -> i32 { |
| 75 | + return index; |
| 76 | + } |
| 77 | + `; |
| 78 | + } else if (this.dispatchLayout.x.length === 2) { |
| 79 | + outCoordsString = 'vec2<i32>(flattenedIndex, coords[1])'; |
| 80 | + getUpdatesCoordsFromFlatIndex = ` |
| 81 | + fn getUpdatesCoordsFromFlatIndex(index : i32) -> vec2<i32> { |
| 82 | + // N.B. |updates| could be a scalar tensor, conceptually representing a |
| 83 | + // 2D tensor with all values equal to that. By design, its size must be |
| 84 | + // the same as |outShape[1]| in one dimension, and |indicesShape[0]| |
| 85 | + // gives the other. |
| 86 | + let sliceSize = uniforms.outShape[1]; |
| 87 | + let d0 = index / sliceSize; |
| 88 | + let d1 = index - d0 * sliceSize; |
| 89 | + return vec2<i32>(d0, d1); |
| 90 | + } |
| 91 | + `; |
63 | 92 | } |
64 | | - this.updatesSnippet = `getUpdates(${updatesString})`; |
| 93 | + const updatesString = |
| 94 | + Array.from({length: this.updatesRank}, (_, idx) => `coords[${idx}]`); |
| 95 | + const updatesSnippet = `getUpdates(${updatesString.join(', ')})`; |
65 | 96 |
|
66 | | - this.strideString = |
67 | | - sliceDimGreaterThanOne ? 'uniforms.strides[j]' : 'uniforms.strides'; |
68 | | - } |
| 97 | + const atomicRMW = (ptr: string, val: string) => { |
| 98 | + let atomicAddSnippet = `atomicAdd(${ptr}, bitcast<i32>(${val}))`; |
| 99 | + if (this.type === 'float32') { |
| 100 | + atomicAddSnippet = ` |
| 101 | + { |
| 102 | + var oldBits = 0; |
| 103 | + var newBits = bitcast<i32>(${val}); |
| 104 | + loop { |
| 105 | + let info = atomicCompareExchangeWeak(${ptr}, oldBits, newBits); |
| 106 | + if (info.exchanged) { |
| 107 | + break; |
| 108 | + } |
| 109 | + oldBits = info.old_value; |
| 110 | + let oldValue = bitcast<f32>(oldBits); |
| 111 | + let newValue = oldValue + (${val}); |
| 112 | + newBits = bitcast<i32>(newValue); |
| 113 | + } |
| 114 | + } |
| 115 | + `; |
| 116 | + } |
| 117 | + const atomicStoreSnippet = `atomicStore(${ptr}, bitcast<i32>(${val}));`; |
| 118 | + return this.sumDupeIndices ? atomicAddSnippet : atomicStoreSnippet; |
| 119 | + }; |
69 | 120 |
|
70 | | - getUserCode(): string { |
71 | 121 | const userCode = ` |
| 122 | + ${getUpdatesCoordsFromFlatIndex} |
| 123 | +
|
72 | 124 | ${getMainHeaderAndGlobalIndexString()} |
73 | 125 |
|
74 | | - let globalIndex = index * ${this.workPerThread}; |
75 | | - if (globalIndex < uniforms.size) { |
76 | | - var sum = vec4<f32>(0.0); |
77 | | - var found = vec4<bool>(false); |
78 | | - for (var i = 0; i < uniforms.updateSize; i = i + 1) { |
79 | | - var flattenedIndex = 0; |
80 | | - for (var j = 0; j < uniforms.sliceDim; j = j + 1) { |
81 | | - let indexInside = i32(round(${this.indicesSnippet})); |
82 | | - flattenedIndex = flattenedIndex + indexInside * ${ |
83 | | - this.strideString}; |
84 | | - } |
85 | | - for (var innerIndex = 0; innerIndex < ${ |
86 | | - this.workPerThread}; innerIndex = innerIndex + 1) { |
87 | | - let curIndex = globalIndex + innerIndex; |
88 | | - let coords = getCoordsFromIndex(curIndex); |
89 | | - if (flattenedIndex == coords[0]) { |
90 | | - sum[innerIndex] = sum[innerIndex] + ${this.updatesSnippet}; |
91 | | - found[innerIndex] = true; |
92 | | - } |
93 | | - } |
94 | | - } |
95 | | - for (var innerIndex = 0; innerIndex < ${ |
96 | | - this.workPerThread}; innerIndex = innerIndex + 1) { |
97 | | - let curIndex = globalIndex + innerIndex; |
98 | | - if (curIndex < uniforms.size) |
99 | | - { |
100 | | - setOutputAtIndex(curIndex, mix(getDefaultValue(), sum[innerIndex], f32(found[innerIndex]))); |
101 | | - } |
| 126 | + if (index < uniforms.size) { |
| 127 | + let coords = getUpdatesCoordsFromFlatIndex(index); |
| 128 | + var flattenedIndex = 0; |
| 129 | + for (var j = 0; j < uniforms.sliceDim; j = j + 1) { |
| 130 | + let indexInside = i32(round(${indicesSnippet})); |
| 131 | + flattenedIndex = flattenedIndex + indexInside * ${strideString}; |
102 | 132 | } |
| 133 | + let updateValue = |
| 134 | + ${mapToWgslTypes(this.type, false)}(${updatesSnippet}); |
| 135 | + let flatIndex = getOutputIndexFromCoords(${outCoordsString}); |
| 136 | +
|
| 137 | + ${atomicRMW('&result[flatIndex]', 'updateValue')}; |
103 | 138 | } |
104 | 139 | }`; |
105 | 140 | return userCode; |
|
0 commit comments