diff --git a/src/backends/webgl/backend_webgl.ts b/src/backends/webgl/backend_webgl.ts index efc85a2960..aaee9a9254 100644 --- a/src/backends/webgl/backend_webgl.ts +++ b/src/backends/webgl/backend_webgl.ts @@ -62,7 +62,7 @@ import * as binaryop_gpu from './binaryop_gpu'; import {BinaryOpProgram} from './binaryop_gpu'; import * as binaryop_packed_gpu from './binaryop_packed_gpu'; import {BinaryOpPackedProgram} from './binaryop_packed_gpu'; -import {createCanvas, getWebGLContext} from './canvas_util'; +import {createCanvas} from './canvas_util'; import {ClipProgram} from './clip_gpu'; import {ClipPackedProgram} from './clip_packed_gpu'; import {ComplexAbsProgram} from './complex_abs_gpu'; @@ -130,6 +130,7 @@ import {UnaryOpProgram} from './unaryop_gpu'; import * as unary_packed_op from './unaryop_packed_gpu'; import {UnaryOpPackedProgram} from './unaryop_packed_gpu'; import {UnpackProgram} from './unpack_gpu'; +import {getActiveContext} from './webgl_context_manager'; import * as webgl_util from './webgl_util'; type KernelInfo = { @@ -221,7 +222,6 @@ export class MathBackendWebGL implements KernelBackend { private dataRefCount = new WeakMap(); private numBytesInGPU = 0; - private canvas: HTMLCanvasElement; private fromPixels2DContext: CanvasRenderingContext2D| OffscreenCanvasRenderingContext2D; @@ -248,15 +248,12 @@ export class MathBackendWebGL implements KernelBackend { } if (gpgpu == null) { - const gl = getWebGLContext(ENV.getNumber('WEBGL_VERSION')); this.binaryCache = getBinaryCache(ENV.getNumber('WEBGL_VERSION')); - this.gpgpu = new GPGPUContext(gl); - this.canvas = gl.canvas; + this.gpgpu = new GPGPUContext(); this.gpgpuCreatedLocally = true; } else { this.binaryCache = {}; this.gpgpuCreatedLocally = false; - this.canvas = gpgpu.gl.canvas; } this.textureManager = new TextureManager(this.gpgpu); this.numMBBeforeWarning = numMBBeforeWarning(); @@ -2500,11 +2497,7 @@ export class MathBackendWebGL implements KernelBackend { return; } this.textureManager.dispose(); - if (this.canvas != null && this.canvas.remove != null) { - this.canvas.remove(); - } else { - this.canvas = null; - } + if (this.fromPixels2DContext != null && //@ts-ignore this.fromPixels2DContext.canvas.remove) { @@ -2515,6 +2508,11 @@ export class MathBackendWebGL implements KernelBackend { this.gpgpu.program = null; this.gpgpu.dispose(); } + + const gl = getActiveContext(); + if (gl.canvas != null && gl.canvas.remove != null) { + gl.canvas.remove(); + } this.disposed = true; } diff --git a/src/backends/webgl/canvas_util.ts b/src/backends/webgl/canvas_util.ts index 447ab5b358..4acf6d1f59 100644 --- a/src/backends/webgl/canvas_util.ts +++ b/src/backends/webgl/canvas_util.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google LLC. All Rights Reserved. + * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -15,8 +15,6 @@ * ============================================================================= */ -const contexts: {[key: string]: WebGLRenderingContext} = {}; - const WEBGL_ATTRIBUTES: WebGLContextAttributes = { alpha: false, antialias: false, @@ -27,34 +25,6 @@ const WEBGL_ATTRIBUTES: WebGLContextAttributes = { failIfMajorPerformanceCaveat: true }; -export function setWebGLContext( - webGLVersion: number, gl: WebGLRenderingContext) { - contexts[webGLVersion] = gl; -} - -export function getWebGLContext(webGLVersion: number): WebGLRenderingContext { - if (!(webGLVersion in contexts)) { - contexts[webGLVersion] = getWebGLRenderingContext(webGLVersion); - } - const gl = contexts[webGLVersion]; - if (gl.isContextLost()) { - delete contexts[webGLVersion]; - return getWebGLContext(webGLVersion); - } - - gl.disable(gl.DEPTH_TEST); - gl.disable(gl.STENCIL_TEST); - gl.disable(gl.BLEND); - gl.disable(gl.DITHER); - gl.disable(gl.POLYGON_OFFSET_FILL); - gl.disable(gl.SAMPLE_COVERAGE); - gl.enable(gl.SCISSOR_TEST); - gl.enable(gl.CULL_FACE); - gl.cullFace(gl.BACK); - - return contexts[webGLVersion]; -} - export function createCanvas(webGLVersion: number) { if (typeof OffscreenCanvas !== 'undefined' && webGLVersion === 2) { return new OffscreenCanvas(300, 150); @@ -65,7 +35,19 @@ export function createCanvas(webGLVersion: number) { } } -function getWebGLRenderingContext(webGLVersion: number): WebGLRenderingContext { +export function cleanupDOMCanvasWebGLRenderingContext( + context: WebGLRenderingContext) { + if (context == null) { + throw new Error('Shold not hit this case'); + } + const canvas = context.canvas; + if (canvas != null && canvas.remove != null) { + canvas.remove(); + } +} + +export function createDOMCanvasWebGLRenderingContext(webGLVersion: number): + WebGLRenderingContext { if (webGLVersion !== 1 && webGLVersion !== 2) { throw new Error('Cannot get WebGL rendering context, WebGL is disabled.'); } @@ -73,7 +55,6 @@ function getWebGLRenderingContext(webGLVersion: number): WebGLRenderingContext { canvas.addEventListener('webglcontextlost', (ev: Event) => { ev.preventDefault(); - delete contexts[webGLVersion]; }, false); if (webGLVersion === 1) { return (canvas.getContext('webgl', WEBGL_ATTRIBUTES) || diff --git a/src/backends/webgl/canvas_util_test.ts b/src/backends/webgl/canvas_util_test.ts index e4e01ec072..daa4a6db94 100644 --- a/src/backends/webgl/canvas_util_test.ts +++ b/src/backends/webgl/canvas_util_test.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google LLC. All Rights Reserved. + * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -18,12 +18,13 @@ import {ENV} from '../../environment'; import {BROWSER_ENVS, describeWithFlags} from '../../jasmine_util'; -import {getWebGLContext} from './canvas_util'; +import {createDOMCanvasWebGLRenderingContext} from './canvas_util'; describeWithFlags('canvas_util', BROWSER_ENVS, () => { it('Returns a valid canvas', () => { - const canvas = getWebGLContext(ENV.getNumber('WEBGL_VERSION')).canvas as ( - HTMLCanvasElement | OffscreenCanvas); + const canvas = + createDOMCanvasWebGLRenderingContext(ENV.getNumber('WEBGL_VERSION')) + .canvas as (HTMLCanvasElement | OffscreenCanvas); expect( (canvas instanceof HTMLCanvasElement) || (canvas instanceof OffscreenCanvas)) @@ -31,14 +32,15 @@ describeWithFlags('canvas_util', BROWSER_ENVS, () => { }); it('Returns a valid gl context', () => { - const gl = getWebGLContext(ENV.getNumber('WEBGL_VERSION')); + const gl = + createDOMCanvasWebGLRenderingContext(ENV.getNumber('WEBGL_VERSION')); expect(gl.isContextLost()).toBe(false); }); }); describeWithFlags('canvas_util webgl2', {flags: {WEBGL_VERSION: 2}}, () => { it('is ok when the user requests webgl 1 canvas', () => { - const canvas = getWebGLContext(1).canvas; + const canvas = createDOMCanvasWebGLRenderingContext(1).canvas; expect((canvas instanceof HTMLCanvasElement)).toBe(true); }); }); diff --git a/src/backends/webgl/clip_gpu.ts b/src/backends/webgl/clip_gpu.ts index bb9dde1b9e..55eaa68ba0 100644 --- a/src/backends/webgl/clip_gpu.ts +++ b/src/backends/webgl/clip_gpu.ts @@ -17,6 +17,7 @@ import {GPGPUContext} from './gpgpu_context'; import {GPGPUProgram} from './gpgpu_math'; +import {getActiveContext} from './webgl_context_manager'; export class ClipProgram implements GPGPUProgram { variableNames = ['A']; @@ -51,8 +52,9 @@ export class ClipProgram implements GPGPUProgram { this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'min'); this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'max'); } - gpgpu.gl.uniform1f(this.minLoc, min); - gpgpu.gl.uniform1f(this.maxLoc, max); + const gl = getActiveContext(); + gl.uniform1f(this.minLoc, min); + gl.uniform1f(this.maxLoc, max); }; } } diff --git a/src/backends/webgl/clip_packed_gpu.ts b/src/backends/webgl/clip_packed_gpu.ts index 566eb088cf..aa176e6b31 100644 --- a/src/backends/webgl/clip_packed_gpu.ts +++ b/src/backends/webgl/clip_packed_gpu.ts @@ -17,6 +17,7 @@ import {GPGPUContext} from './gpgpu_context'; import {GPGPUProgram} from './gpgpu_math'; +import {getActiveContext} from './webgl_context_manager'; export class ClipPackedProgram implements GPGPUProgram { variableNames = ['A']; @@ -53,8 +54,9 @@ export class ClipPackedProgram implements GPGPUProgram { this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'min'); this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'max'); } - gpgpu.gl.uniform1f(this.minLoc, min); - gpgpu.gl.uniform1f(this.maxLoc, max); + const gl = getActiveContext(); + gl.uniform1f(this.minLoc, min); + gl.uniform1f(this.maxLoc, max); }; } } diff --git a/src/backends/webgl/fill_gpu.ts b/src/backends/webgl/fill_gpu.ts index 3822ed0746..5469d37321 100644 --- a/src/backends/webgl/fill_gpu.ts +++ b/src/backends/webgl/fill_gpu.ts @@ -17,6 +17,7 @@ import {GPGPUContext} from './gpgpu_context'; import {GPGPUProgram} from './gpgpu_math'; +import {getActiveContext} from './webgl_context_manager'; export class FillProgram implements GPGPUProgram { variableNames: string[]; @@ -43,7 +44,7 @@ export class FillProgram implements GPGPUProgram { if (this.valueLoc == null) { this.valueLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'value'); } - gpgpu.gl.uniform1f(this.valueLoc, value); + getActiveContext().uniform1f(this.valueLoc, value); }; } } diff --git a/src/backends/webgl/flags_webgl_test.ts b/src/backends/webgl/flags_webgl_test.ts index 2cd07d0ac3..1589fc8270 100644 --- a/src/backends/webgl/flags_webgl_test.ts +++ b/src/backends/webgl/flags_webgl_test.ts @@ -19,7 +19,7 @@ import * as device_util from '../../device_util'; import {ENV} from '../../environment'; import {webgl_util} from '../../webgl'; -import * as canvas_util from './canvas_util'; +import * as webgl_context_manager from './webgl_context_manager'; describe('HAS_WEBGL', () => { beforeEach(() => ENV.reset()); @@ -197,7 +197,7 @@ describe('WEBGL_MAX_TEXTURE_SIZE', () => { ENV.reset(); webgl_util.MAX_TEXTURE_SIZE = null; - spyOn(canvas_util, 'getWebGLContext').and.returnValue({ + spyOn(webgl_context_manager, 'getContextByVersion').and.returnValue({ MAX_TEXTURE_SIZE: 101, getParameter: (param: number) => { if (param === 101) { @@ -223,7 +223,7 @@ describe('WEBGL_MAX_TEXTURES_IN_SHADER', () => { ENV.reset(); webgl_util.MAX_TEXTURES_IN_SHADER = null; - spyOn(canvas_util, 'getWebGLContext').and.callFake(() => { + spyOn(webgl_context_manager, 'getContextByVersion').and.callFake(() => { return { MAX_TEXTURE_IMAGE_UNITS: 101, getParameter: (param: number) => { diff --git a/src/backends/webgl/gpgpu_context.ts b/src/backends/webgl/gpgpu_context.ts index a86ad11d95..cd560a3c4e 100644 --- a/src/backends/webgl/gpgpu_context.ts +++ b/src/backends/webgl/gpgpu_context.ts @@ -19,10 +19,10 @@ import {ENV} from '../../environment'; import {PixelData, TypedArray} from '../../types'; import * as util from '../../util'; -import {getWebGLContext, setWebGLContext} from './canvas_util'; import * as gpgpu_util from './gpgpu_util'; import {TextureConfig} from './gpgpu_util'; import * as tex_util from './tex_util'; +import {getActiveContext} from './webgl_context_manager'; import {WebGL1DisjointQueryTimerExtension, WebGL2DisjointQueryTimerExtension} from './webgl_types'; import * as webgl_util from './webgl_util'; @@ -32,7 +32,6 @@ export interface FenceContext { } export class GPGPUContext { - gl: WebGLRenderingContext; textureFloatExtension: {}; textureHalfFloatExtension: {}; colorBufferFloatExtension: {}; @@ -48,38 +47,34 @@ export class GPGPUContext { private disjoint: boolean; private textureConfig: TextureConfig; - constructor(gl?: WebGLRenderingContext) { - const glVersion = ENV.getNumber('WEBGL_VERSION'); - if (gl != null) { - this.gl = gl; - setWebGLContext(glVersion, gl); - } else { - this.gl = getWebGLContext(glVersion); - } + constructor() { + const gl = getActiveContext(); + webgl_util.checkWebGLError(gl); + // WebGL 2.0 enables texture floats without an extension. if (ENV.getNumber('WEBGL_VERSION') === 1) { - this.textureFloatExtension = webgl_util.getExtensionOrThrow( - this.gl, this.debug, 'OES_texture_float'); + this.textureFloatExtension = + webgl_util.getExtensionOrThrow(gl, this.debug, 'OES_texture_float'); this.colorBufferFloatExtension = - this.gl.getExtension('WEBGL_color_buffer_float'); + gl.getExtension('WEBGL_color_buffer_float'); if (!ENV.getBool('WEBGL_RENDER_FLOAT32_ENABLED')) { this.textureHalfFloatExtension = webgl_util.getExtensionOrThrow( - this.gl, this.debug, 'OES_texture_half_float'); + gl, this.debug, 'OES_texture_half_float'); this.colorBufferHalfFloatExtension = - this.gl.getExtension('EXT_color_buffer_half_float'); + gl.getExtension('EXT_color_buffer_half_float'); } } else { this.colorBufferFloatExtension = webgl_util.getExtensionOrThrow( - this.gl, this.debug, 'EXT_color_buffer_float'); + gl, this.debug, 'EXT_color_buffer_float'); } - this.vertexBuffer = gpgpu_util.createVertexBuffer(this.gl, this.debug); - this.indexBuffer = gpgpu_util.createIndexBuffer(this.gl, this.debug); - this.framebuffer = webgl_util.createFramebuffer(this.gl, this.debug); + this.vertexBuffer = gpgpu_util.createVertexBuffer(gl, this.debug); + this.indexBuffer = gpgpu_util.createIndexBuffer(gl, this.debug); + this.framebuffer = webgl_util.createFramebuffer(gl, this.debug); this.textureConfig = - gpgpu_util.getTextureConfig(this.gl, this.textureHalfFloatExtension); + gpgpu_util.getTextureConfig(gl, this.textureHalfFloatExtension); } private get debug(): boolean { @@ -103,18 +98,21 @@ export class GPGPUContext { 'matrix texture with GPGPUContext.deleteMatrixTexture before ' + 'disposing.'); } - const gl = this.gl; - webgl_util.callAndCheck(gl, this.debug, () => gl.finish()); - webgl_util.callAndCheck( - gl, this.debug, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null)); + const debug = true; + const gl = getActiveContext(); + webgl_util.checkWebGLError(gl); + webgl_util.callAndCheck(gl, debug, () => gl.finish()); + // TODO(kreeger): This bind framebuffer call can throw an INVALID_OPERATION + // error on WebGL2 - fix this. webgl_util.callAndCheck( - gl, this.debug, () => gl.deleteFramebuffer(this.framebuffer)); + gl, debug, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null)); webgl_util.callAndCheck( - gl, this.debug, () => gl.bindBuffer(gl.ARRAY_BUFFER, null)); + gl, debug, () => gl.deleteFramebuffer(this.framebuffer)); webgl_util.callAndCheck( - gl, this.debug, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null)); + gl, debug, () => gl.bindBuffer(gl.ARRAY_BUFFER, null)); webgl_util.callAndCheck( - gl, this.debug, () => gl.deleteBuffer(this.indexBuffer)); + gl, debug, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null)); + webgl_util.callAndCheck(gl, debug, () => gl.deleteBuffer(this.indexBuffer)); this.disposed = true; } @@ -122,60 +120,64 @@ export class GPGPUContext { WebGLTexture { this.throwIfDisposed(); return gpgpu_util.createFloat32MatrixTexture( - this.gl, this.debug, rows, columns, this.textureConfig); + getActiveContext(), this.debug, rows, columns, this.textureConfig); } public createFloat16MatrixTexture(rows: number, columns: number): WebGLTexture { this.throwIfDisposed(); return gpgpu_util.createFloat16MatrixTexture( - this.gl, this.debug, rows, columns, this.textureConfig); + getActiveContext(), this.debug, rows, columns, this.textureConfig); } public createUnsignedBytesMatrixTexture(rows: number, columns: number): WebGLTexture { this.throwIfDisposed(); return gpgpu_util.createUnsignedBytesMatrixTexture( - this.gl, this.debug, rows, columns, this.textureConfig); + getActiveContext(), this.debug, rows, columns, this.textureConfig); } public uploadPixelDataToTexture( texture: WebGLTexture, pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement) { this.throwIfDisposed(); - gpgpu_util.uploadPixelDataToTexture(this.gl, this.debug, texture, pixels); + gpgpu_util.uploadPixelDataToTexture( + getActiveContext(), this.debug, texture, pixels); } public uploadDenseMatrixToTexture( texture: WebGLTexture, width: number, height: number, data: TypedArray) { this.throwIfDisposed(); gpgpu_util.uploadDenseMatrixToTexture( - this.gl, this.debug, texture, width, height, data, this.textureConfig); + getActiveContext(), this.debug, texture, width, height, data, + this.textureConfig); } public createFloat16PackedMatrixTexture(rows: number, columns: number): WebGLTexture { this.throwIfDisposed(); return gpgpu_util.createFloat16PackedMatrixTexture( - this.gl, this.debug, rows, columns, this.textureConfig); + getActiveContext(), this.debug, rows, columns, this.textureConfig); } public createPackedMatrixTexture(rows: number, columns: number): WebGLTexture { this.throwIfDisposed(); return gpgpu_util.createPackedMatrixTexture( - this.gl, this.debug, rows, columns, this.textureConfig); + getActiveContext(), true, rows, columns, this.textureConfig); } public deleteMatrixTexture(texture: WebGLTexture) { this.throwIfDisposed(); if (this.outputTexture === texture) { webgl_util.unbindColorTextureFromFramebuffer( - this.gl, this.debug, this.framebuffer); + getActiveContext(), this.debug, this.framebuffer); this.outputTexture = null; } + console.log(' texture: ' + texture); webgl_util.callAndCheck( - this.gl, this.debug, () => this.gl.deleteTexture(texture)); + getActiveContext(), true, + () => getActiveContext().deleteTexture(texture)); } public downloadByteEncodedFloatMatrixFromOutputTexture( @@ -183,34 +185,35 @@ export class GPGPUContext { return this.downloadMatrixDriver( texture, () => gpgpu_util.downloadByteEncodedFloatMatrixFromOutputTexture( - this.gl, this.debug, rows, columns, this.textureConfig)); + getActiveContext(), this.debug, rows, columns, this.textureConfig)); } public downloadPackedMatrixFromBuffer( buffer: WebGLBuffer, batch: number, rows: number, columns: number, physicalRows: number, physicalCols: number): Float32Array { return gpgpu_util.downloadPackedMatrixFromBuffer( - this.gl, buffer, batch, rows, columns, physicalRows, physicalCols, - this.textureConfig); + getActiveContext(), buffer, batch, rows, columns, physicalRows, + physicalCols, this.textureConfig); } public downloadFloat32MatrixFromBuffer(buffer: WebGLBuffer, size: number): Float32Array { - return gpgpu_util.downloadFloat32MatrixFromBuffer(this.gl, buffer, size); + return gpgpu_util.downloadFloat32MatrixFromBuffer( + getActiveContext(), buffer, size); } public createBufferFromTexture( texture: WebGLTexture, rows: number, columns: number): WebGLBuffer { this.bindTextureToFrameBuffer(texture); const result = gpgpu_util.createBufferFromOutputTexture( - this.gl as WebGL2RenderingContext, this.debug, rows, columns, + getActiveContext() as WebGL2RenderingContext, this.debug, rows, columns, this.textureConfig); this.unbindTextureToFrameBuffer(); return result; } public createAndWaitForFence(): Promise { - const fenceContext = this.createFence(this.gl); + const fenceContext = this.createFence(getActiveContext()); return this.pollFence(fenceContext); } @@ -254,14 +257,14 @@ export class GPGPUContext { return this.downloadMatrixDriver( texture, () => gpgpu_util.downloadMatrixFromPackedOutputTexture( - this.gl, this.debug, physicalRows, physicalCols)); + getActiveContext(), this.debug, physicalRows, physicalCols)); } private vertexAttrsAreBound = false; public createProgram(fragmentShaderSource: string): WebGLProgram { this.throwIfDisposed(); - const gl = this.gl; + const gl = getActiveContext(); const fragmentShader: WebGLShader = webgl_util.createFragmentShader(gl, this.debug, fragmentShaderSource); const vertexShader: WebGLShader = @@ -293,7 +296,8 @@ export class GPGPUContext { } if (program != null) { webgl_util.callAndCheck( - this.gl, this.debug, () => this.gl.deleteProgram(program)); + getActiveContext(), this.debug, + () => getActiveContext().deleteProgram(program)); } } @@ -301,10 +305,11 @@ export class GPGPUContext { this.throwIfDisposed(); this.program = program; if ((this.program != null) && this.debug) { - webgl_util.validateProgram(this.gl, this.debug, this.program); + webgl_util.validateProgram(getActiveContext(), this.debug, this.program); } webgl_util.callAndCheck( - this.gl, this.debug, () => this.gl.useProgram(program)); + getActiveContext(), this.debug, + () => getActiveContext().useProgram(program)); } public getUniformLocation( @@ -313,10 +318,10 @@ export class GPGPUContext { this.throwIfDisposed(); if (shouldThrow) { return webgl_util.getProgramUniformLocationOrThrow( - this.gl, this.debug, program, uniformName); + getActiveContext(), this.debug, program, uniformName); } else { return webgl_util.getProgramUniformLocation( - this.gl, program, uniformName); + getActiveContext(), program, uniformName); } } @@ -324,14 +329,14 @@ export class GPGPUContext { number { this.throwIfDisposed(); return webgl_util.callAndCheck( - this.gl, this.debug, - () => this.gl.getAttribLocation(program, attribute)); + getActiveContext(), this.debug, + () => getActiveContext().getAttribLocation(program, attribute)); } public getUniformLocationNoThrow(program: WebGLProgram, uniformName: string): WebGLUniformLocation { this.throwIfDisposed(); - return this.gl.getUniformLocation(program, uniformName); + return getActiveContext().getUniformLocation(program, uniformName); } public setInputMatrixTexture( @@ -340,8 +345,8 @@ export class GPGPUContext { this.throwIfDisposed(); this.throwIfNoProgram(); webgl_util.bindTextureToProgramUniformSampler( - this.gl, this.debug, this.program, inputMatrixTexture, uniformLocation, - textureUnit); + getActiveContext(), this.debug, this.program, inputMatrixTexture, + uniformLocation, textureUnit); } public setOutputMatrixTexture( @@ -372,15 +377,15 @@ export class GPGPUContext { public debugValidate() { if (this.program != null) { - webgl_util.validateProgram(this.gl, this.debug, this.program); + webgl_util.validateProgram(getActiveContext(), this.debug, this.program); } - webgl_util.validateFramebuffer(this.gl); + webgl_util.validateFramebuffer(getActiveContext()); } public executeProgram() { this.throwIfDisposed(); this.throwIfNoProgram(); - const gl = this.gl; + const gl = getActiveContext(); if (this.debug) { this.debugValidate(); } @@ -391,7 +396,8 @@ export class GPGPUContext { public blockUntilAllProgramsCompleted() { this.throwIfDisposed(); - webgl_util.callAndCheck(this.gl, this.debug, () => this.gl.finish()); + webgl_util.callAndCheck( + getActiveContext(), this.debug, () => getActiveContext().finish()); } private getQueryTimerExtension(): WebGL1DisjointQueryTimerExtension @@ -399,7 +405,7 @@ export class GPGPUContext { if (this.disjointQueryTimerExtension == null) { this.disjointQueryTimerExtension = webgl_util.getExtensionOrThrow( - this.gl, this.debug, + getActiveContext(), this.debug, ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ? 'EXT_disjoint_timer_query_webgl2' : @@ -420,7 +426,7 @@ export class GPGPUContext { beginQuery(): WebGLQuery { if (ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { - const gl2 = this.gl as WebGL2RenderingContext; + const gl2 = getActiveContext() as WebGL2RenderingContext; const ext = this.getQueryTimerExtensionWebGL2(); const query = gl2.createQuery(); @@ -435,7 +441,7 @@ export class GPGPUContext { endQuery() { if (ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { - const gl2 = this.gl as WebGL2RenderingContext; + const gl2 = getActiveContext() as WebGL2RenderingContext; const ext = this.getQueryTimerExtensionWebGL2(); gl2.endQuery(ext.TIME_ELAPSED_EXT); return; @@ -463,7 +469,7 @@ export class GPGPUContext { } if (queryTimerVersion === 2) { - const gl2 = this.gl as WebGL2RenderingContext; + const gl2 = getActiveContext() as WebGL2RenderingContext; const timeElapsedNanos = gl2.getQueryParameter(query, gl2.QUERY_RESULT); // Return milliseconds. @@ -485,13 +491,13 @@ export class GPGPUContext { } if (queryTimerVersion === 2) { - const gl2 = this.gl as WebGL2RenderingContext; + const gl2 = getActiveContext() as WebGL2RenderingContext; const ext = this.getQueryTimerExtensionWebGL2(); const available = gl2.getQueryParameter(query, gl2.QUERY_RESULT_AVAILABLE); if (this.disjoint == null) { - this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT); + this.disjoint = getActiveContext().getParameter(ext.GPU_DISJOINT_EXT); } return available && !this.disjoint; @@ -501,7 +507,7 @@ export class GPGPUContext { const available = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_AVAILABLE_EXT); if (this.disjoint == null) { - this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT); + this.disjoint = getActiveContext().getParameter(ext.GPU_DISJOINT_EXT); } return available && !this.disjoint; @@ -543,22 +549,22 @@ export class GPGPUContext { private bindTextureToFrameBuffer(texture: WebGLTexture) { this.throwIfDisposed(); webgl_util.bindColorTextureToFramebuffer( - this.gl, this.debug, texture, this.framebuffer); + getActiveContext(), this.debug, texture, this.framebuffer); if (this.debug) { - webgl_util.validateFramebuffer(this.gl); + webgl_util.validateFramebuffer(getActiveContext()); } } private unbindTextureToFrameBuffer() { if (this.outputTexture != null) { webgl_util.bindColorTextureToFramebuffer( - this.gl, this.debug, this.outputTexture, this.framebuffer); + getActiveContext(), this.debug, this.outputTexture, this.framebuffer); if (this.debug) { - webgl_util.validateFramebuffer(this.gl); + webgl_util.validateFramebuffer(getActiveContext()); } } else { webgl_util.unbindColorTextureFromFramebuffer( - this.gl, this.debug, this.framebuffer); + getActiveContext(), this.debug, this.framebuffer); } } @@ -576,7 +582,7 @@ export class GPGPUContext { outputMatrixTextureMaybePacked: WebGLTexture, width: number, height: number) { this.throwIfDisposed(); - const gl = this.gl; + const gl = getActiveContext(); webgl_util.bindColorTextureToFramebuffer( gl, this.debug, outputMatrixTextureMaybePacked, this.framebuffer); if (this.debug) { @@ -593,7 +599,8 @@ export class GPGPUContext { x: number, y: number, width: number, height: number) { this.throwIfDisposed(); webgl_util.callAndCheck( - this.gl, this.debug, () => this.gl.scissor(x, y, width, height)); + getActiveContext(), this.debug, + () => getActiveContext().scissor(x, y, width, height)); } private throwIfDisposed() { diff --git a/src/backends/webgl/gpgpu_context_test.ts b/src/backends/webgl/gpgpu_context_test.ts index d03bae8f32..0bdd3ec63b 100644 --- a/src/backends/webgl/gpgpu_context_test.ts +++ b/src/backends/webgl/gpgpu_context_test.ts @@ -22,6 +22,7 @@ import {WEBGL_ENVS} from './backend_webgl_test_registry'; import {getGlslDifferences} from './glsl_version'; import {GPGPUContext, linearSearchLastTrue} from './gpgpu_context'; import * as tex_util from './tex_util'; +import {getActiveContext} from './webgl_context_manager'; const DOWNLOAD_FLOAT_ENVS = { flags: {'WEBGL_DOWNLOAD_FLOAT_ENABLED': true}, @@ -57,7 +58,8 @@ describeWithFlags( const output = gpgpu.createFloat32MatrixTexture(rows, columns); gpgpu.setOutputMatrixTexture(output, rows, columns); const expected = new Int32Array([0, 0, columns, rows]); - expect(gpgpu.gl.getParameter(gpgpu.gl.VIEWPORT)).toEqual(expected); + const gl = getActiveContext(); + expect(gl.getParameter(gl.VIEWPORT)).toEqual(expected); gpgpu.deleteMatrixTexture(output); }); }); @@ -95,7 +97,8 @@ describeWithFlags( const [width, height] = tex_util.getPackedMatrixTextureShapeWidthHeight(rows, columns); const expected = new Int32Array([0, 0, width, height]); - expect(gpgpu.gl.getParameter(gpgpu.gl.VIEWPORT)).toEqual(expected); + const gl = getActiveContext(); + expect(gl.getParameter(gl.VIEWPORT)).toEqual(expected); }); }); @@ -133,7 +136,8 @@ describeWithFlags( it('sets the scissor box to the requested parameters', () => { gpgpu.setOutputMatrixWriteRegion(0, 1, 2, 3); - const scissorBox = gpgpu.gl.getParameter(gpgpu.gl.SCISSOR_BOX); + const gl = getActiveContext(); + const scissorBox = gl.getParameter(gl.SCISSOR_BOX); expect(scissorBox[0]).toEqual(2); expect(scissorBox[1]).toEqual(0); expect(scissorBox[2]).toEqual(3); @@ -155,12 +159,6 @@ describeWithFlags('GPGPUContext', DOWNLOAD_FLOAT_ENVS, () => { gpgpu.dispose(); }); - it('throws an error if used after dispose', () => { - const gpgpuContext = new GPGPUContext(); - gpgpuContext.dispose(); - expect(gpgpuContext.dispose).toThrowError(); - }); - it('throws an error if validation is on and framebuffer incomplete', () => { const glsl = getGlslDifferences(); const src = `${glsl.version} @@ -177,6 +175,14 @@ describeWithFlags('GPGPUContext', DOWNLOAD_FLOAT_ENVS, () => { }); }); +describeWithFlags('GPGPUContext dispose', DOWNLOAD_FLOAT_ENVS, () => { + it('throws an error if used after dispose', () => { + const gpgpuContext = new GPGPUContext(); + gpgpuContext.dispose(); + expect(gpgpuContext.dispose).toThrowError(); + }); +}); + describe('gpgpu_context linearSearchLastTrue', () => { it('[false]', () => { const a: boolean[] = [false]; diff --git a/src/backends/webgl/gpgpu_math.ts b/src/backends/webgl/gpgpu_math.ts index a284da2550..d57bb1e2f5 100644 --- a/src/backends/webgl/gpgpu_math.ts +++ b/src/backends/webgl/gpgpu_math.ts @@ -24,6 +24,7 @@ import {GPGPUContext} from './gpgpu_context'; import * as shader_compiler from './shader_compiler'; import {InputInfo, ShapeInfo} from './shader_compiler'; import {TextureData} from './tex_util'; +import {getActiveContext} from './webgl_context_manager'; export interface GPGPUProgram { variableNames: string[]; @@ -162,14 +163,15 @@ export function runProgram( } gpgpu.setProgram(binary.webGLProgram); + const gl = getActiveContext(); // Set special uniforms (NAN, INFINITY) if (ENV.getNumber('WEBGL_VERSION') === 1) { if (binary.infLoc !== null) { - gpgpu.gl.uniform1f(binary.infLoc, Infinity); + gl.uniform1f(binary.infLoc, Infinity); } } if (binary.nanLoc !== null) { - gpgpu.gl.uniform1f(binary.nanLoc, NaN); + gl.uniform1f(binary.nanLoc, NaN); } // Set user-defined inputs @@ -186,20 +188,20 @@ export function runProgram( if (input.isUniform) { // Upload the values of the tensor as uniform. if (util.sizeFromShape(input.shape) < 2) { - gpgpu.gl.uniform1f(varLoc, input.uniformValues[0]); + gl.uniform1f(varLoc, input.uniformValues[0]); } else { let vals = input.uniformValues; if (!(vals instanceof Float32Array)) { vals = new Float32Array(vals); } - gpgpu.gl.uniform1fv(varLoc, vals); + gl.uniform1fv(varLoc, vals); } return; } // If the input was sliced, upload the flat offset index. if (input.texData.slice != null && varOffsetLoc != null) { - gpgpu.gl.uniform1i(varOffsetLoc, input.texData.slice.flatOffset); + gl.uniform1i(varOffsetLoc, input.texData.slice.flatOffset); } gpgpu.setInputMatrixTexture(input.texData.texture, varLoc, i); diff --git a/src/backends/webgl/gpgpu_util_test.ts b/src/backends/webgl/gpgpu_util_test.ts index 553e542dcc..29dd250ee7 100644 --- a/src/backends/webgl/gpgpu_util_test.ts +++ b/src/backends/webgl/gpgpu_util_test.ts @@ -19,114 +19,127 @@ import {describeWithFlags} from '../../jasmine_util'; import {WEBGL_ENVS} from './backend_webgl_test_registry'; import {GPGPUContext} from './gpgpu_context'; import * as gpgpu_util from './gpgpu_util'; +import {getActiveContext} from './webgl_context_manager'; describeWithFlags('gpgpu_util createWebGLContext', WEBGL_ENVS, () => { let gpgpu: GPGPUContext; + let gl: WebGLRenderingContext; beforeEach(() => { gpgpu = new GPGPUContext(); + gl = getActiveContext(); }); afterEach(() => { gpgpu.dispose(); + gl = null; }); it('disables DEPTH_TEST and STENCIL_TEST', () => { - expect(gpgpu.gl.getParameter(gpgpu.gl.DEPTH_TEST)).toEqual(false); - expect(gpgpu.gl.getParameter(gpgpu.gl.STENCIL_TEST)).toEqual(false); + expect(gl.getParameter(gl.DEPTH_TEST)).toEqual(false); + expect(gl.getParameter(gl.STENCIL_TEST)).toEqual(false); }); it('disables BLEND', () => { - expect(gpgpu.gl.getParameter(gpgpu.gl.BLEND)).toEqual(false); + expect(gl.getParameter(gl.BLEND)).toEqual(false); }); it('disables DITHER, POLYGON_OFFSET_FILL', () => { - expect(gpgpu.gl.getParameter(gpgpu.gl.DITHER)).toEqual(false); - expect(gpgpu.gl.getParameter(gpgpu.gl.POLYGON_OFFSET_FILL)).toEqual(false); + expect(gl.getParameter(gl.DITHER)).toEqual(false); + expect(gl.getParameter(gl.POLYGON_OFFSET_FILL)).toEqual(false); }); it('enables CULL_FACE with BACK', () => { - expect(gpgpu.gl.getParameter(gpgpu.gl.CULL_FACE)).toEqual(true); - expect(gpgpu.gl.getParameter(gpgpu.gl.CULL_FACE_MODE)) - .toEqual(gpgpu.gl.BACK); + expect(gl.getParameter(gl.CULL_FACE)).toEqual(true); + expect(gl.getParameter(gl.CULL_FACE_MODE)).toEqual(gl.BACK); }); it('enables SCISSOR_TEST', () => { - expect(gpgpu.gl.getParameter(gpgpu.gl.SCISSOR_TEST)).toEqual(true); + expect(gl.getParameter(gl.SCISSOR_TEST)).toEqual(true); }); }); describeWithFlags('gpgpu_util createFloat32MatrixTexture', WEBGL_ENVS, () => { + let gl: WebGLRenderingContext; + beforeEach(() => { + gl = getActiveContext(); + }); + + afterEach(() => { + gl = null; + }); + it('sets the TEXTURE_WRAP S+T parameters to CLAMP_TO_EDGE', () => { const gpgpu = new GPGPUContext(); - const textureConfig = gpgpu_util.getTextureConfig(gpgpu.gl); + const textureConfig = gpgpu_util.getTextureConfig(gl); const debug = false; - const tex = gpgpu_util.createFloat32MatrixTexture( - gpgpu.gl, debug, 32, 32, textureConfig); - gpgpu.gl.bindTexture(gpgpu.gl.TEXTURE_2D, tex); - expect( - gpgpu.gl.getTexParameter(gpgpu.gl.TEXTURE_2D, gpgpu.gl.TEXTURE_WRAP_S)) - .toEqual(gpgpu.gl.CLAMP_TO_EDGE); - expect( - gpgpu.gl.getTexParameter(gpgpu.gl.TEXTURE_2D, gpgpu.gl.TEXTURE_WRAP_T)) - .toEqual(gpgpu.gl.CLAMP_TO_EDGE); - gpgpu.gl.bindTexture(gpgpu.gl.TEXTURE_2D, null); + const tex = + gpgpu_util.createFloat32MatrixTexture(gl, debug, 32, 32, textureConfig); + gl.bindTexture(gl.TEXTURE_2D, tex); + expect(gl.getTexParameter(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S)) + .toEqual(gl.CLAMP_TO_EDGE); + expect(gl.getTexParameter(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T)) + .toEqual(gl.CLAMP_TO_EDGE); + gl.bindTexture(gl.TEXTURE_2D, null); gpgpu.deleteMatrixTexture(tex); gpgpu.dispose(); }); it('sets the TEXTURE_[MIN|MAG]_FILTER parameters to NEAREST', () => { const gpgpu = new GPGPUContext(); - const textureConfig = gpgpu_util.getTextureConfig(gpgpu.gl); + const textureConfig = gpgpu_util.getTextureConfig(gl); const debug = false; - const tex = gpgpu_util.createFloat32MatrixTexture( - gpgpu.gl, debug, 32, 32, textureConfig); - gpgpu.gl.bindTexture(gpgpu.gl.TEXTURE_2D, tex); - expect(gpgpu.gl.getTexParameter( - gpgpu.gl.TEXTURE_2D, gpgpu.gl.TEXTURE_MIN_FILTER)) - .toEqual(gpgpu.gl.NEAREST); - expect(gpgpu.gl.getTexParameter( - gpgpu.gl.TEXTURE_2D, gpgpu.gl.TEXTURE_MAG_FILTER)) - .toEqual(gpgpu.gl.NEAREST); - gpgpu.gl.bindTexture(gpgpu.gl.TEXTURE_2D, null); + const tex = + gpgpu_util.createFloat32MatrixTexture(gl, debug, 32, 32, textureConfig); + gl.bindTexture(gl.TEXTURE_2D, tex); + expect(gl.getTexParameter(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER)) + .toEqual(gl.NEAREST); + expect(gl.getTexParameter(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER)) + .toEqual(gl.NEAREST); + gl.bindTexture(gl.TEXTURE_2D, null); gpgpu.deleteMatrixTexture(tex); gpgpu.dispose(); }); }); describeWithFlags('gpgpu_util createPackedMatrixTexture', WEBGL_ENVS, () => { + let gl: WebGLRenderingContext; + beforeEach(() => { + gl = getActiveContext(); + }); + + afterEach(() => { + gl = null; + }); + it('sets the TEXTURE_WRAP S+T parameters to CLAMP_TO_EDGE', () => { const gpgpu = new GPGPUContext(); - const textureConfig = gpgpu_util.getTextureConfig(gpgpu.gl); + const textureConfig = gpgpu_util.getTextureConfig(gl); const debug = false; - const tex = gpgpu_util.createPackedMatrixTexture( - gpgpu.gl, debug, 32, 32, textureConfig); - gpgpu.gl.bindTexture(gpgpu.gl.TEXTURE_2D, tex); - expect( - gpgpu.gl.getTexParameter(gpgpu.gl.TEXTURE_2D, gpgpu.gl.TEXTURE_WRAP_S)) - .toEqual(gpgpu.gl.CLAMP_TO_EDGE); - expect( - gpgpu.gl.getTexParameter(gpgpu.gl.TEXTURE_2D, gpgpu.gl.TEXTURE_WRAP_T)) - .toEqual(gpgpu.gl.CLAMP_TO_EDGE); - gpgpu.gl.bindTexture(gpgpu.gl.TEXTURE_2D, null); + const tex = + gpgpu_util.createPackedMatrixTexture(gl, debug, 32, 32, textureConfig); + gl.bindTexture(gl.TEXTURE_2D, tex); + expect(gl.getTexParameter(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S)) + .toEqual(gl.CLAMP_TO_EDGE); + expect(gl.getTexParameter(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T)) + .toEqual(gl.CLAMP_TO_EDGE); + gl.bindTexture(gl.TEXTURE_2D, null); gpgpu.deleteMatrixTexture(tex); gpgpu.dispose(); }); it('sets the TEXTURE_[MIN|MAG]_FILTER parameters to NEAREST', () => { const gpgpu = new GPGPUContext(); - const textureConfig = gpgpu_util.getTextureConfig(gpgpu.gl); + const textureConfig = gpgpu_util.getTextureConfig(gl); const debug = false; - const tex = gpgpu_util.createPackedMatrixTexture( - gpgpu.gl, debug, 32, 32, textureConfig); - gpgpu.gl.bindTexture(gpgpu.gl.TEXTURE_2D, tex); - expect(gpgpu.gl.getTexParameter( - gpgpu.gl.TEXTURE_2D, gpgpu.gl.TEXTURE_MIN_FILTER)) - .toEqual(gpgpu.gl.NEAREST); - expect(gpgpu.gl.getTexParameter( - gpgpu.gl.TEXTURE_2D, gpgpu.gl.TEXTURE_MAG_FILTER)) - .toEqual(gpgpu.gl.NEAREST); - gpgpu.gl.bindTexture(gpgpu.gl.TEXTURE_2D, null); + const tex = + gpgpu_util.createPackedMatrixTexture(gl, debug, 32, 32, textureConfig); + gl.bindTexture(gl.TEXTURE_2D, tex); + expect(gl.getTexParameter(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER)) + .toEqual(gl.NEAREST); + expect(gl.getTexParameter(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER)) + .toEqual(gl.NEAREST); + gl.bindTexture(gl.TEXTURE_2D, null); gpgpu.deleteMatrixTexture(tex); gpgpu.dispose(); }); diff --git a/src/backends/webgl/multinomial_gpu.ts b/src/backends/webgl/multinomial_gpu.ts index e1c8f1e165..0b8dd46508 100644 --- a/src/backends/webgl/multinomial_gpu.ts +++ b/src/backends/webgl/multinomial_gpu.ts @@ -17,6 +17,7 @@ import {GPGPUContext} from './gpgpu_context'; import {GPGPUProgram} from './gpgpu_math'; +import {getActiveContext} from './webgl_context_manager'; export class MultinomialProgram implements GPGPUProgram { variableNames = ['probs']; @@ -59,7 +60,7 @@ export class MultinomialProgram implements GPGPUProgram { if (this.seedLoc == null) { this.seedLoc = gpgpu.getUniformLocation(webGLProgram, 'seed'); } - gpgpu.gl.uniform1f(this.seedLoc, seed); + getActiveContext().uniform1f(this.seedLoc, seed); }; } } diff --git a/src/backends/webgl/slice_gpu.ts b/src/backends/webgl/slice_gpu.ts index 1547fdd31e..eee02ff65a 100644 --- a/src/backends/webgl/slice_gpu.ts +++ b/src/backends/webgl/slice_gpu.ts @@ -18,6 +18,7 @@ import {GPGPUContext} from './gpgpu_context'; import {GPGPUProgram} from './gpgpu_math'; import {getCoordsDataType} from './shader_compiler'; +import {getActiveContext} from './webgl_context_manager'; export class SliceProgram implements GPGPUProgram { variableNames = ['source']; @@ -69,7 +70,7 @@ export class SliceProgram implements GPGPUProgram { return; } } - gpgpu.gl.uniform1iv(this.startLoc, start); + getActiveContext().uniform1iv(this.startLoc, start); }; } } diff --git a/src/backends/webgl/slice_packed_gpu.ts b/src/backends/webgl/slice_packed_gpu.ts index 6fa26f58e7..fc5d197b2e 100644 --- a/src/backends/webgl/slice_packed_gpu.ts +++ b/src/backends/webgl/slice_packed_gpu.ts @@ -20,6 +20,7 @@ import {getChannels} from '../packing_util'; import {GPGPUContext} from './gpgpu_context'; import {GPGPUProgram} from './gpgpu_math'; import {getCoordsDataType} from './shader_compiler'; +import {getActiveContext} from './webgl_context_manager'; export class SlicePackedProgram implements GPGPUProgram { variableNames = ['source']; @@ -73,7 +74,7 @@ export class SlicePackedProgram implements GPGPUProgram { void main() { ${dtype} coords = getOutputCoords(); ${dtype} sourceLoc; - ${sourceLocSetup} + ${sourceLocSetup} vec4 result = vec4(0.); ${upperRow} ${lowerRow} @@ -97,7 +98,7 @@ export class SlicePackedProgram implements GPGPUProgram { return; } } - gpgpu.gl.uniform1iv(this.startLoc, start); + getActiveContext().uniform1iv(this.startLoc, start); }; } -} \ No newline at end of file +} diff --git a/src/backends/webgl/webgl_context_manager.ts b/src/backends/webgl/webgl_context_manager.ts new file mode 100644 index 0000000000..20fce05806 --- /dev/null +++ b/src/backends/webgl/webgl_context_manager.ts @@ -0,0 +1,133 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {ENV} from '../../environment'; + +import {cleanupDOMCanvasWebGLRenderingContext, createDOMCanvasWebGLRenderingContext} from './canvas_util'; +import {callAndCheck, checkWebGLError} from './webgl_util'; + +let count = 0; +const contexts: {[key: string]: WebGLRenderingContext} = {}; +let contextFactory: (version: number) => WebGLRenderingContext = null; +let contextCleanup: (context: WebGLRenderingContext) => void = null; + +/** + * Sets the callback for creating new WebGLRenderingContext instances. + * @param factory The callback function that returns a context instance. + */ +export function setContextFactory( + factory: (version: number) => WebGLRenderingContext) { + contextFactory = factory; +} + +/** + * Sets the callback for cleaning up WebGLRenderingContext instances. + * @param cleanup The callback function to cleanup the passed in context + * instance. + */ +export function setContextCleanup( + cleanup: (context: WebGLRenderingContext) => void) { + contextCleanup = cleanup; +} + +/** + * Returns the current WebGLRenderingContext based on the ENV flag for + * 'WEBGL_VERSION'. + */ +export function getActiveContext(): WebGLRenderingContext { + return getContextByVersion(ENV.getNumber('WEBGL_VERSION')); +} + +/** + * Returns the WebGLRenderingContext for a given version number. + * @param version The specific version of WebGL to request. + */ +export function getContextByVersion(version: number): WebGLRenderingContext { + // Default to browser context creation is running in the browser. + if (contextFactory == null) { + if (ENV.getBool('IS_BROWSER')) { + // TODO(kreeger): Is there a better place to register this? + contextFactory = createDOMCanvasWebGLRenderingContext; + } else { + throw new Error('Default WebGLRenderingContext factory was not set!'); + } + } + + if (!(version in contexts)) { + contexts[version] = traceGLCalls(contextFactory(version), ++count); + bootstrapWebGLContext(contexts[version]); + checkWebGLError(contexts[version]); + } + + const gl = contexts[version]; + if (gl.isContextLost()) { + checkWebGLError(contexts[version]); + disposeWebGLContext(version); + return getContextByVersion(version); + } + checkWebGLError(contexts[version]); + return contexts[version]; +} + +function disposeWebGLContext(version: number) { + if ((version in contexts)) { + if (contextCleanup == null) { + if (ENV.getBool('IS_BROWSER')) { + // TODO(kreeger): Is there a better place to register this? + contextCleanup = cleanupDOMCanvasWebGLRenderingContext; + } + } + if (contextCleanup != null) { + contextCleanup(contexts[version]); + } + delete contexts[version]; + } +} + +function bootstrapWebGLContext(gl: WebGLRenderingContext) { + // TODO - check GL calls here too. + callAndCheck(gl, ENV.getBool('DEBUG'), () => gl.disable(gl.DEPTH_TEST)); + // gl.disable(gl.DEPTH_TEST); + gl.disable(gl.STENCIL_TEST); + gl.disable(gl.BLEND); + gl.disable(gl.DITHER); + gl.disable(gl.POLYGON_OFFSET_FILL); + gl.disable(gl.SAMPLE_COVERAGE); + gl.enable(gl.SCISSOR_TEST); + gl.enable(gl.CULL_FACE); + gl.cullFace(gl.BACK); +} + +function traceGLCalls(ctx: WebGLRenderingContext, idx: number) { + const handler = { + // tslint:disable-next-line:no-any + get(target: any, prop: PropertyKey, receiver: any): any { + const propValue = target[prop]; + + if (typeof (propValue) === 'function') { + console.log( + ' gl.' + prop.toString() + ' = ' + target.constructor.name + + ' ' + idx); + // tslint:disable-next-line:only-arrow-functions + return function() { + return propValue.apply(target, arguments); + }; + } + return propValue; + }, + }; + return new Proxy(ctx, handler); +} diff --git a/src/backends/webgl/webgl_context_manager_test.ts b/src/backends/webgl/webgl_context_manager_test.ts new file mode 100644 index 0000000000..fe0a23d571 --- /dev/null +++ b/src/backends/webgl/webgl_context_manager_test.ts @@ -0,0 +1,85 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {BROWSER_ENVS, describeWithFlags} from '../../jasmine_util'; + +import {getActiveContext, getContextByVersion, setContextCleanup, setContextFactory} from './webgl_context_manager'; + +describeWithFlags('webgl_context_manager', BROWSER_ENVS, () => { + it('returns the active context for browser WebGL', () => { + const canvas = getActiveContext(); + expect( + (canvas instanceof WebGL2RenderingContext) || + (canvas instanceof WebGLRenderingContext)) + .toBe(true); + }); +}); + +describeWithFlags( + 'webgl_context_manager webgl2', {flags: {WEBGL_VERSION: 2}}, () => { + it('returns webgl1 canvas under webgl2', () => { + const canvas = getContextByVersion(1); + expect(canvas instanceof WebGLRenderingContext).toBe(true); + }); + }); + +describe('webgl_context_manager create/cleanup', () => { + afterAll(() => { + // Reset global context creation and cleanup: + setContextCleanup(null); + setContextFactory(null); + }); + + it('should call factory method to create WebGLRenderingContext', () => { + let created = false; + let cleanedup = false; + let contextLost = false; + const contextFake = { + disable: (cap: number) => {}, + enable: (cap: number) => {}, + cullFace: (cap: number) => {}, + isContextLost: () => { + return contextLost; + } + } as WebGLRenderingContext; + + setContextFactory((version: number) => { + created = true; + return contextFake; + }); + + // Request context version '10' to bypass any cached system WebGL versions: + const context = getContextByVersion(10); + expect(created).toBe(true); + expect(context).toBe(contextFake); + + // Mark fake context as disposed so it will be cleanedup on next context + // creation request: + setContextCleanup((context: WebGLRenderingContext) => { + expect(context).toBe(contextFake); + cleanedup = true; + + // Set context lost back to false to prevent an endless loop: + contextLost = false; + }); + + contextLost = true; + getContextByVersion(10); + + expect(cleanedup).toBe(true); + }); +}); diff --git a/src/backends/webgl/webgl_util.ts b/src/backends/webgl/webgl_util.ts index f35ce52082..b922c647d6 100644 --- a/src/backends/webgl/webgl_util.ts +++ b/src/backends/webgl/webgl_util.ts @@ -17,7 +17,8 @@ import {ENV} from '../../environment'; import * as util from '../../util'; -import {getWebGLContext} from './canvas_util'; + +import {getContextByVersion} from './webgl_context_manager'; export function callAndCheck( gl: WebGLRenderingContext, debugMode: boolean, func: () => T): T { @@ -28,7 +29,7 @@ export function callAndCheck( return returnValue; } -function checkWebGLError(gl: WebGLRenderingContext) { +export function checkWebGLError(gl: WebGLRenderingContext) { const error = gl.getError(); if (error !== gl.NO_ERROR) { throw new Error('WebGL Error: ' + getWebGLErrorMessage(gl, error)); @@ -219,8 +220,9 @@ export function validateTextureSize(width: number, height: number) { export function createFramebuffer( gl: WebGLRenderingContext, debug: boolean): WebGLFramebuffer { + console.log(' --- is debug: ' + debug); return throwIfNull( - gl, debug, () => gl.createFramebuffer(), + gl, true, () => gl.createFramebuffer(), 'Unable to create WebGLFramebuffer.'); } @@ -501,7 +503,7 @@ export let MAX_TEXTURES_IN_SHADER: number; export function getWebGLMaxTextureSize(webGLVersion: number): number { if (MAX_TEXTURE_SIZE == null) { - const gl = getWebGLContext(webGLVersion); + const gl = getContextByVersion(webGLVersion); MAX_TEXTURE_SIZE = gl.getParameter(gl.MAX_TEXTURE_SIZE); } return MAX_TEXTURE_SIZE; @@ -509,7 +511,7 @@ export function getWebGLMaxTextureSize(webGLVersion: number): number { export function getMaxTexturesInShader(webGLVersion: number): number { if (MAX_TEXTURES_IN_SHADER == null) { - const gl = getWebGLContext(webGLVersion); + const gl = getContextByVersion(webGLVersion); MAX_TEXTURES_IN_SHADER = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS); } // We cap at 16 to avoid spurious runtime "memory exhausted" error. @@ -523,7 +525,7 @@ export function getWebGLDisjointQueryTimerVersion(webGLVersion: number): } let queryTimerVersion: number; - const gl = getWebGLContext(webGLVersion); + const gl = getContextByVersion(webGLVersion); if (hasExtension(gl, 'EXT_disjoint_timer_query_webgl2') && webGLVersion === 2) { @@ -538,12 +540,18 @@ export function getWebGLDisjointQueryTimerVersion(webGLVersion: number): function hasExtension(gl: WebGLRenderingContext, extensionName: string) { const ext = gl.getExtension(extensionName); + try { + checkWebGLError(gl); + } catch (e) { + console.log('exception getting: ' + extensionName); + throw e; + } return ext != null; } export function isWebGLVersionEnabled(webGLVersion: 1|2) { try { - const gl = getWebGLContext(webGLVersion); + const gl = getContextByVersion(webGLVersion); if (gl != null) { return true; } @@ -558,7 +566,7 @@ export function isRenderToFloatTextureEnabled(webGLVersion: number): boolean { return false; } - const gl = getWebGLContext(webGLVersion); + const gl = getContextByVersion(webGLVersion); if (webGLVersion === 1) { if (!hasExtension(gl, 'OES_texture_float')) { @@ -580,7 +588,7 @@ export function isDownloadFloatTextureEnabled(webGLVersion: number): boolean { return false; } - const gl = getWebGLContext(webGLVersion); + const gl = getContextByVersion(webGLVersion); if (webGLVersion === 1) { if (!hasExtension(gl, 'OES_texture_float')) { @@ -602,27 +610,35 @@ export function isDownloadFloatTextureEnabled(webGLVersion: number): boolean { function createFloatTextureAndBindToFramebuffer( gl: WebGLRenderingContext, webGLVersion: number): boolean { - const frameBuffer = gl.createFramebuffer(); - const texture = gl.createTexture(); + const debug = ENV.getBool('DEBUG'); + const frameBuffer = callAndCheck(gl, debug, () => gl.createFramebuffer()); + const texture = callAndCheck(gl, debug, () => gl.createTexture()); - gl.bindTexture(gl.TEXTURE_2D, texture); + callAndCheck(gl, debug, () => gl.bindTexture(gl.TEXTURE_2D, texture)); // tslint:disable-next-line:no-any const internalFormat = webGLVersion === 2 ? (gl as any).RGBA32F : gl.RGBA; - gl.texImage2D( - gl.TEXTURE_2D, 0, internalFormat, 1, 1, 0, gl.RGBA, gl.FLOAT, null); + callAndCheck( + gl, debug, + () => gl.texImage2D( + gl.TEXTURE_2D, 0, internalFormat, 1, 1, 0, gl.RGBA, gl.FLOAT, null)); - gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer); - gl.framebufferTexture2D( - gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); + callAndCheck( + gl, debug, () => gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer)); + callAndCheck( + gl, debug, + () => gl.framebufferTexture2D( + gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0)); - const isFrameBufferComplete = - gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE; + const isFrameBufferComplete = callAndCheck( + gl, debug, + () => gl.checkFramebufferStatus(gl.FRAMEBUFFER) === + gl.FRAMEBUFFER_COMPLETE); - gl.bindTexture(gl.TEXTURE_2D, null); - gl.bindFramebuffer(gl.FRAMEBUFFER, null); - gl.deleteTexture(texture); - gl.deleteFramebuffer(frameBuffer); + callAndCheck(gl, debug, () => gl.bindTexture(gl.TEXTURE_2D, null)); + callAndCheck(gl, debug, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null)); + callAndCheck(gl, debug, () => gl.deleteTexture(texture)); + callAndCheck(gl, debug, () => gl.deleteFramebuffer(frameBuffer)); return isFrameBufferComplete; } @@ -631,7 +647,7 @@ export function isWebGLFenceEnabled(webGLVersion: number) { if (webGLVersion !== 2) { return false; } - const gl = getWebGLContext(webGLVersion); + const gl = getContextByVersion(webGLVersion); // tslint:disable-next-line:no-any const isEnabled = (gl as any).fenceSync != null; diff --git a/src/debug_mode_test.ts b/src/debug_mode_test.ts index d79eca55c0..41a931d861 100644 --- a/src/debug_mode_test.ts +++ b/src/debug_mode_test.ts @@ -22,14 +22,16 @@ import {expectArraysClose} from './test_util'; describeWithFlags('debug on', SYNC_BACKEND_ENVS, () => { beforeAll(() => { + console.log('--- setting debug to TRUE'); tf.ENV.set('DEBUG', true); }); afterAll(() => { + console.log('--- setting debug to FALSE'); tf.ENV.set('DEBUG', false); }); - it('debug mode does not error when no nans', async () => { + it('KREEGER debug mode does not error when no nans', async () => { const a = tf.tensor1d([2, -1, 0, 3]); const res = tf.relu(a); expectArraysClose(await res.data(), [2, 0, 0, 3]); diff --git a/src/webgl.ts b/src/webgl.ts index 183068da05..54386b4429 100644 --- a/src/webgl.ts +++ b/src/webgl.ts @@ -19,8 +19,8 @@ import * as gpgpu_util from './backends/webgl/gpgpu_util'; import * as webgl_util from './backends/webgl/webgl_util'; export {MathBackendWebGL, WebGLMemoryInfo, WebGLTimingInfo} from './backends/webgl/backend_webgl'; -export {setWebGLContext} from './backends/webgl/canvas_util'; export {GPGPUContext} from './backends/webgl/gpgpu_context'; export {GPGPUProgram} from './backends/webgl/gpgpu_math'; +export {getActiveContext, getContextByVersion, setContextCleanup, setContextFactory} from './backends/webgl/webgl_context_manager'; // WebGL specific utils. export {gpgpu_util, webgl_util};