Skip to content

fix: handling of default attrs in SimplifiedLayerNormalization + LayerNormalization🐛 #2396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

KarelZe
Copy link
Contributor

@KarelZe KarelZe commented Jun 17, 2025

SkipLayerNormFusion does currently not fuse ops, if stash_type is at default (=1) or epsilon is at default (=1e-5) for LayerNormalization and SimplifiedLayerNormalization

This pr:

  • fixes handling default attrs in LayerNormalization, SimplifiedLayerNormalization
  • adds BART encoder as new test model. I added this model as some of the stash types are at default. The model is versatile and can also be used to test other fusions e.g., EmbedLayerNormalization.
  • allows for commuted inputs.

Closes #2378.

@shubhambhokare1 @justinchuby Could you please review? Any feedback is greatly appreciated.

Copy link

codecov bot commented Jun 17, 2025

Codecov Report

Attention: Patch coverage is 22.48062% with 200 lines in your changes missing coverage. Please review.

Project coverage is 69.91%. Comparing base (59340c6) to head (ba4a971).
Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
...cript/rewriter/ort_fusions/models/_bart_encoder.py 17.01% 198 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2396      +/-   ##
==========================================
- Coverage   70.37%   69.91%   -0.47%     
==========================================
  Files         199      200       +1     
  Lines       25216    25470     +254     
  Branches     2686     2688       +2     
==========================================
+ Hits        17747    17807      +60     
- Misses       6540     6735     +195     
+ Partials      929      928       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

skip_sum_pattern_2 = op.Add(input, skip)
skip_sum = pattern.OrValue([skip_sum_pattern_1, skip_sum_pattern_2], name="skip_sum")

skip_sum = op.Add(input, skip)
if self._has_bias and not self._bias_pre_add:
skip_sum = op.Add(skip_sum, bias)
Copy link
Contributor Author

@KarelZe KarelZe Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I chose to enable commute(...), as we didn't check for all variants in this addition and only in the lines above.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious, did you see patterns that this missed? In principle this is ok, but it could increase the fusion time. We also need to update the implementation of commute() to make use of pattern-disjunction, which will be more efficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for asking, I've only looked at a limited number of models so far, which all had the bias term as second input. I can undo this change to avoid performance regressions 👍

@KarelZe KarelZe marked this pull request as ready for review June 17, 2025 12:13
@KarelZe KarelZe marked this pull request as draft June 17, 2025 12:47
@justinchuby justinchuby requested review from gramalingam and Copilot and removed request for gramalingam June 17, 2025 15:57
@justinchuby
Copy link
Collaborator

@gramalingam

Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes how default attributes (epsilon, stash_type) are handled in both LayerNormalization and SimplifiedLayerNormalization fusions, adds a BART encoder model to the fusion tests, and introduces commuted-input support for SkipLayerNormalization rules.

  • Extract default epsilon from the matched node instead of requiring it in the pattern signature
  • Add test_bart_encoder to validate fusion with default-attribute cases
  • Enable commuted-input variants by applying .commute() to fusion rules

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
skip_normalization_test.py Added test_bart_encoder to cover default-attribute fusions
skip_normalization.py Refactored patterns to drop default attrs, extract epsilon in rewrite, and apply rule commutation
Comments suppressed due to low confidence (2)

onnxscript/rewriter/ort_fusions/skip_normalization_test.py:73

  • The test uses fuse_skip_layer_normalization(model) but there is no import for that symbol in this file. Please add from onnxscript.rewriter.ort_fusions.skip_normalization import fuse_skip_layer_normalization (or adjust the import path) to ensure the function is available.
        fuse_skip_layer_normalization(model)

onnxscript/rewriter/ort_fusions/skip_normalization.py:231

  • The new .commute() calls are applied only to the full SkipLayerNormalization rules. To allow commuted inputs for SkipSimplifiedLayerNormalization as well, you should apply .commute() to the simplified-layer ruleset (if defined) or include those here before applying apply_fusion_rules.
skip_layer_normalization_ruleset = pattern.RewriteRuleSet(

**_,
):
epsilon = simplified_layer_norm.producer().attributes.get_float("epsilon")
Copy link
Preview

Copilot AI Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You extract epsilon from the matched node but do not extract or forward stash_type. If a non-default stash_type was used, it will be lost in the fused op. Consider retrieving stash_type = simplified_layer_norm.producer().attributes.get_int("stash_type") and passing it into SkipSimplifiedLayerNormalization.

Suggested change
epsilon = simplified_layer_norm.producer().attributes.get_float("epsilon")
epsilon = simplified_layer_norm.producer().attributes.get_float("epsilon")
stash_type = simplified_layer_norm.producer().attributes.get_int("stash_type")

Copilot uses AI. Check for mistakes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you mean also for SimplifiedLayerNorm and SkipSimplifierLayerNorm? Unfortunately, I don't see the doc for the first op. But if it is absent in both ops, it seems safe to ignore it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's consider the two cases separately:

(a) For LayerNorm (the rule down below), we are starting with the ONNX op, which supports stash_type. But the SkipLayerNorm in ORT doesn't seem to support stash_type. Is my understanding correct? If so, the rewrite should have a check condition to see if the stash_type has the value supported by the SkipLayerNorm ... if not, we should skip the optimization.

(b) For SimplifiedLayerNorm, if stash_type is not supported by either op, we can ignore it.

However, for (a): we should understand what the default behavior of the ORT ops are: do they use a value of stash_type == FP32 or do they use a stash_type == input-type? The two are different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. 👍 Yes, neither SkipLayerNorm nor SimplifiedSkipLayerNorm seems to support an external stash_type. (see op.LayerNormalization + SkipLayerNormalization ). I'm not very proficient with c/c++, but to my understanding the internal precision for statistics seems to be depend on strict mode and from the input type T e.g., float or bfloat16. For strict mode the computation should be done in fp32. For non-strict mode, the precision depends on the input precision https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @gramalingam Should we then check for stash_type=None or 1 to be safe?

if self._has_bias and not self._bias_pre_add:
skip_sum = op.Add(skip_sum, bias)

normalized = op.LayerNormalization(
skip_sum,
gamma,
beta,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gramalingam beta is an optional input. I'd lean toward matching both variants (w and w/o bias).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, do you mean all 4 combinations (w and w/o beta, w/wo bias)? That can be done by dropping beta here, and specifying _allow_other_inputs=True. The rewriter should then forward the corresponding inputs to the rewritten node.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry meant with and w/o beta. I'm wondering, if I drop beta here, would it still be forwarded correctly, as beta is the third input to LayerNormalization but fourth input to SkipLayerNormalization? (https://onnx.ai/onnx/operators/onnx__LayerNormalization.html and https://github.com/microsoft/onnxruntime/blob/rel-1.20.0/docs/ContribOperators.md#commicrosoftskiplayernormalization ) In my quick tests, I got different outputs.

@KarelZe KarelZe marked this pull request as ready for review June 18, 2025 04:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

Handling of default attrs of LayerNormalization in SkipLayerNormFusion
3 participants