Skip to content

Add MCP Client tool predicate for filtering the MCP tools #3901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.modelcontextprotocol.client.McpSyncClient;

import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
import org.springframework.ai.mcp.McpToolFilter;
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties;
import org.springframework.beans.factory.ObjectProvider;
Expand All @@ -45,22 +46,28 @@ public class McpToolCallbackAutoConfiguration {
* <p>
* These callbacks enable integration with Spring AI's tool execution framework,
* allowing MCP tools to be used as part of AI interactions.
* @param syncClientsToolFilter list of {@link McpToolFilter}s for the sync client to
* filter the discovered tools
* @param syncMcpClients provider of MCP sync clients
* @return list of tool callbacks for MCP integration
*/
@Bean
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC",
matchIfMissing = true)
public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider<List<McpSyncClient>> syncMcpClients) {
public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider<McpToolFilter> syncClientsToolFilter,
ObjectProvider<List<McpSyncClient>> syncMcpClients) {
List<McpSyncClient> mcpClients = syncMcpClients.stream().flatMap(List::stream).toList();
return new SyncMcpToolCallbackProvider(mcpClients);
return new SyncMcpToolCallbackProvider(syncClientsToolFilter.getIfUnique((() -> (McpSyncClient, tool) -> true)),
mcpClients);
}

