Skip to content

Commit da09e4b

Browse files
committed
chore: support passing BoolType/ListType arguments for segments
Signed-off-by: Bo Wang <[email protected]>
1 parent d90a300 commit da09e4b

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

core/partitioning/partitioning.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "partitioning.h"
2+
#include "core/conversion/evaluators/eval_util.h"
23
#include "core/lowering/passes/passes.h"
34
#include "core/util/prelude.h"
45
#include "torch/csrc/jit/api/module.h"
@@ -97,6 +98,12 @@ void registerSegmentInOutIValues(
9798
jit_inputs_ivalues.push_back(ivalues_maps[input].toTensor());
9899
} else if (input->type()->isSubtypeOf(torch::jit::IntType::get())) {
99100
jit_inputs_ivalues.push_back(ivalues_maps[input].toInt());
101+
} else if (input->type()->isSubtypeOf(torch::jit::BoolType::get())) {
102+
jit_inputs_ivalues.push_back(ivalues_maps[input].toBool());
103+
} else if (input->type()->isSubtypeOf(torch::jit::ListType::ofTensors())) {
104+
jit_inputs_ivalues.push_back(ivalues_maps[input].toList());
105+
} else {
106+
std::cerr << "Currently not support the type cast for input type " << input->type()->str() << ".\n";
100107
}
101108
}
102109

0 commit comments

Comments
 (0)