Skip to content

LayerNormalization with rms_scaling documentation is different from implementation #21234

@mzhukova

Description

@mzhukova

The documentation mentions that

rms_scaling: If True, center and scale are ignored, and the inputs are scaled by gamma and the inverse square root of the square of all inputs. This is an approximate and faster approach that avoids ever computing the mean of the input.

However, in the implementation, it actually does the following:

        if self.rms_scaling:
            # Calculate outputs with only variance and gamma if rms scaling
            # is enabled
            # Calculate the variance along self.axis (layer activations).
            variance = ops.var(inputs, axis=self.axis, keepdims=True)
            inv = ops.rsqrt(variance + self.epsilon)

            outputs = (
                inputs * inv * ops.cast(_broadcast(self.gamma), inputs.dtype)
            )

So the mean is indeed used, as variance is computed here rather than RMS norm.

There was also a discussion during the addition of RMS Normalization (#20911 (comment)) that confirms this behavior.

I think the docs could use an update to clarify this behavior. Right now, it sounds like the mean isn't used when rms_scaling is on, but the code suggests otherwise.

Metadata

Metadata

Assignees

Labels

type:docsNeed to modify the documentation

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions