Skip to content

Commit b4feb49

Browse files
ArvindSridharnarendasan
authored andcommitted
Add lowering info logic
Signed-off-by: Arvind Sridhar <[email protected]>
1 parent 02b23cb commit b4feb49

File tree

5 files changed

+9
-5
lines changed

5 files changed

+9
-5
lines changed

core/compiler.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ struct CompileSpec {
1919
partitioning::PartitionInfo partition_info;
2020
};
2121

22-
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);
22+
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg);
2323

2424
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg);
2525

core/lowering/lowering.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct LowerInfo {
1717
bool disable_cse = false;
1818
std::vector<std::string> forced_fallback_modules;
1919
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);
20-
}
20+
};
2121

2222
void LowerBlock(torch::jit::Block* b);
2323
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info);

cpp/api/include/trtorch/trtorch.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,9 @@ struct TRTORCH_API CompileSpec {
577577
/// A list of names of operations that will explicitly run in PyTorch
578578
std::vector<std::string> forced_fallback_ops;
579579

580+
/// A list of names of modules that will explicitly run in PyTorch
581+
std::vector<std::string> forced_fallback_modules;
582+
580583
/**
581584
* @brief Construct a default Torch Fallback object, fallback will be off
582585
*/
@@ -781,7 +784,7 @@ TRTORCH_API void dump_build_info();
781784
*
782785
* @returns bool: Method is supported by TRTorch
783786
*/
784-
TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, std::string method_name);
787+
TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, std::string method_name, CompileSpec info);
785788

786789
/**
787790
* @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT

cpp/api/src/compile_spec.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
375375
internal.partition_info.enabled = external.torch_fallback.enabled;
376376
internal.partition_info.min_block_size = external.torch_fallback.min_block_size;
377377
internal.partition_info.forced_fallback_operators = external.torch_fallback.forced_fallback_ops;
378+
internal.lower_info.forced_fallback_modules = external.torch_fallback.forced_fallback_modules;
378379

379380
switch (external.device.device_type) {
380381
case CompileSpec::Device::DeviceType::kDLA:

cpp/api/src/trtorch.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ namespace trtorch {
1111
core::CompileSpec to_internal_compile_spec(CompileSpec external);
1212
core::runtime::CudaDevice to_internal_cuda_device(CompileSpec::Device device);
1313

14-
bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name) {
15-
return core::CheckMethodOperatorSupport(module, method_name);
14+
bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name, CompileSpec info) {
15+
return core::CheckMethodOperatorSupport(module, method_name, to_internal_compile_spec(info));
1616
}
1717

1818
std::string ConvertGraphToTRTEngine(

0 commit comments

Comments
 (0)