Skip to content

Commit f1685c1

Browse files
committed
Implement a sample Pokemon Battle Environment with server and client integration
1 parent 469a568 commit f1685c1

File tree

12 files changed

+983
-0
lines changed

12 files changed

+983
-0
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""
2+
Pokemon Battle Environment for OpenEnv.
3+
4+
This module provides OpenEnv integration for Pokemon battles via poke-env.
5+
6+
Example:
7+
>>> from envs.pokemon_env import PokemonEnv, PokemonAction
8+
>>>
9+
>>> # Connect to a running Pokemon Showdown server
10+
>>> env = PokemonEnv(battle_format="gen8randombattle")
11+
>>>
12+
>>> # Reset and interact
13+
>>> result = env.reset()
14+
>>> result = env.step(PokemonAction(action_type="move", action_index=0))
15+
>>> print(result.reward, result.done)
16+
>>>
17+
>>> # Cleanup
18+
>>> env.close()
19+
"""
20+
21+
from .client import PokemonEnv
22+
from .models import PokemonAction, PokemonObservation, PokemonState, PokemonData
23+
24+
__all__ = ["PokemonEnv", "PokemonAction", "PokemonObservation", "PokemonState", "PokemonData"]
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
"""
2+
Pokemon Battle Environment HTTP Client.
3+
4+
This module provides the client for connecting to a Pokemon Battle Environment server
5+
over HTTP.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from typing import Any, Dict, TYPE_CHECKING
11+
12+
from core.client_types import StepResult
13+
from core.http_env_client import HTTPEnvClient
14+
15+
from .models import PokemonAction, PokemonObservation, PokemonState, PokemonData
16+
17+
if TYPE_CHECKING:
18+
from core.containers.runtime import ContainerProvider
19+
20+
21+
class PokemonEnv(HTTPEnvClient[PokemonAction, PokemonObservation]):
22+
"""
23+
HTTP client for Pokemon Battle Environment.
24+
25+
This client connects to a Pokemon Battle Environment HTTP server and provides
26+
methods to interact with it: reset(), step(), and state access.
27+
28+
Example:
29+
>>> # Connect to a running server
30+
>>> client = PokemonEnv(base_url="http://localhost:8000")
31+
>>> result = client.reset()
32+
>>> print(result.observation.active_pokemon.species)
33+
>>>
34+
>>> # Take an action
35+
>>> result = client.step(PokemonAction(action_type="move", action_index=0))
36+
>>> print(result.reward, result.done)
37+
38+
Example with Docker:
39+
>>> # Automatically start container and connect
40+
>>> client = PokemonEnv.from_docker_image("pokemon-env:latest")
41+
>>> result = client.reset()
42+
>>> result = client.step(PokemonAction(action_type="switch", action_index=1))
43+
"""
44+
45+
def _step_payload(self, action: PokemonAction) -> Dict[str, Any]:
46+
"""
47+
Convert PokemonAction to JSON payload for step request.
48+
49+
Args:
50+
action: PokemonAction instance.
51+
52+
Returns:
53+
Dictionary representation suitable for JSON encoding.
54+
"""
55+
return {
56+
"action_type": action.action_type,
57+
"action_index": action.action_index,
58+
"move_id": action.move_id,
59+
"switch_pokemon": action.switch_pokemon,
60+
"mega_evolve": action.mega_evolve,
61+
"dynamax": action.dynamax,
62+
"terastallize": action.terastallize,
63+
}
64+
65+
def _parse_pokemon_data(self, data: Dict[str, Any]) -> PokemonData:
66+
"""Parse Pokemon data from JSON."""
67+
return PokemonData(
68+
species=data.get("species", "unknown"),
69+
hp_percent=data.get("hp_percent", 0.0),
70+
max_hp=data.get("max_hp", 100),
71+
current_hp=data.get("current_hp", 0),
72+
level=data.get("level", 50),
73+
status=data.get("status"),
74+
types=data.get("types", []),
75+
ability=data.get("ability"),
76+
item=data.get("item"),
77+
attack=data.get("attack", 0),
78+
defense=data.get("defense", 0),
79+
special_attack=data.get("special_attack", 0),
80+
special_defense=data.get("special_defense", 0),
81+
speed=data.get("speed", 0),
82+
boosts=data.get("boosts", {}),
83+
moves=data.get("moves", []),
84+
fainted=data.get("fainted", False),
85+
active=data.get("active", False),
86+
)
87+
88+
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[PokemonObservation]:
89+
"""
90+
Parse server response into StepResult[PokemonObservation].
91+
92+
Args:
93+
payload: JSON response from server.
94+
95+
Returns:
96+
StepResult with PokemonObservation.
97+
"""
98+
obs_data = payload.get("observation", {})
99+
100+
active_pokemon = None
101+
if obs_data.get("active_pokemon"):
102+
active_pokemon = self._parse_pokemon_data(obs_data["active_pokemon"])
103+
104+
opponent_active = None
105+
if obs_data.get("opponent_active_pokemon"):
106+
opponent_active = self._parse_pokemon_data(obs_data["opponent_active_pokemon"])
107+
108+
team = [self._parse_pokemon_data(p) for p in obs_data.get("team", [])]
109+
opponent_team = [self._parse_pokemon_data(p) for p in obs_data.get("opponent_team", [])]
110+
111+
observation = PokemonObservation(
112+
active_pokemon=active_pokemon,
113+
opponent_active_pokemon=opponent_active,
114+
team=team,
115+
opponent_team=opponent_team,
116+
available_moves=obs_data.get("available_moves", []),
117+
available_switches=obs_data.get("available_switches", []),
118+
legal_actions=obs_data.get("legal_actions", []),
119+
field_conditions=obs_data.get("field_conditions", {}),
120+
turn=obs_data.get("turn", 0),
121+
forced_switch=obs_data.get("forced_switch", False),
122+
can_mega_evolve=obs_data.get("can_mega_evolve", False),
123+
can_dynamax=obs_data.get("can_dynamax", False),
124+
can_terastallize=obs_data.get("can_terastallize", False),
125+
battle_format=obs_data.get("battle_format", "gen8randombattle"),
126+
battle_id=obs_data.get("battle_id"),
127+
done=payload.get("done", False),
128+
reward=payload.get("reward"),
129+
metadata=obs_data.get("metadata", {}),
130+
)
131+
132+
return StepResult(
133+
observation=observation,
134+
reward=payload.get("reward"),
135+
done=payload.get("done", False),
136+
)
137+
138+
def _parse_state(self, payload: Dict[str, Any]) -> PokemonState:
139+
"""
140+
Parse server response into PokemonState object.
141+
142+
Args:
143+
payload: JSON response from /state endpoint.
144+
145+
Returns:
146+
PokemonState object with environment state information.
147+
"""
148+
return PokemonState(
149+
episode_id=payload.get("episode_id"),
150+
step_count=payload.get("step_count", 0),
151+
battle_format=payload.get("battle_format", "gen8randombattle"),
152+
player_username=payload.get("player_username", "player"),
153+
server_url=payload.get("server_url", "localhost:8000"),
154+
battle_id=payload.get("battle_id"),
155+
is_battle_finished=payload.get("is_battle_finished", False),
156+
battle_winner=payload.get("battle_winner"),
157+
)
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""
2+
Data models for Pokemon Battle Environment.
3+
4+
This module defines the Action, Observation, and State types for Pokemon battles
5+
via poke-env integration.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from dataclasses import dataclass, field
11+
from typing import Any, Dict, List, Literal, Optional
12+
13+
from core.env_server import Action, Observation, State
14+
15+
16+
@dataclass
17+
class PokemonAction(Action):
18+
"""
19+
Action for Pokemon battles.
20+
21+
Attributes:
22+
action_type: Type of action - "move" or "switch"
23+
action_index: Index of the move (0-3) or switch target (0-5)
24+
move_id: Optional move identifier (e.g., "thunderbolt")
25+
switch_pokemon: Optional Pokemon to switch to (by species name or index)
26+
mega_evolve: Whether to mega evolve this turn (if applicable)
27+
dynamax: Whether to dynamax this turn (if applicable)
28+
terastallize: Whether to terastallize this turn (if applicable)
29+
"""
30+
action_type: Literal["move", "switch"] = "move"
31+
action_index: int = 0
32+
move_id: Optional[str] = None
33+
switch_pokemon: Optional[str] = None
34+
mega_evolve: bool = False
35+
dynamax: bool = False
36+
terastallize: bool = False
37+
38+
39+
@dataclass
40+
class PokemonData:
41+
"""Simplified Pokemon data for observations."""
42+
species: str
43+
hp_percent: float
44+
max_hp: int
45+
current_hp: int
46+
level: int
47+
status: Optional[str]
48+
types: List[str]
49+
ability: Optional[str]
50+
item: Optional[str]
51+
52+
attack: int
53+
defense: int
54+
special_attack: int
55+
special_defense: int
56+
speed: int
57+
58+
boosts: Dict[str, int] = field(default_factory=dict)
59+
moves: List[Dict[str, Any]] = field(default_factory=list)
60+
61+
fainted: bool = False
62+
active: bool = False
63+
64+
65+
@dataclass
66+
class PokemonObservation(Observation):
67+
"""
68+
Observation from Pokemon battle environment.
69+
70+
This represents the full battle state visible to the agent.
71+
72+
Attributes:
73+
active_pokemon: Currently active Pokemon on your side
74+
opponent_active_pokemon: Currently active opponent Pokemon
75+
team: Your full team of 6 Pokemon
76+
opponent_team: Opponent's team (may have limited visibility)
77+
available_moves: List of move indices you can use (0-3)
78+
available_switches: List of Pokemon indices you can switch to (0-5)
79+
legal_actions: Combined list of legal action descriptors
80+
field_conditions: Dict of field effects (weather, terrain, hazards, etc.)
81+
turn: Current turn number
82+
forced_switch: Whether you must switch (active Pokemon fainted)
83+
can_mega_evolve: Whether mega evolution is possible this turn
84+
can_dynamax: Whether dynamax is possible this turn
85+
can_terastallize: Whether terastallization is possible this turn
86+
battle_format: Battle format (e.g., "gen8randombattle", "gen8ou")
87+
"""
88+
active_pokemon: Optional[PokemonData] = None
89+
opponent_active_pokemon: Optional[PokemonData] = None
90+
team: List[PokemonData] = field(default_factory=list)
91+
opponent_team: List[PokemonData] = field(default_factory=list)
92+
93+
available_moves: List[int] = field(default_factory=list)
94+
available_switches: List[int] = field(default_factory=list)
95+
legal_actions: List[Dict[str, Any]] = field(default_factory=list)
96+
97+
field_conditions: Dict[str, Any] = field(default_factory=dict)
98+
turn: int = 0
99+
forced_switch: bool = False
100+
101+
can_mega_evolve: bool = False
102+
can_dynamax: bool = False
103+
can_terastallize: bool = False
104+
105+
battle_format: str = "gen8randombattle"
106+
battle_id: Optional[str] = None
107+
108+
109+
@dataclass
110+
class PokemonState(State):
111+
"""
112+
State for Pokemon battle environment.
113+
114+
Attributes:
115+
battle_format: Battle format being used
116+
player_username: Player's username
117+
server_url: Pokemon Showdown server URL
118+
battle_id: Current battle ID
119+
is_battle_finished: Whether the battle has concluded
120+
battle_winner: Winner of the battle (if finished)
121+
"""
122+
battle_format: str = "gen8randombattle"
123+
player_username: str = "player"
124+
server_url: str = "localhost:8000"
125+
battle_id: Optional[str] = None
126+
is_battle_finished: bool = False
127+
battle_winner: Optional[str] = None

0 commit comments

Comments
 (0)