diff --git a/tfjs-backend-webgl/src/gpgpu_context.ts b/tfjs-backend-webgl/src/gpgpu_context.ts index 3bebba68e22..279a83e0811 100644 --- a/tfjs-backend-webgl/src/gpgpu_context.ts +++ b/tfjs-backend-webgl/src/gpgpu_context.ts @@ -539,11 +539,15 @@ export class GPGPUContext { return; } // Start a new loop that polls. + let scheduleFn = undefined; + if ('setTimeoutCustom' in env().platform) { + scheduleFn = env().platform.setTimeoutCustom.bind(env().platform); + } util.repeatedTry(() => { this.pollItems(); // End the loop if no more items to poll. return this.itemsToPoll.length === 0; - }); + }, () => 0, null, scheduleFn); } private bindTextureToFrameBuffer(texture: WebGLTexture) { diff --git a/tfjs-backend-webgpu/src/kernels/FromPixels.ts b/tfjs-backend-webgpu/src/kernels/FromPixels.ts index d33ee65c138..eedcaf72c30 100644 --- a/tfjs-backend-webgpu/src/kernels/FromPixels.ts +++ b/tfjs-backend-webgpu/src/kernels/FromPixels.ts @@ -64,8 +64,10 @@ export function fromPixels(args: { [pixels.width, pixels.height]; const outputShape = [height, width, numChannels]; + // Disable importExternalTexture temporarily as it has problem in spec and + // browser impl const importVideo = - env().getBool('WEBGPU_IMPORT_EXTERNAL_TEXTURE') && isVideo; + false && env().getBool('WEBGPU_IMPORT_EXTERNAL_TEXTURE') && isVideo; const isVideoOrImage = isVideo || isImage; if (isImageBitmap || isCanvas || isVideoOrImage) { let textureInfo: TextureInfo; diff --git a/tfjs-core/src/flags.ts b/tfjs-core/src/flags.ts index 08269c19ccb..6c9470bc707 100644 --- a/tfjs-core/src/flags.ts +++ b/tfjs-core/src/flags.ts @@ -82,3 +82,6 @@ ENV.registerFlag('ENGINE_COMPILE_ONLY', () => false); /** Whether to enable canvas2d willReadFrequently for GPU backends */ ENV.registerFlag('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU', () => false); + +/** Whether to use setTimeoutCustom */ +ENV.registerFlag('USE_SETTIMEOUTCUSTOM', () => false); diff --git a/tfjs-core/src/platforms/platform.ts b/tfjs-core/src/platforms/platform.ts index 7a6a5979e0e..0b0ebbf4c56 100644 --- a/tfjs-core/src/platforms/platform.ts +++ b/tfjs-core/src/platforms/platform.ts @@ -46,4 +46,6 @@ export interface Platform { encode(text: string, encoding: string): Uint8Array; /** Decode the provided bytes into a string using the provided encoding. */ decode(bytes: Uint8Array, encoding: string): string; + + setTimeoutCustom?(functionRef: Function, delay: number): void; } diff --git a/tfjs-core/src/platforms/platform_browser.ts b/tfjs-core/src/platforms/platform_browser.ts index ee4cd61b267..4df4f9c73e9 100644 --- a/tfjs-core/src/platforms/platform_browser.ts +++ b/tfjs-core/src/platforms/platform_browser.ts @@ -29,6 +29,12 @@ export class PlatformBrowser implements Platform { // https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder private textEncoder: TextEncoder; + // For setTimeoutCustom + private messageName = 'setTimeoutCustom'; + private functionRefs: Function[] = []; + private handledMessageCount = 0; + private hasEventListener = false; + fetch(path: string, init?: RequestInit): Promise { return fetch(path, init); } @@ -50,6 +56,39 @@ export class PlatformBrowser implements Platform { decode(bytes: Uint8Array, encoding: string): string { return new TextDecoder(encoding).decode(bytes); } + + // If the setTimeout nesting level is greater than 5 and timeout is less + // than 4ms, timeout will be clamped to 4ms, which hurts the perf. + // Interleaving window.postMessage and setTimeout will trick the browser and + // avoid the clamp. + setTimeoutCustom(functionRef: Function, delay: number): void { + if (!window || !env().getBool('USE_SETTIMEOUTCUSTOM')) { + setTimeout(functionRef, delay); + return; + } + + this.functionRefs.push(functionRef); + setTimeout(() => { + window.postMessage( + {name: this.messageName, index: this.functionRefs.length - 1}, '*'); + }, delay); + + if (!this.hasEventListener) { + this.hasEventListener = true; + window.addEventListener('message', (event: MessageEvent) => { + if (event.source === window && event.data.name === this.messageName) { + event.stopPropagation(); + const functionRef = this.functionRefs[event.data.index]; + functionRef(); + this.handledMessageCount++; + if (this.handledMessageCount === this.functionRefs.length) { + this.functionRefs = []; + this.handledMessageCount = 0; + } + } + }, true); + } + } } if (env().get('IS_BROWSER')) { diff --git a/tfjs-core/src/platforms/platform_browser_test.ts b/tfjs-core/src/platforms/platform_browser_test.ts index 16827d1a823..46ec45569ec 100644 --- a/tfjs-core/src/platforms/platform_browser_test.ts +++ b/tfjs-core/src/platforms/platform_browser_test.ts @@ -15,6 +15,7 @@ * ============================================================================= */ +import {env} from '../environment'; import {BROWSER_ENVS, describeWithFlags} from '../jasmine_util'; import {PlatformBrowser} from './platform_browser'; @@ -88,3 +89,62 @@ describeWithFlags('PlatformBrowser', BROWSER_ENVS, async () => { expect(s).toEqual('Здраво'); }); }); + +describeWithFlags('setTimeout', BROWSER_ENVS, () => { + const totalCount = 100; + // Skip the first few samples because the browser does not clamp the timeout + const skipCount = 5; + + it('setTimeout', (done) => { + let count = 0; + let startTime = performance.now(); + let totalTime = 0; + setTimeout(_testSetTimeout, 0); + + function _testSetTimeout() { + const endTime = performance.now(); + count++; + if (count > skipCount) { + totalTime += endTime - startTime; + } + if (count === totalCount) { + const averageTime = totalTime / (totalCount - skipCount); + console.log(`averageTime of setTimeout is ${averageTime} ms`); + expect(averageTime).toBeGreaterThan(4); + done(); + return; + } + startTime = performance.now(); + setTimeout(_testSetTimeout, 0); + } + }); + + it('setTimeoutCustom', (done) => { + let count = 0; + let startTime = performance.now(); + let totalTime = 0; + let originUseSettimeoutcustom: boolean; + + originUseSettimeoutcustom = env().getBool('USE_SETTIMEOUTCUSTOM'); + env().set('USE_SETTIMEOUTCUSTOM', true); + env().platform.setTimeoutCustom(_testSetTimeoutCustom, 0); + + function _testSetTimeoutCustom() { + const endTime = performance.now(); + count++; + if (count > skipCount) { + totalTime += endTime - startTime; + } + if (count === totalCount) { + const averageTime = totalTime / (totalCount - skipCount); + console.log(`averageTime of setTimeoutCustom is ${averageTime} ms`); + expect(averageTime).toBeLessThan(4); + done(); + env().set('USE_SETTIMEOUTCUSTOM', originUseSettimeoutcustom); + return; + } + startTime = performance.now(); + env().platform.setTimeoutCustom(_testSetTimeoutCustom, 0); + } + }); +}); diff --git a/tfjs-core/src/util_base.ts b/tfjs-core/src/util_base.ts index f47f9ac6268..09349a7c687 100644 --- a/tfjs-core/src/util_base.ts +++ b/tfjs-core/src/util_base.ts @@ -303,7 +303,9 @@ export function rightPad(a: string, size: number): string { export function repeatedTry( checkFn: () => boolean, delayFn = (counter: number) => 0, - maxCounter?: number): Promise { + maxCounter?: number, + scheduleFn: (functionRef: Function, delay: number) => void = + setTimeout): Promise { return new Promise((resolve, reject) => { let tryCount = 0; @@ -321,7 +323,7 @@ export function repeatedTry( reject(); return; } - setTimeout(tryFn, nextBackoff); + scheduleFn(tryFn, nextBackoff); }; tryFn(); @@ -504,8 +506,8 @@ export function hasEncodingLoss(oldType: DataType, newType: DataType): boolean { return true; } -export function isTypedArray(a: {}): - a is Float32Array|Int32Array|Uint8Array|Uint8ClampedArray { +export function isTypedArray(a: {}): a is Float32Array|Int32Array|Uint8Array| + Uint8ClampedArray { return a instanceof Float32Array || a instanceof Int32Array || a instanceof Uint8Array || a instanceof Uint8ClampedArray; } @@ -524,9 +526,9 @@ export function bytesPerElement(dtype: DataType): number { /** * Returns the approximate number of bytes allocated in the string array - 2 - * bytes per character. Computing the exact bytes for a native string in JS is - * not possible since it depends on the encoding of the html page that serves - * the website. + * bytes per character. Computing the exact bytes for a native string in JS + * is not possible since it depends on the encoding of the html page that + * serves the website. */ export function bytesFromStringArray(arr: Uint8Array[]): number { if (arr == null) { @@ -556,9 +558,9 @@ export function inferDtype(values: TensorLike): DataType { } if (values instanceof Float32Array) { return 'float32'; - } else if (values instanceof Int32Array - || values instanceof Uint8Array - || values instanceof Uint8ClampedArray) { + } else if ( + values instanceof Int32Array || values instanceof Uint8Array || + values instanceof Uint8ClampedArray) { return 'int32'; } else if (isNumber(values)) { return 'float32'; @@ -712,8 +714,8 @@ export function locToIndex( } /** - * Computes the location (multidimensional index) in a tensor/multidimentional - * array for a given flat index. + * Computes the location (multidimensional index) in a + * tensor/multidimentional array for a given flat index. * * @param index Index in flat array. * @param rank Rank of tensor. @@ -744,8 +746,8 @@ export function isPromise(object: any): object is Promise { // We chose to not use 'obj instanceOf Promise' for two reasons: // 1. It only reliably works for es6 Promise, not other Promise // implementations. - // 2. It doesn't work with framework that uses zone.js. zone.js monkey patch - // the async calls, so it is possible the obj (patched) is comparing to a - // pre-patched Promise. + // 2. It doesn't work with framework that uses zone.js. zone.js monkey + // patch the async calls, so it is possible the obj (patched) is + // comparing to a pre-patched Promise. return object && object.then && typeof object.then === 'function'; } diff --git a/tfjs-core/src/util_test.ts b/tfjs-core/src/util_test.ts index 41d73b84287..8683e76bdd4 100644 --- a/tfjs-core/src/util_test.ts +++ b/tfjs-core/src/util_test.ts @@ -584,8 +584,8 @@ describeWithFlags('util.toNestedArray for a complex tensor', ALL_ENVS, () => { describe('util.fetch', () => { it('should call the platform fetch', () => { - spyOn(tf.env().platform, 'fetch').and - .callFake(async () => ({} as unknown as Response)); + spyOn(tf.env().platform, 'fetch') + .and.callFake(async () => ({} as unknown as Response)); util.fetch('test/path', {method: 'GET'}); diff --git a/tools/karma_template.conf.js b/tools/karma_template.conf.js index 5ad1adb877c..b836fc348ed 100644 --- a/tools/karma_template.conf.js +++ b/tools/karma_template.conf.js @@ -61,7 +61,7 @@ const CUSTOM_LAUNCHERS = { win_10_chrome: { base: 'BrowserStack', browser: 'chrome', - browser_version: '101.0', + browser_version: '104.0', os: 'Windows', os_version: '10' },