Skip to content

Commit 0625ad5

Browse files
committed
fix(openapi): add regex validation and set additionalProperties=False
1 parent fed143b commit 0625ad5

File tree

1 file changed

+213
-57
lines changed

1 file changed

+213
-57
lines changed

mcp_openapi_proxy/openapi.py

Lines changed: 213 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import os
66
import json
7+
import re # Import the re module
78
import requests
89
import yaml
910
from typing import Dict, Optional, List, Union
@@ -12,6 +13,9 @@
1213
from mcp_openapi_proxy.utils import normalize_tool_name
1314
from .logging_setup import logger
1415

16+
# Define the required tool name pattern
17+
TOOL_NAME_REGEX = r"^[a-zA-Z0-9_-]{1,64}$"
18+
1519
def fetch_openapi_spec(url: str, retries: int = 3) -> Optional[Dict]:
1620
"""Fetch and parse an OpenAPI specification from a URL with retries."""
1721
logger.debug(f"Fetching OpenAPI spec from URL: {url}")
@@ -47,6 +51,15 @@ def fetch_openapi_spec(url: str, retries: int = 3) -> Optional[Dict]:
4751
if attempt == retries:
4852
logger.error(f"Failed to fetch spec from {url} after {retries} attempts: {e}")
4953
return None
54+
except FileNotFoundError as e:
55+
logger.error(f"Failed to open local file spec {url}: {e}")
56+
return None
57+
except Exception as e:
58+
attempt += 1
59+
logger.warning(f"Unexpected error during fetch attempt {attempt}/{retries}: {e}")
60+
if attempt == retries:
61+
logger.error(f"Failed to process spec from {url} after {retries} attempts due to unexpected error: {e}")
62+
return None
5063
return None
5164

5265
def build_base_url(spec: Dict) -> Optional[str]:
@@ -60,12 +73,32 @@ def build_base_url(spec: Dict) -> Optional[str]:
6073
return url
6174
logger.error(f"No valid URLs found in SERVER_URL_OVERRIDE: {override}")
6275
return None
76+
6377
if "servers" in spec and spec["servers"]:
64-
return spec["servers"][0]["url"]
65-
elif "host" in spec and "schemes" in spec:
66-
scheme = spec["schemes"][0] if spec["schemes"] else "https"
67-
return f"{scheme}://{spec['host']}{spec.get('basePath', '')}"
68-
logger.error("No servers or host/schemes defined in spec and no SERVER_URL_OVERRIDE.")
78+
# Ensure servers is a list and has items before accessing index 0
79+
if isinstance(spec["servers"], list) and len(spec["servers"]) > 0 and isinstance(spec["servers"][0], dict):
80+
server_url = spec["servers"][0].get("url")
81+
if server_url:
82+
logger.debug(f"Using first server URL from spec: {server_url}")
83+
return server_url
84+
else:
85+
logger.warning("First server entry in spec missing 'url' key.")
86+
else:
87+
logger.warning("Spec 'servers' key is not a non-empty list of dictionaries.")
88+
89+
# Fallback for OpenAPI v2 (Swagger)
90+
if "host" in spec and "schemes" in spec:
91+
scheme = spec["schemes"][0] if spec.get("schemes") else "https"
92+
base_path = spec.get("basePath", "")
93+
host = spec.get("host")
94+
if host:
95+
v2_url = f"{scheme}://{host}{base_path}"
96+
logger.debug(f"Using OpenAPI v2 host/schemes/basePath: {v2_url}")
97+
return v2_url
98+
else:
99+
logger.warning("OpenAPI v2 spec missing 'host'.")
100+
101+
logger.error("Could not determine base URL from spec (servers/host/schemes) or SERVER_URL_OVERRIDE.")
69102
return None
70103

