diff --git a/core/compiler.cpp b/core/compiler.cpp index 1f7ab3aa47..02a617a823 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -11,6 +11,7 @@ #include "torch/csrc/jit/frontend/function_schema_parser.h" #include "torch/csrc/jit/ir/ir.h" +#include "torch/csrc/jit/ir/ir_views.h" #include "torch/csrc/jit/passes/graph_fuser.h" #include "torch/csrc/jit/passes/loop_unrolling.h" #include "torch/csrc/jit/passes/lower_graph.h" @@ -173,10 +174,131 @@ void AddSegmentedBlockToGraph( for (size_t i = 0; i < seg.raw_outputs().size(); ++i) { old_to_new_g[seg.raw_outputs()[i]] = mini_to_new_g[seg.outputs()[i]]; } + size_t offset = seg.target() == partitioning::SegmentedBlock::kTensorRT ? 1 : 0; + for (size_t i = 0; i < seg.raw_inputs().size(); ++i) { + if (!old_to_new_g.count(seg.raw_inputs()[i])) { + old_to_new_g[seg.raw_inputs()[i]] = mini_to_new_g[seg.inputs()[i + offset]]; + } + } return; } +typedef std::pair, std::unordered_map> + GraphAndMapping; + +void AddIfBlockToGraph( + std::shared_ptr& new_g, + torch::jit::Node* if_node, + const std::vector& graph_and_mappings, + std::unordered_map& old_to_new_g) { + torch::jit::IfView if_view(if_node); + + // create a new if node in new_g and add corresponding inputs + auto new_if = new_g->insertNode(new_g->create(torch::jit::prim::If, {}, 0)); + new_if->addInput(util::getOrAddInputForValue(if_view.cond(), new_g, old_to_new_g)); + + // iterate over all blocks and add them to new created prim::If + for (auto graph_and_mapping : graph_and_mappings) { + auto new_if_block = new_if->addBlock(); + auto cur_block_graph = graph_and_mapping.first; + auto cur_block_mapping = graph_and_mapping.second; + std::unordered_map block_graph_to_new_g; + for (auto& i : cur_block_mapping) { + // for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then + // it's mini graph's input + if (old_to_new_g.count(i.first)) { + block_graph_to_new_g[i.second] = old_to_new_g[i.first]; + } + } + + auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); }; + new_if_block->cloneFrom(cur_block_graph->block(), env); + if (cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) { + if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) { + auto self = new_g->insertInput(0, "self_1"); + self->setType(cur_block_graph->inputs()[0]->type()); + } + block_graph_to_new_g[cur_block_graph->inputs()[0]] = new_g->inputs()[0]; + } + for (int i = cur_block_graph->inputs().size() - 1; i >= 0; --i) { + new_if_block->inputs()[i]->replaceAllUsesWith(block_graph_to_new_g[cur_block_graph->inputs()[i]]); + new_if_block->eraseInput(i); + } + } + for (auto ov : if_view.outputs()) { + auto no = new_if->addOutput(); + old_to_new_g[ov] = no; + no->copyMetadata(ov); + } + return; +} + +GraphAndMapping ConstructFallbackGraph( + torch::jit::script::Module& new_mod, + torch::jit::Block* block, + std::unordered_map input_ivalues_map, + CompileSpec cfg, + conversion::GraphParams named_params) { + auto convert_cfg = cfg.convert_info; + auto partition_info = cfg.partition_info; + + auto new_g = std::make_shared(); + + auto segmented_blocks = partitioning::Partition(block, input_ivalues_map, partition_info); + + // the mapping from lowering graph => fallback global graph + std::unordered_map old_to_new_g; + for (auto input : block->inputs()) { + util::getOrAddInputForValue(input, new_g, old_to_new_g); + } + + for (auto& seg_block : segmented_blocks) { + LOG_INFO(*seg_block.g() << "(GraphInSegmentedBlock)\n"); + std::ostringstream trt_engine_id; + trt_engine_id << reinterpret_cast(&seg_block); + + if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) { + std::vector inputs; + for (auto& shape : seg_block.in_shape()) { + inputs.push_back(ir::Input(shape)); + } + // update the input ranges for each segments + convert_cfg.inputs = inputs; + auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params); + auto temp_g = std::make_shared(); + auto device_spec = convert_cfg.engine_settings.device; + auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type); + AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true); + + seg_block.update_graph(temp_g); + AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); + } else { + if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) { + auto if_node = seg_block.raw_nodes()[0]; + + // convert the 2 blocks in prim::if and get the converted graph with mappings + std::vector graph_and_mappings; + for (auto cur_block : if_node->blocks()) { + graph_and_mappings.push_back( + ConstructFallbackGraph(new_mod, cur_block, input_ivalues_map, cfg, named_params)); + } + AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g); + + } else { + AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); + } + } + } + + for (auto& output : block->outputs()) { + if (old_to_new_g.count(output)) { + new_g->registerOutput(old_to_new_g[output]); + } + } + return {new_g, old_to_new_g}; +} + torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Module& mod, CompileSpec cfg) { // TODO: Should be doing a functional transform but need PR #31978 // [jit] More robust mangling @@ -192,53 +314,24 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo auto g = graph_and_parameters.first; auto params = graph_and_parameters.second; auto named_params = conversion::get_named_params(g->inputs(), params); - auto convert_cfg = std::move(cfg.convert_info); - LOG_INFO(*g << "(LoweringGraph)\n"); + LOG_INFO("(LoweredGraph)\n" << *g); - // segment the graph and convert segmented TensorRT block - auto segmented_blocks = partitioning::Partition(g, convert_cfg.inputs, cfg.partition_info); - if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) { + std::unordered_map inputs; + for (size_t i = 0; i < g->inputs().size(); ++i) { + inputs.insert({g->inputs()[i], cfg.convert_info.inputs[i]}); + } + auto input_ivalues_map = partitioning::generateRandomInputs(inputs); + auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, named_params); + new_g = graph_and_mapping.first; + LOG_INFO("(FallbackGraph)\n" << *new_g); + + // if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial + // module + if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) { LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n"); return mod; } - std::unordered_map old_to_new_g; - // add global graph's input to old_to_new_g mapping - for (auto input : g->inputs()) { - util::getOrAddInputForValue(input, new_g, old_to_new_g); - } - for (auto& seg_block : segmented_blocks) { - std::string cur_block_target = - seg_block.target() == partitioning::SegmentedBlock::kTensorRT ? "TensorRT" : "Torch"; - LOG_INFO(*seg_block.g() << "(Sub Graph" << cur_block_target << "Block)\n"); - std::ostringstream trt_engine_id; - trt_engine_id << reinterpret_cast(&seg_block); - if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) { - std::vector inputs; - for (auto& shape : seg_block.in_shape()) { - inputs.push_back(ir::Input(shape)); - } - // update the input ranges for each segments - convert_cfg.inputs = inputs; - auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params); - auto temp_g = std::make_shared(); - auto device_spec = convert_cfg.engine_settings.device; - auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type); - AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true); - - seg_block.update_graph(temp_g); - AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); - } else { - AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); - } - } - - for (auto& output : g->outputs()) { - new_g->registerOutput(old_to_new_g[output]); - } - - LOG_INFO(*new_g << "(FallbackGraph)\n"); - auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g); auto schema = util::GenerateGraphSchema(new_method->name(), new_g); new_mod.type()->addMethod(new_method); diff --git a/core/partitioning/SegmentedBlock.cpp b/core/partitioning/SegmentedBlock.cpp index a940ddf799..505b565482 100644 --- a/core/partitioning/SegmentedBlock.cpp +++ b/core/partitioning/SegmentedBlock.cpp @@ -4,7 +4,7 @@ namespace trtorch { namespace core { namespace partitioning { -SegmentedBlock::SegmentedBlock(SegmentedBlockTarget blk_target, std::vector& nodes) +SegmentedBlock::SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector& nodes) : target_(blk_target), g_(std::make_shared()) { for (auto& node : nodes) { nodes_.push_back(node); diff --git a/core/partitioning/SegmentedBlock.h b/core/partitioning/SegmentedBlock.h index 6baf59f1ab..097346ad2b 100644 --- a/core/partitioning/SegmentedBlock.h +++ b/core/partitioning/SegmentedBlock.h @@ -20,7 +20,7 @@ struct SegmentedBlock { SegmentedBlock() = default; SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared()) {} - SegmentedBlock(SegmentedBlockTarget blk_target, std::vector& nodes); + SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector& nodes); SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr g) : target_(blk_target), g_(g) {} torch::jit::Value* getOrAddInputForValue(torch::jit::Value* v); diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 539525cc01..b4b68e5e5e 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -3,6 +3,7 @@ #include #include "core/conversion/conversion.h" #include "core/partitioning/shape_analysis.h" +#include "torch/csrc/jit/passes/constant_pooling.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" namespace trtorch { @@ -85,8 +86,14 @@ std::vector segmentBlocksWithNonTensorInputs(SegmentedBlock& seg // if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, merge the // dependency nodes at the beginning of the current segmented_block and return this merged segmented_block if (seg_block.target() == SegmentedBlock::kTorch || isAllNodesSupported(dependency_nodes)) { - dependency_nodes.insert(dependency_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end()); - new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes); + // if current node is prim::If, just ensure that we have all required input in kTorch + if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) { + new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes); + new_seg_blocks.push_back(seg_block); + } else { + dependency_nodes.insert(dependency_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end()); + new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes); + } } else { // if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again std::unordered_set nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end()); @@ -127,7 +134,7 @@ std::vector segmentBlocksWithNonTensorInputs(SegmentedBlock& seg return std::move(new_seg_blocks); } -void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr g) { +void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shared_ptr g // create a list so we can insert SegmentedBlock without losing the iterators std::list segmented_blocks_list(segmented_blocks.begin(), segmented_blocks.end()); std::unordered_map::iterator> idx_to_iter; @@ -169,8 +176,10 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr< if (!updated_segments.count(first_torch_id)) { // Segmented Blocks with non-tensor inputs will have to be re-segmented as // TRTorch doesn't support non-tensor inputs for a module. - auto new_torch_block = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]).front(); - *idx_to_iter[first_torch_id] = new_torch_block; + auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]); + segmented_blocks.erase(segmented_blocks.begin() + first_torch_id); + segmented_blocks.insert( + segmented_blocks.begin() + first_torch_id, to_inject_blocks.begin(), to_inject_blocks.end()); updated_segments.insert(first_torch_id); } } @@ -191,7 +200,7 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr< return; } -void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, std::shared_ptr g) { +void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Block* block) { // find the corresponding raw values in original global graph for this segmented block's inputs/outputs std::set input_values; for (auto& seg_block : segmented_blocks) { @@ -200,7 +209,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, std::shared_ptr } } - for (auto& graph_output : g->outputs()) { + for (auto& graph_output : block->outputs()) { input_values.insert(graph_output); } @@ -249,12 +258,12 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, std::shared_ptr return; } -std::vector segment_graph(std::shared_ptr g, const PartitionInfo& partition_info) { +std::vector segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info) { auto min_block_size = partition_info.min_block_size; std::unordered_set forced_fallback_operators( partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end()); - auto nodes = g->block()->nodes(); + auto nodes = block->nodes(); std::vector segmented_blocks; // segment the nodes @@ -278,6 +287,16 @@ std::vector segment_graph(std::shared_ptr g, pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end()); } tensorrt_nodes.clear(); + // if there is a prim::If then this if node will be encapsulated in a SegmentedBlock + // we shouldn't inject node for this block in dependency analysis process + if (n->kind() == torch::jit::prim::If) { + if (!pytorch_nodes.empty()) { + segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes); + pytorch_nodes.clear(); + } + segmented_blocks.emplace_back(SegmentedBlock::kTorch, std::vector{n}); + continue; + } pytorch_nodes.push_back(n); } } @@ -295,21 +314,21 @@ std::vector segment_graph(std::shared_ptr g, } std::vector Partition( - std::shared_ptr g, - std::vector& inputs, + torch::jit::Block* block, + std::unordered_map& input_ivalues_map, const PartitionInfo& partition_info) { LOG_DEBUG(partition_info); // segment lowering global graph into blocks - std::vector segmented_blocks = segment_graph(g, partition_info); + std::vector segmented_blocks = segment_graph(block, partition_info); // resolve nonTensor inputs/outputs - resolveNonTensorInputs(segmented_blocks, g); + resolveNonTensorInputs(segmented_blocks); // register input/output torch::jit::Value for segmented graphs - registerSegmentsOutputs(segmented_blocks, g); + registerSegmentsOutputs(segmented_blocks, block); // run shape analysis on each segmented block - runShapeAnalysis(segmented_blocks, inputs, g); + runShapeAnalysis(segmented_blocks, input_ivalues_map); return segmented_blocks; } diff --git a/core/partitioning/partitioning.h b/core/partitioning/partitioning.h index aa62167429..61ae5ec98f 100644 --- a/core/partitioning/partitioning.h +++ b/core/partitioning/partitioning.h @@ -5,6 +5,7 @@ #include "core/ir/ir.h" #include "core/partitioning/PartitionInfo.h" #include "core/partitioning/SegmentedBlock.h" +#include "core/partitioning/shape_analysis.h" #include "core/util/prelude.h" #include "torch/csrc/jit/ir/ir.h" @@ -14,13 +15,13 @@ namespace partitioning { typedef std::vector PartitionedGraph; -PartitionedGraph segment_graph(std::shared_ptr g, const PartitionInfo& partition_info); +PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info); std::vector Partition( - std::shared_ptr g, - std::vector& inputs, + torch::jit::Block* block, + std::unordered_map& input_ivalues_map, const PartitionInfo& partition_info); } // namespace partitioning } // namespace core -} // namespace trtorch \ No newline at end of file +} // namespace trtorch diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index 88d3f055d5..a3807831a7 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -7,17 +7,20 @@ namespace trtorch { namespace core { namespace partitioning { -std::vector generateRandomInputs(std::vector& inputs) { +std::unordered_map generateRandomInputs( + std::unordered_map& inputs) { // generate random inputs for running pytorch segments + std::unordered_map ivalue_maps; std::vector random_inputs; + for (auto& input : inputs) { - auto cur_shape = input.input_shape; + auto cur_shape = input.second.input_shape; std::vector shape; shape.insert(shape.begin(), std::begin(cur_shape.d), std::begin(cur_shape.d) + cur_shape.nbDims); auto in = at::randint(5, shape, {at::kCUDA}); - random_inputs.push_back(in.clone()); + ivalue_maps[input.first] = in.clone(); } - return random_inputs; + return ivalue_maps; } void getSegmentsOutputByRunning( @@ -51,7 +54,9 @@ void getSegmentsOutputByRunning( // set inputs ivalues, now supports Tensor/Int to pass argumentes between different segments for (auto& input : seg_block.raw_inputs()) { - TRTORCH_CHECK(ivalues_maps.count(input), "Could not find mini graph input IValue " << input->debugName()); + TRTORCH_CHECK( + ivalues_maps.count(input), + "Could not find torch::jit::Value* " << input->debugName() << " in lowering graph for mini graph input.\n"); if (input->node()->kind() == torch::jit::prim::Param) { jit_inputs_ivalues.push_back(ivalues_maps[input]); } else if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) { @@ -100,15 +105,7 @@ void getSegmentsOutputByRunning( void runShapeAnalysis( std::vector& segmented_blocks, - std::vector& inputs, - std::shared_ptr g) { - // store the mapping from lowering graph torch::jit::Value => torch::jit::IValue that we get by running segments - std::unordered_map ivalues_maps; - std::vector random_inputs = generateRandomInputs(inputs); - for (size_t i = 0; i < g->inputs().size(); ++i) { - ivalues_maps[g->inputs()[i]] = random_inputs[i]; - } - + std::unordered_map& ivalues_maps) { // register every segment's input shape, and it's running output IValues for (auto& seg_block : segmented_blocks) { torch::jit::ConstantPooling(seg_block.g()); diff --git a/core/partitioning/shape_analysis.h b/core/partitioning/shape_analysis.h index a92715c3c9..ef3f82dc90 100644 --- a/core/partitioning/shape_analysis.h +++ b/core/partitioning/shape_analysis.h @@ -6,13 +6,13 @@ namespace trtorch { namespace core { namespace partitioning { -std::vector generateRandomInputs(std::vector& inputs); +std::unordered_map generateRandomInputs( + std::unordered_map& input_ranges); void runShapeAnalysis( std::vector& segmented_blocks, - std::vector& inputs, - std::shared_ptr g); + std::unordered_map& ivalues_maps); } // namespace partitioning } // namespace core -} // namespace trtorch \ No newline at end of file +} // namespace trtorch diff --git a/tests/core/BUILD b/tests/core/BUILD index ab0d46f7d1..b1f7d119bb 100644 --- a/tests/core/BUILD +++ b/tests/core/BUILD @@ -3,6 +3,6 @@ test_suite( tests = [ "//tests/core/conversion:conversion_tests", "//tests/core/lowering:lowering_tests", - "//tests/core/partitioning:partitioning_tests" + "//tests/core/partitioning:partitioning_test" ], ) diff --git a/tests/core/partitioning/BUILD b/tests/core/partitioning/BUILD index 89dcd6fabd..e6e9a6ae99 100644 --- a/tests/core/partitioning/BUILD +++ b/tests/core/partitioning/BUILD @@ -7,6 +7,13 @@ config_setting( } ) +filegroup( + name = "jit_models", + srcs = ["//tests/modules:resnet50_traced.jit.pt", + "//tests/modules:mobilenet_v2_traced.jit.pt", + "//tests/modules:conditional_scripted.jit.pt"] +) + partitioning_test( name = "test_segmentation", ) @@ -35,17 +42,34 @@ cc_test( "//conditions:default": ["@libtorch//:libtorch"], }), data = [ - "//tests/modules:jit_models" + ":jit_models" + ] +) + +cc_test( + name = "test_conditionals", + srcs = ["test_conditionals.cpp"], + deps = [ + "//tests/util", + "//core", + "@googletest//:gtest_main", + ] + select({ + ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], + "//conditions:default": ["@libtorch//:libtorch"], + }), + data = [ + ":jit_models" ] ) test_suite( - name = "partitioning_tests", + name = "partitioning_test", tests = [ ":test_segmentation", ":test_shape_analysis", ":test_tensorrt_conversion", ":test_stitched_graph", - ":test_fallback_graph_output" + ":test_fallback_graph_output", + ":test_conditionals" ] ) \ No newline at end of file diff --git a/tests/core/partitioning/test_conditionals.cpp b/tests/core/partitioning/test_conditionals.cpp new file mode 100644 index 0000000000..27b3fc2c10 --- /dev/null +++ b/tests/core/partitioning/test_conditionals.cpp @@ -0,0 +1,43 @@ +#include +#include +#include "core/compiler.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/script.h" + +size_t count_trt_engines_in_conditionals(std::shared_ptr g) { + size_t count = 0; + for (auto n : g->nodes()) { + if (n->kind() == torch::jit::prim::If) { + std::vector blocks{n->blocks()[0], n->blocks()[1]}; + for (auto cur_block : blocks) { + for (auto n : cur_block->nodes()) { + if (n->kind().toQualString() == std::string("tensorrt::execute_engine")) { + ++count; + } + } + } + } + } + return count; +} + +TEST(Partitioning, FallbackOnConditionalsCorrectly) { + torch::jit::script::Module mod; + try { + mod = torch::jit::load("tests/modules/conditional_scripted.jit.pt"); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + return; + } + + std::vector inputs{trtorch::core::ir::Input({3, 3, 16, 16})}; + trtorch::core::CompileSpec cfg(inputs); + cfg.partition_info.enabled = true; + torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg); + auto g = new_mod.get_method("forward").graph(); + + auto conditional_engines_count = count_trt_engines_in_conditionals(g); + + ASSERT_TRUE(conditional_engines_count == 2); +} diff --git a/tests/core/partitioning/test_segmentation.cpp b/tests/core/partitioning/test_segmentation.cpp index 43b5317146..89b0548652 100644 --- a/tests/core/partitioning/test_segmentation.cpp +++ b/tests/core/partitioning/test_segmentation.cpp @@ -67,7 +67,7 @@ TEST(Partitioning, SegmentSequentialModelCorrectly) { trtorch::core::partitioning::PartitionInfo partition_info; partition_info.enabled = true; std::vector segmented_blocks = - trtorch::core::partitioning::segment_graph(g, partition_info); + trtorch::core::partitioning::segment_graph(g->block(), partition_info); ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT, 2)); ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch, 1)); ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3}, {4}})); @@ -100,7 +100,7 @@ TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) { partition_info.enabled = true; partition_info.min_block_size = 3; std::vector segmented_blocks = - trtorch::core::partitioning::segment_graph(g, partition_info); + trtorch::core::partitioning::segment_graph(g->block(), partition_info); ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT, 1)); ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch, 1)); ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3, 4}})); @@ -133,7 +133,7 @@ TEST(Partitioning, SegmentSequentialModelWithForcedOPCorrectly) { partition_info.enabled = true; partition_info.forced_fallback_operators.push_back("aten::relu"); std::vector segmented_blocks = - trtorch::core::partitioning::segment_graph(g, partition_info); + trtorch::core::partitioning::segment_graph(g->block(), partition_info); ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT, 3)); ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch, 2)); ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0}, {1}, {2}, {3}, {4}})); @@ -166,7 +166,7 @@ TEST(Partitioning, SegmentBranchModelCorrectly) { trtorch::core::partitioning::PartitionInfo partition_info; partition_info.enabled = true; std::vector segmented_blocks = - trtorch::core::partitioning::segment_graph(g, partition_info); + trtorch::core::partitioning::segment_graph(g->block(), partition_info); ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT, 2)); ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch, 1)); ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1}, {2}, {3, 4, 5, 6}})); @@ -200,7 +200,7 @@ TEST(Partitioning, SegmentBranchModelWithMinBlockSizeCorrectly) { partition_info.enabled = true; partition_info.min_block_size = 3; std::vector segmented_blocks = - trtorch::core::partitioning::segment_graph(g, partition_info); + trtorch::core::partitioning::segment_graph(g->block(), partition_info); ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT, 1)); ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch, 1)); ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3, 4, 5, 6}})); @@ -234,8 +234,8 @@ TEST(Partitioning, SegmentBranchModelWithForcedFallbackOPCorrectly) { partition_info.enabled = true; partition_info.forced_fallback_operators.push_back("aten::relu"); std::vector segmented_blocks = - trtorch::core::partitioning::segment_graph(g, partition_info); + trtorch::core::partitioning::segment_graph(g->block(), partition_info); ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT, 3)); ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch, 2)); ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1}, {2}, {3}, {4}, {5, 6}})); -} \ No newline at end of file +} diff --git a/tests/core/partitioning/test_shape_analysis.cpp b/tests/core/partitioning/test_shape_analysis.cpp index 6f220478de..9129acfc3f 100644 --- a/tests/core/partitioning/test_shape_analysis.cpp +++ b/tests/core/partitioning/test_shape_analysis.cpp @@ -59,8 +59,14 @@ TEST(Partitioning, InferSequentialModelSegmentedBlockShapeCorrectly) { inputs.push_back(trtorch::core::ir::Input({8, 16, 3, 3})); inputs.push_back(trtorch::core::ir::Input({8})); + std::unordered_map inputs_map; + for (size_t i = 0; i < g->inputs().size(); ++i) { + inputs_map.insert({g->inputs()[i], inputs[i]}); + } + auto input_ivalues_map = trtorch::core::partitioning::generateRandomInputs(inputs_map); std::vector segmented_blocks = - trtorch::core::partitioning::Partition(g, inputs, partition_info); + trtorch::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info); + ASSERT_TRUE(checkSegmentedBlockInputShape( segmented_blocks, {{{3, 3, 16, 16}, {32, 3, 3, 3}, {32}, {16, 32, 3, 3}, {16}}, @@ -101,8 +107,14 @@ TEST(Partitioning, InferBranchModelSegmentedBlockShapeCorrectly) { inputs.push_back(trtorch::core::ir::Input({16, 32, 3, 3})); inputs.push_back(trtorch::core::ir::Input({16})); + std::unordered_map inputs_map; + for (size_t i = 0; i < g->inputs().size(); ++i) { + inputs_map.insert({g->inputs()[i], inputs[i]}); + } + auto input_ivalues_map = trtorch::core::partitioning::generateRandomInputs(inputs_map); std::vector segmented_blocks = - trtorch::core::partitioning::Partition(g, inputs, partition_info); + trtorch::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info); + ASSERT_TRUE(checkSegmentedBlockInputShape( segmented_blocks, {{{3, 3, 16, 16}, {32, 3, 3, 3}, {32}, {16, 32, 3, 3}, {16}}, diff --git a/tests/modules/hub.py b/tests/modules/hub.py index 239fc5c2e9..986a84776c 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -54,18 +54,10 @@ "model": torch.hub.load('pytorch/vision:v0.9.0', 'resnet50', pretrained=True), "path": "both" }, - "fcn_resnet101": { - "model": torch.hub.load('pytorch/vision:v0.9.0', 'fcn_resnet101', pretrained=True), - "path": "script" - }, "ssd": { "model": torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math="fp32"), "path": "trace" }, - "faster_rcnn": { - "model": models.detection.fasterrcnn_resnet50_fpn(pretrained=True), - "path": "script" - }, "efficientnet_b0": { "model": timm.create_model('efficientnet_b0', pretrained=True), "path": "script" @@ -104,3 +96,33 @@ def forward(self, x): trace_model = torch.jit.trace(model, x) torch.jit.save(trace_model, "pooling_traced.jit.pt") + + +# Sample Conditional Model (for testing partitioning and fallback in conditionals) +class FallbackIf(torch.nn.Module): + + def __init__(self): + super(FallbackIf, self).__init__() + self.relu1 = torch.nn.ReLU() + self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1) + self.log_sig = torch.nn.LogSigmoid() + self.conv2 = torch.nn.Conv2d(32, 32, 3, 1, 1) + self.conv3 = torch.nn.Conv2d(32, 3, 3, 1, 1) + + def forward(self, x): + x = self.relu1(x) + x_first = x[0][0][0][0].item() + if x_first > 0: + x = self.conv1(x) + x1 = self.log_sig(x) + x2 = self.conv2(x) + x = self.conv3(x1 + x2) + else: + x = self.log_sig(x) + x = self.conv1(x) + return x + + +conditional_model = FallbackIf().eval().cuda() +conditional_script_model = torch.jit.script(conditional_model) +torch.jit.save(conditional_script_model, "conditional_scripted.jit.pt")