|
1 | 1 | import json
|
2 |
| -from typing import Any, Optional |
| 2 | +from collections.abc import Callable |
| 3 | +from typing import Any, Optional, Union |
3 | 4 | from unittest.mock import patch
|
4 | 5 |
|
5 | 6 | import pytest
|
@@ -51,34 +52,70 @@ def mock_response(
|
51 | 52 | status_code=status_code,
|
52 | 53 | )
|
53 | 54 |
|
54 |
| - def mock_stream(self, task_id: str = "city-to-capital"): |
| 55 | + def mock_stream( |
| 56 | + self, |
| 57 | + task_id: str = "city-to-capital", |
| 58 | + outputs: Optional[list[dict[str, Any]]] = None, |
| 59 | + run_id: str = "1", |
| 60 | + metadata: Optional[dict[str, Any]] = None, |
| 61 | + ): |
| 62 | + outputs = outputs or [ |
| 63 | + {"capital": ""}, |
| 64 | + {"capital": "Tok"}, |
| 65 | + {"capital": "Tokyo"}, |
| 66 | + ] |
| 67 | + if metadata is None: |
| 68 | + metadata = {"cost_usd": 0.01, "duration_seconds": 10.1} |
| 69 | + |
| 70 | + payloads = [{"id": run_id, "task_output": o} for o in outputs] |
| 71 | + |
| 72 | + final_payload = {**payloads[-1], **metadata} |
| 73 | + payloads.append(final_payload) |
| 74 | + streams = [f"data: {json.dumps(p)}\n\n".encode() for p in payloads] |
| 75 | + |
55 | 76 | self.httpx_mock.add_response(
|
56 | 77 | url=f"https://run.workflowai.dev/v1/_/agents/{task_id}/schemas/1/run",
|
57 |
| - stream=IteratorStream( |
58 |
| - [ |
59 |
| - b'data: {"id":"1","task_output":{"capital":""}}\n\n', |
60 |
| - b'data: {"id":"1","task_output":{"capital":"Tok"}}\n\ndata: {"id":"1","task_output":{"capital":"Tokyo"}}\n\n', # noqa: E501 |
61 |
| - b'data: {"id":"1","task_output":{"capital":"Tokyo"},"cost_usd":0.01,"duration_seconds":10.1}\n\n', |
62 |
| - ], |
63 |
| - ), |
| 78 | + stream=IteratorStream(streams), |
64 | 79 | )
|
65 | 80 |
|
| 81 | + def check_register( |
| 82 | + self, |
| 83 | + task_id: str = "city-to-capital", |
| 84 | + input_schema: Optional[Union[dict[str, Any], Callable[[dict[str, Any]], None]]] = None, |
| 85 | + output_schema: Optional[Union[dict[str, Any], Callable[[dict[str, Any]], None]]] = None, |
| 86 | + ): |
| 87 | + request = self.httpx_mock.get_request(url=self.REGISTER_URL) |
| 88 | + assert request is not None |
| 89 | + assert request.headers["Authorization"] == "Bearer test" |
| 90 | + assert request.headers["Content-Type"] == "application/json" |
| 91 | + assert request.headers["x-workflowai-source"] == "sdk" |
| 92 | + assert request.headers["x-workflowai-language"] == "python" |
| 93 | + |
| 94 | + body = json.loads(request.content) |
| 95 | + assert body["id"] == task_id |
| 96 | + if callable(input_schema): |
| 97 | + input_schema(body["input_schema"]) |
| 98 | + else: |
| 99 | + assert body["input_schema"] == input_schema or {"city": {"type": "string"}} |
| 100 | + if callable(output_schema): |
| 101 | + output_schema(body["output_schema"]) |
| 102 | + else: |
| 103 | + assert body["output_schema"] == output_schema or {"capital": {"type": "string"}} |
| 104 | + |
66 | 105 | def check_request(
|
67 | 106 | self,
|
68 | 107 | version: Any = "production",
|
69 | 108 | task_id: str = "city-to-capital",
|
70 | 109 | task_input: Optional[dict[str, Any]] = None,
|
71 | 110 | **matchers: Any,
|
72 | 111 | ):
|
| 112 | + if not matchers: |
| 113 | + matchers = {"url": f"https://run.workflowai.dev/v1/_/agents/{task_id}/schemas/1/run"} |
73 | 114 | request = self.httpx_mock.get_request(**matchers)
|
74 | 115 | assert request is not None
|
75 |
| - assert request.url == f"https://run.workflowai.dev/v1/_/agents/{task_id}/schemas/1/run" |
76 | 116 | body = json.loads(request.content)
|
77 |
| - assert body == { |
78 |
| - "task_input": task_input or {"city": "Hello"}, |
79 |
| - "version": version, |
80 |
| - "stream": False, |
81 |
| - } |
| 117 | + assert body["task_input"] == task_input or {"city": "Hello"} |
| 118 | + assert body["version"] == version |
82 | 119 | assert request.headers["Authorization"] == "Bearer test"
|
83 | 120 | assert request.headers["Content-Type"] == "application/json"
|
84 | 121 | assert request.headers["x-workflowai-source"] == "sdk"
|
|
0 commit comments