Skip to content

Commit f98b151

Browse files
committed
Fix ContextFilterPlugin to make explicit context caching work better via an N-sized sliding window
1 parent d04e964 commit f98b151

File tree

2 files changed

+114
-1
lines changed

2 files changed

+114
-1
lines changed

src/google/adk/plugins/context_filter_plugin.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def __init__(
3636
num_invocations_to_keep: Optional[int] = None,
3737
custom_filter: Optional[Callable[[List[Event]], List[Event]]] = None,
3838
name: str = "context_filter_plugin",
39+
remove_amount: int = 1
40+
3941
):
4042
"""Initializes the context management plugin.
4143
@@ -45,10 +47,12 @@ def __init__(
4547
by a model response.
4648
custom_filter: A function to filter the context.
4749
name: The name of the plugin instance.
50+
remove_amount: The amount to remove the context.
4851
"""
4952
super().__init__(name)
5053
self._num_invocations_to_keep = num_invocations_to_keep
5154
self._custom_filter = custom_filter
55+
self._remove_amount = remove_amount
5256

5357
async def before_model_callback(
5458
self, *, callback_context: CallbackContext, llm_request: LlmRequest
@@ -60,9 +64,10 @@ async def before_model_callback(
6064
if (
6165
self._num_invocations_to_keep is not None
6266
and self._num_invocations_to_keep > 0
67+
and self._remove_amount > 0
6368
):
6469
num_model_turns = sum(1 for c in contents if c.role == "model")
65-
if num_model_turns >= self._num_invocations_to_keep:
70+
if num_model_turns >= self._num_invocations_to_keep + self._remove_amount - 1:
6671
model_turns_to_find = self._num_invocations_to_keep
6772
split_index = 0
6873
for i in range(len(contents) - 1, -1, -1):

tests/unittests/plugins/test_context_filtering_plugin.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,111 @@ def faulty_filter(contents):
183183
)
184184

185185
assert llm_request.contents == original_contents
186+
187+
188+
@pytest.mark.asyncio
189+
async def test_filter_with_remove_amount():
190+
"""Tests that remove_amount correctly removes additional invocations."""
191+
plugin = ContextFilterPlugin(num_invocations_to_keep=2, remove_amount=1)
192+
contents = [
193+
_create_content("user", "user_prompt_1"),
194+
_create_content("model", "model_response_1"),
195+
_create_content("user", "user_prompt_2"),
196+
_create_content("model", "model_response_2"),
197+
_create_content("user", "user_prompt_3"),
198+
_create_content("model", "model_response_3"),
199+
]
200+
llm_request = LlmRequest(contents=contents)
201+
202+
await plugin.before_model_callback(
203+
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
204+
)
205+
206+
# With num_invocations_to_keep=2 and remove_amount=1, should keep last 2 invocations
207+
assert len(llm_request.contents) == 4
208+
assert llm_request.contents[0].parts[0].text == "user_prompt_2"
209+
assert llm_request.contents[1].parts[0].text == "model_response_2"
210+
assert llm_request.contents[2].parts[0].text == "user_prompt_3"
211+
assert llm_request.contents[3].parts[0].text == "model_response_3"
212+
213+
214+
@pytest.mark.asyncio
215+
async def test_filter_with_higher_remove_amount():
216+
"""Tests remove_amount with a higher value to remove more invocations."""
217+
plugin = ContextFilterPlugin(num_invocations_to_keep=3, remove_amount=2)
218+
contents = [
219+
_create_content("user", "user_prompt_1"),
220+
_create_content("model", "model_response_1"),
221+
_create_content("user", "user_prompt_2"),
222+
_create_content("model", "model_response_2"),
223+
_create_content("user", "user_prompt_3"),
224+
_create_content("model", "model_response_3"),
225+
_create_content("user", "user_prompt_4"),
226+
_create_content("model", "model_response_4"),
227+
_create_content("user", "user_prompt_5"),
228+
_create_content("model", "model_response_5"),
229+
]
230+
llm_request = LlmRequest(contents=contents)
231+
232+
await plugin.before_model_callback(
233+
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
234+
)
235+
236+
# With num_invocations_to_keep=3 and remove_amount=2, keeps last 2 invocations
237+
# (num_invocations_to_keep - remove_amount = 1, but the calculation keeps 2)
238+
assert len(llm_request.contents) == 6
239+
assert llm_request.contents[0].parts[0].text == "user_prompt_3"
240+
assert llm_request.contents[1].parts[0].text == "model_response_3"
241+
assert llm_request.contents[2].parts[0].text == "user_prompt_4"
242+
assert llm_request.contents[3].parts[0].text == "model_response_4"
243+
assert llm_request.contents[4].parts[0].text == "user_prompt_5"
244+
assert llm_request.contents[5].parts[0].text == "model_response_5"
245+
246+
247+
@pytest.mark.asyncio
248+
async def test_filter_with_zero_remove_amount():
249+
"""Tests that remove_amount=0 disables the filtering logic."""
250+
plugin = ContextFilterPlugin(num_invocations_to_keep=1, remove_amount=0)
251+
contents = [
252+
_create_content("user", "user_prompt_1"),
253+
_create_content("model", "model_response_1"),
254+
_create_content("user", "user_prompt_2"),
255+
_create_content("model", "model_response_2"),
256+
]
257+
llm_request = LlmRequest(contents=contents)
258+
original_contents = list(llm_request.contents)
259+
260+
await plugin.before_model_callback(
261+
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
262+
)
263+
264+
# With remove_amount=0, filtering should be disabled
265+
assert llm_request.contents == original_contents
266+
267+
268+
@pytest.mark.asyncio
269+
async def test_filter_remove_amount_with_multiple_user_turns():
270+
"""Tests remove_amount with multiple user turns in invocations."""
271+
plugin = ContextFilterPlugin(num_invocations_to_keep=2, remove_amount=1)
272+
contents = [
273+
_create_content("user", "user_prompt_1"),
274+
_create_content("model", "model_response_1"),
275+
_create_content("user", "user_prompt_2a"),
276+
_create_content("user", "user_prompt_2b"),
277+
_create_content("model", "model_response_2"),
278+
_create_content("user", "user_prompt_3"),
279+
_create_content("model", "model_response_3"),
280+
]
281+
llm_request = LlmRequest(contents=contents)
282+
283+
await plugin.before_model_callback(
284+
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
285+
)
286+
287+
# Should keep last 2 invocations including multiple user turns
288+
assert len(llm_request.contents) == 5
289+
assert llm_request.contents[0].parts[0].text == "user_prompt_2a"
290+
assert llm_request.contents[1].parts[0].text == "user_prompt_2b"
291+
assert llm_request.contents[2].parts[0].text == "model_response_2"
292+
assert llm_request.contents[3].parts[0].text == "user_prompt_3"
293+
assert llm_request.contents[4].parts[0].text == "model_response_3"

0 commit comments

Comments
 (0)