File tree Expand file tree Collapse file tree 5 files changed +9
-5
lines changed Expand file tree Collapse file tree 5 files changed +9
-5
lines changed Original file line number Diff line number Diff line change @@ -19,7 +19,7 @@ struct CompileSpec {
19
19
partitioning::PartitionInfo partition_info;
20
20
};
21
21
22
- bool CheckMethodOperatorSupport (const torch::jit::script::Module& mod, std::string method_name);
22
+ bool CheckMethodOperatorSupport (const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg );
23
23
24
24
std::string ConvertGraphToTRTEngine (const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg);
25
25
Original file line number Diff line number Diff line change @@ -17,7 +17,7 @@ struct LowerInfo {
17
17
bool disable_cse = false ;
18
18
std::vector<std::string> forced_fallback_modules;
19
19
friend std::ostream& operator <<(std::ostream& os, const LowerInfo& l);
20
- }
20
+ };
21
21
22
22
void LowerBlock (torch::jit::Block* b);
23
23
void LowerGraph (std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info);
Original file line number Diff line number Diff line change @@ -577,6 +577,9 @@ struct TRTORCH_API CompileSpec {
577
577
// / A list of names of operations that will explicitly run in PyTorch
578
578
std::vector<std::string> forced_fallback_ops;
579
579
580
+ // / A list of names of modules that will explicitly run in PyTorch
581
+ std::vector<std::string> forced_fallback_modules;
582
+
580
583
/* *
581
584
* @brief Construct a default Torch Fallback object, fallback will be off
582
585
*/
@@ -781,7 +784,7 @@ TRTORCH_API void dump_build_info();
781
784
*
782
785
* @returns bool: Method is supported by TRTorch
783
786
*/
784
- TRTORCH_API bool CheckMethodOperatorSupport (const torch::jit::Module& module , std::string method_name);
787
+ TRTORCH_API bool CheckMethodOperatorSupport (const torch::jit::Module& module , std::string method_name, CompileSpec info );
785
788
786
789
/* *
787
790
* @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT
Original file line number Diff line number Diff line change @@ -375,6 +375,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
375
375
internal.partition_info .enabled = external.torch_fallback .enabled ;
376
376
internal.partition_info .min_block_size = external.torch_fallback .min_block_size ;
377
377
internal.partition_info .forced_fallback_operators = external.torch_fallback .forced_fallback_ops ;
378
+ internal.lower_info .forced_fallback_modules = external.torch_fallback .forced_fallback_modules ;
378
379
379
380
switch (external.device .device_type ) {
380
381
case CompileSpec::Device::DeviceType::kDLA :
Original file line number Diff line number Diff line change @@ -11,8 +11,8 @@ namespace trtorch {
11
11
core::CompileSpec to_internal_compile_spec (CompileSpec external);
12
12
core::runtime::CudaDevice to_internal_cuda_device (CompileSpec::Device device);
13
13
14
- bool CheckMethodOperatorSupport (const torch::jit::script::Module& module , std::string method_name) {
15
- return core::CheckMethodOperatorSupport (module , method_name);
14
+ bool CheckMethodOperatorSupport (const torch::jit::script::Module& module , std::string method_name, CompileSpec info ) {
15
+ return core::CheckMethodOperatorSupport (module , method_name, to_internal_compile_spec (info) );
16
16
}
17
17
18
18
std::string ConvertGraphToTRTEngine (
You can’t perform that action at this time.
0 commit comments