71104
def handle_auth(operation: Dict) -> Dict[str, str]:
@@ -75,98 +108,221 @@ def handle_auth(operation: Dict) -> Dict[str, str]:
75108
auth_type = os.getenv("API_AUTH_TYPE", "Bearer").lower()
76109
if api_key:
77110
if auth_type == "bearer":
78-
logger.debug(f"Using API_KEY as Bearer: {api_key[:5]}...")
111+
logger.debug(f"Using API_KEY as Bearer token.") # Avoid logging key prefix
79112
headers["Authorization"] = f"Bearer {api_key}"
80113
elif auth_type == "basic":
81-
logger.debug("API_AUTH_TYPE is Basic, but Basic Auth not implemented yet.")
114+
logger.warning("API_AUTH_TYPE is Basic, but Basic Auth is not fully implemented yet.")
115+
# Potentially add basic auth implementation here if needed
82116
elif auth_type == "api-key":
83117
key_name = os.getenv("API_AUTH_HEADER", "Authorization")
84118
headers[key_name] = api_key
85-
logger.debug(f"Using API_KEY as API-Key in header {key_name}: {api_key[:5]}...")
119+
logger.debug(f"Using API_KEY as API-Key in header '{key_name}'.") # Avoid logging key prefix
120+
else:
121+
logger.warning(f"Unsupported API_AUTH_TYPE: {auth_type}")
122+
# TODO: Add logic to check operation['security'] and spec['components']['securitySchemes']
123+
# to potentially override or supplement env var based auth.
86124
return headers
87125

