Skip to content

Commit cd9daa1

Browse files
authored
Merge branch 'master' into deprecation-decorators
2 parents a310730 + 80972b5 commit cd9daa1

File tree

16 files changed

+252
-147
lines changed

16 files changed

+252
-147
lines changed

.github/workflows/code-style.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ jobs:
2222
code-style:
2323
runs-on: ubuntu-latest
2424
steps:
25-
- uses: actions/checkout@master
25+
- uses: actions/checkout@v2
2626
- uses: actions/setup-python@v2
2727
with:
28-
python-version: "3.7"
28+
python-version: "3.8"
2929
- run: |
3030
python -m pip install autopep8 "black==19.10b0" "isort==4.3.21"
3131
isort -rc .

CONTRIBUTING.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,17 @@ git merge upstream/master
203203
### Writing documentation
204204

205205
Ignite uses [Google style](http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html)
206-
for formatting docstrings. Length of line inside docstrings block must be limited to 120 characters.
206+
for formatting docstrings and
207+
208+
- [`.. versionadded::`] directive for adding new classes, class methods, functions,
209+
- [`.. versionchanged::`] directive for adding new arguments, changing internal behaviours, fixing bugs and
210+
- [`.. deprecated::`] directive for deprecations.
211+
212+
Length of line inside docstrings block must be limited to 120 characters.
213+
214+
[`.. versionadded::`]: https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded
215+
[`.. versionchanged::`]: https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionchanged
216+
[`.. deprecated::`]: https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-deprecated
207217

208218
#### Local documentation building and deploying
209219

docs/source/conf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,10 @@
193193
# -- Options for intersphinx extension ---------------------------------------
194194

195195
# Example configuration for intersphinx: refer to the Python standard library.
196-
intersphinx_mapping = {"https://docs.python.org/3/": None}
196+
intersphinx_mapping = {
197+
"python": ("https://docs.python.org/3", None),
198+
"torch": ("https://pytorch.org/docs/stable/", None),
199+
}
197200

198201
# -- Options for todo extension ----------------------------------------------
199202

ignite/handlers/checkpoint.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class Checkpoint(Serializable):
9898
details.
9999
include_self (bool): Whether to include the `state_dict` of this object in the checkpoint. If `True`, then
100100
there must not be another object in ``to_save`` with key ``checkpointer``.
101+
greater_or_equal (bool): if `True`, the latest equally scored model is stored. Otherwise, the first model.
102+
Default, `False`.
101103
102104
.. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
103105
torch.nn.parallel.DistributedDataParallel.html
@@ -245,6 +247,8 @@ def score_function(engine):
245247
trainer.run(data_loader, max_epochs=10)
246248
> ["best_model_9_val_acc=0.77.pt", "best_model_10_val_acc=0.78.pt", ]
247249
250+
.. versionchanged:: 0.4.3
251+
Added ``greater_or_equal`` parameter.
248252
"""
249253

250254
Item = NamedTuple("Item", [("priority", int), ("filename", str)])
@@ -261,6 +265,7 @@ def __init__(
261265
global_step_transform: Optional[Callable] = None,
262266
filename_pattern: Optional[str] = None,
263267
include_self: bool = False,
268+
greater_or_equal: bool = False,
264269
) -> None:
265270

266271
if to_save is not None: # for compatibility with ModelCheckpoint
@@ -301,6 +306,7 @@ def __init__(
301306
self.filename_pattern = filename_pattern
302307
self._saved = [] # type: List["Checkpoint.Item"]
303308
self.include_self = include_self
309+
self.greater_or_equal = greater_or_equal
304310

305311
def reset(self) -> None:
306312
"""Method to reset saved checkpoint names.
@@ -339,6 +345,12 @@ def _check_lt_n_saved(self, or_equal: bool = False) -> bool:
339345
return True
340346
return len(self._saved) < self.n_saved + int(or_equal)
341347

348+
def _compare_fn(self, new: Union[int, float]) -> bool:
349+
if self.greater_or_equal:
350+
return new >= self._saved[0].priority
351+
else:
352+
return new > self._saved[0].priority
353+
342354
def __call__(self, engine: Engine) -> None:
343355

344356
global_step = None
@@ -354,7 +366,7 @@ def __call__(self, engine: Engine) -> None:
354366
global_step = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED)
355367
priority = global_step
356368

