diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index 4a0d1c451f..7064517dce 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -93,6 +93,20 @@ def get_type_result_resolver_function(self): """ return self.result_type_resolver_fn_ + @property + def nin(self): + """ + Returns the number of arguments treated as inputs. + """ + return 1 + + @property + def nout(self): + """ + Returns the number of arguments treated as outputs. + """ + return 1 + @property def types(self): """Returns information about types supported by @@ -531,6 +545,20 @@ def get_type_promotion_path_acceptance_function(self): """ return self.acceptance_fn_ + @property + def nin(self): + """ + Returns the number of arguments treated as inputs. + """ + return 2 + + @property + def nout(self): + """ + Returns the number of arguments treated as outputs. + """ + return 1 + @property def types(self): """Returns information about types supported by diff --git a/dpctl/tests/elementwise/test_elementwise_classes.py b/dpctl/tests/elementwise/test_elementwise_classes.py index b7f1d26d6e..634b1fbdea 100644 --- a/dpctl/tests/elementwise/test_elementwise_classes.py +++ b/dpctl/tests/elementwise/test_elementwise_classes.py @@ -78,3 +78,27 @@ def test_binary_class_str_repr(): kl_n = binary_fn.__name__ assert kl_n in s assert kl_n in r + + +def test_unary_class_nin(): + nin = unary_fn.nin + assert isinstance(nin, int) + assert nin == 1 + + +def test_binary_class_nin(): + nin = binary_fn.nin + assert isinstance(nin, int) + assert nin == 2 + + +def test_unary_class_nout(): + nout = unary_fn.nout + assert isinstance(nout, int) + assert nout == 1 + + +def test_binary_class_nout(): + nout = binary_fn.nout + assert isinstance(nout, int) + assert nout == 1