Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"dependencies": {
"ajv": "^8.17.1",
"ajv-formats": "^3.0.1",
"bowser": "^2.12.0",
"content-type": "^1.0.5",
"cors": "^2.8.5",
"cross-spawn": "^7.0.5",
Expand Down
237 changes: 152 additions & 85 deletions src/client/auth.test.ts

Large diffs are not rendered by default.

63 changes: 45 additions & 18 deletions src/client/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import {
UnauthorizedClientError
} from '../server/auth/errors.js';
import { FetchLike } from '../shared/transport.js';
import { UserAgentProvider } from '../shared/userAgent.js';

/**
* Implements an end-to-end OAuth client to be used with one MCP server.
Expand Down Expand Up @@ -296,6 +297,7 @@ export async function auth(
scope?: string;
resourceMetadataUrl?: URL;
fetchFn?: FetchLike;
userAgentProvider: UserAgentProvider;
}
): Promise<AuthResult> {
try {
Expand All @@ -322,19 +324,21 @@ async function authInternal(
authorizationCode,
scope,
resourceMetadataUrl,
fetchFn
fetchFn,
userAgentProvider
}: {
serverUrl: string | URL;
authorizationCode?: string;
scope?: string;
resourceMetadataUrl?: URL;
fetchFn?: FetchLike;
userAgentProvider: UserAgentProvider;
}
): Promise<AuthResult> {
let resourceMetadata: OAuthProtectedResourceMetadata | undefined;
let authorizationServerUrl: string | URL | undefined;
try {
resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, { resourceMetadataUrl }, fetchFn);
resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, userAgentProvider, { resourceMetadataUrl }, fetchFn);
if (resourceMetadata.authorization_servers && resourceMetadata.authorization_servers.length > 0) {
authorizationServerUrl = resourceMetadata.authorization_servers[0];
}
Expand All @@ -352,7 +356,7 @@ async function authInternal(

const resource: URL | undefined = await selectResourceURL(serverUrl, provider, resourceMetadata);

const metadata = await discoverAuthorizationServerMetadata(authorizationServerUrl, {
const metadata = await discoverAuthorizationServerMetadata(authorizationServerUrl, userAgentProvider, {
fetchFn
});

Expand All @@ -370,6 +374,7 @@ async function authInternal(
const fullInformation = await registerClient(authorizationServerUrl, {
metadata,
clientMetadata: provider.clientMetadata,
userAgentProvider,
fetchFn
});

Expand All @@ -388,7 +393,8 @@ async function authInternal(
redirectUri: provider.redirectUrl,
resource,
addClientAuthentication: provider.addClientAuthentication,
fetchFn: fetchFn
fetchFn: fetchFn,
userAgentProvider
});

await provider.saveTokens(tokens);
Expand All @@ -407,7 +413,8 @@ async function authInternal(
refreshToken: tokens.refresh_token,
resource,
addClientAuthentication: provider.addClientAuthentication,
fetchFn
fetchFn,
userAgentProvider
});

await provider.saveTokens(newTokens);
Expand Down Expand Up @@ -500,10 +507,11 @@ export function extractResourceMetadataUrl(res: Response): URL | undefined {
*/
export async function discoverOAuthProtectedResourceMetadata(
serverUrl: string | URL,
userAgentProvider: UserAgentProvider,
opts?: { protocolVersion?: string; resourceMetadataUrl?: string | URL },
fetchFn: FetchLike = fetch
): Promise<OAuthProtectedResourceMetadata> {
const response = await discoverMetadataWithFallback(serverUrl, 'oauth-protected-resource', fetchFn, {
const response = await discoverMetadataWithFallback(serverUrl, 'oauth-protected-resource', userAgentProvider, fetchFn, {
protocolVersion: opts?.protocolVersion,
metadataUrl: opts?.resourceMetadataUrl
});
Expand Down Expand Up @@ -557,9 +565,15 @@ function buildWellKnownPath(
/**
* Tries to discover OAuth metadata at a specific URL
*/
async function tryMetadataDiscovery(url: URL, protocolVersion: string, fetchFn: FetchLike = fetch): Promise<Response | undefined> {
async function tryMetadataDiscovery(
url: URL,
protocolVersion: string,
userAgentProvider: UserAgentProvider,
fetchFn: FetchLike = fetch
): Promise<Response | undefined> {
const headers = {
'MCP-Protocol-Version': protocolVersion
'MCP-Protocol-Version': protocolVersion,
'User-Agent': await userAgentProvider()
};
return await fetchWithCorsRetry(url, headers, fetchFn);
}
Expand All @@ -577,6 +591,7 @@ function shouldAttemptFallback(response: Response | undefined, pathname: string)
async function discoverMetadataWithFallback(
serverUrl: string | URL,
wellKnownType: 'oauth-authorization-server' | 'oauth-protected-resource',
userAgentProvider: UserAgentProvider,
fetchFn: FetchLike,
opts?: { protocolVersion?: string; metadataUrl?: string | URL; metadataServerUrl?: string | URL }
): Promise<Response | undefined> {
Expand All @@ -593,12 +608,12 @@ async function discoverMetadataWithFallback(
url.search = issuer.search;
}

let response = await tryMetadataDiscovery(url, protocolVersion, fetchFn);
let response = await tryMetadataDiscovery(url, protocolVersion, userAgentProvider, fetchFn);

// If path-aware discovery fails with 404 and we're not already at root, try fallback to root discovery
if (!opts?.metadataUrl && shouldAttemptFallback(response, issuer.pathname)) {
const rootUrl = new URL(`/.well-known/${wellKnownType}`, issuer);
response = await tryMetadataDiscovery(rootUrl, protocolVersion, fetchFn);
response = await tryMetadataDiscovery(rootUrl, protocolVersion, userAgentProvider, fetchFn);
}

return response;
Expand All @@ -614,6 +629,7 @@ async function discoverMetadataWithFallback(
*/
export async function discoverOAuthMetadata(
issuer: string | URL,
userAgentProvider: UserAgentProvider,
{
authorizationServerUrl,
protocolVersion
Expand All @@ -634,7 +650,7 @@ export async function discoverOAuthMetadata(
}
protocolVersion ??= LATEST_PROTOCOL_VERSION;

const response = await discoverMetadataWithFallback(authorizationServerUrl, 'oauth-authorization-server', fetchFn, {
const response = await discoverMetadataWithFallback(authorizationServerUrl, 'oauth-authorization-server', userAgentProvider, fetchFn, {
protocolVersion,
metadataServerUrl: authorizationServerUrl
});
Expand Down Expand Up @@ -730,6 +746,7 @@ export function buildDiscoveryUrls(authorizationServerUrl: string | URL): { url:
*/
export async function discoverAuthorizationServerMetadata(
authorizationServerUrl: string | URL,
userAgentProvider: UserAgentProvider,
{
fetchFn = fetch,
protocolVersion = LATEST_PROTOCOL_VERSION
Expand All @@ -740,7 +757,8 @@ export async function discoverAuthorizationServerMetadata(
): Promise<AuthorizationServerMetadata | undefined> {
const headers = {
'MCP-Protocol-Version': protocolVersion,
Accept: 'application/json'
Accept: 'application/json',
'User-Agent': await userAgentProvider()
};

// Get the list of URLs to try
Expand Down Expand Up @@ -873,7 +891,8 @@ export async function exchangeAuthorization(
redirectUri,
resource,
addClientAuthentication,
fetchFn
fetchFn,
userAgentProvider
}: {
metadata?: AuthorizationServerMetadata;
clientInformation: OAuthClientInformation;
Expand All @@ -883,6 +902,7 @@ export async function exchangeAuthorization(
resource?: URL;
addClientAuthentication?: OAuthClientProvider['addClientAuthentication'];
fetchFn?: FetchLike;
userAgentProvider: UserAgentProvider;
}
): Promise<OAuthTokens> {
const grantType = 'authorization_code';
Expand All @@ -896,7 +916,8 @@ export async function exchangeAuthorization(
// Exchange code for tokens
const headers = new Headers({
'Content-Type': 'application/x-www-form-urlencoded',
Accept: 'application/json'
Accept: 'application/json',
'User-Agent': await userAgentProvider()
});
const params = new URLSearchParams({
grant_type: grantType,
Expand Down Expand Up @@ -952,14 +973,16 @@ export async function refreshAuthorization(
refreshToken,
resource,
addClientAuthentication,
fetchFn
fetchFn,
userAgentProvider
}: {
metadata?: AuthorizationServerMetadata;
clientInformation: OAuthClientInformation;
refreshToken: string;
resource?: URL;
addClientAuthentication?: OAuthClientProvider['addClientAuthentication'];
fetchFn?: FetchLike;
userAgentProvider: UserAgentProvider;
}
): Promise<OAuthTokens> {
const grantType = 'refresh_token';
Expand All @@ -977,7 +1000,8 @@ export async function refreshAuthorization(

// Exchange refresh token
const headers = new Headers({
'Content-Type': 'application/x-www-form-urlencoded'
'Content-Type': 'application/x-www-form-urlencoded',
'User-Agent': await userAgentProvider()
});
const params = new URLSearchParams({
grant_type: grantType,
Expand Down Expand Up @@ -1018,11 +1042,13 @@ export async function registerClient(
{
metadata,
clientMetadata,
fetchFn
fetchFn,
userAgentProvider
}: {
metadata?: AuthorizationServerMetadata;
clientMetadata: OAuthClientMetadata;
fetchFn?: FetchLike;
userAgentProvider: UserAgentProvider;
}
): Promise<OAuthClientInformationFull> {
let registrationUrl: URL;
Expand All @@ -1040,7 +1066,8 @@ export async function registerClient(
const response = await (fetchFn ?? fetch)(registrationUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
'Content-Type': 'application/json',
'User-Agent': await userAgentProvider()
},
body: JSON.stringify(clientMetadata)
});
Expand Down
12 changes: 8 additions & 4 deletions src/client/middleware.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ describe('withOAuth', () => {
expect(mockAuth).toHaveBeenCalledWith(mockProvider, {
serverUrl: 'https://api.example.com',
resourceMetadataUrl: mockResourceUrl,
fetchFn: mockFetch
fetchFn: mockFetch,
userAgentProvider: expect.any(Function)
});

// Verify the retry used the new token
Expand Down Expand Up @@ -186,7 +187,8 @@ describe('withOAuth', () => {
expect(mockAuth).toHaveBeenCalledWith(mockProvider, {
serverUrl: 'https://api.example.com', // Should be extracted from request URL
resourceMetadataUrl: mockResourceUrl,
fetchFn: mockFetch
fetchFn: mockFetch,
userAgentProvider: expect.any(Function)
});

// Verify the retry used the new token
Expand Down Expand Up @@ -357,7 +359,8 @@ describe('withOAuth', () => {
expect(mockAuth).toHaveBeenCalledWith(mockProvider, {
serverUrl: 'https://api.example.com', // Should extract origin from URL object
resourceMetadataUrl: undefined,
fetchFn: mockFetch
fetchFn: mockFetch,
userAgentProvider: expect.any(Function)
});
});
});
Expand Down Expand Up @@ -896,7 +899,8 @@ describe('Integration Tests', () => {
expect(mockAuth).toHaveBeenCalledWith(mockProvider, {
serverUrl: 'https://mcp-server.example.com',
resourceMetadataUrl: new URL('https://auth.example.com/.well-known/oauth-protected-resource'),
fetchFn: mockFetch
fetchFn: mockFetch,
userAgentProvider: expect.any(Function)
});
});
});
Expand Down
8 changes: 6 additions & 2 deletions src/client/middleware.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { auth, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from './auth.js';
import { FetchLike } from '../shared/transport.js';
import { createUserAgentProvider, UserAgentProvider } from '../shared/userAgent.js';

/**
* Middleware function that wraps and enhances fetch functionality.
Expand Down Expand Up @@ -31,11 +32,13 @@ export type Middleware = (next: FetchLike) => FetchLike;
*
* @param provider - OAuth client provider for authentication
* @param baseUrl - Base URL for OAuth server discovery (defaults to request URL domain)
* @param userAgentProvider - User agent provider for the connection.
* @returns A fetch middleware function
*/
export const withOAuth =
(provider: OAuthClientProvider, baseUrl?: string | URL): Middleware =>
(provider: OAuthClientProvider, baseUrl?: string | URL, userAgentProvider?: UserAgentProvider): Middleware =>
next => {
const uaProvider = userAgentProvider ?? createUserAgentProvider();
return async (input, init) => {
const makeRequest = async (): Promise<Response> => {
const headers = new Headers(init?.headers);
Expand All @@ -62,7 +65,8 @@ export const withOAuth =
const result = await auth(provider, {
serverUrl,
resourceMetadataUrl,
fetchFn: next
fetchFn: next,
userAgentProvider: uaProvider
});

if (result === 'REDIRECT') {
Expand Down
Loading