@@ -21,6 +21,7 @@ use rmcp::transport::{
2121use serde:: { Deserialize , Serialize } ;
2222use serde_json:: Value ;
2323use tokio:: sync:: RwLock ;
24+ use tower_http:: cors:: { Any , CorsLayer } ;
2425use tracing:: { error, info} ;
2526use tracing_subscriber:: { layer:: SubscriberExt , util:: SubscriberInitExt } ;
2627use url:: Url ;
@@ -71,7 +72,8 @@ fn generate_access_token() -> String {
7172/// Validate that the client_id is a URL that meets CIMD mandatory requirements.
7273/// Mirrors the JS validateClientIdUrl helper.
7374fn validate_client_id_url ( raw : & str ) -> Result < String , String > {
74- let url = Url :: parse ( raw) . map_err ( |_| "invalid_client_id: client_id must be a valid URL" . to_string ( ) ) ?;
75+ let url = Url :: parse ( raw)
76+ . map_err ( |_| "invalid_client_id: client_id must be a valid URL" . to_string ( ) ) ?;
7577
7678 // MUST have https scheme
7779 if url. scheme ( ) != "https" {
@@ -94,13 +96,16 @@ fn validate_client_id_url(raw: &str) -> Result<String, String> {
9496
9597 // MUST NOT contain a fragment component
9698 if url. fragment ( ) . is_some ( ) {
97- return Err ( "invalid_client_id: client_id URL MUST NOT contain a fragment component" . to_string ( ) ) ;
99+ return Err (
100+ "invalid_client_id: client_id URL MUST NOT contain a fragment component" . to_string ( ) ,
101+ ) ;
98102 }
99103
100104 // MUST NOT contain a username or password
101105 if !url. username ( ) . is_empty ( ) || url. password ( ) . is_some ( ) {
102106 return Err (
103- "invalid_client_id: client_id URL MUST NOT contain a username or password component" . to_string ( ) ,
107+ "invalid_client_id: client_id URL MUST NOT contain a username or password component"
108+ . to_string ( ) ,
104109 ) ;
105110 }
106111
@@ -135,9 +140,9 @@ async fn fetch_and_validate_client_metadata(client_id_url: &str) -> Result<Value
135140 }
136141
137142 // MUST contain a client_id property equal to the URL of the document
138- let client_id_value = json
139- . get ( " client_id")
140- . ok_or_else ( || "invalid_client: client metadata document MUST contain client_id" . to_string ( ) ) ?;
143+ let client_id_value = json. get ( "client_id" ) . ok_or_else ( || {
144+ "invalid_client: client metadata document MUST contain client_id". to_string ( )
145+ } ) ?;
141146 if client_id_value != client_id_url {
142147 return Err (
143148 "invalid_client: client_id property in metadata document MUST match the document URL"
@@ -148,22 +153,27 @@ async fn fetch_and_validate_client_metadata(client_id_url: &str) -> Result<Value
148153 // token_endpoint_auth_method MUST NOT be any shared secret based method
149154 if let Some ( method) = json. get ( "token_endpoint_auth_method" ) {
150155 if let Some ( method_str) = method. as_str ( ) {
151- let forbidden = [ "client_secret_post" , "client_secret_basic" , "client_secret_jwt" ] ;
152- if forbidden. contains ( & method_str)
153- || method_str. starts_with ( "client_secret_" )
154- {
156+ let forbidden = [
157+ "client_secret_post" ,
158+ "client_secret_basic" ,
159+ "client_secret_jwt" ,
160+ ] ;
161+ if forbidden. contains ( & method_str) || method_str. starts_with ( "client_secret_" ) {
155162 return Err ( "invalid_client: token_endpoint_auth_method MUST NOT be a shared secret based method" . to_string ( ) ) ;
156163 }
157164 }
158165 }
159166
160167 // client_secret and client_secret_expires_at MUST NOT be used
161168 if json. get ( "client_secret" ) . is_some ( ) {
162- return Err ( "invalid_client: client_secret MUST NOT be present in client metadata" . to_string ( ) ) ;
169+ return Err (
170+ "invalid_client: client_secret MUST NOT be present in client metadata" . to_string ( ) ,
171+ ) ;
163172 }
164173 if json. get ( "client_secret_expires_at" ) . is_some ( ) {
165174 return Err (
166- "invalid_client: client_secret_expires_at MUST NOT be present in client metadata" . to_string ( ) ,
175+ "invalid_client: client_secret_expires_at MUST NOT be present in client metadata"
176+ . to_string ( ) ,
167177 ) ;
168178 }
169179
@@ -172,9 +182,9 @@ async fn fetch_and_validate_client_metadata(client_id_url: &str) -> Result<Value
172182
173183/// Validate redirect_uri against metadata.redirect_uris (exact match).
174184fn validate_redirect_uri ( requested_redirect_uri : & str , metadata : & Value ) -> Result < ( ) , String > {
175- let redirect_uris = metadata
176- . get ( " redirect_uris" )
177- . ok_or_else ( || "invalid_client: client metadata must include redirect_uris array" . to_string ( ) ) ?;
185+ let redirect_uris = metadata. get ( "redirect_uris" ) . ok_or_else ( || {
186+ "invalid_client: client metadata must include redirect_uris array" . to_string ( )
187+ } ) ?;
178188
179189 let arr = redirect_uris
180190 . as_array ( )
@@ -185,7 +195,8 @@ fn validate_redirect_uri(requested_redirect_uri: &str, metadata: &Value) -> Resu
185195
186196 if !found {
187197 return Err (
188- "invalid_request: redirect_uri MUST exactly match one of the registered redirect_uris" . to_string ( ) ,
198+ "invalid_request: redirect_uri MUST exactly match one of the registered redirect_uris"
199+ . to_string ( ) ,
189200 ) ;
190201 }
191202
@@ -194,8 +205,8 @@ fn validate_redirect_uri(requested_redirect_uri: &str, metadata: &Value) -> Resu
194205
195206/// Minimal Authorization Server Metadata with CIMD support.
196207async fn oauth_metadata ( ) -> impl IntoResponse {
197- let issuer = std :: env :: var ( "CIMD_ISSUER" )
198- . unwrap_or_else ( |_| format ! ( "http://{}" , BIND_ADDRESS ) ) ;
208+ let issuer =
209+ std :: env :: var ( "CIMD_ISSUER" ) . unwrap_or_else ( |_| format ! ( "http://{}" , BIND_ADDRESS ) ) ;
199210
200211 let body = serde_json:: json!( {
201212 "issuer" : issuer,
@@ -288,9 +299,7 @@ fn render_login_form(params: &AuthorizeQuery, error: Option<&str>) -> Html<Strin
288299 Html ( html)
289300}
290301
291- async fn authorize_get (
292- Query ( params) : Query < AuthorizeQuery > ,
293- ) -> impl IntoResponse {
302+ async fn authorize_get ( Query ( params) : Query < AuthorizeQuery > ) -> impl IntoResponse {
294303 render_login_form ( & params, None )
295304}
296305
@@ -306,7 +315,7 @@ async fn authorize_post(
306315 state : form. state . clone ( ) ,
307316 scope : form. scope . clone ( ) ,
308317 } ;
309-
318+
310319 match handle_authorize ( & state, & params, & form) . await {
311320 Ok ( redirect_response) => redirect_response,
312321 Err ( error_response) => error_response,
@@ -337,10 +346,10 @@ async fn handle_authorize(
337346 ) ) ;
338347 }
339348
340- let client_id_url =
341- validate_client_id_url ( client_id_raw ) . map_err ( |e| bad_request ( & e ) ) ? ;
342- let metadata =
343- fetch_and_validate_client_metadata ( & client_id_url ) . await . map_err ( |e| bad_request ( & e) ) ?;
349+ let client_id_url = validate_client_id_url ( client_id_raw ) . map_err ( |e| bad_request ( & e ) ) ? ;
350+ let metadata = fetch_and_validate_client_metadata ( & client_id_url )
351+ . await
352+ . map_err ( |e| bad_request ( & e) ) ?;
344353 validate_redirect_uri ( redirect_uri, & metadata) . map_err ( |e| bad_request ( & e) ) ?;
345354
346355 // If this is a login POST, validate credentials
@@ -366,12 +375,11 @@ async fn handle_authorize(
366375 ) ;
367376 }
368377
369- let mut url =
370- Url :: parse ( redirect_uri ) . map_err ( |_| bad_request ( "invalid_request: redirect_uri is invalid" ) ) ?;
378+ let mut url = Url :: parse ( redirect_uri )
379+ . map_err ( |_| bad_request ( "invalid_request: redirect_uri is invalid" ) ) ?;
371380 url. query_pairs_mut ( ) . append_pair ( "code" , & code) ;
372381 if let Some ( state_param) = & params. state {
373- url. query_pairs_mut ( )
374- . append_pair ( "state" , state_param) ;
382+ url. query_pairs_mut ( ) . append_pair ( "state" , state_param) ;
375383 }
376384
377385 Ok ( Redirect :: to ( url. as_str ( ) ) . into_response ( ) )
@@ -396,10 +404,7 @@ struct TokenRequest {
396404 code : Option < String > ,
397405}
398406
399- async fn token (
400- State ( state) : State < AppState > ,
401- Form ( form) : Form < TokenRequest > ,
402- ) -> impl IntoResponse {
407+ async fn token ( State ( state) : State < AppState > , Form ( form) : Form < TokenRequest > ) -> impl IntoResponse {
403408 if form. grant_type . as_deref ( ) != Some ( "authorization_code" ) {
404409 let body = serde_json:: json!( {
405410 "error" : "unsupported_grant_type" ,
@@ -454,7 +459,9 @@ async fn token(
454459}
455460
456461async fn index ( ) -> Html < & ' static str > {
457- Html ( "<html><body><h1>CIMD OAuth + MCP Server</h1><p>This server supports Client ID Metadata Documents (SEP-991) and exposes an MCP endpoint at <code>/mcp</code>.</p></body></html>" )
462+ Html (
463+ "<html><body><h1>CIMD OAuth + MCP Server</h1><p>This server supports Client ID Metadata Documents (SEP-991) and exposes an MCP endpoint at <code>/mcp</code>.</p></body></html>" ,
464+ )
458465}
459466
460467#[ tokio:: main]
@@ -480,11 +487,19 @@ async fn main() -> Result<()> {
480487
481488 let addr = BIND_ADDRESS . parse :: < SocketAddr > ( ) ?;
482489
490+ let cors_layer = CorsLayer :: new ( )
491+ . allow_origin ( Any )
492+ . allow_methods ( Any )
493+ . allow_headers ( Any ) ;
494+
483495 let app = Router :: new ( )
484496 . route ( "/" , get ( index) )
485- . route ( "/.well-known/oauth-authorization-server" , get ( oauth_metadata) )
497+ . route (
498+ "/.well-known/oauth-authorization-server" ,
499+ get ( oauth_metadata) ,
500+ )
486501 . route ( "/authorize" , get ( authorize_get) . post ( authorize_post) )
487- . route ( "/token" , post ( token) )
502+ . route ( "/token" , post ( token) . layer ( cors_layer . clone ( ) ) )
488503 . nest_service ( "/mcp" , mcp_service)
489504 . with_state ( state) ;
490505
@@ -497,5 +512,3 @@ async fn main() -> Result<()> {
497512
498513 Ok ( ( ) )
499514}
500-
501-
0 commit comments