@@ -96,6 +96,29 @@ std::vector<torch::jit::Node*> getDependencyNodes(
96
96
return stk;
97
97
}
98
98
99
+ std::set<torch::jit::Node*> getDependentNodes (torch::jit::Node* n) {
100
+ std::set<torch::jit::Node*> dependent_nodes;
101
+ for (auto val : n->outputs ()) {
102
+ for (auto use : val->uses ()) {
103
+ dependent_nodes.insert (use.user );
104
+ }
105
+ }
106
+ if (const auto * schema = n->maybeSchema ()) {
107
+ for (size_t i = 0 ; i < n->inputs ().size (); ++i) {
108
+ const at::AliasInfo* formal = schema->arguments ()[i].alias_info ();
109
+ if (formal && formal->isWrite ()) {
110
+ for (auto use : n->inputs ()[i]->uses ()) {
111
+ torch::jit::Node* use_node = use.user ;
112
+ if (use_node->isAfter (n)) {
113
+ dependent_nodes.insert (use_node);
114
+ }
115
+ }
116
+ }
117
+ }
118
+ }
119
+ return dependent_nodes;
120
+ }
121
+
99
122
// check if the input and output of the graph is Tensor after collection is enabled. If it is, then fallback related
100
123
// nodes
101
124
void fallback_graph_nontensor_in_out (
@@ -145,6 +168,7 @@ void find_all_fallback_nodes(
145
168
for (auto use : output->uses ()) {
146
169
auto node = use.user ;
147
170
if (node->kind () != torch::jit::prim::Constant &&
171
+
148
172
global_fallback_nodes.insert ({node, FallbackNodeType::kNON_TENSOR }).second ) {
149
173
q.push (node);
150
174
}
@@ -319,6 +343,7 @@ std::vector<torch::jit::Node*> traverse_nodes_for_min_block_size(
319
343
size_t min_block_size) {
320
344
auto nodes = block->nodes ();
321
345
std::vector<torch::jit::Node*> cur_trt_nodes;
346
+ std::unordered_set<torch::jit::Node*> cur_trt_nodes_uses;
322
347
std::vector<torch::jit::Node*> min_block_fallback_nodes;
323
348
for (const auto n : nodes) {
324
349
if (n->kind () == torch::jit::prim::Constant)
@@ -328,11 +353,16 @@ std::vector<torch::jit::Node*> traverse_nodes_for_min_block_size(
328
353
if (!global_fallback_nodes.count (n)) {
329
354
// if this node is not in fallback nodes, then it's in trt segments
330
355
cur_trt_nodes.push_back (n);
356
+ auto dependent_nodes = getDependentNodes (n);
357
+ cur_trt_nodes_uses.insert (dependent_nodes.begin (), dependent_nodes.end ());
331
358
} else {
332
- if (cur_trt_nodes.size () < min_block_size) {
333
- min_block_fallback_nodes.insert (min_block_fallback_nodes.end (), cur_trt_nodes.begin (), cur_trt_nodes.end ());
359
+ if (cur_trt_nodes_uses.count (n)) {
360
+ if (cur_trt_nodes.size () < min_block_size) {
361
+ min_block_fallback_nodes.insert (min_block_fallback_nodes.end (), cur_trt_nodes.begin (), cur_trt_nodes.end ());
362
+ }
363
+ cur_trt_nodes.clear ();
364
+ cur_trt_nodes_uses.clear ();
334
365
}
335
- cur_trt_nodes.clear ();
336
366
}
337
367
}
338
368
if (cur_trt_nodes.size () < min_block_size) {
@@ -362,6 +392,59 @@ void find_min_block_size_fallback_nodes(
362
392
}
363
393
}
364
394
395
+ void merge_adjacent_segments_list_in_new_partition (
396
+ PartitionedGraph& original_partition,
397
+ PartitionedGraph& new_partition,
398
+ SegmentedBlock::SegmentedBlockTarget& segment_kind,
399
+ std::vector<size_t >& same_type_segment_idx) {
400
+ TORCHTRT_CHECK (!same_type_segment_idx.empty (), " Unable to merge empty segment list" );
401
+ if (same_type_segment_idx.size () == 1 ) {
402
+ new_partition.push_back (original_partition[same_type_segment_idx[0 ]]);
403
+ } else {
404
+ auto first_idx = same_type_segment_idx[0 ];
405
+ for (size_t i = 1 ; i < same_type_segment_idx.size (); ++i) {
406
+ TORCHTRT_CHECK (
407
+ same_type_segment_idx[i] == (first_idx + i),
408
+ " Unable to merge non-sequential segments: " << same_type_segment_idx);
409
+ }
410
+ LOG_DEBUG (
411
+ " Merging adjacent " << SegmentedBlock::target_to_str (segment_kind) << " segments: " << same_type_segment_idx);
412
+ std::vector<torch::jit::Node*> nodes;
413
+ for (auto segment_to_merge : same_type_segment_idx) {
414
+ const auto & merge_nodes = original_partition[segment_to_merge].raw_nodes ();
415
+ nodes.insert (nodes.end (), merge_nodes.begin (), merge_nodes.end ());
416
+ }
417
+ new_partition.emplace_back (segment_kind, nodes);
418
+ }
419
+ }
420
+
421
+ PartitionedGraph merge_adjacent_segments_of_same_type (PartitionedGraph& original_partition) {
422
+ PartitionedGraph new_partition;
423
+ SegmentedBlock::SegmentedBlockTarget segment_kind = SegmentedBlock::SegmentedBlockTarget::kTorch ;
424
+ std::vector<size_t > same_type_segment_idx;
425
+ for (size_t i = 0UL ; i < original_partition.size (); ++i) {
426
+ auto & segment = original_partition[i];
427
+ if (same_type_segment_idx.empty ()) {
428
+ segment_kind = segment.target ();
429
+ } else if (segment_kind != segment.target () || segment.do_not_merge ()) {
430
+ merge_adjacent_segments_list_in_new_partition (
431
+ original_partition, new_partition, segment_kind, same_type_segment_idx);
432
+ same_type_segment_idx.clear ();
433
+ segment_kind = segment.target ();
434
+ }
435
+ if (segment.do_not_merge ()) {
436
+ new_partition.push_back (segment);
437
+ } else {
438
+ same_type_segment_idx.push_back (i);
439
+ }
440
+ }
441
+ if (!same_type_segment_idx.empty ()) {
442
+ merge_adjacent_segments_list_in_new_partition (
443
+ original_partition, new_partition, segment_kind, same_type_segment_idx);
444
+ }
445
+ return new_partition;
446
+ }
447
+
365
448
PartitionedGraph segment_graph (
366
449
torch::jit::Block* block,
367
450
const PartitionInfo& partition_info,
@@ -387,56 +470,73 @@ PartitionedGraph segment_graph(
387
470
388
471
// segment the nodes
389
472
std::vector<torch::jit::Node*> in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes;
473
+ std::unordered_set<torch::jit::Node*> cur_trt_nodes_uses;
474
+ std::unordered_set<torch::jit::Node*> cur_pyt_nodes_uses;
390
475
for (const auto n : nodes) {
391
476
// Skip constant nodes as they are resources for both kinds of modules
392
477
if (n->kind () == torch::jit::prim::Constant) {
393
478
continue ;
394
479
}
480
+ auto dependent_nodes = getDependentNodes (n);
395
481
// the outputs of trt subgraph shouldn't be collections
396
482
if (check_node_fallback (n, global_fallback_nodes)) {
397
483
in_prog_trt_blk_nodes.push_back (n);
484
+ cur_trt_nodes_uses.insert (dependent_nodes.begin (), dependent_nodes.end ());
398
485
399
- // If there is an active PyTorch block and we have passed the threshold for a valid TRT
400
- // block then segment and reset the active PyTorch block
401
- if (in_prog_trt_blk_nodes. size () >= min_block_size && !in_prog_pyt_blk_nodes. empty ( )) {
486
+ // If we hit a TRT node that is dependent on nodes in the active PyTorch block, finalize the block to materialize
487
+ // those dependencies in the graph
488
+ if (cur_pyt_nodes_uses. count (n )) {
402
489
finalize_block (segmented_blocks, SegmentedBlock::kTorch , in_prog_pyt_blk_nodes);
490
+ cur_pyt_nodes_uses.clear ();
403
491
}
404
492
} else {
405
- // If there is an active TRT block that is valid segment and reset the active TRT block
406
- // otherwise add it to the active PyTorch block and reset
407
- if (in_prog_trt_blk_nodes.size () >= min_block_size) {
408
- finalize_block (segmented_blocks, SegmentedBlock::kTensorRT , in_prog_trt_blk_nodes);
409
- } else {
410
- LOG_DEBUG (
411
- " In progress TRT block does not meet minimum block size requirements, therefore folding into in progress PyTorch block" );
412
- in_prog_pyt_blk_nodes.insert (
413
- in_prog_pyt_blk_nodes.end (), in_prog_trt_blk_nodes.begin (), in_prog_trt_blk_nodes.end ());
493
+ // The current node is dependent on the active TRT block, finalize it to materialize those dependencies in the
494
+ // graph or add them to the active PyTorch block
495
+ if (cur_trt_nodes_uses.count (n)) {
496
+ // If there is an active TRT block that is valid segment and reset the active TRT block
497
+ // otherwise add it to the active PyTorch block and reset
498
+ if (in_prog_trt_blk_nodes.size () >= min_block_size) {
499
+ finalize_block (segmented_blocks, SegmentedBlock::kTensorRT , in_prog_trt_blk_nodes);
500
+ } else {
501
+ LOG_DEBUG (
502
+ " In progress TRT block does not meet minimum block size requirements, therefore folding into in progress PyTorch block" );
503
+ in_prog_pyt_blk_nodes.insert (
504
+ in_prog_pyt_blk_nodes.end (), in_prog_trt_blk_nodes.begin (), in_prog_trt_blk_nodes.end ());
505
+ cur_pyt_nodes_uses.insert (cur_trt_nodes_uses.begin (), cur_trt_nodes_uses.end ());
506
+ }
507
+ in_prog_trt_blk_nodes.clear ();
508
+ cur_trt_nodes_uses.clear ();
414
509
}
415
- in_prog_trt_blk_nodes.clear ();
416
510
// if there is a prim::If then this if node will be encapsulated in a SegmentedBlock
417
511
// we shouldn't inject node for this block in dependency analysis process
418
512
if (n->kind () == torch::jit::prim::If) {
419
513
LOG_DEBUG (
420
514
" Hit a conditional statement, finializing in progress PYT block and creating a new one for the conditional" );
421
515
if (!in_prog_pyt_blk_nodes.empty ()) {
422
516
finalize_block (segmented_blocks, SegmentedBlock::kTorch , in_prog_pyt_blk_nodes);
517
+ cur_pyt_nodes_uses.clear ();
423
518
}
424
519
auto cond_node = std::vector<torch::jit::Node*>{n};
425
520
finalize_block (segmented_blocks, SegmentedBlock::kTorch , cond_node);
521
+ segmented_blocks.back ().do_not_merge (true );
426
522
continue ;
427
523
} else if (n->kind () == torch::jit::prim::Loop) {
428
524
if (!in_prog_pyt_blk_nodes.empty ()) {
429
525
finalize_block (segmented_blocks, SegmentedBlock::kTorch , in_prog_pyt_blk_nodes);
526
+ cur_pyt_nodes_uses.clear ();
430
527
}
431
528
if (checkLoopEvaluatable (n)) {
432
529
in_prog_trt_blk_nodes.push_back (n);
530
+ cur_trt_nodes_uses.insert (dependent_nodes.begin (), dependent_nodes.end ());
433
531
} else {
434
532
auto loop_node = std::vector<torch::jit::Node*>{n};
435
533
finalize_block (segmented_blocks, SegmentedBlock::kTorch , loop_node);
534
+ segmented_blocks.back ().do_not_merge (true );
436
535
}
437
536
continue ;
438
537
}
439
538
in_prog_pyt_blk_nodes.push_back (n);
539
+ cur_pyt_nodes_uses.insert (dependent_nodes.begin (), dependent_nodes.end ());
440
540
}
441
541
}
442
542
@@ -451,6 +551,8 @@ PartitionedGraph segment_graph(
451
551
in_prog_pyt_blk_nodes.end (), in_prog_trt_blk_nodes.begin (), in_prog_trt_blk_nodes.end ());
452
552
finalize_block (segmented_blocks, SegmentedBlock::kTorch , in_prog_pyt_blk_nodes);
453
553
}
554
+
555
+ segmented_blocks = merge_adjacent_segments_of_same_type (segmented_blocks);
454
556
return segmented_blocks;
455
557
}
456
558
@@ -465,7 +567,7 @@ PartitionedGraph Partition(
465
567
fallback_graph_nontensor_in_out (block, global_fallback_nodes);
466
568
467
569
// segment lowering global graph into blocks
468
- LOG_DEBUG (" Parititioning source module into PyTorch and TensorRT sub blocks" );
570
+ LOG_DEBUG (" Partitioning source module into PyTorch and TensorRT sub blocks" );
469
571
PartitionedGraph segmented_blocks = segment_graph (block, partition_info, global_fallback_nodes);
470
572
471
573
// It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks
0 commit comments