Skip to content
This repository was archived by the owner on Oct 7, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bsuite/baselines/jax/actor_critic/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
# Define loss function.
def loss(trajectory: sequence.Trajectory) -> jnp.ndarray:
""""Actor-critic loss."""
logits, values = network(trajectory.observations)
logits, values = network(trajectory.observations) # pytype: disable=wrong-arg-types # jax-ndarray
td_errors = rlax.td_lambda(
v_tm1=values[:-1],
r_t=trajectory.rewards,
Expand Down
4 changes: 2 additions & 2 deletions bsuite/baselines/utils/sequence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def test_buffer(self):
max_sequence_length = 10
obs_shape = (3, 3)
buffer = sequence.Buffer(
obs_spec=specs.Array(obs_shape, dtype=np.float),
action_spec=specs.Array((), dtype=np.int),
obs_spec=specs.Array(obs_shape, dtype=float),
action_spec=specs.Array((), dtype=int),
max_sequence_length=max_sequence_length)
dummy_step = dm_env.transition(observation=np.zeros(obs_shape), reward=0.)

Expand Down
4 changes: 2 additions & 2 deletions bsuite/environments/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ def _reset(self) -> dm_env.TimeStep:
raise NotImplementedError('This environment implements its own auto-reset.')

def action_spec(self):
return specs.DiscreteArray(dtype=np.int, num_values=3, name='action')
return specs.DiscreteArray(dtype=int, num_values=3, name='action')

def observation_spec(self):
return specs.Array(shape=(1, 6), dtype=np.float32, name='state')
return specs.Array(shape=(1, 6), dtype=np.float32, name='observation')

@property
def observation(self) -> np.ndarray:
Expand Down
4 changes: 2 additions & 2 deletions bsuite/environments/catch.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ def _step(self, action: int) -> dm_env.TimeStep:
def observation_spec(self) -> specs.BoundedArray:
"""Returns the observation spec."""
return specs.BoundedArray(shape=self._board.shape, dtype=self._board.dtype,
name="board", minimum=0, maximum=1)
name="observation", minimum=0, maximum=1)

def action_spec(self) -> specs.DiscreteArray:
"""Returns the action spec."""
return specs.DiscreteArray(
dtype=np.int, num_values=len(_ACTIONS), name="action")
dtype=int, num_values=len(_ACTIONS), name="action")

def _observation(self) -> np.ndarray:
self._board.fill(0.)
Expand Down
2 changes: 1 addition & 1 deletion bsuite/experiments/cartpole_swingup/cartpole_swingup.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _reset(self) -> dm_env.TimeStep:
raise NotImplementedError('This environment implements its own auto-reset.')

def action_spec(self):
return specs.DiscreteArray(dtype=np.int, num_values=3, name='action')
return specs.DiscreteArray(dtype=int, num_values=3, name='action')

def observation_spec(self):
return specs.Array(shape=(1, 8), dtype=np.float32, name='state')
Expand Down
2 changes: 1 addition & 1 deletion bsuite/experiments/deep_sea/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _check_data(df: pd.DataFrame) -> None:
def find_solution(df_in: pd.DataFrame,
sweep_vars: Optional[Sequence[str]] = None,
merge: bool = True,
thresh: float = 0.8,
thresh: float = 0.9,
num_episodes: int = NUM_EPISODES) -> pd.DataFrame:
"""Find first episode that gets below thresh regret by sweep_vars."""
# Check data has the necessary columns for deep sea
Expand Down
4 changes: 2 additions & 2 deletions bsuite/logging/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# ============================================================================
"""Read functionality for local csv-based experiments."""

import collections
from collections import abc
import copy
from typing import Any, Callable, List, Mapping, Sequence, Tuple, Union

Expand Down Expand Up @@ -72,7 +72,7 @@ def load_multiple_runs(
# Convert any inputs to dictionary format.
if isinstance(path_collection, six.string_types):
path_collection = {path_collection: path_collection}
if not isinstance(path_collection, collections.Mapping):
if not isinstance(path_collection, abc.Mapping):
path_collection = {path: path for path in path_collection}

# Loop through multiple bsuite runs, and apply single_load_fn to each.
Expand Down
5 changes: 3 additions & 2 deletions bsuite/utils/gym_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ def reward_range(self) -> Tuple[float, float]:

def __getattr__(self, attr):
"""Delegate attribute access to underlying environment."""
return getattr(self._env, attr)

if "_env" in self.__dict__:
return getattr(self._env, attr)
return super().__getattribute__(attr)

def space2spec(space: gym.Space, name: Optional[str] = None):
"""Converts an OpenAI Gym space to a dm_env spec or nested structure of specs.
Expand Down
19 changes: 12 additions & 7 deletions bsuite/utils/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ def raw_env(self):

def __getattr__(self, attr):
"""Delegate attribute access to underlying environment."""
return getattr(self._env, attr)

if "_env" in self.__dict__:
return getattr(self._env, attr)
return super().__getattribute__(attr)

def _logarithmic_logging(episode: int,
ratios: Optional[Sequence[float]] = None) -> bool:
Expand Down Expand Up @@ -173,8 +174,9 @@ def step(self, action):

def __getattr__(self, attr):
"""Delegate attribute access to underlying environment."""
return getattr(self._env, attr)

if "_env" in self.__dict__:
return getattr(self._env, attr)
return super().__getattribute__(attr)

def _small_state_to_image(shape: Sequence[int],
observation: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -307,8 +309,9 @@ def bsuite_info(self) -> Dict[str, Any]:

def __getattr__(self, attr):
"""Delegate attribute access to underlying environment."""
return getattr(self._env, attr)

if "_env" in self.__dict__:
return getattr(self._env, attr)
return super().__getattribute__(attr)

class RewardScale(environments.Environment):
"""Reward Scale environment wrapper."""
Expand Down Expand Up @@ -370,4 +373,6 @@ def bsuite_info(self) -> Dict[str, Any]:

def __getattr__(self, attr):
"""Delegate attribute access to underlying environment."""
return getattr(self._env, attr)
if "_env" in self.__dict__:
return getattr(self._env, attr)
return super().__getattribute__(attr)