diff --git a/docs/visualization.md b/docs/visualization.md index 00f3126d..00816f3b 100644 --- a/docs/visualization.md +++ b/docs/visualization.md @@ -1,6 +1,10 @@ # Agent Visualization -Agent visualization allows you to generate a structured graphical representation of agents and their relationships using **Graphviz**. This is useful for understanding how agents, tools, and handoffs interact within an application. +Agent visualization allows you to generate a structured graphical representation of agents and their relationships. Two rendering options are available: +- **Graphviz** (offline): Default renderer that generates graphs locally +- **Mermaid** (online): Alternative renderer that uses mermaid.ink API + +This is useful for understanding how agents, tools, and handoffs interact within an application. ## Installation @@ -18,6 +22,15 @@ You can generate an agent visualization using the `draw_graph` function. This fu - **Tools** are represented as green ellipses. - **Handoffs** are directed edges from one agent to another. +The renderer can be specified using the `renderer` parameter: +```python +# Using Graphviz (default) +draw_graph(agent, renderer="graphviz") + +# Using Mermaid API +draw_graph(agent, renderer="mermaid") +``` + ### Example Usage ```python @@ -82,5 +95,3 @@ draw_graph(triage_agent, filename="agent_graph.png") ``` This will generate `agent_graph.png` in the working directory. - - diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index 5fb35062..c66e2499 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -1,137 +1,436 @@ -from typing import Optional +import abc +import base64 +import warnings +from dataclasses import dataclass +from enum import Enum +from typing import Generic, Literal, Optional, TypeVar -import graphviz # type: ignore +import graphviz +import requests from agents import Agent from agents.handoffs import Handoff -from agents.tool import Tool -def get_main_graph(agent: Agent) -> str: - """ - Generates the main graph structure in DOT format for the given agent. +class NodeType(Enum): + START = "start" + END = "end" + AGENT = "agent" + TOOL = "tool" + HANDOFF = "handoff" - Args: - agent (Agent): The agent for which the graph is to be generated. - Returns: - str: The DOT format string representing the graph. - """ - parts = [ +class EdgeType(Enum): + HANDOFF = "handoff" + TOOL = "tool" + + +@dataclass(frozen=True) +class Node: + id: str + label: str + type: NodeType + + +@dataclass(frozen=True) +class Edge: + source: Node + target: Node + type: EdgeType + + +class Graph: + def __init__(self): + self.nodes: dict[str, Node] = {} + self.edges: list[Edge] = [] + + def add_node(self, node: Node) -> None: + self.nodes[node.id] = node + + def add_edge(self, edge: Edge) -> None: + """Add an edge to the graph. + + Args: + edge (Edge): The edge to add. + + Raises: + ValueError: If the source or target node does not exist in the graph. + """ + if edge.source.id not in self.nodes: + raise ValueError(f"Source node '{edge.source.id}' does not exist in the graph") + if edge.target.id not in self.nodes: + raise ValueError(f"Target node '{edge.target.id}' does not exist in the graph") + self.edges.append(edge) + + def has_node(self, node_id: str) -> bool: + """Check if a node exists in the graph. + + Args: + node_id (str): The ID of the node to check. + + Returns: + bool: True if the node exists, False otherwise. + """ + return node_id in self.nodes + + def get_node(self, node_id: str) -> Optional[Node]: + """Get a node from the graph. + + Args: + node_id (str): The ID of the node to get. + + Returns: + Optional[Node]: The node if it exists, None otherwise. + """ + return self.nodes.get(node_id) + + +class GraphBuilder: + def __init__(self): + self._visited: set[int] = set() + + def build_from_agent(self, agent: Agent) -> Graph: + """Build a graph from an agent. + + Args: + agent (Agent): The agent to build the graph from. + + Returns: + Graph: The built graph. + """ + self._visited.clear() + graph = Graph() + + # Add start and end nodes + graph.add_node(Node("__start__", "__start__", NodeType.START)) + graph.add_node(Node("__end__", "__end__", NodeType.END)) + + self._add_agent_nodes_and_edges(agent, None, graph) + return graph + + def _add_agent_nodes_and_edges( + self, + agent: Agent | None, + parent: Optional[Agent], + graph: Graph, + ) -> None: + if agent is None: + return + + start_node = graph.get_node("__start__") + end_node = graph.get_node("__end__") + + # Add agent node + agent_id = str(id(agent)) + agent_node = Node(agent_id, agent.name, NodeType.AGENT) + graph.add_node(agent_node) + self._visited.add(agent_id) + + # Connect start node if root agent + if not parent: + graph.add_edge(Edge(start_node, agent_node, EdgeType.HANDOFF)) + + # Add tool nodes and edges + for tool in agent.tools: + tool_id = str(id(tool)) + tool_node = Node(tool_id, tool.name, NodeType.TOOL) + graph.add_node(tool_node) + graph.add_edge(Edge(agent_node, tool_node, EdgeType.TOOL)) + graph.add_edge(Edge(tool_node, agent_node, EdgeType.TOOL)) + + # Process handoffs + for handoff in agent.handoffs: + handoff_id = str(id(handoff)) + if isinstance(handoff, Handoff): + handoff_node = Node(handoff_id, handoff.agent_name, NodeType.HANDOFF) + graph.add_node(handoff_node) + graph.add_edge(Edge(agent_node, handoff_node, EdgeType.HANDOFF)) + elif isinstance(handoff, Agent): + handoff_node = Node(handoff_id, handoff.name, NodeType.AGENT) + graph.add_node(handoff_node) + graph.add_edge(Edge(agent_node, handoff_node, EdgeType.HANDOFF)) + if handoff_id not in self._visited: + self._add_agent_nodes_and_edges(handoff, agent, graph) + + # Connect to end node if no handoffs + if not agent.handoffs: + graph.add_edge(Edge(agent_node, end_node, EdgeType.HANDOFF)) + + +T = TypeVar("T") + + +class GraphRenderer(Generic[T], abc.ABC): + """Abstract base class for graph renderers.""" + + @abc.abstractmethod + def render(self, graph: Graph) -> T: + """Render the graph in the specific format. + + Args: + graph (Graph): The graph to render. + + Returns: + T: The rendered graph in the format specific to the renderer. """ + pass + + @abc.abstractmethod + def save(self, rendered: T, filename: str) -> None: + """Save the rendered graph to a file. + + Args: + rendered (T): The rendered graph returned by render(). + filename (str): The name of the file to save the graph as. + """ + pass + + +class GraphvizRenderer(GraphRenderer[str]): + """Renderer that outputs graphs in Graphviz DOT format.""" + + def render(self, graph: Graph) -> str: + parts = [ + """ digraph G { graph [splines=true]; node [fontname="Arial"]; edge [penwidth=1.5]; """ - ] - parts.append(get_all_nodes(agent)) - parts.append(get_all_edges(agent)) - parts.append("}") - return "".join(parts) + ] + # Add nodes + for node in graph.nodes.values(): + parts.append(self._render_node(node)) -def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: - """ - Recursively generates the nodes for the given agent and its handoffs in DOT format. + # Add edges + for edge in graph.edges: + parts.append(self._render_edge(edge)) - Args: - agent (Agent): The agent for which the nodes are to be generated. + parts.append("}") + return "".join(parts) - Returns: - str: The DOT format string representing the nodes. - """ - parts = [] - - # Start and end the graph - parts.append( - '"__start__" [label="__start__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" - '"__end__" [label="__end__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" - ) - # Ensure parent agent node is colored - if not parent: - parts.append( - f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, ' - "fillcolor=lightyellow, width=1.5, height=0.8];" - ) - - for tool in agent.tools: - parts.append( - f'"{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, ' - f"fillcolor=lightgreen, width=0.5, height=0.3];" - ) - - for handoff in agent.handoffs: - if isinstance(handoff, Handoff): - parts.append( - f'"{handoff.agent_name}" [label="{handoff.agent_name}", ' - f"shape=box, style=filled, style=rounded, " - f"fillcolor=lightyellow, width=1.5, height=0.8];" - ) - if isinstance(handoff, Agent): - parts.append( - f'"{handoff.name}" [label="{handoff.name}", ' - f"shape=box, style=filled, style=rounded, " - f"fillcolor=lightyellow, width=1.5, height=0.8];" - ) - parts.append(get_all_nodes(handoff)) - - return "".join(parts) - - -def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: + def save(self, rendered: str, filename: str) -> None: + """Save the rendered graph as a PNG file using graphviz. + + Args: + rendered (str): The DOT format string. + filename (str): The name of the file to save the graph as. + """ + graphviz.Source(rendered).render(filename, format="png") + + def _render_node(self, node: Node) -> str: + style_map = { + NodeType.START: ( + "shape=ellipse, style=filled, fillcolor=lightblue, width=0.5, height=0.3" + ), + NodeType.END: ( + "shape=ellipse, style=filled, fillcolor=lightblue, width=0.5, height=0.3" + ), + NodeType.AGENT: ( + "shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8" + ), + NodeType.TOOL: ( + "shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3" + ), + NodeType.HANDOFF: ( + "shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8" + ), + } + return f'"{node.id}" [label="{node.label}", {style_map[node.type]}];' + + def _render_edge(self, edge: Edge) -> str: + if edge.type == EdgeType.TOOL: + return f'"{edge.source.id}" -> "{edge.target.id}" [style=dotted, penwidth=1.5];' + return f'"{edge.source.id}" -> "{edge.target.id}";' + + +class MermaidRenderer(GraphRenderer[str]): + """Renderer that outputs graphs in Mermaid flowchart syntax.""" + + def render(self, graph: Graph) -> str: + parts = ["graph TD\n"] + + # Add nodes with styles + for node in graph.nodes.values(): + parts.append(self._render_node(node)) + + # Add edges + for edge in graph.edges: + parts.append(self._render_edge(edge)) + + return "".join(parts) + + def save(self, rendered: str, filename: str) -> None: + """Save the rendered graph as a PNG file using mermaid.ink API. + + Args: + rendered (str): The Mermaid syntax string. + filename (str): The name of the file to save the graph as. + """ + # Encode the graph to base64 + graphbytes = rendered.encode("utf8") + base64_bytes = base64.urlsafe_b64encode(graphbytes) + base64_string = base64_bytes.decode("ascii") + + # Get the image from mermaid.ink + response = requests.get(f"https://mermaid.ink/img/{base64_string}") + response.raise_for_status() + + # Save the image directly from response content + with open(f"{filename}.png", "wb") as f: + f.write(response.content) + + def _render_node(self, node: Node) -> str: + # Map node types to Mermaid shapes + style_map = { + NodeType.START: ["(", ")", "lightblue"], + NodeType.END: ["(", ")", "lightblue"], + NodeType.AGENT: ["[", "]", "lightyellow"], + NodeType.TOOL: ["((", "))", "lightgreen"], + NodeType.HANDOFF: ["[", "]", "lightyellow"], + } + + start, end, color = style_map[node.type] + node_id = self._sanitize_id(node.id) + # Use sanitized ID and original label + return f"{node_id}{start}{node.label}{end}\nstyle {node_id} fill:{color}\n" + + def _render_edge(self, edge: Edge) -> str: + source = self._sanitize_id(edge.source.id) + target = self._sanitize_id(edge.target.id) + if edge.type == EdgeType.TOOL: + return f"{source} -.-> {target}\n" + return f"{source} --> {target}\n" + + def _sanitize_id(self, id: str) -> str: + """Sanitize node IDs to work with Mermaid's stricter ID requirements.""" + return id.replace(" ", "_").replace("-", "_") + + +class GraphView: + def __init__( + self, + rendered_graph: str, + renderer: GraphRenderer, + filename: Optional[str] = None, + ): + self.rendered_graph = rendered_graph + self.renderer = renderer + self.filename = filename + + def view(self) -> None: + """Opens the rendered graph in a separate window.""" + import os + import tempfile + import webbrowser + + if self.filename: + webbrowser.open(f"file://{os.path.abspath(self.filename)}.png") + else: + temp_dir = tempfile.gettempdir() + temp_path = os.path.join(temp_dir, next(tempfile._get_candidate_names())) + self.renderer.save(self.rendered_graph, temp_path) + webbrowser.open(f"file://{os.path.abspath(temp_path)}.png") + + +def draw_graph( + agent: Agent, + filename: str | None = None, + renderer: Literal["graphviz", "mermaid"] = "graphviz", +) -> GraphView: """ - Recursively generates the edges for the given agent and its handoffs in DOT format. + Draws the graph for the given agent using the specified renderer. Args: - agent (Agent): The agent for which the edges are to be generated. - parent (Agent, optional): The parent agent. Defaults to None. + agent (Agent): The agent for which the graph is to be drawn. + filename (str | None): The name of the file to save the graph as PNG. Defaults to None. + renderer (Literal["graphviz", "mermaid"]): The renderer to use. Defaults to "graphviz". Returns: - str: The DOT format string representing the edges. - """ - parts = [] + GraphView: A view object that can be used to display the graph. - if not parent: - parts.append(f'"__start__" -> "{agent.name}";') + Raises: + ValueError: If the specified renderer is not supported. + """ + builder = GraphBuilder() + graph = builder.build_from_agent(agent) - for tool in agent.tools: - parts.append(f""" - "{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5]; - "{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""") + if renderer == "graphviz": + renderer_instance = GraphvizRenderer() + elif renderer == "mermaid": + renderer_instance = MermaidRenderer() + else: + raise ValueError(f"Unsupported renderer: {renderer}") - for handoff in agent.handoffs: - if isinstance(handoff, Handoff): - parts.append(f""" - "{agent.name}" -> "{handoff.agent_name}";""") - if isinstance(handoff, Agent): - parts.append(f""" - "{agent.name}" -> "{handoff.name}";""") - parts.append(get_all_edges(handoff, agent)) + rendered = renderer_instance.render(graph) - if not agent.handoffs and not isinstance(agent, Tool): # type: ignore - parts.append(f'"{agent.name}" -> "__end__";') + if filename: + filename = filename.rsplit(".", 1)[0] + renderer_instance.save(rendered, filename) - return "".join(parts) + return GraphView(rendered, renderer_instance, filename) -def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source: +def get_main_graph(agent: Agent) -> str: """ - Draws the graph for the given agent and optionally saves it as a PNG file. + Generates the main graph structure in DOT format for the given agent. Args: - agent (Agent): The agent for which the graph is to be drawn. - filename (str): The name of the file to save the graph as a PNG. + agent (Agent): The agent for which the graph is to be generated. Returns: - graphviz.Source: The graphviz Source object representing the graph. + str: The DOT format string representing the graph. + + Deprecated: + This function is deprecated. Use GraphBuilder and GraphvizRenderer instead. """ - dot_code = get_main_graph(agent) - graph = graphviz.Source(dot_code) + warnings.warn( + "get_main_graph is deprecated. Use GraphBuilder and GraphvizRenderer instead.", + DeprecationWarning, + stacklevel=2, + ) + builder = GraphBuilder() + renderer = GraphvizRenderer() + graph = builder.build_from_agent(agent) + return renderer.render(graph) - if filename: - graph.render(filename, format="png") - return graph +def get_all_nodes( + agent: Agent, parent: Optional[Agent] = None, visited: Optional[set[int]] = None +) -> str: + """ + Recursively generates the nodes for the given agent and its handoffs in DOT format. + + Deprecated: + This function is deprecated. Use GraphBuilder and GraphvizRenderer instead. + """ + warnings.warn( + "get_all_nodes is deprecated. Use GraphBuilder and GraphvizRenderer instead.", + DeprecationWarning, + stacklevel=2, + ) + builder = GraphBuilder() + renderer = GraphvizRenderer() + graph = builder.build_from_agent(agent) + return "\n".join(renderer._render_node(node) for node in graph.nodes.values()) + + +def get_all_edges( + agent: Agent, parent: Optional[Agent] = None, visited: Optional[set[int]] = None +) -> str: + """ + Recursively generates the edges for the given agent and its handoffs in DOT format. + + Deprecated: + This function is deprecated. Use GraphBuilder and GraphvizRenderer instead. + """ + warnings.warn( + "get_all_edges is deprecated. Use GraphBuilder and GraphvizRenderer instead.", + DeprecationWarning, + stacklevel=2, + ) + builder = GraphBuilder() + renderer = GraphvizRenderer() + graph = builder.build_from_agent(agent) + return "\n".join(renderer._render_edge(edge) for edge in graph.edges) diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 6aa86774..aa00c81c 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -1,10 +1,19 @@ -from unittest.mock import Mock +import dataclasses +from unittest.mock import Mock, patch -import graphviz # type: ignore import pytest from agents import Agent from agents.extensions.visualization import ( + Edge, + EdgeType, + Graph, + GraphBuilder, + GraphView, + GraphvizRenderer, + MermaidRenderer, + Node, + NodeType, draw_graph, get_all_edges, get_all_nodes, @@ -31,106 +40,441 @@ def mock_agent(): return agent +@pytest.fixture +def mock_recursive_agents(): + agent1 = Mock(spec=Agent) + agent1.name = "Agent1" + agent1.tools = [] + agent2 = Mock(spec=Agent) + agent2.name = "Agent2" + agent2.tools = [] + agent1.handoffs = [agent2] + agent2.handoffs = [agent1] + return agent1 + + +# Tests for the new graph abstraction +def test_graph_builder(mock_agent): + builder = GraphBuilder() + graph = builder.build_from_agent(mock_agent) + + # Check nodes + assert "__start__" in graph.nodes + assert "__end__" in graph.nodes + + # Find nodes by name + agent_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + tool1_node = next(node for node in graph.nodes.values() if node.label == "Tool1") + tool2_node = next(node for node in graph.nodes.values() if node.label == "Tool2") + handoff_node = next(node for node in graph.nodes.values() if node.label == "Handoff1") + + # Check node types + assert graph.nodes["__start__"].type == NodeType.START + assert graph.nodes["__end__"].type == NodeType.END + assert agent_node.type == NodeType.AGENT + assert tool1_node.type == NodeType.TOOL + assert tool2_node.type == NodeType.TOOL + assert handoff_node.type == NodeType.HANDOFF + + # Check edges + start_node = graph.nodes["__start__"] + + start_to_agent = Edge(start_node, agent_node, EdgeType.HANDOFF) + agent_to_tool1 = Edge(agent_node, tool1_node, EdgeType.TOOL) + tool1_to_agent = Edge(tool1_node, agent_node, EdgeType.TOOL) + agent_to_tool2 = Edge(agent_node, tool2_node, EdgeType.TOOL) + tool2_to_agent = Edge(tool2_node, agent_node, EdgeType.TOOL) + agent_to_handoff = Edge(agent_node, handoff_node, EdgeType.HANDOFF) + + assert any( + e.source.id == start_to_agent.source.id and e.target.id == start_to_agent.target.id + for e in graph.edges + ) + assert any( + e.source.id == agent_to_tool1.source.id and e.target.id == agent_to_tool1.target.id + for e in graph.edges + ) + assert any( + e.source.id == tool1_to_agent.source.id and e.target.id == tool1_to_agent.target.id + for e in graph.edges + ) + assert any( + e.source.id == agent_to_tool2.source.id and e.target.id == agent_to_tool2.target.id + for e in graph.edges + ) + assert any( + e.source.id == tool2_to_agent.source.id and e.target.id == tool2_to_agent.target.id + for e in graph.edges + ) + assert any( + e.source.id == agent_to_handoff.source.id and e.target.id == agent_to_handoff.target.id + for e in graph.edges + ) + + +def test_graphviz_renderer(mock_agent): + builder = GraphBuilder() + graph = builder.build_from_agent(mock_agent) + renderer = GraphvizRenderer() + dot_code = renderer.render(graph) + + assert "digraph G" in dot_code + assert "graph [splines=true];" in dot_code + assert 'node [fontname="Arial"];' in dot_code + assert "edge [penwidth=1.5];" in dot_code + + # Find nodes by name in rendered output + agent_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + tool1_node = next(node for node in graph.nodes.values() if node.label == "Tool1") + tool2_node = next(node for node in graph.nodes.values() if node.label == "Tool2") + handoff_node = next(node for node in graph.nodes.values() if node.label == "Handoff1") + + # Check node definitions in dot code + agent_style = ( + f'"{agent_node.id}" [label="Agent1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" + ) + assert agent_style in dot_code + tool1_style = ( + f'"{tool1_node.id}" [label="Tool1", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" + ) + assert tool1_style in dot_code + tool2_style = ( + f'"{tool2_node.id}" [label="Tool2", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" + ) + assert tool2_style in dot_code + handoff_style = ( + f'"{handoff_node.id}" [label="Handoff1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" + ) + assert handoff_style in dot_code + + +def test_recursive_graph_builder(mock_recursive_agents): + builder = GraphBuilder() + graph = builder.build_from_agent(mock_recursive_agents) + + # Find nodes by name + agent1_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + agent2_node = next(node for node in graph.nodes.values() if node.label == "Agent2") + + # Check node types + assert agent1_node.type == NodeType.AGENT + assert agent2_node.type == NodeType.AGENT + + # Check edges + agent1_to_agent2 = Edge(agent1_node, agent2_node, EdgeType.HANDOFF) + agent2_to_agent1 = Edge(agent2_node, agent1_node, EdgeType.HANDOFF) + + assert any( + e.source.id == agent1_to_agent2.source.id and e.target.id == agent1_to_agent2.target.id + for e in graph.edges + ) + assert any( + e.source.id == agent2_to_agent1.source.id and e.target.id == agent2_to_agent1.target.id + for e in graph.edges + ) + + +def test_graph_validation(): + graph = Graph() + + # Test adding valid nodes and edges + node1 = Node("1", "Node 1", NodeType.AGENT) + node2 = Node("2", "Node 2", NodeType.TOOL) + graph.add_node(node1) + graph.add_node(node2) + + valid_edge = Edge(node1, node2, EdgeType.TOOL) + graph.add_edge(valid_edge) + + # Test adding edge with non-existent source + node3 = Node("3", "Node 3", NodeType.TOOL) + invalid_edge1 = Edge(node3, node2, EdgeType.TOOL) + with pytest.raises(ValueError, match="Source node '3' does not exist in the graph"): + graph.add_edge(invalid_edge1) + + # Test adding edge with non-existent target + invalid_edge2 = Edge(node1, node3, EdgeType.TOOL) + with pytest.raises(ValueError, match="Target node '3' does not exist in the graph"): + graph.add_edge(invalid_edge2) + + # Test helper methods + assert graph.has_node("1") + assert graph.has_node("2") + assert not graph.has_node("3") + + assert graph.get_node("1") == node1 + assert graph.get_node("2") == node2 + assert graph.get_node("3") is None + + +def test_node_immutability(): + node = Node("1", "Node 1", NodeType.AGENT) + with pytest.raises(dataclasses.FrozenInstanceError): + node.id = "2" + with pytest.raises(dataclasses.FrozenInstanceError): + node.label = "Node 2" + with pytest.raises(dataclasses.FrozenInstanceError): + node.type = NodeType.TOOL + + +def test_edge_immutability(): + edge = Edge("1", "2", EdgeType.TOOL) + with pytest.raises(dataclasses.FrozenInstanceError): + edge.source = "3" + with pytest.raises(dataclasses.FrozenInstanceError): + edge.target = "3" + with pytest.raises(dataclasses.FrozenInstanceError): + edge.type = EdgeType.HANDOFF + + +def test_draw_graph_with_invalid_renderer(mock_agent): + with pytest.raises(ValueError, match="Unsupported renderer: invalid"): + draw_graph(mock_agent, renderer="invalid") + + +def test_draw_graph_default_renderer(mock_agent): + result = draw_graph(mock_agent) + assert isinstance(result, GraphView) + assert "digraph G" in result.rendered_graph + + +def test_draw_graph_with_filename(mock_agent, tmp_path): + filename = tmp_path / "test_graph" + result = draw_graph(mock_agent, filename=str(filename)) + assert isinstance(result, GraphView) + assert "digraph G" in result.rendered_graph + assert (tmp_path / "test_graph.png").exists() + + +def test_draw_graph_with_graphviz(mock_agent): + result = draw_graph(mock_agent, renderer="graphviz") + assert isinstance(result, GraphView) + assert "digraph G" in result.rendered_graph + assert "graph [splines=true];" in result.rendered_graph + assert 'node [fontname="Arial"];' in result.rendered_graph + assert "edge [penwidth=1.5];" in result.rendered_graph + + # Get the graph to find node IDs + builder = GraphBuilder() + graph = builder.build_from_agent(mock_agent) + agent_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + tool1_node = next(node for node in graph.nodes.values() if node.label == "Tool1") + tool2_node = next(node for node in graph.nodes.values() if node.label == "Tool2") + handoff_node = next(node for node in graph.nodes.values() if node.label == "Handoff1") + + # Check node definitions in dot code + agent_style = ( + f'"{agent_node.id}" [label="Agent1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" + ) + assert agent_style in result.rendered_graph + tool1_style = ( + f'"{tool1_node.id}" [label="Tool1", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" + ) + assert tool1_style in result.rendered_graph + tool2_style = ( + f'"{tool2_node.id}" [label="Tool2", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" + ) + assert tool2_style in result.rendered_graph + handoff_style = ( + f'"{handoff_node.id}" [label="Handoff1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" + ) + assert handoff_style in result.rendered_graph + + +def test_draw_graph_with_mermaid(mock_agent): + result = draw_graph(mock_agent, renderer="mermaid") + assert isinstance(result, GraphView) + assert "graph TD" in result.rendered_graph + + # Get the graph to find node IDs + builder = GraphBuilder() + graph = builder.build_from_agent(mock_agent) + agent_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + + assert f"{agent_node.id}[Agent1]" in result.rendered_graph + assert f"style {agent_node.id} fill:lightyellow" in result.rendered_graph + + +def test_draw_graph_with_filename_graphviz(mock_agent, tmp_path): + filename = tmp_path / "test_graph" + result = draw_graph(mock_agent, filename=str(filename), renderer="graphviz") + assert isinstance(result, GraphView) + assert "digraph G" in result.rendered_graph + assert (tmp_path / "test_graph.png").exists() + + +def test_draw_graph_with_filename_mermaid(mock_agent, tmp_path): + filename = tmp_path / "test_graph" + mock_response = Mock() + mock_response.content = b"mock image data" + mock_response.raise_for_status = Mock() + + with patch("requests.get", return_value=mock_response): + result = draw_graph(mock_agent, filename=str(filename), renderer="mermaid") + assert isinstance(result, GraphView) + assert "graph TD" in result.rendered_graph + assert (tmp_path / "test_graph.png").exists() + with open(tmp_path / "test_graph.png", "rb") as f: + assert f.read() == b"mock image data" + + +def test_draw_graph(mock_agent): + result = draw_graph(mock_agent) + assert isinstance(result, GraphView) + assert "digraph G" in result.rendered_graph + + +# Legacy function tests def test_get_main_graph(mock_agent): - result = get_main_graph(mock_agent) - print(result) + with pytest.warns(DeprecationWarning): + result = get_main_graph(mock_agent) assert "digraph G" in result assert "graph [splines=true];" in result assert 'node [fontname="Arial"];' in result assert "edge [penwidth=1.5];" in result - assert ( - '"__start__" [label="__start__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" in result - ) - assert ( - '"__end__" [label="__end__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" in result - ) - assert ( - '"Agent1" [label="Agent1", shape=box, style=filled, ' - "fillcolor=lightyellow, width=1.5, height=0.8];" in result + + # Get the graph to find node IDs + builder = GraphBuilder() + graph = builder.build_from_agent(mock_agent) + agent_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + tool1_node = next(node for node in graph.nodes.values() if node.label == "Tool1") + tool2_node = next(node for node in graph.nodes.values() if node.label == "Tool2") + handoff_node = next(node for node in graph.nodes.values() if node.label == "Handoff1") + + # Check node definitions in dot code + agent_style = ( + f'"{agent_node.id}" [label="Agent1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" ) - assert ( - '"Tool1" [label="Tool1", shape=ellipse, style=filled, ' - "fillcolor=lightgreen, width=0.5, height=0.3];" in result + assert agent_style in result + tool1_style = ( + f'"{tool1_node.id}" [label="Tool1", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" ) - assert ( - '"Tool2" [label="Tool2", shape=ellipse, style=filled, ' - "fillcolor=lightgreen, width=0.5, height=0.3];" in result + assert tool1_style in result + tool2_style = ( + f'"{tool2_node.id}" [label="Tool2", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" ) - assert ( - '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' - "fillcolor=lightyellow, width=1.5, height=0.8];" in result + assert tool2_style in result + handoff_style = ( + f'"{handoff_node.id}" [label="Handoff1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" ) + assert handoff_style in result def test_get_all_nodes(mock_agent): - result = get_all_nodes(mock_agent) - assert ( - '"__start__" [label="__start__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" in result - ) - assert ( - '"__end__" [label="__end__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" in result - ) - assert ( - '"Agent1" [label="Agent1", shape=box, style=filled, ' - "fillcolor=lightyellow, width=1.5, height=0.8];" in result + with pytest.warns(DeprecationWarning): + result = get_all_nodes(mock_agent) + + # Get the graph to find node IDs + builder = GraphBuilder() + graph = builder.build_from_agent(mock_agent) + agent_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + tool1_node = next(node for node in graph.nodes.values() if node.label == "Tool1") + tool2_node = next(node for node in graph.nodes.values() if node.label == "Tool2") + handoff_node = next(node for node in graph.nodes.values() if node.label == "Handoff1") + + # Check node definitions in dot code + agent_style = ( + f'"{agent_node.id}" [label="Agent1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" ) - assert ( - '"Tool1" [label="Tool1", shape=ellipse, style=filled, ' - "fillcolor=lightgreen, width=0.5, height=0.3];" in result + assert agent_style in result + tool1_style = ( + f'"{tool1_node.id}" [label="Tool1", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" ) - assert ( - '"Tool2" [label="Tool2", shape=ellipse, style=filled, ' - "fillcolor=lightgreen, width=0.5, height=0.3];" in result + assert tool1_style in result + tool2_style = ( + f'"{tool2_node.id}" [label="Tool2", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" ) - assert ( - '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' - "fillcolor=lightyellow, width=1.5, height=0.8];" in result + assert tool2_style in result + handoff_style = ( + f'"{handoff_node.id}" [label="Handoff1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" ) + assert handoff_style in result def test_get_all_edges(mock_agent): - result = get_all_edges(mock_agent) - assert '"__start__" -> "Agent1";' in result - assert '"Agent1" -> "__end__";' - assert '"Agent1" -> "Tool1" [style=dotted, penwidth=1.5];' in result - assert '"Tool1" -> "Agent1" [style=dotted, penwidth=1.5];' in result - assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result - assert '"Tool2" -> "Agent1" [style=dotted, penwidth=1.5];' in result - assert '"Agent1" -> "Handoff1";' in result + with pytest.warns(DeprecationWarning): + result = get_all_edges(mock_agent) + # Get the graph to find node IDs + builder = GraphBuilder() + graph = builder.build_from_agent(mock_agent) + start_node = graph.nodes["__start__"] + agent_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + tool1_node = next(node for node in graph.nodes.values() if node.label == "Tool1") + tool2_node = next(node for node in graph.nodes.values() if node.label == "Tool2") + handoff_node = next(node for node in graph.nodes.values() if node.label == "Handoff1") -def test_draw_graph(mock_agent): - graph = draw_graph(mock_agent) - assert isinstance(graph, graphviz.Source) - assert "digraph G" in graph.source - assert "graph [splines=true];" in graph.source - assert 'node [fontname="Arial"];' in graph.source - assert "edge [penwidth=1.5];" in graph.source - assert ( - '"__start__" [label="__start__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" in graph.source - ) - assert ( - '"__end__" [label="__end__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" in graph.source - ) - assert ( - '"Agent1" [label="Agent1", shape=box, style=filled, ' - "fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source - ) - assert ( - '"Tool1" [label="Tool1", shape=ellipse, style=filled, ' - "fillcolor=lightgreen, width=0.5, height=0.3];" in graph.source - ) - assert ( - '"Tool2" [label="Tool2", shape=ellipse, style=filled, ' - "fillcolor=lightgreen, width=0.5, height=0.3];" in graph.source - ) - assert ( - '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' - "fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source + # Check edge definitions + assert f'"{start_node.id}" -> "{agent_node.id}";' in result + assert f'"{agent_node.id}" -> "{tool1_node.id}" [style=dotted, penwidth=1.5];' in result + assert f'"{tool1_node.id}" -> "{agent_node.id}" [style=dotted, penwidth=1.5];' in result + assert f'"{agent_node.id}" -> "{tool2_node.id}" [style=dotted, penwidth=1.5];' in result + assert f'"{tool2_node.id}" -> "{agent_node.id}" [style=dotted, penwidth=1.5];' in result + assert f'"{agent_node.id}" -> "{handoff_node.id}";' in result + + +def test_recursive_handoff_loop(mock_recursive_agents): + with pytest.warns(DeprecationWarning): + dot = get_main_graph(mock_recursive_agents) + + # Get the graph to find node IDs + builder = GraphBuilder() + graph = builder.build_from_agent(mock_recursive_agents) + agent1_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + agent2_node = next(node for node in graph.nodes.values() if node.label == "Agent2") + + # Check node and edge definitions + agent1_style = ( + f'"{agent1_node.id}" [label="Agent1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" + ) + assert agent1_style in dot + agent2_style = ( + f'"{agent2_node.id}" [label="Agent2", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" ) + assert agent2_style in dot + assert f'"{agent1_node.id}" -> "{agent2_node.id}";' in dot + assert f'"{agent2_node.id}" -> "{agent1_node.id}";' in dot + + +def test_mermaid_renderer(mock_agent): + builder = GraphBuilder() + graph = builder.build_from_agent(mock_agent) + renderer = MermaidRenderer() + mermaid_code = renderer.render(graph) + + # Test flowchart header + assert "graph TD" in mermaid_code + + # Find nodes by name + agent_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + tool1_node = next(node for node in graph.nodes.values() if node.label == "Tool1") + tool2_node = next(node for node in graph.nodes.values() if node.label == "Tool2") + handoff_node = next(node for node in graph.nodes.values() if node.label == "Handoff1") + + # Test node rendering + assert f"{agent_node.id}[Agent1]" in mermaid_code + assert f"style {agent_node.id} fill:lightyellow" in mermaid_code + assert f"{tool1_node.id}((Tool1))" in mermaid_code + assert f"style {tool1_node.id} fill:lightgreen" in mermaid_code + assert f"{tool2_node.id}((Tool2))" in mermaid_code + assert f"style {tool2_node.id} fill:lightgreen" in mermaid_code + assert f"{handoff_node.id}[Handoff1]" in mermaid_code + assert f"style {handoff_node.id} fill:lightyellow" in mermaid_code diff --git a/uv.lock b/uv.lock index e443c009..a6018eeb 100644 --- a/uv.lock +++ b/uv.lock @@ -1087,7 +1087,7 @@ wheels = [ [[package]] name = "openai-agents" -version = "0.0.7" +version = "0.0.8" source = { editable = "." } dependencies = [ { name = "griffe" },