Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tfjs-backend-webgl/src/gpgpu_context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -539,11 +539,15 @@ export class GPGPUContext {
return;
}
// Start a new loop that polls.
let scheduleFn = undefined;
if ('setTimeoutCustom' in env().platform) {
scheduleFn = env().platform.setTimeoutCustom.bind(env().platform);
}
util.repeatedTry(() => {
this.pollItems();
// End the loop if no more items to poll.
return this.itemsToPoll.length === 0;
});
}, () => 0, null, scheduleFn);
}

private bindTextureToFrameBuffer(texture: WebGLTexture) {
Expand Down
4 changes: 3 additions & 1 deletion tfjs-backend-webgpu/src/kernels/FromPixels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ export function fromPixels(args: {
[pixels.width, pixels.height];
const outputShape = [height, width, numChannels];

// Disable importExternalTexture temporarily as it has problem in spec and
// browser impl
const importVideo =
env().getBool('WEBGPU_IMPORT_EXTERNAL_TEXTURE') && isVideo;
false && env().getBool('WEBGPU_IMPORT_EXTERNAL_TEXTURE') && isVideo;
const isVideoOrImage = isVideo || isImage;
if (isImageBitmap || isCanvas || isVideoOrImage) {
let textureInfo: TextureInfo;
Expand Down
3 changes: 3 additions & 0 deletions tfjs-core/src/flags.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,6 @@ ENV.registerFlag('ENGINE_COMPILE_ONLY', () => false);

/** Whether to enable canvas2d willReadFrequently for GPU backends */
ENV.registerFlag('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU', () => false);

/** Whether to use setTimeoutCustom */
ENV.registerFlag('USE_SETTIMEOUTCUSTOM', () => false);
2 changes: 2 additions & 0 deletions tfjs-core/src/platforms/platform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,6 @@ export interface Platform {
encode(text: string, encoding: string): Uint8Array;
/** Decode the provided bytes into a string using the provided encoding. */
decode(bytes: Uint8Array, encoding: string): string;

setTimeoutCustom?(functionRef: Function, delay: number): void;
}
39 changes: 39 additions & 0 deletions tfjs-core/src/platforms/platform_browser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ export class PlatformBrowser implements Platform {
// https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder
private textEncoder: TextEncoder;

// For setTimeoutCustom
private messageName = 'setTimeoutCustom';
private functionRefs: Function[] = [];
private handledMessageCount = 0;
private hasEventListener = false;

fetch(path: string, init?: RequestInit): Promise<Response> {
return fetch(path, init);
}
Expand All @@ -50,6 +56,39 @@ export class PlatformBrowser implements Platform {
decode(bytes: Uint8Array, encoding: string): string {
return new TextDecoder(encoding).decode(bytes);
}

// If the setTimeout nesting level is greater than 5 and timeout is less
// than 4ms, timeout will be clamped to 4ms, which hurts the perf.
// Interleaving window.postMessage and setTimeout will trick the browser and
// avoid the clamp.
setTimeoutCustom(functionRef: Function, delay: number): void {
if (!window || !env().getBool('USE_SETTIMEOUTCUSTOM')) {
setTimeout(functionRef, delay);
return;
}

this.functionRefs.push(functionRef);
setTimeout(() => {
window.postMessage(
{name: this.messageName, index: this.functionRefs.length - 1}, '*');
}, delay);

if (!this.hasEventListener) {
this.hasEventListener = true;
window.addEventListener('message', (event: MessageEvent) => {
if (event.source === window && event.data.name === this.messageName) {
event.stopPropagation();
const functionRef = this.functionRefs[event.data.index];
functionRef();
this.handledMessageCount++;
if (this.handledMessageCount === this.functionRefs.length) {
this.functionRefs = [];
this.handledMessageCount = 0;
}
}
}, true);
}
}
}

if (env().get('IS_BROWSER')) {
Expand Down
60 changes: 60 additions & 0 deletions tfjs-core/src/platforms/platform_browser_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* =============================================================================
*/

import {env} from '../environment';
import {BROWSER_ENVS, describeWithFlags} from '../jasmine_util';

import {PlatformBrowser} from './platform_browser';
Expand Down Expand Up @@ -88,3 +89,62 @@ describeWithFlags('PlatformBrowser', BROWSER_ENVS, async () => {
expect(s).toEqual('Здраво');
});
});

describeWithFlags('setTimeout', BROWSER_ENVS, () => {
const totalCount = 100;
// Skip the first few samples because the browser does not clamp the timeout
const skipCount = 5;

it('setTimeout', (done) => {
let count = 0;
let startTime = performance.now();
let totalTime = 0;
setTimeout(_testSetTimeout, 0);

function _testSetTimeout() {
const endTime = performance.now();
count++;
if (count > skipCount) {
totalTime += endTime - startTime;
}
if (count === totalCount) {
const averageTime = totalTime / (totalCount - skipCount);
console.log(`averageTime of setTimeout is ${averageTime} ms`);
expect(averageTime).toBeGreaterThan(4);
done();
return;
}
startTime = performance.now();
setTimeout(_testSetTimeout, 0);
}
});

it('setTimeoutCustom', (done) => {
let count = 0;
let startTime = performance.now();
let totalTime = 0;
let originUseSettimeoutcustom: boolean;

originUseSettimeoutcustom = env().getBool('USE_SETTIMEOUTCUSTOM');
env().set('USE_SETTIMEOUTCUSTOM', true);
env().platform.setTimeoutCustom(_testSetTimeoutCustom, 0);

function _testSetTimeoutCustom() {
const endTime = performance.now();
count++;
if (count > skipCount) {
totalTime += endTime - startTime;
}
if (count === totalCount) {
const averageTime = totalTime / (totalCount - skipCount);
console.log(`averageTime of setTimeoutCustom is ${averageTime} ms`);
expect(averageTime).toBeLessThan(4);
done();
env().set('USE_SETTIMEOUTCUSTOM', originUseSettimeoutcustom);
return;
}
startTime = performance.now();
env().platform.setTimeoutCustom(_testSetTimeoutCustom, 0);
}
});
});
32 changes: 17 additions & 15 deletions tfjs-core/src/util_base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ export function rightPad(a: string, size: number): string {

export function repeatedTry(
checkFn: () => boolean, delayFn = (counter: number) => 0,
maxCounter?: number): Promise<void> {
maxCounter?: number,
scheduleFn: (functionRef: Function, delay: number) => void =
setTimeout): Promise<void> {
return new Promise<void>((resolve, reject) => {
let tryCount = 0;

Expand All @@ -321,7 +323,7 @@ export function repeatedTry(
reject();
return;
}
setTimeout(tryFn, nextBackoff);
scheduleFn(tryFn, nextBackoff);
};

tryFn();
Expand Down Expand Up @@ -504,8 +506,8 @@ export function hasEncodingLoss(oldType: DataType, newType: DataType): boolean {
return true;
}

export function isTypedArray(a: {}):
a is Float32Array|Int32Array|Uint8Array|Uint8ClampedArray {
export function isTypedArray(a: {}): a is Float32Array|Int32Array|Uint8Array|
Uint8ClampedArray {
return a instanceof Float32Array || a instanceof Int32Array ||
a instanceof Uint8Array || a instanceof Uint8ClampedArray;
}
Expand All @@ -524,9 +526,9 @@ export function bytesPerElement(dtype: DataType): number {

/**
* Returns the approximate number of bytes allocated in the string array - 2
* bytes per character. Computing the exact bytes for a native string in JS is
* not possible since it depends on the encoding of the html page that serves
* the website.
* bytes per character. Computing the exact bytes for a native string in JS
* is not possible since it depends on the encoding of the html page that
* serves the website.
*/
export function bytesFromStringArray(arr: Uint8Array[]): number {
if (arr == null) {
Expand Down Expand Up @@ -556,9 +558,9 @@ export function inferDtype(values: TensorLike): DataType {
}
if (values instanceof Float32Array) {
return 'float32';
} else if (values instanceof Int32Array
|| values instanceof Uint8Array
|| values instanceof Uint8ClampedArray) {
} else if (
values instanceof Int32Array || values instanceof Uint8Array ||
values instanceof Uint8ClampedArray) {
return 'int32';
} else if (isNumber(values)) {
return 'float32';
Expand Down Expand Up @@ -712,8 +714,8 @@ export function locToIndex(
}

/**
* Computes the location (multidimensional index) in a tensor/multidimentional
* array for a given flat index.
* Computes the location (multidimensional index) in a
* tensor/multidimentional array for a given flat index.
*
* @param index Index in flat array.
* @param rank Rank of tensor.
Expand Down Expand Up @@ -744,8 +746,8 @@ export function isPromise(object: any): object is Promise<unknown> {
// We chose to not use 'obj instanceOf Promise' for two reasons:
// 1. It only reliably works for es6 Promise, not other Promise
// implementations.
// 2. It doesn't work with framework that uses zone.js. zone.js monkey patch
// the async calls, so it is possible the obj (patched) is comparing to a
// pre-patched Promise.
// 2. It doesn't work with framework that uses zone.js. zone.js monkey
// patch the async calls, so it is possible the obj (patched) is
// comparing to a pre-patched Promise.
return object && object.then && typeof object.then === 'function';
}
4 changes: 2 additions & 2 deletions tfjs-core/src/util_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,8 @@ describeWithFlags('util.toNestedArray for a complex tensor', ALL_ENVS, () => {

describe('util.fetch', () => {
it('should call the platform fetch', () => {
spyOn(tf.env().platform, 'fetch').and
.callFake(async () => ({} as unknown as Response));
spyOn(tf.env().platform, 'fetch')
.and.callFake(async () => ({} as unknown as Response));

util.fetch('test/path', {method: 'GET'});

Expand Down
2 changes: 1 addition & 1 deletion tools/karma_template.conf.js
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ const CUSTOM_LAUNCHERS = {
win_10_chrome: {
base: 'BrowserStack',
browser: 'chrome',
browser_version: '101.0',
browser_version: '104.0',
os: 'Windows',
os_version: '10'
},
Expand Down