Skip to content

Autoconfigure MCP Client with and async HTTP request customizer #3994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.client.transport.AsyncHttpRequestCustomizer;
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
import io.modelcontextprotocol.client.transport.SyncHttpRequestCustomizer;
import io.modelcontextprotocol.spec.McpSchema;

import org.springframework.ai.mcp.client.common.autoconfigure.NamedClientMcpTransport;
Expand All @@ -36,6 +38,7 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.core.log.LogAccessor;

/**
* Auto-configuration for Server-Sent Events (SSE) HTTP client transport in the Model
Expand Down Expand Up @@ -68,6 +71,8 @@
matchIfMissing = true)
public class SseHttpClientTransportAutoConfiguration {

private static final LogAccessor logger = new LogAccessor(SseHttpClientTransportAutoConfiguration.class);

/**
* Creates a list of HTTP client-based SSE transports for MCP communication.
*
Expand All @@ -77,15 +82,22 @@ public class SseHttpClientTransportAutoConfiguration {
* <li>A new HttpClient instance
* <li>Server URL from properties
* <li>ObjectMapper for JSON processing
* <li>A sync or async HTTP request customizer. Sync takes precedence.
* </ul>
* @param sseProperties the SSE client properties containing server configurations
* @param objectMapperProvider the provider for ObjectMapper or a new instance if not
* available
* @param syncHttpRequestCustomizer provider for {@link SyncHttpRequestCustomizer} if
* available
* @param asyncHttpRequestCustomizer provider fo {@link AsyncHttpRequestCustomizer} if
* available
* @return list of named MCP transports
*/
@Bean
public List<NamedClientMcpTransport> sseHttpClientTransports(McpSseClientProperties sseProperties,
ObjectProvider<ObjectMapper> objectMapperProvider) {
ObjectProvider<ObjectMapper> objectMapperProvider,
ObjectProvider<SyncHttpRequestCustomizer> syncHttpRequestCustomizer,
ObjectProvider<AsyncHttpRequestCustomizer> asyncHttpRequestCustomizer) {

ObjectMapper objectMapper = objectMapperProvider.getIfAvailable(ObjectMapper::new);

Expand All @@ -96,11 +108,21 @@ public List<NamedClientMcpTransport> sseHttpClientTransports(McpSseClientPropert
String baseUrl = serverParameters.getValue().url();
String sseEndpoint = serverParameters.getValue().sseEndpoint() != null
? serverParameters.getValue().sseEndpoint() : "/sse";
var transport = HttpClientSseClientTransport.builder(baseUrl)
HttpClientSseClientTransport.Builder transportBuilder = HttpClientSseClientTransport.builder(baseUrl)
.sseEndpoint(sseEndpoint)
.clientBuilder(HttpClient.newBuilder())
.objectMapper(objectMapper)
.build();
.objectMapper(objectMapper);

asyncHttpRequestCustomizer.ifUnique(transportBuilder::asyncHttpRequestCustomizer);
syncHttpRequestCustomizer.ifUnique(transportBuilder::httpRequestCustomizer);
if (asyncHttpRequestCustomizer.getIfUnique() != null && syncHttpRequestCustomizer.getIfUnique() != null) {
logger.warn("Found beans of type %s and %s. Using %s.".formatted(
AsyncHttpRequestCustomizer.class.getSimpleName(),
SyncHttpRequestCustomizer.class.getSimpleName(),
SyncHttpRequestCustomizer.class.getSimpleName()));
}

HttpClientSseClientTransport transport = transportBuilder.build();
sseTransports.add(new NamedClientMcpTransport(serverParameters.getKey(), transport));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.core.log.LogAccessor;

import com.fasterxml.jackson.databind.ObjectMapper;

import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
import io.modelcontextprotocol.client.transport.AsyncHttpRequestCustomizer;
import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
import io.modelcontextprotocol.client.transport.SyncHttpRequestCustomizer;
import io.modelcontextprotocol.spec.McpSchema;

/**
Expand All @@ -59,6 +61,7 @@
* connections
* <li>Configures ObjectMapper for JSON serialization/deserialization
* <li>Supports multiple named server connections with different URLs
* <li>Adds a sync or async HTTP request customizer. Sync takes precedence.
* </ul>
*
* @see HttpClientStreamableHttpTransport
Expand All @@ -71,6 +74,8 @@
matchIfMissing = true)
public class StreamableHttpHttpClientTransportAutoConfiguration {

private static final LogAccessor logger = new LogAccessor(StreamableHttpHttpClientTransportAutoConfiguration.class);

/**
* Creates a list of HTTP client-based Streamable HTTP transports for MCP
* communication.
Expand All @@ -86,11 +91,17 @@ public class StreamableHttpHttpClientTransportAutoConfiguration {
* configurations
* @param objectMapperProvider the provider for ObjectMapper or a new instance if not
* available
* @param syncHttpRequestCustomizer provider for {@link SyncHttpRequestCustomizer} if
* available
* @param asyncHttpRequestCustomizer provider fo {@link AsyncHttpRequestCustomizer} if
* available
* @return list of named MCP transports
*/
@Bean
public List<NamedClientMcpTransport> streamableHttpHttpClientTransports(
McpStreamableHttpClientProperties streamableProperties, ObjectProvider<ObjectMapper> objectMapperProvider) {
McpStreamableHttpClientProperties streamableProperties, ObjectProvider<ObjectMapper> objectMapperProvider,
ObjectProvider<SyncHttpRequestCustomizer> syncHttpRequestCustomizer,
ObjectProvider<AsyncHttpRequestCustomizer> asyncHttpRequestCustomizer) {

ObjectMapper objectMapper = objectMapperProvider.getIfAvailable(ObjectMapper::new);

Expand All @@ -103,11 +114,22 @@ public List<NamedClientMcpTransport> streamableHttpHttpClientTransports(
String streamableHttpEndpoint = serverParameters.getValue().endpoint() != null
? serverParameters.getValue().endpoint() : "/mcp";

HttpClientStreamableHttpTransport transport = HttpClientStreamableHttpTransport.builder(baseUrl)
HttpClientStreamableHttpTransport.Builder transportBuilder = HttpClientStreamableHttpTransport
.builder(baseUrl)
.endpoint(streamableHttpEndpoint)
.clientBuilder(HttpClient.newBuilder())
.objectMapper(objectMapper)
.build();
.objectMapper(objectMapper);

asyncHttpRequestCustomizer.ifUnique(transportBuilder::asyncHttpRequestCustomizer);
syncHttpRequestCustomizer.ifUnique(transportBuilder::httpRequestCustomizer);
if (asyncHttpRequestCustomizer.getIfUnique() != null && syncHttpRequestCustomizer.getIfUnique() != null) {
logger.warn("Found beans of type %s and %s. Using %s.".formatted(
AsyncHttpRequestCustomizer.class.getSimpleName(),
SyncHttpRequestCustomizer.class.getSimpleName(),
SyncHttpRequestCustomizer.class.getSimpleName()));
}

HttpClientStreamableHttpTransport transport = transportBuilder.build();

streamableHttpTransports.add(new NamedClientMcpTransport(serverParameters.getKey(), transport));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
/*
* Copyright 2024-2024 the original author or authors.
* Copyright 2024-2025 the original author or authors.
*/

package org.springframework.ai.mcp.client.autoconfigure;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;

import java.util.List;

Expand All @@ -17,12 +23,19 @@
import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration;
import org.springframework.ai.mcp.client.httpclient.autoconfigure.SseHttpClientTransportAutoConfiguration;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.context.annotation.UserConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.wait.strategy.Wait;

import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.client.transport.AsyncHttpRequestCustomizer;
import io.modelcontextprotocol.client.transport.SyncHttpRequestCustomizer;
import io.modelcontextprotocol.spec.McpSchema.ListToolsResult;
import reactor.core.publisher.Mono;

@Timeout(15)
public class SseHttpClientTransportAutoConfigurationIT {
Expand Down Expand Up @@ -79,8 +92,69 @@ void streamableHttpTest() {
assertThat(toolsResult.tools()).hasSize(8);

logger.info("tools = {}", toolsResult);

});
}

@Test
void usesSyncRequestCustomizer() {
this.contextRunner
.withConfiguration(UserConfigurations.of(SyncRequestCustomizerConfiguration.class,
AsyncRequestCustomizerConfiguration.class))
.run(context -> {
List<McpSyncClient> mcpClients = (List<McpSyncClient>) context.getBean("mcpSyncClients");

assertThat(mcpClients).isNotNull();
assertThat(mcpClients).hasSize(1);

McpSyncClient mcpClient = mcpClients.get(0);

mcpClient.ping();

verify(context.getBean(SyncHttpRequestCustomizer.class), atLeastOnce()).customize(any(), any(), any(),
any());
verifyNoInteractions(context.getBean(AsyncHttpRequestCustomizer.class));
});
}

@Test
void usesAsyncRequestCustomizer() {
this.contextRunner.withConfiguration(UserConfigurations.of(AsyncRequestCustomizerConfiguration.class))
.run(context -> {
List<McpSyncClient> mcpClients = (List<McpSyncClient>) context.getBean("mcpSyncClients");

assertThat(mcpClients).isNotNull();
assertThat(mcpClients).hasSize(1);

McpSyncClient mcpClient = mcpClients.get(0);

mcpClient.ping();

verify(context.getBean(AsyncHttpRequestCustomizer.class), atLeastOnce()).customize(any(), any(), any(),
any());
});
}

@Configuration
static class SyncRequestCustomizerConfiguration {

@Bean
SyncHttpRequestCustomizer syncHttpRequestCustomizer() {
return mock(SyncHttpRequestCustomizer.class);
}

}

@Configuration
static class AsyncRequestCustomizerConfiguration {

@Bean
AsyncHttpRequestCustomizer asyncHttpRequestCustomizer() {
AsyncHttpRequestCustomizer requestCustomizerMock = mock(AsyncHttpRequestCustomizer.class);
when(requestCustomizerMock.customize(any(), any(), any(), any()))
.thenAnswer(invocation -> Mono.just(invocation.getArguments()[0]));
return requestCustomizerMock;
}

}

}
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
/*
* Copyright 2024-2024 the original author or authors.
* Copyright 2024-2025 the original author or authors.
*/

package org.springframework.ai.mcp.client.autoconfigure;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;

import java.util.List;

Expand All @@ -17,12 +23,19 @@
import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration;
import org.springframework.ai.mcp.client.httpclient.autoconfigure.StreamableHttpHttpClientTransportAutoConfiguration;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.context.annotation.UserConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.wait.strategy.Wait;

import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.client.transport.AsyncHttpRequestCustomizer;
import io.modelcontextprotocol.client.transport.SyncHttpRequestCustomizer;
import io.modelcontextprotocol.spec.McpSchema.ListToolsResult;
import reactor.core.publisher.Mono;

@Timeout(15)
public class StreamableHttpHttpClientTransportAutoConfigurationIT {
Expand Down Expand Up @@ -80,8 +93,69 @@ void streamableHttpTest() {
assertThat(toolsResult.tools()).hasSize(8);

logger.info("tools = {}", toolsResult);

});
}

@Test
void usesSyncRequestCustomizer() {
this.contextRunner
.withConfiguration(UserConfigurations.of(SyncRequestCustomizerConfiguration.class,
AsyncRequestCustomizerConfiguration.class))
.run(context -> {
List<McpSyncClient> mcpClients = (List<McpSyncClient>) context.getBean("mcpSyncClients");

assertThat(mcpClients).isNotNull();
assertThat(mcpClients).hasSize(1);

McpSyncClient mcpClient = mcpClients.get(0);

mcpClient.ping();

verify(context.getBean(SyncHttpRequestCustomizer.class), atLeastOnce()).customize(any(), any(), any(),
any());
verifyNoInteractions(context.getBean(AsyncHttpRequestCustomizer.class));
});
}

@Test
void usesAsyncRequestCustomizer() {
this.contextRunner.withConfiguration(UserConfigurations.of(AsyncRequestCustomizerConfiguration.class))
.run(context -> {
List<McpSyncClient> mcpClients = (List<McpSyncClient>) context.getBean("mcpSyncClients");

assertThat(mcpClients).isNotNull();
assertThat(mcpClients).hasSize(1);

McpSyncClient mcpClient = mcpClients.get(0);

mcpClient.ping();

verify(context.getBean(AsyncHttpRequestCustomizer.class), atLeastOnce()).customize(any(), any(), any(),
any());
});
}

@Configuration
static class SyncRequestCustomizerConfiguration {

@Bean
SyncHttpRequestCustomizer syncHttpRequestCustomizer() {
return mock(SyncHttpRequestCustomizer.class);
}

}

@Configuration
static class AsyncRequestCustomizerConfiguration {

@Bean
AsyncHttpRequestCustomizer asyncHttpRequestCustomizer() {
AsyncHttpRequestCustomizer requestCustomizerMock = mock(AsyncHttpRequestCustomizer.class);
when(requestCustomizerMock.customize(any(), any(), any(), any()))
.thenAnswer(invocation -> Mono.just(invocation.getArguments()[0]));
return requestCustomizerMock;
}

}

}