Closed
Description
Bug Description
ConvTranspose operations not converting.
To Reproduce
import torch
import torch.nn.functional as F
import torch.nn as nn
import trtorch
trtorch.logging.set_reportable_log_level(trtorch.logging.Level.Debug)
chans = 32
model = nn.Sequential(nn.ConvTranspose2d(chans*2, chans, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False))
#model = nn.Sequential(nn.Conv2d(chans*2, chans, kernel_size=3, padding=1, stride=2, bias=False))
model = model.cuda().eval()
indata = torch.rand(16, chans*2, 32, 32).cuda()
# Standard Pytorch
out = model(indata)
print('std:', out.shape)
#TRTorch
traced_model = torch.jit.trace(
model, [indata]
)
traced_model = traced_model.cuda()
trt_model = trtorch.compile(traced_model, {
"input_shapes": [indata.shape],
"op_precision": torch.float32,
"max_batch_size": 1,
"torch_fallback": {
"enabled": False,
"force_fallback_ops": [
"aten::_convolution",
],
}
})
out = trt_model(indata)
print('trt:', out.shape)
DEBUG: [TRTorch] - Settings requested for Lowering:
Forced Fallback Modules: [
]
DEBUG: [TRTorch] - After marking operations for torch fallback: graph(%input : Tensor):
%2 : int[] = prim::Constant[value=[1, 1]]()
%3 : int[] = prim::Constant[value=[2, 2]]()
%4 : NoneType = prim::Constant(), scope: __module.0
%5 : int = prim::Constant[value=1](), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0
%6 : bool = prim::Constant[value=1](), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0
%7 : bool = prim::Constant[value=0](), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0
%self.0.weight : Float(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%9 : Tensor = aten::_convolution(%input, %self.0.weight, %4, %3, %2, %2, %6, %2, %5, %7, %7, %6, %6), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0
return (%9)
DEBUG: [TRTorch] - Post unpack var: graph(%input : Tensor):
%2 : int[] = prim::Constant[value=[1, 1]]()
%3 : int[] = prim::Constant[value=[2, 2]]()
%4 : NoneType = prim::Constant(), scope: __module.0
%5 : int = prim::Constant[value=1](), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0
%6 : bool = prim::Constant[value=1](), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0
%7 : bool = prim::Constant[value=0](), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0
%self.0.weight : Float(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%9 : Tensor = aten::_convolution(%input, %self.0.weight, %4, %3, %2, %2, %6, %2, %5, %7, %7, %6, %6), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0
return (%9)
DEBUG: [TRTorch] - RemoveNOPs - Note: Removing operators that have no meaning in TRT
INFO: [TRTorch] - graph(%input : Tensor):
%2 : int[] = prim::Constant[value=[1, 1]]()
%3 : int[] = prim::Constant[value=[2, 2]]()
%4 : NoneType = prim::Constant(), scope: __module.0
%5 : int = prim::Constant[value=1](), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0
%6 : bool = prim::Constant[value=1](), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0
%7 : bool = prim::Constant[value=0](), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0
%self.0.weight : Float(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%9 : Tensor = aten::_convolution(%input, %self.0.weight, %4, %3, %2, %2, %6, %2, %5, %7, %7, %6, %6), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0
return (%9)
(CompileGraph)
INFO: [TRTorch Conversion Context] - [MemUsageChange] Init CUDA: CPU +241, GPU +0, now: CPU 2247, GPU 1450 (MiB)
DEBUG: [TRTorch] - Settings requested for TensorRT engine:
Enabled Precisions: Float32
TF32 Floating Point Computation Enabled: 1
Truncate Long and Double: 0
Make Refittable Engine: 0
Debuggable Engine: 0
Strict Types: 0
GPU ID: 0
Allow GPU Fallback (if running on DLA): 0
Min Timing Iterations: 2
Avg Timing Iterations: 1
Max Workspace Size: 0
Max Batch Size: 1
Device Type: GPU
GPU ID: 0
Engine Capability: standard
Calibrator Created: 0
INFO: [TRTorch Conversion Context] - Converting Block
DEBUG: [TRTorch Conversion Context] -
INFO: [TRTorch Conversion Context] - Adding Input input (named: input_0): Input(shape: [16, 64, 32, 32], dtype: Float32, format: NCHW\Contiguous\Linear) in engine (conversion.AddInputs)
DEBUG: [TRTorch Conversion Context] - Evaluating %2 : int[] = prim::Constant[value=[1, 1]]()
DEBUG: [TRTorch Conversion Context] - Found the value to be: [1, 1]
DEBUG: [TRTorch Conversion Context] - Evaluating %3 : int[] = prim::Constant[value=[2, 2]]()
DEBUG: [TRTorch Conversion Context] - Found the value to be: [2, 2]
DEBUG: [TRTorch Conversion Context] - Evaluating %4 : NoneType = prim::Constant(), scope: __module.0
DEBUG: [TRTorch Conversion Context] - Found the value to be: None
DEBUG: [TRTorch Conversion Context] - Evaluating %5 : int = prim::Constant[value=1](), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0
DEBUG: [TRTorch Conversion Context] - Found the value to be: 1
DEBUG: [TRTorch Conversion Context] - Evaluating %6 : bool = prim::Constant[value=1](), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0
DEBUG: [TRTorch Conversion Context] - Found the value to be: True
DEBUG: [TRTorch Conversion Context] - Evaluating %7 : bool = prim::Constant[value=0](), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0
DEBUG: [TRTorch Conversion Context] - Found the value to be: False
DEBUG: [TRTorch Conversion Context] - Evaluating %self.0.weight : Float(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
DEBUG: [TRTorch Conversion Context] - Found the value to be a tensor (shape [64, 32, 3, 3])
INFO: [TRTorch Conversion Context] - Adding Layer %9 : Tensor = aten::_convolution(%input, %self.0.weight, %4, %3, %2, %2, %6, %2, %5, %7, %7, %6, %6), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0 (ctx.AddLayer)
DEBUG: [TRTorch Conversion Context] - Node input is an already converted tensor
DEBUG: [TRTorch Conversion Context] - Node input is a result of a previously evaluated value
DEBUG: [TRTorch Conversion Context] - Node input is a result of a previously evaluated value
DEBUG: [TRTorch Conversion Context] - Node input is a result of a previously evaluated value
DEBUG: [TRTorch Conversion Context] - Node input is a result of a previously evaluated value
DEBUG: [TRTorch Conversion Context] - Node input is a result of a previously evaluated value
DEBUG: [TRTorch Conversion Context] - Node input is a result of a previously evaluated value
DEBUG: [TRTorch Conversion Context] - Node input is a result of a previously evaluated value
DEBUG: [TRTorch Conversion Context] - Node input is a result of a previously evaluated value
DEBUG: [TRTorch Conversion Context] - Node input is a result of a previously evaluated value
DEBUG: [TRTorch Conversion Context] - Node input is a result of a previously evaluated value
DEBUG: [TRTorch Conversion Context] - Node input is a result of a previously evaluated value
DEBUG: [TRTorch Conversion Context] - Node input is a result of a previously evaluated value
DEBUG: [TRTorch] - Weights: [64, 32, 3, 3]
Data Type: Float32
Number of input maps: 32
Number of output maps: 64
Element shape: [3,3]
DEBUG: [TRTorch] - Input dims: [16, 64, 32, 32]
DEBUG: [TRTorch] - Weights: Weights: [64, 32, 3, 3]
Data Type: Float32
Number of input maps: 32
Number of output maps: 64
Element shape: [3,3]
DEBUG: [TRTorch] - stride: [2, 2]
DEBUG: [TRTorch] - padding: [1, 1]
DEBUG: [TRTorch] - dilation: [1, 1]
DEBUG: [TRTorch] - out_padding: [1, 1]
DEBUG: [TRTorch] - groups: 1
DEBUG: [TRTorch] - Output tensor shape: [16, 32, 63, 63]
INFO: [TRTorch Conversion Context] - Marking Output 9 named output_0 in engine (ctx.MarkOutput)
INFO: [TRTorch Conversion Context] - [MemUsageSnapshot] Builder begin: CPU 2247 MiB, GPU 1450 MiB
DEBUG: [TRTorch Conversion Context] - Applying generic optimizations to the graph for inference.
DEBUG: [TRTorch Conversion Context] - Original: 1 layers
DEBUG: [TRTorch Conversion Context] - After dead-layer removal: 1 layers
DEBUG: [TRTorch Conversion Context] - After Myelin optimization: 1 layers
DEBUG: [TRTorch Conversion Context] - After scale fusion: 1 layers
DEBUG: [TRTorch Conversion Context] - After vertical fusions: 1 layers
DEBUG: [TRTorch Conversion Context] - After dupe layer removal: 1 layers
DEBUG: [TRTorch Conversion Context] - After final dead-layer removal: 1 layers
DEBUG: [TRTorch Conversion Context] - After tensor merging: 1 layers
DEBUG: [TRTorch Conversion Context] - After concat removal: 1 layers
DEBUG: [TRTorch Conversion Context] - Graph construction and optimization completed in 0.00277048 seconds.
DEBUG: [TRTorch Conversion Context] - Using cublasLt a tactic source
INFO: [TRTorch Conversion Context] - [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +126, GPU +156, now: CPU 2373, GPU 1606 (MiB)
DEBUG: [TRTorch Conversion Context] - Using cuDNN as a tactic source
INFO: [TRTorch Conversion Context] - [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 2373, GPU 1614 (MiB)
WARNING: [TRTorch Conversion Context] - Detected invalid timing cache, setup a local cache instead
DEBUG: [TRTorch Conversion Context] - Constructing optimization profile number 0 [1/1].
DEBUG: [TRTorch Conversion Context] - *************** Autotuning Reformat:Float(65536,1024,32,1) -> Float(65536,1,2048,64) ***************
DEBUG: [TRTorch Conversion Context] - --------------- Timing Runner: Optimizer Reformat (Reformat)
DEBUG: [TRTorch Conversion Context] - Tactic: 1002 Time: 0.019456
DEBUG: [TRTorch Conversion Context] - Tactic: 0 Time: 0.021504
DEBUG: [TRTorch Conversion Context] - Fastest Tactic: 1002 Time: 0.019456
DEBUG: [TRTorch Conversion Context] - *************** Autotuning format combination: Float(65536,1024,32,1) -> Float(127008,3969,63,1) ***************
DEBUG: [TRTorch Conversion Context] - --------------- Timing Runner: %9 : Tensor = aten::_convolution(%input, %self.0.weight, %4, %3, %2, %2, %6, %2, %5, %7, %7, %6, %6), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0 (CudnnDeconvolution)
DEBUG: [TRTorch Conversion Context] - CudnnDeconvolution has no valid tactics for this config, skipping
DEBUG: [TRTorch Conversion Context] - --------------- Timing Runner: %9 : Tensor = aten::_convolution(%input, %self.0.weight, %4, %3, %2, %2, %6, %2, %5, %7, %7, %6, %6), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0 (GemmDeconvolution)
DEBUG: [TRTorch Conversion Context] - Tactic: 0 skipped. Scratch requested: 25165824, available: 0
INFO: [TRTorch Conversion Context] - Some tactics do not have sufficient workspace memory to run. Increasing workspace size may increase performance, please check verbose output.
DEBUG: [TRTorch Conversion Context] - Fastest Tactic: -3360065831133338131 Time: 3.40282e+38
DEBUG: [TRTorch Conversion Context] - --------------- Timing Runner: %9 : Tensor = aten::_convolution(%input, %self.0.weight, %4, %3, %2, %2, %6, %2, %5, %7, %7, %6, %6), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0 (CaskDeconvolution)
DEBUG: [TRTorch Conversion Context] - CaskDeconvolution has no valid tactics for this config, skipping
DEBUG: [TRTorch Conversion Context] - *************** Autotuning format combination: Float(65536,1,2048,64) -> Float(127008,1,2016,32) ***************
DEBUG: [TRTorch Conversion Context] - --------------- Timing Runner: %9 : Tensor = aten::_convolution(%input, %self.0.weight, %4, %3, %2, %2, %6, %2, %5, %7, %7, %6, %6), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0 (CudnnDeconvolution)
DEBUG: [TRTorch Conversion Context] - CudnnDeconvolution has no valid tactics for this config, skipping
DEBUG: [TRTorch Conversion Context] - --------------- Timing Runner: %9 : Tensor = aten::_convolution(%input, %self.0.weight, %4, %3, %2, %2, %6, %2, %5, %7, %7, %6, %6), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0 (GemmDeconvolution)
DEBUG: [TRTorch Conversion Context] - GemmDeconvolution has no valid tactics for this config, skipping
DEBUG: [TRTorch Conversion Context] - --------------- Timing Runner: %9 : Tensor = aten::_convolution(%input, %self.0.weight, %4, %3, %2, %2, %6, %2, %5, %7, %7, %6, %6), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0 (CaskDeconvolution)
DEBUG: [TRTorch Conversion Context] - CaskDeconvolution has no valid tactics for this config, skipping
DEBUG: [TRTorch Conversion Context] - Deleting timing cache: 1 entries, 0 hits
INFO: [TRTorch Conversion Context] - [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +0, now: CPU 2378, GPU 1598 (MiB)
ERROR: [TRTorch Conversion Context] - 10: [optimizer.cpp::computeCosts::1855] Error Code 10: Internal Error (Could not find any implementation for node %9 : Tensor = aten::_convolution(%input, %self.0.weight, %4, %3, %2, %2, %6, %2, %5, %7, %7, %6, %6), scope: __module.0 # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:919:0.)
ERROR: [TRTorch Conversion Context] - 2: [builder.cpp::buildSerializedNetwork::417] Error Code 2: Internal Error (Assertion enginePtr != nullptr failed.)
Traceback (most recent call last):
File "convt_test.py", line 25, in <module>
trt_model = trtorch.compile(traced_model, {
File "/opt/conda/lib/python3.8/site-packages/trtorch/_compiler.py", line 81, in compile
compiled_cpp_mod = trtorch._C.compile_graph(module._c, _parse_compile_spec(compile_spec))
RuntimeError: [Error thrown at core/conversion/conversionctx/ConversionCtx.cpp:160] Building serialized network failed in TensorRT
Expected behavior
Expecting the ConvTranspose2d/3d conversion to work without errors. Expected output shape [16, 32, 64, 64] with the code above (not [16, 32, 63, 63] as in the debug log).
Environment
- TRTorch Version: 0.5.0
- PyTorch Version: 1.10.0
- CPU Architecture: x86_64
- OS: Ubuntu 20.04.2 LTS
- How you installed PyTorch (
conda
,pip
,libtorch
, source): - Build command you used: TRTorch/docker/Dockerfile.21.07
- Python version: 3.8.10
- CUDA version: 11.4
Additional context
Conv2d/3d work fine, issues specific to ConvTranspose .