Closed
Description
Bug Description
When compiling the BERT base uncased model via the FX path, the following error is encountered:
Via torchtrt.compile(model, ir="fx",...)
Traceback (most recent call last):
File "bert.py", line 163, in <module>
trt_mod = torchtrt.compile(traced, ir="fx", **compile_spec)
File "~/TensorRT/py/torch_tensorrt/_compile.py", line 142, in compile
return torch_tensorrt.fx.compile(
File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 86, in compile
return lowerer(module, input)
File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 316, in __call__
return do_lower(module, inputs)
File "~/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py", line 118, in pass_with_validation
processed_module = pass_(module, input, *args, **kwargs)
File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 313, in do_lower
lower_result = pm(module)
File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/passes/pass_manager.py", line 238, in __call__
out = _pass(out)
File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/passes/pass_manager.py", line 238, in __call__
out = _pass(out)
File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py", line 68, in wrapped_fn
return fn(gm, input)
File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 247, in <lambda>
trace_func=lambda module, inputs: acc_tracer.trace(
File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py", line 667, in trace
traced = rewriter_base_trace(mod, ast_rewriter_allow_list, leaf_module_list)
File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py", line 585, in rewriter_base_trace
rewritten_graph, rewritten_mod = AccRewritingTracer().trace(
File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py", line 309, in trace
return super().trace(rewritten, concrete_args), rewritten
File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py", line 778, in trace
(self.create_arg(fn(*args)),),
File "<eval_with_key>.1", line 9, in forward
TypeError: slice indices must be integers or None or have an __index__ method
Via dynamo_model = torch._dynamo.optimize(fx2trt_compiler)(model)
Note: dynamo_model(*inputs)
must be called to cause model compilation and elicit the error.
Traceback (most recent call last):
File "~/TensorRT/py/torch_tensorrt/fx/test/tracer/dynamo_backend.py", line 96, in fx2trt_compiler
trt_compiled = fx2trt(gm, example_inputs, **kwargs_fx2trt)
File "~/TensorRT/py/torch_tensorrt/fx/test/tracer/dynamo_backend.py", line 27, in fx2trt
acc_model = acc_tracer.trace(model, inputs)
File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py", line 681, in trace
acc_shape_prop.AccShapeProp(traced).propagate(*sample_inputs)
File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/passes/shape_prop.py", line 185, in propagate
return super().run(*args)
File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/interpreter.py", line 136, in run
self.env[node] = self.run_node(node)
File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py", line 63, in run_node
result = self._run_node(n)
File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py", line 43, in _run_node
return super().run_node(n)
File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/passes/shape_prop.py", line 152, in run_node
raise RuntimeError(
RuntimeError: ShapeProp error for: node=%embedding : [#users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_ids, %self_embeddings_word_embeddings_weight), kwargs = {padding_idx: 0, max_norm: None, norm_type: 2.0, scale_grad_by_freq: False, sparse: False}) with meta={}
Via torch_tensorrt.fx.compile(..., is_aten=True,...)
File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 86, in compile
return lowerer(module, input)
File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 316, in __call__
return do_lower(module, inputs)
File "~/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py", line 118, in pass_with_validation
processed_module = pass_(module, input, *args, **kwargs)
File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 313, in do_lower
lower_result = pm(module)
File "/usr/local/lib/python3.8/dist-packages/torch/fx/passes/pass_manager.py", line 238, in __call__
out = _pass(out)
File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py", line 68, in wrapped_fn
return fn(gm, input)
File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 262, in <lambda>
trace_func=lambda module, inputs: aten_tracer.opt_trace(
File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 153, in opt_trace
pr: PassResult = passes(fx_module)
File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py", line 430, in compose_bmm
if len(real_other.meta["val"].size()) == 2:
KeyError: 'val'
To Reproduce
Steps to reproduce the behavior:
- Initialize model:
BertModel.from_pretrained("bert-base-uncased").eval().cuda()
- Initialize two input tensors, for example:
torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
- (Optional) Use the
transformers
tools to trace the model via:transformers.utils.fx.symbolic_trace(model, input_names=["input_ids", "attention_mask"])
- Compile the model using FX
Expected behavior
Model should compile via the FX path
Environment
- Transformers: 4.26.1
- Torch-TensorRT Version (e.g. 1.0.0): a219e05
- PyTorch Version (e.g. 1.0): 2.0.0.dev20230209+cu117
- CPU Architecture: Intel Xeon CPU
- OS: Ubuntu 20.04
- How you installed PyTorch: pip
- Build command you used:
python setup.py develop
- Are you using local sources or building from archives: local
- Python version: 3.8.13
- CUDA version: 11.7
Additional context
Relevant to Issue #1634 and PR #1648, which intend to develop 1:1 parity between FX and TS model compatibility tests.