-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Open
Labels
type:docsNeed to modify the documentationNeed to modify the documentation
Description
The documentation mentions that
rms_scaling: If True, center and scale are ignored, and the inputs are scaled by gamma and the inverse square root of the square of all inputs. This is an approximate and faster approach that avoids ever computing the mean of the input.
However, in the implementation, it actually does the following:
if self.rms_scaling:
# Calculate outputs with only variance and gamma if rms scaling
# is enabled
# Calculate the variance along self.axis (layer activations).
variance = ops.var(inputs, axis=self.axis, keepdims=True)
inv = ops.rsqrt(variance + self.epsilon)
outputs = (
inputs * inv * ops.cast(_broadcast(self.gamma), inputs.dtype)
)
So the mean is indeed used, as variance is computed here rather than RMS norm.
There was also a discussion during the addition of RMS Normalization (#20911 (comment)) that confirms this behavior.
I think the docs could use an update to clarify this behavior. Right now, it sounds like the mean isn't used when rms_scaling is on, but the code suggests otherwise.
Metadata
Metadata
Assignees
Labels
type:docsNeed to modify the documentationNeed to modify the documentation