diff --git a/defog/admin_methods.py b/defog/admin_methods.py index dd5a33a..6d56d7d 100644 --- a/defog/admin_methods.py +++ b/defog/admin_methods.py @@ -1,8 +1,14 @@ import json -from typing import Dict, List, Optional import pandas as pd +import psycopg2 +import mysql.connector +import pyodbc +import snowflake.connector from defog.metadata_cache import get_global_cache from defog.local_storage import LocalStorage +from google.cloud import bigquery +from typing import Dict, List, Optional +from databricks import sql def update_db_schema(self, path_to_csv, dev=False, temp=False): @@ -304,7 +310,6 @@ def create_empty_tables(self, dev: bool = False): try: if self.db_type == "postgres" or self.db_type == "redshift": - import psycopg2 conn = psycopg2.connect(**self.db_creds) cur = conn.cursor() @@ -314,7 +319,6 @@ def create_empty_tables(self, dev: bool = False): return True elif self.db_type == "mysql": - import mysql.connector conn = mysql.connector.connect(**self.db_creds) cur = conn.cursor() @@ -325,7 +329,6 @@ def create_empty_tables(self, dev: bool = False): return True elif self.db_type == "databricks": - from databricks import sql con = sql.connect(**self.db_creds) con.execute(ddl) @@ -334,7 +337,6 @@ def create_empty_tables(self, dev: bool = False): return True elif self.db_type == "snowflake": - import snowflake.connector conn = snowflake.connector.connect( user=self.db_creds["user"], @@ -352,7 +354,6 @@ def create_empty_tables(self, dev: bool = False): return True elif self.db_type == "bigquery": - from google.cloud import bigquery client = bigquery.Client.from_service_account_json( self.db_creds["json_key_path"] @@ -362,7 +363,6 @@ def create_empty_tables(self, dev: bool = False): return True elif self.db_type == "sqlserver": - import pyodbc if self.db_creds["database"] != "": connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={self.db_creds['server']};DATABASE={self.db_creds['database']};UID={self.db_creds['user']};PWD={self.db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;" diff --git a/defog/async_admin_methods.py b/defog/async_admin_methods.py index 5a5a8f5..b3e970a 100644 --- a/defog/async_admin_methods.py +++ b/defog/async_admin_methods.py @@ -1,9 +1,15 @@ import json -from typing import Dict, List, Optional +import pyodbc import pandas as pd import asyncio +import psycopg2 +import mysql.connector +import snowflake.connector from defog.metadata_cache import get_global_cache from defog.local_storage import LocalStorage +from typing import Dict, List, Optional +from databricks import sql +from google.cloud import bigquery async def update_db_schema(self, path_to_csv, dev=False, temp=False): @@ -325,7 +331,6 @@ async def create_empty_tables(self, dev: bool = False): # The database operations would need to be made async, but for now we'll use asyncio.to_thread try: if self.db_type == "postgres" or self.db_type == "redshift": - import psycopg2 def execute_postgres(): conn = psycopg2.connect(**self.db_creds) @@ -338,7 +343,6 @@ def execute_postgres(): return await asyncio.to_thread(execute_postgres) elif self.db_type == "mysql": - import mysql.connector def execute_mysql(): conn = mysql.connector.connect(**self.db_creds) @@ -352,7 +356,6 @@ def execute_mysql(): return await asyncio.to_thread(execute_mysql) elif self.db_type == "databricks": - from databricks import sql def execute_databricks(): con = sql.connect(**self.db_creds) @@ -364,7 +367,6 @@ def execute_databricks(): return await asyncio.to_thread(execute_databricks) elif self.db_type == "snowflake": - import snowflake.connector def execute_snowflake(): conn = snowflake.connector.connect( @@ -385,7 +387,6 @@ def execute_snowflake(): return await asyncio.to_thread(execute_snowflake) elif self.db_type == "bigquery": - from google.cloud import bigquery def execute_bigquery(): client = bigquery.Client.from_service_account_json( @@ -398,7 +399,6 @@ def execute_bigquery(): return await asyncio.to_thread(execute_bigquery) elif self.db_type == "sqlserver": - import pyodbc def execute_sqlserver(): if self.db_creds["database"] != "": diff --git a/defog/local_metadata_extractor.py b/defog/local_metadata_extractor.py index 8cb19e6..92e8c31 100644 --- a/defog/local_metadata_extractor.py +++ b/defog/local_metadata_extractor.py @@ -4,7 +4,8 @@ """ from typing import Dict, List, Optional, Any - +from defog import Defog +from defog import AsyncDefog def extract_metadata_from_db( db_type: str, @@ -26,7 +27,6 @@ def extract_metadata_from_db( Returns: Dictionary mapping table names to column metadata """ - from defog import Defog # Create instance with the provided credentials temp_defog = Defog(api_key=api_key, db_type=db_type, db_creds=db_creds) @@ -72,7 +72,6 @@ async def extract_metadata_from_db_async( Returns: Dictionary mapping table names to column metadata """ - from defog import AsyncDefog # Create async instance with the provided credentials temp_defog = AsyncDefog(api_key=api_key, db_type=db_type, db_creds=db_creds) diff --git a/defog/local_storage.py b/defog/local_storage.py index 7ffdb4e..c9f51d1 100644 --- a/defog/local_storage.py +++ b/defog/local_storage.py @@ -8,6 +8,7 @@ import datetime import portalocker import re +import hashlib class LocalStorage: @@ -32,8 +33,7 @@ def _get_project_id(self, api_key: Optional[str] = None, db_type: str = "") -> s """Generate a project ID based on db_type or api_key for backward compatibility""" if api_key: # Use hash of api_key for backward compatibility - import hashlib - + return hashlib.sha256(api_key.encode()).hexdigest()[:16] elif db_type: # Validate db_type to prevent path traversal diff --git a/defog/mcp_server.py b/defog/mcp_server.py index 7226c0f..7bc3464 100644 --- a/defog/mcp_server.py +++ b/defog/mcp_server.py @@ -10,6 +10,8 @@ import os import httpx import aiofiles +import asyncio +import argparse # we use fastmcp 2.0 provider instead of the fastmcp provided by mcp # this is because this version makes it easier to change multiple variables, like the port @@ -734,8 +736,6 @@ def run_server(transport=None, port=None): # List all registered tools and resources try: - import asyncio - # List tools tools = asyncio.run(mcp._list_tools()) tool_names = [tool.name for tool in tools] @@ -769,8 +769,6 @@ def run_server(transport=None, port=None): if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser( description="Defog MCP Server - Provides tools for SQL queries, code interpretation, and more" ) diff --git a/tests/conftest.py b/tests/conftest.py index 45f82a9..4598946 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import os import pytest import logging +import sys # Set up colored logging for test skips @@ -32,7 +33,6 @@ def format(self, record): def _should_log_skips() -> bool: """Check if we should log test skips based on pytest verbosity.""" - import sys # Check for pytest verbosity flags verbose_flags = ["-v", "--verbose", "-s", "--capture=no"] diff --git a/tests/test_citations.py b/tests/test_citations.py index eacbb46..8a16e22 100644 --- a/tests/test_citations.py +++ b/tests/test_citations.py @@ -4,7 +4,7 @@ from defog.llm.citations import citations_tool from defog.llm.llm_providers import LLMProvider from tests.conftest import skip_if_no_api_key - +import asyncio class TestCitations(unittest.IsolatedAsyncioTestCase): def setUp(self): @@ -144,7 +144,6 @@ def test_unsupported_provider(self): with self.assertRaises(ValueError) as context: # Using asyncio.run since this should fail immediately - import asyncio asyncio.run( citations_tool( diff --git a/tests/test_code_interp.py b/tests/test_code_interp.py index 264b1f9..f869612 100644 --- a/tests/test_code_interp.py +++ b/tests/test_code_interp.py @@ -1,13 +1,12 @@ import warnings - -warnings.filterwarnings("ignore") - import unittest import pytest from defog.llm.code_interp import code_interpreter_tool from defog.llm.llm_providers import LLMProvider from tests.conftest import skip_if_no_api_key +warnings.filterwarnings("ignore") + class TestCodeInterp(unittest.IsolatedAsyncioTestCase): def setUp(self): diff --git a/tests/test_image_support.py b/tests/test_image_support.py index 863359f..fa5d16f 100644 --- a/tests/test_image_support.py +++ b/tests/test_image_support.py @@ -12,7 +12,12 @@ from PIL import Image, ImageDraw from typing import Optional from pydantic import BaseModel, Field +import base64 +from defog.llm.providers.gemini_provider import GeminiProvider +from defog.llm.providers.anthropic_provider import AnthropicProvider +from defog.llm.providers.openai_provider import OpenAIProvider +from defog.llm.providers.deepseek_provider import DeepSeekProvider from defog.llm.utils_image_support import ( detect_image_in_result, process_tool_results_with_images, @@ -174,7 +179,6 @@ class TestProviderMessageCreation: def test_anthropic_image_message_creation(self): """Test that Anthropic provider creates correct image messages.""" - from defog.llm.providers.anthropic_provider import AnthropicProvider provider = AnthropicProvider(api_key="test") image_data = create_test_image() @@ -198,7 +202,6 @@ def test_anthropic_image_message_creation(self): def test_openai_image_message_creation(self): """Test that OpenAI provider creates correct image messages.""" - from defog.llm.providers.openai_provider import OpenAIProvider provider = OpenAIProvider(api_key="test") image_data = create_test_image() @@ -222,7 +225,6 @@ def test_openai_image_message_creation(self): def test_gemini_image_message_creation(self): """Test that Gemini provider creates correct image messages.""" - from defog.llm.providers.gemini_provider import GeminiProvider provider = GeminiProvider(api_key="test") image_data = create_test_image() @@ -236,9 +238,6 @@ def test_gemini_image_message_creation(self): assert hasattr(msg.parts[1], "inline_data") # Image part assert msg.parts[1].inline_data.mime_type == "image/png" - # Verify the bytes data is correct - import base64 - expected_bytes = base64.b64decode(image_data) assert msg.parts[1].inline_data.data == expected_bytes @@ -370,7 +369,6 @@ class TestProviderImageMessageValidation: def test_anthropic_provider_invalid_image(self): """Test Anthropic provider with invalid image data.""" - from defog.llm.providers.anthropic_provider import AnthropicProvider provider = AnthropicProvider(api_key="test") @@ -381,7 +379,6 @@ def test_anthropic_provider_invalid_image(self): def test_openai_provider_invalid_image(self): """Test OpenAI provider with invalid image data.""" - from defog.llm.providers.openai_provider import OpenAIProvider provider = OpenAIProvider(api_key="test") @@ -392,7 +389,6 @@ def test_openai_provider_invalid_image(self): def test_gemini_provider_invalid_image(self): """Test Gemini provider with invalid image data.""" - from defog.llm.providers.gemini_provider import GeminiProvider provider = GeminiProvider(api_key="test") @@ -403,7 +399,6 @@ def test_gemini_provider_invalid_image(self): def test_deepseek_provider_handles_invalid_image(self): """Test DeepSeek provider gracefully handles invalid images.""" - from defog.llm.providers.deepseek_provider import DeepSeekProvider provider = DeepSeekProvider(api_key="test") @@ -414,7 +409,6 @@ def test_deepseek_provider_handles_invalid_image(self): def test_provider_partial_validation_success(self): """Test provider behavior with mixed valid/invalid images.""" - from defog.llm.providers.anthropic_provider import AnthropicProvider provider = AnthropicProvider(api_key="test") valid_image = create_test_image() @@ -429,7 +423,6 @@ def test_provider_partial_validation_success(self): def test_openai_provider_invalid_image_detail(self): """Test OpenAI provider with invalid image_detail parameter.""" - from defog.llm.providers.openai_provider import OpenAIProvider provider = OpenAIProvider(api_key="test") valid_image = create_test_image() diff --git a/tests/test_llm_post_response_hook.py b/tests/test_llm_post_response_hook.py index f21b61a..a74da46 100644 --- a/tests/test_llm_post_response_hook.py +++ b/tests/test_llm_post_response_hook.py @@ -9,10 +9,9 @@ from tests.conftest import skip_if_no_api_key from dotenv import load_dotenv - -load_dotenv() import logging +load_dotenv() logging.basicConfig(level=logging.INFO) # Mock response hook for testing