@@ -10,16 +10,25 @@ namespace torch_tensorrt {
1010namespace core {
1111namespace partitioning {
1212
13- at::Tensor generateSingleInput (ir::Input& input, c10::optional<at::ScalarType>& type_opt) {
14- auto cur_shape = input.input_shape ;
15- std::vector<int64_t > shape;
13+ at::Tensor generateSingleInput (
14+ ir::Input& input,
15+ c10::optional<at::ScalarType>& type_opt,
16+ const ir::ShapeMode& shape_mode) {
17+ nvinfer1::Dims input_shape = input.input_shape ;
18+ if (input.input_is_dynamic ) {
19+ if (shape_mode == ir::ShapeMode::kMIN ) {
20+ input_shape = input.min ;
21+ } else if (shape_mode == ir::ShapeMode::kOPT ) {
22+ input_shape = input.opt ;
23+ } else {
24+ input_shape = input.max ;
25+ }
26+ }
1627
1728 // Initialize min and max ranges for random number selection
1829 int LoValIncl = 0 ;
1930 int HiValExcl = 2 ;
2031
21- shape.insert (shape.begin (), std::begin (cur_shape.d ), std::begin (cur_shape.d ) + cur_shape.nbDims );
22-
2332 auto type = at::kFloat ;
2433 if (type_opt) {
2534 type = type_opt.value ();
@@ -29,14 +38,15 @@ at::Tensor generateSingleInput(ir::Input& input, c10::optional<at::ScalarType>&
2938
3039 // Make the value range for input tensor a uniform (float) distribution
3140 // over [LoValIncl, HiValExcl), then cast to the desired dtype
32- auto in = ((HiValExcl - LoValIncl) * at::rand (shape , {at::kCUDA }) + LoValIncl).to (type);
41+ auto in = ((HiValExcl - LoValIncl) * at::rand (util::toVec (input_shape) , {at::kCUDA }) + LoValIncl).to (type);
3342
3443 return in;
3544}
3645
3746std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomInputs (
3847 std::unordered_map<const torch::jit::Value*, std::vector<ir::Input>>& inputs,
39- std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>>& types) {
48+ std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>>& types,
49+ const ir::ShapeMode& shape_mode) {
4050 // generate random inputs for running pytorch segments
4151 std::unordered_map<const torch::jit::Value*, torch::jit::IValue> ivalue_map;
4252
@@ -45,21 +55,21 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomI
4555 c10::TypePtr elementType = c10::TensorType::get ();
4656 auto generic_list = c10::impl::GenericList (elementType);
4757 for (size_t i = 0 ; i < input.second .size (); i++) {
48- auto in = generateSingleInput (input.second [i], types[input.first ][i]);
58+ auto in = generateSingleInput (input.second [i], types[input.first ][i], shape_mode );
4959 generic_list.push_back (in.clone ());
5060 }
5161 ivalue_map[input.first ] = c10::IValue (generic_list);
5262 } else if (input.first ->type ()->kind () == torch::jit::TypeKind::TupleType) {
5363 // create tuple
5464 std::vector<torch::jit::IValue> list;
5565 for (size_t i = 0 ; i < input.second .size (); i++) {
56- auto in = generateSingleInput (input.second [i], types[input.first ][i]);
66+ auto in = generateSingleInput (input.second [i], types[input.first ][i], shape_mode );
5767 list.push_back (in.clone ());
5868 }
5969 auto tuple = c10::ivalue::Tuple::create (list); // create tuple ptr
6070 ivalue_map[input.first ] = c10::IValue (tuple);
6171 } else {
62- auto in = generateSingleInput (input.second [0 ], types[input.first ][0 ]);
72+ auto in = generateSingleInput (input.second [0 ], types[input.first ][0 ], shape_mode );
6373 ivalue_map[input.first ] = in.clone ();
6474 }
6575 }
@@ -124,7 +134,8 @@ torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool i
124134void getSegmentsOutputByRunning (
125135 SegmentedBlock& seg_block,
126136 std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
127- const PartitioningInfo& partitioning_info) {
137+ const PartitioningInfo& partitioning_info,
138+ const ir::ShapeMode& shape_mode) {
128139 // create a module to run the graph
129140 auto g = seg_block.g ();
130141 auto copy_g = g->copy ();
@@ -235,7 +246,7 @@ void getSegmentsOutputByRunning(
235246 }
236247
237248 // set input shape for each segmented block so we wil use it in conversion process
238- std::vector<ir::Input > input_shapes;
249+ std::vector<std::vector< int64_t > > input_shapes;
239250 std::vector<at::ScalarType> input_types;
240251 for (size_t i = 0 ; i < seg_block.inputs ().size (); ++i) {
241252 if (ivalues_maps[seg_block.raw_inputs ()[i]].isTensor ()) {
@@ -270,15 +281,19 @@ void getSegmentsOutputByRunning(
270281 // TODO: tuple and list inputs in subgraph
271282 }
272283
273- seg_block.register_inshapes (input_shapes);
284+ seg_block.register_inshapes (input_shapes, shape_mode );
274285 seg_block.register_intypes (input_types);
275286}
276287
277- void runShapeAnalysis (PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map) {
288+ void runShapeAnalysis (
289+ PartitioningCtx* ctx,
290+ torch::jit::Block* block,
291+ ExampleIValues& example_tensor_map,
292+ const ir::ShapeMode& shape_mode) {
278293 // register every segment's input shape, and it's running output IValues
279294 for (auto & seg_block : ctx->partitioned_blocks [block]) {
280295 torch::jit::ConstantPooling (seg_block.g ());
281- getSegmentsOutputByRunning (seg_block, example_tensor_map, ctx->settings );
296+ getSegmentsOutputByRunning (seg_block, example_tensor_map, ctx->settings , shape_mode );
282297 }
283298 return ;
284299}
0 commit comments