Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 61cdbe9

Browse files
LewuatheNikhil Thorat
authored andcommitted
Add Adam optimizer (#170)
* Add Adam optimizer see: https://arxiv.org/abs/1412.6980 * Merge branch 'master' into adam-optimizer * Merge branch 'master' into adam-optimizer
1 parent c7a57b9 commit 61cdbe9

File tree

3 files changed

+191
-0
lines changed

3 files changed

+191
-0
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/**
2+
* @license
3+
* Copyright 2017 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {NDArrayMath} from '../../math/math';
19+
import {NDArray, Scalar} from '../../math/ndarray';
20+
import {Node} from '../graph';
21+
import {SessionRuntime} from '../session';
22+
import {SummedTensorArrayMap, TensorArrayMap} from '../tensor_array_map';
23+
24+
import {Optimizer} from './optimizer';
25+
26+
export class AdamOptimizer extends Optimizer {
27+
constructor(
28+
protected learningRate: number,
29+
private beta1: number, private beta2: number,
30+
specifiedVariableList?: Node[]) {
31+
super(learningRate, specifiedVariableList);
32+
this.eps = Scalar.new(1e-8);
33+
// b1, b2 keep initial value of beta* hyperparameters.
34+
this.b1 = Scalar.new(this.beta1);
35+
this.b2 = Scalar.new(this.beta2);
36+
// accB* will be updated by batch.
37+
this.accB1 = Scalar.new(this.beta1);
38+
this.accB2 = Scalar.new(this.beta2);
39+
}
40+
41+
beforeBatch(
42+
math: NDArrayMath, batchSize: number, runtime: SessionRuntime,
43+
activationArrayMap: TensorArrayMap,
44+
gradientArrayMap: SummedTensorArrayMap) {
45+
super.beforeBatch(
46+
math, batchSize, runtime, activationArrayMap, gradientArrayMap);
47+
48+
if (this.firstMoment.size() === 0) {
49+
this.variableNodes.forEach(node => {
50+
this.firstMoment.set(node.output, NDArray.zeros(node.output.shape));
51+
});
52+
}
53+
54+
if (this.secondMoment.size() === 0) {
55+
this.variableNodes.forEach(node => {
56+
this.secondMoment.set(node.output, NDArray.zeros(node.output.shape));
57+
});
58+
}
59+
}
60+
61+
afterBatch(
62+
math: NDArrayMath, batchSize: number, runtime: SessionRuntime,
63+
activationArrayMap: TensorArrayMap,
64+
gradientArrayMap: SummedTensorArrayMap) {
65+
math.scope((keep) => {
66+
this.variableNodes.forEach(node => {
67+
const oldVariable = activationArrayMap.get(node.output);
68+
const gradient = this.variableGradients.get(node.output);
69+
70+
const oldFirstMoment = this.firstMoment.get(node.output);
71+
const oldSecondMoment = this.secondMoment.get(node.output);
72+
73+
const newFirstMoment = math.scaledArrayAdd(
74+
this.b1, oldFirstMoment, math.sub(this.one, this.b1), gradient);
75+
const gradientSquare = math.multiply(gradient, gradient);
76+
const newSecondMoment = math.scaledArrayAdd(
77+
this.b2, oldSecondMoment, math.sub(this.one, this.b2),
78+
gradientSquare);
79+
80+
const biasCorrectedFirstMoment = math.divide(
81+
newFirstMoment, math.sub(this.one, this.accB1));
82+
const biasCorrectedSecondMoment = math.divide(
83+
newSecondMoment, math.sub(this.one, this.accB2));
84+
85+
const variable = math.scaledArrayAdd(
86+
this.c, math.divide(biasCorrectedFirstMoment,
87+
math.add(math.sqrt(biasCorrectedSecondMoment), this.eps)),
88+
this.one, oldVariable);
89+
activationArrayMap.set(node.output, keep(variable));
90+
node.data = variable;
91+
92+
this.firstMoment.set(node.output, keep(newFirstMoment));
93+
this.secondMoment.set(node.output, keep(newSecondMoment));
94+
95+
oldVariable.dispose();
96+
gradient.dispose();
97+
oldFirstMoment.dispose();
98+
oldSecondMoment.dispose();
99+
});
100+
// accB* represents beta1 and beta2 to
101+
// the power t (the number of iteration).
102+
this.accB1 = keep(math.multiply(this.accB1, this.b1));
103+
this.accB2 = keep(math.multiply(this.accB2, this.b2));
104+
});
105+
106+
this.variableGradients.dispose();
107+
this.variableGradients = new TensorArrayMap();
108+
}
109+
110+
dispose() {
111+
super.dispose();
112+
this.firstMoment.dispose();
113+
this.secondMoment.dispose();
114+
this.eps.dispose();
115+
this.b1.dispose();
116+
this.b2.dispose();
117+
this.accB1.dispose();
118+
this.accB2.dispose();
119+
}
120+
121+
// Average of gradient
122+
private firstMoment = new TensorArrayMap();
123+
// Average of squared gradient
124+
private secondMoment = new TensorArrayMap();
125+
private eps: Scalar;
126+
private b1: Scalar;
127+
private b2: Scalar;
128+
private accB1: Scalar;
129+
private accB2: Scalar;
130+
}

src/graph/session_test.ts

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import {MomentumOptimizer} from './optimizers/momentum_optimizer';
2727
import {RMSPropOptimizer} from './optimizers/rmsprop_optimizer';
2828
import {SGDOptimizer} from './optimizers/sgd_optimizer';
2929
import {AdadeltaOptimizer} from './optimizers/adadelta_optimizer';
30+
import {AdamOptimizer} from './optimizers/adam_optimizer';
3031
import {FeedDictionary, FeedEntry, Session} from './session';
3132

3233
describe('FeedDictionary', () => {
@@ -498,4 +499,63 @@ describe('Session', () => {
498499
dydw2, new Float32Array([-.4, -.8]), 2e-5);
499500
});
500501
});
502+
503+
it('adam', () => {
504+
const x = g.placeholder('x', [2]);
505+
const w = g.variable('w', NDArray.zeros([1, 2]));
506+
const b = g.variable('b', NDArray.zeros([1]));
507+
const y = g.reduceSum(g.add(g.matmul(w, x), b));
508+
509+
const safeMode = true;
510+
const optimizer = new AdamOptimizer(0.1, 0.8, 0.9);
511+
const math = new NDArrayMathCPU(safeMode);
512+
const session = new Session(g, math);
513+
const inputProvider: InputProvider = {
514+
getNextCopy() {
515+
return Array1D.new([2, 4]);
516+
},
517+
disposeCopy(math, example) {}
518+
};
519+
520+
math.scope(() => {
521+
// w = reduce_sum(w_1*x_1 + w_2*x_2 + b)
522+
// new_first_m = [beta1*old_first_m_w1 + (1-beta1)*grad_w1,
523+
// beta1*old_first_m_w2 + (1-beta1)*grad_w2]
524+
// = [.4, .8]
525+
// new_second_m = [beta2*old_second_m_w1 + (1-beta2)*grad_w1**2,
526+
// beta2*old_second_m_w2 + (1-beta2)*grad_w2**2]
527+
// = [.4, 1.6]
528+
// m = [new_first_m/(1-acc_beta1)] = [2, 4]
529+
// v = [new_second_m/(1-acc_beta2)] = [4, 16]
530+
// updates = [m_1/(sqrt(v_1) + eps),
531+
// m_2/(sqrt(v_2) + eps)]
532+
// = [1.0, 1.0]
533+
// w = [ w1_old - lr*updates_1, w2_old - lr*updates_2]
534+
// = [-0.1, -0.1]
535+
//
536+
session.train(y, [{tensor: x, data: inputProvider}], 1, optimizer);
537+
const dydw = session.activationArrayMap.get(w).getValues();
538+
test_util.expectArraysClose(
539+
dydw, new Float32Array([-0.1, -0.1]), 1e-5);
540+
// new_first_m = [beta1*old_first_m_w1 + (1-beta1)*grad_w1,
541+
// beta1*old_first_m_w2 + (1-beta1)*grad_w2]
542+
// = [0.8*0.4 + 0.2*2, 0.8*0.8 + 0.2*4]
543+
// = [0.72, 1.44]
544+
// new_second_m = [beta2*old_second_m_w1 + (1-beta2)*grad_w1**2,
545+
// beta2*old_second_m_w2 + (1-beta2)*grad_w2**2]
546+
// = [0.9*0.4 + 0.1*4, 0.9*1.6+0.1*16]
547+
// = [0.76, 3.04]
548+
// m = [new_first_m/(1-acc_beta1)] = [2, 4]
549+
// v = [new_second_m/(1-acc_beta2)] = [4, 16]
550+
// updates = [m_1/sqrt(v_1) + eps,
551+
// m_2/sqrt(v_2) + eps]
552+
// = [1.0, 1.0]
553+
// w = [ w1_old - lr*updates_1, w2_old - lr*updates_2]
554+
// = [-0.2, -0.2]
555+
session.train(y, [{tensor: x, data: inputProvider}], 1, optimizer);
556+
const dydw2 = session.activationArrayMap.get(w).getValues();
557+
test_util.expectArraysClose(
558+
dydw2, new Float32Array([-.2, -.2]), 2e-5);
559+
});
560+
});
501561
});

src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ export {MomentumOptimizer} from './graph/optimizers/momentum_optimizer';
3636
export {Optimizer} from './graph/optimizers/optimizer';
3737
export {RMSPropOptimizer} from './graph/optimizers/rmsprop_optimizer';
3838
export {SGDOptimizer} from './graph/optimizers/sgd_optimizer';
39+
export {AdamOptimizer} from './graph/optimizers/adam_optimizer';
3940
export {CostReduction, FeedEntry, Session} from './graph/session';
4041
// tslint:disable-next-line:max-line-length
4142
export {GraphRunner, GraphRunnerEventObserver, MetricReduction} from './graph_runner';

0 commit comments

Comments
 (0)