@@ -257,3 +257,113 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) {
257257 int count = count_trt_engines (fallback_g);
258258 ASSERT_TRUE (count == 2 );
259259}
260+
261+ TEST (Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
262+ /* parseIR does not support "= aten::_set_item" so we will build this graph manually
263+ const auto graph = R"IR(
264+ graph(%x : Tensor,
265+ %y : Tensor):
266+ %2 : str = prim::Constant[value="INS"]()
267+ %3 : str = prim::Constant[value="OUTS"]()
268+ %4 : bool = prim::Constant[value=0]()
269+ %5 : int = prim::Constant[value=-1]()
270+ %6 : Dict(str, Tensor) = prim::DictConstruct()
271+ = aten::_set_item(%6, %2, %x)
272+ %7 : Tensor = aten::__getitem__(%6, %2)
273+ %8 : Tensor = aten::lt(%7, %y)
274+ %9 : Tensor?[] = prim::ListConstruct(%8)
275+ %10 : int = prim::dtype(%7)
276+ %11 : Device = prim::device(%7)
277+ %12 : Tensor = aten::tensor(%5, %10, %11, %4)
278+ %13 : Tensor = aten::index_put_(%7, %9, %12, %4)
279+ = aten::_set_item(%6, %3, %7)
280+ %14 : Tensor = aten::__getitem__(%6, %2)
281+ %15 : Tensor = aten::__getitem__(%6, %3)
282+ return (%14, %15))IR";
283+ */
284+ auto g = std::make_shared<torch::jit::Graph>();
285+ auto x = g->insertInput (0 , " x" );
286+ auto y = g->insertInput (1 , " y" );
287+ torch::jit::IValue ins_key (" INS" );
288+ auto ins_key_val = g->insertConstant (ins_key);
289+ torch::jit::IValue outs_key (" OUTS" );
290+ auto outs_key_val = g->insertConstant (outs_key);
291+ torch::jit::IValue zero (0 );
292+ auto false_const_val = g->insertConstant (zero);
293+ false_const_val->setType (c10::BoolType::get ());
294+ torch::jit::IValue neg_one (-1 );
295+ auto neg_one_const_val = g->insertConstant (neg_one);
296+ auto dict_node = g->createDict (ins_key_val->type (), x->type (), torch::jit::ArrayRef<torch::jit::Value*>(), torch::jit::ArrayRef<torch::jit::Value*>());
297+ g->insertNode (dict_node);
298+ auto set_node = g->create (torch::jit::Symbol::fromQualString (" aten::_set_item" ), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val, x}, 0 );
299+ g->insertNode (set_node);
300+ auto get_node = g->create (torch::jit::Symbol::fromQualString (" aten::__getitem__" ), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val}, 1 );
301+ g->insertNode (get_node);
302+ auto lt_node = g->create (torch::jit::Symbol::fromQualString (" aten::lt" ), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output (), y}, 1 );
303+ g->insertNode (lt_node);
304+ auto list_node = g->createList (at::OptionalType::create (lt_node->output ()->type ()), torch::jit::ArrayRef<torch::jit::Value*>{lt_node->output ()});
305+ g->insertNode (list_node);
306+ auto dtype_node = g->create (torch::jit::Symbol::fromQualString (" prim::dtype" ), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output ()}, 1 );
307+ dtype_node->output ()->setType (neg_one_const_val->type ());
308+ g->insertNode (dtype_node);
309+ auto device_node = g->create (torch::jit::Symbol::fromQualString (" prim::device" ), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output ()}, 1 );
310+ device_node->output ()->setType (c10::DeviceObjType::get ());
311+ g->insertNode (device_node);
312+ auto tensor_node = g->create (torch::jit::Symbol::fromQualString (" aten::tensor" ), torch::jit::ArrayRef<torch::jit::Value*>{neg_one_const_val, dtype_node->output (), device_node->output (), false_const_val}, 1 );
313+ g->insertNode (tensor_node);
314+ auto index_put_node = g->create (torch::jit::Symbol::fromQualString (" aten::index_put_" ),
315+ torch::jit::ArrayRef<torch::jit::Value*>{get_node->output (), list_node->output (), tensor_node->output (), false_const_val}, 1 );
316+ g->insertNode (index_put_node);
317+ auto out_set_node = g->create (torch::jit::Symbol::fromQualString (" aten::_set_item" ),
318+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), outs_key_val, get_node->output ()}, 0 );
319+ g->insertNode (out_set_node);
320+ auto get_ins_node = g->create (torch::jit::Symbol::fromQualString (" aten::__getitem__" ), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val}, 1 );
321+ g->insertNode (get_ins_node);
322+ auto get_outs_node = g->create (torch::jit::Symbol::fromQualString (" aten::__getitem__" ), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), outs_key_val}, 1 );
323+ g->insertNode (get_outs_node);
324+ g->registerOutput (get_ins_node->output ());
325+ g->registerOutput (get_outs_node->output ());
326+
327+ torch_tensorrt::core::partitioning::PartitionInfo partition_info;
328+ partition_info.enabled = true ;
329+ std::vector<torch_tensorrt::core::ir::Input> inputs;
330+ inputs.push_back (torch_tensorrt::core::ir::Input ({4 , 4 }));
331+ inputs.push_back (torch_tensorrt::core::ir::Input ({4 , 4 }));
332+
333+ std::unordered_map<const torch::jit::Value*, torch_tensorrt::core::ir::Input> inputs_map;
334+ std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>> input_types;
335+ for (size_t i = 0 ; i < g->inputs ().size (); ++i) {
336+ inputs_map.insert ({g->inputs ()[i], inputs[i]});
337+ input_types.insert ({g->inputs ()[i], {at::kFloat }});
338+ }
339+ auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs (inputs_map, input_types);
340+ auto segmented_blocks =
341+ torch_tensorrt::core::partitioning::Partition (g->block (), input_ivalues_map, partition_info);
342+
343+ int torch_block_cnt = 0 , trt_block_cnt = 0 ;
344+ for (const auto & segmented_block : segmented_blocks) {
345+ if (segmented_block.target () == torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT ) {
346+ ++trt_block_cnt;
347+ ASSERT_TRUE (checkSegmentedBlockInputType (segmented_block, [](torch::jit::TypePtr type_ptr) {
348+ return type_ptr->isSubtypeOf (torch::jit::TensorType::get ());
349+ }));
350+ } else {
351+ ++torch_block_cnt;
352+ bool output_dict = false ;
353+ bool input_dict = false ;
354+ auto dict_type = dict_node->output ()->type ();
355+ for (auto in : segmented_block.raw_inputs ()) {
356+ if (in->type ()->isSubtypeOf (dict_type)){
357+ input_dict = true ;
358+ }
359+ }
360+ for (auto out : segmented_block.raw_outputs ()) {
361+ if (out->type ()->isSubtypeOf (dict_type)){
362+ output_dict = true ;
363+ }
364+ }
365+ EXPECT_TRUE (output_dict ^ input_dict);
366+ }
367+ }
368+ ASSERT_TRUE (trt_block_cnt == 1 && torch_block_cnt == 2 );
369+ }
0 commit comments