Skip to content

Commit d24a818

Browse files
committed
Support shape analysis for dynamic fallback
Signed-off-by: Cheng Hang <[email protected]>
1 parent 666a263 commit d24a818

File tree

6 files changed

+128
-18
lines changed

6 files changed

+128
-18
lines changed

core/compiler.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,15 +221,15 @@ void AddIfBlockToGraph(
221221
GraphAndMapping ConstructFallbackGraph(
222222
torch::jit::script::Module& new_mod,
223223
torch::jit::Block* block,
224-
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map,
224+
std::vector<std::unordered_map<const torch::jit::Value*, torch::jit::IValue>> example_tensor_maps,
225225
CompileSpec cfg,
226226
ir::StaticParams static_params) {
227227
auto convert_cfg = cfg.convert_info;
228228
auto partition_info = cfg.partition_info;
229229

230230
auto new_g = std::make_shared<torch::jit::Graph>();
231231

232-
auto segmented_blocks = partitioning::Partition(block, example_tensor_map, partition_info);
232+
auto segmented_blocks = partitioning::Partition(block, example_tensor_maps, partition_info);
233233

234234
// the mapping from lowering graph => fallback global graph
235235
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
@@ -270,7 +270,7 @@ GraphAndMapping ConstructFallbackGraph(
270270
std::vector<GraphAndMapping> graph_and_mappings;
271271
for (auto cur_block : if_node->blocks()) {
272272
graph_and_mappings.push_back(
273-
ConstructFallbackGraph(new_mod, cur_block, example_tensor_map, cfg, static_params));
273+
ConstructFallbackGraph(new_mod, cur_block, example_tensor_maps, cfg, static_params));
274274
}
275275
AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g);
276276

@@ -429,8 +429,8 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
429429
if (cfg.partition_info.enabled &&
430430
!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
431431
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) {
432-
auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types);
433-
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params);
432+
auto input_ivalues_maps = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types);
433+
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_maps, cfg, static_params);
434434
new_g = graph_and_mapping.first;
435435
LOG_INFO("Segmented Graph: " << *new_g);
436436

core/partitioning/SegmentedBlock.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "NvInfer.h"
77
#include "core/ir/ir.h"
88
#include "core/partitioning/PartitionInfo.h"
9+
#include "core/util/trt_util.h"
910
#include "torch/csrc/jit/ir/ir.h"
1011

