Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 47 additions & 17 deletions python/packages/core/agent_framework/_workflows/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)

from ..exceptions import AgentExecutionException
from ._agent_executor import AgentExecutor
from ._checkpoint import CheckpointStorage
from ._events import (
AgentRunUpdateEvent,
Expand Down Expand Up @@ -141,7 +142,8 @@ async def run(
checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id,
used to load and restore the checkpoint. When provided without checkpoint_id,
enables checkpointing for this run.
**kwargs: Additional keyword arguments.
**kwargs: Additional keyword arguments passed through to underlying workflow
and ai_function tools.

Returns:
The final workflow response as an AgentRunResponse.
Expand All @@ -153,7 +155,7 @@ async def run(
response_id = str(uuid.uuid4())

async for update in self._run_stream_impl(
input_messages, response_id, thread, checkpoint_id, checkpoint_storage
input_messages, response_id, thread, checkpoint_id, checkpoint_storage, **kwargs
):
response_updates.append(update)

Expand Down Expand Up @@ -187,7 +189,8 @@ async def run_stream(
checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id,
used to load and restore the checkpoint. When provided without checkpoint_id,
enables checkpointing for this run.
**kwargs: Additional keyword arguments.
**kwargs: Additional keyword arguments passed through to underlying workflow
and ai_function tools.

Yields:
AgentRunResponseUpdate objects representing the workflow execution progress.
Expand All @@ -198,7 +201,7 @@ async def run_stream(
response_id = str(uuid.uuid4())

async for update in self._run_stream_impl(
input_messages, response_id, thread, checkpoint_id, checkpoint_storage
input_messages, response_id, thread, checkpoint_id, checkpoint_storage, **kwargs
):
response_updates.append(update)
yield update
Expand All @@ -216,6 +219,7 @@ async def _run_stream_impl(
thread: AgentThread,
checkpoint_id: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
**kwargs: Any,
) -> AsyncIterable[AgentRunResponseUpdate]:
"""Internal implementation of streaming execution.

Expand All @@ -225,6 +229,8 @@ async def _run_stream_impl(
thread: The conversation thread containing message history.
checkpoint_id: ID of checkpoint to restore from.
checkpoint_storage: Runtime checkpoint storage.
**kwargs: Additional keyword arguments passed through to the underlying
workflow and ai_function tools.

Yields:
AgentRunResponseUpdate objects representing the workflow execution progress.
Expand Down Expand Up @@ -255,6 +261,7 @@ async def _run_stream_impl(
message=None,
checkpoint_id=checkpoint_id,
checkpoint_storage=checkpoint_storage,
**kwargs,
)
else:
# Execute workflow with streaming (initial run or no function responses)
Expand All @@ -268,6 +275,7 @@ async def _run_stream_impl(
event_stream = self.workflow.run_stream(
message=conversation_messages,
checkpoint_storage=checkpoint_storage,
**kwargs,
)

# Process events from the stream
Expand All @@ -286,22 +294,38 @@ def _convert_workflow_event_to_agent_update(

AgentRunUpdateEvent, RequestInfoEvent, and WorkflowOutputEvent are processed.
Other workflow events are ignored as they are workflow-internal.

For AgentRunUpdateEvent from AgentExecutor instances, only events from executors
with output_response=True are converted to agent updates. This prevents agent
responses from executors that were not explicitly marked to surface their output.
Non-AgentExecutor executors that emit AgentRunUpdateEvent directly are allowed
through since they explicitly chose to emit the event.
"""
match event:
case AgentRunUpdateEvent(data=update):
# Direct pass-through of update in an agent streaming event
case AgentRunUpdateEvent(data=update, executor_id=executor_id):
# For AgentExecutor instances, only pass through if output_response=True.
# Non-AgentExecutor executors that emit AgentRunUpdateEvent are allowed through.
executor = self.workflow.executors.get(executor_id)
if isinstance(executor, AgentExecutor) and not executor.output_response:
return None
if update:
return update
return None

case WorkflowOutputEvent(data=data, source_executor_id=source_executor_id):
# Convert workflow output to an agent response update.
# Handle different data types appropriately.

# Skip AgentRunResponse from AgentExecutor with output_response=True
# since streaming events already surfaced the content.
if isinstance(data, AgentRunResponse):
executor = self.workflow.executors.get(source_executor_id)
if isinstance(executor, AgentExecutor) and executor.output_response:
return None

if isinstance(data, AgentRunResponseUpdate):
# Already an update, pass through
return data
if isinstance(data, ChatMessage):
# Convert ChatMessage to update
return AgentRunResponseUpdate(
contents=list(data.contents),
role=data.role,
Expand All @@ -311,15 +335,9 @@ def _convert_workflow_event_to_agent_update(
created_at=datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
raw_representation=data,
)
# Determine contents based on data type
if isinstance(data, BaseContent):
# Already a content type (TextContent, ImageContent, etc.)
contents: list[Contents] = [cast(Contents, data)]
elif isinstance(data, str):
contents = [TextContent(text=data)]
else:
# Fallback: convert to string representation
contents = [TextContent(text=str(data))]
contents = self._extract_contents(data)
if not contents:
return None
return AgentRunResponseUpdate(
contents=contents,
role=Role.ASSISTANT,
Expand Down Expand Up @@ -405,6 +423,18 @@ def _extract_function_responses(self, input_messages: list[ChatMessage]) -> dict
raise AgentExecutionException("Unexpected content type while awaiting request info responses.")
return function_responses

def _extract_contents(self, data: Any) -> list[Contents]:
"""Recursively extract Contents from workflow output data."""
if isinstance(data, ChatMessage):
return list(data.contents)
if isinstance(data, list):
return [c for item in data for c in self._extract_contents(item)]
if isinstance(data, BaseContent):
return [cast(Contents, data)]
if isinstance(data, str):
return [TextContent(text=data)]
return [TextContent(text=str(data))]

class _ResponseState(TypedDict):
"""State for grouping response updates by message_id."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def __init__(
self._output_response = output_response
self._cache: list[ChatMessage] = []

@property
def output_response(self) -> bool:
"""Whether this executor yields AgentRunResponse as workflow output when complete."""
return self._output_response

@property
def workflow_output_types(self) -> list[type[Any]]:
# Override to declare AgentRunResponse as a possible output type only if enabled.
Expand Down
180 changes: 180 additions & 0 deletions python/packages/core/tests/workflow/test_workflow_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) Microsoft. All rights reserved.

import uuid
from collections.abc import AsyncIterable
from typing import Any

import pytest

from agent_framework import (
AgentProtocol,
AgentRunResponse,
AgentRunResponseUpdate,
AgentRunUpdateEvent,
Expand Down Expand Up @@ -422,6 +424,48 @@ async def raw_yielding_executor(messages: list[ChatMessage], ctx: WorkflowContex
assert isinstance(updates[2].raw_representation, CustomData)
assert updates[2].raw_representation.value == 42

async def test_workflow_as_agent_yield_output_with_list_of_chat_messages(self) -> None:
"""Test that yield_output with list[ChatMessage] extracts contents from all messages.

Note: TextContent items are coalesced by _finalize_response, so multiple text contents
become a single merged TextContent in the final response.
"""

@executor
async def list_yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None:
# Yield a list of ChatMessages (as SequentialBuilder does)
msg_list = [
ChatMessage(role=Role.USER, contents=[TextContent(text="first message")]),
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="second message")]),
ChatMessage(
role=Role.ASSISTANT,
contents=[TextContent(text="third"), TextContent(text="fourth")],
),
]
await ctx.yield_output(msg_list)

workflow = WorkflowBuilder().set_start_executor(list_yielding_executor).build()
agent = workflow.as_agent("list-msg-agent")

# Verify streaming returns the update with all 4 contents before coalescing
updates: list[AgentRunResponseUpdate] = []
async for update in agent.run_stream("test"):
updates.append(update)

assert len(updates) == 1
assert len(updates[0].contents) == 4
texts = [c.text for c in updates[0].contents if isinstance(c, TextContent)]
assert texts == ["first message", "second message", "third", "fourth"]

# Verify run() coalesces text contents (expected behavior)
result = await agent.run("test")

assert isinstance(result, AgentRunResponse)
assert len(result.messages) == 1
# TextContent items are coalesced into one
assert len(result.messages[0].contents) == 1
assert result.messages[0].text == "first messagesecond messagethirdfourth"

async def test_thread_conversation_history_included_in_workflow_run(self) -> None:
"""Test that conversation history from thread is included when running WorkflowAgent.

Expand Down Expand Up @@ -521,6 +565,142 @@ async def test_checkpoint_storage_passed_to_workflow(self) -> None:
checkpoints = await checkpoint_storage.list_checkpoints(workflow.id)
assert len(checkpoints) > 0, "Checkpoints should have been created when checkpoint_storage is provided"

async def test_agent_executor_output_response_false_filters_streaming_events(self):
"""Test that AgentExecutor with output_response=False does not surface streaming events."""

class MockAgent(AgentProtocol):
"""Mock agent for testing."""

def __init__(self, name: str, response_text: str) -> None:
self._name = name
self._response_text = response_text
self._description: str | None = None

@property
def name(self) -> str | None:
return self._name

@property
def description(self) -> str | None:
return self._description

def get_new_thread(self) -> AgentThread:
return AgentThread()

async def run(self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any) -> AgentRunResponse:
return AgentRunResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text=self._response_text)],
text=self._response_text,
)

async def run_stream(
self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any
) -> AsyncIterable[AgentRunResponseUpdate]:
for word in self._response_text.split():
yield AgentRunResponseUpdate(
contents=[TextContent(text=word + " ")],
role=Role.ASSISTANT,
author_name=self._name,
)

@executor
async def start_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None:
from agent_framework import AgentExecutorRequest

await ctx.yield_output("Start output")
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))

# Build workflow: start -> agent1 (no output) -> agent2 (output_response=True)
workflow = (
WorkflowBuilder()
.register_executor(lambda: start_executor, "start")
.register_agent(lambda: MockAgent("agent1", "Agent1 output - should NOT appear"), "agent1")
.register_agent(
lambda: MockAgent("agent2", "Agent2 output - SHOULD appear"), "agent2", output_response=True
)
.set_start_executor("start")
.add_edge("start", "agent1")
.add_edge("agent1", "agent2")
.build()
)

agent = WorkflowAgent(workflow=workflow, name="Test Agent")
result = await agent.run("Test input")

# Collect all message texts
texts = [msg.text for msg in result.messages if msg.text]

# Start output should appear (from yield_output)
assert any("Start output" in t for t in texts), "Start output should appear"

# Agent1 output should NOT appear (output_response=False)
assert not any("Agent1" in t for t in texts), "Agent1 output should NOT appear"

# Agent2 output should appear (output_response=True)
assert any("Agent2" in t for t in texts), "Agent2 output should appear"

async def test_agent_executor_output_response_no_duplicate_from_workflow_output_event(self):
"""Test that AgentExecutor with output_response=True does not duplicate content."""

class MockAgent(AgentProtocol):
"""Mock agent for testing."""

def __init__(self, name: str, response_text: str) -> None:
self._name = name
self._response_text = response_text
self._description: str | None = None

@property
def name(self) -> str | None:
return self._name

@property
def description(self) -> str | None:
return self._description

def get_new_thread(self) -> AgentThread:
return AgentThread()

async def run(self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any) -> AgentRunResponse:
return AgentRunResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text=self._response_text)],
text=self._response_text,
)

async def run_stream(
self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any
) -> AsyncIterable[AgentRunResponseUpdate]:
yield AgentRunResponseUpdate(
contents=[TextContent(text=self._response_text)],
role=Role.ASSISTANT,
author_name=self._name,
)

@executor
async def start_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None:
from agent_framework import AgentExecutorRequest

await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))

# Build workflow with single agent that has output_response=True
workflow = (
WorkflowBuilder()
.register_executor(lambda: start_executor, "start")
.register_agent(lambda: MockAgent("agent", "Unique response text"), "agent", output_response=True)
.set_start_executor("start")
.add_edge("start", "agent")
.build()
)

agent = WorkflowAgent(workflow=workflow, name="Test Agent")
result = await agent.run("Test input")

# Count occurrences of the unique response text
unique_text_count = sum(1 for msg in result.messages if msg.text and "Unique response text" in msg.text)

# Should appear exactly once (not duplicated from both streaming and WorkflowOutputEvent)
assert unique_text_count == 1, f"Response should appear exactly once, but appeared {unique_text_count} times"


class TestWorkflowAgentMergeUpdates:
"""Test cases specifically for the WorkflowAgent.merge_updates static method."""
Expand Down
Loading
Loading