1616 */
1717
1818import { DataType } from '@tensorflow/tfjs-core' ;
19- import { getCoordsDataType , getMainHeaderAndGlobalIndexString , WebGPUProgram } from './webgpu_program' ;
19+ import { getCoordsDataType , getMainHeaderAndGlobalIndexString , mapToWgslTypes , WebGPUProgram } from './webgpu_program' ;
2020import { computeDispatch , flatDispatchLayout } from './webgpu_util' ;
2121
2222export class ScatterOptimizedProgram implements WebGPUProgram {
2323 variableNames = [ 'updates' , 'indices' ] ;
2424 uniforms : string ;
2525 outputShape : number [ ] ;
26+ sumDupeIndices : boolean ;
2627 shaderKey : string ;
2728 dispatchLayout : { x : number [ ] } ;
2829 dispatch : [ number , number , number ] ;
@@ -36,16 +37,17 @@ export class ScatterOptimizedProgram implements WebGPUProgram {
3637 constructor (
3738 flattenXShape : number [ ] , sliceDim : number , indicesRank : number ,
3839 updatesRank : number , strides : number [ ] , shape : number [ ] ,
39- outputDtype : DataType ) {
40+ outputDtype : DataType , sumDupeIndices = true ) {
4041 this . outputShape = shape ;
4142 this . type = outputDtype ;
43+ this . sumDupeIndices = sumDupeIndices ;
4244 this . dispatchLayout = flatDispatchLayout ( flattenXShape ) ;
4345 // Dispatching based on |updates| shape instead of output shape.
4446 this . dispatch =
4547 computeDispatch ( this . dispatchLayout , flattenXShape , this . workGroupSize ) ;
4648 this . sliceDimGreaterThanOne = sliceDim > 1 ;
4749 this . shaderKey = `scatter_${ indicesRank } _${ updatesRank } _${
48- this . sliceDimGreaterThanOne } _${ outputDtype } `;
50+ this . sliceDimGreaterThanOne } _${ outputDtype } _ ${ sumDupeIndices } `;
4951 const stridesType = getCoordsDataType ( strides . length ) ;
5052 this . uniforms = `sliceDim : i32, strides: ${ stridesType } , size: i32,` ;
5153 this . updatesRank = updatesRank ;
@@ -64,45 +66,57 @@ export class ScatterOptimizedProgram implements WebGPUProgram {
6466 const strideString = this . sliceDimGreaterThanOne ? 'uniforms.strides[j]' :
6567 'uniforms.strides' ;
6668
67- let updatesString = '' ;
6869 let outCoordsString = '' ;
6970 let getUpdatesCoordsFromFlatIndex = '' ;
70- if ( this . updatesRank === 1 ) {
71- updatesString = 'coords[0]' ;
71+ if ( this . dispatchLayout . x . length === 1 ) {
7272 outCoordsString = 'flattenedIndex' ;
7373 getUpdatesCoordsFromFlatIndex = `
7474 fn getUpdatesCoordsFromFlatIndex(index : i32) -> i32 {
7575 return index;
7676 }
7777 ` ;
78- } else if ( this . updatesRank === 2 ) {
79- updatesString = 'coords[0], coords[1]' ;
78+ } else if ( this . dispatchLayout . x . length === 2 ) {
8079 outCoordsString = 'vec2<i32>(flattenedIndex, coords[1])' ;
8180 getUpdatesCoordsFromFlatIndex = `
8281 fn getUpdatesCoordsFromFlatIndex(index : i32) -> vec2<i32> {
83- let d0 = index / uniforms.updatesShape[1];
84- let d1 = index - d0 * uniforms.updatesShape[1];
82+ // N.B. |updates| could be a scalar tensor, conceptually representing a
83+ // 2D tensor with all values equal to that. By design, its size must be
84+ // the same as |outShape[1]| in one dimension, and |indicesShape[0]|
85+ // gives the other.
86+ let sliceSize = uniforms.outShape[1];
87+ let d0 = index / sliceSize;
88+ let d1 = index - d0 * sliceSize;
8589 return vec2<i32>(d0, d1);
8690 }
8791 ` ;
8892 }
89- const updatesSnippet = `getUpdates(${ updatesString } )` ;
93+ const updatesString =
94+ Array . from ( { length : this . updatesRank } , ( _ , idx ) => `coords[${ idx } ]` ) ;
95+ const updatesSnippet = `getUpdates(${ updatesString . join ( ', ' ) } )` ;
9096
91- // atomicAdd only supports uint/int type. For float, we use
92- // atomicCompareExchangeWeak to simulate.
93- const atomicAddSnippet = this . type === 'int32' ?
94- `atomicAdd(&(result[flatIndex]), i32(updateValue));` :
95- `
96- var oldValue = atomicLoad(&(result[flatIndex]));
97- var exchanged = false;
98- for (; !exchanged;) {
99- let newValueF32 = bitcast<f32>(oldValue) + updateValue;
100- let newValue = bitcast<i32>(newValueF32);
101- let res = atomicCompareExchangeWeak(&(result[flatIndex]), oldValue, newValue);
102- oldValue = res.old_value;
103- exchanged = res.exchanged;
104- }
105- ` ;
97+ const atomicRMW = ( ptr : string , val : string ) => {
98+ let atomicAddSnippet = `atomicAdd(${ ptr } , bitcast<i32>(${ val } ))` ;
99+ if ( this . type === 'float32' ) {
100+ atomicAddSnippet = `
101+ {
102+ var oldBits = 0;
103+ var newBits = bitcast<i32>(${ val } );
104+ loop {
105+ let info = atomicCompareExchangeWeak(${ ptr } , oldBits, newBits);
106+ if (info.exchanged) {
107+ break;
108+ }
109+ oldBits = info.old_value;
110+ let oldValue = bitcast<f32>(oldBits);
111+ let newValue = oldValue + (${ val } );
112+ newBits = bitcast<i32>(newValue);
113+ }
114+ }
115+ ` ;
116+ }
117+ const atomicStoreSnippet = `atomicStore(${ ptr } , bitcast<i32>(${ val } ));` ;
118+ return this . sumDupeIndices ? atomicAddSnippet : atomicStoreSnippet ;
119+ } ;
106120
107121 const userCode = `
108122 ${ getUpdatesCoordsFromFlatIndex }
@@ -116,10 +130,11 @@ export class ScatterOptimizedProgram implements WebGPUProgram {
116130 let indexInside = i32(round(${ indicesSnippet } ));
117131 flattenedIndex = flattenedIndex + indexInside * ${ strideString } ;
118132 }
119- let updateValue = ${ updatesSnippet } ;
133+ let updateValue =
134+ ${ mapToWgslTypes ( this . type , false ) } (${ updatesSnippet } );
120135 let flatIndex = getOutputIndexFromCoords(${ outCoordsString } );
121136
122- ${ atomicAddSnippet }
137+ ${ atomicRMW ( '&result[flatIndex]' , 'updateValue' ) } ;
123138 }
124139 }` ;
125140 return userCode ;
0 commit comments