Skip to content

Commit 52e2f05

Browse files
committed
feat!: Turning on partial compilation by default
BREAKING CHANGE: This commit turns on partial compilation by default. Unsupported modules will attempt to be run partially in PyTorch and partially in TensorRT Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent a234335 commit 52e2f05

File tree

17 files changed

+244
-82
lines changed

17 files changed

+244
-82
lines changed

core/compiler.cpp

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ GraphAndMapping ConstructFallbackGraph(
253253
}
254254
// update the input ranges for each segments
255255
convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
256+
256257
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params);
257258
auto temp_g = std::make_shared<torch::jit::Graph>();
258259
auto device_spec = convert_cfg.engine_settings.device;
@@ -288,7 +289,7 @@ GraphAndMapping ConstructFallbackGraph(
288289
}
289290

290291

291-
void MapInputsAndDetermineDTypes(CompileSpec& cfg, std::shared_ptr<torch::jit::Graph>& g, ir::StaticParams& static_params, const util::InputTypeMap& first_use_type_map) {
292+
void MapInputsAndDetermineDTypes(CompileSpec& cfg, std::shared_ptr<torch::jit::Graph>& g, ir::StaticParams& static_params, ir::TypeMap& first_use_type_map) {
292293
// Associate input specs with inputs
293294
cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params));
294295

@@ -303,9 +304,31 @@ void MapInputsAndDetermineDTypes(CompileSpec& cfg, std::shared_ptr<torch::jit::G
303304
} else if (!est_type_opt && !spec.dtype_is_user_defined) {
304305
// If we cannot calculate the type and the user did not define the type, then default to FP32
305306
LOG_WARNING(
306-
"Cannot deterime input type from calcuations in graph for input "
307+
"Cannot infer input type from calcuations in graph for input "
307308
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
308309
spec.dtype = nvinfer1::DataType::kFLOAT;
310+
} else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) {
311+
if (!est_type_opt) {
312+
LOG_INFO("Cannot infer input tensor dtype in graph, unable to verify user input dtype settings");
313+
} else {
314+
if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) {
315+
std::stringstream ss;
316+
ss <<"For input " << in->debugName() << ", found user specified input dtype as ";
317+
ss << cfg.convert_info.inputs.find(in)->second.dtype;
318+
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
319+
ss << est_type_opt.value() << std::endl;
320+
ss << "The compiler is going to use the user setting " << cfg.convert_info.inputs.find(in)->second.dtype;
321+
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
322+
ss << "compatibility with PyTorch's data type convention is required.\n";
323+
ss << "If you do indeed see errors at runtime either:\n";
324+
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
325+
ss << "- Disable partial compilation by setting require_full_compilation to True";
326+
auto warn_str = ss.str();
327+
LOG_WARNING(warn_str);
328+
// Overwrite type map with user settings
329+
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
330+
}
331+
}
309332
} else {
310333
// The user defined the type so no changes are necessary
311334
}
@@ -317,10 +340,11 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
317340
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info);
318341

319342
auto g = graph_and_parameters.first;
343+
TRTORCH_CHECK(conversion::VerifyConverterSupportForBlock(g->block()), "Not all operations in graph are supported by the compiler");
320344
auto params = graph_and_parameters.second;
321345
auto static_params = ir::get_static_params(g->inputs(), params);
322346
// Infer the type of an input from the weights of the calculation
323-
auto first_use_types = util::get_block_first_calc_dtypes_opt(g->block());
347+
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
324348

325349
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
326350

@@ -357,11 +381,21 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
357381
auto params = graph_and_parameters.second;
358382
auto static_params = ir::get_static_params(g->inputs(), params);
359383
// Infer the type of an input from the weights of the calculation
360-
auto first_use_types = util::get_block_first_calc_dtypes_opt(g->block());
384+
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
361385

362386
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
363387

