Skip to content

Commit 7753eab

Browse files
fco-dvvfdev-5ahmedo421nF0rmedpradyumnar-sahaj
authored
ddp precision recall (#1646)
* Recall/Precision metrics for ddp : average == false and multilabel == true * For v0.4.3 - Add more versionadded, versionchanged tags - Change v0.5… (#1612) * For v0.4.3 - Add more versionadded, versionchanged tags - Change v0.5.0 to v0.4.3 * Update ignite/contrib/metrics/regression/canberra_metric.py Co-authored-by: vfdev <[email protected]> * Update ignite/contrib/metrics/regression/manhattan_distance.py Co-authored-by: vfdev <[email protected]> * Update ignite/contrib/metrics/regression/r2_score.py Co-authored-by: vfdev <[email protected]> * Update ignite/handlers/checkpoint.py Co-authored-by: vfdev <[email protected]> * address PR comments Co-authored-by: vfdev <[email protected]> * added TimeLimit handler with its test and doc (#1611) * added TimeLimit handler with its test and doc * fixed documentation * fixed docstring and formatting * flake8 fix trailing whitespace :) * modified class logger , default value and tests * changed rounding to nearest integer * tests refactored , docs modified * fixed default value , removed global logger * fixing formatting * Added versionadded * added test for engine termination Co-authored-by: vfdev <[email protected]> * Update handlers to use setup_logger (#1617) * Fixes #1614 - Updated handlers EarlyStopping and TerminateOnNan - Replaced `logging.getLogger` with `setup_logger` in the mentioned handlers * Updated `TimeLimit` handler. Replaced use of `logger.getLogger` with `setup_logger` from `ignite.utils` Co-authored-by: Pradyumna Rahul K <[email protected]> Co-authored-by: Sylvain Desroziers <[email protected]> * Managing Deprecation using decorators (#1585) * Starter code for managing deprecation * Make functions deprecated using the `@deprecated` decorator * Add arguments to the @deprecated decorator to customize it for each function * Improve `@deprecated` decorator and add tests * Replaced the `raise` keyword with added `warnings` * Added tests several possibilities of the decorator usage * Removing the test deprecation to check tests * Add static typing, fix mypy errors * Make `@deprecated` to raise Exceptions or Warning * The `@deprecated` decorator will now always emit warning unless explicitly asked to raise an Exception * Fix mypy errors * Fix mypy errors (hopefully) * Fix the test `test_deprecated_setup_any_logging` * Change the test to work with the `@deprecated` decorator * Change to snake_case, handle mypy ignores * Improve Type Annotations * Update common.py * For v0.4.3 - Add more versionadded, versionchanged tags - Change v0.5… (#1612) * For v0.4.3 - Add more versionadded, versionchanged tags - Change v0.5.0 to v0.4.3 * Update ignite/contrib/metrics/regression/canberra_metric.py Co-authored-by: vfdev <[email protected]> * Update ignite/contrib/metrics/regression/manhattan_distance.py Co-authored-by: vfdev <[email protected]> * Update ignite/contrib/metrics/regression/r2_score.py Co-authored-by: vfdev <[email protected]> * Update ignite/handlers/checkpoint.py Co-authored-by: vfdev <[email protected]> * address PR comments Co-authored-by: vfdev <[email protected]> * `version` -> version Co-authored-by: vfdev <[email protected]> Co-authored-by: François COKELAER <[email protected]> Co-authored-by: Sylvain Desroziers <[email protected]> * Create documentation.md * Distributed tests on Windows should be skipped until fixed. (#1620) * modified CONTRIBUTING.md * bash instead of sh * Added Checkpoint.get_default_score_fn (#1621) * Added Checkpoint.get_default_score_fn to simplify best_model_handler creation * Added score_sign argument * Updated docs * Update about.rst * Update pre-commit hooks and CONTRIBUTING.md (#1622) * Change pre-commit config and CONTRIBUTING.md - Update hook versions - Remove seed-isort-config - Add black profile to isort * Fix files based on new pre-commit config * Add meaningful exclusions to prettier - Also update actions workflow files to match local pre-commit * added requirements.txt and updated readme.md (#1624) * added requirements.txt and updated readme.md * Update examples/contrib/cifar10/README.md Co-authored-by: vfdev <[email protected]> * Update examples/contrib/cifar10/requirements.txt Co-authored-by: vfdev <[email protected]> Co-authored-by: vfdev <[email protected]> * Replace relative paths with raw.githubusercontent (#1629) * Updated cifar10 example (#1632) * Updates for cifar10 example * Updates for cifar10 example * More updates * Updated code * Fixed code-formatting * Fixed failling CI and typos for cifar10 examples (#1633) * Updates for cifar10 example * Updates for cifar10 example * More updates * Updated code * Fixed code-formatting * Fixed typo and failing CI * Fixed hvd spawn fail and better synced qat code * Removed temporary hack to install pth 1.7.1 (#1638) - updated default pth image for gpu tests - updated TORCH_CUDA_ARCH_LIST - fixed /merge -> /head in trigger ci pipeline * [docker] Pillow -> Pillow-SIMD (#1509) (#1639) * [docker] Pillow -> Pillow-SIMD (#1509) * [docker] Pillow -> Pillow-SIMD * replace pillow with pillow-simd in base docker files * chore(docker): apt-get autoremove after pillow-simd installation * apt-get install at once, autoremove g++ * install g++ in pillow installation layer Co-authored-by: Sylvain Desroziers <[email protected]> * Fix g++ install issue Co-authored-by: Jeff Yang <[email protected]> Co-authored-by: Sylvain Desroziers <[email protected]> * Fix multinode tests script (#1631) * fix run_multinode_tests_in_docker.sh : run tests with docker python version * add missing modules * build an image with test env and add 'nnodes' 'nproc_per_node' 'gpu' as parameters * #1615 : change nproc_per_node default to 4 * #1615 : fix for gpu enabled tests / container rm step at the end of the script * add xfail decorator for tests/ignite/engine/test_deterministic.py::test_multinode_distrib_cpu * fix script gpu_options * add default tol=1e-6 for _test_distrib_compute_on_criterion * fix for "RuntimeError: trying to initialize the default process group twice!" * tolerance for test_multinode_distrib_cpu case only * fix assert None error * autopep8 fix Co-authored-by: vfdev <[email protected]> Co-authored-by: Sylvain Desroziers <[email protected]> Co-authored-by: fco-dv <[email protected]> * remove warning for average=False and is_multilabel=True * update docstring and {precision, recall} tests according to test_multilabel_input_NCHW Co-authored-by: vfdev <[email protected]> Co-authored-by: Ahmed Omar <[email protected]> Co-authored-by: Pradyumna Rahul <[email protected]> Co-authored-by: Pradyumna Rahul K <[email protected]> Co-authored-by: Sylvain Desroziers <[email protected]> Co-authored-by: Devanshu Shah <[email protected]> Co-authored-by: Debojyoti Chakraborty <[email protected]> Co-authored-by: Jeff Yang <[email protected]> Co-authored-by: fco-dv <[email protected]>
1 parent 41925f0 commit 7753eab

File tree

4 files changed

+35
-63
lines changed

4 files changed

+35
-63
lines changed

ignite/metrics/precision.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,6 @@ def __init__(
2020
is_multilabel: bool = False,
2121
device: Union[str, torch.device] = torch.device("cpu"),
2222
):
23-
if idist.get_world_size() > 1:
24-
if (not average) and is_multilabel:
25-
warnings.warn(
26-
"Precision/Recall metrics do not work in distributed setting when average=False "
27-
"and is_multilabel=True. Results are not reduced across computing devices. Computed result "
28-
"corresponds to the local rank's (single process) result.",
29-
RuntimeWarning,
30-
)
3123

3224
self._average = average
3325
self.eps = 1e-20
@@ -53,12 +45,14 @@ def compute(self) -> Union[torch.Tensor, float]:
5345
raise NotComputableError(
5446
f"{self.__class__.__name__} must have at least one example before it can be computed."
5547
)
56-
57-
if not (self._type == "multilabel" and not self._average):
58-
if not self._is_reduced:
48+
if not self._is_reduced:
49+
if not (self._type == "multilabel" and not self._average):
5950
self._true_positives = idist.all_reduce(self._true_positives) # type: ignore[assignment]
6051
self._positives = idist.all_reduce(self._positives) # type: ignore[assignment]
61-
self._is_reduced = True # type: bool
52+
else:
53+
self._true_positives = cast(torch.Tensor, idist.all_gather(self._true_positives))
54+
self._positives = cast(torch.Tensor, idist.all_gather(self._positives))
55+
self._is_reduced = True # type: bool
6256

6357
result = self._true_positives / (self._positives + self.eps)
6458

@@ -107,11 +101,6 @@ def thresholded_output_transform(output):
107101
as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger
108102
than available RAM.
109103
110-
.. warning::
111-
112-
In multilabel cases, if average is False, current implementation does not work with distributed computations.
113-
Results are not reduced across the GPUs. Computed result corresponds to the local rank's (single GPU) result.
114-
115104
116105
Args:
117106
output_transform (callable, optional): a callable that is used to transform the

ignite/metrics/recall.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,6 @@ def thresholded_output_transform(output):
4848
as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger
4949
than available RAM.
5050
51-
.. warning::
52-
53-
In multilabel cases, if average is False, current implementation does not work with distributed computations.
54-
Results are not reduced across the GPUs. Computed result corresponds to the local rank's (single GPU) result.
55-
5651
5752
Args:
5853
output_transform (callable, optional): a callable that is used to transform the

tests/ignite/metrics/test_precision.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,7 @@ def update(engine, i):
792792

793793
engine = Engine(update)
794794

795-
pr = Precision(average=average, is_multilabel=True)
795+
pr = Precision(average=average, is_multilabel=True, device=metric_device)
796796
pr.attach(engine, "pr")
797797

798798
data = list(range(n_iters))
@@ -808,13 +808,13 @@ def update(engine, i):
808808
else:
809809
assert res == res2
810810

811+
np_y_preds = to_numpy_multilabel(y_preds)
812+
np_y_true = to_numpy_multilabel(y_true)
813+
assert pr._type == "multilabel"
814+
res = res if average else res.mean().item()
811815
with warnings.catch_warnings():
812816
warnings.simplefilter("ignore", category=UndefinedMetricWarning)
813-
true_res = precision_score(
814-
to_numpy_multilabel(y_true), to_numpy_multilabel(y_preds), average="samples" if average else None
815-
)
816-
817-
assert pytest.approx(res) == true_res
817+
assert precision_score(np_y_true, np_y_preds, average="samples") == pytest.approx(res)
818818

819819
metric_devices = ["cpu"]
820820
if device.type != "xla":
@@ -823,22 +823,16 @@ def update(engine, i):
823823
for metric_device in metric_devices:
824824
_test(average=True, n_epochs=1, metric_device=metric_device)
825825
_test(average=True, n_epochs=2, metric_device=metric_device)
826+
_test(average=False, n_epochs=1, metric_device=metric_device)
827+
_test(average=False, n_epochs=2, metric_device=metric_device)
826828

827-
if idist.get_world_size() > 1:
828-
with pytest.warns(
829-
RuntimeWarning,
830-
match="Precision/Recall metrics do not work in distributed setting when "
831-
"average=False and is_multilabel=True",
832-
):
833-
pr = Precision(average=False, is_multilabel=True)
834-
835-
y_pred = torch.randint(0, 2, size=(4, 3, 6, 8))
836-
y = torch.randint(0, 2, size=(4, 3, 6, 8)).long()
837-
pr.update((y_pred, y))
838-
pr_compute1 = pr.compute()
839-
pr_compute2 = pr.compute()
840-
assert len(pr_compute1) == 4 * 6 * 8
841-
assert (pr_compute1 == pr_compute2).all()
829+
pr1 = Precision(is_multilabel=True, average=True)
830+
pr2 = Precision(is_multilabel=True, average=False)
831+
y_pred = torch.randint(0, 2, size=(10, 4, 20, 23))
832+
y = torch.randint(0, 2, size=(10, 4, 20, 23)).long()
833+
pr1.update((y_pred, y))
834+
pr2.update((y_pred, y))
835+
assert pr1.compute() == pytest.approx(pr2.compute().mean().item())
842836

843837

844838
def _test_distrib_accumulator_device(device):

tests/ignite/metrics/test_recall.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -808,13 +808,13 @@ def update(engine, i):
808808
else:
809809
assert res == res2
810810

811+
np_y_preds = to_numpy_multilabel(y_preds)
812+
np_y_true = to_numpy_multilabel(y_true)
813+
assert re._type == "multilabel"
814+
res = res if average else res.mean().item()
811815
with warnings.catch_warnings():
812816
warnings.simplefilter("ignore", category=UndefinedMetricWarning)
813-
true_res = recall_score(
814-
to_numpy_multilabel(y_true), to_numpy_multilabel(y_preds), average="samples" if average else None
815-
)
816-
817-
assert pytest.approx(res) == true_res
817+
assert recall_score(np_y_true, np_y_preds, average="samples") == pytest.approx(res)
818818

819819
metric_devices = ["cpu"]
820820
if device.type != "xla":
@@ -823,22 +823,16 @@ def update(engine, i):
823823
for metric_device in metric_devices:
824824
_test(average=True, n_epochs=1, metric_device=metric_device)
825825
_test(average=True, n_epochs=2, metric_device=metric_device)
826+
_test(average=False, n_epochs=1, metric_device=metric_device)
827+
_test(average=False, n_epochs=2, metric_device=metric_device)
826828

827-
if idist.get_world_size() > 1:
828-
with pytest.warns(
829-
RuntimeWarning,
830-
match="Precision/Recall metrics do not work in distributed setting when "
831-
"average=False and is_multilabel=True",
832-
):
833-
re = Recall(average=False, is_multilabel=True)
834-
835-
y_pred = torch.randint(0, 2, size=(4, 3, 6, 8))
836-
y = torch.randint(0, 2, size=(4, 3, 6, 8)).long()
837-
re.update((y_pred, y))
838-
re_compute1 = re.compute()
839-
re_compute2 = re.compute()
840-
assert len(re_compute1) == 4 * 6 * 8
841-
assert (re_compute1 == re_compute2).all()
829+
re1 = Recall(is_multilabel=True, average=True)
830+
re2 = Recall(is_multilabel=True, average=False)
831+
y_pred = torch.randint(0, 2, size=(10, 4, 20, 23))
832+
y = torch.randint(0, 2, size=(10, 4, 20, 23)).long()
833+
re1.update((y_pred, y))
834+
re2.update((y_pred, y))
835+
assert re1.compute() == pytest.approx(re2.compute().mean().item())
842836

843837

844838
def _test_distrib_accumulator_device(device):

0 commit comments

Comments
 (0)