Closed
Description
Bug Description
Crash or throw error with multiple inputs
To Reproduce
import torch
import torch.nn as nn
from typing import Tuple, Dict, List
from torch import Tensor
class Model1(nn.Module):
def __init__(self):
super(Model1, self).__init__()
def forward(self, inputs: List[Tuple[Tensor, Tensor]]):
return inputs[0]
class ExportModel(nn.Module):
def __init__(self, model):
super(ExportModel, self).__init__()
self.model = model
def forward(self, t1, t2):
outputs = self.model([(t1,t2)])
return outputs[1]
export_model = ExportModel(Model1())
torch_script_module = torch.jit.script(export_model)
input_list = [torch.tensor([[1,2,3]]), torch.tensor([[1,2,3]])]
result = torch_script_module(*input_list)
input_shape = []
for input in input_list:
input_shape.append(input.shape)
compile_settings = {
"input_shapes": input_shape,
"op_precision": torch.float32,
"torch_fallback" : {
"enabled" : True,
"min_block_size" : 1
}
}
print(compile_settings)
trt_ts_module = trtorch.compile(torch_script_module, compile_settings)
result = trt_ts_module(*input_list)
torch.jit.save(trt_ts_module, "trt_torchscript_module.ts")
The code above will throw error as below:
~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),
RuntimeError: forward() expected at most 2 argument(s) but received 3 argument(s). Declaration: forward.forward(Tensor t1.1, Tensor t2.1) -> (Tensor t2.1)
If I chahge the torch_fallback->enbled
to False
, it will get segmentation fault
in trtorch.compile
.
Expected behavior
convert a good model.
Environment
release v0.3.0