364-
if (cfg.partition_info.enabled) {
388+
if (cfg.partition_info.enabled
389+
&& (cfg.lower_info.forced_fallback_modules.size() == 0
390+
&& cfg.partition_info.forced_fallback_operators.size() == 0
391+
&& conversion::VerifyConverterSupportForBlock(g->block(), true))) {
392+
LOG_INFO("Skipping partitioning since model is fully supported");
393+
}
394+
395+
if (cfg.partition_info.enabled
396+
&& !(cfg.lower_info.forced_fallback_modules.size() == 0
397+
&& cfg.partition_info.forced_fallback_operators.size() == 0
398+
&& conversion::VerifyConverterSupportForBlock(g->block(), false))) {
365399
auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types);
366400
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params);
367401
new_g = graph_and_mapping.first;
@@ -374,6 +408,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
374408
return mod;
375409
}
376410
} else {
411+
TRTORCH_CHECK(conversion::VerifyConverterSupportForBlock(g->block()), "Not all operations in graph are supported by the compiler");
377412
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
378413
auto device_spec = cfg.convert_info.engine_settings.device;
379414
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);

core/conversion/conversion.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {
491491
return convertable_ops;
492492
}
493493

494-
bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
494+
bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors) {
495495
auto unsupported_ops = GetUnsupportedOpsInBlock(b);
496496

497497
if (unsupported_ops.size() != 0) {
@@ -506,16 +506,20 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
506506
unsupported_msg << "https://www.github.com/nvidia/TRTorch/issues" << std::endl;
507507
unsupported_msg << std::endl << "In Module:" << std::endl;
508508

509-
LOG_ERROR(unsupported_msg.str());
509+
if (suppress_errors) {
510+
LOG_ERROR(unsupported_msg.str());
511+
}
510512

511513
for (const auto n : b->nodes()) {
512514
auto schema = n->maybeSchema();
513515
if (schema) {
514516
for (const auto& x : unsupported_ops) {
515517
if (x.first == schema->operator_name()) {
516-
LOG_ERROR(
517-
"Unsupported operator: " << *schema << std::endl
518-
<< trtorch::core::util::GetPyTorchSourceCode(n) << std::endl);
518+
if (suppress_errors) {
519+
LOG_ERROR(
520+
"Unsupported operator: " << *schema << std::endl
521+
<< trtorch::core::util::GetPyTorchSourceCode(n) << std::endl);
522+
}
519523
}
520524
}
521525
}
@@ -531,7 +535,9 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
531535
unsupported_msg
532536
<< "This may be because there are no operators that can be added to the TensorRT graph or all operators have a resolved compile time value."
533537
<< std::endl;
534-
LOG_ERROR(unsupported_msg.str());
538+
if (suppress_errors) {
539+
LOG_ERROR(unsupported_msg.str());
540+
}
535541
return false;
536542
}
537543

core/conversion/conversion.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ std::string ConvertBlockToEngine(
2525

2626
bool OpSupported(const torch::jit::Node* n);
2727

28-
bool VerifyConverterSupportForBlock(const torch::jit::Block* b);
28+
bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors=false);
2929

3030
c10::optional<torch::jit::IValue> EvaluateNode(
3131
ConversionCtx* ctx,

core/ir/ir.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,109 @@ std::vector<const torch::jit::Value*> get_tensor_inputs(
4545
return input_tensors;
4646
}
4747

48+
c10::optional<at::ScalarType> get_value_first_calc_dtype_opt(torch::jit::Block* b, torch::jit::Value* in) {
49+
TRTORCH_ASSERT(in->owningGraph() == b->owningGraph(), "Provided input is not part of the provided graph");
50+
c10::optional<at::ScalarType> dtype = {};
51+
52+
auto b_ins = b->inputs();
53+
std::unordered_set<torch::jit::Value*> b_in_set(b_ins.begin(), b_ins.end());
54+
55+
TRTORCH_ASSERT(
56+
in->type() == c10::TensorType::get(), "Input is not a tensor, cannot check for dtype based on calculation");
57+
58+
auto consumers = in->uses();
59+
auto search_list = std::vector<torch::jit::Use>(consumers.begin(), consumers.end());
60+
61+
for (auto& u : search_list) {
62+
auto n = u.user;
63+
LOG_GRAPH("Node we are looking at: " << util::node_info(n));
64+
auto ins = n->inputs();
65+
auto outs = n->outputs();
66+
67+
bool outputs_tensor = false;
68+
for (auto o : outs) {
69+
if (o->type() == c10::TensorType::get()) {
70+
outputs_tensor = true;
71+
break;
72+
}
73+
}
74+
75+
if (!outputs_tensor) {
76+
LOG_GRAPH("Node " << util::node_info(n) << " does not output a tensor, skipping");
77+
continue;
78+
}
79+
80+
LOG_GRAPH("Node " << util::node_info(n) << " outputs a tensor");
81+
82+
// If all input tensors are block inputs then this node will not give us useful type info so move to the next one
83+
bool all_n_ins_are_b_ins = true;
84+
for (auto in : ins) {
85+
if (b_in_set.find(in) == b_in_set.end()) {
86+
all_n_ins_are_b_ins = false;
87+
break;
88+
}
89+
}
90+
91+
if (all_n_ins_are_b_ins) {
92+
LOG_GRAPH(
93+
"All inputs to Node " << util::node_info(n) << " are graph inputs, cannot be used to determine input type");
94+
for (auto o : outs) {
95+
if (o->type() == c10::TensorType::get()) {
96+
auto o_uses = o->uses();
97+
search_list.insert(search_list.end(), o_uses.begin(), o_uses.end());
98+
}
99+
}
100+
continue;
101+
}
102+
103+
// If node outputs a Tensor it might be a result of tensor calcuation so check to see
104+
// if any inputs to the calculation can give us hints
105+
c10::optional<torch::jit::Node*> const_tensor_n = {};
106+
107+
// Backtrace to constants which will immediately give us the Tensor type if possible
108+
for (auto in : ins) {
109+
LOG_GRAPH("Input to node: " << util::node_info(in->node()));
110+
if (in->type()->isSubtypeOf(torch::jit::TensorType::get())) {
111+
LOG_GRAPH("Input outputs a Tensor");
112+
if (in->node()->kind() == torch::jit::prim::Constant) {
113+
LOG_GRAPH("Input is a constant");
114+
auto const_val = in->node()->t(c10::attr::value);
115+
LOG_GRAPH("Found that constant tensor has type: " << const_val.scalar_type());
116+
dtype = {const_val.scalar_type()};
117+
goto exit_first_calc_dtype;
118+
}
119+
}
120+
}
121+
122+
// Add all tensor outputs to search list if we still dont know
123+
for (auto o : outs) {
124+
if (o->type() == c10::TensorType::get()) {
125+
auto o_uses = o->uses();
126+
search_list.insert(search_list.end(), o_uses.begin(), o_uses.end());
127+
}
128+
}
129+
}
130+
exit_first_calc_dtype:
131+
if (dtype) {
132+
LOG_GRAPH("Estimated input type is " << dtype.value());
133+
} else {
134+
LOG_GRAPH("Cannot determine input types from graph");
135+
}
136+
return dtype;
137+
}
138+
139+
TypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b) {
140+
TypeMap types;
141+
142+
for (auto i : b->inputs()) {
143+
if (i->type() == c10::TensorType::get()) {
144+
torch::jit::Value* in = i;
145+
types.insert({in, get_value_first_calc_dtype_opt(b, i)});
146+
}
147+
}
148+
return types;
149+
}
150+
48151
} // namespace ir
49152
} // namespace core
50153
} // namespace trtorch

