Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ enum TransportType {
}

export class McpEventHandler {
private static readonly FILE_WATCH_DEBOUNCE_MS = 2000
#features: Features
#eventListenerRegistered: boolean
#currentEditingServerName: string | undefined
Expand All @@ -48,6 +49,12 @@ export class McpEventHandler {
#lastProgrammaticState: boolean = false
#serverNameBeforeUpdate: string | undefined

#releaseProgrammaticAfterDebounce(padMs = 500) {
setTimeout(() => {
this.#isProgrammaticChange = false
}, McpEventHandler.FILE_WATCH_DEBOUNCE_MS + padMs)
}

constructor(features: Features, telemetryService: TelemetryService) {
this.#features = features
this.#eventListenerRegistered = false
Expand Down Expand Up @@ -797,7 +804,7 @@ export class McpEventHandler {
command: selectedTransport === TransportType.STDIO ? params.optionsValues.command : undefined,
url: selectedTransport === TransportType.HTTP ? params.optionsValues.url : undefined,
enabled: true,
numTools: McpManager.instance.getAllToolsWithPermissions(serverName).length,
numTools: McpManager.instance.getAllToolsWithPermissions(sanitizedServerName).length,
scope: params.optionsValues['scope'] === 'global' ? 'global' : 'workspace',
transportType: selectedTransport,
languageServerVersion: this.#features.runtime.serverInfo.version,
Expand All @@ -812,6 +819,7 @@ export class McpEventHandler {

// Stay on add/edit page and show error to user
// Keep isProgrammaticChange true during error handling to prevent file watcher triggers
this.#releaseProgrammaticAfterDebounce()
if (isEditMode) {
params.id = 'edit-mcp'
params.title = sanitizedServerName
Expand All @@ -826,7 +834,7 @@ export class McpEventHandler {
this.#newlyAddedServers.delete(serverName)
}

this.#isProgrammaticChange = false
this.#releaseProgrammaticAfterDebounce()

// Go to tools permissions page
return this.#handleOpenMcpServer({ id: 'open-mcp-server', title: sanitizedServerName })
Expand Down Expand Up @@ -927,9 +935,10 @@ export class McpEventHandler {
perm.__configPath__ = agentPath
await mcpManager.updateServerPermission(serverName, perm)
this.#emitMCPConfigEvent()
this.#releaseProgrammaticAfterDebounce()
} catch (error) {
this.#features.logging.error(`Failed to enable MCP server: ${error}`)
this.#isProgrammaticChange = false
this.#releaseProgrammaticAfterDebounce()
}
return { id: params.id }
}
Expand All @@ -953,9 +962,10 @@ export class McpEventHandler {
perm.__configPath__ = agentPath
await mcpManager.updateServerPermission(serverName, perm)
this.#emitMCPConfigEvent()
this.#releaseProgrammaticAfterDebounce()
} catch (error) {
this.#features.logging.error(`Failed to disable MCP server: ${error}`)
this.#isProgrammaticChange = false
this.#releaseProgrammaticAfterDebounce()
}

return { id: params.id }
Expand All @@ -975,11 +985,11 @@ export class McpEventHandler {

try {
await McpManager.instance.removeServer(serverName)

this.#releaseProgrammaticAfterDebounce()
return { id: params.id }
} catch (error) {
this.#features.logging.error(`Failed to delete MCP server: ${error}`)
this.#isProgrammaticChange = false
this.#releaseProgrammaticAfterDebounce()
return { id: params.id }
}
}
Expand Down Expand Up @@ -1262,10 +1272,11 @@ export class McpEventHandler {
this.#pendingPermissionConfig = undefined

this.#features.logging.info(`Applied permission changes for server: ${serverName}`)
this.#releaseProgrammaticAfterDebounce()
return { id: params.id }
} catch (error) {
this.#features.logging.error(`Failed to save MCP permissions: ${error}`)
this.#isProgrammaticChange = false
this.#releaseProgrammaticAfterDebounce()
return { id: params.id }
}
}
Expand Down Expand Up @@ -1430,7 +1441,8 @@ export class McpEventHandler {
*/
#getServerStatusError(serverName: string): { title: string; icon: string; status: Status } | undefined {
const serverStates = McpManager.instance.getAllServerStates()
const serverState = serverStates.get(serverName)
const key = serverName ? sanitizeName(serverName) : serverName
const serverState = key ? serverStates.get(key) : undefined

if (!serverState) {
return undefined
Expand Down Expand Up @@ -1494,11 +1506,10 @@ export class McpEventHandler {
if (!this.#lastProgrammaticState) {
await this.#handleRefreshMCPList({ id: 'refresh-mcp-list' })
} else {
this.#isProgrammaticChange = false
this.#features.logging.debug('Skipping refresh due to programmatic change')
}
this.#debounceTimer = null
}, 2000)
}, McpEventHandler.FILE_WATCH_DEBOUNCE_MS)
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ export class McpManager {
/**
* Load configurations and initialize each enabled server.
*/
private async discoverAllServers(authIntent: AuthIntent = AuthIntent.Silent): Promise<void> {
private async discoverAllServers(): Promise<void> {
// Load agent config
const result = await loadAgentConfig(this.features.workspace, this.features.logging, this.agentPaths)

Expand Down Expand Up @@ -221,7 +221,7 @@ export class McpManager {
// Process servers in batches
for (let i = 0; i < totalServers; i += MAX_CONCURRENT_SERVERS) {
const batch = serversToInit.slice(i, i + MAX_CONCURRENT_SERVERS)
const batchPromises = batch.map(([name, cfg]) => this.initOneServer(name, cfg, authIntent))
const batchPromises = batch.map(([name, cfg]) => this.initOneServer(name, cfg, AuthIntent.Silent))

this.features.logging.debug(
`MCP: initializing batch of ${batch.length} servers (${i + 1}-${Math.min(i + MAX_CONCURRENT_SERVERS, totalServers)} of ${totalServers})`
Expand Down Expand Up @@ -373,19 +373,29 @@ export class McpManager {

if (needsOAuth) {
OAuthClient.initialize(this.features.workspace, this.features.logging)
const bearer = await OAuthClient.getValidAccessToken(base, {
interactive: authIntent === 'interactive',
})
// add authorization header if we are able to obtain a bearer token
if (bearer) {
headers = { ...headers, Authorization: `Bearer ${bearer}` }
} else if (authIntent === 'silent') {
// In silent mode we never launch a browser. If we cannot obtain a token
// from cache/refresh, surface a clear auth-required error and stop here.
throw new AgenticChatError(
`MCP: server '${serverName}' requires OAuth. Open "Edit MCP Server" and save to sign in.`,
'MCPServerAuthFailed'
)
try {
const bearer = await OAuthClient.getValidAccessToken(base, {
interactive: authIntent === AuthIntent.Interactive,
})
if (bearer) {
headers = { ...headers, Authorization: `Bearer ${bearer}` }
} else if (authIntent === AuthIntent.Silent) {
throw new AgenticChatError(
`MCP: server '${serverName}' requires OAuth. Open "Edit MCP Server" and save to sign in.`,
'MCPServerAuthFailed'
)
}
} catch (e: any) {
const msg = e?.message || ''
const short = /authorization_timed_out/i.test(msg)
? 'Sign-in timed out. Please try again.'
: /Authorization error|PKCE|access_denied|login|consent|token exchange failed/i.test(
msg
)
? 'Sign-in was cancelled or failed. Please try again.'
: `OAuth failed: ${msg}`

throw new AgenticChatError(`MCP: ${short}`, 'MCPServerAuthFailed')
}
}

Expand Down Expand Up @@ -1156,15 +1166,16 @@ export class McpManager {
*/
public async removeServerFromConfigFile(serverName: string): Promise<void> {
try {
const cfg = this.mcpServers.get(serverName)
const sanitized = sanitizeName(serverName)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the issue this fix is for?

const cfg = this.mcpServers.get(sanitized)
if (!cfg || !cfg.__configPath__) {
this.features.logging.warn(
`Cannot remove config for server '${serverName}': Config not found or missing path`
)
return
}

const unsanitizedName = this.serverNameMapping.get(serverName) || serverName
const unsanitizedName = this.serverNameMapping.get(sanitized) || serverName

// Remove from agent config
if (unsanitizedName && this.agentConfig) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ describe('OAuthClient getValidAccessToken()', () => {

stubFileSystem(cachedToken, cachedReg)

const token = await OAuthClient.getValidAccessToken(new URL('https://api.example.com/mcp'))
const token = await OAuthClient.getValidAccessToken(new URL('https://api.example.com/mcp'), {
interactive: true,
})
expect(token).to.equal('cached_access')
expect((http.createServer as any).calledOnce).to.be.true
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import * as path from 'path'
import { spawn } from 'child_process'
import { URL, URLSearchParams } from 'url'
import * as http from 'http'
import * as os from 'os'
import { Logger, Workspace } from '@aws/language-server-runtimes/server-interface'

interface Token {
Expand Down Expand Up @@ -46,9 +47,9 @@ export class OAuthClient {
*/
public static async getValidAccessToken(
mcpBase: URL,
opts: { interactive?: boolean } = { interactive: true }
opts: { interactive?: boolean } = { interactive: false }
): Promise<string | undefined> {
const interactive = opts?.interactive !== false
const interactive = opts?.interactive === true
const key = this.computeKey(mcpBase)
const regPath = path.join(this.cacheDir, `${key}.registration.json`)
const tokPath = path.join(this.cacheDir, `${key}.token.json`)
Expand Down Expand Up @@ -333,6 +334,7 @@ export class OAuthClient {
redirectUri: string,
server: http.Server
): Promise<Token> {
const DEFAULT_PKCE_TIMEOUT_MS = 20_000
// a) generate PKCE params
const verifier = this.b64url(crypto.randomBytes(32))
const challenge = this.b64url(crypto.createHash('sha256').update(verifier).digest())
Expand All @@ -353,25 +355,37 @@ export class OAuthClient {

const opener =
process.platform === 'win32'
? { cmd: 'cmd', args: ['/c', 'start', authz.toString()] }
? {
cmd: 'cmd',
args: ['/c', 'start', '', `"${authz.toString().replace(/"/g, '""')}"`],
}
: process.platform === 'darwin'
? { cmd: 'open', args: [authz.toString()] }
: { cmd: 'xdg-open', args: [authz.toString()] }

void spawn(opener.cmd, opener.args, { detached: true, stdio: 'ignore' }).unref()

// c) wait for code on our loopback
const { code, rxState, err } = await new Promise<{ code: string; rxState: string; err?: string }>(resolve => {
const waitForFlow = new Promise<{ code: string; rxState: string; err?: string; errDesc?: string }>(resolve => {
server.on('request', (req, res) => {
const u = new URL(req.url || '/', redirectUri)
const c = u.searchParams.get('code') || ''
const s = u.searchParams.get('state') || ''
const e = u.searchParams.get('error') || undefined
const ed = u.searchParams.get('error_description') || undefined
res.writeHead(200, { 'content-type': 'text/html' }).end('<h2>You may close this tab.</h2>')
resolve({ code: c, rxState: s, err: e })
resolve({ code: c, rxState: s, err: e, errDesc: ed })
})
})
if (err) throw new Error(`Authorization error: ${err}`)
const { code, rxState, err, errDesc } = await Promise.race([
waitForFlow,
new Promise<never>((_, reject) =>
setTimeout(() => reject(new Error('authorization_timed_out')), DEFAULT_PKCE_TIMEOUT_MS)
),
])
if (err) {
throw new Error(`Authorization error: ${err}${errDesc ? ` - ${errDesc}` : ''}`)
}
if (!code || rxState !== state) throw new Error('Invalid authorization response (state mismatch)')

// d) exchange code for token
Expand Down Expand Up @@ -438,12 +452,7 @@ export class OAuthClient {
}

/** Directory for caching registration + tokens */
private static readonly cacheDir = path.join(
process.env.HOME || process.env.USERPROFILE || '.',
'.aws',
'sso',
'cache'
)
private static readonly cacheDir = path.join(os.homedir(), '.aws', 'sso', 'cache')

/**
* Await server.listen() but reject if it emits 'error' (eg EADDRINUSE),
Expand Down
Loading