From f67c0704e1dead6a71807a4b209aaec6d5a93777 Mon Sep 17 00:00:00 2001 From: ravitezu Date: Mon, 5 Oct 2020 23:11:29 +0530 Subject: [PATCH 1/3] [3] [contrib/metrics] setup typing in contrib part of the library --- ignite/contrib/metrics/average_precision.py | 8 +++++--- ignite/contrib/metrics/gpu_info.py | 9 +++++---- .../contrib/metrics/precision_recall_curve.py | 8 +++++--- ignite/contrib/metrics/regression/_base.py | 18 ++++++++++-------- .../metrics/regression/canberra_metric.py | 4 ++-- .../regression/fractional_absolute_error.py | 4 +++- .../metrics/regression/fractional_bias.py | 4 +++- .../geometric_mean_absolute_error.py | 4 +++- .../geometric_mean_relative_absolute_error.py | 4 +++- .../metrics/regression/manhattan_distance.py | 4 ++-- .../regression/maximum_absolute_error.py | 4 +++- .../regression/mean_absolute_relative_error.py | 4 +++- .../contrib/metrics/regression/mean_error.py | 4 +++- .../metrics/regression/mean_normalized_bias.py | 4 +++- .../regression/median_absolute_error.py | 6 ++++-- .../median_absolute_percentage_error.py | 6 ++++-- .../median_relative_absolute_error.py | 6 ++++-- ignite/contrib/metrics/regression/r2_score.py | 6 +++--- .../metrics/regression/wave_hedges_distance.py | 4 +++- ignite/contrib/metrics/roc_auc.py | 14 ++++++++------ 20 files changed, 79 insertions(+), 46 deletions(-) diff --git a/ignite/contrib/metrics/average_precision.py b/ignite/contrib/metrics/average_precision.py index 589dbe506213..042a5b5ddf6e 100644 --- a/ignite/contrib/metrics/average_precision.py +++ b/ignite/contrib/metrics/average_precision.py @@ -1,7 +1,9 @@ +from typing import Any, Callable + from ignite.metrics import EpochMetric -def average_precision_compute_fn(y_preds, y_targets): +def average_precision_compute_fn(y_preds: Any, y_targets: Any): try: from sklearn.metrics import average_precision_score except ImportError: @@ -22,7 +24,7 @@ class AveragePrecision(EpochMetric): :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. - check_compute_fn (bool): Optional default False. If True, `average_precision_score + check_compute_fn (bool): Default False. If True, `average_precision_score `_ is run on the first batch of data to ensure there are no issues. User will be warned in case there are any issues computing the function. @@ -41,7 +43,7 @@ def activated_output_transform(output): """ - def __init__(self, output_transform=lambda x: x, check_compute_fn: bool = False): + def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False): super(AveragePrecision, self).__init__( average_precision_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn ) diff --git a/ignite/contrib/metrics/gpu_info.py b/ignite/contrib/metrics/gpu_info.py index 35e848f39058..3d1382604271 100644 --- a/ignite/contrib/metrics/gpu_info.py +++ b/ignite/contrib/metrics/gpu_info.py @@ -1,9 +1,10 @@ # -*- coding: utf-8 -*- import warnings +from typing import Any import torch -from ignite.engine import Events +from ignite.engine import Engine, EventEnum, Events from ignite.metrics import Metric @@ -54,7 +55,7 @@ def __init__(self): def reset(self): pass - def update(self, output): + def update(self, output: Any): pass def compute(self): @@ -64,7 +65,7 @@ def compute(self): return [] return data["gpu"] - def completed(self, engine, name): + def completed(self, engine: Engine, name: str): data = self.compute() if len(data) < 1: warnings.warn("No GPU information available") @@ -103,5 +104,5 @@ def completed(self, engine, name): # Do not set GPU utilization information pass - def attach(self, engine, name="gpu", event_name=Events.ITERATION_COMPLETED): + def attach(self, engine: Engine, name: str = "gpu", event_name: EventEnum = Events.ITERATION_COMPLETED): engine.add_event_handler(event_name, self.completed, name) diff --git a/ignite/contrib/metrics/precision_recall_curve.py b/ignite/contrib/metrics/precision_recall_curve.py index cc6cf18595de..d5394f41bffb 100644 --- a/ignite/contrib/metrics/precision_recall_curve.py +++ b/ignite/contrib/metrics/precision_recall_curve.py @@ -1,7 +1,9 @@ +from typing import Any, Callable + from ignite.metrics import EpochMetric -def precision_recall_curve_compute_fn(y_preds, y_targets): +def precision_recall_curve_compute_fn(y_preds: Any, y_targets: Any): try: from sklearn.metrics import precision_recall_curve except ImportError: @@ -23,7 +25,7 @@ class PrecisionRecallCurve(EpochMetric): :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. - check_compute_fn (bool): Optional default False. If True, `precision_recall_curve + check_compute_fn (bool): Default False. If True, `precision_recall_curve `_ is run on the first batch of data to ensure there are no issues. User will be warned in case there are any issues computing the function. @@ -42,7 +44,7 @@ def activated_output_transform(output): """ - def __init__(self, output_transform=lambda x: x, check_compute_fn: bool = False): + def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False): super(PrecisionRecallCurve, self).__init__( precision_recall_curve_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn ) diff --git a/ignite/contrib/metrics/regression/_base.py b/ignite/contrib/metrics/regression/_base.py index b08cf655e6f5..caae16dfc21a 100644 --- a/ignite/contrib/metrics/regression/_base.py +++ b/ignite/contrib/metrics/regression/_base.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Callable, Union +from typing import Any, Callable import torch @@ -7,7 +7,7 @@ from ignite.metrics.metric import reinit__is_reduced -def _check_output_shapes(output): +def _check_output_shapes(output: Any): y_pred, y = output if y_pred.shape != y.shape: raise ValueError("Input data shapes should be the same, but given {} and {}".format(y_pred.shape, y.shape)) @@ -21,7 +21,7 @@ def _check_output_shapes(output): raise ValueError("Input y should have shape (N,) or (N, 1), but given {}".format(y.shape)) -def _check_output_types(output): +def _check_output_types(output: Any): y_pred, y = output if y_pred.dtype not in (torch.float16, torch.float32, torch.float64): raise TypeError("Input y_pred dtype should be float 16, 32 or 64, but given {}".format(y_pred.dtype)) @@ -36,7 +36,7 @@ class _BaseRegression(Metric): # method `_update`. @reinit__is_reduced - def update(self, output): + def update(self, output: Any): _check_output_shapes(output) _check_output_types(output) y_pred, y = output[0].detach(), output[1].detach() @@ -50,7 +50,7 @@ def update(self, output): self._update((y_pred, y)) @abstractmethod - def _update(self, output): + def _update(self, output: Any): pass @@ -59,14 +59,16 @@ class _BaseRegressionEpoch(EpochMetric): # `update` method check the shapes and call internal overloaded method `_update`. # Class internally stores complete history of predictions and targets of type float32. - def __init__(self, compute_fn, output_transform=lambda x: x, check_compute_fn: bool = True): + def __init__( + self, compute_fn: Callable, output_transform: Callable = lambda x: x, check_compute_fn: bool = True, + ): super(_BaseRegressionEpoch, self).__init__( compute_fn=compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn ) - def _check_type(self, output): + def _check_type(self, output: Any): _check_output_types(output) super(_BaseRegressionEpoch, self)._check_type(output) - def _check_shape(self, output): + def _check_shape(self, output: Any): _check_output_shapes(output) diff --git a/ignite/contrib/metrics/regression/canberra_metric.py b/ignite/contrib/metrics/regression/canberra_metric.py index 23cba54ea2cf..40ed893a3f3d 100644 --- a/ignite/contrib/metrics/regression/canberra_metric.py +++ b/ignite/contrib/metrics/regression/canberra_metric.py @@ -1,4 +1,4 @@ -from typing import Callable, Union +from typing import Any, Callable, Union import torch @@ -34,7 +34,7 @@ def __init__( def reset(self): self._sum_of_errors = torch.tensor(0.0, device=self._device) - def _update(self, output): + def _update(self, output: Any): y_pred, y = output errors = torch.abs(y - y_pred) / (torch.abs(y_pred) + torch.abs(y)) self._sum_of_errors += torch.sum(errors).to(self._device) diff --git a/ignite/contrib/metrics/regression/fractional_absolute_error.py b/ignite/contrib/metrics/regression/fractional_absolute_error.py index 4e554c1e5a57..f4436477f0e1 100644 --- a/ignite/contrib/metrics/regression/fractional_absolute_error.py +++ b/ignite/contrib/metrics/regression/fractional_absolute_error.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from ignite.contrib.metrics.regression._base import _BaseRegression @@ -24,7 +26,7 @@ def reset(self): self._sum_of_errors = 0.0 self._num_examples = 0 - def _update(self, output): + def _update(self, output: Any): y_pred, y = output errors = 2 * torch.abs(y.view_as(y_pred) - y_pred) / (torch.abs(y_pred) + torch.abs(y.view_as(y_pred))) self._sum_of_errors += torch.sum(errors).item() diff --git a/ignite/contrib/metrics/regression/fractional_bias.py b/ignite/contrib/metrics/regression/fractional_bias.py index 335f3a975e76..50d2edbde395 100644 --- a/ignite/contrib/metrics/regression/fractional_bias.py +++ b/ignite/contrib/metrics/regression/fractional_bias.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from ignite.contrib.metrics.regression._base import _BaseRegression @@ -25,7 +27,7 @@ def reset(self): self._sum_of_errors = 0.0 self._num_examples = 0 - def _update(self, output): + def _update(self, output: Any): y_pred, y = output errors = 2 * (y.view_as(y_pred) - y_pred) / (y_pred + y.view_as(y_pred)) self._sum_of_errors += torch.sum(errors).item() diff --git a/ignite/contrib/metrics/regression/geometric_mean_absolute_error.py b/ignite/contrib/metrics/regression/geometric_mean_absolute_error.py index f610bcb9b06a..959f03744675 100644 --- a/ignite/contrib/metrics/regression/geometric_mean_absolute_error.py +++ b/ignite/contrib/metrics/regression/geometric_mean_absolute_error.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from ignite.contrib.metrics.regression._base import _BaseRegression @@ -24,7 +26,7 @@ def reset(self): self._sum_of_errors = 0.0 self._num_examples = 0 - def _update(self, output): + def _update(self, output: Any): y_pred, y = output errors = torch.log(torch.abs(y.view_as(y_pred) - y_pred)) self._sum_of_errors += torch.sum(errors) diff --git a/ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py b/ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py index 5da526b390fd..70dfb7d7ec0c 100644 --- a/ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py +++ b/ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from ignite.contrib.metrics.regression._base import _BaseRegression @@ -26,7 +28,7 @@ def reset(self): self._num_examples = 0 self._sum_of_errors = 0.0 - def _update(self, output): + def _update(self, output: Any): y_pred, y = output self._sum_y += y.sum() self._num_examples += y.shape[0] diff --git a/ignite/contrib/metrics/regression/manhattan_distance.py b/ignite/contrib/metrics/regression/manhattan_distance.py index e8624284c3c0..80be4c59c0ad 100644 --- a/ignite/contrib/metrics/regression/manhattan_distance.py +++ b/ignite/contrib/metrics/regression/manhattan_distance.py @@ -1,4 +1,4 @@ -from typing import Callable, Union +from typing import Any, Callable, Union import torch @@ -33,7 +33,7 @@ def __init__( def reset(self): self._sum_of_errors = torch.tensor(0.0, device=self._device) - def _update(self, output): + def _update(self, output: Any): y_pred, y = output errors = torch.abs(y - y_pred) self._sum_of_errors += torch.sum(errors).to(self._device) diff --git a/ignite/contrib/metrics/regression/maximum_absolute_error.py b/ignite/contrib/metrics/regression/maximum_absolute_error.py index 4c63d243114c..0e0aed4fbf9a 100644 --- a/ignite/contrib/metrics/regression/maximum_absolute_error.py +++ b/ignite/contrib/metrics/regression/maximum_absolute_error.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from ignite.contrib.metrics.regression._base import _BaseRegression @@ -24,7 +26,7 @@ class MaximumAbsoluteError(_BaseRegression): def reset(self): self._max_of_absolute_errors = -1 - def _update(self, output): + def _update(self, output: Any): y_pred, y = output mae = torch.abs(y_pred - y.view_as(y_pred)).max().item() if self._max_of_absolute_errors < mae: diff --git a/ignite/contrib/metrics/regression/mean_absolute_relative_error.py b/ignite/contrib/metrics/regression/mean_absolute_relative_error.py index e423c0b92eb7..97a490bd9619 100644 --- a/ignite/contrib/metrics/regression/mean_absolute_relative_error.py +++ b/ignite/contrib/metrics/regression/mean_absolute_relative_error.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from ignite.contrib.metrics.regression._base import _BaseRegression @@ -25,7 +27,7 @@ def reset(self): self._sum_of_absolute_relative_errors = 0.0 self._num_samples = 0 - def _update(self, output): + def _update(self, output: Any): y_pred, y = output if (y == 0).any(): raise NotComputableError("The ground truth has 0.") diff --git a/ignite/contrib/metrics/regression/mean_error.py b/ignite/contrib/metrics/regression/mean_error.py index f95f9644fd57..fc36c7330349 100644 --- a/ignite/contrib/metrics/regression/mean_error.py +++ b/ignite/contrib/metrics/regression/mean_error.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from ignite.contrib.metrics.regression._base import _BaseRegression @@ -25,7 +27,7 @@ def reset(self): self._sum_of_errors = 0.0 self._num_examples = 0 - def _update(self, output): + def _update(self, output: Any): y_pred, y = output errors = y.view_as(y_pred) - y_pred self._sum_of_errors += torch.sum(errors).item() diff --git a/ignite/contrib/metrics/regression/mean_normalized_bias.py b/ignite/contrib/metrics/regression/mean_normalized_bias.py index d1ea8603a9dd..4bad5af5c7b9 100644 --- a/ignite/contrib/metrics/regression/mean_normalized_bias.py +++ b/ignite/contrib/metrics/regression/mean_normalized_bias.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from ignite.contrib.metrics.regression._base import _BaseRegression @@ -25,7 +27,7 @@ def reset(self): self._sum_of_errors = 0.0 self._num_examples = 0 - def _update(self, output): + def _update(self, output: Any): y_pred, y = output if (y == 0).any(): diff --git a/ignite/contrib/metrics/regression/median_absolute_error.py b/ignite/contrib/metrics/regression/median_absolute_error.py index 215198e2be7e..2e0e16568e75 100644 --- a/ignite/contrib/metrics/regression/median_absolute_error.py +++ b/ignite/contrib/metrics/regression/median_absolute_error.py @@ -1,9 +1,11 @@ +from typing import Any, Callable + import torch from ignite.contrib.metrics.regression._base import _BaseRegressionEpoch -def median_absolute_error_compute_fn(y_pred, y): +def median_absolute_error_compute_fn(y_pred: Any, y: Any): e = torch.abs(y.view_as(y_pred) - y_pred) return torch.median(e).item() @@ -31,5 +33,5 @@ class MedianAbsoluteError(_BaseRegressionEpoch): """ - def __init__(self, output_transform=lambda x: x): + def __init__(self, output_transform: Callable = lambda x: x): super(MedianAbsoluteError, self).__init__(median_absolute_error_compute_fn, output_transform) diff --git a/ignite/contrib/metrics/regression/median_absolute_percentage_error.py b/ignite/contrib/metrics/regression/median_absolute_percentage_error.py index 9c30afa4ad41..4657138981b3 100644 --- a/ignite/contrib/metrics/regression/median_absolute_percentage_error.py +++ b/ignite/contrib/metrics/regression/median_absolute_percentage_error.py @@ -1,9 +1,11 @@ +from typing import Any, Callable + import torch from ignite.contrib.metrics.regression._base import _BaseRegressionEpoch -def median_absolute_percentage_error_compute_fn(y_pred, y): +def median_absolute_percentage_error_compute_fn(y_pred: Any, y: Any): e = torch.abs(y.view_as(y_pred) - y_pred) / torch.abs(y.view_as(y_pred)) return 100.0 * torch.median(e).item() @@ -31,7 +33,7 @@ class MedianAbsolutePercentageError(_BaseRegressionEpoch): """ - def __init__(self, output_transform=lambda x: x): + def __init__(self, output_transform: Callable = lambda x: x): super(MedianAbsolutePercentageError, self).__init__( median_absolute_percentage_error_compute_fn, output_transform ) diff --git a/ignite/contrib/metrics/regression/median_relative_absolute_error.py b/ignite/contrib/metrics/regression/median_relative_absolute_error.py index b83730f36bac..2b5614ae9c95 100644 --- a/ignite/contrib/metrics/regression/median_relative_absolute_error.py +++ b/ignite/contrib/metrics/regression/median_relative_absolute_error.py @@ -1,9 +1,11 @@ +from typing import Any, Callable + import torch from ignite.contrib.metrics.regression._base import _BaseRegressionEpoch -def median_relative_absolute_error_compute_fn(y_pred, y): +def median_relative_absolute_error_compute_fn(y_pred: Any, y: Any): e = torch.abs(y.view_as(y_pred) - y_pred) / torch.abs(y.view_as(y_pred) - torch.mean(y)) return torch.median(e).item() @@ -31,5 +33,5 @@ class MedianRelativeAbsoluteError(_BaseRegressionEpoch): """ - def __init__(self, output_transform=lambda x: x): + def __init__(self, output_transform: Callable = lambda x: x): super(MedianRelativeAbsoluteError, self).__init__(median_relative_absolute_error_compute_fn, output_transform) diff --git a/ignite/contrib/metrics/regression/r2_score.py b/ignite/contrib/metrics/regression/r2_score.py index 1c9fd9a72d08..6607d9f30aaf 100644 --- a/ignite/contrib/metrics/regression/r2_score.py +++ b/ignite/contrib/metrics/regression/r2_score.py @@ -1,4 +1,4 @@ -from typing import Callable, Union +from typing import Any, Callable, Union import torch @@ -22,7 +22,7 @@ class R2Score(_BaseRegression): """ def __init__( - self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu") + self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), ): self._num_examples = None self._sum_of_errors = None @@ -37,7 +37,7 @@ def reset(self): self._y_sq_sum = torch.tensor(0.0, device=self._device) self._y_sum = torch.tensor(0.0, device=self._device) - def _update(self, output): + def _update(self, output: Any): y_pred, y = output self._num_examples += y.shape[0] self._sum_of_errors += torch.sum(torch.pow(y_pred - y, 2)).to(self._device) diff --git a/ignite/contrib/metrics/regression/wave_hedges_distance.py b/ignite/contrib/metrics/regression/wave_hedges_distance.py index d5572455ba4d..344cb0836b8e 100644 --- a/ignite/contrib/metrics/regression/wave_hedges_distance.py +++ b/ignite/contrib/metrics/regression/wave_hedges_distance.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from ignite.contrib.metrics.regression._base import _BaseRegression @@ -21,7 +23,7 @@ class WaveHedgesDistance(_BaseRegression): def reset(self): self._sum_of_errors = 0.0 - def _update(self, output): + def _update(self, output: Any): y_pred, y = output errors = torch.abs(y.view_as(y_pred) - y_pred) / torch.max(y_pred, y.view_as(y_pred)) self._sum_of_errors += torch.sum(errors).item() diff --git a/ignite/contrib/metrics/roc_auc.py b/ignite/contrib/metrics/roc_auc.py index b2e8235e9dd7..067f0dc54560 100644 --- a/ignite/contrib/metrics/roc_auc.py +++ b/ignite/contrib/metrics/roc_auc.py @@ -1,7 +1,9 @@ +from typing import Any, Callable + from ignite.metrics import EpochMetric -def roc_auc_compute_fn(y_preds, y_targets): +def roc_auc_compute_fn(y_preds: Any, y_targets: Any): try: from sklearn.metrics import roc_auc_score except ImportError: @@ -12,7 +14,7 @@ def roc_auc_compute_fn(y_preds, y_targets): return roc_auc_score(y_true, y_pred) -def roc_auc_curve_compute_fn(y_preds, y_targets): +def roc_auc_curve_compute_fn(y_preds: Any, y_targets: Any): try: from sklearn.metrics import roc_curve except ImportError: @@ -34,7 +36,7 @@ class ROC_AUC(EpochMetric): :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. - check_compute_fn (bool): Optional default False. If True, `roc_curve + check_compute_fn (bool): Default False. If True, `roc_curve `_ is run on the first batch of data to ensure there are no issues. User will be warned in case there are any issues computing the function. @@ -53,7 +55,7 @@ def activated_output_transform(output): """ - def __init__(self, output_transform=lambda x: x, check_compute_fn: bool = False): + def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False): super(ROC_AUC, self).__init__( roc_auc_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn ) @@ -70,7 +72,7 @@ class RocCurve(EpochMetric): :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. - check_compute_fn (bool): Optional default False. If True, `sklearn.metrics.roc_curve + check_compute_fn (bool): Default False. If True, `sklearn.metrics.roc_curve `_ is run on the first batch of data to ensure there are no issues. User will be warned in case there are any issues computing the function. @@ -89,7 +91,7 @@ def activated_output_transform(output): """ - def __init__(self, output_transform=lambda x: x, check_compute_fn: bool = False): + def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False): super(RocCurve, self).__init__( roc_auc_curve_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn ) From c9771106c01061eadb06f5080a086044528c222f Mon Sep 17 00:00:00 2001 From: ravitezu Date: Tue, 6 Oct 2020 11:02:10 +0530 Subject: [PATCH 2/3] review changes --- ignite/contrib/metrics/average_precision.py | 6 ++++-- ignite/contrib/metrics/gpu_info.py | 4 ++-- ignite/contrib/metrics/precision_recall_curve.py | 6 ++++-- ignite/contrib/metrics/regression/_base.py | 14 +++++++------- .../contrib/metrics/regression/canberra_metric.py | 4 ++-- .../regression/fractional_absolute_error.py | 4 ++-- .../contrib/metrics/regression/fractional_bias.py | 4 ++-- .../regression/geometric_mean_absolute_error.py | 4 ++-- .../geometric_mean_relative_absolute_error.py | 5 +++-- .../metrics/regression/manhattan_distance.py | 4 ++-- .../metrics/regression/maximum_absolute_error.py | 4 ++-- .../regression/mean_absolute_relative_error.py | 4 ++-- ignite/contrib/metrics/regression/mean_error.py | 4 ++-- .../metrics/regression/mean_normalized_bias.py | 4 ++-- .../metrics/regression/median_absolute_error.py | 4 ++-- .../regression/median_absolute_percentage_error.py | 4 ++-- .../regression/median_relative_absolute_error.py | 4 ++-- ignite/contrib/metrics/regression/r2_score.py | 4 ++-- .../metrics/regression/wave_hedges_distance.py | 4 ++-- ignite/contrib/metrics/roc_auc.py | 8 +++++--- 20 files changed, 53 insertions(+), 46 deletions(-) diff --git a/ignite/contrib/metrics/average_precision.py b/ignite/contrib/metrics/average_precision.py index 042a5b5ddf6e..d90f7e59ee7b 100644 --- a/ignite/contrib/metrics/average_precision.py +++ b/ignite/contrib/metrics/average_precision.py @@ -1,9 +1,11 @@ -from typing import Any, Callable +from typing import Callable + +import torch from ignite.metrics import EpochMetric -def average_precision_compute_fn(y_preds: Any, y_targets: Any): +def average_precision_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor): try: from sklearn.metrics import average_precision_score except ImportError: diff --git a/ignite/contrib/metrics/gpu_info.py b/ignite/contrib/metrics/gpu_info.py index 3d1382604271..e6f7cebf3278 100644 --- a/ignite/contrib/metrics/gpu_info.py +++ b/ignite/contrib/metrics/gpu_info.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- import warnings -from typing import Any +from typing import Tuple import torch @@ -55,7 +55,7 @@ def __init__(self): def reset(self): pass - def update(self, output: Any): + def update(self, output: Tuple[torch.Tensor, torch.Tensor]): pass def compute(self): diff --git a/ignite/contrib/metrics/precision_recall_curve.py b/ignite/contrib/metrics/precision_recall_curve.py index d5394f41bffb..3138bc94c12f 100644 --- a/ignite/contrib/metrics/precision_recall_curve.py +++ b/ignite/contrib/metrics/precision_recall_curve.py @@ -1,9 +1,11 @@ -from typing import Any, Callable +from typing import Callable + +import torch from ignite.metrics import EpochMetric -def precision_recall_curve_compute_fn(y_preds: Any, y_targets: Any): +def precision_recall_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor): try: from sklearn.metrics import precision_recall_curve except ImportError: diff --git a/ignite/contrib/metrics/regression/_base.py b/ignite/contrib/metrics/regression/_base.py index caae16dfc21a..58b935a352de 100644 --- a/ignite/contrib/metrics/regression/_base.py +++ b/ignite/contrib/metrics/regression/_base.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Any, Callable +from typing import Callable, Tuple import torch @@ -7,7 +7,7 @@ from ignite.metrics.metric import reinit__is_reduced -def _check_output_shapes(output: Any): +def _check_output_shapes(output: Tuple[torch.Tensor, torch.Tensor]): y_pred, y = output if y_pred.shape != y.shape: raise ValueError("Input data shapes should be the same, but given {} and {}".format(y_pred.shape, y.shape)) @@ -21,7 +21,7 @@ def _check_output_shapes(output: Any): raise ValueError("Input y should have shape (N,) or (N, 1), but given {}".format(y.shape)) -def _check_output_types(output: Any): +def _check_output_types(output: Tuple[torch.Tensor, torch.Tensor]): y_pred, y = output if y_pred.dtype not in (torch.float16, torch.float32, torch.float64): raise TypeError("Input y_pred dtype should be float 16, 32 or 64, but given {}".format(y_pred.dtype)) @@ -36,7 +36,7 @@ class _BaseRegression(Metric): # method `_update`. @reinit__is_reduced - def update(self, output: Any): + def update(self, output: Tuple[torch.Tensor, torch.Tensor]): _check_output_shapes(output) _check_output_types(output) y_pred, y = output[0].detach(), output[1].detach() @@ -50,7 +50,7 @@ def update(self, output: Any): self._update((y_pred, y)) @abstractmethod - def _update(self, output: Any): + def _update(self, output: Tuple[torch.Tensor, torch.Tensor]): pass @@ -66,9 +66,9 @@ def __init__( compute_fn=compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn ) - def _check_type(self, output: Any): + def _check_type(self, output: Tuple[torch.Tensor, torch.Tensor]): _check_output_types(output) super(_BaseRegressionEpoch, self)._check_type(output) - def _check_shape(self, output: Any): + def _check_shape(self, output: Tuple[torch.Tensor, torch.Tensor]): _check_output_shapes(output) diff --git a/ignite/contrib/metrics/regression/canberra_metric.py b/ignite/contrib/metrics/regression/canberra_metric.py index 40ed893a3f3d..b44fee22eff4 100644 --- a/ignite/contrib/metrics/regression/canberra_metric.py +++ b/ignite/contrib/metrics/regression/canberra_metric.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Union +from typing import Callable, Tuple, Union import torch @@ -34,7 +34,7 @@ def __init__( def reset(self): self._sum_of_errors = torch.tensor(0.0, device=self._device) - def _update(self, output: Any): + def _update(self, output: Tuple[torch.Tensor, torch.Tensor]): y_pred, y = output errors = torch.abs(y - y_pred) / (torch.abs(y_pred) + torch.abs(y)) self._sum_of_errors += torch.sum(errors).to(self._device) diff --git a/ignite/contrib/metrics/regression/fractional_absolute_error.py b/ignite/contrib/metrics/regression/fractional_absolute_error.py index f4436477f0e1..b1dcd9ec0532 100644 --- a/ignite/contrib/metrics/regression/fractional_absolute_error.py +++ b/ignite/contrib/metrics/regression/fractional_absolute_error.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Tuple import torch @@ -26,7 +26,7 @@ def reset(self): self._sum_of_errors = 0.0 self._num_examples = 0 - def _update(self, output: Any): + def _update(self, output: Tuple[torch.Tensor, torch.Tensor]): y_pred, y = output errors = 2 * torch.abs(y.view_as(y_pred) - y_pred) / (torch.abs(y_pred) + torch.abs(y.view_as(y_pred))) self._sum_of_errors += torch.sum(errors).item() diff --git a/ignite/contrib/metrics/regression/fractional_bias.py b/ignite/contrib/metrics/regression/fractional_bias.py index 50d2edbde395..10e78b3b3e24 100644 --- a/ignite/contrib/metrics/regression/fractional_bias.py +++ b/ignite/contrib/metrics/regression/fractional_bias.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Tuple import torch @@ -27,7 +27,7 @@ def reset(self): self._sum_of_errors = 0.0 self._num_examples = 0 - def _update(self, output: Any): + def _update(self, output: Tuple[torch.Tensor, torch.Tensor]): y_pred, y = output errors = 2 * (y.view_as(y_pred) - y_pred) / (y_pred + y.view_as(y_pred)) self._sum_of_errors += torch.sum(errors).item() diff --git a/ignite/contrib/metrics/regression/geometric_mean_absolute_error.py b/ignite/contrib/metrics/regression/geometric_mean_absolute_error.py index 959f03744675..92cb057a7db1 100644 --- a/ignite/contrib/metrics/regression/geometric_mean_absolute_error.py +++ b/ignite/contrib/metrics/regression/geometric_mean_absolute_error.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Tuple import torch @@ -26,7 +26,7 @@ def reset(self): self._sum_of_errors = 0.0 self._num_examples = 0 - def _update(self, output: Any): + def _update(self, output: Tuple[torch.Tensor, torch.Tensor]): y_pred, y = output errors = torch.log(torch.abs(y.view_as(y_pred) - y_pred)) self._sum_of_errors += torch.sum(errors) diff --git a/ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py b/ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py index 70dfb7d7ec0c..043f60956854 100644 --- a/ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py +++ b/ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py @@ -1,8 +1,9 @@ -from typing import Any +from typing import Tuple import torch from ignite.contrib.metrics.regression._base import _BaseRegression +from ignite.exceptions import NotComputableError class GeometricMeanRelativeAbsoluteError(_BaseRegression): @@ -28,7 +29,7 @@ def reset(self): self._num_examples = 0 self._sum_of_errors = 0.0 - def _update(self, output: Any): + def _update(self, output: Tuple[torch.Tensor, torch.Tensor]): y_pred, y = output self._sum_y += y.sum() self._num_examples += y.shape[0] diff --git a/ignite/contrib/metrics/regression/manhattan_distance.py b/ignite/contrib/metrics/regression/manhattan_distance.py index 80be4c59c0ad..6c3456e62eb8 100644 --- a/ignite/contrib/metrics/regression/manhattan_distance.py +++ b/ignite/contrib/metrics/regression/manhattan_distance.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Union +from typing import Callable, Tuple, Union import torch @@ -33,7 +33,7 @@ def __init__( def reset(self): self._sum_of_errors = torch.tensor(0.0, device=self._device) - def _update(self, output: Any): + def _update(self, output: Tuple[torch.Tensor, torch.Tensor]): y_pred, y = output errors = torch.abs(y - y_pred) self._sum_of_errors += torch.sum(errors).to(self._device) diff --git a/ignite/contrib/metrics/regression/maximum_absolute_error.py b/ignite/contrib/metrics/regression/maximum_absolute_error.py index 0e0aed4fbf9a..a866a5f22bb7 100644 --- a/ignite/contrib/metrics/regression/maximum_absolute_error.py +++ b/ignite/contrib/metrics/regression/maximum_absolute_error.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Tuple import torch @@ -26,7 +26,7 @@ class MaximumAbsoluteError(_BaseRegression): def reset(self): self._max_of_absolute_errors = -1 - def _update(self, output: Any): + def _update(self, output: Tuple[torch.Tensor, torch.Tensor]): y_pred, y = output mae = torch.abs(y_pred - y.view_as(y_pred)).max().item() if self._max_of_absolute_errors < mae: diff --git a/ignite/contrib/metrics/regression/mean_absolute_relative_error.py b/ignite/contrib/metrics/regression/mean_absolute_relative_error.py index 97a490bd9619..05f8da5ab549 100644 --- a/ignite/contrib/metrics/regression/mean_absolute_relative_error.py +++ b/ignite/contrib/metrics/regression/mean_absolute_relative_error.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Tuple import torch @@ -27,7 +27,7 @@ def reset(self): self._sum_of_absolute_relative_errors = 0.0 self._num_samples = 0 - def _update(self, output: Any): + def _update(self, output: Tuple[torch.Tensor, torch.Tensor]): y_pred, y = output if (y == 0).any(): raise NotComputableError("The ground truth has 0.") diff --git a/ignite/contrib/metrics/regression/mean_error.py b/ignite/contrib/metrics/regression/mean_error.py index fc36c7330349..35761d668fb5 100644 --- a/ignite/contrib/metrics/regression/mean_error.py +++ b/ignite/contrib/metrics/regression/mean_error.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Tuple import torch @@ -27,7 +27,7 @@ def reset(self): self._sum_of_errors = 0.0 self._num_examples = 0 - def _update(self, output: Any): + def _update(self, output: Tuple[torch.Tensor, torch.Tensor]): y_pred, y = output errors = y.view_as(y_pred) - y_pred self._sum_of_errors += torch.sum(errors).item() diff --git a/ignite/contrib/metrics/regression/mean_normalized_bias.py b/ignite/contrib/metrics/regression/mean_normalized_bias.py index 4bad5af5c7b9..2514dd2e5ce1 100644 --- a/ignite/contrib/metrics/regression/mean_normalized_bias.py +++ b/ignite/contrib/metrics/regression/mean_normalized_bias.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Tuple import torch @@ -27,7 +27,7 @@ def reset(self): self._sum_of_errors = 0.0 self._num_examples = 0 - def _update(self, output: Any): + def _update(self, output: Tuple[torch.Tensor, torch.Tensor]): y_pred, y = output if (y == 0).any(): diff --git a/ignite/contrib/metrics/regression/median_absolute_error.py b/ignite/contrib/metrics/regression/median_absolute_error.py index 2e0e16568e75..3551020d0881 100644 --- a/ignite/contrib/metrics/regression/median_absolute_error.py +++ b/ignite/contrib/metrics/regression/median_absolute_error.py @@ -1,11 +1,11 @@ -from typing import Any, Callable +from typing import Callable import torch from ignite.contrib.metrics.regression._base import _BaseRegressionEpoch -def median_absolute_error_compute_fn(y_pred: Any, y: Any): +def median_absolute_error_compute_fn(y_pred: torch.Tensor, y: torch.Tensor): e = torch.abs(y.view_as(y_pred) - y_pred) return torch.median(e).item() diff --git a/ignite/contrib/metrics/regression/median_absolute_percentage_error.py b/ignite/contrib/metrics/regression/median_absolute_percentage_error.py index 4657138981b3..9b048cc9240c 100644 --- a/ignite/contrib/metrics/regression/median_absolute_percentage_error.py +++ b/ignite/contrib/metrics/regression/median_absolute_percentage_error.py @@ -1,11 +1,11 @@ -from typing import Any, Callable +from typing import Callable import torch from ignite.contrib.metrics.regression._base import _BaseRegressionEpoch -def median_absolute_percentage_error_compute_fn(y_pred: Any, y: Any): +def median_absolute_percentage_error_compute_fn(y_pred: torch.Tensor, y: torch.Tensor): e = torch.abs(y.view_as(y_pred) - y_pred) / torch.abs(y.view_as(y_pred)) return 100.0 * torch.median(e).item() diff --git a/ignite/contrib/metrics/regression/median_relative_absolute_error.py b/ignite/contrib/metrics/regression/median_relative_absolute_error.py index 2b5614ae9c95..a45889ef9ea2 100644 --- a/ignite/contrib/metrics/regression/median_relative_absolute_error.py +++ b/ignite/contrib/metrics/regression/median_relative_absolute_error.py @@ -1,11 +1,11 @@ -from typing import Any, Callable +from typing import Callable import torch from ignite.contrib.metrics.regression._base import _BaseRegressionEpoch -def median_relative_absolute_error_compute_fn(y_pred: Any, y: Any): +def median_relative_absolute_error_compute_fn(y_pred: torch.Tensor, y: torch.Tensor): e = torch.abs(y.view_as(y_pred) - y_pred) / torch.abs(y.view_as(y_pred) - torch.mean(y)) return torch.median(e).item() diff --git a/ignite/contrib/metrics/regression/r2_score.py b/ignite/contrib/metrics/regression/r2_score.py index 6607d9f30aaf..2610bac946db 100644 --- a/ignite/contrib/metrics/regression/r2_score.py +++ b/ignite/contrib/metrics/regression/r2_score.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Union +from typing import Callable, Tuple, Union import torch @@ -37,7 +37,7 @@ def reset(self): self._y_sq_sum = torch.tensor(0.0, device=self._device) self._y_sum = torch.tensor(0.0, device=self._device) - def _update(self, output: Any): + def _update(self, output: Tuple[torch.Tensor, torch.Tensor]): y_pred, y = output self._num_examples += y.shape[0] self._sum_of_errors += torch.sum(torch.pow(y_pred - y, 2)).to(self._device) diff --git a/ignite/contrib/metrics/regression/wave_hedges_distance.py b/ignite/contrib/metrics/regression/wave_hedges_distance.py index 344cb0836b8e..c56e93d97344 100644 --- a/ignite/contrib/metrics/regression/wave_hedges_distance.py +++ b/ignite/contrib/metrics/regression/wave_hedges_distance.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Tuple import torch @@ -23,7 +23,7 @@ class WaveHedgesDistance(_BaseRegression): def reset(self): self._sum_of_errors = 0.0 - def _update(self, output: Any): + def _update(self, output: Tuple[torch.Tensor, torch.Tensor]): y_pred, y = output errors = torch.abs(y.view_as(y_pred) - y_pred) / torch.max(y_pred, y.view_as(y_pred)) self._sum_of_errors += torch.sum(errors).item() diff --git a/ignite/contrib/metrics/roc_auc.py b/ignite/contrib/metrics/roc_auc.py index 067f0dc54560..d2230daaf58b 100644 --- a/ignite/contrib/metrics/roc_auc.py +++ b/ignite/contrib/metrics/roc_auc.py @@ -1,9 +1,11 @@ -from typing import Any, Callable +from typing import Callable + +import torch from ignite.metrics import EpochMetric -def roc_auc_compute_fn(y_preds: Any, y_targets: Any): +def roc_auc_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor): try: from sklearn.metrics import roc_auc_score except ImportError: @@ -14,7 +16,7 @@ def roc_auc_compute_fn(y_preds: Any, y_targets: Any): return roc_auc_score(y_true, y_pred) -def roc_auc_curve_compute_fn(y_preds: Any, y_targets: Any): +def roc_auc_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor): try: from sklearn.metrics import roc_curve except ImportError: From 73804c9a22b90f96906f096aaa090d3c7820b704 Mon Sep 17 00:00:00 2001 From: vfdev Date: Tue, 6 Oct 2020 10:42:19 +0200 Subject: [PATCH 3/3] Update gpu_info.py --- ignite/contrib/metrics/gpu_info.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/contrib/metrics/gpu_info.py b/ignite/contrib/metrics/gpu_info.py index e6f7cebf3278..16a39c684bb7 100644 --- a/ignite/contrib/metrics/gpu_info.py +++ b/ignite/contrib/metrics/gpu_info.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- import warnings -from typing import Tuple +from typing import Tuple, Union import torch @@ -104,5 +104,5 @@ def completed(self, engine: Engine, name: str): # Do not set GPU utilization information pass - def attach(self, engine: Engine, name: str = "gpu", event_name: EventEnum = Events.ITERATION_COMPLETED): + def attach(self, engine: Engine, name: str = "gpu", event_name: Union[str, EventEnum] = Events.ITERATION_COMPLETED): engine.add_event_handler(event_name, self.completed, name)