Skip to content

Cleanup _libcall and _usercall distinction #199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Nov 22, 2022
Merged
165 changes: 96 additions & 69 deletions onnxscript/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,91 @@ class OnnxClosure:
function: Any


UserModeValue = Any
EagerModeValue = Any
ExtendedModeValue = Any

# UserModeValue = Union[Optional[np.ndarray], List["UserModeValue"], Tuple["UserModeValue", ...]]

# EagerModeValue = Union[
# Optional["tensor.Tensor"], List["EagerModeValue"], Tuple["EagerModeValue", ...]
# ]

# ExtendedModeValue = Union[
# Optional["tensor.Tensor"],
# List["ExtendedModeValue"],
# Tuple["ExtendedModeValue", ...],
# np.ndarray,
# int,
# float,
# bool,
# ]


def _adapt_to_eager_mode(inputs: ExtendedModeValue) -> EagerModeValue:
"""Adapts inputs into representation used by onnxscript eager mode.

This does the following transformations:
* It adds an onnxscript Tensor wrapper around numpy arrays, which
allows the use of overloaded operators like + to be controlled by onnxscript.
* It also provides a promotion of scalars into tensors as a convenience.
This is needed to complement the similar promotion supported by the
onnxscript converter (for example, when an attribute is promoted and used
as an input argument).

Args:
inputs: a list/tuple of inputs to an ONNX function

Returns:
a pair (wrapped_inputs, flag) where flag indicates whether any numpy array
was wrapped into a Tensor.
"""
has_array = False

def adapt(input: ExtendedModeValue) -> EagerModeValue:
if isinstance(input, np.ndarray):
nonlocal has_array
has_array = True
return tensor.Tensor(input)
elif isinstance(input, tensor.Tensor):
return input
elif isinstance(input, (bool, int, float)):
return tensor.Tensor(np.array(input))
elif input is None:
return None
elif isinstance(input, list):
return [adapt(elt) for elt in input]
elif isinstance(input, tuple):
return tuple(adapt(elt) for elt in input)
raise TypeError(f"Unexpected input type {type(input)}.")

result = adapt(inputs)
return result, has_array
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional but I think some tests will be helpful



def _adapt_to_user_mode(output: ExtendedModeValue) -> UserModeValue:
"""Unwraps Tensor wrapper around numpy arrays.

Args:
output: output of an ONNX function, which can be either a single
onnx value or a list/tuple of onnx values.

Returns:
unwrapped output
"""
if isinstance(output, tensor.Tensor):
return output.value
elif output is None:
return None
elif isinstance(output, list):
return [_adapt_to_user_mode(elt) for elt in output]
elif isinstance(output, tuple):
return tuple(_adapt_to_user_mode(elt) for elt in output)
elif isinstance(output, np.ndarray):
return output
raise TypeError(f"Unexpected type {type(output)}.")


class OnnxFunction(Op):
"""Represents an ONNX op for which a function-body has been defined in onnxscript.

Expand Down Expand Up @@ -185,75 +270,17 @@ def fun(*args, **kwargs):

def __call__(self, *args, **kwargs):
"""Implements an eager-mode execution of an onnxscript function."""
if len(args) == 0:
# Operator Constant, it is usually called within a function.
return self._libcall(**kwargs)
if isinstance(args[0], tensor.Tensor):
return self._libcall(*args, **kwargs)
return self._usercall(*args, **kwargs)

def _usercall(self, *args, **kwargs):
"""Eager mode"""
new_args = []
for i, a in enumerate(args):
if isinstance(a, np.ndarray):
new_args.append(tensor.Tensor(a))
elif isinstance(a, (bool, int, float)):
new_args.append(tensor.Tensor(np.array(a)))
else:
raise TypeError(f"Unexpected input type {type(a)} for an input {i}.")
res = self.function(*new_args, **kwargs)
if isinstance(res, np.ndarray):
return res
if isinstance(res, tensor.Tensor):
return res.value
if isinstance(res, (list, tuple)):
unwrapped = []
for i, r in enumerate(res):
if isinstance(r, np.ndarray):
unwrapped.append(r)
elif isinstance(r, tensor.Tensor):
unwrapped.append(r.value)
else:
raise TypeError(
f"Unexpected output type {type(r)} for an output {i} "
f"in function {self.function!r}."
)
if isinstance(res, tuple):
return tuple(unwrapped)
return unwrapped
raise TypeError(f"Unexpected output type {type(res)} in function {self.function!r}.")

def _libcall(self, *args, **kwargs):
"""This method must be called when a function decoracted with `script`
calls another one decorated with `script`.
"""
new_args = []
for i, a in enumerate(args):
if isinstance(a, tensor.Tensor):
new_args.append(a)
elif isinstance(a, bool):
# TODO: default values for function parameters
# are not properly handled yet. This section
# should disappear.
new_args.append(tensor.Tensor(np.array(a)))
else:
raise TypeError(f"Unexpected input type {type(a)} for an input {i}.")
res = self.function(*new_args, **kwargs)
if isinstance(res, tensor.Tensor):
return res
if isinstance(res, tuple):
unwrapped = []
for i, r in enumerate(res):
if isinstance(r, tensor.Tensor):
unwrapped.append(r)
else:
raise TypeError(
f"Unexpected output type {type(r)} for an output {i} "
f"in function {self.function!r}."
)
return tuple(unwrapped)
raise TypeError(f"Unexpected output type {type(res)} in function {self.function!r}.")
new_args, has_array = _adapt_to_eager_mode(args)
result = self.function(*new_args, **kwargs)

# We use a heuristic to decide whether to return output values as
# numpy arrays or tensor.Tensors. If the function has at least one
# numpy array as input, we return numpy arrays. Otherwise, we return
# tensor.Tensors. We could use a user-specified flag to control this
# or explicitly track whether this is a top-level function-call or
# a nested function-call.

return _adapt_to_user_mode(result) if has_array else result

def to_function_proto(self, domain=None):
"""Converts the function into :class:`onnx.FunctionProto`."""
Expand Down