Refactor RMSNorm implementations to use torch.nn.functional.rms_norm #42461
+65
−325
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes #42398
This PR replaces custom
RMSNorm/T5-stylenorm implementations (e.g. in Llama) that manually compute variance and scaling with the built-intorch.nn.functional.rms_norm. For example, code like:is simplified to:
This keeps the behavior and epsilon handling the same while reducing the number of ops, this should improve performance for users without requiring any additional changes on their side.
To verify the performance and the numerical stability, i have wrote the following test
The results show the following:
note: I have encountered that when I try
dtypesthat are lower thenfloat32, old implementation will keep it atfloat32, but my new one will have thedtypeof the input tensor. Thats why i have to cast toy_hf.dtype(tryingfloat64for example will make both implementation outputfloat64). This can be changed, depending on what we want to accomplish.Who can review?
@Rocketknight1