Skip to content

Commit a312219

Browse files
authored
Prune metric: helpers and inputs 3/n (#6547)
* _basic_input_validation * _check_shape_and_type_consistency * _check_num_classes_binary * _check_num_classes_mc * _check_num_classes_ml * _check_top_k * _check_classification_inputs * _input_format_classification * _reduce_stat_scores * DataType * rest * flake8 * chlog
1 parent 0f07eaf commit a312219

File tree

15 files changed

+20
-549
lines changed

15 files changed

+20
-549
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6868
- Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505),
6969

7070
[#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530),
71+
72+
[#6547](https://github.com/PyTorchLightning/pytorch-lightning/pull/6547),
7173

7274
)
7375

pytorch_lightning/metrics/classification/helpers.py

Lines changed: 0 additions & 535 deletions
This file was deleted.

pytorch_lightning/metrics/functional/accuracy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from typing import Optional, Tuple
1515

1616
import torch
17-
18-
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
17+
from torchmetrics.classification.checks import _input_format_classification
18+
from torchmetrics.utilities.enums import DataType
1919

2020

2121
def _accuracy_update(

pytorch_lightning/metrics/functional/auroc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
from typing import Optional, Sequence, Tuple
1616

1717
import torch
18+
from torchmetrics.classification.checks import _input_format_classification
19+
from torchmetrics.utilities.enums import DataType
1820

19-
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
2021
from pytorch_lightning.metrics.functional.auc import auc
2122
from pytorch_lightning.metrics.functional.roc import roc
2223
from pytorch_lightning.utilities import LightningEnum

pytorch_lightning/metrics/functional/confusion_matrix.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
from typing import Optional
1515

1616
import torch
17+
from torchmetrics.classification.checks import _input_format_classification
18+
from torchmetrics.utilities.enums import DataType
1719

18-
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
1920
from pytorch_lightning.utilities import rank_zero_warn
2021

2122

pytorch_lightning/metrics/functional/hamming_distance.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
from typing import Tuple, Union
1515

1616
import torch
17-
18-
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
17+
from torchmetrics.classification.checks import _input_format_classification
1918

2019

2120
def _hamming_distance_update(

pytorch_lightning/metrics/functional/precision_recall.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from typing import Optional
1515

1616
import torch
17+
from torchmetrics.classification.stat_scores import _reduce_stat_scores
1718

18-
from pytorch_lightning.metrics.classification.helpers import _reduce_stat_scores
1919
from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_update
2020
from pytorch_lightning.utilities import rank_zero_warn
2121

pytorch_lightning/metrics/functional/stat_scores.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
from typing import Optional, Tuple
1515

1616
import torch
17-
18-
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
17+
from torchmetrics.classification.checks import _input_format_classification
1918

2019

2120
def _del_column(tensor: torch.Tensor, index: int):

pytorch_lightning/trainer/connectors/env_vars_connector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def _defaults_from_env_vars(fn: Callable) -> Callable:
2323
Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which
2424
input arguments should be moved automatically to the correct device.
2525
"""
26+
2627
@wraps(fn)
2728
def insert_env_defaults(self, *args, **kwargs):
2829
cls = self.__class__ # get the class

tests/metrics/classification/test_accuracy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import pytest
55
import torch
66
from sklearn.metrics import accuracy_score as sk_accuracy
7+
from torchmetrics.classification.checks import _input_format_classification
8+
from torchmetrics.utilities.enums import DataType
79

810
from pytorch_lightning.metrics import Accuracy
9-
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
1011
from pytorch_lightning.metrics.functional import accuracy
1112
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
1213
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls

0 commit comments

Comments
 (0)