Skip to content

Commit 3a12685

Browse files
committed
ENH: SparseDataFrame supports scipy.sparse.spmatrix in setitem
1 parent 9d6d2fe commit 3a12685

File tree

3 files changed

+46
-0
lines changed

3 files changed

+46
-0
lines changed

pandas/core/internals.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2571,6 +2571,16 @@ def _astype(self, dtype, copy=False, raise_on_error=True, values=None,
25712571
return self.make_block_same_class(values=values,
25722572
placement=self.mgr_locs)
25732573

2574+
def _can_hold_element(self, element):
2575+
element = np.asanyarray(element)
2576+
return np.issubdtype(element.dtype, self.sp_values.dtype)
2577+
2578+
def _try_cast(self, element):
2579+
try:
2580+
return np.asarray(element, dtype=self.sp_values.dtype)
2581+
except ValueError:
2582+
return element
2583+
25742584
def __len__(self):
25752585
try:
25762586
return self.sp_index.length

pandas/core/sparse/frame.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,15 @@ def __getitem__(self, key):
433433
else:
434434
return self._get_item_cache(key)
435435

436+
def __setitem__(self, key, value):
437+
if is_scipy_sparse(value):
438+
if any(ax == 1 for ax in value.shape): # 1d spmatrix
439+
value = SparseArray(value, fill_value=self._default_fill_value)
440+
else:
441+
# 2d; make it iterable
442+
value = list(value.tocsc().T)
443+
super().__setitem__(key, value)
444+
436445
@Appender(DataFrame.get_value.__doc__, indents=0)
437446
def get_value(self, index, col, takeable=False):
438447
if takeable is True:

pandas/tests/sparse/test_frame.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,33 @@ def test_setitem_array(self):
540540
self.frame['F'].reindex(index),
541541
check_names=False)
542542

543+
def test_setitem_spmatrix(self):
544+
# GH-15634
545+
tm.skip_if_no_package('scipy')
546+
from scipy.sparse import csr_matrix
547+
548+
sdf = self.frame.copy(False)
549+
550+
# 1d -- column
551+
spm = csr_matrix(np.arange(len(sdf))).T
552+
sdf['X'] = spm
553+
assert (sdf[['X']].to_coo() != spm.tocoo()).nnz == 0
554+
555+
# 1d -- existing column
556+
sdf['A'] = spm.T
557+
assert (sdf[['X']].to_coo() != spm.tocoo()).nnz == 0
558+
559+
# 1d row -- changing series contents not yet supported
560+
spm = csr_matrix(np.arange(sdf.shape[1])).astype(float)
561+
idx = np.r_[[False, True], np.full(sdf.shape[0] - 2, False)]
562+
tm.assert_raises_regex(TypeError, 'assignment',
563+
lambda: sdf.__setitem__(idx, spm))
564+
565+
# 2d -- 2 columns
566+
spm = csr_matrix(np.eye(len(sdf))[:, :2])
567+
sdf[['X', 'A']] = spm
568+
assert (sdf[['X', 'A']].to_coo() != spm.tocoo()).nnz == 0
569+
543570
def test_delitem(self):
544571
A = self.frame['A']
545572
C = self.frame['C']

0 commit comments

Comments
 (0)