@@ -10,9 +10,9 @@ namespace core {
10
10
namespace partitioning {
11
11
12
12
struct usage_info {
13
- int produce_id = - 1 ;
14
- std::vector<int > torch_use_id;
15
- std::vector<int > tensorrt_use_id;
13
+ size_t produce_id; // id of segmented block which contains a raw value of a given torch::jit::Value
14
+ std::vector<size_t > torch_use_id; // ids of segmented blocks which are of type Pytorch
15
+ std::vector<size_t > tensorrt_use_id; // ids of segmented blocks which are of type TensorRT
16
16
};
17
17
18
18
inline bool isTensorOrTensorList (torch::jit::Value* val) {
@@ -70,44 +70,54 @@ std::vector<torch::jit::Node*> getDependencyNodes(std::vector<torch::jit::Value*
70
70
return stk;
71
71
}
72
72
73
- std::vector<SegmentedBlock> injectNodesForNonTensorInputs (SegmentedBlock& seg_block) {
73
+ std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs (SegmentedBlock& seg_block) {
74
74
// reconstruct segmented_block if this block requires nonTensor input
75
75
std::vector<torch::jit::Value*> nontensor_inputs;
76
+ // Gather all non-tensor inputs for this seg_block
76
77
for (auto input : seg_block.raw_inputs ()) {
77
78
if (!isTensorOrTensorList (input)) {
78
79
nontensor_inputs.push_back (input);
79
80
}
80
81
}
81
- std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes (nontensor_inputs);
82
82
83
+ std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes (nontensor_inputs);
83
84
std::vector<SegmentedBlock> new_seg_blocks;
84
- // if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, construct only
85
- // one new block
85
+ // if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, merge the
86
+ // dependency nodes at the beginning of the current segmented_block and return this merged segmented_block
86
87
if (seg_block.target () == SegmentedBlock::kTorch || isAllNodesSupported (dependency_nodes)) {
87
88
dependency_nodes.insert (dependency_nodes.end (), seg_block.raw_nodes ().begin (), seg_block.raw_nodes ().end ());
88
89
new_seg_blocks.emplace_back (seg_block.target (), dependency_nodes);
89
90
} else {
90
91
// if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
91
92
std::unordered_set<torch::jit::Value*> nontensor_inputs_set (nontensor_inputs.begin (), nontensor_inputs.end ());
92
- std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
93
+ std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes (dependency_nodes.begin (), dependency_nodes.end ());
94
+
93
95
bool prev_non_tensor_outputs = false ;
94
96
for (auto n : seg_block.raw_nodes ()) {
95
- // it's a kTorch block if it uses the nonTensor input and the nonTensor input is produced in kTorch block
97
+ // Check if the node has non-tensor inputs or if it consumes non-tensor outputs of previous node.
98
+ // In these cases, these nodes are placed into a new Pytorch SegmentedBlock. Else, they form a new TensorRT
99
+ // SegmentedBlock.
96
100
if (containTargetInputs (n, nontensor_inputs_set) || prev_non_tensor_outputs) {
101
+ // If tensorrt_nodes is not empty, the previous nodes were all tensorrt_nodes. Construct a
102
+ // TensorRT segmented_block and clear the tensorrt_nodes list to be later used for new TRT segments.
97
103
if (!tensorrt_nodes.empty ()) {
98
104
new_seg_blocks.emplace_back (SegmentedBlock::kTensorRT , tensorrt_nodes);
99
105
tensorrt_nodes.clear ();
100
106
}
101
107
pytorch_nodes.push_back (n);
102
108
prev_non_tensor_outputs = containNonTensorOutputs (n);
103
109
} else {
110
+ // If pytorch_nodes is not empty, the previous nodes were all tensorrt_nodes. Construct a
111
+ // Pytorch segmented_block and clear the pytorch_nodes list to be later used for new Pytorch segments.
104
112
if (!pytorch_nodes.empty ()) {
105
113
new_seg_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
106
114
pytorch_nodes.clear ();
107
115
}
108
116
tensorrt_nodes.push_back (n);
109
117
}
110
118
}
119
+
120
+ // Form the last segmented_block with the left over nodes in tensorrt_nodes or pytorch_nodes correspondingly.
111
121
if (!tensorrt_nodes.empty ()) {
112
122
new_seg_blocks.emplace_back (SegmentedBlock::kTensorRT , tensorrt_nodes);
113
123
} else {
@@ -118,7 +128,20 @@ std::vector<SegmentedBlock> injectNodesForNonTensorInputs(SegmentedBlock& seg_bl
118
128
}
119
129
120
130
void resolveNonTensorInputs (PartitionedGraph& segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
121
- // for NonTensor inputs in TensorRT segments, count the usages on Torch segments and TensorRT segments
131
+ // create a list so we can insert SegmentedBlock without losing the iterators
132
+ std::list<SegmentedBlock> segmented_blocks_list (segmented_blocks.begin (), segmented_blocks.end ());
133
+ std::unordered_map<size_t , std::list<SegmentedBlock>::iterator> idx_to_iter;
134
+ auto iter = segmented_blocks_list.begin ();
135
+ for (size_t i = 0 ; i < segmented_blocks.size (); ++i, ++iter) {
136
+ idx_to_iter[i] = iter;
137
+ }
138
+
139
+ // usage_counts is a map which stores non-tensor inputs as keys and the values are indices of segmented blocks which
140
+ // have these non-tensor inputs. Iterate through the graph (segmented blocks) from bottom to top. When we find a
141
+ // non-tensor input in a segmented block of index "i", store it in the usage_counts map. Now for each non-tensor
142
+ // inputs recorded in the usage_counts map, we check if any previous segmented block (segmented block index i goes
143
+ // from n-1 to 0) generated/contains this non-tensor input. If so, we set this idx as the produce_id as it produces
144
+ // the non-tensor input.
122
145
std::unordered_map<torch::jit::Value*, usage_info> usage_counts;
123
146
for (int i = segmented_blocks.size () - 1 ; i >= 0 ; --i) {
124
147
for (auto input : segmented_blocks[i].raw_inputs ()) {
@@ -127,36 +150,44 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr<
127
150
: usage_counts[input].tensorrt_use_id .push_back (i);
128
151
}
129
152
}
153
+
130
154
for (auto & use : usage_counts) {
155
+ // Set the produce_id to the segmented block index that contains/produces this non-tensor torch::jit::Value
131
156
if (segmented_blocks[i].contain_raw_value (use.first )) {
132
157
use.second .produce_id = i;
133
158
}
134
159
}
135
160
}
161
+
136
162
std::unordered_set<int > updated_segments;
137
163
for (auto & use : usage_counts) {
138
164
auto use_info = use.second ;
139
165
// if the segment that produce this nonTensor value is kTensorRT but consumed in kTorch, inject nodes in the first
140
- // kTorch segments
166
+ // kTorch segment.
141
167
if (segmented_blocks[use_info.produce_id ].target () == SegmentedBlock::kTensorRT && !use_info.torch_use_id .empty ()) {
142
- int first_torch_id = use_info.torch_use_id .front ();
168
+ auto first_torch_id = use_info.torch_use_id .front ();
143
169
if (!updated_segments.count (first_torch_id)) {
144
- auto new_torch_block = injectNodesForNonTensorInputs (segmented_blocks[first_torch_id]).front ();
145
- segmented_blocks[first_torch_id] = new_torch_block;
170
+ // Segmented Blocks with non-tensor inputs will have to be re-segmented as
171
+ // TRTorch doesn't support non-tensor inputs for a module.
172
+ auto new_torch_block = segmentBlocksWithNonTensorInputs (segmented_blocks[first_torch_id]).front ();
173
+ *idx_to_iter[first_torch_id] = new_torch_block;
146
174
updated_segments.insert (first_torch_id);
147
175
}
148
- } else {
149
- // KTensorRT segments always need to inject nodes for the nonTensor inputs
150
- for (int i : use_info.tensorrt_use_id ) {
151
- if (!updated_segments.count (i)) {
152
- auto to_inject_blocks = injectNodesForNonTensorInputs (segmented_blocks[i]);
153
- segmented_blocks.erase (segmented_blocks.begin () + i);
154
- segmented_blocks.insert (segmented_blocks.begin () + i, to_inject_blocks.begin (), to_inject_blocks.end ());
155
- updated_segments.insert (i);
156
- }
176
+ }
177
+ // kTensorRT segments always need to inject nodes for the nonTensor inputs
178
+ for (auto i : use_info.tensorrt_use_id ) {
179
+ if (!updated_segments.count (i)) {
180
+ // Segmented Blocks with non-tensor inputs will have to be re-segmented as
181
+ // TRTorch doesn't support non-tensor inputs for a module.
182
+ auto to_inject_blocks = segmentBlocksWithNonTensorInputs (segmented_blocks[i]);
183
+ auto next_iter = segmented_blocks_list.erase (idx_to_iter[i]);
184
+ segmented_blocks_list.insert (next_iter, to_inject_blocks.begin (), to_inject_blocks.end ());
185
+ updated_segments.insert (i);
157
186
}
158
187
}
159
188
}
189
+ segmented_blocks.clear ();
190
+ segmented_blocks.insert (segmented_blocks.begin (), segmented_blocks_list.begin (), segmented_blocks_list.end ());
160
191
return ;
161
192
}
162
193
0 commit comments