Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from keras.src.layers.normalization.layer_normalization import (
LayerNormalization,
)
from keras.src.layers.normalization.rms_normalization import RMSNormalization
from keras.src.layers.normalization.spectral_normalization import (
SpectralNormalization,
)
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
from keras.src.ops.nn import psnr
from keras.src.ops.nn import relu
from keras.src.ops.nn import relu6
from keras.src.ops.nn import rms_normalization
from keras.src.ops.nn import selu
from keras.src.ops.nn import separable_conv
from keras.src.ops.nn import sigmoid
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from keras.src.ops.nn import psnr
from keras.src.ops.nn import relu
from keras.src.ops.nn import relu6
from keras.src.ops.nn import rms_normalization
from keras.src.ops.nn import selu
from keras.src.ops.nn import separable_conv
from keras.src.ops.nn import sigmoid
Expand Down
1 change: 1 addition & 0 deletions keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from keras.src.layers.normalization.layer_normalization import (
LayerNormalization,
)
from keras.src.layers.normalization.rms_normalization import RMSNormalization
from keras.src.layers.normalization.spectral_normalization import (
SpectralNormalization,
)
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
from keras.src.ops.nn import psnr
from keras.src.ops.nn import relu
from keras.src.ops.nn import relu6
from keras.src.ops.nn import rms_normalization
from keras.src.ops.nn import selu
from keras.src.ops.nn import separable_conv
from keras.src.ops.nn import sigmoid
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from keras.src.ops.nn import psnr
from keras.src.ops.nn import relu
from keras.src.ops.nn import relu6
from keras.src.ops.nn import rms_normalization
from keras.src.ops.nn import selu
from keras.src.ops.nn import separable_conv
from keras.src.ops.nn import sigmoid
Expand Down
1 change: 1 addition & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from keras.src.layers.normalization.layer_normalization import (
LayerNormalization,
)
from keras.src.layers.normalization.rms_normalization import RMSNormalization
from keras.src.layers.normalization.spectral_normalization import (
SpectralNormalization,
)
Expand Down
4 changes: 3 additions & 1 deletion keras/src/layers/normalization/layer_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ class LayerNormalization(Layer):
rms_scaling: If True, `center` and `scale` are ignored, and the
inputs are scaled by `gamma` and the inverse square root
of the square of all inputs. This is an approximate and faster
approach that avoids ever computing the mean of the input.
approach that avoids ever computing the mean of the input. Note that
this *isn't* equivalent to the computation that the
`keras.layers.RMSNormalization` layer performs.
beta_initializer: Initializer for the beta weight. Defaults to zeros.
gamma_initializer: Initializer for the gamma weight. Defaults to ones.
beta_regularizer: Optional regularizer for the beta weight.
Expand Down
98 changes: 98 additions & 0 deletions keras/src/layers/normalization/rms_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.layers.layer import Layer


@keras_export("keras.layers.RMSNormalization")
class RMSNormalization(Layer):
"""Root Mean Square (RMS) Normalization layer.

This layer normalizes the input tensor based on its RMS value.

The Keras layer performs the operation as described in
[Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467)
by Biao Zhang et al.


If `scale` is enabled, the layer will scale the normalized outputs via
a learnable scaling factor.

So, with scaling enabled, the normalization equations
are as follows:

Let the intermediate activations for a mini-batch to be the `inputs`.

```python
rms_normalization(x) = x * rsqrt(mean(square(x))) * scale
```

For example:

>>> layer = keras.layers.RMSNormalization()
>>> layer.build([5, 20, 30, 10])
>>> print(layer.scale.shape)
(10,)
>>> layer(np.random.rand(1, 10)).numpy()
array([[0.35098287, 1.0495652 , 1.4645109 , 1.2944688 , 0.31124955,
1.2768592 , 1.184331 , 0.17474432, 0.49955517, 1.2428929 ]],
dtype=float32)

Args:
axis: int. The axis on which to perform the normalization.
epsilon: float. A small number to add to avoid division by zero.
"""

def __init__(self, axis=-1, epsilon=1e-6, **kwargs):
super().__init__(**kwargs)
self.axis = axis
self.epsilon = epsilon

def build(self, input_shape):
if isinstance(self.axis, list):
shape = tuple([input_shape[dim] for dim in self.axis])
else:
shape = (input_shape[self.axis],)
self.axis = [self.axis]

self.scale = self.add_weight(
name="scale", shape=shape, initializer="ones"
)

self.built = True

def call(self, x):
"""Applies RMS normalization to the input tensor.

Args:
x: Input tensor of shape (batch_size, input_dim).

Returns:
The RMS-normalized tensor of the same shape (batch_size, input_dim),
scaled by the learned `scale` parameter.
"""
return ops.rms_normalization(
x, scale=self.scale, axis=self.axis, epsilon=self.epsilon
)

def compute_output_shape(self, input_shape):
if isinstance(self.axis, int):
axes = [self.axis]
else:
axes = self.axis

for axis in axes:
if axis >= len(input_shape) or axis < -len(input_shape):
raise ValueError(
f"Axis {axis} is out of bounds for "
f"input shape {input_shape}. "
f"Received: axis={self.axis}"
)
return input_shape