357-
if self._check_lt_n_saved() or self._saved[0].priority < priority:
369+
if self._check_lt_n_saved() or self._compare_fn(priority):
358370

359371
priority_str = f"{priority}" if isinstance(priority, numbers.Integral) else f"{priority:.4f}"
360372

ignite/metrics/accuracy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,12 @@ def _check_type(self, output: Sequence[torch.Tensor]) -> None:
9292

9393

9494
class Accuracy(_BaseClassification):
95-
"""
96-
Calculates the accuracy for binary, multiclass and multilabel data.
95+
r"""Calculates the accuracy for binary, multiclass and multilabel data.
96+
97+
.. math:: \text{Accuracy} = \frac{ TP + TN }{ TP + TN + FP + FN }
98+
99+
where :math:`\text{TP}` is true positives, :math:`\text{TN}` is true negatives,
100+
:math:`\text{FP}` is false positives and :math:`\text{FN}` is false negatives.
97101
98102
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
99103
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...).

ignite/metrics/confusion_matrix.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ def normalize(matrix: torch.Tensor, average: str) -> torch.Tensor:
130130

131131

132132
def IoU(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambda:
133-
"""Calculates Intersection over Union using :class:`~ignite.metrics.ConfusionMatrix` metric.
133+
r"""Calculates Intersection over Union using :class:`~ignite.metrics.ConfusionMatrix` metric.
134+
135+
.. math:: \text{J}(A, B) = \frac{ \lvert A \cap B \rvert }{ \lvert A \cup B \rvert }
134136
135137
Args:
136138
cm (ConfusionMatrix): instance of confusion matrix metric

ignite/metrics/fbeta.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@ def Fbeta(
1717
output_transform: Optional[Callable] = None,
1818
device: Union[str, torch.device] = torch.device("cpu"),
1919
) -> MetricsLambda:
20-
"""Calculates F-beta score
20+
r"""Calculates F-beta score.
21+
22+
.. math::
23+
F_\beta = \left( 1 + \beta^2 \right) * \frac{ \text{precision} * \text{recall} }
24+
{ \left( \beta^2 * \text{precision} \right) + \text{recall} }
25+
26+
where :math:`\beta` is a positive real factor.
2127
2228
Args:
2329
beta (float): weight of precision in harmonic mean

ignite/metrics/mean_absolute_error.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99

1010

1111
class MeanAbsoluteError(Metric):
12-
"""
13-
Calculates the mean absolute error.
12+
r"""Calculates `the mean absolute error <https://en.wikipedia.org/wiki/Mean_absolute_error>`_.
13+
14+
.. math:: \text{MAE} = \frac{1}{N} \sum_{i=1}^N \lvert y_{i} - x_{i} \rvert
15+
16+
where :math:`y_{i}` is the prediction tensor and :math:`x_{i}` is ground true tensor.
1417
1518
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
1619
"""

ignite/metrics/mean_pairwise_distance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111

1212
class MeanPairwiseDistance(Metric):
13-
"""
14-
Calculates the mean pairwise distance: average of pairwise distances computed on provided batches.
13+
"""Calculates the mean :class:`~torch.nn.PairwiseDistance`.
14+
Average of pairwise distances computed on provided batches.
1515
1616
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
1717
"""

ignite/metrics/mean_squared_error.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99

1010

1111
class MeanSquaredError(Metric):
12-
"""
13-
Calculates the mean squared error.
12+
r"""Calculates the `mean squared error <https://en.wikipedia.org/wiki/Mean_squared_error>`_.
13+
14+
.. math:: \text{MSE} = \frac{1}{N} \sum_{i=1}^N \left(y_{i} - x_{i} \right)^2
15+
16+
where :math:`y_{i}` is the prediction tensor and :math:`x_{i}` is ground true tensor.
1417
1518
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
1619
"""

ignite/metrics/precision.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,11 @@ def compute(self) -> Union[torch.Tensor, float]:
6969

7070

7171
class Precision(_BasePrecisionRecall):
72-
"""
73-
Calculates precision for binary and multiclass data.
72+
r"""Calculates precision for binary and multiclass data.
73+
74+
.. math:: \text{Precision} = \frac{ TP }{ TP + FP }
75+
76+
where :math:`\text{TP}` is true positives and :math:`\text{FP}` is false positives.
7477
7578
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
7679
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...).

ignite/metrics/psnr.py

Lines changed: 27 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,16 @@ class PSNR(Metric):
1212
r"""Computes average `Peak signal-to-noise ratio (PSNR) <https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio>`_.
1313
1414
.. math::
15-
\text{PSNR}(i, j) = 10 * \log_{10}\left(\frac{ MAX_{i}^2 }{ \text{ MSE } }\right)
15+
\text{PSNR}(I, J) = 10 * \log_{10}\left(\frac{ MAX_{I}^2 }{ \text{ MSE } }\right)
1616
1717
where :math:`\text{MSE}` is `mean squared error <https://en.wikipedia.org/wiki/Mean_squared_error>`_.
1818
19-
Note:
20-
A few things to note:
21-
22-
- `y_pred` and `y` **must** have (batch_size, ...) shape.
23-
- `y_pred` and `y` **must** have same dtype and same shape.
24-
- If `y_pred` and `y` have almost identical values, the result will be infinity.
19+
- `y_pred` and `y` **must** have (batch_size, ...) shape.
20+
- `y_pred` and `y` **must** have same dtype and same shape.
2521
2622
Args:
27-
data_range (int or float, optional): The data range of the target image (distance between minimum
28-
and maximum possible values). If not provided, it will be estimated from the input
29-
data type: ``1.0`` for float/double tensor or ``255`` for unsigned 8-bit tensor.
23+
data_range (int or float): The data range of the target image (distance between minimum
24+
and maximum possible values).
3025
For other data types, please set the data range, otherwise an exception will be raised.
3126
output_transform (callable, optional): A callable that is used to transform the Engine’s
3227
process_function’s output into the form expected by the metric.
@@ -46,27 +41,42 @@ def process_function(engine, batch):
4641
# ...
4742
return y_pred, y
4843
engine = Engine(process_function)
49-
psnr = PSNR()
44+
psnr = PSNR(data_range=1.0)
5045
psnr.attach(engine, "psnr")
5146
# ...
5247
state = engine.run(data)
5348
print(f"PSNR: {state.metrics['psnr']}")
5449
50+
This metric by default accepts Grayscale or RGB images. But if you have YCbCr or YUV images, only
51+
Y channel is needed for computing PSNR. And, this can be done with ``output_transform``. For instance,
52+
53+
.. code-block:: python
54+
55+
def get_y_channel(output):
56+
y_pred, y = output
57+
# y_pred and y are (B, 3, H, W) and YCbCr or YUV images
58+
# let's select y channel
59+
return y_pred[:, 0, ...], y[:, 0, ...]
60+
61+
psnr = PSNR(data_range=219, output_transform=get_y_channel)
62+
psnr.attach(engine, "psnr")
63+
# ...
64+
state = engine.run(data)
65+
print(f"PSNR: {state.metrics['psrn']}")
66+
5567
.. versionadded:: 0.5.0
5668
"""
5769

5870
def __init__(
5971
self,
60-
data_range: Optional[Union[int, float]] = None,
72+
data_range: Union[int, float],
6173
output_transform: Callable = lambda x: x,
6274
device: Union[str, torch.device] = torch.device("cpu"),
6375
):
6476
super().__init__(output_transform=output_transform, device=device)
6577
self.data_range = data_range
6678

67-
def _check_shape_dtype_drange(
68-
self, output: Sequence[torch.Tensor], data_range: Union[int, float, None]
69-
) -> Union[int, float]:
79+
def _check_shape_dtype(self, output: Sequence[torch.Tensor]) -> None:
7080
y_pred, y = output
7181
if y_pred.dtype != y.dtype:
7282
raise TypeError(
@@ -78,40 +88,19 @@ def _check_shape_dtype_drange(
7888
f"Expected y_pred and y to have the same shape. Got y_pred: {y_pred.shape} and y: {y.shape}."
7989
)
8090

81-
if data_range is None:
82-
try:
83-
dmin, dmax = _dtype_range[y.dtype]
84-
except KeyError:
85-
raise ValueError(
86-
"Range for this dtype cannot be automatically estimated. Please manually specify the data_range."
87-
)
88-
true_min, true_max = y.min(), y.max()
89-
if true_max > dmax or true_min < dmin:
90-
raise ValueError(
91-
"y has intensity values outside the range expected "
92-
"for its data type. Please manually specify the data_range."
93-
)
94-
if true_min >= 0:
95-
# most common case (255 for uint8, 1 for float)
96-
data_range = dmax
97-
else:
98-
data_range = dmax - dmin
99-
return data_range
100-
10191
@reinit__is_reduced
10292
def reset(self) -> None:
10393
self._sum_of_batchwise_psnr = torch.tensor(0.0, dtype=torch.float64, device=self._device)
10494
self._num_examples = 0
10595

10696
@reinit__is_reduced
10797
def update(self, output: Sequence[torch.Tensor]) -> None:
108-
data_range = self.data_range
109-
data_range = self._check_shape_dtype_drange(output, data_range)
98+
self._check_shape_dtype(output)
11099
y_pred, y = output[0].detach(), output[1].detach()
111100

112101
dim = tuple(range(1, y.ndim))
113102
mse_error = torch.pow(y_pred.double() - y.view_as(y_pred).double(), 2).mean(dim=dim)
114-
self._sum_of_batchwise_psnr += torch.sum(10.0 * torch.log10(data_range ** 2 / mse_error)).to(
103+
self._sum_of_batchwise_psnr += torch.sum(10.0 * torch.log10(self.data_range ** 2 / (mse_error + 1e-10))).to(
115104
device=self._device
116105
)
117106
self._num_examples += y.shape[0]
@@ -121,11 +110,3 @@ def compute(self) -> torch.Tensor:
121110
if self._num_examples == 0:
122111
raise NotComputableError("PSNR must have at least one example before it can be computed.")
123112
return self._sum_of_batchwise_psnr / self._num_examples
124-
125-
126-
_dtype_range = {
127-
torch.uint8: (0, 255),
128-
torch.float16: (-1.0, 1.0),
129-
torch.float32: (-1.0, 1.0),
130-
torch.float64: (-1.0, 1.0),
131-
}

ignite/metrics/recall.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010

1111

1212
class Recall(_BasePrecisionRecall):
13-
"""
14-
Calculates recall for binary and multiclass data.
13+
r"""Calculates recall for binary and multiclass data.
14+
15+
.. math:: \text{Recall} = \frac{ TP }{ TP + FN }
16+
17+
where :math:`\text{TP}` is true positives and :math:`\text{FN}` is false negatives.
1518
1619
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
1720
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...).

ignite/metrics/root_mean_squared_error.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99

1010

1111
class RootMeanSquaredError(MeanSquaredError):
12-
"""
13-
Calculates the root mean squared error.
12+
r"""Calculates the `root mean squared error <https://en.wikipedia.org/wiki/Root-mean-square_deviation>`_.
13+
14+
.. math:: \text{RMSE} = \sqrt{ \frac{1}{N} \sum_{i=1}^N \left(y_{i} - x_{i} \right)^2 }
15+
16+
where :math:`y_{i}` is the prediction tensor and :math:`x_{i}` is ground true tensor.
1417
1518
- ``update`` must receive output of the form (y_pred, y) or `{'y_pred': y_pred, 'y': y}`.
1619
"""

tests/ignite/handlers/test_checkpoint.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,3 +1526,36 @@ def test_checkpoint_reset_with_engine(dirname):
15261526
expected += [f"{_PREFIX}_{name}_{i}.pt" for i in [1 * 2, 2 * 2]]
15271527
assert sorted(os.listdir(dirname)) == sorted(expected)
15281528
assert "PREFIX_model_4.pt" in handler.last_checkpoint
1529+
1530+
1531+
def test_greater_or_equal():
1532+
scores = iter([1, 2, 2, 2])
1533+
1534+
def score_function(_):
1535+
return next(scores)
1536+
1537+
class Saver:
1538+
def __init__(self):
1539+
self.counter = 0
1540+
1541+
def __call__(self, c, f, m):
1542+
if self.counter == 0:
1543+
assert f == "model_1.pt"
1544+
else:
1545+
assert f == "model_2.pt"
1546+
self.counter += 1
1547+
1548+
handler = Saver()
1549+
1550+
checkpointer = Checkpoint(
1551+
to_save={"model": DummyModel()},
1552+
save_handler=handler,
1553+
score_function=score_function,
1554+
n_saved=2,
1555+
greater_or_equal=True,
1556+
)
1557+
trainer = Engine(lambda e, b: None)
1558+
1559+
for _ in range(4):
1560+
checkpointer(trainer)
1561+
assert handler.counter == 4

0 commit comments

Comments
 (0)