Closed
Description
Bug Description
When I try to optimize my yolov5 model, the following problem occurs:
INFO: [Torch-TensorRT] - Lowered Graph: graph(%x1 : Tensor,
%x2 : Tensor,
%x3 : Tensor,
%x4 : Tensor):
%5 : int = prim::Constant[value=1]() # test_insert.py:17:0
%t1 : Tensor = aten::add(%x1, %x2, %5) # test_insert.py:17:0
%t2 : Tensor = aten::add(%x1, %x3, %5) # test_insert.py:18:0
%t3 : Tensor = aten::add(%x1, %x4, %5) # test_insert.py:19:0
%t4 : Tensor = aten::add(%t1, %t2, %5) # test_insert.py:20:0
%10 : Tensor = aten::add(%t4, %t3, %5) # test_insert.py:21:0
return (%10)
GRAPH: [Torch-TensorRT] - Node we are looking at: %t1 : Tensor = aten::add(%x1, %x2, %5) # test_insert.py:17:0
GRAPH: [Torch-TensorRT] - Node %t1 : Tensor = aten::add(%x1, %x2, %5) # test_insert.py:17:0 outputs a tensor
GRAPH: [Torch-TensorRT] - Input to node: %x1 : Tensor, %x2 : Tensor, %x3 : Tensor, %x4 : Tensor = prim::Param()
GRAPH: [Torch-TensorRT] - Input outputs a Tensor
GRAPH: [Torch-TensorRT] - Input to node: %x1 : Tensor, %x2 : Tensor, %x3 : Tensor, %x4 : Tensor = prim::Param()
GRAPH: [Torch-TensorRT] - Input outputs a Tensor
GRAPH: [Torch-TensorRT] - Input to node: %5 : int = prim::Constant[value=1]() # test_insert.py:17:0
GRAPH: [Torch-TensorRT] - Node we are looking at: %t2 : Tensor = aten::add(%x1, %x3, %5) # test_insert.py:18:0
GRAPH: [Torch-TensorRT] - Node %t2 : Tensor = aten::add(%x1, %x3, %5) # test_insert.py:18:0 outputs a tensor
GRAPH: [Torch-TensorRT] - Input to node: %x1 : Tensor, %x2 : Tensor, %x3 : Tensor, %x4 : Tensor = prim::Param()
GRAPH: [Torch-TensorRT] - Input outputs a Tensor
GRAPH: [Torch-TensorRT] - Input to node: %x1 : Tensor, %x2 : Tensor, %x3 : Tensor, %x4 : Tensor = prim::Param()
GRAPH: [Torch-TensorRT] - Input outputs a Tensor
GRAPH: [Torch-TensorRT] - Input to node: %5 : int = prim::Constant[value=1]() # test_insert.py:17:0
Segmentation fault (core dumped)
I wrote a simple example to help reproduce this bug.
To Reproduce
Just run:
import torch
import torch.nn as nn
import torch_tensorrt as tt
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
def forward(self,x1,x2,x3,x4):
t1=x1+x2
t2=x1+x3
t3=x1+x4
t4=t1+t2
t5=t4+t3
return t5
a=Model().cuda().eval()
b=torch.jit.trace(a,[torch.ones([1,3,20,20]).cuda(),torch.ones([1,3,20,20]).cuda(),torch.ones([1,3,20,20]).cuda(),torch.ones([1,3,20,20]).cuda()])
compile_settings = {}
compile_settings["inputs"] = [tt.Input(shape = [1,3,20,20]),tt.Input(shape = [1,3,20,20]),tt.Input(shape = [1,3,20,20]),tt.Input(shape = [1,3,20,20])]
tt.logging.set_reportable_log_level(tt.logging.Level.Graph)
tt.compile(b,**compile_settings)
Expected behavior
The model can finish compiling.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version : 1.0.0
- PyTorch Version: 1.10.0
- CPU Architecture: Intel(R) Xeon(R) Platinum 8352Y CPU @ 2.20GHz
- OS (e.g., Linux): CentOS 7
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: 3.6.8
- CUDA version: 11.4
- GPU models and configuration: A30
- Any other relevant information:
Additional context
I have located this bug in torch_tensorrt::core::ir::get_value_first_calc_dtype_opt and fixed it locally. After confirming the existence of this bug, I will open a PR to fix it.