Skip to content

Collection: Python api support #1071

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/ir/GraphInputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ GraphInputs::GraphInputs(torch::jit::IValue& input_signature_) {
inputs = flattened_inputs;
input_signature = input_signature_;
collection_inputs = collection_inputs_;
LOG_DEBUG("Collection Input Size: " << collection_inputs_.size());
}

} // namespace ir
Expand Down
2 changes: 1 addition & 1 deletion core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa
LOG_DEBUG(
"In progress TRT block does not meet minimum block size requirements, therefore folding into in progress PyTorch block");
in_prog_pyt_blk_nodes.insert(
in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
in_prog_pyt_blk_nodes.begin(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
}
in_prog_trt_blk_nodes.clear();
// if there is a prim::If then this if node will be encapsulated in a SegmentedBlock
Expand Down
12 changes: 12 additions & 0 deletions py/torch_tensorrt/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,18 +214,30 @@ void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IV
} else if(input_ivalue.isCustomClass()) {
core::ir::Input cur_input = (*(input_ivalue.toCustomClass<Input>())).toInternalInput();
converted_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<core::ir::Input>(cur_input)));
} else if(input_ivalue.isPyObject()) {
auto py_object_holder = input_ivalue.toPyObjectHolder();
auto infer_type = py_object_holder->tryToInferType();
auto type = infer_type.type();
torch::jit::IValue ival = py_object_holder->toIValue(type);
torch::jit::IValue converted_item;
to_internal_input_signature(ival, converted_item);
converted_ivalue = torch::jit::IValue(converted_item);
} else {
LOG_ERROR("Unknown input spec type");
}
}

