Skip to content

Commit 36357a0

Browse files
dalejungjreback
authored andcommitted
PRF: Optimize 2d take operations.
Use memory views
1 parent a4870b5 commit 36357a0

File tree

3 files changed

+195
-620
lines changed

3 files changed

+195
-620
lines changed

pandas/core/common.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,17 @@ def take_nd(arr, indexer, axis=0, out=None, fill_value=np.nan,
675675
# to crash when trying to cast it to dtype)
676676
dtype, fill_value = arr.dtype, arr.dtype.type()
677677

678+
flip_order = False
679+
if arr.ndim == 2:
680+
if arr.flags.f_contiguous:
681+
flip_order = True
682+
683+
if flip_order:
684+
arr = arr.T
685+
axis = arr.ndim - axis - 1
686+
if out is not None:
687+
out = out.T
688+
678689
# at this point, it's guaranteed that dtype can hold both the arr values
679690
# and the fill_value
680691
if out is None:
@@ -692,7 +703,11 @@ def take_nd(arr, indexer, axis=0, out=None, fill_value=np.nan,
692703

693704
func = _get_take_nd_function(arr.ndim, arr.dtype, out.dtype,
694705
axis=axis, mask_info=mask_info)
706+
695707
func(arr, indexer, out, fill_value)
708+
709+
if flip_order:
710+
out = out.T
696711
return out
697712

698713

pandas/src/generate_code.py

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def take_1d_%(name)s_%(dest)s(ndarray[%(c_type_in)s] values,
8080

8181
take_2d_axis0_template = """@cython.wraparound(False)
8282
@cython.boundscheck(False)
83-
def take_2d_axis0_%(name)s_%(dest)s(ndarray[%(c_type_in)s, ndim=2] values,
83+
def take_2d_axis0_%(name)s_%(dest)s(%(c_type_in)s[:, :] values,
8484
ndarray[int64_t] indexer,
85-
ndarray[%(c_type_out)s, ndim=2] out,
85+
%(c_type_out)s[:, :] out,
8686
fill_value=np.nan):
8787
cdef:
8888
Py_ssize_t i, j, k, n, idx
@@ -127,9 +127,9 @@ def take_2d_axis0_%(name)s_%(dest)s(ndarray[%(c_type_in)s, ndim=2] values,
127127

128128
take_2d_axis1_template = """@cython.wraparound(False)
129129
@cython.boundscheck(False)
130-
def take_2d_axis1_%(name)s_%(dest)s(ndarray[%(c_type_in)s, ndim=2] values,
130+
def take_2d_axis1_%(name)s_%(dest)s(%(c_type_in)s[:, :] values,
131131
ndarray[int64_t] indexer,
132-
ndarray[%(c_type_out)s, ndim=2] out,
132+
%(c_type_out)s[:, :] out,
133133
fill_value=np.nan):
134134
cdef:
135135
Py_ssize_t i, j, k, n, idx
@@ -143,34 +143,12 @@ def take_2d_axis1_%(name)s_%(dest)s(ndarray[%(c_type_in)s, ndim=2] values,
143143
144144
fv = fill_value
145145
146-
IF %(can_copy)s:
147-
cdef:
148-
%(c_type_out)s *v
149-
%(c_type_out)s *o
150-
151-
#GH3130
152-
if (values.strides[0] == out.strides[0] and
153-
values.strides[0] == sizeof(%(c_type_out)s) and
154-
sizeof(%(c_type_out)s) * n >= 256):
155-
156-
for j from 0 <= j < k:
157-
idx = indexer[j]
158-
if idx == -1:
159-
for i from 0 <= i < n:
160-
out[i, j] = fv
161-
else:
162-
v = &values[0, idx]
163-
o = &out[0, j]
164-
memmove(o, v, <size_t>(sizeof(%(c_type_out)s) * n))
165-
return
166-
167-
for j from 0 <= j < k:
168-
idx = indexer[j]
169-
if idx == -1:
170-
for i from 0 <= i < n:
146+
for i from 0 <= i < n:
147+
for j from 0 <= j < k:
148+
idx = indexer[j]
149+
if idx == -1:
171150
out[i, j] = fv
172-
else:
173-
for i from 0 <= i < n:
151+
else:
174152
out[i, j] = %(preval)svalues[i, idx]%(postval)s
175153
176154
"""

0 commit comments

Comments
 (0)