Skip to content

🐛 [Bug] Transformers BERT Model does not compile via FX Path #1673

Closed
@gs-olive

Description

@gs-olive

Bug Description

When compiling the BERT base uncased model via the FX path, the following error is encountered:

Via torchtrt.compile(model, ir="fx",...)
Traceback (most recent call last):
  File "bert.py", line 163, in <module>
    trt_mod = torchtrt.compile(traced, ir="fx", **compile_spec)
  File "~/TensorRT/py/torch_tensorrt/_compile.py", line 142, in compile
    return torch_tensorrt.fx.compile(
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 86, in compile
    return lowerer(module, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 316, in __call__
    return do_lower(module, inputs)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py", line 118, in pass_with_validation
    processed_module = pass_(module, input, *args, **kwargs)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 313, in do_lower
    lower_result = pm(module)
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/passes/pass_manager.py", line 238, in __call__
    out = _pass(out)
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/passes/pass_manager.py", line 238, in __call__
    out = _pass(out)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py", line 68, in wrapped_fn
    return fn(gm, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 247, in <lambda>
    trace_func=lambda module, inputs: acc_tracer.trace(
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py", line 667, in trace
    traced = rewriter_base_trace(mod, ast_rewriter_allow_list, leaf_module_list)
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py", line 585, in rewriter_base_trace
    rewritten_graph, rewritten_mod = AccRewritingTracer().trace(
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py", line 309, in trace
    return super().trace(rewritten, concrete_args), rewritten
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py", line 778, in trace
    (self.create_arg(fn(*args)),),
  File "<eval_with_key>.1", line 9, in forward
TypeError: slice indices must be integers or None or have an __index__ method
Via dynamo_model = torch._dynamo.optimize(fx2trt_compiler)(model)

Note: dynamo_model(*inputs) must be called to cause model compilation and elicit the error.

Traceback (most recent call last):
  File "~/TensorRT/py/torch_tensorrt/fx/test/tracer/dynamo_backend.py", line 96, in fx2trt_compiler
    trt_compiled = fx2trt(gm, example_inputs, **kwargs_fx2trt)
  File "~/TensorRT/py/torch_tensorrt/fx/test/tracer/dynamo_backend.py", line 27, in fx2trt
    acc_model = acc_tracer.trace(model, inputs)
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py", line 681, in trace
    acc_shape_prop.AccShapeProp(traced).propagate(*sample_inputs)
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/passes/shape_prop.py", line 185, in propagate
    return super().run(*args)
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/interpreter.py", line 136, in run
    self.env[node] = self.run_node(node)
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py", line 63, in run_node
    result = self._run_node(n)
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py", line 43, in _run_node
    return super().run_node(n)
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/passes/shape_prop.py", line 152, in run_node
    raise RuntimeError(
RuntimeError: ShapeProp error for: node=%embedding : [#users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_ids, %self_embeddings_word_embeddings_weight), kwargs = {padding_idx: 0, max_norm: None, norm_type: 2.0, scale_grad_by_freq: False, sparse: False}) with meta={}
Via torch_tensorrt.fx.compile(..., is_aten=True,...)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 86, in compile
    return lowerer(module, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 316, in __call__
    return do_lower(module, inputs)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py", line 118, in pass_with_validation
    processed_module = pass_(module, input, *args, **kwargs)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 313, in do_lower
    lower_result = pm(module)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/passes/pass_manager.py", line 238, in __call__
    out = _pass(out)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py", line 68, in wrapped_fn
    return fn(gm, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 262, in <lambda>
    trace_func=lambda module, inputs: aten_tracer.opt_trace(
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 153, in opt_trace
    pr: PassResult = passes(fx_module)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py", line 430, in compose_bmm
    if len(real_other.meta["val"].size()) == 2:
KeyError: 'val'

To Reproduce

Steps to reproduce the behavior:

  1. Initialize model: BertModel.from_pretrained("bert-base-uncased").eval().cuda()
  2. Initialize two input tensors, for example: torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
  3. (Optional) Use the transformers tools to trace the model via: transformers.utils.fx.symbolic_trace(model, input_names=["input_ids", "attention_mask"])
  4. Compile the model using FX

Expected behavior

Model should compile via the FX path

Environment

  • Transformers: 4.26.1
  • Torch-TensorRT Version (e.g. 1.0.0): a219e05
  • PyTorch Version (e.g. 1.0): 2.0.0.dev20230209+cu117
  • CPU Architecture: Intel Xeon CPU
  • OS: Ubuntu 20.04
  • How you installed PyTorch: pip
  • Build command you used: python setup.py develop
  • Are you using local sources or building from archives: local
  • Python version: 3.8.13
  • CUDA version: 11.7

Additional context

Relevant to Issue #1634 and PR #1648, which intend to develop 1:1 parity between FX and TS model compatibility tests.

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions