Skip to content

Commit b96087b

Browse files
ArvindSridharnarendasan
authored andcommitted
Add partition logic and torch backend integration
Signed-off-by: Arvind Sridhar <[email protected]>
1 parent b4feb49 commit b96087b

File tree

10 files changed

+23
-7
lines changed

10 files changed

+23
-7
lines changed

core/compiler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ void AddEngineToGraph(
119119
}
120120

121121
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name) {
122-
// Go through Lowering to simplify graph
122+
// Go through Lowering to simplify graph and extract weight parameters
123123
auto graph_and_parameters = lowering::Lower(mod, method_name, lowering::LowerInfo());
124124

125125
auto g = graph_and_parameters.first;

core/compiler.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ struct CompileSpec {
1919
partitioning::PartitionInfo partition_info;
2020
};
2121

22-
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg);
22+
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);
2323

2424
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg);
2525

core/partitioning/partitioning.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,8 @@ std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const Partit
274274
}
275275

276276
std::string node_string(n->kind().toQualString());
277-
if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string)) {
277+
auto has_compile_attribute = n->hasAttribute(c10::Symbol::attr("to_compile"));
278+
if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string) && (!has_compile_attribute || n->i(c10::Symbol::attr("to_compile")) == (int64_t) true)) {
278279
tensorrt_nodes.push_back(n);
279280
if (tensorrt_nodes.size() >= min_block_size && !pytorch_nodes.empty()) {
280281
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);

cpp/api/include/trtorch/trtorch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,7 @@ TRTORCH_API void dump_build_info();
784784
*
785785
* @returns bool: Method is supported by TRTorch
786786
*/
787-
TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, std::string method_name, CompileSpec info);
787+
TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, std::string method_name);
788788

789789
/**
790790
* @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT

cpp/api/src/trtorch.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ namespace trtorch {
1111
core::CompileSpec to_internal_compile_spec(CompileSpec external);
1212
core::runtime::CudaDevice to_internal_cuda_device(CompileSpec::Device device);
1313

14-
bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name, CompileSpec info) {
15-
return core::CheckMethodOperatorSupport(module, method_name, to_internal_compile_spec(info));
14+
bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name) {
15+
return core::CheckMethodOperatorSupport(module, method_name);
1616
}
1717

1818
std::string ConvertGraphToTRTEngine(

py/trtorch/_compile_spec.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> trtorch._C.TorchFall
148148
assert isinstance(fallback_info["forced_fallback_ops"], list)
149149
info.forced_fallback_operators = fallback_info["forced_fallback_ops"]
150150

151+
if "forced_fallback_modules" in fallback_info:
152+
assert isinstance(fallback_info["forced_fallback_modules"], list)
153+
info.forced_fallback_modules = fallback_info["forced_fallback_modules"]
154+
151155
return info
152156

153157

@@ -338,6 +342,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
338342
torch_fallback._set_enabled(parsed_spec.torch_fallback.enabled)
339343
torch_fallback._set_min_block_size(parsed_spec.torch_fallback.min_block_size)
340344
torch_fallback._set_forced_fallback_operators(parsed_spec.torch_fallback.forced_fallback_operators)
345+
torch_fallback._set_forced_fallback_modules(parsed_spec.torch_fallback.forced_fallback_modules)
341346

342347
backend_spec._set_device(d)
343348
backend_spec._set_torch_fallback(torch_fallback)

py/trtorch/csrc/register_tensorrt_classes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ void RegisterTRTCompileSpec() {
3737
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, enabled);
3838
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, min_block_size);
3939
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, forced_fallback_operators);
40+
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, forced_fallback_modules);
4041

4142
static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration =
4243
torch::class_<trtorch::pyapi::CompileSpec>("tensorrt", "CompileSpec")

py/trtorch/csrc/tensorrt_classes.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ std::string TorchFallback::to_str() {
166166
ss << " " << i << ',' << std::endl;
167167
}
168168
ss << " ]" << std::endl;
169+
ss << " \"forced_fallback_modules\": [" << std::endl;
170+
for (auto i : forced_fallback_modules) {
171+
ss << " " << i << ',' << std::endl;
172+
}
173+
ss << " ]" << std::endl;
169174
ss << " }" << std::endl;
170175
return ss.str();
171176
}
@@ -203,6 +208,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
203208
info.partition_info.enabled = torch_fallback.enabled;
204209
info.partition_info.min_block_size = torch_fallback.min_block_size;
205210
info.partition_info.forced_fallback_operators = torch_fallback.forced_fallback_operators;
211+
info.lower_info.forced_fallback_modules = torch_fallback.forced_fallback_modules;
206212
info.convert_info.engine_settings.truncate_long_and_double = truncate_long_and_double;
207213

208214
info.convert_info.engine_settings.capability = toTRTEngineCapability(capability);

py/trtorch/csrc/tensorrt_classes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,13 @@ struct TorchFallback : torch::CustomClassHolder {
9090
bool enabled;
9191
int64_t min_block_size;
9292
std::vector<std::string> forced_fallback_operators;
93+
std::vector<std::string> forced_fallback_modules;
9394
TorchFallback() : enabled(false), min_block_size(1) {}
9495

9596
ADD_FIELD_GET_SET(enabled, bool);
9697
ADD_FIELD_GET_SET(min_block_size, int64_t);
9798
ADD_FIELD_GET_SET(forced_fallback_operators, std::vector<std::string>);
99+
ADD_FIELD_GET_SET(forced_fallback_modules, std::vector<std::string>);
98100

99101
std::string to_str();
100102
};

py/trtorch/csrc/trtorch_py.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,8 @@ PYBIND11_MODULE(_C, m) {
281281
.def("__str__", &trtorch::pyapi::TorchFallback::to_str)
282282
.def_readwrite("enabled", &TorchFallback::enabled)
283283
.def_readwrite("min_block_size", &TorchFallback::min_block_size)
284-
.def_readwrite("forced_fallback_operators", &TorchFallback::forced_fallback_operators);
284+
.def_readwrite("forced_fallback_operators", &TorchFallback::forced_fallback_operators)
285+
.def_readwrite("forced_fallback_modules", &TorchFallback::forced_fallback_modules);
285286

286287
m.doc() =
287288
"TRTorch Internal C Bindings: Ahead of Time compilation for PyTorch JIT. A tool to convert PyTorch JIT to TensorRT";

0 commit comments

Comments
 (0)