Skip to content

Commit 73d4574

Browse files
committed
fix(oauth): Three oauth discovery and registration issues
1 parent e65ba53 commit 73d4574

File tree

1 file changed

+90
-51
lines changed

1 file changed

+90
-51
lines changed

crates/rmcp/src/transport/auth.rs

Lines changed: 90 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,11 @@ impl AuthorizationManager {
243243

244244
/// discover oauth2 metadata
245245
pub async fn discover_metadata(&self) -> Result<AuthorizationMetadata, AuthError> {
246-
if let Some(metadata) = self.try_discover_from_base(&self.base_url).await? {
246+
if let Some(metadata) = self.try_discover_oauth_server(&self.base_url).await? {
247247
return Ok(metadata);
248248
}
249249

250-
if let Some(metadata) = self.discover_via_resource_metadata().await? {
250+
if let Some(metadata) = self.discover_oauth_server_via_resource_metadata().await? {
251251
return Ok(metadata);
252252
}
253253

@@ -499,7 +499,9 @@ impl AuthorizationManager {
499499
}
500500
}
501501
}
502-
Err(e) => Err(AuthError::TokenExchangeFailed(e.to_string())),
502+
Err(e) => {
503+
return Err(AuthError::TokenExchangeFailed(e.to_string()));
504+
}
503505
};
504506

505507
// get expires_in from token response
@@ -596,10 +598,8 @@ impl AuthorizationManager {
596598
Ok(response)
597599
}
598600
}
599-
}
600601

