Skip to content

Commit f34e230

Browse files
committed
refactor!: Changing the C++ api to be snake case
BREAKING CHANGE: This changes the C++ API ::ts APIs to be snake case and for CompileModules to become just compile Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 4d606bc commit f34e230

24 files changed

+46
-45
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,5 @@ examples/int8/ptq/ptq
5656
examples/int8/qat/qat
5757
examples/int8/training/vgg16/data/*
5858
examples/int8/datasets/data/*
59-
env/**/*
59+
env/**/*
60+
bazel-Torch-TensorRT-Preview

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ More Information / System Architecture:
1414

1515
## Building a docker container for Torch-TensorRT Preview
1616

17-
We provide a `Dockerfile` in `docker/` directory. We build `Torch-TensorRT` on top of a `Pytorch NGC container` which provide basic dependencies (like CUDA, CUDNN, CUBLAS, TensorRT, Pytorch and others) The dependency libraries in the container can be found in the <a href="https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html">release notes</a>.
17+
We provide a `Dockerfile` in `docker/` directory. We build `Torch-TensorRT` on top of a `Pytorch NGC container` which provide basic dependencies (like CUDA, CUDNN, CUBLAS, TensorRT, Pytorch and others) The dependency libraries in the container can be found in the <a href="https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html">release notes</a>.
1818

1919
Please follow this instruction to build a Docker container.
2020

@@ -41,7 +41,7 @@ auto compile_settings = torch_tensorrt::ts::CompileSpec({input});
4141
// FP16 execution
4242
compile_settings.enabled_precisions = {torch::kHalf};
4343
// Compile module
44-
auto trt_mod = torch_tensorrt::ts::CompileModule(ts_mod, compile_settings);
44+
auto trt_mod = torch_tensorrt::ts::compile(ts_mod, compile_settings);
4545
// Run like normal
4646
auto results = trt_mod.forward({in_tensor});
4747
// Save module for later

