Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ export class McpManager {
cfg: MCPServerConfig,
authIntent: AuthIntent = AuthIntent.Silent
): Promise<void> {
const DEFAULT_SERVER_INIT_TIMEOUT_MS = 60_000
const DEFAULT_SERVER_INIT_TIMEOUT_MS = 120_000
this.setState(serverName, McpServerStatus.INITIALIZING, 0)

try {
Expand Down Expand Up @@ -373,7 +373,7 @@ export class McpManager {
}

if (needsOAuth) {
OAuthClient.initialize(this.features.workspace, this.features.logging)
OAuthClient.initialize(this.features.workspace, this.features.logging, this.features.lsp)
try {
const bearer = await OAuthClient.getValidAccessToken(base, {
interactive: authIntent === AuthIntent.Interactive,
Expand All @@ -382,7 +382,7 @@ export class McpManager {
headers = { ...headers, Authorization: `Bearer ${bearer}` }
} else if (authIntent === AuthIntent.Silent) {
throw new AgenticChatError(
`MCP: server '${serverName}' requires OAuth. Open "Edit MCP Server" and save to sign in.`,
`Server '${serverName}' requires OAuth. Click on Save to reauthenticate.`,
'MCPServerAuthFailed'
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ const fakeLogger = {
error: () => {},
}

const fakeLsp = {
window: {
showDocument: sinon.stub().resolves({ success: true }),
},
} as any

const fakeWorkspace = {
fs: {
exists: async (_path: string) => false,
Expand Down Expand Up @@ -93,9 +99,10 @@ describe('OAuthClient getValidAccessToken()', () => {

beforeEach(() => {
sinon.restore()
OAuthClient.initialize(fakeWorkspace, fakeLogger as any)
OAuthClient.initialize(fakeWorkspace, fakeLogger as any, fakeLsp)
sinon.stub(OAuthClient as any, 'computeKey').returns('testkey')
stubHttpServer()
;(fakeLsp.window.showDocument as sinon.SinonStub).resetHistory()
})

afterEach(() => sinon.restore())
Expand All @@ -117,6 +124,6 @@ describe('OAuthClient getValidAccessToken()', () => {
interactive: true,
})
expect(token).to.equal('cached_access')
expect((http.createServer as any).calledOnce).to.be.true
expect((fakeLsp.window.showDocument as sinon.SinonStub).called).to.be.false
})
})
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { spawn } from 'child_process'
import { URL, URLSearchParams } from 'url'
import * as http from 'http'
import * as os from 'os'
import { Logger, Workspace } from '@aws/language-server-runtimes/server-interface'
import { Logger, Workspace, Lsp } from '@aws/language-server-runtimes/server-interface'

interface Token {
access_token: string
Expand All @@ -35,10 +35,12 @@ interface Registration {
export class OAuthClient {
private static logger: Logger
private static workspace: Workspace
private static lsp: Lsp

public static initialize(ws: Workspace, logger: Logger): void {
public static initialize(ws: Workspace, logger: Logger, lsp: Lsp): void {
this.workspace = ws
this.logger = logger
this.lsp = lsp
}

/**
Expand Down Expand Up @@ -95,10 +97,11 @@ export class OAuthClient {
const savedReg = await this.read<Registration>(regPath)
if (savedReg) {
const port = Number(new URL(savedReg.redirect_uri).port)
const normalized = `http://127.0.0.1:${port}`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a constant port URL? if so can we move it constants?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

port should be read from the server, not the registration

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be the same as we pass the actual port instead of 0 to the listen() method below

server = http.createServer()
try {
await this.listen(server, port)
redirectUri = savedReg.redirect_uri
await this.listen(server, port, '127.0.0.1')
redirectUri = normalized
this.logger.info(`OAuth: reusing redirect URI ${redirectUri}`)
} catch (e: any) {
if (e.code === 'EADDRINUSE') {
Expand Down Expand Up @@ -182,9 +185,9 @@ export class OAuthClient {
/** Spin up a one‑time HTTP listener on localhost:randomPort */
private static async buildCallbackServer(): Promise<{ server: http.Server; redirectUri: string }> {
const server = http.createServer()
await this.listen(server, 0)
await this.listen(server, 0, '127.0.0.1')
const port = (server.address() as any).port as number
return { server, redirectUri: `http://localhost:${port}` }
return { server, redirectUri: `http://127.0.0.1:${port}` }
}

/** Discover OAuth endpoints by HEAD/WWW‑Authenticate, well‑known, or fallback */
Expand Down Expand Up @@ -334,7 +337,7 @@ export class OAuthClient {
redirectUri: string,
server: http.Server
): Promise<Token> {
const DEFAULT_PKCE_TIMEOUT_MS = 20_000
const DEFAULT_PKCE_TIMEOUT_MS = 90_000
// a) generate PKCE params
const verifier = this.b64url(crypto.randomBytes(32))
const challenge = this.b64url(crypto.createHash('sha256').update(verifier).digest())
Expand All @@ -353,17 +356,7 @@ export class OAuthClient {
state: state,
}).toString()

const opener =
process.platform === 'win32'
? {
cmd: 'cmd',
args: ['/c', 'start', '', `"${authz.toString().replace(/"/g, '""')}"`],
}
: process.platform === 'darwin'
? { cmd: 'open', args: [authz.toString()] }
: { cmd: 'xdg-open', args: [authz.toString()] }

void spawn(opener.cmd, opener.args, { detached: true, stdio: 'ignore' }).unref()
await this.lsp.window.showDocument({ uri: authz.toString(), external: true })

// c) wait for code on our loopback
const waitForFlow = new Promise<{ code: string; rxState: string; err?: string; errDesc?: string }>(resolve => {
Expand Down
Loading