-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Description
🐛 Describe the bug
Exporting the model below with real tensor prop runs into a data-dependent error, for some reason ShapeEnv.unbacked_var_to_val isn't populated:
class Baz(torch.nn.Module):
def forward(self, x):
x = torch.where(x <= 0.5)[0]
y = torch.randn(x.shape[0], 4)
if y.numel() < 200:
return x + y[:, 0]
model = Baz()
inputs = (torch.randn(64, 32),)
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
ep = export(model, inputs, strict=False)
logs:
(pytorch-3.10) [[email protected] /data/users/pianpwk/pytorch (5e0788bef)]$ TORCH_LOGS="+dynamic" python test/export/test_export.py -k test_crash_real_tensor
libnccl.so.2: cannot open shared object file: No such file or directory
libnccl.so.2: cannot open shared object file: No such file or directory
/data/users/pianpwk/pytorch/test/export/test_export.py:118: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
@torch.library.impl_abstract("testlib::returns_tensor_symint")
/data/users/pianpwk/pytorch/test/export/test_export.py:131: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
@torch.library.impl_abstract("testlib::foo")
V0910 14:06:11.334000 1210439 torch/fx/experimental/symbolic_shapes.py:2498] create_env
/home/pianpwk/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/z3/z3core.py:5: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
import pkg_resources
/home/pianpwk/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/pkg_resources/__init__.py:3144: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(pkg)
/home/pianpwk/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/pkg_resources/__init__.py:3144: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('pyannote')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(pkg)
/home/pianpwk/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/pkg_resources/__init__.py:3144: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('sphinxcontrib')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(pkg)
I0910 14:06:11.408000 1210439 torch/fx/experimental/symbolic_shapes.py:3317] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:426 in nonzero)
V0910 14:06:11.409000 1210439 torch/fx/experimental/symbolic_shapes.py:4734] _update_var_to_range u0 = VR[0, 2048] (update)
I0910 14:06:11.409000 1210439 torch/fx/experimental/symbolic_shapes.py:5481] constrain_symbol_range u0 [0, 2048]
V0910 14:06:11.430000 1210439 torch/fx/experimental/symbolic_shapes.py:5358] runtime_assert u0 >= 0 == True [statically known]
V0910 14:06:11.434000 1210439 torch/fx/experimental/symbolic_shapes.py:5201] eval Eq(u0, 0) == False [statically known]
V0910 14:06:11.439000 1210439 torch/fx/experimental/symbolic_shapes.py:5358] runtime_assert u0 >= 0 == True [statically known]
V0910 14:06:11.439000 1210439 torch/fx/experimental/symbolic_shapes.py:5358] runtime_assert u0 >= 0 == True [statically known]
V0910 14:06:11.442000 1210439 torch/fx/experimental/symbolic_shapes.py:5358] runtime_assert u0 >= 0 == True [statically known]
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] Data dependent variable 'u0' allocated at:
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/test/export/test_export.py", line 8125, in <module>
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] run_tests()
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/testing/_internal/common_utils.py", line 1273, in run_tests
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] unittest.main(argv=argv)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/home/pianpwk/.conda/envs/pytorch-3.10/lib/python3.10/unittest/main.py", line 101, in __init__
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] self.runTests()
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/home/pianpwk/.conda/envs/pytorch-3.10/lib/python3.10/unittest/main.py", line 271, in runTests
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] self.result = testRunner.run(self.test)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/home/pianpwk/.conda/envs/pytorch-3.10/lib/python3.10/unittest/runner.py", line 184, in run
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] test(result)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/home/pianpwk/.conda/envs/pytorch-3.10/lib/python3.10/unittest/suite.py", line 84, in __call__
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] return self.run(*args, **kwds)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/home/pianpwk/.conda/envs/pytorch-3.10/lib/python3.10/unittest/suite.py", line 122, in run
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] test(result)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/home/pianpwk/.conda/envs/pytorch-3.10/lib/python3.10/unittest/suite.py", line 84, in __call__
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] return self.run(*args, **kwds)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/home/pianpwk/.conda/envs/pytorch-3.10/lib/python3.10/unittest/suite.py", line 122, in run
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] test(result)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/home/pianpwk/.conda/envs/pytorch-3.10/lib/python3.10/unittest/case.py", line 650, in __call__
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] return self.run(*args, **kwds)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/testing/_internal/common_utils.py", line 3112, in run
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] self._run_custom(
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/testing/_internal/common_utils.py", line 3084, in _run_custom
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] super_run(result=result)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/home/pianpwk/.conda/envs/pytorch-3.10/lib/python3.10/unittest/case.py", line 591, in run
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] self._callTestMethod(testMethod)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/home/pianpwk/.conda/envs/pytorch-3.10/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] method()
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/testing/_internal/common_utils.py", line 2979, in wrapper
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] method(*args, **kwargs)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/test/export/test_export.py", line 1192, in test_crash_real_tensor
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] ep = export(model, inputs, strict=False)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/export/__init__.py", line 273, in export
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] return _export(
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 990, in wrapper
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] ep = fn(*args, **kwargs)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/export/exported_program.py", line 114, in wrapper
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] return fn(*args, **kwargs)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1880, in _export
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] export_artifact = export_func( # type: ignore[operator]
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1683, in _non_strict_export
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] aten_export_artifact = _to_aten_func( # type: ignore[operator]
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 637, in _export_to_aten_ir
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] gm, graph_signature = transform(aot_export_module)(
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1611, in _aot_export_non_strict
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/_functorch/aot_autograd.py", line 1246, in aot_export_module
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] fx_g, metadata, in_spec, out_spec = _aot_export_function(
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/_functorch/aot_autograd.py", line 1480, in _aot_export_function
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] fx_g, meta = create_aot_dispatcher_function(
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] return _create_aot_dispatcher_function(
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/_functorch/aot_autograd.py", line 623, in _create_aot_dispatcher_function
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] fw_metadata = run_functionalized_fw_and_collect_metadata(
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 173, in inner
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] flat_f_outs = f(*flat_f_args)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/_functorch/_aot_autograd/utils.py", line 182, in flat_fn
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] tree_out = fn(*args, **kwargs)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 863, in functional_call
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] out = mod(*args[params_len:], **kwargs)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] return self._call_impl(*args, **kwargs)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] return forward_call(*args, **kwargs)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1598, in forward
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] tree_out = self._export_root(*args, **kwargs)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] return self._call_impl(*args, **kwargs)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] return forward_call(*args, **kwargs)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/test/export/test_export.py", line 1184, in forward
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] x = torch.where(x <= 0.5)[0]
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/_export/non_strict_utils.py", line 520, in __torch_function__
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] return func(*args, **kwargs)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/_subclasses/functional_tensor.py", line 534, in __torch_dispatch__
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] outs_unwrapped = func._op_dk(
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/utils/_stats.py", line 21, in wrapper
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] return fn(*args, **kwargs)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/_subclasses/fake_tensor.py", line 1238, in __torch_dispatch__
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] return self.dispatch(func, types, args, kwargs)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/_subclasses/fake_tensor.py", line 1694, in dispatch
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] return self._dispatch_impl(func, types, args, kwargs)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/_subclasses/fake_tensor.py", line 1983, in _dispatch_impl
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] op_impl_out = op_impl(self, func, *args, **kwargs)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/_subclasses/fake_impls.py", line 147, in dispatch_to_op_implementations_dict
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/_subclasses/fake_impls.py", line 426, in nonzero
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] nnz = fake_mode.shape_env.create_unbacked_symint()
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] File "/data/users/pianpwk/pytorch/torch/fx/experimental/recording.py", line 262, in wrapper
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679] return retlog(fn(*args, **kwargs))
V0910 14:06:11.449000 1210439 torch/fx/experimental/symbolic_shapes.py:4679]
W0910 14:06:11.450000 1210439 torch/fx/experimental/symbolic_shapes.py:5124] failed during evaluate_expr(4*u0 < 200, hint=None, size_oblivious=False, forcing_spec=False
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298] failed while running evaluate_expr(*(4*u0 < 200, None), **{'fx_node': False})
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298] Traceback (most recent call last):
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298] File "/data/users/pianpwk/pytorch/torch/fx/experimental/recording.py", line 262, in wrapper
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298] return retlog(fn(*args, **kwargs))
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298] File "/data/users/pianpwk/pytorch/torch/fx/experimental/symbolic_shapes.py", line 5122, in evaluate_expr
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298] return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec)
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298] File "/data/users/pianpwk/pytorch/torch/fx/experimental/symbolic_shapes.py", line 5238, in _evaluate_expr
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298] raise self._make_data_dependent_error(
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression 4*u0 < 200 (unhinted: 4*u0 < 200). (Size-like symbols: u0)
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298]
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298] Potential framework code culprit (scroll up for full backtrace):
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298] File "/data/users/pianpwk/pytorch/test/export/test_export.py", line 1186, in forward
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298] if y.numel() < 200:
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298]
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298] For more information, run with TORCH_LOGS="dynamic"
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298]
E0910 14:06:11.450000 1210439 torch/fx/experimental/recording.py:298] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
inductor []
E
======================================================================
ERROR: test_crash_real_tensor (__main__.TestExport)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/data/users/pianpwk/pytorch/torch/testing/_internal/common_utils.py", line 2979, in wrapper
method(*args, **kwargs)
File "/data/users/pianpwk/pytorch/test/export/test_export.py", line 1192, in test_crash_real_tensor
ep = export(model, inputs, strict=False)
File "/data/users/pianpwk/pytorch/torch/export/__init__.py", line 273, in export
return _export(
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1017, in wrapper
raise e
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 990, in wrapper
ep = fn(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/export/exported_program.py", line 114, in wrapper
return fn(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1880, in _export
export_artifact = export_func( # type: ignore[operator]
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1683, in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 637, in _export_to_aten_ir
gm, graph_signature = transform(aot_export_module)(
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1611, in _aot_export_non_strict
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
File "/data/users/pianpwk/pytorch/torch/_functorch/aot_autograd.py", line 1246, in aot_export_module
fx_g, metadata, in_spec, out_spec = _aot_export_function(
File "/data/users/pianpwk/pytorch/torch/_functorch/aot_autograd.py", line 1480, in _aot_export_function
fx_g, meta = create_aot_dispatcher_function(
File "/data/users/pianpwk/pytorch/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
File "/data/users/pianpwk/pytorch/torch/_functorch/aot_autograd.py", line 623, in _create_aot_dispatcher_function
fw_metadata = run_functionalized_fw_and_collect_metadata(
File "/data/users/pianpwk/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 173, in inner
flat_f_outs = f(*flat_f_args)
File "/data/users/pianpwk/pytorch/torch/_functorch/_aot_autograd/utils.py", line 182, in flat_fn
tree_out = fn(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 863, in functional_call
out = mod(*args[params_len:], **kwargs)
File "/data/users/pianpwk/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1598, in forward
tree_out = self._export_root(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data/users/pianpwk/pytorch/test/export/test_export.py", line 1186, in forward
if y.numel() < 200:
File "/data/users/pianpwk/pytorch/torch/__init__.py", line 680, in __bool__
return self.node.bool_()
File "/data/users/pianpwk/pytorch/torch/fx/experimental/sym_node.py", line 511, in bool_
return self.guard_bool("", 0)
File "/data/users/pianpwk/pytorch/torch/fx/experimental/sym_node.py", line 449, in guard_bool
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
File "/data/users/pianpwk/pytorch/torch/fx/experimental/recording.py", line 262, in wrapper
return retlog(fn(*args, **kwargs))
File "/data/users/pianpwk/pytorch/torch/fx/experimental/symbolic_shapes.py", line 5122, in evaluate_expr
return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec)
File "/data/users/pianpwk/pytorch/torch/fx/experimental/symbolic_shapes.py", line 5238, in _evaluate_expr
raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression 4*u0 < 200 (unhinted: 4*u0 < 200). (Size-like symbols: u0)
Potential framework code culprit (scroll up for full backtrace):
File "/data/users/pianpwk/pytorch/test/export/test_export.py", line 1186, in forward
if y.numel() < 200:
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
To execute this test, run the following from the base repo dir:
python test/export/test_export.py TestExport.test_crash_real_tensor
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
----------------------------------------------------------------------
Ran 1 test in 1.258s
FAILED (errors=1)
Versions
Collecting environment information...
PyTorch version: 2.5.0a0+git6e41540
Is debug build: False
CUDA used to build PyTorch: 12.0
ROCM used to build PyTorch: N/A
OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.4.1 20231218 (Red Hat 11.4.1-3)
Clang version: Could not collect
CMake version: version 3.30.2
Libc version: glibc-2.34
Python version: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.19.0-0_fbk12_hardened_11583_g0bef9520ca2b-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.0.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100
GPU 1: NVIDIA H100
GPU 2: NVIDIA H100
GPU 3: NVIDIA H100
Nvidia driver version: 525.105.17
cuDNN version: Probably one of the following:
/usr/lib64/libcudnn.so.9.2.1
/usr/lib64/libcudnn_adv.so.9.2.1
/usr/lib64/libcudnn_cnn.so.9.2.1
/usr/lib64/libcudnn_engines_precompiled.so.9.2.1
/usr/lib64/libcudnn_engines_runtime_compiled.so.9.2.1
/usr/lib64/libcudnn_graph.so.9.2.1
/usr/lib64/libcudnn_heuristic.so.9.2.1
/usr/lib64/libcudnn_ops.so.9.2.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 184
On-line CPU(s) list: 0-183
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9654 96-Core Processor
CPU family: 25
Model: 17
Thread(s) per core: 1
Core(s) per socket: 184
Socket(s): 1
Stepping: 1
BogoMIPS: 4792.80
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr wbnoinvd arat npt lbrv nrip_save tsc_scale vmcb_clean pausefilter pfthreshold v_vmsave_vmload vgif avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm arch_capabilities
Virtualization: AMD-V
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 11.5 MiB (184 instances)
L1i cache: 11.5 MiB (184 instances)
L2 cache: 92 MiB (184 instances)
L3 cache: 2.9 GiB (184 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-183
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2: Vulnerable, IBPB: disabled, STIBP: disabled
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] bert_pytorch==0.0.1a4
[pip3] executorch==0.4.0.dev20240807+cpu
[pip3] flake8==7.1.1
[pip3] mypy==1.9.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.5
[pip3] onnx==1.16.1
[pip3] onnxruntime==1.18.0
[pip3] open-clip-torch==2.24.0
[pip3] optree==0.11.0
[pip3] pytorch-lightning==2.0.7
[pip3] pytorch-sphinx-theme==0.0.24
[pip3] pytorch-triton==3.0.0+45fff310c8
[pip3] torch==2.5.0a0+git2032f10
[pip3] torch_geometric==2.4.0
[pip3] torch-stoi==0.2.1
[pip3] torchao==0.3.1
[pip3] torchaudio==2.4.0a0+b3f6f51
[pip3] torchdiffeq==0.2.4
[pip3] torchmetrics==1.0.3
[pip3] torchpippy==0.2.0+17dae2c
[pip3] torchrec==0.9.0a0+5e30669
[pip3] torchsde==0.2.6
[pip3] torchsr==1.0.4
[pip3] torchtune==0.0.0
[pip3] torchvision==0.19.0a0+a8a9f2d
[pip3] torchx==0.7.0
[pip3] triton==3.0.0
[conda] bert-pytorch 0.0.1a4 dev_0
[conda] blas 1.0 mkl
[conda] executorch 0.4.0.dev20240807+cpu pypi_0 pypi
[conda] magma-cuda121 2.6.1 1 pytorch
[conda] mkl 2023.1.0 h213fc3f_46344
[conda] mkl-include 2023.1.0 h06a4308_46344
[conda] mkl-service 2.4.0 py310h5eee18b_1
[conda] mkl_fft 1.3.8 py310h5eee18b_0
[conda] mkl_random 1.2.4 py310hdb19cb5_0
[conda] numpy 1.23.5 pypi_0 pypi
[conda] open-clip-torch 2.24.0 pypi_0 pypi
[conda] optree 0.11.0 pypi_0 pypi
[conda] pytorch-lightning 2.0.7 pypi_0 pypi
[conda] pytorch-sphinx-theme 0.0.24 dev_0
[conda] pytorch-triton 3.0.0+45fff310c8 pypi_0 pypi
[conda] pytorch3d 0.7.7 dev_0
[conda] torch 2.3.1 pypi_0 pypi
[conda] torch-geometric 2.4.0 pypi_0 pypi
[conda] torch-stoi 0.2.1 pypi_0 pypi
[conda] torchao 0.4.0 pypi_0 pypi
[conda] torchaudio 2.4.0a0+b3f6f51 dev_0
[conda] torchbench 0.1 dev_0
[conda] torchdiffeq 0.2.4 pypi_0 pypi
[conda] torchmetrics 1.0.3 pypi_0 pypi
[conda] torchpippy 0.2.0+17dae2c dev_0
[conda] torchrec 0.9.0a0+5e30669 dev_0
[conda] torchsde 0.2.6 pypi_0 pypi
[conda] torchsr 1.0.4 pypi_0 pypi
[conda] torchtune 0.0.0 pypi_0 pypi
[conda] torchvision 0.16.2 pypi_0 pypi
[conda] torchx 0.7.0 pypi_0 pypi
[conda] triton 3.0.0 pypi_0 pypi
cc @ezyang @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4