Skip to content

Commit 19ecc64

Browse files
committed
refactor!(//cpp): Inlining partial compilation settings since the
feature is now on by default BREAKING CHANGE: This commit changes the API for automatic fallback to inline settings regarding partial compilation in preparation for it to be turned on by default Now in the compile spec instead of a `torch_fallback` field with its associated struct, there are four new fields in the compile spec ```c++ bool require_full_compilation = true; uint64_t min_block_size = 3; std::vector<std::string> torch_executed_ops = {}; std::vector<std::string> torch_executed_modules = {}; ``` Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 2a0d1c8 commit 19ecc64

File tree

4 files changed

+68
-61
lines changed

4 files changed

+68
-61
lines changed

cpp/include/trtorch/trtorch.h

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -516,38 +516,6 @@ struct TRTORCH_API CompileSpec {
516516
bool explicit_set_dtype;
517517
};
518518

519-
/**
520-
* @brief A struct to hold fallback info
521-
*/
522-
struct TRTORCH_API TorchFallback {
523-
/// enable the automatic fallback feature
524-
bool enabled = false;
525-
526-
/// minimum consecutive operation number that needs to be satisfied to convert to TensorRT
527-
uint64_t min_block_size = 1;
528-
529-
/// A list of names of operations that will explicitly run in PyTorch
530-
std::vector<std::string> forced_fallback_ops;
531-
532-
/// A list of names of modules that will explicitly run in PyTorch
533-
std::vector<std::string> forced_fallback_modules;
534-
535-
/**
536-
* @brief Construct a default Torch Fallback object, fallback will be off
537-
*/
538-
TorchFallback() = default;
539-
540-
/**
541-
* @brief Construct from a bool
542-
*/
543-
TorchFallback(bool enabled) : enabled(enabled) {}
544-
545-
/**
546-
* @brief Constructor for setting min_block_size
547-
*/
548-
TorchFallback(bool enabled, uint64_t min_size) : enabled(enabled), min_block_size(min_size) {}
549-
};
550-
551519
/**
552520
* @brief Construct a new Extra Info object
553521
* Convienence constructor to set fixed input size from vectors describing
@@ -643,11 +611,6 @@ struct TRTORCH_API CompileSpec {
643611
*/
644612
Device device;
645613

646-
/**
647-
* @brief Settings related to partial compilation
648-
*/
649-
TorchFallback torch_fallback;
650-
651614
/**
652615
* Sets the restrictions for the engine (CUDA Safety)
653616
*/
@@ -676,6 +639,27 @@ struct TRTORCH_API CompileSpec {
676639
* Calibration dataloaders for each input for post training quantizatiom
677640
*/
678641
nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
642+
643+
/**
644+
* Require the full module be compiled to TensorRT instead of potentially running unsupported operations in PyTorch
645+
*/
646+
bool require_full_compilation = false;
647+
648+
/**
649+
* Minimum number of contiguous supported operators to compile a subgraph to TensorRT
650+
*/
651+
uint64_t min_block_size = 3;
652+
653+
/**
654+
* List of aten operators that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True
655+
*/
656+
std::vector<std::string> torch_executed_ops;
657+
658+
659+
/**
660+
* List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True
661+
*/
662+
std::vector<std::string> torch_executed_modules;
679663
};
680664

681665
/**

cpp/src/compile_spec.cpp

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -323,21 +323,6 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
323323
internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
324324
}
325325

326-
// /* We want default behavior for types to match PyTorch, so in the case the user did not explicitly set the dtype
327-
// for inputs they will follow PyTorch convetions */ for (size_t i = 0; i < external.inputs.size(); i++) {
328-
// if (!external.inputs[i].get_explicit_set_dtype()) {
329-
// auto& precisions = internal.convert_info.engine_settings.enabled_precisions;
330-
// auto& internal_ins = internal.convert_info.inputs;
331-
// if (precisions.find(nvinfer1::DataType::kINT8) != precisions.end()) {
332-
// internal_ins[i].dtype = nvinfer1::DataType::kFLOAT;
333-
// } else if (precisions.find(nvinfer1::DataType::kHALF) != precisions.end()) {
334-
// internal_ins[i].dtype = nvinfer1::DataType::kHALF;
335-
// } else {
336-
// internal_ins[i].dtype = nvinfer1::DataType::kFLOAT;
337-
// }
338-
// }
339-
// }
340-
341326
internal.convert_info.engine_settings.sparse_weights = external.sparse_weights;
342327
internal.convert_info.engine_settings.disable_tf32 = external.disable_tf32;
343328
internal.convert_info.engine_settings.refit = external.refit;
@@ -346,10 +331,19 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
346331
internal.convert_info.engine_settings.strict_types = external.strict_types;
347332
internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback;
348333
internal.convert_info.engine_settings.max_batch_size = external.max_batch_size;
349-
internal.partition_info.enabled = external.torch_fallback.enabled;
350-
internal.partition_info.min_block_size = external.torch_fallback.min_block_size;
351-
internal.partition_info.forced_fallback_operators = external.torch_fallback.forced_fallback_ops;
352-
internal.lower_info.forced_fallback_modules = external.torch_fallback.forced_fallback_modules;
334+
335+
TRTORCH_CHECK(!(external.require_full_compilation && (external.torch_executed_ops.size() > 0)),
336+
"require_full_compilation is enabled however the list of ops to run in torch is not empty (Found "
337+
<< external.torch_executed_ops.size() << " ops)");
338+
339+
TRTORCH_CHECK(!(external.require_full_compilation && (external.torch_executed_modules.size() > 0)),
340+
"require_full_compilation is enabled however the list of modules to run in torch is not empty (Found "
341+
<< external.torch_executed_modules.size() << " modules)");
342+
343+
internal.partition_info.enabled = external.require_full_compilation;
344+
internal.partition_info.min_block_size = external.min_block_size;
345+
internal.partition_info.forced_fallback_operators = std::move(external.torch_executed_ops);
346+
internal.lower_info.forced_fallback_modules = std::move(external.torch_executed_modules);
353347

354348
switch (external.device.device_type) {
355349
case CompileSpec::Device::DeviceType::kDLA:

tests/core/BUILD

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,37 @@
1+
config_setting(
2+
name = "use_pre_cxx11_abi",
3+
values = {
4+
"define": "abi=pre_cxx11_abi",
5+
}
6+
)
7+
8+
filegroup(
9+
name = "jit_models",
10+
srcs = ["//tests/modules:mobilenet_v2_scripted.jit.pt"]
11+
)
12+
13+
cc_test(
14+
name = "test_detecting_input_type",
15+
srcs = ["test_detecting_input_type.cpp"],
16+
deps = [
17+
"//tests/util",
18+
"//core",
19+
"//core/lowering",
20+
"//core/util:prelude",
21+
"@googletest//:gtest_main",
22+
] + select({
23+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
24+
"//conditions:default": ["@libtorch//:libtorch"],
25+
}),
26+
data = [
27+
":jit_models"
28+
]
29+
)
30+
131
test_suite(
232
name = "core_tests",
333
tests = [
34+
":test_detecting_input_type",
435
"//tests/core/conversion:conversion_tests",
536
"//tests/core/lowering:lowering_tests",
637
"//tests/core/partitioning:partitioning_tests"

tests/cpp/test_module_fallback.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ TEST(CppAPITest, ResNetModuleFallbacksCorrectly) {
2323
}
2424

2525
trtorch::CompileSpec cfg(input_shapes);
26-
cfg.torch_fallback.enabled = true;
27-
cfg.torch_fallback.forced_fallback_modules.push_back("torchvision.models.resnet.BasicBlock");
26+
cfg.torch_executed_modules.push_back("torchvision.models.resnet.BasicBlock");
2827

2928
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
3029
auto trt_mod = trtorch::CompileGraph(mod, cfg);
@@ -51,9 +50,8 @@ TEST(CppAPITest, MobileNetModuleFallbacksCorrectlyWithOneEngine) {
5150
}
5251

5352
trtorch::CompileSpec cfg(input_shapes);
54-
cfg.torch_fallback.enabled = true;
55-
cfg.torch_fallback.min_block_size = 5;
56-
cfg.torch_fallback.forced_fallback_modules.push_back("torchvision.models.mobilenetv2.ConvBNActivation");
53+
cfg.min_block_size = 5;
54+
cfg.torch_executed_modules.push_back("torchvision.models.mobilenetv2.ConvBNActivation");
5755

5856
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
5957
auto trt_mod = trtorch::CompileGraph(mod, cfg);

0 commit comments

Comments
 (0)