Skip to content

Commit 72e56e7

Browse files
authored
[TRT RTX EP] Fix bug for generating the correct subgraph in GetCapability (#26132)
### Description In current TRT RTX EP/ TRT EP implementation, when constructing the `IndexedSubGraph`, for some cases, it will include the node's unused output as the SubGraph's output. So, it will return the incorrect `IndexedSubGraph` from its GetCapability to ORT. Add the logic to prevent adding the unused node's output. With this fix, we can avoid generating the incorrect EPContext model where the EPContext node has unused output.
1 parent bdffd76 commit 72e56e7

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,12 +1466,15 @@ std::unique_ptr<IndexedSubGraph> NvExecutionProvider::GetSubGraph(SubGraph_t gra
14661466
fused_inputs.erase(it);
14671467
erased.insert(output);
14681468
}
1469-
// Only when output is neither in input list nor erased list, add the output to output list
1469+
// Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list
14701470
else if (erased.find(output) == erased.end()) {
14711471
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
14721472
graph_outputs_to_add[output] = output_order;
14731473
}
1474-
fused_outputs[output] = output_order++;
1474+
1475+
if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
1476+
fused_outputs[output] = output_order++;
1477+
}
14751478
}
14761479
}
14771480
}

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2114,12 +2114,15 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
21142114
fused_inputs.erase(it);
21152115
erased.insert(output);
21162116
}
2117-
// Only when output is neither in input list nor erased list, add the output to output list
2117+
// Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list
21182118
else if (erased.find(output) == erased.end()) {
21192119
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
21202120
graph_outputs_to_add[output] = output_order;
21212121
}
2122-
fused_outputs[output] = output_order++;
2122+
2123+
if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
2124+
fused_outputs[output] = output_order++;
2125+
}
21232126
}
21242127
}
21252128
}

0 commit comments

Comments
 (0)