diff --git a/src/common/connectionManager.ts b/src/common/connectionManager.ts index edfb9966..68504e24 100644 --- a/src/common/connectionManager.ts +++ b/src/common/connectionManager.ts @@ -6,6 +6,8 @@ import { packageInfo } from "./packageInfo.js"; import ConnectionString from "mongodb-connection-string-url"; import { MongoClientOptions } from "mongodb"; import { ErrorCodes, MongoDBError } from "./errors.js"; +import { DeviceId } from "../helpers/deviceId.js"; +import { AppNameComponents } from "../helpers/connectionOptions.js"; import { CompositeLogger, LogId } from "./logger.js"; import { ConnectionInfo, generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser"; @@ -69,12 +71,15 @@ export interface ConnectionManagerEvents { export class ConnectionManager extends EventEmitter { private state: AnyConnectionState; + private deviceId: DeviceId; + private clientName: string; private bus: EventEmitter; constructor( private userConfig: UserConfig, private driverOptions: DriverOptions, private logger: CompositeLogger, + deviceId: DeviceId, bus?: EventEmitter ) { super(); @@ -84,6 +89,13 @@ export class ConnectionManager extends EventEmitter { this.bus.on("mongodb-oidc-plugin:auth-failed", this.onOidcAuthFailed.bind(this)); this.bus.on("mongodb-oidc-plugin:auth-succeeded", this.onOidcAuthSucceeded.bind(this)); + + this.deviceId = deviceId; + this.clientName = "unknown"; + } + + setClientName(clientName: string): void { + this.clientName = clientName; } async connect(settings: ConnectionSettings): Promise { @@ -98,9 +110,15 @@ export class ConnectionManager extends EventEmitter { try { settings = { ...settings }; - settings.connectionString = setAppNameParamIfMissing({ + const appNameComponents: AppNameComponents = { + appName: `${packageInfo.mcpServerName} ${packageInfo.version}`, + deviceId: this.deviceId.get(), + clientName: this.clientName, + }; + + settings.connectionString = await setAppNameParamIfMissing({ connectionString: settings.connectionString, - defaultAppName: `${packageInfo.mcpServerName} ${packageInfo.version}`, + components: appNameComponents, }); connectionInfo = generateConnectionInfoFromCliArgs({ diff --git a/src/common/logger.ts b/src/common/logger.ts index be738d5b..1bf25fea 100644 --- a/src/common/logger.ts +++ b/src/common/logger.ts @@ -14,6 +14,7 @@ export const LogId = { serverClosed: mongoLogId(1_000_004), serverCloseFailure: mongoLogId(1_000_005), serverDuplicateLoggers: mongoLogId(1_000_006), + serverMcpClientSet: mongoLogId(1_000_007), atlasCheckCredentials: mongoLogId(1_001_001), atlasDeleteDatabaseUserFailure: mongoLogId(1_001_002), @@ -30,8 +31,8 @@ export const LogId = { telemetryEmitStart: mongoLogId(1_002_003), telemetryEmitSuccess: mongoLogId(1_002_004), telemetryMetadataError: mongoLogId(1_002_005), - telemetryDeviceIdFailure: mongoLogId(1_002_006), - telemetryDeviceIdTimeout: mongoLogId(1_002_007), + deviceIdResolutionError: mongoLogId(1_002_006), + deviceIdTimeout: mongoLogId(1_002_007), toolExecute: mongoLogId(1_003_001), toolExecuteFailure: mongoLogId(1_003_002), diff --git a/src/common/session.ts b/src/common/session.ts index 444a747b..b13c4a7e 100644 --- a/src/common/session.ts +++ b/src/common/session.ts @@ -34,9 +34,10 @@ export class Session extends EventEmitter { readonly exportsManager: ExportsManager; readonly connectionManager: ConnectionManager; readonly apiClient: ApiClient; - agentRunner?: { - name: string; - version: string; + mcpClient?: { + name?: string; + version?: string; + title?: string; }; public logger: CompositeLogger; @@ -69,13 +70,24 @@ export class Session extends EventEmitter { this.connectionManager.on("connection-errored", (error) => this.emit("connection-error", error.errorReason)); } - setAgentRunner(agentRunner: Implementation | undefined): void { - if (agentRunner?.name && agentRunner?.version) { - this.agentRunner = { - name: agentRunner.name, - version: agentRunner.version, - }; + setMcpClient(mcpClient: Implementation | undefined): void { + if (!mcpClient) { + this.connectionManager.setClientName("unknown"); + this.logger.debug({ + id: LogId.serverMcpClientSet, + context: "session", + message: "MCP client info not found", + }); } + + this.mcpClient = { + name: mcpClient?.name || "unknown", + version: mcpClient?.version || "unknown", + title: mcpClient?.title || "unknown", + }; + + // Set the client name on the connection manager for appName generation + this.connectionManager.setClientName(this.mcpClient.name || "unknown"); } async disconnect(): Promise { diff --git a/src/helpers/connectionOptions.ts b/src/helpers/connectionOptions.ts index 10b1ecc8..009996ee 100644 --- a/src/helpers/connectionOptions.ts +++ b/src/helpers/connectionOptions.ts @@ -1,20 +1,59 @@ import { MongoClientOptions } from "mongodb"; import ConnectionString from "mongodb-connection-string-url"; -export function setAppNameParamIfMissing({ +export interface AppNameComponents { + appName: string; + deviceId?: Promise; + clientName?: string; +} + +/** + * Sets the appName parameter with the extended format: appName--deviceId--clientName + * Only sets the appName if it's not already present in the connection string + * @param connectionString - The connection string to modify + * @param components - The components to build the appName from + * @returns The modified connection string + */ +export async function setAppNameParamIfMissing({ connectionString, - defaultAppName, + components, }: { connectionString: string; - defaultAppName?: string; -}): string { + components: AppNameComponents; +}): Promise { const connectionStringUrl = new ConnectionString(connectionString); - const searchParams = connectionStringUrl.typedSearchParams(); - if (!searchParams.has("appName") && defaultAppName !== undefined) { - searchParams.set("appName", defaultAppName); + // Only set appName if it's not already present + if (searchParams.has("appName")) { + return connectionStringUrl.toString(); } + const appName = components.appName || "unknown"; + const deviceId = components.deviceId ? await components.deviceId : "unknown"; + const clientName = components.clientName || "unknown"; + + // Build the extended appName format: appName--deviceId--clientName + const extendedAppName = `${appName}--${deviceId}--${clientName}`; + + searchParams.set("appName", extendedAppName); + return connectionStringUrl.toString(); } + +/** + * Validates the connection string + * @param connectionString - The connection string to validate + * @param looseValidation - Whether to allow loose validation + * @returns void + * @throws Error if the connection string is invalid + */ +export function validateConnectionString(connectionString: string, looseValidation: boolean): void { + try { + new ConnectionString(connectionString, { looseValidation }); + } catch (error) { + throw new Error( + `Invalid connection string with error: ${error instanceof Error ? error.message : String(error)}` + ); + } +} diff --git a/src/helpers/deviceId.ts b/src/helpers/deviceId.ts new file mode 100644 index 00000000..246b0bd1 --- /dev/null +++ b/src/helpers/deviceId.ts @@ -0,0 +1,113 @@ +import { getDeviceId } from "@mongodb-js/device-id"; +import nodeMachineId from "node-machine-id"; +import { LogId, LoggerBase } from "../common/logger.js"; + +export const DEVICE_ID_TIMEOUT = 3000; + +export class DeviceId { + private deviceId: string | undefined = undefined; + private deviceIdPromise: Promise | undefined = undefined; + private abortController: AbortController | undefined = undefined; + private logger: LoggerBase; + private readonly getMachineId: () => Promise; + private timeout: number; + private static instance: DeviceId | undefined = undefined; + + private constructor(logger: LoggerBase, timeout: number = DEVICE_ID_TIMEOUT) { + this.logger = logger; + this.timeout = timeout; + this.getMachineId = (): Promise => nodeMachineId.machineId(true); + } + + public static create(logger: LoggerBase, timeout?: number): DeviceId { + if (this.instance) { + throw new Error("DeviceId instance already exists, use get() to retrieve the device ID"); + } + + const instance = new DeviceId(logger, timeout ?? DEVICE_ID_TIMEOUT); + instance.setup(); + + this.instance = instance; + + return instance; + } + + private setup(): void { + this.deviceIdPromise = this.calculateDeviceId(); + } + + /** + * Closes the device ID calculation promise and abort controller. + */ + public close(): void { + if (this.abortController) { + this.abortController.abort(); + this.abortController = undefined; + } + + this.deviceId = undefined; + this.deviceIdPromise = undefined; + DeviceId.instance = undefined; + } + + /** + * Gets the device ID, waiting for the calculation to complete if necessary. + * @returns Promise that resolves to the device ID string + */ + public get(): Promise { + if (this.deviceId) { + return Promise.resolve(this.deviceId); + } + + if (this.deviceIdPromise) { + return this.deviceIdPromise; + } + + return this.calculateDeviceId(); + } + + /** + * Internal method that performs the actual device ID calculation. + */ + private async calculateDeviceId(): Promise { + if (!this.abortController) { + this.abortController = new AbortController(); + } + + this.deviceIdPromise = getDeviceId({ + getMachineId: this.getMachineId, + onError: (reason, error) => { + this.handleDeviceIdError(reason, String(error)); + }, + timeout: this.timeout, + abortSignal: this.abortController.signal, + }); + + return this.deviceIdPromise; + } + + private handleDeviceIdError(reason: string, error: string): void { + this.deviceIdPromise = Promise.resolve("unknown"); + + switch (reason) { + case "resolutionError": + this.logger.debug({ + id: LogId.deviceIdResolutionError, + context: "deviceId", + message: `Resolution error: ${String(error)}`, + }); + break; + case "timeout": + this.logger.debug({ + id: LogId.deviceIdTimeout, + context: "deviceId", + message: "Device ID retrieval timed out", + noRedaction: true, + }); + break; + case "abort": + // No need to log in the case of 'abort' errors + break; + } + } +} diff --git a/src/index.ts b/src/index.ts index f391a9a7..29f525dc 100644 --- a/src/index.ts +++ b/src/index.ts @@ -50,7 +50,6 @@ async function main(): Promise { assertVersionMode(); const transportRunner = config.transport === "stdio" ? new StdioRunner(config) : new StreamableHttpRunner(config); - const shutdown = (): void => { transportRunner.logger.info({ id: LogId.serverCloseRequested, diff --git a/src/server.ts b/src/server.ts index bf41b26d..5121a858 100644 --- a/src/server.ts +++ b/src/server.ts @@ -17,6 +17,7 @@ import { } from "@modelcontextprotocol/sdk/types.js"; import assert from "assert"; import { ToolBase } from "./tools/tool.js"; +import { validateConnectionString } from "./helpers/connectionOptions.js"; export interface ServerOptions { session: Session; @@ -97,12 +98,14 @@ export class Server { }); this.mcpServer.server.oninitialized = (): void => { - this.session.setAgentRunner(this.mcpServer.server.getClientVersion()); + this.session.setMcpClient(this.mcpServer.server.getClientVersion()); + // Placed here to start the connection to the config connection string as soon as the server is initialized. + void this.connectToConfigConnectionString(); this.session.logger.info({ id: LogId.serverInitialized, context: "server", - message: `Server started with transport ${transport.constructor.name} and agent runner ${this.session.agentRunner?.name}`, + message: `Server started with transport ${transport.constructor.name} and agent runner ${this.session.mcpClient?.name}`, }); this.emitServerEvent("start", Date.now() - this.startTime); @@ -188,20 +191,20 @@ export class Server { } private async validateConfig(): Promise { + // Validate connection string if (this.userConfig.connectionString) { try { - await this.session.connectToMongoDB({ - connectionString: this.userConfig.connectionString, - }); + validateConnectionString(this.userConfig.connectionString, false); } catch (error) { - console.error( - "Failed to connect to MongoDB instance using the connection string from the config: ", - error + console.error("Connection string validation failed with error: ", error); + throw new Error( + "Connection string validation failed with error: " + + (error instanceof Error ? error.message : String(error)) ); - throw new Error("Failed to connect to MongoDB instance using the connection string from the config"); } } + // Validate API client credentials if (this.userConfig.apiClientId && this.userConfig.apiClientSecret) { try { await this.session.apiClient.validateAccessToken(); @@ -219,4 +222,20 @@ export class Server { } } } + + private async connectToConfigConnectionString(): Promise { + if (this.userConfig.connectionString) { + try { + await this.session.connectToMongoDB({ + connectionString: this.userConfig.connectionString, + }); + } catch (error) { + console.error( + "Failed to connect to MongoDB instance using the connection string from the config: ", + error + ); + throw new Error("Failed to connect to MongoDB instance using the connection string from the config"); + } + } + } } diff --git a/src/telemetry/telemetry.ts b/src/telemetry/telemetry.ts index 1d08f2e1..af75fead 100644 --- a/src/telemetry/telemetry.ts +++ b/src/telemetry/telemetry.ts @@ -5,49 +5,44 @@ import { LogId } from "../common/logger.js"; import { ApiClient } from "../common/atlas/apiClient.js"; import { MACHINE_METADATA } from "./constants.js"; import { EventCache } from "./eventCache.js"; -import nodeMachineId from "node-machine-id"; -import { getDeviceId } from "@mongodb-js/device-id"; import { detectContainerEnv } from "../helpers/container.js"; +import { DeviceId } from "../helpers/deviceId.js"; type EventResult = { success: boolean; error?: Error; }; -export const DEVICE_ID_TIMEOUT = 3000; - export class Telemetry { private isBufferingEvents: boolean = true; /** Resolves when the setup is complete or a timeout occurs */ public setupPromise: Promise<[string, boolean]> | undefined; - private deviceIdAbortController = new AbortController(); private eventCache: EventCache; - private getRawMachineId: () => Promise; + private deviceId: DeviceId; private constructor( private readonly session: Session, private readonly userConfig: UserConfig, private readonly commonProperties: CommonProperties, - { eventCache, getRawMachineId }: { eventCache: EventCache; getRawMachineId: () => Promise } + { eventCache, deviceId }: { eventCache: EventCache; deviceId: DeviceId } ) { this.eventCache = eventCache; - this.getRawMachineId = getRawMachineId; + this.deviceId = deviceId; } static create( session: Session, userConfig: UserConfig, + deviceId: DeviceId, { commonProperties = { ...MACHINE_METADATA }, eventCache = EventCache.getInstance(), - getRawMachineId = (): Promise => nodeMachineId.machineId(true), }: { eventCache?: EventCache; - getRawMachineId?: () => Promise; commonProperties?: CommonProperties; } = {} ): Telemetry { - const instance = new Telemetry(session, userConfig, commonProperties, { eventCache, getRawMachineId }); + const instance = new Telemetry(session, userConfig, commonProperties, { eventCache, deviceId }); void instance.setup(); return instance; @@ -57,46 +52,17 @@ export class Telemetry { if (!this.isTelemetryEnabled()) { return; } - this.setupPromise = Promise.all([ - getDeviceId({ - getMachineId: () => this.getRawMachineId(), - onError: (reason, error) => { - switch (reason) { - case "resolutionError": - this.session.logger.debug({ - id: LogId.telemetryDeviceIdFailure, - context: "telemetry", - message: String(error), - }); - break; - case "timeout": - this.session.logger.debug({ - id: LogId.telemetryDeviceIdTimeout, - context: "telemetry", - message: "Device ID retrieval timed out", - noRedaction: true, - }); - break; - case "abort": - // No need to log in the case of aborts - break; - } - }, - abortSignal: this.deviceIdAbortController.signal, - }), - detectContainerEnv(), - ]); - - const [deviceId, containerEnv] = await this.setupPromise; - - this.commonProperties.device_id = deviceId; + + this.setupPromise = Promise.all([this.deviceId.get(), detectContainerEnv()]); + const [deviceIdValue, containerEnv] = await this.setupPromise; + + this.commonProperties.device_id = deviceIdValue; this.commonProperties.is_container_env = containerEnv; this.isBufferingEvents = false; } public async close(): Promise { - this.deviceIdAbortController.abort(); this.isBufferingEvents = false; await this.emitEvents(this.eventCache.getEvents()); } @@ -136,8 +102,8 @@ export class Telemetry { return { ...this.commonProperties, transport: this.userConfig.transport, - mcp_client_version: this.session.agentRunner?.version, - mcp_client_name: this.session.agentRunner?.name, + mcp_client_version: this.session.mcpClient?.version, + mcp_client_name: this.session.mcpClient?.name, session_id: this.session.sessionId, config_atlas_auth: this.session.apiClient.hasCredentials() ? "true" : "false", config_connection_string: this.userConfig.connectionString ? "true" : "false", diff --git a/src/transports/base.ts b/src/transports/base.ts index b2ca3e1a..50213334 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -7,9 +7,11 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import { CompositeLogger, ConsoleLogger, DiskLogger, LoggerBase, McpLogger } from "../common/logger.js"; import { ExportsManager } from "../common/exportsManager.js"; import { ConnectionManager } from "../common/connectionManager.js"; +import { DeviceId } from "../helpers/deviceId.js"; export abstract class TransportRunnerBase { public logger: LoggerBase; + public deviceId: DeviceId; protected constructor(protected readonly userConfig: UserConfig) { const loggers: LoggerBase[] = []; @@ -28,6 +30,7 @@ export abstract class TransportRunnerBase { } this.logger = new CompositeLogger(...loggers); + this.deviceId = DeviceId.create(this.logger); } protected setupServer(userConfig: UserConfig): Server { @@ -43,7 +46,7 @@ export abstract class TransportRunnerBase { const logger = new CompositeLogger(...loggers); const exportsManager = ExportsManager.init(userConfig, logger); - const connectionManager = new ConnectionManager(userConfig, driverOptions, logger); + const connectionManager = new ConnectionManager(userConfig, driverOptions, logger, this.deviceId); const session = new Session({ apiBaseUrl: userConfig.apiBaseUrl, @@ -54,7 +57,7 @@ export abstract class TransportRunnerBase { connectionManager, }); - const telemetry = Telemetry.create(session, userConfig); + const telemetry = Telemetry.create(session, userConfig, this.deviceId); return new Server({ mcpServer, @@ -66,5 +69,13 @@ export abstract class TransportRunnerBase { abstract start(): Promise; - abstract close(): Promise; + abstract closeTransport(): Promise; + + async close(): Promise { + try { + await this.closeTransport(); + } finally { + this.deviceId.close(); + } + } } diff --git a/src/transports/stdio.ts b/src/transports/stdio.ts index 81141b5f..f74022d2 100644 --- a/src/transports/stdio.ts +++ b/src/transports/stdio.ts @@ -74,7 +74,7 @@ export class StdioRunner extends TransportRunnerBase { } } - async close(): Promise { + async closeTransport(): Promise { await this.server?.close(); } } diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index e6e93ba7..c78638e1 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -147,7 +147,7 @@ export class StreamableHttpRunner extends TransportRunnerBase { }); } - async close(): Promise { + async closeTransport(): Promise { await Promise.all([ this.sessionStore.closeAllSessions(), new Promise((resolve, reject) => { diff --git a/tests/accuracy/export.test.ts b/tests/accuracy/export.test.ts index 235a1fe4..46accd30 100644 --- a/tests/accuracy/export.test.ts +++ b/tests/accuracy/export.test.ts @@ -8,7 +8,6 @@ describeAccuracyTests([ { toolName: "export", parameters: { - exportTitle: Matcher.string(), database: "mflix", collection: "movies", exportTarget: [ @@ -27,7 +26,6 @@ describeAccuracyTests([ { toolName: "export", parameters: { - exportTitle: Matcher.string(), database: "mflix", collection: "movies", exportTarget: [ @@ -50,7 +48,6 @@ describeAccuracyTests([ { toolName: "export", parameters: { - exportTitle: Matcher.string(), database: "mflix", collection: "movies", exportTarget: [ @@ -78,7 +75,6 @@ describeAccuracyTests([ { toolName: "export", parameters: { - exportTitle: Matcher.string(), database: "mflix", collection: "movies", exportTarget: [ diff --git a/tests/integration/common/deviceId.test.ts b/tests/integration/common/deviceId.test.ts new file mode 100644 index 00000000..296d3440 --- /dev/null +++ b/tests/integration/common/deviceId.test.ts @@ -0,0 +1,112 @@ +import { describe, expect, it, vi, beforeEach, afterEach } from "vitest"; +import { DeviceId } from "../../../src/helpers/deviceId.js"; +import { CompositeLogger } from "../../../src/common/logger.js"; +import nodeMachineId from "node-machine-id"; + +describe("Device ID", () => { + let testLogger: CompositeLogger; + let deviceId: DeviceId; + + beforeEach(() => { + testLogger = new CompositeLogger(); + testLogger.debug = vi.fn(); + }); + + afterEach(() => { + deviceId?.close(); + }); + + describe("when resolving device ID", () => { + it("should successfully resolve device ID in real environment", async () => { + deviceId = DeviceId.create(testLogger); + const result = await deviceId.get(); + + expect(result).not.toBe("unknown"); + expect(result).toBeTruthy(); + expect(typeof result).toBe("string"); + expect(result.length).toBeGreaterThan(0); + }); + + it("should cache device ID after first resolution", async () => { + // spy on machineId + const machineIdSpy = vi.spyOn(nodeMachineId, "machineId"); + deviceId = DeviceId.create(testLogger); + + // First call + const result1 = await deviceId.get(); + expect(result1).not.toBe("unknown"); + + // Second call should be cached + const result2 = await deviceId.get(); + expect(result2).toBe(result1); + // check that machineId was called only once + expect(machineIdSpy).toHaveBeenCalledOnce(); + }); + + it("should handle concurrent device ID requests correctly", async () => { + deviceId = DeviceId.create(testLogger); + + const promises = Array.from({ length: 5 }, () => deviceId.get()); + + // All should resolve to the same value + const results = await Promise.all(promises); + const firstResult = results[0]; + expect(firstResult).not.toBe("unknown"); + + // All results should be identical + results.forEach((result) => { + expect(result).toBe(firstResult); + }); + }); + }); + + describe("when resolving device ID fails", () => { + const originalMachineId: typeof nodeMachineId.machineId = nodeMachineId.machineId; + + beforeEach(() => { + // mock the machineId function to throw an abort error + nodeMachineId.machineId = vi.fn(); + }); + + afterEach(() => { + // Restore original implementation + nodeMachineId.machineId = originalMachineId; + }); + + it("should handle resolution errors gracefully", async () => { + // mock the machineId function to throw a resolution error + nodeMachineId.machineId = vi.fn().mockImplementation(() => { + return new Promise((resolve, reject) => { + reject(new Error("Machine ID failed")); + }); + }); + deviceId = DeviceId.create(testLogger); + const handleDeviceIdErrorSpy = vi.spyOn(deviceId, "handleDeviceIdError" as keyof DeviceId); + + const result = await deviceId.get(); + + expect(result).toBe("unknown"); + expect(handleDeviceIdErrorSpy).toHaveBeenCalledWith( + "resolutionError", + expect.stringContaining("Machine ID failed") + ); + }); + + it("should handle abort signal scenarios gracefully", async () => { + // slow down the machineId function to give time to send abort signal + nodeMachineId.machineId = vi.fn().mockImplementation(() => { + return new Promise((resolve) => { + setTimeout(() => resolve("delayed-id"), 1000); + }); + }); + + deviceId = DeviceId.create(testLogger, 100); // Short timeout + const handleDeviceIdErrorSpy = vi.spyOn(deviceId, "handleDeviceIdError" as keyof DeviceId); + + const result = await deviceId.get(); + + expect(result).toBe("unknown"); + expect(handleDeviceIdErrorSpy).toHaveBeenCalledWith("timeout", expect.any(String)); + }); + }); +}); diff --git a/tests/integration/helpers.ts b/tests/integration/helpers.ts index 6dc86be2..b2ffaeb9 100644 --- a/tests/integration/helpers.ts +++ b/tests/integration/helpers.ts @@ -1,16 +1,17 @@ +import { CompositeLogger } from "../../src/common/logger.js"; +import { ExportsManager } from "../../src/common/exportsManager.js"; +import { Session } from "../../src/common/session.js"; +import { Server } from "../../src/server.js"; +import { Telemetry } from "../../src/telemetry/telemetry.js"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { InMemoryTransport } from "./inMemoryTransport.js"; -import { Server } from "../../src/server.js"; -import { DriverOptions, UserConfig } from "../../src/common/config.js"; +import { UserConfig, DriverOptions } from "../../src/common/config.js"; import { McpError, ResourceUpdatedNotificationSchema } from "@modelcontextprotocol/sdk/types.js"; -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import { Session } from "../../src/common/session.js"; -import { Telemetry } from "../../src/telemetry/telemetry.js"; import { config, driverOptions } from "../../src/common/config.js"; import { afterAll, afterEach, beforeAll, describe, expect, it, vi } from "vitest"; import { ConnectionManager, ConnectionState } from "../../src/common/connectionManager.js"; -import { CompositeLogger } from "../../src/common/logger.js"; -import { ExportsManager } from "../../src/common/exportsManager.js"; +import { DeviceId } from "../../src/helpers/deviceId.js"; interface ParameterInfo { name: string; @@ -41,6 +42,7 @@ export function setupIntegrationTest( ): IntegrationTest { let mcpClient: Client | undefined; let mcpServer: Server | undefined; + let deviceId: DeviceId | undefined; beforeAll(async () => { const userConfig = getUserConfig(); @@ -48,6 +50,7 @@ export function setupIntegrationTest( const clientTransport = new InMemoryTransport(); const serverTransport = new InMemoryTransport(); + const logger = new CompositeLogger(); await serverTransport.start(); await clientTransport.start(); @@ -65,9 +68,10 @@ export function setupIntegrationTest( } ); - const logger = new CompositeLogger(); const exportsManager = ExportsManager.init(userConfig, logger); - const connectionManager = new ConnectionManager(userConfig, driverOptions, logger); + + deviceId = DeviceId.create(logger); + const connectionManager = new ConnectionManager(userConfig, driverOptions, logger, deviceId); const session = new Session({ apiBaseUrl: userConfig.apiBaseUrl, @@ -86,7 +90,7 @@ export function setupIntegrationTest( userConfig.telemetry = "disabled"; - const telemetry = Telemetry.create(session, userConfig); + const telemetry = Telemetry.create(session, userConfig, deviceId); mcpServer = new Server({ session, @@ -114,6 +118,9 @@ export function setupIntegrationTest( await mcpServer?.close(); mcpServer = undefined; + + deviceId?.close(); + deviceId = undefined; }); const getMcpClient = (): Client => { diff --git a/tests/integration/telemetry.test.ts b/tests/integration/telemetry.test.ts index 62d959fa..cc51ed8b 100644 --- a/tests/integration/telemetry.test.ts +++ b/tests/integration/telemetry.test.ts @@ -1,28 +1,28 @@ -import { createHmac } from "crypto"; import { Telemetry } from "../../src/telemetry/telemetry.js"; import { Session } from "../../src/common/session.js"; import { config, driverOptions } from "../../src/common/config.js"; -import nodeMachineId from "node-machine-id"; +import { DeviceId } from "../../src/helpers/deviceId.js"; import { describe, expect, it } from "vitest"; import { CompositeLogger } from "../../src/common/logger.js"; import { ConnectionManager } from "../../src/common/connectionManager.js"; import { ExportsManager } from "../../src/common/exportsManager.js"; describe("Telemetry", () => { - it("should resolve the actual machine ID", async () => { - const actualId: string = await nodeMachineId.machineId(true); + it("should resolve the actual device ID", async () => { + const logger = new CompositeLogger(); - const actualHashedId = createHmac("sha256", actualId.toUpperCase()).update("atlascli").digest("hex"); + const deviceId = DeviceId.create(logger); + const actualDeviceId = await deviceId.get(); - const logger = new CompositeLogger(); const telemetry = Telemetry.create( new Session({ apiBaseUrl: "", logger, exportsManager: ExportsManager.init(config, logger), - connectionManager: new ConnectionManager(config, driverOptions, logger), + connectionManager: new ConnectionManager(config, driverOptions, logger, deviceId), }), - config + config, + deviceId ); expect(telemetry.getCommonProperties().device_id).toBe(undefined); @@ -30,7 +30,7 @@ describe("Telemetry", () => { await telemetry.setupPromise; - expect(telemetry.getCommonProperties().device_id).toBe(actualHashedId); + expect(telemetry.getCommonProperties().device_id).toBe(actualDeviceId); expect(telemetry["isBufferingEvents"]).toBe(false); }); }); diff --git a/tests/unit/common/session.test.ts b/tests/unit/common/session.test.ts index 9753d01f..2720c941 100644 --- a/tests/unit/common/session.test.ts +++ b/tests/unit/common/session.test.ts @@ -1,27 +1,36 @@ -import { beforeEach, describe, expect, it, vi } from "vitest"; +import { beforeEach, describe, expect, it, Mocked, vi } from "vitest"; import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; import { Session } from "../../../src/common/session.js"; import { config, driverOptions } from "../../../src/common/config.js"; import { CompositeLogger } from "../../../src/common/logger.js"; import { ConnectionManager } from "../../../src/common/connectionManager.js"; import { ExportsManager } from "../../../src/common/exportsManager.js"; +import { DeviceId } from "../../../src/helpers/deviceId.js"; vi.mock("@mongosh/service-provider-node-driver"); + const MockNodeDriverServiceProvider = vi.mocked(NodeDriverServiceProvider); +const MockDeviceId = vi.mocked(DeviceId.create(new CompositeLogger())); describe("Session", () => { let session: Session; + let mockDeviceId: Mocked; + beforeEach(() => { const logger = new CompositeLogger(); + + mockDeviceId = MockDeviceId; + session = new Session({ apiClientId: "test-client-id", apiBaseUrl: "https://api.test.com", logger, exportsManager: ExportsManager.init(config, logger), - connectionManager: new ConnectionManager(config, driverOptions, logger), + connectionManager: new ConnectionManager(config, driverOptions, logger, mockDeviceId), }); MockNodeDriverServiceProvider.connect = vi.fn().mockResolvedValue({} as unknown as NodeDriverServiceProvider); + MockDeviceId.get = vi.fn().mockResolvedValue("test-device-id"); }); describe("connectToMongoDB", () => { @@ -59,7 +68,9 @@ describe("Session", () => { expect(connectMock).toHaveBeenCalledOnce(); const connectionString = connectMock.mock.calls[0]?.[0]; if (testCase.expectAppName) { - expect(connectionString).toContain("appName=MongoDB+MCP+Server"); + // Check for the extended appName format: appName--deviceId--clientName + expect(connectionString).toContain("appName=MongoDB+MCP+Server+"); + expect(connectionString).toContain("--test-device-id--"); } else { expect(connectionString).not.toContain("appName=MongoDB+MCP+Server"); } @@ -77,5 +88,31 @@ describe("Session", () => { expect(connectionConfig?.proxy).toEqual({ useEnvironmentVariableProxies: true }); expect(connectionConfig?.applyProxyToOIDC).toEqual(true); }); + + it("should include client name when agent runner is set", async () => { + session.setMcpClient({ name: "test-client", version: "1.0.0" }); + + await session.connectToMongoDB({ connectionString: "mongodb://localhost:27017" }); + expect(session.serviceProvider).toBeDefined(); + + const connectMock = MockNodeDriverServiceProvider.connect; + expect(connectMock).toHaveBeenCalledOnce(); + const connectionString = connectMock.mock.calls[0]?.[0]; + + // Should include the client name in the appName + expect(connectionString).toContain("--test-device-id--test-client"); + }); + + it("should use 'unknown' for client name when agent runner is not set", async () => { + await session.connectToMongoDB({ connectionString: "mongodb://localhost:27017" }); + expect(session.serviceProvider).toBeDefined(); + + const connectMock = MockNodeDriverServiceProvider.connect; + expect(connectMock).toHaveBeenCalledOnce(); + const connectionString = connectMock.mock.calls[0]?.[0]; + + // Should use 'unknown' for client name when agent runner is not set + expect(connectionString).toContain("--test-device-id--unknown"); + }); }); }); diff --git a/tests/unit/helpers/connectionOptions.test.ts b/tests/unit/helpers/connectionOptions.test.ts new file mode 100644 index 00000000..5ee3d29f --- /dev/null +++ b/tests/unit/helpers/connectionOptions.test.ts @@ -0,0 +1,103 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { setAppNameParamIfMissing } from "../../../src/helpers/connectionOptions.js"; +import { DeviceId } from "../../../src/helpers/deviceId.js"; +import { CompositeLogger } from "../../../src/common/logger.js"; + +const MockDeviceId = vi.mocked(DeviceId.create(new CompositeLogger())); + +describe("Connection Options", () => { + let testLogger: CompositeLogger; + + beforeEach(() => { + testLogger = new CompositeLogger(); + testLogger.debug = vi.fn(); + MockDeviceId.get = vi.fn().mockResolvedValue("test-device-id"); + }); + + describe("setAppNameParamIfMissing", () => { + it("should set extended appName when no appName is present", async () => { + const connectionString = "mongodb://localhost:27017"; + const result = await setAppNameParamIfMissing({ + connectionString, + components: { + appName: "TestApp", + clientName: "TestClient", + deviceId: MockDeviceId.get(), + }, + }); + + expect(result).toContain("appName=TestApp--test-device-id--TestClient"); + }); + + it("should not modify connection string when appName is already present", async () => { + const connectionString = "mongodb://localhost:27017?appName=ExistingApp"; + const result = await setAppNameParamIfMissing({ + connectionString, + components: { + appName: "TestApp", + clientName: "TestClient", + }, + }); + + // The ConnectionString library normalizes URLs, so we need to check the content rather than exact equality + expect(result).toContain("appName=ExistingApp"); + expect(result).not.toContain("TestApp--test-device-id--TestClient"); + }); + + it("should use provided deviceId when available", async () => { + const connectionString = "mongodb://localhost:27017"; + const result = await setAppNameParamIfMissing({ + connectionString, + components: { + appName: "TestApp", + deviceId: Promise.resolve("custom-device-id"), + clientName: "TestClient", + }, + }); + + expect(result).toContain("appName=TestApp--custom-device-id--TestClient"); + }); + + it("should use 'unknown' for clientName when not provided", async () => { + const connectionString = "mongodb://localhost:27017"; + const result = await setAppNameParamIfMissing({ + connectionString, + components: { + appName: "TestApp", + deviceId: MockDeviceId.get(), + }, + }); + + expect(result).toContain("appName=TestApp--test-device-id--unknown"); + }); + + it("should use deviceId as unknown when deviceId is not provided", async () => { + const connectionString = "mongodb://localhost:27017"; + const result = await setAppNameParamIfMissing({ + connectionString, + components: { + appName: "TestApp", + clientName: "TestClient", + }, + }); + + expect(result).toContain("appName=TestApp--unknown--TestClient"); + }); + + it("should preserve other query parameters", async () => { + const connectionString = "mongodb://localhost:27017?retryWrites=true&w=majority"; + const result = await setAppNameParamIfMissing({ + connectionString, + components: { + appName: "TestApp", + clientName: "TestClient", + deviceId: MockDeviceId.get(), + }, + }); + + expect(result).toContain("retryWrites=true"); + expect(result).toContain("w=majority"); + expect(result).toContain("appName=TestApp--test-device-id--TestClient"); + }); + }); +}); diff --git a/tests/unit/helpers/deviceId.test.ts b/tests/unit/helpers/deviceId.test.ts new file mode 100644 index 00000000..68fd54e0 --- /dev/null +++ b/tests/unit/helpers/deviceId.test.ts @@ -0,0 +1,132 @@ +import { describe, expect, it, vi, beforeEach, afterEach } from "vitest"; +import { DeviceId } from "../../../src/helpers/deviceId.js"; +import { getDeviceId } from "@mongodb-js/device-id"; +import { CompositeLogger } from "../../../src/common/logger.js"; + +// Mock the dependencies +vi.mock("@mongodb-js/device-id"); +vi.mock("node-machine-id"); +const MockGetDeviceId = vi.mocked(getDeviceId); + +describe("deviceId", () => { + let testLogger: CompositeLogger; + let deviceId: DeviceId; + + beforeEach(() => { + vi.clearAllMocks(); + testLogger = new CompositeLogger(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + deviceId.close(); + }); + + it("should fail to create separate instances", () => { + deviceId = DeviceId.create(testLogger); + + // try to create a new device id and see it raises an error + expect(() => DeviceId.create(testLogger)).toThrow("DeviceId instance already exists"); + }); + + it("should successfully retrieve device ID", async () => { + const mockDeviceId = "test-device-id-123"; + MockGetDeviceId.mockResolvedValue(mockDeviceId); + + deviceId = DeviceId.create(testLogger); + const result = await deviceId.get(); + + expect(result).toBe(mockDeviceId); + }); + + it("should cache device ID after first retrieval", async () => { + const mockDeviceId = "test-device-id-123"; + MockGetDeviceId.mockResolvedValue(mockDeviceId); + + deviceId = DeviceId.create(testLogger); + + // First call should trigger calculation + const result1 = await deviceId.get(); + expect(result1).toBe(mockDeviceId); + expect(MockGetDeviceId).toHaveBeenCalledTimes(1); + + // Second call should use cached value + const result2 = await deviceId.get(); + expect(result2).toBe(mockDeviceId); + expect(MockGetDeviceId).toHaveBeenCalledTimes(1); // Still only called once + }); + + it("should allow aborting calculation", async () => { + MockGetDeviceId.mockImplementation((options) => { + // Simulate a long-running operation that can be aborted + return new Promise((resolve, reject) => { + const timeout = setTimeout(() => resolve("device-id"), 1000); + options.abortSignal?.addEventListener("abort", () => { + clearTimeout(timeout); + const abortError = new Error("Aborted"); + abortError.name = "AbortError"; + reject(abortError); + }); + }); + }); + + const deviceId = DeviceId.create(testLogger); + + // Start calculation + const promise = deviceId.get(); + + // Abort the calculation + deviceId.close(); + + // Should reject with AbortError + await expect(promise).rejects.toThrow("Aborted"); + }); + + it("should use custom timeout", async () => { + const mockDeviceId = "test-device-id-123"; + MockGetDeviceId.mockResolvedValue(mockDeviceId); + + const deviceId = DeviceId.create(testLogger, 5000); + const result = await deviceId.get(); + + expect(result).toBe(mockDeviceId); + expect(MockGetDeviceId).toHaveBeenCalledWith( + expect.objectContaining({ + timeout: 5000, + }) + ); + }); + + it("should use default timeout when not specified", async () => { + const mockDeviceId = "test-device-id-123"; + MockGetDeviceId.mockResolvedValue(mockDeviceId); + + deviceId = DeviceId.create(testLogger); + const result = await deviceId.get(); + + expect(result).toBe(mockDeviceId); + expect(MockGetDeviceId).toHaveBeenCalledWith( + expect.objectContaining({ + timeout: 3000, // DEVICE_ID_TIMEOUT + }) + ); + }); + + it("should handle multiple close calls gracefully", () => { + deviceId = DeviceId.create(testLogger); + + // First close should work + expect(() => deviceId.close()).not.toThrow(); + + // Second close should also work without error + expect(() => deviceId.close()).not.toThrow(); + }); + + it("should not throw error when get is called after close", async () => { + deviceId = DeviceId.create(testLogger); + deviceId.close(); + + // undefined should be returned + expect(await deviceId.get()).toBeUndefined(); + }); +}); diff --git a/tests/unit/resources/common/debug.test.ts b/tests/unit/resources/common/debug.test.ts index 4f51e381..0292a726 100644 --- a/tests/unit/resources/common/debug.test.ts +++ b/tests/unit/resources/common/debug.test.ts @@ -6,16 +6,18 @@ import { config, driverOptions } from "../../../../src/common/config.js"; import { CompositeLogger } from "../../../../src/common/logger.js"; import { ConnectionManager } from "../../../../src/common/connectionManager.js"; import { ExportsManager } from "../../../../src/common/exportsManager.js"; +import { DeviceId } from "../../../../src/helpers/deviceId.js"; describe("debug resource", () => { const logger = new CompositeLogger(); + const deviceId = DeviceId.create(logger); const session = new Session({ apiBaseUrl: "", logger, exportsManager: ExportsManager.init(config, logger), - connectionManager: new ConnectionManager(config, driverOptions, logger), + connectionManager: new ConnectionManager(config, driverOptions, logger, deviceId), }); - const telemetry = Telemetry.create(session, { ...config, telemetry: "disabled" }); + const telemetry = Telemetry.create(session, { ...config, telemetry: "disabled" }, deviceId); let debugResource: DebugResource = new DebugResource(session, config, telemetry); diff --git a/tests/unit/telemetry.test.ts b/tests/unit/telemetry.test.ts index 6bc3ec45..e1a159d0 100644 --- a/tests/unit/telemetry.test.ts +++ b/tests/unit/telemetry.test.ts @@ -1,13 +1,13 @@ import { ApiClient } from "../../src/common/atlas/apiClient.js"; import { Session } from "../../src/common/session.js"; -import { DEVICE_ID_TIMEOUT, Telemetry } from "../../src/telemetry/telemetry.js"; +import { Telemetry } from "../../src/telemetry/telemetry.js"; import { BaseEvent, TelemetryResult } from "../../src/telemetry/types.js"; import { EventCache } from "../../src/telemetry/eventCache.js"; import { config } from "../../src/common/config.js"; import { afterEach, beforeEach, describe, it, vi, expect } from "vitest"; -import { LogId, NullLogger } from "../../src/common/logger.js"; -import { createHmac } from "crypto"; +import { NullLogger } from "../../src/common/logger.js"; import type { MockedFunction } from "vitest"; +import { DeviceId } from "../../src/helpers/deviceId.js"; // Mock the ApiClient to avoid real API calls vi.mock("../../src/common/atlas/apiClient.js"); @@ -18,9 +18,6 @@ vi.mock("../../src/telemetry/eventCache.js"); const MockEventCache = vi.mocked(EventCache); describe("Telemetry", () => { - const machineId = "test-machine-id"; - const hashedMachineId = createHmac("sha256", machineId.toUpperCase()).update("atlascli").digest("hex"); - let mockApiClient: { sendEvents: MockedFunction<(events: BaseEvent[]) => Promise>; hasCredentials: MockedFunction<() => boolean>; @@ -118,19 +115,23 @@ describe("Telemetry", () => { mockEventCache.appendEvents = vi.fn().mockResolvedValue(undefined); MockEventCache.getInstance = vi.fn().mockReturnValue(mockEventCache as unknown as EventCache); + const mockDeviceId = { + get: vi.fn().mockResolvedValue("test-device-id"), + } as unknown as DeviceId; + // Create a simplified session with our mocked API client session = { apiClient: mockApiClient as unknown as ApiClient, sessionId: "test-session-id", agentRunner: { name: "test-agent", version: "1.0.0" } as const, + mcpClient: { name: "test-agent", version: "1.0.0" }, close: vi.fn().mockResolvedValue(undefined), setAgentRunner: vi.fn().mockResolvedValue(undefined), logger: new NullLogger(), } as unknown as Session; - telemetry = Telemetry.create(session, config, { + telemetry = Telemetry.create(session, config, mockDeviceId, { eventCache: mockEventCache as unknown as EventCache, - getRawMachineId: () => Promise.resolve(machineId), }); config.telemetry = "enabled"; @@ -205,27 +206,27 @@ describe("Telemetry", () => { session_id: "test-session-id", config_atlas_auth: "true", config_connection_string: expect.any(String) as unknown as string, - device_id: hashedMachineId, + device_id: "test-device-id", }; expect(commonProps).toMatchObject(expectedProps); }); - describe("machine ID resolution", () => { + describe("device ID resolution", () => { beforeEach(() => { vi.clearAllMocks(); - vi.useFakeTimers(); }); afterEach(() => { vi.clearAllMocks(); - vi.useRealTimers(); }); - it("should successfully resolve the machine ID", async () => { - telemetry = Telemetry.create(session, config, { - getRawMachineId: () => Promise.resolve(machineId), - }); + it("should successfully resolve the device ID", async () => { + const mockDeviceId = { + get: vi.fn().mockResolvedValue("test-device-id"), + } as unknown as DeviceId; + + telemetry = Telemetry.create(session, config, mockDeviceId); expect(telemetry["isBufferingEvents"]).toBe(true); expect(telemetry.getCommonProperties().device_id).toBe(undefined); @@ -233,15 +234,15 @@ describe("Telemetry", () => { await telemetry.setupPromise; expect(telemetry["isBufferingEvents"]).toBe(false); - expect(telemetry.getCommonProperties().device_id).toBe(hashedMachineId); + expect(telemetry.getCommonProperties().device_id).toBe("test-device-id"); }); - it("should handle machine ID resolution failure", async () => { - const loggerSpy = vi.spyOn(session.logger, "debug"); + it("should handle device ID resolution failure gracefully", async () => { + const mockDeviceId = { + get: vi.fn().mockResolvedValue("unknown"), + } as unknown as DeviceId; - telemetry = Telemetry.create(session, config, { - getRawMachineId: () => Promise.reject(new Error("Failed to get device ID")), - }); + telemetry = Telemetry.create(session, config, mockDeviceId); expect(telemetry["isBufferingEvents"]).toBe(true); expect(telemetry.getCommonProperties().device_id).toBe(undefined); @@ -249,41 +250,25 @@ describe("Telemetry", () => { await telemetry.setupPromise; expect(telemetry["isBufferingEvents"]).toBe(false); + // Should use "unknown" as fallback when device ID resolution fails expect(telemetry.getCommonProperties().device_id).toBe("unknown"); - - expect(loggerSpy).toHaveBeenCalledWith({ - id: LogId.telemetryDeviceIdFailure, - context: "telemetry", - message: "Error: Failed to get device ID", - }); }); - it("should timeout if machine ID resolution takes too long", async () => { - const loggerSpy = vi.spyOn(session.logger, "debug"); + it("should handle device ID timeout gracefully", async () => { + const mockDeviceId = { + get: vi.fn().mockResolvedValue("unknown"), + } as unknown as DeviceId; - telemetry = Telemetry.create(session, config, { getRawMachineId: () => new Promise(() => {}) }); + telemetry = Telemetry.create(session, config, mockDeviceId); expect(telemetry["isBufferingEvents"]).toBe(true); expect(telemetry.getCommonProperties().device_id).toBe(undefined); - vi.advanceTimersByTime(DEVICE_ID_TIMEOUT / 2); - - // Make sure the timeout doesn't happen prematurely. - expect(telemetry["isBufferingEvents"]).toBe(true); - expect(telemetry.getCommonProperties().device_id).toBe(undefined); - - vi.advanceTimersByTime(DEVICE_ID_TIMEOUT); - await telemetry.setupPromise; - expect(telemetry.getCommonProperties().device_id).toBe("unknown"); expect(telemetry["isBufferingEvents"]).toBe(false); - expect(loggerSpy).toHaveBeenCalledWith({ - id: LogId.telemetryDeviceIdTimeout, - context: "telemetry", - message: "Device ID retrieval timed out", - noRedaction: true, - }); + // Should use "unknown" as fallback when device ID times out + expect(telemetry.getCommonProperties().device_id).toBe("unknown"); }); }); });