601-
impl AuthorizationManager {
602-
async fn try_discover_from_base(
602+
async fn try_discover_oauth_server(
603603
&self,
604604
base_url: &Url,
605605
) -> Result<Option<AuthorizationMetadata>, AuthError> {
@@ -648,15 +648,15 @@ impl AuthorizationManager {
648648
Ok(Some(metadata))
649649
}
650650

651-
async fn discover_via_resource_metadata(
651+
async fn discover_oauth_server_via_resource_metadata(
652652
&self,
653653
) -> Result<Option<AuthorizationMetadata>, AuthError> {
654654
let Some(resource_metadata_url) = self.fetch_resource_metadata_url().await? else {
655655
return Ok(None);
656656
};
657657

658658
let Some(resource_metadata) = self
659-
.fetch_resource_metadata_document(&resource_metadata_url)
659+
.fetch_resource_metadata_from_url(&resource_metadata_url)
660660
.await?
661661
else {
662662
return Ok(None);
@@ -695,14 +695,16 @@ impl AuthorizationManager {
695695
continue;
696696
}
697697

698-
if let Some(metadata) = self.try_discover_from_base(&candidate_url).await? {
698+
if let Some(metadata) = self.try_discover_oauth_server(&candidate_url).await? {
699699
return Ok(Some(metadata));
700700
}
701701
}
702702

703703
Ok(None)
704704
}
705705

706+
/// Extract the resource metadata url from the WWW-Authenticate header value.
707+
/// https://www.rfc-editor.org/rfc/rfc9728.html#name-use-of-www-authenticate-for
706708
async fn fetch_resource_metadata_url(&self) -> Result<Option<Url>, AuthError> {
707709
let response = match self
708710
.http_client
@@ -731,7 +733,9 @@ impl AuthorizationManager {
731733
let Ok(value_str) = value.to_str() else {
732734
continue;
733735
};
734-
if let Some(url) = Self::extract_resource_metadata_url(value_str, &self.base_url) {
736+
if let Some(url) =
737+
Self::extract_resource_metadata_url_from_header(value_str, &self.base_url)
738+
{
735739
parsed_url = Some(url);
736740
break;
737741
}
@@ -740,7 +744,7 @@ impl AuthorizationManager {
740744
Ok(parsed_url)
741745
}
742746

743-
async fn fetch_resource_metadata_document(
747+
async fn fetch_resource_metadata_from_url(
744748
&self,
745749
resource_metadata_url: &Url,
746750
) -> Result<Option<ResourceServerMetadata>, AuthError> {
@@ -779,15 +783,16 @@ impl AuthorizationManager {
779783
Ok(Some(metadata))
780784
}
781785

782-
fn extract_resource_metadata_url(header: &str, base_url: &Url) -> Option<Url> {
783-
let lower_header = header.to_ascii_lowercase();
784-
let needle = "resource_metadata=";
786+
/// Extracts a url following `resource_metadata=` in a header value
787+
fn extract_resource_metadata_url_from_header(header: &str, base_url: &Url) -> Option<Url> {
788+
let header_lowercase = header.to_ascii_lowercase();
789+
let fragment_key = "resource_metadata=";
785790
let mut search_offset = 0;
786791

787-
while let Some(pos) = lower_header[search_offset..].find(needle) {
788-
let global_pos = search_offset + pos + needle.len();
792+
while let Some(pos) = header_lowercase[search_offset..].find(fragment_key) {
793+
let global_pos = search_offset + pos + fragment_key.len();
789794
let value_slice = &header[global_pos..];
790-
if let Some((value, consumed)) = Self::parse_auth_param_value(value_slice) {
795+
if let Some((value, consumed)) = Self::parse_next_header_value(value_slice) {
791796
if let Ok(url) = Url::parse(&value) {
792797
return Some(url);
793798
}
@@ -805,14 +810,23 @@ impl AuthorizationManager {
805810
None
806811
}
807812

808-
fn parse_auth_param_value(value: &str) -> Option<(String, usize)> {
809-
let trimmed = value.trim_start();
810-
let leading_ws = value.len() - trimmed.len();
811-
812-
if trimmed.starts_with('"') {
813+
/// Parses an authentication parameter value from a `WWW-Authenticate` header fragment.
814+
/// The header fragment should start with the header value after the `=` character and then
815+
/// reads until the value ends.
816+
///
817+
/// Returns the extracted value together with the number of bytes consumed from the provided
818+
/// fragment. Quoted values support escaped characters (e.g. `\"`). The parser skips leading
819+
/// whitespace before reading either a quoted or token value. If no well-formed value is found,
820+
/// `None` is returned.
821+
fn parse_next_header_value(header_fragment: &str) -> Option<(String, usize)> {
822+
let trimmed = header_fragment.trim_start();
823+
let leading_ws = header_fragment.len() - trimmed.len();
824+
825+
if let Some(stripped) = trimmed.strip_prefix('"') {
813826
let mut escaped = false;
814827
let mut result = String::new();
815-
for (idx, ch) in trimmed[1..].char_indices() {
828+
#[allow(clippy::manual_strip)]
829+
for (idx, ch) in stripped.char_indices() {
816830
if escaped {
817831
result.push(ch);
818832
escaped = false;
@@ -834,34 +848,6 @@ impl AuthorizationManager {
834848
}
835849
}
836850

837-
#[cfg(test)]
838-
mod tests {
839-
use super::AuthorizationManager;
840-
use url::Url;
841-
842-
#[test]
843-
fn parses_resource_metadata_parameter() {
844-
let header = r#"Bearer error="invalid_request", error_description="missing token", resource_metadata="https://example.com/.well-known/oauth-protected-resource/api""#;
845-
let base = Url::parse("https://example.com/api").unwrap();
846-
let parsed = AuthorizationManager::extract_resource_metadata_url(header, &base);
847-
assert_eq!(
848-
parsed.unwrap().as_str(),
849-
"https://example.com/.well-known/oauth-protected-resource/api"
850-
);
851-
}
852-
853-
#[test]
854-
fn parses_relative_resource_metadata_parameter() {
855-
let header = r#"Bearer error="invalid_request", resource_metadata="/.well-known/oauth-protected-resource/api""#;
856-
let base = Url::parse("https://example.com/api").unwrap();
857-
let parsed = AuthorizationManager::extract_resource_metadata_url(header, &base);
858-
assert_eq!(
859-
parsed.unwrap().as_str(),
860-
"https://example.com/.well-known/oauth-protected-resource/api"
861-
);
862-
}
863-
}
864-
865851
/// oauth2 authorization session, for guiding user to complete the authorization process
866852
pub struct AuthorizationSession {
867853
pub auth_manager: AuthorizationManager,
@@ -1180,6 +1166,59 @@ impl OAuthState {
11801166
#[cfg(test)]
11811167
mod tests {
11821168
use super::AuthorizationManager;
1169+
use url::Url;
1170+
1171+
#[test]
1172+
fn parses_resource_metadata_parameter() {
1173+
let header = r#"Bearer error="invalid_request", error_description="missing token", resource_metadata="https://example.com/.well-known/oauth-protected-resource/api""#;
1174+
let base = Url::parse("https://example.com/api").unwrap();
1175+
let parsed = AuthorizationManager::extract_resource_metadata_url_from_header(header, &base);
1176+
assert_eq!(
1177+
parsed.unwrap().as_str(),
1178+
"https://example.com/.well-known/oauth-protected-resource/api"
1179+
);
1180+
}
1181+
1182+
#[test]
1183+
fn parses_relative_resource_metadata_parameter() {
1184+
let header = r#"Bearer error="invalid_request", resource_metadata="/.well-known/oauth-protected-resource/api""#;
1185+
let base = Url::parse("https://example.com/api").unwrap();
1186+
let parsed = AuthorizationManager::extract_resource_metadata_url_from_header(header, &base);
1187+
assert_eq!(
1188+
parsed.unwrap().as_str(),
1189+
"https://example.com/.well-known/oauth-protected-resource/api"
1190+
);
1191+
}
1192+
1193+
#[test]
1194+
fn parse_auth_param_value_handles_quoted_string() {
1195+
let fragment = r#""example", realm="foo""#;
1196+
let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap();
1197+
assert_eq!(parsed.0, "example");
1198+
assert_eq!(parsed.1, 9);
1199+
}
1200+
1201+
#[test]
1202+
fn parse_auth_param_value_handles_escaped_quotes_and_whitespace() {
1203+
let fragment = r#" "a\"b\\c" ,next=value"#;
1204+
let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap();
1205+
assert_eq!(parsed.0, r#"a"b\c"#);
1206+
assert_eq!(parsed.1, 12);
1207+
}
1208+
1209+
#[test]
1210+
fn parse_auth_param_value_handles_token_values() {
1211+
let fragment = " token,next";
1212+
let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap();
1213+
assert_eq!(parsed.0, "token");
1214+
assert_eq!(parsed.1, 7);
1215+
}
1216+
1217+
#[test]
1218+
fn parse_auth_param_value_returns_none_for_unterminated_quotes() {
1219+
let fragment = r#""unterminated,value"#;
1220+
assert!(AuthorizationManager::parse_next_header_value(fragment).is_none());
1221+
}
11831222

11841223
#[test]
11851224
fn well_known_paths_root() {

0 commit comments

Comments
 (0)