From 93f0d91d0b7b332f79816f3574e4b797e7abd17f Mon Sep 17 00:00:00 2001 From: Manuel Naujoks Date: Sat, 4 Oct 2025 14:10:22 +0200 Subject: [PATCH 1/3] Tokens can be cached beyond the lifetime of the (http) transport. --- .../Authentication/ClientOAuthOptions.cs | 6 +++++ .../Authentication/ClientOAuthProvider.cs | 19 +++++++------ .../Authentication/ITokenCache.cs | 17 ++++++++++++ .../Authentication/InMemoryTokenCache.cs | 27 +++++++++++++++++++ .../Authentication/TokenContainer.cs | 4 +-- 5 files changed, 63 insertions(+), 10 deletions(-) create mode 100644 src/ModelContextProtocol.Core/Authentication/ITokenCache.cs create mode 100644 src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs index cc6a8952e..ecb57df0a 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs @@ -86,4 +86,10 @@ public sealed class ClientOAuthOptions /// /// public IDictionary AdditionalAuthorizationParameters { get; set; } = new Dictionary(); + + /// + /// Gets or sets the token cache to use for storing and retrieving tokens beyond the lifetime of the transport. + /// If none is provided, tokens will be cached with the transport. + /// + public ITokenCache? TokenCache { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs index 468728982..e59fc22e8 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -43,7 +43,7 @@ internal sealed partial class ClientOAuthProvider private string? _clientId; private string? _clientSecret; - private TokenContainer? _token; + private ITokenCache _tokenCache; private AuthorizationServerMetadata? _authServerMetadata; /// @@ -85,6 +85,7 @@ public ClientOAuthProvider( _dcrClientUri = options.DynamicClientRegistration?.ClientUri; _dcrInitialAccessToken = options.DynamicClientRegistration?.InitialAccessToken; _dcrResponseDelegate = options.DynamicClientRegistration?.ResponseDelegate; + _tokenCache = options.TokenCache ?? new InMemoryTokenCache(); } /// @@ -138,20 +139,22 @@ public ClientOAuthProvider( { ThrowIfNotBearerScheme(scheme); + var token = await _tokenCache.GetTokenAsync(cancellationToken).ConfigureAwait(false); + // Return the token if it's valid - if (_token != null && _token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) + if (token != null && token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) { - return _token.AccessToken; + return token.AccessToken; } // Try to refresh the token if we have a refresh token - if (_token?.RefreshToken != null && _authServerMetadata != null) + if (token?.RefreshToken != null && _authServerMetadata != null) { - var newToken = await RefreshTokenAsync(_token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); + var newToken = await RefreshTokenAsync(token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); if (newToken != null) { - _token = newToken; - return _token.AccessToken; + await _tokenCache.StoreTokenAsync(newToken, cancellationToken).ConfigureAwait(false); + return newToken.AccessToken; } } @@ -237,7 +240,7 @@ private async Task PerformOAuthAuthorizationAsync( ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token."); } - _token = token; + await _tokenCache.StoreTokenAsync(token, cancellationToken).ConfigureAwait(false); LogOAuthAuthorizationCompleted(); } diff --git a/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs new file mode 100644 index 000000000..3619286b3 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs @@ -0,0 +1,17 @@ +namespace ModelContextProtocol.Authentication; + +/// +/// Allows the client to cache access tokens beyond the lifetime of the transport. +/// +public interface ITokenCache +{ + /// + /// Cache the token. + /// + Task StoreTokenAsync(TokenContainer token, CancellationToken cancellationToken); + + /// + /// Get the cached token. + /// + Task GetTokenAsync(CancellationToken cancellationToken); +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs new file mode 100644 index 000000000..529d56269 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs @@ -0,0 +1,27 @@ + +namespace ModelContextProtocol.Authentication; + +/// +/// Caches the token in-memory within this instance. +/// +internal class InMemoryTokenCache : ITokenCache +{ + private TokenContainer? _token; + + /// + /// Cache the token. + /// + public Task StoreTokenAsync(TokenContainer token, CancellationToken cancellationToken) + { + _token = token; + return Task.CompletedTask; + } + + /// + /// Get the cached token. + /// + public Task GetTokenAsync(CancellationToken cancellationToken) + { + return Task.FromResult(_token); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs index dc55292b9..7ffe05372 100644 --- a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs @@ -5,7 +5,7 @@ namespace ModelContextProtocol.Authentication; /// /// Represents a token response from the OAuth server. /// -internal sealed class TokenContainer +public sealed class TokenContainer { /// /// Gets or sets the access token. @@ -46,7 +46,7 @@ internal sealed class TokenContainer /// /// Gets or sets the timestamp when the token was obtained. /// - [JsonIgnore] + [JsonPropertyName("obtained_at")] public DateTimeOffset ObtainedAt { get; set; } /// From 26f80ef6dbab2215f095a3c15cb72d417b727d0c Mon Sep 17 00:00:00 2001 From: Manuel Naujoks Date: Sat, 11 Oct 2025 18:45:31 +0200 Subject: [PATCH 2/3] Tests, ValueTasks, and dedicated type for caching. --- .../Authentication/ClientOAuthProvider.cs | 7 +- .../Authentication/ITokenCache.cs | 8 +- .../Authentication/InMemoryTokenCache.cs | 10 +- .../Authentication/TokenContainer.cs | 4 +- .../Authentication/TokenContainerCacheable.cs | 42 ++++ .../Authentication/TokenContainerConvert.cs | 26 ++ .../Client/CustomTokenCacheTests.cs | 233 ++++++++++++++++++ 7 files changed, 316 insertions(+), 14 deletions(-) create mode 100644 src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs create mode 100644 src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs create mode 100644 tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs index e59fc22e8..bb411eae8 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -139,7 +139,8 @@ public ClientOAuthProvider( { ThrowIfNotBearerScheme(scheme); - var token = await _tokenCache.GetTokenAsync(cancellationToken).ConfigureAwait(false); + var cachedToken = await _tokenCache.GetTokenAsync(cancellationToken).ConfigureAwait(false); + var token = cachedToken?.ForUse(); // Return the token if it's valid if (token != null && token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) @@ -153,7 +154,7 @@ public ClientOAuthProvider( var newToken = await RefreshTokenAsync(token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); if (newToken != null) { - await _tokenCache.StoreTokenAsync(newToken, cancellationToken).ConfigureAwait(false); + await _tokenCache.StoreTokenAsync(newToken.ForCache(), cancellationToken).ConfigureAwait(false); return newToken.AccessToken; } } @@ -240,7 +241,7 @@ private async Task PerformOAuthAuthorizationAsync( ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token."); } - await _tokenCache.StoreTokenAsync(token, cancellationToken).ConfigureAwait(false); + await _tokenCache.StoreTokenAsync(token.ForCache(), cancellationToken).ConfigureAwait(false); LogOAuthAuthorizationCompleted(); } diff --git a/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs index 3619286b3..46d4cc37b 100644 --- a/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs +++ b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs @@ -6,12 +6,12 @@ namespace ModelContextProtocol.Authentication; public interface ITokenCache { /// - /// Cache the token. + /// Cache the token. After a new access token is acquired, this method is invoked to store it. /// - Task StoreTokenAsync(TokenContainer token, CancellationToken cancellationToken); + ValueTask StoreTokenAsync(TokenContainerCacheable token, CancellationToken cancellationToken); /// - /// Get the cached token. + /// Get the cached token. This method is invoked for every request. /// - Task GetTokenAsync(CancellationToken cancellationToken); + ValueTask GetTokenAsync(CancellationToken cancellationToken); } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs index 529d56269..56346f731 100644 --- a/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs +++ b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs @@ -6,22 +6,22 @@ namespace ModelContextProtocol.Authentication; /// internal class InMemoryTokenCache : ITokenCache { - private TokenContainer? _token; + private TokenContainerCacheable? _token; /// /// Cache the token. /// - public Task StoreTokenAsync(TokenContainer token, CancellationToken cancellationToken) + public ValueTask StoreTokenAsync(TokenContainerCacheable token, CancellationToken cancellationToken) { _token = token; - return Task.CompletedTask; + return default; } /// /// Get the cached token. /// - public Task GetTokenAsync(CancellationToken cancellationToken) + public ValueTask GetTokenAsync(CancellationToken cancellationToken) { - return Task.FromResult(_token); + return new ValueTask(_token); } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs index 7ffe05372..dc55292b9 100644 --- a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs @@ -5,7 +5,7 @@ namespace ModelContextProtocol.Authentication; /// /// Represents a token response from the OAuth server. /// -public sealed class TokenContainer +internal sealed class TokenContainer { /// /// Gets or sets the access token. @@ -46,7 +46,7 @@ public sealed class TokenContainer /// /// Gets or sets the timestamp when the token was obtained. /// - [JsonPropertyName("obtained_at")] + [JsonIgnore] public DateTimeOffset ObtainedAt { get; set; } /// diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs new file mode 100644 index 000000000..5f6bf0e5c --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs @@ -0,0 +1,42 @@ +namespace ModelContextProtocol.Authentication; + +/// +/// Represents a cacheable token representation. +/// +public class TokenContainerCacheable +{ + /// + /// Gets or sets the access token. + /// + public string AccessToken { get; set; } = string.Empty; + + /// + /// Gets or sets the refresh token. + /// + public string? RefreshToken { get; set; } + + /// + /// Gets or sets the number of seconds until the access token expires. + /// + public int ExpiresIn { get; set; } + + /// + /// Gets or sets the extended expiration time in seconds. + /// + public int ExtExpiresIn { get; set; } + + /// + /// Gets or sets the token type (typically "Bearer"). + /// + public string TokenType { get; set; } = string.Empty; + + /// + /// Gets or sets the scope of the access token. + /// + public string Scope { get; set; } = string.Empty; + + /// + /// Gets or sets the timestamp when the token was obtained. + /// + public DateTimeOffset ObtainedAt { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs new file mode 100644 index 000000000..6e2c8e9cd --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs @@ -0,0 +1,26 @@ +namespace ModelContextProtocol.Authentication; + +internal static class TokenContainerConvert +{ + internal static TokenContainer ForUse(this TokenContainerCacheable token) => new() + { + AccessToken = token.AccessToken, + RefreshToken = token.RefreshToken, + ExpiresIn = token.ExpiresIn, + ExtExpiresIn = token.ExtExpiresIn, + TokenType = token.TokenType, + Scope = token.Scope, + ObtainedAt = token.ObtainedAt, + }; + + internal static TokenContainerCacheable ForCache(this TokenContainer token) => new() + { + AccessToken = token.AccessToken, + RefreshToken = token.RefreshToken, + ExpiresIn = token.ExpiresIn, + ExtExpiresIn = token.ExtExpiresIn, + TokenType = token.TokenType, + Scope = token.Scope, + ObtainedAt = token.ObtainedAt, + }; +} diff --git a/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs b/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs new file mode 100644 index 000000000..3ea1262ae --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs @@ -0,0 +1,233 @@ +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Authentication; +using System.Text.Json; +using Moq; +using Moq.Protected; +using System.Net; +using System.Text.Json.Nodes; +using System.Linq.Expressions; + +namespace ModelContextProtocol.Tests.Client; + +public class CustomTokenCacheTests +{ + [Fact] + public async Task GetTokenAsync_CachedAccessTokenIsUsedForOutgoingRequests() + { + // Arrange + var cachedAccessToken = $"my_access_token_{Guid.NewGuid()}"; + + var tokenCacheMock = new Mock(); + MockCachedAccessToken(tokenCacheMock, cachedAccessToken); + + var httpMessageHandlerMock = new Mock(); + MockInitializeResponse(httpMessageHandlerMock); + + var httpClientTransport = new HttpClientTransport( + transportOptions: NewHttpClientTransportOptions(tokenCacheMock.Object), + httpClient: new HttpClient(httpMessageHandlerMock.Object)); + + var connectedTransport = await httpClientTransport.ConnectAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Act + var initializeRequest = new JsonRpcRequest { Method = RequestMethods.Initialize, Id = new RequestId(1) }; + await connectedTransport.SendMessageAsync(initializeRequest, cancellationToken: TestContext.Current.CancellationToken); + + // Assert + httpMessageHandlerMock + .Protected() + .Verify("SendAsync", Times.AtLeastOnce(), ItExpr.Is(req => + req.RequestUri == new Uri("http://localhost:1337/") + && req.Headers.Authorization != null + && req.Headers.Authorization.Scheme == "Bearer" + && req.Headers.Authorization.Parameter == cachedAccessToken + ), ItExpr.IsAny()); + + httpMessageHandlerMock + .Protected() + .Verify("SendAsync", Times.Never(), ItExpr.Is(req => + req.RequestUri == new Uri("http://localhost:1337/") + && (req.Headers.Authorization == null || req.Headers.Authorization.Parameter != cachedAccessToken) + ), ItExpr.IsAny()); + } + + [Fact] + public async Task StoreTokenAsync_NewlyAcquiredAccessTokenIsCached() + { + // Arrange + var tokenCacheMock = new Mock(); + MockNoAccessTokenUntilStored(tokenCacheMock); + + var newAccessToken = $"new_access_token_{Guid.NewGuid()}"; + + var httpMessageHandlerMock = new Mock(); + MockUnauthorizedResponse(httpMessageHandlerMock); + MockProtectedResourceMetadataResponse(httpMessageHandlerMock); + MockAuthorizationServerMetadataResponse(httpMessageHandlerMock); + MockAccessTokenResponse(httpMessageHandlerMock, newAccessToken); + MockInitializeResponse(httpMessageHandlerMock); + + var httpClientTransport = new HttpClientTransport( + transportOptions: NewHttpClientTransportOptions(tokenCacheMock.Object), + httpClient: new HttpClient(httpMessageHandlerMock.Object)); + + var connectedTransport = await httpClientTransport.ConnectAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Act + var initializeRequest = new JsonRpcRequest { Method = RequestMethods.Initialize, Id = new RequestId(1) }; + await connectedTransport.SendMessageAsync(initializeRequest, cancellationToken: TestContext.Current.CancellationToken); + + // Assert + tokenCacheMock + .Verify(tc => tc.StoreTokenAsync( + It.Is(token => token.AccessToken == newAccessToken), + It.IsAny()), Times.Once); + } + + static HttpClientTransportOptions NewHttpClientTransportOptions(ITokenCache? tokenCache = null) => new() + { + Name = "MCP Server", + Endpoint = new Uri("http://localhost:1337/"), + TransportMode = HttpTransportMode.StreamableHttp, + OAuth = new() + { + ClientId = "mcp_inspector", + RedirectUri = new Uri("http://localhost:6274/oauth/callback"), + Scopes = ["openid", "profile", "offline_access"], + AuthorizationRedirectDelegate = (authorizationUrl, redirectUri, cancellationToken) => Task.FromResult($"auth_code_{Guid.NewGuid()}"), + TokenCache = tokenCache, + }, + }; + + static void MockCachedAccessToken(Mock tokenCache, string cachedAccessToken) + { + tokenCache + .Setup(tc => tc.GetTokenAsync(It.IsAny())) + .ReturnsAsync(new TokenContainerCacheable + { + AccessToken = cachedAccessToken, + ObtainedAt = DateTimeOffset.UtcNow, + ExpiresIn = (int)TimeSpan.FromHours(1).TotalSeconds, + }); + } + + static void MockNoAccessTokenUntilStored(Mock tokenCache) + { + tokenCache + .Setup(tc => tc.StoreTokenAsync(It.IsAny(), It.IsAny())) + .Callback((token, ct) => + { + // Simulate that the token is now cached + MockCachedAccessToken(tokenCache, token.AccessToken); + }) + .Returns(default(ValueTask)); + } + + static void MockUnauthorizedResponse(Mock httpMessageHandler) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1337/") + && req.Method == HttpMethod.Post + && (req.Headers.Authorization == null || string.IsNullOrWhiteSpace(req.Headers.Authorization.Parameter)), + response: new HttpResponseMessage(HttpStatusCode.Unauthorized) + { + Headers = { + { "WWW-Authenticate", "Bearer realm=\"Bearer\", resource_metadata=\"http://localhost:1337/.well-known/oauth-protected-resource\"" } + }, + }); + } + + static void MockProtectedResourceMetadataResponse(Mock httpMessageHandler) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1337/.well-known/oauth-protected-resource"), + response: new HttpResponseMessage(HttpStatusCode.OK) + { + Content = ToJsonContent(new + { + resource = "http://localhost:1337/", + authorization_servers = new[] { "http://localhost:1336/" }, + }) + }); + } + + static void MockAuthorizationServerMetadataResponse(Mock httpMessageHandler) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1336/.well-known/openid-configuration"), + response: new HttpResponseMessage(HttpStatusCode.OK) + { + Content = ToJsonContent(new + { + authorization_endpoint = "http://localhost:1336/connect/authorize", + token_endpoint = "http://localhost:1336/connect/token", + }) + }); + } + + static void MockAccessTokenResponse(Mock httpMessageHandler, string accessToken) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1336/connect/token"), + response: new HttpResponseMessage(HttpStatusCode.OK) + { + Content = ToJsonContent(new + { + access_token = accessToken, + }) + }); + } + + static void MockInitializeResponse(Mock httpMessageHandler) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1337/") + && req.Method == HttpMethod.Post + && req.Headers.Authorization != null + && req.Headers.Authorization.Scheme == "Bearer" + && !string.IsNullOrWhiteSpace(req.Headers.Authorization.Parameter), + response: new HttpResponseMessage(HttpStatusCode.OK) + { + Content = ToJsonContent(new JsonRpcResponse + { + Id = new RequestId(1), + Result = ToJson(new InitializeResult + { + ProtocolVersion = "2024-11-05", + Capabilities = new ServerCapabilities + { + Prompts = new PromptsCapability { ListChanged = true }, + Resources = new ResourcesCapability { Subscribe = true, ListChanged = true }, + Tools = new ToolsCapability { ListChanged = true }, + Logging = new LoggingCapability(), + Completions = new CompletionsCapability(), + }, + ServerInfo = new Implementation + { + Name = "mcp-test-server", + Version = "1.0.0" + }, + Instructions = "This server provides weather information and file system access." + }) + }), + }); + } + + static void MockHttpResponse(Mock httpMessageHandler, Expression>? request = null, HttpResponseMessage? response = null) + { + httpMessageHandler + .Protected() + .Setup>("SendAsync", request != null ? ItExpr.Is(request) : ItExpr.IsAny(), ItExpr.IsAny()) + .ReturnsAsync(response ?? new HttpResponseMessage()); + } + + static StringContent ToJsonContent(T content) => new( + content: JsonSerializer.Serialize(content, McpJsonUtilities.DefaultOptions), + encoding: System.Text.Encoding.UTF8, + mediaType: "application/json"); + + static JsonNode? ToJson(T content) => JsonSerializer.SerializeToNode( + value: content, + options: McpJsonUtilities.DefaultOptions); +} From bd4f0ff95f916f67a4072388c7f63674dd3b3e02 Mon Sep 17 00:00:00 2001 From: Manuel Naujoks Date: Sun, 26 Oct 2025 23:23:30 +0100 Subject: [PATCH 3/3] Type rename; alignment test --- .../Authentication/ClientOAuthProvider.cs | 39 +++++++++++-------- .../Authentication/ITokenCache.cs | 4 +- .../Authentication/InMemoryTokenCache.cs | 10 ++--- .../Authentication/TokenContainer.cs | 14 +------ .../Authentication/TokenContainerConvert.cs | 26 ------------- ...ContainerCacheable.cs => TokenResponse.cs} | 17 ++++---- .../McpJsonUtilities.cs | 2 +- .../Client/CustomTokenCacheTests.cs | 34 +++++++++++----- 8 files changed, 68 insertions(+), 78 deletions(-) delete mode 100644 src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs rename src/ModelContextProtocol.Core/Authentication/{TokenContainerCacheable.cs => TokenResponse.cs} (72%) diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs index bb411eae8..503c6e402 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -139,23 +139,22 @@ public ClientOAuthProvider( { ThrowIfNotBearerScheme(scheme); - var cachedToken = await _tokenCache.GetTokenAsync(cancellationToken).ConfigureAwait(false); - var token = cachedToken?.ForUse(); - + var tokens = await _tokenCache.GetTokensAsync(cancellationToken).ConfigureAwait(false); + // Return the token if it's valid - if (token != null && token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) + if (tokens != null && tokens.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) { - return token.AccessToken; + return tokens.AccessToken; } // Try to refresh the token if we have a refresh token - if (token?.RefreshToken != null && _authServerMetadata != null) + if (tokens?.RefreshToken != null && _authServerMetadata != null) { - var newToken = await RefreshTokenAsync(token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); - if (newToken != null) + var newTokens = await RefreshTokenAsync(tokens.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); + if (newTokens != null) { - await _tokenCache.StoreTokenAsync(newToken.ForCache(), cancellationToken).ConfigureAwait(false); - return newToken.AccessToken; + await _tokenCache.StoreTokensAsync(newTokens, cancellationToken).ConfigureAwait(false); + return newTokens.AccessToken; } } @@ -234,14 +233,14 @@ private async Task PerformOAuthAuthorizationAsync( } // Perform the OAuth flow - var token = await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false); + var tokens = await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false); - if (token is null) + if (tokens is null) { ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token."); } - await _tokenCache.StoreTokenAsync(token.ForCache(), cancellationToken).ConfigureAwait(false); + await _tokenCache.StoreTokensAsync(tokens, cancellationToken).ConfigureAwait(false); LogOAuthAuthorizationCompleted(); } @@ -413,15 +412,23 @@ private async Task FetchTokenAsync(HttpRequestMessage request, C httpResponse.EnsureSuccessStatusCode(); using var stream = await httpResponse.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); - var tokenResponse = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.TokenContainer, cancellationToken).ConfigureAwait(false); + var tokenResponse = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.TokenResponse, cancellationToken).ConfigureAwait(false); if (tokenResponse is null) { ThrowFailedToHandleUnauthorizedResponse($"The token endpoint '{request.RequestUri}' returned an empty response."); } - tokenResponse.ObtainedAt = DateTimeOffset.UtcNow; - return tokenResponse; + return new() + { + AccessToken = tokenResponse.AccessToken, + RefreshToken = tokenResponse.RefreshToken, + ExpiresIn = tokenResponse.ExpiresIn, + ExtExpiresIn = tokenResponse.ExtExpiresIn, + TokenType = tokenResponse.TokenType, + Scope = tokenResponse.Scope, + ObtainedAt = DateTimeOffset.UtcNow, + }; } /// diff --git a/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs index 46d4cc37b..3dc6e6351 100644 --- a/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs +++ b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs @@ -8,10 +8,10 @@ public interface ITokenCache /// /// Cache the token. After a new access token is acquired, this method is invoked to store it. /// - ValueTask StoreTokenAsync(TokenContainerCacheable token, CancellationToken cancellationToken); + ValueTask StoreTokensAsync(TokenContainer tokens, CancellationToken cancellationToken); /// /// Get the cached token. This method is invoked for every request. /// - ValueTask GetTokenAsync(CancellationToken cancellationToken); + ValueTask GetTokensAsync(CancellationToken cancellationToken); } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs index 56346f731..977cb6f88 100644 --- a/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs +++ b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs @@ -6,22 +6,22 @@ namespace ModelContextProtocol.Authentication; /// internal class InMemoryTokenCache : ITokenCache { - private TokenContainerCacheable? _token; + private TokenContainer? _tokens; /// /// Cache the token. /// - public ValueTask StoreTokenAsync(TokenContainerCacheable token, CancellationToken cancellationToken) + public ValueTask StoreTokensAsync(TokenContainer tokens, CancellationToken cancellationToken) { - _token = token; + _tokens = tokens; return default; } /// /// Get the cached token. /// - public ValueTask GetTokenAsync(CancellationToken cancellationToken) + public ValueTask GetTokensAsync(CancellationToken cancellationToken) { - return new ValueTask(_token); + return new ValueTask(_tokens); } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs index dc55292b9..5503c96f1 100644 --- a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs @@ -1,57 +1,47 @@ -using System.Text.Json.Serialization; - namespace ModelContextProtocol.Authentication; /// -/// Represents a token response from the OAuth server. +/// Represents a cacheable combination of tokens ready to be used for authentication. /// -internal sealed class TokenContainer +public class TokenContainer { /// /// Gets or sets the access token. /// - [JsonPropertyName("access_token")] public string AccessToken { get; set; } = string.Empty; /// /// Gets or sets the refresh token. /// - [JsonPropertyName("refresh_token")] public string? RefreshToken { get; set; } /// /// Gets or sets the number of seconds until the access token expires. /// - [JsonPropertyName("expires_in")] public int ExpiresIn { get; set; } /// /// Gets or sets the extended expiration time in seconds. /// - [JsonPropertyName("ext_expires_in")] public int ExtExpiresIn { get; set; } /// /// Gets or sets the token type (typically "Bearer"). /// - [JsonPropertyName("token_type")] public string TokenType { get; set; } = string.Empty; /// /// Gets or sets the scope of the access token. /// - [JsonPropertyName("scope")] public string Scope { get; set; } = string.Empty; /// /// Gets or sets the timestamp when the token was obtained. /// - [JsonIgnore] public DateTimeOffset ObtainedAt { get; set; } /// /// Gets the timestamp when the token expires, calculated from ObtainedAt and ExpiresIn. /// - [JsonIgnore] public DateTimeOffset ExpiresAt => ObtainedAt.AddSeconds(ExpiresIn); } diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs deleted file mode 100644 index 6e2c8e9cd..000000000 --- a/src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs +++ /dev/null @@ -1,26 +0,0 @@ -namespace ModelContextProtocol.Authentication; - -internal static class TokenContainerConvert -{ - internal static TokenContainer ForUse(this TokenContainerCacheable token) => new() - { - AccessToken = token.AccessToken, - RefreshToken = token.RefreshToken, - ExpiresIn = token.ExpiresIn, - ExtExpiresIn = token.ExtExpiresIn, - TokenType = token.TokenType, - Scope = token.Scope, - ObtainedAt = token.ObtainedAt, - }; - - internal static TokenContainerCacheable ForCache(this TokenContainer token) => new() - { - AccessToken = token.AccessToken, - RefreshToken = token.RefreshToken, - ExpiresIn = token.ExpiresIn, - ExtExpiresIn = token.ExtExpiresIn, - TokenType = token.TokenType, - Scope = token.Scope, - ObtainedAt = token.ObtainedAt, - }; -} diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs b/src/ModelContextProtocol.Core/Authentication/TokenResponse.cs similarity index 72% rename from src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs rename to src/ModelContextProtocol.Core/Authentication/TokenResponse.cs index 5f6bf0e5c..9eba5ffbf 100644 --- a/src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs +++ b/src/ModelContextProtocol.Core/Authentication/TokenResponse.cs @@ -1,42 +1,45 @@ +using System.Text.Json.Serialization; + namespace ModelContextProtocol.Authentication; /// -/// Represents a cacheable token representation. +/// Represents a token response from the OAuth server. /// -public class TokenContainerCacheable +internal sealed class TokenResponse { /// /// Gets or sets the access token. /// + [JsonPropertyName("access_token")] public string AccessToken { get; set; } = string.Empty; /// /// Gets or sets the refresh token. /// + [JsonPropertyName("refresh_token")] public string? RefreshToken { get; set; } /// /// Gets or sets the number of seconds until the access token expires. /// + [JsonPropertyName("expires_in")] public int ExpiresIn { get; set; } /// /// Gets or sets the extended expiration time in seconds. /// + [JsonPropertyName("ext_expires_in")] public int ExtExpiresIn { get; set; } /// /// Gets or sets the token type (typically "Bearer"). /// + [JsonPropertyName("token_type")] public string TokenType { get; set; } = string.Empty; /// /// Gets or sets the scope of the access token. /// + [JsonPropertyName("scope")] public string Scope { get; set; } = string.Empty; - - /// - /// Gets or sets the timestamp when the token was obtained. - /// - public DateTimeOffset ObtainedAt { get; set; } } diff --git a/src/ModelContextProtocol.Core/McpJsonUtilities.cs b/src/ModelContextProtocol.Core/McpJsonUtilities.cs index 8bc9e21b0..a6cb2e13e 100644 --- a/src/ModelContextProtocol.Core/McpJsonUtilities.cs +++ b/src/ModelContextProtocol.Core/McpJsonUtilities.cs @@ -158,7 +158,7 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(ProtectedResourceMetadata))] [JsonSerializable(typeof(AuthorizationServerMetadata))] - [JsonSerializable(typeof(TokenContainer))] + [JsonSerializable(typeof(TokenResponse))] [JsonSerializable(typeof(DynamicClientRegistrationRequest))] [JsonSerializable(typeof(DynamicClientRegistrationResponse))] diff --git a/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs b/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs index 3ea1262ae..fd16d3073 100644 --- a/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs @@ -2,6 +2,7 @@ using ModelContextProtocol.Protocol; using ModelContextProtocol.Authentication; using System.Text.Json; +using System.Text.Json.Serialization.Metadata; using Moq; using Moq.Protected; using System.Net; @@ -12,6 +13,16 @@ namespace ModelContextProtocol.Tests.Client; public class CustomTokenCacheTests { + [Fact] + public void TokenContainerIsAlignedWithTokenResponse() + { + var tokenResponseType = Type.GetType("ModelContextProtocol.Authentication.TokenResponse, ModelContextProtocol.Core"); + Assert.NotNull(tokenResponseType); + var tokenResponseProperties = tokenResponseType.GetProperties().Select(p => p.Name); + var tokenContainerProperties = typeof(TokenContainer).GetProperties().Select(p => p.Name); + Assert.Equivalent(tokenResponseProperties, tokenContainerProperties); + } + [Fact] public async Task GetTokenAsync_CachedAccessTokenIsUsedForOutgoingRequests() { @@ -80,8 +91,8 @@ public async Task StoreTokenAsync_NewlyAcquiredAccessTokenIsCached() // Assert tokenCacheMock - .Verify(tc => tc.StoreTokenAsync( - It.Is(token => token.AccessToken == newAccessToken), + .Verify(tc => tc.StoreTokensAsync( + It.Is(token => token.AccessToken == newAccessToken), It.IsAny()), Times.Once); } @@ -103,8 +114,8 @@ public async Task StoreTokenAsync_NewlyAcquiredAccessTokenIsCached() static void MockCachedAccessToken(Mock tokenCache, string cachedAccessToken) { tokenCache - .Setup(tc => tc.GetTokenAsync(It.IsAny())) - .ReturnsAsync(new TokenContainerCacheable + .Setup(tc => tc.GetTokensAsync(It.IsAny())) + .ReturnsAsync(new TokenContainer { AccessToken = cachedAccessToken, ObtainedAt = DateTimeOffset.UtcNow, @@ -115,8 +126,8 @@ static void MockCachedAccessToken(Mock tokenCache, string cachedAcc static void MockNoAccessTokenUntilStored(Mock tokenCache) { tokenCache - .Setup(tc => tc.StoreTokenAsync(It.IsAny(), It.IsAny())) - .Callback((token, ct) => + .Setup(tc => tc.StoreTokensAsync(It.IsAny(), It.IsAny())) + .Callback((token, ct) => { // Simulate that the token is now cached MockCachedAccessToken(tokenCache, token.AccessToken); @@ -216,18 +227,23 @@ static void MockInitializeResponse(Mock httpMessageHandler) static void MockHttpResponse(Mock httpMessageHandler, Expression>? request = null, HttpResponseMessage? response = null) { - httpMessageHandler + _ = httpMessageHandler .Protected() .Setup>("SendAsync", request != null ? ItExpr.Is(request) : ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(response ?? new HttpResponseMessage()); } static StringContent ToJsonContent(T content) => new( - content: JsonSerializer.Serialize(content, McpJsonUtilities.DefaultOptions), + content: JsonSerializer.Serialize(content, GetReflectionCapableJsonOptions()), encoding: System.Text.Encoding.UTF8, mediaType: "application/json"); static JsonNode? ToJson(T content) => JsonSerializer.SerializeToNode( value: content, - options: McpJsonUtilities.DefaultOptions); + options: GetReflectionCapableJsonOptions()); + + static JsonSerializerOptions GetReflectionCapableJsonOptions() => new(JsonSerializerDefaults.Web) + { + TypeInfoResolver = new DefaultJsonTypeInfoResolver() + }; }