Skip to content

ENH Enables array_api for LinearDiscriminantAnalysis #102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions sklearn/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)),
"print_changed_only": True,
"display": "text",
"array_api_dispatch": False,
}
_threadlocal = threading.local()

Expand Down Expand Up @@ -40,7 +41,11 @@ def get_config():


def set_config(
assume_finite=None, working_memory=None, print_changed_only=None, display=None
assume_finite=None,
working_memory=None,
print_changed_only=None,
display=None,
array_api_dispatch=None,
):
"""Set global scikit-learn configuration

Expand Down Expand Up @@ -95,11 +100,18 @@ def set_config(
local_config["print_changed_only"] = print_changed_only
if display is not None:
local_config["display"] = display
if array_api_dispatch is not None:
local_config["array_api_dispatch"] = array_api_dispatch


@contextmanager
def config_context(
*, assume_finite=None, working_memory=None, print_changed_only=None, display=None
*,
assume_finite=None,
working_memory=None,
print_changed_only=None,
display=None,
array_api_dispatch=None,
):
"""Context manager for global scikit-learn configuration.

Expand Down Expand Up @@ -171,6 +183,7 @@ def config_context(
working_memory=working_memory,
print_changed_only=print_changed_only,
display=display,
array_api_dispatch=array_api_dispatch,
)

try:
Expand Down
107 changes: 61 additions & 46 deletions sklearn/discriminant_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@

import warnings
import numpy as np
import scipy.linalg
from scipy import linalg
from scipy.special import expit
import math

from .base import BaseEstimator, TransformerMixin, ClassifierMixin
from .base import _ClassNamePrefixFeaturesOutMixin
from .linear_model._base import LinearClassifierMixin
from .covariance import ledoit_wolf, empirical_covariance, shrunk_covariance
from .utils.multiclass import unique_labels
from .utils.validation import check_is_fitted
from .utils._array_api import get_namespace
from .utils.multiclass import check_classification_targets
from .utils.extmath import softmax
from .preprocessing import StandardScaler
Expand Down Expand Up @@ -110,11 +113,17 @@ def _class_means(X, y):
means : array-like of shape (n_classes, n_features)
Class means.
"""
classes, y = np.unique(y, return_inverse=True)
cnt = np.bincount(y)
means = np.zeros(shape=(len(classes), X.shape[1]))
np.add.at(means, y, X)
means /= cnt[:, None]
xp, is_array_api = get_namespace(X)
classes, y = xp.unique_inverse(y)
means = xp.zeros(shape=(classes.shape[0], X.shape[1]))

if is_array_api:
for i in range(classes.shape[0]):
means[i, :] = xp.mean(X[y == i], axis=0)
else:
cnt = np.bincount(y)
np.add.at(means, y, X)
means /= cnt[:, None]
return means


Expand Down Expand Up @@ -466,64 +475,67 @@ def _solve_svd(self, X, y):
y : array-like of shape (n_samples,) or (n_samples, n_targets)
Target values.
"""
xp, is_array_api = get_namespace(X)

if is_array_api:
svd = xp.linalg.svd
else:
svd = scipy.linalg.svd

n_samples, n_features = X.shape
n_classes = len(self.classes_)
n_classes = self.classes_.shape[0]

self.means_ = _class_means(X, y)
if self.store_covariance:
self.covariance_ = _class_cov(X, y, self.priors_)

Xc = []
for idx, group in enumerate(self.classes_):
Xg = X[y == group, :]
Xc.append(Xg - self.means_[idx])
Xg = X[y == group]
Xc.append(Xg - self.means_[idx, :])

self.xbar_ = np.dot(self.priors_, self.means_)
self.xbar_ = self.priors_ @ self.means_

Xc = np.concatenate(Xc, axis=0)
Xc = xp.concat(Xc, axis=0)

# 1) within (univariate) scaling by with classes std-dev
std = Xc.std(axis=0)
std = xp.std(Xc, axis=0)
# avoid division by zero in normalization
std[std == 0] = 1.0
fac = 1.0 / (n_samples - n_classes)

