|
| 1 | +from keras.src import ops |
| 2 | +from keras.src.api_export import keras_export |
| 3 | +from keras.src.layers.layer import Layer |
| 4 | + |
| 5 | + |
| 6 | +@keras_export("keras.layers.RMSNormalization") |
| 7 | +class RMSNormalization(Layer): |
| 8 | + """Root Mean Square (RMS) Normalization layer. |
| 9 | +
|
| 10 | + This layer normalizes the input tensor based on its RMS value. |
| 11 | +
|
| 12 | + The Keras layer performs the operation as described in |
| 13 | + [Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467) |
| 14 | + by Biao Zhang et al. |
| 15 | +
|
| 16 | +
|
| 17 | + If `scale` is enabled, the layer will scale the normalized outputs via |
| 18 | + a learnable scaling factor. |
| 19 | +
|
| 20 | + So, with scaling enabled, the normalization equations |
| 21 | + are as follows: |
| 22 | +
|
| 23 | + Let the intermediate activations for a mini-batch to be the `inputs`. |
| 24 | +
|
| 25 | + ```python |
| 26 | + rms_normalization(x) = x * rsqrt(mean(square(x))) * scale |
| 27 | + ``` |
| 28 | +
|
| 29 | + For example: |
| 30 | +
|
| 31 | + >>> layer = keras.layers.RMSNormalization() |
| 32 | + >>> layer.build([5, 20, 30, 10]) |
| 33 | + >>> print(layer.scale.shape) |
| 34 | + (10,) |
| 35 | + >>> layer(np.random.rand(1, 10)).numpy() |
| 36 | + array([[0.35098287, 1.0495652 , 1.4645109 , 1.2944688 , 0.31124955, |
| 37 | + 1.2768592 , 1.184331 , 0.17474432, 0.49955517, 1.2428929 ]], |
| 38 | + dtype=float32) |
| 39 | +
|
| 40 | + Args: |
| 41 | + axis: int. The axis on which to perform the normalization. |
| 42 | + epsilon: float. A small number to add to avoid division by zero. |
| 43 | + """ |
| 44 | + |
| 45 | + def __init__(self, axis=-1, epsilon=1e-6, **kwargs): |
| 46 | + super().__init__(**kwargs) |
| 47 | + self.axis = axis |
| 48 | + self.epsilon = epsilon |
| 49 | + |
| 50 | + def build(self, input_shape): |
| 51 | + if isinstance(self.axis, list): |
| 52 | + shape = tuple([input_shape[dim] for dim in self.axis]) |
| 53 | + else: |
| 54 | + shape = (input_shape[self.axis],) |
| 55 | + self.axis = [self.axis] |
| 56 | + |
| 57 | + self.scale = self.add_weight( |
| 58 | + name="scale", shape=shape, initializer="ones" |
| 59 | + ) |
| 60 | + |
| 61 | + self.built = True |
| 62 | + |
| 63 | + def call(self, x): |
| 64 | + """Applies RMS normalization to the input tensor. |
| 65 | +
|
| 66 | + Args: |
| 67 | + x: Input tensor of shape (batch_size, input_dim). |
| 68 | +
|
| 69 | + Returns: |
| 70 | + The RMS-normalized tensor of the same shape (batch_size, input_dim), |
| 71 | + scaled by the learned `scale` parameter. |
| 72 | + """ |
| 73 | + return ops.rms_normalization( |
| 74 | + x, scale=self.scale, axis=self.axis, epsilon=self.epsilon |
| 75 | + ) |
| 76 | + |
| 77 | + def compute_output_shape(self, input_shape): |
| 78 | + if isinstance(self.axis, int): |
| 79 | + axes = [self.axis] |
| 80 | + else: |
| 81 | + axes = self.axis |
| 82 | + |
| 83 | + for axis in axes: |
| 84 | + if axis >= len(input_shape) or axis < -len(input_shape): |
| 85 | + raise ValueError( |
| 86 | + f"Axis {axis} is out of bounds for " |
| 87 | + f"input shape {input_shape}. " |
| 88 | + f"Received: axis={self.axis}" |
| 89 | + ) |
| 90 | + return input_shape |
| 91 | + |
| 92 | + def get_config(self): |
| 93 | + config = { |
| 94 | + "axis": self.axis, |
| 95 | + "epsilon": self.epsilon, |
| 96 | + } |
| 97 | + base_config = super().get_config() |
| 98 | + return {**base_config, **config} |
0 commit comments