|
3 | 3 | #include "core/lowering/passes/passes.h" |
4 | 4 | #include "core/util/prelude.h" |
5 | 5 | #include "torch/csrc/jit/api/module.h" |
| 6 | +#include "torch/csrc/jit/ir/constants.h" |
6 | 7 |
|
7 | 8 | namespace trtorch { |
8 | 9 | namespace core { |
@@ -67,6 +68,7 @@ void registerSegmentInOutIValues( |
67 | 68 | // create a module to run the graph |
68 | 69 | auto g = seg_block.g(); |
69 | 70 | auto copy_g = g->copy(); |
| 71 | +// LOG_INFO(*copy_g << "(copy graph)\n"); |
70 | 72 |
|
71 | 73 | // create tuple for multiple outputs |
72 | 74 | if (seg_block.raw_outputs().size() > 1) { |
@@ -163,19 +165,53 @@ void registerSegmentsInputsOutputs( |
163 | 165 | input_values.insert(graph_output); |
164 | 166 | } |
165 | 167 |
|
166 | | - for (auto& mini_graph_input : input_values) { |
167 | | - for (auto& seg_block : segmented_blocks) { |
| 168 | + // should be careful here because some in-place operations don't return any values |
| 169 | + for (auto& seg_block : segmented_blocks) { |
| 170 | + for (auto& mini_graph_input : input_values) { |
168 | 171 | if (std::find(seg_block.raw_inputs().begin(), seg_block.raw_inputs().end(), mini_graph_input) == |
169 | 172 | seg_block.raw_inputs().end() && |
170 | 173 | seg_block.contain_raw_input(mini_graph_input)) { |
171 | 174 | seg_block.registerOutput(mini_graph_input); |
172 | 175 | } |
173 | 176 | } |
| 177 | + if (seg_block.raw_outputs().empty()) { |
| 178 | + seg_block.registerOutput(seg_block.raw_inputs()[0]); |
| 179 | + } |
174 | 180 | } |
175 | 181 |
|
176 | 182 | return; |
177 | 183 | } |
178 | 184 |
|
| 185 | +void eraseNonTensorInputsOutputs( |
| 186 | + SegmentedBlock& seg_block, |
| 187 | + std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) { |
| 188 | + if (seg_block.target() == SegmentedBlock::kTorch) |
| 189 | + return; |
| 190 | + auto mini_graph = seg_block.g(); |
| 191 | + |
| 192 | + for (int i = seg_block.raw_inputs().size() - 1; i >= 0; --i) { |
| 193 | + // erase this input and prepend a prim::Constant if it's not Tensor |
| 194 | + if (!seg_block.raw_inputs()[i]->type()->isSubtypeOf(torch::jit::TensorType::get()) && |
| 195 | + !seg_block.raw_inputs()[i]->type()->isSubtypeOf(c10::ListType::ofTensors())) { |
| 196 | + auto new_val = torch::jit::insertConstant(*mini_graph, ivalues_maps[seg_block.raw_inputs()[i]]); |
| 197 | + seg_block.inputs()[i]->replaceAllUsesWith(new_val); |
| 198 | + seg_block.eraseInput(i); |
| 199 | + } |
| 200 | + } |
| 201 | + |
| 202 | + for (int i = seg_block.raw_outputs().size() - 1; i >= 0; --i) { |
| 203 | + if (!seg_block.raw_outputs()[i]->type()->isSubtypeOf(torch::jit::TensorType::get()) && |
| 204 | + !seg_block.raw_outputs()[i]->type()->isSubtypeOf(c10::ListType::ofTensors())) { |
| 205 | + seg_block.eraseOutput(i); |
| 206 | + } |
| 207 | + } |
| 208 | + |
| 209 | + // not sure to delete this block or just fallback to pytorch |
| 210 | + if (seg_block.raw_outputs().empty()) { |
| 211 | + seg_block.update_target(SegmentedBlock::kTorch); |
| 212 | + } |
| 213 | +} |
| 214 | + |
179 | 215 | void construct_segments( |
180 | 216 | std::vector<torch::jit::Node*>& pytorch_nodes, |
181 | 217 | std::vector<torch::jit::Node*>& tensorrt_nodes, |
@@ -240,6 +276,7 @@ std::vector<SegmentedBlock> segment_graph( |
240 | 276 | // register every segment's input shape, and it's running output Ivalues |
241 | 277 | for (auto& seg_block : segmented_blocks) { |
242 | 278 | registerSegmentInOutIValues(seg_block, ivalues_maps); |
| 279 | + eraseNonTensorInputsOutputs(seg_block, ivalues_maps); |
243 | 280 | } |
244 | 281 |
|
245 | 282 | return segmented_blocks; |
|
0 commit comments