Skip to content

Commit 94c2094

Browse files
perimosocordiaeamueller
authored andcommitted
[MRG+1] BUG: MultiLabelBinarizer.fit_transform sometimes returns an invalid CSR matrix (scikit-learn#7750)
* BUG: MultiLabelBinarizer makes invalid CSR matrix See scipy/scipy#6719 for context. The gist is that the `inverse` array may have a different dtype than `yt.indices`, which causes trouble down the line because, in those cases, `yt.indices` and `yt.indptr` have different dtypes. Alternately, we could insert `yt.check_format(full_check=False)` after modifying the sparse matrix members. * Fixing for old numpy Older versions don't support kwargs for `astype` * Adding tests * line-wrapping * adding comment to tests [ci skip] * added rationale comment [ci skip]
1 parent 3ff22f5 commit 94c2094

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

sklearn/preprocessing/label.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,9 @@ def fit_transform(self, y):
732732
class_mapping = np.empty(len(tmp), dtype=dtype)
733733
class_mapping[:] = tmp
734734
self.classes_, inverse = np.unique(class_mapping, return_inverse=True)
735-
yt.indices = np.take(inverse, yt.indices)
735+
# ensure yt.indices keeps its current dtype
736+
yt.indices = np.array(inverse[yt.indices], dtype=yt.indices.dtype,
737+
copy=False)
736738

737739
if not self.sparse_output:
738740
yt = yt.toarray()

sklearn/preprocessing/tests/test_label.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ def test_sparse_output_multilabel_binarizer():
226226
got = mlb.fit_transform(inp())
227227
assert_equal(issparse(got), sparse_output)
228228
if sparse_output:
229+
# verify CSR assumption that indices and indptr have same dtype
230+
assert_equal(got.indices.dtype, got.indptr.dtype)
229231
got = got.toarray()
230232
assert_array_equal(indicator_mat, got)
231233
assert_array_equal([1, 2, 3], mlb.classes_)
@@ -236,6 +238,8 @@ def test_sparse_output_multilabel_binarizer():
236238
got = mlb.fit(inp()).transform(inp())
237239
assert_equal(issparse(got), sparse_output)
238240
if sparse_output:
241+
# verify CSR assumption that indices and indptr have same dtype
242+
assert_equal(got.indices.dtype, got.indptr.dtype)
239243
got = got.toarray()
240244
assert_array_equal(indicator_mat, got)
241245
assert_array_equal([1, 2, 3], mlb.classes_)

0 commit comments

Comments
 (0)