Skip to content

Commit 1b63e5d

Browse files
committed
Merge branch 'master' of https://github.com/tensorflow/tfjs into jax2tfjs
2 parents 1dac6d1 + d515f4d commit 1b63e5d

File tree

10 files changed

+158
-252
lines changed

10 files changed

+158
-252
lines changed

tfjs-backend-webgpu/src/from_pixels_webgpu_test.ts

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
*/
1717

1818
import * as tf from '@tensorflow/tfjs-core';
19+
import {test_util} from '@tensorflow/tfjs-core';
1920
import {WebGPUBackend} from './backend_webgpu';
2021
import {describeWebGPU} from './test_util';
2122

@@ -37,26 +38,15 @@ describeWebGPU('fromPixels', () => {
3738
const textureManager = backend.textureManager;
3839
textureManager.dispose();
3940

40-
const video = document.createElement('video');
4141
const source = document.createElement('source');
4242
source.src =
4343
// tslint:disable-next-line:max-line-length
4444
'data:video/mp4;base64,AAAAIGZ0eXBpc29tAAACAGlzb21pc28yYXZjMW1wNDEAAAAIZnJlZQAAAu1tZGF0AAACrQYF//+p3EXpvebZSLeWLNgg2SPu73gyNjQgLSBjb3JlIDE1NSByMjkwMSA3ZDBmZjIyIC0gSC4yNjQvTVBFRy00IEFWQyBjb2RlYyAtIENvcHlsZWZ0IDIwMDMtMjAxOCAtIGh0dHA6Ly93d3cudmlkZW9sYW4ub3JnL3gyNjQuaHRtbCAtIG9wdGlvbnM6IGNhYmFjPTEgcmVmPTMgZGVibG9jaz0xOjA6MCBhbmFseXNlPTB4MzoweDExMyBtZT1oZXggc3VibWU9NyBwc3k9MSBwc3lfcmQ9MS4wMDowLjAwIG1peGVkX3JlZj0xIG1lX3JhbmdlPTE2IGNocm9tYV9tZT0xIHRyZWxsaXM9MSA4eDhkY3Q9MSBjcW09MCBkZWFkem9uZT0yMSwxMSBmYXN0X3Bza2lwPTEgY2hyb21hX3FwX29mZnNldD0tMiB0aHJlYWRzPTMgbG9va2FoZWFkX3RocmVhZHM9MSBzbGljZWRfdGhyZWFkcz0wIG5yPTAgZGVjaW1hdGU9MSBpbnRlcmxhY2VkPTAgYmx1cmF5X2NvbXBhdD0wIGNvbnN0cmFpbmVkX2ludHJhPTAgYmZyYW1lcz0zIGJfcHlyYW1pZD0yIGJfYWRhcHQ9MSBiX2JpYXM9MCBkaXJlY3Q9MSB3ZWlnaHRiPTEgb3Blbl9nb3A9MCB3ZWlnaHRwPTIga2V5aW50PTI1MCBrZXlpbnRfbWluPTEgc2NlbmVjdXQ9NDAgaW50cmFfcmVmcmVzaD0wIHJjX2xvb2thaGVhZD00MCByYz1jcmYgbWJ0cmVlPTEgY3JmPTI4LjAgcWNvbXA9MC42MCBxcG1pbj0wIHFwbWF4PTY5IHFwc3RlcD00IGlwX3JhdGlvPTEuNDAgYXE9MToxLjAwAIAAAAAwZYiEAD//8m+P5OXfBeLGOfKE3xkODvFZuBflHv/+VwJIta6cbpIo4ABLoKBaYTkTAAAC7m1vb3YAAABsbXZoZAAAAAAAAAAAAAAAAAAAA+gAAAPoAAEAAAEAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIAAAIYdHJhawAAAFx0a2hkAAAAAwAAAAAAAAAAAAAAAQAAAAAAAAPoAAAAAAAAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAQAAAAACgAAAAWgAAAAAAJGVkdHMAAAAcZWxzdAAAAAAAAAABAAAD6AAAAAAAAQAAAAABkG1kaWEAAAAgbWRoZAAAAAAAAAAAAAAAAAAAQAAAAEAAVcQAAAAAAC1oZGxyAAAAAAAAAAB2aWRlAAAAAAAAAAAAAAAAVmlkZW9IYW5kbGVyAAAAATttaW5mAAAAFHZtaGQAAAABAAAAAAAAAAAAAAAkZGluZgAAABxkcmVmAAAAAAAAAAEAAAAMdXJsIAAAAAEAAAD7c3RibAAAAJdzdHNkAAAAAAAAAAEAAACHYXZjMQAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAACgAFoASAAAAEgAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABj//wAAADFhdmNDAWQACv/hABhnZAAKrNlCjfkhAAADAAEAAAMAAg8SJZYBAAZo6+JLIsAAAAAYc3R0cwAAAAAAAAABAAAAAQAAQAAAAAAcc3RzYwAAAAAAAAABAAAAAQAAAAEAAAABAAAAFHN0c3oAAAAAAAAC5QAAAAEAAAAUc3RjbwAAAAAAAAABAAAAMAAAAGJ1ZHRhAAAAWm1ldGEAAAAAAAAAIWhkbHIAAAAAAAAAAG1kaXJhcHBsAAAAAAAAAAAAAAAALWlsc3QAAAAlqXRvbwAAAB1kYXRhAAAAAQAAAABMYXZmNTguMTIuMTAw';
4545
source.type = 'video/mp4';
46-
video.appendChild(source);
47-
document.body.appendChild(video);
48-
49-
video.autoplay = true;
50-
video.loop = true;
51-
video.muted = true;
52-
video.preload = 'auto';
53-
await video.play();
5446

55-
// ensure video element to be loaded
56-
if ('requestVideoFrameCallback' in video) {
57-
// tslint:disable-next-line:no-any
58-
await new Promise(go => (video as any).requestVideoFrameCallback(go));
59-
}
47+
const video = await test_util.createVideoElement(source);
48+
document.body.appendChild(video);
49+
await test_util.play(video);
6050

6151
{
6252
tf.env().set('WEBGPU_IMPORT_EXTERNAL_TEXTURE', true);

tfjs-backend-webgpu/src/kernels/ScatterNd.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
import {backend_util, KernelConfig, KernelFunc, ScatterNd, ScatterNdAttrs, ScatterNdInputs, TensorInfo, util} from '@tensorflow/tfjs-core';
1919

2020
import {WebGPUBackend} from '../backend_webgpu';
21+
import {ScatterProgram} from '../scatter_webgpu';
2122

2223
import {fill} from './Fill';
2324
import {reshape} from './Reshape';
24-
import {ScatterOptimizedProgram} from '../scatter_optimized_webgpu';
2525

2626
export function scatterNd(args: {
2727
inputs: ScatterNdInputs,
@@ -54,7 +54,7 @@ export function scatterNd(args: {
5454
{type: 'int32', data: [sliceRank]}, {type: 'int32', data: strides},
5555
{type: 'int32', data: [size]}
5656
];
57-
const program = new ScatterOptimizedProgram(
57+
const program = new ScatterProgram(
5858
flattenX.shape, sliceRank, flattenIndices.shape.length,
5959
flattenX.shape.length, strides, flattenShape, type);
6060
const res = backend.runWebGPUProgram(

tfjs-backend-webgpu/src/kernels/SparseToDense.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import {backend_util, KernelConfig, KernelFunc, Rank, SparseToDense, SparseToDen
1919

2020
import {WebGPUBackend} from '../backend_webgpu';
2121
import {scatterImplCPU} from '../kernel_utils/shared';
22-
import {ScatterOptimizedProgram} from '../scatter_optimized_webgpu';
22+
import {ScatterProgram} from '../scatter_webgpu';
2323

2424
import {identity} from './Identity';
2525
import {reshape} from './Reshape';
@@ -89,7 +89,7 @@ export function sparseToDense(args: {
8989
break;
9090
case 1:
9191
if (true) {
92-
const program = new ScatterOptimizedProgram(
92+
const program = new ScatterProgram(
9393
[numUpdates, sliceSize], sliceRank, $sparseIndices.shape.length,
9494
$sparseValues.shape.length, strides, flattenShape, type,
9595
sumDupeIndices);
@@ -101,15 +101,15 @@ export function sparseToDense(args: {
101101
default:
102102
if (true) {
103103
// First replace the default value with 0 at indices.
104-
const program = new ScatterOptimizedProgram(
104+
const program = new ScatterProgram(
105105
[numUpdates, sliceSize], sliceRank, $sparseIndices.shape.length,
106106
zero.shape.length, strides, flattenShape, type, sumDupeIndices);
107107
backend.runWebGPUProgram(
108108
program, [zero, $sparseIndices], type, uniformData, $denseValues);
109109
}
110110
{
111111
// Then replace 0 with the (sum of) sparse value(s) at indices.
112-
const program = new ScatterOptimizedProgram(
112+
const program = new ScatterProgram(
113113
[numUpdates, sliceSize], sliceRank, $sparseIndices.shape.length,
114114
$sparseValues.shape.length, strides, flattenShape, type);
115115
backend.runWebGPUProgram(

tfjs-backend-webgpu/src/matmul_packed_webgpu.ts

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ const calculateResultSnippet =
160160

161161
export function makeMatMulPackedVec4Source(
162162
workPerThread: number[], workGroupSize: [number, number, number],
163-
transposeA = false, tileInner = 32, splitK = false,
163+
transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32,
164164
isVectorA = false): string {
165165
const tileAOuter = workGroupSize[1] * workPerThread[1];
166166
const tileBOuter = workGroupSize[0] * workPerThread[0];
@@ -209,8 +209,10 @@ export function makeMatMulPackedVec4Source(
209209
let batch = ${splitK ? '0' : 'i32(globalId.z)'};
210210
let globalRowStart = i32(workgroupId.y) * ${tileAOuter};
211211
212-
let numTiles = ${splitK ? '1' : '(uniforms.dimInner - 1) / TileInner + 1'};
213-
var kStart = ${splitK ? 'i32(globalId.z) * TileInner' : '0'};
212+
let numTiles = ${
213+
splitK ? `${Math.ceil(splitedDimInner / tileInner)}` :
214+
'(uniforms.dimInner - 1) / TileInner + 1'};
215+
var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'};
214216
215217
var acc: array<vec4<f32>, RowPerThread>;
216218
@@ -281,7 +283,8 @@ const readDataFromSubASnippet = (transposeA: boolean) => {
281283

282284
export function makeMatMulPackedSource(
283285
workPerThread: number[], workGroupSize: [number, number, number],
284-
transposeA = false, tileInner = 32, splitK = false): string {
286+
transposeA = false, tileInner = 32, splitK = false,
287+
splitedDimInner = 32): string {
285288
const tileAOuter = workPerThread[1] * workGroupSize[1];
286289
const tileBOuter = workPerThread[0] * workGroupSize[0];
287290
const tileAWidth = transposeA ? tileAOuter : tileInner;
@@ -323,8 +326,9 @@ export function makeMatMulPackedSource(
323326
let globalRowStart = i32(workgroupId.y) * ${tileAOuter};
324327
325328
let numTiles = ${
326-
splitK ? '1' : '(uniforms.dimInner - 1) / TileInner + 1'};
327-
var kStart = ${splitK ? 'i32(globalId.z) * TileInner' : '0'};
329+
splitK ? `${Math.ceil(splitedDimInner / tileInner)}` :
330+
'(uniforms.dimInner - 1) / TileInner + 1'};
331+
var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'};
328332
329333
var acc : array<array<f32, ColPerThread>, RowPerThread>;
330334
@@ -565,7 +569,7 @@ export class MatMulPackedProgram implements WebGPUProgram {
565569
this.isVec4 ?
566570
makeMatMulPackedVec4Source(
567571
this.elementsPerThread, this.workGroupSize, this.transposeA,
568-
this.tileInner, false, this.isVectorA) :
572+
this.tileInner, false, null, this.isVectorA) :
569573
(this.isVectorA ? makeVectorMatrixProductSource(
570574
this.workGroupSize, this.transposeA) :
571575
makeMatMulPackedSource(

tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ export class MatMulSplitKProgram implements WebGPUProgram {
3737
batchAEqualOne: boolean;
3838
batchBEqualOne: boolean;
3939
isVec4 = false;
40-
tileInner = 32;
40+
splitedDimInner = 128;
4141

4242
constructor(
4343
outputShape: [number, number, number], dimInner: number,
@@ -51,7 +51,8 @@ export class MatMulSplitKProgram implements WebGPUProgram {
5151
this.isVec4 = (transposeA && this.outputShape[1] % 4 === 0 ||
5252
!transposeA && dimInner % 4 === 0) &&
5353
this.outputShape[2] % 4 === 0;
54-
this.elementsPerThread = [4, 4, this.tileInner];
54+
this.elementsPerThread = [4, 4, this.splitedDimInner];
55+
5556
if (!this.isVec4) {
5657
if (this.outputShape[1] < 16) {
5758
this.elementsPerThread[1] = 1;
@@ -119,10 +120,10 @@ export class MatMulSplitKProgram implements WebGPUProgram {
119120
${
120121
this.isVec4 ? makeMatMulPackedVec4Source(
121122
this.elementsPerThread, this.workGroupSize,
122-
this.transposeA, this.tileInner, true) :
123+
this.transposeA, 32, true, this.splitedDimInner) :
123124
makeMatMulPackedSource(
124125
this.elementsPerThread, this.workGroupSize,
125-
this.transposeA, this.tileInner, true)}
126+
this.transposeA, 32, true, this.splitedDimInner)}
126127
`;
127128
return userCode;
128129
}

tfjs-backend-webgpu/src/scatter_optimized_webgpu.ts

Lines changed: 0 additions & 142 deletions
This file was deleted.

0 commit comments

Comments
 (0)