diff --git a/core/ir/GraphInputs.cpp b/core/ir/GraphInputs.cpp
index 645624f2f1..792189137a 100644
--- a/core/ir/GraphInputs.cpp
+++ b/core/ir/GraphInputs.cpp
@@ -68,6 +68,7 @@ GraphInputs::GraphInputs(torch::jit::IValue& input_signature_) {
inputs = flattened_inputs;
input_signature = input_signature_;
collection_inputs = collection_inputs_;
+ LOG_DEBUG("Collection Input Size: " << collection_inputs_.size());
}
} // namespace ir
diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp
index 3583fe7f10..4b3ff65ff0 100644
--- a/core/partitioning/partitioning.cpp
+++ b/core/partitioning/partitioning.cpp
@@ -500,7 +500,7 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa
LOG_DEBUG(
"In progress TRT block does not meet minimum block size requirements, therefore folding into in progress PyTorch block");
in_prog_pyt_blk_nodes.insert(
- in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
+ in_prog_pyt_blk_nodes.begin(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
}
in_prog_trt_blk_nodes.clear();
// if there is a prim::If then this if node will be encapsulated in a SegmentedBlock
diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.cpp b/py/torch_tensorrt/csrc/tensorrt_classes.cpp
index 68c8af8ff3..03bf75e04c 100644
--- a/py/torch_tensorrt/csrc/tensorrt_classes.cpp
+++ b/py/torch_tensorrt/csrc/tensorrt_classes.cpp
@@ -214,11 +214,22 @@ void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IV
} else if(input_ivalue.isCustomClass()) {
core::ir::Input cur_input = (*(input_ivalue.toCustomClass())).toInternalInput();
converted_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(cur_input)));
+ } else if(input_ivalue.isPyObject()) {
+ auto py_object_holder = input_ivalue.toPyObjectHolder();
+ auto infer_type = py_object_holder->tryToInferType();
+ auto type = infer_type.type();
+ torch::jit::IValue ival = py_object_holder->toIValue(type);
+ torch::jit::IValue converted_item;
+ to_internal_input_signature(ival, converted_item);
+ converted_ivalue = torch::jit::IValue(converted_item);
+ } else {
+ LOG_ERROR("Unknown input spec type");
}
}
core::CompileSpec init_compile_spec(CompileSpec external) {
if (external.inputs.size() > 0) {
+ LOG_DEBUG("init_compile_spec with input vector");
std::vector internal_inputs;
for (auto i : external.inputs) {
internal_inputs.push_back(i.toInternalInput());
@@ -226,6 +237,7 @@ core::CompileSpec init_compile_spec(CompileSpec external) {
core::CompileSpec internal(internal_inputs);
return internal;
} else {
+ LOG_DEBUG("init_compile_spec with input signature");
torch::jit::IValue converted_input_signature;
to_internal_input_signature(external.input_signature.signature_ivalue, converted_input_signature);
core::CompileSpec internal(converted_input_signature);
diff --git a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp
index 7d80e6cf4e..8403a446d9 100644
--- a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp
+++ b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp
@@ -2,6 +2,7 @@
#include "pybind11/stl.h"
#include "Python.h"
+#include "ATen/core/jit_type.h"
#include "core/compiler.h"
#include "core/conversion/conversion.h"
#include "tensorrt_classes.h"
@@ -179,7 +180,11 @@ PYBIND11_MODULE(_C, m) {
.def_readwrite("format", &Input::format);
py::class_(m, "InputSignature")
- .def(py::init<>())
+ .def(pybind11::init([](py::object py_obj) {
+ InputSignature input_signature;
+ input_signature.signature_ivalue = torch::jit::toIValue(std::move(py_obj), c10::PyObjectType::get(), c10::nullopt);
+ return input_signature;
+ }))
.def("__str__", &InputSignature::to_str)
.def_readwrite("_signature_ivalue", &InputSignature::signature_ivalue);
diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py
index 45973562ab..59a0a54615 100644
--- a/py/torch_tensorrt/ts/_compile_spec.py
+++ b/py/torch_tensorrt/ts/_compile_spec.py
@@ -168,8 +168,7 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> _ts_C.TorchFallback:
return info
-def _parse_input_signature(input_signature: Any) -> _C.InputSignature:
- print(input_signature)
+def _parse_input_signature(input_signature: Any):
if isinstance(input_signature, tuple):
input_list = []
for item in input_signature:
@@ -180,7 +179,7 @@ def _parse_input_signature(input_signature: Any) -> _C.InputSignature:
input_list = []
for item in input_signature:
input = _parse_input_signature(item)
- input_list.append(input)
+ input_list.append(input)
return input_list
elif isinstance(input_signature, Input) or isinstance(input_signature, torch.Tensor):
i = Input._from_tensor(input_signature) if isinstance(input_signature, torch.Tensor) else input_signature
@@ -202,17 +201,14 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> _ts_C.CompileSpec:
elif compile_spec["input_signature"] is not None:
log(Level.Warning, "Input signature parsing is an experimental feature, behavior and APIs may change")
- signature =_parse_input_signature(compile_spec["input_signature"])
- print(signature)
- info.input_signature = signature
+ signature = _parse_input_signature(compile_spec["input_signature"])
+ info.input_signature = _C.InputSignature(signature) # py_object
else:
raise KeyError(
"Module input definitions are requried to compile module. Provide a list of torch_tensorrt.Input keyed to \"inputs\" in the compile spec"
)
- #assert(len(info.inputs) > 0 or compile_spec["input_signature"] is not None, "Require at least one input definition to compile model")
-
if "enabled_precisions" in compile_spec:
info.enabled_precisions = _parse_enabled_precisions(compile_spec["enabled_precisions"])
diff --git a/tests/cpp/test_collection.cpp b/tests/cpp/test_collection.cpp
index e6592a4e20..c269ebac17 100644
--- a/tests/cpp/test_collection.cpp
+++ b/tests/cpp/test_collection.cpp
@@ -39,7 +39,7 @@ TEST(CppAPITests, TestCollectionStandardTensorInput) {
input_range.push_back({in0.sizes(), torch::kF16});
torch_tensorrt::ts::CompileSpec compile_settings(input_range);
compile_settings.require_full_compilation = true;
- compile_settings.min_block_size = 1;
+ compile_settings.min_block_size = 3;
// // FP16 execution
compile_settings.enabled_precisions = {torch::kHalf};
@@ -88,7 +88,7 @@ TEST(CppAPITests, TestCollectionTupleInput) {
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
compile_settings.require_full_compilation = false;
- compile_settings.min_block_size = 1;
+ compile_settings.min_block_size = 3;
// // FP16 execution
compile_settings.enabled_precisions = {torch::kHalf};
@@ -153,7 +153,7 @@ TEST(CppAPITests, TestCollectionListInput) {
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
compile_settings.require_full_compilation = false;
- compile_settings.min_block_size = 1;
+ compile_settings.min_block_size = 3;
compile_settings.torch_executed_ops.push_back("aten::__getitem__");
// // FP16 execution
@@ -206,7 +206,7 @@ TEST(CppAPITests, TestCollectionTupleInputOutput) {
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
compile_settings.require_full_compilation = false;
- compile_settings.min_block_size = 1;
+ compile_settings.min_block_size = 3;
// compile_settings.torch_executed_ops.push_back("prim::TupleConstruct");
@@ -276,7 +276,7 @@ TEST(CppAPITests, TestCollectionListInputOutput) {
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
compile_settings.require_full_compilation = false;
- compile_settings.min_block_size = 1;
+ compile_settings.min_block_size = 3;
// Need to skip the conversion of __getitem__ and ListConstruct
compile_settings.torch_executed_ops.push_back("aten::__getitem__");
@@ -346,7 +346,7 @@ TEST(CppAPITests, TestCollectionComplexModel) {
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
compile_settings.require_full_compilation = false;
- compile_settings.min_block_size = 1;
+ compile_settings.min_block_size = 3;
// Need to skip the conversion of __getitem__ and ListConstruct
compile_settings.torch_executed_ops.push_back("aten::__getitem__");
diff --git a/tests/py/test_collections.py b/tests/py/test_collections.py
index 1f6694e5ba..fb532c52bd 100644
--- a/tests/py/test_collections.py
+++ b/tests/py/test_collections.py
@@ -29,9 +29,11 @@ def setUp(self):
def test_compile(self):
compile_spec = {
- "input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape))),
+ "input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),),
"device": torchtrt.Device("gpu:0"),
- "enabled_precisions": {torch.float}
+ "enabled_precisions": {torch.float},
+ "require_full_compilation": False,
+ "min_block_size": 3
}
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
@@ -45,9 +47,11 @@ def setUp(self):
def test_compile(self):
compile_spec = {
- "input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)]),
+ "input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],),
"device": torchtrt.Device("gpu:0"),
- "enabled_precisions": {torch.float}
+ "enabled_precisions": {torch.float},
+ "require_full_compilation": False,
+ "min_block_size": 3
}
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
@@ -61,9 +65,11 @@ def setUp(self):
def test_compile(self):
compile_spec = {
- "input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape))),
+ "input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),),
"device": torchtrt.Device("gpu:0"),
- "enabled_precisions": {torch.float}
+ "enabled_precisions": {torch.float},
+ "require_full_compilation": False,
+ "min_block_size": 3
}
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
@@ -79,9 +85,11 @@ def setUp(self):
def test_compile(self):
compile_spec = {
- "input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)]),
+ "input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],),
"device": torchtrt.Device("gpu:0"),
- "enabled_precisions": {torch.float}
+ "enabled_precisions": {torch.float},
+ "require_full_compilation": False,
+ "min_block_size": 3
}
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
@@ -98,9 +106,11 @@ def setUp(self):
def test_compile(self):
compile_spec = {
- "input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)]),
+ "input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],),
"device": torchtrt.Device("gpu:0"),
- "enabled_precisions": {torch.float}
+ "enabled_precisions": {torch.float},
+ "require_full_compilation": False,
+ "min_block_size": 3
}
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)