Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 78431ed

Browse files
author
Nikhil Thorat
authored
Add ndarray.getValuesAsync() method which returns a promise that resolves when values are ready. (#146)
* async values * merge * fix unit tests * remove loop demo * remove line from ndarray.ts * throw when lose context not supported * WEBGL_lose_context * respond to comments
1 parent 50c3805 commit 78431ed

14 files changed

+243
-77
lines changed

demos/benchmarks/conv_benchmarks.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ export class ConvGPUBenchmark extends ConvBenchmark {
8080
gpgpu.dispose();
8181
};
8282

83-
if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER')) {
84-
gpgpu.runBenchmark(benchmark).then((timeElapsed: number) => {
83+
if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) {
84+
gpgpu.runQuery(benchmark).then((timeElapsed: number) => {
8585
delayedCleanup();
8686
resolve(timeElapsed);
8787
});

demos/benchmarks/conv_transposed_benchmarks.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ export class ConvTransposedGPUBenchmark extends ConvTransposedBenchmark {
8181
gpgpu.dispose();
8282
};
8383

84-
if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER')) {
85-
gpgpu.runBenchmark(benchmark).then((timeElapsed: number) => {
84+
if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) {
85+
gpgpu.runQuery(benchmark).then((timeElapsed: number) => {
8686
delayedCleanup();
8787
resolve(timeElapsed);
8888
});

demos/benchmarks/logsumexp_benchmarks.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ export class LogSumExpGPUBenchmark extends BenchmarkTest {
6666
gpgpu.dispose();
6767
};
6868

69-
if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER')) {
70-
gpgpu.runBenchmark(benchmark).then((timeElapsed: number) => {
69+
if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) {
70+
gpgpu.runQuery(benchmark).then((timeElapsed: number) => {
7171
delayedCleanup();
7272
resolve(timeElapsed);
7373
});

demos/benchmarks/matmul_benchmarks.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ export class MatmulGPUBenchmark extends BenchmarkTest {
8080
gpgpu.dispose();
8181
};
8282

83-
if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER')) {
84-
gpgpu.runBenchmark(benchmark).then((timeElapsed: number) => {
83+
if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) {
84+
gpgpu.runQuery(benchmark).then((timeElapsed: number) => {
8585
delayedCleanup();
8686
resolve(timeElapsed);
8787
});

demos/benchmarks/pool_benchmarks.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ export class PoolGPUBenchmark extends PoolBenchmark {
9696
gpgpu.dispose();
9797
};
9898

99-
if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER')) {
100-
gpgpu.runBenchmark(benchmark).then((timeElapsed: number) => {
99+
if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) {
100+
gpgpu.runQuery(benchmark).then((timeElapsed: number) => {
101101
delayedCleanup();
102102
resolve(timeElapsed);
103103
});

demos/benchmarks/reduction_ops_benchmark.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ export class ReductionOpsGPUBenchmark extends ReductionOpsBenchmark {
7575
math.dispose();
7676
};
7777

78-
if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER')) {
79-
math.getGPGPUContext().runBenchmark(benchmark).then(
78+
if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) {
79+
math.getGPGPUContext().runQuery(benchmark).then(
8080
(timeElapsed: number) => {
8181
delayedCleanup();
8282
resolve(timeElapsed);

demos/benchmarks/unary_ops_benchmark.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ export class UnaryOpsGPUBenchmark extends UnaryOpsBenchmark {
103103
math.dispose();
104104
};
105105

106-
if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER')) {
107-
math.getGPGPUContext().runBenchmark(benchmark).then(
106+
if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) {
107+
math.getGPGPUContext().runQuery(benchmark).then(
108108
(timeElapsed: number) => {
109109
delayedCleanup();
110110
resolve(timeElapsed);

src/environment.ts

Lines changed: 59 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,18 @@ export enum Type {
2424
}
2525

2626
export interface Features {
27-
'WEBGL_DISJOINT_QUERY_TIMER'?: boolean;
27+
// Whether the disjoint_query_timer extension is an available extension.
28+
'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED'?: boolean;
29+
// Whether the timer object from the disjoint_query_timer extension gives
30+
// timing information that is reliable.
31+
'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE'?: boolean;
32+
// 0: No WebGL, 1: WebGL 1.0, 2: WebGL 2.0.
2833
'WEBGL_VERSION'?: number;
2934
}
3035

3136
export const URL_PROPERTIES: URLProperty[] = [
32-
{name: 'WEBGL_DISJOINT_QUERY_TIMER', type: Type.BOOLEAN},
37+
{name: 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED', type: Type.BOOLEAN},
38+
{name: 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', type: Type.BOOLEAN},
3339
{name: 'WEBGL_VERSION', type: Type.NUMBER}
3440
];
3541

@@ -38,50 +44,51 @@ export interface URLProperty {
3844
type: Type;
3945
}
4046

41-
function isWebGL2Enabled() {
47+
function getWebGLRenderingContext(webGLVersion: number): WebGLRenderingContext {
48+
if (webGLVersion === 0) {
49+
throw new Error('Cannot get WebGL rendering context, WebGL is disabled.');
50+
}
51+
4252
const tempCanvas = document.createElement('canvas');
43-
const gl = tempCanvas.getContext('webgl2') as WebGLRenderingContext;
53+
if (webGLVersion === 1) {
54+
return (tempCanvas.getContext('webgl') ||
55+
tempCanvas.getContext('experimental-webgl')) as
56+
WebGLRenderingContext;
57+
}
58+
return tempCanvas.getContext('webgl2') as WebGLRenderingContext;
59+
}
60+
61+
function loseContext(gl: WebGLRenderingContext) {
4462
if (gl != null) {
4563
const loseContextExtension = gl.getExtension('WEBGL_lose_context');
4664
if (loseContextExtension == null) {
4765
throw new Error(
4866
'Extension WEBGL_lose_context not supported on this browser.');
4967
}
5068
loseContextExtension.loseContext();
51-
return true;
5269
}
53-
return false;
5470
}
5571

56-
function isWebGL1Enabled() {
57-
const tempCanvas = document.createElement('canvas');
58-
const gl =
59-
(tempCanvas.getContext('webgl') ||
60-
tempCanvas.getContext('experimental-webgl')) as WebGLRenderingContext;
72+
function isWebGLVersionEnabled(webGLVersion: 1|2) {
73+
const gl = getWebGLRenderingContext(webGLVersion);
6174
if (gl != null) {
62-
const loseContextExtension = gl.getExtension('WEBGL_lose_context');
63-
if (loseContextExtension == null) {
64-
throw new Error(
65-
'Extension WEBGL_lose_context not supported on this browser.');
66-
}
67-
loseContextExtension.loseContext();
75+
loseContext(gl);
6876
return true;
6977
}
7078
return false;
7179
}
7280

73-
function evaluateFeature<K extends keyof Features>(feature: K): Features[K] {
74-
if (feature === 'WEBGL_DISJOINT_QUERY_TIMER') {
75-
return !device_util.isMobile();
76-
} else if (feature === 'WEBGL_VERSION') {
77-
if (isWebGL2Enabled()) {
78-
return 2;
79-
} else if (isWebGL1Enabled()) {
80-
return 1;
81-
}
82-
return 0;
81+
function isWebGLDisjointQueryTimerEnabled(webGLVersion: number) {
82+
const gl = getWebGLRenderingContext(webGLVersion);
83+
84+
const extensionName = webGLVersion === 1 ? 'EXT_disjoint_timer_query' :
85+
'EXT_disjoint_timer_query_webgl2';
86+
const ext = gl.getExtension(extensionName);
87+
const isExtEnabled = ext != null;
88+
if (gl != null) {
89+
loseContext(gl);
8390
}
84-
throw new Error(`Unknown feature ${feature}.`);
91+
return isExtEnabled;
8592
}
8693

8794
export class Environment {
@@ -98,10 +105,33 @@ export class Environment {
98105
return this.features[feature];
99106
}
100107

101-
this.features[feature] = evaluateFeature(feature);
108+
this.features[feature] = this.evaluateFeature(feature);
102109

103110
return this.features[feature];
104111
}
112+
113+
private evaluateFeature<K extends keyof Features>(feature: K): Features[K] {
114+
if (feature === 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED') {
115+
const webGLVersion = this.get('WEBGL_VERSION');
116+
117+
if (webGLVersion === 0) {
118+
return false;
119+
}
120+
121+
return isWebGLDisjointQueryTimerEnabled(webGLVersion);
122+
} else if (feature === 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') {
123+
return this.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED') &&
124+
!device_util.isMobile();
125+
} else if (feature === 'WEBGL_VERSION') {
126+
if (isWebGLVersionEnabled(2)) {
127+
return 2;
128+
} else if (isWebGLVersionEnabled(1)) {
129+
return 1;
130+
}
131+
return 0;
132+
}
133+
throw new Error(`Unknown feature ${feature}.`);
134+
}
105135
}
106136

107137
// Expects flags from URL in the format ?dljsflags=FLAG1:1,FLAG2:true.

src/environment_test.ts

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,101 @@
1515
* =============================================================================
1616
*/
1717
import * as device_util from './device_util';
18-
import {Environment} from './environment';
18+
import {Environment, Features} from './environment';
1919

20-
describe('disjoint query timer', () => {
21-
it('mobile', () => {
20+
describe('disjoint query timer enabled', () => {
21+
it('no webgl', () => {
22+
const features: Features = {'WEBGL_VERSION': 0};
23+
24+
const env = new Environment(features);
25+
26+
expect(env.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED')).toBe(false);
27+
});
28+
29+
it('webgl 1', () => {
30+
const features: Features = {'WEBGL_VERSION': 1};
31+
32+
spyOn(document, 'createElement').and.returnValue({
33+
getContext: (context: string) => {
34+
if (context === 'webgl' || context === 'experimental-webgl') {
35+
return {
36+
getExtension: (extensionName: string) => {
37+
if (extensionName === 'EXT_disjoint_timer_query') {
38+
return {};
39+
} else if (extensionName === 'WEBGL_lose_context') {
40+
return {loseContext: () => {}};
41+
}
42+
return null;
43+
}
44+
};
45+
}
46+
return null;
47+
}
48+
});
49+
50+
const env = new Environment(features);
51+
52+
expect(env.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED')).toBe(true);
53+
});
54+
55+
it('webgl 2', () => {
56+
const features: Features = {'WEBGL_VERSION': 2};
57+
58+
spyOn(document, 'createElement').and.returnValue({
59+
getContext: (context: string) => {
60+
if (context === 'webgl2') {
61+
return {
62+
getExtension: (extensionName: string) => {
63+
if (extensionName === 'EXT_disjoint_timer_query_webgl2') {
64+
return {};
65+
} else if (extensionName === 'WEBGL_lose_context') {
66+
return {loseContext: () => {}};
67+
}
68+
return null;
69+
}
70+
};
71+
}
72+
return null;
73+
}
74+
});
75+
76+
const env = new Environment(features);
77+
78+
expect(env.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED')).toBe(true);
79+
});
80+
81+
82+
});
83+
describe('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', () => {
84+
it('disjoint query timer disabled', () => {
85+
const features:
86+
Features = {'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED': false};
87+
88+
const env = new Environment(features);
89+
90+
expect(env.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE'))
91+
.toBe(false);
92+
});
93+
94+
it('disjoint query timer enabled, mobile', () => {
95+
const features:
96+
Features = {'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED': true};
2297
spyOn(device_util, 'isMobile').and.returnValue(true);
2398

24-
const env = new Environment();
99+
const env = new Environment(features);
25100

26-
expect(env.get('WEBGL_DISJOINT_QUERY_TIMER')).toBe(false);
101+
expect(env.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE'))
102+
.toBe(false);
27103
});
28104

29-
it('not mobile', () => {
105+
it('disjoint query timer enabled, not mobile', () => {
106+
const features:
107+
Features = {'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED': true};
30108
spyOn(device_util, 'isMobile').and.returnValue(false);
31109

32-
const env = new Environment();
110+
const env = new Environment(features);
33111

34-
expect(env.get('WEBGL_DISJOINT_QUERY_TIMER')).toBe(true);
112+
expect(env.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')).toBe(true);
35113
});
36114
});
37115

src/math/ndarray.ts

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
* =============================================================================
1616
*/
1717

18+
import {ENV} from '../environment';
1819
import * as util from '../util';
1920

2021
import {GPGPUContext} from './webgl/gpgpu_context';
@@ -243,6 +244,27 @@ export class NDArray {
243244
return this.data.values;
244245
}
245246

247+
getValuesAsync(): Promise<Float32Array> {
248+
return new Promise<Float32Array>((resolve, reject) => {
249+
if (this.data.values != null) {
250+
resolve(this.data.values);
251+
return;
252+
}
253+
254+
if (!ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED')) {
255+
resolve(this.getValues());
256+
return;
257+
}
258+
259+
// Construct an empty query. We're just interested in getting a callback
260+
// when the GPU command queue has executed until this point in time.
261+
const queryFn = () => {};
262+
GPGPU.runQuery(queryFn).then(() => {
263+
resolve(this.getValues());
264+
});
265+
});
266+
}
267+
246268
private uploadToGPU(preferredTexShape?: [number, number]) {
247269
throwIfGPUNotInitialized();
248270
this.data.textureShapeRC = webgl_util.getTextureShapeFromLogicalShape(

0 commit comments

Comments
 (0)