diff --git a/eslint.config.js b/eslint.config.js index a8e093eb..c619edd3 100644 --- a/eslint.config.js +++ b/eslint.config.js @@ -64,6 +64,7 @@ export default defineConfig([ "eslint.config.js", "vitest.config.ts", "src/types/*.d.ts", + "tests/integration/fixtures/", ]), eslintPluginPrettierRecommended, ]); diff --git a/package-lock.json b/package-lock.json index 87cc8dc5..648b9f8d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -40,11 +40,13 @@ "@ai-sdk/openai": "^1.3.23", "@eslint/js": "^9.30.1", "@modelcontextprotocol/inspector": "^0.16.0", + "@mongodb-js/oidc-mock-provider": "^0.11.3", "@redocly/cli": "^1.34.4", "@types/express": "^5.0.1", "@types/http-proxy": "^1.17.16", "@types/node": "^24.0.12", "@types/proper-lockfile": "^4.1.4", + "@types/semver": "^7.7.0", "@types/simple-oauth2": "^5.0.7", "@types/yargs-parser": "^21.0.3", "@vitest/coverage-v8": "^3.2.4", @@ -60,6 +62,7 @@ "openapi-typescript": "^7.8.0", "prettier": "^3.6.2", "proper-lockfile": "^4.1.2", + "semver": "^7.7.2", "simple-git": "^3.28.0", "tsx": "^4.20.3", "typescript": "^5.8.3", @@ -2238,17 +2241,74 @@ "integrity": "sha512-ZR/IZi/jI81TRas5X9kzN9t2GZI6u9JdawKctdCoXCrtyvQmRU6ktviCcvXGLwjcZnIWEWbZM7bkpnEdITYSCw==", "license": "Apache-2.0" }, + "node_modules/@mongodb-js/oidc-mock-provider": { + "version": "0.11.3", + "resolved": "https://registry.npmjs.org/@mongodb-js/oidc-mock-provider/-/oidc-mock-provider-0.11.3.tgz", + "integrity": "sha512-U1bCNOKAWQevd5vObXB58Dt+Fw1G21YZ31MmrRZSkfX3JlWT+YTTSot9lgzWs58PdFr3RhAa8VMrudThMDqbgA==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "yargs": "^17.7.2" + }, + "bin": { + "oidc-mock-provider": "bin/oidc-mock-provider.js" + } + }, + "node_modules/@mongodb-js/oidc-mock-provider/node_modules/cliui": { + "version": "8.0.1", + "resolved": "https://registry.npmjs.org/cliui/-/cliui-8.0.1.tgz", + "integrity": "sha512-BSeNnyus75C4//NQ9gQt1/csTXyo/8Sb+afLAkzAptFuMsod9HFokGNudZpi/oQV73hnVK+sR+5PVRMd+Dr7YQ==", + "dev": true, + "license": "ISC", + "dependencies": { + "string-width": "^4.2.0", + "strip-ansi": "^6.0.1", + "wrap-ansi": "^7.0.0" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/@mongodb-js/oidc-mock-provider/node_modules/yargs": { + "version": "17.7.2", + "resolved": "https://registry.npmjs.org/yargs/-/yargs-17.7.2.tgz", + "integrity": "sha512-7dSzzRQ++CKnNI/krKnYRV7JKKPUXMEh61soaHKg9mrWEhzFWhFnxPxGl+69cD1Ou63C13NUPCnmIcrvqCuM6w==", + "dev": true, + "license": "MIT", + "dependencies": { + "cliui": "^8.0.1", + "escalade": "^3.1.1", + "get-caller-file": "^2.0.5", + "require-directory": "^2.1.1", + "string-width": "^4.2.3", + "y18n": "^5.0.5", + "yargs-parser": "^21.1.1" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/@mongodb-js/oidc-mock-provider/node_modules/yargs-parser": { + "version": "21.1.1", + "resolved": "https://registry.npmjs.org/yargs-parser/-/yargs-parser-21.1.1.tgz", + "integrity": "sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw==", + "dev": true, + "license": "ISC", + "engines": { + "node": ">=12" + } + }, "node_modules/@mongodb-js/oidc-plugin": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/@mongodb-js/oidc-plugin/-/oidc-plugin-2.0.1.tgz", - "integrity": "sha512-P9UwfwKHTH5qtycZUxSmYCXaxB5FVodEmQAp2QiktBA8jTy3uoX5tjuvTlOUT0gJxoPMHstSRaFIgW/ZhToKWw==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/@mongodb-js/oidc-plugin/-/oidc-plugin-2.0.2.tgz", + "integrity": "sha512-E+xStW+3qtA8Da9h/cBUDGBd0RmbOwyNEncEbhAf2ZJpTEwHxgAhVO/STmxiaRqw0u4w8EmXrGqDdyGagRhx+A==", "license": "Apache-2.0", "peer": true, "dependencies": { "express": "^5.1.0", "node-fetch": "^3.3.2", "open": "^10.1.2", - "openid-client": "^6.5.1" + "openid-client": "^6.6.3" }, "engines": { "node": ">= 20.19.2" @@ -5423,6 +5483,13 @@ "dev": true, "license": "MIT" }, + "node_modules/@types/semver": { + "version": "7.7.0", + "resolved": "https://registry.npmjs.org/@types/semver/-/semver-7.7.0.tgz", + "integrity": "sha512-k107IF4+Xr7UHjwDc7Cfd6PRQfbdkiRabXGRjo07b4WyPahFBZCZ1sE+BNxYIJPPg73UkfOsVOLwqVc/6ETrIA==", + "dev": true, + "license": "MIT" + }, "node_modules/@types/send": { "version": "0.17.5", "resolved": "https://registry.npmjs.org/@types/send/-/send-0.17.5.tgz", @@ -9288,9 +9355,9 @@ } }, "node_modules/jose": { - "version": "6.0.11", - "resolved": "https://registry.npmjs.org/jose/-/jose-6.0.11.tgz", - "integrity": "sha512-QxG7EaliDARm1O1S8BGakqncGT9s25bKL1WSf6/oa17Tkqwi8D2ZNglqCF+DsYF88/rV66Q/Q2mFAy697E1DUg==", + "version": "6.0.12", + "resolved": "https://registry.npmjs.org/jose/-/jose-6.0.12.tgz", + "integrity": "sha512-T8xypXs8CpmiIi78k0E+Lk7T2zlK4zDyg+o1CZ4AkOHgDg98ogdP2BeZ61lTFKFyoEwJ9RgAgN+SdM3iPgNonQ==", "license": "MIT", "peer": true, "funding": { @@ -10547,9 +10614,9 @@ } }, "node_modules/oauth4webapi": { - "version": "3.6.0", - "resolved": "https://registry.npmjs.org/oauth4webapi/-/oauth4webapi-3.6.0.tgz", - "integrity": "sha512-OwXPTXjKPOldTpAa19oksrX9TYHA0rt+VcUFTkJ7QKwgmevPpNm9Cn5vFZUtIo96FiU6AfPuUUGzoXqgOzibWg==", + "version": "3.7.0", + "resolved": "https://registry.npmjs.org/oauth4webapi/-/oauth4webapi-3.7.0.tgz", + "integrity": "sha512-Q52wTPUWPsVLVVmTViXPQFMW2h2xv2jnDGxypjpelCFKaOjLsm7AxYuOk1oQgFm95VNDbuggasu9htXrz6XwKw==", "license": "MIT", "funding": { "url": "https://github.com/sponsors/panva" @@ -10782,14 +10849,14 @@ } }, "node_modules/openid-client": { - "version": "6.6.2", - "resolved": "https://registry.npmjs.org/openid-client/-/openid-client-6.6.2.tgz", - "integrity": "sha512-Xya5TNMnnZuTM6DbHdB4q0S3ig2NTAELnii/ASie1xDEr8iiB8zZbO871OWBdrw++sd3hW6bqWjgcmSy1RTWHA==", + "version": "6.6.4", + "resolved": "https://registry.npmjs.org/openid-client/-/openid-client-6.6.4.tgz", + "integrity": "sha512-PLWVhRksRnNH05sqeuCX/PR+1J70NyZcAcPske+FeF732KKONd3v0p5Utx1ro1iLfCglH8B3/+dA1vqIHDoIiA==", "license": "MIT", "peer": true, "dependencies": { - "jose": "^6.0.11", - "oauth4webapi": "^3.5.4" + "jose": "^6.0.12", + "oauth4webapi": "^3.7.0" }, "funding": { "url": "https://github.com/sponsors/panva" diff --git a/package.json b/package.json index 7bba9bf6..c3d6e9d4 100644 --- a/package.json +++ b/package.json @@ -61,11 +61,13 @@ "@ai-sdk/openai": "^1.3.23", "@eslint/js": "^9.30.1", "@modelcontextprotocol/inspector": "^0.16.0", + "@mongodb-js/oidc-mock-provider": "^0.11.3", "@redocly/cli": "^1.34.4", "@types/express": "^5.0.1", "@types/http-proxy": "^1.17.16", "@types/node": "^24.0.12", "@types/proper-lockfile": "^4.1.4", + "@types/semver": "^7.7.0", "@types/simple-oauth2": "^5.0.7", "@types/yargs-parser": "^21.0.3", "@vitest/coverage-v8": "^3.2.4", @@ -81,6 +83,7 @@ "openapi-typescript": "^7.8.0", "prettier": "^3.6.2", "proper-lockfile": "^4.1.2", + "semver": "^7.7.2", "simple-git": "^3.28.0", "tsx": "^4.20.3", "typescript": "^5.8.3", diff --git a/src/common/config.ts b/src/common/config.ts index f5c6a079..aebd6e73 100644 --- a/src/common/config.ts +++ b/src/common/config.ts @@ -151,7 +151,8 @@ function getLocalDataPath(): string { : path.join(os.homedir(), ".mongodb"); } -export const defaultDriverOptions: ConnectionInfo["driverOptions"] = { +export type DriverOptions = ConnectionInfo["driverOptions"]; +export const defaultDriverOptions: DriverOptions = { readConcern: { level: "local", }, @@ -345,8 +346,8 @@ export function setupDriverConfig({ defaults, }: { config: UserConfig; - defaults: ConnectionInfo["driverOptions"]; -}): ConnectionInfo["driverOptions"] { + defaults: Partial; +}): DriverOptions { const { driverOptions } = generateConnectionInfoFromCliArgs(config); return { ...defaults, diff --git a/src/common/connectionManager.ts b/src/common/connectionManager.ts index 6c2cb277..2d9bb838 100644 --- a/src/common/connectionManager.ts +++ b/src/common/connectionManager.ts @@ -1,4 +1,4 @@ -import { driverOptions } from "./config.js"; +import { UserConfig, DriverOptions } from "./config.js"; import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; import EventEmitter from "events"; import { setAppNameParamIfMissing } from "../helpers/connectionOptions.js"; @@ -6,6 +6,8 @@ import { packageInfo } from "./packageInfo.js"; import ConnectionString from "mongodb-connection-string-url"; import { MongoClientOptions } from "mongodb"; import { ErrorCodes, MongoDBError } from "./errors.js"; +import { CompositeLogger, LogId } from "./logger.js"; +import { ConnectionInfo, generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser"; export interface AtlasClusterConnectionInfo { username: string; @@ -67,11 +69,21 @@ export interface ConnectionManagerEvents { export class ConnectionManager extends EventEmitter { private state: AnyConnectionState; + private bus: EventEmitter; - constructor() { + constructor( + private userConfig: UserConfig, + private driverOptions: DriverOptions, + private logger: CompositeLogger, + bus?: EventEmitter + ) { super(); + this.bus = bus ?? new EventEmitter(); this.state = { tag: "disconnected" }; + + this.bus.on("mongodb-oidc-plugin:auth-failed", this.onOidcAuthFailed.bind(this)); + this.bus.on("mongodb-oidc-plugin:auth-succeeded", this.onOidcAuthSucceeded.bind(this)); } async connect(settings: ConnectionSettings): Promise { @@ -82,6 +94,8 @@ export class ConnectionManager extends EventEmitter { } let serviceProvider: NodeDriverServiceProvider; + let connectionInfo: ConnectionInfo; + try { settings = { ...settings }; settings.connectionString = setAppNameParamIfMissing({ @@ -89,11 +103,30 @@ export class ConnectionManager extends EventEmitter { defaultAppName: `${packageInfo.mcpServerName} ${packageInfo.version}`, }); - serviceProvider = await NodeDriverServiceProvider.connect(settings.connectionString, { - productDocsLink: "https://github.com/mongodb-js/mongodb-mcp-server/", - productName: "MongoDB MCP", - ...driverOptions, + connectionInfo = generateConnectionInfoFromCliArgs({ + ...this.userConfig, + ...this.driverOptions, + connectionSpecifier: settings.connectionString, }); + + if (connectionInfo.driverOptions.oidc) { + connectionInfo.driverOptions.oidc.allowedFlows ??= ["auth-code"]; + connectionInfo.driverOptions.oidc.notifyDeviceFlow ??= this.onOidcNotifyDeviceFlow.bind(this); + } + + connectionInfo.driverOptions.proxy ??= { useEnvironmentVariableProxies: true }; + connectionInfo.driverOptions.applyProxyToOIDC ??= true; + + serviceProvider = await NodeDriverServiceProvider.connect( + connectionInfo.connectionString, + { + productDocsLink: "https://github.com/mongodb-js/mongodb-mcp-server/", + productName: "MongoDB MCP", + ...connectionInfo.driverOptions, + }, + undefined, + this.bus + ); } catch (error: unknown) { const errorReason = error instanceof Error ? error.message : `${error as string}`; this.changeState("connection-errored", { @@ -105,13 +138,28 @@ export class ConnectionManager extends EventEmitter { } try { + const connectionType = ConnectionManager.inferConnectionTypeFromSettings(this.userConfig, connectionInfo); + if (connectionType.startsWith("oidc")) { + // The error here is irrelevant, we only use this ping to force the connection + // Errors will be handled by the auth flow. + void serviceProvider?.runCommand?.("admin", { hello: 1 }).catch(() => {}); + + return this.changeState("connection-requested", { + tag: "connecting", + connectedAtlasCluster: settings.atlas, + serviceProvider, + connectionStringAuthType: connectionType, + oidcConnectionType: connectionType as OIDCConnectionAuthType, + }); + } + await serviceProvider?.runCommand?.("admin", { hello: 1 }); return this.changeState("connection-succeeded", { tag: "connected", connectedAtlasCluster: settings.atlas, serviceProvider, - connectionStringAuthType: ConnectionManager.inferConnectionTypeFromSettings(settings), + connectionStringAuthType: connectionType, }); } catch (error: unknown) { const errorReason = error instanceof Error ? error.message : `${error as string}`; @@ -157,13 +205,62 @@ export class ConnectionManager extends EventEmitter { return newState; } - static inferConnectionTypeFromSettings(settings: ConnectionSettings): ConnectionStringAuthType { + private onOidcAuthFailed(error: unknown): void { + if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) { + void this.disconnect().then(() => { + this.changeState("connection-errored", { tag: "errored", errorReason: String(error) }); + }); + } + } + + private onOidcAuthSucceeded(): void { + if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) { + this.changeState("connection-succeeded", { ...this.state, tag: "connected" }); + } + + this.logger.info({ + id: LogId.oidcFlow, + context: "mongodb-oidc-plugin:auth-succeeded", + message: "Authenticated successfully.", + }); + } + + private onOidcNotifyDeviceFlow(flowInfo: { verificationUrl: string; userCode: string }): void { + if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) { + this.changeState("connection-requested", { + ...this.state, + tag: "connecting", + connectionStringAuthType: "oidc-device-flow", + oidcLoginUrl: flowInfo.verificationUrl, + oidcUserCode: flowInfo.userCode, + }); + } + + this.logger.info({ + id: LogId.oidcFlow, + context: "mongodb-oidc-plugin:notify-device-flow", + message: "OIDC Flow changed automatically to device flow.", + }); + } + + static inferConnectionTypeFromSettings( + config: UserConfig, + settings: { connectionString: string } + ): ConnectionStringAuthType { const connString = new ConnectionString(settings.connectionString); const searchParams = connString.typedSearchParams(); switch (searchParams.get("authMechanism")) { case "MONGODB-OIDC": { - return "oidc-auth-flow"; // TODO: depending on if we don't have a --browser later it can be oidc-device-flow + if (config.transport === "stdio" && config.browser) { + return "oidc-auth-flow"; + } + + if (config.transport === "http" && config.httpHost === "127.0.0.1" && config.browser) { + return "oidc-auth-flow"; + } + + return "oidc-device-flow"; } case "MONGODB-X509": return "x.509"; @@ -181,4 +278,32 @@ export class ConnectionManager extends EventEmitter { return "scram"; } } + + static async waitUntil( + tag: T["tag"], + cm: ConnectionManager, + signal: AbortSignal, + additionalCondition?: (state: T) => boolean + ): Promise { + let ts: NodeJS.Timeout | undefined; + + return new Promise((resolve, reject) => { + ts = setInterval(() => { + if (signal.aborted) { + return reject(new Error(`Aborted: ${signal.reason}`)); + } + + const status = cm.currentConnectionState; + if (status.tag === tag) { + if (!additionalCondition || (additionalCondition && additionalCondition(status as T))) { + return resolve(status as T); + } + } + }, 100); + }).finally(() => { + if (ts !== undefined) { + clearInterval(ts); + } + }); + } } diff --git a/src/common/logger.ts b/src/common/logger.ts index 0add105c..be738d5b 100644 --- a/src/common/logger.ts +++ b/src/common/logger.ts @@ -58,6 +58,8 @@ export const LogId = { exportedDataListError: mongoLogId(1_007_006), exportedDataAutoCompleteError: mongoLogId(1_007_007), exportLockError: mongoLogId(1_007_008), + + oidcFlow: mongoLogId(1_008_001), } as const; interface LogPayload { diff --git a/src/tools/mongodb/mongodbTool.ts b/src/tools/mongodb/mongodbTool.ts index 2cd1a060..9beaf54a 100644 --- a/src/tools/mongodb/mongodbTool.ts +++ b/src/tools/mongodb/mongodbTool.ts @@ -68,6 +68,16 @@ export abstract class MongoDBToolBase extends ToolBase { : "Note to LLM: do not invent connection strings and explicitly ask the user to provide one. If they have previously connected to MongoDB using MCP, you can ask them if they want to reconnect using the same connection string."; const connectToolsNames = connectTools?.map((t) => `"${t.name}"`).join(", "); + const connectionStatus = this.session.connectionManager.currentConnectionState; + const additionalPromptForOidc: { type: "text"; text: string }[] = []; + + if (connectionStatus.tag === "connecting" && connectionStatus.oidcConnectionType === "oidc-device-flow") { + additionalPromptForOidc.push({ + type: "text", + text: `The user needs to finish their OIDC connection by opening '${connectionStatus.oidcLoginUrl}' in the browser and use the following user code: '${connectionStatus.oidcUserCode}'`, + }); + } + switch (error.code) { case ErrorCodes.NotConnectedToMongoDB: return { @@ -76,6 +86,7 @@ export abstract class MongoDBToolBase extends ToolBase { type: "text", text: "You need to connect to a MongoDB instance before you can access its data.", }, + ...additionalPromptForOidc, { type: "text", text: connectToolsNames diff --git a/src/transports/base.ts b/src/transports/base.ts index 22a000cc..b2ca3e1a 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -1,4 +1,4 @@ -import { UserConfig } from "../common/config.js"; +import { driverOptions, UserConfig } from "../common/config.js"; import { packageInfo } from "../common/packageInfo.js"; import { Server } from "../server.js"; import { Session } from "../common/session.js"; @@ -43,7 +43,7 @@ export abstract class TransportRunnerBase { const logger = new CompositeLogger(...loggers); const exportsManager = ExportsManager.init(userConfig, logger); - const connectionManager = new ConnectionManager(); + const connectionManager = new ConnectionManager(userConfig, driverOptions, logger); const session = new Session({ apiBaseUrl: userConfig.apiBaseUrl, diff --git a/tests/accuracy/sdk/describeAccuracyTests.ts b/tests/accuracy/sdk/describeAccuracyTests.ts index bd5b5c0d..9818c857 100644 --- a/tests/accuracy/sdk/describeAccuracyTests.ts +++ b/tests/accuracy/sdk/describeAccuracyTests.ts @@ -61,7 +61,7 @@ export function describeAccuracyTests(accuracyTestConfigs: AccuracyTestConfig[]) eachModel(`$displayName`, function (model) { const configsWithDescriptions = getConfigsWithDescriptions(accuracyTestConfigs); const accuracyRunId = `${process.env.MDB_ACCURACY_RUN_ID}`; - const mdbIntegration = setupMongoDBIntegrationTest(); + const mdbIntegration = setupMongoDBIntegrationTest({}, []); const { populateTestData, cleanupTestDatabases } = prepareTestData(mdbIntegration); let commitSHA: string; diff --git a/tests/integration/common/connectionManager.oidc.test.ts b/tests/integration/common/connectionManager.oidc.test.ts new file mode 100644 index 00000000..69aa8674 --- /dev/null +++ b/tests/integration/common/connectionManager.oidc.test.ts @@ -0,0 +1,256 @@ +import { describe, beforeEach, afterAll, it, expect, TestContext } from "vitest"; +import semver from "semver"; +import process from "process"; +import { + describeWithMongoDB, + isCommunityServer, + getServerVersion, + MongoDBIntegrationTestCase, +} from "../tools/mongodb/mongodbHelpers.js"; +import { defaultTestConfig, responseAsText, timeout } from "../helpers.js"; +import { + ConnectionManager, + ConnectionStateConnected, + ConnectionStateConnecting, +} from "../../../src/common/connectionManager.js"; +import { setupDriverConfig, UserConfig } from "../../../src/common/config.js"; +import path from "path"; +import type { OIDCMockProviderConfig } from "@mongodb-js/oidc-mock-provider"; +import { OIDCMockProvider } from "@mongodb-js/oidc-mock-provider"; + +const DEFAULT_TIMEOUT = 10000; + +// OIDC is only supported on Linux servers +describe.skipIf(process.platform !== "linux")("ConnectionManager OIDC Tests", async () => { + function setParameter(param: string): ["--setParameter", string] { + return ["--setParameter", param]; + } + + const defaultOidcConfig = { + issuer: "mockta", + clientId: "mocktaTestServer", + requestScopes: ["mongodbGroups"], + authorizationClaim: "groups", + audience: "resource-server-audience-value", + authNamePrefix: "dev", + } as const; + + const fetchBrowserFixture = `"${path.resolve(__dirname, "../fixtures/curl.mjs")}"`; + + let tokenFetches: number = 0; + let getTokenPayload: OIDCMockProviderConfig["getTokenPayload"]; + const oidcMockProviderConfig: OIDCMockProviderConfig = { + getTokenPayload(metadata) { + return getTokenPayload(metadata); + }, + }; + const oidcMockProvider: OIDCMockProvider = await OIDCMockProvider.create(oidcMockProviderConfig); + + afterAll(async () => { + await oidcMockProvider.close(); + }, DEFAULT_TIMEOUT); + + beforeEach(() => { + tokenFetches = 0; + getTokenPayload = ((metadata) => { + tokenFetches++; + return { + expires_in: 1, + payload: { + // Define the user information stored inside the access tokens + groups: [`${metadata.client_id}-group`], + sub: "testuser", + aud: "resource-server-audience-value", + }, + }; + }) as OIDCMockProviderConfig["getTokenPayload"]; + }); + + /** + * We define a test function for the OIDC tests because we will run the test suite on different MongoDB Versions, to make sure + * we don't break compatibility with older or newer versions. So this is kind of a test factory for a single server version. + **/ + type OidcTestParameters = { + defaultTests: boolean; + additionalConfig: Partial; + additionalServerParams: string[]; + }; + + type OidcIt = ( + name: string, + callback: (context: TestContext, integration: MongoDBIntegrationTestCase) => Promise + ) => void; + type OidcTestCases = (it: OidcIt) => void; + + function describeOidcTest( + mongodbVersion: string, + context: string, + args?: Partial, + addCb?: OidcTestCases + ): void { + const serverOidcConfig = { ...defaultOidcConfig, issuer: oidcMockProvider.issuer }; + const serverArgs = [ + ...setParameter(`oidcIdentityProviders=${JSON.stringify([serverOidcConfig])}`), + ...setParameter("authenticationMechanisms=SCRAM-SHA-256,MONGODB-OIDC"), + ...setParameter("enableTestCommands=true"), + ...(args?.additionalServerParams ?? []), + ]; + + const oidcConfig = { + ...defaultTestConfig, + oidcRedirectURi: "http://localhost:0/", + authenticationMechanism: "MONGODB-OIDC", + maxIdleTimeMS: "1", + minPoolSize: "0", + username: "testuser", + browser: fetchBrowserFixture, + ...args?.additionalConfig, + }; + + describeWithMongoDB( + `${mongodbVersion} Enterprise :: ${context}`, + (integration) => { + function oidcIt(name: string, cb: Parameters[1]): void { + /* eslint-disable vitest/expect-expect */ + it(name, { timeout: DEFAULT_TIMEOUT }, async (context) => { + context.skip( + await isCommunityServer(integration), + "OIDC is not supported in MongoDB Community" + ); + context.skip( + semver.satisfies(await getServerVersion(integration), "< 7", { includePrerelease: true }), + "OIDC is only supported on MongoDB newer than 7.0" + ); + + await cb?.(context, integration); + }); + /* eslint-enable vitest/expect-expect */ + } + + beforeEach(async () => { + const connectionManager = integration.mcpServer().session.connectionManager; + // disconnect on purpose doesn't change the state if it was failed to avoid losing + // information in production. + await connectionManager.disconnect(); + // for testing, force disconnecting AND setting the connection to closed to reset the + // state of the connection manager + connectionManager.changeState("connection-closed", { tag: "disconnected" }); + + await integration.connectMcpClient(); + }, DEFAULT_TIMEOUT); + + addCb?.(oidcIt); + }, + () => oidcConfig, + () => ({ + ...setupDriverConfig({ + config: oidcConfig, + defaults: {}, + }), + }), + { enterprise: true, version: mongodbVersion }, + serverArgs + ); + } + + const baseTestMatrix = [ + { version: "8.0.12", nonce: false }, + { version: "8.0.12", nonce: true }, + ] as const; + + for (const { version, nonce } of baseTestMatrix) { + describeOidcTest(version, `auth-flow;nonce=${nonce}`, { additionalConfig: { oidcNoNonce: !nonce } }, (it) => { + it("can connect with the expected user", async ({ signal }, integration) => { + const state = await ConnectionManager.waitUntil( + "connected", + integration.mcpServer().session.connectionManager, + signal + ); + + type ConnectionStatus = { + authInfo: { + authenticatedUsers: { user: string; db: string }[]; + authenticatedUserRoles: { role: string; db: string }[]; + }; + }; + + const status: ConnectionStatus = (await state.serviceProvider.runCommand("admin", { + connectionStatus: 1, + })) as unknown as ConnectionStatus; + + expect(status.authInfo.authenticatedUsers[0]).toEqual({ user: "dev/testuser", db: "$external" }); + expect(status.authInfo.authenticatedUserRoles[0]).toEqual({ + role: "dev/mocktaTestServer-group", + db: "admin", + }); + }); + + it("can list existing databases", async ({ signal }, integration) => { + const state = await ConnectionManager.waitUntil( + "connected", + integration.mcpServer().session.connectionManager, + signal + ); + + const listDbResult = await state.serviceProvider.listDatabases("admin"); + const databases = listDbResult.databases as unknown[]; + expect(databases.length).toBeGreaterThan(0); + }); + + it("can refresh a token once expired", async ({ signal }, integration) => { + const state = await ConnectionManager.waitUntil( + "connected", + integration.mcpServer().session.connectionManager, + signal + ); + + await timeout(2000); + await state.serviceProvider.listDatabases("admin"); + expect(tokenFetches).toBeGreaterThan(1); + }); + }); + } + + // just infer from all the versions in the base test matrix, so it doesn't need to be maintained separately + const deviceAuthMatrix = new Set(baseTestMatrix.map((base) => base.version)); + + for (const version of deviceAuthMatrix) { + describeOidcTest( + version, + "device-flow", + { additionalConfig: { oidcFlows: "device-auth", browser: false } }, + (it) => { + it("gets requested by the agent to connect", async ({ signal }, integration) => { + const state = await ConnectionManager.waitUntil( + "connecting", + integration.mcpServer().session.connectionManager, + signal, + (state) => !!state.oidcLoginUrl && !!state.oidcUserCode + ); + + const response = responseAsText( + await integration.mcpClient().callTool({ name: "list-databases", arguments: {} }) + ); + + expect(response).toContain("The user needs to finish their OIDC connection by opening"); + expect(response).toContain(state.oidcLoginUrl); + expect(response).toContain(state.oidcUserCode); + + await ConnectionManager.waitUntil( + "connected", + integration.mcpServer().session.connectionManager, + signal + ); + + const connectedResponse = responseAsText( + await integration.mcpClient().callTool({ name: "list-databases", arguments: {} }) + ); + + expect(connectedResponse).toContain("admin"); + expect(connectedResponse).toContain("config"); + expect(connectedResponse).toContain("local"); + }); + } + ); + } +}); diff --git a/tests/integration/common/connectionManager.test.ts b/tests/integration/common/connectionManager.test.ts index 4536361a..5af4582d 100644 --- a/tests/integration/common/connectionManager.test.ts +++ b/tests/integration/common/connectionManager.test.ts @@ -4,6 +4,7 @@ import { ConnectionStateConnected, ConnectionStringAuthType, } from "../../../src/common/connectionManager.js"; +import type { UserConfig } from "../../../src/common/config.js"; import { describeWithMongoDB } from "../tools/mongodb/mongodbHelpers.js"; import { describe, beforeEach, expect, it, vi, afterEach } from "vitest"; @@ -136,23 +137,52 @@ describeWithMongoDB("Connection Manager", (integration) => { describe("Connection Manager connection type inference", () => { const testCases = [ - { connectionString: "mongodb://localhost:27017", connectionType: "scram" }, - { connectionString: "mongodb://localhost:27017?authMechanism=MONGODB-X509", connectionType: "x.509" }, - { connectionString: "mongodb://localhost:27017?authMechanism=GSSAPI", connectionType: "kerberos" }, + { userConfig: {}, connectionString: "mongodb://localhost:27017", connectionType: "scram" }, { + userConfig: {}, + connectionString: "mongodb://localhost:27017?authMechanism=MONGODB-X509", + connectionType: "x.509", + }, + { + userConfig: {}, + connectionString: "mongodb://localhost:27017?authMechanism=GSSAPI", + connectionType: "kerberos", + }, + { + userConfig: {}, connectionString: "mongodb://localhost:27017?authMechanism=PLAIN&authSource=$external", connectionType: "ldap", }, - { connectionString: "mongodb://localhost:27017?authMechanism=PLAIN", connectionType: "scram" }, - { connectionString: "mongodb://localhost:27017?authMechanism=MONGODB-OIDC", connectionType: "oidc-auth-flow" }, + { userConfig: {}, connectionString: "mongodb://localhost:27017?authMechanism=PLAIN", connectionType: "scram" }, + { + userConfig: { transport: "stdio", browser: "firefox" }, + connectionString: "mongodb://localhost:27017?authMechanism=MONGODB-OIDC", + connectionType: "oidc-auth-flow", + }, + { + userConfig: { transport: "http", httpHost: "127.0.0.1", browser: "ie6" }, + connectionString: "mongodb://localhost:27017?authMechanism=MONGODB-OIDC", + connectionType: "oidc-auth-flow", + }, + { + userConfig: { transport: "http", httpHost: "0.0.0.0", browser: "ie6" }, + connectionString: "mongodb://localhost:27017?authMechanism=MONGODB-OIDC", + connectionType: "oidc-device-flow", + }, + { + userConfig: { transport: "stdio" }, + connectionString: "mongodb://localhost:27017?authMechanism=MONGODB-OIDC", + connectionType: "oidc-device-flow", + }, ] as { + userConfig: Partial; connectionString: string; connectionType: ConnectionStringAuthType; }[]; - for (const { connectionString, connectionType } of testCases) { + for (const { userConfig, connectionString, connectionType } of testCases) { it(`infers ${connectionType} from ${connectionString}`, () => { - const actualConnectionType = ConnectionManager.inferConnectionTypeFromSettings({ + const actualConnectionType = ConnectionManager.inferConnectionTypeFromSettings(userConfig as UserConfig, { connectionString, }); diff --git a/tests/integration/fixtures/curl.mjs b/tests/integration/fixtures/curl.mjs new file mode 100755 index 00000000..c8432c5e --- /dev/null +++ b/tests/integration/fixtures/curl.mjs @@ -0,0 +1,15 @@ +#!/usr/bin/env node +import fetch from "node-fetch"; + +if (process.env.MONGOSH_E2E_TEST_CURL_ALLOW_INVALID_TLS) { + process.env.NODE_TLS_REJECT_UNAUTHORIZED = "0"; +} + +// fetch() an URL and ignore the response body +(async function () { + (await fetch(process.argv[2])).body?.resume(); +})().catch((err) => { + process.nextTick(() => { + throw err; + }); +}); diff --git a/tests/integration/helpers.ts b/tests/integration/helpers.ts index 738cbdfd..ebac3f42 100644 --- a/tests/integration/helpers.ts +++ b/tests/integration/helpers.ts @@ -1,12 +1,12 @@ import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { InMemoryTransport } from "./inMemoryTransport.js"; import { Server } from "../../src/server.js"; -import { UserConfig } from "../../src/common/config.js"; +import { DriverOptions, UserConfig } from "../../src/common/config.js"; import { McpError, ResourceUpdatedNotificationSchema } from "@modelcontextprotocol/sdk/types.js"; import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import { Session } from "../../src/common/session.js"; import { Telemetry } from "../../src/telemetry/telemetry.js"; -import { config } from "../../src/common/config.js"; +import { config, driverOptions } from "../../src/common/config.js"; import { afterAll, afterEach, beforeAll, describe, expect, it, vi } from "vitest"; import { ConnectionManager } from "../../src/common/connectionManager.js"; import { CompositeLogger } from "../../src/common/logger.js"; @@ -31,12 +31,21 @@ export const defaultTestConfig: UserConfig = { loggers: ["stderr"], }; -export function setupIntegrationTest(getUserConfig: () => UserConfig): IntegrationTest { +export const defaultDriverOptions: DriverOptions = { + ...driverOptions, +}; + +export function setupIntegrationTest( + getUserConfig: () => UserConfig, + getDriverOptions: () => DriverOptions +): IntegrationTest { let mcpClient: Client | undefined; let mcpServer: Server | undefined; beforeAll(async () => { const userConfig = getUserConfig(); + const driverOptions = getDriverOptions(); + const clientTransport = new InMemoryTransport(); const serverTransport = new InMemoryTransport(); @@ -58,7 +67,7 @@ export function setupIntegrationTest(getUserConfig: () => UserConfig): Integrati const logger = new CompositeLogger(); const exportsManager = ExportsManager.init(userConfig, logger); - const connectionManager = new ConnectionManager(); + const connectionManager = new ConnectionManager(userConfig, driverOptions, logger); const session = new Session({ apiBaseUrl: userConfig.apiBaseUrl, @@ -291,3 +300,7 @@ export function resourceChangedNotification(client: Client, uri: string): Promis }); }); } + +export function responseAsText(response: Awaited>): string { + return JSON.stringify(response.content, undefined, 2); +} diff --git a/tests/integration/server.test.ts b/tests/integration/server.test.ts index c9409831..ef98075a 100644 --- a/tests/integration/server.test.ts +++ b/tests/integration/server.test.ts @@ -1,4 +1,4 @@ -import { defaultTestConfig, expectDefined, setupIntegrationTest } from "./helpers.js"; +import { defaultDriverOptions, defaultTestConfig, expectDefined, setupIntegrationTest } from "./helpers.js"; import { describeWithMongoDB } from "./tools/mongodb/mongodbHelpers.js"; import { describe, expect, it } from "vitest"; @@ -19,15 +19,19 @@ describe("Server integration test", () => { ...defaultTestConfig, apiClientId: undefined, apiClientSecret: undefined, - }) + }), + () => defaultDriverOptions ); describe("with atlas", () => { - const integration = setupIntegrationTest(() => ({ - ...defaultTestConfig, - apiClientId: "test", - apiClientSecret: "test", - })); + const integration = setupIntegrationTest( + () => ({ + ...defaultTestConfig, + apiClientId: "test", + apiClientSecret: "test", + }), + () => defaultDriverOptions + ); describe("list capabilities", () => { it("should return positive number of tools and have some atlas tools", async () => { @@ -59,12 +63,15 @@ describe("Server integration test", () => { }); describe("with read-only mode", () => { - const integration = setupIntegrationTest(() => ({ - ...defaultTestConfig, - readOnly: true, - apiClientId: "test", - apiClientSecret: "test", - })); + const integration = setupIntegrationTest( + () => ({ + ...defaultTestConfig, + readOnly: true, + apiClientId: "test", + apiClientSecret: "test", + }), + () => defaultDriverOptions + ); it("should only register read and metadata operation tools when read-only mode is enabled", async () => { const tools = await integration.mcpClient().listTools(); diff --git a/tests/integration/telemetry.test.ts b/tests/integration/telemetry.test.ts index 95bc79c2..62d959fa 100644 --- a/tests/integration/telemetry.test.ts +++ b/tests/integration/telemetry.test.ts @@ -1,7 +1,7 @@ import { createHmac } from "crypto"; import { Telemetry } from "../../src/telemetry/telemetry.js"; import { Session } from "../../src/common/session.js"; -import { config } from "../../src/common/config.js"; +import { config, driverOptions } from "../../src/common/config.js"; import nodeMachineId from "node-machine-id"; import { describe, expect, it } from "vitest"; import { CompositeLogger } from "../../src/common/logger.js"; @@ -18,9 +18,9 @@ describe("Telemetry", () => { const telemetry = Telemetry.create( new Session({ apiBaseUrl: "", - logger: new CompositeLogger(), + logger, exportsManager: ExportsManager.init(config, logger), - connectionManager: new ConnectionManager(), + connectionManager: new ConnectionManager(config, driverOptions, logger), }), config ); diff --git a/tests/integration/tools/atlas/atlasHelpers.ts b/tests/integration/tools/atlas/atlasHelpers.ts index 622a99dc..57cd9811 100644 --- a/tests/integration/tools/atlas/atlasHelpers.ts +++ b/tests/integration/tools/atlas/atlasHelpers.ts @@ -1,18 +1,21 @@ import { ObjectId } from "mongodb"; import { Group } from "../../../../src/common/atlas/openapi.js"; import { ApiClient } from "../../../../src/common/atlas/apiClient.js"; -import { setupIntegrationTest, IntegrationTest, defaultTestConfig } from "../../helpers.js"; +import { setupIntegrationTest, IntegrationTest, defaultTestConfig, defaultDriverOptions } from "../../helpers.js"; import { afterAll, beforeAll, describe, SuiteCollector } from "vitest"; export type IntegrationTestFunction = (integration: IntegrationTest) => void; export function describeWithAtlas(name: string, fn: IntegrationTestFunction): SuiteCollector { const testDefinition = (): void => { - const integration = setupIntegrationTest(() => ({ - ...defaultTestConfig, - apiClientId: process.env.MDB_MCP_API_CLIENT_ID, - apiClientSecret: process.env.MDB_MCP_API_CLIENT_SECRET, - })); + const integration = setupIntegrationTest( + () => ({ + ...defaultTestConfig, + apiClientId: process.env.MDB_MCP_API_CLIENT_ID, + apiClientSecret: process.env.MDB_MCP_API_CLIENT_SECRET, + }), + () => defaultDriverOptions + ); describe(name, () => { fn(integration); diff --git a/tests/integration/tools/mongodb/connect/connect.test.ts b/tests/integration/tools/mongodb/connect/connect.test.ts index 7dd275d3..26d65ba4 100644 --- a/tests/integration/tools/mongodb/connect/connect.test.ts +++ b/tests/integration/tools/mongodb/connect/connect.test.ts @@ -1,5 +1,6 @@ import { describeWithMongoDB } from "../mongodbHelpers.js"; import { + defaultDriverOptions, getResponseContent, getResponseElements, validateThrowsForInvalidArguments, @@ -138,10 +139,13 @@ describeWithMongoDB( ); describe("Connect tool when disabled", () => { - const integration = setupIntegrationTest(() => ({ - ...defaultTestConfig, - disabledTools: ["connect"], - })); + const integration = setupIntegrationTest( + () => ({ + ...defaultTestConfig, + disabledTools: ["connect"], + }), + () => defaultDriverOptions + ); it("is not suggested when querying MongoDB disconnected", async () => { const response = await integration.mcpClient().callTool({ diff --git a/tests/integration/tools/mongodb/mongodbHelpers.ts b/tests/integration/tools/mongodb/mongodbHelpers.ts index bdf5065a..6ca7b9fd 100644 --- a/tests/integration/tools/mongodb/mongodbHelpers.ts +++ b/tests/integration/tools/mongodb/mongodbHelpers.ts @@ -1,10 +1,16 @@ -import { MongoCluster } from "mongodb-runner"; +import { MongoCluster, MongoClusterOptions } from "mongodb-runner"; import path from "path"; import { fileURLToPath } from "url"; import fs from "fs/promises"; import { Document, MongoClient, ObjectId } from "mongodb"; -import { getResponseContent, IntegrationTest, setupIntegrationTest, defaultTestConfig } from "../../helpers.js"; -import { UserConfig } from "../../../../src/common/config.js"; +import { + getResponseContent, + IntegrationTest, + setupIntegrationTest, + defaultTestConfig, + defaultDriverOptions, +} from "../../helpers.js"; +import { UserConfig, DriverOptions } from "../../../../src/common/config.js"; import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it } from "vitest"; const __dirname = path.dirname(fileURLToPath(import.meta.url)); @@ -45,16 +51,27 @@ interface MongoDBIntegrationTest { randomDbName: () => string; } +export type MongoDBIntegrationTestCase = IntegrationTest & + MongoDBIntegrationTest & { connectMcpClient: () => Promise }; + export function describeWithMongoDB( name: string, - fn: (integration: IntegrationTest & MongoDBIntegrationTest & { connectMcpClient: () => Promise }) => void, - getUserConfig: (mdbIntegration: MongoDBIntegrationTest) => UserConfig = () => defaultTestConfig + fn: (integration: MongoDBIntegrationTestCase) => void, + getUserConfig: (mdbIntegration: MongoDBIntegrationTest) => UserConfig = () => defaultTestConfig, + getDriverOptions: (mdbIntegration: MongoDBIntegrationTest) => DriverOptions = () => defaultDriverOptions, + downloadOptions: MongoClusterOptions["downloadOptions"] = { enterprise: false }, + serverArgs: string[] = [] ): void { describe(name, () => { - const mdbIntegration = setupMongoDBIntegrationTest(); - const integration = setupIntegrationTest(() => ({ - ...getUserConfig(mdbIntegration), - })); + const mdbIntegration = setupMongoDBIntegrationTest(downloadOptions, serverArgs); + const integration = setupIntegrationTest( + () => ({ + ...getUserConfig(mdbIntegration), + }), + () => ({ + ...getDriverOptions(mdbIntegration), + }) + ); fn({ ...integration, @@ -72,7 +89,10 @@ export function describeWithMongoDB( }); } -export function setupMongoDBIntegrationTest(): MongoDBIntegrationTest { +export function setupMongoDBIntegrationTest( + downloadOptions: MongoClusterOptions["downloadOptions"], + serverArgs: string[] +): MongoDBIntegrationTest { let mongoCluster: MongoCluster | undefined; let mongoClient: MongoClient | undefined; let randomDbName: string; @@ -101,7 +121,9 @@ export function setupMongoDBIntegrationTest(): MongoDBIntegrationTest { tmpDir: dbsDir, logDir: path.join(tmpDir, "mongodb-runner", "logs"), topology: "standalone", - version: "8.0.10", + version: downloadOptions?.version ?? "8.0.12", + downloadOptions, + args: serverArgs, }); return; @@ -252,3 +274,17 @@ export function getDocsFromUntrustedContent(content: string): unknown[] { const json = lines.slice(startIdx, endIdx + 1).join("\n"); return JSON.parse(json) as unknown[]; } + +export async function isCommunityServer(integration: MongoDBIntegrationTestCase): Promise { + const client = integration.mongoClient(); + const buildInfo = await client.db("_").command({ buildInfo: 1 }); + const modules: string[] = buildInfo.modules as string[]; + + return !modules.includes("enterprise"); +} + +export async function getServerVersion(integration: MongoDBIntegrationTestCase): Promise { + const client = integration.mongoClient(); + const serverStatus = await client.db("admin").admin().serverStatus(); + return serverStatus.version as string; +} diff --git a/tests/unit/common/session.test.ts b/tests/unit/common/session.test.ts index add1cac5..9753d01f 100644 --- a/tests/unit/common/session.test.ts +++ b/tests/unit/common/session.test.ts @@ -1,7 +1,7 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; import { Session } from "../../../src/common/session.js"; -import { config } from "../../../src/common/config.js"; +import { config, driverOptions } from "../../../src/common/config.js"; import { CompositeLogger } from "../../../src/common/logger.js"; import { ConnectionManager } from "../../../src/common/connectionManager.js"; import { ExportsManager } from "../../../src/common/exportsManager.js"; @@ -18,7 +18,7 @@ describe("Session", () => { apiBaseUrl: "https://api.test.com", logger, exportsManager: ExportsManager.init(config, logger), - connectionManager: new ConnectionManager(), + connectionManager: new ConnectionManager(config, driverOptions, logger), }); MockNodeDriverServiceProvider.connect = vi.fn().mockResolvedValue({} as unknown as NodeDriverServiceProvider); diff --git a/tests/unit/resources/common/debug.test.ts b/tests/unit/resources/common/debug.test.ts index 8e798827..4f51e381 100644 --- a/tests/unit/resources/common/debug.test.ts +++ b/tests/unit/resources/common/debug.test.ts @@ -2,7 +2,7 @@ import { beforeEach, describe, expect, it } from "vitest"; import { DebugResource } from "../../../../src/resources/common/debug.js"; import { Session } from "../../../../src/common/session.js"; import { Telemetry } from "../../../../src/telemetry/telemetry.js"; -import { config } from "../../../../src/common/config.js"; +import { config, driverOptions } from "../../../../src/common/config.js"; import { CompositeLogger } from "../../../../src/common/logger.js"; import { ConnectionManager } from "../../../../src/common/connectionManager.js"; import { ExportsManager } from "../../../../src/common/exportsManager.js"; @@ -13,7 +13,7 @@ describe("debug resource", () => { apiBaseUrl: "", logger, exportsManager: ExportsManager.init(config, logger), - connectionManager: new ConnectionManager(), + connectionManager: new ConnectionManager(config, driverOptions, logger), }); const telemetry = Telemetry.create(session, { ...config, telemetry: "disabled" });