Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit c7a57b9

Browse files
SayHelloToWorldNikhil Thorat
authored andcommitted
Fix AdagradOptimizer bug (#183)
* Update session_test.ts * Update adagrad_optimizer.ts * Update adagrad_optimizer.ts * Update model-builder.ts
1 parent f4b135c commit c7a57b9

File tree

3 files changed

+3
-8
lines changed

3 files changed

+3
-8
lines changed

demos/model-builder/model-builder.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,6 @@ export class ModelBuilder extends ModelBuilderPolymer {
369369
break;
370370
}
371371
case "adagrad": {
372-
this.needMomentum = true;
373372
break;
374373
}
375374
default: {
@@ -390,7 +389,7 @@ export class ModelBuilder extends ModelBuilderPolymer {
390389
return new RMSPropOptimizer(+this.learningRate, +this.gamma);
391390
}
392391
case 'adagrad': {
393-
return new AdagradOptimizer(+this.learningRate, +this.momentum);
392+
return new AdagradOptimizer(+this.learningRate);
394393
}
395394
default: {
396395
throw new Error(`Unknown optimizer "${this.selectedOptimizerName}"`);

src/graph/optimizers/adagrad_optimizer.ts

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,8 @@ import {Optimizer} from './optimizer';
2424

2525
export class AdagradOptimizer extends Optimizer {
2626
constructor(
27-
protected learningRate: number, protected momentum: number,
28-
specifiedVariableList?: Node[]) {
27+
protected learningRate: number, specifiedVariableList?: Node[]) {
2928
super(learningRate, specifiedVariableList);
30-
this.m = Scalar.new(momentum);
3129
this.eps = Scalar.new(1e-6);
3230
}
3331

@@ -74,12 +72,10 @@ export class AdagradOptimizer extends Optimizer {
7472

7573
dispose() {
7674
super.dispose();
77-
this.m.dispose();
7875
this.eps.dispose();
7976
this.accumulatedSquaredGradients.dispose();
8077
}
8178

8279
private accumulatedSquaredGradients = new TensorArrayMap();
83-
private m: Scalar;
8480
private eps: Scalar;
8581
}

src/graph/session_test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ describe('Session', () => {
354354
const y = g.reduceSum(g.add(g.matmul(w, x), b));
355355

356356
const safeMode = true;
357-
const optimizer = new AdagradOptimizer(0.1, 0.5);
357+
const optimizer = new AdagradOptimizer(0.1);
358358
const math = new NDArrayMathCPU(safeMode);
359359
const session = new Session(g, math);
360360
const inputProvider: InputProvider = {

0 commit comments

Comments
 (0)