Skip to content
This repository was archived by the owner on Jan 7, 2023. It is now read-only.

Commit 6e5bf38

Browse files
committed
stretch: make fields argument optional, and stretch all fields if fields=None
1 parent 4b1683e commit 6e5bf38

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

root_numpy/_utils.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def stack(recs, fields=None):
7272
return np.hstack([rec[fields] for rec in recs])
7373

7474

75-
def stretch(arr, fields):
75+
def stretch(arr, fields=None):
7676
"""Stretch an array.
7777
7878
Stretch an array by ``hstack()``-ing multiple array fields while
@@ -83,8 +83,8 @@ def stretch(arr, fields):
8383
----------
8484
arr : NumPy structured or record array
8585
The array to be stretched.
86-
fields : list of strings
87-
A list of column names to stretch.
86+
fields : list of strings, optional (default=None)
87+
A list of column names to stretch. If None, then stretch all fields.
8888
8989
Returns
9090
-------
@@ -105,19 +105,21 @@ def stretch(arr, fields):
105105
"""
106106
dt = []
107107
has_array_field = False
108-
has_scalar_filed = False
109108
first_array = None
110109

110+
if fields is None:
111+
fields = arr.dtype.names
112+
111113
# Construct dtype
112-
for c in fields:
113-
if _is_object_field(arr, c):
114-
dt.append((c, arr[c][0].dtype))
114+
for field in fields:
115+
if _is_object_field(arr, field):
116+
dt.append((field, arr[field][0].dtype))
115117
has_array_field = True
116-
first_array = c if first_array is None else first_array
118+
if first_array is None:
119+
first_array = field
117120
else:
118121
# Assume scalar
119-
dt.append((c, arr[c].dtype))
120-
has_scalar_filed = True
122+
dt.append((field, arr[field].dtype))
121123

122124
if not has_array_field:
123125
raise RuntimeError("No array column specified")
@@ -126,21 +128,21 @@ def stretch(arr, fields):
126128
numrec = np.sum(len_array)
127129
ret = np.empty(numrec, dtype=dt)
128130

129-
for c in fields:
130-
if _is_object_field(arr, c):
131+
for field in fields:
132+
if _is_object_field(arr, field):
131133
# FIXME: this is rather inefficient since the stack
132134
# is copied over to the return value
133-
stack = np.hstack(arr[c])
135+
stack = np.hstack(arr[field])
134136
if len(stack) != numrec:
135137
raise ValueError(
136138
"Array lengths do not match: "
137139
"expected %d but found %d in %s" %
138-
(numrec, len(stack), c))
139-
ret[c] = stack
140+
(numrec, len(stack), field))
141+
ret[field] = stack
140142
else:
141143
# FIXME: this is rather inefficient since the repeat result
142144
# is copied over to the return value
143-
ret[c] = np.repeat(arr[c], len_array)
145+
ret[field] = np.repeat(arr[field], len_array)
144146

145147
return ret
146148

0 commit comments

Comments
 (0)