@@ -72,7 +72,7 @@ def stack(recs, fields=None):
72
72
return np .hstack ([rec [fields ] for rec in recs ])
73
73
74
74
75
- def stretch (arr , fields ):
75
+ def stretch (arr , fields = None ):
76
76
"""Stretch an array.
77
77
78
78
Stretch an array by ``hstack()``-ing multiple array fields while
@@ -83,8 +83,8 @@ def stretch(arr, fields):
83
83
----------
84
84
arr : NumPy structured or record array
85
85
The array to be stretched.
86
- fields : list of strings
87
- A list of column names to stretch.
86
+ fields : list of strings, optional (default=None)
87
+ A list of column names to stretch. If None, then stretch all fields.
88
88
89
89
Returns
90
90
-------
@@ -105,19 +105,21 @@ def stretch(arr, fields):
105
105
"""
106
106
dt = []
107
107
has_array_field = False
108
- has_scalar_filed = False
109
108
first_array = None
110
109
110
+ if fields is None :
111
+ fields = arr .dtype .names
112
+
111
113
# Construct dtype
112
- for c in fields :
113
- if _is_object_field (arr , c ):
114
- dt .append ((c , arr [c ][0 ].dtype ))
114
+ for field in fields :
115
+ if _is_object_field (arr , field ):
116
+ dt .append ((field , arr [field ][0 ].dtype ))
115
117
has_array_field = True
116
- first_array = c if first_array is None else first_array
118
+ if first_array is None :
119
+ first_array = field
117
120
else :
118
121
# Assume scalar
119
- dt .append ((c , arr [c ].dtype ))
120
- has_scalar_filed = True
122
+ dt .append ((field , arr [field ].dtype ))
121
123
122
124
if not has_array_field :
123
125
raise RuntimeError ("No array column specified" )
@@ -126,21 +128,21 @@ def stretch(arr, fields):
126
128
numrec = np .sum (len_array )
127
129
ret = np .empty (numrec , dtype = dt )
128
130
129
- for c in fields :
130
- if _is_object_field (arr , c ):
131
+ for field in fields :
132
+ if _is_object_field (arr , field ):
131
133
# FIXME: this is rather inefficient since the stack
132
134
# is copied over to the return value
133
- stack = np .hstack (arr [c ])
135
+ stack = np .hstack (arr [field ])
134
136
if len (stack ) != numrec :
135
137
raise ValueError (
136
138
"Array lengths do not match: "
137
139
"expected %d but found %d in %s" %
138
- (numrec , len (stack ), c ))
139
- ret [c ] = stack
140
+ (numrec , len (stack ), field ))
141
+ ret [field ] = stack
140
142
else :
141
143
# FIXME: this is rather inefficient since the repeat result
142
144
# is copied over to the return value
143
- ret [c ] = np .repeat (arr [c ], len_array )
145
+ ret [field ] = np .repeat (arr [field ], len_array )
144
146
145
147
return ret
146
148
0 commit comments