From a64c50168258c1090a50519db156dc93304060a1 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Thu, 29 Apr 2021 02:34:41 -0500 Subject: [PATCH 1/7] feat: support prim::If Signed-off-by: Bo Wang --- core/compiler.cpp | 147 ++++++++++++++++++++------- core/partitioning/SegmentedBlock.cpp | 2 +- core/partitioning/SegmentedBlock.h | 2 +- core/partitioning/partitioning.cpp | 41 ++++---- core/partitioning/partitioning.h | 5 +- core/partitioning/shape_analysis.cpp | 12 ++- core/partitioning/shape_analysis.h | 3 +- 7 files changed, 149 insertions(+), 63 deletions(-) diff --git a/core/compiler.cpp b/core/compiler.cpp index 33dca957b2..987546f7c1 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -11,6 +11,7 @@ #include "torch/csrc/jit/frontend/function_schema_parser.h" #include "torch/csrc/jit/ir/ir.h" +#include "torch/csrc/jit/ir/ir_views.h" #include "torch/csrc/jit/passes/graph_fuser.h" #include "torch/csrc/jit/passes/lower_graph.h" #include "torch/csrc/jit/passes/pass_manager.h" @@ -171,10 +172,111 @@ void AddSegmentedBlockToGraph( for (size_t i = 0; i < seg.raw_outputs().size(); ++i) { old_to_new_g[seg.raw_outputs()[i]] = mini_to_new_g[seg.outputs()[i]]; } + size_t offset = seg.target() == partitioning::SegmentedBlock::kTensorRT ? 1 : 0; + for (size_t i = 0; i < seg.raw_inputs().size(); ++i) { + if (!old_to_new_g.count(seg.raw_inputs()[i])) { + old_to_new_g[seg.raw_inputs()[i]] = mini_to_new_g[seg.inputs()[i + offset]]; + } + } return; } +typedef std::pair, std::unordered_map> GraphAndMapping; + +GraphAndMapping ConstructFallbackGraph(torch::jit::script::Module& new_mod, torch::jit::Block* block, + std::unordered_map input_ivalues_map, + CompileSpec cfg, int& trt_engine_id, conversion::GraphParams named_params) { + auto convert_cfg = cfg.convert_info; + auto partition_info = cfg.partition_info; + + auto new_g = std::make_shared(); + + auto segmented_blocks = partitioning::Partition(block, input_ivalues_map, partition_info); + + // the mapping from lowering graph => fallback global graph + std::unordered_map old_to_new_g; + for (auto input : block->inputs()) { + util::getOrAddInputForValue(input, new_g, old_to_new_g); + } + + for (auto& seg_block : segmented_blocks) { + LOG_INFO(*seg_block.g() << "(GraphInSegmentedBlock)\n"); + + if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) { + std::vector input_ranges; + for (auto& shape : seg_block.in_shape()) { + input_ranges.push_back(ir::InputRange(shape)); + } + // update the input ranges for each segments + convert_cfg.input_ranges = input_ranges; + auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params); + auto temp_g = std::make_shared(); + AddEngineToGraph(new_mod, temp_g, engine, trt_engine_id++, true); + + seg_block.update_graph(temp_g); + AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); + } else { + if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) { + auto outer_node = seg_block.raw_nodes()[0]; + torch::jit::IfView if_view(outer_node); + + + // convert the 2 blocks in prim::if and get the converted graph with mappings + std::vector graph_and_mappings; + for (auto cur_block : outer_node->blocks()) { + graph_and_mappings.push_back(ConstructFallbackGraph(new_mod, cur_block, input_ivalues_map, cfg, trt_engine_id, named_params)); + } + + // create a new if node in new_g and add corresponding inputs + auto new_if = + new_g->insertNode(new_g->create(torch::jit::prim::If, {}, 0)); + new_if->addInput(util::getOrAddInputForValue(if_view.cond(), new_g, old_to_new_g)); + + + for (auto graph_and_mapping : graph_and_mappings) { + auto new_if_block = new_if->addBlock(); + auto cur_block_graph = graph_and_mapping.first; + auto cur_block_mapping = graph_and_mapping.second; + std::unordered_map block_graph_to_new_g; + for (auto& i : cur_block_mapping) { + // for every pair in then_mapping, old_value => then value, if old_value also appears in old_to_new_g, then it's then graph's input + if (old_to_new_g.count(i.first)) { + block_graph_to_new_g[i.second] = old_to_new_g[i.first]; + } + } + + auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); }; + new_if_block->cloneFrom(cur_block_graph->block(), env); + if (cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) { + block_graph_to_new_g[cur_block_graph->inputs()[0]] = new_g->inputs()[0]; + } + for (int i = cur_block_graph->inputs().size() - 1; i >= 0; --i) { + new_if_block->inputs()[i]->replaceAllUsesWith(block_graph_to_new_g[cur_block_graph->inputs()[i]]); + new_if_block->eraseInput(i); + } + } + for (auto ov : if_view.outputs()) { + auto no = new_if->addOutput(); + old_to_new_g[ov] = no; + no->copyMetadata(ov); + } + + LOG_INFO(*new_g << "new g with if\n"); + } else { + AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); + } + } + } + + for (auto& output : block->outputs()) { + if (old_to_new_g.count(output)) { + new_g->registerOutput(old_to_new_g[output]); + } + } + return {new_g, old_to_new_g}; +} + torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Module& mod, CompileSpec cfg) { // TODO: Should be doing a functional transform but need PR #31978 // [jit] More robust mangling @@ -190,46 +292,23 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo auto g = graph_and_parameters.first; auto params = graph_and_parameters.second; auto named_params = conversion::get_named_params(g->inputs(), params); - auto convert_cfg = std::move(cfg.convert_info); LOG_INFO(*g << "(LoweringGraph)\n"); // segment the graph and convert segmented TensorRT block - auto segmented_blocks = partitioning::Partition(g, convert_cfg.input_ranges, cfg.partition_info); - if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) { - LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n"); - return mod; - } +// auto segmented_blocks = partitioning::Partition(g->block(), convert_cfg.input_ranges, cfg.partition_info); +// if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) { +// LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n"); +// return mod; +// } int trt_engine_id = 0; - std::unordered_map old_to_new_g; - // add global graph's input to old_to_new_g mapping - for (auto input : g->inputs()) { - util::getOrAddInputForValue(input, new_g, old_to_new_g); - } - for (auto& seg_block : segmented_blocks) { - LOG_INFO(*g << "(MiniGraphInSegmentedBlock)\n"); - if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) { - std::vector input_ranges; - for (auto& shape : seg_block.in_shape()) { - input_ranges.push_back(ir::InputRange(shape)); - } - // update the input ranges for each segments - convert_cfg.input_ranges = input_ranges; - auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params); - auto temp_g = std::make_shared(); - AddEngineToGraph(new_mod, temp_g, engine, trt_engine_id++, true); - - seg_block.update_graph(temp_g); - AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); - } else { - AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); - } - } - - for (auto& output : g->outputs()) { - new_g->registerOutput(old_to_new_g[output]); + std::unordered_map input_ranges; + for (size_t i = 0; i < g->inputs().size(); ++i) { + input_ranges.insert({g->inputs()[i], cfg.convert_info.input_ranges[i]}); } - + auto input_ivalues_map = partitioning::generateRandomInputs(input_ranges); + auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, trt_engine_id, named_params); + new_g = graph_and_mapping.first; LOG_INFO(*new_g << "(FallbackGraph)\n"); auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g); diff --git a/core/partitioning/SegmentedBlock.cpp b/core/partitioning/SegmentedBlock.cpp index a940ddf799..505b565482 100644 --- a/core/partitioning/SegmentedBlock.cpp +++ b/core/partitioning/SegmentedBlock.cpp @@ -4,7 +4,7 @@ namespace trtorch { namespace core { namespace partitioning { -SegmentedBlock::SegmentedBlock(SegmentedBlockTarget blk_target, std::vector& nodes) +SegmentedBlock::SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector& nodes) : target_(blk_target), g_(std::make_shared()) { for (auto& node : nodes) { nodes_.push_back(node); diff --git a/core/partitioning/SegmentedBlock.h b/core/partitioning/SegmentedBlock.h index 72e21069f2..dd674949cb 100644 --- a/core/partitioning/SegmentedBlock.h +++ b/core/partitioning/SegmentedBlock.h @@ -20,7 +20,7 @@ struct SegmentedBlock { SegmentedBlock() = default; SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared()) {} - SegmentedBlock(SegmentedBlockTarget blk_target, std::vector& nodes); + SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector& nodes); SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr g) : target_(blk_target), g_(g) {} torch::jit::Value* getOrAddInputForValue(torch::jit::Value* v); diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index e6f4451362..8b0f887ced 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -2,7 +2,6 @@ #include #include "core/conversion/conversion.h" -#include "core/partitioning/shape_analysis.h" #include "torch/csrc/jit/passes/constant_pooling.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" @@ -118,7 +117,7 @@ std::vector injectNodesForNonTensorInputs(SegmentedBlock& seg_bl return std::move(new_seg_blocks); } -void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr g) { +void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // for NonTensor inputs in TensorRT segments, count the usages on Torch segments and TensorRT segments std::unordered_map usage_counts; for (int i = segmented_blocks.size() - 1; i >= 0; --i) { @@ -161,7 +160,7 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr< return; } -void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, std::shared_ptr g) { +void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Block* block) { // find the corresponding raw values in original global graph for this segmented block's inputs/outputs std::set input_values; for (auto& seg_block : segmented_blocks) { @@ -170,7 +169,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, std::shared_ptr } } - for (auto& graph_output : g->outputs()) { + for (auto& graph_output : block->outputs()) { input_values.insert(graph_output); } @@ -219,12 +218,13 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, std::shared_ptr return; } -std::vector segment_graph(std::shared_ptr g, const PartitionInfo& partition_info) { +std::vector segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info) { auto min_block_size = partition_info.min_block_size; std::unordered_set forced_fallback_operators( partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end()); - auto nodes = g->block()->nodes(); + + auto nodes = block->nodes(); std::vector segmented_blocks; // segment the nodes @@ -248,6 +248,16 @@ std::vector segment_graph(std::shared_ptr g, pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end()); } tensorrt_nodes.clear(); + // if there is a prim::If then this if node will be encapsulated in a SegmentedBlock + // we shouldn't inject node for this block in dependency analysis process + if (n->kind() == torch::jit::prim::If) { + if (!pytorch_nodes.empty()) { + segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes); + pytorch_nodes.clear(); + } + segmented_blocks.emplace_back(SegmentedBlock::kTorch, std::vector{n}); + continue; + } pytorch_nodes.push_back(n); } } @@ -265,30 +275,23 @@ std::vector segment_graph(std::shared_ptr g, } std::vector Partition( - std::shared_ptr g, - std::vector& input_ranges, + torch::jit::Block* block, + std::unordered_map& input_ivalues_map, const PartitionInfo& partition_info) { LOG_DEBUG(partition_info); // segment lowering global graph into blocks - std::vector segmented_blocks = segment_graph(g, partition_info); + std::vector segmented_blocks = segment_graph(block, partition_info); // resolve nonTensor inputs/outputs - resolveNonTensorInputs(segmented_blocks, g); + resolveNonTensorInputs(segmented_blocks); // register input/output torch::jit::Value for segmented graphs - registerSegmentsOutputs(segmented_blocks, g); - - // store the mapping from lowering graph torch::jit::Value => torch::jit::IValue that we get by running segments - std::unordered_map ivalues_maps; - std::vector random_inputs = generateRandomInputs(input_ranges); - for (size_t i = 0; i < g->inputs().size(); ++i) { - ivalues_maps[g->inputs()[i]] = random_inputs[i]; - } + registerSegmentsOutputs(segmented_blocks, block); // register every segment's input shape, and it's running output IValues for (auto& seg_block : segmented_blocks) { torch::jit::ConstantPooling(seg_block.g()); - getSegmentsOutputByRunning(seg_block, ivalues_maps); + getSegmentsOutputByRunning(seg_block, input_ivalues_map); } return segmented_blocks; diff --git a/core/partitioning/partitioning.h b/core/partitioning/partitioning.h index af857237a3..c10a3774dd 100644 --- a/core/partitioning/partitioning.h +++ b/core/partitioning/partitioning.h @@ -5,6 +5,7 @@ #include "core/ir/ir.h" #include "core/partitioning/PartitionInfo.h" #include "core/partitioning/SegmentedBlock.h" +#include "core/partitioning/shape_analysis.h" #include "core/util/prelude.h" #include "torch/csrc/jit/ir/ir.h" @@ -17,8 +18,8 @@ typedef std::vector PartitionedGraph; PartitionedGraph segment_graph(std::shared_ptr g, const PartitionInfo& partition_info); std::vector Partition( - std::shared_ptr g, - std::vector& input_ranges, + torch::jit::Block* block, + std::unordered_map& input_ivalues_map, const PartitionInfo& partition_info); } // namespace partitioning diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index 42a016575d..110afbbb3a 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -6,17 +6,19 @@ namespace trtorch { namespace core { namespace partitioning { -std::vector generateRandomInputs(std::vector& input_ranges) { +std::unordered_map generateRandomInputs( + std::unordered_map& input_ranges) { // generate random inputs for running pytorch segments + std::unordered_map ivalue_maps; std::vector random_inputs; for (auto& input_range : input_ranges) { - auto cur_shape = input_range.input_shape; + auto cur_shape = input_range.second.input_shape; std::vector shape; shape.insert(shape.begin(), std::begin(cur_shape.d), std::begin(cur_shape.d) + cur_shape.nbDims); auto in = at::randint(5, shape, {at::kCUDA}); - random_inputs.push_back(in.clone()); + ivalue_maps[input_range.first] = in.clone(); } - return random_inputs; + return ivalue_maps; } void getSegmentsOutputByRunning( @@ -50,7 +52,7 @@ void getSegmentsOutputByRunning( // set inputs ivalues, now supports Tensor/Int to pass argumentes between different segments for (auto& input : seg_block.raw_inputs()) { - TRTORCH_CHECK(ivalues_maps.count(input), "Could not find mini graph input IValue " << input->debugName()); + TRTORCH_CHECK(ivalues_maps.count(input), "Could not find torch::jit::Value* " << input->debugName() << " in lowering graph for mini graph input.\n"); if (input->node()->kind() == torch::jit::prim::Param) { jit_inputs_ivalues.push_back(ivalues_maps[input]); } else if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) { diff --git a/core/partitioning/shape_analysis.h b/core/partitioning/shape_analysis.h index 152de3dcce..ee9fb00d19 100644 --- a/core/partitioning/shape_analysis.h +++ b/core/partitioning/shape_analysis.h @@ -5,7 +5,8 @@ namespace trtorch { namespace core { namespace partitioning { -std::vector generateRandomInputs(std::vector& input_ranges); +std::unordered_map generateRandomInputs( + std::unordered_map& input_ranges); void getSegmentsOutputByRunning( SegmentedBlock& seg_block, From 9823fff2575d49c0e2db9e4a630b211dca3113fb Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Thu, 29 Apr 2021 14:23:34 -0500 Subject: [PATCH 2/7] chore: improve the implementation of prim::if Signed-off-by: Bo Wang --- core/compiler.cpp | 121 +++++++++++++++------------ core/partitioning/partitioning.cpp | 1 - core/partitioning/shape_analysis.cpp | 4 +- 3 files changed, 72 insertions(+), 54 deletions(-) diff --git a/core/compiler.cpp b/core/compiler.cpp index 987546f7c1..5a8b3b8565 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -182,11 +182,62 @@ void AddSegmentedBlockToGraph( return; } -typedef std::pair, std::unordered_map> GraphAndMapping; +typedef std::pair, std::unordered_map> + GraphAndMapping; -GraphAndMapping ConstructFallbackGraph(torch::jit::script::Module& new_mod, torch::jit::Block* block, - std::unordered_map input_ivalues_map, - CompileSpec cfg, int& trt_engine_id, conversion::GraphParams named_params) { +void AddIfBlockToGraph( + std::shared_ptr& new_g, + torch::jit::Node* if_node, + const std::vector& graph_and_mappings, + std::unordered_map& old_to_new_g) { + torch::jit::IfView if_view(if_node); + + // create a new if node in new_g and add corresponding inputs + auto new_if = new_g->insertNode(new_g->create(torch::jit::prim::If, {}, 0)); + new_if->addInput(util::getOrAddInputForValue(if_view.cond(), new_g, old_to_new_g)); + + for (auto graph_and_mapping : graph_and_mappings) { + auto new_if_block = new_if->addBlock(); + auto cur_block_graph = graph_and_mapping.first; + auto cur_block_mapping = graph_and_mapping.second; + std::unordered_map block_graph_to_new_g; + for (auto& i : cur_block_mapping) { + // for every pair in then_mapping, old_value => then value, if old_value also appears in old_to_new_g, then it's + // then graph's input + if (old_to_new_g.count(i.first)) { + block_graph_to_new_g[i.second] = old_to_new_g[i.first]; + } + } + + auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); }; + new_if_block->cloneFrom(cur_block_graph->block(), env); + if (cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) { + if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) { + auto self = new_g->insertInput(0, "self_1"); + self->setType(loop_graph->inputs()[0]->type()); + } + block_graph_to_new_g[cur_block_graph->inputs()[0]] = new_g->inputs()[0]; + } + for (int i = cur_block_graph->inputs().size() - 1; i >= 0; --i) { + new_if_block->inputs()[i]->replaceAllUsesWith(block_graph_to_new_g[cur_block_graph->inputs()[i]]); + new_if_block->eraseInput(i); + } + } + for (auto ov : if_view.outputs()) { + auto no = new_if->addOutput(); + old_to_new_g[ov] = no; + no->copyMetadata(ov); + } + return; +} + +GraphAndMapping ConstructFallbackGraph( + torch::jit::script::Module& new_mod, + torch::jit::Block* block, + std::unordered_map input_ivalues_map, + CompileSpec cfg, + int& trt_engine_id, + conversion::GraphParams named_params) { auto convert_cfg = cfg.convert_info; auto partition_info = cfg.partition_info; @@ -218,51 +269,16 @@ GraphAndMapping ConstructFallbackGraph(torch::jit::script::Module& new_mod, torc AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); } else { if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) { - auto outer_node = seg_block.raw_nodes()[0]; - torch::jit::IfView if_view(outer_node); - + auto if_node = seg_block.raw_nodes()[0]; // convert the 2 blocks in prim::if and get the converted graph with mappings std::vector graph_and_mappings; - for (auto cur_block : outer_node->blocks()) { - graph_and_mappings.push_back(ConstructFallbackGraph(new_mod, cur_block, input_ivalues_map, cfg, trt_engine_id, named_params)); + for (auto cur_block : if_node->blocks()) { + graph_and_mappings.push_back( + ConstructFallbackGraph(new_mod, cur_block, input_ivalues_map, cfg, trt_engine_id, named_params)); } + AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g); - // create a new if node in new_g and add corresponding inputs - auto new_if = - new_g->insertNode(new_g->create(torch::jit::prim::If, {}, 0)); - new_if->addInput(util::getOrAddInputForValue(if_view.cond(), new_g, old_to_new_g)); - - - for (auto graph_and_mapping : graph_and_mappings) { - auto new_if_block = new_if->addBlock(); - auto cur_block_graph = graph_and_mapping.first; - auto cur_block_mapping = graph_and_mapping.second; - std::unordered_map block_graph_to_new_g; - for (auto& i : cur_block_mapping) { - // for every pair in then_mapping, old_value => then value, if old_value also appears in old_to_new_g, then it's then graph's input - if (old_to_new_g.count(i.first)) { - block_graph_to_new_g[i.second] = old_to_new_g[i.first]; - } - } - - auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); }; - new_if_block->cloneFrom(cur_block_graph->block(), env); - if (cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) { - block_graph_to_new_g[cur_block_graph->inputs()[0]] = new_g->inputs()[0]; - } - for (int i = cur_block_graph->inputs().size() - 1; i >= 0; --i) { - new_if_block->inputs()[i]->replaceAllUsesWith(block_graph_to_new_g[cur_block_graph->inputs()[i]]); - new_if_block->eraseInput(i); - } - } - for (auto ov : if_view.outputs()) { - auto no = new_if->addOutput(); - old_to_new_g[ov] = no; - no->copyMetadata(ov); - } - - LOG_INFO(*new_g << "new g with if\n"); } else { AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); } @@ -294,23 +310,24 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo auto named_params = conversion::get_named_params(g->inputs(), params); LOG_INFO(*g << "(LoweringGraph)\n"); - // segment the graph and convert segmented TensorRT block -// auto segmented_blocks = partitioning::Partition(g->block(), convert_cfg.input_ranges, cfg.partition_info); -// if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) { -// LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n"); -// return mod; -// } - int trt_engine_id = 0; std::unordered_map input_ranges; for (size_t i = 0; i < g->inputs().size(); ++i) { input_ranges.insert({g->inputs()[i], cfg.convert_info.input_ranges[i]}); } auto input_ivalues_map = partitioning::generateRandomInputs(input_ranges); - auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, trt_engine_id, named_params); + auto graph_and_mapping = + ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, trt_engine_id, named_params); new_g = graph_and_mapping.first; LOG_INFO(*new_g << "(FallbackGraph)\n"); + // if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial + // module + if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) { + LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n"); + return mod; + } + auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g); auto schema = util::GenerateGraphSchema(new_method->name(), new_g); new_mod.type()->addMethod(new_method); diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 8b0f887ced..c42eb11c8c 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -223,7 +223,6 @@ std::vector segment_graph(torch::jit::Block* block, const Partit std::unordered_set forced_fallback_operators( partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end()); - auto nodes = block->nodes(); std::vector segmented_blocks; diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index 110afbbb3a..b3a8d68369 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -52,7 +52,9 @@ void getSegmentsOutputByRunning( // set inputs ivalues, now supports Tensor/Int to pass argumentes between different segments for (auto& input : seg_block.raw_inputs()) { - TRTORCH_CHECK(ivalues_maps.count(input), "Could not find torch::jit::Value* " << input->debugName() << " in lowering graph for mini graph input.\n"); + TRTORCH_CHECK( + ivalues_maps.count(input), + "Could not find torch::jit::Value* " << input->debugName() << " in lowering graph for mini graph input.\n"); if (input->node()->kind() == torch::jit::prim::Param) { jit_inputs_ivalues.push_back(ivalues_maps[input]); } else if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) { From d00627f1679edd3576323e5a07010ca88ee2c1a8 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Mon, 3 May 2021 23:32:48 -0500 Subject: [PATCH 3/7] fix: make sure that prim::if is in raw_nodes()[0] in dependency analysis Signed-off-by: Bo Wang --- core/compiler.cpp | 7 ++++--- core/partitioning/partitioning.cpp | 15 +++++++++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/core/compiler.cpp b/core/compiler.cpp index 5a8b3b8565..00fb408259 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -196,14 +196,15 @@ void AddIfBlockToGraph( auto new_if = new_g->insertNode(new_g->create(torch::jit::prim::If, {}, 0)); new_if->addInput(util::getOrAddInputForValue(if_view.cond(), new_g, old_to_new_g)); + // iterate over all blocks and add them to new created prim::If for (auto graph_and_mapping : graph_and_mappings) { auto new_if_block = new_if->addBlock(); auto cur_block_graph = graph_and_mapping.first; auto cur_block_mapping = graph_and_mapping.second; std::unordered_map block_graph_to_new_g; for (auto& i : cur_block_mapping) { - // for every pair in then_mapping, old_value => then value, if old_value also appears in old_to_new_g, then it's - // then graph's input + // for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then it's + // mini graph's input if (old_to_new_g.count(i.first)) { block_graph_to_new_g[i.second] = old_to_new_g[i.first]; } @@ -214,7 +215,7 @@ void AddIfBlockToGraph( if (cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) { if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) { auto self = new_g->insertInput(0, "self_1"); - self->setType(loop_graph->inputs()[0]->type()); + self->setType(cur_block_graph->inputs()[0]->type()); } block_graph_to_new_g[cur_block_graph->inputs()[0]] = new_g->inputs()[0]; } diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index c42eb11c8c..5da14f1e4d 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -84,8 +84,14 @@ std::vector injectNodesForNonTensorInputs(SegmentedBlock& seg_bl // if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, construct only // one new block if (seg_block.target() == SegmentedBlock::kTorch || isAllNodesSupported(dependency_nodes)) { - dependency_nodes.insert(dependency_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end()); - new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes); + // if current node is prim::If, just ensure that we have all required input in kTorch + if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) { + new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes); + new_seg_blocks.push_back(seg_block); + } else { + dependency_nodes.insert(dependency_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end()); + new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes); + } } else { // if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again std::unordered_set nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end()); @@ -141,8 +147,9 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { if (segmented_blocks[use_info.produce_id].target() == SegmentedBlock::kTensorRT && !use_info.torch_use_id.empty()) { int first_torch_id = use_info.torch_use_id.front(); if (!updated_segments.count(first_torch_id)) { - auto new_torch_block = injectNodesForNonTensorInputs(segmented_blocks[first_torch_id]).front(); - segmented_blocks[first_torch_id] = new_torch_block; + auto to_inject_blocks = injectNodesForNonTensorInputs(segmented_blocks[first_torch_id]); + segmented_blocks.erase(segmented_blocks.begin() + first_torch_id); + segmented_blocks.insert(segmented_blocks.begin() + first_torch_id, to_inject_blocks.begin(), to_inject_blocks.end()); updated_segments.insert(first_torch_id); } } else { From 2af79355b4e45fdeef63089c872cbe6e8c6d4d29 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Tue, 4 May 2021 18:05:21 -0500 Subject: [PATCH 4/7] chore: apply linting Signed-off-by: Bo Wang --- core/compiler.cpp | 7 +++---- core/partitioning/partitioning.cpp | 5 +++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/core/compiler.cpp b/core/compiler.cpp index dfcd69d252..2f2f2020f4 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -203,8 +203,8 @@ void AddIfBlockToGraph( auto cur_block_mapping = graph_and_mapping.second; std::unordered_map block_graph_to_new_g; for (auto& i : cur_block_mapping) { - // for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then it's - // mini graph's input + // for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then + // it's mini graph's input if (old_to_new_g.count(i.first)) { block_graph_to_new_g[i.second] = old_to_new_g[i.first]; } @@ -317,8 +317,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo input_ranges.insert({g->inputs()[i], cfg.convert_info.input_ranges[i]}); } auto input_ivalues_map = partitioning::generateRandomInputs(input_ranges); - auto graph_and_mapping = - ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, named_params); + auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, named_params); new_g = graph_and_mapping.first; LOG_INFO(*new_g << "(FallbackGraph)\n"); diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 692878a943..6590cd5d05 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -2,8 +2,8 @@ #include #include "core/conversion/conversion.h" -#include "torch/csrc/jit/passes/constant_pooling.h" #include "core/partitioning/shape_analysis.h" +#include "torch/csrc/jit/passes/constant_pooling.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" namespace trtorch { @@ -150,7 +150,8 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { if (!updated_segments.count(first_torch_id)) { auto to_inject_blocks = injectNodesForNonTensorInputs(segmented_blocks[first_torch_id]); segmented_blocks.erase(segmented_blocks.begin() + first_torch_id); - segmented_blocks.insert(segmented_blocks.begin() + first_torch_id, to_inject_blocks.begin(), to_inject_blocks.end()); + segmented_blocks.insert( + segmented_blocks.begin() + first_torch_id, to_inject_blocks.begin(), to_inject_blocks.end()); updated_segments.insert(first_torch_id); } } else { From f2053776c2437d3efee893e7d7b55832f1d34235 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Fri, 14 May 2021 19:56:19 -0500 Subject: [PATCH 5/7] test: add test suite for conditionals Signed-off-by: Bo Wang --- tests/core/partitioning/BUILD | 30 +++++++++++-- tests/core/partitioning/test_conditionals.cpp | 43 +++++++++++++++++++ tests/modules/hub.py | 28 ++++++++++++ 3 files changed, 98 insertions(+), 3 deletions(-) create mode 100644 tests/core/partitioning/test_conditionals.cpp diff --git a/tests/core/partitioning/BUILD b/tests/core/partitioning/BUILD index 89dcd6fabd..e6e9a6ae99 100644 --- a/tests/core/partitioning/BUILD +++ b/tests/core/partitioning/BUILD @@ -7,6 +7,13 @@ config_setting( } ) +filegroup( + name = "jit_models", + srcs = ["//tests/modules:resnet50_traced.jit.pt", + "//tests/modules:mobilenet_v2_traced.jit.pt", + "//tests/modules:conditional_scripted.jit.pt"] +) + partitioning_test( name = "test_segmentation", ) @@ -35,17 +42,34 @@ cc_test( "//conditions:default": ["@libtorch//:libtorch"], }), data = [ - "//tests/modules:jit_models" + ":jit_models" + ] +) + +cc_test( + name = "test_conditionals", + srcs = ["test_conditionals.cpp"], + deps = [ + "//tests/util", + "//core", + "@googletest//:gtest_main", + ] + select({ + ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], + "//conditions:default": ["@libtorch//:libtorch"], + }), + data = [ + ":jit_models" ] ) test_suite( - name = "partitioning_tests", + name = "partitioning_test", tests = [ ":test_segmentation", ":test_shape_analysis", ":test_tensorrt_conversion", ":test_stitched_graph", - ":test_fallback_graph_output" + ":test_fallback_graph_output", + ":test_conditionals" ] ) \ No newline at end of file diff --git a/tests/core/partitioning/test_conditionals.cpp b/tests/core/partitioning/test_conditionals.cpp new file mode 100644 index 0000000000..016b434fab --- /dev/null +++ b/tests/core/partitioning/test_conditionals.cpp @@ -0,0 +1,43 @@ +#include +#include +#include "core/compiler.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/script.h" + +size_t count_trt_engines_in_conditionals(std::shared_ptr g) { + size_t count = 0; + for (auto n : g->nodes()) { + if (n->kind() == torch::jit::prim::If) { + std::vector blocks{n->blocks()[0], n->blocks()[1]}; + for (auto cur_block : blocks) { + for (auto n : cur_block->nodes()) { + if (n->kind().toQualString() == std::string("tensorrt::execute_engine")) { + ++count; + } + } + } + } + } + return count; +} + +TEST(Partitioning, FallbackOnConditionalsCorrectly) { + torch::jit::script::Module mod; + try { + mod = torch::jit::load("tests/modules/conditional_scripted.jit.pt"); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + return; + } + + std::vector input_ranges{trtorch::core::ir::InputRange({3, 3, 16, 16})}; + trtorch::core::CompileSpec cfg(input_ranges); + cfg.partition_info.enabled = true; + torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg); + auto g = new_mod.get_method("forward").graph(); + + auto conditional_engines_count = count_trt_engines_in_conditionals(g); + + ASSERT_TRUE(conditional_engines_count == 2); +} diff --git a/tests/modules/hub.py b/tests/modules/hub.py index e2a0516e0a..5a521f439a 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -95,3 +95,31 @@ def forward(self, x): trace_model = torch.jit.trace(model, x) torch.jit.save(trace_model, "pooling_traced.jit.pt") + + +# Sample Conditional Model (for testing partitioning and fallback in conditionals) +class FallbackIf(torch.nn.Module): + def __init__(self): + super(FallbackIf, self).__init__() + self.relu1 = torch.nn.ReLU() + self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1) + self.log_sig = torch.nn.LogSigmoid() + self.conv2 = torch.nn.Conv2d(32, 32, 3, 1, 1) + self.conv3 = torch.nn.Conv2d(32, 3, 3, 1, 1) + + def forward(self, x): + x = self.relu1(x) + x_first = x[0][0][0][0].item() + if x_first > 0: + x = self.conv1(x) + x1 = self.log_sig(x) + x2 = self.conv2(x) + x = self.conv3(x1 + x2) + else: + x = self.log_sig(x) + x = self.conv1(x) + return x + +conditional_model = FallbackIf().eval().cuda() +conditional_script_model = torch.jit.script(conditional_model) +torch.jit.save(conditional_script_model, "conditional_scripted.jit.pt") \ No newline at end of file From ff87956f0bdd6a4992d2eb7394989f9b4f2b8ef9 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 26 Jul 2021 11:18:35 -0700 Subject: [PATCH 6/7] fix: Fix testcases using old InputRange API Signed-off-by: Dheeraj Peri --- core/compiler.cpp | 42 +------------------ tests/core/partitioning/test_conditionals.cpp | 4 +- .../core/partitioning/test_shape_analysis.cpp | 2 +- tests/modules/hub.py | 8 ---- 4 files changed, 4 insertions(+), 52 deletions(-) diff --git a/core/compiler.cpp b/core/compiler.cpp index 4dea4a76f5..02a617a823 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -261,7 +261,7 @@ GraphAndMapping ConstructFallbackGraph( if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) { std::vector inputs; for (auto& shape : seg_block.in_shape()) { - inputs.push_back(ir::InputRange(shape)); + inputs.push_back(ir::Input(shape)); } // update the input ranges for each segments convert_cfg.inputs = inputs; @@ -332,46 +332,6 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo return mod; } -// <<<<<<< HEAD -// ======= -// std::unordered_map old_to_new_g; -// // add global graph's input to old_to_new_g mapping -// for (auto input : g->inputs()) { -// util::getOrAddInputForValue(input, new_g, old_to_new_g); -// } -// for (auto& seg_block : segmented_blocks) { -// std::string cur_block_target = -// seg_block.target() == partitioning::SegmentedBlock::kTensorRT ? "TensorRT" : "Torch"; -// LOG_INFO(*seg_block.g() << "(Sub Graph" << cur_block_target << "Block)\n"); -// std::ostringstream trt_engine_id; -// trt_engine_id << reinterpret_cast(&seg_block); -// if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) { -// std::vector inputs; -// for (auto& shape : seg_block.in_shape()) { -// inputs.push_back(ir::Input(shape)); -// } -// // update the input ranges for each segments -// convert_cfg.inputs = inputs; -// auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params); -// auto temp_g = std::make_shared(); -// auto device_spec = convert_cfg.engine_settings.device; -// auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type); -// AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true); -// -// seg_block.update_graph(temp_g); -// AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); -// } else { -// AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); -// } -// } -// -// for (auto& output : g->outputs()) { -// new_g->registerOutput(old_to_new_g[output]); -// } -// -// LOG_INFO(*new_g << "(FallbackGraph)\n"); -// -// >>>>>>> master auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g); auto schema = util::GenerateGraphSchema(new_method->name(), new_g); new_mod.type()->addMethod(new_method); diff --git a/tests/core/partitioning/test_conditionals.cpp b/tests/core/partitioning/test_conditionals.cpp index 016b434fab..27b3fc2c10 100644 --- a/tests/core/partitioning/test_conditionals.cpp +++ b/tests/core/partitioning/test_conditionals.cpp @@ -31,8 +31,8 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) { return; } - std::vector input_ranges{trtorch::core::ir::InputRange({3, 3, 16, 16})}; - trtorch::core::CompileSpec cfg(input_ranges); + std::vector inputs{trtorch::core::ir::Input({3, 3, 16, 16})}; + trtorch::core::CompileSpec cfg(inputs); cfg.partition_info.enabled = true; torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg); auto g = new_mod.get_method("forward").graph(); diff --git a/tests/core/partitioning/test_shape_analysis.cpp b/tests/core/partitioning/test_shape_analysis.cpp index 48bb2474c8..9129acfc3f 100644 --- a/tests/core/partitioning/test_shape_analysis.cpp +++ b/tests/core/partitioning/test_shape_analysis.cpp @@ -107,7 +107,7 @@ TEST(Partitioning, InferBranchModelSegmentedBlockShapeCorrectly) { inputs.push_back(trtorch::core::ir::Input({16, 32, 3, 3})); inputs.push_back(trtorch::core::ir::Input({16})); - std::unordered_map inputs_map; + std::unordered_map inputs_map; for (size_t i = 0; i < g->inputs().size(); ++i) { inputs_map.insert({g->inputs()[i], inputs[i]}); } diff --git a/tests/modules/hub.py b/tests/modules/hub.py index c858169cd2..ceca1edef7 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -54,18 +54,10 @@ "model": torch.hub.load('pytorch/vision:v0.9.0', 'resnet50', pretrained=True), "path": "both" }, - "fcn_resnet101": { - "model": torch.hub.load('pytorch/vision:v0.9.0', 'fcn_resnet101', pretrained=True), - "path": "script" - }, "ssd": { "model": torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math="fp32"), "path": "trace" }, - "faster_rcnn": { - "model": models.detection.fasterrcnn_resnet50_fpn(pretrained=True), - "path": "script" - }, "efficientnet_b0": { "model": timm.create_model('efficientnet_b0', pretrained=True), "path": "script" From 0b49965113b7f5f638b234f5a656705500bc6ebc Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 28 Jul 2021 09:04:13 -0700 Subject: [PATCH 7/7] refactor: apply linting Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- tests/modules/hub.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/modules/hub.py b/tests/modules/hub.py index ceca1edef7..986a84776c 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -100,6 +100,7 @@ def forward(self, x): # Sample Conditional Model (for testing partitioning and fallback in conditionals) class FallbackIf(torch.nn.Module): + def __init__(self): super(FallbackIf, self).__init__() self.relu1 = torch.nn.ReLU() @@ -121,6 +122,7 @@ def forward(self, x): x = self.conv1(x) return x + conditional_model = FallbackIf().eval().cuda() conditional_script_model = torch.jit.script(conditional_model) torch.jit.save(conditional_script_model, "conditional_scripted.jit.pt")