Skip to content

Commit 2e3eb1a

Browse files
authored
Frame ID map for multi-tab support (API) (#174)
1 parent f68e86c commit 2e3eb1a

File tree

5 files changed

+538
-5
lines changed

5 files changed

+538
-5
lines changed

.changeset/vegan-intrepid-mustang.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"stagehand": patch
3+
---
4+
5+
Added frame_id_map to support multi-tab handling on API

stagehand/context.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ def __init__(self, context: BrowserContext, stagehand):
1414
# Use a weak key dictionary to map Playwright Pages to our StagehandPage wrappers
1515
self.page_map = weakref.WeakKeyDictionary()
1616
self.active_stagehand_page = None
17+
# Map frame IDs to StagehandPage instances
18+
self.frame_id_map = {}
1719

1820
async def new_page(self) -> StagehandPage:
1921
pw_page: Page = await self._context.new_page()
@@ -23,9 +25,13 @@ async def new_page(self) -> StagehandPage:
2325

2426
async def create_stagehand_page(self, pw_page: Page) -> StagehandPage:
2527
# Create a StagehandPage wrapper for the given Playwright page
26-
stagehand_page = StagehandPage(pw_page, self.stagehand)
28+
stagehand_page = StagehandPage(pw_page, self.stagehand, self)
2729
await self.inject_custom_scripts(pw_page)
2830
self.page_map[pw_page] = stagehand_page
31+
32+
# Initialize frame tracking for this page
33+
await self._attach_frame_navigated_listener(pw_page, stagehand_page)
34+
2935
return stagehand_page
3036

3137
async def inject_custom_scripts(self, pw_page: Page):
@@ -69,9 +75,21 @@ def set_active_page(self, stagehand_page: StagehandPage):
6975
def get_active_page(self) -> StagehandPage:
7076
return self.active_stagehand_page
7177

78+
def register_frame_id(self, frame_id: str, page: StagehandPage):
79+
"""Register a frame ID to StagehandPage mapping."""
80+
self.frame_id_map[frame_id] = page
81+
82+
def unregister_frame_id(self, frame_id: str):
83+
"""Unregister a frame ID from the mapping."""
84+
if frame_id in self.frame_id_map:
85+
del self.frame_id_map[frame_id]
86+
87+
def get_stagehand_page_by_frame_id(self, frame_id: str) -> StagehandPage:
88+
"""Get StagehandPage by frame ID."""
89+
return self.frame_id_map.get(frame_id)
90+
7291
@classmethod
7392
async def init(cls, context: BrowserContext, stagehand):
74-
stagehand.logger.debug("StagehandContext.init() called", category="context")
7593
instance = cls(context, stagehand)
7694
# Pre-initialize StagehandPages for any existing pages
7795
stagehand.logger.debug(
@@ -150,3 +168,67 @@ async def wrapped_pages():
150168

151169
return wrapped_pages
152170
return attr
171+
172+
async def _attach_frame_navigated_listener(
173+
self, pw_page: Page, stagehand_page: StagehandPage
174+
):
175+
"""
176+
Attach CDP listener for frame navigation events to track frame IDs.
177+
This mirrors the TypeScript implementation's frame tracking.
178+
"""
179+
try:
180+
# Create CDP session for the page
181+
cdp_session = await self._context.new_cdp_session(pw_page)
182+
await cdp_session.send("Page.enable")
183+
184+
# Get the current root frame ID
185+
frame_tree = await cdp_session.send("Page.getFrameTree")
186+
root_frame_id = frame_tree.get("frameTree", {}).get("frame", {}).get("id")
187+
188+
if root_frame_id:
189+
# Initialize the page with its frame ID
190+
stagehand_page.update_root_frame_id(root_frame_id)
191+
self.register_frame_id(root_frame_id, stagehand_page)
192+
193+
# Set up event listener for frame navigation
194+
def on_frame_navigated(params):
195+
"""Handle Page.frameNavigated events"""
196+
frame = params.get("frame", {})
197+
frame_id = frame.get("id")
198+
parent_id = frame.get("parentId")
199+
200+
# Only track root frames (no parent)
201+
if not parent_id and frame_id:
202+
# Skip if it's the same frame ID
203+
if frame_id == stagehand_page.frame_id:
204+
return
205+
206+
# Unregister old frame ID if exists
207+
old_id = stagehand_page.frame_id
208+
if old_id:
209+
self.unregister_frame_id(old_id)
210+
211+
# Register new frame ID
212+
self.register_frame_id(frame_id, stagehand_page)
213+
stagehand_page.update_root_frame_id(frame_id)
214+
215+
self.stagehand.logger.debug(
216+
f"Frame navigated from {old_id} to {frame_id}",
217+
category="context",
218+
)
219+
220+
# Register the event listener
221+
cdp_session.on("Page.frameNavigated", on_frame_navigated)
222+
223+
# Clean up frame ID when page closes
224+
def on_page_close():
225+
if stagehand_page.frame_id:
226+
self.unregister_frame_id(stagehand_page.frame_id)
227+
228+
pw_page.once("close", on_page_close)
229+
230+
except Exception as e:
231+
self.stagehand.logger.error(
232+
f"Failed to attach frame navigation listener: {str(e)}",
233+
category="context",
234+
)

stagehand/page.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
import time
13
from typing import Optional, Union
24

35
from playwright.async_api import CDPSession, Page
@@ -26,16 +28,29 @@ class StagehandPage:
2628

2729
_cdp_client: Optional[CDPSession] = None
2830

29-
def __init__(self, page: Page, stagehand_client):
31+
def __init__(self, page: Page, stagehand_client, context=None):
3032
"""
3133
Initialize a StagehandPage instance.
3234
3335
Args:
3436
page (Page): The underlying Playwright page.
3537
stagehand_client: The client used to interface with the Stagehand server.
38+
context: The StagehandContext instance (optional).
3639
"""
3740
self._page = page
3841
self._stagehand = stagehand_client
42+
self._context = context
43+
self._frame_id = None
44+
45+
@property
46+
def frame_id(self) -> Optional[str]:
47+
"""Get the current root frame ID."""
48+
return self._frame_id
49+
50+
def update_root_frame_id(self, new_id: str):
51+
"""Update the root frame ID."""
52+
self._frame_id = new_id
53+
self._stagehand.logger.debug(f"Updated frame ID to {new_id}", category="page")
3954

4055
# TODO try catch here
4156
async def ensure_injection(self):
@@ -98,6 +113,10 @@ async def goto(
98113
if options:
99114
payload["options"] = options
100115

116+
# Add frame ID if available
117+
if self._frame_id:
118+
payload["frameId"] = self._frame_id
119+
101120
lock = self._stagehand._get_lock_for_session()
102121
async with lock:
103122
result = await self._stagehand._execute("navigate", payload)
@@ -168,6 +187,10 @@ async def act(
168187
result = await self._act_handler.act(payload)
169188
return result
170189

190+
# Add frame ID if available
191+
if self._frame_id:
192+
payload["frameId"] = self._frame_id
193+
171194
lock = self._stagehand._get_lock_for_session()
172195
async with lock:
173196
result = await self._stagehand._execute("act", payload)
@@ -237,6 +260,10 @@ async def observe(
237260

238261
return result
239262

263+
# Add frame ID if available
264+
if self._frame_id:
265+
payload["frameId"] = self._frame_id
266+
240267
lock = self._stagehand._get_lock_for_session()
241268
async with lock:
242269
result = await self._stagehand._execute("observe", payload)
@@ -361,6 +388,10 @@ async def extract(
361388
return result.data
362389

363390
# Use API
391+
# Add frame ID if available
392+
if self._frame_id:
393+
payload["frameId"] = self._frame_id
394+
364395
lock = self._stagehand._get_lock_for_session()
365396
async with lock:
366397
result_dict = await self._stagehand._execute("extract", payload)
@@ -487,8 +518,6 @@ async def _wait_for_settled_dom(self, timeout_ms: int = None):
487518
timeout_ms (int, optional): Maximum time to wait in milliseconds.
488519
If None, uses the stagehand client's dom_settle_timeout_ms.
489520
"""
490-
import asyncio
491-
import time
492521

493522
timeout = timeout_ms or getattr(self._stagehand, "dom_settle_timeout_ms", 30000)
494523
client = await self.get_cdp_client()

0 commit comments

Comments
 (0)