Skip to content

Commit d515f4d

Browse files
authored
[webgpu] s/ScatterOptimizedProgram/ScatterProgram/g (#6761)
1 parent a366cc2 commit d515f4d

File tree

5 files changed

+102
-216
lines changed

5 files changed

+102
-216
lines changed

tfjs-backend-webgpu/src/kernels/ScatterNd.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
import {backend_util, KernelConfig, KernelFunc, ScatterNd, ScatterNdAttrs, ScatterNdInputs, TensorInfo, util} from '@tensorflow/tfjs-core';
1919

2020
import {WebGPUBackend} from '../backend_webgpu';
21+
import {ScatterProgram} from '../scatter_webgpu';
2122

2223
import {fill} from './Fill';
2324
import {reshape} from './Reshape';
24-
import {ScatterOptimizedProgram} from '../scatter_optimized_webgpu';
2525

2626
export function scatterNd(args: {
2727
inputs: ScatterNdInputs,
@@ -54,7 +54,7 @@ export function scatterNd(args: {
5454
{type: 'int32', data: [sliceRank]}, {type: 'int32', data: strides},
5555
{type: 'int32', data: [size]}
5656
];
57-
const program = new ScatterOptimizedProgram(
57+
const program = new ScatterProgram(
5858
flattenX.shape, sliceRank, flattenIndices.shape.length,
5959
flattenX.shape.length, strides, flattenShape, type);
6060
const res = backend.runWebGPUProgram(

tfjs-backend-webgpu/src/kernels/SparseToDense.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import {backend_util, KernelConfig, KernelFunc, Rank, SparseToDense, SparseToDen
1919

2020
import {WebGPUBackend} from '../backend_webgpu';
2121
import {scatterImplCPU} from '../kernel_utils/shared';
22-
import {ScatterOptimizedProgram} from '../scatter_optimized_webgpu';
22+
import {ScatterProgram} from '../scatter_webgpu';
2323

2424
import {identity} from './Identity';
2525
import {reshape} from './Reshape';
@@ -89,7 +89,7 @@ export function sparseToDense(args: {
8989
break;
9090
case 1:
9191
if (true) {
92-
const program = new ScatterOptimizedProgram(
92+
const program = new ScatterProgram(
9393
[numUpdates, sliceSize], sliceRank, $sparseIndices.shape.length,
9494
$sparseValues.shape.length, strides, flattenShape, type,
9595
sumDupeIndices);
@@ -101,15 +101,15 @@ export function sparseToDense(args: {
101101
default:
102102
if (true) {
103103
// First replace the default value with 0 at indices.
104-
const program = new ScatterOptimizedProgram(
104+
const program = new ScatterProgram(
105105
[numUpdates, sliceSize], sliceRank, $sparseIndices.shape.length,
106106
zero.shape.length, strides, flattenShape, type, sumDupeIndices);
107107
backend.runWebGPUProgram(
108108
program, [zero, $sparseIndices], type, uniformData, $denseValues);
109109
}
110110
{
111111
// Then replace 0 with the (sum of) sparse value(s) at indices.
112-
const program = new ScatterOptimizedProgram(
112+
const program = new ScatterProgram(
113113
[numUpdates, sliceSize], sliceRank, $sparseIndices.shape.length,
114114
$sparseValues.shape.length, strides, flattenShape, type);
115115
backend.runWebGPUProgram(

tfjs-backend-webgpu/src/scatter_optimized_webgpu.ts

Lines changed: 0 additions & 142 deletions
This file was deleted.

tfjs-backend-webgpu/src/scatter_webgpu.ts

Lines changed: 96 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,91 +15,126 @@
1515
* =============================================================================
1616
*/
1717

18-
import {getCoordsDataType, getMainHeaderAndGlobalIndexString, WebGPUProgram} from './webgpu_program';
18+
import {DataType} from '@tensorflow/tfjs-core';
19+
import {getCoordsDataType, getMainHeaderAndGlobalIndexString, mapToWgslTypes, WebGPUProgram} from './webgpu_program';
1920
import {computeDispatch, flatDispatchLayout} from './webgpu_util';
2021

2122
export class ScatterProgram implements WebGPUProgram {
22-
variableNames = ['updates', 'indices', 'defaultValue'];
23+
variableNames = ['updates', 'indices'];
2324
uniforms: string;
2425
outputShape: number[];
26+
sumDupeIndices: boolean;
2527
shaderKey: string;
2628
dispatchLayout: {x: number[]};
2729
dispatch: [number, number, number];
2830
workGroupSize: [number, number, number] = [64, 1, 1];
29-
workPerThread = 4;
30-
size = true;
31-
indicesSnippet: string;
32-
strideString: string;
33-
updatesSnippet: string;
31+
updatesRank: number;
32+
indicesRank: number;
33+
sliceDimGreaterThanOne: boolean;
34+
atomic = true;
35+
type: DataType;
3436

3537
constructor(
36-
updateSize: number, sliceDim: number, indicesRank: number,
38+
flattenXShape: number[], sliceDim: number, indicesRank: number,
3739
updatesRank: number, strides: number[], shape: number[],
38-
summingDupeIndex = true) {
40+
outputDtype: DataType, sumDupeIndices = true) {
3941
this.outputShape = shape;
40-
this.dispatchLayout = flatDispatchLayout(this.outputShape);
41-
this.dispatch = computeDispatch(
42-
this.dispatchLayout, this.outputShape, this.workGroupSize,
43-
[this.workPerThread, 1, 1]);
44-
const sliceDimGreaterThanOne = sliceDim > 1;
45-
this.shaderKey =
46-
`scatter_${indicesRank}_${updatesRank}_${sliceDimGreaterThanOne}`;
42+
this.type = outputDtype;
43+
this.sumDupeIndices = sumDupeIndices;
44+
this.dispatchLayout = flatDispatchLayout(flattenXShape);
45+
// Dispatching based on |updates| shape instead of output shape.
46+
this.dispatch =
47+
computeDispatch(this.dispatchLayout, flattenXShape, this.workGroupSize);
48+
this.sliceDimGreaterThanOne = sliceDim > 1;
49+
this.shaderKey = `scatter_${indicesRank}_${updatesRank}_${
50+
this.sliceDimGreaterThanOne}_${outputDtype}_${sumDupeIndices}`;
4751
const stridesType = getCoordsDataType(strides.length);
48-
this.uniforms =
49-
`updateSize : i32, sliceDim : i32, strides: ${stridesType},`;
52+
this.uniforms = `sliceDim : i32, strides: ${stridesType}, size: i32,`;
53+
this.updatesRank = updatesRank;
54+
this.indicesRank = indicesRank;
55+
}
56+
57+
getUserCode(): string {
5058
let indicesString = '';
51-
if (indicesRank === 1) {
52-
indicesString = 'i';
53-
} else if (indicesRank === 2) {
54-
indicesString = 'i, j';
59+
if (this.indicesRank === 1) {
60+
indicesString = 'coords[0]';
61+
} else if (this.indicesRank === 2) {
62+
indicesString = 'coords[0], j';
5563
}
56-
this.indicesSnippet = `getIndices(${indicesString})`;
64+
const indicesSnippet = `getIndices(${indicesString})`;
65+
66+
const strideString = this.sliceDimGreaterThanOne ? 'uniforms.strides[j]' :
67+
'uniforms.strides';
5768

58-
let updatesString = '';
59-
if (updatesRank === 1) {
60-
updatesString = 'i';
61-
} else if (updatesRank === 2) {
62-
updatesString = 'i, coords[1]';
69+
let outCoordsString = '';
70+
let getUpdatesCoordsFromFlatIndex = '';
71+
if (this.dispatchLayout.x.length === 1) {
72+
outCoordsString = 'flattenedIndex';
73+
getUpdatesCoordsFromFlatIndex = `
74+
fn getUpdatesCoordsFromFlatIndex(index : i32) -> i32 {
75+
return index;
76+
}
77+
`;
78+
} else if (this.dispatchLayout.x.length === 2) {
79+
outCoordsString = 'vec2<i32>(flattenedIndex, coords[1])';
80+
getUpdatesCoordsFromFlatIndex = `
81+
fn getUpdatesCoordsFromFlatIndex(index : i32) -> vec2<i32> {
82+
// N.B. |updates| could be a scalar tensor, conceptually representing a
83+
// 2D tensor with all values equal to that. By design, its size must be
84+
// the same as |outShape[1]| in one dimension, and |indicesShape[0]|
85+
// gives the other.
86+
let sliceSize = uniforms.outShape[1];
87+
let d0 = index / sliceSize;
88+
let d1 = index - d0 * sliceSize;
89+
return vec2<i32>(d0, d1);
90+
}
91+
`;
6392
}
64-
this.updatesSnippet = `getUpdates(${updatesString})`;
93+
const updatesString =
94+
Array.from({length: this.updatesRank}, (_, idx) => `coords[${idx}]`);
95+
const updatesSnippet = `getUpdates(${updatesString.join(', ')})`;
6596

66-
this.strideString =
67-
sliceDimGreaterThanOne ? 'uniforms.strides[j]' : 'uniforms.strides';
68-
}
97+
const atomicRMW = (ptr: string, val: string) => {
98+
let atomicAddSnippet = `atomicAdd(${ptr}, bitcast<i32>(${val}))`;
99+
if (this.type === 'float32') {
100+
atomicAddSnippet = `
101+
{
102+
var oldBits = 0;
103+
var newBits = bitcast<i32>(${val});
104+
loop {
105+
let info = atomicCompareExchangeWeak(${ptr}, oldBits, newBits);
106+
if (info.exchanged) {
107+
break;
108+
}
109+
oldBits = info.old_value;
110+
let oldValue = bitcast<f32>(oldBits);
111+
let newValue = oldValue + (${val});
112+
newBits = bitcast<i32>(newValue);
113+
}
114+
}
115+
`;
116+
}
117+
const atomicStoreSnippet = `atomicStore(${ptr}, bitcast<i32>(${val}));`;
118+
return this.sumDupeIndices ? atomicAddSnippet : atomicStoreSnippet;
119+
};
69120

70-
getUserCode(): string {
71121
const userCode = `
122+
${getUpdatesCoordsFromFlatIndex}
123+
72124
${getMainHeaderAndGlobalIndexString()}
73125
74-
let globalIndex = index * ${this.workPerThread};
75-
if (globalIndex < uniforms.size) {
76-
var sum = vec4<f32>(0.0);
77-
var found = vec4<bool>(false);
78-
for (var i = 0; i < uniforms.updateSize; i = i + 1) {
79-
var flattenedIndex = 0;
80-
for (var j = 0; j < uniforms.sliceDim; j = j + 1) {
81-
let indexInside = i32(round(${this.indicesSnippet}));
82-
flattenedIndex = flattenedIndex + indexInside * ${
83-
this.strideString};
84-
}
85-
for (var innerIndex = 0; innerIndex < ${
86-
this.workPerThread}; innerIndex = innerIndex + 1) {
87-
let curIndex = globalIndex + innerIndex;
88-
let coords = getCoordsFromIndex(curIndex);
89-
if (flattenedIndex == coords[0]) {
90-
sum[innerIndex] = sum[innerIndex] + ${this.updatesSnippet};
91-
found[innerIndex] = true;
92-
}
93-
}
94-
}
95-
for (var innerIndex = 0; innerIndex < ${
96-
this.workPerThread}; innerIndex = innerIndex + 1) {
97-
let curIndex = globalIndex + innerIndex;
98-
if (curIndex < uniforms.size)
99-
{
100-
setOutputAtIndex(curIndex, mix(getDefaultValue(), sum[innerIndex], f32(found[innerIndex])));
101-
}
126+
if (index < uniforms.size) {
127+
let coords = getUpdatesCoordsFromFlatIndex(index);
128+
var flattenedIndex = 0;
129+
for (var j = 0; j < uniforms.sliceDim; j = j + 1) {
130+
let indexInside = i32(round(${indicesSnippet}));
131+
flattenedIndex = flattenedIndex + indexInside * ${strideString};
102132
}
133+
let updateValue =
134+
${mapToWgslTypes(this.type, false)}(${updatesSnippet});
135+
let flatIndex = getOutputIndexFromCoords(${outCoordsString});
136+
137+
${atomicRMW('&result[flatIndex]', 'updateValue')};
103138
}
104139
}`;
105140
return userCode;

0 commit comments

Comments
 (0)