@@ -183,3 +183,111 @@ def faulty_filter(contents):
183183 )
184184
185185 assert llm_request .contents == original_contents
186+
187+
188+ @pytest .mark .asyncio
189+ async def test_filter_with_remove_amount ():
190+ """Tests that remove_amount correctly removes additional invocations."""
191+ plugin = ContextFilterPlugin (num_invocations_to_keep = 2 , remove_amount = 1 )
192+ contents = [
193+ _create_content ("user" , "user_prompt_1" ),
194+ _create_content ("model" , "model_response_1" ),
195+ _create_content ("user" , "user_prompt_2" ),
196+ _create_content ("model" , "model_response_2" ),
197+ _create_content ("user" , "user_prompt_3" ),
198+ _create_content ("model" , "model_response_3" ),
199+ ]
200+ llm_request = LlmRequest (contents = contents )
201+
202+ await plugin .before_model_callback (
203+ callback_context = Mock (spec = CallbackContext ), llm_request = llm_request
204+ )
205+
206+ # With num_invocations_to_keep=2 and remove_amount=1, should keep last 2 invocations
207+ assert len (llm_request .contents ) == 4
208+ assert llm_request .contents [0 ].parts [0 ].text == "user_prompt_2"
209+ assert llm_request .contents [1 ].parts [0 ].text == "model_response_2"
210+ assert llm_request .contents [2 ].parts [0 ].text == "user_prompt_3"
211+ assert llm_request .contents [3 ].parts [0 ].text == "model_response_3"
212+
213+
214+ @pytest .mark .asyncio
215+ async def test_filter_with_higher_remove_amount ():
216+ """Tests remove_amount with a higher value to remove more invocations."""
217+ plugin = ContextFilterPlugin (num_invocations_to_keep = 3 , remove_amount = 2 )
218+ contents = [
219+ _create_content ("user" , "user_prompt_1" ),
220+ _create_content ("model" , "model_response_1" ),
221+ _create_content ("user" , "user_prompt_2" ),
222+ _create_content ("model" , "model_response_2" ),
223+ _create_content ("user" , "user_prompt_3" ),
224+ _create_content ("model" , "model_response_3" ),
225+ _create_content ("user" , "user_prompt_4" ),
226+ _create_content ("model" , "model_response_4" ),
227+ _create_content ("user" , "user_prompt_5" ),
228+ _create_content ("model" , "model_response_5" ),
229+ ]
230+ llm_request = LlmRequest (contents = contents )
231+
232+ await plugin .before_model_callback (
233+ callback_context = Mock (spec = CallbackContext ), llm_request = llm_request
234+ )
235+
236+ # With num_invocations_to_keep=3 and remove_amount=2, keeps last 2 invocations
237+ # (num_invocations_to_keep - remove_amount = 1, but the calculation keeps 2)
238+ assert len (llm_request .contents ) == 6
239+ assert llm_request .contents [0 ].parts [0 ].text == "user_prompt_3"
240+ assert llm_request .contents [1 ].parts [0 ].text == "model_response_3"
241+ assert llm_request .contents [2 ].parts [0 ].text == "user_prompt_4"
242+ assert llm_request .contents [3 ].parts [0 ].text == "model_response_4"
243+ assert llm_request .contents [4 ].parts [0 ].text == "user_prompt_5"
244+ assert llm_request .contents [5 ].parts [0 ].text == "model_response_5"
245+
246+
247+ @pytest .mark .asyncio
248+ async def test_filter_with_zero_remove_amount ():
249+ """Tests that remove_amount=0 disables the filtering logic."""
250+ plugin = ContextFilterPlugin (num_invocations_to_keep = 1 , remove_amount = 0 )
251+ contents = [
252+ _create_content ("user" , "user_prompt_1" ),
253+ _create_content ("model" , "model_response_1" ),
254+ _create_content ("user" , "user_prompt_2" ),
255+ _create_content ("model" , "model_response_2" ),
256+ ]
257+ llm_request = LlmRequest (contents = contents )
258+ original_contents = list (llm_request .contents )
259+
260+ await plugin .before_model_callback (
261+ callback_context = Mock (spec = CallbackContext ), llm_request = llm_request
262+ )
263+
264+ # With remove_amount=0, filtering should be disabled
265+ assert llm_request .contents == original_contents
266+
267+
268+ @pytest .mark .asyncio
269+ async def test_filter_remove_amount_with_multiple_user_turns ():
270+ """Tests remove_amount with multiple user turns in invocations."""
271+ plugin = ContextFilterPlugin (num_invocations_to_keep = 2 , remove_amount = 1 )
272+ contents = [
273+ _create_content ("user" , "user_prompt_1" ),
274+ _create_content ("model" , "model_response_1" ),
275+ _create_content ("user" , "user_prompt_2a" ),
276+ _create_content ("user" , "user_prompt_2b" ),
277+ _create_content ("model" , "model_response_2" ),
278+ _create_content ("user" , "user_prompt_3" ),
279+ _create_content ("model" , "model_response_3" ),
280+ ]
281+ llm_request = LlmRequest (contents = contents )
282+
283+ await plugin .before_model_callback (
284+ callback_context = Mock (spec = CallbackContext ), llm_request = llm_request
285+ )
286+
287+ # Should keep last 2 invocations including multiple user turns
288+ assert len (llm_request .contents ) == 5
289+ assert llm_request .contents [0 ].parts [0 ].text == "user_prompt_2a"
290+ assert llm_request .contents [1 ].parts [0 ].text == "user_prompt_2b"
291+ assert llm_request .contents [2 ].parts [0 ].text == "model_response_2"
292+ assert llm_request .contents [3 ].parts [0 ].text == "user_prompt_3"
293+ assert llm_request .contents [4 ].parts [0 ].text == "model_response_3"
0 commit comments