diff --git a/onnxscript/values.py b/onnxscript/values.py index 472dab2898..ed919c99df 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -142,6 +142,91 @@ class OnnxClosure: function: Any +UserModeValue = Any +EagerModeValue = Any +ExtendedModeValue = Any + +# UserModeValue = Union[Optional[np.ndarray], List["UserModeValue"], Tuple["UserModeValue", ...]] + +# EagerModeValue = Union[ +# Optional["tensor.Tensor"], List["EagerModeValue"], Tuple["EagerModeValue", ...] +# ] + +# ExtendedModeValue = Union[ +# Optional["tensor.Tensor"], +# List["ExtendedModeValue"], +# Tuple["ExtendedModeValue", ...], +# np.ndarray, +# int, +# float, +# bool, +# ] + + +def _adapt_to_eager_mode(inputs: ExtendedModeValue) -> EagerModeValue: + """Adapts inputs into representation used by onnxscript eager mode. + + This does the following transformations: + * It adds an onnxscript Tensor wrapper around numpy arrays, which + allows the use of overloaded operators like + to be controlled by onnxscript. + * It also provides a promotion of scalars into tensors as a convenience. + This is needed to complement the similar promotion supported by the + onnxscript converter (for example, when an attribute is promoted and used + as an input argument). + + Args: + inputs: a list/tuple of inputs to an ONNX function + + Returns: + a pair (wrapped_inputs, flag) where flag indicates whether any numpy array + was wrapped into a Tensor. + """ + has_array = False + + def adapt(input: ExtendedModeValue) -> EagerModeValue: + if isinstance(input, np.ndarray): + nonlocal has_array + has_array = True + return tensor.Tensor(input) + elif isinstance(input, tensor.Tensor): + return input + elif isinstance(input, (bool, int, float)): + return tensor.Tensor(np.array(input)) + elif input is None: + return None + elif isinstance(input, list): + return [adapt(elt) for elt in input] + elif isinstance(input, tuple): + return tuple(adapt(elt) for elt in input) + raise TypeError(f"Unexpected input type {type(input)}.") + + result = adapt(inputs) + return result, has_array + + +def _adapt_to_user_mode(output: ExtendedModeValue) -> UserModeValue: + """Unwraps Tensor wrapper around numpy arrays. + + Args: + output: output of an ONNX function, which can be either a single + onnx value or a list/tuple of onnx values. + + Returns: + unwrapped output + """ + if isinstance(output, tensor.Tensor): + return output.value + elif output is None: + return None + elif isinstance(output, list): + return [_adapt_to_user_mode(elt) for elt in output] + elif isinstance(output, tuple): + return tuple(_adapt_to_user_mode(elt) for elt in output) + elif isinstance(output, np.ndarray): + return output + raise TypeError(f"Unexpected type {type(output)}.") + + class OnnxFunction(Op): """Represents an ONNX op for which a function-body has been defined in onnxscript. @@ -185,75 +270,17 @@ def fun(*args, **kwargs): def __call__(self, *args, **kwargs): """Implements an eager-mode execution of an onnxscript function.""" - if len(args) == 0: - # Operator Constant, it is usually called within a function. - return self._libcall(**kwargs) - if isinstance(args[0], tensor.Tensor): - return self._libcall(*args, **kwargs) - return self._usercall(*args, **kwargs) - - def _usercall(self, *args, **kwargs): - """Eager mode""" - new_args = [] - for i, a in enumerate(args): - if isinstance(a, np.ndarray): - new_args.append(tensor.Tensor(a)) - elif isinstance(a, (bool, int, float)): - new_args.append(tensor.Tensor(np.array(a))) - else: - raise TypeError(f"Unexpected input type {type(a)} for an input {i}.") - res = self.function(*new_args, **kwargs) - if isinstance(res, np.ndarray): - return res - if isinstance(res, tensor.Tensor): - return res.value - if isinstance(res, (list, tuple)): - unwrapped = [] - for i, r in enumerate(res): - if isinstance(r, np.ndarray): - unwrapped.append(r) - elif isinstance(r, tensor.Tensor): - unwrapped.append(r.value) - else: - raise TypeError( - f"Unexpected output type {type(r)} for an output {i} " - f"in function {self.function!r}." - ) - if isinstance(res, tuple): - return tuple(unwrapped) - return unwrapped - raise TypeError(f"Unexpected output type {type(res)} in function {self.function!r}.") - - def _libcall(self, *args, **kwargs): - """This method must be called when a function decoracted with `script` - calls another one decorated with `script`. - """ - new_args = [] - for i, a in enumerate(args): - if isinstance(a, tensor.Tensor): - new_args.append(a) - elif isinstance(a, bool): - # TODO: default values for function parameters - # are not properly handled yet. This section - # should disappear. - new_args.append(tensor.Tensor(np.array(a))) - else: - raise TypeError(f"Unexpected input type {type(a)} for an input {i}.") - res = self.function(*new_args, **kwargs) - if isinstance(res, tensor.Tensor): - return res - if isinstance(res, tuple): - unwrapped = [] - for i, r in enumerate(res): - if isinstance(r, tensor.Tensor): - unwrapped.append(r) - else: - raise TypeError( - f"Unexpected output type {type(r)} for an output {i} " - f"in function {self.function!r}." - ) - return tuple(unwrapped) - raise TypeError(f"Unexpected output type {type(res)} in function {self.function!r}.") + new_args, has_array = _adapt_to_eager_mode(args) + result = self.function(*new_args, **kwargs) + + # We use a heuristic to decide whether to return output values as + # numpy arrays or tensor.Tensors. If the function has at least one + # numpy array as input, we return numpy arrays. Otherwise, we return + # tensor.Tensors. We could use a user-specified flag to control this + # or explicitly track whether this is a top-level function-call or + # a nested function-call. + + return _adapt_to_user_mode(result) if has_array else result def to_function_proto(self, domain=None): """Converts the function into :class:`onnx.FunctionProto`."""