Skip to content

Commit cb343dd

Browse files
Copilotlarohra
andauthored
Python: Add orchestration ID to durable agent entity state and code refactor (#2484)
* Initial plan * Add orchestration ID to durable agent entity state for Python Co-authored-by: larohra <[email protected]> * Fix type safety checks * Fix tests --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: larohra <[email protected]> Co-authored-by: Laveesh Rohra <[email protected]>
1 parent 0d5d10d commit cb343dd

File tree

8 files changed

+354
-130
lines changed

8 files changed

+354
-130
lines changed

python/packages/azurefunctions/agent_framework_azurefunctions/_app.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ async def mcp_tool_handler(context: str, client: df.DurableOrchestrationClient)
562562
logger.debug("[MCP Tool Trigger] Received invocation for agent: %s", agent_name)
563563
return await self._handle_mcp_tool_invocation(agent_name=agent_name, context=context, client=client)
564564

565+
_ = mcp_tool_handler
565566
logger.debug("[AgentFunctionApp] Registered MCP tool trigger for agent: %s", agent_name)
566567

567568
async def _handle_mcp_tool_invocation(
@@ -587,15 +588,17 @@ async def _handle_mcp_tool_invocation(
587588

588589
# Parse JSON context string
589590
try:
590-
parsed_context = json.loads(context)
591+
parsed_context: Any = json.loads(context)
591592
except json.JSONDecodeError as e:
592593
raise ValueError(f"Invalid MCP context format: {e}") from e
593594

595+
parsed_context = cast(Mapping[str, Any], parsed_context) if isinstance(parsed_context, dict) else {}
596+
594597
# Extract arguments from MCP context
595-
arguments = parsed_context.get("arguments", {}) if isinstance(parsed_context, dict) else {}
598+
arguments: dict[str, Any] = parsed_context.get("arguments", {})
596599

597600
# Validate required 'query' argument
598-
query = arguments.get("query")
601+
query: Any = arguments.get("query")
599602
if not query or not isinstance(query, str):
600603
raise ValueError("MCP Tool invocation is missing required 'query' argument of type string.")
601604

@@ -951,10 +954,9 @@ def _extract_normalized_headers(self, req: func.HttpRequest) -> dict[str, str]:
951954
"""Create a lowercase header mapping from the incoming request."""
952955
headers: dict[str, str] = {}
953956
raw_headers = req.headers
954-
if isinstance(raw_headers, Mapping):
955-
for key, value in raw_headers.items():
956-
if value is not None:
957-
headers[str(key).lower()] = str(value)
957+
for key, value in cast(Mapping[str, str], raw_headers).items():
958+
headers[key.lower()] = value
959+
958960
return headers
959961

960962
@staticmethod

python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py

Lines changed: 153 additions & 119 deletions
Large diffs are not rendered by default.

python/packages/azurefunctions/agent_framework_azurefunctions/_models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ class RunRequest:
287287
thread_id: Optional thread ID for tracking
288288
correlation_id: Optional correlation ID for tracking the response to this specific request
289289
created_at: Optional timestamp when the request was created
290+
orchestration_id: Optional ID of the orchestration that initiated this request
290291
"""
291292

292293
message: str
@@ -297,6 +298,7 @@ class RunRequest:
297298
thread_id: str | None = None
298299
correlation_id: str | None = None
299300
created_at: str | None = None
301+
orchestration_id: str | None = None
300302

301303
def __init__(
302304
self,
@@ -308,6 +310,7 @@ def __init__(
308310
thread_id: str | None = None,
309311
correlation_id: str | None = None,
310312
created_at: str | None = None,
313+
orchestration_id: str | None = None,
311314
) -> None:
312315
self.message = message
313316
self.role = self.coerce_role(role)
@@ -317,6 +320,7 @@ def __init__(
317320
self.thread_id = thread_id
318321
self.correlation_id = correlation_id
319322
self.created_at = created_at
323+
self.orchestration_id = orchestration_id
320324

321325
@staticmethod
322326
def coerce_role(value: Role | str | None) -> Role:
@@ -346,6 +350,8 @@ def to_dict(self) -> dict[str, Any]:
346350
result["correlationId"] = self.correlation_id
347351
if self.created_at:
348352
result["created_at"] = self.created_at
353+
if self.orchestration_id:
354+
result["orchestrationId"] = self.orchestration_id
349355

350356
return result
351357

@@ -361,4 +367,5 @@ def from_dict(cls, data: dict[str, Any]) -> RunRequest:
361367
thread_id=data.get("thread_id"),
362368
correlation_id=data.get("correlationId"),
363369
created_at=data.get("created_at"),
370+
orchestration_id=data.get("orchestrationId"),
364371
)

python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,14 @@ def my_orchestration(context):
272272
)
273273

274274
# Prepare the request using RunRequest model
275+
# Include the orchestration's instance_id so it can be stored in the agent's entity state
275276
run_request = RunRequest(
276277
message=message_str,
277278
enable_tool_calls=enable_tool_calls,
278279
correlation_id=correlation_id,
279280
thread_id=session_id.key,
280281
response_format=response_format,
282+
orchestration_id=self.context.instance_id,
281283
)
282284

283285
logger.debug("[DurableAIAgent] Calling entity %s with message: %s", entity_id, message_str[:100])

python/packages/azurefunctions/tests/test_entities.py

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def test_init_creates_entity(self) -> None:
7979
assert entity.agent == mock_agent
8080
assert len(entity.state.data.conversation_history) == 0
8181
assert entity.state.data.extension_data is None
82-
assert entity.state.schema_version == "1.0.0"
82+
assert entity.state.schema_version == DurableAgentState.SCHEMA_VERSION
8383

8484
def test_init_stores_agent_reference(self) -> None:
8585
"""Test that the agent reference is stored correctly."""
@@ -124,8 +124,7 @@ async def test_run_agent_executes_agent(self) -> None:
124124
# Verify agent.run was called
125125
mock_agent.run.assert_called_once()
126126
_, kwargs = mock_agent.run.call_args
127-
sent_messages = kwargs.get("messages")
128-
assert isinstance(sent_messages, list)
127+
sent_messages: list[Any] = kwargs.get("messages")
129128
assert len(sent_messages) == 1
130129
sent_message = sent_messages[0]
131130
assert isinstance(sent_message, ChatMessage)
@@ -910,5 +909,98 @@ async def test_entity_function_with_run_request_dict(self) -> None:
910909
assert text_found, f"Response text not found in message: {message}"
911910

912911

912+
class TestDurableAgentStateRequestOrchestrationId:
913+
"""Test suite for DurableAgentStateRequest orchestration_id field."""
914+
915+
def test_request_with_orchestration_id(self) -> None:
916+
"""Test creating a request with an orchestration_id."""
917+
request = DurableAgentStateRequest(
918+
correlation_id="corr-123",
919+
created_at=datetime.now(),
920+
messages=[
921+
DurableAgentStateMessage(
922+
role="user",
923+
contents=[DurableAgentStateTextContent(text="test")],
924+
)
925+
],
926+
orchestration_id="orch-456",
927+
)
928+
929+
assert request.orchestration_id == "orch-456"
930+
931+
def test_request_to_dict_includes_orchestration_id(self) -> None:
932+
"""Test that to_dict includes orchestrationId when set."""
933+
request = DurableAgentStateRequest(
934+
correlation_id="corr-123",
935+
created_at=datetime.now(),
936+
messages=[
937+
DurableAgentStateMessage(
938+
role="user",
939+
contents=[DurableAgentStateTextContent(text="test")],
940+
)
941+
],
942+
orchestration_id="orch-789",
943+
)
944+
945+
data = request.to_dict()
946+
947+
assert "orchestrationId" in data
948+
assert data["orchestrationId"] == "orch-789"
949+
950+
def test_request_to_dict_excludes_orchestration_id_when_none(self) -> None:
951+
"""Test that to_dict excludes orchestrationId when not set."""
952+
request = DurableAgentStateRequest(
953+
correlation_id="corr-123",
954+
created_at=datetime.now(),
955+
messages=[
956+
DurableAgentStateMessage(
957+
role="user",
958+
contents=[DurableAgentStateTextContent(text="test")],
959+
)
960+
],
961+
)
962+
963+
data = request.to_dict()
964+
965+
assert "orchestrationId" not in data
966+
967+
def test_request_from_dict_with_orchestration_id(self) -> None:
968+
"""Test from_dict correctly parses orchestrationId."""
969+
data = {
970+
"$type": "request",
971+
"correlationId": "corr-123",
972+
"createdAt": "2024-01-01T00:00:00Z",
973+
"messages": [{"role": "user", "contents": [{"$type": "text", "text": "test"}]}],
974+
"orchestrationId": "orch-from-dict",
975+
}
976+
977+
request = DurableAgentStateRequest.from_dict(data)
978+
979+
assert request.orchestration_id == "orch-from-dict"
980+
981+
def test_request_from_run_request_with_orchestration_id(self) -> None:
982+
"""Test from_run_request correctly transfers orchestration_id."""
983+
run_request = RunRequest(
984+
message="test message",
985+
correlation_id="corr-run",
986+
orchestration_id="orch-from-run-request",
987+
)
988+
989+
durable_request = DurableAgentStateRequest.from_run_request(run_request)
990+
991+
assert durable_request.orchestration_id == "orch-from-run-request"
992+
993+
def test_request_from_run_request_without_orchestration_id(self) -> None:
994+
"""Test from_run_request correctly handles missing orchestration_id."""
995+
run_request = RunRequest(
996+
message="test message",
997+
correlation_id="corr-run",
998+
)
999+
1000+
durable_request = DurableAgentStateRequest.from_run_request(run_request)
1001+
1002+
assert durable_request.orchestration_id is None
1003+
1004+
9131005
if __name__ == "__main__":
9141006
pytest.main([__file__, "-v", "--tb=short"])

python/packages/azurefunctions/tests/test_models.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,71 @@ def test_round_trip_with_correlationId(self) -> None:
336336
assert restored.correlation_id == original.correlation_id
337337
assert restored.thread_id == original.thread_id
338338

339+
def test_init_with_orchestration_id(self) -> None:
340+
"""Test RunRequest initialization with orchestration_id."""
341+
request = RunRequest(
342+
message="Test message",
343+
thread_id="thread-orch-init",
344+
orchestration_id="orch-123",
345+
)
346+
347+
assert request.message == "Test message"
348+
assert request.orchestration_id == "orch-123"
349+
350+
def test_to_dict_with_orchestration_id(self) -> None:
351+
"""Test to_dict includes orchestrationId."""
352+
request = RunRequest(
353+
message="Test",
354+
thread_id="thread-orch-to-dict",
355+
orchestration_id="orch-456",
356+
)
357+
data = request.to_dict()
358+
359+
assert data["message"] == "Test"
360+
assert data["orchestrationId"] == "orch-456"
361+
362+
def test_to_dict_excludes_orchestration_id_when_none(self) -> None:
363+
"""Test to_dict excludes orchestrationId when not set."""
364+
request = RunRequest(
365+
message="Test",
366+
thread_id="thread-orch-none",
367+
)
368+
data = request.to_dict()
369+
370+
assert "orchestrationId" not in data
371+
372+
def test_from_dict_with_orchestration_id(self) -> None:
373+
"""Test from_dict with orchestrationId."""
374+
data = {
375+
"message": "Test",
376+
"orchestrationId": "orch-789",
377+
"thread_id": "thread-orch-from-dict",
378+
}
379+
request = RunRequest.from_dict(data)
380+
381+
assert request.message == "Test"
382+
assert request.orchestration_id == "orch-789"
383+
assert request.thread_id == "thread-orch-from-dict"
384+
385+
def test_round_trip_with_orchestration_id(self) -> None:
386+
"""Test round-trip to_dict and from_dict with orchestration_id."""
387+
original = RunRequest(
388+
message="Test message",
389+
thread_id="thread-123",
390+
role=Role.SYSTEM,
391+
correlation_id="corr-123",
392+
orchestration_id="orch-123",
393+
)
394+
395+
data = original.to_dict()
396+
restored = RunRequest.from_dict(data)
397+
398+
assert restored.message == original.message
399+
assert restored.role == original.role
400+
assert restored.correlation_id == original.correlation_id
401+
assert restored.orchestration_id == original.orchestration_id
402+
assert restored.thread_id == original.thread_id
403+
339404

340405
class TestModelIntegration:
341406
"""Test suite for integration between models."""

python/packages/azurefunctions/tests/test_orchestration.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,28 @@ def test_run_creates_entity_call(self) -> None:
302302
assert request["correlationId"] == "correlation-guid"
303303
assert "thread_id" in request
304304
assert request["thread_id"] == "thread-guid"
305+
# Verify orchestration ID is set from context.instance_id
306+
assert "orchestrationId" in request
307+
assert request["orchestrationId"] == "test-instance-001"
308+
309+
def test_run_sets_orchestration_id(self) -> None:
310+
"""Test that run() sets the orchestration_id from context.instance_id."""
311+
mock_context = Mock()
312+
mock_context.instance_id = "my-orchestration-123"
313+
mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"])
314+
315+
entity_task = _create_entity_task()
316+
mock_context.call_entity = Mock(return_value=entity_task)
317+
318+
agent = DurableAIAgent(mock_context, "TestAgent")
319+
thread = agent.get_new_thread()
320+
321+
agent.run(messages="Test", thread=thread)
322+
323+
call_args = mock_context.call_entity.call_args
324+
request = call_args[0][2]
325+
326+
assert request["orchestrationId"] == "my-orchestration-123"
305327

306328
def test_run_without_thread(self) -> None:
307329
"""Test that run() works without explicit thread (creates unique session key)."""

python/uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)