-
Notifications
You must be signed in to change notification settings - Fork 210
feat(mcp): Add MCP sampling support #2239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Implement comprehensive MCP sampling functionality following the MCP specification. Tested against Anthropic's `everything` server 🤖 Assisted by [Amazon Q Developer](https://aws.amazon.com/q/developer)
@@ -508,6 +514,7 @@ impl ChatSession { | |||
model_id: Option<String>, | |||
tool_config: HashMap<String, ToolSpec>, | |||
interactive: bool, | |||
sampling_receiver: tokio::sync::mpsc::UnboundedReceiver<crate::mcp_client::sampling_ipc::PendingSamplingRequest>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to just bring these types in with use statements.
@@ -480,6 +484,8 @@ pub struct ChatSession { | |||
conversation: ConversationState, | |||
tool_uses: Vec<QueuedTool>, | |||
pending_tool_index: Option<usize>, | |||
/// Channel receiver for incoming sampling requests from MCP servers | |||
sampling_receiver: tokio::sync::mpsc::UnboundedReceiver<crate::mcp_client::sampling_ipc::PendingSamplingRequest>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -293,6 +293,9 @@ impl ChatArgs { | |||
.await?; | |||
let tool_config = tool_manager.load_tools(os, &mut stderr).await?; | |||
|
|||
// Set the ApiClient for MCP clients that have sampling enabled | |||
tool_manager.set_streaming_client(std::sync::Arc::new(os.client.clone())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -2360,6 +2379,12 @@ mod tests { | |||
agents | |||
} | |||
|
|||
#[cfg(test)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This entire module is already configured to only be compiled during test. No need to annotate again here.
// MCP Sampling types | ||
#[derive(Debug, Serialize, Deserialize, Clone)] | ||
#[serde(rename_all = "camelCase")] | ||
pub struct SamplingMessage { | ||
pub role: String, | ||
pub content: SamplingContent, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize, Clone)] | ||
#[serde(tag = "type")] | ||
pub enum SamplingContent { | ||
#[serde(rename = "text")] | ||
Text { text: String }, | ||
#[serde(rename = "image")] | ||
Image { data: String, mime_type: String }, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize, Clone)] | ||
#[serde(rename_all = "camelCase")] | ||
pub struct ModelPreferences { | ||
pub hints: Option<Vec<ModelHint>>, | ||
pub cost_priority: Option<f64>, | ||
pub speed_priority: Option<f64>, | ||
pub intelligence_priority: Option<f64>, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize, Clone)] | ||
pub struct ModelHint { | ||
pub name: String, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
#[serde(rename_all = "camelCase")] | ||
pub struct SamplingRequest { | ||
pub messages: Vec<SamplingMessage>, | ||
pub model_preferences: Option<ModelPreferences>, | ||
pub system_prompt: Option<String>, | ||
pub include_context: Option<String>, // "none" | "thisServer" | "allServers" | ||
pub temperature: Option<f64>, | ||
pub max_tokens: Option<u32>, | ||
pub stop_sequences: Option<Vec<String>>, | ||
pub metadata: Option<serde_json::Value>, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
#[serde(rename_all = "camelCase")] | ||
pub struct SamplingResponse { | ||
pub role: String, | ||
pub content: SamplingContent, | ||
pub model: String, | ||
pub stop_reason: String, | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These can go in crates/chat-cli/src/mcp_client/facilitator_types.rs
@@ -131,6 +193,8 @@ pub struct Client<T: Transport> { | |||
// TODO: move this to tool manager that way all the assets are treated equally | |||
pub prompt_gets: Arc<SyncRwLock<HashMap<String, PromptGet>>>, | |||
pub is_prompts_out_of_date: Arc<AtomicBool>, | |||
sampling_sender: Option<tokio::sync::mpsc::UnboundedSender<crate::mcp_client::sampling_ipc::PendingSamplingRequest>>, | |||
streaming_client: Arc<Mutex<Option<std::sync::Arc<crate::api_client::ApiClient>>>>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a static shared reference of a lock of a nullable static shared reference of a client, whose inners are themselves static shared references.
Can you help me understand why we are putting the Arc<ApiClient>
behind a mutex? It's guarding an Arc<T>
, which in terms of mutability is just analogous to &T. That is already a shared reference.
Also, if this is nullable as we might not actually receive a streaming in some cases OR there exists a period of time during the initialization of the application where we won't have enough information to to assign the Client<T>
with one, Option<Arc<T>>
makes more sense here (as opposed to Arc<Option<T>>
), not that I think we actually need an Arc
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not too familiar with Rust's compiler restrictions (hence "Assisted by Q Developer" :) ), but the overall dataflow here is that streaming_client
is initially empty, and then we update the Mutex to contain the global client from the general chat_cli
environment.
As I understood, since the Mutex is being passed through different classes and called potentially-in-parallel (to get its contents or see that it's empty), to avoid race conditions it needs to be inside an Arc
regardless of whether it's been populated with an atomic api_client
object. Then the internal Arc
is just to reflect the type of the api_client
object passed into the Option.
Is this understanding incorrect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there is a race condition here. The send_message
method takes a shared reference:
pub async fn send_message(&self, conversation: ConversationState) -> Result<SendMessageOutput, ApiClientError> { |
There is no need to lock.
match client_ref.handle_sampling_request(req_clone.clone()).await { | ||
Ok(response) => { | ||
let msg = JsonRpcMessage::Response(response); | ||
if let Err(e) = transport_ref.send(&msg).await { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Admittedly I haven't looked into the use case for sampling a whole lot. I think you mentioned that you are writing a server that will leverage this in some way. Can you tell me a bit about how it leverages what is being sent here? The disconnect I have here is that the content in this jrpc response are unstructured responses that takes whatever form the LLM produces.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The purpose of sampling is to let the MCP server call an LLM, without needing to have its own internal LLM setup/auth/payments/etc for the user to manage. These can be used either to generate some portion of a final tool output sent to the Agent (basically dynamic summarization/prompt-engineering) or to chain them together in structured data-processing pipelines for more reliable behavior than the general user-side Agent could manage.
For instance, one of my use-cases is an MCP debugging tool for our team's service. There are cases where we know data needs to be processed with an LLM, but it either needs to be prompted in some specific way (e.g. specific reasoning patterns for analyzing debug output) or just takes up a lot of context (e.g. anomaly detection in log/metric data).
For these cases, prompting the LLM Agent to do this itself is:
A) Fragile, and prone to hallucination, deviance, or negative impact on the user's prompt engineering (via strong optimization pressure from our own prompts)
B) Potentially expensive in token-count
C) Difficult to prompt-engineer the outputs for continued investigation, since heuristic summarization rarely generalizes and returning large raw datasets causes token inefficiency and hallucinations.
On the other hand, if we can query the LLM ourselves outside the LLM agent's context and history, we get significantly more control over how it's applied to analysis problems, and can use the LLM to summarize and prompt-engineer our output, providing more reliable results with better context window usage.
(The other use-case I mentioned isn't fleshed out enough to describe in detail, but from exploratory prototypes it definitely needs asynchronous sampling, again to control/parallelize LLM calls for analyzing/summarizing domain-specific non-LLM data better than the agent would)
).map_err(|e| ClientError::NegotiationError(format!("Invalid sampling request: {}", e)))?; | ||
|
||
// Check if sampling is enabled for this server | ||
if self.sampling_sender.is_some() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am struggling to see how this is used beyond just being checked for its existence. Can you please help me understand if I had missed it anywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good point. I think we're just using transport_ref
to send the data now, so probably this can be switched out with a boolean.
request.params.unwrap_or_default() | ||
).map_err(|e| ClientError::NegotiationError(format!("Invalid sampling request: {}", e)))?; | ||
|
||
// Check if sampling is enabled for this server |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we perhaps perform all the sanity check at the call sites of handle_sampling_request
? Not entirely sure what other logic you have in mind but if this is the only flow we could do ourselves a favor here and not have to go through the trouble of cloning and deserializing only to find out we don't have what it takes to sample in the first place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When initializing the MCP servers, we need to know whether we'll tell them that the client supports sampling (since I couldn't figure out a workable UX for approving individual sampling requests outside the MCP config); so we at least need sanity checks before and at MCP initialization.
Not sure what other sanity checks you're proposing to move into handle_sampling_request
?
let model_name = request.model_preferences | ||
.as_ref() | ||
.and_then(|prefs| prefs.hints.as_ref()) | ||
.and_then(|hints| hints.first()) | ||
.map(|hint| hint.name.clone()) | ||
.unwrap_or_else(|| "default-model".to_string()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have already obtained this info when we constructed user_message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, good point, we can just get the name from there.
@@ -850,12 +863,24 @@ impl Clone for ToolManager { | |||
is_interactive: self.is_interactive, | |||
mcp_load_record: self.mcp_load_record.clone(), | |||
disabled_servers: self.disabled_servers.clone(), | |||
sampling_request_sender: self.sampling_request_sender.clone(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see how this field is used. Why does ToolManager
need to own an instance of this mpsc sender?
while let Ok(mut sampling_request) = self.sampling_receiver.try_recv() { | ||
tracing::info!(target: "mcp", "Auto-approving sampling request from configured server: {}", sampling_request.server_name); | ||
|
||
// Automatically approve the sampling request | ||
sampling_request.send_approval_result( | ||
crate::mcp_client::sampling_ipc::SamplingApprovalResult::approved() | ||
); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not entirely sure what this is doing. Some questions:
- This is sending back an approval every time?
- What is receiving what
send_approval_result
is sending? - This is being called in the main chat loop under
prompt_user
, which means this will only get run when user is to be prompted (i.e. after Q CLI has responded). Correct me if I am wrong here, but I think sampling would occur during a tool call (i.e. before the user is to have their turn again). To me this is not the right point in the logic flow to check and respond for a sampling request. Let me know if I had misunderstood anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think sampling would occur during a tool call (i.e. before the user is to have their turn again).
The docs don't mention anything either way (though the dataflow diagram doesn't mention an ongoing tool-call when the server initiates sampling), but to my understanding this is incorrect.
MCP servers support long-running operations (see for instance FastMCP's doc on progress reporting, which is meant to support such operations); it seems strange to support sampling for tool-calls but not such long-running operations.
There is also at least one real customer use-case on the MCP discussions where it seems the server proactively starts sampling requests even though the client isn't explicitly expecting them.
What is receiving what send_approval_result is sending?
This was missed in a refactoring, as mentioned in this other thread transport_ref is handling sending data to the server now so we can probably get rid of the sampling_result call. Will test that this works when refactoring.
This is sending back an approval every time?
We're filtering whether a server has access to sampling by whether we provide a sampling callback at all, since right now it's a binary trust/don't-trust situation. If a server is able to make a sampling request and reach this section, it should be allowed.
Supporting manual approve/deny is difficult from a CLI interface given we only have one input-stream (and we can't rely on things like popups working without breaking customer orchestrations, either).
Adding approvals only during tool calls and then rejecting untrusted sampling requests outside that context is maybe doable, but I don't think it should block an initial implementation of the feature.
EDIT: The code-as-written of the python SDK also seems to support sampling at arbitrary times with a similar callback-based approach as the one we're using here (assuming I'm reading it correctly).
@@ -284,7 +284,7 @@ impl ChatArgs { | |||
info!(?conversation_id, "Generated new conversation id"); | |||
let (prompt_request_sender, prompt_request_receiver) = std::sync::mpsc::channel::<Option<String>>(); | |||
let (prompt_response_sender, prompt_response_receiver) = std::sync::mpsc::channel::<Vec<String>>(); | |||
let mut tool_manager = ToolManagerBuilder::default() | |||
let (mut tool_manager, sampling_receiver) = ToolManagerBuilder::default() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had left my question below on how sampling_receiver
is used. That question aside, I think sampling_receiver
can just be owned by ToolManager
.
Implement comprehensive MCP sampling functionality following the MCP specification.
Tested against Anthropic's
everything
server🤖 Assisted by Amazon Q Developer
Description of changes:
This PR adds support for sampling, which allows MCP servers to make LLM calls using the Agent's own LLM backend.
This is currently supported by very few clients, but has been a feature of the MCP protocol since the initial 2024-11-05 version.
It is also a very useful feature for MCP authors, who often have a use for writing their own subagents, summarizing information without using up chat context, guiding prompts better than the LLM Agent might, prompt-engineering to better control an Agent's response to their output, etc. Sampling can significantly improve the quality of an MCP server, and the server can still fall back to non-sampling approaches if some other client doesn't support it.
UX: To enable sampling, you set
sampling: true
in the config for an MCP server, after which any sampling requests are acceptedsampling: true
.Tested against Anthropic's
everything
server.sampleLLM
successfully generates responses and seems to also respect token limits.By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.