core/partitioning/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,6 @@ torchtrt::ts::CompileSpec cfg(input_sizes);
6262
cfg.torch_fallback = torchtrt::CompileSpec::TorchFallback(true);
6363
cfg.torch_fallback.min_block_size = 2;
6464
cfg.torch_fallback.forced_fallback_ops.push_back("aten::relu");
65-
auto trt_mod = torchtrt::ts::CompileModule(mod, cfg);
65+
auto trt_mod = torchtrt::ts::compile(mod, cfg);
6666
auto out = trt_mod.forward({in});
6767
```

cpp/bin/torchtrtc/main.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ int main(int argc, char** argv) {
600600
// Instead of compiling, just embed engine in a PyTorch module
601601
if (embed_engine) {
602602
std::string serialized_engine = read_buf(real_input_path);
603-
auto trt_mod = torchtrt::ts::EmbedEngineInNewModule(serialized_engine, compile_settings.device);
603+
auto trt_mod = torchtrt::ts::embed_engine_in_new_module(serialized_engine, compile_settings.device);
604604
trt_mod.save(real_output_path);
605605
return 0;
606606
}
@@ -622,12 +622,12 @@ int main(int argc, char** argv) {
622622
}
623623

624624
if (save_engine) {
625-
auto engine = torchtrt::ts::ConvertMethodToTRTEngine(mod, "forward", compile_settings);
625+
auto engine = torchtrt::ts::convert_method_to_trt_engine(mod, "forward", compile_settings);
626626
std::ofstream out(real_output_path);
627627
out << engine;
628628
out.close();
629629
} else {
630-
auto trt_mod = torchtrt::ts::CompileModule(mod, compile_settings);
630+
auto trt_mod = torchtrt::ts::compile(mod, compile_settings);
631631

632632
if (!no_threshold_check &&
633633
(compile_settings.enabled_precisions.size() == 1 &&

cpp/include/torch_tensorrt/torch_tensorrt.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,7 @@ struct TORCHTRT_API CompileSpec {
701701
*
702702
* @returns bool: Method is supported by Torch-TensorRT.TorchScript
703703
*/
704-
TORCHTRT_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, std::string method_name);
704+
TORCHTRT_API bool check_method_operator_support(const torch::jit::Module& module, std::string method_name);
705705

706706
/**
707707
* @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT
@@ -717,7 +717,7 @@ TORCHTRT_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, s
717717
*
718718
* @return: A new module trageting a TensorRT engine
719719
*/
720-
TORCHTRT_API torch::jit::Module CompileModule(const torch::jit::Module& module, CompileSpec info);
720+
TORCHTRT_API torch::jit::Module compile(const torch::jit::Module& module, CompileSpec info);
721721

722722
/**
723723
* @brief Compile a TorchScript method for NVIDIA GPUs using TensorRT
@@ -733,7 +733,7 @@ TORCHTRT_API torch::jit::Module CompileModule(const torch::jit::Module& module,
733733
* @return: std::string: Serialized TensorRT engine equivilant to the method
734734
* graph
735735
*/
736-
TORCHTRT_API std::string ConvertMethodToTRTEngine(
736+
TORCHTRT_API std::string convert_method_to_trt_engine(
737737
const torch::jit::Module& module,
738738
std::string method_name,
739739
CompileSpec info);
@@ -751,6 +751,6 @@ TORCHTRT_API std::string ConvertMethodToTRTEngine(
751751
*
752752
* @return: A new module trageting a TensorRT engine
753753
*/
754-
TORCHTRT_API torch::jit::Module EmbedEngineInNewModule(const std::string& engine, Device device);
754+
TORCHTRT_API torch::jit::Module embed_engine_in_new_module(const std::string& engine, Device device);
755755
} // namespace torchscript
756756
} // namespace torch_tensorrt

cpp/src/torch_tensorrt.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ namespace torchscript {
1212
// Defined in compile_spec.cpp
1313
torch_tensorrt::core::CompileSpec to_internal_compile_spec(CompileSpec external);
1414

15-
bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name) {
15+
bool check_method_operator_support(const torch::jit::script::Module& module, std::string method_name) {
1616
return torch_tensorrt::core::CheckMethodOperatorSupport(module, method_name);
1717
}
1818

19-
std::string ConvertMethodToTRTEngine(
19+
std::string convert_method_to_trt_engine(
2020
const torch::jit::script::Module& module,
2121
std::string method_name,
2222
CompileSpec info) {
@@ -26,14 +26,14 @@ std::string ConvertMethodToTRTEngine(
2626
return torch_tensorrt::core::ConvertGraphToTRTEngine(module, method_name, to_internal_compile_spec(info));
2727
}
2828

29-
torch::jit::script::Module CompileModule(const torch::jit::script::Module& module, CompileSpec info) {
29+
torch::jit::script::Module compile(const torch::jit::script::Module& module, CompileSpec info) {
3030
LOG_DEBUG(get_build_info());
3131
// Want to export a much simpler (non TRT header dependent) API so doing the
3232
// type conversion here
3333
return torch_tensorrt::core::CompileGraph(module, to_internal_compile_spec(info));
3434
}
3535

36-
torch::jit::Module EmbedEngineInNewModule(const std::string& engine, Device device) {
36+
torch::jit::Module embed_engine_in_new_module(const std::string& engine, Device device) {
3737
return torch_tensorrt::core::EmbedEngineInNewModule(engine, to_internal_cuda_device(device));
3838
}
3939

examples/benchmark/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,11 @@ int main(int argc, const char* argv[]) {
127127
compile_spec.enabled_precisions.insert(torch::kF16);
128128
#endif
129129

130-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, compile_spec);
130+
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);
131131

132132
#ifdef SAVE_ENGINE
133133
std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
134-
auto engine = torch_tensorrt::ts::ConvertMethodToTRTEngine(mod, "forward", compile_spec);
134+
auto engine = torch_tensorrt::ts::convert_method_to_trt_engine(mod, "forward", compile_spec);
135135
std::ofstream out("/tmp/engine_converted_from_jit.trt");
136136
out << engine;
137137
out.close();

examples/int8/ptq/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@ torch::jit::Module compile_int8_model(const std::string& data_dir, torch::jit::M
5656

5757
#ifdef SAVE_ENGINE
5858
std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
59-
auto engine = torch_tensorrt::ts::ConvertMethodToTRTEngine(mod, "forward", compile_spec);
59+
auto engine = torch_tensorrt::ts::convert_method_to_trt_engine(mod, "forward", compile_spec);
6060
std::ofstream out("/tmp/int8_engine_converted_from_jit.trt");
6161
out << engine;
6262
out.close();
6363
#endif
6464

6565
std::cout << "Compiling and quantizing module" << std::endl;
66-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, compile_spec);
66+
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);
6767
return std::move(trt_mod);
6868
}
6969

examples/int8/qat/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@ torch::jit::Module compile_int8_qat_model(const std::string& data_dir, torch::ji
4040

4141
#ifdef SAVE_ENGINE
4242
std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
43-
auto engine = torch_tensorrt::ts::ConvertMethodToTRTEngine(mod, "forward", compile_spec);
43+
auto engine = torch_tensorrt::ts::convert_method_to_trt_engine(mod, "forward", compile_spec);
4444
std::ofstream out("/tmp/int8_engine_converted_from_jit.trt");
4545
out << engine;
4646
out.close();
4747
#endif
4848

4949
std::cout << "Compiling and quantizing module" << std::endl;
50-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, compile_spec);
50+
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);
5151
return std::move(trt_mod);
5252
}
5353

tests/accuracy/test_dla_fp16_accuracy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ TEST_P(AccuracyTests, DLAFP16AccuracyIsClose) {
3434
compile_spec.device.allow_gpu_fallback = true;
3535
compile_spec.workspace_size = 1 << 28;
3636

37-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, compile_spec);
37+
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);
3838

3939
torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA});
4040
for (auto batch : *eval_dataloader) {

tests/accuracy/test_dla_int8_accuracy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ TEST_P(AccuracyTests, DLAINT8AccuracyIsClose) {
6161
torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100;
6262

6363
// Compile Graph
64-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, compile_spec);
64+
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);
6565

6666
// Check the INT8 accuracy in TRT
6767
torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA});

tests/accuracy/test_fp16_accuracy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) {
2929
auto compile_spec = torch_tensorrt::ts::CompileSpec({input_shape});
3030
compile_spec.enabled_precisions.insert(torch::kF16);
3131

32-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, compile_spec);
32+
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);
3333

3434
torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA});
3535
for (auto batch : *eval_dataloader) {

tests/accuracy/test_fp32_accuracy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ TEST_P(AccuracyTests, FP32AccuracyIsClose) {
2929
auto compile_spec = torch_tensorrt::ts::CompileSpec({input_shape});
3030
compile_spec.enabled_precisions.insert(torch::kF32);
3131

32-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, compile_spec);
32+
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);
3333

3434
torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA});
3535
for (auto batch : *eval_dataloader) {

tests/accuracy/test_int8_accuracy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ TEST_P(AccuracyTests, INT8AccuracyIsClose) {
5858
torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100;
5959

6060
// Compile Graph
61-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, compile_spec);
61+
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);
6262

6363
// Check the INT8 accuracy in TRT
6464
torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA});

tests/cpp/test_compiled_modules.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ TEST_P(CppAPITests, CompiledModuleIsClose) {
1313
std::vector<at::Tensor> jit_results;
1414
jit_results.push_back(jit_results_ivalues.toTensor());
1515

16-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, input_shapes);
16+
auto trt_mod = torch_tensorrt::ts::compile(mod, input_shapes);
1717
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
1818
std::vector<at::Tensor> trt_results;
1919
trt_results.push_back(trt_results_ivalues.toTensor());

tests/cpp/test_default_input_types.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ TEST_P(CppAPITests, InputsUseDefaultFP32) {
1515
auto spec = torch_tensorrt::ts::CompileSpec({in});
1616
spec.enabled_precisions.insert(torch_tensorrt::DataType::kHalf);
1717

18-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, spec);
18+
auto trt_mod = torch_tensorrt::ts::compile(mod, spec);
1919
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
2020
std::vector<at::Tensor> trt_results;
2121
trt_results.push_back(trt_results_ivalues.toTensor());
@@ -38,7 +38,7 @@ TEST_P(CppAPITests, InputsUseDefaultFP16) {
3838

3939
mod.to(torch::kHalf);
4040

41-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, spec);
41+
auto trt_mod = torch_tensorrt::ts::compile(mod, spec);
4242
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
4343
std::vector<at::Tensor> trt_results;
4444
trt_results.push_back(trt_results_ivalues.toTensor());
@@ -60,7 +60,7 @@ TEST_P(CppAPITests, InputsUseDefaultFP16WithoutFP16Enabled) {
6060

6161
mod.to(torch::kHalf);
6262

63-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, spec);
63+
auto trt_mod = torch_tensorrt::ts::compile(mod, spec);
6464
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
6565
std::vector<at::Tensor> trt_results;
6666
trt_results.push_back(trt_results_ivalues.toTensor());
@@ -84,7 +84,7 @@ TEST_P(CppAPITests, InputsRespectUserSettingFP16WeightsFP32In) {
8484

8585
mod.to(torch::kHalf);
8686

87-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, spec);
87+
auto trt_mod = torch_tensorrt::ts::compile(mod, spec);
8888
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
8989
std::vector<at::Tensor> trt_results;
9090
trt_results.push_back(trt_results_ivalues.toTensor());
@@ -106,7 +106,7 @@ TEST_P(CppAPITests, InputsRespectUserSettingFP32WeightsFP16In) {
106106
auto spec = torch_tensorrt::ts::CompileSpec({in});
107107
spec.enabled_precisions.insert(torch_tensorrt::DataType::kHalf);
108108

109-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, spec);
109+
auto trt_mod = torch_tensorrt::ts::compile(mod, spec);
110110
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
111111
std::vector<at::Tensor> trt_results;
112112
trt_results.push_back(trt_results_ivalues.toTensor());

tests/cpp/test_example_tensors.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ TEST_P(CppAPITests, InputsFromTensors) {
1111

1212
auto spec = torch_tensorrt::ts::CompileSpec({trt_inputs_ivalues[0].toTensor()});
1313

14-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, spec);
14+
auto trt_mod = torch_tensorrt::ts::compile(mod, spec);
1515
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
1616
std::vector<at::Tensor> trt_results;
1717
trt_results.push_back(trt_results_ivalues.toTensor());

tests/cpp/test_module_fallback.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ TEST(CppAPITest, ResNetModuleFallbacksCorrectly) {
2626
cfg.torch_executed_modules.push_back("torchvision.models.resnet.BasicBlock");
2727

2828
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
29-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, cfg);
29+
auto trt_mod = torch_tensorrt::ts::compile(mod, cfg);
3030
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
3131
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results, trt_results, 2e-6));
3232
}
@@ -54,7 +54,7 @@ TEST(CppAPITest, MobileNetModuleFallbacksCorrectlyWithOneEngine) {
5454
cfg.torch_executed_modules.push_back("torchvision.models.mobilenetv2.ConvBNActivation");
5555

5656
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
57-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, cfg);
57+
auto trt_mod = torch_tensorrt::ts::compile(mod, cfg);
5858

5959
auto g = trt_mod.get_method("forward").graph();
6060
auto nodes = g->block()->nodes();

tests/cpp/test_modules_as_engines.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ TEST_P(CppAPITests, ModuleToEngineToModuleIsClose) {
4040
cudaGetDevice(&device_id);
4141
compile_spec.device.device_type = torch_tensorrt::Device::DeviceType::kGPU;
4242
compile_spec.device.gpu_id = device_id;
43-
auto engine = torch_tensorrt::ts::ConvertMethodToTRTEngine(mod, "forward", input_ranges);
44-
auto trt_mod = torch_tensorrt::ts::EmbedEngineInNewModule(engine, compile_spec.device);
43+
auto engine = torch_tensorrt::ts::convert_method_to_trt_engine(mod, "forward", input_ranges);
44+
auto trt_mod = torch_tensorrt::ts::embed_engine_in_new_module(engine, compile_spec.device);
4545

4646
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, inputs_ivalues);
4747
std::vector<at::Tensor> trt_results;

tests/cpp/test_multi_gpu_serde.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ TEST_P(CppAPITests, CompiledModuleIsClose) {
1414
std::vector<at::Tensor> jit_results;
1515
jit_results.push_back(jit_results_ivalues.toTensor());
1616

17-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, input_shapes);
17+
auto trt_mod = torch_tensorrt::ts::compile(mod, input_shapes);
1818

1919
// Deliberately changing the device ID. torch_tensorrt runtime should correct the Device ID internally
2020
torch_tensorrt::set_device(1);

tests/cpp/test_multiple_registered_engines.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ TEST(CppAPITest, CanRunMultipleEngines) {
4141
std::vector<at::Tensor> jit2_results;
4242
jit2_results.push_back(jit2_results_ivalues.toTensor());
4343

44-
auto trt_mod1 = torch_tensorrt::ts::CompileModule(mod1, input_shapes);
44+
auto trt_mod1 = torch_tensorrt::ts::compile(mod1, input_shapes);
4545
torch::jit::IValue trt1_results_ivalues =
4646
torch_tensorrt::tests::util::RunModuleForward(trt_mod1, trt1_inputs_ivalues);
4747
std::vector<at::Tensor> trt1_results;
4848
trt1_results.push_back(trt1_results_ivalues.toTensor());
4949

50-
auto trt_mod2 = torch_tensorrt::ts::CompileModule(mod2, input_shapes);
50+
auto trt_mod2 = torch_tensorrt::ts::compile(mod2, input_shapes);
5151
torch::jit::IValue trt2_results_ivalues =
5252
torch_tensorrt::tests::util::RunModuleForward(trt_mod2, trt2_inputs_ivalues);
5353
std::vector<at::Tensor> trt2_results;

tests/cpp/test_runtime_thread_safety.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ TEST(CppAPITests, RuntimeThreadSafety) {
5252
// FP32 execution
5353
compile_settings.enabled_precisions = {torch::kFloat};
5454
compile_settings.strict_types = true;
55-
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, compile_settings);
56-
std::cout << "torch_tensorrt::ts::CompileModule" << std::endl;
55+
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_settings);
56+
std::cout << "torch_tensorrt::ts::compile" << std::endl;
5757

5858
int num_threads = 10;
5959
std::vector<torch::jit::IValue> out_vec(num_threads), trt_out_vec(num_threads);

tests/cpp/test_serialization.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ TEST_P(CppAPITests, SerializedModuleIsStillCorrect) {
2727
pre_serialized_inputs_ivalues.push_back(in.clone());
2828
}
2929

30-
auto pre_serialized_mod = torch_tensorrt::ts::CompileModule(mod, input_shapes);
30+
auto pre_serialized_mod = torch_tensorrt::ts::compile(mod, input_shapes);
3131
torch::jit::IValue pre_serialized_results_ivalues =
3232
torch_tensorrt::tests::util::RunModuleForward(pre_serialized_mod, pre_serialized_inputs_ivalues);
3333
std::vector<at::Tensor> pre_serialized_results;
@@ -57,7 +57,7 @@ TEST_P(CppAPITests, SerializedDynamicModuleIsStillCorrect) {
5757
}
5858

5959
auto pre_serialized_mod =
60-
torch_tensorrt::ts::CompileModule(mod, torch_tensorrt::ts::CompileSpec(toInputRangesDynamic(input_shapes)));
60+
torch_tensorrt::ts::compile(mod, torch_tensorrt::ts::CompileSpec(toInputRangesDynamic(input_shapes)));
6161
torch::jit::IValue pre_serialized_results_ivalues =
6262
torch_tensorrt::tests::util::RunModuleForward(pre_serialized_mod, pre_serialized_inputs_ivalues);
6363
std::vector<at::Tensor> pre_serialized_results;

tests/util/run_forward.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ std::vector<at::Tensor> RunModuleForwardAsEngine(torch::jit::Module& mod, std::v
1818
input_ranges.push_back(in.sizes());
1919
}
2020

21-
auto engine = torch_tensorrt::ts::ConvertMethodToTRTEngine(mod, "forward", input_ranges);
21+
auto engine = torch_tensorrt::ts::convert_method_to_trt_engine(mod, "forward", input_ranges);
2222
return RunEngine(engine, inputs);
2323
}
2424

0 commit comments

Comments
 (0)