Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ dev = [
hint = [
"sentence-transformers>=5.0.0",
]
retrievers = [
"bm25s>=0.2.14",
"langchain>=0.3.27",
]


[project.scripts]
Expand Down
89 changes: 89 additions & 0 deletions src/agentlab/agents/bm25_agent/README.md
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have the the bm25_agent and embedding_agent in the focus_agent subdirectory, as they are related baselines. @recursix Do you have any thoughts about this?

Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# BM25Agent

A retrieval-augmented agent that uses BM25 (Best Matching 25) algorithm to filter and retrieve the most relevant parts of the accessibility tree (AXTree) based on the current goal and task history.

## Overview

``BM25Agent`` extends ``GenericAgent`` with intelligent content retrieval capabilities. Instead of processing the entire accessibility tree, it chunks the content and uses BM25 ranking to retrieve only the most relevant sections, reducing token usage and improving focus on task-relevant elements.

## Key Features

- **BM25-based retrieval**: Uses the BM25 algorithm to rank and retrieve relevant content chunks
- **Token-aware chunking**: Splits accessibility trees using tiktoken for optimal token usage
- **Configurable parameters**: Adjustable chunk size, overlap, and top-k retrieval
- **History integration**: Can optionally include task history in retrieval queries
- **Memory efficient**: Reduces context size by filtering irrelevant content

## Architecture

```text
Query (goal + history) → BM25 Retriever → Top-K Chunks → LLM → Action
AXTree
```

## Usage

### Basic Configuration

```python
from agentlab.agents.bm25_agent import BM25RetrieverAgent, BM25RetrieverAgentArgs
from agentlab.agents.bm25_agent.bm25_retriever import BM25RetrieverArgs
from agentlab.agents.bm25_agent.bm25_retriever_agent import BM25RetrieverAgentFlags

# Configure retriever parameters
retriever_args = BM25RetrieverArgs(
chunk_size=200, # Tokens per chunk
overlap=10, # Token overlap between chunks
top_k=10, # Number of chunks to retrieve
use_recursive_text_splitter=False # Use Langchain text splitter
)

# Configure agent flags
retriever_flags = BM25RetrieverAgentFlags(
use_history=True # Include task history in queries
)

# Create agent
agent_args = BM25RetrieverAgentArgs(
chat_model_args=your_chat_model_args,
flags=your_flags,
retriever_args=retriever_args,
retriever_flags=retriever_flags
)

agent = agent_args.make_agent()
```

### Pre-configured Agents

```python
from agentlab.agents.bm25_agent.agent_configs import (
BM25_RETRIEVER_AGENT, # Chunk size is 200 tokens
BM25_RETRIEVER_AGENT_100 # Chunk size is 100 tokens
)

# Use default configuration
agent = BM25_RETRIEVER_AGENT.make_agent()
```

## Configuration Parameters

### BM25RetrieverArgs

- `chunk_size` (int, default=100): Number of tokens per chunk
- `overlap` (int, default=10): Token overlap between consecutive chunks
- `top_k` (int, default=5): Number of most relevant chunks to retrieve
- `use_recursive_text_splitter` (bool, default=False): Use LangChain's recursive text splitter. Using this text splitter will override the ``chunk_size`` an ``overlap`` parameters.

### BM25RetrieverAgentFlags

- `use_history` (bool, default=False): Include interaction history in retrieval queries

## Citation

If you use this agent in your work, please consider citing:

```bibtex

```
8 changes: 8 additions & 0 deletions src/agentlab/agents/bm25_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .bm25_retriever_agent import BM25RetrieverAgent, BM25RetrieverAgentArgs
from .bm25_retriever import BM25RetrieverArgs
from .agent_configs import (
BM25_RETRIEVER_AGENT,
BM25_RETRIEVER_AGENT_100,
BM25_RETRIEVER_AGENT_50,
BM25_RETRIEVER_AGENT_500,
)
68 changes: 68 additions & 0 deletions src/agentlab/agents/bm25_agent/agent_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_4o
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT

from .bm25_retriever import BM25RetrieverArgs
from .bm25_retriever_agent import BM25RetrieverAgentArgs, BM25RetrieverAgentFlags

FLAGS_GPT_4o = FLAGS_GPT_4o.copy()
FLAGS_GPT_4o.obs.use_think_history = True

BM25_RETRIEVER_AGENT = BM25RetrieverAgentArgs(
agent_name="BM25RetrieverAgent-4.1",
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"],
flags=FLAGS_GPT_4o,
retriever_args=BM25RetrieverArgs(
top_k=10,
chunk_size=200,
overlap=10,
use_recursive_text_splitter=False,
),
retriever_flags=BM25RetrieverAgentFlags(
use_history=True,
),
)

BM25_RETRIEVER_AGENT_100 = BM25RetrieverAgentArgs(
agent_name="BM25RetrieverAgent-4.1-100",
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"],
flags=FLAGS_GPT_4o,
retriever_args=BM25RetrieverArgs(
top_k=10,
chunk_size=100,
overlap=10,
use_recursive_text_splitter=False,
),
retriever_flags=BM25RetrieverAgentFlags(
use_history=True,
),
)

BM25_RETRIEVER_AGENT_50 = BM25RetrieverAgentArgs(
agent_name="BM25RetrieverAgent-4.1-50",
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"],
flags=FLAGS_GPT_4o,
retriever_args=BM25RetrieverArgs(
top_k=10,
chunk_size=50,
overlap=5,
use_recursive_text_splitter=False,
),
retriever_flags=BM25RetrieverAgentFlags(
use_history=True,
),
)

BM25_RETRIEVER_AGENT_500 = BM25RetrieverAgentArgs(
agent_name="BM25RetrieverAgent-4.1-500",
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"],
flags=FLAGS_GPT_4o,
retriever_args=BM25RetrieverArgs(
top_k=10,
chunk_size=500,
overlap=10,
use_recursive_text_splitter=False,
),
retriever_flags=BM25RetrieverAgentFlags(
use_history=True,
),
)
72 changes: 72 additions & 0 deletions src/agentlab/agents/bm25_agent/bm25_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import re
from dataclasses import dataclass

try:
import bm25s
except ImportError:
raise ImportError("bm25s is not installed. Please install it using `pip agentlab[retrievers]`.")
import tiktoken # Added import for tiktoken

from .utils import get_chunks_from_tokenizer


def count_tokens(text: str) -> int:
"""Count the number of tokens in the text using tiktoken for GPT-4."""
encoding = tiktoken.encoding_for_model("gpt-4")
tokens = encoding.encode(text)
return len(tokens)


@dataclass
class BM25RetrieverArgs:
chunk_size: int = 100
overlap: int = 10
top_k: int = 5
use_recursive_text_splitter: bool = False


class BM25SRetriever:
"""Simple retriever using BM25S to retrieve the most relevant lines"""

def __init__(
self,
tree: str,
chunk_size: int,
overlap: int,
top_k: int,
use_recursive_text_splitter: bool,
):
self.chunk_size = chunk_size
self.overlap = overlap
self.top_k = top_k
self.use_recursive_text_splitter = use_recursive_text_splitter
corpus = get_chunks_from_tokenizer(tree)
self.retriever = bm25s.BM25(corpus=corpus)
tokenized_corpus = bm25s.tokenize(corpus)
self.retriever.index(tokenized_corpus)

def retrieve(self, query):
tokenized_query = bm25s.tokenize(query)
if self.top_k > len(self.retriever.corpus):
results, _ = self.retriever.retrieve(
query_tokens=tokenized_query, k=len(self.retriever.corpus)
)
else:
results, _ = self.retriever.retrieve(query_tokens=tokenized_query, k=self.top_k)
return [str(res) for res in results[0]]

def create_text_chunks(self, axtree, chunk_size=200, overlap=50):
if self.use_recursive_text_splitter:
try:
from langchain.text_splitter import (
RecursiveCharacterTextSplitter,
)
except ImportError:
raise ImportError(
"langchain is not installed. Please install it using `pip agentlab[retrievers]`."
)

text_splitter = RecursiveCharacterTextSplitter()
return text_splitter.split_text(axtree)
else:
return get_chunks_from_tokenizer(axtree, self.chunk_size, self.overlap)
92 changes: 92 additions & 0 deletions src/agentlab/agents/bm25_agent/bm25_retriever_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from copy import copy
from dataclasses import dataclass

from browsergym.experiments import Agent

import agentlab.agents.dynamic_prompting as dp
from agentlab.agents.generic_agent.generic_agent import GenericAgent, GenericAgentArgs
from agentlab.agents.generic_agent.generic_agent_prompt import GenericPromptFlags
from agentlab.llm.chat_api import ChatModelArgs

from .bm25_retriever import BM25RetrieverArgs, BM25SRetriever


@dataclass
class BM25RetrieverAgentFlags:
use_history: bool = False


@dataclass
class BM25RetrieverAgentArgs(GenericAgentArgs):
flags: GenericPromptFlags = None
chat_model_args: ChatModelArgs = None
retriever_args: BM25RetrieverArgs = None
retriever_flags: BM25RetrieverAgentFlags = None
max_retry: int = 4
agent_name: str = None

def __post_init__(self):
if self.agent_name is None:
self.agent_name = f"BM25RetrieverAgent-{self.chat_model_args.model_name}".replace(
"/", "_"
)

def make_agent(self) -> Agent:
return BM25RetrieverAgent(
self.chat_model_args,
self.flags,
self.retriever_args,
self.retriever_flags,
self.max_retry,
)


class BM25RetrieverAgent(GenericAgent):
def __init__(
self,
chat_model_args: ChatModelArgs,
flags,
retriever_args: BM25RetrieverArgs,
retriever_flags: BM25RetrieverAgentFlags,
max_retry: int = 4,
):
super().__init__(chat_model_args, flags, max_retry)
self.retriever_args = retriever_args
self.retriever_flags = retriever_flags

def get_new_obs(self, obs: dict) -> str:
query = (
obs["goal"] + "\n" + obs["history"] if self.retriever_flags.use_history else obs["goal"]
)
axtree_txt: str = obs["axtree_txt"] if self.flags.obs.use_ax_tree else obs["pruned_dom"]
# Initialize BM25 retriever with the current observation
retriever = BM25SRetriever(
axtree_txt,
chunk_size=self.retriever_args.chunk_size,
overlap=self.retriever_args.overlap,
top_k=self.retriever_args.top_k,
use_recursive_text_splitter=self.retriever_args.use_recursive_text_splitter,
)
# Retrieve the most relevant chunks
relevant_chunks = retriever.retrieve(query)
new_tree = ""
for i, chunk in enumerate(relevant_chunks):
new_tree += f"\n\nChunk {i}:\n{chunk}"
return new_tree

def get_action(self, obs: dict):
obs_history_copy = copy(self.obs_history)
obs_history_copy.append(obs)
history = dp.History(
history_obs=obs_history_copy,
actions=self.actions,
memories=self.memories,
thoughts=self.thoughts,
flags=self.flags.obs,
)
obs["history"] = history.prompt
obs["axtree_txt"] = self.get_new_obs(obs)
action, info = super().get_action(obs)
info.extra_info["pruned_tree"] = obs["axtree_txt"]
info.extra_info["retriever_query"] = obs["goal"] + "\n" + obs["history"]
return action, info
14 changes: 14 additions & 0 deletions src/agentlab/agents/bm25_agent/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import tiktoken

encoder = tiktoken.encoding_for_model("gpt-4o")
tokenizer = tiktoken.get_encoding(encoder.name)


def get_chunks_from_tokenizer(axtree, chunk_size=200, overlap=50):
all_text = tokenizer.encode(axtree)
chunks = []
for i in range(0, len(all_text), chunk_size - overlap):
tokens = all_text[i : i + chunk_size]
chunk = tokenizer.decode(tokens)
chunks.append(chunk)
return chunks
Loading
Loading