1112
namespace torch_tensorrt {
@@ -76,6 +77,40 @@ struct SegmentedBlock {
7677
void register_inshapes(std::vector<ir::Input>& in_shapes) {
7778
in_shapes_ = in_shapes;
7879
}
80+
81+
void register_opt_shapes(std::vector<ir::Input>& opt_shapes) {
82+
assert(in_shapes_.size() == opt_shapes.size());
83+
for (size_t i = 0; i < opt_shapes.size(); i++) {
84+
in_shapes_[i].opt = opt_shapes[i].opt;
85+
}
86+
}
87+
88+
void register_max_shapes(std::vector<ir::Input>& max_shapes) {
89+
assert(in_shapes_.size() == max_shapes.size());
90+
for (size_t i = 0; i < max_shapes.size(); i++) {
91+
in_shapes_[i].max = max_shapes[i].max;
92+
}
93+
}
94+
95+
void construct_dynamic_shape() {
96+
for (size_t i = 0; i < in_shapes_.size(); i++) {
97+
std::vector<int64_t> dyn_shape;
98+
for (int j = 0; j < in_shapes_[i].input_shape.nbDims; j++) {
99+
std::set<uint64_t> dim;
100+
dim.insert(in_shapes_[i].min.d[j]);
101+
dim.insert(in_shapes_[i].opt.d[j]);
102+
dim.insert(in_shapes_[i].max.d[j]);
103+
if (dim.size() != 1) {
104+
dyn_shape.push_back(-1);
105+
in_shapes_[i].input_is_dynamic = true;
106+
} else {
107+
dyn_shape.push_back(in_shapes_[i].opt.d[j]);
108+
}
109+
}
110+
in_shapes_[i].input_shape = util::toDims(dyn_shape);
111+
}
112+
}
113+
79114
const std::vector<ir::Input>& in_shapes() const {
80115
return in_shapes_;
81116
}

core/partitioning/partitioning.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa
581581

582582
PartitionedGraph Partition(
583583
torch::jit::Block* block,
584-
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map,
584+
std::vector<std::unordered_map<const torch::jit::Value*, torch::jit::IValue>>& example_tensor_maps,
585585
const PartitionInfo& partition_info) {
586586
LOG_DEBUG(partition_info);
587587
// segment lowering global graph into blocks
@@ -596,7 +596,7 @@ PartitionedGraph Partition(
596596
registerSegmentsOutputs(segmented_blocks, block);
597597

598598
// run shape analysis on each segmented block
599-
runShapeAnalysis(segmented_blocks, example_tensor_map, partition_info);
599+
runShapeAnalysis(segmented_blocks, example_tensor_maps, partition_info);
600600

601601
for (uint64_t i = 0; i < segmented_blocks.size(); i++) {
602602
segmented_blocks[i].update_id(i);

core/partitioning/partitioning.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa
2020

2121
PartitionedGraph Partition(
2222
torch::jit::Block* block,
23-
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map,
23+
std::vector<std::unordered_map<const torch::jit::Value*, torch::jit::IValue>>& example_tensor_map,
2424
const PartitionInfo& partition_info);
2525

2626
std::ostream& operator<<(std::ostream& os, const PartitionedGraph& g);

core/partitioning/shape_analysis.cpp

Lines changed: 83 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,43 @@ namespace torch_tensorrt {
88
namespace core {
99
namespace partitioning {
1010

11-
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomInputs(
11+
std::vector<std::unordered_map<const torch::jit::Value*, torch::jit::IValue>> generateRandomInputs(
1212
std::unordered_map<const torch::jit::Value*, ir::Input>& inputs,
1313
std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>>& types) {
14+
std::vector<std::unordered_map<const torch::jit::Value*, torch::jit::IValue>> ivalue_maps;
15+
16+
bool is_dynamic = false;
17+
for (auto& input : inputs) {
18+
if (input.second.input_is_dynamic)
19+
is_dynamic = true;
20+
}
21+
if (is_dynamic) {
22+
LOG_WARNING("Dynamic fallback encountered");
23+
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> ivalue_map_min, ivalue_map_opt, ivalue_map_max;
24+
for (auto& input : inputs) {
25+
auto cur_min = input.second.min;
26+
auto cur_opt = input.second.opt;
27+
auto cur_max = input.second.max;
28+
std::vector<int64_t> min_shape, opt_shape, max_shape;
29+
min_shape.insert(min_shape.begin(), std::begin(cur_min.d), std::begin(cur_min.d) + cur_min.nbDims);
30+
opt_shape.insert(opt_shape.begin(), std::begin(cur_opt.d), std::begin(cur_opt.d) + cur_opt.nbDims);
31+
max_shape.insert(max_shape.begin(), std::begin(cur_max.d), std::begin(cur_max.d) + cur_max.nbDims);
32+
auto type_opt = types[input.first];
33+
auto type = at::kFloat;
34+
if (type_opt) {
35+
type = type_opt.value();
36+
} else {
37+
LOG_WARNING("Input type for doing shape analysis could not be determined, defaulting to F32");
38+
}
39+
auto in_min = at::randint(5, min_shape, {at::kCUDA}).to(type);
40+
auto in_opt = at::randint(5, opt_shape, {at::kCUDA}).to(type);
41+
auto in_max = at::randint(5, max_shape, {at::kCUDA}).to(type);
42+
ivalue_map_min[input.first] = in_min.clone();
43+
ivalue_map_opt[input.first] = in_opt.clone();
44+
ivalue_map_max[input.first] = in_max.clone();
45+
}
46+
return {ivalue_map_min, ivalue_map_opt, ivalue_map_max};
47+
}
1448
// generate random inputs for running pytorch segments
1549
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> ivalue_map;
1650

@@ -30,12 +64,13 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomI
3064
ivalue_map[input.first] = in.clone();
3165
in_i++;
3266
}
33-
return ivalue_map;
67+
return {ivalue_map};
3468
}
3569

3670
void getSegmentsOutputByRunning(
3771
SegmentedBlock& seg_block,
3872
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
73+
int register_iteration,
3974
const PartitionInfo& partition_info) {
4075
// create a module to run the graph
4176
auto g = seg_block.g();
@@ -63,6 +98,12 @@ void getSegmentsOutputByRunning(
6398

6499
std::vector<torch::jit::IValue> jit_inputs_ivalues;
65100

101+
for (auto& input : seg_block.raw_inputs()) {
102+
LOG_DEBUG(
103+
"Register input ivalues_maps for torch::jit::Value* " << input->debugName() << ", produced from "
104+
<< util::node_info(input->node()));
105+
}
106+
66107
// set inputs ivalues, now supports Tensor/Int to pass argumentes between different segments
67108
for (auto& input : seg_block.raw_inputs()) {
68109
TORCHTRT_CHECK(
@@ -111,6 +152,9 @@ void getSegmentsOutputByRunning(
111152
size_t idx = 0;
112153
for (auto& output : seg_block.raw_outputs()) {
113154
ivalues_maps[output] = jit_results[idx++];
155+
LOG_DEBUG(
156+
"Register output ivalues_maps for torch::jit::Value* " << output->debugName() << ", produced from "
157+
<< util::node_info(output->node()));
114158
}
115159

116160
// set input shape for each segmented block so we wil use it in conversion process
@@ -146,19 +190,50 @@ void getSegmentsOutputByRunning(
146190
input_types.push_back(cur_ivalue.toTensor().scalar_type());
147191
}
148192
}
149-
150-
seg_block.register_inshapes(input_shapes);
193+
LOG_DEBUG("Begin register shape");
194+
if (register_iteration == 0)
195+
seg_block.register_inshapes(input_shapes);
196+
else if (register_iteration == 1)
197+
seg_block.register_opt_shapes(input_shapes);
198+
else if (register_iteration == 2)
199+
seg_block.register_max_shapes(input_shapes);
151200
seg_block.register_intypes(input_types);
201+
LOG_DEBUG("Done");
152202
}
153203

154204
void runShapeAnalysis(
155205
std::vector<SegmentedBlock>& segmented_blocks,
156-
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map,
206+
std::vector<std::unordered_map<const torch::jit::Value*, torch::jit::IValue>>& example_tensor_maps,
157207
const PartitionInfo& partition_info) {
158208
// register every segment's input shape, and it's running output IValues
159-
for (auto& seg_block : segmented_blocks) {
160-
torch::jit::ConstantPooling(seg_block.g());
161-
getSegmentsOutputByRunning(seg_block, example_tensor_map, partition_info);
209+
if (example_tensor_maps.size() == 1) {
210+
int i = 0;
211+
for (auto& seg_block : segmented_blocks) {
212+
torch::jit::ConstantPooling(seg_block.g());
213+
LOG_DEBUG("Running the graph @" << i);
214+
getSegmentsOutputByRunning(seg_block, example_tensor_maps[0], 0, partition_info);
215+
i++;
216+
}
217+
} else if (example_tensor_maps.size() == 3) {
218+
int i = 0;
219+
for (auto& seg_block : segmented_blocks) {
220+
torch::jit::ConstantPooling(seg_block.g());
221+
LOG_DEBUG("Running min graph @" << i);
222+
getSegmentsOutputByRunning(seg_block, example_tensor_maps[0], 0, partition_info);
223+
i++;
224+
}
225+
for (auto& seg_block : segmented_blocks) {
226+
torch::jit::ConstantPooling(seg_block.g());
227+
LOG_DEBUG("Running opt graph @" << i);
228+
getSegmentsOutputByRunning(seg_block, example_tensor_maps[1], 1, partition_info);
229+
}
230+
for (auto& seg_block : segmented_blocks) {
231+
torch::jit::ConstantPooling(seg_block.g());
232+
LOG_DEBUG("Running max graph @" << i);
233+
getSegmentsOutputByRunning(seg_block, example_tensor_maps[2], 2, partition_info);
234+
}
235+
for (auto& seg_block : segmented_blocks)
236+
seg_block.construct_dynamic_shape();
162237
}
163238
return;
164239
}

core/partitioning/shape_analysis.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ namespace torch_tensorrt {
66
namespace core {
77
namespace partitioning {
88

9-
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomInputs(
9+
std::vector<std::unordered_map<const torch::jit::Value*, torch::jit::IValue>> generateRandomInputs(
1010
std::unordered_map<const torch::jit::Value*, ir::Input>& input_ranges,
1111
std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>>& input_types);
1212

1313
void runShapeAnalysis(
1414
std::vector<SegmentedBlock>& segmented_blocks,
15-
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
15+
std::vector<std::unordered_map<const torch::jit::Value*, torch::jit::IValue>>& ivalues_maps,
1616
const PartitionInfo& partition_info);
1717

1818
} // namespace partitioning

0 commit comments

Comments
 (0)