Skip to content

Commit c6b41eb

Browse files
ezyangaditew01
authored andcommitted
Fix precedence of bitwise and/or printing (pytorch#143197)
Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#143197 Approved by: https://github.com/albanD, https://github.com/williamwen42
1 parent d3a4f8d commit c6b41eb

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

test/dynamo/test_repros.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6315,6 +6315,16 @@ def make_dist_and_execute(t, d):
63156315
for _ in range(2):
63166316
make_dist_and_execute(torch.randn(10), SubCateg)
63176317

6318+
def test_bitwise_print_precedence(self):
6319+
import math
6320+
6321+
@torch.compile(fullgraph=True, dynamic=True)
6322+
def f(x):
6323+
torch._check(math.floor((x.size(0) | 3) * 4) == 12)
6324+
return x.sin()
6325+
6326+
f(torch.randn(2))
6327+
63186328
def test_tensor_split_within_device_cm(self):
63196329
@torch.compile(fullgraph=True)
63206330
def split(x):

torch/utils/_sympy/functions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sympy.core.operations import LatticeOp, ShortCircuit
2626
from sympy.core.sorting import ordered
2727
from sympy.core.traversal import walk
28+
from sympy.printing.precedence import PRECEDENCE
2829
from sympy.utilities.iterables import sift
2930

3031
from .numbers import int_oo
@@ -1350,8 +1351,16 @@ def eval(cls, a):
13501351

13511352

13521353
def make_opaque_bitwise_fn(name, real_op_name):
1354+
if name == "bitwise_and":
1355+
prec = PRECEDENCE["BitwiseAnd"]
1356+
elif name == "bitwise_or":
1357+
prec = PRECEDENCE["BitwiseOr"]
1358+
else:
1359+
raise AssertionError(f"unrecognized {name}")
1360+
13531361
class BitwiseFn(sympy.Function):
13541362
_torch_handler_name = name
1363+
precedence: int = prec
13551364

13561365
@classmethod
13571366
def eval(cls, a, b):

torch/utils/_sympy/printers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ def _print_Relational(self, expr: sympy.Expr) -> str:
2525
return self.stringify(expr.args, f" {expr.rel_op} ", precedence(expr))
2626

2727
def _print_BitwiseFn_bitwise_and(self, expr: sympy.Expr) -> str:
28-
return self.stringify(expr.args, " & ", PRECEDENCE["Atom"] - 0.5)
28+
return self.stringify(expr.args, " & ", PRECEDENCE["BitwiseAnd"])
2929

3030
def _print_BitwiseFn_bitwise_or(self, expr: sympy.Expr) -> str:
31-
return self.stringify(expr.args, " | ", PRECEDENCE["Atom"] - 0.5)
31+
return self.stringify(expr.args, " | ", PRECEDENCE["BitwiseOr"])
3232

3333
# NB: this is OK to put here, because Mod is only defined for positive
3434
# numbers, and so across C/Python its behavior is consistent

0 commit comments

Comments
 (0)