Skip to content

Commit bdfb9d7

Browse files
authored
[webgpu] Migrate sparseToDense to the atomic-based kernel (#6552)
PERF The sparseToDense op takes an optional default value. Unlike scatterNd, the output cannot be initialized with fill(), since the default value is a scalar tensor (which could be the result of a previous op) than a scalar number. The (horrible!) workaround here is to broadcast the value with tile(). The other challenge is if the kernel should discard the original value at index or accumulate on that. The magic is performed by splitting the op into two "scatter" steps: 1) replace the default value with 0, and 2) add the input sparse values to 0 or whatever. This avoids a bitmap for recording whether the output element at index has been updated by another invocation. Closes #6525
1 parent 466807f commit bdfb9d7

File tree

3 files changed

+120
-42
lines changed

3 files changed

+120
-42
lines changed

tfjs-backend-webgpu/src/kernels/SparseToDense.ts

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ import {backend_util, KernelConfig, KernelFunc, Rank, SparseToDense, SparseToDen
1919

2020
import {WebGPUBackend} from '../backend_webgpu';
2121
import {scatterImplCPU} from '../kernel_utils/shared';
22-
import {ScatterProgram} from '../scatter_webgpu';
22+
import {ScatterOptimizedProgram} from '../scatter_optimized_webgpu';
2323

24+
import {identity} from './Identity';
2425
import {reshape} from './Reshape';
26+
import {tile} from './Tile';
2527

2628
export function sparseToDense(args: {
2729
inputs: SparseToDenseInputs,
@@ -46,24 +48,85 @@ export function sparseToDense(args: {
4648
sliceRank, strides, $defaultValue, sumDupeIndices);
4749
return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values);
4850
}
51+
52+
const flattenShape = [outputSize / sliceSize, sliceSize];
53+
54+
const $sparseIndices = reshape({
55+
inputs: {x: sparseIndices},
56+
backend,
57+
attrs: {shape: [numUpdates, sliceRank]}
58+
});
59+
const $sparseValues = sparseValues.shape.length ?
60+
reshape({
61+
inputs: {x: sparseValues},
62+
backend,
63+
attrs: {shape: [numUpdates, sliceSize]}
64+
}) :
65+
identity({inputs: {x: sparseValues}, backend});
66+
67+
const type = $sparseValues.dtype;
68+
const zero =
69+
backend.makeTensorInfo([], type, util.makeZerosTypedArray(1, type));
70+
71+
// Fill output tensor with the default value.
72+
const $defaultValue = reshape({
73+
inputs: {x: defaultValue},
74+
backend,
75+
attrs: {shape: Array(flattenShape.length).fill(1)}
76+
});
77+
const $denseValues =
78+
tile({inputs: {x: $defaultValue}, backend, attrs: {reps: flattenShape}});
79+
80+
const size = util.sizeFromShape([numUpdates, sliceSize]);
4981
const uniformData = [
50-
{type: 'int32', data: [numUpdates]},
5182
{type: 'int32', data: [sliceRank]},
5283
{type: 'int32', data: strides},
84+
{type: 'int32', data: [size]},
5385
];
54-
const program = new ScatterProgram(
55-
numUpdates, sliceRank, sparseIndices.shape.length,
56-
sparseValues.shape.length, strides, [outputSize, 1], sumDupeIndices);
5786

58-
const res = backend.runWebGPUProgram(
59-
program, [sparseValues, sparseIndices, defaultValue], sparseValues.dtype,
60-
uniformData);
87+
switch (numUpdates) {
88+
case 0:
89+
break;
90+
case 1:
91+
if (true) {
92+
const program = new ScatterOptimizedProgram(
93+
[numUpdates, sliceSize], sliceRank, $sparseIndices.shape.length,
94+
$sparseValues.shape.length, strides, flattenShape, type,
95+
sumDupeIndices);
96+
backend.runWebGPUProgram(
97+
program, [$sparseValues, $sparseIndices], type, uniformData,
98+
$denseValues);
99+
}
100+
break;
101+
default:
102+
if (true) {
103+
// First replace the default value with 0 at indices.
104+
const program = new ScatterOptimizedProgram(
105+
[numUpdates, sliceSize], sliceRank, $sparseIndices.shape.length,
106+
zero.shape.length, strides, flattenShape, type, sumDupeIndices);
107+
backend.runWebGPUProgram(
108+
program, [zero, $sparseIndices], type, uniformData, $denseValues);
109+
}
110+
{
111+
// Then replace 0 with the (sum of) sparse value(s) at indices.
112+
const program = new ScatterOptimizedProgram(
113+
[numUpdates, sliceSize], sliceRank, $sparseIndices.shape.length,
114+
$sparseValues.shape.length, strides, flattenShape, type);
115+
backend.runWebGPUProgram(
116+
program, [$sparseValues, $sparseIndices], type, uniformData,
117+
$denseValues);
118+
}
119+
}
61120

62-
const reshaped =
63-
reshape({inputs: {x: res}, backend, attrs: {shape: outputShape}});
121+
const denseValues = reshape(
122+
{inputs: {x: $denseValues}, backend, attrs: {shape: outputShape}});
64123

65-
backend.disposeData(res.dataId);
66-
return reshaped;
124+
backend.disposeData($sparseIndices.dataId);
125+
backend.disposeData($sparseValues.dataId);
126+
backend.disposeData($defaultValue.dataId);
127+
backend.disposeData(zero.dataId);
128+
backend.disposeData($denseValues.dataId);
129+
return denseValues;
67130
}
68131

69132
export const sparseToDenseConfig: KernelConfig = {

tfjs-backend-webgpu/src/scatter_optimized_webgpu.ts

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
*/
1717

1818
import {DataType} from '@tensorflow/tfjs-core';
19-
import {getCoordsDataType, getMainHeaderAndGlobalIndexString, WebGPUProgram} from './webgpu_program';
19+
import {getCoordsDataType, getMainHeaderAndGlobalIndexString, mapToWgslTypes, WebGPUProgram} from './webgpu_program';
2020
import {computeDispatch, flatDispatchLayout} from './webgpu_util';
2121

2222
export class ScatterOptimizedProgram implements WebGPUProgram {
2323
variableNames = ['updates', 'indices'];
2424
uniforms: string;
2525
outputShape: number[];
26+
sumDupeIndices: boolean;
2627
shaderKey: string;
2728
dispatchLayout: {x: number[]};
2829
dispatch: [number, number, number];
@@ -36,16 +37,17 @@ export class ScatterOptimizedProgram implements WebGPUProgram {
3637
constructor(
3738
flattenXShape: number[], sliceDim: number, indicesRank: number,
3839
updatesRank: number, strides: number[], shape: number[],
39-
outputDtype: DataType) {
40+
outputDtype: DataType, sumDupeIndices = true) {
4041
this.outputShape = shape;
4142
this.type = outputDtype;
43+
this.sumDupeIndices = sumDupeIndices;
4244
this.dispatchLayout = flatDispatchLayout(flattenXShape);
4345
// Dispatching based on |updates| shape instead of output shape.
4446
this.dispatch =
4547
computeDispatch(this.dispatchLayout, flattenXShape, this.workGroupSize);
4648
this.sliceDimGreaterThanOne = sliceDim > 1;
4749
this.shaderKey = `scatter_${indicesRank}_${updatesRank}_${
48-
this.sliceDimGreaterThanOne}_${outputDtype}`;
50+
this.sliceDimGreaterThanOne}_${outputDtype}_${sumDupeIndices}`;
4951
const stridesType = getCoordsDataType(strides.length);
5052
this.uniforms = `sliceDim : i32, strides: ${stridesType}, size: i32,`;
5153
this.updatesRank = updatesRank;
@@ -64,45 +66,57 @@ export class ScatterOptimizedProgram implements WebGPUProgram {
6466
const strideString = this.sliceDimGreaterThanOne ? 'uniforms.strides[j]' :
6567
'uniforms.strides';
6668

67-
let updatesString = '';
6869
let outCoordsString = '';
6970
let getUpdatesCoordsFromFlatIndex = '';
70-
if (this.updatesRank === 1) {
71-
updatesString = 'coords[0]';
71+
if (this.dispatchLayout.x.length === 1) {
7272
outCoordsString = 'flattenedIndex';
7373
getUpdatesCoordsFromFlatIndex = `
7474
fn getUpdatesCoordsFromFlatIndex(index : i32) -> i32 {
7575
return index;
7676
}
7777
`;
78-
} else if (this.updatesRank === 2) {
79-
updatesString = 'coords[0], coords[1]';
78+
} else if (this.dispatchLayout.x.length === 2) {
8079
outCoordsString = 'vec2<i32>(flattenedIndex, coords[1])';
8180
getUpdatesCoordsFromFlatIndex = `
8281
fn getUpdatesCoordsFromFlatIndex(index : i32) -> vec2<i32> {
83-
let d0 = index / uniforms.updatesShape[1];
84-
let d1 = index - d0 * uniforms.updatesShape[1];
82+
// N.B. |updates| could be a scalar tensor, conceptually representing a
83+
// 2D tensor with all values equal to that. By design, its size must be
84+
// the same as |outShape[1]| in one dimension, and |indicesShape[0]|
85+
// gives the other.
86+
let sliceSize = uniforms.outShape[1];
87+
let d0 = index / sliceSize;
88+
let d1 = index - d0 * sliceSize;
8589
return vec2<i32>(d0, d1);
8690
}
8791
`;
8892
}
89-
const updatesSnippet = `getUpdates(${updatesString})`;
93+
const updatesString =
94+
Array.from({length: this.updatesRank}, (_, idx) => `coords[${idx}]`);
95+
const updatesSnippet = `getUpdates(${updatesString.join(', ')})`;
9096

91-
// atomicAdd only supports uint/int type. For float, we use
92-
// atomicCompareExchangeWeak to simulate.
93-
const atomicAddSnippet = this.type === 'int32' ?
94-
`atomicAdd(&(result[flatIndex]), i32(updateValue));` :
95-
`
96-
var oldValue = atomicLoad(&(result[flatIndex]));
97-
var exchanged = false;
98-
for (; !exchanged;) {
99-
let newValueF32 = bitcast<f32>(oldValue) + updateValue;
100-
let newValue = bitcast<i32>(newValueF32);
101-
let res = atomicCompareExchangeWeak(&(result[flatIndex]), oldValue, newValue);
102-
oldValue = res.old_value;
103-
exchanged = res.exchanged;
104-
}
105-
`;
97+
const atomicRMW = (ptr: string, val: string) => {
98+
let atomicAddSnippet = `atomicAdd(${ptr}, bitcast<i32>(${val}))`;
99+
if (this.type === 'float32') {
100+
atomicAddSnippet = `
101+
{
102+
var oldBits = 0;
103+
var newBits = bitcast<i32>(${val});
104+
loop {
105+
let info = atomicCompareExchangeWeak(${ptr}, oldBits, newBits);
106+
if (info.exchanged) {
107+
break;
108+
}
109+
oldBits = info.old_value;
110+
let oldValue = bitcast<f32>(oldBits);
111+
let newValue = oldValue + (${val});
112+
newBits = bitcast<i32>(newValue);
113+
}
114+
}
115+
`;
116+
}
117+
const atomicStoreSnippet = `atomicStore(${ptr}, bitcast<i32>(${val}));`;
118+
return this.sumDupeIndices ? atomicAddSnippet : atomicStoreSnippet;
119+
};
106120

107121
const userCode = `
108122
${getUpdatesCoordsFromFlatIndex}
@@ -116,10 +130,11 @@ export class ScatterOptimizedProgram implements WebGPUProgram {
116130
let indexInside = i32(round(${indicesSnippet}));
117131
flattenedIndex = flattenedIndex + indexInside * ${strideString};
118132
}
119-
let updateValue = ${updatesSnippet};
133+
let updateValue =
134+
${mapToWgslTypes(this.type, false)}(${updatesSnippet});
120135
let flatIndex = getOutputIndexFromCoords(${outCoordsString});
121136
122-
${atomicAddSnippet}
137+
${atomicRMW('&result[flatIndex]', 'updateValue')};
123138
}
124139
}`;
125140
return userCode;

tfjs-backend-webgpu/src/webgpu_program.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ const commonSnippet = `
367367
type InputInfo = {
368368
dtype: DataType; shape: number[]; name: string;
369369
};
370-
type WGSLDataType = 'f32'|'i32'|'vec4<f32>'|'vec4<i32>'|'vec4<bool>';
370+
export type WGSLDataType = 'f32'|'i32'|'vec4<f32>'|'vec4<i32>'|'vec4<bool>';
371371

372372
/**
373373
* Derives logical coordinates from a flat index. Performs integer division
@@ -754,7 +754,7 @@ function isFlatDispatch(program: WebGPUProgram): boolean {
754754
return program.dispatch[1] === 1 && program.dispatch[2] === 1;
755755
}
756756

757-
function mapToWgslTypes(type: DataType, isVec4: boolean): WGSLDataType|
757+
export function mapToWgslTypes(type: DataType, isVec4: boolean): WGSLDataType|
758758
DataType {
759759
if (type === 'float32') {
760760
return isVec4 ? 'vec4<f32>' : 'f32';

0 commit comments

Comments
 (0)