Skip to content

Commit bec2196

Browse files
authored
[layers] Fix bug: LayerNormalization registered as BatchNormalization (#2174)
erroneously. Fixes #2170 BUG
1 parent 3833990 commit bec2196

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

tfjs-layers/src/layers/normalization.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ export interface LayerNormalizationLayerArgs extends LayerArgs {
472472

473473
export class LayerNormalization extends Layer {
474474
/** @nocollapse */
475-
static className = 'BatchNormalization';
475+
static className = 'LayerNormalization';
476476

477477
private axis: number|number[];
478478
readonly epsilon: number;

tfjs-layers/src/layers/normalization_test.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,17 @@ describe('LayerNormalization Layer: Symbolic', () => {
692692
const layerPrime = tfl.layers.layerNormalization(tsConfig);
693693
expect(layerPrime.getConfig()).toEqual(layer.getConfig());
694694
});
695+
696+
it('Deserialize model with BatchNorm Layer', async () => {
697+
// tslint:disable:max-line-length
698+
const modelJSONString =
699+
`{"class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "Dense", "config": {"name": "dense", "trainable": true, "batch_input_shape": [null, 5], "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "BatchNormalization", "config": {"name": "batch_normalization", "trainable": true, "dtype": "float32", "axis": [1], "momentum": 0.99, "epsilon": 0.001, "center": true, "scale": true, "beta_initializer": {"class_name": "Zeros", "config": {}}, "gamma_initializer": {"class_name": "Ones", "config": {}}, "moving_mean_initializer": {"class_name": "Zeros", "config": {}}, "moving_variance_initializer": {"class_name": "Ones", "config": {}}, "beta_regularizer": null, "gamma_regularizer": null, "beta_constraint": null, "gamma_constraint": null}}, {"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 1, "activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}]}, "keras_version": "2.2.4-tf", "backend": "tensorflow"}`;
700+
// tslint:enable:max-line-length
701+
const model = await tfl.models.modelFromJSON(JSON.parse(modelJSONString));
702+
const ys = model.predict(zeros([3, 5])) as Tensor;
703+
expect(ys.shape).toEqual([3, 1]);
704+
expect(model.layers[1].getWeights().length).toEqual(4);
705+
});
695706
});
696707

697708
describeMathCPUAndGPU('LayerNormalization Layer: Tensor', () => {

0 commit comments

Comments
 (0)