-
Notifications
You must be signed in to change notification settings - Fork 90
Add focus_agent, embedding and bm25 agents #302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
imenelydiaker
wants to merge
3
commits into
ServiceNow:main
Choose a base branch
from
imenelydiaker:add-focus-agent
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# BM25Agent | ||
|
||
A retrieval-augmented agent that uses BM25 (Best Matching 25) algorithm to filter and retrieve the most relevant parts of the accessibility tree (AXTree) based on the current goal and task history. | ||
|
||
## Overview | ||
|
||
``BM25Agent`` extends ``GenericAgent`` with intelligent content retrieval capabilities. Instead of processing the entire accessibility tree, it chunks the content and uses BM25 ranking to retrieve only the most relevant sections, reducing token usage and improving focus on task-relevant elements. | ||
|
||
## Key Features | ||
|
||
- **BM25-based retrieval**: Uses the BM25 algorithm to rank and retrieve relevant content chunks | ||
- **Token-aware chunking**: Splits accessibility trees using tiktoken for optimal token usage | ||
- **Configurable parameters**: Adjustable chunk size, overlap, and top-k retrieval | ||
- **History integration**: Can optionally include task history in retrieval queries | ||
- **Memory efficient**: Reduces context size by filtering irrelevant content | ||
|
||
## Architecture | ||
|
||
```text | ||
Query (goal + history) → BM25 Retriever → Top-K Chunks → LLM → Action | ||
↑ | ||
AXTree | ||
``` | ||
|
||
## Usage | ||
|
||
### Basic Configuration | ||
|
||
```python | ||
from agentlab.agents.bm25_agent import BM25RetrieverAgent, BM25RetrieverAgentArgs | ||
from agentlab.agents.bm25_agent.bm25_retriever import BM25RetrieverArgs | ||
from agentlab.agents.bm25_agent.bm25_retriever_agent import BM25RetrieverAgentFlags | ||
|
||
# Configure retriever parameters | ||
retriever_args = BM25RetrieverArgs( | ||
chunk_size=200, # Tokens per chunk | ||
overlap=10, # Token overlap between chunks | ||
top_k=10, # Number of chunks to retrieve | ||
use_recursive_text_splitter=False # Use Langchain text splitter | ||
) | ||
|
||
# Configure agent flags | ||
retriever_flags = BM25RetrieverAgentFlags( | ||
use_history=True # Include task history in queries | ||
) | ||
|
||
# Create agent | ||
agent_args = BM25RetrieverAgentArgs( | ||
chat_model_args=your_chat_model_args, | ||
flags=your_flags, | ||
retriever_args=retriever_args, | ||
retriever_flags=retriever_flags | ||
) | ||
|
||
agent = agent_args.make_agent() | ||
``` | ||
|
||
### Pre-configured Agents | ||
|
||
```python | ||
from agentlab.agents.bm25_agent.agent_configs import ( | ||
BM25_RETRIEVER_AGENT, # Chunk size is 200 tokens | ||
BM25_RETRIEVER_AGENT_100 # Chunk size is 100 tokens | ||
) | ||
|
||
# Use default configuration | ||
agent = BM25_RETRIEVER_AGENT.make_agent() | ||
``` | ||
|
||
## Configuration Parameters | ||
|
||
### BM25RetrieverArgs | ||
|
||
- `chunk_size` (int, default=100): Number of tokens per chunk | ||
- `overlap` (int, default=10): Token overlap between consecutive chunks | ||
- `top_k` (int, default=5): Number of most relevant chunks to retrieve | ||
- `use_recursive_text_splitter` (bool, default=False): Use LangChain's recursive text splitter. Using this text splitter will override the ``chunk_size`` an ``overlap`` parameters. | ||
|
||
### BM25RetrieverAgentFlags | ||
|
||
- `use_history` (bool, default=False): Include interaction history in retrieval queries | ||
|
||
## Citation | ||
|
||
If you use this agent in your work, please consider citing: | ||
|
||
```bibtex | ||
|
||
``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .bm25_retriever_agent import BM25RetrieverAgent, BM25RetrieverAgentArgs | ||
from .bm25_retriever import BM25RetrieverArgs | ||
from .agent_configs import ( | ||
BM25_RETRIEVER_AGENT, | ||
BM25_RETRIEVER_AGENT_100, | ||
BM25_RETRIEVER_AGENT_50, | ||
BM25_RETRIEVER_AGENT_500, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_4o | ||
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT | ||
|
||
from .bm25_retriever import BM25RetrieverArgs | ||
from .bm25_retriever_agent import BM25RetrieverAgentArgs, BM25RetrieverAgentFlags | ||
|
||
FLAGS_GPT_4o = FLAGS_GPT_4o.copy() | ||
FLAGS_GPT_4o.obs.use_think_history = True | ||
|
||
BM25_RETRIEVER_AGENT = BM25RetrieverAgentArgs( | ||
agent_name="BM25RetrieverAgent-4.1", | ||
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"], | ||
flags=FLAGS_GPT_4o, | ||
retriever_args=BM25RetrieverArgs( | ||
top_k=10, | ||
chunk_size=200, | ||
overlap=10, | ||
use_recursive_text_splitter=False, | ||
), | ||
retriever_flags=BM25RetrieverAgentFlags( | ||
use_history=True, | ||
), | ||
) | ||
|
||
BM25_RETRIEVER_AGENT_100 = BM25RetrieverAgentArgs( | ||
agent_name="BM25RetrieverAgent-4.1-100", | ||
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"], | ||
flags=FLAGS_GPT_4o, | ||
retriever_args=BM25RetrieverArgs( | ||
top_k=10, | ||
chunk_size=100, | ||
overlap=10, | ||
use_recursive_text_splitter=False, | ||
), | ||
retriever_flags=BM25RetrieverAgentFlags( | ||
use_history=True, | ||
), | ||
) | ||
|
||
BM25_RETRIEVER_AGENT_50 = BM25RetrieverAgentArgs( | ||
agent_name="BM25RetrieverAgent-4.1-50", | ||
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"], | ||
flags=FLAGS_GPT_4o, | ||
retriever_args=BM25RetrieverArgs( | ||
top_k=10, | ||
chunk_size=50, | ||
overlap=5, | ||
use_recursive_text_splitter=False, | ||
), | ||
retriever_flags=BM25RetrieverAgentFlags( | ||
use_history=True, | ||
), | ||
) | ||
|
||
BM25_RETRIEVER_AGENT_500 = BM25RetrieverAgentArgs( | ||
agent_name="BM25RetrieverAgent-4.1-500", | ||
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"], | ||
flags=FLAGS_GPT_4o, | ||
retriever_args=BM25RetrieverArgs( | ||
top_k=10, | ||
chunk_size=500, | ||
overlap=10, | ||
use_recursive_text_splitter=False, | ||
), | ||
retriever_flags=BM25RetrieverAgentFlags( | ||
use_history=True, | ||
), | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import re | ||
from dataclasses import dataclass | ||
|
||
try: | ||
import bm25s | ||
except ImportError: | ||
raise ImportError("bm25s is not installed. Please install it using `pip agentlab[retrievers]`.") | ||
import tiktoken # Added import for tiktoken | ||
|
||
from .utils import get_chunks_from_tokenizer | ||
|
||
|
||
def count_tokens(text: str) -> int: | ||
"""Count the number of tokens in the text using tiktoken for GPT-4.""" | ||
encoding = tiktoken.encoding_for_model("gpt-4") | ||
tokens = encoding.encode(text) | ||
return len(tokens) | ||
|
||
|
||
@dataclass | ||
class BM25RetrieverArgs: | ||
chunk_size: int = 100 | ||
overlap: int = 10 | ||
top_k: int = 5 | ||
use_recursive_text_splitter: bool = False | ||
|
||
|
||
class BM25SRetriever: | ||
"""Simple retriever using BM25S to retrieve the most relevant lines""" | ||
|
||
def __init__( | ||
self, | ||
tree: str, | ||
chunk_size: int, | ||
overlap: int, | ||
top_k: int, | ||
use_recursive_text_splitter: bool, | ||
): | ||
self.chunk_size = chunk_size | ||
self.overlap = overlap | ||
self.top_k = top_k | ||
self.use_recursive_text_splitter = use_recursive_text_splitter | ||
corpus = get_chunks_from_tokenizer(tree) | ||
self.retriever = bm25s.BM25(corpus=corpus) | ||
tokenized_corpus = bm25s.tokenize(corpus) | ||
self.retriever.index(tokenized_corpus) | ||
|
||
def retrieve(self, query): | ||
tokenized_query = bm25s.tokenize(query) | ||
if self.top_k > len(self.retriever.corpus): | ||
results, _ = self.retriever.retrieve( | ||
query_tokens=tokenized_query, k=len(self.retriever.corpus) | ||
) | ||
else: | ||
results, _ = self.retriever.retrieve(query_tokens=tokenized_query, k=self.top_k) | ||
return [str(res) for res in results[0]] | ||
|
||
def create_text_chunks(self, axtree, chunk_size=200, overlap=50): | ||
if self.use_recursive_text_splitter: | ||
try: | ||
from langchain.text_splitter import ( | ||
RecursiveCharacterTextSplitter, | ||
) | ||
except ImportError: | ||
raise ImportError( | ||
"langchain is not installed. Please install it using `pip agentlab[retrievers]`." | ||
) | ||
|
||
text_splitter = RecursiveCharacterTextSplitter() | ||
return text_splitter.split_text(axtree) | ||
else: | ||
return get_chunks_from_tokenizer(axtree, self.chunk_size, self.overlap) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from copy import copy | ||
from dataclasses import dataclass | ||
|
||
from browsergym.experiments import Agent | ||
|
||
import agentlab.agents.dynamic_prompting as dp | ||
from agentlab.agents.generic_agent.generic_agent import GenericAgent, GenericAgentArgs | ||
from agentlab.agents.generic_agent.generic_agent_prompt import GenericPromptFlags | ||
from agentlab.llm.chat_api import ChatModelArgs | ||
|
||
from .bm25_retriever import BM25RetrieverArgs, BM25SRetriever | ||
|
||
|
||
@dataclass | ||
class BM25RetrieverAgentFlags: | ||
use_history: bool = False | ||
|
||
|
||
@dataclass | ||
class BM25RetrieverAgentArgs(GenericAgentArgs): | ||
flags: GenericPromptFlags = None | ||
chat_model_args: ChatModelArgs = None | ||
retriever_args: BM25RetrieverArgs = None | ||
retriever_flags: BM25RetrieverAgentFlags = None | ||
max_retry: int = 4 | ||
agent_name: str = None | ||
|
||
def __post_init__(self): | ||
if self.agent_name is None: | ||
self.agent_name = f"BM25RetrieverAgent-{self.chat_model_args.model_name}".replace( | ||
"/", "_" | ||
) | ||
|
||
def make_agent(self) -> Agent: | ||
return BM25RetrieverAgent( | ||
self.chat_model_args, | ||
self.flags, | ||
self.retriever_args, | ||
self.retriever_flags, | ||
self.max_retry, | ||
) | ||
|
||
|
||
class BM25RetrieverAgent(GenericAgent): | ||
def __init__( | ||
self, | ||
chat_model_args: ChatModelArgs, | ||
flags, | ||
retriever_args: BM25RetrieverArgs, | ||
retriever_flags: BM25RetrieverAgentFlags, | ||
max_retry: int = 4, | ||
): | ||
super().__init__(chat_model_args, flags, max_retry) | ||
self.retriever_args = retriever_args | ||
self.retriever_flags = retriever_flags | ||
|
||
def get_new_obs(self, obs: dict) -> str: | ||
query = ( | ||
obs["goal"] + "\n" + obs["history"] if self.retriever_flags.use_history else obs["goal"] | ||
) | ||
axtree_txt: str = obs["axtree_txt"] if self.flags.obs.use_ax_tree else obs["pruned_dom"] | ||
# Initialize BM25 retriever with the current observation | ||
retriever = BM25SRetriever( | ||
axtree_txt, | ||
chunk_size=self.retriever_args.chunk_size, | ||
overlap=self.retriever_args.overlap, | ||
top_k=self.retriever_args.top_k, | ||
use_recursive_text_splitter=self.retriever_args.use_recursive_text_splitter, | ||
) | ||
# Retrieve the most relevant chunks | ||
relevant_chunks = retriever.retrieve(query) | ||
new_tree = "" | ||
for i, chunk in enumerate(relevant_chunks): | ||
new_tree += f"\n\nChunk {i}:\n{chunk}" | ||
return new_tree | ||
|
||
def get_action(self, obs: dict): | ||
obs_history_copy = copy(self.obs_history) | ||
obs_history_copy.append(obs) | ||
history = dp.History( | ||
history_obs=obs_history_copy, | ||
actions=self.actions, | ||
memories=self.memories, | ||
thoughts=self.thoughts, | ||
flags=self.flags.obs, | ||
) | ||
obs["history"] = history.prompt | ||
obs["axtree_txt"] = self.get_new_obs(obs) | ||
action, info = super().get_action(obs) | ||
info.extra_info["pruned_tree"] = obs["axtree_txt"] | ||
info.extra_info["retriever_query"] = obs["goal"] + "\n" + obs["history"] | ||
return action, info |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import tiktoken | ||
|
||
encoder = tiktoken.encoding_for_model("gpt-4o") | ||
tokenizer = tiktoken.get_encoding(encoder.name) | ||
|
||
|
||
def get_chunks_from_tokenizer(axtree, chunk_size=200, overlap=50): | ||
all_text = tokenizer.encode(axtree) | ||
chunks = [] | ||
for i in range(0, len(all_text), chunk_size - overlap): | ||
tokens = all_text[i : i + chunk_size] | ||
chunk = tokenizer.decode(tokens) | ||
chunks.append(chunk) | ||
return chunks |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to have the the bm25_agent and embedding_agent in the focus_agent subdirectory, as they are related baselines. @recursix Do you have any thoughts about this?