Skip to content

Commit 3231ee8

Browse files
authored
TYP: Type MaskedArray.compress (numpy#29480)
1 parent 4ab0c51 commit 3231ee8

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

numpy/ma/core.pyi

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,36 @@ class MaskedArray(ndarray[_ShapeT_co, _DTypeT_co]):
614614
set_fill_value: Any
615615
def filled(self, /, fill_value: _ScalarLike_co | None = None) -> ndarray[_ShapeT_co, _DTypeT_co]: ...
616616
def compressed(self) -> ndarray[tuple[int], _DTypeT_co]: ...
617-
def compress(self, condition, axis=..., out=...): ...
617+
618+
@overload
619+
def compress(
620+
self,
621+
condition: _ArrayLikeBool_co,
622+
axis: _ShapeLike | None,
623+
out: _ArrayT
624+
) -> _ArrayT: ...
625+
@overload
626+
def compress(
627+
self,
628+
condition: _ArrayLikeBool_co,
629+
axis: _ShapeLike | None = None,
630+
*,
631+
out: _ArrayT
632+
) -> _ArrayT: ...
633+
@overload
634+
def compress(
635+
self,
636+
condition: _ArrayLikeBool_co,
637+
axis: None = None,
638+
out: None = None
639+
) -> MaskedArray[tuple[int], _DTypeT_co]: ...
640+
@overload
641+
def compress(
642+
self,
643+
condition: _ArrayLikeBool_co,
644+
axis: _ShapeLike | None = None,
645+
out: None = None
646+
) -> MaskedArray[_AnyShape, _DTypeT_co]: ...
618647

619648
# TODO: How to deal with the non-commutative nature of `==` and `!=`?
620649
# xref numpy/numpy#17368

numpy/typing/tests/data/reveal/ma.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,11 @@ assert_type(np.ma.count(MAR_o, None, True), NDArray[np.int_])
301301

302302
assert_type(MAR_f4.compressed(), np.ndarray[tuple[int], np.dtype[np.float32]])
303303

304+
assert_type(MAR_f4.compress([True, False]), np.ma.MaskedArray[tuple[int], np.dtype[np.float32]])
305+
assert_type(MAR_f4.compress([True, False], axis=0), MaskedArray[np.float32])
306+
assert_type(MAR_f4.compress([True, False], axis=0, out=MAR_subclass), MaskedArraySubclassC)
307+
assert_type(MAR_f4.compress([True, False], 0, MAR_subclass), MaskedArraySubclassC)
308+
304309
assert_type(np.ma.compressed(MAR_i8), np.ndarray[tuple[int], np.dtype[np.int64]])
305310
assert_type(np.ma.compressed([[1, 2, 3]]), np.ndarray[tuple[int], np.dtype])
306311

0 commit comments

Comments
 (0)