Skip to content

Commit 3e8d865

Browse files
authored
Rename EagerArray to Tensor (#121)
Tensor seems a more appropriate name.
1 parent 4ea094d commit 3e8d865

File tree

6 files changed

+36
-36
lines changed

6 files changed

+36
-36
lines changed

onnxscript/autocast.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from onnx.defs import OpSchema
2-
from .eager_array import EagerArray
2+
from .tensor import Tensor
33
import numpy as np
44

55

@@ -54,7 +54,7 @@ def cast_inputs(get_type_info, cast, opschema, *args):
5454
def dynamic_cast_inputs(opschema, *args):
5555
'''Used for autocast during eager-mode execution.'''
5656
def get_type_info(x):
57-
return x.dtype if isinstance(x, EagerArray) else None
57+
return x.dtype if isinstance(x, Tensor) else None
5858

5959
def cast(x, typeinfo):
6060
if isinstance(x, (int, float)):
@@ -65,7 +65,7 @@ def cast(x, typeinfo):
6565
dtype = np.int32
6666
else: # isinstance(x, float):
6767
dtype = np.float32
68-
return EagerArray(np.array(x, dtype=dtype))
68+
return Tensor(np.array(x, dtype=dtype))
6969
else:
7070
return x
7171

onnxscript/eager_mode_evaluator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from onnxruntime.capi.onnxruntime_pybind11_state import Fail, InvalidGraph, InvalidArgument
1212
from .utils import values_to_value_infos, proto2text
1313
from .irbuilder import select_ir_version
14-
from .eager_array import EagerArray
14+
from .tensor import Tensor
1515

1616

1717
class EagerModeError(RuntimeError):
@@ -61,7 +61,7 @@ def os_to_ort_value(v):
6161
'''
6262
Converts an onnxscript encoding of an ONNX value into the encoding used by ORT.
6363
'''
64-
if isinstance(v, EagerArray):
64+
if isinstance(v, Tensor):
6565
return v.value
6666
elif isinstance(v, list):
6767
return v
@@ -78,7 +78,7 @@ def ort_to_os_value(v):
7878
Converts an ORT encoding of an ONNX value into the encoding used by onnxscript.
7979
'''
8080
if isinstance(v, np.ndarray):
81-
return EagerArray(v)
81+
return Tensor(v)
8282
elif isinstance(v, list):
8383
return v
8484
elif v is None:

onnxscript/onnx_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import onnx.helper
88

99

10-
class Tensor:
10+
class TensorType:
1111
# Reference implementation placeholder
1212
# represents a generic ONNX tensor type
1313
def __init__(self, dtype=onnx.TensorProto.UNDEFINED, shape=None) -> None:
@@ -61,7 +61,7 @@ def mk_dim(dim):
6161
s = None
6262
else:
6363
s = [shape]
64-
return Tensor(self.dtype, s)
64+
return TensorType(self.dtype, s)
6565

6666
def to_type_proto(self):
6767
return onnx.helper.make_tensor_type_proto(self.dtype, ())

onnxscript/eager_array.py renamed to onnxscript/tensor.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,30 @@
77
from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE
88

99

10-
class EagerArray:
10+
class Tensor:
1111
"""
12-
Wraps arrays to intercept calls to operators and use onnxruntime
13-
to process the output.
12+
An implementation of ONNX Tensors, based on a wrapper around numpy arrays.
13+
Serves to define overloaded ops with an ONNX/ONNXScript semantics.
1414
"""
1515

16-
def __init__(self, tensor, opset=None):
17-
if not isinstance(tensor, np.ndarray):
18-
raise TypeError(f"Unexpected type {type(tensor)}. It must be a numpy array.")
19-
self._tensor = tensor
16+
def __init__(self, nparray, opset=None):
17+
if not isinstance(nparray, np.ndarray):
18+
raise TypeError(f"Unexpected type {type(nparray)}. It must be a numpy array.")
19+
self._nparray = nparray
2020
from onnxscript.onnx_opset import default_opset
2121
self._opset = opset or default_opset
2222

2323
@property
2424
def value(self):
25-
return self._tensor
25+
return self._nparray
2626

2727
@property
2828
def shape(self):
29-
return self._tensor.shape
29+
return self._nparray.shape
3030

3131
@property
3232
def dtype(self):
33-
return self._tensor.dtype
33+
return self._nparray.dtype
3434

3535
@property
3636
def onnx_dtype(self):
@@ -52,7 +52,7 @@ def __getitem__(self, index):
5252
if isinstance(index, int):
5353
# case A[i]: indexing
5454
# promote integer input to tensor
55-
i = EagerArray(np.array(index))
55+
i = Tensor(np.array(index))
5656
# use Gather to perform indexing
5757
return op.Gather(self, i, axis=0)
5858
if not isinstance(index, (slice, tuple)):
@@ -78,13 +78,13 @@ def __getitem__(self, index):
7878
else:
7979
raise TypeError(f"Unexpected type {type(s)}: slice or int expected.")
8080
indices = np.array(indices, dtype=np.int64).T
81-
starts = EagerArray(indices[0])
82-
ends = EagerArray(indices[1])
83-
axis = EagerArray(indices[2])
84-
steps = EagerArray(indices[3])
81+
starts = Tensor(indices[0])
82+
ends = Tensor(indices[1])
83+
axis = Tensor(indices[2])
84+
steps = Tensor(indices[3])
8585
result = op.Slice(self, starts, ends, axis, steps)
8686
if len(to_squeeze) > 0:
87-
result = EagerArray(np.squeeze(result.value, axis=tuple(to_squeeze)))
87+
result = Tensor(np.squeeze(result.value, axis=tuple(to_squeeze)))
8888
return result
8989

9090
def __mod__(self, other):

onnxscript/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import onnx
1010
from onnx import TensorProto, ValueInfoProto, ModelProto, FunctionProto
1111
from onnx.helper import make_tensor_type_proto, make_sequence_type_proto
12-
from .eager_array import EagerArray
12+
from .tensor import Tensor
1313

1414
# print utility unavailable in ONNX 1.12 or earlier:
1515
try:
@@ -23,7 +23,7 @@ def value_to_type_proto(val):
2323
'''
2424
Return the ONNX type of a python-value.
2525
'''
26-
if isinstance(val, (np.ndarray, EagerArray)):
26+
if isinstance(val, (np.ndarray, Tensor)):
2727
elem_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[val.dtype]
2828
shape = val.shape
2929
return make_tensor_type_proto(elem_type, shape)

onnxscript/values.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from enum import IntFlag
1111
import numpy as np
1212
import onnx
13-
from .eager_array import EagerArray
13+
from .tensor import Tensor
1414
from .autocast import dynamic_cast_inputs
1515

1616

@@ -174,7 +174,7 @@ def __call__(self, *args, **kwargs):
174174
if len(args) == 0:
175175
# Operator Constant, it is usually called within a function.
176176
return self._libcall(**kwargs)
177-
if isinstance(args[0], EagerArray):
177+
if isinstance(args[0], Tensor):
178178
return self._libcall(*args, **kwargs)
179179
return self._usercall(*args, **kwargs)
180180

@@ -183,21 +183,21 @@ def _usercall(self, *args, **kwargs):
183183
new_args = []
184184
for i, a in enumerate(args):
185185
if isinstance(a, np.ndarray):
186-
new_args.append(EagerArray(a))
186+
new_args.append(Tensor(a))
187187
elif isinstance(a, bool):
188-
new_args.append(EagerArray(np.array(a)))
188+
new_args.append(Tensor(np.array(a)))
189189
else:
190190
raise TypeError(
191191
f"Unexpected input type {type(a)} for an input {i}.")
192192
res = self.function(*new_args, **kwargs)
193193
if isinstance(res, np.ndarray):
194194
return res
195-
if isinstance(res, EagerArray):
195+
if isinstance(res, Tensor):
196196
return res.value
197197
if isinstance(res, (list, tuple)):
198198
unwrapped = []
199199
for i, r in enumerate(res):
200-
if isinstance(r, EagerArray):
200+
if isinstance(r, Tensor):
201201
unwrapped.append(r.value)
202202
else:
203203
raise TypeError(
@@ -216,23 +216,23 @@ def _libcall(self, *args, **kwargs):
216216
"""
217217
new_args = []
218218
for i, a in enumerate(args):
219-
if isinstance(a, EagerArray):
219+
if isinstance(a, Tensor):
220220
new_args.append(a)
221221
elif isinstance(a, bool):
222222
# TODO: default values for function parameters
223223
# are not properly handled yet. This section
224224
# should disappear.
225-
new_args.append(EagerArray(np.array(a)))
225+
new_args.append(Tensor(np.array(a)))
226226
else:
227227
raise TypeError(
228228
f"Unexpected input type {type(a)} for an input {i}.")
229229
res = self.function(*new_args, **kwargs)
230-
if isinstance(res, EagerArray):
230+
if isinstance(res, Tensor):
231231
return res
232232
if isinstance(res, tuple):
233233
unwrapped = []
234234
for i, r in enumerate(res):
235-
if isinstance(r, EagerArray):
235+
if isinstance(r, Tensor):
236236
unwrapped.append(r)
237237
else:
238238
raise TypeError(

0 commit comments

Comments
 (0)