From fd139e4097fe43c1670aa03be9a9c7e41af52079 Mon Sep 17 00:00:00 2001 From: ShawnZheng Date: Fri, 6 Jun 2025 10:32:51 +0800 Subject: [PATCH] [Fix]: Write a custom JSON Encoder to avoid serialization problems & Add 'try except' to avoid lagging problems & Fix variable 'success' which is not used --- src/mcp_server_milvus/server.py | 331 +++++++++++++++++++++----------- 1 file changed, 214 insertions(+), 117 deletions(-) diff --git a/src/mcp_server_milvus/server.py b/src/mcp_server_milvus/server.py index 713274c..cd5a232 100644 --- a/src/mcp_server_milvus/server.py +++ b/src/mcp_server_milvus/server.py @@ -13,8 +13,23 @@ ) +class CustomJSONEncoder(json.JSONEncoder): + def default(self, obj): + try: + if hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, dict)): + return list(obj) + elif hasattr(obj, '__dict__'): + return obj.__dict__ + else: + return str(obj) + except Exception: + return str(obj) + + class MilvusConnector: - def __init__(self, uri: str, token: Optional[str] = None, db_name: Optional[str] = "default"): + def __init__( + self, uri: str, token: Optional[str] = None, db_name: Optional[str] = "default" + ): self.uri = uri self.token = token self.client = MilvusClient(uri=uri, token=token, db_name=db_name) @@ -222,7 +237,9 @@ async def create_collection( except Exception as e: raise ValueError(f"Failed to create collection: {str(e)}") - async def insert_data(self, collection_name: str, data: list[dict[str, Any]]) -> dict[str, Any]: + async def insert_data( + self, collection_name: str, data: list[dict[str, Any]] + ) -> dict[str, Any]: """ Insert data into a collection. @@ -236,7 +253,9 @@ async def insert_data(self, collection_name: str, data: list[dict[str, Any]]) -> except Exception as e: raise ValueError(f"Insert failed: {str(e)}") - async def delete_entities(self, collection_name: str, filter_expr: str) -> dict[str, Any]: + async def delete_entities( + self, collection_name: str, filter_expr: str + ) -> dict[str, Any]: """ Delete entities from a collection based on filter expression. @@ -245,7 +264,9 @@ async def delete_entities(self, collection_name: str, filter_expr: str) -> dict[ filter_expr: Filter expression to select entities to delete """ try: - result = self.client.delete(collection_name=collection_name, expr=filter_expr) + result = self.client.delete( + collection_name=collection_name, expr=filter_expr + ) return result except Exception as e: raise ValueError(f"Delete failed: {str(e)}") @@ -357,16 +378,22 @@ async def bulk_insert( total_records = len(data[field_names[0]]) for i in range(0, total_records, batch_size): - batch_data = {field: data[field][i : i + batch_size] for field in field_names} + batch_data = { + field: data[field][i : i + batch_size] for field in field_names + } - result = self.client.insert(collection_name=collection_name, data=batch_data) + result = self.client.insert( + collection_name=collection_name, data=batch_data + ) results.append(result) return results except Exception as e: raise ValueError(f"Bulk insert failed: {str(e)}") - async def load_collection(self, collection_name: str, replica_number: int = 1) -> bool: + async def load_collection( + self, collection_name: str, replica_number: int = 1 + ) -> bool: """ Load a collection into memory for search and query. @@ -407,7 +434,9 @@ async def get_query_segment_info(self, collection_name: str) -> dict[str, Any]: except Exception as e: raise ValueError(f"Failed to get query segment info: {str(e)}") - async def upsert_data(self, collection_name: str, data: dict[str, list[Any]]) -> dict[str, Any]: + async def upsert_data( + self, collection_name: str, data: dict[str, list[Any]] + ) -> dict[str, Any]: """ Upsert data into a collection (insert or update if exists). @@ -438,7 +467,9 @@ async def get_index_info( except Exception as e: raise ValueError(f"Failed to get index info: {str(e)}") - async def get_collection_loading_progress(self, collection_name: str) -> dict[str, Any]: + async def get_collection_loading_progress( + self, collection_name: str + ) -> dict[str, Any]: """ Get the loading progress of a collection. @@ -465,7 +496,11 @@ async def use_database(self, db_name: str) -> bool: """ try: # Create a new client with the specified database - self.client = MilvusClient(uri=self.uri, token=self.token, db_name=db_name) + self.client = MilvusClient( + uri=self.uri, + token=self.token, + db_name=db_name + ) return True except Exception as e: raise ValueError(f"Failed to switch database: {str(e)}") @@ -515,28 +550,36 @@ async def milvus_text_search( output_fields: Fields to include in results drop_ratio: Proportion of low-frequency terms to ignore (0.0-1.0) """ - connector = ctx.request_context.lifespan_context.connector - results = await connector.search_collection( - collection_name=collection_name, - query_text=query_text, - limit=limit, - output_fields=output_fields, - drop_ratio=drop_ratio, - ) - - output = f"Search results for '{query_text}' in collection '{collection_name}':\n\n" - for result in results: - output += f"{result}\n\n" - - return output + try: + connector = ctx.request_context.lifespan_context.connector + results = await connector.search_collection( + collection_name=collection_name, + query_text=query_text, + limit=limit, + output_fields=output_fields, + drop_ratio=drop_ratio, + ) + + output = f"Search results for '{query_text}' in collection '{collection_name}':\n\n" + for result in results: + output += f"{result}\n\n" + + return output + except ValueError as e: + return f"Text search error: {str(e)}" + except Exception as e: + return f"Unexpected text search error: {str(e)}" @mcp.tool() async def milvus_list_collections(ctx: Context) -> str: """List all collections in the database.""" - connector = ctx.request_context.lifespan_context.connector - collections = await connector.list_collections() - return f"Collections in database:\n{', '.join(collections)}" + try: + connector = ctx.request_context.lifespan_context.connector + collections = await connector.list_collections() + return f"Collections in database:\n{', '.join(collections)}" + except Exception as e: + return f"Error listing collections: {str(e)}" @mcp.tool() @@ -556,19 +599,24 @@ async def milvus_query( output_fields: Fields to include in results limit: Maximum number of results """ - connector = ctx.request_context.lifespan_context.connector - results = await connector.query_collection( - collection_name=collection_name, - filter_expr=filter_expr, - output_fields=output_fields, - limit=limit, - ) + try: + connector = ctx.request_context.lifespan_context.connector + results = await connector.query_collection( + collection_name=collection_name, + filter_expr=filter_expr, + output_fields=output_fields, + limit=limit, + ) - output = f"Query results for '{filter_expr}' in collection '{collection_name}':\n\n" - for result in results: - output += f"{result}\n\n" + output = f"Query results for '{filter_expr}' in collection '{collection_name}':\n\n" + for result in results: + output += f"{result}\n\n" - return output + return output + except ValueError as e: + return f"Query error: {str(e)}" + except Exception as e: + return f"Unexpected query error: {str(e)}" @mcp.tool() @@ -594,22 +642,27 @@ async def milvus_vector_search( metric_type: Distance metric (COSINE, L2, IP) filter_expr: Optional filter expression """ - connector = ctx.request_context.lifespan_context.connector - results = await connector.vector_search( - collection_name=collection_name, - vector=vector, - vector_field=vector_field, - limit=limit, - output_fields=output_fields, - metric_type=metric_type, - filter_expr=filter_expr, - ) - - output = f"Vector search results for '{collection_name}':\n\n" - for result in results: - output += f"{result}\n\n" - - return output + try: + connector = ctx.request_context.lifespan_context.connector + results = await connector.vector_search( + collection_name=collection_name, + vector=vector, + vector_field=vector_field, + limit=limit, + output_fields=output_fields, + metric_type=metric_type, + filter_expr=filter_expr, + ) + + output = f"Vector search results for '{collection_name}':\n\n" + for result in results: + output += f"{result}\n\n" + + return output + except ValueError as e: + return f"Vector search error: {str(e)}" + except Exception as e: + return f"Unexpected vector search error: {str(e)}" @mcp.tool() @@ -637,24 +690,29 @@ async def milvus_hybrid_search( output_fields: Fields to return in results filter_expr: Optional filter expression """ - connector = ctx.request_context.lifespan_context.connector - - results = await connector.hybrid_search( - collection_name=collection_name, - query_text=query_text, - text_field=text_field, - vector=vector, - vector_field=vector_field, - limit=limit, - output_fields=output_fields, - filter_expr=filter_expr, - ) - - output = f"Hybrid search results for text '{query_text}' in '{collection_name}':\n\n" - for result in results: - output += f"{result}\n\n" - - return output + try: + connector = ctx.request_context.lifespan_context.connector + + results = await connector.hybrid_search( + collection_name=collection_name, + query_text=query_text, + text_field=text_field, + vector=vector, + vector_field=vector_field, + limit=limit, + output_fields=output_fields, + filter_expr=filter_expr, + ) + + output = (f"Hybrid search results for text '{query_text}' in '{collection_name}':\n\n") + for result in results: + output += f"{result}\n\n" + + return output + except ValueError as e: + return f"Hybrid search error: {str(e)}" + except Exception as e: + return f"Unexpected hybrid search error: {str(e)}" @mcp.tool() @@ -672,14 +730,19 @@ async def milvus_create_collection( collection_schema: Collection schema definition index_params: Optional index parameters """ - connector = ctx.request_context.lifespan_context.connector - success = await connector.create_collection( - collection_name=collection_name, - schema=collection_schema, - index_params=index_params, - ) + try: + connector = ctx.request_context.lifespan_context.connector + await connector.create_collection( + collection_name=collection_name, + schema=collection_schema, + index_params=index_params, + ) - return f"Collection '{collection_name}' created successfully" + return f"Collection '{collection_name}' created successfully" + except ValueError as e: + return f"Collection creation error: {str(e)}" + except Exception as e: + return f"Unexpected collection creation error: {str(e)}" @mcp.tool() @@ -693,10 +756,17 @@ async def milvus_insert_data( collection_name: Name of collection data: List of dictionaries, each representing a record """ - connector = ctx.request_context.lifespan_context.connector - result = await connector.insert_data(collection_name=collection_name, data=data) + try: + connector = ctx.request_context.lifespan_context.connector + result = await connector.insert_data(collection_name=collection_name, data=data) - return f"Data inserted into collection '{collection_name}' with result: {str(result)}" + return ( + f"Data inserted into collection '{collection_name}' with result: {str(result)}" + ) + except ValueError as e: + return f"Data insertion error: {str(e)}" + except Exception as e: + return f"Unexpected data insertion error: {str(e)}" @mcp.tool() @@ -710,12 +780,17 @@ async def milvus_delete_entities( collection_name: Name of collection filter_expr: Filter expression to select entities to delete """ - connector = ctx.request_context.lifespan_context.connector - result = await connector.delete_entities( - collection_name=collection_name, filter_expr=filter_expr - ) + try: + connector = ctx.request_context.lifespan_context.connector + result = await connector.delete_entities( + collection_name=collection_name, filter_expr=filter_expr + ) - return f"Entities deleted from collection '{collection_name}' with result: {str(result)}" + return f"Entities deleted from collection '{collection_name}' with result: {str(result)}" + except ValueError as e: + return f"Entity deletion error: {str(e)}" + except Exception as e: + return f"Unexpected entity deletion error: {str(e)}" @mcp.tool() @@ -729,12 +804,17 @@ async def milvus_load_collection( collection_name: Name of collection to load replica_number: Number of replicas """ - connector = ctx.request_context.lifespan_context.connector - success = await connector.load_collection( - collection_name=collection_name, replica_number=replica_number - ) + try: + connector = ctx.request_context.lifespan_context.connector + await connector.load_collection( + collection_name=collection_name, replica_number=replica_number + ) - return f"Collection '{collection_name}' loaded successfully with {replica_number} replica(s)" + return f"Collection '{collection_name}' loaded successfully with {replica_number} replica(s)" + except ValueError as e: + return f"Collection loading error: {str(e)}" + except Exception as e: + return f"Unexpected collection loading error: {str(e)}" @mcp.tool() @@ -745,18 +825,26 @@ async def milvus_release_collection(collection_name: str, ctx: Context = None) - Args: collection_name: Name of collection to release """ - connector = ctx.request_context.lifespan_context.connector - success = await connector.release_collection(collection_name=collection_name) + try: + connector = ctx.request_context.lifespan_context.connector + await connector.release_collection(collection_name=collection_name) - return f"Collection '{collection_name}' released successfully" + return f"Collection '{collection_name}' released successfully" + except ValueError as e: + return f"Collection release error: {str(e)}" + except Exception as e: + return f"Unexpected collection release error: {str(e)}" @mcp.tool() async def milvus_list_databases(ctx: Context = None) -> str: """List all databases in the Milvus instance.""" - connector = ctx.request_context.lifespan_context.connector - databases = await connector.list_databases() - return f"Databases in Milvus instance:\n{', '.join(databases)}" + try: + connector = ctx.request_context.lifespan_context.connector + databases = await connector.list_databases() + return f"Databases in Milvus instance:\n{', '.join(databases)}" + except Exception as e: + return f"Error listing databases: {str(e)}" @mcp.tool() @@ -767,37 +855,46 @@ async def milvus_use_database(db_name: str, ctx: Context = None) -> str: Args: db_name: Name of the database to use """ - connector = ctx.request_context.lifespan_context.connector - success = await connector.use_database(db_name) - - return f"Switched to database '{db_name}' successfully" + try: + connector = ctx.request_context.lifespan_context.connector + await connector.use_database(db_name) + return f"Switched to database '{db_name}' successfully" + except ValueError as e: + return f"Database switch error: {str(e)}" + except Exception as e: + return f"Unexpected database switch error: {str(e)}" @mcp.tool() async def milvus_get_collection_info(collection_name: str, ctx: Context = None) -> str: """ Lists detailed information about a specific collection - + Args: collection_name: Name of collection to load """ - connector = ctx.request_context.lifespan_context.connector - collection_info = await connector.get_collection_info(collection_name) - info_str = json.dumps(collection_info, indent=2) - return f"Collection information:\n{info_str}" - + try: + connector = ctx.request_context.lifespan_context.connector + collection_info = await connector.get_collection_info(collection_name) + info_str = json.dumps(collection_info, indent=2, cls=CustomJSONEncoder) + return f"Collection information:\n{info_str}" + except ValueError as e: + return f"Collection info error: {str(e)}" + except Exception as e: + return f"Unexpected collection info error: {str(e)}" def parse_arguments(): parser = argparse.ArgumentParser(description="Milvus MCP Server") - parser.add_argument( - "--milvus-uri", type=str, default="http://localhost:19530", help="Milvus server URI" - ) - parser.add_argument( - "--milvus-token", type=str, default=None, help="Milvus authentication token" - ) - parser.add_argument("--milvus-db", type=str, default="default", help="Milvus database name") - parser.add_argument("--sse", action="store_true", help="Enable SSE mode") - parser.add_argument("--port", type=int, default=8000, help="Port number for SSE server") + parser.add_argument("--milvus-uri", type=str, + default="http://localhost:19530", help="Milvus server URI") + parser.add_argument("--milvus-token", type=str, + default=None, help="Milvus authentication token") + parser.add_argument("--milvus-db", type=str, + default="default", help="Milvus database name") + parser.add_argument("--sse", action="store_true", + help="Enable SSE mode") + parser.add_argument("--port", type=int, + default=8000, help="Port number for SSE server") return parser.parse_args()