Skip to content

Commit 90cedbe

Browse files
ezyangfacebook-github-bot
authored andcommitted
Some minor type stub improvements (#118529)
Summary: I was just playing around with improving the typing of symbolic_shapes. The PR is not "complete" but I in particular wanted to get feedback on whether or not people liked making ValueRanges Generic; it seems that distinguishing if you have an Expr ValueRange or a SympyBoolean ValueRange is a lot of trouble for downstream. Using TypeGuard, we can perform refinements on the generic parameter inside methods, although we still have to cast back to ValueRange[T] due to python/mypy#14425 (comment) Signed-off-by: Edward Z. Yang <[email protected]> X-link: pytorch/pytorch#118529 Approved by: https://github.com/Skylion007 Reviewed By: clee2000 Differential Revision: D53296779 Pulled By: ezyang fbshipit-source-id: 95799914350e50aeedf1acad93d71b79cad827c8
1 parent 8e2c99f commit 90cedbe

File tree

1 file changed

+12
-12
lines changed
  • userbenchmark/dynamo/dynamobench/_dynamo

1 file changed

+12
-12
lines changed

userbenchmark/dynamo/dynamobench/_dynamo/utils.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2066,18 +2066,18 @@ def defake(x):
20662066
size: "torch._prims_common.ShapeType"
20672067
stride: "torch._prims_common.StrideType"
20682068
if x._has_symbolic_sizes_strides:
2069-
size = [
2070-
s.node.shape_env.size_hint(s.node.expr)
2071-
if isinstance(s, torch.SymInt)
2072-
else s
2073-
for s in x.size()
2074-
]
2075-
stride = [
2076-
s.node.shape_env.size_hint(s.node.expr)
2077-
if isinstance(s, torch.SymInt)
2078-
else s
2079-
for s in x.stride()
2080-
]
2069+
size = []
2070+
for s in x.size():
2071+
if isinstance(s, torch.SymInt):
2072+
size.append(s.node.shape_env.size_hint(s.node.expr))
2073+
else:
2074+
size.append(s)
2075+
stride = []
2076+
for s in x.stride():
2077+
if isinstance(s, torch.SymInt):
2078+
stride.append(s.node.shape_env.size_hint(s.node.expr))
2079+
else:
2080+
stride.append(s)
20812081
else:
20822082
size = x.size()
20832083
stride = x.stride()

0 commit comments

Comments
 (0)