diff --git a/core/ir/GraphInputs.cpp b/core/ir/GraphInputs.cpp index 645624f2f1..792189137a 100644 --- a/core/ir/GraphInputs.cpp +++ b/core/ir/GraphInputs.cpp @@ -68,6 +68,7 @@ GraphInputs::GraphInputs(torch::jit::IValue& input_signature_) { inputs = flattened_inputs; input_signature = input_signature_; collection_inputs = collection_inputs_; + LOG_DEBUG("Collection Input Size: " << collection_inputs_.size()); } } // namespace ir diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 3583fe7f10..4b3ff65ff0 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -500,7 +500,7 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa LOG_DEBUG( "In progress TRT block does not meet minimum block size requirements, therefore folding into in progress PyTorch block"); in_prog_pyt_blk_nodes.insert( - in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end()); + in_prog_pyt_blk_nodes.begin(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end()); } in_prog_trt_blk_nodes.clear(); // if there is a prim::If then this if node will be encapsulated in a SegmentedBlock diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.cpp b/py/torch_tensorrt/csrc/tensorrt_classes.cpp index 68c8af8ff3..03bf75e04c 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/tensorrt_classes.cpp @@ -214,11 +214,22 @@ void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IV } else if(input_ivalue.isCustomClass()) { core::ir::Input cur_input = (*(input_ivalue.toCustomClass())).toInternalInput(); converted_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(cur_input))); + } else if(input_ivalue.isPyObject()) { + auto py_object_holder = input_ivalue.toPyObjectHolder(); + auto infer_type = py_object_holder->tryToInferType(); + auto type = infer_type.type(); + torch::jit::IValue ival = py_object_holder->toIValue(type); + torch::jit::IValue converted_item; + to_internal_input_signature(ival, converted_item); + converted_ivalue = torch::jit::IValue(converted_item); + } else { + LOG_ERROR("Unknown input spec type"); } } core::CompileSpec init_compile_spec(CompileSpec external) { if (external.inputs.size() > 0) { + LOG_DEBUG("init_compile_spec with input vector"); std::vector internal_inputs; for (auto i : external.inputs) { internal_inputs.push_back(i.toInternalInput()); @@ -226,6 +237,7 @@ core::CompileSpec init_compile_spec(CompileSpec external) { core::CompileSpec internal(internal_inputs); return internal; } else { + LOG_DEBUG("init_compile_spec with input signature"); torch::jit::IValue converted_input_signature; to_internal_input_signature(external.input_signature.signature_ivalue, converted_input_signature); core::CompileSpec internal(converted_input_signature); diff --git a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp index 7d80e6cf4e..8403a446d9 100644 --- a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp +++ b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp @@ -2,6 +2,7 @@ #include "pybind11/stl.h" #include "Python.h" +#include "ATen/core/jit_type.h" #include "core/compiler.h" #include "core/conversion/conversion.h" #include "tensorrt_classes.h" @@ -179,7 +180,11 @@ PYBIND11_MODULE(_C, m) { .def_readwrite("format", &Input::format); py::class_(m, "InputSignature") - .def(py::init<>()) + .def(pybind11::init([](py::object py_obj) { + InputSignature input_signature; + input_signature.signature_ivalue = torch::jit::toIValue(std::move(py_obj), c10::PyObjectType::get(), c10::nullopt); + return input_signature; + })) .def("__str__", &InputSignature::to_str) .def_readwrite("_signature_ivalue", &InputSignature::signature_ivalue); diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 45973562ab..59a0a54615 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -168,8 +168,7 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> _ts_C.TorchFallback: return info -def _parse_input_signature(input_signature: Any) -> _C.InputSignature: - print(input_signature) +def _parse_input_signature(input_signature: Any): if isinstance(input_signature, tuple): input_list = [] for item in input_signature: @@ -180,7 +179,7 @@ def _parse_input_signature(input_signature: Any) -> _C.InputSignature: input_list = [] for item in input_signature: input = _parse_input_signature(item) - input_list.append(input) + input_list.append(input) return input_list elif isinstance(input_signature, Input) or isinstance(input_signature, torch.Tensor): i = Input._from_tensor(input_signature) if isinstance(input_signature, torch.Tensor) else input_signature @@ -202,17 +201,14 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> _ts_C.CompileSpec: elif compile_spec["input_signature"] is not None: log(Level.Warning, "Input signature parsing is an experimental feature, behavior and APIs may change") - signature =_parse_input_signature(compile_spec["input_signature"]) - print(signature) - info.input_signature = signature + signature = _parse_input_signature(compile_spec["input_signature"]) + info.input_signature = _C.InputSignature(signature) # py_object else: raise KeyError( "Module input definitions are requried to compile module. Provide a list of torch_tensorrt.Input keyed to \"inputs\" in the compile spec" ) - #assert(len(info.inputs) > 0 or compile_spec["input_signature"] is not None, "Require at least one input definition to compile model") - if "enabled_precisions" in compile_spec: info.enabled_precisions = _parse_enabled_precisions(compile_spec["enabled_precisions"]) diff --git a/tests/cpp/test_collection.cpp b/tests/cpp/test_collection.cpp index e6592a4e20..c269ebac17 100644 --- a/tests/cpp/test_collection.cpp +++ b/tests/cpp/test_collection.cpp @@ -39,7 +39,7 @@ TEST(CppAPITests, TestCollectionStandardTensorInput) { input_range.push_back({in0.sizes(), torch::kF16}); torch_tensorrt::ts::CompileSpec compile_settings(input_range); compile_settings.require_full_compilation = true; - compile_settings.min_block_size = 1; + compile_settings.min_block_size = 3; // // FP16 execution compile_settings.enabled_precisions = {torch::kHalf}; @@ -88,7 +88,7 @@ TEST(CppAPITests, TestCollectionTupleInput) { auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2); compile_settings.require_full_compilation = false; - compile_settings.min_block_size = 1; + compile_settings.min_block_size = 3; // // FP16 execution compile_settings.enabled_precisions = {torch::kHalf}; @@ -153,7 +153,7 @@ TEST(CppAPITests, TestCollectionListInput) { auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2); compile_settings.require_full_compilation = false; - compile_settings.min_block_size = 1; + compile_settings.min_block_size = 3; compile_settings.torch_executed_ops.push_back("aten::__getitem__"); // // FP16 execution @@ -206,7 +206,7 @@ TEST(CppAPITests, TestCollectionTupleInputOutput) { auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2); compile_settings.require_full_compilation = false; - compile_settings.min_block_size = 1; + compile_settings.min_block_size = 3; // compile_settings.torch_executed_ops.push_back("prim::TupleConstruct"); @@ -276,7 +276,7 @@ TEST(CppAPITests, TestCollectionListInputOutput) { auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2); compile_settings.require_full_compilation = false; - compile_settings.min_block_size = 1; + compile_settings.min_block_size = 3; // Need to skip the conversion of __getitem__ and ListConstruct compile_settings.torch_executed_ops.push_back("aten::__getitem__"); @@ -346,7 +346,7 @@ TEST(CppAPITests, TestCollectionComplexModel) { auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2); compile_settings.require_full_compilation = false; - compile_settings.min_block_size = 1; + compile_settings.min_block_size = 3; // Need to skip the conversion of __getitem__ and ListConstruct compile_settings.torch_executed_ops.push_back("aten::__getitem__"); diff --git a/tests/py/test_collections.py b/tests/py/test_collections.py index 1f6694e5ba..fb532c52bd 100644 --- a/tests/py/test_collections.py +++ b/tests/py/test_collections.py @@ -29,9 +29,11 @@ def setUp(self): def test_compile(self): compile_spec = { - "input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape))), + "input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),), "device": torchtrt.Device("gpu:0"), - "enabled_precisions": {torch.float} + "enabled_precisions": {torch.float}, + "require_full_compilation": False, + "min_block_size": 3 } trt_mod = torchtrt.ts.compile(self.model, **compile_spec) @@ -45,9 +47,11 @@ def setUp(self): def test_compile(self): compile_spec = { - "input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)]), + "input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],), "device": torchtrt.Device("gpu:0"), - "enabled_precisions": {torch.float} + "enabled_precisions": {torch.float}, + "require_full_compilation": False, + "min_block_size": 3 } trt_mod = torchtrt.ts.compile(self.model, **compile_spec) @@ -61,9 +65,11 @@ def setUp(self): def test_compile(self): compile_spec = { - "input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape))), + "input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),), "device": torchtrt.Device("gpu:0"), - "enabled_precisions": {torch.float} + "enabled_precisions": {torch.float}, + "require_full_compilation": False, + "min_block_size": 3 } trt_mod = torchtrt.ts.compile(self.model, **compile_spec) @@ -79,9 +85,11 @@ def setUp(self): def test_compile(self): compile_spec = { - "input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)]), + "input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],), "device": torchtrt.Device("gpu:0"), - "enabled_precisions": {torch.float} + "enabled_precisions": {torch.float}, + "require_full_compilation": False, + "min_block_size": 3 } trt_mod = torchtrt.ts.compile(self.model, **compile_spec) @@ -98,9 +106,11 @@ def setUp(self): def test_compile(self): compile_spec = { - "input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)]), + "input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],), "device": torchtrt.Device("gpu:0"), - "enabled_precisions": {torch.float} + "enabled_precisions": {torch.float}, + "require_full_compilation": False, + "min_block_size": 3 } trt_mod = torchtrt.ts.compile(self.model, **compile_spec)