Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions defog/admin_methods.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import json
from typing import Dict, List, Optional
import pandas as pd
import psycopg2
import mysql.connector
import pyodbc
import snowflake.connector
from defog.metadata_cache import get_global_cache
from defog.local_storage import LocalStorage
from google.cloud import bigquery
from typing import Dict, List, Optional
from databricks import sql


def update_db_schema(self, path_to_csv, dev=False, temp=False):
Expand Down Expand Up @@ -304,7 +310,6 @@ def create_empty_tables(self, dev: bool = False):

try:
if self.db_type == "postgres" or self.db_type == "redshift":
import psycopg2

conn = psycopg2.connect(**self.db_creds)
cur = conn.cursor()
Expand All @@ -314,7 +319,6 @@ def create_empty_tables(self, dev: bool = False):
return True

elif self.db_type == "mysql":
import mysql.connector

conn = mysql.connector.connect(**self.db_creds)
cur = conn.cursor()
Expand All @@ -325,7 +329,6 @@ def create_empty_tables(self, dev: bool = False):
return True

elif self.db_type == "databricks":
from databricks import sql

con = sql.connect(**self.db_creds)
con.execute(ddl)
Expand All @@ -334,7 +337,6 @@ def create_empty_tables(self, dev: bool = False):
return True

elif self.db_type == "snowflake":
import snowflake.connector

