Skip to content

🐛 [Bug] Compilation Error on GPT-2 #1455

Closed
@gs-olive

Description

@gs-olive

Bug Description

When converting the GPT-2 network (https://huggingface.co/gpt2) from TorchScript to Torch-TRT, the following error is encountered:

compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:167] Unsupported input data type unsigned char

To Reproduce

Steps to reproduce the behavior:

  1. Run torch_tensorrt.compile with GPT-2 model as input, using fp32 precision.
  2. Choose fixed input size of [1, 128] and enable truncate_long_and_double with 12 GB workspace.
  3. Pass in model keyword args to disable attention and hidden state outputs

Expected behavior

Model should successfully compile to Torch-TRT. Specifically, internal (non-user-provided) type-casting issues should not cause errors.

Environment

  • Torch-TensorRT Version: 1.3.0a0+e3b99294
  • PyTorch Version: 1.13.0.dev20220921+cu116
  • CPU Architecture: Intel Xeon CPU
  • OS: Ubuntu 20.04
  • How you installed PyTorch: pip
  • Build command you used: python setup.py develop
  • Are you using local sources or building from archives: local
  • Python version: 3.8.13
  • CUDA version: 11.6

Additional context

The problematic data in GPT-2 seems to be this bias term, instantiated in the attention module, which has type uint8. In both the TorchScript IR and the model code (example 1, example 2), it seems that this bias term is generally cast to a bool. The error is thrown in this code segment:

c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype());
if (dtype == c10::nullopt) {
TORCHTRT_THROW_ERROR("Unsupported input data type " << cur_ivalue.toTensor().dtype());

The conversion of a uint8 type to a TRT Data Type fails, however simply patching this conversion also does not fix the issue, as an out-of-bounds error later follows.

Temporary Solution

A temporary fix to this problem is to add the following to the compilation arguments in torch_tensorrt.compile:

torch_tensorrt.compile( ..., torch_executed_ops=["aten::where"], ...)

This solution works as it happens to exclude the code which uses and processes the uint8 tensor, however it is only a temporary fix and does not resolve the underlying issue.

Steps to a Solution

  • Fix mismatched dimension issue in aten::where
  • Make at::kByte a valid input type

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions