Skip to content

Commit 12d6f79

Browse files
smessmerfacebook-github-bot
authored andcommitted
Optional inputs and outputs (pytorch#19289)
Summary: Pull Request resolved: pytorch#19289 Allow optional inputs and outputs in native c10 operators Reviewed By: dzhulgakov Differential Revision: D14931927 fbshipit-source-id: 48f8bec009c6374345b34d933f148c08bb4f7118
1 parent fa96de2 commit 12d6f79

File tree

6 files changed

+535
-2
lines changed

6 files changed

+535
-2
lines changed

aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,104 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenFallbackKernelWith
456456
EXPECT_EQ(4, outputs[0].toInt());
457457
}
458458

459+
c10::optional<Tensor> called_arg2;
460+
c10::optional<int64_t> called_arg3;
461+
c10::optional<std::string> called_arg4;
462+
463+
void kernelWithOptInputWithoutOutput(Tensor arg1, const c10::optional<Tensor>& arg2, c10::optional<int64_t> arg3, c10::optional<std::string> arg4) {
464+
called = true;
465+
called_arg2 = arg2;
466+
called_arg3 = arg3;
467+
called_arg4 = arg4;
468+
}
469+
470+
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) {
471+
auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> ()", &kernelWithOptInputWithoutOutput);
472+
auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", "");
473+
ASSERT_TRUE(op.has_value());
474+
475+
called = false;
476+
auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text"));
477+
EXPECT_EQ(0, outputs.size());
478+
479+
EXPECT_TRUE(called);
480+
EXPECT_TRUE(called_arg2.has_value());
481+
EXPECT_EQ(called_arg2->type_id(), TensorType2());
482+
EXPECT_FALSE(called_arg3.has_value());
483+
EXPECT_TRUE(called_arg4.has_value());
484+
EXPECT_EQ(*called_arg4, "text");
485+
486+
called = false;
487+
outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue());
488+
EXPECT_EQ(0, outputs.size());
489+
490+
EXPECT_TRUE(called);
491+
EXPECT_FALSE(called_arg2.has_value());
492+
EXPECT_TRUE(called_arg3.has_value());
493+
EXPECT_EQ(*called_arg3, 4);
494+
EXPECT_FALSE(called_arg4.has_value());
495+
}
496+
497+
c10::optional<Tensor> kernelWithOptInputWithOutput(Tensor arg1, const c10::optional<Tensor>& arg2, c10::optional<int64_t> arg3, c10::optional<std::string> arg4) {
498+
called = true;
499+
called_arg2 = arg2;
500+
called_arg3 = arg3;
501+
called_arg4 = arg4;
502+
return arg2;
503+
}
504+
505+
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithOptionalInputs_withOutput_whenRegistered_thenCanBeCalled) {
506+
auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> Tensor?", &kernelWithOptInputWithOutput);
507+
auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", "");
508+
ASSERT_TRUE(op.has_value());
509+
510+
called = false;
511+
auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text"));
512+
EXPECT_EQ(1, outputs.size());
513+
EXPECT_EQ(TensorType2(), outputs[0].toTensor().type_id());
514+
515+
EXPECT_TRUE(called);
516+
EXPECT_TRUE(called_arg2.has_value());
517+
EXPECT_EQ(called_arg2->type_id(), TensorType2());
518+
EXPECT_FALSE(called_arg3.has_value());
519+
EXPECT_TRUE(called_arg4.has_value());
520+
EXPECT_EQ(*called_arg4, "text");
521+
522+
called = false;
523+
outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue());
524+
EXPECT_EQ(1, outputs.size());
525+
EXPECT_TRUE(outputs[0].isNone());
526+
527+
EXPECT_TRUE(called);
528+
EXPECT_FALSE(called_arg2.has_value());
529+
EXPECT_TRUE(called_arg3.has_value());
530+
EXPECT_EQ(*called_arg3, 4);
531+
EXPECT_FALSE(called_arg4.has_value());
532+
}
533+
534+
std::tuple<c10::optional<Tensor>, c10::optional<int64_t>, c10::optional<std::string>>
535+
kernelWithOptInputWithMultipleOutputs(Tensor arg1, const c10::optional<Tensor>& arg2, c10::optional<int64_t> arg3, c10::optional<std::string> arg4) {
536+
return std::make_tuple(arg2, arg3, arg4);
537+
}
538+
539+
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithOptionalInputs_withMultipleOutputs_whenRegistered_thenCanBeCalled) {
540+
auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> (Tensor?, int?, str?)", &kernelWithOptInputWithMultipleOutputs);
541+
auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", "");
542+
ASSERT_TRUE(op.has_value());
543+
544+
auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text"));
545+
EXPECT_EQ(3, outputs.size());
546+
EXPECT_EQ(TensorType2(), outputs[0].toTensor().type_id());
547+
EXPECT_TRUE(outputs[1].isNone());
548+
EXPECT_EQ("text", outputs[2].toString()->string());
549+
550+
outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue());
551+
EXPECT_EQ(3, outputs.size());
552+
EXPECT_TRUE(outputs[0].isNone());
553+
EXPECT_EQ(4, outputs[1].toInt());
554+
EXPECT_TRUE(outputs[2].isNone());
555+
}
556+
459557
std::tuple<int64_t, Tensor> kernelForSchemaInference(Tensor arg1, int64_t arg2, ArrayRef<Tensor> arg3) {
460558
return {};
461559
}

aten/src/ATen/core/op_registration/kernel_function_test.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,104 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenFallbackKernelWithoutTen
458458
EXPECT_EQ(4, outputs[0].toInt());
459459
}
460460

461+
c10::optional<Tensor> called_arg2;
462+
c10::optional<int64_t> called_arg3;
463+
c10::optional<std::string> called_arg4;
464+
465+
void kernelWithOptInputWithoutOutput(Tensor arg1, const c10::optional<Tensor>& arg2, c10::optional<int64_t> arg3, c10::optional<std::string> arg4) {
466+
called = true;
467+
called_arg2 = arg2;
468+
called_arg3 = arg3;
469+
called_arg4 = arg4;
470+
}
471+
472+
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) {
473+
auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> ()", kernel<decltype(kernelWithOptInputWithoutOutput), &kernelWithOptInputWithoutOutput>(), dispatchKey(TensorType1()));
474+
auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", "");
475+
ASSERT_TRUE(op.has_value());
476+
477+
called = false;
478+
auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text"));
479+
EXPECT_EQ(0, outputs.size());
480+
481+
EXPECT_TRUE(called);
482+
EXPECT_TRUE(called_arg2.has_value());
483+
EXPECT_EQ(called_arg2->type_id(), TensorType2());
484+
EXPECT_FALSE(called_arg3.has_value());
485+
EXPECT_TRUE(called_arg4.has_value());
486+
EXPECT_EQ(*called_arg4, "text");
487+
488+
called = false;
489+
outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue());
490+
EXPECT_EQ(0, outputs.size());
491+
492+
EXPECT_TRUE(called);
493+
EXPECT_FALSE(called_arg2.has_value());
494+
EXPECT_TRUE(called_arg3.has_value());
495+
EXPECT_EQ(*called_arg3, 4);
496+
EXPECT_FALSE(called_arg4.has_value());
497+
}
498+
499+
c10::optional<Tensor> kernelWithOptInputWithOutput(Tensor arg1, const c10::optional<Tensor>& arg2, c10::optional<int64_t> arg3, c10::optional<std::string> arg4) {
500+
called = true;
501+
called_arg2 = arg2;
502+
called_arg3 = arg3;
503+
called_arg4 = arg4;
504+
return arg2;
505+
}
506+
507+
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithOptionalInputs_withOutput_whenRegistered_thenCanBeCalled) {
508+
auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> Tensor?", kernel<decltype(kernelWithOptInputWithOutput), &kernelWithOptInputWithOutput>(), dispatchKey(TensorType1()));
509+
auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", "");
510+
ASSERT_TRUE(op.has_value());
511+
512+
called = false;
513+
auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text"));
514+
EXPECT_EQ(1, outputs.size());
515+
EXPECT_EQ(TensorType2(), outputs[0].toTensor().type_id());
516+
517+
EXPECT_TRUE(called);
518+
EXPECT_TRUE(called_arg2.has_value());
519+
EXPECT_EQ(called_arg2->type_id(), TensorType2());
520+
EXPECT_FALSE(called_arg3.has_value());
521+
EXPECT_TRUE(called_arg4.has_value());
522+
EXPECT_EQ(*called_arg4, "text");
523+
524+
called = false;
525+
outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue());
526+
EXPECT_EQ(1, outputs.size());
527+
EXPECT_TRUE(outputs[0].isNone());
528+
529+
EXPECT_TRUE(called);
530+
EXPECT_FALSE(called_arg2.has_value());
531+
EXPECT_TRUE(called_arg3.has_value());
532+
EXPECT_EQ(*called_arg3, 4);
533+
EXPECT_FALSE(called_arg4.has_value());
534+
}
535+
536+
std::tuple<c10::optional<Tensor>, c10::optional<int64_t>, c10::optional<std::string>>
537+
kernelWithOptInputWithMultipleOutputs(Tensor arg1, const c10::optional<Tensor>& arg2, c10::optional<int64_t> arg3, c10::optional<std::string> arg4) {
538+
return std::make_tuple(arg2, arg3, arg4);
539+
}
540+
541+
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithOptionalInputs_withMultipleOutputs_whenRegistered_thenCanBeCalled) {
542+
auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> (Tensor?, int?, str?)", kernel<decltype(kernelWithOptInputWithMultipleOutputs), &kernelWithOptInputWithMultipleOutputs>(), dispatchKey(TensorType1()));
543+
auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", "");
544+
ASSERT_TRUE(op.has_value());
545+
546+
auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text"));
547+
EXPECT_EQ(3, outputs.size());
548+
EXPECT_EQ(TensorType2(), outputs[0].toTensor().type_id());
549+
EXPECT_TRUE(outputs[1].isNone());
550+
EXPECT_EQ("text", outputs[2].toString()->string());
551+
552+
outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue());
553+
EXPECT_EQ(3, outputs.size());
554+
EXPECT_TRUE(outputs[0].isNone());
555+
EXPECT_EQ(4, outputs[1].toInt());
556+
EXPECT_TRUE(outputs[2].isNone());
557+
}
558+
461559
std::tuple<int64_t, Tensor> kernelForSchemaInference(Tensor arg1, int64_t arg2, ArrayRef<Tensor> arg3) {
462560
return {};
463561
}

aten/src/ATen/core/op_registration/kernel_functor.h

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ namespace detail {
2929
// cast it to the type that should be passed to the kernel function.
3030
// Examples: If the IValue contains a plain type like an int, return that.
3131
// If the IValue contains an IntList, return it as ArrayRef<int>.
32+
// TODO Should we move the IValue so we can avoid bumping the Tensor refcount?
3233
template<class T>
3334
struct ivalue_to_arg_type {
3435
static T call(const IValue& v) {
@@ -41,10 +42,34 @@ namespace detail {
4142
return v.to<intrusive_ptr<ivalue::List<T>>>()->elements();
4243
}
4344
};
45+
template<class T>
46+
struct ivalue_to_arg_type<optional<T>> {
47+
static optional<T> call(const IValue& v) {
48+
if (v.isNone()) {
49+
return nullopt;
50+
}
51+
return v.to<T>();
52+
}
53+
};
4454

4555
template<class T>
46-
IValue return_type_to_ivalue(T&& t) {
47-
return IValue(std::forward<T>(t));
56+
struct return_type_to_ivalue_ {
57+
static IValue call(T&& v) {
58+
return IValue(std::move(v));
59+
}
60+
};
61+
template<class T>
62+
struct return_type_to_ivalue_<optional<T>> {
63+
static IValue call(optional<T>&& v) {
64+
if (!v.has_value()) {
65+
return IValue();
66+
}
67+
return IValue(std::move(*v));
68+
}
69+
};
70+
template<class T>
71+
IValue return_type_to_ivalue(T&& v) {
72+
return return_type_to_ivalue_<guts::decay_t<T>>::call(std::move(v));
4873
}
4974

5075
template<class Functor, size_t... ivalue_arg_indices>

aten/src/ATen/core/op_registration/kernel_functor_test.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,110 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenFallbackKernelWithoutTens
601601
EXPECT_EQ(4, outputs[0].toInt());
602602
}
603603

604+
c10::optional<Tensor> called_arg2;
605+
c10::optional<int64_t> called_arg3;
606+
c10::optional<std::string> called_arg4;
607+
608+
struct KernelWithOptInputWithoutOutput final : OperatorKernel {
609+
void operator()(Tensor arg1, const c10::optional<Tensor>& arg2, c10::optional<int64_t> arg3, c10::optional<std::string> arg4) {
610+
called = true;
611+
called_arg2 = arg2;
612+
called_arg3 = arg3;
613+
called_arg4 = arg4;
614+
}
615+
};
616+
617+
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) {
618+
auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> ()", kernel<KernelWithOptInputWithoutOutput>(), dispatchKey(TensorType1()));
619+
auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", "");
620+
ASSERT_TRUE(op.has_value());
621+
622+
called = false;
623+
auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text"));
624+
EXPECT_EQ(0, outputs.size());
625+
626+
EXPECT_TRUE(called);
627+
EXPECT_TRUE(called_arg2.has_value());
628+
EXPECT_EQ(called_arg2->type_id(), TensorType2());
629+
EXPECT_FALSE(called_arg3.has_value());
630+
EXPECT_TRUE(called_arg4.has_value());
631+
EXPECT_EQ(*called_arg4, "text");
632+
633+
called = false;
634+
outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue());
635+
EXPECT_EQ(0, outputs.size());
636+
637+
EXPECT_TRUE(called);
638+
EXPECT_FALSE(called_arg2.has_value());
639+
EXPECT_TRUE(called_arg3.has_value());
640+
EXPECT_EQ(*called_arg3, 4);
641+
EXPECT_FALSE(called_arg4.has_value());
642+
}
643+
644+
struct KernelWithOptInputWithOutput final : OperatorKernel {
645+
c10::optional<Tensor> operator()(Tensor arg1, const c10::optional<Tensor>& arg2, c10::optional<int64_t> arg3, c10::optional<std::string> arg4) {
646+
called = true;
647+
called_arg2 = arg2;
648+
called_arg3 = arg3;
649+
called_arg4 = arg4;
650+
return arg2;
651+
}
652+
};
653+
654+
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithOptionalInputs_withOutput_whenRegistered_thenCanBeCalled) {
655+
auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> Tensor?", kernel<KernelWithOptInputWithOutput>(), dispatchKey(TensorType1()));
656+
auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", "");
657+
ASSERT_TRUE(op.has_value());
658+
659+
called = false;
660+
auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text"));
661+
EXPECT_EQ(1, outputs.size());
662+
EXPECT_EQ(TensorType2(), outputs[0].toTensor().type_id());
663+
664+
EXPECT_TRUE(called);
665+
EXPECT_TRUE(called_arg2.has_value());
666+
EXPECT_EQ(called_arg2->type_id(), TensorType2());
667+
EXPECT_FALSE(called_arg3.has_value());
668+
EXPECT_TRUE(called_arg4.has_value());
669+
EXPECT_EQ(*called_arg4, "text");
670+
671+
called = false;
672+
outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue());
673+
EXPECT_EQ(1, outputs.size());
674+
EXPECT_TRUE(outputs[0].isNone());
675+
676+
EXPECT_TRUE(called);
677+
EXPECT_FALSE(called_arg2.has_value());
678+
EXPECT_TRUE(called_arg3.has_value());
679+
EXPECT_EQ(*called_arg3, 4);
680+
EXPECT_FALSE(called_arg4.has_value());
681+
}
682+
683+
struct KernelWithOptInputWithMultipleOutputs final : OperatorKernel {
684+
std::tuple<c10::optional<Tensor>, c10::optional<int64_t>, c10::optional<std::string>>
685+
operator()(Tensor arg1, const c10::optional<Tensor>& arg2, c10::optional<int64_t> arg3, c10::optional<std::string> arg4) {
686+
return std::make_tuple(arg2, arg3, arg4);
687+
}
688+
};
689+
690+
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithOptionalInputs_withMultipleOutputs_whenRegistered_thenCanBeCalled) {
691+
auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> (Tensor?, int?, str?)", kernel<KernelWithOptInputWithMultipleOutputs>(), dispatchKey(TensorType1()));
692+
auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", "");
693+
ASSERT_TRUE(op.has_value());
694+
695+
auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text"));
696+
EXPECT_EQ(3, outputs.size());
697+
EXPECT_EQ(TensorType2(), outputs[0].toTensor().type_id());
698+
EXPECT_TRUE(outputs[1].isNone());
699+
EXPECT_EQ("text", outputs[2].toString()->string());
700+
701+
outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue());
702+
EXPECT_EQ(3, outputs.size());
703+
EXPECT_TRUE(outputs[0].isNone());
704+
EXPECT_EQ(4, outputs[1].toInt());
705+
EXPECT_TRUE(outputs[2].isNone());
706+
}
707+
604708
struct KernelForSchemaInference final : OperatorKernel {
605709
std::tuple<int64_t, Tensor> operator()(Tensor arg1, int64_t arg2, ArrayRef<Tensor> arg3) {
606710
return {};

0 commit comments

Comments
 (0)