diff --git a/src/mcp_agent/agents/workflow/chain_agent.py b/src/mcp_agent/agents/workflow/chain_agent.py index 70a5604b..d8551de9 100644 --- a/src/mcp_agent/agents/workflow/chain_agent.py +++ b/src/mcp_agent/agents/workflow/chain_agent.py @@ -71,12 +71,21 @@ async def generate( # # Get the original user message (last message in the list) user_message = multipart_messages[-1] if multipart_messages else None + aggregator = getattr(self.context, "response_aggregator", None) + if not self.cumulative: response: PromptMessageMultipart = await self.agents[0].generate(multipart_messages) + if aggregator: + await aggregator.add_agent_response(self.agents[0].name, response.all_text()) # Process the rest of the agents in the chain for agent in self.agents[1:]: next_message = Prompt.user(*response.content) response = await agent.generate([next_message]) + if aggregator: + await aggregator.add_agent_response(agent.name, response.all_text()) + + if aggregator and await aggregator.should_send_response(): + await aggregator.get_aggregated_response() return response @@ -96,6 +105,8 @@ async def generate( chain_messages = multipart_messages.copy() chain_messages.extend(all_responses) current_response = await agent.generate(chain_messages, request_params) + if aggregator: + await aggregator.add_agent_response(agent.name, current_response.all_text()) # Store the response all_responses.append(current_response) @@ -111,10 +122,16 @@ async def generate( # For cumulative mode, return the properly formatted output with XML tags response_text = "\n\n".join(final_results) - return PromptMessageMultipart( + final_message = PromptMessageMultipart( role="assistant", content=[TextContent(type="text", text=response_text)], ) + if aggregator: + await aggregator.add_agent_response(self.name, response_text) + if await aggregator.should_send_response(): + await aggregator.get_aggregated_response() + + return final_message async def structured( self, diff --git a/src/mcp_agent/mcp_server/agent_server.py b/src/mcp_agent/mcp_server/agent_server.py index a0e1f972..cf0c240b 100644 --- a/src/mcp_agent/mcp_server/agent_server.py +++ b/src/mcp_agent/mcp_server/agent_server.py @@ -68,18 +68,29 @@ def register_agent_tools(self, agent_name: str, agent) -> None: ) async def send_message(message: str, ctx: MCPContext) -> str: """Send a message to the agent and return its response.""" - # Get the agent's context + from mcp_agent.agents.workflow.chain_agent import ChainAgent + + # For chain agents, handle execution without SSE aggregation + if isinstance(agent, ChainAgent): + response = await agent.send(message) + + if hasattr(response, "all_text"): + return response.all_text() + elif isinstance(response, dict): + import json + + return json.dumps(response) + return str(response) + + # Non-chain agents use normal flow agent_context = getattr(agent, "context", None) - # Define the function to execute async def execute_send(): return await agent.send(message) - # Execute with bridged context if agent_context and ctx: - return await self.with_bridged_context(agent_context, ctx, execute_send) - else: - return await execute_send() + return await self.with_bridged_context(agent_context, ctx, execute_send) + return await execute_send() # Register a history prompt for this agent @self.mcp_server.prompt( @@ -368,7 +379,14 @@ async def _close_sse_connections(self): except Exception as e: logger.error(f"Error during ASGI lifespan shutdown: {e}") - async def with_bridged_context(self, agent_context, mcp_context, func, *args, **kwargs): + async def with_bridged_context( + self, + agent_context, + mcp_context, + func, + *args, + **kwargs, + ): """ Execute a function with bridged context between MCP and agent @@ -397,6 +415,9 @@ async def bridged_progress(progress, total=None) -> None: if hasattr(agent_context, "progress_reporter"): agent_context.progress_reporter = bridged_progress + if aggregator is not None: + agent_context.response_aggregator = aggregator + try: # Call the function return await func(*args, **kwargs) @@ -408,6 +429,8 @@ async def bridged_progress(progress, total=None) -> None: # Remove MCP context reference if hasattr(agent_context, "mcp_context"): delattr(agent_context, "mcp_context") + if aggregator is not None and hasattr(agent_context, "response_aggregator"): + delattr(agent_context, "response_aggregator") async def _cleanup_stdio(self): """Minimal cleanup for STDIO transport to avoid keeping process alive.""" diff --git a/src/mcp_agent/server/response_aggregator.py b/src/mcp_agent/server/response_aggregator.py new file mode 100644 index 00000000..14d8f161 --- /dev/null +++ b/src/mcp_agent/server/response_aggregator.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from enum import Enum +from typing import Any, Dict + + +class ChainResponseAggregator: + """Aggregate responses for a multi-agent chain.""" + + def __init__(self, chain_name: str, total_agents: int) -> None: + self.chain_name = chain_name + self.total_agents = total_agents + self.agent_responses: Dict[str, Any] = {} + self.completed_agents = 0 + self._response_sent = False + + async def add_agent_response(self, agent_name: str, response: Any) -> None: + """Record a response from an agent in the chain.""" + self.agent_responses[agent_name] = response + self.completed_agents += 1 + + async def should_send_response(self) -> bool: + """Return ``True`` if the aggregated response should be sent.""" + return not self._response_sent and self.completed_agents >= self.total_agents + + async def get_aggregated_response(self) -> Dict[str, Any]: + """Return the aggregated response for the chain.""" + self._response_sent = True + return {"chain": self.chain_name, "responses": self.agent_responses} + + +class SSEEventType(Enum): + AGENT_START = "agent_start" + AGENT_PROGRESS = "agent_progress" + AGENT_COMPLETE = "agent_complete" + CHAIN_COMPLETE = "chain_complete" + ERROR = "error" + + +async def send_sse_event(event_type: SSEEventType, data: Dict[str, Any], stream: Any) -> None: + """Send an SSE event to the provided stream if possible.""" + if stream is not None and hasattr(stream, "send"): + await stream.send({"event": event_type.value, "data": data}) diff --git a/tests/unit/mcp_agent/server/test_response_aggregator.py b/tests/unit/mcp_agent/server/test_response_aggregator.py new file mode 100644 index 00000000..909808a7 --- /dev/null +++ b/tests/unit/mcp_agent/server/test_response_aggregator.py @@ -0,0 +1,43 @@ +import importlib.util +from pathlib import Path + +import pytest + +MODULE_PATH = ( + Path(__file__).resolve().parents[4] / "src" / "mcp_agent" / "server" / "response_aggregator.py" +) +spec = importlib.util.spec_from_file_location("response_aggregator", MODULE_PATH) +response_aggregator = importlib.util.module_from_spec(spec) +assert spec.loader +spec.loader.exec_module(response_aggregator) + +ChainResponseAggregator = response_aggregator.ChainResponseAggregator +SSEEventType = response_aggregator.SSEEventType +send_sse_event = response_aggregator.send_sse_event + + +@pytest.mark.asyncio +async def test_chain_response_aggregator(): + agg = ChainResponseAggregator("chain", 2) + await agg.add_agent_response("a1", "one") + assert not await agg.should_send_response() + await agg.add_agent_response("a2", "two") + assert await agg.should_send_response() + result = await agg.get_aggregated_response() + assert result["chain"] == "chain" + assert result["responses"] == {"a1": "one", "a2": "two"} + + +class _DummyStream: + def __init__(self) -> None: + self.sent = [] + + async def send(self, data): + self.sent.append(data) + + +@pytest.mark.asyncio +async def test_send_sse_event(): + stream = _DummyStream() + await send_sse_event(SSEEventType.AGENT_START, {"foo": "bar"}, stream) + assert stream.sent == [{"event": "agent_start", "data": {"foo": "bar"}}]