-
Notifications
You must be signed in to change notification settings - Fork 3
Use array_api in Gaussian Mixture Models #99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
6093daf
ab1ab7a
bd7a4e6
6124b00
57167b4
1f8e25a
46b5be4
19abb1a
7983c55
aeb1488
f574e76
b518f5b
fb12c03
bf5975f
3c08e1d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,12 +5,16 @@ | |
# License: BSD 3 clause | ||
|
||
import numpy as np | ||
from math import log | ||
from functools import partial | ||
|
||
from scipy import linalg | ||
import scipy | ||
|
||
from ._base import BaseMixture, _check_shape | ||
from ..utils import check_array | ||
from ..utils.extmath import row_norms | ||
from ..utils._array_api import get_namespace | ||
|
||
|
||
############################################################################### | ||
|
@@ -171,12 +175,13 @@ def _estimate_gaussian_covariances_full(resp, X, nk, means, reg_covar): | |
covariances : array, shape (n_components, n_features, n_features) | ||
The covariance matrix of the current components. | ||
""" | ||
np, _ = get_namespace(resp, X, nk) | ||
n_components, n_features = means.shape | ||
covariances = np.empty((n_components, n_features, n_features)) | ||
for k in range(n_components): | ||
diff = X - means[k] | ||
covariances[k] = np.dot(resp[:, k] * diff.T, diff) / nk[k] | ||
covariances[k].flat[:: n_features + 1] += reg_covar | ||
diff = X - means[k, :] | ||
covariances[k, :, :] = ((resp[:, k] * diff.T) @ diff) / nk[k] | ||
np.reshape(covariances[k, :, :], (-1,))[:: n_features + 1] += reg_covar | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Workaround for no There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hard to read either way; I don't think this is a portable solution. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case, the right-hand side is a float. The line is adding a float to the diagonal, something like this: import numpy as np
covariances = np.ones((3, 3))
reg_covar = 4.0
np.fill_diagonal(covariances, covariances.diagonal() + reg_covar) With only
import numpy.array_api as xp
covariances = xp.ones((3, 3))
reg_covar = 4.0
covariances += reg_covar * np.eye(3) I could be missing another way of "adding a scalar to the diagonal" using only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's In [1]: import numpy.array_api as xp
In [2]: covariances = xp.ones((3, 3))
In [3]: diag = xp.linalg.diagonal(covariances)
In [4]: reg_covar = 4.0
In [5]: diag += reg_covar
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-5-dc9034854296> in <module>
----> 1 diag += reg_covar
~/.conda/envs/cupy-scipy/lib/python3.9/site-packages/numpy/array_api/_array_object.py in __iadd__(self, other)
739 if other is NotImplemented:
740 return other
--> 741 self._array.__iadd__(other._array)
742 return self
743
ValueError: output array is read-only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can never assume something is a view because multiple libraries don't have such a concept. And the ones that do are inconsistent with each other. tl;dr relying on views is always a bug. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's some additional info on why relying on views is not portable: This is again an example where the desire to support simultaneously efficient NumPy and portable Array API codes leads to two code paths. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This discussion about JAX not having There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Athan pointed out that |
||
return covariances | ||
|
||
|
||
|
@@ -286,8 +291,9 @@ def _estimate_gaussian_parameters(X, resp, reg_covar, covariance_type): | |
The covariance matrix of the current components. | ||
The shape depends of the covariance_type. | ||
""" | ||
nk = resp.sum(axis=0) + 10 * np.finfo(resp.dtype).eps | ||
means = np.dot(resp.T, X) / nk[:, np.newaxis] | ||
np, _ = get_namespace(X, resp) | ||
nk = np.sum(resp, axis=0) + 10 * np.finfo(resp.dtype).eps | ||
means = resp.T @ X / np.reshape(nk, (-1, 1)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using |
||
covariances = { | ||
"full": _estimate_gaussian_covariances_full, | ||
"tied": _estimate_gaussian_covariances_tied, | ||
|
@@ -321,27 +327,31 @@ def _compute_precision_cholesky(covariances, covariance_type): | |
"or collapsed samples). Try to decrease the number of components, " | ||
"or increase reg_covar." | ||
) | ||
np, is_array_api = get_namespace(covariances) | ||
if is_array_api: | ||
cholesky = np.linalg.cholesky | ||
solve = np.linalg.solve | ||
else: | ||
cholesky = partial(scipy.linalg.cholesky, lower=True) | ||
solve = partial(scipy.linalg.solve_triangular, lower=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be added in the future perhaps? @IvanYashchuk WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is also the territory where the type dispatcher for Replacing In [1]: from scipy_dispatch import linalg
...: import cupy
In [2]: import cupy.array_api as xp
<ipython-input-2-23333abb466b>:1: UserWarning: The numpy.array_api submodule is still experimental. See NEP 47.
import cupy.array_api as xp
In [3]: a = cupy.array_api.asarray(cupy.random.random((3, 3)))
In [4]: b = cupy.array_api.asarray(cupy.random.random((3,)))
In [5]: import scipy_dispatch.cupy_backend.linalg # activate dispatching
C:\Users\Ivan\dev\scipy-singledispatch\scipy_dispatch\cupy_backend\linalg.py:12: UserWarning: The numpy.array_api submodule is still experimental. See NEP 47.
import numpy.array_api
In [6]: type(linalg.solve_triangular(a, b))
Out[6]: cupy.array_api._array_object.Array There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At a glance,
I'm guessing this has been discussed in length somewhere. Is there a document that explains why multiple dispatch and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Right, regular SciPy functions can be registered to work for
It might be enough in many cases.
The requirements are being discussed at https://discuss.scientific-python.org/t/requirements-and-discussion-of-a-type-dispatcher-for-the-ecosystem/157/34. A few reasons for It may be the case that we actually do not everything that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Full disclosure: I designed and implemented Just for reference,
I like PyTorch's behaviour better when it comes to the first three points. I don't care about the naming of the parameters though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There's probably more reasons, see for example the discussion in scipy/scipy#14356 (review) and following comments about conversion and supporting multiple types for the same parameter (union of ndarray and dtype, list/scalar/str inputs, etc.). That PR also adds some developer docs with discussion on this topic. Unless it can be determined that either such things do not need to be supported in the future or there is a clean upgrade path later on, I don't think there's a point in using |
||
|
||
if covariance_type == "full": | ||
n_components, n_features, _ = covariances.shape | ||
precisions_chol = np.empty((n_components, n_features, n_features)) | ||
for k, covariance in enumerate(covariances): | ||
for k in range(n_components): | ||
try: | ||
cov_chol = linalg.cholesky(covariance, lower=True) | ||
cov_chol = cholesky(covariances[k, :, :]) | ||
except linalg.LinAlgError: | ||
raise ValueError(estimate_precision_error_message) | ||
precisions_chol[k] = linalg.solve_triangular( | ||
cov_chol, np.eye(n_features), lower=True | ||
).T | ||
precisions_chol[k, :, :] = solve(cov_chol, np.eye(n_features)).T | ||
|
||
elif covariance_type == "tied": | ||
_, n_features = covariances.shape | ||
try: | ||
cov_chol = linalg.cholesky(covariances, lower=True) | ||
cov_chol = cholesky(covariances) | ||
except linalg.LinAlgError: | ||
raise ValueError(estimate_precision_error_message) | ||
precisions_chol = linalg.solve_triangular( | ||
cov_chol, np.eye(n_features), lower=True | ||
).T | ||
precisions_chol = linalg.solve(cov_chol, np.eye(n_features)).T | ||
else: | ||
if np.any(np.less_equal(covariances, 0.0)): | ||
raise ValueError(estimate_precision_error_message) | ||
|
@@ -373,11 +383,11 @@ def _compute_log_det_cholesky(matrix_chol, covariance_type, n_features): | |
log_det_precision_chol : array-like of shape (n_components,) | ||
The determinant of the precision matrix for each component. | ||
""" | ||
np, _ = get_namespace(matrix_chol) | ||
if covariance_type == "full": | ||
n_components, _, _ = matrix_chol.shape | ||
log_det_chol = np.sum( | ||
np.log(matrix_chol.reshape(n_components, -1)[:, :: n_features + 1]), 1 | ||
) | ||
matrix_col_reshape = np.reshape(matrix_chol, (n_components, -1)) | ||
log_det_chol = np.sum(np.log(matrix_col_reshape[:, :: n_features + 1]), axis=1) | ||
|
||
elif covariance_type == "tied": | ||
log_det_chol = np.sum(np.log(np.diag(matrix_chol))) | ||
|
@@ -413,6 +423,7 @@ def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type): | |
------- | ||
log_prob : array, shape (n_samples, n_components) | ||
""" | ||
np, _ = get_namespace(X, means, precisions_chol) | ||
n_samples, n_features = X.shape | ||
n_components, _ = means.shape | ||
# The determinant of the precision matrix from the Cholesky decomposition | ||
|
@@ -423,8 +434,10 @@ def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type): | |
|
||
if covariance_type == "full": | ||
log_prob = np.empty((n_samples, n_components)) | ||
for k, (mu, prec_chol) in enumerate(zip(means, precisions_chol)): | ||
y = np.dot(X, prec_chol) - np.dot(mu, prec_chol) | ||
for k in range(n_components): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can not iterate from |
||
mu = means[k, :] | ||
prec_chol = precisions_chol[k, :, :] | ||
y = X @ prec_chol - mu @ prec_chol | ||
log_prob[:, k] = np.sum(np.square(y), axis=1) | ||
|
||
elif covariance_type == "tied": | ||
|
@@ -450,7 +463,7 @@ def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type): | |
) | ||
# Since we are using the precision of the Cholesky decomposition, | ||
# `- 0.5 * log_det_precision` becomes `+ log_det_precision_chol` | ||
return -0.5 * (n_features * np.log(2 * np.pi) + log_prob) + log_det | ||
return -0.5 * (n_features * log(2 * np.pi) + log_prob) + log_det | ||
|
||
|
||
class GaussianMixture(BaseMixture): | ||
|
@@ -742,6 +755,7 @@ def _m_step(self, X, log_resp): | |
the point of each sample in X. | ||
""" | ||
n_samples, _ = X.shape | ||
np, _ = get_namespace(X, log_resp) | ||
self.weights_, self.means_, self.covariances_ = _estimate_gaussian_parameters( | ||
X, np.exp(log_resp), self.reg_covar, self.covariance_type | ||
) | ||
|
@@ -756,6 +770,7 @@ def _estimate_log_prob(self, X): | |
) | ||
|
||
def _estimate_log_weights(self): | ||
np, _ = get_namespace(self.weights_) | ||
return np.log(self.weights_) | ||
|
||
def _compute_lower_bound(self, _, log_prob_norm): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
"""Tools to support array_api.""" | ||
import numpy as np | ||
from scipy.special import logsumexp as sp_logsumexp | ||
from .._config import get_config | ||
|
||
|
||
def get_namespace(*xs): | ||
# `xs` contains one or more arrays, or possibly Python scalars (accepting | ||
# those is a matter of taste, but doesn't seem unreasonable). | ||
# Returns a tuple: (array_namespace, is_array_api) | ||
|
||
if not get_config()["array_api_dispatch"]: | ||
return np, False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Global configuration option to control the dispatching. |
||
|
||
namespaces = { | ||
x.__array_namespace__() if hasattr(x, "__array_namespace__") else None | ||
for x in xs | ||
if not isinstance(x, (bool, int, float, complex)) | ||
} | ||
|
||
if not namespaces: | ||
# one could special-case np.ndarray above or use np.asarray here if | ||
# older numpy versions need to be supported. | ||
raise ValueError("Unrecognized array input") | ||
|
||
if len(namespaces) != 1: | ||
raise ValueError(f"Multiple namespaces for array inputs: {namespaces}") | ||
|
||
(xp,) = namespaces | ||
if xp is None: | ||
# Use numpy as default | ||
return np, False | ||
|
||
return xp, True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Returns a boolean so the caller can easily tell if we are using the |
||
|
||
|
||
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): | ||
np, is_array_api = get_namespace(a) | ||
|
||
# Use SciPy if a is an ndarray | ||
if not is_array_api: | ||
return sp_logsumexp( | ||
a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hopefully this is not needed in the future. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This we should fix in the standard I think. It should have |
||
|
||
if b is not None: | ||
a, b = np.broadcast_arrays(a, b) | ||
if np.any(b == 0): | ||
a = a + 0.0 # promote to at least float | ||
a[b == 0] = -np.inf | ||
|
||
a_max = np.max(a, axis=axis, keepdims=True) | ||
|
||
if a_max.ndim > 0: | ||
a_max[~np.isfinite(a_max)] = 0 | ||
elif not np.isfinite(a_max): | ||
a_max = 0 | ||
|
||
if b is not None: | ||
b = np.asarray(b) | ||
tmp = b * np.exp(a - a_max) | ||
else: | ||
tmp = np.exp(a - a_max) | ||
|
||
# suppress warnings about log of zero | ||
s = np.sum(tmp, axis=axis, keepdims=keepdims) | ||
if return_sign: | ||
sgn = np.sign(s) | ||
s *= sgn # /= makes more sense but we need zero -> zero | ||
out = np.log(s) | ||
|
||
if not keepdims: | ||
a_max = np.squeeze(a_max, axis=axis) | ||
out += a_max | ||
|
||
if return_sign: | ||
return out, sgn | ||
else: | ||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Workaround for no
errstate
inarray_api
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think floating point warnings will ever be portable. They're not even consistent in NumPy, and a constant source of pain. Maybe we need (later) a utility context manager
errstate
that is do-nothing or delegate to library-specific implementation, to remove theif-else
here.