Skip to content

Support shape analysis for dynamic fallback #1111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ void AddIfBlockToGraph(
GraphAndMapping ConstructFallbackGraph(
torch::jit::script::Module& new_mod,
torch::jit::Block* block,
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map,
std::vector<std::unordered_map<const torch::jit::Value*, torch::jit::IValue>> example_tensor_maps,
CompileSpec cfg,
ir::StaticParams static_params,
std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
Expand All @@ -231,7 +231,7 @@ GraphAndMapping ConstructFallbackGraph(

auto new_g = std::make_shared<torch::jit::Graph>();

auto segmented_blocks = partitioning::Partition(block, example_tensor_map, partition_info, fallback_nodes);
auto segmented_blocks = partitioning::Partition(block, example_tensor_maps, partition_info, fallback_nodes);

// the mapping from lowering graph => fallback global graph
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
Expand Down Expand Up @@ -272,7 +272,7 @@ GraphAndMapping ConstructFallbackGraph(
std::vector<GraphAndMapping> graph_and_mappings;
for (auto cur_block : if_node->blocks()) {
graph_and_mappings.push_back(
ConstructFallbackGraph(new_mod, cur_block, example_tensor_map, cfg, static_params, fallback_nodes));
ConstructFallbackGraph(new_mod, cur_block, example_tensor_maps, cfg, static_params, fallback_nodes));
}
AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g);

Expand Down Expand Up @@ -408,10 +408,10 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
if (cfg.partition_info.enabled &&
!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) {
auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types);
auto input_ivalues_maps = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types);
std::unordered_map<torch::jit::Node*, int> fallback_nodes;
auto graph_and_mapping =
ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params, fallback_nodes);
ConstructFallbackGraph(new_mod, g->block(), input_ivalues_maps, cfg, static_params, fallback_nodes);
new_g = graph_and_mapping.first;
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
for (size_t i = 0; i < new_g->inputs().size(); ++i) {
Expand Down
2 changes: 1 addition & 1 deletion core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
passes::RemoveNOPs(g);
passes::AliasOperators(g);
passes::SiluToSigmoidMultipication(g);
passes::RemoveSingleUse0DTensors(g);
// passes::RemoveSingleUse0DTensors(g);
passes::RemoveUnnecessaryCasts(g);
LOG_GRAPH(*g);
}
Expand Down
35 changes: 35 additions & 0 deletions core/partitioning/SegmentedBlock.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "NvInfer.h"
#include "core/ir/ir.h"
#include "core/partitioning/PartitionInfo.h"
#include "core/util/trt_util.h"
#include "torch/csrc/jit/ir/ir.h"

