diff --git a/src/math/math.ts b/src/math/math.ts index 55d3428574..f09f72c95d 100644 --- a/src/math/math.ts +++ b/src/math/math.ts @@ -826,6 +826,15 @@ export abstract class NDArrayMath { } protected abstract sinInternal(ndarray: T): T; + /** + * Computes atan of the input NDArray element-wise, y = atan(x). + * @param ndarray The input NDArray. + */ + atan(ndarray: T): T { + return this.executeOp('atan', () => this.atanInternal(ndarray)); + } + protected abstract atanInternal(ndarray: T): T; + /** * Computes step of the input NDArray element-wise, y = 1 if x > 0 | 0 if x <= * 0 diff --git a/src/math/math_cpu.ts b/src/math/math_cpu.ts index 207f5e3a59..27372e31a4 100644 --- a/src/math/math_cpu.ts +++ b/src/math/math_cpu.ts @@ -364,6 +364,15 @@ export class NDArrayMathCPU extends NDArrayMath { return NDArray.make(ndarray.shape, {values: resultValues}); } + protected atanInternal(ndarray: T): T { + const resultValues = new Float32Array(ndarray.size); + const values = ndarray.getValues(); + for (let i = 0; i < values.length; ++i) { + resultValues[i] = Math.atan(values[i]); + } + return NDArray.make(ndarray.shape, {values: resultValues}); + } + protected stepInternal(ndarray: T): T { const resultValues = new Float32Array(ndarray.size); const values = ndarray.getValues(); diff --git a/src/math/math_cpu_test.ts b/src/math/math_cpu_test.ts index b913729a65..c4ef8a7b16 100644 --- a/src/math/math_cpu_test.ts +++ b/src/math/math_cpu_test.ts @@ -914,6 +914,20 @@ describe('NDArrayMathCPU unary ops', () => { const expected = [Math.sin(4), NaN, Math.sin(0)]; expect(res).toEqual(new Float32Array(expected)); }); + + it('atan', () => { + const a = Array1D.new([4, -3, 0]); + const res = math.atan(a).getValues(); + const expected = [Math.atan(4), Math.atan(-3), Math.atan(0)]; + expect(res).toEqual(new Float32Array(expected)); + }); + + it('atan propagates NaNs', () => { + const a = Array1D.new([4, NaN, 0]); + const res = math.atan(a).getValues(); + const expected = [Math.atan(4), NaN, Math.atan(0)]; + expect(res).toEqual(new Float32Array(expected)); + }); }); describe('NDArrayMathCPU scalar OP ndarray', () => { diff --git a/src/math/math_gpu.ts b/src/math/math_gpu.ts index e11aaab27b..43d75a24e0 100644 --- a/src/math/math_gpu.ts +++ b/src/math/math_gpu.ts @@ -276,6 +276,11 @@ export class NDArrayMathGPU extends NDArrayMath { return this.compileAndRun(program, [a]); } + protected atanInternal(a: T): T { + const program = new UnaryOpProgram(a.shape, UnaryOp.ATAN); + return this.compileAndRun(program, [a]); + } + protected stepInternal(a: T): T { const program = new UnaryOpProgram(a.shape, UnaryOp.STEP); return this.compileAndRun(program, [a]); diff --git a/src/math/math_gpu_test.ts b/src/math/math_gpu_test.ts index e6d334f4b7..9b4e6a1908 100644 --- a/src/math/math_gpu_test.ts +++ b/src/math/math_gpu_test.ts @@ -991,6 +991,27 @@ describe('NDArrayMathGPU unary ops', () => { test_util.expectArraysClose(res, new Float32Array(expected), 1e-4); a.dispose(); }); + + it('atan', () => { + const values = [1, -3, 2, 7, -4]; + const a = Array1D.new(values); + const result = math.atan(a); + const expected = new Float32Array(a.size); + for (let i = 0; i < a.size; i++) { + expected[i] = Math.atan(values[i]); + } + test_util.expectArraysClose(result.getValues(), expected, 1e-3); + + a.dispose(); + }); + + it('atan propagates NaNs', () => { + const a = Array1D.new([4, NaN, 0]); + const res = math.atan(a).getValues(); + const expected = [Math.atan(4), NaN, Math.atan(0)]; + test_util.expectArraysClose(res, new Float32Array(expected), 1e-4); + a.dispose(); + }); }); describe('NDArrayMathGPU min/max', () => { diff --git a/src/math/webgl/unaryop_gpu.ts b/src/math/webgl/unaryop_gpu.ts index ec82eb646a..08acadf4f4 100644 --- a/src/math/webgl/unaryop_gpu.ts +++ b/src/math/webgl/unaryop_gpu.ts @@ -18,7 +18,7 @@ import {GPGPUProgram} from './gpgpu_math'; export enum UnaryOp { - EXP, LOG, SQRT, NEG, RELU, SIGMOID, STEP, SIN, TANH + EXP, LOG, SQRT, NEG, RELU, SIGMOID, STEP, SIN, TANH, ATAN } export class UnaryOpProgram implements GPGPUProgram { @@ -70,6 +70,9 @@ function getOpSnippet(op: UnaryOp) { case UnaryOp.TANH: return `float e2x = exp(-2.0 * abs(v)); float r = sign(v) * (1.0 - e2x) / (1.0 + e2x);`; + case UnaryOp.ATAN: + return CHECK_NAN_SNIPPET + + 'float r = atan(v);'; default: throw Error('Unrecognized unary op type ' + op); }