Skip to content

Commit a414346

Browse files
committed
added back mask method that does condition inversion
added condition testing to where that raised ValueError on an invalid condition (e.g. not an ndarray like object) added tests for same
1 parent 8034116 commit a414346

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

pandas/core/frame.py

Lines changed: 18 additions & 0 deletions
Original file line numberOriginal file lineDiff line numberDiff line change
@@ -4882,6 +4882,9 @@ def where(self, cond, other=NA, inplace=False):
4882
-------
4882
-------
4883
wh: DataFrame
4883
wh: DataFrame
4884
"""
4884
"""
4885+
if not hasattr(cond,'shape'):
4886+
raise ValueError('where requires an ndarray like object for its condition')
4887+
4885
if isinstance(cond, np.ndarray):
4888
if isinstance(cond, np.ndarray):
4886
if cond.shape != self.shape:
4889
if cond.shape != self.shape:
4887
raise ValueError('Array onditional must be same shape as self')
4890
raise ValueError('Array onditional must be same shape as self')
@@ -4901,6 +4904,21 @@ def where(self, cond, other=NA, inplace=False):
4901
rs = np.where(cond, self, other)
4904
rs = np.where(cond, self, other)
4902
return self._constructor(rs, self.index, self.columns)
4905
return self._constructor(rs, self.index, self.columns)
4903

4906

4907+
def mask(self, cond):
4908+
"""
4909+
Returns copy of self whose values are replaced with nan if the
4910+
inverted condition is True
4911+
4912+
Parameters
4913+
----------
4914+
cond: boolean DataFrame or array
4915+
4916+
Returns
4917+
-------
4918+
wh: DataFrame
4919+
"""
4920+
return self.where(~cond, NA)
4921+
4904
_EMPTY_SERIES = Series([])
4922
_EMPTY_SERIES = Series([])
4905

4923

4906

4924

pandas/tests/test_frame.py

Lines changed: 12 additions & 0 deletions
Original file line numberOriginal file lineDiff line numberDiff line change
@@ -5220,6 +5220,18 @@ def test_where(self):
5220
err2 = cond.ix[:2, :].values
5220
err2 = cond.ix[:2, :].values
5221
self.assertRaises(ValueError, df.where, err2, other1)
5221
self.assertRaises(ValueError, df.where, err2, other1)
5222

5222

5223+
# invalid conditions
5224+
self.assertRaises(ValueError, df.mask, True)
5225+
self.assertRaises(ValueError, df.mask, 0)
5226+
5227+
def test_mask(self):
5228+
df = DataFrame(np.random.randn(5, 3))
5229+
cond = df > 0
5230+
5231+
rs = df.where(cond, np.nan)
5232+
assert_frame_equal(rs, df.mask(df <= 0))
5233+
assert_frame_equal(rs, df.mask(~cond))
5234+
5223

5235

5224
#----------------------------------------------------------------------
5236
#----------------------------------------------------------------------
5225
# Transposing
5237
# Transposing

0 commit comments

Comments
 (0)