Skip to content

Commit a1e0556

Browse files
authored
Improve typing coverage (#175)
* Improve typing coverage * Even more types * Fixes * Update changelog * Unified docstrings * Improve error messages for unsupported spaces
1 parent a10e3ae commit a1e0556

25 files changed

+206
-138
lines changed

docs/misc/changelog.rst

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,31 @@
33
Changelog
44
==========
55

6+
7+
Pre-Release 0.10.0a0 (WIP)
8+
------------------------------
9+
10+
Breaking Changes:
11+
^^^^^^^^^^^^^^^^^
12+
13+
New Features:
14+
^^^^^^^^^^^^^
15+
16+
Bug Fixes:
17+
^^^^^^^^^^
18+
19+
Deprecations:
20+
^^^^^^^^^^^^^
21+
22+
Others:
23+
^^^^^^^
24+
- Improved typing coverage
25+
- Improved error messages for unsupported spaces
26+
27+
Documentation:
28+
^^^^^^^^^^^^^^
29+
30+
631
Pre-Release 0.9.0 (2020-10-03)
732
------------------------------
833

stable_baselines3/common/atari_wrappers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
except ImportError:
1010
cv2 = None
1111

12-
from stable_baselines3.common.type_aliases import GymStepReturn
12+
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
1313

1414

1515
class NoopResetEnv(gym.Wrapper):
@@ -146,7 +146,7 @@ def step(self, action: int) -> GymStepReturn:
146146

147147
return max_frame, total_reward, done, info
148148

149-
def reset(self, **kwargs):
149+
def reset(self, **kwargs) -> GymObs:
150150
return self.env.reset(**kwargs)
151151

152152

stable_baselines3/common/base_class.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def set_parameters(
478478
load_path_or_dict: Union[str, Dict[str, Dict]],
479479
exact_match: bool = True,
480480
device: Union[th.device, str] = "auto",
481-
):
481+
) -> None:
482482
"""
483483
Load parameters from a given zip-file or a nested dictionary containing parameters for
484484
different modules (see ``get_parameters``).
@@ -610,7 +610,7 @@ def load(
610610
model.policy.reset_noise() # pytype: disable=attribute-error
611611
return model
612612

613-
def get_parameters(self):
613+
def get_parameters(self) -> Dict[str, Dict]:
614614
"""
615615
Return the parameters of the agent. This includes parameters from different networks, e.g.
616616
critics (value functions) and policies (pi functions).

stable_baselines3/common/buffers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from abc import ABC, abstractmethod
23
from typing import Generator, Optional, Union
34

45
import numpy as np
@@ -16,7 +17,7 @@
1617
from stable_baselines3.common.vec_env import VecNormalize
1718

1819

19-
class BaseBuffer(object):
20+
class BaseBuffer(ABC):
2021
"""
2122
Base class that represent a buffer (rollout or replay)
2223
@@ -102,7 +103,10 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None):
102103
batch_inds = np.random.randint(0, upper_bound, size=batch_size)
103104
return self._get_samples(batch_inds, env=env)
104105

105-
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None):
106+
@abstractmethod
107+
def _get_samples(
108+
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
109+
) -> Union[ReplayBufferSamples, RolloutBufferSamples]:
106110
"""
107111
:param batch_inds:
108112
:param env:

stable_baselines3/common/callbacks.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import warnings
33
from abc import ABC, abstractmethod
4-
from typing import Any, Dict, List, Optional, Union
4+
from typing import Any, Callable, Dict, List, Optional, Union
55

66
import gym
77
import numpy as np
@@ -217,9 +217,10 @@ class CheckpointCallback(BaseCallback):
217217
:param save_freq:
218218
:param save_path: Path to the folder where the model will be saved.
219219
:param name_prefix: Common prefix to the saved models
220+
:param verbose:
220221
"""
221222

222-
def __init__(self, save_freq: int, save_path: str, name_prefix="rl_model", verbose=0):
223+
def __init__(self, save_freq: int, save_path: str, name_prefix: str = "rl_model", verbose: int = 0):
223224
super(CheckpointCallback, self).__init__(verbose)
224225
self.save_freq = save_freq
225226
self.save_path = save_path
@@ -247,7 +248,7 @@ class ConvertCallback(BaseCallback):
247248
:param verbose:
248249
"""
249250

250-
def __init__(self, callback, verbose=0):
251+
def __init__(self, callback: Callable, verbose: int = 0):
251252
super(ConvertCallback, self).__init__(verbose)
252253
self.callback = callback
253254

