Skip to content

Commit e3771c3

Browse files
authored
SVM Outlier detector (#814)
* Add svm pytorch sgd backend * Add svm pytorch bgd backend * Add svm tests * Add svm frontend implementation * Add docstrings
1 parent a6ec861 commit e3771c3

File tree

7 files changed

+1045
-0
lines changed

7 files changed

+1045
-0
lines changed

alibi_detect/od/_svm.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
2+
3+
import numpy as np
4+
5+
from alibi_detect.base import (BaseDetector, FitMixin, ThresholdMixin,
6+
outlier_prediction_dict)
7+
from alibi_detect.exceptions import _catch_error as catch_error
8+
from alibi_detect.od.pytorch import SgdSVMTorch, BgdSVMTorch
9+
from alibi_detect.utils._types import Literal
10+
from alibi_detect.utils.frameworks import BackendValidator
11+
from alibi_detect.version import __version__
12+
13+
14+
if TYPE_CHECKING:
15+
import torch
16+
17+
18+
backends = {
19+
'pytorch': {
20+
'sgd': SgdSVMTorch,
21+
'bgd': BgdSVMTorch
22+
}
23+
}
24+
25+
26+
class SVM(BaseDetector, ThresholdMixin, FitMixin):
27+
def __init__(
28+
self,
29+
nu: float,
30+
n_components: Optional[int] = None,
31+
kernel: 'torch.nn.Module' = None,
32+
optimization: Literal['sgd', 'bgd'] = 'sgd',
33+
backend: Literal['pytorch'] = 'pytorch',
34+
device: Optional[Union[Literal['cuda', 'gpu', 'cpu'], 'torch.device']] = None,
35+
) -> None:
36+
"""One-Class Support vector machine (OCSVM) outlier detector.
37+
38+
The one-class Support vector machine outlier detector fits a one-class SVM to the reference data.
39+
40+
Rather than the typical approach of optimizing the exact kernel OCSVM objective through a dual formulation,
41+
here we instead map the data into the kernel's RKHS and then solve the linear optimization problem
42+
directly through its primal formulation. The Nystroem approximation is used to speed up training and inference
43+
by approximating the kernel's RKHS.
44+
45+
We provide two options, specified by the `optimization` parameter, for optimizing the one-class svm. `''sgd''`
46+
wraps the `SGDOneClassSVM` class from the sklearn package and the other, `''bgd''` uses a custom implementation
47+
in PyTorch. The PyTorch approach is tailored for operation on GPUs. Instead of applying stochastic gradient
48+
descent (one data point at a time) with a fixed learning rate schedule it performs full gradient descent with
49+
step size chosen at each iteration via line search. Note that on a CPU this would not necessarily be preferable
50+
to SGD as we would have to iterate through both data points and candidate step sizes, however on GPU all of the
51+
operations are vectorized/parallelized. Moreover, the Nystroem approximation has complexity `O(n^2m)` where
52+
`n` is the number of reference instances and `m` defines the number of inducing points. This can therefore be
53+
expensive for large reference sets and benefits from implementation on the GPU.
54+
55+
In general if using a small dataset then using the `''cpu''` with the optimization `''sgd''` is the best choice.
56+
Whereas if using a large dataset then using the `''gpu''` with the optimization `''bgd''` is the best choice.
57+
58+
Parameters
59+
----------
60+
nu
61+
The proportion of the training data that should be considered outliers. Note that this does not necessarily
62+
correspond to the false positive rate on test data, which is still defined when calling the
63+
`infer_threshold` method. `nu` should be thought of as a regularization parameter that affects how smooth
64+
the svm decision boundary is.
65+
n_components
66+
Number of components in the Nystroem approximation, By default uses all of them.
67+
kernel
68+
Kernel function to use for outlier detection. Should be an instance of a subclass of `torch.nn.Module`. If
69+
not specified then defaults to the `GaussianRBF`.
70+
optimization
71+
Optimization method to use. Choose from ``'sgd'`` or ``'bgd'``. Defaults to ``'sgd'``.
72+
backend
73+
Backend used for outlier detection. Defaults to ``'pytorch'``. Options are ``'pytorch'``.
74+
device
75+
Device type used. The default tries to use the GPU and falls back on CPU if needed. Can be specified by
76+
passing either ``'cuda'``, ``'gpu'``, ``'cpu'`` or an instance of ``torch.device``.
77+
78+
Raises
79+
------
80+
NotImplementedError
81+
If choice of `backend` is not implemented.
82+
ValueError
83+
If choice of `optimization` is not valid.
84+
ValueError
85+
If `n_components` is not a positive integer.
86+
"""
87+
super().__init__()
88+
89+
if optimization not in ('sgd', 'bgd'):
90+
raise ValueError(f'Optimization {optimization} not recognized. Choose from `sgd` or `bgd`.')
91+
92+
if n_components is not None and n_components <= 0:
93+
raise ValueError(f'n_components must be a positive integer, got {n_components}.')
94+
95+
backend_str: str = backend.lower()
96+
BackendValidator(
97+
backend_options={'pytorch': ['pytorch']},
98+
construct_name=self.__class__.__name__
99+
).verify_backend(backend_str)
100+
101+
backend_cls = backends[backend][optimization]
102+
args: Dict[str, Any] = {
103+
'n_components': n_components,
104+
'kernel': kernel,
105+
'nu': nu
106+
}
107+
args['device'] = device
108+
self.backend = backend_cls(**args)
109+
110+
def fit(
111+
self,
112+
x_ref: np.ndarray,
113+
tol: float = 1e-6,
114+
max_iter: int = 1000,
115+
step_size_range: Tuple[float, float] = (1e-8, 1.0),
116+
n_step_sizes: int = 16,
117+
n_iter_no_change: int = 25,
118+
verbose: int = 0,
119+
) -> None:
120+
"""Fit the detector on reference data.
121+
122+
Uses the choice of optimization method to fit the svm model to the data.
123+
124+
Parameters
125+
----------
126+
x_ref
127+
Reference data used to fit the detector.
128+
tol
129+
Convergence threshold used to fit the detector. Used for both ``'sgd'`` and ``'bgd'`` optimizations.
130+
Defaults to ``1e-3``.
131+
max_iter
132+
The maximum number of optimization steps. Used for both ``'sgd'`` and ``'bgd'`` optimizations.
133+
step_size_range
134+
The range of values to be considered for the gradient descent step size at each iteration. This is specified
135+
as a tuple of the form `(min_eta, max_eta)` and only used for the ``'bgd'`` optimization.
136+
n_step_sizes
137+
The number of step sizes in the defined range to be tested for loss reduction. This many points are spaced
138+
evenly along the range in log space. This is only used for the ``'bgd'`` optimization.
139+
n_iter_no_change
140+
The number of iterations over which the loss must decrease by `tol` in order for optimization to continue.
141+
This is only used for the ``'bgd'`` optimization..
142+
verbose
143+
Verbosity level during training. ``0`` is silent, ``1`` prints fit status. If using `bgd`, fit displays a
144+
progress bar. Otherwise, if using `sgd` then we output the Sklearn `SGDOneClassSVM.fit()` logs.
145+
"""
146+
self.backend.fit(
147+
self.backend._to_tensor(x_ref),
148+
**self.backend.format_fit_kwargs(locals())
149+
)
150+
151+
@catch_error('NotFittedError')
152+
def score(self, x: np.ndarray) -> np.ndarray:
153+
"""Score `x` instances using the detector.
154+
155+
Scores the data using the fitted svm model. The higher the score, the more anomalous the instance.
156+
157+
Parameters
158+
----------
159+
x
160+
Data to score. The shape of `x` should be `(n_instances, n_features)`.
161+
162+
Returns
163+
-------
164+
Outlier scores. The shape of the scores is `(n_instances,)`. The higher the score, the more anomalous the \
165+
instance.
166+
167+
Raises
168+
------
169+
NotFittedError
170+
If called before detector has been fit.
171+
"""
172+
score = self.backend.score(self.backend._to_tensor(x))
173+
return self.backend._to_numpy(score)
174+
175+
@catch_error('NotFittedError')
176+
def infer_threshold(self, x: np.ndarray, fpr: float) -> None:
177+
"""Infer the threshold for the SVM detector.
178+
179+
The threshold is computed so that the outlier detector would incorrectly classify `fpr` proportion of the
180+
reference data as outliers.
181+
182+
Parameters
183+
----------
184+
x
185+
Reference data used to infer the threshold.
186+
fpr
187+
False positive rate used to infer the threshold. The false positive rate is the proportion of
188+
instances in `x` that are incorrectly classified as outliers. The false positive rate should
189+
be in the range ``(0, 1)``.
190+
191+
Raises
192+
------
193+
ValueError
194+
Raised if `fpr` is not in ``(0, 1)``.
195+
NotFittedError
196+
If called before detector has been fit.
197+
"""
198+
self.backend.infer_threshold(self.backend._to_tensor(x), fpr)
199+
200+
@catch_error('NotFittedError')
201+
def predict(self, x: np.ndarray) -> Dict[str, Any]:
202+
"""Predict whether the instances in `x` are outliers or not.
203+
204+
Scores the instances in `x` and if the threshold was inferred, returns the outlier labels and p-values as well.
205+
206+
Parameters
207+
----------
208+
x
209+
Data to predict. The shape of `x` should be `(n_instances, n_features)`.
210+
211+
Returns
212+
-------
213+
Dictionary with keys 'data' and 'meta'. 'data' contains the outlier scores. If threshold inference was \
214+
performed, 'data' also contains the threshold value, outlier labels and p-vals . The shape of the scores is \
215+
`(n_instances,)`. The higher the score, the more anomalous the instance. 'meta' contains information about \
216+
the detector.
217+
218+
Raises
219+
------
220+
NotFittedError
221+
If called before detector has been fit.
222+
"""
223+
outputs = self.backend.predict(self.backend._to_tensor(x))
224+
output = outlier_prediction_dict()
225+
output['data'] = {
226+
**output['data'],
227+
**self.backend._to_numpy(outputs)
228+
}
229+
output['meta'] = {
230+
**output['meta'],
231+
'name': self.__class__.__name__,
232+
'detector_type': 'outlier',
233+
'online': False,
234+
'version': __version__,
235+
}
236+
return output

alibi_detect/od/pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
KernelPCATorch, LinearPCATorch = import_optional('alibi_detect.od.pytorch.pca', ['KernelPCATorch', 'LinearPCATorch'])
77
Ensembler = import_optional('alibi_detect.od.pytorch.ensemble', ['Ensembler'])
88
GMMTorch = import_optional('alibi_detect.od.pytorch.gmm', ['GMMTorch'])
9+
BgdSVMTorch, SgdSVMTorch = import_optional('alibi_detect.od.pytorch.svm', ['BgdSVMTorch', 'SgdSVMTorch'])

0 commit comments

Comments
 (0)