namespace torch_tensorrt {
Expand Down Expand Up @@ -76,6 +77,40 @@ struct SegmentedBlock {
void register_inshapes(std::vector<ir::Input>& in_shapes) {
in_shapes_ = in_shapes;
}

void register_opt_shapes(std::vector<ir::Input>& opt_shapes) {
assert(in_shapes_.size() == opt_shapes.size());
for (size_t i = 0; i < opt_shapes.size(); i++) {
in_shapes_[i].opt = opt_shapes[i].opt;
}
}

void register_max_shapes(std::vector<ir::Input>& max_shapes) {
assert(in_shapes_.size() == max_shapes.size());
for (size_t i = 0; i < max_shapes.size(); i++) {
in_shapes_[i].max = max_shapes[i].max;
}
}

void construct_dynamic_shape() {
for (size_t i = 0; i < in_shapes_.size(); i++) {
std::vector<int64_t> dyn_shape;
for (int j = 0; j < in_shapes_[i].input_shape.nbDims; j++) {
std::set<uint64_t> dim;
dim.insert(in_shapes_[i].min.d[j]);
dim.insert(in_shapes_[i].opt.d[j]);
dim.insert(in_shapes_[i].max.d[j]);
if (dim.size() != 1) {
dyn_shape.push_back(-1);
in_shapes_[i].input_is_dynamic = true;
} else {
dyn_shape.push_back(in_shapes_[i].opt.d[j]);
}
}
in_shapes_[i].input_shape = util::toDims(dyn_shape);
}
}

const std::vector<ir::Input>& in_shapes() const {
return in_shapes_;
}
Expand Down
4 changes: 2 additions & 2 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ PartitionedGraph segment_graph(

PartitionedGraph Partition(
torch::jit::Block* block,
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map,
std::vector<std::unordered_map<const torch::jit::Value*, torch::jit::IValue>>& example_tensor_maps,
const PartitionInfo& partition_info,
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
LOG_DEBUG(partition_info);
Expand All @@ -453,7 +453,7 @@ PartitionedGraph Partition(
registerSegmentsOutputs(segmented_blocks, block);

// run shape analysis on each segmented block
runShapeAnalysis(segmented_blocks, example_tensor_map, partition_info);
runShapeAnalysis(segmented_blocks, example_tensor_maps, partition_info);

for (uint64_t i = 0; i < segmented_blocks.size(); i++) {
segmented_blocks[i].update_id(i);
Expand Down
2 changes: 1 addition & 1 deletion core/partitioning/partitioning.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ PartitionedGraph segment_graph(

PartitionedGraph Partition(
torch::jit::Block* block,
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map,
std::vector<std::unordered_map<const torch::jit::Value*, torch::jit::IValue>>& example_tensor_map,
const PartitionInfo& partition_info,
std::unordered_map<torch::jit::Node*, int>& fallback_nodes);

Expand Down
91 changes: 83 additions & 8 deletions core/partitioning/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,43 @@ namespace torch_tensorrt {
namespace core {
namespace partitioning {

std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomInputs(
std::vector<std::unordered_map<const torch::jit::Value*, torch::jit::IValue>> generateRandomInputs(
std::unordered_map<const torch::jit::Value*, ir::Input>& inputs,
std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>>& types) {
std::vector<std::unordered_map<const torch::jit::Value*, torch::jit::IValue>> ivalue_maps;

bool is_dynamic = false;
for (auto& input : inputs) {
if (input.second.input_is_dynamic)
is_dynamic = true;
}
if (is_dynamic) {
LOG_WARNING("Dynamic fallback encountered");
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> ivalue_map_min, ivalue_map_opt, ivalue_map_max;
for (auto& input : inputs) {
auto cur_min = input.second.min;
auto cur_opt = input.second.opt;
auto cur_max = input.second.max;
std::vector<int64_t> min_shape, opt_shape, max_shape;
min_shape.insert(min_shape.begin(), std::begin(cur_min.d), std::begin(cur_min.d) + cur_min.nbDims);
opt_shape.insert(opt_shape.begin(), std::begin(cur_opt.d), std::begin(cur_opt.d) + cur_opt.nbDims);
max_shape.insert(max_shape.begin(), std::begin(cur_max.d), std::begin(cur_max.d) + cur_max.nbDims);
auto type_opt = types[input.first];
auto type = at::kFloat;
if (type_opt) {
type = type_opt.value();
} else {
LOG_WARNING("Input type for doing shape analysis could not be determined, defaulting to F32");
}
auto in_min = at::randint(5, min_shape, {at::kCUDA}).to(type);
auto in_opt = at::randint(5, opt_shape, {at::kCUDA}).to(type);
auto in_max = at::randint(5, max_shape, {at::kCUDA}).to(type);
ivalue_map_min[input.first] = in_min.clone();
ivalue_map_opt[input.first] = in_opt.clone();
ivalue_map_max[input.first] = in_max.clone();
}
return {ivalue_map_min, ivalue_map_opt, ivalue_map_max};
}
// generate random inputs for running pytorch segments
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> ivalue_map;

Expand All @@ -30,12 +64,13 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomI
ivalue_map[input.first] = in.clone();
in_i++;
}
return ivalue_map;
return {ivalue_map};
}

void getSegmentsOutputByRunning(
SegmentedBlock& seg_block,
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
int register_iteration,
const PartitionInfo& partition_info) {
// create a module to run the graph
auto g = seg_block.g();
Expand Down Expand Up @@ -63,6 +98,12 @@ void getSegmentsOutputByRunning(

std::vector<torch::jit::IValue> jit_inputs_ivalues;

for (auto& input : seg_block.raw_inputs()) {
LOG_DEBUG(
"Register input ivalues_maps for torch::jit::Value* " << input->debugName() << ", produced from "
<< util::node_info(input->node()));
}

// set inputs ivalues, now supports Tensor/Int to pass argumentes between different segments
for (auto& input : seg_block.raw_inputs()) {
TORCHTRT_CHECK(
Expand Down Expand Up @@ -111,6 +152,9 @@ void getSegmentsOutputByRunning(
size_t idx = 0;
for (auto& output : seg_block.raw_outputs()) {
ivalues_maps[output] = jit_results[idx++];
LOG_DEBUG(
"Register output ivalues_maps for torch::jit::Value* " << output->debugName() << ", produced from "
<< util::node_info(output->node()));
}

// set input shape for each segmented block so we wil use it in conversion process
Expand Down Expand Up @@ -146,19 +190,50 @@ void getSegmentsOutputByRunning(
input_types.push_back(cur_ivalue.toTensor().scalar_type());
}
}

seg_block.register_inshapes(input_shapes);
LOG_DEBUG("Begin register shape");
if (register_iteration == 0)
seg_block.register_inshapes(input_shapes);
else if (register_iteration == 1)
seg_block.register_opt_shapes(input_shapes);
else if (register_iteration == 2)
seg_block.register_max_shapes(input_shapes);
seg_block.register_intypes(input_types);
LOG_DEBUG("Done");
}

void runShapeAnalysis(
std::vector<SegmentedBlock>& segmented_blocks,
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map,
std::vector<std::unordered_map<const torch::jit::Value*, torch::jit::IValue>>& example_tensor_maps,
const PartitionInfo& partition_info) {
// register every segment's input shape, and it's running output IValues
for (auto& seg_block : segmented_blocks) {
torch::jit::ConstantPooling(seg_block.g());
getSegmentsOutputByRunning(seg_block, example_tensor_map, partition_info);
if (example_tensor_maps.size() == 1) {
int i = 0;
for (auto& seg_block : segmented_blocks) {
torch::jit::ConstantPooling(seg_block.g());
LOG_DEBUG("Running the graph @" << i);
getSegmentsOutputByRunning(seg_block, example_tensor_maps[0], 0, partition_info);
i++;
}
} else if (example_tensor_maps.size() == 3) {
int i = 0;
for (auto& seg_block : segmented_blocks) {
torch::jit::ConstantPooling(seg_block.g());
LOG_DEBUG("Running min graph @" << i);
getSegmentsOutputByRunning(seg_block, example_tensor_maps[0], 0, partition_info);
i++;
}
for (auto& seg_block : segmented_blocks) {
torch::jit::ConstantPooling(seg_block.g());
LOG_DEBUG("Running opt graph @" << i);
getSegmentsOutputByRunning(seg_block, example_tensor_maps[1], 1, partition_info);
}
for (auto& seg_block : segmented_blocks) {
torch::jit::ConstantPooling(seg_block.g());
LOG_DEBUG("Running max graph @" << i);
getSegmentsOutputByRunning(seg_block, example_tensor_maps[2], 2, partition_info);
}
for (auto& seg_block : segmented_blocks)
seg_block.construct_dynamic_shape();
}
return;
}
Expand Down
4 changes: 2 additions & 2 deletions core/partitioning/shape_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ namespace torch_tensorrt {
namespace core {
namespace partitioning {

std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomInputs(
std::vector<std::unordered_map<const torch::jit::Value*, torch::jit::IValue>> generateRandomInputs(
std::unordered_map<const torch::jit::Value*, ir::Input>& input_ranges,
std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>>& input_types);

void runShapeAnalysis(
std::vector<SegmentedBlock>& segmented_blocks,
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
std::vector<std::unordered_map<const torch::jit::Value*, torch::jit::IValue>>& ivalues_maps,
const PartitionInfo& partition_info);

} // namespace partitioning
Expand Down