diff --git a/client-metadata.json b/client-metadata.json new file mode 100644 index 00000000..e1b78b9f --- /dev/null +++ b/client-metadata.json @@ -0,0 +1,7 @@ +{ + "client_id": "https://raw.githubusercontent.com/modelcontextprotocol/rust-sdk/refs/heads/main/client-metadata.json", + "redirect_uris": ["http://localhost:4000/callback"], + "grant_types": ["authorization_code"], + "response_types": ["code"], + "token_endpoint_auth_method": "none" +} diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index 68dfc208..412d585a 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -239,6 +239,15 @@ struct AuthorizationState { csrf_token: CsrfToken, } +/// SEP-991: URL-based Client IDs +/// Validate that the client_id is a valid URL with https scheme and non-root pathname +fn is_https_url(value: &str) -> bool { + Url::parse(value) + .ok() + .map(|url| url.scheme() == "https" && url.path() != "/" && url.host_str().is_some()) + .unwrap_or(false) +} + impl AuthorizationManager { fn well_known_paths(base_path: &str, resource: &str) -> Vec { let trimmed = base_path.trim_start_matches('/').trim_end_matches('/'); @@ -950,30 +959,57 @@ impl AuthorizationSession { scopes: &[&str], redirect_uri: &str, client_name: Option<&str>, + client_metadata_url: Option<&str>, ) -> Result { - // Default client config - let config = OAuthClientConfig { - client_id: "mcp-client".to_string(), - client_secret: None, - scopes: scopes.iter().map(|s| s.to_string()).collect(), - redirect_uri: redirect_uri.to_string(), - }; - - // try to dynamic register client - let config = match auth_manager - .register_client(client_name.unwrap_or("MCP Client"), redirect_uri) - .await - { - Ok(config) => config, - Err(e) => { - warn!( - "Dynamic registration failed: {}, fallback to default config", - e - ); - // fallback to default config - config + let metadata = auth_manager.metadata.as_ref(); + let supports_url_based_client_id = metadata + .and_then(|m| { + m.additional_fields + .get("client_id_metadata_document_supported") + }) + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + let config = if supports_url_based_client_id { + if let Some(client_metadata_url) = client_metadata_url { + if !is_https_url(client_metadata_url) { + return Err(AuthError::RegistrationFailed(format!( + "client_metadata_url must be a valid HTTPS URL with a non-root pathname, got: {}", + client_metadata_url + ))); + } + // SEP-991: URL-based Client IDs - use URL as client_id directly + OAuthClientConfig { + client_id: client_metadata_url.to_string(), + client_secret: None, + scopes: scopes.iter().map(|s| s.to_string()).collect(), + redirect_uri: redirect_uri.to_string(), + } + } else { + // Fallback to dynamic registration + auth_manager + .register_client(client_name.unwrap_or("MCP Client"), redirect_uri) + .await + .map_err(|e| { + AuthError::RegistrationFailed(format!("Dynamic registration failed: {}", e)) + })? + } + } else { + // Fallback to dynamic registration + match auth_manager + .register_client(client_name.unwrap_or("MCP Client"), redirect_uri) + .await + { + Ok(config) => config, + Err(e) => { + return Err(AuthError::RegistrationFailed(format!( + "Dynamic registration failed: {}", + e + ))); + } } }; + // reset client config auth_manager.configure_client(config)?; let auth_url = auth_manager.get_authorization_url(scopes).await?; @@ -1125,6 +1161,18 @@ impl OAuthState { scopes: &[&str], redirect_uri: &str, client_name: Option<&str>, + ) -> Result<(), AuthError> { + self.start_authorization_with_metadata_url(scopes, redirect_uri, client_name, None) + .await + } + + /// start authorization with optional client metadata URL (SEP-991) + pub async fn start_authorization_with_metadata_url( + &mut self, + scopes: &[&str], + redirect_uri: &str, + client_name: Option<&str>, + client_metadata_url: Option<&str>, ) -> Result<(), AuthError> { if let OAuthState::Unauthorized(mut manager) = std::mem::replace( self, @@ -1134,8 +1182,14 @@ impl OAuthState { let metadata = manager.discover_metadata().await?; manager.metadata = Some(metadata); debug!("start session"); - let session = - AuthorizationSession::new(manager, scopes, redirect_uri, client_name).await?; + let session = AuthorizationSession::new( + manager, + scopes, + redirect_uri, + client_name, + client_metadata_url, + ) + .await?; *self = OAuthState::Session(session); Ok(()) } else { @@ -1256,7 +1310,31 @@ impl OAuthState { mod tests { use url::Url; - use super::AuthorizationManager; + use super::{AuthorizationManager, is_https_url}; + + // SEP-991: URL-based Client IDs + // Tests adapted from the TypeScript SDK's isHttpsUrl test suite + #[test] + fn test_is_https_url_scenarios() { + // Returns true for valid https url with path + assert!(is_https_url("https://example.com/client-metadata.json")); + // Returns true for https url with query params + assert!(is_https_url("https://example.com/metadata?version=1")); + // Returns false for https url without path + assert!(!is_https_url("https://example.com")); + assert!(!is_https_url("https://example.com/")); + assert!(!is_https_url("https://")); + // Returns false for http url + assert!(!is_https_url("http://example.com/metadata")); + // Returns false for non-url strings + assert!(!is_https_url("not a url")); + // Returns false for empty string + assert!(!is_https_url("")); + // Returns false for javascript scheme + assert!(!is_https_url("javascript:alert(1)")); + // Returns false for data scheme + assert!(!is_https_url("data:text/html,")); + } #[test] fn parses_resource_metadata_parameter() { diff --git a/examples/clients/src/auth/oauth_client.rs b/examples/clients/src/auth/oauth_client.rs index 53fa5cca..9b131652 100644 --- a/examples/clients/src/auth/oauth_client.rs +++ b/examples/clients/src/auth/oauth_client.rs @@ -1,4 +1,4 @@ -use std::{net::SocketAddr, sync::Arc}; +use std::{env, net::SocketAddr, sync::Arc}; use anyhow::{Context, Result}; use axum::{ @@ -23,10 +23,11 @@ use tokio::{ }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; -const MCP_SERVER_URL: &str = "http://localhost:3000/mcp"; -const MCP_REDIRECT_URI: &str = "http://localhost:8080/callback"; +const MCP_SERVER_URL: &str = "http://127.0.0.1:3000/mcp"; +const MCP_REDIRECT_URI: &str = "http://127.0.0.1:8080/callback"; const CALLBACK_PORT: u16 = 8080; const CALLBACK_HTML: &str = include_str!("callback.html"); +const CLIENT_METADATA_URL: &str = "https://raw.githubusercontent.com/modelcontextprotocol/rust-sdk/refs/heads/main/client-metadata.json"; #[derive(Clone)] struct AppState { @@ -79,6 +80,9 @@ async fn main() -> Result<()> { let addr = SocketAddr::from(([127, 0, 0, 1], CALLBACK_PORT)); tracing::info!("Starting callback server at: http://{}", addr); + tracing::warn!( + "Note: Callback server may not receive callbacks if redirect URI doesn't match localhost if using CIMD (SEP-991)" + ); // Start server in a separate task tokio::spawn(async move { @@ -90,19 +94,37 @@ async fn main() -> Result<()> { } }); - // Get server URL - let server_url = MCP_SERVER_URL.to_string(); + // Get server URL and client metadata URL from CLI (with defaults) + // + // Usage: + // cargo run --example clients_oauth_client -- + let args: Vec = env::args().collect(); + let server_url = args + .get(1) + .cloned() + .unwrap_or_else(|| MCP_SERVER_URL.to_string()); + let client_metadata_url = args + .get(2) + .cloned() + .unwrap_or_else(|| CLIENT_METADATA_URL.to_string()); + tracing::info!("Using MCP server URL: {}", server_url); + tracing::info!( + "Using CIMD (SEP-991) with client metadata URL: {}", + client_metadata_url + ); // Initialize oauth state machine let mut oauth_state = OAuthState::new(&server_url, None) .await .context("Failed to initialize oauth state machine")?; + // Use CIMD (SEP-991) with client metadata URL oauth_state - .start_authorization( + .start_authorization_with_metadata_url( &["mcp", "profile", "email"], MCP_REDIRECT_URI, Some("Test MCP Client"), + Some(&client_metadata_url), ) .await .context("Failed to start authorization")?; diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index 72431558..068532ad 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -44,6 +44,7 @@ tower-http = { version = "0.6", features = ["cors"] } hyper = { version = "1" } hyper-util = { version = "0", features = ["server"] } tokio-util = { version = "0.7" } +url = "2.5" [dev-dependencies] tokio-stream = { version = "0.1" } @@ -97,6 +98,10 @@ path = "src/simple_auth_streamhttp.rs" name = "servers_complex_auth_streamhttp" path = "src/complex_auth_streamhttp.rs" +[[example]] +name = "servers_cimd_auth_streamhttp" +path = "src/cimd_auth_streamhttp.rs" + [[example]] name = "servers_calculator_stdio" path = "src/calculator_stdio.rs" diff --git a/examples/servers/src/cimd_auth_streamhttp.rs b/examples/servers/src/cimd_auth_streamhttp.rs new file mode 100644 index 00000000..73eab5e6 --- /dev/null +++ b/examples/servers/src/cimd_auth_streamhttp.rs @@ -0,0 +1,514 @@ +use std::{ + collections::HashMap, + net::SocketAddr, + sync::Arc, + time::{Duration, SystemTime}, +}; + +use anyhow::Result; +use axum::{ + Json, Router, + extract::{Form, Query, State}, + http::StatusCode, + response::{Html, IntoResponse, Redirect, Response}, + routing::{get, post}, +}; +use rand::{Rng, distr::Alphanumeric}; +use rmcp::transport::{ + StreamableHttpServerConfig, + streamable_http_server::{session::local::LocalSessionManager, tower::StreamableHttpService}, +}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tokio::sync::RwLock; +use tower_http::cors::{Any, CorsLayer}; +use tracing::{error, info}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +use url::Url; + +// Import Counter tool for MCP service +mod common; +use common::counter::Counter; + +const BIND_ADDRESS: &str = "127.0.0.1:3000"; + +/// In-memory authorization code record +#[derive(Clone, Debug)] +struct AuthCodeRecord { + client_id: String, + redirect_uri: String, + expires_at: SystemTime, +} + +#[derive(Clone)] +struct AppState { + auth_codes: Arc>>, +} + +impl AppState { + fn new() -> Self { + Self { + auth_codes: Arc::new(RwLock::new(HashMap::new())), + } + } +} + +fn generate_authorization_code() -> String { + rand::rng() + .sample_iter(&Alphanumeric) + .take(32) + .map(char::from) + .collect() +} + +fn generate_access_token() -> String { + rand::rng() + .sample_iter(&Alphanumeric) + .take(32) + .map(char::from) + .collect() +} + +/// Validate that the client_id is a URL that meets CIMD mandatory requirements. +/// Mirrors the JS validateClientIdUrl helper. +fn validate_client_id_url(raw: &str) -> Result { + let url = Url::parse(raw) + .map_err(|_| "invalid_client_id: client_id must be a valid URL".to_string())?; + + // MUST have https scheme + if url.scheme() != "https" { + return Err("invalid_client_id: client_id URL MUST use https scheme".to_string()); + } + + // MUST contain a path component (cannot be empty or just "/") + let path = url.path(); + if path.is_empty() || path == "/" { + return Err("invalid_client_id: client_id URL MUST contain a path component".to_string()); + } + + // MUST NOT contain single-dot or double-dot path segments + if path.split('/').any(|s| s == "." || s == "..") { + return Err( + "invalid_client_id: client_id URL MUST NOT contain single-dot or double-dot path segments" + .to_string(), + ); + } + + // MUST NOT contain a fragment component + if url.fragment().is_some() { + return Err( + "invalid_client_id: client_id URL MUST NOT contain a fragment component".to_string(), + ); + } + + // MUST NOT contain a username or password + if !url.username().is_empty() || url.password().is_some() { + return Err( + "invalid_client_id: client_id URL MUST NOT contain a username or password component" + .to_string(), + ); + } + + Ok(url.to_string()) +} + +/// Fetch and validate the client metadata document from the client_id URL. +/// Implements MUST / MUST NOT rules from CIMD section 4.1. +async fn fetch_and_validate_client_metadata(client_id_url: &str) -> Result { + let client = reqwest::Client::new(); + let res = client + .get(client_id_url) + .header( + reqwest::header::ACCEPT, + "application/json, application/*+json", + ) + .send() + .await + .map_err(|_| "invalid_client: failed to fetch client metadata document".to_string())?; + + if !res.status().is_success() { + return Err("invalid_client: failed to fetch client metadata document".to_string()); + } + + let json: Value = res + .json() + .await + .map_err(|_| "invalid_client: client metadata document is not valid JSON".to_string())?; + + if !json.is_object() { + return Err("invalid_client: client metadata document must be a JSON object".to_string()); + } + + // MUST contain a client_id property equal to the URL of the document + let client_id_value = json.get("client_id").ok_or_else(|| { + "invalid_client: client metadata document MUST contain client_id".to_string() + })?; + if client_id_value != client_id_url { + return Err( + "invalid_client: client_id property in metadata document MUST match the document URL" + .to_string(), + ); + } + + // token_endpoint_auth_method MUST NOT be any shared secret based method + if let Some(method) = json.get("token_endpoint_auth_method") { + if let Some(method_str) = method.as_str() { + let forbidden = [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + ]; + if forbidden.contains(&method_str) || method_str.starts_with("client_secret_") { + return Err("invalid_client: token_endpoint_auth_method MUST NOT be a shared secret based method".to_string()); + } + } + } + + // client_secret and client_secret_expires_at MUST NOT be used + if json.get("client_secret").is_some() { + return Err( + "invalid_client: client_secret MUST NOT be present in client metadata".to_string(), + ); + } + if json.get("client_secret_expires_at").is_some() { + return Err( + "invalid_client: client_secret_expires_at MUST NOT be present in client metadata" + .to_string(), + ); + } + + Ok(json) +} + +/// Validate redirect_uri against metadata.redirect_uris (exact match). +fn validate_redirect_uri(requested_redirect_uri: &str, metadata: &Value) -> Result<(), String> { + let redirect_uris = metadata.get("redirect_uris").ok_or_else(|| { + "invalid_client: client metadata must include redirect_uris array".to_string() + })?; + + let arr = redirect_uris + .as_array() + .ok_or_else(|| "invalid_client: redirect_uris must be an array".to_string())?; + + let requested = requested_redirect_uri.to_string(); + let found = arr.iter().any(|u| u.as_str() == Some(&requested)); + + if !found { + return Err( + "invalid_request: redirect_uri MUST exactly match one of the registered redirect_uris" + .to_string(), + ); + } + + Ok(()) +} + +/// Minimal Authorization Server Metadata with CIMD support. +async fn oauth_metadata() -> impl IntoResponse { + let issuer = + std::env::var("CIMD_ISSUER").unwrap_or_else(|_| format!("http://{}", BIND_ADDRESS)); + + let body = serde_json::json!({ + "issuer": issuer, + "authorization_endpoint": format!("{}/authorize", issuer), + "token_endpoint": format!("{}/token", issuer), + "client_id_metadata_document_supported": true, + }); + + Json(body) +} + +#[derive(Debug, Deserialize)] +struct AuthorizeQuery { + client_id: Option, + redirect_uri: Option, + response_type: Option, + state: Option, + scope: Option, +} + +#[derive(Debug, Deserialize)] +struct LoginForm { + username: Option, + password: Option, + // OAuth params come from hidden form fields + client_id: Option, + redirect_uri: Option, + response_type: Option, + state: Option, + scope: Option, +} + +fn render_login_form(params: &AuthorizeQuery, error: Option<&str>) -> Html { + let hidden_fields = [ + ("client_id", params.client_id.as_deref().unwrap_or_default()), + ( + "redirect_uri", + params.redirect_uri.as_deref().unwrap_or_default(), + ), + ( + "response_type", + params.response_type.as_deref().unwrap_or_default(), + ), + ("state", params.state.as_deref().unwrap_or_default()), + ("scope", params.scope.as_deref().unwrap_or_default()), + ] + .iter() + .map(|(k, v)| format!(r#""#)) + .collect::>() + .join("\n "); + + let error_html = error + .map(|e| format!(r#"
{}
"#, e)) + .unwrap_or_default(); + + let html = format!( + r#" + + + + OAuth Login - CIMD Server + + + +

OAuth Login

+ {error_html} +
+ {hidden_fields} + + + + + +
+

+ Demo credentials: admin / admin +

+ + +"# + ); + + Html(html) +} + +async fn authorize_get(Query(params): Query) -> impl IntoResponse { + render_login_form(¶ms, None) +} + +async fn authorize_post( + State(state): State, + Form(form): Form, +) -> impl IntoResponse { + // Convert LoginForm (which includes OAuth params from hidden fields) to AuthorizeQuery + let params = AuthorizeQuery { + client_id: form.client_id.clone(), + redirect_uri: form.redirect_uri.clone(), + response_type: form.response_type.clone(), + state: form.state.clone(), + scope: form.scope.clone(), + }; + + match handle_authorize(&state, ¶ms, &form).await { + Ok(redirect_response) => redirect_response, + Err(error_response) => error_response, + } +} + +async fn handle_authorize( + state: &AppState, + params: &AuthorizeQuery, + form: &LoginForm, +) -> Result { + let client_id_raw = params + .client_id + .as_deref() + .ok_or_else(|| bad_request("invalid_request: client_id is required"))?; + let redirect_uri = params + .redirect_uri + .as_deref() + .ok_or_else(|| bad_request("invalid_request: redirect_uri is required"))?; + let response_type = params + .response_type + .as_deref() + .ok_or_else(|| bad_request("invalid_request: response_type is required"))?; + + if response_type != "code" { + return Err(bad_request( + "unsupported_response_type: only response_type=code is supported", + )); + } + + let client_id_url = validate_client_id_url(client_id_raw).map_err(|e| bad_request(&e))?; + let metadata = fetch_and_validate_client_metadata(&client_id_url) + .await + .map_err(|e| bad_request(&e))?; + validate_redirect_uri(redirect_uri, &metadata).map_err(|e| bad_request(&e))?; + + // If this is a login POST, validate credentials + if let (Some(username), Some(password)) = (&form.username, &form.password) { + if username != "admin" || password != "admin" { + let html = render_login_form(params, Some("Invalid username or password")); + return Err(html.into_response()); + } + + // Login successful - generate authorization code and redirect + let code = generate_authorization_code(); + let expires_at = SystemTime::now() + Duration::from_secs(10 * 60); + + { + let mut codes = state.auth_codes.write().await; + codes.insert( + code.clone(), + AuthCodeRecord { + client_id: client_id_url, + redirect_uri: redirect_uri.to_string(), + expires_at, + }, + ); + } + + let mut url = Url::parse(redirect_uri) + .map_err(|_| bad_request("invalid_request: redirect_uri is invalid"))?; + url.query_pairs_mut().append_pair("code", &code); + if let Some(state_param) = ¶ms.state { + url.query_pairs_mut().append_pair("state", state_param); + } + + Ok(Redirect::to(url.as_str()).into_response()) + } else { + // GET request without credentials: show login form + let html = render_login_form(params, None); + Err(html.into_response()) + } +} + +fn bad_request(message: &str) -> Response { + let body = serde_json::json!({ + "error": "invalid_request", + "error_description": message, + }); + (StatusCode::BAD_REQUEST, Json(body)).into_response() +} + +#[derive(Debug, Deserialize)] +struct TokenRequest { + grant_type: Option, + code: Option, +} + +async fn token(State(state): State, Form(form): Form) -> impl IntoResponse { + if form.grant_type.as_deref() != Some("authorization_code") { + let body = serde_json::json!({ + "error": "unsupported_grant_type", + "error_description": "Only authorization_code is supported in this demo", + }); + return (StatusCode::BAD_REQUEST, Json(body)).into_response(); + } + + let code = match &form.code { + Some(c) => c.clone(), + None => { + let body = serde_json::json!({ + "error": "invalid_request", + "error_description": "Authorization code is required", + }); + return (StatusCode::BAD_REQUEST, Json(body)).into_response(); + } + }; + + let record_opt = { + let mut codes = state.auth_codes.write().await; + codes.remove(&code) + }; + + let record = match record_opt { + Some(r) => r, + None => { + let body = serde_json::json!({ + "error": "invalid_grant", + "error_description": "Invalid authorization code", + }); + return (StatusCode::BAD_REQUEST, Json(body)).into_response(); + } + }; + + if SystemTime::now() > record.expires_at { + let body = serde_json::json!({ + "error": "invalid_grant", + "error_description": "Authorization code has expired", + }); + return (StatusCode::BAD_REQUEST, Json(body)).into_response(); + } + + let access_token = generate_access_token(); + let body = serde_json::json!({ + "access_token": access_token, + "token_type": "Bearer", + "expires_in": 3600, + }); + + Json(body).into_response() +} + +async fn index() -> Html<&'static str> { + Html( + "

CIMD OAuth + MCP Server

This server supports Client ID Metadata Documents (SEP-991) and exposes an MCP endpoint at /mcp.

", + ) +} + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let state = AppState::new(); + + // Create streamable HTTP service for MCP + let mcp_service: StreamableHttpService = + StreamableHttpService::new( + || Ok(Counter::new()), + LocalSessionManager::default().into(), + StreamableHttpServerConfig::default(), + ); + + let addr = BIND_ADDRESS.parse::()?; + + let cors_layer = CorsLayer::new() + .allow_origin(Any) + .allow_methods(Any) + .allow_headers(Any); + + let app = Router::new() + .route("/", get(index)) + .route( + "/.well-known/oauth-authorization-server", + get(oauth_metadata), + ) + .route("/authorize", get(authorize_get).post(authorize_post)) + .route("/token", post(token).layer(cors_layer.clone())) + .nest_service("/mcp", mcp_service) + .with_state(state); + + let listener = tokio::net::TcpListener::bind(addr).await?; + info!("CIMD OAuth server listening on http://{}", addr); + + if let Err(e) = axum::serve(listener, app).await { + error!("server error: {}", e); + } + + Ok(()) +}