diff --git a/sklearn/_config.py b/sklearn/_config.py index c41c180012056..8786f66d6b4c5 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -9,6 +9,7 @@ "working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)), "print_changed_only": True, "display": "text", + "array_api_dispatch": False, } _threadlocal = threading.local() @@ -40,7 +41,11 @@ def get_config(): def set_config( - assume_finite=None, working_memory=None, print_changed_only=None, display=None + assume_finite=None, + working_memory=None, + print_changed_only=None, + display=None, + array_api_dispatch=None, ): """Set global scikit-learn configuration @@ -95,11 +100,18 @@ def set_config( local_config["print_changed_only"] = print_changed_only if display is not None: local_config["display"] = display + if array_api_dispatch is not None: + local_config["array_api_dispatch"] = array_api_dispatch @contextmanager def config_context( - *, assume_finite=None, working_memory=None, print_changed_only=None, display=None + *, + assume_finite=None, + working_memory=None, + print_changed_only=None, + display=None, + array_api_dispatch=None, ): """Context manager for global scikit-learn configuration. @@ -171,6 +183,7 @@ def config_context( working_memory=working_memory, print_changed_only=print_changed_only, display=display, + array_api_dispatch=array_api_dispatch, ) try: diff --git a/sklearn/discriminant_analysis.py b/sklearn/discriminant_analysis.py index c28b6e48b9a22..697389be85805 100644 --- a/sklearn/discriminant_analysis.py +++ b/sklearn/discriminant_analysis.py @@ -11,8 +11,10 @@ import warnings import numpy as np +import scipy.linalg from scipy import linalg from scipy.special import expit +import math from .base import BaseEstimator, TransformerMixin, ClassifierMixin from .base import _ClassNamePrefixFeaturesOutMixin @@ -20,6 +22,7 @@ from .covariance import ledoit_wolf, empirical_covariance, shrunk_covariance from .utils.multiclass import unique_labels from .utils.validation import check_is_fitted +from .utils._array_api import get_namespace from .utils.multiclass import check_classification_targets from .utils.extmath import softmax from .preprocessing import StandardScaler @@ -110,11 +113,17 @@ def _class_means(X, y): means : array-like of shape (n_classes, n_features) Class means. """ - classes, y = np.unique(y, return_inverse=True) - cnt = np.bincount(y) - means = np.zeros(shape=(len(classes), X.shape[1])) - np.add.at(means, y, X) - means /= cnt[:, None] + xp, is_array_api = get_namespace(X) + classes, y = xp.unique_inverse(y) + means = xp.zeros(shape=(classes.shape[0], X.shape[1])) + + if is_array_api: + for i in range(classes.shape[0]): + means[i, :] = xp.mean(X[y == i], axis=0) + else: + cnt = np.bincount(y) + np.add.at(means, y, X) + means /= cnt[:, None] return means @@ -466,8 +475,15 @@ def _solve_svd(self, X, y): y : array-like of shape (n_samples,) or (n_samples, n_targets) Target values. """ + xp, is_array_api = get_namespace(X) + + if is_array_api: + svd = xp.linalg.svd + else: + svd = scipy.linalg.svd + n_samples, n_features = X.shape - n_classes = len(self.classes_) + n_classes = self.classes_.shape[0] self.means_ = _class_means(X, y) if self.store_covariance: @@ -475,55 +491,51 @@ def _solve_svd(self, X, y): Xc = [] for idx, group in enumerate(self.classes_): - Xg = X[y == group, :] - Xc.append(Xg - self.means_[idx]) + Xg = X[y == group] + Xc.append(Xg - self.means_[idx, :]) - self.xbar_ = np.dot(self.priors_, self.means_) + self.xbar_ = self.priors_ @ self.means_ - Xc = np.concatenate(Xc, axis=0) + Xc = xp.concat(Xc, axis=0) # 1) within (univariate) scaling by with classes std-dev - std = Xc.std(axis=0) + std = xp.std(Xc, axis=0) # avoid division by zero in normalization std[std == 0] = 1.0 fac = 1.0 / (n_samples - n_classes) # 2) Within variance scaling - X = np.sqrt(fac) * (Xc / std) + X = math.sqrt(fac) * (Xc / std) # SVD of centered (within)scaled data - U, S, Vt = linalg.svd(X, full_matrices=False, check_finite=False) + U, S, Vt = svd(X, full_matrices=False) - rank = np.sum(S > self.tol) + rank = xp.sum(xp.astype(S > self.tol, xp.int32)) # Scaling of within covariance is: V' 1/S - scalings = (Vt[:rank] / std).T / S[:rank] + scalings = (Vt[:rank, :] / std).T / S[:rank] # 3) Between variance scaling # Scale weighted centers - X = np.dot( - ( - (np.sqrt((n_samples * self.priors_) * fac)) - * (self.means_ - self.xbar_).T - ).T, - scalings, - ) + X = ( + (xp.sqrt((n_samples * self.priors_) * fac)) * (self.means_ - self.xbar_).T + ).T @ scalings # Centers are living in a space with n_classes-1 dim (maximum) # Use SVD to find projection in the space spanned by the # (n_classes) centers - _, S, Vt = linalg.svd(X, full_matrices=0, check_finite=False) + _, S, Vt = svd(X, full_matrices=False) if self._max_components == 0: - self.explained_variance_ratio_ = np.empty((0,), dtype=S.dtype) + self.explained_variance_ratio_ = xp.empty((0,), dtype=S.dtype) else: - self.explained_variance_ratio_ = (S ** 2 / np.sum(S ** 2))[ + self.explained_variance_ratio_ = (S ** 2 / xp.sum(S ** 2))[ : self._max_components ] - rank = np.sum(S > self.tol * S[0]) - self.scalings_ = np.dot(scalings, Vt.T[:, :rank]) - coef = np.dot(self.means_ - self.xbar_, self.scalings_) - self.intercept_ = -0.5 * np.sum(coef ** 2, axis=1) + np.log(self.priors_) - self.coef_ = np.dot(coef, self.scalings_.T) - self.intercept_ -= np.dot(self.xbar_, self.coef_.T) + rank = xp.sum(xp.astype(S > self.tol * S[0], xp.int32)) + self.scalings_ = scalings @ Vt.T[:, :rank] + coef = (self.means_ - self.xbar_) @ self.scalings_ + self.intercept_ = -0.5 * xp.sum(coef ** 2, axis=1) + xp.log(self.priors_) + self.coef_ = coef @ self.scalings_.T + self.intercept_ -= self.xbar_ @ self.coef_.T def fit(self, X, y): """Fit the Linear Discriminant Analysis model. @@ -547,12 +559,13 @@ def fit(self, X, y): self : object Fitted estimator. """ + xp, _ = get_namespace(X) X, y = self._validate_data( - X, y, ensure_min_samples=2, dtype=[np.float64, np.float32] + X, y, ensure_min_samples=2, dtype=[xp.float64, xp.float32] ) self.classes_ = unique_labels(y) n_samples, _ = X.shape - n_classes = len(self.classes_) + n_classes = self.classes_.shape[0] if n_samples == n_classes: raise ValueError( @@ -560,20 +573,22 @@ def fit(self, X, y): ) if self.priors is None: # estimate priors from sample - _, y_t = np.unique(y, return_inverse=True) # non-negative ints - self.priors_ = np.bincount(y_t) / float(len(y)) + _, cnts = xp.unique_counts(y) # non-negative ints + self.priors_ = xp.astype(cnts, xp.float64) / float(y.shape[0]) else: - self.priors_ = np.asarray(self.priors) + self.priors_ = xp.asarray(self.priors) - if (self.priors_ < 0).any(): + if xp.any(self.priors_ < 0): raise ValueError("priors must be non-negative") - if not np.isclose(self.priors_.sum(), 1.0): - warnings.warn("The priors do not sum to 1. Renormalizing", UserWarning) - self.priors_ = self.priors_ / self.priors_.sum() + + # TODO: implement isclose in wrapper? + # if not np.isclose(np.sum(self.priors_), 1.0): + # warnings.warn("The priors do not sum to 1. Renormalizing", UserWarning) + # self.priors_ = self.priors_ / self.priors_.sum() # Maximum number of components no matter what n_components is # specified: - max_components = min(len(self.classes_) - 1, X.shape[1]) + max_components = min(n_classes - 1, X.shape[1]) if self.n_components is None: self._max_components = max_components @@ -614,12 +629,12 @@ def fit(self, X, y): "'lsqr', and 'eigen').".format(self.solver) ) if self.classes_.size == 2: # treat binary case as a special case - self.coef_ = np.array( - self.coef_[1, :] - self.coef_[0, :], ndmin=2, dtype=X.dtype - ) - self.intercept_ = np.array( - self.intercept_[1] - self.intercept_[0], ndmin=1, dtype=X.dtype + coef_ = xp.asarray(self.coef_[1, :] - self.coef_[0, :], dtype=X.dtype) + self.coef_ = xp.reshape(coef_, (1, -1)) + intercept_ = xp.asarray( + self.intercept_[1] - self.intercept_[0], dtype=X.dtype ) + self.intercept_ = xp.reshape(intercept_, 1) self._n_features_out = self._max_components return self diff --git a/sklearn/linear_model/_base.py b/sklearn/linear_model/_base.py index 652556cf1e702..9ee59520f73e4 100644 --- a/sklearn/linear_model/_base.py +++ b/sklearn/linear_model/_base.py @@ -35,6 +35,7 @@ from ..utils.extmath import _incremental_mean_and_var from ..utils.sparsefuncs import mean_variance_axis, inplace_column_scale from ..utils.fixes import sparse_lsqr +from ..utils._array_api import get_namespace from ..utils._seq_dataset import ArrayDataset32, CSRDataset32 from ..utils._seq_dataset import ArrayDataset64, CSRDataset64 from ..utils.validation import check_is_fitted, _check_sample_weight @@ -405,8 +406,9 @@ def decision_function(self, X): check_is_fitted(self) X = self._validate_data(X, accept_sparse="csr", reset=False) + xp, _ = get_namespace(X) scores = safe_sparse_dot(X, self.coef_.T, dense_output=True) + self.intercept_ - return scores.ravel() if scores.shape[1] == 1 else scores + return xp.reshape(scores, -1) if scores.shape[1] == 1 else scores def predict(self, X): """ @@ -422,12 +424,16 @@ def predict(self, X): y_pred : ndarray of shape (n_samples,) Vector containing the class labels for each sample. """ + xp, _ = get_namespace(X) scores = self.decision_function(X) if len(scores.shape) == 1: - indices = (scores > 0).astype(int) + indices = xp.astype(scores > 0, int) else: indices = scores.argmax(axis=1) - return self.classes_[indices] + # Should really use `np.take` + # return np.take(self.classes_, indices, axis=0) + # assuming classes are [0, 1, 2, ....] + return indices def _predict_proba_lr(self, X): """Probability estimation for OvR logistic regression. diff --git a/sklearn/tests/test_discriminant_analysis.py b/sklearn/tests/test_discriminant_analysis.py index 40ede1feba547..3ce9d1ec0fd42 100644 --- a/sklearn/tests/test_discriminant_analysis.py +++ b/sklearn/tests/test_discriminant_analysis.py @@ -4,6 +4,8 @@ from scipy import linalg +from sklearn.base import clone +from sklearn._config import config_context from sklearn.utils import check_random_state from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_array_almost_equal @@ -667,3 +669,30 @@ def test_get_feature_names_out(): dtype=object, ) assert_array_equal(names_out, expected_names_out) + + +def test_lda_array_api(): + """Check that the array_api Array gives the same results as ndarrays.""" + pytest.importorskip("numpy", minversion="1.22", reason="Requires Array API") + xp = pytest.importorskip("numpy.array_api") + + X_xp = xp.asarray(X) + y_xp = xp.asarray(y) + + lda = LinearDiscriminantAnalysis() + lda.fit(X, y) + + lda_xp = clone(lda) + with config_context(array_api_dispatch=True): + lda_xp.fit(X_xp, y_xp) + + gm_attributes_array = { + key: value for key, value in vars(lda).items() if isinstance(value, np.ndarray) + } + for key in gm_attributes_array: + gm_xp_param = getattr(lda_xp, key) + assert hasattr(gm_xp_param, "__array_namespace__") + + assert_allclose( + gm_attributes_array[key], gm_xp_param, err_msg=f"{key} not the same" + ) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py new file mode 100644 index 0000000000000..345478eb1ef88 --- /dev/null +++ b/sklearn/utils/_array_api.py @@ -0,0 +1,98 @@ +"""Tools to support array_api.""" +import numpy +from .._config import get_config + + +class _ArrayAPIWrapper: + def __init__(self, array_namespace): + self._namespace = array_namespace + + def __getattr__(self, name): + return getattr(self._namespace, name) + + def astype(self, x, dtype, *, copy=True, casting="unsafe"): + # support casting for NumPy + if self._namespace.__name__ == "numpy.array_api": + x_np = x.astype(dtype, casting=casting, copy=copy) + return self._namespace.asarray(x_np) + + f = self._namespace.astype + return f(x, dtype, copy=copy) + + def asarray(self, obj, *, dtype=None, device=None, copy=None, order=None): + # support order in NumPy + if self._namespace.__name__ == "numpy.array_api": + if copy: + x_np = numpy.array(obj, dtype=dtype, order=order, copy=True) + else: + x_np = numpy.asarray(obj, dtype=dtype, order=order) + return self._namespace(x_np) + + f = self._namespace.asarray + return f(obj, dtype=dtype, device=device, copy=copy) + + def may_share_memory(self, a, b): + # support may_share_memory in NumPy + if self._namespace.__name__ == "numpy.array_api": + return numpy.may_share_memory(a, b) + + # The safe choice is to return True for all other array_api Arrays + return True + + +class _NumPyApiWrapper: + def __getattr__(self, name): + return getattr(numpy, name) + + def astype(self, x, dtype, *, copy=True, casting="unsafe"): + # astype is not defined in the top level NumPy namespace + return x.astype(dtype, copy=copy, casting=casting) + + def asarray(self, obj, *, dtype=None, device=None, copy=None, order=None): + # copy is in the ArrayAPI spec but not in NumPy's asarray + if copy: + return numpy.array(obj, dtype=dtype, order=order, copy=True) + else: + return numpy.asarray(obj, dtype=dtype, order=order) + + def unique_inverse(self, x): + return numpy.unique(x, return_inverse=True) + + def unique_counts(self, x): + return numpy.unique(x, return_counts=True) + + def unique_values(self, x): + return numpy.unique(x) + + def concat(self, arrays, *, axis=None): + return numpy.concatenate(arrays, axis=axis) + + +def get_namespace(*xs): + # `xs` contains one or more arrays, or possibly Python scalars (accepting + # those is a matter of taste, but doesn't seem unreasonable). + # Returns a tuple: (array_namespace, is_array_api) + + if not get_config()["array_api_dispatch"]: + return _NumPyApiWrapper(), False + + namespaces = { + x.__array_namespace__() if hasattr(x, "__array_namespace__") else None + for x in xs + if not isinstance(x, (bool, int, float, complex)) + } + + if not namespaces: + # one could special-case np.ndarray above or use np.asarray here if + # older numpy versions need to be supported. + raise ValueError("Unrecognized array input") + + if len(namespaces) != 1: + raise ValueError(f"Multiple namespaces for array inputs: {namespaces}") + + (xp,) = namespaces + if xp is None: + # Use numpy as default + return _NumPyApiWrapper(), False + + return _ArrayAPIWrapper(xp), True diff --git a/sklearn/utils/multiclass.py b/sklearn/utils/multiclass.py index 4e5981042f277..4494ff410f36a 100644 --- a/sklearn/utils/multiclass.py +++ b/sklearn/utils/multiclass.py @@ -17,11 +17,13 @@ import numpy as np from .validation import check_array, _assert_all_finite +from ..utils._array_api import get_namespace def _unique_multiclass(y): - if hasattr(y, "__array__"): - return np.unique(np.asarray(y)) + xp, is_array_api = get_namespace(y) + if hasattr(y, "__array__") or is_array_api: + return xp.unique_values(xp.asarray(y)) else: return set(y) @@ -70,6 +72,7 @@ def unique_labels(*ys): >>> unique_labels([1, 2, 10], [5, 11]) array([ 1, 2, 5, 10, 11]) """ + xp, is_array_api = get_namespace(*ys) if not ys: raise ValueError("No argument has been passed.") # Check that we don't mix label format @@ -102,13 +105,17 @@ def unique_labels(*ys): if not _unique_labels: raise ValueError("Unknown label type: %s" % repr(ys)) - ys_labels = set(chain.from_iterable(_unique_labels(y) for y in ys)) + if is_array_api: + # array_api does not allow for mixed dtypes + unique_ys = xp.concat([_unique_labels(y) for y in ys]) + return xp.unique_values(unique_ys) + ys_labels = set(chain.from_iterable((i for i in _unique_labels(y)) for y in ys)) # Check that we don't mix string type with number type if len(set(isinstance(label, str) for label in ys_labels)) > 1: raise ValueError("Mix of label input types (string and number)") - return np.array(sorted(ys_labels)) + return xp.asarray(sorted(ys_labels)) def _is_integral_float(y): @@ -143,17 +150,18 @@ def is_multilabel(y): >>> is_multilabel(np.array([[1, 0, 0]])) True """ - if hasattr(y, "__array__") or isinstance(y, Sequence): + xp, is_array_api = get_namespace(y) + if hasattr(y, "__array__") or isinstance(y, Sequence) or is_array_api: # DeprecationWarning will be replaced by ValueError, see NEP 34 # https://numpy.org/neps/nep-0034-infer-dtype-is-object.html with warnings.catch_warnings(): warnings.simplefilter("error", np.VisibleDeprecationWarning) try: - y = np.asarray(y) + y = xp.asarray(y) except np.VisibleDeprecationWarning: # dtype=object should be provided explicitly for ragged arrays, # see NEP 34 - y = np.array(y, dtype=object) + y = xp.asarray(y, dtype=object) if not (hasattr(y, "shape") and y.ndim == 2 and y.shape[1] > 1): return False @@ -163,14 +171,14 @@ def is_multilabel(y): y = y.tocsr() return ( len(y.data) == 0 - or np.unique(y.data).size == 1 + or xp.unique_values(y.data).size == 1 and ( y.dtype.kind in "biu" - or _is_integral_float(np.unique(y.data)) # bool, int, uint + or _is_integral_float(xp.unique_values(y.data)) # bool, int, uint ) ) else: - labels = np.unique(y) + labels = xp.unique_values(y) return len(labels) < 3 and ( y.dtype.kind in "biu" or _is_integral_float(labels) # bool, int, uint @@ -269,9 +277,12 @@ def type_of_target(y, input_name=""): >>> type_of_target(np.array([[0, 1], [1, 1]])) 'multilabel-indicator' """ + xp, is_array_api = get_namespace(y) valid = ( - isinstance(y, Sequence) or issparse(y) or hasattr(y, "__array__") - ) and not isinstance(y, str) + (isinstance(y, Sequence) or issparse(y) or hasattr(y, "__array__")) + and not isinstance(y, str) + or is_array_api + ) if not valid: raise ValueError( @@ -290,11 +301,11 @@ def type_of_target(y, input_name=""): with warnings.catch_warnings(): warnings.simplefilter("error", np.VisibleDeprecationWarning) try: - y = np.asarray(y) + y = xp.asarray(y) except np.VisibleDeprecationWarning: # dtype=object should be provided explicitly for ragged arrays, # see NEP 34 - y = np.asarray(y, dtype=object) + y = xp.asarray(y, dtype=object) # The old sequence of sequences format try: @@ -326,12 +337,12 @@ def type_of_target(y, input_name=""): suffix = "" # [1, 2, 3] or [[1], [2], [3]] # check float and contains non-integer float values - if y.dtype.kind == "f" and np.any(y != y.astype(int)): + if y.dtype.kind == "f" and xp.any(y != y.astype(int)): # [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.] _assert_all_finite(y, input_name=input_name) return "continuous" + suffix - if (len(np.unique(y)) > 2) or (y.ndim >= 2 and len(y[0]) > 1): + if (xp.unique_values(y).shape[0] > 2) or (y.ndim >= 2 and len(y[0]) > 1): return "multiclass" + suffix # [1, 2, 3] or [[1., 2., 3]] or [[1, 2]] else: return "binary" # [1, 2] or [["a"], ["b"]] diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index cf2265d5b21cd..3487e93ee136b 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -15,6 +15,7 @@ import operator import numpy as np +import numpy import scipy.sparse as sp from inspect import signature, isclass, Parameter @@ -29,6 +30,7 @@ from ..exceptions import PositiveSpectrumWarning from ..exceptions import NotFittedError from ..exceptions import DataConversionWarning +from ..utils._array_api import get_namespace FLOAT_DTYPES = (np.float64, np.float32, np.float16) @@ -94,24 +96,27 @@ def _assert_all_finite( # validation is also imported in extmath from .extmath import _safe_accumulator_op + xp, _ = get_namespace(X) + if _get_config()["assume_finite"]: return - X = np.asanyarray(X) + + X = xp.asarray(X) # First try an O(n) time, O(1) space solution for the common case that # everything is finite; fall back to O(n) space np.isfinite to prevent # false positives from overflow in sum method. The sum is also calculated # safely to reduce dtype induced overflows. is_float = X.dtype.kind in "fc" - if is_float and (np.isfinite(_safe_accumulator_op(np.sum, X))): + if is_float and (xp.isfinite(_safe_accumulator_op(xp.sum, X))): pass elif is_float: if ( allow_nan - and np.isinf(X).any() + and xp.any(xp.isinf(X)) or not allow_nan - and not np.isfinite(X).all() + and not xp.all(xp.isfinite(X)) ): - if not allow_nan and np.isnan(X).any(): + if not allow_nan and xp.any(xp.isnan(X)): type_err = "NaN" else: msg_dtype = msg_dtype if msg_dtype is not None else X.dtype @@ -122,7 +127,7 @@ def _assert_all_finite( not allow_nan and estimator_name and input_name == "X" - and np.isnan(X).any() + and xp.any(xp.isnan(X)) ): # Improve the error message on how to handle missing values in # scikit-learn. @@ -139,8 +144,8 @@ def _assert_all_finite( raise ValueError(msg_err) # for object dtype data, we only check for NaNs (GH-13254) - elif X.dtype == np.dtype("object") and not allow_nan: - if _object_dtype_isnan(X).any(): + elif X.dtype == object and not allow_nan: + if xp.any(_object_dtype_isnan(X)): raise ValueError("Input contains NaN") @@ -703,6 +708,7 @@ def check_array( "https://numpy.org/doc/stable/reference/generated/numpy.matrix.html", # noqa FutureWarning, ) + xp, _ = get_namespace(array) # store reference to original array to check if copy is needed when # function returns @@ -748,7 +754,7 @@ def check_array( if dtype_numeric: if dtype_orig is not None and dtype_orig.kind == "O": # if input is object, convert to float. - dtype = np.float64 + dtype = xp.float64 else: dtype = None @@ -818,7 +824,7 @@ def check_array( # Conversion float -> int should not contain NaN or # inf (numpy#14412). We cannot use casting='safe' because # then conversion float -> int would be disallowed. - array = np.asarray(array, order=order) + array = xp.asarray(array, order=order) if array.dtype.kind == "f": _assert_all_finite( array, @@ -827,9 +833,9 @@ def check_array( estimator_name=estimator_name, input_name=input_name, ) - array = array.astype(dtype, casting="unsafe", copy=False) + array = xp.astype(array, dtype, casting="unsafe", copy=False) else: - array = np.asarray(array, order=order, dtype=dtype) + array = xp.asarray(array, order=order, dtype=dtype) except ComplexWarning as complex_warning: raise ValueError( "Complex data not supported\n{}\n".format(array) @@ -870,7 +876,7 @@ def check_array( stacklevel=2, ) try: - array = array.astype(np.float64) + array = xp.astype(array, np.float64) except ValueError as e: raise ValueError( "Unable to convert array of bytes/strings " @@ -908,8 +914,8 @@ def check_array( % (n_features, array.shape, ensure_min_features, context) ) - if copy and np.may_share_memory(array, array_orig): - array = np.array(array, dtype=dtype, order=order) + if copy and xp.may_share_memory(array, array_orig): + array = xp.asarray(array, dtype=dtype, order=order, copy=True) return array @@ -1119,10 +1125,11 @@ def column_or_1d(y, *, warn=False): ValueError If `y` is not a 1D array or a 2D array with a single row or column. """ - y = np.asarray(y) - shape = np.shape(y) + xp, _ = get_namespace(y) + y = xp.asarray(y) + shape = y.shape if len(shape) == 1: - return np.ravel(y) + return xp.reshape(y, -1) if len(shape) == 2 and shape[1] == 1: if warn: warnings.warn( @@ -1132,7 +1139,7 @@ def column_or_1d(y, *, warn=False): DataConversionWarning, stacklevel=2, ) - return np.ravel(y) + return xp.reshape(y, -1) raise ValueError( "y should be a 1d array, got an array of shape {} instead.".format(shape) @@ -1333,6 +1340,7 @@ def check_non_negative(X, whom): whom : str Who passed X to this function. """ + xp, _ = get_namespace(X) # avoid X.min() on sparse matrix since it also sorts the indices if sp.issparse(X): if X.format in ["lil", "dok"]: @@ -1342,7 +1350,7 @@ def check_non_negative(X, whom): else: X_min = X.data.min() else: - X_min = X.min() + X_min = xp.min(X) if X_min < 0: raise ValueError("Negative values in data passed to %s" % whom)