Skip to content

Fix mypy errors attributed to pytorch_lightning.loggers.neptune #13692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
cda3a6d
update pyproject.toml for ci code-check
jxtngx Jul 17, 2022
5bcc364
update missing annotations and return types
jxtngx Jul 17, 2022
93037aa
Merge branch 'master' into codeq/neptune-logger
jxtngx Jul 23, 2022
85268f3
Merge branch 'master' into codeq/neptune-logger
jxtngx Jul 25, 2022
dbebeab
Merge branch 'master' into codeq/neptune-logger
jxtngx Jul 26, 2022
923a49f
Merge branch 'master' into codeq/neptune-logger
jxtngx Jul 27, 2022
b92d227
Update pyproject.toml
awaelchli Jul 27, 2022
8e4bc43
Update src/pytorch_lightning/loggers/neptune.py
jxtngx Jul 27, 2022
5696f88
Update src/pytorch_lightning/loggers/neptune.py
jxtngx Jul 27, 2022
2817b7b
update
jxtngx Jul 28, 2022
5b17417
update
jxtngx Jul 28, 2022
861b067
specific types and improve typing for dict reduction
awaelchli Jul 28, 2022
4f049a9
update
jxtngx Jul 28, 2022
93a6a0f
Merge branch 'master' into codeq/neptune-logger
jxtngx Jul 28, 2022
575f150
update
jxtngx Jul 29, 2022
93ff29a
Merge branch 'master' into codeq/neptune-logger
jxtngx Jul 29, 2022
1fbc783
Merge branch 'master' into codeq/neptune-logger
jxtngx Aug 1, 2022
8603488
Merge branch 'master' into codeq/neptune-logger
jxtngx Aug 1, 2022
987337f
update for pytest
jxtngx Aug 1, 2022
48e0db1
Merge branch 'master' into codeq/neptune-logger
jxtngx Aug 1, 2022
bdcba72
Merge branch 'master' into codeq/neptune-logger
Aug 2, 2022
f97012d
create run if not existing yet
Aug 2, 2022
ebd6e31
adjust testing mocks
Aug 2, 2022
c6302bd
Revert "adjust testing mocks"
Aug 2, 2022
2833aa4
Revert "create run if not existing yet"
Aug 2, 2022
e65d3d2
allow for optional name and version
Aug 2, 2022
d9bdbf1
commit suggestion
jxtngx Aug 3, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ module = [
"pytorch_lightning.core.saving",
"pytorch_lightning.demos.boring_classes",
"pytorch_lightning.demos.mnist_datamodule",
"pytorch_lightning.loggers.neptune",
"pytorch_lightning.profilers.base",
"pytorch_lightning.profilers.pytorch",
"pytorch_lightning.profilers.simple",
Expand Down
67 changes: 37 additions & 30 deletions src/pytorch_lightning/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@
import os
import warnings
from argparse import Namespace
from functools import reduce
from typing import Any, Callable, Dict, Generator, Mapping, Optional, Sequence, Set, Union
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Sequence, Set, Union
from weakref import ReferenceType

import torch
from torch import Tensor

from pytorch_lightning import __version__
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.imports import _RequirementAvailable
Expand Down Expand Up @@ -272,7 +270,7 @@ def __init__(
prefix: str = "training",
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
**neptune_run_kwargs,
**neptune_run_kwargs: Any,
):
# verify if user passed proper init arguments
self._verify_input_arguments(api_key, project, name, run, neptune_run_kwargs)
Expand All @@ -290,26 +288,27 @@ def __init__(
self._api_key = api_key
self._run_instance = run
self._neptune_run_kwargs = neptune_run_kwargs
self._run_short_id = None
self._run_short_id: Optional[str] = None

if self._run_instance is not None:
self._retrieve_run_data()

# make sure that we've log integration version for outside `Run` instances
self._run_instance[_INTEGRATION_VERSION_KEY] = __version__
self._run_instance[_INTEGRATION_VERSION_KEY] = pl.__version__

def _retrieve_run_data(self):
def _retrieve_run_data(self) -> None:
try:
self._run_instance.wait()
self._run_short_id = self._run_instance["sys/id"].fetch()
self._run_name = self._run_instance["sys/name"].fetch()
if self._run_instance is not None:
self._run_instance.wait()
self._run_short_id = self._run_instance["sys/id"].fetch()
self._run_name = self._run_instance["sys/name"].fetch()
except NeptuneOfflineModeFetchException:
self._run_short_id = "OFFLINE"
self._run_name = "offline-name"

@property
def _neptune_init_args(self):
args = {}
def _neptune_init_args(self) -> Dict:
args: Dict = {}
# Backward compatibility in case of previous version retrieval
try:
args = self._neptune_run_kwargs
Expand All @@ -334,7 +333,7 @@ def _neptune_init_args(self):

return args

def _construct_path_with_prefix(self, *keys) -> str:
def _construct_path_with_prefix(self, *keys: str) -> str:
"""Return sequence of keys joined by `LOGGER_JOIN_CHAR`, started with `_prefix` if defined."""
if self._prefix:
return self.LOGGER_JOIN_CHAR.join([self._prefix, *keys])
Expand All @@ -347,7 +346,7 @@ def _verify_input_arguments(
name: Optional[str],
run: Optional["Run"],
neptune_run_kwargs: dict,
):
) -> None:
legacy_kwargs_msg = (
"Following kwargs are deprecated: {legacy_kwargs}.\n"
"If you are looking for the Neptune logger using legacy Python API,"
Expand Down Expand Up @@ -393,17 +392,17 @@ def _verify_input_arguments(
" you can't provide other neptune.init() parameters.\n"
)

def __getstate__(self):
def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
# Run instance can't be pickled
state["_run_instance"] = None
return state

def __setstate__(self, state):
def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__ = state
self._run_instance = neptune.init(**self._neptune_init_args)

@property
@property # type: ignore[misc]
@rank_zero_experiment
def experiment(self) -> Run:
r"""
Expand Down Expand Up @@ -433,15 +432,15 @@ def training_step(self, batch, batch_idx):
"""
return self.run

@property
@property # type: ignore[misc]
@rank_zero_experiment
def run(self) -> Run:
try:
if not self._run_instance:
self._run_instance = neptune.init(**self._neptune_init_args)
self._retrieve_run_data()
# make sure that we've log integration version for newly created
self._run_instance[_INTEGRATION_VERSION_KEY] = __version__
self._run_instance[_INTEGRATION_VERSION_KEY] = pl.__version__

return self._run_instance
except NeptuneLegacyProjectException as e:
Expand Down Expand Up @@ -531,7 +530,7 @@ def save_dir(self) -> Optional[str]:
return os.path.join(os.getcwd(), ".neptune")

@rank_zero_only
def log_model_summary(self, model, max_depth=-1):
def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) -> None:
model_str = str(ModelSummary(model=model, max_depth=max_depth))
self.run[self._construct_path_with_prefix("model/summary")] = neptune.types.File.from_content(
content=model_str, extension="txt"
Expand Down Expand Up @@ -600,14 +599,16 @@ def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[Ch
return filepath

@classmethod
def _get_full_model_names_from_exp_structure(cls, exp_structure: dict, namespace: str) -> Set[str]:
def _get_full_model_names_from_exp_structure(cls, exp_structure: Dict[str, Any], namespace: str) -> Set[str]:
"""Returns all paths to properties which were already logged in `namespace`"""
structure_keys = namespace.split(cls.LOGGER_JOIN_CHAR)
uploaded_models_dict = reduce(lambda d, k: d[k], [exp_structure, *structure_keys])
structure_keys: List[str] = namespace.split(cls.LOGGER_JOIN_CHAR)
for key in structure_keys:
exp_structure = exp_structure[key]
uploaded_models_dict = exp_structure
return set(cls._dict_paths(uploaded_models_dict))

@classmethod
def _dict_paths(cls, d: dict, path_in_build: str = None) -> Generator:
def _dict_paths(cls, d: Dict[str, Any], path_in_build: str = None) -> Generator:
for k, v in d.items():
path = f"{path_in_build}/{k}" if path_in_build is not None else k
if not isinstance(v, dict):
Expand All @@ -618,6 +619,9 @@ def _dict_paths(cls, d: dict, path_in_build: str = None) -> Generator:
@property
def name(self) -> str:
"""Return the experiment name or 'offline-name' when exp is run in offline mode."""
if self._run_name is None:
_ = self.run
assert self._run_name is not None
return self._run_name

@property
Expand All @@ -626,10 +630,13 @@ def version(self) -> str:

It's Neptune Run's short_id
"""
if self._run_short_id is None:
_ = self.run
assert self._run_short_id is not None
return self._run_short_id

@staticmethod
def _signal_deprecated_api_usage(f_name, sample_code, raise_exception=False):
def _signal_deprecated_api_usage(f_name: str, sample_code: str, raise_exception: bool = False) -> None:
msg_suffix = (
f"If you are looking for the Neptune logger using legacy Python API,"
f" it's still available as part of neptune-contrib package:\n"
Expand All @@ -649,10 +656,10 @@ def _signal_deprecated_api_usage(f_name, sample_code, raise_exception=False):
raise ValueError("The function you've used is deprecated.\n" + msg_suffix)

@rank_zero_only
def log_metric(self, metric_name: str, metric_value: Union[Tensor, float, str], step: Optional[int] = None):
def log_metric(self, metric_name: str, metric_value: Union[Tensor, float, str], step: Optional[int] = None) -> None:
key = f"{self._prefix}/{metric_name}"
self._signal_deprecated_api_usage("log_metric", f"logger.run['{key}'].log(42)")
if torch.is_tensor(metric_value):
if isinstance(metric_value, Tensor):
metric_value = metric_value.cpu().detach()

self.run[key].log(metric_value, step=step)
Expand All @@ -678,12 +685,12 @@ def log_artifact(self, artifact: str, destination: Optional[str] = None) -> None
self._signal_deprecated_api_usage("log_artifact", f"logger.run['{key}].log('path_to_file')")
self.run[key].log(destination)

def set_property(self, *args, **kwargs):
def set_property(self, *args: Any, **kwargs: Any) -> None:
self._signal_deprecated_api_usage(
"log_artifact", f"logger.run['{self._prefix}/{self.PARAMETERS_KEY}/key'].log(value)", raise_exception=True
)

def append_tags(self, *args, **kwargs):
def append_tags(self, *args: Any, **kwargs: Any) -> None:
self._signal_deprecated_api_usage(
"append_tags", "logger.run['sys/tags'].add(['foo', 'bar'])", raise_exception=True
)
2 changes: 2 additions & 0 deletions tests/tests_pytorch/loggers/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def _get_logger_with_mocks(**kwargs):
logger._run_instance.__getitem__.return_value.fetch.return_value = "exp-name"
run_attr_mock = MagicMock()
logger._run_instance.__getitem__.return_value = run_attr_mock
logger._run_name = "exp-name"
logger._run_short_id = "short"

return logger, run_instance_mock, run_attr_mock

Expand Down