@@ -396,6 +396,20 @@ auto aten_registrations TRTORCH_UNUSED =
396396 EvalOptions ().validSchemas ({
397397 " aten::numel(Tensor self) -> int" ,
398398 })})
399+ .evaluator({c10::Symbol::fromQualString (" aten::t" ),
400+ [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
401+ auto tensor_var = args.at (n->input (0 ));
402+ if (tensor_var.IValue ()->isTensor ()) {
403+ auto tensor = tensor_var.unwrapToTensor ();
404+ return tensor.t ();
405+ } else {
406+ TRTORCH_THROW_ERROR (" Unimplemented data type for aten::t evaluator: ITensor" );
407+ return {};
408+ }
409+ },
410+ EvalOptions ().validSchemas ({
411+ " aten::t(Tensor self) -> Tensor" ,
412+ })})
399413 .evaluator({c10::Symbol::fromQualString (" aten::dim" ),
400414 [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
401415 auto tensor_var = args.at (n->input (0 ));
@@ -467,7 +481,61 @@ auto aten_registrations TRTORCH_UNUSED =
467481 LOG_WARNING (" Warning from TorchScript: " << *warning);
468482 return {};
469483 },
470- EvalOptions ()});
484+ EvalOptions ()})
485+ .evaluator({c10::Symbol::fromQualString (" aten::arange" ),
486+ [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
487+ int input_size = n->inputs ().size ();
488+ int scalar_count = 0 ;
489+ for (int i = 0 ; i < input_size; i++) {
490+ if (args.at (n->input (i)).IValue ()->isScalar ()) {
491+ scalar_count += 1 ;
492+ }
493+ }
494+ if (scalar_count == 1 ) {
495+ if (args.at (n->input (0 )).IValue ()->isInt ()) {
496+ int end_scalar = args.at (n->input (0 )).unwrapToInt ();
497+ return torch::arange (end_scalar);
498+ } else if (args.at (n->input (0 )).IValue ()->isDouble ()) {
499+ float end_scalar = args.at (n->input (0 )).unwrapToScalar ().to <float >();
500+ return torch::arange (end_scalar);
501+ }
502+ } else if (scalar_count == 2 ) {
503+ if (args.at (n->input (0 )).IValue ()->isDouble () || args.at (n->input (1 )).IValue ()->isDouble ()) {
504+ float start_scalar = args.at (n->input (0 )).unwrapToScalar ().to <float >();
505+ float end_scalar = args.at (n->input (1 )).unwrapToScalar ().to <float >();
506+ return torch::arange (start_scalar, end_scalar);
507+ } else {
508+ int start_scalar = args.at (n->input (0 )).unwrapToInt ();
509+ int end_scalar = args.at (n->input (1 )).unwrapToInt ();
510+ return torch::arange (start_scalar, end_scalar);
511+ }
512+ } else if (scalar_count == 3 ) {
513+ if (args.at (n->input (0 )).IValue ()->isDouble () || args.at (n->input (1 )).IValue ()->isDouble () ||
514+ args.at (n->input (2 )).IValue ()->isDouble ()) {
515+ float start_scalar = args.at (n->input (0 )).unwrapToScalar ().to <float >();
516+ float end_scalar = args.at (n->input (1 )).unwrapToScalar ().to <float >();
517+ float step_scalar = args.at (n->input (2 )).unwrapToScalar ().to <float >();
518+ return torch::arange (start_scalar, end_scalar, step_scalar);
519+ } else {
520+ int start_scalar = args.at (n->input (0 )).unwrapToInt ();
521+ int end_scalar = args.at (n->input (1 )).unwrapToInt ();
522+ int step_scalar = args.at (n->input (2 )).unwrapToInt ();
523+ return torch::arange (start_scalar, end_scalar, step_scalar);
524+ }
525+ } else {
526+ TRTORCH_THROW_ERROR (
527+ " Invalid input argument size for aten::arange, input argument size: " << input_size);
528+ }
529+ return {};
530+ },
531+ EvalOptions ().validSchemas ({
532+ R"SIG( aten::arange(Scalar end, *, int? dtype=None, int? layout=None,
533+ Device? device=None, bool? pin_memory=None) -> (Tensor))SIG" ,
534+ R"SIG( aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None,
535+ Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor))SIG" ,
536+ R"SIG( aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None,
537+ Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor))SIG" ,
538+ })});
471539} // namespace
472540} // namespace evaluators
473541} // namespace conversion
0 commit comments