conn = snowflake.connector.connect(
user=self.db_creds["user"],
Expand All @@ -352,7 +354,6 @@ def create_empty_tables(self, dev: bool = False):
return True

elif self.db_type == "bigquery":
from google.cloud import bigquery

client = bigquery.Client.from_service_account_json(
self.db_creds["json_key_path"]
Expand All @@ -362,7 +363,6 @@ def create_empty_tables(self, dev: bool = False):
return True

elif self.db_type == "sqlserver":
import pyodbc

if self.db_creds["database"] != "":
connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={self.db_creds['server']};DATABASE={self.db_creds['database']};UID={self.db_creds['user']};PWD={self.db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;"
Expand Down
14 changes: 7 additions & 7 deletions defog/async_admin_methods.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import json
from typing import Dict, List, Optional
import pyodbc
import pandas as pd
import asyncio
import psycopg2
import mysql.connector
import snowflake.connector
from defog.metadata_cache import get_global_cache
from defog.local_storage import LocalStorage
from typing import Dict, List, Optional
from databricks import sql
from google.cloud import bigquery


async def update_db_schema(self, path_to_csv, dev=False, temp=False):
Expand Down Expand Up @@ -325,7 +331,6 @@ async def create_empty_tables(self, dev: bool = False):
# The database operations would need to be made async, but for now we'll use asyncio.to_thread
try:
if self.db_type == "postgres" or self.db_type == "redshift":
import psycopg2

def execute_postgres():
conn = psycopg2.connect(**self.db_creds)
Expand All @@ -338,7 +343,6 @@ def execute_postgres():
return await asyncio.to_thread(execute_postgres)

elif self.db_type == "mysql":
import mysql.connector

def execute_mysql():
conn = mysql.connector.connect(**self.db_creds)
Expand All @@ -352,7 +356,6 @@ def execute_mysql():
return await asyncio.to_thread(execute_mysql)

elif self.db_type == "databricks":
from databricks import sql

def execute_databricks():
con = sql.connect(**self.db_creds)
Expand All @@ -364,7 +367,6 @@ def execute_databricks():
return await asyncio.to_thread(execute_databricks)

elif self.db_type == "snowflake":
import snowflake.connector

def execute_snowflake():
conn = snowflake.connector.connect(
Expand All @@ -385,7 +387,6 @@ def execute_snowflake():
return await asyncio.to_thread(execute_snowflake)

elif self.db_type == "bigquery":
from google.cloud import bigquery

def execute_bigquery():
client = bigquery.Client.from_service_account_json(
Expand All @@ -398,7 +399,6 @@ def execute_bigquery():
return await asyncio.to_thread(execute_bigquery)

elif self.db_type == "sqlserver":
import pyodbc

def execute_sqlserver():
if self.db_creds["database"] != "":
Expand Down
5 changes: 2 additions & 3 deletions defog/local_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from typing import Dict, List, Optional, Any

from defog import Defog
from defog import AsyncDefog

def extract_metadata_from_db(
db_type: str,
Expand All @@ -26,7 +27,6 @@ def extract_metadata_from_db(
Returns:
Dictionary mapping table names to column metadata
"""
from defog import Defog

# Create instance with the provided credentials
temp_defog = Defog(api_key=api_key, db_type=db_type, db_creds=db_creds)
Expand Down Expand Up @@ -72,7 +72,6 @@ async def extract_metadata_from_db_async(
Returns:
Dictionary mapping table names to column metadata
"""
from defog import AsyncDefog

# Create async instance with the provided credentials
temp_defog = AsyncDefog(api_key=api_key, db_type=db_type, db_creds=db_creds)
Expand Down
4 changes: 2 additions & 2 deletions defog/local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import datetime
import portalocker
import re
import hashlib


class LocalStorage:
Expand All @@ -32,8 +33,7 @@ def _get_project_id(self, api_key: Optional[str] = None, db_type: str = "") -> s
"""Generate a project ID based on db_type or api_key for backward compatibility"""
if api_key:
# Use hash of api_key for backward compatibility
import hashlib


return hashlib.sha256(api_key.encode()).hexdigest()[:16]
elif db_type:
# Validate db_type to prevent path traversal
Expand Down
6 changes: 2 additions & 4 deletions defog/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import os
import httpx
import aiofiles
import asyncio
import argparse

# we use fastmcp 2.0 provider instead of the fastmcp provided by mcp
# this is because this version makes it easier to change multiple variables, like the port
Expand Down Expand Up @@ -734,8 +736,6 @@ def run_server(transport=None, port=None):

# List all registered tools and resources
try:
import asyncio

# List tools
tools = asyncio.run(mcp._list_tools())
tool_names = [tool.name for tool in tools]
Expand Down Expand Up @@ -769,8 +769,6 @@ def run_server(transport=None, port=None):


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(
description="Defog MCP Server - Provides tools for SQL queries, code interpretation, and more"
)
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import pytest
import logging
import sys


# Set up colored logging for test skips
Expand Down Expand Up @@ -32,7 +33,6 @@ def format(self, record):

def _should_log_skips() -> bool:
"""Check if we should log test skips based on pytest verbosity."""
import sys

# Check for pytest verbosity flags
verbose_flags = ["-v", "--verbose", "-s", "--capture=no"]
Expand Down
3 changes: 1 addition & 2 deletions tests/test_citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from defog.llm.citations import citations_tool
from defog.llm.llm_providers import LLMProvider
from tests.conftest import skip_if_no_api_key

import asyncio

class TestCitations(unittest.IsolatedAsyncioTestCase):
def setUp(self):
Expand Down Expand Up @@ -144,7 +144,6 @@ def test_unsupported_provider(self):

with self.assertRaises(ValueError) as context:
# Using asyncio.run since this should fail immediately
import asyncio

asyncio.run(
citations_tool(
Expand Down
5 changes: 2 additions & 3 deletions tests/test_code_interp.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import warnings

warnings.filterwarnings("ignore")

import unittest
import pytest
from defog.llm.code_interp import code_interpreter_tool
from defog.llm.llm_providers import LLMProvider
from tests.conftest import skip_if_no_api_key

warnings.filterwarnings("ignore")


class TestCodeInterp(unittest.IsolatedAsyncioTestCase):
def setUp(self):
Expand Down
17 changes: 5 additions & 12 deletions tests/test_image_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from PIL import Image, ImageDraw
from typing import Optional
from pydantic import BaseModel, Field
import base64

from defog.llm.providers.gemini_provider import GeminiProvider
from defog.llm.providers.anthropic_provider import AnthropicProvider
from defog.llm.providers.openai_provider import OpenAIProvider
from defog.llm.providers.deepseek_provider import DeepSeekProvider
from defog.llm.utils_image_support import (
detect_image_in_result,
process_tool_results_with_images,
Expand Down Expand Up @@ -174,7 +179,6 @@ class TestProviderMessageCreation:

def test_anthropic_image_message_creation(self):
"""Test that Anthropic provider creates correct image messages."""
from defog.llm.providers.anthropic_provider import AnthropicProvider

provider = AnthropicProvider(api_key="test")
image_data = create_test_image()
Expand All @@ -198,7 +202,6 @@ def test_anthropic_image_message_creation(self):

def test_openai_image_message_creation(self):
"""Test that OpenAI provider creates correct image messages."""
from defog.llm.providers.openai_provider import OpenAIProvider

provider = OpenAIProvider(api_key="test")
image_data = create_test_image()
Expand All @@ -222,7 +225,6 @@ def test_openai_image_message_creation(self):

def test_gemini_image_message_creation(self):
"""Test that Gemini provider creates correct image messages."""
from defog.llm.providers.gemini_provider import GeminiProvider

provider = GeminiProvider(api_key="test")
image_data = create_test_image()
Expand All @@ -236,9 +238,6 @@ def test_gemini_image_message_creation(self):
assert hasattr(msg.parts[1], "inline_data") # Image part
assert msg.parts[1].inline_data.mime_type == "image/png"

# Verify the bytes data is correct
import base64

expected_bytes = base64.b64decode(image_data)
assert msg.parts[1].inline_data.data == expected_bytes

Expand Down Expand Up @@ -370,7 +369,6 @@ class TestProviderImageMessageValidation:

def test_anthropic_provider_invalid_image(self):
"""Test Anthropic provider with invalid image data."""
from defog.llm.providers.anthropic_provider import AnthropicProvider

provider = AnthropicProvider(api_key="test")

Expand All @@ -381,7 +379,6 @@ def test_anthropic_provider_invalid_image(self):

def test_openai_provider_invalid_image(self):
"""Test OpenAI provider with invalid image data."""
from defog.llm.providers.openai_provider import OpenAIProvider

provider = OpenAIProvider(api_key="test")

Expand All @@ -392,7 +389,6 @@ def test_openai_provider_invalid_image(self):

def test_gemini_provider_invalid_image(self):
"""Test Gemini provider with invalid image data."""
from defog.llm.providers.gemini_provider import GeminiProvider

provider = GeminiProvider(api_key="test")

Expand All @@ -403,7 +399,6 @@ def test_gemini_provider_invalid_image(self):

def test_deepseek_provider_handles_invalid_image(self):
"""Test DeepSeek provider gracefully handles invalid images."""
from defog.llm.providers.deepseek_provider import DeepSeekProvider

provider = DeepSeekProvider(api_key="test")

Expand All @@ -414,7 +409,6 @@ def test_deepseek_provider_handles_invalid_image(self):

def test_provider_partial_validation_success(self):
"""Test provider behavior with mixed valid/invalid images."""
from defog.llm.providers.anthropic_provider import AnthropicProvider

provider = AnthropicProvider(api_key="test")
valid_image = create_test_image()
Expand All @@ -429,7 +423,6 @@ def test_provider_partial_validation_success(self):

def test_openai_provider_invalid_image_detail(self):
"""Test OpenAI provider with invalid image_detail parameter."""
from defog.llm.providers.openai_provider import OpenAIProvider

provider = OpenAIProvider(api_key="test")
valid_image = create_test_image()
Expand Down
3 changes: 1 addition & 2 deletions tests/test_llm_post_response_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
from tests.conftest import skip_if_no_api_key

from dotenv import load_dotenv

load_dotenv()
import logging

load_dotenv()
logging.basicConfig(level=logging.INFO)

# Mock response hook for testing
Expand Down