core/ir/ir.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ std::vector<const torch::jit::Value*> get_tensor_inputs(
5252
std::shared_ptr<torch::jit::Graph>& g,
5353
StaticParams& static_params);
5454

55+
using TypeMap = std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>>;
56+
57+
c10::optional<at::ScalarType> get_value_first_calc_dtype_opt(torch::jit::Block* b, torch::jit::Value* in);
58+
ir::TypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b);
59+
5560
} // namespace ir
5661
} // namespace core
5762
} // namespace trtorch

core/lowering/LowerInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace lowering {
1010

1111
std::ostream& operator<<(std::ostream& os, const LowerInfo& l) {
1212
os << "Settings requested for Lowering:" << std::endl;
13-
os << " Forced Fallback Modules: [" << std::endl;
13+
os << " torch_executed_modules: [" << std::endl;
1414
for (auto i : l.forced_fallback_modules) {
1515
os << " " << i << std::endl;
1616
}

core/partitioning/PartitionInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ std::ostream& operator<<(std::ostream& os, const PartitionInfo& s) {
1414
if (s.enabled) {
1515
os << "True";
1616
os << "\n \"min_block_size\": " << s.min_block_size \
17-
<< "\n \"forced_fallback_operators\": [";
17+
<< "\n \"torch_executed_operators\": [";
1818
for (auto i : s.forced_fallback_operators) {
1919
os <<"\n " << i << ',';
2020
}

core/util/BUILD

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@ cc_library(
2727
hdrs = [
2828
"jit_util.h",
2929
],
30-
srcs = [
31-
"jit_util.cpp"
32-
],
3330
deps = [
3431
":macros"
3532
] + select({

core/util/jit_util.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ namespace trtorch {
99
namespace core {
1010
namespace util {
1111

12-
using InputTypeMap = std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>>;
1312

1413
inline std::string node_info(const torch::jit::Node* n) {
1514
std::stringstream ss;
@@ -62,9 +61,6 @@ inline std::string GetPyTorchSourceCode(const torch::jit::Node* n) {
6261
return source_code;
6362
}
6463

65-
c10::optional<at::ScalarType> get_value_first_calc_dtype_opt(torch::jit::Block* b, torch::jit::Value* in);
66-
InputTypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b);
67-
6864
} // namespace util
6965
} // namespace core
7066
} // namespace trtorch

core/util/logging/TRTorchLogger.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ namespace {
125125

126126
TRTorchLogger& get_global_logger() {
127127
#ifndef NDEBUG
128-
static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", LogLevel::kGRAPH, true);
128+
static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", LogLevel::kDEBUG, true);
129129
#else
130-
static TRTorchLogger global_logger("[TRTorch] - ", LogLevel::kERROR, false);
130+
static TRTorchLogger global_logger("[TRTorch] - ", LogLevel::kWARNING, false);
131131
#endif
132132
return global_logger;
133133
}

cpp/bin/trtorchc/README.md

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ OPTIONS:
3636
--allow-gpu-fallback (Only used when targeting DLA
3737
(device-type)) Lets engine run layers on
3838
GPU if they are not supported on DLA
39-
--allow-torch-fallback Enable layers to run in torch if they
40-
are not supported in TensorRT
39+
--require-full-compilation Require that the model should be fully
40+
compiled to TensorRT or throw an error
4141
--disable-tf32 Prevent Float32 layers from using the
4242
TF32 data format
4343
--sparse-weights Enable sparsity for weights of conv and
@@ -63,18 +63,22 @@ OPTIONS:
6363
--calibration-cache-file=[file_path]
6464
Path to calibration cache file to use
6565
for post training quantization
66-
--ffo=[forced_fallback_ops...],
67-
--forced-fallback-op=[forced_fallback_ops...]
66+
--teo=[torch-executed-ops...],
67+
--torch-executed-ops=[torch-executed-ops...]
6868
(Repeatable) Operator in the graph that
69-
should be forced to fallback to Pytorch
70-
for execution (allow torch fallback must
71-
be set)
72-
--ffm=[forced_fallback_mods...],
73-
--forced-fallback-mod=[forced_fallback_mods...]
74-
(Repeatable) Module that should be
75-
forced to fallback to Pytorch for
76-
execution (allow torch fallback must be
77-
set)
69+
should always be run in PyTorch for
70+
execution (partial compilation must be
71+
enabled)
72+
--tem=[torch-executed-mods...],
73+
--torch-executed-mods=[torch-executed-mods...]
74+
(Repeatable) Module that should always
75+
be run in Pytorch for execution (partial
76+
compilation must be enabled)
77+
--mbs=[torch-executed-mods...],
78+
--min-block-size=[torch-executed-mods...]
79+
Minimum number of contiguous TensorRT
80+
supported ops to compile a subgraph to
81+
TensorRT
7882
--embed-engine Whether to treat input file as a
7983
serialized TensorRT engine and embed it
8084
into a TorchScript module (device spec

0 commit comments

Comments
 (0)