@@ -314,7 +315,7 @@ def __init__(
314315
self.evaluations_timesteps = []
315316
self.evaluations_length = []
316317

317-
def _init_callback(self):
318+
def _init_callback(self) -> None:
318319
# Does not work in some corner cases, where the wrapper is not the same
319320
if not isinstance(self.training_env, type(self.eval_env)):
320321
warnings.warn("Training and eval env are not of the same type" f"{self.training_env} != {self.eval_env}")
@@ -450,7 +451,7 @@ def __init__(self, max_episodes: int, verbose: int = 0):
450451
self._total_max_episodes = max_episodes
451452
self.n_episodes = 0
452453

453-
def _init_callback(self):
454+
def _init_callback(self) -> None:
454455
# At start set total max according to number of envirnments
455456
self._total_max_episodes = self.max_episodes * self.training_env.num_envs
456457

stable_baselines3/common/cmd_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from stable_baselines3.common.atari_wrappers import AtariWrapper
88
from stable_baselines3.common.monitor import Monitor
9-
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
9+
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
1010

1111

1212
def make_vec_env(
@@ -19,7 +19,7 @@ def make_vec_env(
1919
env_kwargs: Optional[Dict[str, Any]] = None,
2020
vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
2121
vec_env_kwargs: Optional[Dict[str, Any]] = None,
22-
):
22+
) -> VecEnv:
2323
"""
2424
Create a wrapped, monitored ``VecEnv``.
2525
By default it uses a ``DummyVecEnv`` which is usually faster
@@ -85,7 +85,7 @@ def make_atari_env(
8585
env_kwargs: Optional[Dict[str, Any]] = None,
8686
vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None,
8787
vec_env_kwargs: Optional[Dict[str, Any]] = None,
88-
):
88+
) -> VecEnv:
8989
"""
9090
Create a wrapped, monitored VecEnv for Atari.
9191
It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games.

stable_baselines3/common/distributions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Probability distributions."""
22

33
from abc import ABC, abstractmethod
4-
from typing import Any, Dict, List, Optional, Tuple
4+
from typing import Any, Dict, List, Optional, Tuple, Union
55

66
import gym
77
import torch as th
@@ -19,7 +19,7 @@ def __init__(self):
1919
super(Distribution, self).__init__()
2020

2121
@abstractmethod
22-
def proba_distribution_net(self, *args, **kwargs):
22+
def proba_distribution_net(self, *args, **kwargs) -> Union[nn.Module, Tuple[nn.Module, nn.Parameter]]:
2323
"""Create the layers and parameters that represent the distribution.
2424
2525
Subclasses must define this, but the arguments and return type vary between

stable_baselines3/common/logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class SeqWriter(object):
5050
sequence writer
5151
"""
5252

53-
def write_sequence(self, sequence: List):
53+
def write_sequence(self, sequence: List) -> None:
5454
"""
5555
write_sequence an array to file
5656

stable_baselines3/common/monitor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
import os
66
import time
77
from glob import glob
8-
from typing import Any, Dict, List, Optional, Tuple
8+
from typing import List, Optional, Tuple, Union
99

1010
import gym
1111
import numpy as np
1212
import pandas
1313

14+
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
15+
1416

1517
class Monitor(gym.Wrapper):
1618
"""
@@ -62,7 +64,7 @@ def __init__(
6264
self.total_steps = 0
6365
self.current_reset_info = {} # extra info about the current episode, that was passed in during reset()
6466

65-
def reset(self, **kwargs) -> np.ndarray:
67+
def reset(self, **kwargs) -> GymObs:
6668
"""
6769
Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True
6870
@@ -83,7 +85,7 @@ def reset(self, **kwargs) -> np.ndarray:
8385
self.current_reset_info[key] = value
8486
return self.env.reset(**kwargs)
8587

86-
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict[Any, Any]]:
88+
def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
8789
"""
8890
Step the environment with the given action
8991
@@ -112,7 +114,7 @@ def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict[Any, A
112114
self.total_steps += 1
113115
return observation, reward, done, info
114116

115-
def close(self):
117+
def close(self) -> None:
116118
"""
117119
Closes the environment
118120
"""

stable_baselines3/common/noise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def base_noise(self) -> ActionNoise:
139139
return self._base_noise
140140

141141
@base_noise.setter
142-
def base_noise(self, base_noise: ActionNoise):
142+
def base_noise(self, base_noise: ActionNoise) -> None:
143143
if base_noise is None:
144144
raise ValueError("Expected base_noise to be an instance of ActionNoise, not None", ActionNoise)
145145
if not isinstance(base_noise, ActionNoise):

0 commit comments

Comments
 (0)