Skip to content

pybind11_object_dealloc error #239

@allenling

Description

@allenling

Description

export onnx file from torch, then build trt engine from onnx file in a for loop

got error

terminate called after throwing an instance of 'std::runtime_error'
  what():  pybind11_object_dealloc(): Tried to deallocate unregistered instance!

Environment

TensorRT Version: 6.0.1.5
GPU Type: RTX2080
Nvidia Driver Version:
CUDA Version: 10.1
CUDNN Version: 7.6.5
Operating System + Version: Ubuntu16.04
Python Version (if applicable): 3.5.2
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 1.3
ONNX Version: 1.6
Baremetal or Container (if container which image + tag):

Relevant Files

test a simple net

import tensorrt as trt
import torch


class TestNet(torch.nn.Module):
    def __init__(self):
        torch.nn.Module.__init__(self)
        self.conv = torch.nn.Conv1d(2, 4, kernel_size=(1,))
        return
    def forward(self, x):
        res = self.conv(x)
        return res


def export_onnx():
    onnx_files = []
    test_net = TestNet()
    test_net.cuda()
    it = torch.randn(1, 2, 3).to("cuda:0")
    for i in range(12):
        fname = "trt_pybind_%s.onnx" % i
        torch.onnx.export(test_net, it, fname, export_params=True, input_names=["data"],
                          keep_initializers_as_inputs=True)
        onnx_files.append(fname)
    return onnx_files


def onnx_to_trt_loop_without_with(onnx_files):
    for onnx_file_path in onnx_files:
        TRT_LOGGER = trt.Logger(trt.Logger.INFO)
        builder = trt.Builder(TRT_LOGGER)
        network = builder.create_network()
        parser = trt.OnnxParser(network, TRT_LOGGER)
        with open(onnx_file_path, 'rb') as onnx_file:
            suc = parser.parse(onnx_file.read())
            print("trt parse %s %s" % (onnx_file_path, suc))
            if suc is False:
                return
            builder.max_workspace_size = 1<<30
            builder.fp16_mode = False
            builder.max_batch_size = 1
            builder.strict_type_constraints = False
            engine = builder.build_cuda_engine(network)
            print("binding(0)", engine.get_binding_name(0))
    return


def onnx_to_trt_loop(onnx_files):
    for onnx_file_path in onnx_files:
        with trt.Logger(trt.Logger.INFO) as TRT_LOGGER, trt.Builder(TRT_LOGGER) as builder,\
        builder.create_network() as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
            with open(onnx_file_path, 'rb') as onnx_file:
                suc = parser.parse(onnx_file.read())
                print("trt parse %s %s" % (onnx_file_path, suc))
                if suc is False:
                    return
            builder.max_workspace_size = 1<<30
            builder.fp16_mode = False
            builder.max_batch_size = 1
            builder.strict_type_constraints = False
            with builder.build_cuda_engine(network) as engine:
                print("binding(0)", engine.get_binding_name(0))
    return


def onnx_to_trt_with_single_file(onnx_file_path):
    with trt.Logger(trt.Logger.INFO) as TRT_LOGGER, trt.Builder(TRT_LOGGER) as builder,\
    builder.create_network() as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
        with open(onnx_file_path, 'rb') as onnx_file:
            suc = parser.parse(onnx_file.read())
            print("trt parse %s %s" % (onnx_file_path, suc))
            if suc is False:
                return
        builder.max_workspace_size = 1<<30
        builder.fp16_mode = False
        builder.max_batch_size = 1
        builder.strict_type_constraints = False
        with builder.build_cuda_engine(network) as engine:
            print("binding(0)", engine.get_binding_name(0))
    return


def build_engine_outside_loop():
    onnx_files = export_onnx()
    for onnx_file_path in onnx_files:
        onnx_to_trt_with_single_file(onnx_file_path)
    return


def build_engine_loop():
    onnx_files = export_onnx()
    onnx_to_trt_loop(onnx_files)
    # onnx_to_trt_loop_without_with(onnx_files)
    return

def main():
    build_engine_loop()
    # build_engine_outside_loop()
    print("main done")
    return
    
    
if __name__ == "__main__":
    main()

Steps To Reproduce

onnx_to_trt_loop build many engines in a for loop, but got the error

onnx_to_trt_loop_without_with, which build many engines in a for loop, would be fine

the deference between onnx_to_trt_loop and onnx_to_trt_loop_without_with is latter remove with statement

further more, build_engine_outside_loop would be fine, which use with statement but move for loop outside

Metadata

Metadata

Assignees

No one assigned

    Labels

    Module:ONNXIssues relating to ONNX usage and import

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions