Skip to content

Use array_api in Gaussian Mixture Models #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions sklearn/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)),
"print_changed_only": True,
"display": "text",
"array_api_dispatch": False,
}
_threadlocal = threading.local()

Expand Down Expand Up @@ -40,7 +41,11 @@ def get_config():


def set_config(
assume_finite=None, working_memory=None, print_changed_only=None, display=None
assume_finite=None,
working_memory=None,
print_changed_only=None,
display=None,
array_api_dispatch=None,
):
"""Set global scikit-learn configuration

Expand Down Expand Up @@ -95,11 +100,18 @@ def set_config(
local_config["print_changed_only"] = print_changed_only
if display is not None:
local_config["display"] = display
if array_api_dispatch is not None:
local_config["array_api_dispatch"] = array_api_dispatch


@contextmanager
def config_context(
*, assume_finite=None, working_memory=None, print_changed_only=None, display=None
*,
assume_finite=None,
working_memory=None,
print_changed_only=None,
display=None,
array_api_dispatch=None,
):
"""Context manager for global scikit-learn configuration.

Expand Down Expand Up @@ -171,6 +183,7 @@ def config_context(
working_memory=working_memory,
print_changed_only=print_changed_only,
display=display,
array_api_dispatch=array_api_dispatch,
)

try:
Expand Down
20 changes: 14 additions & 6 deletions sklearn/mixture/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from time import time

import numpy as np
from scipy.special import logsumexp

from .. import cluster
from ..base import BaseEstimator
from ..base import DensityMixin
from ..exceptions import ConvergenceWarning
from ..utils import check_random_state
from ..utils._array_api import get_namespace, logsumexp
from ..utils.validation import check_is_fitted


Expand Down Expand Up @@ -136,6 +136,7 @@ def _initialize_parameters(self, X, random_state):
used for the method chosen to initialize the parameters.
"""
n_samples, _ = X.shape
np, _ = get_namespace(X)

if self.init_params == "kmeans":
resp = np.zeros((n_samples, self.n_components))
Expand All @@ -149,7 +150,8 @@ def _initialize_parameters(self, X, random_state):
resp[np.arange(n_samples), label] = 1
elif self.init_params == "random":
resp = random_state.rand(n_samples, self.n_components)
resp /= resp.sum(axis=1)[:, np.newaxis]
resp = np.asarray(resp)
resp /= np.reshape(np.sum(resp, axis=1), (-1, 1))
else:
raise ValueError(
"Unimplemented initialization method '%s'" % self.init_params
Expand Down Expand Up @@ -225,6 +227,7 @@ def fit_predict(self, X, y=None):
labels : array, shape (n_samples,)
Component labels.
"""
np, _ = get_namespace(X)
X = self._validate_data(X, dtype=[np.float64, np.float32], ensure_min_samples=2)
if X.shape[0] < self.n_components:
raise ValueError(
Expand Down Expand Up @@ -291,7 +294,7 @@ def fit_predict(self, X, y=None):
# for any value of max_iter and tol (and any random_state).
_, log_resp = self._e_step(X)

return log_resp.argmax(axis=1)
return np.argmax(log_resp, axis=1)

def _e_step(self, X):
"""E step.
Expand All @@ -309,6 +312,7 @@ def _e_step(self, X):
Logarithm of the posterior probabilities (or responsibilities) of
the point of each sample in X.
"""
np, _ = get_namespace(X)
log_prob_norm, log_resp = self._estimate_log_prob_resp(X)
return np.mean(log_prob_norm), log_resp

Expand Down Expand Up @@ -527,11 +531,15 @@ def _estimate_log_prob_resp(self, X):
log_responsibilities : array, shape (n_samples, n_components)
logarithm of the responsibilities
"""
np, is_array_api = get_namespace(X)
weighted_log_prob = self._estimate_weighted_log_prob(X)
log_prob_norm = logsumexp(weighted_log_prob, axis=1)
with np.errstate(under="ignore"):
# ignore underflow
log_resp = weighted_log_prob - log_prob_norm[:, np.newaxis]
if is_array_api:
log_resp = weighted_log_prob - np.reshape(log_prob_norm, (-1, 1))
else:
with np.errstate(under="ignore"):
# ignore underflow
log_resp = weighted_log_prob - log_prob_norm[:, np.newaxis]
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Workaround for no errstate in array_api.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think floating point warnings will ever be portable. They're not even consistent in NumPy, and a constant source of pain. Maybe we need (later) a utility context manager errstate that is do-nothing or delegate to library-specific implementation, to remove the if-else here.

return log_prob_norm, log_resp

def _print_verbose_msg_init_beg(self, n_init):
Expand Down
55 changes: 35 additions & 20 deletions sklearn/mixture/_gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
# License: BSD 3 clause

import numpy as np
from math import log
from functools import partial

from scipy import linalg
import scipy

from ._base import BaseMixture, _check_shape
from ..utils import check_array
from ..utils.extmath import row_norms
from ..utils._array_api import get_namespace


###############################################################################
Expand Down Expand Up @@ -171,12 +175,13 @@ def _estimate_gaussian_covariances_full(resp, X, nk, means, reg_covar):
covariances : array, shape (n_components, n_features, n_features)
The covariance matrix of the current components.
"""
np, _ = get_namespace(resp, X, nk)
n_components, n_features = means.shape
covariances = np.empty((n_components, n_features, n_features))
for k in range(n_components):
diff = X - means[k]
covariances[k] = np.dot(resp[:, k] * diff.T, diff) / nk[k]
covariances[k].flat[:: n_features + 1] += reg_covar
diff = X - means[k, :]
covariances[k, :, :] = ((resp[:, k] * diff.T) @ diff) / nk[k]
np.reshape(covariances[k, :, :], (-1,))[:: n_features + 1] += reg_covar
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Workaround for no flat in array_api.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hard to read either way; I don't think this is a portable solution. covariances is not a 1-D array; wouldn't it be better to reshape the right-hand side here to match the shape of the left-hand size (or broadcast correctly)?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, the right-hand side is a float. The line is adding a float to the diagonal, something like this:

import numpy as np

covariances = np.ones((3, 3))
reg_covar = 4.0

np.fill_diagonal(covariances, covariances.diagonal() + reg_covar)

With only array_api, I see two options:

  1. The current one with reshaping and slicing.
  2. Create the diagonal array on the right hand side (which would allocate more memory compare to option 1):
import numpy.array_api as xp

covariances = xp.ones((3, 3))
reg_covar = 4.0
covariances += reg_covar * np.eye(3)

I could be missing another way of "adding a scalar to the diagonal" using only array_api.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's linalg.diagonal that returns the matrix diagonal but the standard doesn't specify whether it's a view or copy operation. numpy.array_api implementation wraps np.diagonal that returns a non-writable view.

In [1]: import numpy.array_api as xp
In [2]: covariances = xp.ones((3, 3))
In [3]: diag = xp.linalg.diagonal(covariances)
In [4]: reg_covar = 4.0
In [5]: diag += reg_covar
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-5-dc9034854296> in <module>
----> 1 diag += reg_covar

~/.conda/envs/cupy-scipy/lib/python3.9/site-packages/numpy/array_api/_array_object.py in __iadd__(self, other)
    739         if other is NotImplemented:
    740             return other
--> 741         self._array.__iadd__(other._array)
    742         return self
    743

ValueError: output array is read-only

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can never assume something is a view because multiple libraries don't have such a concept. And the ones that do are inconsistent with each other. tl;dr relying on views is always a bug.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's some additional info on why relying on views is not portable:
https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html

This is again an example where the desire to support simultaneously efficient NumPy and portable Array API codes leads to two code paths.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This discussion about JAX not having fill_diagonal is probably relevant: jax-ml/jax#2680. The portable solutions are (a) using eye, or (b) add a for-loop for scalar inplace ops. It wouldn't surprise me if the for-loop is fast compared to the operation above, so it'd be fine and more readable. You could also special case numpy.ndarray if desired.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Athan pointed out that put will solve this, and should get into the standard soonish.

return covariances


Expand Down Expand Up @@ -286,8 +291,9 @@ def _estimate_gaussian_parameters(X, resp, reg_covar, covariance_type):
The covariance matrix of the current components.
The shape depends of the covariance_type.
"""
nk = resp.sum(axis=0) + 10 * np.finfo(resp.dtype).eps
means = np.dot(resp.T, X) / nk[:, np.newaxis]
np, _ = get_namespace(X, resp)
nk = np.sum(resp, axis=0) + 10 * np.finfo(resp.dtype).eps
means = resp.T @ X / np.reshape(nk, (-1, 1))
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using @ to avoid using dot, which I find nicer.

covariances = {
"full": _estimate_gaussian_covariances_full,
"tied": _estimate_gaussian_covariances_tied,
Expand Down Expand Up @@ -321,27 +327,31 @@ def _compute_precision_cholesky(covariances, covariance_type):
"or collapsed samples). Try to decrease the number of components, "
"or increase reg_covar."
)
np, is_array_api = get_namespace(covariances)
if is_array_api:
cholesky = np.linalg.cholesky
solve = np.linalg.solve
else:
cholesky = partial(scipy.linalg.cholesky, lower=True)
solve = partial(scipy.linalg.solve_triangular, lower=True)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to use np.linalg.solve because array_api does not have solve_triangular.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be added in the future perhaps? @IvanYashchuk WDYT?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

solve_triangular could be added to the Array API spec in the future. It wouldn't be terribly difficult to add it to numpy.array_api and cupy.array_api. This functionality is part of level-3 BLAS and it's available in PyTorch (torch.linalg.solve_triangular), in CuPy (cupyx.scipy.linalg.solve_triangular) and I imagine in other libraries as well.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also the territory where the type dispatcher for scipy.linalg.solve_triangular would be handy. CuPy's SciPy should also start working with cupy.array_api inputs (it doesn't currently) and SciPy should work with numpy.array_api inputs.

Replacing scipy with scipy_dispatch (installed with python -m pip install git+https://github.com/IvanYashchuk/scipy-singledispatch.git@master) would give a working prototype:

In [1]: from scipy_dispatch import linalg
   ...: import cupy

In [2]: import cupy.array_api as xp
<ipython-input-2-23333abb466b>:1: UserWarning: The numpy.array_api submodule is still experimental. See NEP 47.
  import cupy.array_api as xp

In [3]: a = cupy.array_api.asarray(cupy.random.random((3, 3)))

In [4]: b = cupy.array_api.asarray(cupy.random.random((3,)))

In [5]: import scipy_dispatch.cupy_backend.linalg # activate dispatching
C:\Users\Ivan\dev\scipy-singledispatch\scipy_dispatch\cupy_backend\linalg.py:12: UserWarning: The numpy.array_api submodule is still experimental. See NEP 47.
  import numpy.array_api

In [6]: type(linalg.solve_triangular(a, b))
Out[6]: cupy.array_api._array_object.Array

Copy link
Owner Author

@thomasjpfan thomasjpfan Jan 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At a glance, scipy_dispatch looks to be a simple wrapper around singledispatch and would cover a majority of what users want. If cupy array -> use cupy operators. If a user wants to use an Intel operator on their NumPy array, they can register a single dispatch on np.ndarray.

uarray adds multiple dispatch, but I do not know if there is a need for it. Is it enough to dispatch based on the first argument and then make sure that all other arguments are compatible?

I'm guessing this has been discussed in length somewhere. Is there a document that explains why multiple dispatch and uarray are required?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a user wants to use an Intel operator on their NumPy array, they can register a single dispatch on np.ndarray.

Right, regular SciPy functions can be registered to work for typing.Any or object type. And an alternative implementation specific to NumPy could be registered using np.ndarray type.

Is it enough to dispatch based on the first argument and then make sure that all other arguments are compatible?

It might be enough in many cases.

I'm guessing this has been discussed in length somewhere. Is there a document that explains why multiple dispatch and uarray are required?

The requirements are being discussed at https://discuss.scientific-python.org/t/requirements-and-discussion-of-a-type-dispatcher-for-the-ecosystem/157/34. A few reasons for uarray are listed in this post (https://discuss.scientific-python.org/t/a-proposed-design-for-supporting-multiple-array-types-across-scipy-scikit-learn-scikit-image-and-beyond/131). The main ones being: 1. ability to specify locally using context managers what backend to use; 2. ability to register different backend for the same array type (most often for np.ndarray); 3. it's already used in scipy.fft and we have a PR for scipy.ndimage.

It may be the case that we actually do not everything that uarray provides and uarray looks scary enough for several people both on the library side and users side that it might be worth considering implementing the dispatching using a simpler option first (that is singledispatch or Plum which in my tests adds less overhead than singledispatch).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Full disclosure: I designed and implemented linalg.solve_triangular in PyTorch.

Just for reference, torch.linalg.solve_triangular diverges from scipy.solve_triangular in the following ways:

  • PyTorch does not expose BLAS' trans parameter (which is a bit confusing when used with lower), but rather handles this internally looking at the strides of the tensor.
  • PyTorch has an upper kwonly parameter without a default. SciPy has lower=False. We went with upper to be consistent with linalg.cholesky.
  • PyTorch implements a left=True parameter that, when false, it solves XA = B.
  • SciPy's unit_diagonal parameter is called unitriangular in PyTorch.
  • SciPy has a number of extra parameters, namely overwrite_b=False, debug=None, check_finite=True

I like PyTorch's behaviour better when it comes to the first three points. I don't care about the naming of the parameters though.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be the case that we actually do not everything that uarray provides and uarray looks scary enough for several people both on the library side and users side that it might be worth considering implementing the dispatching using a simpler option first (that is singledispatch or Plum which in my tests adds less overhead than singledispatch).

There's probably more reasons, see for example the discussion in scipy/scipy#14356 (review) and following comments about conversion and supporting multiple types for the same parameter (union of ndarray and dtype, list/scalar/str inputs, etc.). That PR also adds some developer docs with discussion on this topic.

Unless it can be determined that either such things do not need to be supported in the future or there is a clean upgrade path later on, I don't think there's a point in using singledispatch.


if covariance_type == "full":
n_components, n_features, _ = covariances.shape
precisions_chol = np.empty((n_components, n_features, n_features))
for k, covariance in enumerate(covariances):
for k in range(n_components):
try:
cov_chol = linalg.cholesky(covariance, lower=True)
cov_chol = cholesky(covariances[k, :, :])
except linalg.LinAlgError:
raise ValueError(estimate_precision_error_message)
precisions_chol[k] = linalg.solve_triangular(
cov_chol, np.eye(n_features), lower=True
).T
precisions_chol[k, :, :] = solve(cov_chol, np.eye(n_features)).T

elif covariance_type == "tied":
_, n_features = covariances.shape
try:
cov_chol = linalg.cholesky(covariances, lower=True)
cov_chol = cholesky(covariances)
except linalg.LinAlgError:
raise ValueError(estimate_precision_error_message)
precisions_chol = linalg.solve_triangular(
cov_chol, np.eye(n_features), lower=True
).T
precisions_chol = linalg.solve(cov_chol, np.eye(n_features)).T
else:
if np.any(np.less_equal(covariances, 0.0)):
raise ValueError(estimate_precision_error_message)
Expand Down Expand Up @@ -373,11 +383,11 @@ def _compute_log_det_cholesky(matrix_chol, covariance_type, n_features):
log_det_precision_chol : array-like of shape (n_components,)
The determinant of the precision matrix for each component.
"""
np, _ = get_namespace(matrix_chol)
if covariance_type == "full":
n_components, _, _ = matrix_chol.shape
log_det_chol = np.sum(
np.log(matrix_chol.reshape(n_components, -1)[:, :: n_features + 1]), 1
)
matrix_col_reshape = np.reshape(matrix_chol, (n_components, -1))
log_det_chol = np.sum(np.log(matrix_col_reshape[:, :: n_features + 1]), axis=1)

elif covariance_type == "tied":
log_det_chol = np.sum(np.log(np.diag(matrix_chol)))
Expand Down Expand Up @@ -413,6 +423,7 @@ def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type):
-------
log_prob : array, shape (n_samples, n_components)
"""
np, _ = get_namespace(X, means, precisions_chol)
n_samples, n_features = X.shape
n_components, _ = means.shape
# The determinant of the precision matrix from the Cholesky decomposition
Expand All @@ -423,8 +434,10 @@ def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type):

if covariance_type == "full":
log_prob = np.empty((n_samples, n_components))
for k, (mu, prec_chol) in enumerate(zip(means, precisions_chol)):
y = np.dot(X, prec_chol) - np.dot(mu, prec_chol)
for k in range(n_components):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can not iterate from array_api, so we must iterate through the axis explicitly. (Which I prefer)

mu = means[k, :]
prec_chol = precisions_chol[k, :, :]
y = X @ prec_chol - mu @ prec_chol
log_prob[:, k] = np.sum(np.square(y), axis=1)

elif covariance_type == "tied":
Expand All @@ -450,7 +463,7 @@ def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type):
)
# Since we are using the precision of the Cholesky decomposition,
# `- 0.5 * log_det_precision` becomes `+ log_det_precision_chol`
return -0.5 * (n_features * np.log(2 * np.pi) + log_prob) + log_det
return -0.5 * (n_features * log(2 * np.pi) + log_prob) + log_det


class GaussianMixture(BaseMixture):
Expand Down Expand Up @@ -742,6 +755,7 @@ def _m_step(self, X, log_resp):
the point of each sample in X.
"""
n_samples, _ = X.shape
np, _ = get_namespace(X, log_resp)
self.weights_, self.means_, self.covariances_ = _estimate_gaussian_parameters(
X, np.exp(log_resp), self.reg_covar, self.covariance_type
)
Expand All @@ -756,6 +770,7 @@ def _estimate_log_prob(self, X):
)

