Skip to content

🐛 [Bug] Crash or throw error with multiple inputs #475

Closed
@seed93

Description

@seed93

Bug Description

Crash or throw error with multiple inputs

To Reproduce

import torch
import torch.nn as nn
from typing import Tuple, Dict, List
from torch import Tensor

class Model1(nn.Module):
    def __init__(self):
        super(Model1, self).__init__()
    def forward(self, inputs: List[Tuple[Tensor, Tensor]]):
        return inputs[0]

class ExportModel(nn.Module):
    def __init__(self, model):
        super(ExportModel, self).__init__()
        self.model = model
    def forward(self, t1, t2):
        outputs = self.model([(t1,t2)])
        return outputs[1]

export_model = ExportModel(Model1())
torch_script_module = torch.jit.script(export_model) 

input_list = [torch.tensor([[1,2,3]]), torch.tensor([[1,2,3]])]

result = torch_script_module(*input_list)

input_shape = []
for input in input_list:
    input_shape.append(input.shape)

compile_settings = {
    "input_shapes": input_shape,
    "op_precision": torch.float32,
    "torch_fallback" : {
      "enabled" : True,
      "min_block_size" : 1
    }
}
print(compile_settings)
trt_ts_module = trtorch.compile(torch_script_module, compile_settings)
result = trt_ts_module(*input_list)
torch.jit.save(trt_ts_module, "trt_torchscript_module.ts")

The code above will throw error as below:

~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

RuntimeError: forward() expected at most 2 argument(s) but received 3 argument(s). Declaration: forward.forward(Tensor t1.1, Tensor t2.1) -> (Tensor t2.1)

If I chahge the torch_fallback->enbled to False, it will get segmentation fault in trtorch.compile.

Expected behavior

convert a good model.

Environment

release v0.3.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions