@@ -9,19 +9,46 @@ namespace trtorch {
99namespace core {
1010namespace partitioning {
1111
12- inline bool isTensorOrTensorList (torch::jit::Value* val) {
13- return val->type ()->isSubtypeOf (torch::jit::TensorType::get ()) ||
14- val->type ()->isSubtypeOf (torch::jit::ListType::ofTensors ());
15- }
16-
1712struct usage_info {
1813 int produce_id = -1 ;
1914 std::vector<int > torch_use_id;
2015 std::vector<int > tensorrt_use_id;
2116};
2217
18+ inline bool isTensorOrTensorList (torch::jit::Value* val) {
19+ return val->type ()->isSubtypeOf (torch::jit::TensorType::get ()) ||
20+ val->type ()->isSubtypeOf (torch::jit::ListType::ofTensors ());
21+ }
22+
23+ bool isAllNodesSupported (const std::vector<torch::jit::Node*>& nodes) {
24+ for (auto node : nodes) {
25+ if (!conversion::OpSupported (node)) {
26+ return false ;
27+ }
28+ }
29+ return true ;
30+ }
31+
32+ bool containNonTensorInputs (torch::jit::Node* n, const std::unordered_set<torch::jit::Value*>& target_inputs) {
33+ for (auto input : n->inputs ()) {
34+ if (!isTensorOrTensorList (input) && target_inputs.count (input)) {
35+ return true ;
36+ }
37+ }
38+ return false ;
39+ }
40+
41+ bool containNonTensorOutputs (torch::jit::Node* n) {
42+ for (auto output : n->outputs ()) {
43+ if (!isTensorOrTensorList (output)) {
44+ return true ;
45+ }
46+ }
47+ return false ;
48+ }
49+
2350std::vector<torch::jit::Node*> getDependencyNodes (std::vector<torch::jit::Value*>& vals) {
24- // using bfs to get the DAG dependency nodes for input value
51+ // use bfs to get the DAG dependency nodes for input value
2552 std::queue<torch::jit::Value*, std::deque<torch::jit::Value*>> q (
2653 std::deque<torch::jit::Value*>(vals.begin (), vals.end ()));
2754 std::unordered_set<torch::jit::Node*> visited;
@@ -43,17 +70,50 @@ std::vector<torch::jit::Node*> getDependencyNodes(std::vector<torch::jit::Value*
4370 return stk;
4471}
4572
46- SegmentedBlock injectNodesForNonTensorInputs (SegmentedBlock& seg_block) {
73+ std::vector< SegmentedBlock> injectNodesForNonTensorInputs (SegmentedBlock& seg_block) {
4774 // reconstruct segmented_block if this block requires nonTensor input
4875 std::vector<torch::jit::Value*> nontensor_inputs;
4976 for (auto input : seg_block.raw_inputs ()) {
5077 if (!isTensorOrTensorList (input)) {
5178 nontensor_inputs.push_back (input);
5279 }
5380 }
54- std::vector<torch::jit::Node*> new_block_nodes = getDependencyNodes (nontensor_inputs);
55- new_block_nodes.insert (new_block_nodes.end (), seg_block.raw_nodes ().begin (), seg_block.raw_nodes ().end ());
56- return std::move (SegmentedBlock (seg_block.target (), new_block_nodes));
81+ std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes (nontensor_inputs);
82+
83+ std::vector<SegmentedBlock> new_seg_blocks;
84+ // if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, construct only
85+ // one new block
86+ if (seg_block.target () == SegmentedBlock::kTorch || isAllNodesSupported (dependency_nodes)) {
87+ dependency_nodes.insert (dependency_nodes.end (), seg_block.raw_nodes ().begin (), seg_block.raw_nodes ().end ());
88+ new_seg_blocks.emplace_back (seg_block.target (), dependency_nodes);
89+ } else {
90+ // if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
91+ std::unordered_set<torch::jit::Value*> nontensor_inputs_set (nontensor_inputs.begin (), nontensor_inputs.end ());
92+ new_seg_blocks.emplace_back (SegmentedBlock::kTorch , dependency_nodes);
93+ std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
94+ bool prev_non_tensor_outputs = false ;
95+ for (auto n : seg_block.raw_nodes ()) {
96+ // it's a kTorch block if it uses the nonTensor input and the nonTensor input is produced in kTorch block
97+ if (containNonTensorInputs (n, nontensor_inputs_set) || prev_non_tensor_outputs) {
98+ if (!tensorrt_nodes.empty ()) {
99+ new_seg_blocks.emplace_back (SegmentedBlock::kTensorRT , tensorrt_nodes);
100+ }
101+ pytorch_nodes.push_back (n);
102+ prev_non_tensor_outputs = containNonTensorOutputs (n);
103+ } else {
104+ if (!pytorch_nodes.empty ()) {
105+ new_seg_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
106+ }
107+ tensorrt_nodes.push_back (n);
108+ }
109+ }
110+ if (!tensorrt_nodes.empty ()) {
111+ new_seg_blocks.emplace_back (SegmentedBlock::kTensorRT , tensorrt_nodes);
112+ } else {
113+ new_seg_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
114+ }
115+ }
116+ return std::move (new_seg_blocks);
57117}
58118
59119void resolveNonTensorInputs (PartitionedGraph& segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
@@ -80,16 +140,17 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr<
80140 if (segmented_blocks[use_info.produce_id ].target () == SegmentedBlock::kTensorRT && !use_info.torch_use_id .empty ()) {
81141 int first_torch_id = use_info.torch_use_id .front ();
82142 if (!updated_segments.count (first_torch_id)) {
83- auto new_torch_block = injectNodesForNonTensorInputs (segmented_blocks[first_torch_id]);
143+ auto new_torch_block = injectNodesForNonTensorInputs (segmented_blocks[first_torch_id]). front () ;
84144 segmented_blocks[first_torch_id] = new_torch_block;
85145 updated_segments.insert (first_torch_id);
86146 }
87147 } else {
88148 // KTensorRT segments always need to inject nodes for the nonTensor inputs
89149 for (int i : use_info.tensorrt_use_id ) {
90150 if (!updated_segments.count (i)) {
91- auto new_seg_block = injectNodesForNonTensorInputs (segmented_blocks[i]);
92- segmented_blocks[i] = new_seg_block;
151+ auto to_inject_blocks = injectNodesForNonTensorInputs (segmented_blocks[i]);
152+ segmented_blocks.erase (segmented_blocks.begin () + i);
153+ segmented_blocks.insert (segmented_blocks.begin () + i, to_inject_blocks.begin (), to_inject_blocks.end ());
93154 updated_segments.insert (i);
94155 }
95156 }
0 commit comments