From da281adae805a0e7ceeea2effaf09296c1f0b226 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Sat, 14 May 2022 00:06:16 -0700 Subject: [PATCH 01/12] feat: refactoring segmentation in partitioning Signed-off-by: Bo Wang --- core/compiler.cpp | 4 +- core/partitioning/partitioning.cpp | 94 +++++++++++++++++++++++++++++- 2 files changed, 93 insertions(+), 5 deletions(-) diff --git a/core/compiler.cpp b/core/compiler.cpp index b684b808f5..e520f43664 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -198,7 +198,7 @@ void AddIfBlockToGraph( auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); }; new_if_block->cloneFrom(cur_block_graph->block(), env); - if (cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) { + if (cur_block_graph->inputs().size() && cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) { if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) { auto self = new_g->insertInput(0, "self_1"); self->setType(cur_block_graph->inputs()[0]->type()); @@ -293,7 +293,7 @@ GraphAndMapping ConstructFallbackGraph( // Set the output as the produced tuple new_g->registerOutput(return_tuple_node->outputs()[0]); } else { - if (old_to_new_g.count(block->outputs()[0])) { + if (block->outputs().size() && old_to_new_g.count(block->outputs()[0])) { new_g->registerOutput(old_to_new_g[block->outputs()[0]]); } } diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 63161217e4..7f4187762c 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -145,6 +145,36 @@ void getDirtyNodes( } } +void find_all_fallback_nodes(std::unordered_map &fallback_nodes) { + std::queue q; + for (auto &node : fallback_nodes) { + q.push(node.first); + } + + std::unordered_set visited_nodes; + while (!q.empty()) { + auto cur_node = q.front(); + q.pop(); + // for every node that produces this fallback node's NonTensor input, they should fallback too + for (auto input : cur_node->inputs()) { + if (!isTensor(input) && fallback_nodes.insert({input->node(), 4}).second) { + q.push(input->node()); + } + } + // for every node that consumes this fallback node's NonTensor output, they should fallback too + for (auto output : cur_node->outputs()) { + if (!isTensor(output)) { + for (auto use : output->uses()) { + auto node = use.user; + if (fallback_nodes.insert({node, 4}).second) { + q.push(node); + } + } + } + } + } +} + std::pair, SegmentedBlock> segmentBlocksWithTensorListInputs( SegmentedBlock& seg_block, const std::unordered_map& tensorlist_inputs) { @@ -491,6 +521,24 @@ bool should_run_in_trt(torch::jit::Node* n, const std::unordered_set& fallback_nodes) { + if (fallback_nodes.count(n)) { + if (fallback_nodes.at(n) == 0) { + LOG_GRAPH("Node not supported by conversion: " << util::node_info(n)); + } else if (fallback_nodes.at(n) == 1) { + LOG_GRAPH("Node explicitly set to run in torch: " << util::node_info(n)); + } else if (fallback_nodes.at(n) == 2) { + LOG_GRAPH("Node is within a module set to run in torch: " << util::node_info(n)); + } else { + LOG_GRAPH("Node fallback to Torch because the NonTensor dependencies with other fallback nodes: " << util::node_info(n)); + } + return false; + } + + LOG_GRAPH("Node is going to run in TensorRT: " << util::node_info(n)); + return true; +} + void finalize_block( PartitionedGraph& g, SegmentedBlock::SegmentedBlockTarget kind, @@ -501,12 +549,52 @@ void finalize_block( LOG_DEBUG(g.back()); } + +// use this function to get all initial fallback nodes (nodes that are unsupported or forced fallback) +// we use a map to indicate the reason why it's fallback to torch +std::unordered_map get_fallback_nodes(torch::jit::Block* block, const std::unordered_set& forced_fallback_ops) { + std::unordered_map fallback_nodes; + auto nodes = block->nodes(); + for (const auto n : nodes) { + if (n->kind() == torch::jit::prim::Constant) { + continue; + } + + // If the op is not supported by the conversion phase it should run in PyTorch + if (!conversion::OpSupported(n)) { + fallback_nodes.insert({n, 0}); + } + + // If the user specifies the op to run in Torch it should run in PyTorch + if (forced_fallback_ops.find(n->kind().toQualString()) != forced_fallback_ops.end()) { + fallback_nodes.insert({n, 1}); + } + + // If the user specifies the module containing this op to run in torch it should run in PyTorch + const auto to_compile_sym = c10::Symbol::attr("to_compile"); + if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) { + fallback_nodes.insert({n, 2}); + } + + } + return fallback_nodes; +} + PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info) { auto min_block_size = partition_info.min_block_size; std::unordered_set forced_fallback_ops( partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end()); +// get the initial fallback nodes (nodes that are unsupported or forced fallback) + auto fallback_nodes = get_fallback_nodes(block, forced_fallback_ops); + +// For fallback nodes, if it consumes any NonTensor inputs or TensorList inputs, then the node that produces this input should also fallback +// Similarly, if it produces any NonTensor outputs or TensorList outputs, then the node that produces this input should also fallback +// TODO: don't need to fallback the TensorList related nodes once the collection feature is supported + find_all_fallback_nodes(fallback_nodes); + auto nodes = block->nodes(); + PartitionedGraph segmented_blocks; // segment the nodes @@ -517,7 +605,7 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa continue; } - if (should_run_in_trt(n, forced_fallback_ops)) { + if (check_node_fallback(n, fallback_nodes)) { in_prog_trt_blk_nodes.push_back(n); // If there is an active PyTorch block and we have passed the threshold for a valid TRT @@ -570,7 +658,7 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa finalize_block(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes); } - if (!in_prog_pyt_blk_nodes.empty()) { + if (!in_prog_pyt_blk_nodes.empty() || !in_prog_trt_blk_nodes.empty()) { in_prog_pyt_blk_nodes.insert( in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end()); finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes); @@ -589,7 +677,7 @@ PartitionedGraph Partition( PartitionedGraph segmented_blocks = segment_graph(block, partition_info); // resolve nonTensor inputs/outputs - resolveNonTensorInputs(segmented_blocks); + // resolveNonTensorInputs(segmented_blocks); // register input/output torch::jit::Value for segmented graphs LOG_DEBUG("Registering input/output torch::jit::Value for segmented graphs"); From 4f974fd5683ec0d0b7559853ccf65cc821d91b27 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Wed, 18 May 2022 21:31:47 -0700 Subject: [PATCH 02/12] feat: cover more cases after refactoring segmentation Signed-off-by: Bo Wang --- core/partitioning/partitioning.cpp | 420 ++++------------------------- 1 file changed, 46 insertions(+), 374 deletions(-) diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 7f4187762c..c26a4f1f8f 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -30,24 +30,6 @@ inline bool isTensor(torch::jit::Value* val) { return val->type()->isSubtypeOf(torch::jit::TensorType::get()); } -bool isAllNodesSupported(const std::vector& nodes) { - for (auto node : nodes) { - if (!conversion::OpSupported(node)) { - return false; - } - } - return true; -} - -bool containTargetInputs(torch::jit::Node* n, const std::unordered_set& target_inputs) { - for (auto input : n->inputs()) { - if (!isTensorOrTensorList(input) && target_inputs.count(input)) { - return true; - } - } - return false; -} - bool containNonTensorOutputs(torch::jit::Node* n) { for (auto output : n->outputs()) { if (!isTensorOrTensorList(output)) { @@ -81,73 +63,9 @@ std::vector getDependencyNodes(const std::vector getOutputNodes( - torch::jit::Value* value, - const std::unordered_set& seg_block_nodes) { - // use bfs to get the DAG outputs nodes for input value - std::queue q; - std::vector stk; - std::unordered_set visited; - q.push(value); - - // top-down order traversing - while (!q.empty()) { - auto cur_val = q.front(); - q.pop(); - for (auto use : cur_val->uses()) { - auto node = use.user; - // use node must be in seg_block_nodes - if (seg_block_nodes.count(node) && !visited.count(node)) { - stk.push_back(node); - visited.insert(node); - // travel its' all outputs - for (auto output : node->outputs()) { - if (!isTensor(output)) { - q.push(output); - } - } - } - } - } - - // top-down order and we don't need to reverse it - return stk; -} - -void getDirtyNodes( - std::unordered_set& dirty_nodes, - const std::unordered_set& seg_block_nodes) { - std::queue q; - for (auto& node : dirty_nodes) { - q.push(node); - } - dirty_nodes.clear(); - - while (!q.empty()) { - auto cur_node = q.front(); - q.pop(); - if (!dirty_nodes.count(cur_node) && seg_block_nodes.count(cur_node)) { - dirty_nodes.insert(cur_node); - for (auto input : cur_node->inputs()) { - if (!isTensorOrTensorList(input)) { - q.push(input->node()); - } - } - for (auto output : cur_node->outputs()) { - if (!isTensorOrTensorList(output)) { - for (auto use : output->uses()) { - auto node = use.user; - q.push(node); - } - } - } - } - } -} - -void find_all_fallback_nodes(std::unordered_map &fallback_nodes) { +void find_all_fallback_nodes(std::unordered_map& fallback_nodes) { std::queue q; - for (auto &node : fallback_nodes) { + for (auto& node : fallback_nodes) { q.push(node.first); } @@ -175,254 +93,26 @@ void find_all_fallback_nodes(std::unordered_map &fallbac } } -std::pair, SegmentedBlock> segmentBlocksWithTensorListInputs( - SegmentedBlock& seg_block, - const std::unordered_map& tensorlist_inputs) { - std::unordered_set all_append_nodes; - std::unordered_map append_blocks; - const std::unordered_set seg_block_nodes( - seg_block.raw_nodes().begin(), seg_block.raw_nodes().end()); - for (auto input_pair : tensorlist_inputs) { - auto append_nodes = getOutputNodes(input_pair.first, seg_block_nodes); - append_blocks[input_pair.first] = SegmentedBlock(input_pair.second.target(), append_nodes); - all_append_nodes.insert(append_nodes.begin(), append_nodes.end()); - } - - std::vector trt_nodes; - for (auto node : seg_block.raw_nodes()) { - if (all_append_nodes.count(node) == 0) { - trt_nodes.emplace_back(node); - } - } - SegmentedBlock trt_block(SegmentedBlock::kTensorRT, trt_nodes); - - return std::pair, SegmentedBlock>(append_blocks, trt_block); -} - -PartitionedGraph segmentBlocksWithSpecifiedInputs( - SegmentedBlock& seg_block, - std::vector& inputs_to_resolve) { - std::vector dependency_nodes = getDependencyNodes(inputs_to_resolve); - PartitionedGraph new_seg_blocks; - // if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, merge the - // dependency nodes at the beginning of the current segmented_block and return this merged segmented_block - if (seg_block.target() == SegmentedBlock::kTorch || isAllNodesSupported(dependency_nodes)) { - // if current node is prim::If, just ensure that we have all required input in kTorch - if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) { - new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes); - new_seg_blocks.push_back(seg_block); - } else { - dependency_nodes.insert(dependency_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end()); - new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes); - } - } else { - // if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again - std::unordered_set inputs_to_resolve_set(inputs_to_resolve.begin(), inputs_to_resolve.end()); - std::vector tensorrt_nodes, pytorch_nodes; - - // take all nodes with non_tensor_inputs as initial dirty nodes (nodes that should be in PyTorch block), then we use - // dfs/bfs to find all dirty nodes that consume non_tensor values produced by dirty nodes or produces non_tensor - // values consumed by dirty nodes - std::unordered_set dirty_nodes; - const std::unordered_set seg_block_nodes( - seg_block.raw_nodes().begin(), seg_block.raw_nodes().end()); - - for (auto n : seg_block.raw_nodes()) { - if (containTargetInputs(n, inputs_to_resolve_set)) { - dirty_nodes.insert(n); - } - } - getDirtyNodes(dirty_nodes, seg_block_nodes); - for (auto n : seg_block.raw_nodes()) { - if (dirty_nodes.count(n)) { - if (!tensorrt_nodes.empty()) { - new_seg_blocks.emplace_back(new_seg_blocks.size(), SegmentedBlock::kTensorRT, tensorrt_nodes); - tensorrt_nodes.clear(); +void resolveTRTNonTensorInputs(PartitionedGraph& segmented_blocks) { + // if a TRT segment has nonTensor Inputs, the nodes that produce this nonTensor Inputs must in another TensorRT engine + // because we have already found the interface between Torch and TRT in segmentation phase + // what we do here is just find the dependency nodes of the TRT segments that have nonTensor inputs + for (size_t i = 0; i < segmented_blocks.size(); ++i) { + if (segmented_blocks[i].target() == SegmentedBlock::kTensorRT) { + std::vector inputs_to_resolve; + for (auto input : segmented_blocks[i].raw_inputs()) { + if (!isTensor(input)) { + inputs_to_resolve.push_back(input); } - pytorch_nodes.push_back(n); - } else { - if (!pytorch_nodes.empty()) { - new_seg_blocks.emplace_back(new_seg_blocks.size(), SegmentedBlock::kTorch, pytorch_nodes); - pytorch_nodes.clear(); - } - tensorrt_nodes.push_back(n); } - } - - // Form the last segmented_block with the leftover nodes in tensorrt_nodes or pytorch_nodes correspondingly. - if (!tensorrt_nodes.empty()) { - new_seg_blocks.emplace_back(new_seg_blocks.size(), SegmentedBlock::kTensorRT, tensorrt_nodes); - } else { - new_seg_blocks.emplace_back(new_seg_blocks.size(), SegmentedBlock::kTorch, pytorch_nodes); - } - } - - return new_seg_blocks; -} - -PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) { - // reconstruct segmented_block if this block requires nonTensor input - std::vector inputs_to_resolve; - // Gather all non-tensor inputs for this block - for (auto input : seg_block.raw_inputs()) { - if (!isTensorOrTensorList(input)) { - inputs_to_resolve.push_back(input); - } - } - return segmentBlocksWithSpecifiedInputs(seg_block, inputs_to_resolve); -} - -std::unordered_map getInputUsageCounts( - const PartitionedGraph& segmented_blocks, - const std::function& condition) { - // usage_counts is a map which stores non-tensor inputs as keys and the values are indices of segmented blocks which - // have these non-tensor inputs. Iterate through the graph (segmented blocks) from bottom to top. When we find a - // non-tensor input in a segmented block of index "i", store it in the usage_counts map. Now for each non-tensor - // inputs recorded in the usage_counts map, we check if any previous segmented block (segmented block index i goes - // from n-1 to 0) generated/contains this non-tensor input. If so, we set this idx as the produce_id as it produces - // the non-tensor input. - std::unordered_map usage_counts; - for (int i = segmented_blocks.size() - 1; i >= 0; --i) { - for (auto input : segmented_blocks[i].raw_inputs()) { - if (condition(input)) { - segmented_blocks[i].target() == SegmentedBlock::kTorch ? usage_counts[input].torch_use_id.push_back(i) - : usage_counts[input].tensorrt_use_id.push_back(i); - } - } - - // For each non-tensor value in the usage_counts map, keep updating the produce_id to the earliest segmented block - // that has/produces it. - for (auto& use : usage_counts) { - // Set the produce_id to the segmented block index that contains/produces this non-tensor torch::jit::Value - if (segmented_blocks[i].contain_raw_value(use.first)) { - use.second.produce_id = i; + if (!inputs_to_resolve.empty()) { + std::vector dependency_nodes = getDependencyNodes(inputs_to_resolve); + dependency_nodes.insert( + dependency_nodes.end(), segmented_blocks[i].raw_nodes().begin(), segmented_blocks[i].raw_nodes().end()); + segmented_blocks[i] = SegmentedBlock(SegmentedBlock::kTensorRT, dependency_nodes); } } } - return usage_counts; -} - -std::unordered_map::iterator> getIdxtoIterMap( - std::list& segmented_blocks_list) { - std::unordered_map::iterator> idx_to_iter; - auto iter = segmented_blocks_list.begin(); - for (uint64_t i = 0; i < segmented_blocks_list.size(); ++i, ++iter) { - idx_to_iter[i] = iter; - } - return idx_to_iter; -} - -void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) { - // get input usage counts and blocks_list - std::list segmented_blocks_list(segmented_blocks.cbegin(), segmented_blocks.cend()); - auto usage_counts = getInputUsageCounts( - segmented_blocks, [](torch::jit::Value* input) -> bool { return !isTensorOrTensorList(input); }); - auto idx_to_iter = getIdxtoIterMap(segmented_blocks_list); - - std::map> - torch_values_to_fix; // Only need to resolve values generated by tensorrt - std::set tensorrt_blocks_to_fix; // Need to resolve ALL non-tensor inputs - - // update blocks_list - std::unordered_set updated_segments; - for (auto& use : usage_counts) { - auto use_info = use.second; - // if the segment that produce this nonTensor value is kTensorRT but consumed in kTorch, inject nodes in the first - // kTorch segment. - if (segmented_blocks[use_info.produce_id].target() == SegmentedBlock::kTensorRT && !use_info.torch_use_id.empty()) { - auto first_torch_id = use_info.torch_use_id.back(); - torch_values_to_fix[first_torch_id].push_back(use.first); - } - // kTensorRT segments always need to inject nodes for the nonTensor inputs - for (auto i : use_info.tensorrt_use_id) { - tensorrt_blocks_to_fix.insert(i); - } - } - for (auto torch_block_pair : torch_values_to_fix) { - auto to_inject_blocks = - segmentBlocksWithSpecifiedInputs(segmented_blocks[torch_block_pair.first], torch_block_pair.second); - auto next_iter = segmented_blocks_list.erase(idx_to_iter[torch_block_pair.first]); - segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end()); - } - - for (auto i : tensorrt_blocks_to_fix) { - auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]); - auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]); - segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end()); - } - - segmented_blocks.clear(); - segmented_blocks.insert(segmented_blocks.begin(), segmented_blocks_list.begin(), segmented_blocks_list.end()); - return; -} - -void resolveTensorListInputBlocks(PartitionedGraph& segmented_blocks) { - // usage_counts is a map with key as non-tensor/tensorlist inputs and value as the idx of segmented block which - // produces/contains it. - auto usage_counts = - getInputUsageCounts(segmented_blocks, [](torch::jit::Value* input) -> bool { return isTensorList(input); }); - - // Get idx of the segblock to its iterator mapping - std::list segmented_blocks_list(segmented_blocks.cbegin(), segmented_blocks.cend()); - auto idx_to_iter = getIdxtoIterMap(segmented_blocks_list); - - std::unordered_set updated_segments; - // we need to re-segment TensorRT segments whose inputs are TensorLists - for (auto& use : usage_counts) { - auto use_info = use.second; - // For a particular tensorlist input, traverse through all ids of segmented blocks whose target is TensorRT - for (auto i : use_info.tensorrt_use_id) { - if (!updated_segments.count(i)) { - // tensorlistinput_to_segblock is a mapping from {tensorlist input : segmented block which produced this - // tensorlist input} - std::unordered_map tensorlistinput_to_segblock; - for (auto input : segmented_blocks[i].raw_inputs()) { - if (isTensorList(input)) { - tensorlistinput_to_segblock[input] = segmented_blocks[usage_counts[input].produce_id]; - } - } - - // For each tensorlist input in tensorlistinput_to_segblock, get the node which actually uses this input. - // Once we retrieve the node, we remove it from the current TensorRT segmented_blocks[i]. This node should be - // added to block that generated/produced (can be obtained via produce_id) this tensorlist input in the first - // place. - auto seg_blocks = segmentBlocksWithTensorListInputs(segmented_blocks[i], tensorlistinput_to_segblock); - auto append_blocks = seg_blocks.first; - auto trt_block = seg_blocks.second; - // Remove the current TensorRT seg_block and replace it with new TRT block (non empty) which has the node that - // uses tensorlist input removed. - auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]); - if (trt_block.raw_nodes().size() > 0) { - segmented_blocks_list.insert(next_iter, trt_block); - } - - // append blocks' nodes to the producer seg_block - for (auto append_block : append_blocks) { - auto input = append_block.first; // corresponds to the tensorlist input - auto block = append_block.second; - // append nodes to segmented_blocks_list - auto producer = idx_to_iter[usage_counts[input].produce_id]; - for (auto n : block.raw_nodes()) { - producer->cloneNode(n); - } - } - updated_segments.insert(i); - } - } - } - segmented_blocks.clear(); - segmented_blocks.insert(segmented_blocks.begin(), segmented_blocks_list.begin(), segmented_blocks_list.end()); - return; -} - -void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shared_ptr g - // make sure that all inputs should be tensor - LOG_DEBUG("Resolving nonTensor inputs/outputs of segmented_blocks"); - resolveNonTensorInputBlocks(segmented_blocks); - - // we need to re-segment tensorrt blocks whose inputs are tensorlists (eg: Tensor [] instead of Tensor). - LOG_DEBUG("Resolving inputs of type TensorList in segmented_blocks"); - resolveTensorListInputBlocks(segmented_blocks); } void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Block* block) { @@ -453,6 +143,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo } // if no output, then register the last node's output as current graph's output if (seg_block.raw_outputs().empty()) { + LOG_DEBUG(seg_block << " no output\n"); // for Torch segments, register input as output if (seg_block.target() == SegmentedBlock::kTorch) { seg_block.registerOutput(seg_block.raw_inputs()[0]); @@ -460,7 +151,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo // for TensorRT segments, register last nonInput Tensor outputs for (int i = seg_block.raw_nodes().size() - 1; i >= 0; --i) { for (auto node_output : seg_block.raw_nodes()[i]->outputs()) { - if (isTensorOrTensorList(node_output)) + if (isTensor(node_output)) seg_block.registerOutput(node_output); } if (!seg_block.raw_outputs().empty()) @@ -497,30 +188,6 @@ bool checkLoopEvaluatable(torch::jit::Node* n) { return compile_to_trt; } -bool should_run_in_trt(torch::jit::Node* n, const std::unordered_set& torch_ops) { - // If the op is not supported by the conversion phase it should run in PyTorch - if (!conversion::OpSupported(n)) { - LOG_GRAPH("Node not supported by conversion: " << util::node_info(n)); - return false; - } - - // If the user specifies the op to run in Torch it should run in PyTorch - if (torch_ops.find(n->kind().toQualString()) != torch_ops.end()) { - LOG_GRAPH("Node explicitly set to run in torch: " << util::node_info(n)); - return false; - } - - // If the user specifies the module containing this op to run in torch it should run in PyTorch - const auto to_compile_sym = c10::Symbol::attr("to_compile"); - if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) { - LOG_GRAPH("Node is within a module set to run in torch: " << util::node_info(n)); - return false; - } - - LOG_GRAPH("Node is going to run in TensorRT: " << util::node_info(n)); - return true; -} - bool check_node_fallback(torch::jit::Node* n, const std::unordered_map& fallback_nodes) { if (fallback_nodes.count(n)) { if (fallback_nodes.at(n) == 0) { @@ -530,7 +197,9 @@ bool check_node_fallback(torch::jit::Node* n, const std::unordered_map get_fallback_nodes(torch::jit::Block* block, const std::unordered_set& forced_fallback_ops) { +std::unordered_map get_fallback_nodes( + torch::jit::Block* block, + const std::unordered_set& forced_fallback_ops) { std::unordered_map fallback_nodes; auto nodes = block->nodes(); for (const auto n : nodes) { @@ -560,22 +230,21 @@ std::unordered_map get_fallback_nodes(torch::jit::Block* continue; } - // If the op is not supported by the conversion phase it should run in PyTorch - if (!conversion::OpSupported(n)) { - fallback_nodes.insert({n, 0}); - } - - // If the user specifies the op to run in Torch it should run in PyTorch - if (forced_fallback_ops.find(n->kind().toQualString()) != forced_fallback_ops.end()) { - fallback_nodes.insert({n, 1}); - } + // If the op is not supported by the conversion phase it should run in PyTorch + if (!conversion::OpSupported(n)) { + fallback_nodes.insert({n, 0}); + } - // If the user specifies the module containing this op to run in torch it should run in PyTorch - const auto to_compile_sym = c10::Symbol::attr("to_compile"); - if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) { - fallback_nodes.insert({n, 2}); - } + // If the user specifies the op to run in Torch it should run in PyTorch + if (forced_fallback_ops.find(n->kind().toQualString()) != forced_fallback_ops.end()) { + fallback_nodes.insert({n, 1}); + } + // If the user specifies the module containing this op to run in torch it should run in PyTorch + const auto to_compile_sym = c10::Symbol::attr("to_compile"); + if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) { + fallback_nodes.insert({n, 2}); + } } return fallback_nodes; } @@ -585,12 +254,13 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa std::unordered_set forced_fallback_ops( partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end()); -// get the initial fallback nodes (nodes that are unsupported or forced fallback) + // get the initial fallback nodes (nodes that are unsupported or forced fallback) auto fallback_nodes = get_fallback_nodes(block, forced_fallback_ops); -// For fallback nodes, if it consumes any NonTensor inputs or TensorList inputs, then the node that produces this input should also fallback -// Similarly, if it produces any NonTensor outputs or TensorList outputs, then the node that produces this input should also fallback -// TODO: don't need to fallback the TensorList related nodes once the collection feature is supported + // For fallback nodes, if it consumes any NonTensor inputs or TensorList inputs, then the node that produces this + // input should also fallback Similarly, if it produces any NonTensor outputs or TensorList outputs, then the node + // that produces this input should also fallback + // TODO: don't need to fallback the TensorList related nodes once the collection feature is supported find_all_fallback_nodes(fallback_nodes); auto nodes = block->nodes(); @@ -676,8 +346,10 @@ PartitionedGraph Partition( LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks"); PartitionedGraph segmented_blocks = segment_graph(block, partition_info); + // It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks + // resolve nonTensor inputs/outputs - // resolveNonTensorInputs(segmented_blocks); + resolveTRTNonTensorInputs(segmented_blocks); // register input/output torch::jit::Value for segmented graphs LOG_DEBUG("Registering input/output torch::jit::Value for segmented graphs"); From 22d91f5e6aa9c81ca77bd544324b23262cf7bbdb Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Thu, 2 Jun 2022 12:06:35 -0700 Subject: [PATCH 03/12] fix: fix the bug that tag Constant node as fallback node Signed-off-by: Bo Wang --- core/partitioning/partitioning.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index c26a4f1f8f..92f581e77a 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -75,7 +75,8 @@ void find_all_fallback_nodes(std::unordered_map& fallbac q.pop(); // for every node that produces this fallback node's NonTensor input, they should fallback too for (auto input : cur_node->inputs()) { - if (!isTensor(input) && fallback_nodes.insert({input->node(), 4}).second) { + if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant && + fallback_nodes.insert({input->node(), 4}).second) { q.push(input->node()); } } @@ -84,7 +85,7 @@ void find_all_fallback_nodes(std::unordered_map& fallbac if (!isTensor(output)) { for (auto use : output->uses()) { auto node = use.user; - if (fallback_nodes.insert({node, 4}).second) { + if (node->kind() != torch::jit::prim::Constant && fallback_nodes.insert({node, 4}).second) { q.push(node); } } From 1df7cbb4f703d399a97901779fbc59034a8c8932 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Wed, 8 Jun 2022 01:26:07 -0700 Subject: [PATCH 04/12] fix: fix the bug that getDependencyNodes misses modifying nodes Signed-off-by: Bo Wang --- core/partitioning/partitioning.cpp | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 92f581e77a..94d922e82c 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -39,7 +39,28 @@ bool containNonTensorOutputs(torch::jit::Node* n) { return false; } -std::vector getDependencyNodes(const std::vector& vals) { +bool isModifyingNodes(torch::jit::Node* node) { + std::unordered_set modifying_node_set{"aten::append"}; + return modifying_node_set.find(node->kind().toQualString()) != modifying_node_set.end(); +} + +std::vector findModifyingNodes(torch::jit::Value* val, const std::unordered_set &seg_block_nodes) { + std::vector modifying_nodes; + for (auto use: val->uses()) { + torch::jit::Node* node = use.user; + if (seg_block_nodes.find(node) != seg_block_nodes.end()) { + break; + } + if (isModifyingNodes(node)) { + modifying_nodes.push_back(node); + } + } + return modifying_nodes; +} + +std::vector getDependencyNodes(const std::vector& vals, const SegmentedBlock &seg_block) { + // get all nodes in the segmentedblock + std::unordered_set seg_block_nodes(seg_block.raw_nodes().begin(), seg_block.raw_nodes().end()); // use bfs to get the DAG dependency nodes for input value std::queue> q( std::deque(vals.begin(), vals.end())); @@ -51,6 +72,8 @@ std::vector getDependencyNodes(const std::vectornode(); if (node->kind() != torch::jit::prim::Constant && !visited.count(node)) { visited.insert(node); + auto modifying_nodes = findModifyingNodes(cur_val, seg_block_nodes); + stk.insert(stk.end(), modifying_nodes.rbegin(), modifying_nodes.rend()); stk.push_back(node); for (auto input : node->inputs()) { if (!isTensorOrTensorList(input)) { @@ -107,7 +130,7 @@ void resolveTRTNonTensorInputs(PartitionedGraph& segmented_blocks) { } } if (!inputs_to_resolve.empty()) { - std::vector dependency_nodes = getDependencyNodes(inputs_to_resolve); + std::vector dependency_nodes = getDependencyNodes(inputs_to_resolve, segmented_blocks[i]); dependency_nodes.insert( dependency_nodes.end(), segmented_blocks[i].raw_nodes().begin(), segmented_blocks[i].raw_nodes().end()); segmented_blocks[i] = SegmentedBlock(SegmentedBlock::kTensorRT, dependency_nodes); @@ -144,7 +167,6 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo } // if no output, then register the last node's output as current graph's output if (seg_block.raw_outputs().empty()) { - LOG_DEBUG(seg_block << " no output\n"); // for Torch segments, register input as output if (seg_block.target() == SegmentedBlock::kTorch) { seg_block.registerOutput(seg_block.raw_inputs()[0]); @@ -276,6 +298,7 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa continue; } + if (check_node_fallback(n, fallback_nodes)) { in_prog_trt_blk_nodes.push_back(n); From 0bd8a9834a17d35c9d24c05a18712e0fcbbf3a1c Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Wed, 8 Jun 2022 14:48:35 -0700 Subject: [PATCH 05/12] refactor: refactored the method of getting modifying nodes Signed-off-by: Bo Wang --- core/partitioning/partitioning.cpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 94d922e82c..f661646b44 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -39,9 +39,17 @@ bool containNonTensorOutputs(torch::jit::Node* n) { return false; } -bool isModifyingNodes(torch::jit::Node* node) { - std::unordered_set modifying_node_set{"aten::append"}; - return modifying_node_set.find(node->kind().toQualString()) != modifying_node_set.end(); +bool isModifyingNodes(torch::jit::Node* node, torch::jit::Value* val) { + const auto& schema = node->schema(); + for (size_t i = 0; i < node->inputs().size(); ++i) { + if (node->inputs()[i] == val) { + const at::AliasInfo* formal = schema.arguments()[i].alias_info(); + if (formal->isWrite()) { + return true; + } + } + } + return false; } std::vector findModifyingNodes(torch::jit::Value* val, const std::unordered_set &seg_block_nodes) { @@ -51,7 +59,7 @@ std::vector findModifyingNodes(torch::jit::Value* val, const if (seg_block_nodes.find(node) != seg_block_nodes.end()) { break; } - if (isModifyingNodes(node)) { + if (isModifyingNodes(node, val)) { modifying_nodes.push_back(node); } } @@ -298,7 +306,6 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa continue; } - if (check_node_fallback(n, fallback_nodes)) { in_prog_trt_blk_nodes.push_back(n); From 8d9a8d4e0b106113befdd66820cba8c556fb3714 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Tue, 14 Jun 2022 16:52:58 -0700 Subject: [PATCH 06/12] fix: nodes in If block should fallback it it dependes on any nonTensor value that's produced in outer block Signed-off-by: Bo Wang --- core/compiler.cpp | 14 +++++++++----- core/partitioning/partitioning.cpp | 30 +++++++++++++++++++----------- core/partitioning/partitioning.h | 3 ++- 3 files changed, 30 insertions(+), 17 deletions(-) diff --git a/core/compiler.cpp b/core/compiler.cpp index e520f43664..4a4389bea3 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -198,7 +198,8 @@ void AddIfBlockToGraph( auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); }; new_if_block->cloneFrom(cur_block_graph->block(), env); - if (cur_block_graph->inputs().size() && cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) { + if (cur_block_graph->inputs().size() && + cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) { if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) { auto self = new_g->insertInput(0, "self_1"); self->setType(cur_block_graph->inputs()[0]->type()); @@ -223,13 +224,14 @@ GraphAndMapping ConstructFallbackGraph( torch::jit::Block* block, std::unordered_map example_tensor_map, CompileSpec cfg, - ir::StaticParams static_params) { + ir::StaticParams static_params, + std::unordered_map& fallback_nodes) { auto convert_cfg = cfg.convert_info; auto partition_info = cfg.partition_info; auto new_g = std::make_shared(); - auto segmented_blocks = partitioning::Partition(block, example_tensor_map, partition_info); + auto segmented_blocks = partitioning::Partition(block, example_tensor_map, partition_info, fallback_nodes); // the mapping from lowering graph => fallback global graph std::unordered_map old_to_new_g; @@ -270,7 +272,7 @@ GraphAndMapping ConstructFallbackGraph( std::vector graph_and_mappings; for (auto cur_block : if_node->blocks()) { graph_and_mappings.push_back( - ConstructFallbackGraph(new_mod, cur_block, example_tensor_map, cfg, static_params)); + ConstructFallbackGraph(new_mod, cur_block, example_tensor_map, cfg, static_params, fallback_nodes)); } AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g); @@ -430,7 +432,9 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) !(cfg.lower_info.forced_fallback_modules.size() == 0 && cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) { auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types); - auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params); + std::unordered_map fallback_nodes; + auto graph_and_mapping = + ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params, fallback_nodes); new_g = graph_and_mapping.first; LOG_INFO("Segmented Graph: " << *new_g); diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index f661646b44..93f377ab51 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -52,9 +52,11 @@ bool isModifyingNodes(torch::jit::Node* node, torch::jit::Value* val) { return false; } -std::vector findModifyingNodes(torch::jit::Value* val, const std::unordered_set &seg_block_nodes) { +std::vector findModifyingNodes( + torch::jit::Value* val, + const std::unordered_set& seg_block_nodes) { std::vector modifying_nodes; - for (auto use: val->uses()) { + for (auto use : val->uses()) { torch::jit::Node* node = use.user; if (seg_block_nodes.find(node) != seg_block_nodes.end()) { break; @@ -66,7 +68,9 @@ std::vector findModifyingNodes(torch::jit::Value* val, const return modifying_nodes; } -std::vector getDependencyNodes(const std::vector& vals, const SegmentedBlock &seg_block) { +std::vector getDependencyNodes( + const std::vector& vals, + const SegmentedBlock& seg_block) { // get all nodes in the segmentedblock std::unordered_set seg_block_nodes(seg_block.raw_nodes().begin(), seg_block.raw_nodes().end()); // use bfs to get the DAG dependency nodes for input value @@ -251,10 +255,10 @@ void finalize_block( // use this function to get all initial fallback nodes (nodes that are unsupported or forced fallback) // we use a map to indicate the reason why it's fallback to torch -std::unordered_map get_fallback_nodes( +void get_fallback_nodes( torch::jit::Block* block, - const std::unordered_set& forced_fallback_ops) { - std::unordered_map fallback_nodes; + const std::unordered_set& forced_fallback_ops, + std::unordered_map& fallback_nodes) { auto nodes = block->nodes(); for (const auto n : nodes) { if (n->kind() == torch::jit::prim::Constant) { @@ -277,16 +281,19 @@ std::unordered_map get_fallback_nodes( fallback_nodes.insert({n, 2}); } } - return fallback_nodes; + return; } -PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info) { +PartitionedGraph segment_graph( + torch::jit::Block* block, + const PartitionInfo& partition_info, + std::unordered_map& fallback_nodes) { auto min_block_size = partition_info.min_block_size; std::unordered_set forced_fallback_ops( partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end()); // get the initial fallback nodes (nodes that are unsupported or forced fallback) - auto fallback_nodes = get_fallback_nodes(block, forced_fallback_ops); + get_fallback_nodes(block, forced_fallback_ops, fallback_nodes); // For fallback nodes, if it consumes any NonTensor inputs or TensorList inputs, then the node that produces this // input should also fallback Similarly, if it produces any NonTensor outputs or TensorList outputs, then the node @@ -371,11 +378,12 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa PartitionedGraph Partition( torch::jit::Block* block, std::unordered_map& example_tensor_map, - const PartitionInfo& partition_info) { + const PartitionInfo& partition_info, + std::unordered_map& fallback_nodes) { LOG_DEBUG(partition_info); // segment lowering global graph into blocks LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks"); - PartitionedGraph segmented_blocks = segment_graph(block, partition_info); + PartitionedGraph segmented_blocks = segment_graph(block, partition_info, fallback_nodes); // It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks diff --git a/core/partitioning/partitioning.h b/core/partitioning/partitioning.h index 31ecfebf25..6140f7f74d 100644 --- a/core/partitioning/partitioning.h +++ b/core/partitioning/partitioning.h @@ -21,7 +21,8 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa PartitionedGraph Partition( torch::jit::Block* block, std::unordered_map& example_tensor_map, - const PartitionInfo& partition_info); + const PartitionInfo& partition_info, + std::unordered_map& fallback_nodes); std::ostream& operator<<(std::ostream& os, const PartitionedGraph& g); From da9b13c61d9ab1ca2f1e28da146e5a16b92a6a99 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Thu, 16 Jun 2022 13:42:11 -0700 Subject: [PATCH 07/12] tests: add test for inplace op in if block Signed-off-by: Bo Wang --- core/partitioning/partitioning.h | 2 +- tests/core/partitioning/BUILD | 3 +- tests/core/partitioning/test_conditionals.cpp | 31 +++++++++++++++++++ .../test_resolve_nontensor_inputs.cpp | 9 ++++-- tests/core/partitioning/test_segmentation.cpp | 18 +++++++---- .../core/partitioning/test_shape_analysis.cpp | 6 ++-- tests/modules/custom_models.py | 12 +++++++ tests/modules/hub.py | 4 +++ 8 files changed, 72 insertions(+), 13 deletions(-) diff --git a/core/partitioning/partitioning.h b/core/partitioning/partitioning.h index 6140f7f74d..77cece8e63 100644 --- a/core/partitioning/partitioning.h +++ b/core/partitioning/partitioning.h @@ -16,7 +16,7 @@ namespace partitioning { typedef std::vector PartitionedGraph; -PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info); +PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info, std::unordered_map &fallback_nodes); PartitionedGraph Partition( torch::jit::Block* block, diff --git a/tests/core/partitioning/BUILD b/tests/core/partitioning/BUILD index 98c549e11d..ec5e9c77fc 100644 --- a/tests/core/partitioning/BUILD +++ b/tests/core/partitioning/BUILD @@ -13,7 +13,8 @@ filegroup( "//tests/modules:mobilenet_v2_traced.jit.pt", "//tests/modules:conditional_scripted.jit.pt", "//tests/modules:loop_fallback_eval_scripted.jit.pt", - "//tests/modules:loop_fallback_no_eval_scripted.jit.pt"] + "//tests/modules:loop_fallback_no_eval_scripted.jit.pt", + "//tests/modules:inplace_op_if_scripted.jit.pt"] ) partitioning_test( diff --git a/tests/core/partitioning/test_conditionals.cpp b/tests/core/partitioning/test_conditionals.cpp index 9698559f80..691812dc6d 100644 --- a/tests/core/partitioning/test_conditionals.cpp +++ b/tests/core/partitioning/test_conditionals.cpp @@ -42,3 +42,34 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) { ASSERT_TRUE(conditional_engines_count == 2); } + + +TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) { + torch::jit::script::Module mod; + try { + mod = torch::jit::load("tests/modules/inplace_op_if_scripted.jit.pt"); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + return; + } + + const std::vector> input_shapes = {{4, 4}, {4, 4}}; + std::vector jit_inputs_ivalues; + std::vector trt_inputs_ivalues; + for (auto in_shape : input_shapes) { + auto in = at::randint(5, in_shape, {at::kCUDA}); + jit_inputs_ivalues.push_back(in.clone()); + trt_inputs_ivalues.push_back(in.clone()); + } + + std::vector inputs{torch_tensorrt::core::ir::Input({4, 4}), torch_tensorrt::core::ir::Input({4, 4})}; + auto g = mod.get_method("forward").graph(); + torch_tensorrt::core::CompileSpec cfg(inputs); + cfg.partition_info.enabled = true; + cfg.partition_info.forced_fallback_operators.push_back("prim::ListConstruct"); + + auto jit_results = mod.forward(jit_inputs_ivalues).toTensor(); + auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg); + auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor(); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results, trt_results, 2e-6)); +} diff --git a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp index facdd31151..33368e3ba2 100644 --- a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp +++ b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp @@ -123,8 +123,9 @@ TEST(Partitioning, ResolveNonTensorInputsCorrectly) { input_types.insert({g->inputs()[i], {at::kFloat}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); + std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info); + torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes); int torch_block_cnt = 0, trt_block_cnt = 0; for (const auto& segmented_block : segmented_blocks) { @@ -181,8 +182,9 @@ TEST(Partitioning, ResolveTensorListInputsInTrtCorrectly) { input_types.insert({g->inputs()[i], {at::kFloat}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); + std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info); + torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes); int torch_block_cnt = 0, trt_block_cnt = 0; for (const auto& segmented_block : segmented_blocks) { @@ -372,7 +374,8 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) { input_types.insert({g->inputs()[i], {at::kFloat}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); - auto segmented_blocks = torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info); + std::unordered_map fallback_nodes; + auto segmented_blocks = torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes); int torch_block_cnt = 0, trt_block_cnt = 0; for (const auto& segmented_block : segmented_blocks) { diff --git a/tests/core/partitioning/test_segmentation.cpp b/tests/core/partitioning/test_segmentation.cpp index 1a77833577..7af2b2c921 100644 --- a/tests/core/partitioning/test_segmentation.cpp +++ b/tests/core/partitioning/test_segmentation.cpp @@ -74,8 +74,9 @@ TEST(Partitioning, SegmentSequentialModelCorrectly) { torch_tensorrt::core::partitioning::PartitionInfo partition_info; partition_info.enabled = true; + std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info); + torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); ASSERT_TRUE( checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 2)); ASSERT_TRUE( @@ -109,8 +110,9 @@ TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) { torch_tensorrt::core::partitioning::PartitionInfo partition_info; partition_info.enabled = true; partition_info.min_block_size = 3; + std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info); + torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); ASSERT_TRUE( checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 1)); ASSERT_TRUE( @@ -144,8 +146,9 @@ TEST(Partitioning, SegmentSequentialModelWithForcedOPCorrectly) { torch_tensorrt::core::partitioning::PartitionInfo partition_info; partition_info.enabled = true; partition_info.forced_fallback_operators.push_back("aten::relu"); + std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info); + torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info,fallback_nodes); ASSERT_TRUE( checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 3)); ASSERT_TRUE( @@ -179,8 +182,9 @@ TEST(Partitioning, SegmentBranchModelCorrectly) { torch_tensorrt::core::partitioning::PartitionInfo partition_info; partition_info.enabled = true; + std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info); + torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); ASSERT_TRUE( checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 2)); ASSERT_TRUE( @@ -215,8 +219,9 @@ TEST(Partitioning, SegmentBranchModelWithMinBlockSizeCorrectly) { torch_tensorrt::core::partitioning::PartitionInfo partition_info; partition_info.enabled = true; partition_info.min_block_size = 3; + std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info); + torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); ASSERT_TRUE( checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 1)); ASSERT_TRUE( @@ -255,8 +260,9 @@ TEST(Partitioning, SegmentBranchModelWithForcedFallbackOPCorrectly) { torch_tensorrt::core::partitioning::PartitionInfo partition_info; partition_info.enabled = true; partition_info.forced_fallback_operators.push_back("aten::relu"); + std::unordered_map fallback_nodes; torch_tensorrt::core::partitioning::PartitionedGraph segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info); + torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); ASSERT_TRUE( checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 3)); ASSERT_TRUE( diff --git a/tests/core/partitioning/test_shape_analysis.cpp b/tests/core/partitioning/test_shape_analysis.cpp index 8effa821ae..7bcabc0d51 100644 --- a/tests/core/partitioning/test_shape_analysis.cpp +++ b/tests/core/partitioning/test_shape_analysis.cpp @@ -66,8 +66,9 @@ TEST(Partitioning, InferSequentialModelSegmentedBlockShapeCorrectly) { input_types.insert({g->inputs()[i], {at::kFloat}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); + std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info); + torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes); ASSERT_TRUE(checkSegmentedBlockInputShape( segmented_blocks, @@ -116,8 +117,9 @@ TEST(Partitioning, InferBranchModelSegmentedBlockShapeCorrectly) { input_types.insert({g->inputs()[i], {at::kFloat}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); + std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info); + torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes); ASSERT_TRUE(checkSegmentedBlockInputShape( segmented_blocks, diff --git a/tests/modules/custom_models.py b/tests/modules/custom_models.py index 252a3a2b5d..aac83ffe7c 100644 --- a/tests/modules/custom_models.py +++ b/tests/modules/custom_models.py @@ -87,6 +87,18 @@ def forward(self, x): return x +# Sample Inplace OP in Conditional Block Model +class FallbackInplaceOPIf(nn.Module): + def __init__(self): + super(FallbackInplaceOPIf, self).__init__() + def forward(self, x, y): + mod_list = [x] + if x.sum() > y.sum(): + mod_list.append(y) + z = torch.cat(mod_list) + return z + + def BertModule(): model_name = "bert-base-uncased" enc = BertTokenizer.from_pretrained(model_name) diff --git a/tests/modules/hub.py b/tests/modules/hub.py index 3ac2de5ac6..ac8a5c3afa 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -96,6 +96,10 @@ "model": cm.FallbackIf(), "path": "script" }, + "inplace_op_if": { + "model": cm.FallbackInplaceOPIf(), + "path": "script" + }, "bert-base-uncased": { "model": cm.BertModule(), "path": "trace" From e929b65cab75ce931d49329449b061d92760a921 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Thu, 16 Jun 2022 14:06:37 -0700 Subject: [PATCH 08/12] chore: apply linting Signed-off-by: Bo Wang --- core/partitioning/partitioning.h | 5 ++++- tests/core/partitioning/test_conditionals.cpp | 4 ++-- tests/core/partitioning/test_resolve_nontensor_inputs.cpp | 3 ++- tests/core/partitioning/test_segmentation.cpp | 2 +- tests/modules/custom_models.py | 2 ++ 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/core/partitioning/partitioning.h b/core/partitioning/partitioning.h index 77cece8e63..fce88134b7 100644 --- a/core/partitioning/partitioning.h +++ b/core/partitioning/partitioning.h @@ -16,7 +16,10 @@ namespace partitioning { typedef std::vector PartitionedGraph; -PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info, std::unordered_map &fallback_nodes); +PartitionedGraph segment_graph( + torch::jit::Block* block, + const PartitionInfo& partition_info, + std::unordered_map& fallback_nodes); PartitionedGraph Partition( torch::jit::Block* block, diff --git a/tests/core/partitioning/test_conditionals.cpp b/tests/core/partitioning/test_conditionals.cpp index 691812dc6d..86d701dd4a 100644 --- a/tests/core/partitioning/test_conditionals.cpp +++ b/tests/core/partitioning/test_conditionals.cpp @@ -43,7 +43,6 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) { ASSERT_TRUE(conditional_engines_count == 2); } - TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) { torch::jit::script::Module mod; try { @@ -62,7 +61,8 @@ TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) { trt_inputs_ivalues.push_back(in.clone()); } - std::vector inputs{torch_tensorrt::core::ir::Input({4, 4}), torch_tensorrt::core::ir::Input({4, 4})}; + std::vector inputs{ + torch_tensorrt::core::ir::Input({4, 4}), torch_tensorrt::core::ir::Input({4, 4})}; auto g = mod.get_method("forward").graph(); torch_tensorrt::core::CompileSpec cfg(inputs); cfg.partition_info.enabled = true; diff --git a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp index 33368e3ba2..10115459d0 100644 --- a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp +++ b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp @@ -375,7 +375,8 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) { } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); std::unordered_map fallback_nodes; - auto segmented_blocks = torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes); + auto segmented_blocks = + torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes); int torch_block_cnt = 0, trt_block_cnt = 0; for (const auto& segmented_block : segmented_blocks) { diff --git a/tests/core/partitioning/test_segmentation.cpp b/tests/core/partitioning/test_segmentation.cpp index 7af2b2c921..bf32bcf918 100644 --- a/tests/core/partitioning/test_segmentation.cpp +++ b/tests/core/partitioning/test_segmentation.cpp @@ -148,7 +148,7 @@ TEST(Partitioning, SegmentSequentialModelWithForcedOPCorrectly) { partition_info.forced_fallback_operators.push_back("aten::relu"); std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info,fallback_nodes); + torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); ASSERT_TRUE( checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 3)); ASSERT_TRUE( diff --git a/tests/modules/custom_models.py b/tests/modules/custom_models.py index aac83ffe7c..20d501045f 100644 --- a/tests/modules/custom_models.py +++ b/tests/modules/custom_models.py @@ -89,8 +89,10 @@ def forward(self, x): # Sample Inplace OP in Conditional Block Model class FallbackInplaceOPIf(nn.Module): + def __init__(self): super(FallbackInplaceOPIf, self).__init__() + def forward(self, x, y): mod_list = [x] if x.sum() > y.sum(): From 2af4a7c2a80909f2039d9ae70717229bbc9f50ee Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 22 Jun 2022 06:50:35 -0700 Subject: [PATCH 09/12] fix: Fix tests and corner cases for resolving non tensor inputs in fallback Signed-off-by: Dheeraj Peri --- core/partitioning/partitioning.cpp | 2 +- tests/core/partitioning/test_resolve_nontensor_inputs.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 93f377ab51..476d6fcfba 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -44,7 +44,7 @@ bool isModifyingNodes(torch::jit::Node* node, torch::jit::Value* val) { for (size_t i = 0; i < node->inputs().size(); ++i) { if (node->inputs()[i] == val) { const at::AliasInfo* formal = schema.arguments()[i].alias_info(); - if (formal->isWrite()) { + if (formal && formal->isWrite()) { return true; } } diff --git a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp index 10115459d0..fea202fc65 100644 --- a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp +++ b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp @@ -257,7 +257,7 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) { torch::jit::script::Module new_mod = torch_tensorrt::core::CompileGraph(mod, cfg); auto fallback_g = new_mod.get_method("forward").graph(); int count = count_trt_engines(fallback_g); - ASSERT_TRUE(count == 2); + ASSERT_TRUE(count == 1); } TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) { From 8e1e904cd6a056d661c9199b29553882a0e6c098 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 22 Jun 2022 07:01:24 -0700 Subject: [PATCH 10/12] chore: Fix linter issues Signed-off-by: Dheeraj Peri --- tests/core/partitioning/test_conditionals.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/partitioning/test_conditionals.cpp b/tests/core/partitioning/test_conditionals.cpp index 86d701dd4a..e2cbbb549b 100644 --- a/tests/core/partitioning/test_conditionals.cpp +++ b/tests/core/partitioning/test_conditionals.cpp @@ -61,8 +61,8 @@ TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) { trt_inputs_ivalues.push_back(in.clone()); } - std::vector inputs{ - torch_tensorrt::core::ir::Input({4, 4}), torch_tensorrt::core::ir::Input({4, 4})}; + std::vector inputs{torch_tensorrt::core::ir::Input({4, 4}), + torch_tensorrt::core::ir::Input({4, 4})}; auto g = mod.get_method("forward").graph(); torch_tensorrt::core::CompileSpec cfg(inputs); cfg.partition_info.enabled = true; From 2b09d1d0cfbf8b7b7e5eef3be6a4bf2a7f24970f Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 22 Jun 2022 07:24:33 -0700 Subject: [PATCH 11/12] chore: add misplaced no eval scripted loop fallback model Signed-off-by: Dheeraj Peri --- tests/modules/hub.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/modules/hub.py b/tests/modules/hub.py index ac8a5c3afa..7befcf0397 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -92,6 +92,10 @@ "model": cm.LoopFallbackEval(), "path": "script" }, + "loop_fallback_no_eval_scripted": { + "model": cm.LoopFallbackNoEval(), + "path": "script" + }, "conditional": { "model": cm.FallbackIf(), "path": "script" From 85306d87c1d5d3cbc414be4067bba9bdb73cfbdc Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 22 Jun 2022 07:25:52 -0700 Subject: [PATCH 12/12] chore: Minor fix in hub.py model name Signed-off-by: Dheeraj Peri --- tests/modules/hub.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modules/hub.py b/tests/modules/hub.py index 7befcf0397..1702628f20 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -92,7 +92,7 @@ "model": cm.LoopFallbackEval(), "path": "script" }, - "loop_fallback_no_eval_scripted": { + "loop_fallback_no_eval": { "model": cm.LoopFallbackNoEval(), "path": "script" },