Skip to content

Commit 05b4481

Browse files
committed
add unit tests
1 parent ae84338 commit 05b4481

File tree

1 file changed

+240
-0
lines changed

1 file changed

+240
-0
lines changed
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import asyncio
2+
from datetime import datetime
3+
from typing import Any
4+
5+
import pytest
6+
7+
from agents import Agent, GuardrailFunctionOutput, InputGuardrail, Runner, RunContextWrapper
8+
from agents.items import TResponseInputItem
9+
from agents.exceptions import InputGuardrailTripwireTriggered
10+
11+
from .fake_model import FakeModel
12+
from openai.types.responses import ResponseCompletedEvent
13+
from .test_responses import get_text_message
14+
15+
16+
def make_input_guardrail(delay_seconds: float, *, trip: bool) -> InputGuardrail[Any]:
17+
async def guardrail(
18+
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
19+
) -> GuardrailFunctionOutput:
20+
# Simulate variable guardrail completion timing.
21+
if delay_seconds > 0:
22+
await asyncio.sleep(delay_seconds)
23+
return GuardrailFunctionOutput(
24+
output_info={"delay": delay_seconds}, tripwire_triggered=trip
25+
)
26+
27+
# Name helps assertions/debugging and ensures deterministic identity.
28+
name = "tripping_input_guardrail" if trip else "delayed_input_guardrail"
29+
return InputGuardrail(guardrail_function=guardrail, name=name)
30+
31+
32+
@pytest.mark.asyncio
33+
@pytest.mark.parametrize("guardrail_delay", [0.0, 0.2])
34+
async def test_run_streamed_input_guardrail_timing_is_consistent(guardrail_delay: float):
35+
"""Ensure streaming behavior matches whether input guardrail finishes before or after LLM stream.
36+
37+
We verify that:
38+
- The sequence of streamed event types is identical.
39+
- Final output matches.
40+
- Exactly one input guardrail result is recorded and does not trigger.
41+
"""
42+
43+
# Arrange: Agent with a single text output and a delayed input guardrail
44+
model = FakeModel()
45+
model.set_next_output([get_text_message("Final response")])
46+
47+
agent = Agent(
48+
name="TimingAgent",
49+
model=model,
50+
input_guardrails=[make_input_guardrail(guardrail_delay, trip=False)],
51+
)
52+
53+
# Act: Run streamed and collect event types
54+
result = Runner.run_streamed(agent, input="Hello")
55+
event_types: list[str] = []
56+
57+
async for event in result.stream_events():
58+
event_types.append(event.type)
59+
60+
# Assert: Guardrail results populated and identical behavioral outcome
61+
assert len(result.input_guardrail_results) == 1, "Expected exactly one input guardrail result"
62+
assert result.input_guardrail_results[0].guardrail.get_name() == "delayed_input_guardrail", (
63+
"Guardrail name mismatch"
64+
)
65+
assert result.input_guardrail_results[0].output.tripwire_triggered is False, (
66+
"Guardrail should not trigger in this test"
67+
)
68+
69+
# Final output should be the text from the model's single message
70+
assert result.final_output == "Final response"
71+
72+
# Minimal invariants on event sequence to ensure stability across timing
73+
# Must start with agent update and include raw response events
74+
assert len(event_types) >= 3, f"Unexpectedly few events: {event_types}"
75+
assert event_types[0] == "agent_updated_stream_event"
76+
# Ensure we observed raw response events in the stream irrespective of guardrail timing
77+
assert any(t == "raw_response_event" for t in event_types)
78+
79+
80+
@pytest.mark.asyncio
81+
async def test_run_streamed_input_guardrail_sequences_match_between_fast_and_slow():
82+
"""Run twice with fast vs slow input guardrail and compare event sequences exactly."""
83+
84+
async def run_once(delay: float) -> list[str]:
85+
model = FakeModel()
86+
model.set_next_output([get_text_message("Final response")])
87+
agent = Agent(
88+
name="TimingAgent",
89+
model=model,
90+
input_guardrails=[make_input_guardrail(delay, trip=False)],
91+
)
92+
result = Runner.run_streamed(agent, input="Hello")
93+
events: list[str] = []
94+
async for ev in result.stream_events():
95+
events.append(ev.type)
96+
return events
97+
98+
events_fast = await run_once(0.0)
99+
events_slow = await run_once(0.2)
100+
101+
assert events_fast == events_slow, (
102+
f"Event sequences differ between guardrail timings:\nfast={events_fast}\nslow={events_slow}"
103+
)
104+
105+
106+
# make_tripping_input_guardrail merged into make_input_guardrail
107+
108+
109+
@pytest.mark.asyncio
110+
@pytest.mark.parametrize("guardrail_delay", [0.0, 0.2])
111+
async def test_run_streamed_input_guardrail_tripwire_raises(guardrail_delay: float):
112+
"""Guardrail tripwire must raise from stream_events regardless of timing."""
113+
114+
model = FakeModel()
115+
model.set_next_output([get_text_message("Final response")])
116+
117+
agent = Agent(
118+
name="TimingAgentTrip",
119+
model=model,
120+
input_guardrails=[make_input_guardrail(guardrail_delay, trip=True)],
121+
)
122+
123+
result = Runner.run_streamed(agent, input="Hello")
124+
125+
with pytest.raises(InputGuardrailTripwireTriggered) as excinfo:
126+
async for _ in result.stream_events():
127+
pass
128+
129+
# Exception contains the guardrail result and run data
130+
exc = excinfo.value
131+
assert exc.guardrail_result.output.tripwire_triggered is True
132+
assert exc.run_data is not None
133+
assert len(exc.run_data.input_guardrail_results) == 1
134+
assert (
135+
exc.run_data.input_guardrail_results[0].guardrail.get_name() == "tripping_input_guardrail"
136+
)
137+
138+
139+
class SlowCompleteFakeModel(FakeModel):
140+
"""A FakeModel that delays just before emitting ResponseCompletedEvent in streaming."""
141+
142+
def __init__(self, delay_seconds: float, tracing_enabled: bool = True):
143+
super().__init__(tracing_enabled=tracing_enabled)
144+
self._delay_seconds = delay_seconds
145+
146+
async def stream_response(self, *args, **kwargs): # type: ignore[override]
147+
async for ev in super().stream_response(*args, **kwargs):
148+
if isinstance(ev, ResponseCompletedEvent) and self._delay_seconds > 0:
149+
await asyncio.sleep(self._delay_seconds)
150+
yield ev
151+
152+
153+
def _get_span_by_type(spans, span_type: str):
154+
for s in spans:
155+
exported = s.export()
156+
if not exported:
157+
continue
158+
if exported.get("span_data", {}).get("type") == span_type:
159+
return s
160+
return None
161+
162+
163+
def _iso(s: str | None) -> datetime:
164+
assert s is not None
165+
return datetime.fromisoformat(s)
166+
167+
168+
@pytest.mark.asyncio
169+
async def test_parent_span_and_trace_finish_after_slow_input_guardrail():
170+
"""Agent span and trace finish after guardrail when guardrail completes last."""
171+
172+
model = FakeModel(tracing_enabled=True)
173+
model.set_next_output([get_text_message("Final response")])
174+
agent = Agent(
175+
name="TimingAgentTrace",
176+
model=model,
177+
input_guardrails=[make_input_guardrail(0.2, trip=False)], # guardrail slower than model
178+
)
179+
180+
result = Runner.run_streamed(agent, input="Hello")
181+
async for _ in result.stream_events():
182+
pass
183+
184+
from .testing_processor import fetch_ordered_spans
185+
186+
spans = fetch_ordered_spans()
187+
agent_span = _get_span_by_type(spans, "agent")
188+
guardrail_span = _get_span_by_type(spans, "guardrail")
189+
generation_span = _get_span_by_type(spans, "generation")
190+
191+
assert agent_span and guardrail_span and generation_span, (
192+
"Expected agent, guardrail, generation spans"
193+
)
194+
195+
# Agent span must finish last
196+
assert _iso(agent_span.ended_at) >= _iso(guardrail_span.ended_at)
197+
assert _iso(agent_span.ended_at) >= _iso(generation_span.ended_at)
198+
199+
# Trace should end after all spans end
200+
from .testing_processor import fetch_events
201+
202+
events = fetch_events()
203+
assert events[-1] == "trace_end"
204+
205+
206+
@pytest.mark.asyncio
207+
async def test_parent_span_and_trace_finish_after_slow_model():
208+
"""Agent span and trace finish after model when model completes last."""
209+
210+
model = SlowCompleteFakeModel(delay_seconds=0.2, tracing_enabled=True)
211+
model.set_next_output([get_text_message("Final response")])
212+
agent = Agent(
213+
name="TimingAgentTrace",
214+
model=model,
215+
input_guardrails=[make_input_guardrail(0.0, trip=False)], # guardrail faster than model
216+
)
217+
218+
result = Runner.run_streamed(agent, input="Hello")
219+
async for _ in result.stream_events():
220+
pass
221+
222+
from .testing_processor import fetch_ordered_spans
223+
224+
spans = fetch_ordered_spans()
225+
agent_span = _get_span_by_type(spans, "agent")
226+
guardrail_span = _get_span_by_type(spans, "guardrail")
227+
generation_span = _get_span_by_type(spans, "generation")
228+
229+
assert agent_span and guardrail_span and generation_span, (
230+
"Expected agent, guardrail, generation spans"
231+
)
232+
233+
# Agent span must finish last
234+
assert _iso(agent_span.ended_at) >= _iso(guardrail_span.ended_at)
235+
assert _iso(agent_span.ended_at) >= _iso(generation_span.ended_at)
236+
237+
from .testing_processor import fetch_events
238+
239+
events = fetch_events()
240+
assert events[-1] == "trace_end"

0 commit comments

Comments
 (0)