-
Notifications
You must be signed in to change notification settings - Fork 70
fix: handling of default attrs in SimplifiedLayerNormalization + LayerNormalization🐛 #2396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2396 +/- ##
==========================================
- Coverage 70.37% 69.91% -0.47%
==========================================
Files 199 200 +1
Lines 25216 25470 +254
Branches 2686 2688 +2
==========================================
+ Hits 17747 17807 +60
- Misses 6540 6735 +195
+ Partials 929 928 -1 ☔ View full report in Codecov by Sentry. |
skip_sum_pattern_2 = op.Add(input, skip) | ||
skip_sum = pattern.OrValue([skip_sum_pattern_1, skip_sum_pattern_2], name="skip_sum") | ||
|
||
skip_sum = op.Add(input, skip) | ||
if self._has_bias and not self._bias_pre_add: | ||
skip_sum = op.Add(skip_sum, bias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I chose to enable commute(...), as we didn't check for all variants in this addition and only in the lines above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am curious, did you see patterns that this missed? In principle this is ok, but it could increase the fusion time. We also need to update the implementation of commute() to make use of pattern-disjunction, which will be more efficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for asking, I've only looked at a limited number of models so far, which all had the bias term as second input. I can undo this change to avoid performance regressions 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes how default attributes (epsilon
, stash_type
) are handled in both LayerNormalization
and SimplifiedLayerNormalization
fusions, adds a BART encoder model to the fusion tests, and introduces commuted-input support for SkipLayerNormalization
rules.
- Extract default
epsilon
from the matched node instead of requiring it in the pattern signature - Add
test_bart_encoder
to validate fusion with default-attribute cases - Enable commuted-input variants by applying
.commute()
to fusion rules
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
skip_normalization_test.py | Added test_bart_encoder to cover default-attribute fusions |
skip_normalization.py | Refactored patterns to drop default attrs, extract epsilon in rewrite, and apply rule commutation |
Comments suppressed due to low confidence (2)
onnxscript/rewriter/ort_fusions/skip_normalization_test.py:73
- The test uses
fuse_skip_layer_normalization(model)
but there is no import for that symbol in this file. Please addfrom onnxscript.rewriter.ort_fusions.skip_normalization import fuse_skip_layer_normalization
(or adjust the import path) to ensure the function is available.
fuse_skip_layer_normalization(model)
onnxscript/rewriter/ort_fusions/skip_normalization.py:231
- The new
.commute()
calls are applied only to the fullSkipLayerNormalization
rules. To allow commuted inputs forSkipSimplifiedLayerNormalization
as well, you should apply.commute()
to the simplified-layer ruleset (if defined) or include those here before applyingapply_fusion_rules
.
skip_layer_normalization_ruleset = pattern.RewriteRuleSet(
**_, | ||
): | ||
epsilon = simplified_layer_norm.producer().attributes.get_float("epsilon") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You extract epsilon
from the matched node but do not extract or forward stash_type
. If a non-default stash_type
was used, it will be lost in the fused op. Consider retrieving stash_type = simplified_layer_norm.producer().attributes.get_int("stash_type")
and passing it into SkipSimplifiedLayerNormalization
.
epsilon = simplified_layer_norm.producer().attributes.get_float("epsilon") | |
epsilon = simplified_layer_norm.producer().attributes.get_float("epsilon") | |
stash_type = simplified_layer_norm.producer().attributes.get_int("stash_type") |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess there is no stash type for fused layer norm ops? https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.SkipLayerNormalization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume you mean also for SimplifiedLayerNorm and SkipSimplifierLayerNorm? Unfortunately, I don't see the doc for the first op. But if it is absent in both ops, it seems safe to ignore it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's consider the two cases separately:
(a) For LayerNorm (the rule down below), we are starting with the ONNX op, which supports stash_type. But the SkipLayerNorm in ORT doesn't seem to support stash_type. Is my understanding correct? If so, the rewrite should have a check condition to see if the stash_type has the value supported by the SkipLayerNorm ... if not, we should skip the optimization.
(b) For SimplifiedLayerNorm, if stash_type is not supported by either op, we can ignore it.
However, for (a): we should understand what the default behavior of the ORT ops are: do they use a value of stash_type == FP32 or do they use a stash_type == input-type? The two are different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. 👍 Yes, neither SkipLayerNorm
nor SimplifiedSkipLayerNorm
seems to support an external stash_type. (see op.LayerNormalization + SkipLayerNormalization ). I'm not very proficient with c/c++, but to my understanding the internal precision for statistics seems to be depend on strict
mode and from the input type T
e.g., float or bfloat16. For strict mode the computation should be done in fp32. For non-strict mode, the precision depends on the input precision https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @gramalingam Should we then check for stash_type=None or 1
to be safe?
if self._has_bias and not self._bias_pre_add: | ||
skip_sum = op.Add(skip_sum, bias) | ||
|
||
normalized = op.LayerNormalization( | ||
skip_sum, | ||
gamma, | ||
beta, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gramalingam beta
is an optional input. I'd lean toward matching both variants (w and w/o bias).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, do you mean all 4 combinations (w and w/o beta, w/wo bias)? That can be done by dropping beta here, and specifying _allow_other_inputs=True
. The rewriter should then forward the corresponding inputs to the rewritten node.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry meant with and w/o beta
. I'm wondering, if I drop beta
here, would it still be forwarded correctly, as beta is the third input to LayerNormalization
but fourth input to SkipLayerNormalization
? (https://onnx.ai/onnx/operators/onnx__LayerNormalization.html and https://github.com/microsoft/onnxruntime/blob/rel-1.20.0/docs/ContribOperators.md#commicrosoftskiplayernormalization ) In my quick tests, I got different outputs.
SkipLayerNormFusion
does currently not fuse ops, if stash_type is at default (=1) or epsilon is at default (=1e-5) forLayerNormalization
andSimplifiedLayerNormalization
This pr:
LayerNormalization
,SimplifiedLayerNormalization
EmbedLayerNormalization
.Closes #2378.
@shubhambhokare1 @justinchuby Could you please review? Any feedback is greatly appreciated.