@Bean
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC")
public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(ObjectProvider<List<McpAsyncClient>> mcpClientsProvider) {
public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(ObjectProvider<McpToolFilter> asyncClientsToolFilter,
ObjectProvider<List<McpAsyncClient>> mcpClientsProvider) {
List<McpAsyncClient> mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList();
return new AsyncMcpToolCallbackProvider(mcpClients);
return new AsyncMcpToolCallbackProvider(
asyncClientsToolFilter.getIfUnique(() -> (McpAsyncClient, tool) -> true), mcpClients);
}

public static class McpToolCallbackAutoConfigurationCondition extends AllNestedConditions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,30 @@

package org.springframework.ai.mcp.client.common.autoconfigure;

import org.junit.jupiter.api.Test;
import java.lang.reflect.Field;
import java.util.List;

import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.spec.McpSchema;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;

import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
import org.springframework.ai.mcp.McpToolFilter;
import org.springframework.ai.mcp.McpClientMetadata;
import org.springframework.ai.mcp.McpMetadata;
import org.springframework.ai.mcp.McpServerMetadata;
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration.McpToolCallbackAutoConfigurationCondition;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Conditional;
import org.springframework.context.annotation.Configuration;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

/**
* Tests for {@link McpToolCallbackAutoConfigurationCondition}.
Expand Down Expand Up @@ -73,6 +88,66 @@ void doesMatchWhenBothPropertiesAreMissing() {
this.contextRunner.run(context -> assertThat(context).hasBean("testBean"));
}

@Test
void verifySyncToolCallbackFilterConfiguration() {
this.contextRunner
.withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpClientFilterConfiguration.class)
.withPropertyValues("spring.ai.mcp.client.type=SYNC")
.run(context -> {
assertThat(context).hasBean("mcpClientFilter");
SyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(SyncMcpToolCallbackProvider.class);
Field field = SyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter");
field.setAccessible(true);
McpToolFilter toolFilter = (McpToolFilter) field.get(toolCallbackProvider);
McpSyncClient syncClient1 = mock(McpSyncClient.class);
var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0");
when(syncClient1.getClientInfo()).thenReturn(clientInfo1);
McpSchema.Tool tool1 = mock(McpSchema.Tool.class);
when(tool1.name()).thenReturn("tool1");
McpSchema.Tool tool2 = mock(McpSchema.Tool.class);
when(tool2.name()).thenReturn("tool2");
McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class);
when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2));
when(syncClient1.listTools()).thenReturn(listToolsResult1);
assertThat(toolFilter.test(new McpMetadata(new McpClientMetadata(null, syncClient1.getClientInfo()),
new McpServerMetadata(null)), tool1))
.isFalse();
assertThat(toolFilter.test(new McpMetadata(new McpClientMetadata(null, syncClient1.getClientInfo()),
new McpServerMetadata(null)), tool2))
.isTrue();
});
}

@Test
void verifyASyncToolCallbackFilterConfiguration() {
this.contextRunner
.withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpClientFilterConfiguration.class)
.withPropertyValues("spring.ai.mcp.client.type=ASYNC")
.run(context -> {
assertThat(context).hasBean("mcpClientFilter");
AsyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(AsyncMcpToolCallbackProvider.class);
Field field = AsyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter");
field.setAccessible(true);
McpToolFilter toolFilter = (McpToolFilter) field.get(toolCallbackProvider);
McpAsyncClient asyncClient1 = mock(McpAsyncClient.class);
var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0");
when(asyncClient1.getClientInfo()).thenReturn(clientInfo1);
McpSchema.Tool tool1 = mock(McpSchema.Tool.class);
when(tool1.name()).thenReturn("tool1");
McpSchema.Tool tool2 = mock(McpSchema.Tool.class);
when(tool2.name()).thenReturn("tool2");
McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class);
when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2));
when(asyncClient1.listTools()).thenReturn(Mono.just(listToolsResult1));
assertThat(toolFilter.test(new McpMetadata(new McpClientMetadata(null, asyncClient1.getClientInfo()),
new McpServerMetadata(null)), tool1))
.isFalse();
assertThat(toolFilter.test(new McpMetadata(new McpClientMetadata(null, asyncClient1.getClientInfo()),
new McpServerMetadata(null)), tool2))
.isTrue();
});
}

@Configuration
@Conditional(McpToolCallbackAutoConfigurationCondition.class)
static class TestConfiguration {
Expand All @@ -84,4 +159,23 @@ String testBean() {

}

@Configuration
static class McpClientFilterConfiguration {

@Bean
McpToolFilter mcpClientFilter() {
return new McpToolFilter() {
@Override
public boolean test(McpMetadata metadata, McpSchema.Tool tool) {
if (metadata.mcpClientMetadata().clientInfo().name().equals("client1")
&& tool.name().contains("tool1")) {
return false;
}
return true;
}
};
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@

import java.util.ArrayList;
import java.util.List;
import java.util.function.BiPredicate;

import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.spec.McpSchema.Tool;
import io.modelcontextprotocol.util.Assert;
import reactor.core.publisher.Flux;

Expand Down Expand Up @@ -74,21 +72,21 @@
*/
public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider {

private final List<McpAsyncClient> mcpClients;
private final McpToolFilter toolFilter;

private final BiPredicate<McpAsyncClient, Tool> toolFilter;
private final List<McpAsyncClient> mcpClients;

/**
* Creates a new {@code AsyncMcpToolCallbackProvider} instance with a list of MCP
* clients.
* @param mcpClients the list of MCP clients to use for discovering tools
* @param toolFilter a filter to apply to each discovered tool
* @param mcpClients the list of MCP clients to use for discovering tools
*/
public AsyncMcpToolCallbackProvider(BiPredicate<McpAsyncClient, Tool> toolFilter, List<McpAsyncClient> mcpClients) {
public AsyncMcpToolCallbackProvider(McpToolFilter toolFilter, List<McpAsyncClient> mcpClients) {
Assert.notNull(mcpClients, "MCP clients must not be null");
Assert.notNull(toolFilter, "Tool filter must not be null");
this.mcpClients = mcpClients;
this.toolFilter = toolFilter;
this.mcpClients = mcpClients;
}

/**
Expand All @@ -106,10 +104,10 @@ public AsyncMcpToolCallbackProvider(List<McpAsyncClient> mcpClients) {
/**
* Creates a new {@code AsyncMcpToolCallbackProvider} instance with one or more MCP
* clients.
* @param mcpClients the MCP clients to use for discovering tools
* @param toolFilter a filter to apply to each discovered tool
* @param mcpClients the MCP clients to use for discovering tools
*/
public AsyncMcpToolCallbackProvider(BiPredicate<McpAsyncClient, Tool> toolFilter, McpAsyncClient... mcpClients) {
public AsyncMcpToolCallbackProvider(McpToolFilter toolFilter, McpAsyncClient... mcpClients) {
this(toolFilter, List.of(mcpClients));
}

Expand Down Expand Up @@ -147,7 +145,9 @@ public ToolCallback[] getToolCallbacks() {
ToolCallback[] toolCallbacks = mcpClient.listTools()
.map(response -> response.tools()
.stream()
.filter(tool -> this.toolFilter.test(mcpClient, tool))
.filter(tool -> this.toolFilter.test(new McpMetadata(
new McpClientMetadata(mcpClient.getClientCapabilities(), mcpClient.getClientInfo()),
new McpServerMetadata(mcpClient.initialize().block())), tool))
.map(tool -> new AsyncMcpToolCallback(mcpClient, tool))
.toArray(ToolCallback[]::new))
.block();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright 2025-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.mcp;

import io.modelcontextprotocol.spec.McpSchema;

/**
* MCP client metadata record.
*
* @author Ilayaperumal Gopinathan
*/
public record McpClientMetadata(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright 2025-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.mcp;

/**
* MCP metadata record containing the client/server specific meta data.
*
* @author Ilayaperumal Gopinathan
*/
public record McpMetadata(McpClientMetadata mcpClientMetadata, McpServerMetadata mcpServermetadata) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright 2025-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.mcp;

import io.modelcontextprotocol.spec.McpSchema;

/**
* MCP server metadata record.
*
* @author Ilayaperumal Gopinathan
*/
public record McpServerMetadata(McpSchema.InitializeResult initializeResult) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright 2025-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.mcp;

import java.util.function.BiPredicate;

import io.modelcontextprotocol.spec.McpSchema;

/**
* A {@link BiPredicate} for {@link SyncMcpToolCallbackProvider} and the
* {@link AsyncMcpToolCallbackProvider} to filter the discovered tool for the given
* {@link McpMetadata}.
*
* @author Ilayaperumal Gopinathan
*/
public interface McpToolFilter extends BiPredicate<McpMetadata, McpSchema.Tool> {

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
package org.springframework.ai.mcp;

import java.util.List;
import java.util.function.BiPredicate;

import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.spec.McpSchema.Tool;

import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
Expand Down Expand Up @@ -72,15 +70,15 @@ public class SyncMcpToolCallbackProvider implements ToolCallbackProvider {

private final List<McpSyncClient> mcpClients;

private final BiPredicate<McpSyncClient, Tool> toolFilter;
private final McpToolFilter toolFilter;

/**
* Creates a new {@code SyncMcpToolCallbackProvider} instance with a list of MCP
* clients.
* @param mcpClients the list of MCP clients to use for discovering tools
* @param toolFilter a filter to apply to each discovered tool
*/
public SyncMcpToolCallbackProvider(BiPredicate<McpSyncClient, Tool> toolFilter, List<McpSyncClient> mcpClients) {
public SyncMcpToolCallbackProvider(McpToolFilter toolFilter, List<McpSyncClient> mcpClients) {
Assert.notNull(mcpClients, "MCP clients must not be null");
Assert.notNull(toolFilter, "Tool filter must not be null");
this.mcpClients = mcpClients;
Expand All @@ -102,7 +100,7 @@ public SyncMcpToolCallbackProvider(List<McpSyncClient> mcpClients) {
* @param mcpClients the MCP clients to use for discovering tools
* @param toolFilter a filter to apply to each discovered tool
*/
public SyncMcpToolCallbackProvider(BiPredicate<McpSyncClient, Tool> toolFilter, McpSyncClient... mcpClients) {
public SyncMcpToolCallbackProvider(McpToolFilter toolFilter, McpSyncClient... mcpClients) {
this(toolFilter, List.of(mcpClients));
}

Expand Down Expand Up @@ -133,7 +131,9 @@ public ToolCallback[] getToolCallbacks() {
.flatMap(mcpClient -> mcpClient.listTools()
.tools()
.stream()
.filter(tool -> this.toolFilter.test(mcpClient, tool))
.filter(tool -> this.toolFilter.test(new McpMetadata(
new McpClientMetadata(mcpClient.getClientCapabilities(), mcpClient.getClientInfo()),
new McpServerMetadata(mcpClient.initialize())), tool))
.map(tool -> new SyncMcpToolCallback(mcpClient, tool)))
.toArray(ToolCallback[]::new);
validateToolCallbacks(array);
Expand Down
Loading