Skip to content

Commit 1818813

Browse files
committed
[feat] Add dependency awareness to torch-trt partitioning (#40)
Adds a heuristic to torch-trt partitioning's segmentation to avoid materializing segments until we hit a dependency of that segment. This can significantly reduce the number of segments/engines in cases where the linear traversal of torchscipt nodes would otherwise produce alternating torch and TRT segments which are not dependent on each-other Fixes # (issue) Please delete options that are not relevant and/or add your own. - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - This change requires a documentation update - [ ] My code follows the style guidelines of this project (You can use the linters) - [ ] I have performed a self-review of my own code - [ ] I have commented my code, particularly in hard-to-understand areas and hacks - [ ] I have made corresponding changes to the documentation - [ ] I have added tests to verify my fix or my feature - [ ] New and existing unit tests pass locally with my changes - [ ] I have added the relevant labels to my PR in so that relevant reviewers are notified
1 parent 5d1acba commit 1818813

File tree

5 files changed

+472
-330
lines changed

5 files changed

+472
-330
lines changed

core/partitioning/SegmentedBlock.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,14 @@ struct SegmentedBlock {
9595
return target_;
9696
}
9797

98+
bool do_not_merge(void) const {
99+
return do_not_merge_;
100+
}
101+
102+
void do_not_merge(bool x) {
103+
do_not_merge_ = x;
104+
}
105+
98106
friend std::ostream& operator<<(std::ostream& os, const SegmentedBlock& b);
99107

100108
private:
@@ -107,6 +115,7 @@ struct SegmentedBlock {
107115
std::vector<torch::jit::Node*> nodes_;
108116
std::shared_ptr<torch::jit::Graph> g_;
109117
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_;
118+
bool do_not_merge_ = false;
110119
};
111120

112121
std::ostream& operator<<(std::ostream& os, const SegmentedBlock::SegmentedBlockTarget& t);

core/partitioning/partitioning.cpp

Lines changed: 119 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,29 @@ std::vector<torch::jit::Node*> getDependencyNodes(
9696
return stk;
9797
}
9898

99+
std::set<torch::jit::Node*> getDependentNodes(torch::jit::Node* n) {
100+
std::set<torch::jit::Node*> dependent_nodes;
101+
for (auto val : n->outputs()) {
102+
for (auto use : val->uses()) {
103+
dependent_nodes.insert(use.user);
104+
}
105+
}
106+
if (const auto* schema = n->maybeSchema()) {
107+
for (size_t i = 0; i < n->inputs().size(); ++i) {
108+
const at::AliasInfo* formal = schema->arguments()[i].alias_info();
109+
if (formal && formal->isWrite()) {
110+
for (auto use : n->inputs()[i]->uses()) {
111+
torch::jit::Node* use_node = use.user;
112+
if (use_node->isAfter(n)) {
113+
dependent_nodes.insert(use_node);
114+
}
115+
}
116+
}
117+
}
118+
}
119+
return dependent_nodes;
120+
}
121+
99122
// check if the input and output of the graph is Tensor after collection is enabled. If it is, then fallback related
100123
// nodes
101124
void fallback_graph_nontensor_in_out(
@@ -145,6 +168,7 @@ void find_all_fallback_nodes(
145168
for (auto use : output->uses()) {
146169
auto node = use.user;
147170
if (node->kind() != torch::jit::prim::Constant &&
171+
148172
global_fallback_nodes.insert({node, FallbackNodeType::kNON_TENSOR}).second) {
149173
q.push(node);
150174
}
@@ -319,6 +343,7 @@ std::vector<torch::jit::Node*> traverse_nodes_for_min_block_size(
319343
size_t min_block_size) {
320344
auto nodes = block->nodes();
321345
std::vector<torch::jit::Node*> cur_trt_nodes;
346+
std::unordered_set<torch::jit::Node*> cur_trt_nodes_uses;
322347
std::vector<torch::jit::Node*> min_block_fallback_nodes;
323348
for (const auto n : nodes) {
324349
if (n->kind() == torch::jit::prim::Constant)
@@ -328,11 +353,16 @@ std::vector<torch::jit::Node*> traverse_nodes_for_min_block_size(
328353
if (!global_fallback_nodes.count(n)) {
329354
// if this node is not in fallback nodes, then it's in trt segments
330355
cur_trt_nodes.push_back(n);
356+
auto dependent_nodes = getDependentNodes(n);
357+
cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
331358
} else {
332-
if (cur_trt_nodes.size() < min_block_size) {
333-
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
359+
if (cur_trt_nodes_uses.count(n)) {
360+
if (cur_trt_nodes.size() < min_block_size) {
361+
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
362+
}
363+
cur_trt_nodes.clear();
364+
cur_trt_nodes_uses.clear();
334365
}
335-
cur_trt_nodes.clear();
336366
}
337367
}
338368
if (cur_trt_nodes.size() < min_block_size) {
@@ -362,6 +392,59 @@ void find_min_block_size_fallback_nodes(
362392
}
363393
}
364394

395+
void merge_adjacent_segments_list_in_new_partition(
396+
PartitionedGraph& original_partition,
397+
PartitionedGraph& new_partition,
398+
SegmentedBlock::SegmentedBlockTarget& segment_kind,
399+
std::vector<size_t>& same_type_segment_idx) {
400+
TORCHTRT_CHECK(!same_type_segment_idx.empty(), "Unable to merge empty segment list");
401+
if (same_type_segment_idx.size() == 1) {
402+
new_partition.push_back(original_partition[same_type_segment_idx[0]]);
403+
} else {
404+
auto first_idx = same_type_segment_idx[0];
405+
for (size_t i = 1; i < same_type_segment_idx.size(); ++i) {
406+
TORCHTRT_CHECK(
407+
same_type_segment_idx[i] == (first_idx + i),
408+
"Unable to merge non-sequential segments: " << same_type_segment_idx);
409+
}
410+
LOG_DEBUG(
411+
"Merging adjacent " << SegmentedBlock::target_to_str(segment_kind) << " segments: " << same_type_segment_idx);
412+
std::vector<torch::jit::Node*> nodes;
413+
for (auto segment_to_merge : same_type_segment_idx) {
414+
const auto& merge_nodes = original_partition[segment_to_merge].raw_nodes();
415+
nodes.insert(nodes.end(), merge_nodes.begin(), merge_nodes.end());
416+
}
417+
new_partition.emplace_back(segment_kind, nodes);
418+
}
419+
}
420+
421+
PartitionedGraph merge_adjacent_segments_of_same_type(PartitionedGraph& original_partition) {
422+
PartitionedGraph new_partition;
423+
SegmentedBlock::SegmentedBlockTarget segment_kind = SegmentedBlock::SegmentedBlockTarget::kTorch;
424+
std::vector<size_t> same_type_segment_idx;
425+
for (size_t i = 0UL; i < original_partition.size(); ++i) {
426+
auto& segment = original_partition[i];
427+
if (same_type_segment_idx.empty()) {
428+
segment_kind = segment.target();
429+
} else if (segment_kind != segment.target() || segment.do_not_merge()) {
430+
merge_adjacent_segments_list_in_new_partition(
431+
original_partition, new_partition, segment_kind, same_type_segment_idx);
432+
same_type_segment_idx.clear();
433+
segment_kind = segment.target();
434+
}
435+
if (segment.do_not_merge()) {
436+
new_partition.push_back(segment);
437+
} else {
438+
same_type_segment_idx.push_back(i);
439+
}
440+
}
441+
if (!same_type_segment_idx.empty()) {
442+
merge_adjacent_segments_list_in_new_partition(
443+
original_partition, new_partition, segment_kind, same_type_segment_idx);
444+
}
445+
return new_partition;
446+
}
447+
365448
PartitionedGraph segment_graph(
366449
torch::jit::Block* block,
367450
const PartitionInfo& partition_info,
@@ -387,56 +470,73 @@ PartitionedGraph segment_graph(
387470

388471
// segment the nodes
389472
std::vector<torch::jit::Node*> in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes;
473+
std::unordered_set<torch::jit::Node*> cur_trt_nodes_uses;
474+
std::unordered_set<torch::jit::Node*> cur_pyt_nodes_uses;
390475
for (const auto n : nodes) {
391476
// Skip constant nodes as they are resources for both kinds of modules
392477
if (n->kind() == torch::jit::prim::Constant) {
393478
continue;
394479
}
480+
auto dependent_nodes = getDependentNodes(n);
395481
// the outputs of trt subgraph shouldn't be collections
396482
if (check_node_fallback(n, global_fallback_nodes)) {
397483
in_prog_trt_blk_nodes.push_back(n);
484+
cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
398485

399-
// If there is an active PyTorch block and we have passed the threshold for a valid TRT
400-
// block then segment and reset the active PyTorch block
401-
if (in_prog_trt_blk_nodes.size() >= min_block_size && !in_prog_pyt_blk_nodes.empty()) {
486+
// If we hit a TRT node that is dependent on nodes in the active PyTorch block, finalize the block to materialize
487+
// those dependencies in the graph
488+
if (cur_pyt_nodes_uses.count(n)) {
402489
finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
490+
cur_pyt_nodes_uses.clear();
403491
}
404492
} else {
405-
// If there is an active TRT block that is valid segment and reset the active TRT block
406-
// otherwise add it to the active PyTorch block and reset
407-
if (in_prog_trt_blk_nodes.size() >= min_block_size) {
408-
finalize_block(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes);
409-
} else {
410-
LOG_DEBUG(
411-
"In progress TRT block does not meet minimum block size requirements, therefore folding into in progress PyTorch block");
412-
in_prog_pyt_blk_nodes.insert(
413-
in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
493+
// The current node is dependent on the active TRT block, finalize it to materialize those dependencies in the
494+
// graph or add them to the active PyTorch block
495+
if (cur_trt_nodes_uses.count(n)) {
496+
// If there is an active TRT block that is valid segment and reset the active TRT block
497+
// otherwise add it to the active PyTorch block and reset
498+
if (in_prog_trt_blk_nodes.size() >= min_block_size) {
499+
finalize_block(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes);
500+
} else {
501+
LOG_DEBUG(
502+
"In progress TRT block does not meet minimum block size requirements, therefore folding into in progress PyTorch block");
503+
in_prog_pyt_blk_nodes.insert(
504+
in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
505+
cur_pyt_nodes_uses.insert(cur_trt_nodes_uses.begin(), cur_trt_nodes_uses.end());
506+
}
507+
in_prog_trt_blk_nodes.clear();
508+
cur_trt_nodes_uses.clear();
414509
}
415-
in_prog_trt_blk_nodes.clear();
416510
// if there is a prim::If then this if node will be encapsulated in a SegmentedBlock
417511
// we shouldn't inject node for this block in dependency analysis process
418512
if (n->kind() == torch::jit::prim::If) {
419513
LOG_DEBUG(
420514
"Hit a conditional statement, finializing in progress PYT block and creating a new one for the conditional");
421515
if (!in_prog_pyt_blk_nodes.empty()) {
422516
finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
517+
cur_pyt_nodes_uses.clear();
423518
}
424519
auto cond_node = std::vector<torch::jit::Node*>{n};
425520
finalize_block(segmented_blocks, SegmentedBlock::kTorch, cond_node);
521+
segmented_blocks.back().do_not_merge(true);
426522
continue;
427523
} else if (n->kind() == torch::jit::prim::Loop) {
428524
if (!in_prog_pyt_blk_nodes.empty()) {
429525
finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
526+
cur_pyt_nodes_uses.clear();
430527
}
431528
if (checkLoopEvaluatable(n)) {
432529
in_prog_trt_blk_nodes.push_back(n);
530+
cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
433531
} else {
434532
auto loop_node = std::vector<torch::jit::Node*>{n};
435533
finalize_block(segmented_blocks, SegmentedBlock::kTorch, loop_node);
534+
segmented_blocks.back().do_not_merge(true);
436535
}
437536
continue;
438537
}
439538
in_prog_pyt_blk_nodes.push_back(n);
539+
cur_pyt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
440540
}
441541
}
442542

@@ -451,6 +551,8 @@ PartitionedGraph segment_graph(
451551
in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
452552
finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
453553
}
554+
555+
segmented_blocks = merge_adjacent_segments_of_same_type(segmented_blocks);
454556
return segmented_blocks;
455557
}
456558

@@ -465,7 +567,7 @@ PartitionedGraph Partition(
465567
fallback_graph_nontensor_in_out(block, global_fallback_nodes);
466568

467569
// segment lowering global graph into blocks
468-
LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks");
570+
LOG_DEBUG("Partitioning source module into PyTorch and TensorRT sub blocks");
469571
PartitionedGraph segmented_blocks = segment_graph(block, partition_info, global_fallback_nodes);
470572

471573
// It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks

tests/core/partitioning/test_conditionals.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) {
4040

4141
auto conditional_engines_count = count_trt_engines_in_conditionals(new_g);
4242

43-
ASSERT_TRUE(conditional_engines_count == 2);
43+
ASSERT_TRUE(conditional_engines_count == 1);
4444
}
4545

4646
TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) {

tests/core/partitioning/test_resolve_nontensor_inputs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ TEST(Partitioning, ResolveTensorListInputsInTrtCorrectly) {
201201
}));
202202
}
203203
}
204-
ASSERT_TRUE(trt_block_cnt == 2 && torch_block_cnt == 2);
204+
ASSERT_TRUE(trt_block_cnt == 1 && torch_block_cnt == 1);
205205
}
206206

207207
TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) {

0 commit comments

Comments
 (0)