diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 018b565421..4632744790 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -353,11 +353,9 @@ auto aten_registrations TORCHTRT_UNUSED = return {}; } }, - EvalOptions().validSchemas({ - "aten::add.int(int a, int b) -> (int)", - "aten::add.float(float a, float b) -> (float)", - "aten::add.str(str a, str b) -> (str)" - })}) + EvalOptions().validSchemas({"aten::add.int(int a, int b) -> (int)", + "aten::add.float(float a, float b) -> (float)", + "aten::add.str(str a, str b) -> (str)"})}) .evaluator({c10::Symbol::fromQualString("aten::add_"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isList()) { diff --git a/core/lowering/passes/exception_elimination.cpp b/core/lowering/passes/exception_elimination.cpp index 77aec78b08..63581af1ca 100644 --- a/core/lowering/passes/exception_elimination.cpp +++ b/core/lowering/passes/exception_elimination.cpp @@ -44,9 +44,9 @@ struct ExceptionOrPassPatternElimination { bool arm1_starts_with_exception = (*arm1_start)->kind() == prim::RaiseException; bool arm2_starts_with_exception = (*arm2_start)->kind() == prim::RaiseException; - //if (!arm1_starts_with_exception && !arm2_starts_with_exception) { - // Neither arm matches the pattern - // return false; + // if (!arm1_starts_with_exception && !arm2_starts_with_exception) { + // Neither arm matches the pattern + // return false; //} /// Check if this Node hosts a pattern like so: @@ -90,7 +90,7 @@ struct ExceptionOrPassPatternElimination { for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) { auto n = *it; if (n->kind() == prim::If && isExceptionOrPassNode(n)) { - LOG_ERROR("Found that node " << *n << " is an exception or pass node (EliminateChecks)" << std::endl); + LOG_GRAPH("Found that node " << *n << " is an exception or pass node (EliminateChecks)" << std::endl); it.destroyCurrent(); } } @@ -104,7 +104,7 @@ void EliminateExceptionOrPassPattern(std::shared_ptr graph) { ExceptionOrPassPatternElimination eppe(std::move(graph)); eppe.run(); if (graph) { - LOG_ERROR("Post Eliminate Exception or Pass Patterns: " << *graph); + LOG_GRAPH("Post Eliminate Exception or Pass Patterns: " << *graph); } } diff --git a/core/lowering/register_trt_placeholder_ops.cpp b/core/lowering/register_trt_placeholder_ops.cpp index 5ba8171208..17d7d3f47a 100644 --- a/core/lowering/register_trt_placeholder_ops.cpp +++ b/core/lowering/register_trt_placeholder_ops.cpp @@ -10,7 +10,10 @@ c10::AliasAnalysisKind aliasAnalysisFromSchema() { RegisterOperators trt_placeholder_ops_reg({ /// Op marks a Tensor to be conveted from an Torch Tensor /// to a TRT constant Tensor - Operator("trt::const(Tensor val) -> Tensor", [](Stack& stack) { /*noop*/ }, aliasAnalysisFromSchema()), + Operator( + "trt::const(Tensor val) -> Tensor", + [](Stack& stack) { /*noop*/ }, + aliasAnalysisFromSchema()), }); } // namespace jit diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp old mode 100755 new mode 100644 index ed6a38ec4f..63161217e4 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -57,7 +57,7 @@ bool containNonTensorOutputs(torch::jit::Node* n) { return false; } -std::vector getDependencyNodes(std::vector& vals) { +std::vector getDependencyNodes(const std::vector& vals) { // use bfs to get the DAG dependency nodes for input value std::queue> q( std::deque(vals.begin(), vals.end())); @@ -169,17 +169,10 @@ std::pair, SegmentedBlock return std::pair, SegmentedBlock>(append_blocks, trt_block); } -PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) { - // reconstruct segmented_block if this block requires nonTensor input - std::vector nontensor_inputs; - // Gather all non-tensor inputs for this seg_block - for (auto input : seg_block.raw_inputs()) { - if (!isTensorOrTensorList(input)) { - nontensor_inputs.push_back(input); - } - } - - std::vector dependency_nodes = getDependencyNodes(nontensor_inputs); +PartitionedGraph segmentBlocksWithSpecifiedInputs( + SegmentedBlock& seg_block, + std::vector& inputs_to_resolve) { + std::vector dependency_nodes = getDependencyNodes(inputs_to_resolve); PartitionedGraph new_seg_blocks; // if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, merge the // dependency nodes at the beginning of the current segmented_block and return this merged segmented_block @@ -194,7 +187,7 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) { } } else { // if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again - std::unordered_set nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end()); + std::unordered_set inputs_to_resolve_set(inputs_to_resolve.begin(), inputs_to_resolve.end()); std::vector tensorrt_nodes, pytorch_nodes; // take all nodes with non_tensor_inputs as initial dirty nodes (nodes that should be in PyTorch block), then we use @@ -205,7 +198,7 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) { seg_block.raw_nodes().begin(), seg_block.raw_nodes().end()); for (auto n : seg_block.raw_nodes()) { - if (containTargetInputs(n, nontensor_inputs_set)) { + if (containTargetInputs(n, inputs_to_resolve_set)) { dirty_nodes.insert(n); } } @@ -237,6 +230,18 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) { return new_seg_blocks; } +PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) { + // reconstruct segmented_block if this block requires nonTensor input + std::vector inputs_to_resolve; + // Gather all non-tensor inputs for this block + for (auto input : seg_block.raw_inputs()) { + if (!isTensorOrTensorList(input)) { + inputs_to_resolve.push_back(input); + } + } + return segmentBlocksWithSpecifiedInputs(seg_block, inputs_to_resolve); +} + std::unordered_map getInputUsageCounts( const PartitionedGraph& segmented_blocks, const std::function& condition) { @@ -284,6 +289,10 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) { segmented_blocks, [](torch::jit::Value* input) -> bool { return !isTensorOrTensorList(input); }); auto idx_to_iter = getIdxtoIterMap(segmented_blocks_list); + std::map> + torch_values_to_fix; // Only need to resolve values generated by tensorrt + std::set tensorrt_blocks_to_fix; // Need to resolve ALL non-tensor inputs + // update blocks_list std::unordered_set updated_segments; for (auto& use : usage_counts) { @@ -292,27 +301,26 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) { // kTorch segment. if (segmented_blocks[use_info.produce_id].target() == SegmentedBlock::kTensorRT && !use_info.torch_use_id.empty()) { auto first_torch_id = use_info.torch_use_id.back(); - if (!updated_segments.count(first_torch_id)) { - // Segmented Blocks with non-tensor inputs will have to be re-segmented as - // Torch-TensorRT doesn't support non-tensor inputs for a module. - auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]); - auto next_iter = segmented_blocks_list.erase(idx_to_iter[first_torch_id]); - segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end()); - updated_segments.insert(first_torch_id); - } + torch_values_to_fix[first_torch_id].push_back(use.first); } // kTensorRT segments always need to inject nodes for the nonTensor inputs for (auto i : use_info.tensorrt_use_id) { - if (!updated_segments.count(i)) { - // Segmented Blocks with non-tensor inputs will have to be re-segmented as - // Torch-TensorRT doesn't support non-tensor inputs for a module. - auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]); - auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]); - segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end()); - updated_segments.insert(i); - } + tensorrt_blocks_to_fix.insert(i); } } + for (auto torch_block_pair : torch_values_to_fix) { + auto to_inject_blocks = + segmentBlocksWithSpecifiedInputs(segmented_blocks[torch_block_pair.first], torch_block_pair.second); + auto next_iter = segmented_blocks_list.erase(idx_to_iter[torch_block_pair.first]); + segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end()); + } + + for (auto i : tensorrt_blocks_to_fix) { + auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]); + auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]); + segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end()); + } + segmented_blocks.clear(); segmented_blocks.insert(segmented_blocks.begin(), segmented_blocks_list.begin(), segmented_blocks_list.end()); return; diff --git a/tests/core/lowering/test_exception_elimination_pass.cpp b/tests/core/lowering/test_exception_elimination_pass.cpp index 5a0931ee8d..b7e4ac00d1 100644 --- a/tests/core/lowering/test_exception_elimination_pass.cpp +++ b/tests/core/lowering/test_exception_elimination_pass.cpp @@ -44,7 +44,7 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block0) { auto if_block0 = if_node->addBlock(); auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0); if_block0->appendNode(exception_node); - /*auto if_block1 =*/ if_node->addBlock(); + /*auto if_block1 =*/if_node->addBlock(); g->insertNode(if_node); auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val}); g->insertNode(cat_node); @@ -97,7 +97,7 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block1) { bool_node->output()->setType(torch::jit::BoolType::get()); g->insertNode(bool_node); auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0); - /*auto if_block0 = */if_node->addBlock(); + /*auto if_block0 = */ if_node->addBlock(); auto if_block1 = if_node->addBlock(); auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0); if_block1->appendNode(exception_node); @@ -154,7 +154,7 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Negative) { auto if_block0 = if_node->addBlock(); auto append_node = g->create(torch::jit::aten::append, {list_node->output(), y}); if_block0->appendNode(append_node); - /*auto if_block1 = */if_node->addBlock(); + /*auto if_block1 = */ if_node->addBlock(); g->insertNode(if_node); auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val}); g->insertNode(cat_node); diff --git a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp index a83d2330e4..facdd31151 100644 --- a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp +++ b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp @@ -257,3 +257,147 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) { int count = count_trt_engines(fallback_g); ASSERT_TRUE(count == 2); } + +TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) { + /* parseIR does not support "= aten::_set_item" so we will build this graph manually + const auto graph = R"IR( + graph(%x : Tensor, + %y : Tensor): + %2 : str = prim::Constant[value="INS"]() + %3 : str = prim::Constant[value="OUTS"]() + %4 : bool = prim::Constant[value=0]() + %5 : int = prim::Constant[value=-1]() + %6 : Dict(str, Tensor) = prim::DictConstruct() + = aten::_set_item(%6, %2, %x) + %7 : Tensor = aten::__getitem__(%6, %2) + %8 : Tensor = aten::lt(%7, %y) + %9 : Tensor?[] = prim::ListConstruct(%8) + %10 : int = prim::dtype(%7) + %11 : Device = prim::device(%7) + %12 : Tensor = aten::tensor(%5, %10, %11, %4) + %13 : Tensor = aten::index_put_(%7, %9, %12, %4) + = aten::_set_item(%6, %3, %7) + %14 : Tensor = aten::__getitem__(%6, %2) + %15 : Tensor = aten::__getitem__(%6, %3) + return (%14, %15))IR"; + */ + auto g = std::make_shared(); + auto x = g->insertInput(0, "x"); + auto y = g->insertInput(1, "y"); + torch::jit::IValue ins_key("INS"); + auto ins_key_val = g->insertConstant(ins_key); + torch::jit::IValue outs_key("OUTS"); + auto outs_key_val = g->insertConstant(outs_key); + torch::jit::IValue zero(0); + auto false_const_val = g->insertConstant(zero); + false_const_val->setType(c10::BoolType::get()); + torch::jit::IValue neg_one(-1); + auto neg_one_const_val = g->insertConstant(neg_one); + auto dict_node = g->createDict( + ins_key_val->type(), + x->type(), + torch::jit::ArrayRef(), + torch::jit::ArrayRef()); + g->insertNode(dict_node); + auto set_node = g->create( + torch::jit::Symbol::fromQualString("aten::_set_item"), + torch::jit::ArrayRef{dict_node->output(), ins_key_val, x}, + 0); + g->insertNode(set_node); + auto get_node = g->create( + torch::jit::Symbol::fromQualString("aten::__getitem__"), + torch::jit::ArrayRef{dict_node->output(), ins_key_val}, + 1); + g->insertNode(get_node); + auto lt_node = g->create( + torch::jit::Symbol::fromQualString("aten::lt"), + torch::jit::ArrayRef{get_node->output(), y}, + 1); + g->insertNode(lt_node); + auto list_node = g->createList( + at::OptionalType::create(lt_node->output()->type()), torch::jit::ArrayRef{lt_node->output()}); + g->insertNode(list_node); + auto dtype_node = g->create( + torch::jit::Symbol::fromQualString("prim::dtype"), + torch::jit::ArrayRef{get_node->output()}, + 1); + dtype_node->output()->setType(neg_one_const_val->type()); + g->insertNode(dtype_node); + auto device_node = g->create( + torch::jit::Symbol::fromQualString("prim::device"), + torch::jit::ArrayRef{get_node->output()}, + 1); + device_node->output()->setType(c10::DeviceObjType::get()); + g->insertNode(device_node); + auto tensor_node = g->create( + torch::jit::Symbol::fromQualString("aten::tensor"), + torch::jit::ArrayRef{ + neg_one_const_val, dtype_node->output(), device_node->output(), false_const_val}, + 1); + g->insertNode(tensor_node); + auto index_put_node = g->create( + torch::jit::Symbol::fromQualString("aten::index_put_"), + torch::jit::ArrayRef{ + get_node->output(), list_node->output(), tensor_node->output(), false_const_val}, + 1); + g->insertNode(index_put_node); + auto out_set_node = g->create( + torch::jit::Symbol::fromQualString("aten::_set_item"), + torch::jit::ArrayRef{dict_node->output(), outs_key_val, get_node->output()}, + 0); + g->insertNode(out_set_node); + auto get_ins_node = g->create( + torch::jit::Symbol::fromQualString("aten::__getitem__"), + torch::jit::ArrayRef{dict_node->output(), ins_key_val}, + 1); + g->insertNode(get_ins_node); + auto get_outs_node = g->create( + torch::jit::Symbol::fromQualString("aten::__getitem__"), + torch::jit::ArrayRef{dict_node->output(), outs_key_val}, + 1); + g->insertNode(get_outs_node); + g->registerOutput(get_ins_node->output()); + g->registerOutput(get_outs_node->output()); + + torch_tensorrt::core::partitioning::PartitionInfo partition_info; + partition_info.enabled = true; + std::vector inputs; + inputs.push_back(torch_tensorrt::core::ir::Input({4, 4})); + inputs.push_back(torch_tensorrt::core::ir::Input({4, 4})); + + std::unordered_map inputs_map; + std::unordered_map> input_types; + for (size_t i = 0; i < g->inputs().size(); ++i) { + inputs_map.insert({g->inputs()[i], inputs[i]}); + input_types.insert({g->inputs()[i], {at::kFloat}}); + } + auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); + auto segmented_blocks = torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info); + + int torch_block_cnt = 0, trt_block_cnt = 0; + for (const auto& segmented_block : segmented_blocks) { + if (segmented_block.target() == torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT) { + ++trt_block_cnt; + ASSERT_TRUE(checkSegmentedBlockInputType(segmented_block, [](torch::jit::TypePtr type_ptr) { + return type_ptr->isSubtypeOf(torch::jit::TensorType::get()); + })); + } else { + ++torch_block_cnt; + bool output_dict = false; + bool input_dict = false; + auto dict_type = dict_node->output()->type(); + for (auto in : segmented_block.raw_inputs()) { + if (in->type()->isSubtypeOf(dict_type)) { + input_dict = true; + } + } + for (auto out : segmented_block.raw_outputs()) { + if (out->type()->isSubtypeOf(dict_type)) { + output_dict = true; + } + } + EXPECT_TRUE(output_dict ^ input_dict); + } + } + ASSERT_TRUE(trt_block_cnt == 1 && torch_block_cnt == 2); +}