|
1 | 1 | import os |
2 | 2 | import warnings |
3 | 3 | from abc import ABC, abstractmethod |
4 | | -from typing import Any, Dict, List, Optional, Union |
| 4 | +from typing import Any, Callable, Dict, List, Optional, Union |
5 | 5 |
|
6 | 6 | import gym |
7 | 7 | import numpy as np |
@@ -217,9 +217,10 @@ class CheckpointCallback(BaseCallback): |
217 | 217 | :param save_freq: |
218 | 218 | :param save_path: Path to the folder where the model will be saved. |
219 | 219 | :param name_prefix: Common prefix to the saved models |
| 220 | + :param verbose: |
220 | 221 | """ |
221 | 222 |
|
222 | | - def __init__(self, save_freq: int, save_path: str, name_prefix="rl_model", verbose=0): |
| 223 | + def __init__(self, save_freq: int, save_path: str, name_prefix: str = "rl_model", verbose: int = 0): |
223 | 224 | super(CheckpointCallback, self).__init__(verbose) |
224 | 225 | self.save_freq = save_freq |
225 | 226 | self.save_path = save_path |
@@ -247,7 +248,7 @@ class ConvertCallback(BaseCallback): |
247 | 248 | :param verbose: |
248 | 249 | """ |
249 | 250 |
|
250 | | - def __init__(self, callback, verbose=0): |
| 251 | + def __init__(self, callback: Callable, verbose: int = 0): |
251 | 252 | super(ConvertCallback, self).__init__(verbose) |
252 | 253 | self.callback = callback |
253 | 254 |
|
@@ -314,7 +315,7 @@ def __init__( |
314 | 315 | self.evaluations_timesteps = [] |
315 | 316 | self.evaluations_length = [] |
316 | 317 |
|
317 | | - def _init_callback(self): |
| 318 | + def _init_callback(self) -> None: |
318 | 319 | # Does not work in some corner cases, where the wrapper is not the same |
319 | 320 | if not isinstance(self.training_env, type(self.eval_env)): |
320 | 321 | warnings.warn("Training and eval env are not of the same type" f"{self.training_env} != {self.eval_env}") |
@@ -450,7 +451,7 @@ def __init__(self, max_episodes: int, verbose: int = 0): |
450 | 451 | self._total_max_episodes = max_episodes |
451 | 452 | self.n_episodes = 0 |
452 | 453 |
|
453 | | - def _init_callback(self): |
| 454 | + def _init_callback(self) -> None: |
454 | 455 | # At start set total max according to number of envirnments |
455 | 456 | self._total_max_episodes = self.max_episodes * self.training_env.num_envs |
456 | 457 |
|
|
0 commit comments