Skip to content

Commit

Permalink
[layers] Fix bug: LayerNormalization registered as BatchNormalization
Browse files Browse the repository at this point in the history
errorneously.

Fixes tensorflow#2170
  • Loading branch information
caisq committed Oct 8, 2019
1 parent 3833990 commit f04ace5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tfjs-layers/src/layers/normalization.ts
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ export interface LayerNormalizationLayerArgs extends LayerArgs {

export class LayerNormalization extends Layer {
/** @nocollapse */
static className = 'BatchNormalization';
static className = 'LayerNormalization';

private axis: number|number[];
readonly epsilon: number;
Expand Down
11 changes: 11 additions & 0 deletions tfjs-layers/src/layers/normalization_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,17 @@ describe('LayerNormalization Layer: Symbolic', () => {
const layerPrime = tfl.layers.layerNormalization(tsConfig);
expect(layerPrime.getConfig()).toEqual(layer.getConfig());
});

it('Deserialize model with BatchNorm Layer', async () => {
// tslint:disable:max-line-length
const modelJSONString =
`{"class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "Dense", "config": {"name": "dense", "trainable": true, "batch_input_shape": [null, 5], "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "BatchNormalization", "config": {"name": "batch_normalization", "trainable": true, "dtype": "float32", "axis": [1], "momentum": 0.99, "epsilon": 0.001, "center": true, "scale": true, "beta_initializer": {"class_name": "Zeros", "config": {}}, "gamma_initializer": {"class_name": "Ones", "config": {}}, "moving_mean_initializer": {"class_name": "Zeros", "config": {}}, "moving_variance_initializer": {"class_name": "Ones", "config": {}}, "beta_regularizer": null, "gamma_regularizer": null, "beta_constraint": null, "gamma_constraint": null}}, {"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 1, "activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}]}, "keras_version": "2.2.4-tf", "backend": "tensorflow"}`;
// tslint:enable:max-line-length
const model = await tfl.models.modelFromJSON(JSON.parse(modelJSONString));
const ys = model.predict(zeros([3, 5])) as Tensor;
expect(ys.shape).toEqual([3, 1]);
expect(model.layers[1].getWeights().length).toEqual(4);
});
});

describeMathCPUAndGPU('LayerNormalization Layer: Tensor', () => {
Expand Down

0 comments on commit f04ace5

Please sign in to comment.