88126
def register_functions(spec: Dict) -> List[types.Tool]:
89127
"""Register tools from OpenAPI spec."""
90-
from .utils import is_tool_whitelisted
91-
92-
tools: List[types.Tool] = []
93-
logger.debug("Clearing previously registered tools to allow re-registration")
128+
from .utils import is_tool_whitelisted # Keep import here to avoid circular dependency if utils imports openapi
129+
130+
tools_list: List[types.Tool] = [] # Use a local list for registration
131+
logger.debug("Starting tool registration from OpenAPI spec.")
94132
if not spec:
95-
logger.error("OpenAPI spec is None or empty.")
96-
return tools
133+
logger.error("OpenAPI spec is None or empty during registration.")
134+
return tools_list
97135
if 'paths' not in spec:
98-
logger.error("No 'paths' key in OpenAPI spec.")
99-
return tools
100-
logger.debug(f"Spec paths available: {list(spec['paths'].keys())}")
101-
filtered_paths = {path: item for path, item in spec['paths'].items() if is_tool_whitelisted(path)}
102-
logger.debug(f"Filtered paths: {list(filtered_paths.keys())}")
136+
logger.error("No 'paths' key in OpenAPI spec during registration.")
137+
return tools_list
138+
139+
logger.debug(f"Available paths in spec: {list(spec['paths'].keys())}")
140+
# Filter paths based on whitelist *before* iterating
141+
# Note: is_tool_whitelisted expects the path string
142+
filtered_paths = {
143+
path: item
144+
for path, item in spec['paths'].items()
145+
if is_tool_whitelisted(path)
146+
}
147+
logger.debug(f"Paths after whitelist filtering: {list(filtered_paths.keys())}")
148+
103149
if not filtered_paths:
104-
logger.warning("No whitelisted paths found in OpenAPI spec after filtering.")
105-
return tools
150+
logger.warning("No whitelisted paths found in OpenAPI spec after filtering. No tools will be registered.")
151+
return tools_list
152+
153+
registered_names = set() # Keep track of names to detect duplicates
154+
106155
for path, path_item in filtered_paths.items():
107-
if not path_item:
108-
logger.debug(f"Empty path item for {path}")
156+
if not path_item or not isinstance(path_item, dict):
157+
logger.debug(f"Skipping empty or invalid path item for {path}")
109158
continue
110159
for method, operation in path_item.items():
111-
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
112-
logger.debug(f"Skipping unsupported method {method} for {path}")
160+
# Check if method is a valid HTTP verb and operation is a dictionary
161+
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch', 'options', 'head', 'trace'] or not isinstance(operation, dict):
162+
# logger.debug(f"Skipping non-operation entry or unsupported method '{method}' for path '{path}'")
113163
continue
114164
try:
115165
raw_name = f"{method.upper()} {path}"
116166
function_name = normalize_tool_name(raw_name)
167+
168+
# --- Add Regex Validation Step ---
169+
if not re.match(TOOL_NAME_REGEX, function_name):
170+
logger.error(
171+
f"Skipping registration for '{raw_name}': "
172+
f"Generated name '{function_name}' does not match required pattern '{TOOL_NAME_REGEX}'."
173+
)
174+
continue # Skip this tool
175+
176+
# --- Check for duplicate names ---
177+
if function_name in registered_names:
178+
logger.warning(
179+
f"Skipping registration for '{raw_name}': "
180+
f"Duplicate tool name '{function_name}' detected."
181+
)
182+
continue # Skip this tool
183+
117184
description = operation.get('summary', operation.get('description', 'No description available'))
185+
# Ensure description is a string
186+
if not isinstance(description, str):
187+
logger.warning(f"Description for {function_name} is not a string, using default.")
188+
description = "No description available"
189+
190+
# --- Build Input Schema ---
118191
input_schema = {
119192
"type": "object",
120193
"properties": {},
121194
"required": [],
122-
"additionalProperties": False
195+
"additionalProperties": False # Explicitly set additionalProperties to False
123196
}
124-
parameters = operation.get('parameters', [])
125-
placeholder_params = [part.strip('{}') for part in path.split('/') if '{' in part and '}' in part]
126-
for param_name in placeholder_params:
127-
input_schema['properties'][param_name] = {
128-
"type": "string",
129-
"description": f"Path parameter {param_name}"
130-
}
131-
input_schema['required'].append(param_name)
132-
logger.debug(f"Added URI placeholder {param_name} to inputSchema for {function_name}")
133-
for param in parameters:
134-
param_name = param.get('name')
135-
param_in = param.get('in')
197+
# Process parameters defined directly under the operation
198+
op_params = operation.get('parameters', [])
199+
# Process parameters defined at the path level (common parameters)
200+
path_params = path_item.get('parameters', [])
201+
# Combine parameters, giving operation-level precedence if names clash (though unlikely per spec)
202+
all_params = {p.get('name'): p for p in path_params if isinstance(p, dict) and p.get('name')}
203+
all_params.update({p.get('name'): p for p in op_params if isinstance(p, dict) and p.get('name')})
204+
205+
for param_name, param_details in all_params.items():
206+
if not param_name or not isinstance(param_details, dict):
207+
continue # Skip invalid parameter definitions
208+
209+
param_in = param_details.get('in')
210+
# We primarily care about 'path' and 'query' for simple input schema generation
211+
# Body parameters are handled differently (often implicitly the whole input)
136212
if param_in in ['path', 'query']:
137-
param_type = param.get('schema', {}).get('type', 'string')
138-
schema_type = param_type if param_type in ['string', 'integer', 'boolean', 'number'] else 'string'
213+
param_schema = param_details.get('schema', {})
214+
prop_type = param_schema.get('type', 'string')
215+
# Basic type mapping, default to string
216+
schema_type = prop_type if prop_type in ['string', 'integer', 'boolean', 'number', 'array'] else 'string'
217+
139218
input_schema['properties'][param_name] = {
140219
"type": schema_type,
141-
"description": param.get('description', f"{param_in} parameter {param_name}")
220+
"description": param_details.get('description', f"{param_in} parameter {param_name}")
142221
}
143-
if param.get('required', False) and param_name not in input_schema['required']:
144-
input_schema['required'].append(param_name)
222+
# Add format if available
223+
if param_schema.get('format'):
224+
input_schema['properties'][param_name]['format'] = param_schema.get('format')
225+
# Add enum if available
226+
if param_schema.get('enum'):
227+
input_schema['properties'][param_name]['enum'] = param_schema.get('enum')
228+
229+
if param_details.get('required', False):
230+
# Only add to required if not already present (e.g., from path template)
231+
if param_name not in input_schema['required']:
232+
input_schema['required'].append(param_name)
233+
234+
# Add path parameters derived from the path template itself (e.g., /users/{id})
235+
# These are always required and typically strings
236+
template_params = re.findall(r"\{([^}]+)\}", path)
237+
for tp_name in template_params:
238+
if tp_name not in input_schema['properties']:
239+
input_schema['properties'][tp_name] = {
240+
"type": "string", # Path params are usually strings
241+
"description": f"Path parameter '{tp_name}'"
242+
}
243+
if tp_name not in input_schema['required']:
244+
input_schema['required'].append(tp_name)
245+
246+
247+
# Handle request body (for POST, PUT, PATCH)
248+
request_body = operation.get('requestBody')
249+
if request_body and isinstance(request_body, dict):
250+
content = request_body.get('content')
251+
if content and isinstance(content, dict):
252+
# Prefer application/json if available
253+
json_content = content.get('application/json')
254+
if json_content and isinstance(json_content, dict) and 'schema' in json_content:
255+
body_schema = json_content['schema']
256+
# If body schema is object with properties, merge them
257+
if body_schema.get('type') == 'object' and 'properties' in body_schema:
258+
input_schema['properties'].update(body_schema['properties'])
259+
if 'required' in body_schema and isinstance(body_schema['required'], list):
260+
# Add required body properties, avoiding duplicates
261+
for req_prop in body_schema['required']:
262+
if req_prop not in input_schema['required']:
263+
input_schema['required'].append(req_prop)
264+
# If body schema is not an object or has no properties,
265+
# maybe represent it as a single 'body' parameter? Needs decision.
266+
# else:
267+
# input_schema['properties']['body'] = body_schema
268+
# if request_body.get('required', False):
269+
# input_schema['required'].append('body')
270+
271+
272+
# Create and register the tool
145273
tool = types.Tool(
146274
name=function_name,
147275
description=description,
148276
inputSchema=input_schema,
149277
)
150-
tools.append(tool)
151-
logger.debug(f"Registered function: {function_name} ({method.upper()} {path}) with inputSchema: {json.dumps(input_schema)}")
278+
tools_list.append(tool)
279+
registered_names.add(function_name)
280+
logger.debug(f"Registered tool: {function_name} from {raw_name}") # Simplified log
281+
152282
except Exception as e:
153283
logger.error(f"Error registering function for {method.upper()} {path}: {e}", exc_info=True)
154-
logger.debug(f"Registered {len(tools)} functions from OpenAPI spec.")
284+
285+
logger.info(f"Successfully registered {len(tools_list)} tools from OpenAPI spec.")
286+
287+
# Update the global/shared tools list if necessary (depends on server implementation)
288+
# Example for lowlevel server:
155289
from . import server_lowlevel
156-
server_lowlevel.tools.clear()
157-
server_lowlevel.tools.extend(tools)
158-
return tools
290+
if hasattr(server_lowlevel, 'tools'):
291+
logger.debug("Updating server_lowlevel.tools list.")
292+
server_lowlevel.tools.clear()
293+
server_lowlevel.tools.extend(tools_list)
294+
# Add similar logic if needed for fastmcp server or remove if registration happens differently there
295+
296+
return tools_list # Return the list of registered tools
159297

