File tree 3 files changed +21
-2
lines changed 3 files changed +21
-2
lines changed Original file line number Diff line number Diff line change @@ -6315,6 +6315,16 @@ def make_dist_and_execute(t, d):
6315
6315
for _ in range (2 ):
6316
6316
make_dist_and_execute (torch .randn (10 ), SubCateg )
6317
6317
6318
+ def test_bitwise_print_precedence (self ):
6319
+ import math
6320
+
6321
+ @torch .compile (fullgraph = True , dynamic = True )
6322
+ def f (x ):
6323
+ torch ._check (math .floor ((x .size (0 ) | 3 ) * 4 ) == 12 )
6324
+ return x .sin ()
6325
+
6326
+ f (torch .randn (2 ))
6327
+
6318
6328
def test_tensor_split_within_device_cm (self ):
6319
6329
@torch .compile (fullgraph = True )
6320
6330
def split (x ):
Original file line number Diff line number Diff line change 25
25
from sympy .core .operations import LatticeOp , ShortCircuit
26
26
from sympy .core .sorting import ordered
27
27
from sympy .core .traversal import walk
28
+ from sympy .printing .precedence import PRECEDENCE
28
29
from sympy .utilities .iterables import sift
29
30
30
31
from .numbers import int_oo
@@ -1350,8 +1351,16 @@ def eval(cls, a):
1350
1351
1351
1352
1352
1353
def make_opaque_bitwise_fn (name , real_op_name ):
1354
+ if name == "bitwise_and" :
1355
+ prec = PRECEDENCE ["BitwiseAnd" ]
1356
+ elif name == "bitwise_or" :
1357
+ prec = PRECEDENCE ["BitwiseOr" ]
1358
+ else :
1359
+ raise AssertionError (f"unrecognized { name } " )
1360
+
1353
1361
class BitwiseFn (sympy .Function ):
1354
1362
_torch_handler_name = name
1363
+ precedence : int = prec
1355
1364
1356
1365
@classmethod
1357
1366
def eval (cls , a , b ):
Original file line number Diff line number Diff line change @@ -25,10 +25,10 @@ def _print_Relational(self, expr: sympy.Expr) -> str:
25
25
return self .stringify (expr .args , f" { expr .rel_op } " , precedence (expr ))
26
26
27
27
def _print_BitwiseFn_bitwise_and (self , expr : sympy .Expr ) -> str :
28
- return self .stringify (expr .args , " & " , PRECEDENCE ["Atom" ] - 0.5 )
28
+ return self .stringify (expr .args , " & " , PRECEDENCE ["BitwiseAnd" ] )
29
29
30
30
def _print_BitwiseFn_bitwise_or (self , expr : sympy .Expr ) -> str :
31
- return self .stringify (expr .args , " | " , PRECEDENCE ["Atom" ] - 0.5 )
31
+ return self .stringify (expr .args , " | " , PRECEDENCE ["BitwiseOr" ] )
32
32
33
33
# NB: this is OK to put here, because Mod is only defined for positive
34
34
# numbers, and so across C/Python its behavior is consistent
You can’t perform that action at this time.
0 commit comments