# 2) Within variance scaling
X = np.sqrt(fac) * (Xc / std)
X = math.sqrt(fac) * (Xc / std)
# SVD of centered (within)scaled data
U, S, Vt = linalg.svd(X, full_matrices=False, check_finite=False)
U, S, Vt = svd(X, full_matrices=False)

rank = np.sum(S > self.tol)
rank = xp.sum(xp.astype(S > self.tol, xp.int32))
# Scaling of within covariance is: V' 1/S
scalings = (Vt[:rank] / std).T / S[:rank]
scalings = (Vt[:rank, :] / std).T / S[:rank]

# 3) Between variance scaling
# Scale weighted centers
X = np.dot(
(
(np.sqrt((n_samples * self.priors_) * fac))
* (self.means_ - self.xbar_).T
).T,
scalings,
)
X = (
(xp.sqrt((n_samples * self.priors_) * fac)) * (self.means_ - self.xbar_).T
).T @ scalings
# Centers are living in a space with n_classes-1 dim (maximum)
# Use SVD to find projection in the space spanned by the
# (n_classes) centers
_, S, Vt = linalg.svd(X, full_matrices=0, check_finite=False)
_, S, Vt = svd(X, full_matrices=False)

if self._max_components == 0:
self.explained_variance_ratio_ = np.empty((0,), dtype=S.dtype)
self.explained_variance_ratio_ = xp.empty((0,), dtype=S.dtype)
else:
self.explained_variance_ratio_ = (S ** 2 / np.sum(S ** 2))[
self.explained_variance_ratio_ = (S ** 2 / xp.sum(S ** 2))[
: self._max_components
]

rank = np.sum(S > self.tol * S[0])
self.scalings_ = np.dot(scalings, Vt.T[:, :rank])
coef = np.dot(self.means_ - self.xbar_, self.scalings_)
self.intercept_ = -0.5 * np.sum(coef ** 2, axis=1) + np.log(self.priors_)
self.coef_ = np.dot(coef, self.scalings_.T)
self.intercept_ -= np.dot(self.xbar_, self.coef_.T)
rank = xp.sum(xp.astype(S > self.tol * S[0], xp.int32))
self.scalings_ = scalings @ Vt.T[:, :rank]
coef = (self.means_ - self.xbar_) @ self.scalings_
self.intercept_ = -0.5 * xp.sum(coef ** 2, axis=1) + xp.log(self.priors_)
self.coef_ = coef @ self.scalings_.T
self.intercept_ -= self.xbar_ @ self.coef_.T

def fit(self, X, y):
"""Fit the Linear Discriminant Analysis model.
Expand All @@ -547,33 +559,36 @@ def fit(self, X, y):
self : object
Fitted estimator.
"""
xp, _ = get_namespace(X)
X, y = self._validate_data(
X, y, ensure_min_samples=2, dtype=[np.float64, np.float32]
X, y, ensure_min_samples=2, dtype=[xp.float64, xp.float32]
)
self.classes_ = unique_labels(y)
n_samples, _ = X.shape
n_classes = len(self.classes_)
n_classes = self.classes_.shape[0]

if n_samples == n_classes:
raise ValueError(
"The number of samples must be more than the number of classes."
)

if self.priors is None: # estimate priors from sample
_, y_t = np.unique(y, return_inverse=True) # non-negative ints
self.priors_ = np.bincount(y_t) / float(len(y))
_, cnts = xp.unique_counts(y) # non-negative ints
self.priors_ = xp.astype(cnts, xp.float64) / float(y.shape[0])
else:
self.priors_ = np.asarray(self.priors)
self.priors_ = xp.asarray(self.priors)

if (self.priors_ < 0).any():
if xp.any(self.priors_ < 0):
raise ValueError("priors must be non-negative")
if not np.isclose(self.priors_.sum(), 1.0):
warnings.warn("The priors do not sum to 1. Renormalizing", UserWarning)
self.priors_ = self.priors_ / self.priors_.sum()

# TODO: implement isclose in wrapper?
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to implement our own isclose.

# if not np.isclose(np.sum(self.priors_), 1.0):
# warnings.warn("The priors do not sum to 1. Renormalizing", UserWarning)
# self.priors_ = self.priors_ / self.priors_.sum()

# Maximum number of components no matter what n_components is
# specified:
max_components = min(len(self.classes_) - 1, X.shape[1])
max_components = min(n_classes - 1, X.shape[1])

if self.n_components is None:
self._max_components = max_components
Expand Down Expand Up @@ -614,12 +629,12 @@ def fit(self, X, y):
"'lsqr', and 'eigen').".format(self.solver)
)
if self.classes_.size == 2: # treat binary case as a special case
self.coef_ = np.array(
self.coef_[1, :] - self.coef_[0, :], ndmin=2, dtype=X.dtype
)
self.intercept_ = np.array(
self.intercept_[1] - self.intercept_[0], ndmin=1, dtype=X.dtype
coef_ = xp.asarray(self.coef_[1, :] - self.coef_[0, :], dtype=X.dtype)
self.coef_ = xp.reshape(coef_, (1, -1))
intercept_ = xp.asarray(
self.intercept_[1] - self.intercept_[0], dtype=X.dtype
)
self.intercept_ = xp.reshape(intercept_, 1)
self._n_features_out = self._max_components
return self

Expand Down
12 changes: 9 additions & 3 deletions sklearn/linear_model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ..utils.extmath import _incremental_mean_and_var
from ..utils.sparsefuncs import mean_variance_axis, inplace_column_scale
from ..utils.fixes import sparse_lsqr
from ..utils._array_api import get_namespace
from ..utils._seq_dataset import ArrayDataset32, CSRDataset32
from ..utils._seq_dataset import ArrayDataset64, CSRDataset64
from ..utils.validation import check_is_fitted, _check_sample_weight
Expand Down Expand Up @@ -405,8 +406,9 @@ def decision_function(self, X):
check_is_fitted(self)

X = self._validate_data(X, accept_sparse="csr", reset=False)
xp, _ = get_namespace(X)
scores = safe_sparse_dot(X, self.coef_.T, dense_output=True) + self.intercept_
return scores.ravel() if scores.shape[1] == 1 else scores
return xp.reshape(scores, -1) if scores.shape[1] == 1 else scores

def predict(self, X):
"""
Expand All @@ -422,12 +424,16 @@ def predict(self, X):
y_pred : ndarray of shape (n_samples,)
Vector containing the class labels for each sample.
"""
xp, _ = get_namespace(X)
scores = self.decision_function(X)
if len(scores.shape) == 1:
indices = (scores > 0).astype(int)
indices = xp.astype(scores > 0, int)
else:
indices = scores.argmax(axis=1)
return self.classes_[indices]
# Should really use `np.take`
# return np.take(self.classes_, indices, axis=0)
# assuming classes are [0, 1, 2, ....]
return indices

def _predict_proba_lr(self, X):
"""Probability estimation for OvR logistic regression.
Expand Down
29 changes: 29 additions & 0 deletions sklearn/tests/test_discriminant_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from scipy import linalg

from sklearn.base import clone
from sklearn._config import config_context
from sklearn.utils import check_random_state
from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import assert_array_almost_equal
Expand Down Expand Up @@ -667,3 +669,30 @@ def test_get_feature_names_out():
dtype=object,
)
assert_array_equal(names_out, expected_names_out)


def test_lda_array_api():
"""Check that the array_api Array gives the same results as ndarrays."""
pytest.importorskip("numpy", minversion="1.22", reason="Requires Array API")
xp = pytest.importorskip("numpy.array_api")

X_xp = xp.asarray(X)
y_xp = xp.asarray(y)

lda = LinearDiscriminantAnalysis()
lda.fit(X, y)

lda_xp = clone(lda)
with config_context(array_api_dispatch=True):
lda_xp.fit(X_xp, y_xp)

gm_attributes_array = {
key: value for key, value in vars(lda).items() if isinstance(value, np.ndarray)
}
for key in gm_attributes_array:
gm_xp_param = getattr(lda_xp, key)
assert hasattr(gm_xp_param, "__array_namespace__")

assert_allclose(
gm_attributes_array[key], gm_xp_param, err_msg=f"{key} not the same"
)
Loading