From d239486ef178a541636ce06b7ca9f05dcb620b4b Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 27 Nov 2023 09:31:01 -0800 Subject: [PATCH 1/2] Implements `nin` and `nout` for element-wise funcs `nin` and `nout` properties return the number of arguments to the function treated as inputs or outputs, respectively --- dpctl/tensor/_elementwise_common.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index 4a0d1c451f..7064517dce 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -93,6 +93,20 @@ def get_type_result_resolver_function(self): """ return self.result_type_resolver_fn_ + @property + def nin(self): + """ + Returns the number of arguments treated as inputs. + """ + return 1 + + @property + def nout(self): + """ + Returns the number of arguments treated as outputs. + """ + return 1 + @property def types(self): """Returns information about types supported by @@ -531,6 +545,20 @@ def get_type_promotion_path_acceptance_function(self): """ return self.acceptance_fn_ + @property + def nin(self): + """ + Returns the number of arguments treated as inputs. + """ + return 2 + + @property + def nout(self): + """ + Returns the number of arguments treated as outputs. + """ + return 1 + @property def types(self): """Returns information about types supported by From ad258b82aa9ca0b04d2aaab6eb8768d6f9be8eb0 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 27 Nov 2023 09:41:51 -0800 Subject: [PATCH 2/2] Adds tests for `nin` and `nout` properties --- .../elementwise/test_elementwise_classes.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/dpctl/tests/elementwise/test_elementwise_classes.py b/dpctl/tests/elementwise/test_elementwise_classes.py index b7f1d26d6e..634b1fbdea 100644 --- a/dpctl/tests/elementwise/test_elementwise_classes.py +++ b/dpctl/tests/elementwise/test_elementwise_classes.py @@ -78,3 +78,27 @@ def test_binary_class_str_repr(): kl_n = binary_fn.__name__ assert kl_n in s assert kl_n in r + + +def test_unary_class_nin(): + nin = unary_fn.nin + assert isinstance(nin, int) + assert nin == 1 + + +def test_binary_class_nin(): + nin = binary_fn.nin + assert isinstance(nin, int) + assert nin == 2 + + +def test_unary_class_nout(): + nout = unary_fn.nout + assert isinstance(nout, int) + assert nout == 1 + + +def test_binary_class_nout(): + nout = binary_fn.nout + assert isinstance(nout, int) + assert nout == 1