Skip to content

Commit 476910f

Browse files
williamwen42Chao1Han
authored andcommitted
[flex attention] change "==" to "is" in inspect parameter comparison (pytorch#165003)
Patch for pytorch#164760. This doesn't actually fix the underlying torch function issue though. Explanation: `is` is traced differently compared to `__eq__`, so we end up avoiding the issue where we attempt to evaluate `torch.eq(tensor, inspect._empty)` in the first place. Pull Request resolved: pytorch#165003 Approved by: https://github.com/mlazos
1 parent c8b25c0 commit 476910f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torch/nn/attention/flex_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def _get_mod_type(fn: Callable) -> _ModificationType:
261261
num_positional_args = sum(
262262
1
263263
for param in inspect.signature(fn).parameters.values()
264-
if param.default == inspect.Parameter.empty
264+
if param.default is inspect.Parameter.empty
265265
)
266266
assert num_positional_args == 5 or num_positional_args == 4
267267
if num_positional_args == 5:

0 commit comments

Comments
 (0)