diff --git a/cpp/include/trtorch/trtorch.h b/cpp/include/trtorch/trtorch.h index c62e947694..314b9fa6e4 100644 --- a/cpp/include/trtorch/trtorch.h +++ b/cpp/include/trtorch/trtorch.h @@ -427,7 +427,7 @@ struct TRTORCH_API CompileSpec { Input(c10::ArrayRef shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous); /** - * @brief Construct a new Input Range object dynamic input size from + * @brief Construct a new Input spec object dynamic input size from * c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max * supported sizes. dtype (Expected data type for the input) defaults to PyTorch * / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8) @@ -462,7 +462,7 @@ struct TRTORCH_API CompileSpec { TensorFormat format = TensorFormat::kContiguous); /** - * @brief Construct a new Input Range object dynamic input size from + * @brief Construct a new Input spec object dynamic input size from * c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max * supported sizes. dtype (Expected data type for the input) defaults to PyTorch * / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8) @@ -479,7 +479,7 @@ struct TRTORCH_API CompileSpec { TensorFormat format = TensorFormat::kContiguous); /** - * @brief Construct a new Input Range object dynamic input size from + * @brief Construct a new Input spec object dynamic input size from * c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max * supported sizes * @@ -496,6 +496,16 @@ struct TRTORCH_API CompileSpec { DataType dtype, TensorFormat format = TensorFormat::kContiguous); + /** + * @brief Construct a new Input spec object using a torch tensor as an example + * The tensor's shape, type and layout inform the spec's values + * + * Note: You cannot set dynamic shape through this method, you must use an alternative constructor + * + * @param tensor Reference tensor to set shape, type and layout + */ + Input(at::Tensor tensor); + bool get_explicit_set_dtype() { return explicit_set_dtype; } diff --git a/cpp/src/compile_spec.cpp b/cpp/src/compile_spec.cpp index 24c771cad4..bd2de946a4 100644 --- a/cpp/src/compile_spec.cpp +++ b/cpp/src/compile_spec.cpp @@ -287,6 +287,26 @@ CompileSpec::Input::Input( this->input_is_dynamic = true; } +CompileSpec::Input::Input(at::Tensor tensor) { + this->opt_shape = tensor.sizes().vec(); + this->min_shape = tensor.sizes().vec(); + this->max_shape = tensor.sizes().vec(); + this->shape = tensor.sizes().vec(); + this->dtype = tensor.scalar_type(); + this->explicit_set_dtype = true; + TRTORCH_ASSERT( + tensor.is_contiguous(at::MemoryFormat::ChannelsLast) || tensor.is_contiguous(at::MemoryFormat::Contiguous), + "Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last"); + at::MemoryFormat frmt; + if (tensor.is_contiguous(at::MemoryFormat::Contiguous)) { + frmt = at::MemoryFormat::Contiguous; + } else { + frmt = at::MemoryFormat::ChannelsLast; + } + this->format = frmt; + this->input_is_dynamic = false; +} + /* ==========================================*/ core::ir::Input to_internal_input(CompileSpec::InputRange& i) { diff --git a/py/trtorch/Input.py b/py/trtorch/Input.py index d36d1eb3b6..51cf4f6860 100644 --- a/py/trtorch/Input.py +++ b/py/trtorch/Input.py @@ -196,3 +196,16 @@ def _parse_format(format: Any) -> _types.TensorFormat: else: raise TypeError( "Tensor format needs to be specified with either torch.memory_format or trtorch.TensorFormat") + + @classmethod + def _from_tensor(cls, t: torch.Tensor): + if not any([ + t.is_contiguous(memory_format=torch.contiguous_format), + t.is_contiguous(memory_format=torch.channels_last) + ]): + raise ValueError( + "Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last" + ) + frmt = torch.contiguous_format if t.is_contiguous( + memory_format=torch.contiguous_format) else torch.channels_last + return cls(shape=t.shape, dtype=t.dtype, format=frmt) diff --git a/py/trtorch/_compile_spec.py b/py/trtorch/_compile_spec.py index dc2e7095fa..dc2d36dd1e 100644 --- a/py/trtorch/_compile_spec.py +++ b/py/trtorch/_compile_spec.py @@ -174,7 +174,12 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec: info.inputs = _parse_input_ranges(compile_spec["input_shapes"]) if "inputs" in compile_spec: - info.inputs = [i._to_internal() for i in compile_spec["inputs"]] + if not all([isinstance(i, torch.Tensor) or isinstance(i, trtorch.Input) for i in compile_spec["inputs"]]): + raise KeyError("Input specs should be either trtorch.Input or torch.Tensor, found types: {}".format( + [typeof(i) for i in compile_spec["inputs"]])) + + inputs = [trtorch.Input._from_tensor(i) if isinstance(i, torch.Tensor) else i for i in compile_spec["inputs"]] + info.inputs = [i._to_internal() for i in inputs] if "op_precision" in compile_spec and "enabled_precisions" in compile_spec: raise KeyError( diff --git a/tests/cpp/BUILD b/tests/cpp/BUILD index 3d35abf366..8023502849 100644 --- a/tests/cpp/BUILD +++ b/tests/cpp/BUILD @@ -16,7 +16,8 @@ test_suite( ":test_modules_as_engines", ":test_multiple_registered_engines", ":test_serialization", - ":test_module_fallback" + ":test_module_fallback", + ":test_example_tensors" ], ) @@ -28,7 +29,8 @@ test_suite( ":test_modules_as_engines", ":test_multiple_registered_engines", ":test_serialization", - ":test_module_fallback" + ":test_module_fallback", + ":test_example_tensors" ], ) @@ -43,6 +45,17 @@ cc_test( ], ) +cc_test( + name = "test_example_tensors", + srcs = ["test_example_tensors.cpp"], + data = [ + "//tests/modules:jit_models", + ], + deps = [ + ":cpp_api_test", + ], +) + cc_test( name = "test_serialization", srcs = ["test_serialization.cpp"], diff --git a/tests/cpp/test_example_tensors.cpp b/tests/cpp/test_example_tensors.cpp new file mode 100644 index 0000000000..f0f509f996 --- /dev/null +++ b/tests/cpp/test_example_tensors.cpp @@ -0,0 +1,23 @@ +#include "cpp_api_test.h" + +TEST_P(CppAPITests, InputsFromTensors) { + std::vector jit_inputs_ivalues; + std::vector trt_inputs_ivalues; + for (auto in_shape : input_shapes) { + auto in = at::randn(in_shape, {at::kCUDA}); + jit_inputs_ivalues.push_back(in.clone()); + trt_inputs_ivalues.push_back(in.clone()); + } + + auto spec = trtorch::CompileSpec({trt_inputs_ivalues[0].toTensor()}); + + auto trt_mod = trtorch::CompileGraph(mod, spec); + torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues); + std::vector trt_results; + trt_results.push_back(trt_results_ivalues.toTensor()); +} + +INSTANTIATE_TEST_SUITE_P( + CompiledModuleForwardIsCloseSuite, + CppAPITests, + testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}))); diff --git a/tests/py/test_api.py b/tests/py/test_api.py index 94239b1475..f3b99232c1 100644 --- a/tests/py/test_api.py +++ b/tests/py/test_api.py @@ -73,6 +73,27 @@ def test_compile_script(self): same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() self.assertTrue(same < 2e-2) + def test_from_torch_tensor(self): + compile_spec = { + "inputs": [self.input], + "device": { + "device_type": trtorch.DeviceType.GPU, + "gpu_id": 0, + }, + "enabled_precisions": {torch.float} + } + + trt_mod = trtorch.compile(self.scripted_model, compile_spec) + same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() + self.assertTrue(same < 2e-2) + + def test_device(self): + compile_spec = {"inputs": [self.input], "device": trtorch.Device("gpu:0"), "enabled_precisions": {torch.float}} + + trt_mod = trtorch.compile(self.scripted_model, compile_spec) + same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() + self.assertTrue(same < 2e-2) + class TestCompileHalf(ModelTestCase):