From b1eec857f4e13769f5a45e1a0adb681b4621d815 Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Thu, 17 Jul 2025 19:01:27 +0100 Subject: [PATCH 1/4] Add MCP Client tool predicate for filtering the MCP tools - Introduce MCP Sync/Async client BiPredicate interface as a tool filter for the MCP Sync/Async ToolCallbackProvider to use when filtering the MCP tools - Update MCP ToolCallbackAutoConfiguration to use these BiPredicate beans when defined (default is to allow all) - Add test verifying the tool filter configuration on both sync and async toolcallback provider auto-configuration - Update the unit tests for the MCP toolcallback provider Signed-off-by: Ilayaperumal Gopinathan --- .../McpToolCallbackAutoConfiguration.java | 18 +++- ...llbackAutoConfigurationConditionTests.java | 101 ++++++++++++++++++ .../ai/mcp/AsyncMcpToolCallbackProvider.java | 8 +- .../ai/mcp/McpAsyncClientBiPredicate.java | 32 ++++++ .../ai/mcp/McpSyncClientBiPredicate.java | 32 ++++++ .../ai/mcp/SyncMcpToolCallbackProvider.java | 8 +- .../mcp/SyncMcpToolCallbackProviderTests.java | 11 +- 7 files changed, 189 insertions(+), 21 deletions(-) create mode 100644 mcp/common/src/main/java/org/springframework/ai/mcp/McpAsyncClientBiPredicate.java create mode 100644 mcp/common/src/main/java/org/springframework/ai/mcp/McpSyncClientBiPredicate.java diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java index a477af8a47a..e8a7259b5e0 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java @@ -17,13 +17,18 @@ package org.springframework.ai.mcp.client.common.autoconfigure; import java.util.List; +import java.util.function.BiPredicate; import io.modelcontextprotocol.client.McpAsyncClient; import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.spec.McpSchema; import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider; +import org.springframework.ai.mcp.McpAsyncClientBiPredicate; +import org.springframework.ai.mcp.McpSyncClientBiPredicate; import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties; +import org.springframework.ai.tool.annotation.Tool; import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.AllNestedConditions; @@ -51,16 +56,21 @@ public class McpToolCallbackAutoConfiguration { @Bean @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", matchIfMissing = true) - public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider> syncMcpClients) { + public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider syncClientsToolFilter, + ObjectProvider> syncMcpClients) { List mcpClients = syncMcpClients.stream().flatMap(List::stream).toList(); - return new SyncMcpToolCallbackProvider(mcpClients); + return new SyncMcpToolCallbackProvider(syncClientsToolFilter.getIfUnique((() -> (McpSyncClient, tool) -> true)), + mcpClients); } @Bean @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") - public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(ObjectProvider> mcpClientsProvider) { + public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks( + ObjectProvider asyncClientsToolFilter, + ObjectProvider> mcpClientsProvider) { List mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList(); - return new AsyncMcpToolCallbackProvider(mcpClients); + return new AsyncMcpToolCallbackProvider( + asyncClientsToolFilter.getIfUnique(() -> (McpAsyncClient, tool) -> true), mcpClients); } public static class McpToolCallbackAutoConfigurationCondition extends AllNestedConditions { diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java index 3708e0fa036..82ff854109c 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java @@ -16,8 +16,19 @@ package org.springframework.ai.mcp.client.common.autoconfigure; +import java.lang.reflect.Field; +import java.util.List; + +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.spec.McpSchema; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider; +import org.springframework.ai.mcp.McpAsyncClientBiPredicate; +import org.springframework.ai.mcp.McpSyncClientBiPredicate; +import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration.McpToolCallbackAutoConfigurationCondition; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; @@ -25,6 +36,8 @@ import org.springframework.context.annotation.Configuration; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Tests for {@link McpToolCallbackAutoConfigurationCondition}. @@ -73,6 +86,58 @@ void doesMatchWhenBothPropertiesAreMissing() { this.contextRunner.run(context -> assertThat(context).hasBean("testBean")); } + @Test + void verifySyncToolCallbackFilterConfiguration() { + this.contextRunner + .withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpSyncClientFilterConfiguration.class) + .withPropertyValues("spring.ai.mcp.client.type=SYNC") + .run(context -> { + assertThat(context).hasBean("syncClientFilter"); + SyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(SyncMcpToolCallbackProvider.class); + Field field = SyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter"); + field.setAccessible(true); + McpSyncClientBiPredicate toolFilter = (McpSyncClientBiPredicate) field.get(toolCallbackProvider); + McpSyncClient syncClient1 = mock(McpSyncClient.class); + var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0"); + when(syncClient1.getClientInfo()).thenReturn(clientInfo1); + McpSchema.Tool tool1 = mock(McpSchema.Tool.class); + when(tool1.name()).thenReturn("tool1"); + McpSchema.Tool tool2 = mock(McpSchema.Tool.class); + when(tool2.name()).thenReturn("tool2"); + McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class); + when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2)); + when(syncClient1.listTools()).thenReturn(listToolsResult1); + assertThat(toolFilter.test(syncClient1, tool1)).isFalse(); + assertThat(toolFilter.test(syncClient1, tool2)).isTrue(); + }); + } + + @Test + void verifyASyncToolCallbackFilterConfiguration() { + this.contextRunner + .withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpAsyncClientFilterConfiguration.class) + .withPropertyValues("spring.ai.mcp.client.type=ASYNC") + .run(context -> { + assertThat(context).hasBean("asyncClientFilter"); + AsyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(AsyncMcpToolCallbackProvider.class); + Field field = AsyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter"); + field.setAccessible(true); + McpAsyncClientBiPredicate toolFilter = (McpAsyncClientBiPredicate) field.get(toolCallbackProvider); + McpAsyncClient asyncClient1 = mock(McpAsyncClient.class); + var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0"); + when(asyncClient1.getClientInfo()).thenReturn(clientInfo1); + McpSchema.Tool tool1 = mock(McpSchema.Tool.class); + when(tool1.name()).thenReturn("tool1"); + McpSchema.Tool tool2 = mock(McpSchema.Tool.class); + when(tool2.name()).thenReturn("tool2"); + McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class); + when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2)); + when(asyncClient1.listTools()).thenReturn(Mono.just(listToolsResult1)); + assertThat(toolFilter.test(asyncClient1, tool1)).isFalse(); + assertThat(toolFilter.test(asyncClient1, tool2)).isTrue(); + }); + } + @Configuration @Conditional(McpToolCallbackAutoConfigurationCondition.class) static class TestConfiguration { @@ -84,4 +149,40 @@ String testBean() { } + @Configuration + static class McpSyncClientFilterConfiguration { + + @Bean + McpSyncClientBiPredicate syncClientFilter() { + return new McpSyncClientBiPredicate() { + @Override + public boolean test(McpSyncClient mcpSyncClient, McpSchema.Tool tool) { + if (mcpSyncClient.getClientInfo().name().equals("client1") && tool.name().contains("tool1")) { + return false; + } + return true; + } + }; + } + + } + + @Configuration + static class McpAsyncClientFilterConfiguration { + + @Bean + McpAsyncClientBiPredicate asyncClientFilter() { + return new McpAsyncClientBiPredicate() { + @Override + public boolean test(McpAsyncClient mcpAsyncClient, McpSchema.Tool tool) { + if (mcpAsyncClient.getClientInfo().name().equals("client1") && tool.name().contains("tool1")) { + return false; + } + return true; + } + }; + } + + } + } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java index 3525b9593e3..a508fcec1d6 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java @@ -18,10 +18,8 @@ import java.util.ArrayList; import java.util.List; -import java.util.function.BiPredicate; import io.modelcontextprotocol.client.McpAsyncClient; -import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Flux; @@ -76,7 +74,7 @@ public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider { private final List mcpClients; - private final BiPredicate toolFilter; + private final McpAsyncClientBiPredicate toolFilter; /** * Creates a new {@code AsyncMcpToolCallbackProvider} instance with a list of MCP @@ -84,7 +82,7 @@ public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider { * @param mcpClients the list of MCP clients to use for discovering tools * @param toolFilter a filter to apply to each discovered tool */ - public AsyncMcpToolCallbackProvider(BiPredicate toolFilter, List mcpClients) { + public AsyncMcpToolCallbackProvider(McpAsyncClientBiPredicate toolFilter, List mcpClients) { Assert.notNull(mcpClients, "MCP clients must not be null"); Assert.notNull(toolFilter, "Tool filter must not be null"); this.mcpClients = mcpClients; @@ -109,7 +107,7 @@ public AsyncMcpToolCallbackProvider(List mcpClients) { * @param mcpClients the MCP clients to use for discovering tools * @param toolFilter a filter to apply to each discovered tool */ - public AsyncMcpToolCallbackProvider(BiPredicate toolFilter, McpAsyncClient... mcpClients) { + public AsyncMcpToolCallbackProvider(McpAsyncClientBiPredicate toolFilter, McpAsyncClient... mcpClients) { this(toolFilter, List.of(mcpClients)); } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpAsyncClientBiPredicate.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpAsyncClientBiPredicate.java new file mode 100644 index 00000000000..1e2dfe876f7 --- /dev/null +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpAsyncClientBiPredicate.java @@ -0,0 +1,32 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp; + +import java.util.function.BiPredicate; + +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.spec.McpSchema; + +/** + * A {@link BiPredicate} for {@link AsyncMcpToolCallbackProvider} to filter the discovered + * tool for the given {@link McpAsyncClient}. + * + * @author Ilayaperumal Gopinathan + */ +public interface McpAsyncClientBiPredicate extends BiPredicate { + +} diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpSyncClientBiPredicate.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpSyncClientBiPredicate.java new file mode 100644 index 00000000000..2be8ca2405a --- /dev/null +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpSyncClientBiPredicate.java @@ -0,0 +1,32 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp; + +import java.util.function.BiPredicate; + +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.spec.McpSchema; + +/** + * A {@link BiPredicate} for {@link SyncMcpToolCallbackProvider} to filter the discovered + * tool for the given {@link McpSyncClient}. + * + * @author Ilayaperumal Gopinathan + */ +public interface McpSyncClientBiPredicate extends BiPredicate { + +} diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java index 7d0aa4276a1..407c50be6ee 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java @@ -17,10 +17,8 @@ package org.springframework.ai.mcp; import java.util.List; -import java.util.function.BiPredicate; import io.modelcontextprotocol.client.McpSyncClient; -import io.modelcontextprotocol.spec.McpSchema.Tool; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.ToolCallbackProvider; @@ -72,7 +70,7 @@ public class SyncMcpToolCallbackProvider implements ToolCallbackProvider { private final List mcpClients; - private final BiPredicate toolFilter; + private final McpSyncClientBiPredicate toolFilter; /** * Creates a new {@code SyncMcpToolCallbackProvider} instance with a list of MCP @@ -80,7 +78,7 @@ public class SyncMcpToolCallbackProvider implements ToolCallbackProvider { * @param mcpClients the list of MCP clients to use for discovering tools * @param toolFilter a filter to apply to each discovered tool */ - public SyncMcpToolCallbackProvider(BiPredicate toolFilter, List mcpClients) { + public SyncMcpToolCallbackProvider(McpSyncClientBiPredicate toolFilter, List mcpClients) { Assert.notNull(mcpClients, "MCP clients must not be null"); Assert.notNull(toolFilter, "Tool filter must not be null"); this.mcpClients = mcpClients; @@ -102,7 +100,7 @@ public SyncMcpToolCallbackProvider(List mcpClients) { * @param mcpClients the MCP clients to use for discovering tools * @param toolFilter a filter to apply to each discovered tool */ - public SyncMcpToolCallbackProvider(BiPredicate toolFilter, McpSyncClient... mcpClients) { + public SyncMcpToolCallbackProvider(McpSyncClientBiPredicate toolFilter, McpSyncClient... mcpClients) { this(toolFilter, List.of(mcpClients)); } diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java index d8830d1718a..b82581c5e09 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java @@ -17,7 +17,6 @@ package org.springframework.ai.mcp; import java.util.List; -import java.util.function.BiPredicate; import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.spec.McpSchema.Implementation; @@ -164,7 +163,7 @@ void toolFilterShouldRejectAllToolsWhenConfigured() { when(this.mcpClient.listTools()).thenReturn(listToolsResult); // Create a filter that rejects all tools - BiPredicate rejectAllFilter = (client, tool) -> false; + McpSyncClientBiPredicate rejectAllFilter = (client, tool) -> false; SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(rejectAllFilter, this.mcpClient); @@ -192,8 +191,7 @@ void toolFilterShouldFilterToolsByNameWhenConfigured() { when(this.mcpClient.listTools()).thenReturn(listToolsResult); // Create a filter that only accepts tools with names containing "2" or "3" - BiPredicate nameFilter = (client, tool) -> tool.name().contains("2") - || tool.name().contains("3"); + McpSyncClientBiPredicate nameFilter = (client, tool) -> tool.name().contains("2") || tool.name().contains("3"); SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(nameFilter, this.mcpClient); @@ -228,8 +226,7 @@ void toolFilterShouldFilterToolsByClientWhenConfigured() { when(mcpClient2.getClientInfo()).thenReturn(clientInfo2); // Create a filter that only accepts tools from client1 - BiPredicate clientFilter = (client, - tool) -> client.getClientInfo().name().equals("testClient1"); + McpSyncClientBiPredicate clientFilter = (client, tool) -> client.getClientInfo().name().equals("testClient1"); SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(clientFilter, mcpClient1, mcpClient2); @@ -256,7 +253,7 @@ void toolFilterShouldCombineClientAndToolCriteriaWhenConfigured() { when(weatherClient.getClientInfo()).thenReturn(weatherClientInfo); // Create a filter that only accepts weather tools from the weather service - BiPredicate complexFilter = (client, + McpSyncClientBiPredicate complexFilter = (client, tool) -> client.getClientInfo().name().equals("weather-service") && tool.name().equals("weather"); SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(complexFilter, weatherClient); From e5be1133237298ec386197faf01406c1d17bd1e9 Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Thu, 24 Jul 2025 23:53:03 +0100 Subject: [PATCH 2/4] Fix javadoc Signed-off-by: Ilayaperumal Gopinathan --- .../autoconfigure/McpToolCallbackAutoConfiguration.java | 2 ++ .../springframework/ai/mcp/AsyncMcpToolCallbackProvider.java | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java index e8a7259b5e0..6904cc56f57 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java @@ -50,6 +50,8 @@ public class McpToolCallbackAutoConfiguration { *

* These callbacks enable integration with Spring AI's tool execution framework, * allowing MCP tools to be used as part of AI interactions. + * @param syncClientsToolFilter list of {@link McpSyncClientBiPredicate}s for the sync + * client to filter the discovered tools * @param syncMcpClients provider of MCP sync clients * @return list of tool callbacks for MCP integration */ diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java index a508fcec1d6..2ff4ddf3d93 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java @@ -79,8 +79,8 @@ public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider { /** * Creates a new {@code AsyncMcpToolCallbackProvider} instance with a list of MCP * clients. - * @param mcpClients the list of MCP clients to use for discovering tools * @param toolFilter a filter to apply to each discovered tool + * @param mcpClients the list of MCP clients to use for discovering tools */ public AsyncMcpToolCallbackProvider(McpAsyncClientBiPredicate toolFilter, List mcpClients) { Assert.notNull(mcpClients, "MCP clients must not be null"); @@ -104,8 +104,8 @@ public AsyncMcpToolCallbackProvider(List mcpClients) { /** * Creates a new {@code AsyncMcpToolCallbackProvider} instance with one or more MCP * clients. - * @param mcpClients the MCP clients to use for discovering tools * @param toolFilter a filter to apply to each discovered tool + * @param mcpClients the MCP clients to use for discovering tools */ public AsyncMcpToolCallbackProvider(McpAsyncClientBiPredicate toolFilter, McpAsyncClient... mcpClients) { this(toolFilter, List.of(mcpClients)); From a28bb411e89a045fa2f069c6bc55bf3267163188 Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Fri, 25 Jul 2025 19:27:31 +0100 Subject: [PATCH 3/4] Introduce MCP ClientMetadata and use it for filtering - Add McpClientMetadata record which contains MCP client/server meta data that can be used for filtering the toolcallbacks - This provides a convenient approach to handling just the metadata from the client - Update the auto-configuration and tests Signed-off-by: Ilayaperumal Gopinathan --- .../McpToolCallbackAutoConfiguration.java | 12 ++-- ...llbackAutoConfigurationConditionTests.java | 56 +++++++------------ .../ai/mcp/AsyncMcpToolCallbackProvider.java | 14 +++-- ...edicate.java => McpClientBiPredicate.java} | 7 ++- ...iPredicate.java => McpClientMetadata.java} | 10 +--- .../ai/mcp/SyncMcpToolCallbackProvider.java | 9 +-- .../mcp/SyncMcpToolCallbackProviderTests.java | 11 ++-- 7 files changed, 51 insertions(+), 68 deletions(-) rename mcp/common/src/main/java/org/springframework/ai/mcp/{McpSyncClientBiPredicate.java => McpClientBiPredicate.java} (80%) rename mcp/common/src/main/java/org/springframework/ai/mcp/{McpAsyncClientBiPredicate.java => McpClientMetadata.java} (69%) diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java index 6904cc56f57..d66c6c778f2 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java @@ -17,18 +17,14 @@ package org.springframework.ai.mcp.client.common.autoconfigure; import java.util.List; -import java.util.function.BiPredicate; import io.modelcontextprotocol.client.McpAsyncClient; import io.modelcontextprotocol.client.McpSyncClient; -import io.modelcontextprotocol.spec.McpSchema; import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider; -import org.springframework.ai.mcp.McpAsyncClientBiPredicate; -import org.springframework.ai.mcp.McpSyncClientBiPredicate; +import org.springframework.ai.mcp.McpClientBiPredicate; import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties; -import org.springframework.ai.tool.annotation.Tool; import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.AllNestedConditions; @@ -50,7 +46,7 @@ public class McpToolCallbackAutoConfiguration { *

* These callbacks enable integration with Spring AI's tool execution framework, * allowing MCP tools to be used as part of AI interactions. - * @param syncClientsToolFilter list of {@link McpSyncClientBiPredicate}s for the sync + * @param syncClientsToolFilter list of {@link McpClientBiPredicate}s for the sync * client to filter the discovered tools * @param syncMcpClients provider of MCP sync clients * @return list of tool callbacks for MCP integration @@ -58,7 +54,7 @@ public class McpToolCallbackAutoConfiguration { @Bean @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", matchIfMissing = true) - public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider syncClientsToolFilter, + public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider syncClientsToolFilter, ObjectProvider> syncMcpClients) { List mcpClients = syncMcpClients.stream().flatMap(List::stream).toList(); return new SyncMcpToolCallbackProvider(syncClientsToolFilter.getIfUnique((() -> (McpSyncClient, tool) -> true)), @@ -68,7 +64,7 @@ public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider asyncClientsToolFilter, + ObjectProvider asyncClientsToolFilter, ObjectProvider> mcpClientsProvider) { List mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList(); return new AsyncMcpToolCallbackProvider( diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java index 82ff854109c..779f530e7f8 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java @@ -26,8 +26,8 @@ import reactor.core.publisher.Mono; import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider; -import org.springframework.ai.mcp.McpAsyncClientBiPredicate; -import org.springframework.ai.mcp.McpSyncClientBiPredicate; +import org.springframework.ai.mcp.McpClientBiPredicate; +import org.springframework.ai.mcp.McpClientMetadata; import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration.McpToolCallbackAutoConfigurationCondition; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -89,14 +89,14 @@ void doesMatchWhenBothPropertiesAreMissing() { @Test void verifySyncToolCallbackFilterConfiguration() { this.contextRunner - .withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpSyncClientFilterConfiguration.class) + .withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpClientFilterConfiguration.class) .withPropertyValues("spring.ai.mcp.client.type=SYNC") .run(context -> { - assertThat(context).hasBean("syncClientFilter"); + assertThat(context).hasBean("mcpClientFilter"); SyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(SyncMcpToolCallbackProvider.class); Field field = SyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter"); field.setAccessible(true); - McpSyncClientBiPredicate toolFilter = (McpSyncClientBiPredicate) field.get(toolCallbackProvider); + McpClientBiPredicate toolFilter = (McpClientBiPredicate) field.get(toolCallbackProvider); McpSyncClient syncClient1 = mock(McpSyncClient.class); var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0"); when(syncClient1.getClientInfo()).thenReturn(clientInfo1); @@ -107,22 +107,24 @@ void verifySyncToolCallbackFilterConfiguration() { McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class); when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2)); when(syncClient1.listTools()).thenReturn(listToolsResult1); - assertThat(toolFilter.test(syncClient1, tool1)).isFalse(); - assertThat(toolFilter.test(syncClient1, tool2)).isTrue(); + assertThat(toolFilter.test(new McpClientMetadata(null, syncClient1.getClientInfo(), null), tool1)) + .isFalse(); + assertThat(toolFilter.test(new McpClientMetadata(null, syncClient1.getClientInfo(), null), tool2)) + .isTrue(); }); } @Test void verifyASyncToolCallbackFilterConfiguration() { this.contextRunner - .withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpAsyncClientFilterConfiguration.class) + .withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpClientFilterConfiguration.class) .withPropertyValues("spring.ai.mcp.client.type=ASYNC") .run(context -> { - assertThat(context).hasBean("asyncClientFilter"); + assertThat(context).hasBean("mcpClientFilter"); AsyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(AsyncMcpToolCallbackProvider.class); Field field = AsyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter"); field.setAccessible(true); - McpAsyncClientBiPredicate toolFilter = (McpAsyncClientBiPredicate) field.get(toolCallbackProvider); + McpClientBiPredicate toolFilter = (McpClientBiPredicate) field.get(toolCallbackProvider); McpAsyncClient asyncClient1 = mock(McpAsyncClient.class); var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0"); when(asyncClient1.getClientInfo()).thenReturn(clientInfo1); @@ -133,8 +135,10 @@ void verifyASyncToolCallbackFilterConfiguration() { McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class); when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2)); when(asyncClient1.listTools()).thenReturn(Mono.just(listToolsResult1)); - assertThat(toolFilter.test(asyncClient1, tool1)).isFalse(); - assertThat(toolFilter.test(asyncClient1, tool2)).isTrue(); + assertThat(toolFilter.test(new McpClientMetadata(null, asyncClient1.getClientInfo(), null), tool1)) + .isFalse(); + assertThat(toolFilter.test(new McpClientMetadata(null, asyncClient1.getClientInfo(), null), tool2)) + .isTrue(); }); } @@ -150,32 +154,14 @@ String testBean() { } @Configuration - static class McpSyncClientFilterConfiguration { + static class McpClientFilterConfiguration { @Bean - McpSyncClientBiPredicate syncClientFilter() { - return new McpSyncClientBiPredicate() { + McpClientBiPredicate mcpClientFilter() { + return new McpClientBiPredicate() { @Override - public boolean test(McpSyncClient mcpSyncClient, McpSchema.Tool tool) { - if (mcpSyncClient.getClientInfo().name().equals("client1") && tool.name().contains("tool1")) { - return false; - } - return true; - } - }; - } - - } - - @Configuration - static class McpAsyncClientFilterConfiguration { - - @Bean - McpAsyncClientBiPredicate asyncClientFilter() { - return new McpAsyncClientBiPredicate() { - @Override - public boolean test(McpAsyncClient mcpAsyncClient, McpSchema.Tool tool) { - if (mcpAsyncClient.getClientInfo().name().equals("client1") && tool.name().contains("tool1")) { + public boolean test(McpClientMetadata clientMetadata, McpSchema.Tool tool) { + if (clientMetadata.clientInfo().name().equals("client1") && tool.name().contains("tool1")) { return false; } return true; diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java index 2ff4ddf3d93..4037f2acfe1 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java @@ -20,6 +20,7 @@ import java.util.List; import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Flux; @@ -72,9 +73,9 @@ */ public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider { - private final List mcpClients; + private final McpClientBiPredicate toolFilter; - private final McpAsyncClientBiPredicate toolFilter; + private final List mcpClients; /** * Creates a new {@code AsyncMcpToolCallbackProvider} instance with a list of MCP @@ -82,11 +83,11 @@ public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider { * @param toolFilter a filter to apply to each discovered tool * @param mcpClients the list of MCP clients to use for discovering tools */ - public AsyncMcpToolCallbackProvider(McpAsyncClientBiPredicate toolFilter, List mcpClients) { + public AsyncMcpToolCallbackProvider(McpClientBiPredicate toolFilter, List mcpClients) { Assert.notNull(mcpClients, "MCP clients must not be null"); Assert.notNull(toolFilter, "Tool filter must not be null"); - this.mcpClients = mcpClients; this.toolFilter = toolFilter; + this.mcpClients = mcpClients; } /** @@ -107,7 +108,7 @@ public AsyncMcpToolCallbackProvider(List mcpClients) { * @param toolFilter a filter to apply to each discovered tool * @param mcpClients the MCP clients to use for discovering tools */ - public AsyncMcpToolCallbackProvider(McpAsyncClientBiPredicate toolFilter, McpAsyncClient... mcpClients) { + public AsyncMcpToolCallbackProvider(McpClientBiPredicate toolFilter, McpAsyncClient... mcpClients) { this(toolFilter, List.of(mcpClients)); } @@ -145,7 +146,8 @@ public ToolCallback[] getToolCallbacks() { ToolCallback[] toolCallbacks = mcpClient.listTools() .map(response -> response.tools() .stream() - .filter(tool -> this.toolFilter.test(mcpClient, tool)) + .filter(tool -> this.toolFilter.test(new McpClientMetadata(mcpClient.getClientCapabilities(), + mcpClient.getClientInfo(), mcpClient.initialize().block()), tool)) .map(tool -> new AsyncMcpToolCallback(mcpClient, tool)) .toArray(ToolCallback[]::new)) .block(); diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpSyncClientBiPredicate.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpClientBiPredicate.java similarity index 80% rename from mcp/common/src/main/java/org/springframework/ai/mcp/McpSyncClientBiPredicate.java rename to mcp/common/src/main/java/org/springframework/ai/mcp/McpClientBiPredicate.java index 2be8ca2405a..1b16267c123 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/McpSyncClientBiPredicate.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpClientBiPredicate.java @@ -22,11 +22,12 @@ import io.modelcontextprotocol.spec.McpSchema; /** - * A {@link BiPredicate} for {@link SyncMcpToolCallbackProvider} to filter the discovered - * tool for the given {@link McpSyncClient}. + * A {@link BiPredicate} for {@link SyncMcpToolCallbackProvider} and the + * {@link AsyncMcpToolCallbackProvider} to filter the discovered tool for the given + * {@link McpClientMetadata}. * * @author Ilayaperumal Gopinathan */ -public interface McpSyncClientBiPredicate extends BiPredicate { +public interface McpClientBiPredicate extends BiPredicate { } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpAsyncClientBiPredicate.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpClientMetadata.java similarity index 69% rename from mcp/common/src/main/java/org/springframework/ai/mcp/McpAsyncClientBiPredicate.java rename to mcp/common/src/main/java/org/springframework/ai/mcp/McpClientMetadata.java index 1e2dfe876f7..127b9b231ac 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/McpAsyncClientBiPredicate.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpClientMetadata.java @@ -16,17 +16,13 @@ package org.springframework.ai.mcp; -import java.util.function.BiPredicate; - -import io.modelcontextprotocol.client.McpAsyncClient; import io.modelcontextprotocol.spec.McpSchema; /** - * A {@link BiPredicate} for {@link AsyncMcpToolCallbackProvider} to filter the discovered - * tool for the given {@link McpAsyncClient}. + * MCP client metadata record containing the client/server specific data. * * @author Ilayaperumal Gopinathan */ -public interface McpAsyncClientBiPredicate extends BiPredicate { - +public record McpClientMetadata(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, + McpSchema.InitializeResult initializeResult) { } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java index 407c50be6ee..ac8b31b93be 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java @@ -70,7 +70,7 @@ public class SyncMcpToolCallbackProvider implements ToolCallbackProvider { private final List mcpClients; - private final McpSyncClientBiPredicate toolFilter; + private final McpClientBiPredicate toolFilter; /** * Creates a new {@code SyncMcpToolCallbackProvider} instance with a list of MCP @@ -78,7 +78,7 @@ public class SyncMcpToolCallbackProvider implements ToolCallbackProvider { * @param mcpClients the list of MCP clients to use for discovering tools * @param toolFilter a filter to apply to each discovered tool */ - public SyncMcpToolCallbackProvider(McpSyncClientBiPredicate toolFilter, List mcpClients) { + public SyncMcpToolCallbackProvider(McpClientBiPredicate toolFilter, List mcpClients) { Assert.notNull(mcpClients, "MCP clients must not be null"); Assert.notNull(toolFilter, "Tool filter must not be null"); this.mcpClients = mcpClients; @@ -100,7 +100,7 @@ public SyncMcpToolCallbackProvider(List mcpClients) { * @param mcpClients the MCP clients to use for discovering tools * @param toolFilter a filter to apply to each discovered tool */ - public SyncMcpToolCallbackProvider(McpSyncClientBiPredicate toolFilter, McpSyncClient... mcpClients) { + public SyncMcpToolCallbackProvider(McpClientBiPredicate toolFilter, McpSyncClient... mcpClients) { this(toolFilter, List.of(mcpClients)); } @@ -131,7 +131,8 @@ public ToolCallback[] getToolCallbacks() { .flatMap(mcpClient -> mcpClient.listTools() .tools() .stream() - .filter(tool -> this.toolFilter.test(mcpClient, tool)) + .filter(tool -> this.toolFilter.test(new McpClientMetadata(mcpClient.getClientCapabilities(), + mcpClient.getClientInfo(), mcpClient.initialize()), tool)) .map(tool -> new SyncMcpToolCallback(mcpClient, tool))) .toArray(ToolCallback[]::new); validateToolCallbacks(array); diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java index b82581c5e09..5220517c7e4 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java @@ -163,7 +163,7 @@ void toolFilterShouldRejectAllToolsWhenConfigured() { when(this.mcpClient.listTools()).thenReturn(listToolsResult); // Create a filter that rejects all tools - McpSyncClientBiPredicate rejectAllFilter = (client, tool) -> false; + McpClientBiPredicate rejectAllFilter = (client, tool) -> false; SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(rejectAllFilter, this.mcpClient); @@ -191,7 +191,7 @@ void toolFilterShouldFilterToolsByNameWhenConfigured() { when(this.mcpClient.listTools()).thenReturn(listToolsResult); // Create a filter that only accepts tools with names containing "2" or "3" - McpSyncClientBiPredicate nameFilter = (client, tool) -> tool.name().contains("2") || tool.name().contains("3"); + McpClientBiPredicate nameFilter = (client, tool) -> tool.name().contains("2") || tool.name().contains("3"); SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(nameFilter, this.mcpClient); @@ -226,7 +226,8 @@ void toolFilterShouldFilterToolsByClientWhenConfigured() { when(mcpClient2.getClientInfo()).thenReturn(clientInfo2); // Create a filter that only accepts tools from client1 - McpSyncClientBiPredicate clientFilter = (client, tool) -> client.getClientInfo().name().equals("testClient1"); + McpClientBiPredicate clientFilter = (clientMetadata, + tool) -> clientMetadata.clientInfo().name().equals("testClient1"); SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(clientFilter, mcpClient1, mcpClient2); @@ -253,8 +254,8 @@ void toolFilterShouldCombineClientAndToolCriteriaWhenConfigured() { when(weatherClient.getClientInfo()).thenReturn(weatherClientInfo); // Create a filter that only accepts weather tools from the weather service - McpSyncClientBiPredicate complexFilter = (client, - tool) -> client.getClientInfo().name().equals("weather-service") && tool.name().equals("weather"); + McpClientBiPredicate complexFilter = (client, tool) -> client.clientInfo().name().equals("weather-service") + && tool.name().equals("weather"); SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(complexFilter, weatherClient); From 523315b10bc67fc41ff08b58d66b9e8ffedc24f1 Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Tue, 29 Jul 2025 10:13:04 +0100 Subject: [PATCH 4/4] Add MCP metadata which would contain MCP client and server metadata - This provides a convenience for the filter to operate against metadata Add McpToolFilter which is of type BiPredicate - The filter configuration would look like this: @Configuration static class McpClientFilterConfiguration { @Bean McpToolFilter mcpClientFilter() { return new McpToolFilter() { @Override public boolean test(McpMetadata metadata, McpSchema.Tool tool) { if (metadata.mcpClientMetadata().clientInfo().name().equals("client1") && tool.name().contains("tool1")) { return false; } return true; } }; } } Signed-off-by: Ilayaperumal Gopinathan --- .../McpToolCallbackAutoConfiguration.java | 11 ++++--- ...llbackAutoConfigurationConditionTests.java | 29 ++++++++++++------- .../ai/mcp/AsyncMcpToolCallbackProvider.java | 12 ++++---- .../ai/mcp/McpClientMetadata.java | 5 ++-- .../springframework/ai/mcp/McpMetadata.java | 25 ++++++++++++++++ .../ai/mcp/McpServerMetadata.java | 27 +++++++++++++++++ ...entBiPredicate.java => McpToolFilter.java} | 5 ++-- .../ai/mcp/SyncMcpToolCallbackProvider.java | 11 +++---- .../mcp/SyncMcpToolCallbackProviderTests.java | 13 +++++---- 9 files changed, 98 insertions(+), 40 deletions(-) create mode 100644 mcp/common/src/main/java/org/springframework/ai/mcp/McpMetadata.java create mode 100644 mcp/common/src/main/java/org/springframework/ai/mcp/McpServerMetadata.java rename mcp/common/src/main/java/org/springframework/ai/mcp/{McpClientBiPredicate.java => McpToolFilter.java} (84%) diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java index d66c6c778f2..005957edc59 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java @@ -22,7 +22,7 @@ import io.modelcontextprotocol.client.McpSyncClient; import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider; -import org.springframework.ai.mcp.McpClientBiPredicate; +import org.springframework.ai.mcp.McpToolFilter; import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties; import org.springframework.beans.factory.ObjectProvider; @@ -46,15 +46,15 @@ public class McpToolCallbackAutoConfiguration { *

* These callbacks enable integration with Spring AI's tool execution framework, * allowing MCP tools to be used as part of AI interactions. - * @param syncClientsToolFilter list of {@link McpClientBiPredicate}s for the sync - * client to filter the discovered tools + * @param syncClientsToolFilter list of {@link McpToolFilter}s for the sync client to + * filter the discovered tools * @param syncMcpClients provider of MCP sync clients * @return list of tool callbacks for MCP integration */ @Bean @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", matchIfMissing = true) - public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider syncClientsToolFilter, + public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider syncClientsToolFilter, ObjectProvider> syncMcpClients) { List mcpClients = syncMcpClients.stream().flatMap(List::stream).toList(); return new SyncMcpToolCallbackProvider(syncClientsToolFilter.getIfUnique((() -> (McpSyncClient, tool) -> true)), @@ -63,8 +63,7 @@ public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider asyncClientsToolFilter, + public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(ObjectProvider asyncClientsToolFilter, ObjectProvider> mcpClientsProvider) { List mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList(); return new AsyncMcpToolCallbackProvider( diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java index 779f530e7f8..23024c2a9a7 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java @@ -26,8 +26,10 @@ import reactor.core.publisher.Mono; import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider; -import org.springframework.ai.mcp.McpClientBiPredicate; +import org.springframework.ai.mcp.McpToolFilter; import org.springframework.ai.mcp.McpClientMetadata; +import org.springframework.ai.mcp.McpMetadata; +import org.springframework.ai.mcp.McpServerMetadata; import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration.McpToolCallbackAutoConfigurationCondition; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -96,7 +98,7 @@ void verifySyncToolCallbackFilterConfiguration() { SyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(SyncMcpToolCallbackProvider.class); Field field = SyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter"); field.setAccessible(true); - McpClientBiPredicate toolFilter = (McpClientBiPredicate) field.get(toolCallbackProvider); + McpToolFilter toolFilter = (McpToolFilter) field.get(toolCallbackProvider); McpSyncClient syncClient1 = mock(McpSyncClient.class); var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0"); when(syncClient1.getClientInfo()).thenReturn(clientInfo1); @@ -107,9 +109,11 @@ void verifySyncToolCallbackFilterConfiguration() { McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class); when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2)); when(syncClient1.listTools()).thenReturn(listToolsResult1); - assertThat(toolFilter.test(new McpClientMetadata(null, syncClient1.getClientInfo(), null), tool1)) + assertThat(toolFilter.test(new McpMetadata(new McpClientMetadata(null, syncClient1.getClientInfo()), + new McpServerMetadata(null)), tool1)) .isFalse(); - assertThat(toolFilter.test(new McpClientMetadata(null, syncClient1.getClientInfo(), null), tool2)) + assertThat(toolFilter.test(new McpMetadata(new McpClientMetadata(null, syncClient1.getClientInfo()), + new McpServerMetadata(null)), tool2)) .isTrue(); }); } @@ -124,7 +128,7 @@ void verifyASyncToolCallbackFilterConfiguration() { AsyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(AsyncMcpToolCallbackProvider.class); Field field = AsyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter"); field.setAccessible(true); - McpClientBiPredicate toolFilter = (McpClientBiPredicate) field.get(toolCallbackProvider); + McpToolFilter toolFilter = (McpToolFilter) field.get(toolCallbackProvider); McpAsyncClient asyncClient1 = mock(McpAsyncClient.class); var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0"); when(asyncClient1.getClientInfo()).thenReturn(clientInfo1); @@ -135,9 +139,11 @@ void verifyASyncToolCallbackFilterConfiguration() { McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class); when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2)); when(asyncClient1.listTools()).thenReturn(Mono.just(listToolsResult1)); - assertThat(toolFilter.test(new McpClientMetadata(null, asyncClient1.getClientInfo(), null), tool1)) + assertThat(toolFilter.test(new McpMetadata(new McpClientMetadata(null, asyncClient1.getClientInfo()), + new McpServerMetadata(null)), tool1)) .isFalse(); - assertThat(toolFilter.test(new McpClientMetadata(null, asyncClient1.getClientInfo(), null), tool2)) + assertThat(toolFilter.test(new McpMetadata(new McpClientMetadata(null, asyncClient1.getClientInfo()), + new McpServerMetadata(null)), tool2)) .isTrue(); }); } @@ -157,11 +163,12 @@ String testBean() { static class McpClientFilterConfiguration { @Bean - McpClientBiPredicate mcpClientFilter() { - return new McpClientBiPredicate() { + McpToolFilter mcpClientFilter() { + return new McpToolFilter() { @Override - public boolean test(McpClientMetadata clientMetadata, McpSchema.Tool tool) { - if (clientMetadata.clientInfo().name().equals("client1") && tool.name().contains("tool1")) { + public boolean test(McpMetadata metadata, McpSchema.Tool tool) { + if (metadata.mcpClientMetadata().clientInfo().name().equals("client1") + && tool.name().contains("tool1")) { return false; } return true; diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java index 4037f2acfe1..e7bc3e3ed98 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java @@ -20,7 +20,6 @@ import java.util.List; import io.modelcontextprotocol.client.McpAsyncClient; -import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Flux; @@ -73,7 +72,7 @@ */ public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider { - private final McpClientBiPredicate toolFilter; + private final McpToolFilter toolFilter; private final List mcpClients; @@ -83,7 +82,7 @@ public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider { * @param toolFilter a filter to apply to each discovered tool * @param mcpClients the list of MCP clients to use for discovering tools */ - public AsyncMcpToolCallbackProvider(McpClientBiPredicate toolFilter, List mcpClients) { + public AsyncMcpToolCallbackProvider(McpToolFilter toolFilter, List mcpClients) { Assert.notNull(mcpClients, "MCP clients must not be null"); Assert.notNull(toolFilter, "Tool filter must not be null"); this.toolFilter = toolFilter; @@ -108,7 +107,7 @@ public AsyncMcpToolCallbackProvider(List mcpClients) { * @param toolFilter a filter to apply to each discovered tool * @param mcpClients the MCP clients to use for discovering tools */ - public AsyncMcpToolCallbackProvider(McpClientBiPredicate toolFilter, McpAsyncClient... mcpClients) { + public AsyncMcpToolCallbackProvider(McpToolFilter toolFilter, McpAsyncClient... mcpClients) { this(toolFilter, List.of(mcpClients)); } @@ -146,8 +145,9 @@ public ToolCallback[] getToolCallbacks() { ToolCallback[] toolCallbacks = mcpClient.listTools() .map(response -> response.tools() .stream() - .filter(tool -> this.toolFilter.test(new McpClientMetadata(mcpClient.getClientCapabilities(), - mcpClient.getClientInfo(), mcpClient.initialize().block()), tool)) + .filter(tool -> this.toolFilter.test(new McpMetadata( + new McpClientMetadata(mcpClient.getClientCapabilities(), mcpClient.getClientInfo()), + new McpServerMetadata(mcpClient.initialize().block())), tool)) .map(tool -> new AsyncMcpToolCallback(mcpClient, tool)) .toArray(ToolCallback[]::new)) .block(); diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpClientMetadata.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpClientMetadata.java index 127b9b231ac..7ceeae32ef1 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/McpClientMetadata.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpClientMetadata.java @@ -19,10 +19,9 @@ import io.modelcontextprotocol.spec.McpSchema; /** - * MCP client metadata record containing the client/server specific data. + * MCP client metadata record. * * @author Ilayaperumal Gopinathan */ -public record McpClientMetadata(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, - McpSchema.InitializeResult initializeResult) { +public record McpClientMetadata(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpMetadata.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpMetadata.java new file mode 100644 index 00000000000..660e811ff0d --- /dev/null +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpMetadata.java @@ -0,0 +1,25 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp; + +/** + * MCP metadata record containing the client/server specific meta data. + * + * @author Ilayaperumal Gopinathan + */ +public record McpMetadata(McpClientMetadata mcpClientMetadata, McpServerMetadata mcpServermetadata) { +} diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpServerMetadata.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpServerMetadata.java new file mode 100644 index 00000000000..57810443f10 --- /dev/null +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpServerMetadata.java @@ -0,0 +1,27 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp; + +import io.modelcontextprotocol.spec.McpSchema; + +/** + * MCP server metadata record. + * + * @author Ilayaperumal Gopinathan + */ +public record McpServerMetadata(McpSchema.InitializeResult initializeResult) { +} diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpClientBiPredicate.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolFilter.java similarity index 84% rename from mcp/common/src/main/java/org/springframework/ai/mcp/McpClientBiPredicate.java rename to mcp/common/src/main/java/org/springframework/ai/mcp/McpToolFilter.java index 1b16267c123..78156dd7d5e 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/McpClientBiPredicate.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolFilter.java @@ -18,16 +18,15 @@ import java.util.function.BiPredicate; -import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.spec.McpSchema; /** * A {@link BiPredicate} for {@link SyncMcpToolCallbackProvider} and the * {@link AsyncMcpToolCallbackProvider} to filter the discovered tool for the given - * {@link McpClientMetadata}. + * {@link McpMetadata}. * * @author Ilayaperumal Gopinathan */ -public interface McpClientBiPredicate extends BiPredicate { +public interface McpToolFilter extends BiPredicate { } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java index ac8b31b93be..e600f1facd1 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java @@ -70,7 +70,7 @@ public class SyncMcpToolCallbackProvider implements ToolCallbackProvider { private final List mcpClients; - private final McpClientBiPredicate toolFilter; + private final McpToolFilter toolFilter; /** * Creates a new {@code SyncMcpToolCallbackProvider} instance with a list of MCP @@ -78,7 +78,7 @@ public class SyncMcpToolCallbackProvider implements ToolCallbackProvider { * @param mcpClients the list of MCP clients to use for discovering tools * @param toolFilter a filter to apply to each discovered tool */ - public SyncMcpToolCallbackProvider(McpClientBiPredicate toolFilter, List mcpClients) { + public SyncMcpToolCallbackProvider(McpToolFilter toolFilter, List mcpClients) { Assert.notNull(mcpClients, "MCP clients must not be null"); Assert.notNull(toolFilter, "Tool filter must not be null"); this.mcpClients = mcpClients; @@ -100,7 +100,7 @@ public SyncMcpToolCallbackProvider(List mcpClients) { * @param mcpClients the MCP clients to use for discovering tools * @param toolFilter a filter to apply to each discovered tool */ - public SyncMcpToolCallbackProvider(McpClientBiPredicate toolFilter, McpSyncClient... mcpClients) { + public SyncMcpToolCallbackProvider(McpToolFilter toolFilter, McpSyncClient... mcpClients) { this(toolFilter, List.of(mcpClients)); } @@ -131,8 +131,9 @@ public ToolCallback[] getToolCallbacks() { .flatMap(mcpClient -> mcpClient.listTools() .tools() .stream() - .filter(tool -> this.toolFilter.test(new McpClientMetadata(mcpClient.getClientCapabilities(), - mcpClient.getClientInfo(), mcpClient.initialize()), tool)) + .filter(tool -> this.toolFilter.test(new McpMetadata( + new McpClientMetadata(mcpClient.getClientCapabilities(), mcpClient.getClientInfo()), + new McpServerMetadata(mcpClient.initialize())), tool)) .map(tool -> new SyncMcpToolCallback(mcpClient, tool))) .toArray(ToolCallback[]::new); validateToolCallbacks(array); diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java index 5220517c7e4..84fe7553959 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java @@ -163,7 +163,7 @@ void toolFilterShouldRejectAllToolsWhenConfigured() { when(this.mcpClient.listTools()).thenReturn(listToolsResult); // Create a filter that rejects all tools - McpClientBiPredicate rejectAllFilter = (client, tool) -> false; + McpToolFilter rejectAllFilter = (client, tool) -> false; SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(rejectAllFilter, this.mcpClient); @@ -191,7 +191,7 @@ void toolFilterShouldFilterToolsByNameWhenConfigured() { when(this.mcpClient.listTools()).thenReturn(listToolsResult); // Create a filter that only accepts tools with names containing "2" or "3" - McpClientBiPredicate nameFilter = (client, tool) -> tool.name().contains("2") || tool.name().contains("3"); + McpToolFilter nameFilter = (client, tool) -> tool.name().contains("2") || tool.name().contains("3"); SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(nameFilter, this.mcpClient); @@ -226,8 +226,8 @@ void toolFilterShouldFilterToolsByClientWhenConfigured() { when(mcpClient2.getClientInfo()).thenReturn(clientInfo2); // Create a filter that only accepts tools from client1 - McpClientBiPredicate clientFilter = (clientMetadata, - tool) -> clientMetadata.clientInfo().name().equals("testClient1"); + McpToolFilter clientFilter = (mcpMetadata, + tool) -> mcpMetadata.mcpClientMetadata().clientInfo().name().equals("testClient1"); SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(clientFilter, mcpClient1, mcpClient2); @@ -254,8 +254,9 @@ void toolFilterShouldCombineClientAndToolCriteriaWhenConfigured() { when(weatherClient.getClientInfo()).thenReturn(weatherClientInfo); // Create a filter that only accepts weather tools from the weather service - McpClientBiPredicate complexFilter = (client, tool) -> client.clientInfo().name().equals("weather-service") - && tool.name().equals("weather"); + McpToolFilter complexFilter = (mcpMetadata, + tool) -> mcpMetadata.mcpClientMetadata().clientInfo().name().equals("weather-service") + && tool.name().equals("weather"); SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(complexFilter, weatherClient);