diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index af5669a1..5bcbe1ce 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -1,10 +1,15 @@ # ruff: noqa: B008 import argparse import asyncio +import csv +import io +import json import logging import os import signal import sys +from datetime import date, datetime, timedelta +from decimal import Decimal from enum import Enum from typing import Any from typing import List @@ -72,7 +77,108 @@ async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver]: def format_text_response(text: Any) -> ResponseType: """Format a text response.""" - return [types.TextContent(type="text", text=str(text))] + def json_serializer(obj): + """Custom JSON serializer for PostgreSQL/Python types.""" + # Handle datetime types + if isinstance(obj, (datetime, date)): + return obj.isoformat() + # Handle time (without date) + elif hasattr(obj, 'isoformat') and callable(obj.isoformat): + return obj.isoformat() + # Handle Decimal + elif isinstance(obj, Decimal): + return float(obj) + # Handle timedelta (PostgreSQL INTERVAL type) + elif isinstance(obj, timedelta): + return obj.total_seconds() # Return as seconds (number) + # Handle bytes (PostgreSQL BYTEA type) + elif isinstance(obj, (bytes, bytearray)): + import base64 + return base64.b64encode(obj).decode('ascii') # Base64 encode binary data + # Handle memoryview + elif isinstance(obj, memoryview): + import base64 + return base64.b64encode(obj.tobytes()).decode('ascii') + # Handle UUID + elif hasattr(obj, 'hex'): # UUID objects have a hex attribute + return str(obj) + # Default: convert to string + return str(obj) + + # Convert lists and dicts to JSON, everything else to string + if isinstance(text, (list, dict)): + text = json.dumps(text, default=json_serializer, ensure_ascii=False) + else: + text = str(text) + + return [types.TextContent(type="text", text=text)] + + +def format_csv_response(data: Any) -> ResponseType: + """Format a response as CSV.""" + def csv_value_converter(obj): + """Convert PostgreSQL/Python types to CSV-friendly strings.""" + # Handle datetime types + if isinstance(obj, (datetime, date)): + return obj.isoformat() + # Handle time (without date) + elif hasattr(obj, 'isoformat') and callable(obj.isoformat): + return obj.isoformat() + # Handle Decimal + elif isinstance(obj, Decimal): + return str(obj) # Keep full precision for CSV + # Handle timedelta (PostgreSQL INTERVAL type) + elif isinstance(obj, timedelta): + return str(obj.total_seconds()) # Return as seconds string + # Handle bytes (PostgreSQL BYTEA type) + elif isinstance(obj, (bytes, bytearray)): + import base64 + return base64.b64encode(obj).decode('ascii') + # Handle memoryview + elif isinstance(obj, memoryview): + import base64 + return base64.b64encode(obj.tobytes()).decode('ascii') + # Handle UUID + elif hasattr(obj, 'hex'): # UUID objects have a hex attribute + return str(obj) + # Handle None + elif obj is None: + return "" + # Handle lists/arrays - convert to pipe-separated string + elif isinstance(obj, list): + return "|".join(str(csv_value_converter(item)) for item in obj) + # Handle dicts - convert to JSON string + elif isinstance(obj, dict): + return json.dumps(obj) + # Default: convert to string + return str(obj) + + if not isinstance(data, list) or not data: + return [types.TextContent(type="text", text="")] + + # Create CSV output + output = io.StringIO() + writer = csv.writer(output) + + # Write header row (column names) + if isinstance(data[0], dict): + headers = list(data[0].keys()) + writer.writerow(headers) + + # Write data rows + for row in data: + converted_row = [csv_value_converter(row.get(header)) for header in headers] + writer.writerow(converted_row) + else: + # Handle non-dict data + writer.writerow(["value"]) + for item in data: + writer.writerow([csv_value_converter(item)]) + + csv_text = output.getvalue() + output.close() + + return [types.TextContent(type="text", text=csv_text)] def format_error_response(error: str) -> ResponseType: @@ -389,14 +495,26 @@ async def explain_query( # Query function declaration without the decorator - we'll add it dynamically based on access mode async def execute_sql( sql: str = Field(description="SQL to run", default="all"), + output_format: Literal["json", "csv"] = Field(description="Output format: 'json' (default) or 'csv'", default="json"), ) -> ResponseType: - """Executes a SQL query against the database.""" + """Executes a SQL query against the database and returns results in JSON or CSV format.""" try: sql_driver = await get_sql_driver() rows = await sql_driver.execute_query(sql) # type: ignore if rows is None: - return format_text_response("No results") - return format_text_response(list([r.cells for r in rows])) + if output_format == "csv": + return format_csv_response([]) + else: + return format_text_response("No results") + + # Convert rows to list of dictionaries + result_data = list([r.cells for r in rows]) + + # Format based on requested output format + if output_format == "csv": + return format_csv_response(result_data) + else: + return format_text_response(result_data) except Exception as e: logger.error(f"Error executing query: {e}") return format_error_response(str(e)) diff --git a/tests/unit/test_execute.py b/tests/unit/test_execute.py new file mode 100644 index 00000000..a06fb79c --- /dev/null +++ b/tests/unit/test_execute.py @@ -0,0 +1,49 @@ +"""Tests for execute_sql function with JSON and CSV output formats.""" + +import json +from datetime import datetime +from decimal import Decimal +from unittest.mock import AsyncMock, patch + +import pytest + +from postgres_mcp.server import execute_sql + + +class MockRow: + def __init__(self, cells): + self.cells = cells + + +@pytest.mark.asyncio +async def test_execute_sql_json_output(): + """Test execute_sql outputs valid JSON (not Python repr format).""" + mock_driver = AsyncMock() + mock_driver.execute_query.return_value = [ + MockRow({"id": 1, "salary": Decimal('50000.00'), "created_at": datetime(2023, 1, 1)}) + ] + + with patch('postgres_mcp.server.get_sql_driver', return_value=mock_driver): + result = await execute_sql("SELECT * FROM users") + + # Should return valid JSON + parsed = json.loads(result[0].text) + assert parsed[0]["salary"] == 50000.0 # Decimal -> float, not repr + assert parsed[0]["created_at"] == "2023-01-01T00:00:00" # ISO format + + +@pytest.mark.asyncio +async def test_execute_sql_csv_output(): + """Test execute_sql outputs CSV format.""" + mock_driver = AsyncMock() + mock_driver.execute_query.return_value = [ + MockRow({"id": 1, "name": "John", "salary": Decimal('50000.00')}) + ] + + with patch('postgres_mcp.server.get_sql_driver', return_value=mock_driver): + result = await execute_sql("SELECT * FROM users", output_format="csv") + + lines = result[0].text.strip().split('\n') + assert len(lines) == 2 # Header + data + assert "id" in lines[0] + assert "50000.00" in lines[1] # Decimal precision preserved \ No newline at end of file