diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java index a477af8a47a..005957edc59 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java @@ -22,6 +22,7 @@ import io.modelcontextprotocol.client.McpSyncClient; import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider; +import org.springframework.ai.mcp.McpToolFilter; import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties; import org.springframework.beans.factory.ObjectProvider; @@ -45,22 +46,28 @@ public class McpToolCallbackAutoConfiguration { *

* These callbacks enable integration with Spring AI's tool execution framework, * allowing MCP tools to be used as part of AI interactions. + * @param syncClientsToolFilter list of {@link McpToolFilter}s for the sync client to + * filter the discovered tools * @param syncMcpClients provider of MCP sync clients * @return list of tool callbacks for MCP integration */ @Bean @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", matchIfMissing = true) - public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider> syncMcpClients) { + public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider syncClientsToolFilter, + ObjectProvider> syncMcpClients) { List mcpClients = syncMcpClients.stream().flatMap(List::stream).toList(); - return new SyncMcpToolCallbackProvider(mcpClients); + return new SyncMcpToolCallbackProvider(syncClientsToolFilter.getIfUnique((() -> (McpSyncClient, tool) -> true)), + mcpClients); } @Bean @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") - public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(ObjectProvider> mcpClientsProvider) { + public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(ObjectProvider asyncClientsToolFilter, + ObjectProvider> mcpClientsProvider) { List mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList(); - return new AsyncMcpToolCallbackProvider(mcpClients); + return new AsyncMcpToolCallbackProvider( + asyncClientsToolFilter.getIfUnique(() -> (McpAsyncClient, tool) -> true), mcpClients); } public static class McpToolCallbackAutoConfigurationCondition extends AllNestedConditions { diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java index 3708e0fa036..23024c2a9a7 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java @@ -16,8 +16,21 @@ package org.springframework.ai.mcp.client.common.autoconfigure; -import org.junit.jupiter.api.Test; +import java.lang.reflect.Field; +import java.util.List; +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; + +import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider; +import org.springframework.ai.mcp.McpToolFilter; +import org.springframework.ai.mcp.McpClientMetadata; +import org.springframework.ai.mcp.McpMetadata; +import org.springframework.ai.mcp.McpServerMetadata; +import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration.McpToolCallbackAutoConfigurationCondition; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; @@ -25,6 +38,8 @@ import org.springframework.context.annotation.Configuration; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Tests for {@link McpToolCallbackAutoConfigurationCondition}. @@ -73,6 +88,66 @@ void doesMatchWhenBothPropertiesAreMissing() { this.contextRunner.run(context -> assertThat(context).hasBean("testBean")); } + @Test + void verifySyncToolCallbackFilterConfiguration() { + this.contextRunner + .withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpClientFilterConfiguration.class) + .withPropertyValues("spring.ai.mcp.client.type=SYNC") + .run(context -> { + assertThat(context).hasBean("mcpClientFilter"); + SyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(SyncMcpToolCallbackProvider.class); + Field field = SyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter"); + field.setAccessible(true); + McpToolFilter toolFilter = (McpToolFilter) field.get(toolCallbackProvider); + McpSyncClient syncClient1 = mock(McpSyncClient.class); + var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0"); + when(syncClient1.getClientInfo()).thenReturn(clientInfo1); + McpSchema.Tool tool1 = mock(McpSchema.Tool.class); + when(tool1.name()).thenReturn("tool1"); + McpSchema.Tool tool2 = mock(McpSchema.Tool.class); + when(tool2.name()).thenReturn("tool2"); + McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class); + when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2)); + when(syncClient1.listTools()).thenReturn(listToolsResult1); + assertThat(toolFilter.test(new McpMetadata(new McpClientMetadata(null, syncClient1.getClientInfo()), + new McpServerMetadata(null)), tool1)) + .isFalse(); + assertThat(toolFilter.test(new McpMetadata(new McpClientMetadata(null, syncClient1.getClientInfo()), + new McpServerMetadata(null)), tool2)) + .isTrue(); + }); + } + + @Test + void verifyASyncToolCallbackFilterConfiguration() { + this.contextRunner + .withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpClientFilterConfiguration.class) + .withPropertyValues("spring.ai.mcp.client.type=ASYNC") + .run(context -> { + assertThat(context).hasBean("mcpClientFilter"); + AsyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(AsyncMcpToolCallbackProvider.class); + Field field = AsyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter"); + field.setAccessible(true); + McpToolFilter toolFilter = (McpToolFilter) field.get(toolCallbackProvider); + McpAsyncClient asyncClient1 = mock(McpAsyncClient.class); + var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0"); + when(asyncClient1.getClientInfo()).thenReturn(clientInfo1); + McpSchema.Tool tool1 = mock(McpSchema.Tool.class); + when(tool1.name()).thenReturn("tool1"); + McpSchema.Tool tool2 = mock(McpSchema.Tool.class); + when(tool2.name()).thenReturn("tool2"); + McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class); + when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2)); + when(asyncClient1.listTools()).thenReturn(Mono.just(listToolsResult1)); + assertThat(toolFilter.test(new McpMetadata(new McpClientMetadata(null, asyncClient1.getClientInfo()), + new McpServerMetadata(null)), tool1)) + .isFalse(); + assertThat(toolFilter.test(new McpMetadata(new McpClientMetadata(null, asyncClient1.getClientInfo()), + new McpServerMetadata(null)), tool2)) + .isTrue(); + }); + } + @Configuration @Conditional(McpToolCallbackAutoConfigurationCondition.class) static class TestConfiguration { @@ -84,4 +159,23 @@ String testBean() { } + @Configuration + static class McpClientFilterConfiguration { + + @Bean + McpToolFilter mcpClientFilter() { + return new McpToolFilter() { + @Override + public boolean test(McpMetadata metadata, McpSchema.Tool tool) { + if (metadata.mcpClientMetadata().clientInfo().name().equals("client1") + && tool.name().contains("tool1")) { + return false; + } + return true; + } + }; + } + + } + } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java index 3525b9593e3..e7bc3e3ed98 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java @@ -18,10 +18,8 @@ import java.util.ArrayList; import java.util.List; -import java.util.function.BiPredicate; import io.modelcontextprotocol.client.McpAsyncClient; -import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Flux; @@ -74,21 +72,21 @@ */ public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider { - private final List mcpClients; + private final McpToolFilter toolFilter; - private final BiPredicate toolFilter; + private final List mcpClients; /** * Creates a new {@code AsyncMcpToolCallbackProvider} instance with a list of MCP * clients. - * @param mcpClients the list of MCP clients to use for discovering tools * @param toolFilter a filter to apply to each discovered tool + * @param mcpClients the list of MCP clients to use for discovering tools */ - public AsyncMcpToolCallbackProvider(BiPredicate toolFilter, List mcpClients) { + public AsyncMcpToolCallbackProvider(McpToolFilter toolFilter, List mcpClients) { Assert.notNull(mcpClients, "MCP clients must not be null"); Assert.notNull(toolFilter, "Tool filter must not be null"); - this.mcpClients = mcpClients; this.toolFilter = toolFilter; + this.mcpClients = mcpClients; } /** @@ -106,10 +104,10 @@ public AsyncMcpToolCallbackProvider(List mcpClients) { /** * Creates a new {@code AsyncMcpToolCallbackProvider} instance with one or more MCP * clients. - * @param mcpClients the MCP clients to use for discovering tools * @param toolFilter a filter to apply to each discovered tool + * @param mcpClients the MCP clients to use for discovering tools */ - public AsyncMcpToolCallbackProvider(BiPredicate toolFilter, McpAsyncClient... mcpClients) { + public AsyncMcpToolCallbackProvider(McpToolFilter toolFilter, McpAsyncClient... mcpClients) { this(toolFilter, List.of(mcpClients)); } @@ -147,7 +145,9 @@ public ToolCallback[] getToolCallbacks() { ToolCallback[] toolCallbacks = mcpClient.listTools() .map(response -> response.tools() .stream() - .filter(tool -> this.toolFilter.test(mcpClient, tool)) + .filter(tool -> this.toolFilter.test(new McpMetadata( + new McpClientMetadata(mcpClient.getClientCapabilities(), mcpClient.getClientInfo()), + new McpServerMetadata(mcpClient.initialize().block())), tool)) .map(tool -> new AsyncMcpToolCallback(mcpClient, tool)) .toArray(ToolCallback[]::new)) .block(); diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpClientMetadata.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpClientMetadata.java new file mode 100644 index 00000000000..7ceeae32ef1 --- /dev/null +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpClientMetadata.java @@ -0,0 +1,27 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp; + +import io.modelcontextprotocol.spec.McpSchema; + +/** + * MCP client metadata record. + * + * @author Ilayaperumal Gopinathan + */ +public record McpClientMetadata(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { +} diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpMetadata.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpMetadata.java new file mode 100644 index 00000000000..660e811ff0d --- /dev/null +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpMetadata.java @@ -0,0 +1,25 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp; + +/** + * MCP metadata record containing the client/server specific meta data. + * + * @author Ilayaperumal Gopinathan + */ +public record McpMetadata(McpClientMetadata mcpClientMetadata, McpServerMetadata mcpServermetadata) { +} diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpServerMetadata.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpServerMetadata.java new file mode 100644 index 00000000000..57810443f10 --- /dev/null +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpServerMetadata.java @@ -0,0 +1,27 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp; + +import io.modelcontextprotocol.spec.McpSchema; + +/** + * MCP server metadata record. + * + * @author Ilayaperumal Gopinathan + */ +public record McpServerMetadata(McpSchema.InitializeResult initializeResult) { +} diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolFilter.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolFilter.java new file mode 100644 index 00000000000..78156dd7d5e --- /dev/null +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolFilter.java @@ -0,0 +1,32 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp; + +import java.util.function.BiPredicate; + +import io.modelcontextprotocol.spec.McpSchema; + +/** + * A {@link BiPredicate} for {@link SyncMcpToolCallbackProvider} and the + * {@link AsyncMcpToolCallbackProvider} to filter the discovered tool for the given + * {@link McpMetadata}. + * + * @author Ilayaperumal Gopinathan + */ +public interface McpToolFilter extends BiPredicate { + +} diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java index 7d0aa4276a1..e600f1facd1 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java @@ -17,10 +17,8 @@ package org.springframework.ai.mcp; import java.util.List; -import java.util.function.BiPredicate; import io.modelcontextprotocol.client.McpSyncClient; -import io.modelcontextprotocol.spec.McpSchema.Tool; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.ToolCallbackProvider; @@ -72,7 +70,7 @@ public class SyncMcpToolCallbackProvider implements ToolCallbackProvider { private final List mcpClients; - private final BiPredicate toolFilter; + private final McpToolFilter toolFilter; /** * Creates a new {@code SyncMcpToolCallbackProvider} instance with a list of MCP @@ -80,7 +78,7 @@ public class SyncMcpToolCallbackProvider implements ToolCallbackProvider { * @param mcpClients the list of MCP clients to use for discovering tools * @param toolFilter a filter to apply to each discovered tool */ - public SyncMcpToolCallbackProvider(BiPredicate toolFilter, List mcpClients) { + public SyncMcpToolCallbackProvider(McpToolFilter toolFilter, List mcpClients) { Assert.notNull(mcpClients, "MCP clients must not be null"); Assert.notNull(toolFilter, "Tool filter must not be null"); this.mcpClients = mcpClients; @@ -102,7 +100,7 @@ public SyncMcpToolCallbackProvider(List mcpClients) { * @param mcpClients the MCP clients to use for discovering tools * @param toolFilter a filter to apply to each discovered tool */ - public SyncMcpToolCallbackProvider(BiPredicate toolFilter, McpSyncClient... mcpClients) { + public SyncMcpToolCallbackProvider(McpToolFilter toolFilter, McpSyncClient... mcpClients) { this(toolFilter, List.of(mcpClients)); } @@ -133,7 +131,9 @@ public ToolCallback[] getToolCallbacks() { .flatMap(mcpClient -> mcpClient.listTools() .tools() .stream() - .filter(tool -> this.toolFilter.test(mcpClient, tool)) + .filter(tool -> this.toolFilter.test(new McpMetadata( + new McpClientMetadata(mcpClient.getClientCapabilities(), mcpClient.getClientInfo()), + new McpServerMetadata(mcpClient.initialize())), tool)) .map(tool -> new SyncMcpToolCallback(mcpClient, tool))) .toArray(ToolCallback[]::new); validateToolCallbacks(array); diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java index d8830d1718a..84fe7553959 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java @@ -17,7 +17,6 @@ package org.springframework.ai.mcp; import java.util.List; -import java.util.function.BiPredicate; import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.spec.McpSchema.Implementation; @@ -164,7 +163,7 @@ void toolFilterShouldRejectAllToolsWhenConfigured() { when(this.mcpClient.listTools()).thenReturn(listToolsResult); // Create a filter that rejects all tools - BiPredicate rejectAllFilter = (client, tool) -> false; + McpToolFilter rejectAllFilter = (client, tool) -> false; SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(rejectAllFilter, this.mcpClient); @@ -192,8 +191,7 @@ void toolFilterShouldFilterToolsByNameWhenConfigured() { when(this.mcpClient.listTools()).thenReturn(listToolsResult); // Create a filter that only accepts tools with names containing "2" or "3" - BiPredicate nameFilter = (client, tool) -> tool.name().contains("2") - || tool.name().contains("3"); + McpToolFilter nameFilter = (client, tool) -> tool.name().contains("2") || tool.name().contains("3"); SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(nameFilter, this.mcpClient); @@ -228,8 +226,8 @@ void toolFilterShouldFilterToolsByClientWhenConfigured() { when(mcpClient2.getClientInfo()).thenReturn(clientInfo2); // Create a filter that only accepts tools from client1 - BiPredicate clientFilter = (client, - tool) -> client.getClientInfo().name().equals("testClient1"); + McpToolFilter clientFilter = (mcpMetadata, + tool) -> mcpMetadata.mcpClientMetadata().clientInfo().name().equals("testClient1"); SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(clientFilter, mcpClient1, mcpClient2); @@ -256,8 +254,9 @@ void toolFilterShouldCombineClientAndToolCriteriaWhenConfigured() { when(weatherClient.getClientInfo()).thenReturn(weatherClientInfo); // Create a filter that only accepts weather tools from the weather service - BiPredicate complexFilter = (client, - tool) -> client.getClientInfo().name().equals("weather-service") && tool.name().equals("weather"); + McpToolFilter complexFilter = (mcpMetadata, + tool) -> mcpMetadata.mcpClientMetadata().clientInfo().name().equals("weather-service") + && tool.name().equals("weather"); SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(complexFilter, weatherClient);