core::CompileSpec init_compile_spec(CompileSpec external) {
if (external.inputs.size() > 0) {
LOG_DEBUG("init_compile_spec with input vector");
std::vector<core::ir::Input> internal_inputs;
for (auto i : external.inputs) {
internal_inputs.push_back(i.toInternalInput());
}
core::CompileSpec internal(internal_inputs);
return internal;
} else {
LOG_DEBUG("init_compile_spec with input signature");
torch::jit::IValue converted_input_signature;
to_internal_input_signature(external.input_signature.signature_ivalue, converted_input_signature);
core::CompileSpec internal(converted_input_signature);
Expand Down
7 changes: 6 additions & 1 deletion py/torch_tensorrt/csrc/torch_tensorrt_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "pybind11/stl.h"

#include "Python.h"
#include "ATen/core/jit_type.h"
#include "core/compiler.h"
#include "core/conversion/conversion.h"
#include "tensorrt_classes.h"
Expand Down Expand Up @@ -179,7 +180,11 @@ PYBIND11_MODULE(_C, m) {
.def_readwrite("format", &Input::format);

py::class_<InputSignature>(m, "InputSignature")
.def(py::init<>())
.def(pybind11::init([](py::object py_obj) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the namespace pybind11? is this something different?

InputSignature input_signature;
input_signature.signature_ivalue = torch::jit::toIValue(std::move(py_obj), c10::PyObjectType::get(), c10::nullopt);
return input_signature;
}))
.def("__str__", &InputSignature::to_str)
.def_readwrite("_signature_ivalue", &InputSignature::signature_ivalue);

Expand Down
12 changes: 4 additions & 8 deletions py/torch_tensorrt/ts/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> _ts_C.TorchFallback:

return info

def _parse_input_signature(input_signature: Any) -> _C.InputSignature:
print(input_signature)
def _parse_input_signature(input_signature: Any):
if isinstance(input_signature, tuple):
input_list = []
for item in input_signature:
Expand All @@ -180,7 +179,7 @@ def _parse_input_signature(input_signature: Any) -> _C.InputSignature:
input_list = []
for item in input_signature:
input = _parse_input_signature(item)
input_list.append(input)
input_list.append(input)
return input_list
elif isinstance(input_signature, Input) or isinstance(input_signature, torch.Tensor):
i = Input._from_tensor(input_signature) if isinstance(input_signature, torch.Tensor) else input_signature
Expand All @@ -202,17 +201,14 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> _ts_C.CompileSpec:

elif compile_spec["input_signature"] is not None:
log(Level.Warning, "Input signature parsing is an experimental feature, behavior and APIs may change")
signature =_parse_input_signature(compile_spec["input_signature"])
print(signature)
info.input_signature = signature
signature = _parse_input_signature(compile_spec["input_signature"])
info.input_signature = _C.InputSignature(signature) # py_object

else:
raise KeyError(
"Module input definitions are requried to compile module. Provide a list of torch_tensorrt.Input keyed to \"inputs\" in the compile spec"
)

#assert(len(info.inputs) > 0 or compile_spec["input_signature"] is not None, "Require at least one input definition to compile model")

if "enabled_precisions" in compile_spec:
info.enabled_precisions = _parse_enabled_precisions(compile_spec["enabled_precisions"])

Expand Down
12 changes: 6 additions & 6 deletions tests/cpp/test_collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ TEST(CppAPITests, TestCollectionStandardTensorInput) {
input_range.push_back({in0.sizes(), torch::kF16});
torch_tensorrt::ts::CompileSpec compile_settings(input_range);
compile_settings.require_full_compilation = true;
compile_settings.min_block_size = 1;
compile_settings.min_block_size = 3;

// // FP16 execution
compile_settings.enabled_precisions = {torch::kHalf};
Expand Down Expand Up @@ -88,7 +88,7 @@ TEST(CppAPITests, TestCollectionTupleInput) {

auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
compile_settings.require_full_compilation = false;
compile_settings.min_block_size = 1;
compile_settings.min_block_size = 3;

// // FP16 execution
compile_settings.enabled_precisions = {torch::kHalf};
Expand Down Expand Up @@ -153,7 +153,7 @@ TEST(CppAPITests, TestCollectionListInput) {

auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
compile_settings.require_full_compilation = false;
compile_settings.min_block_size = 1;
compile_settings.min_block_size = 3;
compile_settings.torch_executed_ops.push_back("aten::__getitem__");

// // FP16 execution
Expand Down Expand Up @@ -206,7 +206,7 @@ TEST(CppAPITests, TestCollectionTupleInputOutput) {

auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
compile_settings.require_full_compilation = false;
compile_settings.min_block_size = 1;
compile_settings.min_block_size = 3;

// compile_settings.torch_executed_ops.push_back("prim::TupleConstruct");

Expand Down Expand Up @@ -276,7 +276,7 @@ TEST(CppAPITests, TestCollectionListInputOutput) {

auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
compile_settings.require_full_compilation = false;
compile_settings.min_block_size = 1;
compile_settings.min_block_size = 3;

// Need to skip the conversion of __getitem__ and ListConstruct
compile_settings.torch_executed_ops.push_back("aten::__getitem__");
Expand Down Expand Up @@ -346,7 +346,7 @@ TEST(CppAPITests, TestCollectionComplexModel) {

auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
compile_settings.require_full_compilation = false;
compile_settings.min_block_size = 1;
compile_settings.min_block_size = 3;

// Need to skip the conversion of __getitem__ and ListConstruct
compile_settings.torch_executed_ops.push_back("aten::__getitem__");
Expand Down
30 changes: 20 additions & 10 deletions tests/py/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ def setUp(self):

def test_compile(self):
compile_spec = {
"input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape))),
"input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),),
"device": torchtrt.Device("gpu:0"),
"enabled_precisions": {torch.float}
"enabled_precisions": {torch.float},
"require_full_compilation": False,
"min_block_size": 3
}

trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
Expand All @@ -45,9 +47,11 @@ def setUp(self):

def test_compile(self):
compile_spec = {
"input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)]),
"input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],),
"device": torchtrt.Device("gpu:0"),
"enabled_precisions": {torch.float}
"enabled_precisions": {torch.float},
"require_full_compilation": False,
"min_block_size": 3
}

trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
Expand All @@ -61,9 +65,11 @@ def setUp(self):

def test_compile(self):
compile_spec = {
"input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape))),
"input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),),
"device": torchtrt.Device("gpu:0"),
"enabled_precisions": {torch.float}
"enabled_precisions": {torch.float},
"require_full_compilation": False,
"min_block_size": 3
}

trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
Expand All @@ -79,9 +85,11 @@ def setUp(self):

def test_compile(self):
compile_spec = {
"input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)]),
"input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],),
"device": torchtrt.Device("gpu:0"),
"enabled_precisions": {torch.float}
"enabled_precisions": {torch.float},
"require_full_compilation": False,
"min_block_size": 3
}

trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
Expand All @@ -98,9 +106,11 @@ def setUp(self):

def test_compile(self):
compile_spec = {
"input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)]),
"input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],),
"device": torchtrt.Device("gpu:0"),
"enabled_precisions": {torch.float}
"enabled_precisions": {torch.float},
"require_full_compilation": False,
"min_block_size": 3
}

trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
Expand Down