Skip to content
This repository was archived by the owner on Jan 7, 2023. It is now read-only.

Commit 4108861

Browse files
committed
Merge pull request #185 from ndawe/stretch
[MRG] stretch() improvements
2 parents c332a0f + cc2f178 commit 4108861

File tree

2 files changed

+87
-87
lines changed

2 files changed

+87
-87
lines changed

root_numpy/_utils.py

Lines changed: 46 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@
1717
VLEN = np.vectorize(len)
1818

1919

20-
def _is_object_field(arr, col):
21-
return arr.dtype[col] == 'O'
22-
23-
2420
def rec2array(rec, fields=None):
2521
"""Convert a record array into a ndarray with a homogeneous data type.
2622
@@ -72,7 +68,7 @@ def stack(recs, fields=None):
7268
return np.hstack([rec[fields] for rec in recs])
7369

7470

75-
def stretch(arr, fields):
71+
def stretch(arr, fields=None):
7672
"""Stretch an array.
7773
7874
Stretch an array by ``hstack()``-ing multiple array fields while
@@ -83,8 +79,8 @@ def stretch(arr, fields):
8379
----------
8480
arr : NumPy structured or record array
8581
The array to be stretched.
86-
fields : list of strings
87-
A list of column names to stretch.
82+
fields : list of strings, optional (default=None)
83+
A list of column names to stretch. If None, then stretch all fields.
8884
8985
Returns
9086
-------
@@ -103,44 +99,51 @@ def stretch(arr, fields):
10399
dtype=[('scalar', '<i8'), ('array', '<f8')])
104100
105101
"""
106-
dt = []
107-
has_array_field = False
108-
has_scalar_filed = False
109-
first_array = None
110-
111-
# Construct dtype
112-
for c in fields:
113-
if _is_object_field(arr, c):
114-
dt.append((c, arr[c][0].dtype))
115-
has_array_field = True
116-
first_array = c if first_array is None else first_array
117-
else:
118-
# Assume scalar
119-
dt.append((c, arr[c].dtype))
120-
has_scalar_filed = True
121-
122-
if not has_array_field:
123-
raise RuntimeError("No array column specified")
124-
125-
len_array = VLEN(arr[first_array])
126-
numrec = np.sum(len_array)
127-
ret = np.empty(numrec, dtype=dt)
128-
129-
for c in fields:
130-
if _is_object_field(arr, c):
131-
# FIXME: this is rather inefficient since the stack
132-
# is copied over to the return value
133-
stack = np.hstack(arr[c])
134-
if len(stack) != numrec:
102+
dtype = []
103+
len_array = None
104+
105+
if fields is None:
106+
fields = arr.dtype.names
107+
108+
# Construct dtype and check consistency
109+
for field in fields:
110+
dt = arr.dtype[field]
111+
if dt == 'O' or len(dt.shape):
112+
if dt == 'O':
113+
# Variable-length array field
114+
lengths = VLEN(arr[field])
115+
else:
116+
lengths = np.repeat(dt.shape[0], arr.shape[0])
117+
# Fixed-length array field
118+
if len_array is None:
119+
len_array = lengths
120+
elif not np.array_equal(lengths, len_array):
135121
raise ValueError(
136-
"Array lengths do not match: "
137-
"expected %d but found %d in %s" %
138-
(numrec, len(stack), c))
139-
ret[c] = stack
122+
"inconsistent lengths of array columns in input")
123+
if dt == 'O':
124+
dtype.append((field, arr[field][0].dtype))
125+
else:
126+
dtype.append((field, arr[field].dtype, dt.shape[1:]))
127+
else:
128+
# Scalar field
129+
dtype.append((field, dt))
130+
131+
if len_array is None:
132+
raise RuntimeError("no array column in input")
133+
134+
# Build stretched output
135+
ret = np.empty(np.sum(len_array), dtype=dtype)
136+
for field in fields:
137+
dt = arr.dtype[field]
138+
if dt == 'O' or len(dt.shape) == 1:
139+
# Variable-length or 1D fixed-length array field
140+
ret[field] = np.hstack(arr[field])
141+
elif len(dt.shape):
142+
# Multidimensional fixed-length array field
143+
ret[field] = np.vstack(arr[field])
140144
else:
141-
# FIXME: this is rather inefficient since the repeat result
142-
# is copied over to the return value
143-
ret[c] = np.repeat(arr[c], len_array)
145+
# Scalar field
146+
ret[field] = np.repeat(arr[field], len_array)
144147

145148
return ret
146149

root_numpy/tests.py

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -574,56 +574,53 @@ def test_fill_graph():
574574

575575

576576
def test_stretch():
577-
nrec = 5
578-
arr = np.empty(nrec,
577+
arr = np.empty(5,
579578
dtype=[
580579
('scalar', np.int),
581-
('df1', 'O'),
582-
('df2', 'O'),
583-
('df3', 'O')])
584-
585-
for i in range(nrec):
586-
df1 = np.array(range(i + 1), dtype=np.float)
587-
df2 = np.array(range(i + 1), dtype=np.int) * 2
588-
df3 = np.array(range(i + 1), dtype=np.double) * 3
589-
arr[i] = (i, df1, df2, df3)
580+
('vl1', 'O'),
581+
('vl2', 'O'),
582+
('vl3', 'O'),
583+
('fl1', np.int, (2, 2)),
584+
('fl2', np.float, (2, 3)),
585+
('fl3', np.double, (3, 2))])
586+
587+
for i in range(arr.shape[0]):
588+
vl1 = np.array(range(i + 1), dtype=np.int)
589+
vl2 = np.array(range(i + 2), dtype=np.float) * 2
590+
vl3 = np.array(range(2), dtype=np.double) * 3
591+
fl1 = np.array(range(4), dtype=np.int).reshape((2, 2))
592+
fl2 = np.array(range(6), dtype=np.float).reshape((2, 3))
593+
fl3 = np.array(range(6), dtype=np.double).reshape((3, 2))
594+
arr[i] = (i, vl1, vl2, vl3, fl1, fl2, fl3)
595+
596+
# no array columns included
597+
assert_raises(RuntimeError, rnp.stretch, arr, ['scalar',])
590598

591-
stretched = rnp.stretch(
592-
arr, ['scalar', 'df1', 'df2', 'df3'])
599+
# lengths don't match
600+
assert_raises(ValueError, rnp.stretch, arr, ['scalar', 'vl1', 'vl2',])
601+
assert_raises(ValueError, rnp.stretch, arr, ['scalar', 'fl1', 'fl3',])
602+
assert_raises(ValueError, rnp.stretch, arr)
593603

604+
# variable-length stretch
605+
stretched = rnp.stretch(arr, ['scalar', 'vl1',])
594606
assert_equal(stretched.dtype,
595-
[('scalar', np.int),
596-
('df1', np.float),
597-
('df2', np.int),
598-
('df3', np.double)])
599-
assert_equal(stretched.size, 15)
600-
601-
assert_almost_equal(stretched['df1'][14], 4.0)
602-
assert_almost_equal(stretched['df2'][14], 8)
603-
assert_almost_equal(stretched['df3'][14], 12.0)
604-
assert_almost_equal(stretched['scalar'][14], 4)
605-
assert_almost_equal(stretched['scalar'][13], 4)
606-
assert_almost_equal(stretched['scalar'][12], 4)
607-
assert_almost_equal(stretched['scalar'][11], 4)
608-
assert_almost_equal(stretched['scalar'][10], 4)
609-
assert_almost_equal(stretched['scalar'][9], 3)
610-
611-
arr = np.empty(1, dtype=[('scalar', np.int),])
612-
arr[0] = (1,)
613-
assert_raises(RuntimeError, rnp.stretch, arr, ['scalar',])
607+
[('scalar', np.int),
608+
('vl1', np.int)])
609+
assert_equal(stretched.shape[0], 15)
610+
assert_array_equal(
611+
stretched['scalar'],
612+
np.repeat(arr['scalar'], np.vectorize(len)(arr['vl1'])))
614613

615-
nrec = 5
616-
arr = np.empty(nrec,
617-
dtype=[
618-
('scalar', np.int),
619-
('df1', 'O'),
620-
('df2', 'O')])
621-
622-
for i in range(nrec):
623-
df1 = np.array(range(i + 1), dtype=np.float)
624-
df2 = np.array(range(i + 2), dtype=np.int) * 2
625-
arr[i] = (i, df1, df2)
626-
assert_raises(ValueError, rnp.stretch, arr, ['scalar', 'df1', 'df2'])
614+
# fixed-length stretch
615+
stretched = rnp.stretch(arr, ['scalar', 'vl3', 'fl1', 'fl2',])
616+
assert_equal(stretched.dtype,
617+
[('scalar', np.int),
618+
('vl3', np.double),
619+
('fl1', np.int, (2,)),
620+
('fl2', np.float, (3,))])
621+
assert_equal(stretched.shape[0], 10)
622+
assert_array_equal(
623+
stretched['scalar'], np.repeat(arr['scalar'], 2))
627624

628625

629626
def test_blockwise_inner_join():

0 commit comments

Comments
 (0)