@@ -8,9 +8,43 @@ namespace torch_tensorrt {
8
8
namespace core {
9
9
namespace partitioning {
10
10
11
- std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomInputs (
11
+ std::vector<std:: unordered_map<const torch::jit::Value*, torch::jit::IValue> > generateRandomInputs (
12
12
std::unordered_map<const torch::jit::Value*, ir::Input>& inputs,
13
13
std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>>& types) {
14
+ std::vector<std::unordered_map<const torch::jit::Value*, torch::jit::IValue>> ivalue_maps;
15
+
16
+ bool is_dynamic = false ;
17
+ for (auto & input : inputs) {
18
+ if (input.second .input_is_dynamic )
19
+ is_dynamic = true ;
20
+ }
21
+ if (is_dynamic) {
22
+ LOG_WARNING (" Dynamic fallback encountered" );
23
+ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> ivalue_map_min, ivalue_map_opt, ivalue_map_max;
24
+ for (auto & input : inputs) {
25
+ auto cur_min = input.second .min ;
26
+ auto cur_opt = input.second .opt ;
27
+ auto cur_max = input.second .max ;
28
+ std::vector<int64_t > min_shape, opt_shape, max_shape;
29
+ min_shape.insert (min_shape.begin (), std::begin (cur_min.d ), std::begin (cur_min.d ) + cur_min.nbDims );
30
+ opt_shape.insert (opt_shape.begin (), std::begin (cur_opt.d ), std::begin (cur_opt.d ) + cur_opt.nbDims );
31
+ max_shape.insert (max_shape.begin (), std::begin (cur_max.d ), std::begin (cur_max.d ) + cur_max.nbDims );
32
+ auto type_opt = types[input.first ];
33
+ auto type = at::kFloat ;
34
+ if (type_opt) {
35
+ type = type_opt.value ();
36
+ } else {
37
+ LOG_WARNING (" Input type for doing shape analysis could not be determined, defaulting to F32" );
38
+ }
39
+ auto in_min = at::randint (5 , min_shape, {at::kCUDA }).to (type);
40
+ auto in_opt = at::randint (5 , opt_shape, {at::kCUDA }).to (type);
41
+ auto in_max = at::randint (5 , max_shape, {at::kCUDA }).to (type);
42
+ ivalue_map_min[input.first ] = in_min.clone ();
43
+ ivalue_map_opt[input.first ] = in_opt.clone ();
44
+ ivalue_map_max[input.first ] = in_max.clone ();
45
+ }
46
+ return {ivalue_map_min, ivalue_map_opt, ivalue_map_max};
47
+ }
14
48
// generate random inputs for running pytorch segments
15
49
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> ivalue_map;
16
50
@@ -30,12 +64,13 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomI
30
64
ivalue_map[input.first ] = in.clone ();
31
65
in_i++;
32
66
}
33
- return ivalue_map;
67
+ return { ivalue_map} ;
34
68
}
35
69
36
70
void getSegmentsOutputByRunning (
37
71
SegmentedBlock& seg_block,
38
72
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
73
+ int register_iteration,
39
74
const PartitionInfo& partition_info) {
40
75
// create a module to run the graph
41
76
auto g = seg_block.g ();
@@ -63,6 +98,12 @@ void getSegmentsOutputByRunning(
63
98
64
99
std::vector<torch::jit::IValue> jit_inputs_ivalues;
65
100
101
+ for (auto & input : seg_block.raw_inputs ()) {
102
+ LOG_DEBUG (
103
+ " Register input ivalues_maps for torch::jit::Value* " << input->debugName () << " , produced from "
104
+ << util::node_info (input->node ()));
105
+ }
106
+
66
107
// set inputs ivalues, now supports Tensor/Int to pass argumentes between different segments
67
108
for (auto & input : seg_block.raw_inputs ()) {
68
109
TORCHTRT_CHECK (
@@ -111,6 +152,9 @@ void getSegmentsOutputByRunning(
111
152
size_t idx = 0 ;
112
153
for (auto & output : seg_block.raw_outputs ()) {
113
154
ivalues_maps[output] = jit_results[idx++];
155
+ LOG_DEBUG (
156
+ " Register output ivalues_maps for torch::jit::Value* " << output->debugName () << " , produced from "
157
+ << util::node_info (output->node ()));
114
158
}
115
159
116
160
// set input shape for each segmented block so we wil use it in conversion process
@@ -146,19 +190,50 @@ void getSegmentsOutputByRunning(
146
190
input_types.push_back (cur_ivalue.toTensor ().scalar_type ());
147
191
}
148
192
}
149
-
150
- seg_block.register_inshapes (input_shapes);
193
+ LOG_DEBUG (" Begin register shape" );
194
+ if (register_iteration == 0 )
195
+ seg_block.register_inshapes (input_shapes);
196
+ else if (register_iteration == 1 )
197
+ seg_block.register_opt_shapes (input_shapes);
198
+ else if (register_iteration == 2 )
199
+ seg_block.register_max_shapes (input_shapes);
151
200
seg_block.register_intypes (input_types);
201
+ LOG_DEBUG (" Done" );
152
202
}
153
203
154
204
void runShapeAnalysis (
155
205
std::vector<SegmentedBlock>& segmented_blocks,
156
- std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map ,
206
+ std::vector<std:: unordered_map<const torch::jit::Value*, torch::jit::IValue>>& example_tensor_maps ,
157
207
const PartitionInfo& partition_info) {
158
208
// register every segment's input shape, and it's running output IValues
159
- for (auto & seg_block : segmented_blocks) {
160
- torch::jit::ConstantPooling (seg_block.g ());
161
- getSegmentsOutputByRunning (seg_block, example_tensor_map, partition_info);
209
+ if (example_tensor_maps.size () == 1 ) {
210
+ int i = 0 ;
211
+ for (auto & seg_block : segmented_blocks) {
212
+ torch::jit::ConstantPooling (seg_block.g ());
213
+ LOG_DEBUG (" Running the graph @" << i);
214
+ getSegmentsOutputByRunning (seg_block, example_tensor_maps[0 ], 0 , partition_info);
215
+ i++;
216
+ }
217
+ } else if (example_tensor_maps.size () == 3 ) {
218
+ int i = 0 ;
219
+ for (auto & seg_block : segmented_blocks) {
220
+ torch::jit::ConstantPooling (seg_block.g ());
221
+ LOG_DEBUG (" Running min graph @" << i);
222
+ getSegmentsOutputByRunning (seg_block, example_tensor_maps[0 ], 0 , partition_info);
223
+ i++;
224
+ }
225
+ for (auto & seg_block : segmented_blocks) {
226
+ torch::jit::ConstantPooling (seg_block.g ());
227
+ LOG_DEBUG (" Running opt graph @" << i);
228
+ getSegmentsOutputByRunning (seg_block, example_tensor_maps[1 ], 1 , partition_info);
229
+ }
230
+ for (auto & seg_block : segmented_blocks) {
231
+ torch::jit::ConstantPooling (seg_block.g ());
232
+ LOG_DEBUG (" Running max graph @" << i);
233
+ getSegmentsOutputByRunning (seg_block, example_tensor_maps[2 ], 2 , partition_info);
234
+ }
235
+ for (auto & seg_block : segmented_blocks)
236
+ seg_block.construct_dynamic_shape ();
162
237
}
163
238
return ;
164
239
}
0 commit comments