@@ -219,19 +219,16 @@ void AddIfBlockToGraph(
219
219
return ;
220
220
}
221
221
222
- GraphAndMapping ConstructFallbackGraph (
222
+ GraphAndMapping ConstructFallbackGraph_ (
223
223
torch::jit::script::Module& new_mod,
224
224
torch::jit::Block* block,
225
- std::unordered_map< const torch::jit::Value*, torch::jit::IValue> example_tensor_map ,
226
- CompileSpec cfg ,
225
+ partitioning::PartitioningCtx* partitioning_ctx ,
226
+ conversion::ConversionInfo convert_info ,
227
227
ir::StaticParams static_params,
228
- std::unordered_map<torch::jit::Node*, int >& fallback_nodes) {
229
- auto convert_cfg = cfg.convert_info ;
230
- auto partitioning_info = cfg.partitioning_info ;
231
-
228
+ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map) {
232
229
auto new_g = std::make_shared<torch::jit::Graph>();
233
230
234
- auto segmented_blocks = partitioning::Partition (block, example_tensor_map, partitioning_info, fallback_nodes );
231
+ auto segmented_blocks = partitioning::partition (partitioning_ctx, block, example_tensor_map );
235
232
236
233
// the mapping from lowering graph => fallback global graph
237
234
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
@@ -240,7 +237,7 @@ GraphAndMapping ConstructFallbackGraph(
240
237
}
241
238
242
239
for (auto & seg_block : segmented_blocks) {
243
- LOG_INFO (seg_block << " (GraphInSegmentedBlock) \n " );
240
+ LOG_INFO (" Block segment: " << seg_block );
244
241
std::ostringstream trt_engine_id;
245
242
trt_engine_id << reinterpret_cast <const int *>(&seg_block);
246
243
@@ -254,12 +251,12 @@ GraphAndMapping ConstructFallbackGraph(
254
251
inputs.push_back (in);
255
252
}
256
253
// update the input ranges for each segments
257
- convert_cfg .inputs = ir::associate_specs_with_inputs (seg_block.g (), inputs, static_params);
254
+ convert_info .inputs = ir::associate_specs_with_inputs (seg_block.g (), inputs, static_params);
258
255
259
256
// TODO mapping Inputs Ivalue to flatten one here
260
- auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_cfg , static_params);
257
+ auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_info , static_params);
261
258
auto temp_g = std::make_shared<torch::jit::Graph>();
262
- auto device_spec = convert_cfg .engine_settings .device ;
259
+ auto device_spec = convert_info .engine_settings .device ;
263
260
auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
264
261
AddEngineToGraph (new_mod, temp_g, engine, cuda_device, trt_engine_id.str (), true );
265
262
@@ -272,8 +269,8 @@ GraphAndMapping ConstructFallbackGraph(
272
269
// convert the 2 blocks in prim::if and get the converted graph with mappings
273
270
std::vector<GraphAndMapping> graph_and_mappings;
274
271
for (auto cur_block : if_node->blocks ()) {
275
- graph_and_mappings.push_back (
276
- ConstructFallbackGraph ( new_mod, cur_block, example_tensor_map, cfg , static_params, fallback_nodes ));
272
+ graph_and_mappings.push_back (ConstructFallbackGraph_ (
273
+ new_mod, cur_block, partitioning_ctx, convert_info , static_params, example_tensor_map ));
277
274
}
278
275
AddIfBlockToGraph (new_g, if_node, graph_and_mappings, old_to_new_g);
279
276
@@ -303,13 +300,32 @@ GraphAndMapping ConstructFallbackGraph(
303
300
return {new_g, old_to_new_g};
304
301
}
305
302
303
+ GraphAndMapping ConstructFallbackGraph (
304
+ torch::jit::script::Module& new_mod,
305
+ torch::jit::Block* block,
306
+ CompileSpec cfg,
307
+ ir::StaticParams static_params,
308
+ ir::CollectionTypeMap first_use_types) {
309
+ auto convert_info = cfg.convert_info ;
310
+ auto partitioning_info = cfg.partitioning_info ;
311
+
312
+ auto partitioning_ctx = partitioning::PartitioningCtx (block, partitioning_info);
313
+ auto collection_input_ivalues_map =
314
+ partitioning::generateRandomInputs (partitioning_info.collection_input_spec_map , first_use_types);
315
+
316
+ return ConstructFallbackGraph_ (
317
+ new_mod, block, &partitioning_ctx, convert_info, static_params, collection_input_ivalues_map);
318
+ }
319
+
306
320
void MapInputsAndDetermineDTypes (
307
321
CompileSpec& cfg,
308
322
std::shared_ptr<torch::jit::Graph>& g,
309
323
ir::StaticParams& static_params,
310
324
ir::CollectionTypeMap& first_use_type_map) {
311
325
cfg.convert_info .collection_input_spec_map =
312
326
std::move (ir::associate_specs_with_collection_inputs (g, cfg.graph_inputs , static_params));
327
+ cfg.partitioning_info .collection_input_spec_map =
328
+ ir::CollectionInputSpecMap (cfg.convert_info .collection_input_spec_map );
313
329
314
330
auto collection_inputs = ir::get_collection_inputs (g, static_params);
315
331
LOG_DEBUG (
@@ -434,11 +450,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
434
450
(!(cfg.lower_info .forced_fallback_modules .size () == 0 &&
435
451
cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
436
452
outputIsCollection)) {
437
- std::unordered_map<torch::jit::Node*, int > fallback_nodes;
438
- auto collection_input_ivalues_map =
439
- partitioning::generateRandomInputs (cfg.convert_info .collection_input_spec_map , first_use_types);
440
- auto graph_and_mapping = ConstructFallbackGraph (
441
- new_mod, g->block (), collection_input_ivalues_map, cfg, static_params, fallback_nodes);
453
+ auto graph_and_mapping = ConstructFallbackGraph (new_mod, g->block (), cfg, static_params, first_use_types);
442
454
new_g = graph_and_mapping.first ;
443
455
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
444
456
for (size_t i = 0 ; i < new_g->inputs ().size (); ++i) {
0 commit comments