From 3e090ee44bfa67f3e84b84dff6357528d96e53c5 Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Tue, 3 May 2022 11:40:53 -0700 Subject: [PATCH 1/3] fix: Avoid resolving non-tensor inputs to torch segment_blocks unneccessarily Signed-off-by: Michael Feliz --- core/partitioning/partitioning.cpp | 62 +++++----- .../test_resolve_nontensor_inputs.cpp | 110 ++++++++++++++++++ 2 files changed, 143 insertions(+), 29 deletions(-) diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 2a1e3f8943..04aed0532a 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -137,17 +137,8 @@ std::pair, SegmentedBlock return std::pair, SegmentedBlock>(append_blocks, trt_block); } -PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) { - // reconstruct segmented_block if this block requires nonTensor input - std::vector nontensor_inputs; - // Gather all non-tensor inputs for this seg_block - for (auto input : seg_block.raw_inputs()) { - if (!isTensorOrTensorList(input)) { - nontensor_inputs.push_back(input); - } - } - - std::vector dependency_nodes = getDependencyNodes(nontensor_inputs); +PartitionedGraph segmentBlocksWithSpecifiedInputs(SegmentedBlock& seg_block, std::vector inputs_to_resolve){ + std::vector dependency_nodes = getDependencyNodes(inputs_to_resolve); PartitionedGraph new_seg_blocks; // if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, merge the // dependency nodes at the beginning of the current segmented_block and return this merged segmented_block @@ -162,7 +153,7 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) { } } else { // if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again - std::unordered_set nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end()); + std::unordered_set inputs_to_resolve_set(inputs_to_resolve.begin(), inputs_to_resolve.end()); std::vector tensorrt_nodes, pytorch_nodes(dependency_nodes.begin(), dependency_nodes.end()); bool prev_non_tensor_outputs = false; @@ -170,7 +161,7 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) { // Check if the node has non-tensor inputs or if it consumes non-tensor outputs of previous node. // In these cases, these nodes are placed into a new Pytorch SegmentedBlock. Else, they form a new TensorRT // SegmentedBlock. - if (containTargetInputs(n, nontensor_inputs_set) || prev_non_tensor_outputs) { + if (containTargetInputs(n, inputs_to_resolve_set) || prev_non_tensor_outputs) { // If tensorrt_nodes is not empty, the previous nodes were all tensorrt_nodes. Construct a // TensorRT segmented_block and clear the tensorrt_nodes list to be later used for new TRT segments. if (!tensorrt_nodes.empty()) { @@ -201,6 +192,18 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) { return new_seg_blocks; } +PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) { + // reconstruct segmented_block if this block requires nonTensor input + std::vector inputs_to_resolve; + // Gather all non-tensor inputs for this block + for (auto input : seg_block.raw_inputs()) { + if (!isTensorOrTensorList(input)) { + inputs_to_resolve.push_back(input); + } + } + return segmentBlocksWithSpecifiedInputs(seg_block, inputs_to_resolve); +} + std::unordered_map getInputUsageCounts( const PartitionedGraph& segmented_blocks, const std::function& condition) { @@ -248,6 +251,9 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) { segmented_blocks, [](torch::jit::Value* input) -> bool { return !isTensorOrTensorList(input); }); auto idx_to_iter = getIdxtoIterMap(segmented_blocks_list); + std::map> torch_values_to_fix; //Only need to resolve values generated by tensorrt + std::set tensorrt_blocks_to_fix; //Need to resolve ALL non-tensor inputs + // update blocks_list std::unordered_set updated_segments; for (auto& use : usage_counts) { @@ -256,27 +262,25 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) { // kTorch segment. if (segmented_blocks[use_info.produce_id].target() == SegmentedBlock::kTensorRT && !use_info.torch_use_id.empty()) { auto first_torch_id = use_info.torch_use_id.back(); - if (!updated_segments.count(first_torch_id)) { - // Segmented Blocks with non-tensor inputs will have to be re-segmented as - // Torch-TensorRT doesn't support non-tensor inputs for a module. - auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]); - auto next_iter = segmented_blocks_list.erase(idx_to_iter[first_torch_id]); - segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end()); - updated_segments.insert(first_torch_id); - } + torch_values_to_fix[first_torch_id].push_back(use.first); } // kTensorRT segments always need to inject nodes for the nonTensor inputs for (auto i : use_info.tensorrt_use_id) { - if (!updated_segments.count(i)) { - // Segmented Blocks with non-tensor inputs will have to be re-segmented as - // Torch-TensorRT doesn't support non-tensor inputs for a module. - auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]); - auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]); - segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end()); - updated_segments.insert(i); - } + tensorrt_blocks_to_fix.insert(i); } } + for(auto torch_block_pair : torch_values_to_fix){ + auto to_inject_blocks = segmentBlocksWithSpecifiedInputs(segmented_blocks[torch_block_pair.first], torch_block_pair.second); + auto next_iter = segmented_blocks_list.erase(idx_to_iter[torch_block_pair.first]); + segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end()); + } + + for(auto i : tensorrt_blocks_to_fix){ + auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]); + auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]); + segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end()); + } + segmented_blocks.clear(); segmented_blocks.insert(segmented_blocks.begin(), segmented_blocks_list.begin(), segmented_blocks_list.end()); return; diff --git a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp index a83d2330e4..ee04c2ed12 100644 --- a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp +++ b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp @@ -257,3 +257,113 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) { int count = count_trt_engines(fallback_g); ASSERT_TRUE(count == 2); } + +TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) { + /* parseIR does not support "= aten::_set_item" so we will build this graph manually + const auto graph = R"IR( + graph(%x : Tensor, + %y : Tensor): + %2 : str = prim::Constant[value="INS"]() + %3 : str = prim::Constant[value="OUTS"]() + %4 : bool = prim::Constant[value=0]() + %5 : int = prim::Constant[value=-1]() + %6 : Dict(str, Tensor) = prim::DictConstruct() + = aten::_set_item(%6, %2, %x) + %7 : Tensor = aten::__getitem__(%6, %2) + %8 : Tensor = aten::lt(%7, %y) + %9 : Tensor?[] = prim::ListConstruct(%8) + %10 : int = prim::dtype(%7) + %11 : Device = prim::device(%7) + %12 : Tensor = aten::tensor(%5, %10, %11, %4) + %13 : Tensor = aten::index_put_(%7, %9, %12, %4) + = aten::_set_item(%6, %3, %7) + %14 : Tensor = aten::__getitem__(%6, %2) + %15 : Tensor = aten::__getitem__(%6, %3) + return (%14, %15))IR"; + */ + auto g = std::make_shared(); + auto x = g->insertInput(0, "x"); + auto y = g->insertInput(1, "y"); + torch::jit::IValue ins_key("INS"); + auto ins_key_val = g->insertConstant(ins_key); + torch::jit::IValue outs_key("OUTS"); + auto outs_key_val = g->insertConstant(outs_key); + torch::jit::IValue zero(0); + auto false_const_val = g->insertConstant(zero); + false_const_val->setType(c10::BoolType::get()); + torch::jit::IValue neg_one(-1); + auto neg_one_const_val = g->insertConstant(neg_one); + auto dict_node = g->createDict(ins_key_val->type(), x->type(), torch::jit::ArrayRef(), torch::jit::ArrayRef()); + g->insertNode(dict_node); + auto set_node = g->create(torch::jit::Symbol::fromQualString("aten::_set_item"), torch::jit::ArrayRef{dict_node->output(), ins_key_val, x}, 0); + g->insertNode(set_node); + auto get_node = g->create(torch::jit::Symbol::fromQualString("aten::__getitem__"), torch::jit::ArrayRef{dict_node->output(), ins_key_val}, 1); + g->insertNode(get_node); + auto lt_node = g->create(torch::jit::Symbol::fromQualString("aten::lt"), torch::jit::ArrayRef{get_node->output(), y}, 1); + g->insertNode(lt_node); + auto list_node = g->createList(at::OptionalType::create(lt_node->output()->type()), torch::jit::ArrayRef{lt_node->output()}); + g->insertNode(list_node); + auto dtype_node = g->create(torch::jit::Symbol::fromQualString("prim::dtype"), torch::jit::ArrayRef{get_node->output()}, 1); + dtype_node->output()->setType(neg_one_const_val->type()); + g->insertNode(dtype_node); + auto device_node = g->create(torch::jit::Symbol::fromQualString("prim::device"), torch::jit::ArrayRef{get_node->output()}, 1); + device_node->output()->setType(c10::DeviceObjType::get()); + g->insertNode(device_node); + auto tensor_node = g->create(torch::jit::Symbol::fromQualString("aten::tensor"), torch::jit::ArrayRef{neg_one_const_val, dtype_node->output(), device_node->output(), false_const_val}, 1); + g->insertNode(tensor_node); + auto index_put_node = g->create(torch::jit::Symbol::fromQualString("aten::index_put_"), + torch::jit::ArrayRef{get_node->output(), list_node->output(), tensor_node->output(), false_const_val}, 1); + g->insertNode(index_put_node); + auto out_set_node = g->create(torch::jit::Symbol::fromQualString("aten::_set_item"), + torch::jit::ArrayRef{dict_node->output(), outs_key_val, get_node->output()}, 0); + g->insertNode(out_set_node); + auto get_ins_node = g->create(torch::jit::Symbol::fromQualString("aten::__getitem__"), torch::jit::ArrayRef{dict_node->output(), ins_key_val}, 1); + g->insertNode(get_ins_node); + auto get_outs_node = g->create(torch::jit::Symbol::fromQualString("aten::__getitem__"), torch::jit::ArrayRef{dict_node->output(), outs_key_val}, 1); + g->insertNode(get_outs_node); + g->registerOutput(get_ins_node->output()); + g->registerOutput(get_outs_node->output()); + + torch_tensorrt::core::partitioning::PartitionInfo partition_info; + partition_info.enabled = true; + std::vector inputs; + inputs.push_back(torch_tensorrt::core::ir::Input({4, 4})); + inputs.push_back(torch_tensorrt::core::ir::Input({4, 4})); + + std::unordered_map inputs_map; + std::unordered_map> input_types; + for (size_t i = 0; i < g->inputs().size(); ++i) { + inputs_map.insert({g->inputs()[i], inputs[i]}); + input_types.insert({g->inputs()[i], {at::kFloat}}); + } + auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); + auto segmented_blocks = + torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info); + + int torch_block_cnt = 0, trt_block_cnt = 0; + for (const auto& segmented_block : segmented_blocks) { + if (segmented_block.target() == torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT) { + ++trt_block_cnt; + ASSERT_TRUE(checkSegmentedBlockInputType(segmented_block, [](torch::jit::TypePtr type_ptr) { + return type_ptr->isSubtypeOf(torch::jit::TensorType::get()); + })); + } else { + ++torch_block_cnt; + bool output_dict = false; + bool input_dict = false; + auto dict_type = dict_node->output()->type(); + for (auto in : segmented_block.raw_inputs()) { + if(in->type()->isSubtypeOf(dict_type)){ + input_dict = true; + } + } + for (auto out : segmented_block.raw_outputs()) { + if(out->type()->isSubtypeOf(dict_type)){ + output_dict = true; + } + } + EXPECT_TRUE(output_dict ^ input_dict); + } + } + ASSERT_TRUE(trt_block_cnt == 1 && torch_block_cnt == 2); +} From 58df59ac3a77db2d5d2f8dc7416dbcae904c06fa Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Tue, 3 May 2022 11:42:02 -0700 Subject: [PATCH 2/3] Convert segmentBlocksWithSpecifiedInputs input vector to a const ref Signed-off-by: Michael Feliz --- core/partitioning/partitioning.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 04aed0532a..6f0a932298 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -57,7 +57,7 @@ bool containNonTensorOutputs(torch::jit::Node* n) { return false; } -std::vector getDependencyNodes(std::vector& vals) { +std::vector getDependencyNodes(const std::vector& vals) { // use bfs to get the DAG dependency nodes for input value std::queue> q( std::deque(vals.begin(), vals.end())); @@ -137,7 +137,7 @@ std::pair, SegmentedBlock return std::pair, SegmentedBlock>(append_blocks, trt_block); } -PartitionedGraph segmentBlocksWithSpecifiedInputs(SegmentedBlock& seg_block, std::vector inputs_to_resolve){ +PartitionedGraph segmentBlocksWithSpecifiedInputs(SegmentedBlock& seg_block, const std::vector &inputs_to_resolve){ std::vector dependency_nodes = getDependencyNodes(inputs_to_resolve); PartitionedGraph new_seg_blocks; // if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, merge the From b3d9a212a180c5e661ceeb840348f45992863a0c Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Tue, 3 May 2022 16:05:29 -0700 Subject: [PATCH 3/3] Linting Signed-off-by: Michael Feliz --- core/partitioning/partitioning.cpp | 16 +++-- .../test_resolve_nontensor_inputs.cpp | 72 ++++++++++++++----- 2 files changed, 63 insertions(+), 25 deletions(-) diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 6f0a932298..bffd3b4748 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -137,7 +137,9 @@ std::pair, SegmentedBlock return std::pair, SegmentedBlock>(append_blocks, trt_block); } -PartitionedGraph segmentBlocksWithSpecifiedInputs(SegmentedBlock& seg_block, const std::vector &inputs_to_resolve){ +PartitionedGraph segmentBlocksWithSpecifiedInputs( + SegmentedBlock& seg_block, + const std::vector& inputs_to_resolve) { std::vector dependency_nodes = getDependencyNodes(inputs_to_resolve); PartitionedGraph new_seg_blocks; // if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, merge the @@ -251,8 +253,9 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) { segmented_blocks, [](torch::jit::Value* input) -> bool { return !isTensorOrTensorList(input); }); auto idx_to_iter = getIdxtoIterMap(segmented_blocks_list); - std::map> torch_values_to_fix; //Only need to resolve values generated by tensorrt - std::set tensorrt_blocks_to_fix; //Need to resolve ALL non-tensor inputs + std::map> + torch_values_to_fix; // Only need to resolve values generated by tensorrt + std::set tensorrt_blocks_to_fix; // Need to resolve ALL non-tensor inputs // update blocks_list std::unordered_set updated_segments; @@ -269,13 +272,14 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) { tensorrt_blocks_to_fix.insert(i); } } - for(auto torch_block_pair : torch_values_to_fix){ - auto to_inject_blocks = segmentBlocksWithSpecifiedInputs(segmented_blocks[torch_block_pair.first], torch_block_pair.second); + for (auto torch_block_pair : torch_values_to_fix) { + auto to_inject_blocks = + segmentBlocksWithSpecifiedInputs(segmented_blocks[torch_block_pair.first], torch_block_pair.second); auto next_iter = segmented_blocks_list.erase(idx_to_iter[torch_block_pair.first]); segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end()); } - for(auto i : tensorrt_blocks_to_fix){ + for (auto i : tensorrt_blocks_to_fix) { auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]); auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]); segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end()); diff --git a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp index ee04c2ed12..facdd31151 100644 --- a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp +++ b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp @@ -293,33 +293,68 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) { false_const_val->setType(c10::BoolType::get()); torch::jit::IValue neg_one(-1); auto neg_one_const_val = g->insertConstant(neg_one); - auto dict_node = g->createDict(ins_key_val->type(), x->type(), torch::jit::ArrayRef(), torch::jit::ArrayRef()); + auto dict_node = g->createDict( + ins_key_val->type(), + x->type(), + torch::jit::ArrayRef(), + torch::jit::ArrayRef()); g->insertNode(dict_node); - auto set_node = g->create(torch::jit::Symbol::fromQualString("aten::_set_item"), torch::jit::ArrayRef{dict_node->output(), ins_key_val, x}, 0); + auto set_node = g->create( + torch::jit::Symbol::fromQualString("aten::_set_item"), + torch::jit::ArrayRef{dict_node->output(), ins_key_val, x}, + 0); g->insertNode(set_node); - auto get_node = g->create(torch::jit::Symbol::fromQualString("aten::__getitem__"), torch::jit::ArrayRef{dict_node->output(), ins_key_val}, 1); + auto get_node = g->create( + torch::jit::Symbol::fromQualString("aten::__getitem__"), + torch::jit::ArrayRef{dict_node->output(), ins_key_val}, + 1); g->insertNode(get_node); - auto lt_node = g->create(torch::jit::Symbol::fromQualString("aten::lt"), torch::jit::ArrayRef{get_node->output(), y}, 1); + auto lt_node = g->create( + torch::jit::Symbol::fromQualString("aten::lt"), + torch::jit::ArrayRef{get_node->output(), y}, + 1); g->insertNode(lt_node); - auto list_node = g->createList(at::OptionalType::create(lt_node->output()->type()), torch::jit::ArrayRef{lt_node->output()}); + auto list_node = g->createList( + at::OptionalType::create(lt_node->output()->type()), torch::jit::ArrayRef{lt_node->output()}); g->insertNode(list_node); - auto dtype_node = g->create(torch::jit::Symbol::fromQualString("prim::dtype"), torch::jit::ArrayRef{get_node->output()}, 1); + auto dtype_node = g->create( + torch::jit::Symbol::fromQualString("prim::dtype"), + torch::jit::ArrayRef{get_node->output()}, + 1); dtype_node->output()->setType(neg_one_const_val->type()); g->insertNode(dtype_node); - auto device_node = g->create(torch::jit::Symbol::fromQualString("prim::device"), torch::jit::ArrayRef{get_node->output()}, 1); + auto device_node = g->create( + torch::jit::Symbol::fromQualString("prim::device"), + torch::jit::ArrayRef{get_node->output()}, + 1); device_node->output()->setType(c10::DeviceObjType::get()); g->insertNode(device_node); - auto tensor_node = g->create(torch::jit::Symbol::fromQualString("aten::tensor"), torch::jit::ArrayRef{neg_one_const_val, dtype_node->output(), device_node->output(), false_const_val}, 1); + auto tensor_node = g->create( + torch::jit::Symbol::fromQualString("aten::tensor"), + torch::jit::ArrayRef{ + neg_one_const_val, dtype_node->output(), device_node->output(), false_const_val}, + 1); g->insertNode(tensor_node); - auto index_put_node = g->create(torch::jit::Symbol::fromQualString("aten::index_put_"), - torch::jit::ArrayRef{get_node->output(), list_node->output(), tensor_node->output(), false_const_val}, 1); + auto index_put_node = g->create( + torch::jit::Symbol::fromQualString("aten::index_put_"), + torch::jit::ArrayRef{ + get_node->output(), list_node->output(), tensor_node->output(), false_const_val}, + 1); g->insertNode(index_put_node); - auto out_set_node = g->create(torch::jit::Symbol::fromQualString("aten::_set_item"), - torch::jit::ArrayRef{dict_node->output(), outs_key_val, get_node->output()}, 0); + auto out_set_node = g->create( + torch::jit::Symbol::fromQualString("aten::_set_item"), + torch::jit::ArrayRef{dict_node->output(), outs_key_val, get_node->output()}, + 0); g->insertNode(out_set_node); - auto get_ins_node = g->create(torch::jit::Symbol::fromQualString("aten::__getitem__"), torch::jit::ArrayRef{dict_node->output(), ins_key_val}, 1); + auto get_ins_node = g->create( + torch::jit::Symbol::fromQualString("aten::__getitem__"), + torch::jit::ArrayRef{dict_node->output(), ins_key_val}, + 1); g->insertNode(get_ins_node); - auto get_outs_node = g->create(torch::jit::Symbol::fromQualString("aten::__getitem__"), torch::jit::ArrayRef{dict_node->output(), outs_key_val}, 1); + auto get_outs_node = g->create( + torch::jit::Symbol::fromQualString("aten::__getitem__"), + torch::jit::ArrayRef{dict_node->output(), outs_key_val}, + 1); g->insertNode(get_outs_node); g->registerOutput(get_ins_node->output()); g->registerOutput(get_outs_node->output()); @@ -337,10 +372,9 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) { input_types.insert({g->inputs()[i], {at::kFloat}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); - auto segmented_blocks = - torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info); + auto segmented_blocks = torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info); - int torch_block_cnt = 0, trt_block_cnt = 0; + int torch_block_cnt = 0, trt_block_cnt = 0; for (const auto& segmented_block : segmented_blocks) { if (segmented_block.target() == torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT) { ++trt_block_cnt; @@ -353,12 +387,12 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) { bool input_dict = false; auto dict_type = dict_node->output()->type(); for (auto in : segmented_block.raw_inputs()) { - if(in->type()->isSubtypeOf(dict_type)){ + if (in->type()->isSubtypeOf(dict_type)) { input_dict = true; } } for (auto out : segmented_block.raw_outputs()) { - if(out->type()->isSubtypeOf(dict_type)){ + if (out->type()->isSubtypeOf(dict_type)) { output_dict = true; } }