Skip to content

Commit b131bb1

Browse files
committed
fix(oauth): add CORS headers to token endpoint
Add CORS headers to token endpoint to allow cross-origin requests from browsers during OAuth authorization code exchange flow. Signed-off-by: tanish111 <[email protected]>
1 parent 45af2c2 commit b131bb1

File tree

3 files changed

+60
-42
lines changed

3 files changed

+60
-42
lines changed

crates/rmcp/src/transport/auth.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,10 @@ impl AuthorizationSession {
963963
) -> Result<Self, AuthError> {
964964
let metadata = auth_manager.metadata.as_ref();
965965
let supports_url_based_client_id = metadata
966-
.and_then(|m| m.additional_fields.get("client_id_metadata_document_supported"))
966+
.and_then(|m| {
967+
m.additional_fields
968+
.get("client_id_metadata_document_supported")
969+
})
967970
.and_then(|v| v.as_bool())
968971
.unwrap_or(false);
969972

@@ -1313,7 +1316,7 @@ impl OAuthState {
13131316
mod tests {
13141317
use url::Url;
13151318

1316-
use super::{is_https_url, AuthorizationManager};
1319+
use super::{AuthorizationManager, is_https_url};
13171320

13181321
// SEP-991: URL-based Client IDs
13191322
// Tests adapted from the TypeScript SDK's isHttpsUrl test suite

examples/clients/src/auth/oauth_client.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ async fn main() -> Result<()> {
8080

8181
let addr = SocketAddr::from(([127, 0, 0, 1], CALLBACK_PORT));
8282
tracing::info!("Starting callback server at: http://{}", addr);
83-
tracing::warn!("Note: Callback server may not receive callbacks if redirect URI doesn't match localhost if using CIMD (SEP-991)");
83+
tracing::warn!(
84+
"Note: Callback server may not receive callbacks if redirect URI doesn't match localhost if using CIMD (SEP-991)"
85+
);
8486

8587
// Start server in a separate task
8688
tokio::spawn(async move {

examples/servers/src/cimd_auth_streamhttp.rs

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use rmcp::transport::{
2121
use serde::{Deserialize, Serialize};
2222
use serde_json::Value;
2323
use tokio::sync::RwLock;
24+
use tower_http::cors::{Any, CorsLayer};
2425
use tracing::{error, info};
2526
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
2627
use url::Url;
@@ -71,7 +72,8 @@ fn generate_access_token() -> String {
7172
/// Validate that the client_id is a URL that meets CIMD mandatory requirements.
7273
/// Mirrors the JS validateClientIdUrl helper.
7374
fn validate_client_id_url(raw: &str) -> Result<String, String> {
74-
let url = Url::parse(raw).map_err(|_| "invalid_client_id: client_id must be a valid URL".to_string())?;
75+
let url = Url::parse(raw)
76+
.map_err(|_| "invalid_client_id: client_id must be a valid URL".to_string())?;
7577

7678
// MUST have https scheme
7779
if url.scheme() != "https" {
@@ -94,13 +96,16 @@ fn validate_client_id_url(raw: &str) -> Result<String, String> {
9496

9597
// MUST NOT contain a fragment component
9698
if url.fragment().is_some() {
97-
return Err("invalid_client_id: client_id URL MUST NOT contain a fragment component".to_string());
99+
return Err(
100+
"invalid_client_id: client_id URL MUST NOT contain a fragment component".to_string(),
101+
);
98102
}
99103

100104
// MUST NOT contain a username or password
101105
if !url.username().is_empty() || url.password().is_some() {
102106
return Err(
103-
"invalid_client_id: client_id URL MUST NOT contain a username or password component".to_string(),
107+
"invalid_client_id: client_id URL MUST NOT contain a username or password component"
108+
.to_string(),
104109
);
105110
}
106111

@@ -135,9 +140,9 @@ async fn fetch_and_validate_client_metadata(client_id_url: &str) -> Result<Value
135140
}
136141

137142
// MUST contain a client_id property equal to the URL of the document
138-
let client_id_value = json
139-
.get("client_id")
140-
.ok_or_else(|| "invalid_client: client metadata document MUST contain client_id".to_string())?;
143+
let client_id_value = json.get("client_id").ok_or_else(|| {
144+
"invalid_client: client metadata document MUST contain client_id".to_string()
145+
})?;
141146
if client_id_value != client_id_url {
142147
return Err(
143148
"invalid_client: client_id property in metadata document MUST match the document URL"
@@ -148,22 +153,27 @@ async fn fetch_and_validate_client_metadata(client_id_url: &str) -> Result<Value
148153
// token_endpoint_auth_method MUST NOT be any shared secret based method
149154
if let Some(method) = json.get("token_endpoint_auth_method") {
150155
if let Some(method_str) = method.as_str() {
151-
let forbidden = ["client_secret_post", "client_secret_basic", "client_secret_jwt"];
152-
if forbidden.contains(&method_str)
153-
|| method_str.starts_with("client_secret_")
154-
{
156+
let forbidden = [
157+
"client_secret_post",
158+
"client_secret_basic",
159+
"client_secret_jwt",
160+
];
161+
if forbidden.contains(&method_str) || method_str.starts_with("client_secret_") {
155162
return Err("invalid_client: token_endpoint_auth_method MUST NOT be a shared secret based method".to_string());
156163
}
157164
}
158165
}
159166

160167
// client_secret and client_secret_expires_at MUST NOT be used
161168
if json.get("client_secret").is_some() {
162-
return Err("invalid_client: client_secret MUST NOT be present in client metadata".to_string());
169+
return Err(
170+
"invalid_client: client_secret MUST NOT be present in client metadata".to_string(),
171+
);
163172
}
164173
if json.get("client_secret_expires_at").is_some() {
165174
return Err(
166-
"invalid_client: client_secret_expires_at MUST NOT be present in client metadata".to_string(),
175+
"invalid_client: client_secret_expires_at MUST NOT be present in client metadata"
176+
.to_string(),
167177
);
168178
}
169179

@@ -172,9 +182,9 @@ async fn fetch_and_validate_client_metadata(client_id_url: &str) -> Result<Value
172182

173183
/// Validate redirect_uri against metadata.redirect_uris (exact match).
174184
fn validate_redirect_uri(requested_redirect_uri: &str, metadata: &Value) -> Result<(), String> {
175-
let redirect_uris = metadata
176-
.get("redirect_uris")
177-
.ok_or_else(|| "invalid_client: client metadata must include redirect_uris array".to_string())?;
185+
let redirect_uris = metadata.get("redirect_uris").ok_or_else(|| {
186+
"invalid_client: client metadata must include redirect_uris array".to_string()
187+
})?;
178188

179189
let arr = redirect_uris
180190
.as_array()
@@ -185,7 +195,8 @@ fn validate_redirect_uri(requested_redirect_uri: &str, metadata: &Value) -> Resu
185195

186196
if !found {
187197
return Err(
188-
"invalid_request: redirect_uri MUST exactly match one of the registered redirect_uris".to_string(),
198+
"invalid_request: redirect_uri MUST exactly match one of the registered redirect_uris"
199+
.to_string(),
189200
);
190201
}
191202

@@ -194,8 +205,8 @@ fn validate_redirect_uri(requested_redirect_uri: &str, metadata: &Value) -> Resu
194205

195206
/// Minimal Authorization Server Metadata with CIMD support.
196207
async fn oauth_metadata() -> impl IntoResponse {
197-
let issuer = std::env::var("CIMD_ISSUER")
198-
.unwrap_or_else(|_| format!("http://{}", BIND_ADDRESS));
208+
let issuer =
209+
std::env::var("CIMD_ISSUER").unwrap_or_else(|_| format!("http://{}", BIND_ADDRESS));
199210

200211
let body = serde_json::json!({
201212
"issuer": issuer,
@@ -288,9 +299,7 @@ fn render_login_form(params: &AuthorizeQuery, error: Option<&str>) -> Html<Strin
288299
Html(html)
289300
}
290301

291-
async fn authorize_get(
292-
Query(params): Query<AuthorizeQuery>,
293-
) -> impl IntoResponse {
302+
async fn authorize_get(Query(params): Query<AuthorizeQuery>) -> impl IntoResponse {
294303
render_login_form(&params, None)
295304
}
296305

@@ -306,7 +315,7 @@ async fn authorize_post(
306315
state: form.state.clone(),
307316
scope: form.scope.clone(),
308317
};
309-
318+
310319
match handle_authorize(&state, &params, &form).await {
311320
Ok(redirect_response) => redirect_response,
312321
Err(error_response) => error_response,
@@ -337,10 +346,10 @@ async fn handle_authorize(
337346
));
338347
}
339348

340-
let client_id_url =
341-
validate_client_id_url(client_id_raw).map_err(|e| bad_request(&e))?;
342-
let metadata =
343-
fetch_and_validate_client_metadata(&client_id_url).await.map_err(|e| bad_request(&e))?;
349+
let client_id_url = validate_client_id_url(client_id_raw).map_err(|e| bad_request(&e))?;
350+
let metadata = fetch_and_validate_client_metadata(&client_id_url)
351+
.await
352+
.map_err(|e| bad_request(&e))?;
344353
validate_redirect_uri(redirect_uri, &metadata).map_err(|e| bad_request(&e))?;
345354

346355
// If this is a login POST, validate credentials
@@ -366,12 +375,11 @@ async fn handle_authorize(
366375
);
367376
}
368377

369-
let mut url =
370-
Url::parse(redirect_uri).map_err(|_| bad_request("invalid_request: redirect_uri is invalid"))?;
378+
let mut url = Url::parse(redirect_uri)
379+
.map_err(|_| bad_request("invalid_request: redirect_uri is invalid"))?;
371380
url.query_pairs_mut().append_pair("code", &code);
372381
if let Some(state_param) = &params.state {
373-
url.query_pairs_mut()
374-
.append_pair("state", state_param);
382+
url.query_pairs_mut().append_pair("state", state_param);
375383
}
376384

377385
Ok(Redirect::to(url.as_str()).into_response())
@@ -396,10 +404,7 @@ struct TokenRequest {
396404
code: Option<String>,
397405
}
398406

399-
async fn token(
400-
State(state): State<AppState>,
401-
Form(form): Form<TokenRequest>,
402-
) -> impl IntoResponse {
407+
async fn token(State(state): State<AppState>, Form(form): Form<TokenRequest>) -> impl IntoResponse {
403408
if form.grant_type.as_deref() != Some("authorization_code") {
404409
let body = serde_json::json!({
405410
"error": "unsupported_grant_type",
@@ -454,7 +459,9 @@ async fn token(
454459
}
455460

456461
async fn index() -> Html<&'static str> {
457-
Html("<html><body><h1>CIMD OAuth + MCP Server</h1><p>This server supports Client ID Metadata Documents (SEP-991) and exposes an MCP endpoint at <code>/mcp</code>.</p></body></html>")
462+
Html(
463+
"<html><body><h1>CIMD OAuth + MCP Server</h1><p>This server supports Client ID Metadata Documents (SEP-991) and exposes an MCP endpoint at <code>/mcp</code>.</p></body></html>",
464+
)
458465
}
459466

460467
#[tokio::main]
@@ -480,11 +487,19 @@ async fn main() -> Result<()> {
480487

481488
let addr = BIND_ADDRESS.parse::<SocketAddr>()?;
482489

490+
let cors_layer = CorsLayer::new()
491+
.allow_origin(Any)
492+
.allow_methods(Any)
493+
.allow_headers(Any);
494+
483495
let app = Router::new()
484496
.route("/", get(index))
485-
.route("/.well-known/oauth-authorization-server", get(oauth_metadata))
497+
.route(
498+
"/.well-known/oauth-authorization-server",
499+
get(oauth_metadata),
500+
)
486501
.route("/authorize", get(authorize_get).post(authorize_post))
487-
.route("/token", post(token))
502+
.route("/token", post(token).layer(cors_layer.clone()))
488503
.nest_service("/mcp", mcp_service)
489504
.with_state(state);
490505

@@ -497,5 +512,3 @@ async fn main() -> Result<()> {
497512

498513
Ok(())
499514
}
500-
501-

0 commit comments

Comments
 (0)