diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp old mode 100644 new mode 100755 index 2a1e3f8943..ed6a38ec4f --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -68,6 +68,7 @@ std::vector getDependencyNodes(std::vectornode(); if (node->kind() != torch::jit::prim::Constant && !visited.count(node)) { + visited.insert(node); stk.push_back(node); for (auto input : node->inputs()) { if (!isTensorOrTensorList(input)) { @@ -89,14 +90,14 @@ std::vector getOutputNodes( std::unordered_set visited; q.push(value); - // top-down order traveling + // top-down order traversing while (!q.empty()) { auto cur_val = q.front(); q.pop(); for (auto use : cur_val->uses()) { auto node = use.user; // use node must be in seg_block_nodes - if (seg_block_nodes.count(node) != 0 && !visited.count(node)) { + if (seg_block_nodes.count(node) && !visited.count(node)) { stk.push_back(node); visited.insert(node); // travel its' all outputs @@ -109,10 +110,41 @@ std::vector getOutputNodes( } } - // top-down order and we don't need reverse it + // top-down order and we don't need to reverse it return stk; } +void getDirtyNodes( + std::unordered_set& dirty_nodes, + const std::unordered_set& seg_block_nodes) { + std::queue q; + for (auto& node : dirty_nodes) { + q.push(node); + } + dirty_nodes.clear(); + + while (!q.empty()) { + auto cur_node = q.front(); + q.pop(); + if (!dirty_nodes.count(cur_node) && seg_block_nodes.count(cur_node)) { + dirty_nodes.insert(cur_node); + for (auto input : cur_node->inputs()) { + if (!isTensorOrTensorList(input)) { + q.push(input->node()); + } + } + for (auto output : cur_node->outputs()) { + if (!isTensorOrTensorList(output)) { + for (auto use : output->uses()) { + auto node = use.user; + q.push(node); + } + } + } + } + } +} + std::pair, SegmentedBlock> segmentBlocksWithTensorListInputs( SegmentedBlock& seg_block, const std::unordered_map& tensorlist_inputs) { @@ -163,25 +195,29 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) { } else { // if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again std::unordered_set nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end()); - std::vector tensorrt_nodes, pytorch_nodes(dependency_nodes.begin(), dependency_nodes.end()); + std::vector tensorrt_nodes, pytorch_nodes; - bool prev_non_tensor_outputs = false; + // take all nodes with non_tensor_inputs as initial dirty nodes (nodes that should be in PyTorch block), then we use + // dfs/bfs to find all dirty nodes that consume non_tensor values produced by dirty nodes or produces non_tensor + // values consumed by dirty nodes + std::unordered_set dirty_nodes; + const std::unordered_set seg_block_nodes( + seg_block.raw_nodes().begin(), seg_block.raw_nodes().end()); + + for (auto n : seg_block.raw_nodes()) { + if (containTargetInputs(n, nontensor_inputs_set)) { + dirty_nodes.insert(n); + } + } + getDirtyNodes(dirty_nodes, seg_block_nodes); for (auto n : seg_block.raw_nodes()) { - // Check if the node has non-tensor inputs or if it consumes non-tensor outputs of previous node. - // In these cases, these nodes are placed into a new Pytorch SegmentedBlock. Else, they form a new TensorRT - // SegmentedBlock. - if (containTargetInputs(n, nontensor_inputs_set) || prev_non_tensor_outputs) { - // If tensorrt_nodes is not empty, the previous nodes were all tensorrt_nodes. Construct a - // TensorRT segmented_block and clear the tensorrt_nodes list to be later used for new TRT segments. + if (dirty_nodes.count(n)) { if (!tensorrt_nodes.empty()) { new_seg_blocks.emplace_back(new_seg_blocks.size(), SegmentedBlock::kTensorRT, tensorrt_nodes); tensorrt_nodes.clear(); } pytorch_nodes.push_back(n); - prev_non_tensor_outputs = containNonTensorOutputs(n); } else { - // If pytorch_nodes is not empty, the previous nodes were all pytorch_nodes. Construct a - // Pytorch segmented_block and clear the pytorch_nodes list to be later used for new Pytorch segments. if (!pytorch_nodes.empty()) { new_seg_blocks.emplace_back(new_seg_blocks.size(), SegmentedBlock::kTorch, pytorch_nodes); pytorch_nodes.clear(); @@ -190,7 +226,7 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) { } } - // Form the last segmented_block with the left over nodes in tensorrt_nodes or pytorch_nodes correspondingly. + // Form the last segmented_block with the leftover nodes in tensorrt_nodes or pytorch_nodes correspondingly. if (!tensorrt_nodes.empty()) { new_seg_blocks.emplace_back(new_seg_blocks.size(), SegmentedBlock::kTensorRT, tensorrt_nodes); } else { diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index a899b3313e..d24b1f980a 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -86,6 +86,8 @@ void getSegmentsOutputByRunning( jit_inputs_ivalues.push_back(ivalues_maps[input].toScalar()); } else if (input->type()->kind() == torch::jit::TypeKind::DictType) { jit_inputs_ivalues.push_back(ivalues_maps[input].toGenericDict()); + } else if (input->type()->kind() == torch::jit::TypeKind::DeviceObjType) { + jit_inputs_ivalues.push_back(ivalues_maps[input].toDevice()); } else { TORCHTRT_THROW_ERROR( "Expected to find type " << input->type()->str() << " for value " << input->debugName()