diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..d099bb0 Binary files /dev/null and b/.DS_Store differ diff --git a/llm/README.md b/llm/README.md new file mode 100644 index 0000000..43db7b6 --- /dev/null +++ b/llm/README.md @@ -0,0 +1,245 @@ +# Terra: Hybrid LLM and RL Approach + +A sophisticated framework combining Large Language Models (LLMs) and Reinforcement Learning (RL) for intelligent map exploration and excavation tasks. + +## πŸ“‹ Table of Contents + +- [Overview](#overview) +- [Architecture](#architecture) +- [Installation](#installation) +- [Configuration](#configuration) +- [Usage](#usage) + - [Basic Usage](#basic-usage) + - [Running on Clusters](#running-on-clusters) +- [Supported Models](#supported-models) +- [Level Index Reference](#level-index-reference) +- [Prompts Documentation](#prompts-documentation) +- [Project Structure](#project-structure) +- [Contributing](#contributing) + +## 🎯 Overview + +This module implements a hybrid approach that leverages both LLMs and RL policies to efficiently process excavation tasks across partitioned maps. The system intelligently delegates between fast RL policies and more sophisticated LLM-based decision making based on the complexity of each partition. + +## πŸ—οΈ Architecture + +![Hybrid VLM Architecture](assets/VLM_Schema.png) + +The system operates in four main phases: + +### 1. **Map Partitioning** +The map is divided into manageable sections using one of three methods: +- **Manual**: Direct specification of partition boundaries +- **Random**: Automated partitioning with configurable constraints + - Minimum width/height (percentage-based) + - Minimum target count per partition +- **LLM-based**: Intelligent partitioning using language models + +### 2. **Partition Processing** +Each partition is managed by a master LLM agent that decides between: +- **RL Policy** (`delegate_to_RL`): Fast, pre-trained policy for routine tasks +- **LLM Policy** (`delegate_to_LLM`): Sophisticated decision-making for complex scenarios + +### 3. **Synchronization** +- Global map updates after partition completion +- Cross-partition information synchronization +- Periodic re-evaluation of delegation strategy + +### 4. **Iteration** +- Process continues until map completion +- Automatic progression to next map with fresh partitioning + +## πŸš€ Installation + +### Prerequisites + +1. Ensure Terra and Terra baselines are installed +2. Install additional dependencies from `environment_llm.yaml`: + ```bash + conda env create -f environment_llm.yaml + conda activate terra-llm + ``` + +### API Keys Setup + +Export the API keys for your chosen model providers: + +```bash +# Google Models (Gemini) +export GOOGLE_API_KEY="your-api-key-here" + +# OpenAI Models (GPT, o3) +export OPENAI_API_KEY="your-api-key-here" + +# Anthropic Models (Claude) +export ANTHROPIC_API_KEY="your-api-key-here" +``` + +## βš™οΈ Configuration + +The main configuration file is [`config_llm.yaml`](config_llm.yaml). Key parameters include: + +- Partitioning strategy and constraints +- Model selection and API settings +- RL policy paths +- Iteration and synchronization intervals +- Logging and debugging options + +Prompts can be customized by modifying files in the [`prompts`](prompts) folder. + +## πŸ“– Usage + +### Basic Usage + +Run the main script with the following command: + +```bash +DATASET_PATH= DATASET_SIZE= python -m llm.main_llm \ + --model_name \ + --model_key \ + --num_timesteps \ + -s \ + -n \ + -run \ + --level_index +``` + +#### Parameters + +| Parameter | Description | +|-----------|-------------| +| `DATASET_PATH` | Path to the Terra-generated map dataset | +| `DATASET_SIZE` | Number of maps in the dataset | +| `--model_name` | LLM model identifier (see [Supported Models](#supported-models)) | +| `--model_key` | Provider key: `gpt`, `gemini`, or `claude` | +| `--num_timesteps` | Maximum steps per episode | +| `-s` | Random seed for reproducibility | +| `-n` | Number of parallel environments | +| `-run` | Path to pre-trained RL policy | +| `--level_index` | Map difficulty level (see [Level Index](#level-index-reference)) | + +### Running big maps +After generating big maps (currently tested only with 128x128) it is possible to run the code as described in the previous section. + +> [!WARNING] +> Make sure to set the new map size in `config_llm.yaml` file and adapt the visualization in Terra `terra/viz/game/setting.py`. A value of 384 can be used for 128x128 (2x the value used for 64x64)! + +### Running on Clusters + +Ensure API keys are properly configured in your cluster environment. Consult your cluster's [documentation](https://scicomp.ethz.ch/wiki/Main_Page) for specific setup instructions. + +For SLURM-based clusters (e.g., ETH ZΓΌrich Euler): + +```bash +sbatch run_levels.slurm +``` + +Adapt the parameter of the script to the actual values. +This script will run the evaluation on different nodes for each levels (see mapping in a later section). + +For random partitioning and choosing the best partition (with higher coverage) among all the trials it is possible to use + +```bash +sbatch run_levels_random.slurm +``` + +where +--n_maps is the number of test map and +--n_partitions_per_map is the number of trials for the random partitions. + +Note that in this script you need to run each level independently and also change the `config.yaml` file to use random partitions. + +To aggregate the results you need to use the `result_aggregator.py` script to get the final benchmarks. + + +## πŸ€– Supported Models + +The framework supports multiple LLM providers through [LiteLLM](https://docs.litellm.ai/docs/): + +### OpenAI +- `gpt-4o` +- `gpt-4.1` +- `gpt-5` +- `o4-mini` +- `o3` +- `o3-mini` + +### Google +- `gemini-1.5-flash-latest` +- `gemini-2.0-flash` +- `gemini-2.5-pro` +- `gemini-2.5-flash` + +### Anthropic +- `claude-3-haiku-20240307` +- `claude-3-7-sonnet-20250219` +- `claude-opus-4-20250514` +- `claude-sonnet-4-20250514` + +For the latest supported models, refer to the [LiteLLM providers documentation](https://docs.litellm.ai/docs/providers). + +## πŸ“Š Level Index Reference + +| Level Name | Index | Description | +|------------|-------|-------------| +| All levels | None | Run all available levels | +| Foundations | 0 | Basic excavation tasks | +| Single Trenches | 1 | Simple linear excavations | +| Double Trenches | 2 | Parallel excavation paths | +| Double Diagonal | 3 | Angled parallel paths | +| Triple Trenches | 4 | Complex parallel structures | +| Triple Diagonal | 5 | Advanced angled patterns | + +## πŸ“ Prompts Documentation + +The system uses three types of specialized prompts: + +### 1. **Partitioning Agent** +- [`partitioning.txt`](prompts/partitioning.txt): Standard adaptive partitioning +- [`partitioning_exact.txt`](prompts/partitioning_exact.txt): Fixed excavator count (experimental) + +### 2. **Delegation Agent** +- [`delegation_no_intervention.txt`](prompts/delegation_no_intervention.txt): Production-ready autonomous delegation +- [`delegation.txt`](prompts/delegation.txt): Experimental intervention mode (not fully tested) + +### 3. **Excavator Agent** +- [`excavator_llm_simple.txt`](prompts/excavator_llm_simple.txt): System prompt for LLM excavator control +- [`excavator_action.txt`](prompts/excavator_action.txt): Context-aware status updates + +## πŸ“ Project Structure + +``` +llm/ +β”œβ”€β”€ assets/ # Media and documentation assets +β”œβ”€β”€ prompts/ # Customizable prompt templates +β”‚ β”œβ”€β”€ delegation_no_intervention.txt +β”‚ β”œβ”€β”€ delegation.txt +β”‚ β”œβ”€β”€ excavator_action.txt +β”‚ β”œβ”€β”€ excavator_llm_simple.txt +β”‚ β”œβ”€β”€ partitioning_exact.txt +β”‚ └── partitioning.txt +β”œβ”€β”€ __init__.py +β”œβ”€β”€ config_llm.yaml # Main configuration file +β”œβ”€β”€ env_llm.py # Individual environment management +β”œβ”€β”€ env_manager_llm.py # Global environment orchestration +β”œβ”€β”€ eval_llm.py # Benchmarking utilities +β”œβ”€β”€ main_llm.py # Entry point +β”œβ”€β”€ prompt_manager_llm.py # Prompt loading and management +β”œβ”€β”€ session_manager_llm.py # LLM agent lifecycle management +└── utils_llm.py # Helper functions and utilities +``` + +## 🀝 Contributing + +Found a bug or have a feature request? Please open an issue and tag @gioelemo. We welcome contributions that improve the framework's capabilities or documentation. + +### Development Guidelines + +1. Follow existing code style and conventions +2. Add tests for new functionality +3. Update documentation as needed +4. Submit pull requests with clear descriptions + +--- + +**Note**: This is an active research project. Performance may vary based on model selection and task complexity. \ No newline at end of file diff --git a/llm/__init__.py b/llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llm/adk_llm.py b/llm/adk_llm.py new file mode 100644 index 0000000..1445931 --- /dev/null +++ b/llm/adk_llm.py @@ -0,0 +1,399 @@ +from google import genai +from google.genai import types + +import base64 +import json +import cv2 +import re +import os +import logging +from typing import Optional, Dict, Any, List, Tuple +# Set up logging + +import asyncio +logger = logging.getLogger("AutonomousExcavatorADK.llms") + + +class LLM_query: + def __init__(self, model_name=None, model=None, system_message=None, action_size=None, runner=None, user_id=None, session_id=None): + """ + Initialize the Agent with the specified model and environment. + + Args: + model_name: The specific model name to use + model: The model provider key ('gpt4', 'gpt4o', 'claude', 'gemini') + system_message: The system prompt to use + env: The Gymnasium environment + """ + self.model_key = model + logger.info(f'Model Key: {self.model_key}') + + self.model_name = model_name + logger.info(f'Model Name: {self.model_name}') + + self.messages = [] + self.system_message = system_message + self.action_size = action_size + #self.action_space = self.env.action_space.n + self.action_space = action_size + self.reset_count = 0 + self.runner = runner + self.user_id = user_id + self.session_id = session_id + + def encode_image(self, cv_image): + _, buffer = cv2.imencode(".jpg", cv_image) + return base64.b64encode(buffer).decode("utf-8") + + async def query_LLM(self): + #TODO: Optimize the lenght of the messages to be sent to the model + + response_text = "" + + message = self.messages[-1] + async for event in self.runner.run_async(user_id=self.user_id, session_id=self.session_id, new_message=message): + print(f" [Event] Author: {event.author}, Type: {type(event).__name__}, Final: {event.is_final_response()}") + if event.is_final_response(): + if event.content and event.content.parts: + response_text = event.content.parts[0].text + elif event.actions and event.actions.escalate: + response_text = f"Agent escalated: {event.error_message or 'No specific message'}" + break + else: + if event.content and event.content.parts and event.content.parts[0].text: + response_text += event.content.parts[0].text + + + self.response = response_text + self.reset_count = 0 + + # return the output of the model + return self.response + + def reset_model(self): + + if self.reset_count >= 3: + return + + self.messages = [] + + self.reset_count += 1 + + print('Model is re-initiated...') + + def clean_model_output(self, output): + """ + Clean the model output to ensure it's valid JSON. + + Args: + output: The raw output from the model + + Returns: + Cleaned output string + """ + if not output: + logger.warning("Received empty output from model") + return "" + + # Remove any unescaped newline characters within the JSON string values + cleaned_output = re.sub(r'(?= 3: + return None + + if response == 'idchoicescreatedmodelobjectsystem_fingerprintusage': + response = self.get_response() + + return response + + + def generate_response(self, path) -> str: + response = self.get_response() + # Check if it is just reasoning or actual action output + self.path = path + + response_text = self.clean_response(response, path) + + action_output = self.check_action(response_text) + + return action_output, response_text + + def add_user_message(self, frame=None, user_msg=None, local_map=None, traversability_map=None): + + if self.model_key == 'gemini' or self.model_key == 'claude' or self.model_key == 'gpt': + if frame is not None and user_msg is not None and traversability_map is not None and local_map is not None: + image_data = self.encode_image(frame) + image_data_traversability = self.encode_image(traversability_map) + image_data_local_map = self.encode_image(local_map) + + # Create the list of Part objects + parts = [ + types.Part.from_bytes(data=image_data, mime_type="image/jpeg"), # Create image Part from bytes/base64 + types.Part.from_bytes(data=image_data_local_map, mime_type="image/jpeg"), + types.Part.from_bytes(data=image_data_traversability, mime_type="image/jpeg"), + types.Part.from_text(text=user_msg) # Create text Part from string + ] + + # Create a Content object with the role and the list of parts + user_content = types.Content(role="user", parts=parts) + + # Append the Content object to your messages list + self.messages.append(user_content) + + elif frame is not None and user_msg is not None and traversability_map is None: + image_data = self.encode_image(frame) + + # Create the list of Part objects + parts = [ + types.Part.from_bytes(data=image_data, mime_type="image/jpeg"), # Create image Part from bytes/base64 + types.Part.from_text(text=user_msg) # Create text Part from string + ] + + # Create a Content object with the role and the list of parts + user_content = types.Content(role="user", parts=parts) + + # Append the Content object to your messages list + self.messages.append(user_content) + + elif frame is not None and user_msg is None: + image_data = self.encode_image(frame) + + # Create the list of Part objects + parts = [ + types.Part.from_bytes(data=image_data, mime_type="image/jpeg"), # Create image Part from bytes/base64 + ] + + # Create a Content object with the role and the list of parts + user_content = types.Content(role="user", parts=parts) + + # Append the Content object to your messages list + self.messages.append(user_content) + + elif frame is None and user_msg is not None: + user_content = types.Content( + role="user", + parts=[types.Part.from_text(text=user_msg)] + ) + self.messages.append(user_content) + else: + pass + + # Ensure self.messages only contains types.Content objects + self.messages = [msg for msg in self.messages if isinstance(msg, types.Content)] + + def add_assistant_message(self, demo_str=None): + + if self.model_key =='gemini' or self.model_key == 'claude' or self.model_key == 'gpt': + if demo_str is not None: + self.messages.append( + { + "role": "model", + "parts": types.Part.from_text(text=demo_str), + } + ) + demo_str = None + return + + if self.response is not None: + #assistant_msg = self.response.text + assistant_msg = self.response + self.messages.append( + { + "role": "model", + "parts": types.Part.from_text(text=assistant_msg), + } + ) + + else: + self.messages.append( + { + "role": "assistant", + "content": [ + {"type": "text", "text": ' '}, + ] + } + ) + + def delete_messages(self): + print('Deleting Set of Messages...') + + if self.model_key == 'gemini' or self.model_key == 'claude' or self.model_key == 'gpt': + # Check if the first message is a system message (using .role attribute) + if self.messages and isinstance(self.messages[0], types.Content) and self.messages[0].role == 'system': + print("Removing oldest user/model pair after system message.") + # Delete the oldest user message (at index 1 after system) + # and the oldest model message (at index 2 after system) + # Remove higher index first + if len(self.messages) > 2: # Ensure there are at least 3 messages (system, user, model) + self.messages.pop(2) # Remove model + self.messages.pop(1) # Remove user + elif len(self.messages) > 1: # If only system and user + self.messages.pop(1) # Remove user + # If only system message, do nothing as we need a pair to remove + + else: + print("Removing oldest user/model pair.") + # Delete the oldest user message (at index 0) + # and the oldest model message (at index 1) + # Remove higher index first + if len(self.messages) > 1: # Ensure there is at least a user and model message + self.messages.pop(1) # Remove model + self.messages.pop(0) # Remove user + elif len(self.messages) > 0: # If only user message + self.messages.pop(0) # Remove user + # If list is empty, do nothing + + else: + print("Message history length below threshold. No messages deleted.") \ No newline at end of file diff --git a/llm/assets/VLM_Schema.png b/llm/assets/VLM_Schema.png new file mode 100644 index 0000000..f6194ab Binary files /dev/null and b/llm/assets/VLM_Schema.png differ diff --git a/llm/config_llm.yaml b/llm/config_llm.yaml new file mode 100644 index 0000000..b3465b7 --- /dev/null +++ b/llm/config_llm.yaml @@ -0,0 +1,44 @@ +# Configuration file for LLM (Large Language Model) settings +# This file contains various settings for the LLM environment, agent behavior, partitioning, application, +# rendering, intervention system, and enhanced delegation. + +# Environment settings +original_map_size: 64 # Original map size for the environment + +# Agent behavior +force_delegate_to_rl: true # Force delegation to RL agent for testing +force_delegate_to_llm: false # Force delegation to LLM agent for testing +llm_call_frequency: 15 # Number of steps between LLM calls + +# Partitioning settings +use_manual_partitioning: false # Use manual partitioning +use_random_partitioning: true # Use random partitioning +max_num_partitions: 1 # Maximum number of partitions for LLM (Partitioning Agent) +use_exact_number_of_partitions: false # Use exact number of partitions for LLM (Partitioning Agent) +use_image_prompt: true # Use image prompt for LLM (Partitioning Agent) +use_exclusive_assignment: false # Use exclusive assignment for partition-specific target maps + +# Application settings +app_name: "ExcavatorGameApp" # Application name for ADK +user_id: "user_1" # User ID for ADK +session_id: "session_001" # Session ID for ADK +compute_bench_stats: true # Compute and print benchmark statistics + +# Rendering settings +grid_rendering: false # Use grid rendering for partitions +use_rendering: true # Use rendering for the environment +use_display: true # Use display for the environment +visualize_partitions: true # Visualize partitions +save_video: false # Save video of the environment +fps: 10 # Frames per second for video rendering + +# Intervention System Settings +enable_intervention: false # Enable intervention system +intervention_check_frequency: 5 # Check every N steps +stuck_detection_window: 10 # Number of recent steps to analyze +min_reward_threshold: 0.001 # Minimum reward expected in window +intervention_cooldown: 3 # Minimum steps between interventions + +# Enhanced delegation +enhanced_delegation_prompts: false # Use enhanced delegation prompts +include_stuck_context: false # Include stuck context in prompts \ No newline at end of file diff --git a/llm/env_llm.py b/llm/env_llm.py new file mode 100644 index 0000000..5925e26 --- /dev/null +++ b/llm/env_llm.py @@ -0,0 +1,237 @@ +from terra.env import TerraEnvBatch +import jax.numpy as jnp + +from llm.utils_llm import * + +class TerraEnvBatchWithMapOverride(TerraEnvBatch): + """ + Extended version of TerraEnvBatch that supports map overrides. + This class enables working with subsets of larger maps. + """ + def reset_with_map_override(self, env_cfgs, rngs, custom_pos=None, custom_angle=None, + target_map_override=None, traversability_mask_override=None, + padding_mask_override=None, dumpability_mask_override=None, + dumpability_mask_init_override=None, action_map_override=None, + agent_config_override=None): + """ + Reset the environment with custom map overrides. + + Args: + env_cfgs: Environment configurations + rngs: Random number generators + custom_pos: Custom initial position + custom_angle: Custom initial angle + target_map_override: Override for target map + traversability_mask_override: Override for traversability mask + padding_mask_override: Override for padding mask + dumpability_mask_override: Override for dumpability mask + dumpability_mask_init_override: Override for initial dumpability mask + action_map_override: Override for action map + + Returns: + Initial timestep + """ + # Determine the new edge length based on overrides + new_edge_length = None + if target_map_override is not None: + if len(target_map_override.shape) == 2: + new_edge_length = target_map_override.shape[0] # Use the first dimension + else: + new_edge_length = target_map_override.shape[1] # Use the second dimension for batched maps + elif action_map_override is not None: + if len(action_map_override.shape) == 2: + new_edge_length = action_map_override.shape[0] + else: + new_edge_length = action_map_override.shape[1] + + # If we have a new edge length, update the env_cfg + # Update the env_cfg with new map size and agent config if provided + if new_edge_length is not None or agent_config_override is not None: + + # Update maps config if new edge length is provided + if new_edge_length is not None: + updated_maps_config = env_cfgs.maps._replace( + edge_length_px=jnp.array([new_edge_length], dtype=jnp.int32) + ) + else: + updated_maps_config = env_cfgs.maps + + # Update agent config if override is provided + if agent_config_override is not None: + updated_agent_config = env_cfgs.agent._replace(**agent_config_override) + else: + updated_agent_config = env_cfgs.agent + + # Update the env_cfgs with the new configurations + env_cfgs = env_cfgs._replace( + maps=updated_maps_config, + agent=updated_agent_config + ) + + # First reset with possibly updated env_cfgs + timestep = self.reset(env_cfgs, rngs, custom_pos, custom_angle) + + # Then override maps if provided - use completely new arrays + if target_map_override is not None: + # Add batch dimension if needed + if len(target_map_override.shape) == 2: + target_map_override = target_map_override[None, ...] + timestep = timestep._replace( + state=timestep.state._replace( + world=timestep.state.world._replace( + target_map=timestep.state.world.target_map._replace( + map=target_map_override + ) + ) + ) + ) + + if traversability_mask_override is not None: + if len(traversability_mask_override.shape) == 2: + traversability_mask_override = traversability_mask_override[None, ...] + timestep = timestep._replace( + state=timestep.state._replace( + world=timestep.state.world._replace( + traversability_mask=timestep.state.world.traversability_mask._replace( + map=traversability_mask_override + ) + ) + ) + ) + + if padding_mask_override is not None: + if len(padding_mask_override.shape) == 2: + padding_mask_override = padding_mask_override[None, ...] + timestep = timestep._replace( + state=timestep.state._replace( + world=timestep.state.world._replace( + padding_mask=timestep.state.world.padding_mask._replace( + map=padding_mask_override + ) + ) + ) + ) + + if dumpability_mask_override is not None: + if len(dumpability_mask_override.shape) == 2: + dumpability_mask_override = dumpability_mask_override[None, ...] + timestep = timestep._replace( + state=timestep.state._replace( + world=timestep.state.world._replace( + dumpability_mask=timestep.state.world.dumpability_mask._replace( + map=dumpability_mask_override + ) + ) + ) + ) + + if dumpability_mask_init_override is not None: + if len(dumpability_mask_init_override.shape) == 2: + dumpability_mask_init_override = dumpability_mask_init_override[None, ...] + timestep = timestep._replace( + state=timestep.state._replace( + world=timestep.state.world._replace( + dumpability_mask_init=timestep.state.world.dumpability_mask_init._replace( + map=dumpability_mask_init_override + ) + ) + ) + ) + + if action_map_override is not None: + if len(action_map_override.shape) == 2: + action_map_override = action_map_override[None, ...] + timestep = timestep._replace( + state=timestep.state._replace( + world=timestep.state.world._replace( + action_map=timestep.state.world.action_map._replace( + map=action_map_override + ) + ) + ) + ) + + # Update the env_cfg in the timestep state to ensure consistency + if new_edge_length is not None or agent_config_override is not None: + # Update the state's env_cfg if it exists + if hasattr(timestep.state, 'env_cfg'): + state_env_cfg = timestep.state.env_cfg + + if new_edge_length is not None: + state_env_cfg = state_env_cfg._replace( + maps=state_env_cfg.maps._replace( + edge_length_px=jnp.array([new_edge_length], dtype=jnp.int32) + ) + ) + + if agent_config_override is not None: + state_env_cfg = state_env_cfg._replace( + agent=state_env_cfg.agent._replace(**agent_config_override) + ) + + timestep = timestep._replace( + state=timestep.state._replace(env_cfg=state_env_cfg) + ) + + # Update the timestep's env_cfg if it exists at the top level + if hasattr(timestep, 'env_cfg'): + timestep_env_cfg = timestep.env_cfg + + if new_edge_length is not None: + timestep_env_cfg = timestep_env_cfg._replace( + maps=timestep_env_cfg.maps._replace( + edge_length_px=jnp.array([new_edge_length], dtype=jnp.int32) + ) + ) + + if agent_config_override is not None: + timestep_env_cfg = timestep_env_cfg._replace( + agent=timestep_env_cfg.agent._replace(**agent_config_override) + ) + # We need to manually update the observation to match the new maps + updated_obs = dict(timestep.observation) + + # Update all map-related observations + if target_map_override is not None and 'target_map' in updated_obs: + updated_obs['target_map'] = target_map_override + + if action_map_override is not None and 'action_map' in updated_obs: + updated_obs['action_map'] = action_map_override + + if dumpability_mask_override is not None and 'dumpability_mask' in updated_obs: + updated_obs['dumpability_mask'] = dumpability_mask_override + + if traversability_mask_override is not None and 'traversability_mask' in updated_obs: + updated_obs['traversability_mask'] = traversability_mask_override + + if padding_mask_override is not None and 'padding_mask' in updated_obs: + updated_obs['padding_mask'] = padding_mask_override + + if dumpability_mask_init_override is not None and 'dumpability_mask_init' in updated_obs: + updated_obs['dumpability_mask_init'] = dumpability_mask_init_override + + # Return the timestep with the updated observation + timestep = timestep._replace(observation=updated_obs) + + return timestep + + +class LargeMapTerraEnv(TerraEnvBatchWithMapOverride): + """A version of TerraEnvBatch specifically for 128x128 maps""" + + def reset_with_map_override(self, env_cfgs, rngs, custom_pos=None, custom_angle=None, + target_map_override=None, traversability_mask_override=None, + padding_mask_override=None, dumpability_mask_override=None, + dumpability_mask_init_override=None, action_map_override=None, + agent_config_override=None): + """Reset with 64x64 map overrides - ensures shapes are consistent""" + + # Call the TerraEnvBatchWithMapOverride's reset_with_map_override method directly + return TerraEnvBatchWithMapOverride.reset_with_map_override( + self, env_cfgs, rngs, custom_pos, custom_angle, + target_map_override, traversability_mask_override, + padding_mask_override, dumpability_mask_override, + dumpability_mask_init_override, action_map_override, + agent_config_override + ) + \ No newline at end of file diff --git a/llm/env_manager_llm.py b/llm/env_manager_llm.py new file mode 100644 index 0000000..4599957 --- /dev/null +++ b/llm/env_manager_llm.py @@ -0,0 +1,1801 @@ +import numpy as np +import jax + +import jax.numpy as jnp + +from terra.env import TerraEnv + +from llm.utils_llm import * +from llm.adk_llm import * + +import pygame as pg + +from llm.env_llm import LargeMapTerraEnv + +class EnvironmentsManager: + """ + Manages completely separate environments for large map and small maps. + Each environment has its own timestep, configuration, and state. + Only map data is exchanged between environments. + """ + + def __init__(self, seed, global_env_config, small_env_config=None, shuffle_maps=False, rendering=False, display=False, size=64): + """ + Initialize with separate configurations for large and small environments. + + Args: + seed: Random seed for reproducibility + global_env_config: Environment configuration for the large global map + small_env_config: Environment configuration for small maps (or None to derive from global) + num_partitions: Number of partitions for the large map + shuffle_maps: Whether to shuffle maps + """ + self.rng = jax.random.PRNGKey(seed) + self.global_env_config = global_env_config + self.shuffle_maps = shuffle_maps + # Create a custom small environment config if not provided + if small_env_config is None: + self.small_env_config = self._derive_small_environment_config() + else: + self.small_env_config = small_env_config + + # Overlapping partition data - will be set externally + self.partitions = [] + self.overlap_map = {} # Maps partition_id -> set of overlapping partition_ids + self.overlap_regions = {} # Cache overlap region calculations + self.rendering = rendering + self.display = display + + # Initialize the global environment (128x128) with LargeMapTerraEnv + print("Initializing LargeMapTerraEnv for global environment...") + + self.global_env = LargeMapTerraEnv( + rendering=rendering, + n_envs_x_rendering=1, + n_envs_y_rendering=1, + display=display, + shuffle_maps=shuffle_maps, + ) + + self.map_size_px = size # Set map size based on provided size parameter + print(f"Global map size: {self.map_size_px}x{self.map_size_px} pixels") + self.small_agent_config = { + 'height': jnp.array([9], dtype=jnp.int32), + 'width': jnp.array([5], dtype=jnp.int32) + } + self.big_agent_config = { + 'height': jnp.array([19], dtype=jnp.int32), + 'width': jnp.array([11], dtype=jnp.int32) + } + # Initialize the small environment with regular TerraEnv (non-batched) + print("Initializing TerraEnv for small environment...") + self.small_env = TerraEnv.new( + maps_size_px=64, + rendering=False, + n_envs_x=1, + n_envs_y=1, + display=False, + agent_config_override=self.small_agent_config + ) + + # Store global map data + self.global_maps = { + 'target_map': None, + 'action_map': None, + 'dumpability_mask': None, + 'dumpability_mask_init': None, + 'padding_mask': None, + 'traversability_mask': None, + 'trench_axes': None, + 'trench_type': None, + } + + # Define partition scheme + self.partitions = [] + + # Initialize global environment and extract maps + self._initialize_global_environment() + + # Track which environment is currently being displayed + self.current_display_env = "global" # or "small" + + # Track small environment state + self.small_env_timestep = None + self.current_partition_idx = None + + def _partitions_overlap(self, i: int, j: int) -> bool: + """Check if two partitions overlap.""" + p1_coords = self.partitions[i]['region_coords'] + p2_coords = self.partitions[j]['region_coords'] + + y1_start, x1_start, y1_end, x1_end = p1_coords + y2_start, x2_start, y2_end, x2_end = p2_coords + + # Check for overlap - rectangles overlap if they overlap in BOTH dimensions + y_overlap = (y1_start <= y2_end) and (y2_start <= y1_end) + x_overlap = (x1_start <= x2_end) and (x2_start <= x1_end) + + overlap_exists = y_overlap and x_overlap + + return overlap_exists + + def set_partitions(self, partitions): + """ + Set the partitions and compute overlap relationships. + """ + print(f"\n=== SETTING PARTITIONS ===") + self.partitions = partitions + + print(f"Partitions set:") + for i, partition in enumerate(self.partitions): + print(f" Partition {i}: {partition}") + + # Use the overlap computation + self._compute_overlap_relationships() + + print(f"Set {len(self.partitions)} partitions with overlaps computed.") + + def initialize_with_fixed_overlaps(self, partitions): + """ + Initialize partitions with fixed overlap detection. + """ + + # Set partitions using the fixed method + self.set_partitions(partitions) + + def step_simple(self, partition_idx: int, action, partition_states: dict): + """ + Simple step function - just steps the environment without any synchronization. + Synchronization happens separately. + """ + partition_state = partition_states[partition_idx] + current_state = partition_state['timestep'].state + current_env_cfg = partition_state['timestep'].env_cfg + + # Extract required data for step + current_target_map = current_state.world.target_map.map + current_padding_mask = current_state.world.padding_mask.map + current_dumpability_mask_init = current_state.world.dumpability_mask_init.map + current_trench_axes = current_state.world.trench_axes + current_trench_type = current_state.world.trench_type + current_action_map = current_state.world.action_map.map + + # Step the environment + new_timestep = self.small_env.step( + state=current_state, + action=action, + target_map=current_target_map, + padding_mask=current_padding_mask, + trench_axes=current_trench_axes, + trench_type=current_trench_type, + dumpability_mask_init=current_dumpability_mask_init, + action_map=current_action_map, + env_cfg=current_env_cfg + ) + + return new_timestep + + def _create_clean_env_config(self): + """Create a clean environment config for 64x64 maps without batch dimensions""" + # If you have a reference to the original config structure, use it + # Otherwise, create a minimal one + try: + # Try to create from the global config but clean it up + base_cfg = self.small_env_config if hasattr(self, 'small_env_config') else self.global_env_config + + # Remove any batch dimensions by taking the first element + def unbatch(x): + if hasattr(x, 'shape') and len(x.shape) > 0 and x.shape[0] == 1: + return x[0] + return x + + clean_cfg = jax.tree_map(unbatch, base_cfg) + return clean_cfg + + except Exception as e: + print(f"Warning: Could not clean config: {e}") + # Return the original config and hope for the best + return self.global_env_config + + def initialize_small_environment(self, partition_idx): + """ + Initialize the small environment with map data from a specific global map partition. + Uses TerraEnv (non-batched) for better performance and simpler interface. + """ + if partition_idx < 0 or partition_idx >= len(self.partitions): + raise ValueError(f"Invalid partition index: {partition_idx}") + + partition = self.partitions[partition_idx] + region_coords = partition['region_coords'] + custom_pos = partition['start_pos'] + custom_angle = partition['start_angle'] + + if self.map_size_px == 64: + sub_maps = { + 'target_map': create_sub_task_target_map_64x64(self.global_maps['target_map'], region_coords), #ok + 'action_map': self.global_maps['action_map'], + 'dumpability_mask': self.global_maps['dumpability_mask'], + 'dumpability_mask_init': self.global_maps['dumpability_mask_init'], + 'padding_mask': self.global_maps['padding_mask'], + 'traversability_mask': self.global_maps['traversability_mask'], #OK, keep the full traversability mask + } + else: + sub_maps = { + 'target_map': create_sub_task_target_map_64x64_big(self.global_maps['target_map'], region_coords), + 'action_map': create_sub_task_action_map_64x64_big(self.global_maps['action_map'], region_coords), + 'dumpability_mask': create_sub_task_dumpability_mask_64x64_big(self.global_maps['dumpability_mask'], region_coords), + 'dumpability_mask_init': create_sub_task_dumpability_mask_64x64_big(self.global_maps['dumpability_mask_init'], region_coords), + 'padding_mask': create_sub_task_padding_mask_64x64_big(self.global_maps['padding_mask'], region_coords), + 'traversability_mask': create_sub_task_traversability_mask_64x64_big(self.global_maps['traversability_mask'], region_coords), + } + + # Fix trench data shapes - remove batch dimension for single environment + trench_axes = self.global_maps['trench_axes'] + trench_type = self.global_maps['trench_type'] + + # Remove batch dimension if present + if trench_axes.shape[0] == 1: + trench_axes = trench_axes[0] # Shape: (3, 3) instead of (1, 3, 3) + if trench_type.shape[0] == 1: + trench_type = trench_type[0] # Shape: () instead of (1,) + trench_axes = trench_axes.astype(jnp.float32) + trench_type = trench_type.astype(jnp.int32) + + # Reset the small environment using TerraEnv's interface (no batching) + clean_env_cfg = self._create_clean_env_config() + print(f"Environment config created") + + self.rng, reset_key = jax.random.split(self.rng) + + try: + print("Resetting small environment with custom map data...") + + # Use TerraEnv's reset method directly - much cleaner interface + small_timestep = self.small_env.reset( + key=reset_key, + target_map=sub_maps['target_map'], + padding_mask=sub_maps['padding_mask'], + trench_axes=trench_axes, + trench_type=trench_type, + dumpability_mask_init=sub_maps['dumpability_mask_init'], + action_map=sub_maps['action_map'], + env_cfg=clean_env_cfg, + custom_pos=custom_pos, + custom_angle=custom_angle + ) + print("Small environment reset successfully.") + + # Store current small environment state + self.small_env_timestep = small_timestep + self.current_partition_idx = partition_idx + + # Set partition status to active + self.partitions[partition_idx]['status'] = 'active' + + # Switch display to small environment + self.current_display_env = "small" + return small_timestep + + except Exception as e: + import traceback + print(f"Error initializing small environment: {e}") + print(traceback.format_exc()) + raise + + def _update_world_map(self, world_state, map_name: str, new_map): + """ + Helper method to update a specific map in the world state. + This creates a new world state with the updated map. + """ + # Get the current map object + current_map_obj = getattr(world_state, map_name) + + # Create new map object with updated data + updated_map_obj = current_map_obj._replace(map=new_map) + + # Create new world state with updated map + updated_world = world_state._replace(**{map_name: updated_map_obj}) + + return updated_world + + def _derive_small_environment_config(self): + """ + Derive a configuration for small environments based on the global config. + Returns a modified config with appropriate size settings. + """ + # Create a copy of the global environment config + small_config = jax.tree_map(lambda x: x, self.global_env_config) + + # Modify map size and other relevant parameters + # This requires knowledge of the config structure + if hasattr(small_config, 'maps') and hasattr(small_config.maps, 'edge_length_px'): + small_config = small_config._replace( + maps=small_config.maps._replace( + edge_length_px=jnp.array([64], dtype=jnp.int32) + ) + ) + + # If map_size is a separate attribute + if hasattr(small_config, 'map_size'): + small_config = small_config._replace(map_size=64) + + return small_config + + def _initialize_global_environment(self): + """Initialize the global environment with proper batching""" + self.rng, reset_key = jax.random.split(self.rng) + + # Create array of keys for batching consistency + reset_keys = jax.random.split(reset_key, 1) # Shape: (1, 2) + + print("Initializing global environment...") + global_timestep = self.global_env.reset(self.global_env_config, reset_keys) + + # Extract and store global map data + self.global_maps['target_map'] = global_timestep.state.world.target_map.map[0].copy() + self.global_maps['action_map'] = global_timestep.state.world.action_map.map[0].copy() + self.global_maps['dumpability_mask'] = global_timestep.state.world.dumpability_mask.map[0].copy() + self.global_maps['dumpability_mask_init'] = global_timestep.state.world.dumpability_mask_init.map[0].copy() + self.global_maps['padding_mask'] = global_timestep.state.world.padding_mask.map[0].copy() + self.global_maps['traversability_mask'] = global_timestep.state.world.traversability_mask.map[0].copy() + self.global_maps['trench_axes'] = global_timestep.state.world.trench_axes.copy() + self.global_maps['trench_type'] = global_timestep.state.world.trench_type.copy() + + # Store global timestep + self.global_timestep = global_timestep + + print("Global environment initialized successfully.") + return self.global_timestep + + def map_position_small_to_global(self, small_pos, region_coords): + """ + Map agent position from small map coordinates to global map coordinates. + Assumes the small map places the region at (0,0), so we need to add offsets. + Returns position in (x, y) format for rendering. + """ + y_start, x_start, y_end, x_end = region_coords + + # Extract position values - assuming agent position is [x, y] + if hasattr(small_pos, 'shape'): + if len(small_pos.shape) == 1 and small_pos.shape[0] == 2: + local_x = float(small_pos[0]) + local_y = float(small_pos[1]) + else: + local_x = float(small_pos.flatten()[0]) + local_y = float(small_pos.flatten()[1]) + else: + local_x = float(small_pos[0]) + local_y = float(small_pos[1]) + + # Adjust global coordinates based on map size + if self.map_size_px == 128: + global_x = local_x + y_start + global_y = local_y + x_start + else: + global_x = local_x + global_y = local_y + + return (int(global_y), int(global_x)) + + def is_small_task_completed(self): + """Check if the current small environment task is completed.""" + if self.small_env_timestep is None: + return False + + # Handle both scalar and array cases for done flag + done_value = self.small_env_timestep.done + if isinstance(done_value, jnp.ndarray): + if done_value.shape == (): # Scalar array + return bool(done_value) + elif len(done_value.shape) > 0: # Array with dimensions + return bool(done_value[0]) + else: + return bool(done_value) + else: + return bool(done_value) + + def _update_global_environment_display_with_all_agents(self, partition_states): + """ + Update the global environment display with ALL active agents. + Handle initialization errors properly. + """ + try: + self.rng, reset_key = jax.random.split(self.rng) + reset_keys = jax.random.split(reset_key, 1) + + # Collect all active agent positions and angles + all_agent_positions = [] + all_agent_angles_base = [] + all_agent_angles_cabin = [] + all_agent_loaded = [] + + for partition_idx, partition_state in partition_states.items(): + if partition_state['status'] == 'active' and partition_state['timestep'] is not None: + # Get agent state from this partition + small_agent_state = partition_state['timestep'].state.agent.agent_state + partition = self.partitions[partition_idx] + region_coords = partition['region_coords'] + + small_pos = small_agent_state.pos_base + small_angle_base = small_agent_state.angle_base + small_angle_cabin = small_agent_state.angle_cabin + small_loaded = small_agent_state.loaded + + # Map position to global coordinates + global_pos = self.map_position_small_to_global(small_pos, region_coords) + + # Handle angle extraction + if hasattr(small_angle_base, 'shape'): + if small_angle_base.shape == (): + angle_base_val = int(small_angle_base) + elif len(small_angle_base.shape) >= 1: + angle_base_val = int(small_angle_base.flatten()[0]) + else: + angle_base_val = 0.0 + else: + angle_base_val = int(small_angle_base) + + if hasattr(small_angle_cabin, 'shape'): + if small_angle_cabin.shape == (): + angle_cabin_val = int(small_angle_cabin) + elif len(small_angle_cabin.shape) >= 1: + angle_cabin_val = int(small_angle_cabin.flatten()[0]) + else: + angle_cabin_val = 0.0 + else: + angle_cabin_val = int(small_angle_cabin) + + if hasattr(small_loaded, 'shape'): + if small_loaded.shape == (): + small_loaded = int(small_loaded) + elif len(small_loaded.shape) >= 1: + small_loaded = int(small_loaded.flatten()[0]) + else: + small_loaded = False + else: + small_loaded = int(small_loaded) + + all_agent_positions.append(global_pos) + all_agent_angles_base.append(angle_base_val) + all_agent_angles_cabin.append(angle_cabin_val) + all_agent_loaded.append(small_loaded) + + print(f"Agent {partition_idx} at global position: {global_pos}, angle base: {angle_base_val}, angle cabin: {angle_cabin_val}, loaded: {small_loaded}") + + # Update global maps from small environments incrementally + if self.map_size_px == 64: + self.update_global_maps_from_all_small_environments_small(partition_states) + else: + self.update_global_maps_from_all_small_environments_big(partition_states) + + # Use first agent for reset position (others will be added during rendering) + custom_pos = all_agent_positions[0] if all_agent_positions else None + custom_angle = all_agent_angles_base[0] if all_agent_angles_base else None + + # Reset global environment with updated maps + self.global_timestep = self.global_env.reset_with_map_override( + self.global_env_config, + reset_keys, + custom_pos=custom_pos, + custom_angle=custom_angle, + target_map_override=self.global_maps['target_map'], + traversability_mask_override=self.global_maps['traversability_mask'], + padding_mask_override=self.global_maps['padding_mask'], + dumpability_mask_override=self.global_maps['dumpability_mask'], + dumpability_mask_init_override=self.global_maps['dumpability_mask_init'], + action_map_override=self.global_maps['action_map'], + agent_config_override=self.small_agent_config + ) + + # Store all agent positions for rendering - Initialize these attributes + if not hasattr(self.global_env, 'all_agent_positions'): + self.global_env.all_agent_positions = [] + if not hasattr(self.global_env, 'all_agent_angles_base'): + self.global_env.all_agent_angles_base = [] + if not hasattr(self.global_env, 'all_agent_angles_cabin'): + self.global_env.all_agent_angles_cabin = [] + if not hasattr(self.global_env, 'all_agent_loaded'): + self.global_env.all_agent_loaded = [] + + self.global_env.all_agent_positions = all_agent_positions + self.global_env.all_agent_angles_base = all_agent_angles_base + self.global_env.all_agent_angles_cabin = all_agent_angles_cabin + self.global_env.all_agent_loaded = all_agent_loaded + + print(f"Global environment updated with {len(all_agent_positions)} active agents.") + + except Exception as e: + print(f"Warning: Could not update global environment display: {e}") + import traceback + traceback.print_exc() + + def update_global_maps_from_all_small_environments_small(self, partition_states): + """ + Update global maps with changes from ALL active small environments. + Fixed to handle shape mismatches by properly extracting the correct region size. + """ + for partition_idx, partition_state in partition_states.items(): + if partition_state['status'] == 'active' and partition_state['timestep'] is not None: + partition = self.partitions[partition_idx] + y_start, x_start, y_end, x_end = partition['region_coords'] + + region_slice = (slice(y_start, y_end + 1), slice(x_start, x_end + 1)) + + # Get current state from small environment + small_state = partition_state['timestep'].state + + # Extract the maps from small environment (these are 64x64) + small_maps = { + 'dumpability_mask': small_state.world.dumpability_mask.map, + 'target_map': small_state.world.target_map.map, + 'action_map': small_state.world.action_map.map, + 'traversability_mask': small_state.world.traversability_mask.map, + 'padding_mask': small_state.world.padding_mask.map, + } + + for map_name, small_map in small_maps.items(): + # Extract the portion that matches the region size + extracted_region = small_map[region_slice] + + self.global_maps[map_name] = self.global_maps[map_name].at[region_slice].set(extracted_region) + + def update_global_maps_from_all_small_environments_big(self, partition_states): + """ + Update global maps with changes from ALL active small environments. + Properly handles coordinate transformation between 64x64 partition maps and 128x128 global maps. + """ + for partition_idx, partition_state in partition_states.items(): + if partition_state['status'] == 'active' and partition_state['timestep'] is not None: + try: + partition = self.partitions[partition_idx] + y_start, x_start, y_end, x_end = partition['region_coords'] + + # Calculate actual region dimensions + region_height = y_end - y_start + 1 + region_width = x_end - x_start + 1 + + # Get current state from small environment + small_state = partition_state['timestep'].state + + # Extract the maps from small environment (these are 64x64) + small_maps = { + 'dumpability_mask': small_state.world.dumpability_mask.map, + 'target_map': small_state.world.target_map.map, + 'action_map': small_state.world.action_map.map, + 'traversability_mask': small_state.world.traversability_mask.map, + 'padding_mask': small_state.world.padding_mask.map, + } + + for map_name, small_map in small_maps.items(): + if small_map.shape != (64, 64): + continue + + # Extract region from the 64x64 local map using LOCAL coordinates + if region_height <= 64 and region_width <= 64: + # Extract the relevant portion from the TOP-LEFT of the 64x64 map + # This corresponds to the actual region data + extracted_region = small_map[:region_height, :region_width] + else: + continue + + # Now update the global map using GLOBAL coordinates + global_region_slice = (slice(y_start, y_end + 1), slice(x_start, x_end + 1)) + + try: + # Verify global map exists and has correct shape + if map_name not in self.global_maps: + continue + + global_map = self.global_maps[map_name] + if global_map.shape != (128, 128): + continue + + # Update the global map + self.global_maps[map_name] = global_map.at[global_region_slice].set(extracted_region) + + except Exception as update_error: + print(f" Failed to update global {map_name}: {update_error}") + continue + + except Exception as partition_error: + print(f" Failed to process partition {partition_idx}: {partition_error}") + import traceback + traceback.print_exc() + continue + + def render_global_environment_with_multiple_agents(self, partition_states, VISUALIZE_PARTITIONS=False): + """ + Update and render global environment showing ALL active excavators. + Fixed to handle missing attributes gracefully. + """ + # First update with all agents + self._update_global_environment_display_with_all_agents(partition_states) + + # Then render with multiple agents + try: + obs = self.global_timestep.observation + info = self.global_timestep.info + + # Pass additional agent positions to the rendering system + if (hasattr(self.global_env, 'all_agent_positions') and + hasattr(self.global_env, 'all_agent_angles_base') and + hasattr(self.global_env, 'all_agent_angles_cabin') and + hasattr(self.global_env, 'all_agent_loaded')): + + # Add all agent positions to the info for rendering + info['additional_agents'] = { + 'positions': self.global_env.all_agent_positions, + 'angles base': self.global_env.all_agent_angles_base, + 'angles cabin': self.global_env.all_agent_angles_cabin, + 'loaded': self.global_env.all_agent_loaded + } + else: + print("Warning: Agent attributes not properly initialized for rendering") + # Initialize empty lists to prevent further errors + info['additional_agents'] = { + 'positions': [], + 'angles base': [], + 'angles cabin': [], + 'loaded': [] + } + + if VISUALIZE_PARTITIONS: + info['show_partitions'] = True + info['partitions'] = self.partitions # Just pass the whole partition list + + # Pass agent config to rendering + info['agent_config'] = self.small_agent_config + + if self.rendering: + self.global_env.terra_env.render_obs_pygame(obs, info) + + except Exception as e: + print(f"Global rendering error: {e}") + import traceback + traceback.print_exc() + + def render_all_partition_views_grid(self, partition_states): + """ + Render all active partition views in a grid layout. + This shows what each agent sees simultaneously. + """ + + active_partitions = [idx for idx, state in partition_states.items() + if state['status'] == 'active'] + + if not active_partitions: + return + + # Get screen dimensions + screen = pg.display.get_surface() + if screen is None: + return + + screen_width, screen_height = screen.get_size() + + # Calculate grid layout + num_partitions = len(active_partitions) + cols = min(2, num_partitions) # Max 2 columns + rows = (num_partitions + cols - 1) // cols + + # Calculate size for each partition view + partition_width = screen_width // cols + partition_height = screen_height // rows + + # Clear screen + screen.fill((50, 50, 50)) + + # Render each partition + for i, partition_idx in enumerate(active_partitions): + partition_state = partition_states[partition_idx] + + # Calculate position in grid + col = i % cols + row = i // cols + x_offset = col * partition_width + y_offset = row * partition_height + + # Render this partition's view + self._render_single_partition_view( + screen, partition_state, partition_idx, + x_offset, y_offset, partition_width, partition_height + ) + + pg.display.flip() + + def _render_single_partition_view(self, screen, partition_state, partition_idx, + x_offset, y_offset, width, height): + """ + Render a single partition's view within the given screen area. + """ + # Get the maps from the partition + current_timestep = partition_state['timestep'] + world = current_timestep.state.world + agent_state = current_timestep.state.agent.agent_state + + # Extract maps + target_map = world.target_map.map + action_map = world.action_map.map + traversability_mask = world.traversability_mask.map + agent_pos = agent_state.pos_base + + # Map dimensions + map_height, map_width = target_map.shape + + # Calculate tile size to fit in available space + tile_width = (width - 40) // map_width # Leave 40 pixels for margins + tile_height = (height - 60) // map_height # Leave 60 pixels for title and info + tile_size = max(2, min(tile_width, tile_height)) + + # Center the map in the available space + map_pixel_width = map_width * tile_size + map_pixel_height = map_height * tile_size + map_x = x_offset + (width - map_pixel_width) // 2 + map_y = y_offset + 40 # Leave space for title + + # Draw title + font = pg.font.Font(None, 32) + title = f"Partition {partition_idx}" + text = font.render(title, True, (255, 255, 255)) + screen.blit(text, (x_offset + 10, y_offset + 5)) + + # Draw the map + for y in range(map_height): + for x in range(map_width): + # Get cell values + target_val = target_map[y, x] + action_val = action_map[y, x] + traversable = traversability_mask[y, x] + + # Determine color based on cell state + if traversable == -1: # Agent position + color = (255, 100, 255) # Magenta + elif traversable == 1: # Obstacle (including other agents) + color = (255, 50, 50) # Red + elif action_val > 0: # Dumped soil + color = (139, 69, 19) # Brown + elif action_val == -1: # Dug area + color = (101, 67, 33) # Dark brown + elif target_val == -1: # Target to dig + color = (255, 255, 0) # Yellow + elif target_val == 1: # Target to dump + color = (0, 255, 0) # Green + else: # Free space + color = (220, 220, 220) # Light gray + + # Draw the tile + rect = pg.Rect( + map_x + x * tile_size, + map_y + y * tile_size, + tile_size, + tile_size + ) + pg.draw.rect(screen, color, rect) + + # Draw border around map + border_rect = pg.Rect(map_x - 1, map_y - 1, + map_pixel_width + 2, map_pixel_height + 2) + pg.draw.rect(screen, (255, 255, 255), border_rect, 1) + + # Draw agent position and stats + small_font = pg.font.Font(None, 20) + + # Agent position + pos_text = f"Agent: ({agent_pos[0]:.1f}, {agent_pos[1]:.1f})" + pos_surface = small_font.render(pos_text, True, (255, 255, 255)) + screen.blit(pos_surface, (x_offset + 10, y_offset + height - 40)) + + # Obstacle count (red cells = other agents + terrain obstacles) + obstacle_count = np.sum(traversability_mask == 1) + obstacle_text = f"Red obstacles: {obstacle_count}" + obstacle_surface = small_font.render(obstacle_text, True, (255, 100, 100)) + screen.blit(obstacle_surface, (x_offset + 10, y_offset + height - 20)) + + def _should_show_agent_in_partition(self, partition_idx, agent_y, agent_x): + """ + Determine if an agent at the given position should be visible to the partition. + + For global maps, you might want to: + 1. Show all agents everywhere (return True) + 2. Show agents only within a certain distance of the partition's region + 3. Show agents only within the partition's assigned region + """ + # Option 1: Show all agents everywhere (recommended for global maps) + return True + + # Option 2: Show agents within partition region + buffer + # if partition_idx < len(self.partitions): + # partition = self.partitions[partition_idx] + # y_start, x_start, y_end, x_end = partition['region_coords'] + # + # # Add buffer around partition region + # buffer = 10 + # return (y_start - buffer <= agent_y <= y_end + buffer and + # x_start - buffer <= agent_x <= x_end + buffer) + # + # return False + + def initialize_base_traversability_masks(self, partition_states): + """ + Store the initial clean traversability masks for each partition. + This captures the original terrain obstacles before any agent synchronization. + Call this ONCE after partition initialization but BEFORE any agent sync. + """ + if not hasattr(self, 'base_traversability_masks'): + self.base_traversability_masks = {} + + for partition_idx, partition_state in partition_states.items(): + if partition_state['status'] == 'active': + # Get the current traversability mask + current_mask = partition_state['timestep'].state.world.traversability_mask.map.copy() + + # Clean ANY agent markers to get pure terrain + # -1 = agent position (clear to 0) + # 1 = could be terrain or agent obstacles (assume terrain at initialization) + # 0 = free space (keep as is) + + # Create a completely clean mask - only terrain obstacles, no agents + clean_mask = jnp.where( + current_mask == -1, # Remove any agent positions + 0, # Set to free space + jnp.where( + current_mask == 1, # Keep terrain obstacles + 1, + 0 # Everything else becomes free space + ) + ) + + self.base_traversability_masks[partition_idx] = clean_mask + + def _update_partition_with_other_agents(self, target_partition_idx, target_partition_state, + all_occupied_cells, partition_states): + """ + Update a partition's traversability mask to show other agents as obstacles. + Now properly preserves original terrain obstacles. + """ + current_timestep = target_partition_state['timestep'] + + # STEP 1: Start from the clean base mask (has original terrain obstacles but no agent obstacles) + if hasattr(self, 'base_traversability_masks') and target_partition_idx in self.base_traversability_masks: + # Start from clean base (original terrain obstacles preserved) + current_traversability = self.base_traversability_masks[target_partition_idx].copy() + + # Add back the current agent's position (-1) + original_traversability = current_timestep.state.world.traversability_mask.map + agent_mask = (original_traversability == -1) + current_traversability = jnp.where( + agent_mask, + -1, # Restore current agent position + current_traversability # Keep clean base with terrain obstacles + ) + else: + # Fallback: use current mask but this might not work perfectly + print(f"Warning: No base mask for partition {target_partition_idx}, using current mask") + current_traversability = current_timestep.state.world.traversability_mask.map.copy() + + # Try to clear only agent obstacles (this is less reliable) + # Keep -1 (current agent) and assume original 1s are terrain + # This fallback is not ideal - base masks are recommended + + # STEP 2: Add current positions of OTHER agents as obstacles + agents_added = 0 + cells_added = 0 + + for other_partition_idx, occupied_cells in all_occupied_cells.items(): + if other_partition_idx == target_partition_idx: + continue # Don't add self as obstacle + + for cell_y, cell_x in occupied_cells: + # Check if this cell should be visible in this partition + if self._should_show_agent_in_partition(target_partition_idx, cell_y, cell_x): + # Check bounds + if (0 <= cell_y < current_traversability.shape[0] and + 0 <= cell_x < current_traversability.shape[1]): + # Mark as obstacle (value = 1) - this represents another agent + # Only set if it's currently free space (0) to avoid overwriting terrain + if current_traversability[cell_y, cell_x] == 0: + current_traversability = current_traversability.at[cell_y, cell_x].set(1) + cells_added += 1 + + if cells_added > 0: + agents_added += 1 + + # STEP 3: Update the world state + updated_world = self._update_world_map( + current_timestep.state.world, + 'traversability_mask', + current_traversability + ) + updated_state = current_timestep.state._replace(world=updated_world) + updated_timestep = current_timestep._replace(state=updated_state) + + partition_states[target_partition_idx]['timestep'] = updated_timestep + + if agents_added > 0: + print(f" βœ“ Added {agents_added} other agents ({cells_added} cells) to partition {target_partition_idx}") + + def step_with_full_global_sync(self, partition_idx: int, action, partition_states: dict): + """ + Synchronize BEFORE stepping to prevent collisions. + """ + # Step 1: Sync current positions BEFORE any movement + self._sync_agent_positions_across_partitions(partition_states) + + # # Step 2: Update observations so agents see synchronized state + self._update_all_observations(partition_states) + + # Step 3: NOW take the action with proper obstacle awareness + new_timestep = self.step_simple(partition_idx, action, partition_states) + + # Step 4: Update the partition state + partition_states[partition_idx]['timestep'] = new_timestep + + # Step 5: Extract changes and update global maps + self._update_global_maps_from_single_partition_small(partition_idx, partition_states) + + # Step 6: Propagate changes to other partitions + if self.overlap_regions != {} and self.map_size_px == 64: + self._sync_all_partitions_from_global_maps_excluding_traversability(partition_states) + #print(f" βœ“ Synced all partitions from global maps (excluding traversability)") + + return new_timestep + + def _update_global_maps_from_single_partition_big(self, source_partition_idx, partition_states): + """ + Update global maps with proper handling for overlapping regions. + This version correctly handles coordinate mapping and overlap synchronization. + """ + if source_partition_idx not in partition_states: + return + + source_state = partition_states[source_partition_idx]['timestep'].state + partition = self.partitions[source_partition_idx] + region_coords = partition['region_coords'] + y_start, x_start, y_end, x_end = region_coords + + #print(f" Updating global maps from partition {source_partition_idx}") + #print(f" Partition region_coords: {region_coords}") + + # Calculate the actual region dimensions + region_height = y_end - y_start + 1 + region_width = x_end - x_start + 1 + + #print(f" Region dimensions in global map: {region_height}x{region_width}") + + # Define which maps to update globally + maps_to_update = [ + 'action_map', + 'dumpability_mask', + 'dumpability_mask_init', + 'target_map' # Include target_map for overlap sync + ] + + # Update each map in the global storage + for map_name in maps_to_update: + # Get the current map from the partition (this is always 64x64) + partition_map = getattr(source_state.world, map_name).map + #print(f" {map_name} partition shape: {partition_map.shape}") + + # CRITICAL FIX: The partition map is 64x64, which represents the ENTIRE partition view + # We can only update the portion of the global map that fits within 64x64 + # If the region is larger than 64x64, we can only update a 64x64 portion + + # Calculate how much we can actually update + update_height = min(64, region_height) + update_width = min(64, region_width) + + #print(f" Will update {update_height}x{update_width} region in global map") + + # Extract the data to update (the full partition map up to the update size) + data_to_update = partition_map[:update_height, :update_width] + + # Define the target slice in the global map (only update what we can) + global_region_slice = ( + slice(y_start, y_start + update_height), + slice(x_start, x_start + update_width) + ) + + # Update the global map with the extracted region + try: + self.global_maps[map_name] = self.global_maps[map_name].at[global_region_slice].set(data_to_update) + except Exception as e: + print(f" βœ— Error updating global {map_name}: {e}") + + def _sync_overlapping_regions_big_maps(self, source_partition_idx, partition_states): + """ + Synchronize overlapping regions between partitions for big maps. + Handles the case where regions are larger than 64x64. + """ + + # Get all partitions that overlap with the source + overlapping_partitions = self.overlap_map.get(source_partition_idx, set()) + + for target_partition_idx in overlapping_partitions: + if (target_partition_idx not in partition_states or + partition_states[target_partition_idx]['status'] != 'active'): + continue + + # Get overlap information + overlap_key = (min(source_partition_idx, target_partition_idx), + max(source_partition_idx, target_partition_idx)) + + if overlap_key not in self.overlap_regions: + continue + + overlap_info = self.overlap_regions[overlap_key] + + # Determine which partition is source and which is target + if source_partition_idx < target_partition_idx: + source_slice = overlap_info['partition_i_slice'] + target_slice = overlap_info['partition_j_slice'] + else: + source_slice = overlap_info['partition_j_slice'] + target_slice = overlap_info['partition_i_slice'] + + # Sync the overlapping region + self._sync_single_overlap_region( + source_partition_idx, target_partition_idx, + source_slice, target_slice, + partition_states + ) + + def _sync_single_overlap_region(self, source_idx, target_idx, source_slice, target_slice, partition_states): + """ + Sync a single overlapping region from source to target partition. + Handles shape mismatches by only syncing the actual overlapping data. + """ + source_state = partition_states[source_idx]['timestep'].state + target_state = partition_states[target_idx]['timestep'].state + + # Maps to sync (excluding traversability which is handled separately) + maps_to_sync = ['action_map', 'dumpability_mask', 'dumpability_mask_init'] + + for map_name in maps_to_sync: + try: + # Get source and target maps + source_map = getattr(source_state.world, map_name).map + target_map = getattr(target_state.world, map_name).map.copy() + + # Extract the actual shapes of the slices + source_y_slice, source_x_slice = source_slice + target_y_slice, target_x_slice = target_slice + + # Calculate actual dimensions of overlap region + source_height = min(source_y_slice.stop - source_y_slice.start, source_map.shape[0] - source_y_slice.start) + source_width = min(source_x_slice.stop - source_x_slice.start, source_map.shape[1] - source_x_slice.start) + target_height = min(target_y_slice.stop - target_y_slice.start, target_map.shape[0] - target_y_slice.start) + target_width = min(target_x_slice.stop - target_x_slice.start, target_map.shape[1] - target_x_slice.start) + + # Use the minimum dimensions to ensure compatibility + sync_height = min(source_height, target_height) + sync_width = min(source_width, target_width) + + # Create adjusted slices + adj_source_slice = ( + slice(source_y_slice.start, source_y_slice.start + sync_height), + slice(source_x_slice.start, source_x_slice.start + sync_width) + ) + adj_target_slice = ( + slice(target_y_slice.start, target_y_slice.start + sync_height), + slice(target_x_slice.start, target_x_slice.start + sync_width) + ) + + # Extract overlapping region from source + source_overlap_data = source_map[adj_source_slice] + + # Update target map with source data + target_map = target_map.at[adj_target_slice].set(source_overlap_data) + + # Update the world state + updated_world = self._update_world_map(target_state.world, map_name, target_map) + target_state = target_state._replace(world=updated_world) + + except Exception as e: + print(f" βœ— Error syncing {map_name}: {e}") + + # Special handling for target_map - merge instead of overwrite + try: + source_target_map = source_state.world.target_map.map + target_target_map = target_state.world.target_map.map.copy() + + # Use adjusted slices for target map as well + source_overlap_targets = source_target_map[adj_source_slice] + target_overlap_targets = target_target_map[adj_target_slice] + + # Merge logic: keep dig targets (-1) from both, prioritize source for conflicts + merged_targets = jnp.where( + (target_overlap_targets == -1) | (source_overlap_targets == -1), + -1, # Keep dig targets from either partition + source_overlap_targets # Otherwise use source + ) + + target_target_map = target_target_map.at[adj_target_slice].set(merged_targets) + updated_world = self._update_world_map(target_state.world, 'target_map', target_target_map) + target_state = target_state._replace(world=updated_world) + + except Exception as e: + print(f" βœ— Error merging target_map: {e}") + + # Update the partition state + updated_timestep = partition_states[target_idx]['timestep']._replace(state=target_state) + partition_states[target_idx]['timestep'] = updated_timestep + + def _calculate_overlap_region(self, partition_i: int, partition_j: int): + """ + Calculate the overlapping region between two partitions with correct coordinate mapping. + Handles cases where partitions are larger than 64x64. + """ + p1_coords = self.partitions[partition_i]['region_coords'] + p2_coords = self.partitions[partition_j]['region_coords'] + + y1_start, x1_start, y1_end, x1_end = p1_coords + y2_start, x2_start, y2_end, x2_end = p2_coords + + # Find intersection in global coordinates + overlap_y_start = max(y1_start, y2_start) + overlap_x_start = max(x1_start, x2_start) + overlap_y_end = min(y1_end, y2_end) + overlap_x_end = min(x1_end, x2_end) + + # Check if there's actual overlap + if overlap_y_start > overlap_y_end or overlap_x_start > overlap_x_end: + return None + + # Calculate local coordinates relative to each partition's origin + # But limit to 64x64 since that's the actual partition map size + local_i_y_start = overlap_y_start - y1_start + local_i_x_start = overlap_x_start - x1_start + local_i_y_end = overlap_y_end - y1_start + local_i_x_end = overlap_x_end - x1_start + + local_j_y_start = overlap_y_start - y2_start + local_j_x_start = overlap_x_start - x2_start + local_j_y_end = overlap_y_end - y2_start + local_j_x_end = overlap_x_end - x2_start + + # CRITICAL: Ensure local coordinates don't exceed 64x64 bounds + # The partition maps are always 64x64, even if the region is larger + local_i_y_start = max(0, min(local_i_y_start, 63)) + local_i_x_start = max(0, min(local_i_x_start, 63)) + local_i_y_end = max(0, min(local_i_y_end, 63)) + local_i_x_end = max(0, min(local_i_x_end, 63)) + + local_j_y_start = max(0, min(local_j_y_start, 63)) + local_j_x_start = max(0, min(local_j_x_start, 63)) + local_j_y_end = max(0, min(local_j_y_end, 63)) + local_j_x_end = max(0, min(local_j_x_end, 63)) + + return { + 'global_slice': (slice(overlap_y_start, overlap_y_end + 1), + slice(overlap_x_start, overlap_x_end + 1)), + 'partition_i_slice': (slice(local_i_y_start, local_i_y_end + 1), + slice(local_i_x_start, local_i_x_end + 1)), + 'partition_j_slice': (slice(local_j_y_start, local_j_y_end + 1), + slice(local_j_x_start, local_j_x_end + 1)), + 'overlap_bounds': (overlap_y_start, overlap_x_start, overlap_y_end, overlap_x_end) + } + + def _compute_overlap_relationships(self): + """ + Compute overlap relationships with corrected coordinate mapping. + """ + print(f"\n=== COMPUTING OVERLAP RELATIONSHIPS ===") + + self.overlap_map = {i: set() for i in range(len(self.partitions))} + self.overlap_regions = {} + + for i in range(len(self.partitions)): + for j in range(i + 1, len(self.partitions)): + print(f"\nChecking partitions {i} and {j}:") + + if self._partitions_overlap(i, j): + self.overlap_map[i].add(j) + self.overlap_map[j].add(i) + + # Use the overlap calculation + overlap_info = self._calculate_overlap_region(i, j) + if overlap_info is not None: + self.overlap_regions[(i, j)] = overlap_info + self.overlap_regions[(j, i)] = overlap_info # Symmetric + + # Debug: print overlap details + print(f" Overlap found:") + print(f" Global region: {overlap_info['overlap_bounds']}") + print(f" Partition {i} local slice: {overlap_info['partition_i_slice']}") + print(f" Partition {j} local slice: {overlap_info['partition_j_slice']}") + else: + print(f" Could not calculate overlap region!") + else: + print(f" No overlap detected") + + print(f"\n=== FINAL OVERLAP RELATIONSHIPS ===") + for i, partition in enumerate(self.partitions): + overlaps = list(self.overlap_map[i]) + print(f"Partition {i}: region={partition['region_coords']}, overlaps with {overlaps}") + + print(f"Total overlap regions cached: {len(self.overlap_regions)}") + + def _sync_all_partitions_from_global_maps_excluding_traversability_big(self, partition_states): + """ + Synchronize ALL partitions with updated global maps, handling 64x64 partition maps correctly. + """ + + for target_partition_idx, target_partition_state in partition_states.items(): + if target_partition_state['status'] != 'active': + continue + + # Get current state + current_timestep = target_partition_state['timestep'] + current_state = current_timestep.state + + # Create updated world state with global maps but preserve partition-specific targets + updated_world = self._create_world_with_global_maps_preserve_targets_big( + current_state.world, target_partition_idx + ) + + # Create updated state and timestep + updated_state = current_state._replace(world=updated_world) + updated_timestep = current_timestep._replace(state=updated_state) + + # Update the partition state + partition_states[target_partition_idx]['timestep'] = updated_timestep + + def _create_world_with_global_maps_preserve_targets_big(self, current_world, partition_idx): + """ + Create a new world state that uses global maps but correctly handles 64x64 partition size. + """ + # Get partition info + partition = self.partitions[partition_idx] + region_coords = partition['region_coords'] + y_start, x_start, y_end, x_end = region_coords + + # Calculate region dimensions + region_height = min(64, y_end - y_start + 1) + region_width = min(64, x_end - x_start + 1) + + #print(f" Creating world for partition {partition_idx}, extracting {region_height}x{region_width} from global") + + # Get the original partition-specific target map + if hasattr(self, 'partition_target_maps') and partition_idx in self.partition_target_maps: + partition_target_map = self.partition_target_maps[partition_idx] + else: + partition_target_map = current_world.target_map.map + + # Extract 64x64 regions from global maps for this partition + extracted_maps = {} + + for map_name in ['action_map', 'dumpability_mask', 'dumpability_mask_init', 'padding_mask']: + global_map = self.global_maps[map_name] + + # Extract the region from global map + extracted_region = global_map[y_start:y_start + region_height, x_start:x_start + region_width] + + # If extracted region is smaller than 64x64, pad it + if extracted_region.shape != (64, 64): + # Create a 64x64 map with appropriate default values + if map_name == 'action_map': + default_val = 0 # Free space + elif 'dumpability' in map_name: + default_val = 0 if 'dumpability_mask' in map_name else 1 # Can't dump / can dump for init + elif map_name == 'padding_mask': + default_val = 1 # Non-traversable padding + else: + default_val = 0 + + padded_map = jnp.full((64, 64), default_val, dtype=extracted_region.dtype) + padded_map = padded_map.at[:extracted_region.shape[0], :extracted_region.shape[1]].set(extracted_region) + extracted_maps[map_name] = padded_map + else: + extracted_maps[map_name] = extracted_region + + # Create updated world with extracted maps + updated_world = current_world._replace( + target_map=current_world.target_map._replace(map=partition_target_map), # Keep partition-specific + action_map=current_world.action_map._replace(map=extracted_maps['action_map']), + dumpability_mask=current_world.dumpability_mask._replace(map=extracted_maps['dumpability_mask']), + dumpability_mask_init=current_world.dumpability_mask_init._replace(map=extracted_maps['dumpability_mask_init']), + padding_mask=current_world.padding_mask._replace(map=extracted_maps['padding_mask']) + # NOTE: traversability_mask is handled separately by agent sync + ) + + return updated_world + + def step_with_full_global_sync_big(self, partition_idx: int, action, partition_states: dict): + """ + Enhanced step function with better error handling for shape mismatches. + """ + try: + # Step 1: Sync current positions BEFORE any movement + self._sync_agent_positions_across_partitions(partition_states) + + # Step 2: Update observations so agents see synchronized state + self._update_all_observations(partition_states) + + # Step 3: Take the action with proper obstacle awareness + new_timestep = self.step_simple(partition_idx, action, partition_states) + + # Step 4: Update the partition state + partition_states[partition_idx]['timestep'] = new_timestep + + # Step 5: Extract changes and update global maps + self._update_global_maps_from_single_partition_big(partition_idx, partition_states) + + # Step 6: Sync overlapping regions for big maps + if self.overlap_regions: + self._sync_overlapping_regions_big_maps(partition_idx, partition_states) + + # Step 7: Propagate changes to other partitions + if self.overlap_regions != {}: + self._sync_all_partitions_from_global_maps_excluding_traversability_big(partition_states) + + return new_timestep + + except Exception as e: + print(f" ERROR in step_with_full_global_sync for partition {partition_idx}: {e}") + import traceback + traceback.print_exc() + # Return the timestep even if sync failed + return new_timestep if 'new_timestep' in locals() else partition_states[partition_idx]['timestep'] + + def _update_partition_traversability_with_dumped_soil_and_dig_targets(self, target_partition_idx, target_partition_state, + all_agent_positions, partition_states): + """ + Clean approach to updating traversability that includes: + 1. Original terrain obstacles + 2. Dumped soil as obstacles + 3. Other agents as obstacles + 4. OTHER PARTITIONS' DIG TARGETS (-1) as obstacles (only dig targets, not dump targets) + + Traversability logic: + - 0: Free space (can drive through) + - 1: Obstacles (terrain + other agents + dumped soil + other partitions' dig targets) + - -1: Current agent position + """ + current_timestep = target_partition_state['timestep'] + + # STEP 1: Start from completely clean base mask (original terrain only) + if target_partition_idx in self.base_traversability_masks: + clean_traversability = self.base_traversability_masks[target_partition_idx].copy() + else: + current_mask = current_timestep.state.world.traversability_mask.map + clean_traversability = jnp.where( + (current_mask == -1) | (current_mask == 1), # Remove all agent markers + 0, # Set to free space + current_mask # Keep original terrain + ) + + # STEP 2: Add dumped soil areas as obstacles + action_map = current_timestep.state.world.action_map.map + dumped_areas = (action_map > 0) # Positive values = dumped soil + + # Mark dumped soil areas as obstacles (1) + clean_traversability = jnp.where( + dumped_areas, + 1, # Dumped soil = obstacle + clean_traversability # Keep existing values + ) + + # STEP 3: Add other partitions' DIG TARGETS (-1) as obstacles, but NOT dump targets (1) + other_dig_targets_blocked = 0 + + for other_partition_idx, other_partition_state in partition_states.items(): + if (other_partition_idx == target_partition_idx or + other_partition_state['status'] != 'active'): + continue + + # Get the original target map for the other partition + if hasattr(self, 'partition_target_maps') and other_partition_idx in self.partition_target_maps: + other_target_map = self.partition_target_maps[other_partition_idx] + + # Only block dig targets (-1), NOT dump targets (1) + other_dig_targets = (other_target_map == -1) # Only dig targets + # Note: We don't block dump targets (other_target_map == 1) because agents can potentially traverse dump areas + + # Mark dig targets as obstacles in current partition's traversability + clean_traversability = jnp.where( + other_dig_targets, + 1, # Other partitions' dig targets = obstacles + clean_traversability # Keep existing values + ) + + dig_targets_blocked_from_this_partition = jnp.sum(other_dig_targets) + other_dig_targets_blocked += dig_targets_blocked_from_this_partition + + # STEP 4: Add THIS partition's agent positions as agents (-1) + if target_partition_idx in all_agent_positions: + own_positions = all_agent_positions[target_partition_idx] + for cell_y, cell_x in own_positions: + if (0 <= cell_y < clean_traversability.shape[0] and + 0 <= cell_x < clean_traversability.shape[1]): + clean_traversability = clean_traversability.at[cell_y, cell_x].set(-1) + + # STEP 5: Add OTHER agents as OBSTACLES (1), not agents + other_agents_added = 0 + other_cells_added = 0 + + for other_partition_idx, other_positions in all_agent_positions.items(): + if other_partition_idx == target_partition_idx: + continue # Skip own agent + + for cell_y, cell_x in other_positions: + # Check bounds + if (0 <= cell_y < clean_traversability.shape[0] and + 0 <= cell_x < clean_traversability.shape[1]): + + # Add as OBSTACLE (1), not agent (-1) + # Only if it's currently free space (0) - don't overwrite own agent or existing obstacles + current_value = clean_traversability[cell_y, cell_x] + if current_value == 0: # Free space + clean_traversability = clean_traversability.at[cell_y, cell_x].set(1) + other_cells_added += 1 + + if other_cells_added > 0: + other_agents_added += 1 + + # STEP 6: Update the world state + updated_world = self._update_world_map( + current_timestep.state.world, + 'traversability_mask', + clean_traversability + ) + updated_state = current_timestep.state._replace(world=updated_world) + updated_timestep = current_timestep._replace(state=updated_state) + + partition_states[target_partition_idx]['timestep'] = updated_timestep + + def _sync_all_partitions_from_global_maps_excluding_traversability(self, partition_states): + """ + UPDATED: Synchronize ALL partitions with updated global maps, but preserve partition-specific targets. + """ + + for target_partition_idx, target_partition_state in partition_states.items(): + if target_partition_state['status'] != 'active': + continue + + # Get current state + current_timestep = target_partition_state['timestep'] + current_state = current_timestep.state + + # Create updated world state with global maps but preserve partition-specific targets + updated_world = self._create_world_with_global_maps_preserve_targets(current_state.world, target_partition_idx) + + # Create updated state and timestep + updated_state = current_state._replace(world=updated_world) + updated_timestep = current_timestep._replace(state=updated_state) + + # Update the partition state + partition_states[target_partition_idx]['timestep'] = updated_timestep + + def _update_all_observations(self, partition_states): + """ + Update observations for all partitions to match their synced states. + """ + + for partition_idx, partition_state in partition_states.items(): + if partition_state['status'] != 'active': + continue + + current_timestep = partition_state['timestep'] + + # Create updated observation that matches the synced state + updated_observation = self._create_observation_from_synced_state( + current_timestep.observation, + current_timestep.state.world + ) + + # Update the timestep with the new observation + updated_timestep = current_timestep._replace(observation=updated_observation) + partition_states[partition_idx]['timestep'] = updated_timestep + + def _update_global_maps_from_single_partition_small(self, source_partition_idx, partition_states): + """ + Update global maps but handle target_map specially. + Target maps should remain partition-specific and not be fully synchronized. + """ + if source_partition_idx not in partition_states: + return + + source_state = partition_states[source_partition_idx]['timestep'].state + partition = self.partitions[source_partition_idx] + region_coords = partition['region_coords'] + y_start, x_start, y_end, x_end = region_coords + + # Define which maps to update globally (EXCLUDE target_map) + maps_to_update = [ + 'action_map', + 'dumpability_mask', + 'dumpability_mask_init' + ] + + # Update each map in the global storage (EXCLUDING target_map) + for map_name in maps_to_update: + # Get the current map from the partition + partition_map = getattr(source_state.world, map_name).map + + # Extract the region that corresponds to this partition + region_slice = (slice(y_start, y_end + 1), slice(x_start, x_end + 1)) + partition_region = partition_map[region_slice] + + # Update the global map with this region + self.global_maps[map_name] = self.global_maps[map_name].at[region_slice].set(partition_region) + + # Handle target_map specially - update global but don't sync back to other partitions + target_map = source_state.world.target_map.map + region_slice = (slice(y_start, y_end + 1), slice(x_start, x_end + 1)) + target_region = target_map[region_slice] + + # Update global target map for tracking purposes, but partitions keep their own + self.global_maps['target_map'] = self.global_maps['target_map'].at[region_slice].set(target_region) + + def _create_observation_from_synced_state(self, original_observation, synced_world): + """ + Create an updated observation dictionary that reflects the synced world state. + """ + # Start with the original observation (copy all fields) + updated_observation = {} + for key, value in original_observation.items(): + updated_observation[key] = value + + # Update the critical fields with synced data + updated_observation['traversability_mask'] = synced_world.traversability_mask.map + updated_observation['action_map'] = synced_world.action_map.map + updated_observation['target_map'] = synced_world.target_map.map + updated_observation['dumpability_mask'] = synced_world.dumpability_mask.map + updated_observation['padding_mask'] = synced_world.padding_mask.map + + return updated_observation + + def _sync_agent_positions_across_partitions(self, partition_states): + """ + Properly sync agent positions with dumped soil and dig target blocking only. + """ + # Ensure base masks are initialized + if not hasattr(self, 'base_traversability_masks'): + #print(" WARNING: Base masks not initialized, initializing now...") + self.initialize_base_traversability_masks(partition_states) + + # Collect all current agent positions + all_agent_positions = {} + + for partition_idx, partition_state in partition_states.items(): + if partition_state['status'] != 'active': + continue + + current_timestep = partition_state['timestep'] + traversability = current_timestep.state.world.traversability_mask.map + + # Find this partition's agent positions (value = -1) + agent_mask = (traversability == -1) + agent_positions = jnp.where(agent_mask) + + if len(agent_positions[0]) > 0: + occupied_cells = [] + for i in range(len(agent_positions[0])): + cell = (int(agent_positions[0][i]), int(agent_positions[1][i])) + occupied_cells.append(cell) + all_agent_positions[partition_idx] = occupied_cells + + # Update each partition with clean traversability including only dig target obstacles + for target_partition_idx, target_partition_state in partition_states.items(): + if target_partition_state['status'] != 'active': + continue + + self._update_partition_traversability_with_dumped_soil_and_dig_targets( + target_partition_idx, target_partition_state, + all_agent_positions, partition_states + ) + + def initialize_partition_specific_target_maps(self, partition_states): + """ + Store the original partition-specific target maps. + Each partition should only see their own targets, never targets from other partitions. + Call this ONCE after partition initialization. + """ + if not hasattr(self, 'partition_target_maps'): + self.partition_target_maps = {} + + for partition_idx, partition_state in partition_states.items(): + if partition_state['status'] == 'active': + # Store the original target map for this partition + original_target_map = partition_state['timestep'].state.world.target_map.map.copy() + self.partition_target_maps[partition_idx] = original_target_map + + def _create_world_with_global_maps_preserve_targets(self, current_world, partition_idx): + """ + Create a new world state that uses global maps but preserves partition-specific targets. + """ + # Get the original partition-specific target map + if hasattr(self, 'partition_target_maps') and partition_idx in self.partition_target_maps: + partition_target_map = self.partition_target_maps[partition_idx] + else: + # Fallback to current target map + partition_target_map = current_world.target_map.map + + updated_world = current_world._replace( + target_map=current_world.target_map._replace(map=partition_target_map), # Keep partition-specific + action_map=current_world.action_map._replace(map=self.global_maps['action_map']), + dumpability_mask=current_world.dumpability_mask._replace(map=self.global_maps['dumpability_mask']), + dumpability_mask_init=current_world.dumpability_mask_init._replace(map=self.global_maps['dumpability_mask_init']), + padding_mask=current_world.padding_mask._replace(map=self.global_maps['padding_mask']) + # NOTE: traversability_mask will be handled separately by agent sync + ) + + return updated_world + + def assign_exclusive_targets_in_overlaps(self, partition_states): + """ + Assign targets in overlapping regions exclusively to one partition. + This prevents conflicts and double work. + """ + print("\n=== ASSIGNING EXCLUSIVE TARGETS IN OVERLAPPING REGIONS ===") + + # Track which targets have been assigned + assigned_targets = {} # (y, x) -> partition_idx + + # First pass: identify all targets in overlapping regions + overlap_targets = {} # (y, x) -> list of partition_idx that can see it + + for partition_idx, partition_state in partition_states.items(): + if partition_state['status'] != 'active': + continue + + partition = self.partitions[partition_idx] + region_coords = partition['region_coords'] + y_start, x_start, y_end, x_end = region_coords + + # Get this partition's target map + if hasattr(self, 'partition_target_maps') and partition_idx in self.partition_target_maps: + target_map = self.partition_target_maps[partition_idx] + + # Find all dig targets in this partition + dig_targets = jnp.where(target_map == -1) + + for i in range(len(dig_targets[0])): + local_y = int(dig_targets[0][i]) + local_x = int(dig_targets[1][i]) + + # Convert to global coordinates + global_y = y_start + local_y + global_x = x_start + local_x + + # Check if this target is in an overlap region + for other_idx in self.overlap_map.get(partition_idx, set()): + other_partition = self.partitions[other_idx] + other_y_start, other_x_start, other_y_end, other_x_end = other_partition['region_coords'] + + # Check if this global coordinate is within the other partition's region + if (other_y_start <= global_y <= other_y_end and + other_x_start <= global_x <= other_x_end): + # This target is in an overlap region + coord = (global_y, global_x) + if coord not in overlap_targets: + overlap_targets[coord] = [] + overlap_targets[coord].append(partition_idx) + + # Second pass: assign overlapping targets based on strategy + targets_reassigned = 0 + for (global_y, global_x), partition_list in overlap_targets.items(): + if len(partition_list) > 1: + # Multiple partitions can see this target - assign to one + assigned_partition = self._choose_partition_for_target( + global_y, global_x, partition_list, partition_states + ) + assigned_targets[(global_y, global_x)] = assigned_partition + targets_reassigned += 1 + + print(f" Target at ({global_y}, {global_x}) assigned to partition {assigned_partition} " + f"(was visible to: {partition_list})") + + # Third pass: update partition target maps to remove non-assigned targets + for partition_idx, partition_state in partition_states.items(): + if partition_state['status'] != 'active': + continue + + partition = self.partitions[partition_idx] + region_coords = partition['region_coords'] + y_start, x_start, y_end, x_end = region_coords + + # Get current target map + if hasattr(self, 'partition_target_maps') and partition_idx in self.partition_target_maps: + target_map = self.partition_target_maps[partition_idx].copy() + modified = False + + # Check each target in this partition + dig_targets = jnp.where(target_map == -1) + for i in range(len(dig_targets[0])): + local_y = int(dig_targets[0][i]) + local_x = int(dig_targets[1][i]) + global_y = y_start + local_y + global_x = x_start + local_x + + # If this target was assigned to another partition, remove it + if ((global_y, global_x) in assigned_targets and + assigned_targets[(global_y, global_x)] != partition_idx): + # Change from dig target (-1) to free space (0) or dump area (1) + target_map = target_map.at[local_y, local_x].set(0) + modified = True + + if modified: + # Update the stored partition target map + self.partition_target_maps[partition_idx] = target_map + + # Update the actual partition's target map + current_timestep = partition_state['timestep'] + updated_world = self._update_world_map( + current_timestep.state.world, + 'target_map', + target_map + ) + updated_state = current_timestep.state._replace(world=updated_world) + updated_timestep = current_timestep._replace(state=updated_state) + partition_states[partition_idx]['timestep'] = updated_timestep + + print(f" Total targets reassigned: {targets_reassigned}") + print(f"=== TARGET ASSIGNMENT COMPLETE ===\n") + + return assigned_targets + + def _choose_partition_for_target(self, global_y, global_x, partition_list, partition_states): + """ + Choose which partition should handle a target in an overlapping region. + + Strategies: + 1. Closest agent + 2. Least loaded partition (fewer remaining targets) + 3. First-come-first-served (lowest partition index) + 4. Based on partition efficiency/performance + """ + # Strategy 1: Assign to partition with closest agent + min_distance = float('inf') + chosen_partition = partition_list[0] + + for partition_idx in partition_list: + if partition_states[partition_idx]['status'] != 'active': + continue + + # Get agent position + agent_state = partition_states[partition_idx]['timestep'].state.agent.agent_state + agent_pos = agent_state.pos_base + + # Calculate distance to target + # Note: agent position might need coordinate transformation + partition = self.partitions[partition_idx] + y_start, x_start, _, _ = partition['region_coords'] + + # Convert agent local position to global + agent_global_y = y_start + agent_pos[0] + agent_global_x = x_start + agent_pos[1] + + distance = jnp.sqrt((agent_global_y - global_y)**2 + (agent_global_x - global_x)**2) + + if distance < min_distance: + min_distance = distance + chosen_partition = partition_idx + + return chosen_partition + + def initialize_partition_specific_target_maps_with_exclusive_assignment(self, partition_states): + """ + Enhanced version that assigns exclusive targets after initialization. + """ + # First, do the regular initialization + if not hasattr(self, 'partition_target_maps'): + self.partition_target_maps = {} + + print("Storing partition-specific target maps...") + + for partition_idx, partition_state in partition_states.items(): + if partition_state['status'] == 'active': + # Store the original target map for this partition + original_target_map = partition_state['timestep'].state.world.target_map.map.copy() + self.partition_target_maps[partition_idx] = original_target_map + + # Count targets for verification + dig_targets = jnp.sum(original_target_map == -1) + dump_targets = jnp.sum(original_target_map == 1) + + print(f" Partition {partition_idx}: {dig_targets} dig targets, {dump_targets} dump targets") + + # Then assign exclusive targets in overlapping regions + if self.overlap_regions: + self.assign_exclusive_targets_in_overlaps(partition_states) \ No newline at end of file diff --git a/llm/eval_llm.py b/llm/eval_llm.py new file mode 100644 index 0000000..68dca20 --- /dev/null +++ b/llm/eval_llm.py @@ -0,0 +1,71 @@ +import jax.numpy as jnp +import numpy as np + +def compute_stats_llm(episode_done_once_list, episode_length_list, move_cumsum_list, + do_cumsum_list, areas_list, dig_tiles_per_target_map_init_list, + dug_tiles_per_action_map_list): + """ + Compute statistics from the results of multiple environments. + Args: + episode_done_once_list (list): List of booleans indicating if the episode was done once. + episode_length_list (list): List of integers representing the length of each episode. + move_cumsum_list (list): List of cumulative sums of moves for each environment. + do_cumsum_list (list): List of cumulative sums of 'do' actions for each environment. + areas_list (list): List of areas for each environment. + dig_tiles_per_target_map_init_list (list): List of initial dig tiles per target map. + dug_tiles_per_action_map_list (list): List of dug tiles per action map. + Returns: + None: Prints the computed statistics. + """ + + + episode_done_once = jnp.array(episode_done_once_list) + episode_length = jnp.array(episode_length_list) + move_cumsum = jnp.array(move_cumsum_list) + do_cumsum = jnp.array(do_cumsum_list) + areas = jnp.array(areas_list) + dig_tiles_per_target_map_init = jnp.array(dig_tiles_per_target_map_init_list) + dug_tiles_per_action_map = jnp.array(dug_tiles_per_action_map_list) + + print("\nSummary of results across all environments:") + print(f"Episode done once: {episode_done_once}") + print(f"Episode length: {episode_length}") + print(f"Move cumsum: {move_cumsum}") + print(f"Do cumsum: {do_cumsum}") + print(f"Areas: {areas}") + print(f"Dig tiles per target map init: {dig_tiles_per_target_map_init}") + print(f"Dug tiles per action map: {dug_tiles_per_action_map}") + + # Path efficiency -- only include finished envs + move_cumsum *= episode_done_once + path_efficiency = (move_cumsum / jnp.sqrt(areas))[episode_done_once] + path_efficiency_std = path_efficiency.std() + path_efficiency_mean = path_efficiency.mean() + + # Workspaces efficiency -- only include finished envs + reference_workspace_area = 0.5 * np.pi * (8**2) + n_dig_actions = do_cumsum // 2 + workspaces_efficiency = ( + reference_workspace_area + * ((n_dig_actions * episode_done_once) / areas)[episode_done_once] + ) + workspaces_efficiency_mean = workspaces_efficiency.mean() + workspaces_efficiency_std = workspaces_efficiency.std() + + coverage_ratios = dug_tiles_per_action_map / dig_tiles_per_target_map_init + coverage_scores = episode_done_once + (~episode_done_once) * coverage_ratios + coverage_score_mean = coverage_scores.mean() + coverage_score_std = coverage_scores.std() + + completion_rate = 100 * episode_done_once.sum() / len(episode_done_once) + + print("\nStats:\n") + print(f"Completion: {completion_rate:.2f}%") + + print( + f"Path efficiency: {path_efficiency_mean:.2f} ({path_efficiency_std:.2f})" + ) + print( + f"Workspaces efficiency: {workspaces_efficiency_mean:.2f} ({workspaces_efficiency_std:.2f})" + ) + print(f"Coverage: {coverage_score_mean:.2f} ({coverage_score_std:.2f})") \ No newline at end of file diff --git a/llm/game/README.md b/llm/game/README.md new file mode 100644 index 0000000..695165d --- /dev/null +++ b/llm/game/README.md @@ -0,0 +1,35 @@ +# LLM Game - Terra Game AI Player + +This module implements an LLM-based AI player for the Terra game, replacing manual keyboard controls with intelligent automated gameplay. + +## Code Execution +The code can be executed with the following command, similar as the hybrid RL-LLM policy. + +```python +DATASET_PATH= DATASET_SIZE= python -m llm.game.main_manual_llm --model_name --model_key --num_timesteps +``` +The parameters are the same that are defined in the `README.md` file in the parent folder. + +**Note**: Ensure that the corresponding API keys are properly exported as environment variables before running the command. + +## Running Analysis +To analyze the performance of your LLM game player: + +```python +python llm/game/analyze_results_py. +``` + +This script will: +- Process the gameplay logs +- Generate cumulative reward plots +- Create final reward visualizations +- Save all outputs to the `analysis/` folder + +## Performance Considerations + +- API Rate Limits: Be mindful of your LLM provider's rate limits +- Cost Management: Monitor API usage to control costs +- Response Time: Consider caching strategies for frequently occurring game states + +## References +This implementation is inspired by the [Atari-GPT](https://github.com/nwayt001/atari-gpt) project, which pioneered the use of LLMs for game playing. \ No newline at end of file diff --git a/llm/game/analyze_results.py b/llm/game/analyze_results.py new file mode 100644 index 0000000..fc3d65f --- /dev/null +++ b/llm/game/analyze_results.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +""" +Analyze and visualize results for the AutonomousExcavatorGame tested with different models. +""" + +import os +import json +import argparse +import numpy as np +import matplotlib.pyplot as plt +import csv +from collections import defaultdict +import logging + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler("analysis.log"), + logging.StreamHandler() + ] +) +logger = logging.getLogger("autonomous-excavator.analysis") + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description='Analyze AutonomousExcavatorGame results') + parser.add_argument('--input_dir', default='./experiments', + help='Directory containing experiment results') + parser.add_argument('--output_dir', default='./analysis', + help='Directory to save analysis results') + parser.add_argument('--models', nargs='+', default=None, + help='Models to include in analysis (default: all models)') + return parser.parse_args() + +def find_experiment_dirs(input_dir, models=None): + """Find experiment directories for the AutonomousExcavatorGame.""" + experiment_dirs = [] + + # List all directories in the input directory + for item in os.listdir(input_dir): + item_path = os.path.join(input_dir, item) + if not os.path.isdir(item_path): + continue + + # Parse the directory name to extract the model name + # Assuming the format is "model_name_YYYY-MM-DD_HH-MM-SS" + parts = item.rsplit("_", 2) # Split from the right to separate timestamp + if len(parts) < 3: + continue # Skip directories that don't match the expected format + + model = "_".join(parts[:-2]) # Extract the model name (everything before the timestamp) + + # Filter by model if specified + if models and model not in models: + continue + + experiment_dirs.append((item_path, model)) + + return experiment_dirs +def load_results(experiment_dir): + """Load results from an experiment directory.""" + results = {} + + # Try to load actions and rewards + csv_path = os.path.join(experiment_dir, 'actions_rewards.csv') + if os.path.exists(csv_path): + actions = [] + rewards = [] + try: + with open(csv_path, 'r') as f: + reader = csv.reader(f) + next(reader) # Skip header + for row in reader: + if len(row) >= 2: + actions.append(int(row[0])) + rewards.append(float(row[1])) + results['actions'] = actions + results['cumulative_rewards'] = rewards + except Exception as e: + logger.error(f"Error loading CSV from {csv_path}: {str(e)}") + + return results + +def analyze_results(experiment_dirs): + """Analyze results from all experiment directories.""" + analysis = defaultdict(list) # Store a list of experiments for each model + + for exp_dir, model in experiment_dirs: + logger.info(f"Analyzing results for model {model} in directory {exp_dir}") + results = load_results(exp_dir) + + if not results: + logger.warning(f"No results found in {exp_dir}") + continue + + # Extract timestamp from the directory name + timestamp = "_".join(exp_dir.split("_")[-2:]) # Extract both date and time + + # Calculate metrics + if 'cumulative_rewards' in results and results['cumulative_rewards']: + final_reward = results['cumulative_rewards'][-1] + max_reward = max(results['cumulative_rewards']) + + analysis[model].append({ + 'final_reward': final_reward, + 'max_reward': max_reward, + 'rewards': results['cumulative_rewards'], + 'timestamp': timestamp + }) + + logger.info(f"Model {model} ({timestamp}): Final reward = {final_reward}, Max reward = {max_reward}") + + return analysis + +def plot_results(analysis, output_dir): + """Generate plots from the analysis results.""" + os.makedirs(output_dir, exist_ok=True) + + # Plot reward curves for each experiment + plt.figure(figsize=(12, 6)) + for model, experiments in analysis.items(): + for exp in experiments: + rewards = exp['rewards'] + timestamp = exp['timestamp'] + plt.plot(rewards, label=f"{model} ({timestamp})") + + plt.title('Cumulative Rewards for AutonomousExcavatorGame') + plt.xlabel('Steps') + plt.ylabel('Cumulative Reward') + plt.legend() + plt.grid(True) + plt.savefig(os.path.join(output_dir, 'cumulative_rewards.png')) + plt.close() + + # Create a bar chart of final rewards for all experiments + models = [] + final_rewards = [] + for model, experiments in analysis.items(): + for exp in experiments: + models.append(f"{model} ({exp['timestamp']})") + final_rewards.append(exp['final_reward']) + + plt.figure(figsize=(10, 6)) + plt.bar(models, final_rewards, color='skyblue') + plt.xlabel('Experiments') + plt.ylabel('Final Reward') + plt.title('Final Rewards by Experiment for AutonomousExcavatorGame') + plt.xticks(rotation=45, ha='right') + plt.tight_layout() + plt.savefig(os.path.join(output_dir, 'final_rewards.png')) + plt.close() + + # Save the analysis as JSON + with open(os.path.join(output_dir, 'analysis.json'), 'w') as f: + # Convert defaultdict to regular dict for JSON serialization + analysis_dict = {model: data for model, data in analysis.items()} + + # Remove the rewards arrays to keep the JSON file small + for model in analysis_dict: + if 'rewards' in analysis_dict[model]: + del analysis_dict[model]['rewards'] + + json.dump(analysis_dict, f, indent=2) + +def main(): + """Main function.""" + args = parse_args() + + # Find experiment directories + experiment_dirs = find_experiment_dirs(args.input_dir, args.models) + + if not experiment_dirs: + logger.error(f"No experiment directories found in {args.input_dir}") + return + + logger.info(f"Found {len(experiment_dirs)} experiment directories") + + # Analyze results + analysis = analyze_results(experiment_dirs) + + # Plot results + plot_results(analysis, args.output_dir) + + logger.info(f"Analysis complete. Results saved to {args.output_dir}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/llm/game/main_manual_llm.py b/llm/game/main_manual_llm.py new file mode 100644 index 0000000..c10a4fe --- /dev/null +++ b/llm/game/main_manual_llm.py @@ -0,0 +1,345 @@ +import time +import jax +import jax.numpy as jnp +import pygame as pg +import json +import os +from tqdm import tqdm +import csv +import numpy as np +import datetime +import argparse + +from pygame.locals import ( + K_q, + QUIT, +) + +from terra.config import BatchConfig +from terra.config import EnvConfig +from terra.env import TerraEnvBatch +from llm.adk_llm import * + +from llm.utils_llm import * + + +from google.adk.agents import Agent +from google.adk.models.lite_llm import LiteLlm +from google.adk.sessions import InMemorySessionService +from google.adk.runners import Runner +from google.genai import types + +os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "False" + +def run_experiment(llm_model_name, llm_model_key, num_timesteps, seed): + """ + Run an LLM-based simulation experiment. + + Args: + model_name: The name of the LLM model to use. + model_key: The name of the LLM model key to use. + num_timesteps: The number of timesteps to run + + Returns: + None + """ + + f = open('llm/game/prompt.txt', 'r') + system_message = f.read() + f.close() + + batch_cfg = BatchConfig() + action_type = batch_cfg.action_type + n_envs_x = 1 + n_envs_y = 1 + n_envs = n_envs_x * n_envs_y + rng = jax.random.PRNGKey(seed) + env = TerraEnvBatch( + rendering=True, + display=True, + n_envs_x_rendering=n_envs_x, + n_envs_y_rendering=n_envs_y, + ) + + print("Starting the environment...") + start_time = time.time() + env_cfgs = jax.vmap(lambda x: EnvConfig.new())(jnp.arange(n_envs)) + rng, _rng = jax.random.split(rng) + _rng = _rng[None] + timestep = env.reset(env_cfgs, _rng) + + if llm_model_key == "gpt": + model_name_extended = "openai/{}".format(llm_model_name) + elif llm_model_key == "claude": + model_name_extended = "anthropic/{}".format(llm_model_name) + else: + model_name_extended = llm_model_name + + # Initialize the agent + print("Using model: ", model_name_extended) + + if llm_model_key == "gemini": + agent_excavator = Agent( + name="ExcavatorAgent", + model=model_name_extended, + description="You are an excavator agent. You can control the excavator to dig and move.", + instruction=system_message, + ) + else: + agent_excavator = Agent( + name="ExcavatorAgent", + model=LiteLlm(model=model_name_extended), + description="You are an excavator agent. You can control the excavator to dig and move.", + instruction=system_message, + ) + print("Agent initialized.") + + session_service = InMemorySessionService() + + APP_NAME = "ExcavatorGameApp" + USER_ID = "user_1" + SESSION_ID = "session_001" + + session = session_service.create_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id=SESSION_ID, + ) + + print("Session created. App: ", APP_NAME, " User ID: ", USER_ID, " Session ID: ", SESSION_ID) + + runner = Runner( + agent=agent_excavator, + app_name=APP_NAME, + session_service=session_service, + ) + + print(f"Runner initialized for agent {runner.agent.name}.") + + llm_query = LLM_query( + model_name=model_name_extended, + model=llm_model_key, + system_message=system_message, + action_size=7, + session_id=SESSION_ID, + runner=runner, + user_id=USER_ID, + ) + + print("LLM query initialized.") + + # Define the repeat_action function + def repeat_action(action, n_times=n_envs): + return action_type.new(action.action[None].repeat(n_times, 0)) + + # Trigger the JIT compilation + timestep = env.step(timestep, repeat_action(action_type.do_nothing()), _rng) + end_time = time.time() + print(f"Environment started. Compilation time: {end_time - start_time} seconds.") + + env.terra_env.render_obs_pygame(timestep.observation, timestep.info) + + playing = True + screen = pg.display.get_surface() + rewards = 0 + cumulative_rewards = [] + action_list = [] + steps_taken = 0 + num_timesteps = num_timesteps + frames = [] + + progress_bar = tqdm(total=num_timesteps, desc="Rollout", unit="steps") + + previous_action = [] + current_map = timestep.state.world.target_map.map[0] # Extract the target map + initial_target_num = jnp.sum(current_map < 0) # Count the initial target pixels + print("Initial target number: ", initial_target_num) + previous_map = current_map.copy() # Initialize the previous map + count_map_change = 0 + DETERMINISTIC = True + + while playing and steps_taken < num_timesteps: + for event in pg.event.get(): + if event.type == QUIT or (event.type == pg.KEYDOWN and event.key == K_q): + playing = False + + current_map = timestep.state.world.target_map.map[0] # Extract the target map + if previous_map is None or not jnp.array_equal(previous_map, current_map): + print("Map changed!") + count_map_change += 1 + initial_target_num = jnp.sum(current_map < 0) # Count the initial target pixels + print("Current target number: ", initial_target_num) + + previous_map = current_map.copy() # Update the previous map + previous_action = [] # Reset the previous action list + llm_query.delete_messages() # Clear previous messages + + game_state_image = capture_screen(screen) + frames.append(game_state_image) + + state = timestep.state + base_orientation = extract_base_orientation(state) + bucket_status = extract_bucket_status(state) # Extract bucket status + + start, target_positions = extract_positions(timestep.state) + nearest_target = find_nearest_target(start, target_positions) + + print(f"Current direction: {base_orientation['direction']}") + print(f"Bucket status: {bucket_status}") + print(f"Current position: {start} (y,x)") + print(f"Nearest target position: {nearest_target} (y,x)") + print(f"Previous action list: {previous_action}") + + usr_msg = ( + f"Analyze this game frame and the provided local map to select the optimal action. " + f"The base of the excavator is currently facing {base_orientation['direction']}. " + f"The bucket is currently {bucket_status}. " + f"The excavator is currently located at {start} (y,x). " + f"The target digging positions are {target_positions} (y,x). " + f"The list of the previous actions is {previous_action}. " + f"Ensure that the excavator base maintains a safe minimum distance (8 to 10 pixels) from the target area to allow proper alignment of the orange area with the purple area for efficient digging. " + f"Avoid moving too close to the purple area to prevent overlap with the base. " + f"If the previous action was digging and the bucket is still empty, moving backward can be an appropriate action to reposition. You can then try to dig in the next action. " + f"Focus on immediate gameplay elements visible in this specific frame and the spatial context from the map. " + f"Follow the format: {{\"reasoning\": \"detailed step-by-step analysis\", \"action\": X}}" + ) + + llm_query.add_user_message(frame=game_state_image, user_msg=usr_msg, local_map=None) + + action_output, reasoning = llm_query.generate_response("./") + + print(f"\n Action output: {action_output}, Reasoning: {reasoning}") + + llm_query.add_assistant_message() + + previous_action.append(action_output) + + # Create the action object + action = action_type.new(action_output) + + # Add a batch dimension to the action + action = jax.tree_map(lambda x: jnp.expand_dims(x, axis=0), action) + + # Repeat the action for all environments + batched_action = repeat_action(action) + + # Perform the action in the environment + # rng, _rng = jax.random.split(rng) + # _rng = _rng[None] + # timestep = env.step(timestep, batched_action, _rng) + if DETERMINISTIC: + key = jnp.array([[count_map_change, count_map_change]], dtype=jnp.uint32) # Convert to a JAX array + + timestep = env.step( + timestep, + repeat_action(action), + key, + ) + else: + rng, _rng = jax.random.split(rng) + _rng = _rng[None] + + timestep = env.step( + timestep, + repeat_action(action), + _rng, + ) + + + # Update rewards and actions + print(f"Reward: {timestep.reward.item()}") + rewards += timestep.reward.item() + cumulative_rewards.append(rewards) + action_list.append(action_output) + + # Render the environment + env.terra_env.render_obs_pygame(timestep.observation, timestep.info) + + # if steps_taken % 5 == 1: + # agent.delete_messages() + + # Update progress + steps_taken += 1 + progress_bar.update(1) + progress_bar.set_postfix({"reward": rewards}) + + # Close progress bar + progress_bar.close() + + print(f"Rollout complete. Total reward: {rewards}") + + # Generate a timestamp + current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + + # Create a unique output directory for the model and timestamp + output_dir = os.path.join("experiments", f"{llm_model_name}_{current_time}") + print(f"Output directory: {output_dir}") + os.makedirs(output_dir, exist_ok=True) + + # Save actions and rewards to a CSV file + output_file = os.path.join(output_dir, "actions_rewards.csv") + with open(output_file, "w") as f: + writer = csv.writer(f) + writer.writerow(["actions", "cumulative_rewards"]) + for action, cum_reward in zip(action_list, cumulative_rewards): + writer.writerow([action, cum_reward]) + + print(f"Results saved to {output_file}") + + # Save the gameplay video + video_path = os.path.join(output_dir, "gameplay.mp4") + save_video(frames, video_path) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run an LLM-based simulation experiment.") + parser.add_argument( + "--model_name", + type=str, + required=True, + choices=["gpt-4o", + "gpt-4.1", + "gpt-5", + "o4-mini", + "o3", + "o3-mini", + "gemini-1.5-flash-latest", + "gemini-2.0-flash", + "gemini-2.5-pro", + "gemini-2.5-flash", + "claude-3-haiku-20240307", + "claude-3-7-sonnet-20250219", + "claude-opus-4-20250514", + "claude-sonnet-4-20250514", + ], + help="Name of the LLM model to use." + ) + parser.add_argument( + "--model_key", + type=str, + required=True, + choices=["gpt", + "gemini", + "claude"], + help="Name of the LLM model key to use." + ) + parser.add_argument( + "--num_timesteps", + type=int, + default=100, + help="Number of timesteps to run." + ) + parser.add_argument( + "-s", + "--seed", + type=int, + default=0, + help="Random seed for the environment.", + ) + args = parser.parse_args() + + run_experiment( + args.model_name, + args.model_key, + args.num_timesteps, + args.seed, + ) \ No newline at end of file diff --git a/llm/game/prompt.txt b/llm/game/prompt.txt new file mode 100644 index 0000000..d47d8fc --- /dev/null +++ b/llm/game/prompt.txt @@ -0,0 +1,34 @@ +You are an intelligent assistant responsible for selecting the optimal action for an autonomous excavator to efficiently dig and deposit soil. + +Inputs: +(1) Images: + (a) Current game state, + (b) Target digging profile (positive values), +(2) Positions: + (a) Excavator base (y, x), + (b) List of dig target coordinates (y, x), both (a) and (b) use top-down coordinates + (X-axis: left to right, Y-axis: top to bottom); +(3) Past actions (list of previous actions); +(4) Bucket state (loaded/empty); +(5) Main base orientation (up, down, left, right, or unknown (if not one of the previous case.)). + +Critical Constraints: +(a) The excavator base must maintain a minimum distance from the digging target +area to allow overlap of the orange area with the purple area; +(b) Digging needs to be performed only once per target β€” repeated digging is inefficient; +(c) The excavated area (marked in blue) is considered an obstacle and is not traversable. + +Objectives: +(1) Ensure proper reach and spacing for efficient digging; +(2) After digging, rotate the cabin and use action 6 (DO) to deposit soil, ensuring the +deposit is far from the digging area to avoid obstruction during movements. + +Actions: '-1': DO_NOTHING, '0': FORWARD, '1': BACKWARD, '2': CLOCK, '3': ANTICLOCK, +'4': CABIN_CLOCK, '5': CABIN_ANTICLOCK, '6': DO. + +Rules: +- Avoid black obstacles on the map; +- The red arrow shows the cabin's direction, and it turns grey when the bucket is loaded; +- After digging, the base cannot move until the bucket is emptied. + +Output: {\"reasoning\": \"Explanation of why this action is optimal.\", \"action\": }"} diff --git a/llm/main_llm.py b/llm/main_llm.py new file mode 100644 index 0000000..84a2765 --- /dev/null +++ b/llm/main_llm.py @@ -0,0 +1,781 @@ +""" +Partially from https://github.com/RobertTLange/gymnax-blines +""" + +import numpy as np +import jax +from utils.helpers import load_pkl_object + +import jax.numpy as jnp +from utils.utils_ppo import obs_to_model_input + +from tensorflow_probability.substrates import jax as tfp +from train import TrainConfig # needed for unpickling checkpoints +from terra.config import EnvConfig +from terra.config import BatchConfig + + +from llm.utils_llm import * +from llm.adk_llm import * +from terra.actions import ( + WheeledAction, + TrackedAction, + WheeledActionType, + TrackedActionType, +) + +import asyncio +import os +import argparse +import datetime +import json +import pygame as pg + +from pygame.locals import ( + K_q, + QUIT, +) + +from llm.eval_llm import compute_stats_llm +from llm.env_manager_llm import EnvironmentsManager + +os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "False" + +def run_experiment(llm_model_name, llm_model_key, num_timesteps, seed, + run, current_experiment_number, env_manager, global_env_config,small_env_config=None): + """ + Run an experiment with completely separate environments for global and small maps. + """ + + (FORCE_DELEGATE_TO_RL, FORCE_DELEGATE_TO_LLM, LLM_CALL_FREQUENCY, + USE_MANUAL_PARTITIONING, MAX_NUM_PARTITIONS, VISUALIZE_PARTITIONS, + USE_IMAGE_PROMPT , APP_NAME, USER_ID, SESSION_ID, + GRID_RENDERING, ORIGINAL_MAP_SIZE, + USE_RENDERING, _, ENABLE_INTERVENTION, INTERVENTION_FREQUENCY, + STUCK_WINDOW, MIN_REWARD, USE_RANDOM_PARTITIONING, + USE_EXACT_NUMBER_OF_PARTITIONS, SAVE_VIDEO, FPS, _, USE_EXCLUSIVE_ASSIGNMENT + ) = setup_experiment_config() + + # Initialize once with proper batching + rng = jax.random.PRNGKey(seed) + rng, _rng = jax.random.split(rng) + rng_reset_initial = jax.random.split(_rng, 1) + + initial_custom_pos = None + initial_custom_angle = None + + # Initial setup + env_manager.global_env.timestep = env_manager.global_env.reset( + global_env_config, rng_reset_initial, initial_custom_pos, initial_custom_angle + ) + + batch_cfg = BatchConfig() + action_type = batch_cfg.action_type + + def repeat_action(action, n_times=1): + return action_type.new(action.action[None].repeat(n_times, 0)) + + # Trigger the JIT compilation + env_manager.global_env.timestep = env_manager.global_env.step( + env_manager.global_env.timestep, repeat_action(action_type.do_nothing()), rng_reset_initial + ) + + if USE_RENDERING: + env_manager.global_env.terra_env.render_obs_pygame( + env_manager.global_env.timestep.observation, env_manager.global_env.timestep.info + ) + + # Initialize variables for tracking progress across all maps + global_step = 0 + playing = True + current_map_index = 0 + max_maps = 10 # Set a reasonable limit for number of maps to process + + # For visualization and metrics across all maps + all_frames = [] + all_reward_seq = [] + all_global_step_rewards = [] + all_obs_seq = [] + all_action_list = [] + + tile_size = global_env_config.tile_size[0].item() + move_tiles = global_env_config.agent.move_tiles[0].item() + + action_type = batch_cfg.action_type + if action_type == TrackedAction: + move_actions = (TrackedActionType.FORWARD, TrackedActionType.BACKWARD) + l_actions = () + do_action = TrackedActionType.DO + elif action_type == WheeledAction: + move_actions = (WheeledActionType.FORWARD, WheeledActionType.BACKWARD) + l_actions = (WheeledActionType.CLOCK, WheeledActionType.ANTICLOCK) + do_action = WheeledActionType.DO + else: + raise (ValueError(f"{action_type=}")) + + obs = env_manager.global_env.timestep.observation + areas = (obs["target_map"] == -1).sum( + tuple([i for i in range(len(obs["target_map"].shape))][1:]) + ) * (tile_size**2) + target_maps_init = obs["target_map"].copy() + dig_tiles_per_target_map_init = (target_maps_init == -1).sum( + tuple([i for i in range(len(target_maps_init.shape))][1:]) + ) + reward_seq = [] + episode_done_once = None + episode_length = None + move_cumsum = None + do_cumsum = None + sub_task_seed = current_experiment_number + + screen = pg.display.get_surface() + + # MAIN LOOP - PROCESS MULTIPLE MAPS + while playing and global_step < num_timesteps and current_map_index < max_maps: + print(f"\n{'='*80}") + print(f"STARTING MAP {current_map_index}") + print(f"{'='*80}") + + # Reset to next map (reusing the same environment) + try: + #if current_map_index > 0: # Don't reset on first map since it's already initialized + reset_to_next_map(current_map_index, seed, env_manager, global_env_config, + initial_custom_pos, initial_custom_angle) + + env_manager.global_env.terra_env.render_obs_pygame( + env_manager.global_env.timestep.observation, + env_manager.global_env.timestep.info + ) + screen = pg.display.get_surface() + game_state_image = capture_screen(screen) + + llm_query, runner_delegation, session_manager, prompts = setup_partitions_and_llm( + current_map_index, ORIGINAL_MAP_SIZE, env_manager, + config, llm_model_name, llm_model_key, + APP_NAME, USER_ID, SESSION_ID, screen, + USE_MANUAL_PARTITIONING, USE_IMAGE_PROMPT, MAX_NUM_PARTITIONS,USE_EXACT_NUMBER_OF_PARTITIONS, USE_RANDOM_PARTITIONING, sub_task_seed) + partition_states, partition_models, active_partitions = initialize_partitions_for_current_map(env_manager, config, model_params) + + if USE_EXCLUSIVE_ASSIGNMENT: + env_manager.initialize_partition_specific_target_maps_with_exclusive_assignment(partition_states) + else: + env_manager.initialize_partition_specific_target_maps(partition_states) + + if partition_states is None: + print(f"Failed to initialize map {current_map_index}, moving to next map") + current_map_index += 1 + continue + + except Exception as e: + print(f"Error setting up map {current_map_index}: {e}") + current_map_index += 1 + continue + + # Track metrics for this map + map_frames = [] + map_reward_seq = [] + map_global_step_rewards = [] + map_obs_seq = [] + map_action_list = [] + + # First step delegate to RL agent + llm_decision = "delegate_to_rl" + + # MAP-SPECIFIC GAME LOOP + map_step = 0 + max_steps_per_map = num_timesteps + map_done = False # Track map completion + + while playing and active_partitions and map_step < max_steps_per_map and global_step < num_timesteps: + # Handle quit events + if USE_RENDERING: + for event in pg.event.get(): + if event.type == QUIT or (event.type == pg.KEYDOWN and event.key == K_q): + playing = False + + print(f"\nMap {current_map_index}, Step {map_step} (Global {global_step}) - " + f"Processing {len(active_partitions)} active partitions") + + if USE_RENDERING: + #Capture screen state + screen = pg.display.get_surface() + game_state_image = capture_screen(screen) + + else: + screen = None + game_state_image = None + map_frames.append(game_state_image) + + # Step all active partitions simultaneously + partitions_to_remove = [] + current_step_reward = 0.0 + + for partition_idx in active_partitions: + partition_state = partition_states[partition_idx] + print(f" Processing partition {partition_idx} (partition step {partition_state['step_count']})") + + try: + # Set the small environment to the current partition's state + env_manager.small_env_timestep = partition_state['timestep'] + env_manager.current_partition_idx = partition_idx + + current_observation = partition_state['timestep'].observation + + map_obs_seq.append(current_observation) + + # Extract partition info and create subsurface + partition_info = env_manager.partitions[partition_idx] + region_coords = partition_info['region_coords'] + y_start, x_start, y_end, x_end = region_coords + width = x_end - x_start + 1 + height = y_end - y_start + 1 + + if USE_RENDERING: + subsurface = extract_subsurface(screen, x_start, y_start, width, height, ORIGINAL_MAP_SIZE, global_env_config, partition_idx) + game_state_image_small = capture_screen(subsurface) + else: + game_state_image_small = None + + state = env_manager.small_env_timestep.state + base_orientation = extract_base_orientation(state) + bucket_status = extract_bucket_status(state) + + # LLM decision making + if global_step % LLM_CALL_FREQUENCY == 0 and global_step > 0 and \ + FORCE_DELEGATE_TO_RL is False and \ + FORCE_DELEGATE_TO_LLM is False: + + print(" Calling LLM agent for decision...") + + # Check if intervention is enabled and needed + needs_intervention = False + stuck_info = {'is_stuck': False, 'reason': 'not_checked'} + + if ENABLE_INTERVENTION: + stuck_info = detect_stuck_excavator( + partition_state, + threshold_steps=STUCK_WINDOW, + min_reward_threshold=MIN_REWARD + ) + needs_intervention = should_intervene( + partition_state, + active_partitions, + intervention_frequency=INTERVENTION_FREQUENCY + ) + + if needs_intervention: + print(f" Partition {partition_idx} appears stuck: {stuck_info['reason']}") + print(f" Details: {stuck_info['details']}") + + try: + obs_dict = {k: v.tolist() for k, v in current_observation.items()} + observation_str = json.dumps(obs_dict) + + except AttributeError: + # Handle the case where current_observation is not a dictionary + observation_str = str(current_observation) + + # Enhanced context with stuck information + base_context = f"Map {current_map_index}, Step {map_step}" + if needs_intervention and ENABLE_INTERVENTION: + stuck_context = f" | STUCK: {stuck_info['reason']} - {stuck_info['details']}" + context = base_context + stuck_context + else: + context = base_context + + if USE_IMAGE_PROMPT: + delegation_prompt = get_delegation_prompt( + prompts, + "See image", + context=context, + ENABLE_INTERVENTION=ENABLE_INTERVENTION + ) + else: + delegation_prompt = get_delegation_prompt( + prompts, + observation_str, + context=context, + ENABLE_INTERVENTION=ENABLE_INTERVENTION + ) + delegation_session_id = f"{SESSION_ID}_map_{current_map_index}_delegation" # This creates "session_001_map_0_delegation" + delegation_user_id = f"{USER_ID}_delegation" # This creates "user_1_delegation" + + try: + if USE_IMAGE_PROMPT: + response = asyncio.run(call_agent_async_master( + delegation_prompt, + game_state_image_small, + runner_delegation, + delegation_user_id, + delegation_session_id, + session_manager + )) + else: + response = asyncio.run(call_agent_async_master( + delegation_prompt, + None, + runner_delegation, + delegation_user_id, + delegation_session_id, + session_manager + )) + + llm_response_text = response + print(f"LLM response: {llm_response_text}") + + if "delegate_to_rl" in llm_response_text.lower(): + llm_decision = "delegate_to_rl" + print("Delegating to RL agent based on LLM response.") + elif "delegate_to_llm" in llm_response_text.lower(): + llm_decision = "delegate_to_llm" + print("Delegating to LLM agent based on LLM response.") + elif "intervention" in llm_response_text.lower(): + llm_decision = "intervention" + print("INTERVENTION mode activated based on LLM response.") + else: + # Default fallback + if needs_intervention: + llm_decision = "intervention" + print("INTERVENTION mode activated due to detected stuck condition.") + else: + llm_decision = "delegate_to_rl" + + except Exception as adk_err: + print(f"Error during ADK agent communication: {adk_err}") + if needs_intervention: + llm_decision = "intervention" + print("INTERVENTION mode activated due to communication error and stuck condition.") + else: + llm_decision = "delegate_to_rl" + + if FORCE_DELEGATE_TO_LLM: + llm_decision = "delegate_to_llm" + elif FORCE_DELEGATE_TO_RL: + llm_decision = "delegate_to_rl" + + # Action selection + if llm_decision == "delegate_to_rl": + print(f" Partition {partition_idx} - Delegating to RL agent") + try: + batched_observation = add_batch_dimension_to_observation(current_observation) + obs = obs_to_model_input(batched_observation, partition_state['prev_actions_rl'], config) + + current_model = partition_models[partition_idx] + _, logits_pi = current_model['model'].apply(current_model['params'], obs) + pi = tfp.distributions.Categorical(logits=logits_pi) + + # Use map-specific random key + action_rng = jax.random.PRNGKey(seed + global_step * len(env_manager.partitions) + partition_idx + current_map_index * 10000) + action_rng, action_key, step_key = jax.random.split(action_rng, 3) + action_rl = pi.sample(seed=action_key) + + partition_state['actions'].append(action_rl) + map_action_list.append(action_rl) + + except Exception as rl_error: + print(f" ERROR getting action from RL model for partition {partition_idx}: {rl_error}") + action_rl = jnp.array(0) + partition_state['actions'].append(action_rl) + map_action_list.append(action_rl) + + elif llm_decision == "delegate_to_llm": + print(f" Partition {partition_idx} - Delegating to LLM agent") + + start = env_manager.small_env_timestep.state.agent.agent_state.pos_base + + msg = get_excavator_prompt(prompts, + base_orientation['direction'], + bucket_status, + start) + + llm_query.add_user_message(frame=game_state_image_small, user_msg=msg, local_map=None) + action_output, reasoning = llm_query.generate_response("./") + print(f"\n Action output: {action_output}, Reasoning: {reasoning}") + llm_query.add_assistant_message() + + action_rl = jnp.array([action_output], dtype=jnp.int32) + map_action_list.append(action_rl) + + elif llm_decision == "intervention" and ENABLE_INTERVENTION: + print(f" Partition {partition_idx} - INTERVENTION MODE") + total_interventions += 1 + if partition_idx not in partition_interventions: + partition_interventions[partition_idx] = 0 + partition_interventions[partition_idx] += 1 + + try: + stuck_info = detect_stuck_excavator( + partition_state, + threshold_steps=STUCK_WINDOW, + min_reward_threshold=MIN_REWARD + ) + action_rl = get_intervention_action(partition_state, stuck_info, action_type) + partition_state['actions'].append(action_rl) + map_action_list.append(action_rl) + + # Log intervention details + print(f" Intervention #{total_interventions} for partition {partition_idx}") + print(f" Reason: {stuck_info['reason']} | Action: {action_rl}") + except Exception as intervention_error: + print(f" ERROR during intervention for partition {partition_idx}: {intervention_error}") + # Fallback to a safe action + action_rl = jnp.array([0], dtype=jnp.int32) # Forward movement + partition_state['actions'].append(action_rl) + map_action_list.append(action_rl) + + else: + print(" Master Agent stop.") + action_rl = jnp.array([-1], dtype=jnp.int32) + map_action_list.append(action_rl) + + # Clear LLM messages periodically + if len(llm_query.messages) > 3: + llm_query.delete_messages() + + # Update action history and step environment + partition_state['prev_actions_rl'] = jnp.roll(partition_state['prev_actions_rl'], shift=1, axis=1) + partition_state['prev_actions_rl'] = partition_state['prev_actions_rl'].at[:, 0].set(action_rl) + + wrapped_action = wrap_action_llm(action_rl, action_type) + + # Take step with full sync + if ORIGINAL_MAP_SIZE == 64: + new_timestep = env_manager.step_with_full_global_sync(partition_idx, wrapped_action, partition_states) + else: + new_timestep = env_manager.step_with_full_global_sync_big(partition_idx, wrapped_action, partition_states) + partition_states[partition_idx]['timestep'] = new_timestep + partition_state['step_count'] += 1 + + # Process reward + reward = new_timestep.reward + if isinstance(reward, jnp.ndarray): + if reward.shape == (): + reward_val = float(reward) + elif len(reward.shape) > 0: + reward_val = float(reward.flatten()[0]) + else: + reward_val = float(reward) + else: + reward_val = float(reward) + + if not (jnp.isnan(reward_val) or jnp.isinf(reward_val)): + partition_state['rewards'].append(reward_val) + partition_state['total_reward'] += reward_val + map_reward_seq.append(reward_val) + current_step_reward += reward_val + print(f" Partition {partition_idx} - reward: {reward_val:.4f}, action: {action_rl}, done: {new_timestep.done}") + else: + print(f" Partition {partition_idx} - INVALID reward: {reward_val}, action: {action_rl}, done: {new_timestep.done}") + + # Check completion conditions + partition_completed = False + + if env_manager.is_small_task_completed(): + print(f" Partition {partition_idx} COMPLETED after {partition_state['step_count']} steps!") + print(f" Total reward for partition {partition_idx}: {partition_state['total_reward']:.4f}") + env_manager.partitions[partition_idx]['status'] = 'completed' + partition_state['status'] = 'completed' + partition_completed = True + + elif partition_state['step_count'] >= max_steps_per_map: + print(f" Partition {partition_idx} TIMED OUT") + env_manager.partitions[partition_idx]['status'] = 'failed' + partition_state['status'] = 'failed' + partition_completed = True + + elif jnp.isnan(reward): + print(f" Partition {partition_idx} FAILED due to NaN reward") + env_manager.partitions[partition_idx]['status'] = 'failed' + partition_state['status'] = 'failed' + partition_completed = True + + if partition_completed: + partitions_to_remove.append(partition_idx) + + except Exception as e: + print(f" ERROR stepping partition {partition_idx}: {e}") + if partition_idx < len(env_manager.partitions): + env_manager.partitions[partition_idx]['status'] = 'failed' + partition_state['status'] = 'failed' + partitions_to_remove.append(partition_idx) + + + # Remove completed/failed partitions + for partition_idx in partitions_to_remove: + if partition_idx in active_partitions: + active_partitions.remove(partition_idx) + print(f" Removed partition {partition_idx} from active list") + + print(f" Remaining active partitions: {active_partitions}") + map_global_step_rewards.append(current_step_reward) + print(f" Map {current_map_index} step {map_step} reward: {current_step_reward:.4f}") + + # Render + if GRID_RENDERING: + env_manager.render_all_partition_views_grid(partition_states) + else: + env_manager.render_global_environment_with_multiple_agents(partition_states, VISUALIZE_PARTITIONS) + + # After processing all partitions, check if map is complete + map_metrics = calculate_map_completion_metrics(partition_states) + map_done = map_metrics['done'] + + # Update done flag for this step + done = jnp.array(map_done) # Convert to JAX array for consistency + + map_step += 1 + global_step += 1 + + reward_seq.append(current_step_reward) + + if episode_done_once is None: + episode_done_once = done + if episode_length is None: + episode_length = jnp.zeros_like(done, dtype=jnp.int32) + if move_cumsum is None: + move_cumsum = jnp.zeros_like(done, dtype=jnp.int32) + if do_cumsum is None: + do_cumsum = jnp.zeros_like(done, dtype=jnp.int32) + + episode_done_once = episode_done_once | done + episode_length += ~episode_done_once + + move_cumsum_tmp = jnp.zeros_like(done, dtype=jnp.int32) + for move_action in move_actions: + move_mask = (action_rl == move_action) * (~episode_done_once) + move_cumsum_tmp += move_tiles * tile_size * move_mask.astype(jnp.int32) + for l_action in l_actions: + l_mask = (action_rl == l_action) * (~episode_done_once) + move_cumsum_tmp += 2 * move_tiles * tile_size * l_mask.astype(jnp.int32) + move_cumsum += move_cumsum_tmp + + do_cumsum += (action_rl == do_action) * (~episode_done_once) + + dug_tiles_per_action_map = (env_manager.global_maps['action_map'] == -1).sum() + + # Add map data to global collections + all_frames.extend(map_frames) + all_reward_seq.extend(map_reward_seq) + all_global_step_rewards.extend(map_global_step_rewards) + all_obs_seq.extend(map_obs_seq) + all_action_list.extend(map_action_list) + + # Move to next map + current_map_index += 1 + + # Check if we should continue + if not playing or global_step >= num_timesteps: + break + + print(f"\nTransitioning to map {current_map_index}...") + + # FINAL SUMMARY ACROSS ALL MAPS + print(f"\n{'='*80}") + print(f"EXPERIMENT COMPLETED - PROCESSED {current_map_index} MAPS") + print(f"{'='*80}") + + if SAVE_VIDEO: + # Save results + current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + safe_model_name = llm_model_name.replace('/', '_') + output_dir = os.path.join("experiments", f"{safe_model_name}_{current_time}") + os.makedirs(output_dir, exist_ok=True) + + # Save video + video_path = os.path.join(output_dir, "gameplay_all_maps.mp4") + save_video(all_frames, video_path, FPS) + + print(f"\nResults saved to: {output_dir}") + + info = { + "episode_done_once": episode_done_once, + "episode_length": episode_length, + "move_cumsum": move_cumsum, + "do_cumsum": do_cumsum, + "areas": areas, + "dig_tiles_per_target_map_init": dig_tiles_per_target_map_init, + "dug_tiles_per_action_map": dug_tiles_per_action_map, + } + + # Print intervention statistics + if ENABLE_INTERVENTION and total_interventions > 0: # ← USED HERE + print(f"\nπŸ”§ INTERVENTION STATISTICS:") + print(f" Total interventions: {total_interventions}") + print(f" Interventions per partition: {partition_interventions}") + print(f" Intervention rate: {total_interventions/global_step:.1%}") + + # Save intervention stats + info["total_interventions"] = total_interventions + info["partition_interventions"] = partition_interventions + + return info + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run an LLM-based simulation experiment with RL agents.") + parser.add_argument( + "--model_name", + type=str, + required=True, + choices=["gpt-4o", + "gpt-4.1", + "gpt-5", + "o4-mini", + "o3", + "o3-mini", + "gemini-1.5-flash-latest", + "gemini-2.0-flash", + "gemini-2.5-pro", + "gemini-2.5-flash", + "claude-3-haiku-20240307", + "claude-3-7-sonnet-20250219", + "claude-opus-4-20250514", + "claude-sonnet-4-20250514", + ], + help="Name of the LLM model to use." + ) + parser.add_argument( + "--model_key", + type=str, + required=True, + choices=["gpt", + "gemini", + "claude"], + help="Name of the LLM model key to use." + ) + parser.add_argument( + "--num_timesteps", + type=int, + default=100, + help="Number of timesteps to run." + ) + parser.add_argument( + "-n", + "--n_envs", + type=int, + default=1, + help="Number of environments", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + default=0, + help="Random seed for the environment.", + ) + + parser.add_argument( + "-run", + "--run_name", + type=str, + default="/home/gioelemo/Documents/terra/tracked-dense.pkl", + help="Policy to use for the experiment. Must be a valid path to a .pkl file containing the policy.", + ) + + parser.add_argument( + "--level_index", + type=int, + default=None, + help="Index of the level to run from CurriculumGlobalConfig.levels. If None, runs all levels." + ) + + + args = parser.parse_args() + NUM_ENVS = args.n_envs + + if args.level_index is not None: + import os + os.environ['TERRA_LEVEL_INDEX'] = str(args.level_index) + + episode_done_once_list = [] + episode_length_list = [] + move_cumsum_list = [] + do_cumsum_list = [] + areas_list = [] + dig_tiles_per_target_map_init_list = [] + dug_tiles_per_action_map_list = [] + + base_seed = args.seed + + (FORCE_DELEGATE_TO_RL, FORCE_DELEGATE_TO_LLM, LLM_CALL_FREQUENCY, + USE_MANUAL_PARTITIONING, MAX_NUM_PARTITIONS, VISUALIZE_PARTITIONS, + USE_IMAGE_PROMPT , APP_NAME, USER_ID, SESSION_ID, + GRID_RENDERING, ORIGINAL_MAP_SIZE, + USE_RENDERING, USE_DISPLAY, ENABLE_INTERVENTION, INTERVENTION_FREQUENCY, + STUCK_WINDOW, MIN_REWARD, USE_RANDOM_PARTITIONING, + USE_EXACT_NUMBER_OF_PARTITIONS, SAVE_VIDEO, FPS, COMPUTE_BENCH_STATS, _ + ) = setup_experiment_config() + + # Track intervention statistics + total_interventions = 0 + partition_interventions = {} + + agent_checkpoint_path = args.run_name + model_params = None + config = None + + print(f"Loading RL agent configuration from: {agent_checkpoint_path}") + log = load_pkl_object(agent_checkpoint_path) + config = log["train_config"] + model_params = log["model"] + + # Create the original environment configs for the full map + global_env_config = jax.tree_map( + lambda x: x[0][None, ...].repeat(1, 0), log["env_config"] + ) + + config.num_test_rollouts = 1 + config.num_devices = 1 + config.num_embeddings_agent_min = 60 + + # Initialize the environment manager ONCE with all maps + print("Initializing environment manager with all maps...") + env_manager = EnvironmentsManager( + seed=base_seed, + global_env_config=global_env_config, + small_env_config=None, + shuffle_maps=False, + rendering=USE_RENDERING, + display=USE_DISPLAY, + size=ORIGINAL_MAP_SIZE, + ) + print("Environment manager initialized.") + + for i in range(NUM_ENVS): + print(f"Running experiment {i+1}/{NUM_ENVS} with args: {args}") + global_env_config = jax.tree_map( + lambda x: x[0][None, ...].repeat(1, 0), log["env_config"] + ) + info = run_experiment( + args.model_name, + args.model_key, + args.num_timesteps, + base_seed + i * 1000, # Ensure different seeds + #base_seed, + args.run_name, + i+1, + env_manager, global_env_config + ) + # Collect results from this experiment + episode_done_once = info["episode_done_once"] + episode_length = info["episode_length"] + move_cumsum = info["move_cumsum"] + do_cumsum = info["do_cumsum"] + areas = info["areas"] + dig_tiles_per_target_map_init = info["dig_tiles_per_target_map_init"] + dug_tiles_per_action_map = info["dug_tiles_per_action_map"] + + episode_done_once_list.append(episode_done_once.item()) + episode_length_list.append(episode_length.item()) + move_cumsum_list.append(move_cumsum.item()) + do_cumsum_list.append(do_cumsum.item()) + areas_list.append(areas.item()) + dig_tiles_per_target_map_init_list.append(dig_tiles_per_target_map_init.item()) + dug_tiles_per_action_map_list.append(dug_tiles_per_action_map.item()) + + print("\nExperiment completed for all environments.") + if COMPUTE_BENCH_STATS: + compute_stats_llm(episode_done_once_list, episode_length_list, move_cumsum_list, + do_cumsum_list, areas_list, dig_tiles_per_target_map_init_list, + dug_tiles_per_action_map_list) + diff --git a/llm/main_llm_random.py b/llm/main_llm_random.py new file mode 100644 index 0000000..166cb12 --- /dev/null +++ b/llm/main_llm_random.py @@ -0,0 +1,998 @@ +""" +Partially from https://github.com/RobertTLange/gymnax-blines +""" + +import numpy as np +import jax +from utils.helpers import load_pkl_object + +import jax.numpy as jnp +from utils.utils_ppo import obs_to_model_input + +from tensorflow_probability.substrates import jax as tfp +from train import TrainConfig # needed for unpickling checkpoints +from terra.config import EnvConfig +from terra.config import BatchConfig + + +from llm.utils_llm import * +from llm.adk_llm import * +from terra.actions import ( + WheeledAction, + TrackedAction, + WheeledActionType, + TrackedActionType, +) + +import asyncio +import os +import argparse +import datetime +import json +import pygame as pg + +from pygame.locals import ( + K_q, + QUIT, +) + +from llm.eval_llm import compute_stats_llm +from llm.env_manager_llm import EnvironmentsManager + +os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "False" + +def reset_to_same_map(map_index, seed, env_manager, global_env_config, + initial_custom_pos=None, initial_custom_angle=None): + """Reset the existing environment to the SAME map for different partition trials""" + print(f"\n{'='*60}") + print(f"RESETTING TO SAME MAP {map_index} (Different Partition Trial)") + print(f"{'='*60}") + + # Create SAME seed for this map to ensure we get the same map layout + # Use the map_index to determine the map, not the trial number + map_seed = seed + map_index * 1000 # This stays constant for all trials of the same map + map_rng = jax.random.PRNGKey(map_seed) + map_rng, reset_rng = jax.random.split(map_rng) + reset_keys = jax.random.split(reset_rng, 1) + + # Reset the existing environment to get the SAME map + env_manager.global_env.timestep = env_manager.global_env.reset( + global_env_config, reset_keys, initial_custom_pos, initial_custom_angle + ) + + # Extract and store the map data (this should be the same for all trials) + new_timestep = env_manager.global_env.timestep + env_manager.global_maps['target_map'] = new_timestep.state.world.target_map.map[0].copy() + env_manager.global_maps['action_map'] = new_timestep.state.world.action_map.map[0].copy() + env_manager.global_maps['dumpability_mask'] = new_timestep.state.world.dumpability_mask.map[0].copy() + env_manager.global_maps['dumpability_mask_init'] = new_timestep.state.world.dumpability_mask_init.map[0].copy() + env_manager.global_maps['padding_mask'] = new_timestep.state.world.padding_mask.map[0].copy() + env_manager.global_maps['traversability_mask'] = new_timestep.state.world.traversability_mask.map[0].copy() + env_manager.global_maps['trench_axes'] = new_timestep.state.world.trench_axes.copy() + env_manager.global_maps['trench_type'] = new_timestep.state.world.trench_type.copy() + + # Store the global timestep + env_manager.global_timestep = new_timestep + + # Clear any existing partitions data to ensure fresh start for new partition strategy + env_manager.partitions = [] + env_manager.overlap_map = {} + env_manager.overlap_regions = {} + + print(f"Environment reset to SAME map {map_index} for new partition trial") + print(f"Target map has {jnp.sum(env_manager.global_maps['target_map'] < 0)} dig targets") + +def run_single_map_corrected(map_idx, args, config, model_params, global_env_config): + """ + Run all partition trials for a single map and return the best result. + CORRECTED VERSION: Uses the same map for all partition trials. + """ + print(f"\n{'='*80}") + print(f"PROCESSING MAP {map_idx} (Job for map index {map_idx})") + print(f"TESTING DIFFERENT PARTITIONS ON THE SAME MAP") + print(f"{'='*80}") + + # Lists to store all partition results for this map + map_results = [] + + # Initialize global variables for intervention tracking + global total_interventions, partition_interventions + total_interventions = 0 + partition_interventions = {} + + # Track best coverage found so far + best_coverage_so_far = 0.0 + early_stop_achieved = False + + # Initialize environment manager ONCE for this map + # Calculate seed that will give us the specific map we want + map_seed = args.seed + map_idx * 50000 + + env_manager = EnvironmentsManager( + seed=map_seed, # This determines which map we get + global_env_config=global_env_config, + small_env_config=None, + shuffle_maps=False, + rendering=args.use_rendering, + display=args.use_display + ) + + # Create environment config for this map + map_env_config = jax.tree_map( + lambda x: x[0][None, ...].repeat(1, 0), global_env_config + ) + + # Initialize the map ONCE + print(f"Initializing map {map_idx} with seed {map_seed}") + reset_to_same_map(map_idx, args.seed, env_manager, map_env_config) + + # Store the initial map state to verify it's the same across trials + initial_target_map = env_manager.global_maps['target_map'].copy() + total_dig_targets = jnp.sum(initial_target_map == -1).item() + print(f"Map {map_idx} has {total_dig_targets} dig targets") + + # Run multiple partition experiments on the SAME map + for partition_trial in range(args.n_partitions_per_map): + print(f"\nMap {map_idx}, Partition Trial {partition_trial + 1}/{args.n_partitions_per_map}") + print(f"Using SAME map with DIFFERENT random partitioning strategy") + + try: + # Reset to the SAME map (this should give us identical terrain) + if partition_trial > 0: # Don't reset on first trial since we just initialized + reset_to_same_map(map_idx, args.seed, env_manager, map_env_config) + + # Verify we got the same map + current_target_map = env_manager.global_maps['target_map'] + if not jnp.array_equal(initial_target_map, current_target_map): + print("WARNING: Map changed between trials! This shouldn't happen.") + print(f"Initial dig targets: {jnp.sum(initial_target_map == -1)}") + print(f"Current dig targets: {jnp.sum(current_target_map == -1)}") + else: + print("βœ“ Confirmed: Using same map as previous trials") + + # Make config and model_params available globally for the function + globals()['config'] = config + globals()['model_params'] = model_params + + # Reset intervention counters for this trial + total_interventions = 0 + partition_interventions = {} + + # Use different seed for partitioning randomness only + # This affects partition generation but NOT the map itself + partition_seed = map_seed + partition_trial * 1000 + + # Temporarily modify the environment manager's seed for partition generation + original_seed = env_manager.seed if hasattr(env_manager, 'seed') else None + if hasattr(env_manager, 'seed'): + env_manager.seed = partition_seed + + # Run experiment with the SAME map but DIFFERENT partitioning + info = run_experiment_single_map_trial( + args.model_name, # llm_model_name + args.model_key, # llm_model_key + args.num_timesteps, # num_timesteps + partition_seed, # seed for partitioning randomness + args.run_name, # run + partition_trial + 1, # current_experiment_number (affects partition seed) + env_manager, # env_manager (contains the same map) + map_env_config, # global_env_config + map_idx # map_idx to track which map we're on + ) + + # Restore original seed if it existed + if original_seed is not None and hasattr(env_manager, 'seed'): + env_manager.seed = original_seed + + # Calculate coverage (dug tiles / total target tiles) + total_target_tiles = info["dig_tiles_per_target_map_init"].item() + dug_tiles = info["dug_tiles_per_action_map"].item() + coverage = dug_tiles / total_target_tiles if total_target_tiles > 0 else 0.0 + + # Verify the map didn't change during execution + if total_target_tiles != total_dig_targets: + print(f"WARNING: Target tile count changed from {total_dig_targets} to {total_target_tiles}") + + # Store results for this partition trial + trial_result = { + 'map_idx': map_idx, + 'trial_idx': partition_trial, + 'coverage': coverage, + 'episode_done_once': info["episode_done_once"].item(), + 'episode_length': info["episode_length"].item(), + 'move_cumsum': info["move_cumsum"].item(), + 'do_cumsum': info["do_cumsum"].item(), + 'areas': info["areas"].item(), + 'dig_tiles_per_target_map_init': total_target_tiles, + 'dug_tiles_per_action_map': dug_tiles, + 'total_interventions': info.get('total_interventions', 0), + 'partition_interventions': info.get('partition_interventions', {}), + 'partition_seed_used': partition_seed + } + + map_results.append(trial_result) + + print(f" Trial {partition_trial + 1} coverage: {coverage:.4f}") + + # Update best coverage tracking + if coverage > best_coverage_so_far: + best_coverage_so_far = coverage + + # Optional: Early stopping if full coverage achieved + if coverage >= 1.0: # 100% coverage achieved + early_stop_achieved = True + print(f" πŸŽ‰ FULL COVERAGE ACHIEVED! Early stopping after trial {partition_trial + 1}") + break + + except Exception as e: + print(f" ERROR in trial {partition_trial + 1}: {e}") + import traceback + traceback.print_exc() + + # Add default values for failed trials + trial_result = { + 'map_idx': map_idx, + 'trial_idx': partition_trial, + 'coverage': 0.0, + 'episode_done_once': False, + 'episode_length': 0, + 'move_cumsum': 0, + 'do_cumsum': 0, + 'areas': 0, + 'dig_tiles_per_target_map_init': total_dig_targets, + 'dug_tiles_per_action_map': 0, + 'total_interventions': 0, + 'partition_interventions': {}, + 'partition_seed_used': partition_seed if 'partition_seed' in locals() else 0 + } + map_results.append(trial_result) + + # Clean up environment manager + del env_manager + + # Find the best partition for this map + if map_results: + best_result = max(map_results, key=lambda x: x['coverage']) + best_coverage = best_result['coverage'] + best_trial_idx = best_result['trial_idx'] + + print(f"\nBest partition strategy for map {map_idx}:") + print(f" Trial: {best_trial_idx}") + print(f" Coverage: {best_coverage:.4f}") + print(f" Episode done: {best_result['episode_done_once']}") + print(f" Episode length: {best_result['episode_length']}") + + # Print summary of all trials for this map + coverages = [r['coverage'] for r in map_results] + print(f" Trials completed: {len(map_results)}/{args.n_partitions_per_map}") + print(f" All trials coverage - Mean: {np.mean(coverages):.4f}, Std: {np.std(coverages):.4f}") + print(f" All trials coverage - Min: {np.min(coverages):.4f}, Max: {np.max(coverages):.4f}") + + # Save results for this map + map_results_data = { + 'map_idx': map_idx, + 'map_seed_used': map_seed, + 'total_dig_targets': total_dig_targets, + 'best_result': best_result, + 'all_results': map_results, + 'early_stop_achieved': early_stop_achieved, + 'trials_completed': len(map_results), + 'trials_planned': args.n_partitions_per_map, + 'summary_stats': { + 'mean_coverage': float(np.mean(coverages)), + 'std_coverage': float(np.std(coverages)), + 'min_coverage': float(np.min(coverages)), + 'max_coverage': float(np.max(coverages)) + } + } + + # Save individual map results + results_dir = f"results_parallel/{args.model_name.replace('/', '_')}" + os.makedirs(results_dir, exist_ok=True) + + map_filename = os.path.join(results_dir, f"map_{map_idx:04d}_results.json") + with open(map_filename, 'w') as f: + json.dump(map_results_data, f, indent=2) + + print(f"Results for map {map_idx} saved to: {map_filename}") + + return best_result + else: + print(f"No valid results for map {map_idx}") + return None + +def run_experiment_single_map_trial(llm_model_name, llm_model_key, num_timesteps, seed, + run, current_experiment_number, env_manager, global_env_config, map_idx): + """ + Modified version of run_experiment that processes only ONE partition trial on a single map. + This replaces the multi-map loop with a single-map, single-trial approach. + """ + + (FORCE_DELEGATE_TO_RL, FORCE_DELEGATE_TO_LLM, LLM_CALL_FREQUENCY, + USE_MANUAL_PARTITIONING, MAX_NUM_PARTITIONS, VISUALIZE_PARTITIONS, + USE_IMAGE_PROMPT , APP_NAME, USER_ID, SESSION_ID, + GRID_RENDERING, ORIGINAL_MAP_SIZE, + USE_RENDERING, _, ENABLE_INTERVENTION, INTERVENTION_FREQUENCY, + STUCK_WINDOW, MIN_REWARD, USE_RANDOM_PARTITIONING, + USE_EXACT_NUMBER_OF_PARTITIONS, SAVE_VIDEO, FPS, _ + ) = setup_experiment_config() + + # Initialize global variables for intervention tracking + global total_interventions, partition_interventions + total_interventions = 0 + partition_interventions = {} + + # Initialize once with proper batching + rng = jax.random.PRNGKey(seed) + rng, _rng = jax.random.split(rng) + rng_reset_initial = jax.random.split(_rng, 1) + + initial_custom_pos = None + initial_custom_angle = None + + batch_cfg = BatchConfig() + action_type = batch_cfg.action_type + + def repeat_action(action, n_times=1): + return action_type.new(action.action[None].repeat(n_times, 0)) + + # Trigger the JIT compilation + env_manager.global_env.timestep = env_manager.global_env.step( + env_manager.global_env.timestep, repeat_action(action_type.do_nothing()), rng_reset_initial + ) + + if USE_RENDERING: + env_manager.global_env.terra_env.render_obs_pygame( + env_manager.global_env.timestep.observation, env_manager.global_env.timestep.info + ) + + # Initialize variables for tracking this single trial + global_step = 0 + playing = True + + # For visualization and metrics + all_frames = [] + all_reward_seq = [] + all_global_step_rewards = [] + all_obs_seq = [] + all_action_list = [] + + tile_size = global_env_config.tile_size[0].item() + move_tiles = global_env_config.agent.move_tiles[0].item() + + if action_type == TrackedAction: + move_actions = (TrackedActionType.FORWARD, TrackedActionType.BACKWARD) + l_actions = () + do_action = TrackedActionType.DO + elif action_type == WheeledAction: + move_actions = (WheeledActionType.FORWARD, WheeledActionType.BACKWARD) + l_actions = (WheeledActionType.CLOCK, WheeledActionType.ANTICLOCK) + do_action = WheeledActionType.DO + else: + raise ValueError(f"{action_type=}") + + obs = env_manager.global_env.timestep.observation + areas = (obs["target_map"] == -1).sum( + tuple([i for i in range(len(obs["target_map"].shape))][1:]) + ) * (tile_size**2) + target_maps_init = obs["target_map"].copy() + dig_tiles_per_target_map_init = (target_maps_init == -1).sum( + tuple([i for i in range(len(target_maps_init.shape))][1:]) + ) + reward_seq = [] + episode_done_once = None + episode_length = None + move_cumsum = None + do_cumsum = None + + # Use the current_experiment_number as partition trial seed modifier + sub_task_seed = current_experiment_number + + screen = pg.display.get_surface() if USE_RENDERING else None + + print(f"\n{'='*80}") + print(f"RUNNING PARTITION TRIAL {current_experiment_number} ON MAP {map_idx}") + print(f"{'='*80}") + + try: + # Render initial state + if USE_RENDERING: + env_manager.global_env.terra_env.render_obs_pygame( + env_manager.global_env.timestep.observation, + env_manager.global_env.timestep.info + ) + screen = pg.display.get_surface() + game_state_image = capture_screen(screen) + + # Setup partitions and LLM for this trial (this is where randomness comes in) + llm_query, runner_delegation, session_manager, prompts = setup_partitions_and_llm( + map_idx, ORIGINAL_MAP_SIZE, env_manager, + config, llm_model_name, llm_model_key, + APP_NAME, USER_ID, f"{SESSION_ID}_map_{map_idx}_trial_{current_experiment_number}", + screen, USE_MANUAL_PARTITIONING, USE_IMAGE_PROMPT, MAX_NUM_PARTITIONS, + USE_EXACT_NUMBER_OF_PARTITIONS, USE_RANDOM_PARTITIONING, sub_task_seed) + + partition_states, partition_models, active_partitions = initialize_partitions_for_current_map( + env_manager, config, model_params) + + env_manager.initialize_partition_specific_target_maps(partition_states) + + if partition_states is None: + raise Exception(f"Failed to initialize partitions for trial {current_experiment_number}") + + except Exception as e: + print(f"Error setting up trial {current_experiment_number}: {e}") + raise + + # Track metrics for this trial + trial_frames = [] + trial_reward_seq = [] + trial_global_step_rewards = [] + trial_obs_seq = [] + trial_action_list = [] + + # First step delegate to RL agent + llm_decision = "delegate_to_rl" + + # SINGLE TRIAL GAME LOOP + trial_step = 0 + max_steps_per_trial = num_timesteps + trial_done = False + + while playing and active_partitions and trial_step < max_steps_per_trial and global_step < num_timesteps: + # Handle quit events + if USE_RENDERING: + for event in pg.event.get(): + if event.type == QUIT or (event.type == pg.KEYDOWN and event.key == K_q): + playing = False + + print(f"\nTrial {current_experiment_number}, Step {trial_step} (Global {global_step}) - " + f"Processing {len(active_partitions)} active partitions") + + if USE_RENDERING: + # Capture screen state + screen = pg.display.get_surface() + game_state_image = capture_screen(screen) + else: + screen = None + game_state_image = None + + trial_frames.append(game_state_image) + + # Step all active partitions simultaneously + partitions_to_remove = [] + current_step_reward = 0.0 + + for partition_idx in active_partitions: + partition_state = partition_states[partition_idx] + print(f" Processing partition {partition_idx} (partition step {partition_state['step_count']})") + + try: + # Set the small environment to the current partition's state + env_manager.small_env_timestep = partition_state['timestep'] + env_manager.current_partition_idx = partition_idx + + current_observation = partition_state['timestep'].observation + trial_obs_seq.append(current_observation) + + # Extract partition info and create subsurface + partition_info = env_manager.partitions[partition_idx] + region_coords = partition_info['region_coords'] + y_start, x_start, y_end, x_end = region_coords + width = x_end - x_start + 1 + height = y_end - y_start + 1 + + if USE_RENDERING: + subsurface = extract_subsurface(screen, x_start, y_start, width, height, ORIGINAL_MAP_SIZE, global_env_config, partition_idx) + game_state_image_small = capture_screen(subsurface) + else: + game_state_image_small = None + + state = env_manager.small_env_timestep.state + base_orientation = extract_base_orientation(state) + bucket_status = extract_bucket_status(state) + + # LLM decision making + if global_step % LLM_CALL_FREQUENCY == 0 and global_step > 0 and \ + FORCE_DELEGATE_TO_RL is False and \ + FORCE_DELEGATE_TO_LLM is False: + + print(" Calling LLM agent for decision...") + + # Check if intervention is enabled and needed + needs_intervention = False + stuck_info = {'is_stuck': False, 'reason': 'not_checked'} + + if ENABLE_INTERVENTION: + stuck_info = detect_stuck_excavator( + partition_state, + threshold_steps=STUCK_WINDOW, + min_reward_threshold=MIN_REWARD + ) + needs_intervention = should_intervene( + partition_state, + active_partitions, + intervention_frequency=INTERVENTION_FREQUENCY + ) + + if needs_intervention: + print(f" Partition {partition_idx} appears stuck: {stuck_info['reason']}") + print(f" Details: {stuck_info['details']}") + + try: + obs_dict = {k: v.tolist() for k, v in current_observation.items()} + observation_str = json.dumps(obs_dict) + except AttributeError: + # Handle the case where current_observation is not a dictionary + observation_str = str(current_observation) + + # Enhanced context with stuck information + base_context = f"Map {map_idx}, Trial {current_experiment_number}, Step {trial_step}" + if needs_intervention and ENABLE_INTERVENTION: + stuck_context = f" | STUCK: {stuck_info['reason']} - {stuck_info['details']}" + context = base_context + stuck_context + else: + context = base_context + + if USE_IMAGE_PROMPT: + delegation_prompt = get_delegation_prompt( + prompts, + "See image", + context=context, + ENABLE_INTERVENTION=ENABLE_INTERVENTION + ) + else: + delegation_prompt = get_delegation_prompt( + prompts, + observation_str, + context=context, + ENABLE_INTERVENTION=ENABLE_INTERVENTION + ) + + delegation_session_id = f"{SESSION_ID}_map_{map_idx}_trial_{current_experiment_number}_delegation" + delegation_user_id = f"{USER_ID}_delegation" + + try: + if USE_IMAGE_PROMPT: + response = asyncio.run(call_agent_async_master( + delegation_prompt, + game_state_image_small, + runner_delegation, + delegation_user_id, + delegation_session_id, + session_manager + )) + else: + response = asyncio.run(call_agent_async_master( + delegation_prompt, + None, + runner_delegation, + delegation_user_id, + delegation_session_id, + session_manager + )) + + llm_response_text = response + print(f"LLM response: {llm_response_text}") + + if "delegate_to_rl" in llm_response_text.lower(): + llm_decision = "delegate_to_rl" + print("Delegating to RL agent based on LLM response.") + elif "delegate_to_llm" in llm_response_text.lower(): + llm_decision = "delegate_to_llm" + print("Delegating to LLM agent based on LLM response.") + elif "intervention" in llm_response_text.lower(): + llm_decision = "intervention" + print("INTERVENTION mode activated based on LLM response.") + else: + # Default fallback + if needs_intervention: + llm_decision = "intervention" + print("INTERVENTION mode activated due to detected stuck condition.") + else: + llm_decision = "delegate_to_rl" + + except Exception as adk_err: + print(f"Error during ADK agent communication: {adk_err}") + if needs_intervention: + llm_decision = "intervention" + print("INTERVENTION mode activated due to communication error and stuck condition.") + else: + llm_decision = "delegate_to_rl" + + if FORCE_DELEGATE_TO_LLM: + llm_decision = "delegate_to_llm" + elif FORCE_DELEGATE_TO_RL: + llm_decision = "delegate_to_rl" + + # Action selection + if llm_decision == "delegate_to_rl": + print(f" Partition {partition_idx} - Delegating to RL agent") + try: + batched_observation = add_batch_dimension_to_observation(current_observation) + obs = obs_to_model_input(batched_observation, partition_state['prev_actions_rl'], config) + + current_model = partition_models[partition_idx] + _, logits_pi = current_model['model'].apply(current_model['params'], obs) + pi = tfp.distributions.Categorical(logits=logits_pi) + + # Use trial-specific random key + action_rng = jax.random.PRNGKey(seed + global_step * len(env_manager.partitions) + partition_idx + current_experiment_number * 10000) + action_rng, action_key, step_key = jax.random.split(action_rng, 3) + action_rl = pi.sample(seed=action_key) + + partition_state['actions'].append(action_rl) + trial_action_list.append(action_rl) + + except Exception as rl_error: + print(f" ERROR getting action from RL model for partition {partition_idx}: {rl_error}") + action_rl = jnp.array(0) + partition_state['actions'].append(action_rl) + trial_action_list.append(action_rl) + + elif llm_decision == "delegate_to_llm": + print(f" Partition {partition_idx} - Delegating to LLM agent") + + start = env_manager.small_env_timestep.state.agent.agent_state.pos_base + + msg = get_excavator_prompt(prompts, + base_orientation['direction'], + bucket_status, + start) + + llm_query.add_user_message(frame=game_state_image_small, user_msg=msg, local_map=None) + action_output, reasoning = llm_query.generate_response("./") + print(f"\n Action output: {action_output}, Reasoning: {reasoning}") + llm_query.add_assistant_message() + + action_rl = jnp.array([action_output], dtype=jnp.int32) + trial_action_list.append(action_rl) + + elif llm_decision == "intervention" and ENABLE_INTERVENTION: + print(f" Partition {partition_idx} - INTERVENTION MODE") + total_interventions += 1 + if partition_idx not in partition_interventions: + partition_interventions[partition_idx] = 0 + partition_interventions[partition_idx] += 1 + + try: + stuck_info = detect_stuck_excavator( + partition_state, + threshold_steps=STUCK_WINDOW, + min_reward_threshold=MIN_REWARD + ) + action_rl = get_intervention_action(partition_state, stuck_info, action_type) + partition_state['actions'].append(action_rl) + trial_action_list.append(action_rl) + + # Log intervention details + print(f" Intervention #{total_interventions} for partition {partition_idx}") + print(f" Reason: {stuck_info['reason']} | Action: {action_rl}") + except Exception as intervention_error: + print(f" ERROR during intervention for partition {partition_idx}: {intervention_error}") + # Fallback to a safe action + action_rl = jnp.array([0], dtype=jnp.int32) # Forward movement + partition_state['actions'].append(action_rl) + trial_action_list.append(action_rl) + + else: + print(" Master Agent stop.") + action_rl = jnp.array([-1], dtype=jnp.int32) + trial_action_list.append(action_rl) + + # Clear LLM messages periodically + if len(llm_query.messages) > 3: + llm_query.delete_messages() + + # Update action history and step environment + partition_state['prev_actions_rl'] = jnp.roll(partition_state['prev_actions_rl'], shift=1, axis=1) + partition_state['prev_actions_rl'] = partition_state['prev_actions_rl'].at[:, 0].set(action_rl) + + wrapped_action = wrap_action_llm(action_rl, action_type) + + # Take step with full sync + new_timestep = env_manager.step_with_full_global_sync(partition_idx, wrapped_action, partition_states) + + partition_states[partition_idx]['timestep'] = new_timestep + partition_state['step_count'] += 1 + + # Process reward + reward = new_timestep.reward + if isinstance(reward, jnp.ndarray): + if reward.shape == (): + reward_val = float(reward) + elif len(reward.shape) > 0: + reward_val = float(reward.flatten()[0]) + else: + reward_val = float(reward) + else: + reward_val = float(reward) + + if not (jnp.isnan(reward_val) or jnp.isinf(reward_val)): + partition_state['rewards'].append(reward_val) + partition_state['total_reward'] += reward_val + trial_reward_seq.append(reward_val) + current_step_reward += reward_val + print(f" Partition {partition_idx} - reward: {reward_val:.4f}, action: {action_rl}, done: {new_timestep.done}") + else: + print(f" Partition {partition_idx} - INVALID reward: {reward_val}, action: {action_rl}, done: {new_timestep.done}") + + # Check completion conditions + partition_completed = False + + if env_manager.is_small_task_completed(): + print(f" Partition {partition_idx} COMPLETED after {partition_state['step_count']} steps!") + print(f" Total reward for partition {partition_idx}: {partition_state['total_reward']:.4f}") + env_manager.partitions[partition_idx]['status'] = 'completed' + partition_state['status'] = 'completed' + partition_completed = True + + elif partition_state['step_count'] >= max_steps_per_trial: + print(f" Partition {partition_idx} TIMED OUT") + env_manager.partitions[partition_idx]['status'] = 'failed' + partition_state['status'] = 'failed' + partition_completed = True + + elif jnp.isnan(reward): + print(f" Partition {partition_idx} FAILED due to NaN reward") + env_manager.partitions[partition_idx]['status'] = 'failed' + partition_state['status'] = 'failed' + partition_completed = True + + if partition_completed: + partitions_to_remove.append(partition_idx) + + except Exception as e: + print(f" ERROR stepping partition {partition_idx}: {e}") + if partition_idx < len(env_manager.partitions): + env_manager.partitions[partition_idx]['status'] = 'failed' + partition_state['status'] = 'failed' + partitions_to_remove.append(partition_idx) + + # Remove completed/failed partitions + for partition_idx in partitions_to_remove: + if partition_idx in active_partitions: + active_partitions.remove(partition_idx) + print(f" Removed partition {partition_idx} from active list") + + print(f" Remaining active partitions: {active_partitions}") + trial_global_step_rewards.append(current_step_reward) + print(f" Trial {current_experiment_number} step {trial_step} reward: {current_step_reward:.4f}") + + # Render + if GRID_RENDERING: + env_manager.render_all_partition_views_grid(partition_states) + else: + env_manager.render_global_environment_with_multiple_agents(partition_states, VISUALIZE_PARTITIONS) + + # After processing all partitions, check if trial is complete + map_metrics = calculate_map_completion_metrics(partition_states) + trial_done = map_metrics['done'] + + # Update done flag for this step + done = jnp.array(trial_done) # Convert to JAX array for consistency + + trial_step += 1 + global_step += 1 + + reward_seq.append(current_step_reward) + + if episode_done_once is None: + episode_done_once = done + if episode_length is None: + episode_length = jnp.zeros_like(done, dtype=jnp.int32) + if move_cumsum is None: + move_cumsum = jnp.zeros_like(done, dtype=jnp.int32) + if do_cumsum is None: + do_cumsum = jnp.zeros_like(done, dtype=jnp.int32) + + episode_done_once = episode_done_once | done + episode_length += ~episode_done_once + + move_cumsum_tmp = jnp.zeros_like(done, dtype=jnp.int32) + for move_action in move_actions: + move_mask = (action_rl == move_action) * (~episode_done_once) + move_cumsum_tmp += move_tiles * tile_size * move_mask.astype(jnp.int32) + for l_action in l_actions: + l_mask = (action_rl == l_action) * (~episode_done_once) + move_cumsum_tmp += 2 * move_tiles * tile_size * l_mask.astype(jnp.int32) + move_cumsum += move_cumsum_tmp + + do_cumsum += (action_rl == do_action) * (~episode_done_once) + + dug_tiles_per_action_map = (env_manager.global_maps['action_map'] == -1).sum() + + if trial_done: + print(f"Trial {current_experiment_number} completed!") + break + + # Add trial data to global collections + all_frames.extend(trial_frames) + all_reward_seq.extend(trial_reward_seq) + all_global_step_rewards.extend(trial_global_step_rewards) + all_obs_seq.extend(trial_obs_seq) + all_action_list.extend(trial_action_list) + + # FINAL SUMMARY FOR THIS TRIAL + print(f"\n{'='*80}") + print(f"TRIAL {current_experiment_number} ON MAP {map_idx} COMPLETED") + print(f"{'='*80}") + + if SAVE_VIDEO: + # Save results for this trial + current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + safe_model_name = llm_model_name.replace('/', '_') + output_dir = os.path.join("experiments", f"{safe_model_name}_map_{map_idx}_trial_{current_experiment_number}_{current_time}") + os.makedirs(output_dir, exist_ok=True) + + # Save video + video_path = os.path.join(output_dir, f"trial_{current_experiment_number}_gameplay.mp4") + save_video(all_frames, video_path, FPS) + + print(f"\nTrial results saved to: {output_dir}") + + info = { + "episode_done_once": episode_done_once, + "episode_length": episode_length, + "move_cumsum": move_cumsum, + "do_cumsum": do_cumsum, + "areas": areas, + "dig_tiles_per_target_map_init": dig_tiles_per_target_map_init, + "dug_tiles_per_action_map": dug_tiles_per_action_map, + } + + # Print intervention statistics + if ENABLE_INTERVENTION and total_interventions > 0: + print(f"\nπŸ”§ INTERVENTION STATISTICS:") + print(f" Total interventions: {total_interventions}") + print(f" Interventions per partition: {partition_interventions}") + print(f" Intervention rate: {total_interventions/global_step:.1%}") + + # Save intervention stats + info["total_interventions"] = total_interventions + info["partition_interventions"] = partition_interventions + + return info + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run an LLM-based simulation experiment with RL agents - Parallel Version.") + parser.add_argument( + "--model_name", + type=str, + required=True, + choices=["gpt-4o", + "gpt-4.1", + "o4-mini", + "o3", + "o3-mini", + "gemini-1.5-flash-latest", + "gemini-2.0-flash", + "gemini-2.5-pro", + "gemini-2.5-flash", + "claude-3-haiku-20240307", + "claude-3-7-sonnet-20250219", + "claude-opus-4-20250514", + "claude-sonnet-4-20250514", + ], + help="Name of the LLM model to use." + ) + parser.add_argument( + "--model_key", + type=str, + required=True, + choices=["gpt", + "gemini", + "claude"], + help="Name of the LLM model key to use." + ) + parser.add_argument( + "--num_timesteps", + type=int, + default=100, + help="Number of timesteps to run." + ) + parser.add_argument( + "-n", + "--n_maps", + type=int, + default=10, + help="Total number of different maps to process", + ) + parser.add_argument( + "--map_idx", + type=int, + required=True, + help="Index of the specific map to process (0-based, for parallel execution)", + ) + parser.add_argument( + "--n_partitions_per_map", + type=int, + default=100, + help="Number of random partition trials per map", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + default=0, + help="Random seed for the environment.", + ) + parser.add_argument( + "-run", + "--run_name", + type=str, + default="/home/gioelemo/Documents/terra/tracked-dense.pkl", + help="Policy to use for the experiment. Must be a valid path to a .pkl file containing the policy.", + ) + parser.add_argument( + "--level_index", + type=int, + default=None, + help="Index of the level to run from CurriculumGlobalConfig.levels. If None, runs all levels." + ) + parser.add_argument( + "--use_rendering", + action="store_true", + help="Enable rendering (usually disabled for parallel jobs)" + ) + parser.add_argument( + "--use_display", + action="store_true", + help="Enable display (usually disabled for parallel jobs)" + ) + + args = parser.parse_args() + + # Validate map index + if args.map_idx < 0 or args.map_idx >= args.n_maps: + raise ValueError(f"map_idx ({args.map_idx}) must be between 0 and {args.n_maps-1}") + + if args.level_index is not None: + import os + os.environ['TERRA_LEVEL_INDEX'] = str(args.level_index) + + # Get experiment configuration + (_, _, _, _, _, _, _ , _, _, _, _, _, USE_RENDERING, USE_DISPLAY, + _, _, _, _, _,_, _, _, COMPUTE_BENCH_STATS + ) = setup_experiment_config() + + # Override rendering settings for parallel execution + args.use_rendering = args.use_rendering and USE_RENDERING + args.use_display = args.use_display and USE_DISPLAY + + # Load model configuration once + agent_checkpoint_path = args.run_name + print(f"Loading RL agent configuration from: {agent_checkpoint_path}") + log = load_pkl_object(agent_checkpoint_path) + config = log["train_config"] + model_params = log["model"] + + # Create the original environment configs for the full map + global_env_config = jax.tree_map( + lambda x: x[0][None, ...].repeat(1, 0), log["env_config"] + ) + + config.num_test_rollouts = 1 + config.num_devices = 1 + config.num_embeddings_agent_min = 60 + + print(f"\nRunning map {args.map_idx}/{args.n_maps} with {args.n_partitions_per_map} partition trials") + print(f"Model: {args.model_name}") + print(f"Timesteps: {args.num_timesteps}") + print(f"Seed base: {args.seed}") + print(f"Rendering: {args.use_rendering}") + print(f"Display: {args.use_display}") + + # Run the single map + result = run_single_map_corrected(args.map_idx, args, config, model_params, global_env_config) + + if result: + print(f"\n{'='*80}") + print(f"MAP {args.map_idx} COMPLETED SUCCESSFULLY") + print(f"Best coverage: {result['coverage']:.4f}") + print(f"{'='*80}") + else: + print(f"\n{'='*80}") + print(f"MAP {args.map_idx} FAILED") + print(f"{'='*80}") + exit(1) diff --git a/llm/prompt_manager_llm.py b/llm/prompt_manager_llm.py new file mode 100644 index 0000000..6dbf709 --- /dev/null +++ b/llm/prompt_manager_llm.py @@ -0,0 +1,55 @@ +from pathlib import Path +import json + +class PromptManager: + """Lightweight prompt manager that loads from external files.""" + + def __init__(self, prompts_dir: str = "prompts"): + self.prompts_dir = Path(prompts_dir) + self.prompts_dir.mkdir(exist_ok=True) + self._cache = {} + + def get(self, prompt_name: str, **kwargs) -> str: + """Get a prompt with optional variable substitution.""" + if prompt_name not in self._cache: + self._load_prompt(prompt_name) + + prompt = self._cache[prompt_name] + if kwargs: + try: + return prompt.format(**kwargs) + except KeyError as e: + raise ValueError(f"Missing variable {e} for prompt '{prompt_name}'") + return prompt + + def _load_prompt(self, prompt_name: str): + """Load a prompt from file.""" + # Try .txt file first + txt_file = self.prompts_dir / f"{prompt_name}.txt" + if txt_file.exists(): + with open(txt_file, 'r', encoding='utf-8') as f: + self._cache[prompt_name] = f.read().strip() + return + + # Try .json file + json_file = self.prompts_dir / f"{prompt_name}.json" + if json_file.exists(): + with open(json_file, 'r', encoding='utf-8') as f: + data = json.load(f) + if isinstance(data, dict) and 'prompt' in data: + self._cache[prompt_name] = data['prompt'] + elif isinstance(data, str): + self._cache[prompt_name] = data + else: + raise ValueError(f"Invalid JSON format in {json_file}") + return + + raise FileNotFoundError(f"Prompt file not found: {prompt_name}.txt or {prompt_name}.json") + + def reload(self, prompt_name: str = None): + """Reload prompts from files (useful for development).""" + if prompt_name: + if prompt_name in self._cache: + del self._cache[prompt_name] + else: + self._cache.clear() \ No newline at end of file diff --git a/llm/prompts/delegation.txt b/llm/prompts/delegation.txt new file mode 100644 index 0000000..0f9cad4 --- /dev/null +++ b/llm/prompts/delegation.txt @@ -0,0 +1,14 @@ +You are a master coordinator deciding how to control an excavator agent. Based on the current game state, choose ONE of these options: + +1. **delegate_to_rl** - Use the trained RL agent (good for starting the excavation task) +2. **delegate_to_llm** - Use the LLM agent (good when there is a few soil left to dig) +3. **intervention** - Take direct control to help a stuck agent (use when agent is stuck, repeating actions, or making no progress) + +## When to use intervention: +- Agent is going in circles or repeating the same actions +- Agent has made no progress for several steps +- Agent is consistently getting negative or zero rewards +- Agent appears stuck against obstacles +- Agent needs a strategic reset or repositioning + +Respond with exactly one word: delegate_to_rl, delegate_to_llm, or intervention \ No newline at end of file diff --git a/llm/prompts/delegation_no_intervention.txt b/llm/prompts/delegation_no_intervention.txt new file mode 100644 index 0000000..2c688c9 --- /dev/null +++ b/llm/prompts/delegation_no_intervention.txt @@ -0,0 +1,6 @@ +You are a master coordinator deciding how to control an excavator agent. Based on the current game state, choose ONE of these options: + +1. **delegate_to_rl** - Use the trained RL agent (good for starting the excavation task) +2. **delegate_to_llm** - Use the LLM agent (good when there is a few soil left to dig) + +Respond with exactly one word: delegate_to_rl, delegate_to_llm \ No newline at end of file diff --git a/llm/prompts/excavator_action.txt b/llm/prompts/excavator_action.txt new file mode 100644 index 0000000..e995e91 --- /dev/null +++ b/llm/prompts/excavator_action.txt @@ -0,0 +1,5 @@ +Analyze this game frame and the provided local map to select the optimal action. +The base of the excavator is currently facing {direction}. +The bucket is currently {bucket_status}. +The excavator is currently located at {position} (y,x). +Follow the format: {{\"reasoning\": \"detailed step-by-step analysis\", \"action\": X}} \ No newline at end of file diff --git a/llm/prompts/excavator_llm_simple.txt b/llm/prompts/excavator_llm_simple.txt new file mode 100644 index 0000000..2b0eaff --- /dev/null +++ b/llm/prompts/excavator_llm_simple.txt @@ -0,0 +1,26 @@ +You are an intelligent assistant responsible for selecting the optimal action for an autonomous excavator to efficiently dig and deposit soil. + +Inputs: +(1) Images: (a) Current game state; +(2) Positions: Excavator base (y, x), use top-down coordinates (X-axis: left to right, Y-axis: top to bottom); +(3) Bucket state (loaded/empty); +(4) Base orientation (up, down, left, or right). + +Critical Constraints: +(a) The excavator base must maintain a minimum distance from the digging target area; +(b) Digging needs to be performed only once per target (purple area) β€” repeated digging is inefficient; +(c) The excavated area (marked in blue) is considered an obstacle and is not traversable. + +Objectives: +(1) Ensure proper reach and spacing for efficient digging; +(2) After digging, rotate the cabin and use action 6 (DO) to deposit soil, ensuring the deposit is far from the digging area to avoid obstruction during movements. + +Actions: +'-1': DO_NOTHING, '0': FORWARD, '1': BACKWARD, '2': CLOCK, '3': ANTICLOCK, '4': CABIN_CLOCK, '5': CABIN_ANTICLOCK, '6': DO. + +Rules: +- Avoid black obstacles on the map; +- The brown arrow shows the cabin's direction, and it turns grey when the bucket is loaded; +- After digging, the base cannot move until the bucket is emptied. + +Output: {\"reasoning\": \"Explanation of why this action is optimal.\", \"action\": } diff --git a/llm/prompts/partitioning.txt b/llm/prompts/partitioning.txt new file mode 100644 index 0000000..07ac3e9 --- /dev/null +++ b/llm/prompts/partitioning.txt @@ -0,0 +1,74 @@ +You are a master excavation coordinator responsible for optimizing excavation operations on a site map. Your task is to analyze the given terrain and intelligently partition it into optimal regions for multiple excavator deployments. + +## CRITICAL REQUIREMENTS: +- You will receive a {map_size}x{map_size} map as input +- Important: The partitions should be of maximal size 64x64. You could also consider smaller partitions +- **PURPLE AREAS REPRESENT TARGET EXCAVATION ZONES - ALL PURPLE AREAS MUST BE FULLY COVERED BY PARTITIONS** +- EACH PARTITION MUST HAVE UNIQUE, NON-IDENTICAL COORDINATES +- DO NOT CREATE DUPLICATE PARTITIONS - EACH PARTITION MUST COVER A DIFFERENT AREA + +RULE HIERARCHY: + +1. PRIMARY DIRECTIVE (NON-NEGOTIABLE): You must ensure 100% of all purple pixels are contained within your partitions. +2. SECONDARY CONSTRAINT: All the partition must be non-empty (they should at least contain a target purple area (minimum 10% of the pixels)) + +If these rules conflict, the Primary Directive (full coverage) always wins. This means you should merge or expand partitions to cover all purple areas, even if it results in a less elegant layout. + +## GUIDELINES FOR PARTITIONING: + +### Primary Objectives: +1. **COMPLETE COVERAGE**: Ensure every purple pixel/area that needs excavation is fully contained within at least one partition +2. **TRENCH PRIORITIZATION**: If you see multiple trenches or distinct purple excavation areas, create dedicated partitions for each major trench/area +3. **NO GAPS**: Purple areas must not fall between partition boundaries or be partially covered + +### Optimization Considerations: +1. Analyze the state of the map carefully, identifying all purple excavation zones, terrain features, obstacles, and excavation requirements +2. Create efficient partitions that maximize excavator productivity and minimize travel time while ensuring complete purple area coverage +3. Ensure each partition has adequate space for the excavator to maneuver around the purple excavation zones +4. Designate appropriate soil deposit areas within each partition or create shared deposit zones if more efficient +5. Position starting points strategically near purple areas to minimize initial travel time to excavation targets +6. Consider terrain complexity when determining partition size - more complex areas may require smaller partitions +7. **Partition boundaries should extend beyond purple areas to provide maneuvering space for excavators** + +### Coverage Strategy: +- When purple areas span multiple potential partitions, either: +- Extend one partition to fully encompass the entire purple area, OR +- Ensure overlapping partitions so no purple area is missed +- Buffer zones around purple areas should be included in partitions for excavator movement +- Prioritize complete excavation task completion over perfect geometric partitioning + +## RESPONSE FORMAT: +Respond with a JSON list of partition objects, each containing: +- 'id': Unique numeric identifier for each partition (starting from 0) +- 'region_coords': MUST BE A TUPLE with parentheses, NOT an array with brackets: (y_start, x_start, y_end, x_end) +- 'start_pos': MUST BE A TUPLE with parentheses, NOT an array with brackets: (y, x) +- 'start_angle': Always use 0 degrees for initial orientation +- 'status': Set to 'pending' for all new partitions + +## COORDINATE FORMAT REQUIREMENTS: +**CRITICAL**: You MUST use Python tuple notation with parentheses () for coordinates, NOT arrays with square brackets []. Failure to use tuple notation will result in errors. + +### CORRECT FORMAT (with tuples): +[{{'id': 0, 'region_coords': (0, 0, 31, 31), 'start_pos': (16, 16), 'start_angle': 0, 'status': 'pending'}}] + +### INCORRECT FORMAT (with arrays): +[{{'id': 0, 'region_coords': [0, 0, 31, 31], 'start_pos': [16, 16], 'start_angle': 0, 'status': 'pending'}}] + +## Example Response: +For partitioning a 64x64 map into 4 equal quadrants (USING TUPLES, NOT ARRAYS): +[{{'id': 0, 'region_coords': (0, 0, 31, 31), 'start_pos': (16, 16), 'start_angle': 0, 'status': 'pending'}}, {{'id': 1, 'region_coords': (0, 32, 31, 63), 'start_pos': (16, 48), 'start_angle': 0, 'status': 'pending'}}, {{'id': 2, 'region_coords': (32, 0, 63, 31), 'start_pos': (48, 16), 'start_angle': 0, 'status': 'pending'}}, {{'id': 3, 'region_coords': (32, 32, 63, 63), 'start_pos': (48, 48), 'start_angle': 0, 'status': 'pending'}}] + +## WHAT TO DO: +- DO NOT use string representations like "(0, 0, 31, 31)" or "[0, 0, 31, 31]" +- DO NOT quote the coordinate tuples - they should be actual tuples +- DO use Python tuple notation with parentheses: (0, 0, 31, 31) +- DO use single quotes for dictionary keys in Python format + +## FINAL REMINDERS: +- Always return a list of partitions even if only creating a single partition +- Ensure each partition has sufficient space for both excavation and soil deposit operations +- Include ample maneuvering space in partitions to prevent excavator from getting stuck +- **VERIFY that all purple excavation areas are completely covered by your partition layout** +- **REMEMBER TO USE TUPLES (PARENTHESES) FOR ALL COORDINATES** +- Starting positions should be optimized for quick access to purple excavation zones within each partition +- NO EMPTY PARTITIONS: Every single partition you create must contain at least one purple pixel. If a partition does not cover any part of a purple excavation zone, it must be deleted. \ No newline at end of file diff --git a/llm/prompts/partitioning_exact.txt b/llm/prompts/partitioning_exact.txt new file mode 100644 index 0000000..8be614b --- /dev/null +++ b/llm/prompts/partitioning_exact.txt @@ -0,0 +1,82 @@ +You are a master excavation coordinator responsible for optimizing excavation operations on a site map. Your task is to analyze the given terrain and intelligently partition it into optimal regions for multiple excavator deployments. + +## CRITICAL REQUIREMENTS: +- You will receive a {map_size}x{map_size} map as input +- The partitions should be of maximal size 64x64. You could also consider smaller partitions +- **PURPLE AREAS REPRESENT TARGET EXCAVATION ZONES - ALL PURPLE AREAS MUST BE FULLY COVERED BY PARTITIONS** +- EACH PARTITION MUST HAVE UNIQUE, NON-IDENTICAL COORDINATES +- DO NOT CREATE DUPLICATE PARTITIONS - EACH PARTITION MUST COVER A DIFFERENT AREA + +RULE HIERARCHY: + +1. PRIMARY DIRECTIVE (NON-NEGOTIABLE): You must ensure 100% of all purple pixels are contained within your partitions. + +2. SECONDARY CONSTRAINT: You must use exactly {max_partitions} partitions. + +3. TERTIARY CONSTRAINT: All the partition must be non-empty (they should at least contain a target purple area (minimum 10% of the pixels)) + +If these rules conflict, the Primary Directive (full coverage) always wins. This means you should merge or expand partitions to cover all purple areas, even if it results in a less elegant layout, as long as you do not exceed the {max_partitions} limit. + + +## GUIDELINES FOR PARTITIONING: + +### Primary Objectives: +1. **COMPLETE COVERAGE**: Ensure every purple pixel/area that needs excavation is fully contained within at least one partition +2. **TRENCH PRIORITIZATION**: If you see multiple trenches or distinct purple excavation areas, create dedicated partitions for each major trench/area +3. **NO GAPS**: Purple areas must not fall between partition boundaries or be partially covered + +### Optimization Considerations: +1. Analyze the state of the map carefully, identifying all purple excavation zones, terrain features, obstacles, and excavation requirements +2. Create efficient partitions that maximize excavator productivity and minimize travel time while ensuring complete purple area coverage +3. Ensure each partition has adequate space for the excavator to maneuver around the purple excavation zones +4. Designate appropriate soil deposit areas within each partition or create shared deposit zones if more efficient +5. Position starting points strategically near purple areas to minimize initial travel time to excavation targets +6. Consider terrain complexity when determining partition size - more complex areas may require smaller partitions +7. **Partition boundaries should extend beyond purple areas to provide maneuvering space for excavators** + +### Coverage Strategy: +- When purple areas span multiple potential partitions, either: + - Extend one partition to fully encompass the entire purple area, OR + - Ensure overlapping partitions so no purple area is missed +- Buffer zones around purple areas should be included in partitions for excavator movement +- Prioritize complete excavation task completion over perfect geometric partitioning + +## PARTITION LIMITS: +Adhere to the SECONDARY CONSTRAINT defined in the RULE HIERARCHY. Create the number of partitions needed for full coverage, but do not exceed {max_partitions}. + +## RESPONSE FORMAT: +Respond with a JSON list of partition objects, each containing: +- 'id': Unique numeric identifier for each partition (starting from 0) +- 'region_coords': MUST BE A TUPLE with parentheses, NOT an array with brackets: (y_start, x_start, y_end, x_end) +- 'start_pos': MUST BE A TUPLE with parentheses, NOT an array with brackets: (y, x) +- 'start_angle': Always use 0 degrees for initial orientation +- 'status': Set to 'pending' for all new partitions + +## COORDINATE FORMAT REQUIREMENTS: +**CRITICAL**: You MUST use Python tuple notation with parentheses () for coordinates, NOT arrays with square brackets []. Failure to use tuple notation will result in errors. + +### CORRECT FORMAT (with tuples): +[{{'id': 0, 'region_coords': (0, 0, 31, 31), 'start_pos': (16, 16), 'start_angle': 0, 'status': 'pending'}}] + +### INCORRECT FORMAT (with arrays): +[{{'id': 0, 'region_coords': [0, 0, 31, 31], 'start_pos': [16, 16], 'start_angle': 0, 'status': 'pending'}}] + +## Example Response: +For partitioning a 64x64 map into 4 equal quadrants (USING TUPLES, NOT ARRAYS): +[{{'id': 0, 'region_coords': (0, 0, 31, 31), 'start_pos': (16, 16), 'start_angle': 0, 'status': 'pending'}}, {{'id': 1, 'region_coords': (0, 32, 31, 63), 'start_pos': (16, 48), 'start_angle': 0, 'status': 'pending'}}, {{'id': 2, 'region_coords': (32, 0, 63, 31), 'start_pos': (48, 16), 'start_angle': 0, 'status': 'pending'}}, {{'id': 3, 'region_coords': (32, 32, 63, 63), 'start_pos': (48, 48), 'start_angle': 0, 'status': 'pending'}}] + +## WHAT TO DO: +- DO NOT use string representations like "(0, 0, 31, 31)" or "[0, 0, 31, 31]" +- DO NOT quote the coordinate tuples - they should be actual tuples +- DO use Python tuple notation with parentheses: (0, 0, 31, 31) +- DO use single quotes for dictionary keys in Python format + +## FINAL REMINDERS: +- Always return a list of partitions even if only creating a single partition +- Ensure each partition has sufficient space for both excavation and soil deposit operations +- Include ample maneuvering space in partitions to prevent excavator from getting stuck +- **VERIFY that all purple excavation areas are completely covered by your partition layout** +- **REMEMBER TO USE TUPLES (PARENTHESES) FOR ALL COORDINATES** +- Starting positions should be optimized for quick access to purple excavation zones within each partition +- DOUBLE CHECK IF YOU USE EXACTLY {max_partitions} PARTITIONS! +- NO EMPTY PARTITIONS: Every single partition you create must contain at least one purple pixel. If a partition does not cover any part of a purple excavation zone, it must be deleted. \ No newline at end of file diff --git a/llm/prompts/partitioning_new.txt b/llm/prompts/partitioning_new.txt new file mode 100644 index 0000000..2428cd7 --- /dev/null +++ b/llm/prompts/partitioning_new.txt @@ -0,0 +1,101 @@ +You are a meticulous excavation site planner. Your sole purpose is to generate a JSON list of rectangular partitions that provide 100% coverage for all purple excavation zones on an input map. + +Failure Condition: If even a single purple pixel is left uncovered, the entire plan is a failure. + +Input and Output + +Input: A {map_size}x{map_size} map image. + +Output: A JSON list of partition objects, formatted exactly as specified below. + +Mandatory 4-Step Execution Process + +You must follow these four steps in order. + +Step 1: Analyze Purple Zone Bounding Box +First, identify the single bounding box that contains all purple pixels on the map. + +min_y: The y-coordinate of the topmost purple pixel. + +max_y: The y-coordinate of the bottommost purple pixel. + +min_x: The x-coordinate of the leftmost purple pixel. + +max_x: The x-coordinate of the rightmost purple pixel. + +Step 2: Define the Mandated Coverage Area +Using the bounding box from Step 1, calculate the total rectangular area that MUST be covered. This is done by applying a non-negotiable 15-pixel safety margin on all sides. + +coverage_y_start = min_y - 15 + +coverage_y_end = max_y + 15 + +coverage_x_start = min_x - 15 + +coverage_x_end = max_x + 15 + +Step 3: Design Partitions to Fill the Coverage Area +Create a set of rectangular partitions that, when combined, completely fill the Mandated Coverage Area calculated in Step 2. + +Simple Shapes: If the purple area is a simple rectangle, a single partition is sufficient. + +Complex Shapes (L-shape, T-shape, multiple zones): Use multiple partitions. + +Critical Overlap Rule: Where two partitions meet, they must overlap by 12 to 16 pixels to ensure there are absolutely no gaps. + +Step 4: Format the Final Output +Present the final plan as a JSON list of partition objects. + +region_coords: Must be a TUPLE (y_start, x_start, y_end, x_end). + +start_pos: The center of the partition (center_y, center_x). + +status: Always 'pending'. + +start_angle: Always 0. + +Example of a Correct Execution + +This example shows how to handle a T-shaped purple area. + +Scenario: A T-shaped purple area has the following absolute bounds: min_y=100, max_y=150, min_x=100, max_x=200. The vertical stem runs from y=110 to y=150 at the center, and the horizontal bar is at the top from y=100 to y=110. + +Execution: + +Bounding Box Analysis: The bounds are correctly identified as min_y=100, max_y=150, min_x=100, max_x=200. + +Mandated Coverage Area: Applying the 15px margin, the required coverage is y:[85, 165] and x:[85, 215]. + +Partition Design: This is a complex shape, so two partitions are needed. One for the horizontal bar and one for the vertical stem. They must overlap significantly. + +Final JSON Output: + +JSON +[ + {{ + "id": 0, + "region_coords": (85, 85, 125, 215), + "start_pos": (105, 150), + "start_angle": 0, + "status": "pending" + }}, + {{ + "id": 1, + "region_coords": (110, 135, 165, 165), + "start_pos": (137, 150), + "start_angle": 0, + "status": "pending" + }} +] +Final Output Format Reminder + +JSON +[ + {{ + "id": 0, + "region_coords": (y_start, x_start, y_end, x_end), + "start_pos": (center_y, center_x), + "start_angle": 0, + "status": "pending" + }} +] \ No newline at end of file diff --git a/llm/results_aggregator.py b/llm/results_aggregator.py new file mode 100644 index 0000000..ea95518 --- /dev/null +++ b/llm/results_aggregator.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python3 +""" +Aggregate results from parallel SLURM jobs +Run this after all parallel jobs complete +""" + +import json +import numpy as np +import argparse +import os +import datetime +from glob import glob + +def load_map_results(results_dir): + """Load all individual map results""" + pattern = os.path.join(results_dir, "map_*_results.json") + result_files = sorted(glob(pattern)) + + if not result_files: + raise ValueError(f"No result files found in {results_dir}") + + print(f"Found {len(result_files)} result files") + + all_results = [] + for file_path in result_files: + try: + with open(file_path, 'r') as f: + data = json.load(f) + all_results.append(data) + print(f"Loaded {file_path}: Map {data['map_idx']}, best coverage {data['best_result']['coverage']:.4f}") + except Exception as e: + print(f"Error loading {file_path}: {e}") + + return all_results + +def compute_final_statistics(all_results): + """Compute final statistics from all maps using the exact same function as the original""" + + # Import the original function + from llm.eval_llm import compute_stats_llm + + # Extract best results from each map + best_results = [result['best_result'] for result in all_results] + + # Create the exact same lists as in the original code + best_episode_done_once_list = [r['episode_done_once'] for r in best_results] + best_episode_length_list = [r['episode_length'] for r in best_results] + best_move_cumsum_list = [r['move_cumsum'] for r in best_results] + best_do_cumsum_list = [r['do_cumsum'] for r in best_results] + best_areas_list = [r['areas'] for r in best_results] + best_dig_tiles_per_target_map_init_list = [r['dig_tiles_per_target_map_init'] for r in best_results] + best_dug_tiles_per_action_map_list = [r['dug_tiles_per_action_map'] for r in best_results] + + # Calculate coverages for additional reporting + coverages = [] + for dug, total in zip(best_dug_tiles_per_action_map_list, best_dig_tiles_per_target_map_init_list): + if total > 0: + coverages.append(dug / total) + else: + coverages.append(0.0) + + # Call the original compute_stats_llm function with the exact same parameters + print(f"\n{'='*80}") + print(f"CALLING ORIGINAL compute_stats_llm FUNCTION") + print(f"{'='*80}") + + compute_stats_llm( + best_episode_done_once_list, + best_episode_length_list, + best_move_cumsum_list, + best_do_cumsum_list, + best_areas_list, + best_dig_tiles_per_target_map_init_list, + best_dug_tiles_per_action_map_list + ) + + # Also compute some additional statistics for saving + stats = { + 'n_maps': len(all_results), + 'coverage_stats': { + 'mean': float(np.mean(coverages)), + 'std': float(np.std(coverages)), + 'min': float(np.min(coverages)), + 'max': float(np.max(coverages)), + 'median': float(np.median(coverages)), + 'q25': float(np.percentile(coverages, 25)), + 'q75': float(np.percentile(coverages, 75)) + }, + 'episode_stats': { + 'completion_rate': float(np.mean(best_episode_done_once_list)), + 'mean_length': float(np.mean(best_episode_length_list)), + 'std_length': float(np.std(best_episode_length_list)) + }, + 'action_stats': { + 'mean_move_cumsum': float(np.mean(best_move_cumsum_list)), + 'mean_do_cumsum': float(np.mean(best_do_cumsum_list)) + }, + 'intervention_stats': { + 'total_interventions': sum(r.get('total_interventions', 0) for r in best_results), + 'maps_with_interventions': sum(1 for r in best_results if r.get('total_interventions', 0) > 0) + }, + # Store the arrays that were passed to compute_stats_llm + 'compute_stats_arrays': { + 'best_episode_done_once_list': best_episode_done_once_list, + 'best_episode_length_list': best_episode_length_list, + 'best_move_cumsum_list': best_move_cumsum_list, + 'best_do_cumsum_list': best_do_cumsum_list, + 'best_areas_list': best_areas_list, + 'best_dig_tiles_per_target_map_init_list': best_dig_tiles_per_target_map_init_list, + 'best_dug_tiles_per_action_map_list': best_dug_tiles_per_action_map_list + } + } + + return stats, best_results, coverages + +def print_statistics(stats, model_name): + """Print formatted statistics - the original compute_stats_llm already printed detailed stats""" + print(f"\n{'='*80}") + print(f"ADDITIONAL COVERAGE STATISTICS - {model_name}") + print(f"{'='*80}") + print(f"Number of maps processed: {stats['n_maps']}") + + print(f"\nBest Partition Coverage Statistics:") + print(f" Mean coverage: {stats['coverage_stats']['mean']:.4f} Β± {stats['coverage_stats']['std']:.4f}") + print(f" Median coverage: {stats['coverage_stats']['median']:.4f}") + print(f" Min coverage: {stats['coverage_stats']['min']:.4f}") + print(f" Max coverage: {stats['coverage_stats']['max']:.4f}") + print(f" Q25-Q75: {stats['coverage_stats']['q25']:.4f} - {stats['coverage_stats']['q75']:.4f}") + + if stats['intervention_stats']['total_interventions'] > 0: + print(f"\nIntervention Statistics:") + print(f" Total interventions: {stats['intervention_stats']['total_interventions']}") + print(f" Maps with interventions: {stats['intervention_stats']['maps_with_interventions']}/{stats['n_maps']}") + print(f" Intervention rate: {stats['intervention_stats']['total_interventions']/sum(stats['compute_stats_arrays']['best_episode_length_list']):.1%} per step") + +def save_aggregated_results(all_results, stats, coverages, model_name, output_dir): + """Save aggregated results to file""" + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + + aggregated_data = { + 'timestamp': timestamp, + 'model_name': model_name, + 'statistics': stats, + 'best_coverages': coverages, + 'individual_map_results': all_results, + 'metadata': { + 'aggregation_script': 'aggregate_parallel_results.py', + 'aggregation_time': datetime.datetime.now().isoformat() + } + } + + # Save main aggregated results + aggregated_filename = os.path.join(output_dir, f"aggregated_results_{model_name.replace('/', '_')}_{timestamp}.json") + with open(aggregated_filename, 'w') as f: + json.dump(aggregated_data, f, indent=2) + + # Save summary statistics in a more readable format + summary_filename = os.path.join(output_dir, f"summary_{model_name.replace('/', '_')}_{timestamp}.txt") + with open(summary_filename, 'w') as f: + f.write(f"Aggregated Results Summary - {model_name}\n") + f.write(f"Generated: {datetime.datetime.now().isoformat()}\n") + f.write("="*80 + "\n\n") + + f.write(f"Number of maps processed: {stats['n_maps']}\n\n") + + f.write("Coverage Statistics:\n") + f.write(f" Mean: {stats['coverage_stats']['mean']:.4f} Β± {stats['coverage_stats']['std']:.4f}\n") + f.write(f" Median: {stats['coverage_stats']['median']:.4f}\n") + f.write(f" Min: {stats['coverage_stats']['min']:.4f}\n") + f.write(f" Max: {stats['coverage_stats']['max']:.4f}\n") + f.write(f" Q25-Q75: {stats['coverage_stats']['q25']:.4f} - {stats['coverage_stats']['q75']:.4f}\n\n") + + f.write("Episode Statistics:\n") + f.write(f" Completion rate: {stats['episode_stats']['completion_rate']:.1%}\n") + f.write(f" Mean episode length: {stats['episode_stats']['mean_length']:.1f} Β± {stats['episode_stats']['std_length']:.1f}\n\n") + + f.write("Action Statistics:\n") + f.write(f" Mean move cumsum: {stats['action_stats']['mean_move_cumsum']:.1f}\n") + f.write(f" Mean do cumsum: {stats['action_stats']['mean_do_cumsum']:.1f}\n\n") + + if stats['intervention_stats']['total_interventions'] > 0: + f.write("Intervention Statistics:\n") + f.write(f" Total interventions: {stats['intervention_stats']['total_interventions']}\n") + f.write(f" Maps with interventions: {stats['intervention_stats']['maps_with_interventions']}/{stats['n_maps']}\n\n") + + f.write("Individual Map Coverages:\n") + for i, coverage in enumerate(coverages): + f.write(f" Map {i:2d}: {coverage:.4f}\n") + + return aggregated_filename, summary_filename + +def main(): + parser = argparse.ArgumentParser(description="Aggregate results from parallel SLURM jobs") + parser.add_argument( + "--results_dir", + type=str, + default="results_parallel", + help="Directory containing individual map result files" + ) + parser.add_argument( + "--model_name", + type=str, + required=True, + help="Model name (used for file naming and display)" + ) + parser.add_argument( + "--output_dir", + type=str, + default="aggregated_results", + help="Directory to save aggregated results" + ) + + args = parser.parse_args() + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + # Build full results directory path + results_dir = os.path.join(args.results_dir, args.model_name.replace('/', '_')) + + if not os.path.exists(results_dir): + print(f"Results directory not found: {results_dir}") + print(f"Available directories in {args.results_dir}:") + if os.path.exists(args.results_dir): + for item in os.listdir(args.results_dir): + item_path = os.path.join(args.results_dir, item) + if os.path.isdir(item_path): + print(f" {item}") + return + + print(f"Loading results from: {results_dir}") + + try: + # Load all map results + all_results = load_map_results(results_dir) + + if not all_results: + print("No valid results found!") + return + + # Compute statistics + stats, best_results, coverages = compute_final_statistics(all_results) + + # Print statistics + print_statistics(stats, args.model_name) + + # Save aggregated results + aggregated_file, summary_file = save_aggregated_results( + all_results, stats, coverages, args.model_name, args.output_dir + ) + + print(f"\nResults saved to:") + print(f" Detailed: {aggregated_file}") + print(f" Summary: {summary_file}") + + print(f"\n{'='*80}") + print("AGGREGATION COMPLETED SUCCESSFULLY") + print(f"{'='*80}") + + except Exception as e: + print(f"Error during aggregation: {e}") + import traceback + traceback.print_exc() + return 1 + + return 0 + +if __name__ == "__main__": + exit(main()) \ No newline at end of file diff --git a/llm/run_levels.slurm b/llm/run_levels.slurm new file mode 100644 index 0000000..de4c72f --- /dev/null +++ b/llm/run_levels.slurm @@ -0,0 +1,43 @@ +#!/bin/bash +#SBATCH --job-name=llm_levels +#SBATCH --array=0-5 # Run up to 6 jobs in parallel +#SBATCH --cpus-per-task=4 +#SBATCH --mem-per-cpu=8G +#SBATCH --time=10:00:00 +#SBATCH --output=logs_LLM/job_%A_%a.out +#SBATCH --error=logs_LLM/job_%A_%a.err +#SBATCH --mail-type=END,FAIL +#SBATCH --mail-user=gioelemo@ethz.ch + +# Load necessary modules +module load eth_proxy + +export GOOGLE_API_KEY=... + +# Set paths to conda +CONDA_ROOT=/cluster/home/gioelemo/miniconda3 +CONDA_ENV=terra + +# Activate conda environment +eval "$($CONDA_ROOT/bin/conda shell.bash hook)" +conda activate $CONDA_ENV + +# Set environment variables +export JAX_PLATFORMS=cpu +export JAX_PLATFORM_NAME=cpu + +export CUDA_VISIBLE_DEVICES="" + +export DATASET_PATH=/cluster/home/gioelemo/terra_jax/terra/train4/train +export DATASET_SIZE=500 + +# Run the Python script for the given level +LEVEL_INDEX=$SLURM_ARRAY_TASK_ID + +python -m main_llm \ + --model_name gemini-2.5-pro \ + --model_key gemini \ + --num_timesteps 400 \ + -s 1 -n 100 \ + -run /cluster/home/gioelemo/terra_jax/terra/tracked-dense.pkl \ + --level_index $LEVEL_INDEX diff --git a/llm/run_levels_random.slurm b/llm/run_levels_random.slurm new file mode 100644 index 0000000..a31af04 --- /dev/null +++ b/llm/run_levels_random.slurm @@ -0,0 +1,46 @@ +#!/bin/bash +#SBATCH --job-name=random_parallel_level5 +#SBATCH --array=0-99 +#SBATCH --cpus-per-task=4 +#SBATCH --mem-per-cpu=6G +#SBATCH --time=02:00:00 +#SBATCH --output=logs_level5/job_%A_%a.out +#SBATCH --error=logs_level5/job_%A_%a.err +#SBATCH --mail-type=END,FAIL +#SBATCH --mail-user=gioelemo@ethz.ch + +# Load necessary modules +module load eth_proxy +mkdir -p logs + +export GOOGLE_API_KEY=... + +# Set paths to conda +CONDA_ROOT=/cluster/home/gioelemo/miniconda3 +CONDA_ENV=terra + +# Activate conda environment +eval "$($CONDA_ROOT/bin/conda shell.bash hook)" +conda activate $CONDA_ENV + +# Set environment variables +export JAX_PLATFORMS=cpu +export JAX_PLATFORM_NAME=cpu + +export CUDA_VISIBLE_DEVICES="" + +export DATASET_PATH=/cluster/home/gioelemo/terra_jax/terra/train4/train +export DATASET_SIZE=500 + +# Run the Python script for the given level + +python -m main_llm_random \ + --model_name gemini-2.5-pro \ + --model_key gemini \ + --num_timesteps 400 \ + -s 1 --n_maps 100 --n_partitions_per_map 20 \ + -run /cluster/home/gioelemo/terra_jax/terra/tracked-dense.pkl \ + --level_index 5 \ + --map_idx $SLURM_ARRAY_TASK_ID \ + --use_rendering + diff --git a/llm/session_manager_llm.py b/llm/session_manager_llm.py new file mode 100644 index 0000000..c03a97f --- /dev/null +++ b/llm/session_manager_llm.py @@ -0,0 +1,69 @@ +from google.adk.sessions import InMemorySessionService +from google.adk.runners import Runner + +class SessionManager: + """Manages ADK sessions across multiple agents to prevent session loss.""" + + def __init__(self): + self.session_services = {} + self.sessions = {} + self.runners = {} + + def create_agent_session(self, agent_name, app_name, user_id, session_id): + """Create a new session for an agent.""" + # Create unique session service for this agent + session_service_key = f"{agent_name}_{app_name}" + + if session_service_key not in self.session_services: + self.session_services[session_service_key] = InMemorySessionService() + + session_service = self.session_services[session_service_key] + + # Create session + session = session_service.create_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + ) + + # Store session reference + session_key = f"{user_id}_{session_id}" + self.sessions[session_key] = { + 'session': session, + 'service': session_service, + 'app_name': app_name, + 'user_id': user_id, + 'session_id': session_id + } + + #print(f"Created session: {session_key} for {agent_name}") + return session_service + + def create_runner(self, agent, runner_key, app_name): + """Create and store a runner for an agent.""" + session_service_key = f"{runner_key}_{app_name}" + session_service = self.session_services.get(session_service_key) + + if not session_service: + raise ValueError(f"No session service found for {session_service_key}") + + runner = Runner( + agent=agent, + app_name=app_name, + session_service=session_service, + ) + + self.runners[runner_key] = runner + #print(f"Created runner: {runner_key}") + return runner + + def get_session_info(self, user_id, session_id): + """Get session information.""" + session_key = f"{user_id}_{session_id}" + return self.sessions.get(session_key) + + def list_sessions(self): + """List all active sessions for debugging.""" + print("\nActive Sessions:") + for key, session_info in self.sessions.items(): + print(f" {key}: {session_info['app_name']}") \ No newline at end of file diff --git a/llm/utils_llm.py b/llm/utils_llm.py new file mode 100644 index 0000000..778f94c --- /dev/null +++ b/llm/utils_llm.py @@ -0,0 +1,1743 @@ +import base64 +import cv2 +import numpy as np +import jax +from google.adk.agents import Agent +from google.adk.models.lite_llm import LiteLlm +from google.genai import types +import jax.numpy as jnp +from llm.adk_llm import * + +import csv +from utils.models import load_neural_network + +import json +import jax.numpy as jnp +import ast +import io +from PIL import Image +import matplotlib.pyplot as plt + +import jax.numpy as jnp +import pygame as pg + +from llm.session_manager_llm import SessionManager +from llm.prompt_manager_llm import PromptManager +import os +import yaml +import datetime +import random + +def encode_image(cv_image): + _, buffer = cv2.imencode(".jpg", cv_image) + return base64.b64encode(buffer).decode("utf-8") + +def save_csv(output_file, action_list, cumulative_rewards): + with open(output_file, "w", newline='') as f: + writer = csv.writer(f) + writer.writerow(["actions", "cumulative_rewards"]) # Header updated + # Iterate through actions and the calculated cumulative rewards + for action, cum_reward in zip(action_list, cumulative_rewards): + # Assuming action is array-like (e.g., JAX array) with one element + action_value = action[0] if hasattr(action, '__getitem__') and len(action) > 0 else action + # cum_reward from np.cumsum is already a scalar number + reward_value = cum_reward + writer.writerow([action_value, reward_value]) + + print(f"Results saved to {output_file}") + +def create_sub_task_target_map_64x64(global_target_map_data: jnp.ndarray, + region_coords: tuple[int, int, int, int]) -> jnp.ndarray: + """ + Creates a 64x64 target map for an RL agent's sub-task from a 64x64 input. + + Retains both `-1` values (dig targets) and `1` values (dump targets) from + the specified region in the global map. Everything outside the region is set to 0 (free). + + Args: + global_target_map_data: Target map of size 64x64 (1: dump, 0: free, -1: dig). + region_coords: (y_start, x_start, y_end, x_end), inclusive bounds. + + Returns: + A new 64x64 map with `-1`s and `1`s from the region at their original positions; + everything else is 0. + """ + y_start, x_start, y_end, x_end = region_coords + + # Initialize a 64x64 map with all zeros (free space) + #sub_task_map = jnp.zeros((64, 64), dtype=global_target_map_data.dtype) + # Initialize a 64x64 map with all ones (dump space) + sub_task_map = jnp.ones((64, 64), dtype=global_target_map_data.dtype) + + # Define slice object for region + region_slice = (slice(y_start, y_end + 1), slice(x_start, x_end + 1)) + + # Extract region from global map + region_data = global_target_map_data[region_slice] + + # Place region data at the SAME position in the new map + sub_task_map = sub_task_map.at[region_slice].set(region_data) + + return sub_task_map + +def create_sub_task_target_map_64x64_big(global_target_map_data: jnp.ndarray, + region_coords: tuple[int, int, int, int]) -> jnp.ndarray: + """ + FIXED: Creates a 64x64 target map that always returns exactly 64x64 dimensions. + + Args: + global_target_map_data: Target map of any size (1: dump, 0: free, -1: dig). + region_coords: (y_start, x_start, y_end, x_end), inclusive bounds. + + Returns: + A 64x64 map with the extracted region data placed appropriately. + """ + y_start, x_start, y_end, x_end = region_coords + + # Always initialize a 64x64 map + sub_task_map = jnp.ones((64, 64), dtype=global_target_map_data.dtype) + + # Calculate the actual region size + region_height = y_end - y_start + 1 + region_width = x_end - x_start + 1 + + # Extract the region from global map + region_slice = (slice(y_start, y_end + 1), slice(x_start, x_end + 1)) + region_data = global_target_map_data[region_slice] + + # Ensure region_data is exactly the expected size + if region_data.shape != (region_height, region_width): + print(f"Warning: Region data shape {region_data.shape} doesn't match expected {(region_height, region_width)}") + # Crop or pad as needed + min_h = min(region_data.shape[0], region_height) + min_w = min(region_data.shape[1], region_width) + region_data = region_data[:min_h, :min_w] + + # Calculate how much of the region we can fit in 64x64 + fit_height = min(region_height, 64) + fit_width = min(region_width, 64) + + # Place the region data in the 64x64 map + # Option 1: Place at origin (0,0) + sub_task_map = sub_task_map.at[:fit_height, :fit_width].set(region_data[:fit_height, :fit_width]) + + # Option 2: Place at the same relative position (if it fits) + # if region_height <= 64 and region_width <= 64: + # # Calculate offset to maintain relative position + # offset_y = min(y_start, 64 - region_height) + # offset_x = min(x_start, 64 - region_width) + # sub_task_map = sub_task_map.at[offset_y:offset_y+region_height, offset_x:offset_x+region_width].set(region_data) + + return sub_task_map + +def create_sub_task_action_map_64x64_big(action_map_data: jnp.ndarray, + region_coords: tuple[int, int, int, int]) -> jnp.ndarray: + """ + FIXED: Creates a 64x64 action map that always returns exactly 64x64 dimensions. + """ + y_start, x_start, y_end, x_end = region_coords + + # Always initialize a 64x64 map + sub_task_map = jnp.zeros((64, 64), dtype=action_map_data.dtype) + + # Calculate the actual region size + region_height = y_end - y_start + 1 + region_width = x_end - x_start + 1 + + # Extract the region from global map + region_slice = (slice(y_start, y_end + 1), slice(x_start, x_end + 1)) + region_data = action_map_data[region_slice] + + # Calculate how much of the region we can fit in 64x64 + fit_height = min(region_height, 64) + fit_width = min(region_width, 64) + + # Place the region data in the 64x64 map + sub_task_map = sub_task_map.at[:fit_height, :fit_width].set(region_data[:fit_height, :fit_width]) + + return sub_task_map + +def create_sub_task_mask_64x64_big(mask_data: jnp.ndarray, + region_coords: tuple[int, int, int, int], + default_value: int = 1) -> jnp.ndarray: + """ + FIXED: Generic function to create 64x64 masks that always return exactly 64x64 dimensions. + + Args: + mask_data: Input mask of any size + region_coords: (y_start, x_start, y_end, x_end), inclusive bounds + default_value: Default value for areas outside the region (1 for non-traversable, 0 for traversable) + """ + y_start, x_start, y_end, x_end = region_coords + + # Always initialize a 64x64 map + sub_task_map = jnp.full((64, 64), default_value, dtype=mask_data.dtype) + + # Calculate the actual region size + region_height = y_end - y_start + 1 + region_width = x_end - x_start + 1 + + # Extract the region from global map + region_slice = (slice(y_start, y_end + 1), slice(x_start, x_end + 1)) + region_data = mask_data[region_slice] + + # Calculate how much of the region we can fit in 64x64 + fit_height = min(region_height, 64) + fit_width = min(region_width, 64) + + # Place the region data in the 64x64 map + sub_task_map = sub_task_map.at[:fit_height, :fit_width].set(region_data[:fit_height, :fit_width]) + + return sub_task_map + +# Wrapper functions for specific mask types +def create_sub_task_padding_mask_64x64_big(padding_mask_data, region_coords): + return create_sub_task_mask_64x64_big(padding_mask_data, region_coords, default_value=1) + +def create_sub_task_traversability_mask_64x64_big(traversability_mask_data, region_coords): + return create_sub_task_mask_64x64_big(traversability_mask_data, region_coords, default_value=1) + +def create_sub_task_dumpability_mask_64x64_big(dumpability_mask_data, region_coords): + return create_sub_task_mask_64x64_big(dumpability_mask_data, region_coords, default_value=0) + +def extract_python_format_data(llm_response_text): + """ + Extracts Python-formatted data from LLM response, preserving tuples. + + Args: + llm_response_text (str): The raw text response from the LLM + + Returns: + list: The parsed Python list with tuples preserved + + Raises: + ValueError: If no valid Python data could be extracted + """ + # First, check if we have a code block and extract its content + code_block_pattern = r'```(?:json|python)?\s*([\s\S]*?)\s*```' + code_match = re.search(code_block_pattern, llm_response_text, re.DOTALL) + + if code_match: + content = code_match.group(1) + else: + # If no code block, use the whole text + content = llm_response_text + + # Clean up the content to ensure it's valid Python syntax + # Replace double quotes with single quotes for keys (Python style) + content = re.sub(r'"([^"]+)":', r"'\1':", content) + + # Make sure status values are properly quoted + content = re.sub(r"'status':\s*([a-zA-Z_][a-zA-Z0-9_]*)", r"'status': '\1'", content) + + try: + # Use ast.literal_eval to parse the Python literals, which preserves tuples + return ast.literal_eval(content) + except (SyntaxError, ValueError) as e: + logger.warning(f"ast.literal_eval failed: {e}") + + # Try to extract and process each dict individually + results = [] + dict_pattern = r'\{\s*\'id\':\s*(\d+)[\s\S]*?(?=\}\s*,|\}\s*$)' + + for match in re.finditer(dict_pattern, content, re.DOTALL): + try: + dict_str = match.group(0) + '}' + # Make sure all string values are properly quoted + dict_str = re.sub(r"'([^']+)':\s*([a-zA-Z_][a-zA-Z0-9_]*)", r"'\1': '\2'", dict_str) + obj = ast.literal_eval(dict_str) + results.append(obj) + except (SyntaxError, ValueError) as e: + logger.warning(f"Failed to parse dict: {e}") + continue + + if results: + return results + + # If we still couldn't parse it, try a more manual approach + try: + # Extract data manually using regex + result = [] + id_pattern = r"'id':\s*(\d+)" + region_pattern = r"'region_coords':\s*\(([^)]+)\)" + pos_pattern = r"'start_pos':\s*\(([^)]+)\)" + angle_pattern = r"'start_angle':\s*(\d+)" + status_pattern = r"'status':\s*'([^']+)'" + + # Get all IDs + ids = re.findall(id_pattern, content) + region_coords = re.findall(region_pattern, content) + start_positions = re.findall(pos_pattern, content) + start_angles = re.findall(angle_pattern, content) + statuses = re.findall(status_pattern, content) + + # Ensure we have the same number of matches for each field + min_length = min(len(ids), len(region_coords), len(start_positions), + len(start_angles), len(statuses)) + + for i in range(min_length): + # Parse tuple values + region_tuple = tuple(int(x.strip()) for x in region_coords[i].split(',')) + start_pos_tuple = tuple(int(x.strip()) for x in start_positions[i].split(',')) + + obj = { + 'id': int(ids[i]), + 'region_coords': region_tuple, + 'start_pos': start_pos_tuple, + 'start_angle': int(start_angles[i]), + 'status': statuses[i] + } + result.append(obj) + + if result: + return result + except Exception as e: + logger.error(f"Manual extraction failed: {e}") + + raise ValueError("Could not extract valid Python data with tuples from LLM response") + +def is_valid_region_list(var): + """ + Checks if the variable is a list of dictionaries with the required structure. + The structure should be a list containing at least one dictionary with the keys: + 'id', 'region_coords', 'start_pos', 'start_angle', and 'status'. + + 'region_coords' and 'start_pos' should be tuples. + Each region must be maximum 64x64 in size. + + Example of valid structure: + [{'id': 0, 'region_coords': (15, 15, 50, 50), 'start_pos': (42, 36), 'start_angle': 0, 'status': 'pending'}] + + Args: + var: The variable to check + + Returns: + bool: True if the variable has the valid structure, False otherwise + """ + # Check if var is a list + if not isinstance(var, list): + return False + + # Check if list has at least one element + if len(var) == 0: + return False + + # Maximum allowed partition size + MAX_PARTITION_SIZE = 64 + + # Check each element in the list + for item in var: + # Check if item is a dictionary + if not isinstance(item, dict): + return False + + # Check required keys + required_keys = {'id', 'region_coords', 'start_pos', 'start_angle', 'status'} + if set(item.keys()) != required_keys: + return False + + # Check types of specific fields + if not isinstance(item['id'], (int, float)): + return False + + if not isinstance(item['region_coords'], tuple) or len(item['region_coords']) != 4: + return False + + if not isinstance(item['start_pos'], tuple) or len(item['start_pos']) != 2: + return False + + if not isinstance(item['start_angle'], (int, float)): + return False + + if not isinstance(item['status'], str): + return False + + # Check partition size (region_coords format: (y_start, x_start, y_end, x_end)) + y_start, x_start, y_end, x_end = item['region_coords'] + if not all(isinstance(coord, (int, float)) for coord in item['region_coords']): + return False + + # Calculate width and height from coordinates + width = x_end - x_start + height = y_end - y_start + + print(f"Checking partition: {item['id']} with width {width} and height {height}") + + if width > MAX_PARTITION_SIZE or height > MAX_PARTITION_SIZE: + return False + + return True + +def compute_manual_subtasks(ORIGINAL_MAP_SIZE, NUM_PARTITIONS): + """ + UPDATED: Ensures all partitions are exactly 64x64 for RL policy compatibility. + """ + if ORIGINAL_MAP_SIZE not in [64, 128]: + raise ValueError(f"Unsupported ORIGINAL_MAP_SIZE: {ORIGINAL_MAP_SIZE}. Must be 64 or 128.") + + if NUM_PARTITIONS == 1: + if ORIGINAL_MAP_SIZE == 64: + sub_tasks_manual = [ + {'id': 0, 'region_coords': (0, 0, 63, 63), 'start_pos': (32, 32), 'start_angle': 0, 'status': 'pending'}, + ] + else: # 128x128 + # Take a 64x64 region from the 128x128 map + # sub_tasks_manual = [ + # {'id': 0, 'region_coords': (0, 0, 63, 63), 'start_pos': (32, 32), 'start_angle': 0, 'status': 'pending'}, + # ] + + # Single partition for 128x128 map, taking the center 64x64 region + sub_tasks_manual = [ + {'id': 0, 'region_coords': (32, 32, 95, 95), 'start_pos': (64, 64), 'start_angle': 0, 'status': 'pending'}, + ] + + + elif NUM_PARTITIONS == 2: + if ORIGINAL_MAP_SIZE == 64: + # Horizontal split of 64x64 map + sub_tasks_manual = [ + {'id': 0, 'region_coords': (0, 0, 31, 63), 'start_pos': (16, 32), 'start_angle': 0, 'status': 'pending'}, + {'id': 1, 'region_coords': (32, 0, 63, 63), 'start_pos': (48, 32), 'start_angle': 0, 'status': 'pending'} + ] + + # Horizontal partitioning for 64x64 map (x,y) + # sub_tasks_manual = [ + # {'id': 0, 'region_coords': (0, 0, 31, 63), 'start_pos': (32, 16), 'start_angle': 0, 'status': 'pending'}, + # {'id': 1, 'region_coords': (32, 0, 63, 63), 'start_pos': (32, 48), 'start_angle': 0, 'status': 'pending'} + # ] + + # Vertical partitioning for 64x64 map + # sub_tasks_manual = [ + # {'id': 0, 'region_coords': (0, 0, 63, 31), 'start_pos': (32, 16), 'start_angle': 0, 'status': 'pending'}, + # {'id': 1, 'region_coords': (0, 32, 63, 63), 'start_pos': (32, 48), 'start_angle': 0, 'status': 'pending'} + # ] + + # Vertical partitioning with overlapping + # sub_tasks_manual = [ + # {'id': 0, 'region_coords': (0, 0, 63, 35), 'start_pos': (32, 18), 'start_angle': 0, 'status': 'pending'}, + # {'id': 1, 'region_coords': (0, 28, 63, 63), 'start_pos': (32, 46), 'start_angle': 0, 'status': 'pending'} + # ] + + # Random partitioning for 64x64 map + # sub_tasks_manual = [ + # {'id': 0, 'region_coords': (0, 0, 32, 32), 'start_pos': (25, 20), 'start_angle': 0, 'status': 'pending'}, + # {'id': 1, 'region_coords': (0, 33, 63, 63), 'start_pos': (40, 40), 'start_angle': 0, 'status': 'pending'}, + # ] + + else: # 128x128 + # Two 64x64 regions from different parts of the 128x128 map + sub_tasks_manual = [ + {'id': 0, 'region_coords': (0, 0, 63, 63), 'start_pos': (32, 32), 'start_angle': 0, 'status': 'pending'}, + {'id': 1, 'region_coords': (64, 0, 127, 63), 'start_pos': (96, 32), 'start_angle': 0, 'status': 'pending'} + ] + + # Two overlapping 64x64 regions in the center of the 128x128 map + # Left-center region: (16, 16) to (79, 79) + # Right-center region: (48, 48) to (111, 111) + # Overlap area: (48, 48) to (79, 79) = 32x32 overlap + # sub_tasks_manual = [ + # {'id': 0, 'region_coords': (16, 16, 79, 79), 'start_pos': (48, 48), 'start_angle': 0, 'status': 'pending'}, + # {'id': 1, 'region_coords': (48, 48, 111, 111), 'start_pos': (80, 80), 'start_angle': 0, 'status': 'pending'} + # ] + + # Two overlapping 64x64 regions in the center of the 128x128 map shifting the start positions + # sub_tasks_manual = [ + # {'id': 0, 'region_coords': (32, 0, 95, 63), 'start_pos': (64, 32), 'start_angle': 0, 'status': 'pending'}, + # {'id': 1, 'region_coords': (32, 64, 95, 127), 'start_pos': (64, 96), 'start_angle': 0, 'status': 'pending'} + # ] + # sub_tasks_manual = [ + # {'id': 0, 'region_coords': (32, 0, 95, 63), 'start_pos': (32, 32), 'start_angle': 0, 'status': 'pending'}, + # {'id': 1, 'region_coords': (32, 48, 95, 111), 'start_pos': (32, 80), 'start_angle': 0, 'status': 'pending'} + # ] + # sub_tasks_manual = [ + # {'id': 0, 'region_coords': (32, 0, 95, 63), 'start_pos': (32, 32), 'start_angle': 0, 'status': 'pending'}, + # {'id': 1, 'region_coords': (32, 54, 95, 117), 'start_pos': (32, 80), 'start_angle': 0, 'status': 'pending'} + # ] + + elif NUM_PARTITIONS == 4: + if ORIGINAL_MAP_SIZE == 64: + # 2x2 grid of 32x32 regions (these will be padded to 64x64) + sub_tasks_manual = [ + {'id': 0, 'region_coords': (0, 0, 31, 31), 'start_pos': (16, 16), 'start_angle': 0, 'status': 'pending'}, + {'id': 1, 'region_coords': (0, 32, 31, 63), 'start_pos': (16, 48), 'start_angle': 0, 'status': 'pending'}, + {'id': 2, 'region_coords': (32, 0, 63, 31), 'start_pos': (48, 16), 'start_angle': 0, 'status': 'pending'}, + {'id': 3, 'region_coords': (32, 32, 63, 63), 'start_pos': (48, 48), 'start_angle': 0, 'status': 'pending'} + ] + else: # 128x128 + # Four 64x64 regions covering different corners of the 128x128 map + sub_tasks_manual = [ + {'id': 0, 'region_coords': (0, 0, 63, 63), 'start_pos': (32, 32), 'start_angle': 0, 'status': 'pending'}, + {'id': 1, 'region_coords': (0, 64, 63, 127), 'start_pos': (32, 96), 'start_angle': 0, 'status': 'pending'}, + {'id': 2, 'region_coords': (64, 0, 127, 63), 'start_pos': (96, 32), 'start_angle': 0, 'status': 'pending'}, + {'id': 3, 'region_coords': (64, 64, 127, 127), 'start_pos': (96, 96), 'start_angle': 0, 'status': 'pending'} + ] + # Four 64x64 regions covering different corners of the 128x128 map different start positions + # sub_tasks_manual = [ + # {'id': 0, 'region_coords': (0, 0, 63, 63), 'start_pos': (20, 20), 'start_angle': 0, 'status': 'pending'}, + # {'id': 1, 'region_coords': (0, 64, 63, 127), 'start_pos': (20, 44), 'start_angle': 0, 'status': 'pending'}, + # {'id': 2, 'region_coords': (64, 0, 127, 63), 'start_pos': (44, 20), 'start_angle': 0, 'status': 'pending'}, + # {'id': 3, 'region_coords': (64, 64, 127, 127), 'start_pos': (44, 44), 'start_angle': 0, 'status': 'pending'} + # ] + else: + raise ValueError("Invalid number of partitions. Must be 1, 2 or 4.") + + # Validate all partitions + for partition in sub_tasks_manual: + region_coords = partition['region_coords'] + y_start, x_start, y_end, x_end = region_coords + height = y_end - y_start + 1 + width = x_end - x_start + 1 + print(f"Partition {partition['id']}: {height}x{width} (will be processed as 64x64)") + + # Ensure we don't exceed the original map boundaries + if y_end >= ORIGINAL_MAP_SIZE or x_end >= ORIGINAL_MAP_SIZE: + raise ValueError(f"Partition {partition['id']} exceeds map boundaries") + + return sub_tasks_manual + +def check_overall_completion(partition_states): + """ + Check if the overall task is complete based on partition completion status. + Returns True if all partitions are completed or if sufficient progress has been made. + """ + if not partition_states: + return False + + completed_partitions = [] + failed_partitions = [] + active_partitions = [] + + for partition_idx, partition_state in partition_states.items(): + status = partition_state.get('status', 'unknown') + if status == 'completed': + completed_partitions.append(partition_idx) + elif status == 'failed': + failed_partitions.append(partition_idx) + elif status == 'active': + active_partitions.append(partition_idx) + + total_partitions = len(partition_states) + + + all_completed = len(completed_partitions) == total_partitions + + is_complete = all_completed + + return is_complete + +def calculate_map_completion_metrics(partition_states): + """ + Calculate completion metrics for the current map based on partition states. + """ + if not partition_states: + return { + 'done': False, + 'completion_rate': 0.0, + 'total_reward': 0.0, + 'completed_partitions': 0, + 'failed_partitions': 0, + 'total_partitions': 0 + } + + completed_count = 0 + failed_count = 0 + total_reward = 0.0 + total_partitions = len(partition_states) + + for partition_idx, partition_state in partition_states.items(): + status = partition_state.get('status', 'unknown') + partition_reward = partition_state.get('total_reward', 0.0) + total_reward += partition_reward + + if status == 'completed': + completed_count += 1 + elif status == 'failed': + failed_count += 1 + + completion_rate = completed_count / total_partitions if total_partitions > 0 else 0.0 + is_done = check_overall_completion(partition_states) + + return { + 'done': is_done, + 'completion_rate': completion_rate, + 'total_reward': total_reward, + 'completed_partitions': completed_count, + 'failed_partitions': failed_count, + 'total_partitions': total_partitions + } + +def wrap_action_llm(action_rl, action_type): + """ + Wrap RL action for the environment. + Ensures correct shape for single environment (non-batched). + """ + # Ensure action_rl is a single integer, not an array + if isinstance(action_rl, jnp.ndarray): + if action_rl.shape == (1,): + action_val = action_rl[0] # Extract single value + elif action_rl.shape == (): + action_val = action_rl # Already scalar + else: + raise ValueError(f"Unexpected action shape: {action_rl.shape}") + else: + action_val = action_rl + + # Convert to proper format for single environment + # Shape should be [1] not [1,1] + wrapped_action = action_type( + type=jnp.array([action_val], dtype=jnp.int8), # Shape: [1] + action=jnp.array([action_val], dtype=jnp.int8) # Shape: [1] + ) + + return wrapped_action + +def add_batch_dimension_to_observation(obs): + """Add batch dimension to all observation components.""" + batched_obs = {} + for key, value in obs.items(): + if isinstance(value, jnp.ndarray): + batched_obs[key] = jnp.expand_dims(value, axis=0) + else: + batched_obs[key] = jnp.array([value]) + return batched_obs + +def reset_to_next_map(map_index, seed, env_manager, global_env_config, + initial_custom_pos=None, initial_custom_angle=None): + """Reset the existing environment to the next map""" + print(f"\n{'='*60}") + print(f"RESETTING TO MAP {map_index}") + print(f"{'='*60}") + + # Create new seed for this map reset + map_seed = seed + map_index * 1000 + map_rng = jax.random.PRNGKey(map_seed) + map_rng, reset_rng = jax.random.split(map_rng) + reset_keys = jax.random.split(reset_rng, 1) + + # Reset the existing environment to get a new map + # The environment will internally cycle through its available maps + env_manager.global_env.timestep = env_manager.global_env.reset( + global_env_config, reset_keys, initial_custom_pos, initial_custom_angle + ) + + # Extract and store the NEW global map data directly + new_timestep = env_manager.global_env.timestep + env_manager.global_maps['target_map'] = new_timestep.state.world.target_map.map[0].copy() + env_manager.global_maps['action_map'] = new_timestep.state.world.action_map.map[0].copy() + env_manager.global_maps['dumpability_mask'] = new_timestep.state.world.dumpability_mask.map[0].copy() + env_manager.global_maps['dumpability_mask_init'] = new_timestep.state.world.dumpability_mask_init.map[0].copy() + env_manager.global_maps['padding_mask'] = new_timestep.state.world.padding_mask.map[0].copy() + env_manager.global_maps['traversability_mask'] = new_timestep.state.world.traversability_mask.map[0].copy() + env_manager.global_maps['trench_axes'] = new_timestep.state.world.trench_axes.copy() + env_manager.global_maps['trench_type'] = new_timestep.state.world.trench_type.copy() + + # Store the new global timestep + env_manager.global_timestep = new_timestep + + # Clear any existing partitions data to ensure fresh start + env_manager.partitions = [] + env_manager.overlap_map = {} + env_manager.overlap_regions = {} + + print(f"Environment reset to map {map_index}") + print(f"New target map has {jnp.sum(env_manager.global_maps['target_map'] < 0)} dig targets") + +def initialize_partitions_for_current_map(env_manager, config, model_params): + """Initialize all partitions for the current map""" + partition_states = {} + partition_models = {} + active_partitions = [] + + num_partitions = len(env_manager.partitions) + print(f"Number of partitions: {num_partitions}") + + # Initialize all partitions + for partition_idx in range(num_partitions): + try: + print(f"Initializing partition {partition_idx}...") + + small_env_timestep = env_manager.initialize_small_environment(partition_idx) + + partition_states[partition_idx] = { + 'timestep': small_env_timestep, + 'prev_actions_rl': jnp.zeros((1, config.num_prev_actions), dtype=jnp.int32), + 'step_count': 0, + 'status': 'active', + 'rewards': [], + 'actions': [], + 'total_reward': 0.0, + } + active_partitions.append(partition_idx) + + partition_models[partition_idx] = { + 'model': load_neural_network(config, env_manager.small_env), + 'params': model_params.copy(), + 'prev_actions': jnp.zeros((1, config.num_prev_actions), dtype=jnp.int32) + } + + except Exception as e: + print(f"Failed to initialize partition {partition_idx}: {e}") + if partition_idx < len(env_manager.partitions): + env_manager.partitions[partition_idx]['status'] = 'failed' + + if not active_partitions: + print("No partitions could be initialized!") + return None, None, None + + print(f"Successfully initialized {len(active_partitions)} partitions: {active_partitions}") + return partition_states, partition_models, active_partitions + +def init_llms(llm_model_key, llm_model_name, config, action_size, + APP_NAME, USER_ID, SESSION_ID, MAP_SIZE, MAX_NUM_PARTITIONS): + """ + Initialize LLMs using file-based prompts. + """ + # Initialize prompt manager + prompts = PromptManager(prompts_dir="llm/prompts") + + session_manager = SessionManager() + + if llm_model_key == "gpt": + llm_model_name_extended = "openai/{}".format(llm_model_name) + elif llm_model_key == "claude": + llm_model_name_extended = "anthropic/{}".format(llm_model_name) + else: + llm_model_name_extended = llm_model_name + + print("Using model: ", llm_model_name_extended) + + # Load system messages from files + system_message_master = prompts.get("partitioning_new", map_size=MAP_SIZE, max_partitions=MAX_NUM_PARTITIONS) + + system_message_delegation = prompts.get("delegation", observation="See current state") + + system_message_excavator = prompts.get("excavator_llm_simple") + + # Create agents + if llm_model_key == "gemini": + llm_partitioning_agent = Agent( + name="PartitioningAgent", + model=llm_model_name_extended, + description="Master excavation coordinator for partitioning", + instruction=system_message_master, + ) + + llm_delegation_agent = Agent( + name="DelegationAgent", + model=llm_model_name_extended, + description="Task delegation coordinator", + instruction=system_message_delegation, + ) + + llm_excavator_agent = Agent( + name="ExcavatorAgent", + model=llm_model_name_extended, + description="Excavator control agent", + instruction=system_message_excavator, + ) + else: + llm_partitioning_agent = Agent( + name="PartitioningAgent", + model=LiteLlm(model=llm_model_name_extended), + description="Master excavation coordinator for partitioning", + instruction=system_message_master, + ) + + llm_delegation_agent = Agent( + name="DelegationAgent", + model=LiteLlm(model=llm_model_name_extended), + description="Task delegation coordinator", + instruction=system_message_delegation, + ) + + llm_excavator_agent = Agent( + name="ExcavatorAgent", + model=LiteLlm(model=llm_model_name_extended), + description="Excavator control agent", + instruction=system_message_excavator, + ) + + + # Partitioning session + app_name_partitioning = f"{APP_NAME}_partitioning" + user_id_partitioning = f"{USER_ID}_partitioning" + session_id_partitioning = f"{SESSION_ID}_partitioning" + + session_service_partitioning = session_manager.create_agent_session( + "PartitioningAgent", + app_name_partitioning, + user_id_partitioning, + session_id_partitioning + ) + + # Delegation session + app_name_delegation = f"{APP_NAME}_delegation" + user_id_delegation = f"{USER_ID}_delegation" + session_id_delegation = f"{SESSION_ID}_delegation" + + session_service_delegation = session_manager.create_agent_session( + "DelegationAgent", + app_name_delegation, + user_id_delegation, + session_id_delegation + ) + + # Excavator session + app_name_excavator = f"{APP_NAME}_excavator" + user_id_excavator = f"{USER_ID}_excavator" + session_id_excavator = f"{SESSION_ID}_excavator" + + session_service_excavator = session_manager.create_agent_session( + "ExcavatorAgent", + app_name_excavator, + user_id_excavator, + session_id_excavator + ) + + #print("All sessions created successfully.") + + # CREATE RUNNERS WITH SESSION MANAGER + runner_partitioning = session_manager.create_runner( + llm_partitioning_agent, + "PartitioningAgent", + app_name_partitioning + ) + + runner_delegation = session_manager.create_runner( + llm_delegation_agent, + "DelegationAgent", + app_name_delegation + ) + + runner_excavator = session_manager.create_runner( + llm_excavator_agent, + "ExcavatorAgent", + app_name_excavator + ) + + #print("All runners created successfully.") + + # Create LLM query object + llm_query = LLM_query( + model_name=llm_model_name_extended, + model=llm_model_key, + system_message=system_message_excavator, + action_size=action_size, + session_id=session_id_excavator, + runner=runner_excavator, + user_id=user_id_excavator, + ) + + # Initialize previous actions + prev_actions = None + if config: + import jax.numpy as jnp + prev_actions = jnp.zeros( + (1, config.num_prev_actions), + dtype=jnp.int32 + ) + else: + print("Warning: rl_config is None, prev_actions will not be initialized.") + + # Debug: List all sessions + #session_manager.list_sessions() + + return (prompts, llm_query, runner_partitioning, runner_delegation, prev_actions, session_manager) + +def get_delegation_prompt(prompts, current_observation, context="", ENABLE_INTERVENTION=False): + """Get delegation prompt with current state.""" + try: + obs_str = json.dumps({k: v.tolist() if hasattr(v, 'tolist') else str(v) + for k, v in current_observation.items()}) if isinstance(current_observation, dict) else str(current_observation) + if ENABLE_INTERVENTION: + prompt = prompts.get("delegation", observation=obs_str) + else: + prompt = prompts.get("delegation_no_intervention", observation=obs_str) + if context: + prompt += f"\n\nAdditional context: {context}" + return prompt + except Exception as e: + print(f"Error generating delegation prompt: {e}") + return "Decide between 'delegate_to_rl' or 'delegate_to_llm'." + +def get_excavator_prompt(prompts, direction, bucket_status, position): + """Get excavator action prompt with current state.""" + try: + return prompts.get("excavator_action", + direction=direction, + bucket_status=bucket_status, + position=position) + except Exception as e: + print(f"Error generating excavator prompt: {e}") + return "Choose the best action (0-6) for the current game state." + +async def call_agent_async_master(query: str, image, runner, user_id, session_id, session_manager=None): + """ + Fixed version of call_agent_async_master with better error handling and session verification. + """ + print(f"\n>>> Calling agent with user_id: {user_id}, session_id: {session_id}") + + # Verify session exists if session_manager is provided + if session_manager: + session_info = session_manager.get_session_info(user_id, session_id) + if not session_info: + print(f"WARNING: Session {user_id}_{session_id} not found in session manager") + # Try to recreate session if possible + # This would require more context about the agent and app_name + + # Prepare the user's message in ADK format + text = types.Part.from_text(text=query) + parts = [text] + + if image is not None: + image_data = encode_image(image) + content_image = types.Part.from_bytes(data=image_data, mime_type="image/jpeg") + parts.append(content_image) + + user_content = types.Content(role='user', parts=parts) + + final_response_text = "Agent did not produce a final response." # Default + + try: + # Execute the agent with proper error handling + async for event in runner.run_async( + user_id=user_id, + session_id=session_id, + new_message=user_content + ): + print(f" [Event] Author: {event.author}, Type: {type(event).__name__}, Final: {event.is_final_response()}") + + if event.is_final_response(): + if event.content and event.content.parts: + final_response_text = event.content.parts[0].text + elif event.actions and event.actions.escalate: + final_response_text = f"Agent escalated: {event.error_message or 'No specific message.'}" + break + + except Exception as e: + print(f"Error during agent execution: {e}") + final_response_text = f"Error: {str(e)}" + + print(f"<<< Agent Response: {final_response_text}") + return final_response_text + +def setup_partitions_and_llm(map_index, ORIGINAL_MAP_SIZE, env_manager, config, llm_model_name, llm_model_key, + APP_NAME, USER_ID, SESSION_ID, screen, USE_MANUAL_PARTITIONING=False, + USE_IMAGE_PROMPT=False,MAX_NUM_PARTITIONS=4,USE_EXACT_NUMBER_OF_PARTITIONS=False, USE_RANDOM_PARTITIONING=False, sub_task_seed=58): + """ + Setup_partitions_and_llm with proper session management. + """ + + action_size = 7 + target_map = env_manager.global_maps['target_map'] + + if USE_MANUAL_PARTITIONING: + sub_tasks_manual = compute_manual_subtasks(ORIGINAL_MAP_SIZE, MAX_NUM_PARTITIONS) + elif USE_RANDOM_PARTITIONING: + sub_tasks_manual = compute_random_subtasks_validated( + ORIGINAL_MAP_SIZE, MAX_NUM_PARTITIONS, target_map, + seed=sub_task_seed * (map_index+1), min_targets=1 + ) + else: + sub_tasks_manual = compute_manual_subtasks(ORIGINAL_MAP_SIZE, MAX_NUM_PARTITIONS) + + # Initialize LLM agent with fixed session management + (prompts, llm_query, runner_partitioning, runner_delegation, prev_actions, session_manager) = init_llms( + llm_model_key, llm_model_name, config, action_size, + APP_NAME, USER_ID, f"{SESSION_ID}_map_{map_index}", ORIGINAL_MAP_SIZE, MAX_NUM_PARTITIONS + ) + + sub_tasks_llm = [] + + # ALWAYS initialize partitions - either manual or LLM-generated + if USE_MANUAL_PARTITIONING or USE_RANDOM_PARTITIONING: + print("Using manually or random defined sub-tasks.") + env_manager.initialize_with_fixed_overlaps(sub_tasks_manual) + else: + print("Calling LLM agent for partitioning decision...") + + game_state_image = capture_screen(screen) + current_observation = env_manager.global_env.timestep.observation + + try: + obs_dict = {k: v.tolist() for k, v in current_observation.items()} + observation_str = json.dumps(obs_dict) + except AttributeError: + observation_str = str(current_observation) + # Use file-based prompt + if USE_IMAGE_PROMPT: + if USE_EXACT_NUMBER_OF_PARTITIONS: + prompt = prompts.get("partitioning_exact", + map_size=ORIGINAL_MAP_SIZE, + max_partitions=MAX_NUM_PARTITIONS) + "\n\nCurrent observation: See image" + else: + prompt = prompts.get("partitioning", + map_size=ORIGINAL_MAP_SIZE, + max_partitions=MAX_NUM_PARTITIONS) + "\n\nCurrent observation: See image" + + else: + try: + obs_dict = {k: v.tolist() for k, v in current_observation.items()} + observation_str = json.dumps(obs_dict) + except AttributeError: + observation_str = str(current_observation) + + if USE_EXACT_NUMBER_OF_PARTITIONS: + prompt = prompts.get("partitioning_exact", + map_size=ORIGINAL_MAP_SIZE, + max_partitions=MAX_NUM_PARTITIONS) + "\n\nCurrent observation: See image" + else: + prompt = prompts.get("partitioning", + map_size=ORIGINAL_MAP_SIZE, + max_partitions=MAX_NUM_PARTITIONS) + "\n\nCurrent observation: See image" + + try: + user_id_partitioning = f"{USER_ID}_partitioning" + session_id_partitioning = f"{SESSION_ID}_map_{map_index}_partitioning" + + if USE_IMAGE_PROMPT: + response = asyncio.run(call_agent_async_master( + prompt, game_state_image, runner_partitioning, + user_id_partitioning, session_id_partitioning, session_manager + )) + else: + response = asyncio.run(call_agent_async_master( + prompt, None, runner_partitioning, + user_id_partitioning, session_id_partitioning, session_manager + )) + + llm_response_text = response + + try: + sub_tasks_llm = extract_python_format_data(llm_response_text) + print("Successfully parsed LLM response with tuples preserved") + except ValueError as e: + print(f"Extraction failed: {e}") + sub_tasks_llm = sub_tasks_manual + + except Exception as adk_err: + print(f"Error during PARTITIONING ADK agent communication: {adk_err}") + sub_tasks_llm = sub_tasks_manual + + partition_validation = is_valid_region_list(sub_tasks_llm) + + if partition_validation: + print("Using LLM-generated sub-tasks.") + env_manager.initialize_with_fixed_overlaps(sub_tasks_llm) + else: + print("LLM-generated partitions invalid, falling back to manually defined sub-tasks.") + env_manager.initialize_with_fixed_overlaps(sub_tasks_manual) + + return llm_query, runner_delegation, session_manager, prompts + +def extract_subsurface(screen, x_start, y_start, width, height, ORIGINAL_MAP_SIZE, global_env_config, partition_idx): + """Extract a subsurface from the screen.""" + + try: + screen_width, screen_height = screen.get_size() + + # Get the actual tile size from the environment + # This should be available from your global_env_config + env_tile_size = global_env_config.tile_size[0].item() # From your existing code + + # Calculate the rendering scale factor + # This depends on how the environment renders to the screen + render_scale = screen_width / (ORIGINAL_MAP_SIZE * env_tile_size) + + # Convert game world coordinates to screen pixel coordinates + screen_x_start = int(x_start * env_tile_size * render_scale) + screen_y_start = int(y_start * env_tile_size * render_scale) + screen_width_partition = int(width * env_tile_size * render_scale) + screen_height_partition = int(height * env_tile_size * render_scale) + + # Rest of the clamping and subsurface creation code remains the same... + screen_x_start = max(0, min(screen_x_start, screen_width - 1)) + screen_y_start = max(0, min(screen_y_start, screen_height - 1)) + screen_width_partition = min(screen_width_partition, screen_width - screen_x_start) + screen_height_partition = min(screen_height_partition, screen_height - screen_y_start) + + if screen_width_partition <= 0 or screen_height_partition <= 0: + print(f" Warning: Invalid partition size, using fallback") + subsurface = screen.subsurface((0, 0, min(64, screen_width), min(64, screen_height))) + else: + subsurface = screen.subsurface((screen_x_start, screen_y_start, screen_width_partition, screen_height_partition)) + + except ValueError as e: + print(f"Error extracting subsurface for partition {partition_idx}: {e}") + fallback_size = min(64, screen_width, screen_height) + subsurface = screen.subsurface((0, 0, fallback_size, fallback_size)) + + return subsurface + +def capture_screen(surface): + """Captures the current screen and converts it to an image format.""" + img_array = pg.surfarray.array3d(surface) + #img_array = np.rot90(img_array, k=3) # Rotate if needed + img_array = np.transpose(img_array, (1, 0, 2)) # Correct rotation + + img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) + return img_array + +def save_video(frames, output_path, fps=1): + """Saves a list of frames as a video.""" + if len(frames) == 0: + print("No frames to save.") + return + + height, width, _ = frames[0].shape + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) + + for frame in frames: + out.write(frame) + out.release() + print(f"Video saved to {output_path}") + +def extract_bucket_status(state): + """ + Extract the bucket status from the state. + + Args: + state: The current State object. + + Returns: + str: The bucket status ('loaded' or 'empty'). + """ + # Access the bucket status from the agent's state + bucket_status = state.agent.agent_state.loaded + + # Map the status to a human-readable string + return "loaded" if bucket_status else "empty" + +def base_orientation_to_direction(angle_base): + """ + Convert the base orientation value (0-3) to a cardinal direction. + + Args: + angle_base (int or JAX array): The base orientation value. + + Returns: + str: The corresponding cardinal direction ('up', 'right', 'down', 'left'). + """ + # Convert JAX array to a Python scalar if necessary + if isinstance(angle_base, jax.Array): + angle_base = angle_base.item() + + # Map orientation to cardinal direction + direction_map = { + 0: "right", + 3: "up", + 7: "left", + 11: "down" + } + return direction_map.get(angle_base, "unknown") # Default to 'unknown' if invalid + +def extract_base_orientation(state): + """ + Extract the excavator's base orientation from the state and convert it to a cardinal direction. + + Args: + state: The current State object. + + Returns: + A dictionary containing the base angle and its corresponding cardinal direction. + """ + # Extract the base angle + angle_base = state.agent.agent_state.angle_base + + # Convert the base angle to a cardinal direction + direction = base_orientation_to_direction(angle_base) + + return { + "angle_base": angle_base, + "direction": direction, + } + +def load_experiment_constants(config_file="llm/config_llm.yaml"): + """ + Load all experiment constants from YAML file. + Returns a namespace object with all your constants as attributes. + + Usage: + Replace this: + FORCE_DELEGATE_TO_RL = True + FORCE_DELEGATE_TO_LLM = False + ... + + With this: + constants = load_experiment_constants() + FORCE_DELEGATE_TO_RL = constants.force_delegate_to_rl + FORCE_DELEGATE_TO_LLM = constants.force_delegate_to_llm + ... + """ + + # Default values (same as your original constants) + defaults = { + 'force_delegate_to_rl': True, + 'force_delegate_to_llm': False, + 'llm_call_frequency': 15, + 'use_manual_partitioning': True, + 'max_num_partitions': 2, + 'visualize_partitions': True, + 'use_image_prompt': True, + 'app_name': "ExcavatorGameApp", + 'user_id': "user_1", + 'session_id': "session_001", + 'grid_rendering': True, + 'original_map_size': 128, + 'use_rendering': True, + 'use_display': True, + 'enable_intervention': True, + 'intervention_check_frequency': 15, + 'stuck_detection_window': 10, + 'min_reward_threshold': 0.001, + 'use_random_partitioning': False, + 'use_exact_number_of_partitions': False, + 'save_video': False, + 'fps': 30, + 'compute_bench_stats': False, + 'use_exclusive_assignment': False, # New constant for exclusive assignment + } + + # Try to load from YAML file + if os.path.exists(config_file): + try: + with open(config_file, 'r') as f: + config = yaml.safe_load(f) + # Update defaults with loaded values + defaults.update(config) + #print(f"βœ… Configuration loaded from {config_file}") + except Exception as e: + print(f"⚠️ Error loading {config_file}: {e}. Using defaults.") + else: + print(f"⚠️ Config file {config_file} not found. Using defaults.") + + # Convert to namespace for easy attribute access + class ConfigNamespace: + def __init__(self, config_dict): + for key, value in config_dict.items(): + setattr(self, key, value) + + def __repr__(self): + attrs = [f"{k}={v}" for k, v in self.__dict__.items()] + return f"ConfigNamespace({', '.join(attrs)})" + + return ConfigNamespace(defaults) + +def setup_experiment_config(config_file="llm/config_llm.yaml"): + """ + Direct replacement for your constant definitions. + + Replace this block in your code: + FORCE_DELEGATE_TO_RL = True + FORCE_DELEGATE_TO_LLM = False + LLM_CALL_FREQUENCY = 15 + # ... etc + + With this single line: + FORCE_DELEGATE_TO_RL, FORCE_DELEGATE_TO_LLM, LLM_CALL_FREQUENCY, ... = setup_experiment_config() + """ + constants = load_experiment_constants(config_file) + + return ( + constants.force_delegate_to_rl, # FORCE_DELEGATE_TO_RL + constants.force_delegate_to_llm, # FORCE_DELEGATE_TO_LLM + constants.llm_call_frequency, # LLM_CALL_FREQUENCY + constants.use_manual_partitioning, # USE_MANUAL_PARTITIONING + constants.max_num_partitions, # MAX_NUM_PARTITIONS + constants.visualize_partitions, # VISUALIZE_PARTITIONS + constants.use_image_prompt, # USE_IMAGE_PROMPT + constants.app_name, # APP_NAME + constants.user_id, # USER_ID + constants.session_id, # SESSION_ID + constants.grid_rendering, # GRID_RENDERING + constants.original_map_size, # ORIGINAL_MAP_SIZE + constants.use_rendering, # USE_RENDERING + constants.use_display, # USE_DISPLAY + constants.enable_intervention, + constants.intervention_check_frequency, + constants.stuck_detection_window, + constants.min_reward_threshold, + constants.use_random_partitioning, # USE_RANDOM_PARTITIONING + constants.use_exact_number_of_partitions, + constants.save_video, # SAVE_VIDEO + constants.fps, + constants.compute_bench_stats, + constants.use_exclusive_assignment, # USE_EXCLUSIVE_ASSIGNMENT + ) + +def detect_stuck_excavator(partition_state, threshold_steps=10, min_reward_threshold=0.001): + """ + Detect if an excavator is stuck based on recent performance. + + Args: + partition_state: Current partition state dictionary + threshold_steps: Number of recent steps to analyze + min_reward_threshold: Minimum reward expected in the period + + Returns: + dict: Stuck detection result with details + """ + step_count = partition_state['step_count'] + rewards = partition_state['rewards'] + actions = partition_state['actions'] + + # Need minimum steps to evaluate + if step_count < threshold_steps: + return { + 'is_stuck': False, + 'reason': 'insufficient_data', + 'details': f'Only {step_count} steps, need {threshold_steps}' + } + + # Analyze recent performance + recent_rewards = rewards[-threshold_steps:] if len(rewards) >= threshold_steps else rewards + recent_actions = actions[-threshold_steps:] if len(actions) >= threshold_steps else actions + + # Check 1: Very low or negative rewards + total_recent_reward = sum(recent_rewards) + if total_recent_reward < min_reward_threshold: + return { + 'is_stuck': True, + 'reason': 'low_reward', + 'details': f'Total reward in last {len(recent_rewards)} steps: {total_recent_reward:.4f}' + } + + # Check 2: Repetitive actions (agent going in circles) + if len(recent_actions) >= 6: + action_sequence = [int(a[0]) if hasattr(a, '__getitem__') else int(a) for a in recent_actions[-6:]] + # Check for simple repetitive patterns + if len(set(action_sequence)) <= 2: # Only using 2 or fewer different actions + return { + 'is_stuck': True, + 'reason': 'repetitive_actions', + 'details': f'Recent actions: {action_sequence}' + } + + # Check 3: No progress (consistently getting 0 rewards) + zero_reward_count = sum(1 for r in recent_rewards if abs(r) < 0.0001) + if zero_reward_count >= threshold_steps * 0.8: # 80% of recent steps had no reward + return { + 'is_stuck': True, + 'reason': 'no_progress', + 'details': f'{zero_reward_count}/{len(recent_rewards)} recent steps had zero reward' + } + + return { + 'is_stuck': False, + 'reason': 'performing_well', + 'details': f'Recent reward: {total_recent_reward:.4f}, unique actions: {len(set([int(a[0]) if hasattr(a, "__getitem__") else int(a) for a in recent_actions]))}' + } + +def get_intervention_action(partition_state, stuck_info, action_type): + """ + Get an intervention action to help unstuck the agent. + + Args: + partition_state: Current partition state + stuck_info: Result from detect_stuck_excavator() + action_type: Action type class (TrackedAction or WheeledAction) + + Returns: + jnp.array: Action to take + """ + # Import action types + from terra.actions import TrackedActionType, WheeledActionType + + # Determine available actions based on action type + if action_type.__name__ == 'TrackedAction': + FORWARD = TrackedActionType.FORWARD + BACKWARD = TrackedActionType.BACKWARD + DO = TrackedActionType.DO + turn_actions = [] # Tracked vehicles don't have turn actions + else: # WheeledAction + FORWARD = WheeledActionType.FORWARD + BACKWARD = WheeledActionType.BACKWARD + CLOCK = WheeledActionType.CLOCK + ANTICLOCK = WheeledActionType.ANTICLOCK + DO = WheeledActionType.DO + turn_actions = [CLOCK, ANTICLOCK] + + # Get recent actions to avoid repeating them + recent_actions = partition_state['actions'][-5:] if len(partition_state['actions']) >= 5 else partition_state['actions'] + recent_action_values = [int(a[0]) if hasattr(a, '__getitem__') else int(a) for a in recent_actions] + + intervention_action = None + + if stuck_info['reason'] == 'repetitive_actions': + # Agent is repeating actions - try something different + print(f" πŸ”§ INTERVENTION: Breaking repetitive pattern") + + # If mostly moving forward/backward, try turning (for wheeled) or DO action + if FORWARD in recent_action_values or BACKWARD in recent_action_values: + if turn_actions and CLOCK not in recent_action_values: + intervention_action = CLOCK + elif turn_actions and ANTICLOCK not in recent_action_values: + intervention_action = ANTICLOCK + else: + intervention_action = DO # Try digging/dumping + else: + # Try moving if not moving recently + intervention_action = FORWARD + + elif stuck_info['reason'] == 'low_reward' or stuck_info['reason'] == 'no_progress': + # Agent making no progress - try strategic actions + print(f" πŸ”§ INTERVENTION: Addressing low progress") + + # Cycle through: turn -> move -> dig -> turn + step_in_cycle = partition_state['step_count'] % 4 + + if step_in_cycle == 0 and turn_actions: + intervention_action = CLOCK + elif step_in_cycle == 1: + intervention_action = FORWARD + elif step_in_cycle == 2: + intervention_action = DO + else: + if turn_actions: + intervention_action = ANTICLOCK + else: + intervention_action = BACKWARD + + else: + # Default intervention - try a different direction + print(f" πŸ”§ INTERVENTION: General unstuck attempt") + if turn_actions: + intervention_action = CLOCK # Turn to face new direction + else: + intervention_action = BACKWARD # Back up + + # Convert to proper format + intervention_action_val = int(intervention_action) + print(f" πŸ”§ INTERVENTION ACTION: {intervention_action_val} (reason: {stuck_info['reason']})") + + return jnp.array([intervention_action_val], dtype=jnp.int32) + +def should_intervene(partition_state, active_partitions, intervention_frequency=15): + """ + Decide if intervention should be considered for this partition. + + Args: + partition_state: Current partition state + active_partitions: List of active partitions + intervention_frequency: How often to check for intervention + + Returns: + bool: True if intervention should be considered + """ + # Only check every N steps to avoid over-intervening + if partition_state['step_count'] % intervention_frequency != 0: + return False + + # Don't intervene too early + if partition_state['step_count'] < 10: + return False + + # Check if stuck + stuck_info = detect_stuck_excavator(partition_state) + return stuck_info['is_stuck'] + +def check_partition_has_targets(target_map, region_coords, min_targets=1): + """ + Check if a partition region contains dig targets. + + Args: + target_map: Global target map (jnp.ndarray) + region_coords: (y_start, x_start, y_end, x_end) tuple + min_targets: Minimum number of dig targets required + + Returns: + dict: Information about targets in the partition + """ + y_start, x_start, y_end, x_end = region_coords + + # Extract the region from the target map + region_slice = (slice(y_start, y_end + 1), slice(x_start, x_end + 1)) + region_data = target_map[region_slice] + + # Count dig targets (-1) and dump targets (1) + dig_targets = jnp.sum(region_data == -1) + dump_targets = jnp.sum(region_data == 1) + free_space = jnp.sum(region_data == 0) + + has_enough_targets = dig_targets >= min_targets + + return { + 'has_targets': has_enough_targets, + 'dig_targets': int(dig_targets), + 'dump_targets': int(dump_targets), + 'free_space': int(free_space), + 'total_cells': region_data.size, + 'region_coords': region_coords + } + +def filter_empty_partitions(partitions, target_map, min_targets=1): + """ + Filter out partitions that don't have enough dig targets. + + Args: + partitions: List of partition dictionaries + target_map: Global target map from env_manager + min_targets: Minimum dig targets required per partition + + Returns: + tuple: (filtered_partitions, partition_stats) + """ + valid_partitions = [] + partition_stats = [] + + for partition in partitions: + region_coords = partition['region_coords'] + stats = check_partition_has_targets(target_map, region_coords, min_targets) + partition_stats.append(stats) + + if stats['has_targets']: + valid_partitions.append(partition) + print(f" Partition {partition['id']}: {stats['dig_targets']} dig targets, {stats['dump_targets']} dump targets") + else: + print(f" Partition {partition['id']}: Only {stats['dig_targets']} dig targets (minimum: {min_targets}) - SKIPPED") + + return valid_partitions, partition_stats + +def compute_random_subtasks_validated(ORIGINAL_MAP_SIZE, NUM_PARTITIONS, target_map, + seed=None, min_targets=1, max_attempts=100): + """ + Compute random subtasks and validate they contain dig targets. + Will retry with different random splits if partitions are empty. + + Args: + ORIGINAL_MAP_SIZE: Size of the original map (64 or 128) + NUM_PARTITIONS: Number of partitions (1, 2, or 4) + target_map: Global target map from env_manager + seed: Random seed for reproducibility + min_targets: Minimum dig targets required per partition + max_attempts: Maximum attempts to find valid partitions + + Returns: + List of valid partition dictionaries + """ + if seed is not None: + random.seed(seed) + + if ORIGINAL_MAP_SIZE not in [64, 128]: + raise ValueError(f"Unsupported ORIGINAL_MAP_SIZE: {ORIGINAL_MAP_SIZE}. Must be 64 or 128.") + + if NUM_PARTITIONS not in [1, 2, 4]: + raise ValueError("Invalid number of partitions. Must be 1, 2 or 4.") + + min_targets = int(jnp.sum(target_map == -1) * 0.3) + print(f"Minimum targets per partition: {min_targets}") + + print(f"\nGenerating {NUM_PARTITIONS} random partitions with target validation...") + + # For single partition, just return the full map if it has targets + if NUM_PARTITIONS == 1: + full_partition = [{ + 'id': 0, + 'region_coords': (0, 0, ORIGINAL_MAP_SIZE - 1, ORIGINAL_MAP_SIZE - 1), + 'start_pos': (ORIGINAL_MAP_SIZE // 2, ORIGINAL_MAP_SIZE // 2), + 'start_angle': 0, + 'status': 'pending' + }] + + valid_partitions, _ = filter_empty_partitions(full_partition, target_map, min_targets) + if valid_partitions: + return valid_partitions + else: + raise ValueError("The entire map has no dig targets!") + + best_partitions = [] + best_valid_count = 0 + + for attempt in range(max_attempts): + print(f" Attempt {attempt + 1}/{max_attempts}...") + + # Generate random partitions + if NUM_PARTITIONS == 2: + partitions = _generate_random_2_partitions(ORIGINAL_MAP_SIZE) + elif NUM_PARTITIONS == 4: + partitions = _generate_random_4_partitions(ORIGINAL_MAP_SIZE) + + # Validate partitions + valid_partitions, stats = filter_empty_partitions(partitions, target_map, min_targets) + + print(f" Generated {len(partitions)} partitions, {len(valid_partitions)} valid") + + # Keep track of the best result so far + if len(valid_partitions) > best_valid_count: + best_partitions = valid_partitions.copy() + best_valid_count = len(valid_partitions) + + # If we got enough valid partitions, we're done + #if len(valid_partitions) >= max(1, NUM_PARTITIONS // 2): # At least half the requested partitions + if len(valid_partitions) == NUM_PARTITIONS: + print(f" Found {len(valid_partitions)} valid partitions after {attempt + 1} attempts") + return valid_partitions + + # If we couldn't find enough good partitions, return the best we found + if best_partitions: + print(f" Returning best result: {len(best_partitions)} valid partitions (out of {NUM_PARTITIONS} requested)") + return best_partitions + else: + # Fallback: return full map as single partition + print(f" No valid partitions found, falling back to single full-map partition") + return [{ + 'id': 0, + 'region_coords': (0, 0, ORIGINAL_MAP_SIZE - 1, ORIGINAL_MAP_SIZE - 1), + 'start_pos': (ORIGINAL_MAP_SIZE // 2, ORIGINAL_MAP_SIZE // 2), + 'start_angle': 0, + 'status': 'pending' + }] + +def _generate_random_2_partitions(ORIGINAL_MAP_SIZE): + """Generate 2 random partitions.""" + is_vertical = random.choice([True, False]) + + if is_vertical: + # Vertical split + min_split = int(ORIGINAL_MAP_SIZE * 0.3) + max_split = int(ORIGINAL_MAP_SIZE * 0.7) + split_x = random.randint(min_split, max_split) + + start_y = ORIGINAL_MAP_SIZE // 2 + start_x1 = split_x // 2 + start_x2 = split_x + (ORIGINAL_MAP_SIZE - split_x) // 2 + + return [ + { + 'id': 0, + 'region_coords': (0, 0, ORIGINAL_MAP_SIZE - 1, split_x - 1), + 'start_pos': (start_y, start_x1), + 'start_angle': 0, + 'status': 'pending' + }, + { + 'id': 1, + 'region_coords': (0, split_x, ORIGINAL_MAP_SIZE - 1, ORIGINAL_MAP_SIZE - 1), + 'start_pos': (start_y, start_x2), + 'start_angle': 0, + 'status': 'pending' + } + ] + else: + # Horizontal split + min_split = int(ORIGINAL_MAP_SIZE * 0.3) + max_split = int(ORIGINAL_MAP_SIZE * 0.7) + split_y = random.randint(min_split, max_split) + + start_x = ORIGINAL_MAP_SIZE // 2 + start_y1 = split_y // 2 + start_y2 = split_y + (ORIGINAL_MAP_SIZE - split_y) // 2 + + return [ + { + 'id': 0, + 'region_coords': (0, 0, split_y - 1, ORIGINAL_MAP_SIZE - 1), + 'start_pos': (start_y1, start_x), + 'start_angle': 0, + 'status': 'pending' + }, + { + 'id': 1, + 'region_coords': (split_y, 0, ORIGINAL_MAP_SIZE - 1, ORIGINAL_MAP_SIZE - 1), + 'start_pos': (start_y2, start_x), + 'start_angle': 0, + 'status': 'pending' + } + ] + +def _generate_random_4_partitions(ORIGINAL_MAP_SIZE): + """Generate 4 random partitions in a 2x2 grid.""" + min_split = int(ORIGINAL_MAP_SIZE * 0.3) + max_split = int(ORIGINAL_MAP_SIZE * 0.7) + + split_x = random.randint(min_split, max_split) + split_y = random.randint(min_split, max_split) + + # Calculate start positions for each quadrant + start_x1 = split_x // 2 + start_x2 = split_x + (ORIGINAL_MAP_SIZE - split_x) // 2 + start_y1 = split_y // 2 + start_y2 = split_y + (ORIGINAL_MAP_SIZE - split_y) // 2 + + return [ + { + 'id': 0, + 'region_coords': (0, 0, split_y - 1, split_x - 1), + 'start_pos': (start_y1, start_x1), + 'start_angle': 0, + 'status': 'pending' + }, + { + 'id': 1, + 'region_coords': (0, split_x, split_y - 1, ORIGINAL_MAP_SIZE - 1), + 'start_pos': (start_y1, start_x2), + 'start_angle': 0, + 'status': 'pending' + }, + { + 'id': 2, + 'region_coords': (split_y, 0, ORIGINAL_MAP_SIZE - 1, split_x - 1), + 'start_pos': (start_y2, start_x1), + 'start_angle': 0, + 'status': 'pending' + }, + { + 'id': 3, + 'region_coords': (split_y, split_x, ORIGINAL_MAP_SIZE - 1, ORIGINAL_MAP_SIZE - 1), + 'start_pos': (start_y2, start_x2), + 'start_angle': 0, + 'status': 'pending' + } + ] + +def extract_positions(state): + """ + Extract the current base position and target position from the game state. + + Args: + state: The current game state object. + + Returns: + A tuple containing: + - current_position: A dictionary with the current base position (x, y). + - target_position: A dictionary with the target position (x, y), or None if not available. + + """ + + # Extract th11e current base position + current_position = { + "x": state.agent.agent_state.pos_base[0][0], + "y": state.agent.agent_state.pos_base[0][1] + } + + # Extract the target position from the target_map if available + target_positions = [] + + for x in range(state.world.target_map.map.shape[1]): # Iterate over rows + for y in range(state.world.target_map.map.shape[2]): # Iterate over columns + if state.world.target_map.map[0, x, y] == -1: # Access the value at (0, x, y) + target_positions.append((x, y)) + + # # Convert positions to tuples + start = (int(current_position["x"]), int(current_position["y"])) + + return start, target_positions + +def find_nearest_target(start, target_positions): + """ + Find the nearest target position to the starting point. + + Args: + start (tuple): The starting position as (x, y). + target_positions (list of tuples): A list of target positions as (x, y). + + Returns: + tuple: The nearest target position as (x, y), or None if the list is empty. + """ + if not target_positions: + return None + + # Calculate the Euclidean distance to each target and find the nearest one + nearest_target = min(target_positions, key=lambda target: (target[0] - start[0])**2 + (target[1] - start[1])**2) + return nearest_target \ No newline at end of file diff --git a/utils/models.py b/utils/models.py index 2afd483..71427c8 100644 --- a/utils/models.py +++ b/utils/models.py @@ -4,7 +4,9 @@ import flax.linen as nn from typing import Sequence, Union from terra.actions import TrackedAction, WheeledAction -from terra.env import TerraEnvBatch +from terra.env import TerraEnvBatch, TerraEnv +from terra.config import BatchConfig + from functools import partial @@ -58,6 +60,55 @@ def get_model_ready(rng, config, env: TerraEnvBatch, speed=False): return model, params +def get_model_ready(rng, config, env: TerraEnv, speed=False): + """Instantiate a model according to obs shape of environment. Using TerraEnv instead of TerraEnvBatch.""" + + num_embeddings_agent = 64 + batch_cfg = BatchConfig() # Get default config + + # Override with small map dimensions + map_width = 64 + map_height = 64 + action_type = batch_cfg.action_type + num_state_obs = batch_cfg.agent.num_state_obs + angles_cabin = batch_cfg.agent.angles_cabin + + jax.debug.print("num_embeddings_agent = {x}", x=num_embeddings_agent) + map_min_max = ( + tuple(config["maps_net_normalization_bounds"]) + if not config["clip_action_maps"] + else (-1, 1) + ) + jax.debug.print("map normalization min max = {x}", x=map_min_max) + model = SimplifiedCoupledCategoricalNet( + num_prev_actions=config["num_prev_actions"], + num_embeddings_agent=num_embeddings_agent, + map_min_max=map_min_max, + local_map_min_max=tuple(config["local_map_normalization_bounds"]), + loaded_max=config["loaded_max"], + action_type=action_type, + ) + + obs = [ + jnp.zeros((config["num_envs"], num_state_obs)), + jnp.zeros((config["num_envs"], angles_cabin)), + jnp.zeros((config["num_envs"], angles_cabin)), + jnp.zeros((config["num_envs"], angles_cabin)), + jnp.zeros((config["num_envs"], angles_cabin)), + jnp.zeros((config["num_envs"], angles_cabin)), + jnp.zeros((config["num_envs"], angles_cabin)), + jnp.zeros((config["num_envs"], map_width, map_height)), + jnp.zeros((config["num_envs"], map_width, map_height)), + jnp.zeros((config["num_envs"], map_width, map_height)), + jnp.zeros((config["num_envs"], config["num_prev_actions"])), + ] + params = model.init(rng, obs) + + print(f"Model: {sum(x.size for x in jax.tree_leaves(params)):,} parameters") + return model, params + + + def load_neural_network(config, env): """Load neural network model based on config and environment.""" rng = jax.random.PRNGKey(0)