160298
def lookup_operation_details(function_name: str, spec: Dict) -> Union[Dict, None]:
161299
"""Look up operation details from OpenAPI spec by function name."""
162300
if not spec or 'paths' not in spec:
301+
logger.warning("Spec is missing or has no 'paths' key in lookup_operation_details.")
163302
return None
303+
304+
# Pre-compile regex for faster matching if called frequently (though likely not needed here)
305+
# TOOL_NAME_REGEX_COMPILED = re.compile(TOOL_NAME_REGEX)
306+
164307
for path, path_item in spec['paths'].items():
165-
for method, operation in path_item.items():
166-
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
167-
continue
168-
raw_name = f"{method.upper()} {path}"
169-
current_function_name = normalize_tool_name(raw_name)
170-
if current_function_name == function_name:
171-
return {"path": path, "method": method.upper(), "operation": operation, "original_path": path}
172-
return None
308+
if not isinstance(path_item, dict): continue # Skip invalid path items
309+
for method, operation in path_item.items():
310+
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch', 'options', 'head', 'trace'] or not isinstance(operation, dict):
311+
continue
312+
raw_name = f"{method.upper()} {path}"
313+
# Regenerate the name using the exact same logic as registration
314+
current_function_name = normalize_tool_name(raw_name)
315+
316+
# Validate the looked-up name matches the required pattern *before* comparing
317+
# This ensures we don't accidentally match an invalid name during lookup
318+
if not re.match(TOOL_NAME_REGEX, current_function_name):
319+
# Log this? It indicates an issue either in normalization or the spec itself
320+
# logger.warning(f"Normalized name '{current_function_name}' for '{raw_name}' is invalid during lookup.")
321+
continue # Skip potentially invalid names
322+
323+
if current_function_name == function_name:
324+
logger.debug(f"Found operation details for '{function_name}' at {method.upper()} {path}")
325+
return {"path": path, "method": method.upper(), "operation": operation, "original_path": path}
326+
327+
logger.warning(f"Could not find operation details for function name: '{function_name}'")
328+
return None

0 commit comments

Comments
 (0)