Skip to content

Commit 6147d4f

Browse files
committed
chore: Support more types conversion for minigraph inputs
Signed-off-by: Bo Wang <[email protected]>
1 parent da09e4b commit 6147d4f

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

core/partitioning/partitioning.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,19 +91,18 @@ void registerSegmentInOutIValues(
9191

9292
// set inputs ivalues, now supports Tensor/Int to pass argumentes between different segments
9393
for (auto& input : seg_block.raw_inputs()) {
94-
if (!ivalues_maps.count(input)) {
95-
std::cerr << "could find graph input ivalues\n";
96-
}
94+
TRTORCH_CHECK(ivalues_maps.count(input), "Could not find mini graph input IValue " << input->debugName());
9795
if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) {
9896
jit_inputs_ivalues.push_back(ivalues_maps[input].toTensor());
9997
} else if (input->type()->isSubtypeOf(torch::jit::IntType::get())) {
10098
jit_inputs_ivalues.push_back(ivalues_maps[input].toInt());
10199
} else if (input->type()->isSubtypeOf(torch::jit::BoolType::get())) {
102100
jit_inputs_ivalues.push_back(ivalues_maps[input].toBool());
103-
} else if (input->type()->isSubtypeOf(torch::jit::ListType::ofTensors())) {
101+
} else if (input->type()->kind() == torch::jit::TypeKind::ListType) {
104102
jit_inputs_ivalues.push_back(ivalues_maps[input].toList());
105103
} else {
106-
std::cerr << "Currently not support the type cast for input type " << input->type()->str() << ".\n";
104+
TRTORCH_CHECK(input->type()->kind() == torch::jit::TypeKind::TupleType, "Input for mini graph is not TupleType.");
105+
jit_inputs_ivalues.push_back(ivalues_maps[input].toTuple());
107106
}
108107
}
109108

0 commit comments

Comments
 (0)