def _estimate_log_weights(self):
np, _ = get_namespace(self.weights_)
return np.log(self.weights_)

def _compute_lower_bound(self, _, log_prob_norm):
Expand Down
79 changes: 79 additions & 0 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""Tools to support array_api."""
import numpy as np
from scipy.special import logsumexp as sp_logsumexp
from .._config import get_config


def get_namespace(*xs):
# `xs` contains one or more arrays, or possibly Python scalars (accepting
# those is a matter of taste, but doesn't seem unreasonable).
# Returns a tuple: (array_namespace, is_array_api)

if not get_config()["array_api_dispatch"]:
return np, False
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Global configuration option to control the dispatching.


namespaces = {
x.__array_namespace__() if hasattr(x, "__array_namespace__") else None
for x in xs
if not isinstance(x, (bool, int, float, complex))
}

if not namespaces:
# one could special-case np.ndarray above or use np.asarray here if
# older numpy versions need to be supported.
raise ValueError("Unrecognized array input")

if len(namespaces) != 1:
raise ValueError(f"Multiple namespaces for array inputs: {namespaces}")

(xp,) = namespaces
if xp is None:
# Use numpy as default
return np, False

return xp, True
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returns a boolean so the caller can easily tell if we are using the array_api namespace



def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
np, is_array_api = get_namespace(a)

# Use SciPy if a is an ndarray
if not is_array_api:
return sp_logsumexp(
a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign
)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully this is not needed in the future.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This we should fix in the standard I think. It should have logsumexp.


if b is not None:
a, b = np.broadcast_arrays(a, b)
if np.any(b == 0):
a = a + 0.0 # promote to at least float
a[b == 0] = -np.inf

a_max = np.max(a, axis=axis, keepdims=True)

if a_max.ndim > 0:
a_max[~np.isfinite(a_max)] = 0
elif not np.isfinite(a_max):
a_max = 0

if b is not None:
b = np.asarray(b)
tmp = b * np.exp(a - a_max)
else:
tmp = np.exp(a - a_max)

# suppress warnings about log of zero
s = np.sum(tmp, axis=axis, keepdims=keepdims)
if return_sign:
sgn = np.sign(s)
s *= sgn # /= makes more sense but we need zero -> zero
out = np.log(s)

if not keepdims:
a_max = np.squeeze(a_max, axis=axis)
out += a_max

if return_sign:
return out, sgn
else:
return out
Loading