def get_config(self):
config = {
"axis": self.axis,
"epsilon": self.epsilon,
}
base_config = super().get_config()
return {**base_config, **config}
69 changes: 69 additions & 0 deletions keras/src/layers/normalization/rms_normalization_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import numpy as np
import pytest

from keras.src import layers
from keras.src import ops
from keras.src import testing


class RMSNormalizationTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_ln_basics(self):
self.run_layer_test(
layers.RMSNormalization,
init_kwargs={},
input_shape=(4, 2),
expected_output_shape=(4, 2),
expected_num_trainable_weights=1,
expected_num_seed_generators=0,
)
self.run_layer_test(
layers.RMSNormalization,
init_kwargs={
"axis": -1,
},
input_shape=(4, 2),
expected_output_shape=(4, 2),
expected_num_trainable_weights=1,
expected_num_seed_generators=0,
)

def test_correctness(self):
layer = layers.RMSNormalization()
layer.build(input_shape=(2, 2, 2))
inputs = np.random.normal(
loc=5.0, scale=10.0, size=(1000, 2, 2, 2)
).astype("float32")

inputs = ops.convert_to_tensor(inputs)

out = layer(inputs)
expected = (
inputs
* ops.rsqrt(ops.mean(ops.square(inputs), axis=-1, keepdims=True))
* layer.scale
)

self.assertAllClose(out, expected, atol=1e-1)

def test_output(self):
layer = layers.RMSNormalization()
inputs = np.arange(10).astype("float32")[None, :]
out = layer(inputs)
self.assertAllClose(
out,
[
[
0.0,
0.18731716,
0.37463433,
0.5619515,
0.74926865,
0.9365858,
1.123903,
1.3112202,
1.4985373,
1.6858544,
]
],
)
77 changes: 77 additions & 0 deletions keras/src/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from keras.src.backend.common.backend_utils import (
compute_conv_transpose_output_shape,
)
from keras.src.backend.common.keras_tensor import is_keras_tensor
from keras.src.ops import operation_utils
from keras.src.ops.operation import Operation
from keras.src.ops.operation_utils import reduce_shape
Expand Down Expand Up @@ -2653,6 +2654,82 @@ def dot_product_attention(
)


class RMSNorm(Operation):
def __init__(self, scale, axis=-1, epsilon=None):
super().__init__()
self.axis = axis
self.scale = scale
self.epsilon = epsilon

def compute_output_spec(self, x):
return KerasTensor(shape=x.shape)

def call(self, x):
return _rms_normalization(
x, scale=self.scale, axis=self.axis, epsilon=self.epsilon
)


@keras_export(
[
"keras.ops.rms_normalization",
"keras.ops.nn.rms_normalization",
]
)
def rms_normalization(x, scale=1, axis=-1, epsilon=None):
"""Performs Root Mean Square (RMS) normalization on `x`.

The Keras operation implements the operation as described in
[Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467)
by Biao Zhang et al.

The operation is different from LayerNormalization with RMS scaling.

It is defined as `rms_normalization(x) = x * rsqrt(mean(square(x))) * scale`

Args:
x: Input tensor.
axis: The axis or axes along which to perform normalization.
Default to -1.
scale: Optional scaling factor for the normalization.
epsilon: A lower bound value for the norm.
Defaults to `backend.epsilon()`.

Returns:
The normalized array.

Example:

>>> x = np.random.rand(1, 10)
>>> x_norm = keras.ops.rms_normalization(x, (10,))
>>> print(x_norm)
array([[0.69384296, 0.94444374, 0.16551171, 0.05749961, 1.11008865,
0.52475186, 1.57686807, 1.69893307, 1.27292764, 0.30819128]])
"""
if any_symbolic_tensors((x,)):
return RMSNorm(scale=scale, axis=axis, epsilon=epsilon).symbolic_call(x)
return _rms_normalization(x, scale=scale, axis=axis, epsilon=epsilon)


def _rms_normalization(x, scale=1, axis=-1, epsilon=None):
x = backend.convert_to_tensor(x)
if len(x.shape) == 0:
x = backend.numpy.expand_dims(x, axis=0)
if epsilon is None:
epsilon = backend.epsilon()

if not is_keras_tensor(scale):
scale = backend.convert_to_tensor(scale, dtype=x.dtype)
if not is_keras_tensor(epsilon):
epsilon = backend.convert_to_tensor(epsilon, dtype=x.dtype)

rrms = backend.math.rsqrt(
backend.numpy.mean(backend.numpy.square(x), axis=axis, keepdims=True)
+ epsilon
)
return (x * rrms) * scale


class Polar(Operation):
def __init__(self):
super().__init__()
Expand Down
6 changes: 6 additions & 0 deletions keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3142,3 +3142,9 @@ def test_invalid_strategy_ctc_decode(self):
beam_width=beam_width,
top_paths=top_paths,
)

def test_rms_normalization(self):
x = KerasTensor([None, 2, 3])
self.assertEqual(
knn.rms_normalization(x, (None, 2, 3)).shape, (None, 2, 3)
)