Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 5cb5951

Browse files
authored
Add math.multinomial and math.oneHot (#160)
* add multinomial * add cpu implementation and tests * add onehot operation * resolve comments
1 parent 7a64583 commit 5cb5951

File tree

9 files changed

+428
-5
lines changed

9 files changed

+428
-5
lines changed

package.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"devDependencies": {
1414
"@types/jasmine": "~2.5.53",
1515
"@types/polymer": "~1.1.31",
16+
"@types/seedrandom": "~2.4.27",
1617
"bower": "~1.8.0",
1718
"browserify": "~14.4.0",
1819
"cross-spawn": "~5.1.0",
@@ -37,5 +38,8 @@
3738
"build": "tsc",
3839
"test": "karma start",
3940
"lint": "tslint -p . --type-check -t verbose"
41+
},
42+
"dependencies": {
43+
"seedrandom": "~2.4.3"
4044
}
4145
}

src/math/math.ts

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ export abstract class NDArrayMath {
194194
return result;
195195
}
196196

197+
/** Disposes the math object and any resources used by it. */
198+
dispose() {}
199+
197200
/**
198201
* Computes the dot product of two matrices, A * B. These must be matrices,
199202
* use matrixTimesVector and vectorTimesMatrix, dotProduct, and outerProduct
@@ -270,9 +273,8 @@ export abstract class NDArrayMath {
270273
`rank ${matrix.rank}.`);
271274
util.assert(
272275
v.size === matrix.shape[0],
273-
`Error in vectorTimesMatrix: size of first rank 1 input (${v.size}) ` +
274-
`must match inner dimension of second rank 2 input, but got ` +
275-
`rank ${matrix.rank}.`);
276+
`Error in vectorTimesMatrix: size of vector (${v.size}) ` +
277+
`must match first dimension of matrix (${matrix.shape[0]})`);
276278

277279
return this.matMul(v.as2D(1, -1), matrix).as1D();
278280
}
@@ -1574,6 +1576,52 @@ export abstract class NDArrayMath {
15741576
});
15751577
return [res[0].as2D(1, -1), res[1].as2D(1, -1)];
15761578
}
1579+
1580+
/**
1581+
* Draws samples from a multinomial distribution.
1582+
*
1583+
* @param probabilities 1D array with normalized outcome probabilities.
1584+
* @param numSamples Number of samples to draw.
1585+
* @param seed Optional. The seed number.
1586+
*/
1587+
multinomial(probabilities: Array1D, numSamples: number, seed?: number):
1588+
Array1D {
1589+
const numOutcomes = probabilities.size;
1590+
if (numOutcomes < 2) {
1591+
throw new Error(
1592+
`Error in multinomial: you need at least 2 outcomes, but got ` +
1593+
`${numOutcomes}.`);
1594+
}
1595+
seed = seed || Math.random();
1596+
return this.executeOp(
1597+
'multinomial',
1598+
() => this.multinomialInternal(probabilities, numSamples, seed));
1599+
}
1600+
protected abstract multinomialInternal(
1601+
probabilities: Array1D, numSamples: number, seed: number): Array1D;
1602+
1603+
/**
1604+
* Returns a one-hot tensor. The locations represented by `indices` take
1605+
* value `onValue` (defaults to 1), while all other locations take value
1606+
* `offValue` (defaults to 0).
1607+
*
1608+
* @param indices 1D Array of indices.
1609+
* @param depth The depth of the one hot dimension.
1610+
* @param onValue A number used to fill in output when the index matches the
1611+
* location.
1612+
* @param offValue A number used to fill in the output when the index does not
1613+
* match the location.
1614+
*/
1615+
oneHot(indices: Array1D, depth: number, onValue = 1, offValue = 0): Array2D {
1616+
if (depth < 2) {
1617+
throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`);
1618+
}
1619+
return this.executeOp(
1620+
'oneHot', () => this.oneHotInternal(indices, depth, onValue, offValue));
1621+
}
1622+
protected abstract oneHotInternal(
1623+
indices: Array1D, depth: number, onValue: number,
1624+
offValue: number): Array2D;
15771625
}
15781626

15791627
export enum MatrixOrientation {

src/math/math_cpu.ts

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
* =============================================================================
1616
*/
1717

18+
import * as seedrandom from 'seedrandom';
1819
import * as util from '../util';
19-
2020
import * as concat_util from './concat_util';
2121
import * as conv_util from './conv_util';
2222
import {ConvInfo} from './conv_util';
@@ -31,7 +31,8 @@ export class NDArrayMathCPU extends NDArrayMath {
3131

3232
protected cloneInternal<T extends NDArray>(ndarray: T): T {
3333
return NDArray.make(
34-
ndarray.shape, {values: new Float32Array(ndarray.getValues())}) as T;
34+
ndarray.shape,
35+
{values: new Float32Array(ndarray.getValues())}) as T;
3536
}
3637

3738
protected slice1DInternal(input: Array1D, begin: number, size: number):
@@ -982,4 +983,48 @@ export class NDArrayMathCPU extends NDArrayMath {
982983
}
983984
return Array3D.make(x.shape, {values: outValues});
984985
}
986+
987+
protected multinomialInternal(
988+
probabilities: Array1D, numSamples: number, seed: number): Array1D {
989+
const probVals = probabilities.getValues();
990+
991+
// The cdf won't include the last event. It will be implicit if not other
992+
// event happened.
993+
const cdf = new Float32Array(probabilities.size - 1);
994+
cdf[0] = probVals[0];
995+
for (let event = 1; event < cdf.length; ++event) {
996+
cdf[event] = cdf[event - 1] + probVals[event];
997+
}
998+
999+
const random = seedrandom(seed.toString());
1000+
const res = new Float32Array(numSamples);
1001+
1002+
for (let i = 0; i < numSamples; ++i) {
1003+
const r = random();
1004+
1005+
// Assume last event happened by default.
1006+
res[i] = cdf.length;
1007+
1008+
for (let event = 0; event < cdf.length; event++) {
1009+
if (r < cdf[event]) {
1010+
res[i] = event;
1011+
break;
1012+
}
1013+
}
1014+
}
1015+
1016+
return Array1D.new(res);
1017+
}
1018+
1019+
protected oneHotInternal(
1020+
indices: Array1D, depth: number, onValue: number,
1021+
offValue: number): Array2D {
1022+
const res = new Float32Array(indices.size * depth);
1023+
res.fill(offValue);
1024+
1025+
for (let event = 0; event < indices.size; ++event) {
1026+
res[event * depth + indices.get(event)] = onValue;
1027+
}
1028+
return Array2D.new([indices.size, depth], res);
1029+
}
9851030
}

src/math/math_gpu.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ import {LogSumExpProgram} from './webgl/logsumexp_gpu';
3838
import {MaxPool2DBackpropProgram} from './webgl/max_pool_backprop_gpu';
3939
import {MinMaxProgram} from './webgl/minmax_gpu';
4040
import {MatMulProgram} from './webgl/mulmat_gpu';
41+
import {MultinomialProgram} from './webgl/multinomial_gpu';
42+
import {OneHotProgram} from './webgl/onehot_gpu';
4143
import {Pool2DProgram} from './webgl/pool_gpu';
4244
import {ReduceSumProgram} from './webgl/reducesum_gpu';
4345
import {ResizeBilinear3DProgram} from './webgl/resize_bilinear_gpu';
@@ -424,6 +426,20 @@ export class NDArrayMathGPU extends NDArrayMath {
424426
return this.compileAndRun(program, [x]);
425427
}
426428

429+
protected multinomialInternal(
430+
probs: Array1D, numSamples: number, seed: number): Array1D {
431+
const program = new MultinomialProgram(probs.size, numSamples);
432+
const customSetup = program.getCustomSetupFunc(seed);
433+
return this.compileAndRun(program, [probs], null, customSetup);
434+
}
435+
436+
protected oneHotInternal(
437+
indices: ndarray.Array1D, depth: number, onValue: number,
438+
offValue: number): ndarray.Array2D {
439+
const program = new OneHotProgram(indices.size, depth, onValue, offValue);
440+
return this.compileAndRun(program, [indices]);
441+
}
442+
427443
private getAndSaveBinary(key: string, getBinary: () => GPGPUBinary):
428444
GPGPUBinary {
429445
if (!(key in this.binaryCache)) {

src/math/multinomial_test.ts

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/**
2+
* @license
3+
* Copyright 2017 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {NDArrayMath} from './math';
19+
import {NDArrayMathCPU} from './math_cpu';
20+
import {NDArrayMathGPU} from './math_gpu';
21+
import {Array1D} from './ndarray';
22+
23+
function executeTests(mathFactory: () => NDArrayMath) {
24+
let math: NDArrayMath;
25+
26+
beforeEach(() => {
27+
math = mathFactory();
28+
math.startScope();
29+
});
30+
31+
afterEach(() => {
32+
math.endScope(null);
33+
math.dispose();
34+
});
35+
36+
it('Flip a fair coin and check bounds', () => {
37+
const probs = Array1D.new([0.5, 0.5]);
38+
const result = math.multinomial(probs, 100);
39+
expect(result.shape).toEqual([100]);
40+
const [min, max] = getBounds(result.getValues());
41+
expect(min >= 0 - 1e-4);
42+
expect(max <= 1 + 1e-4);
43+
});
44+
45+
it('Flip a two-sided coin with 100% of heads', () => {
46+
const probs = Array1D.new([1, 0]);
47+
const result = math.multinomial(probs, 100);
48+
expect(result.shape).toEqual([100]);
49+
const [min, max] = getBounds(result.getValues());
50+
expect(min).toBeCloseTo(0, 1e-4);
51+
expect(max).toBeCloseTo(0, 1e-4);
52+
});
53+
54+
it('Flip a two-sided coin with 100% of tails', () => {
55+
const probs = Array1D.new([0, 1]);
56+
const result = math.multinomial(probs, 100);
57+
expect(result.shape).toEqual([100]);
58+
const [min, max] = getBounds(result.getValues());
59+
expect(min).toBeCloseTo(1, 1e-4);
60+
expect(max).toBeCloseTo(1, 1e-4);
61+
});
62+
63+
it('Flip a single-sided coin throws error', () => {
64+
const probs = Array1D.new([1]);
65+
expect(() => math.multinomial(probs, 100)).toThrowError();
66+
});
67+
68+
it('Flip a ten-sided coin and check bounds', () => {
69+
const n = 10;
70+
const probs = Array1D.zeros([n]);
71+
for (let i = 0; i < n; ++i) {
72+
probs.set(1 / n, i);
73+
}
74+
const result = math.multinomial(probs, 100);
75+
expect(result.shape).toEqual([100]);
76+
const [min, max] = getBounds(result.getValues());
77+
expect(min >= 0 - 1e-4);
78+
expect(max <= 9 + 1e-4);
79+
});
80+
81+
function getBounds(a: Float32Array) {
82+
let min = Number.MAX_VALUE;
83+
let max = Number.MIN_VALUE;
84+
85+
for (let i = 0; i < a.length; ++i) {
86+
min = Math.min(min, a[i]);
87+
max = Math.max(max, a[i]);
88+
}
89+
return [min, max];
90+
}
91+
}
92+
93+
describe('mathCPU multinomial', () => {
94+
executeTests(() => new NDArrayMathCPU());
95+
});
96+
97+
describe('mathGPU multinomial', () => {
98+
executeTests(() => new NDArrayMathGPU());
99+
});

src/math/onehot_test.ts

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/**
2+
* @license
3+
* Copyright 2017 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as test_util from '../test_util';
19+
import {NDArrayMath} from './math';
20+
import {NDArrayMathCPU} from './math_cpu';
21+
import {NDArrayMathGPU} from './math_gpu';
22+
import {Array1D} from './ndarray';
23+
24+
function executeTests(mathFactory: () => NDArrayMath) {
25+
let math: NDArrayMath;
26+
27+
beforeEach(() => {
28+
math = mathFactory();
29+
math.startScope();
30+
});
31+
32+
afterEach(() => {
33+
math.endScope(null);
34+
math.dispose();
35+
});
36+
37+
it('Depth 1 throws error', () => {
38+
const indices = Array1D.new([0, 0, 0]);
39+
expect(() => math.oneHot(indices, 1)).toThrowError();
40+
});
41+
42+
it('Depth 2, diagonal', () => {
43+
const indices = Array1D.new([0, 1]);
44+
const res = math.oneHot(indices, 2);
45+
const expected = new Float32Array([1, 0, 0, 1]);
46+
expect(res.shape).toEqual([2, 2]);
47+
test_util.expectArraysClose(res.getValues(), expected);
48+
});
49+
50+
it('Depth 2, transposed diagonal', () => {
51+
const indices = Array1D.new([1, 0]);
52+
const res = math.oneHot(indices, 2);
53+
const expected = new Float32Array([0, 1, 1, 0]);
54+
expect(res.shape).toEqual([2, 2]);
55+
test_util.expectArraysClose(res.getValues(), expected);
56+
});
57+
58+
it('Depth 3, 4 events', () => {
59+
const indices = Array1D.new([2, 1, 2, 0]);
60+
const res = math.oneHot(indices, 3);
61+
const expected = new Float32Array([0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0]);
62+
expect(res.shape).toEqual([4, 3]);
63+
test_util.expectArraysClose(res.getValues(), expected);
64+
});
65+
66+
it('Depth 2 onValue=3, offValue=-2', () => {
67+
const indices = Array1D.new([0, 1]);
68+
const res = math.oneHot(indices, 2, 3, -2);
69+
const expected = new Float32Array([3, -2, -2, 3]);
70+
expect(res.shape).toEqual([2, 2]);
71+
test_util.expectArraysClose(res.getValues(), expected);
72+
});
73+
}
74+
75+
describe('mathCPU oneHot', () => {
76+
executeTests(() => new NDArrayMathCPU());
77+
});
78+
79+
describe('mathGPU oneHot', () => {
80+
executeTests(() => new NDArrayMathGPU());
81+
});

0 commit comments

Comments
 (0)