Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
812aafc
first mcp
AlexPiche Aug 15, 2025
ca8516b
fix the env server
AlexPiche Aug 16, 2025
f3af1bc
tweak prompt
AlexPiche Aug 16, 2025
5b10c33
upd
AlexPiche Aug 16, 2025
d2e6d09
clean up
AlexPiche Aug 18, 2025
228cb42
hard code dino
AlexPiche Aug 18, 2025
fdf3c83
less envs
AlexPiche Aug 18, 2025
1165397
less envs
AlexPiche Aug 18, 2025
40a144a
longer timeout
AlexPiche Aug 18, 2025
2d25d88
longer seq length
AlexPiche Aug 18, 2025
2036167
more envs
AlexPiche Aug 18, 2025
664b539
more llms per actor
AlexPiche Aug 18, 2025
4b0db03
even more envs
AlexPiche Aug 18, 2025
63d4092
longer timeout and revert prompt
AlexPiche Aug 18, 2025
6d81456
retry task
AlexPiche Aug 18, 2025
373b0ac
pid deno module
AlexPiche Aug 18, 2025
e2de768
diff deno tmp dir
AlexPiche Aug 18, 2025
763b594
none node modules
AlexPiche Aug 18, 2025
0783570
bigger timeout
AlexPiche Aug 19, 2025
b284fcb
diff temp dir for each mcp
AlexPiche Aug 19, 2025
eb48d90
0.0.0.0
AlexPiche Aug 19, 2025
efa2717
filter based on port
AlexPiche Aug 19, 2025
3d86a28
change port to 7778
AlexPiche Aug 21, 2025
96a75c1
mcp and verify server
AlexPiche Aug 21, 2025
0b4c992
use custom parser
AlexPiche Aug 21, 2025
471d28d
relative path
AlexPiche Aug 21, 2025
8e0eeff
test apth
AlexPiche Aug 21, 2025
f93d756
typo
AlexPiche Aug 21, 2025
32e3eb6
clean up
AlexPiche Aug 21, 2025
5a3ab0e
clean up
AlexPiche Aug 21, 2025
436e233
rename domain to mcp
AlexPiche Aug 22, 2025
366263b
more envs
AlexPiche Aug 22, 2025
371be6e
less env replicas
AlexPiche Aug 22, 2025
1045868
Merge remote-tracking branch 'origin/debug_miniwob' into mcp_tir
AlexPiche Aug 22, 2025
05f7667
Merge remote-tracking branch 'origin/debug_miniwob' into mcp_tir
AlexPiche Aug 22, 2025
46b39d1
clean up tmp
AlexPiche Aug 22, 2025
af63f51
change mcp dir
AlexPiche Aug 22, 2025
55a96e5
bigger model len
AlexPiche Aug 22, 2025
dd0ea2b
typo
AlexPiche Aug 22, 2025
dc4052d
typo
AlexPiche Aug 23, 2025
bb4d0c5
clean up
AlexPiche Aug 26, 2025
ccdcd32
center reward
AlexPiche Aug 26, 2025
7f5ed95
running avg reward
AlexPiche Aug 26, 2025
88a0ee7
start from real mean
AlexPiche Aug 26, 2025
66bcfbd
Fix paths
rafapi Aug 28, 2025
3fcb847
Use relative path
rafapi Aug 28, 2025
9f239c6
Fix path
rafapi Aug 28, 2025
020a021
revert mktemp changes
rafapi Aug 28, 2025
4323f57
Fix deno paths
rafapi Aug 29, 2025
2b5e9f5
udt
rafapi Aug 29, 2025
565d25c
make the cache tag stable across all processes
rafapi Aug 29, 2025
fc17df7
fix
rafapi Aug 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ celerybeat.pid

# SageMath parsed files
*.sage.py
node_modules/

# Environments
.env
Expand Down Expand Up @@ -185,4 +186,4 @@ results
results/
data/
cache/
dump.rdb
dump.rdb
2 changes: 1 addition & 1 deletion conf/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ vllm_config:
tensor-parallel-size: 1
pipeline-parallel-size: 1
generation-config: vllm
max_model_len: 10000
max_model_len: 16000

world:
replicas: 1
Expand Down
2 changes: 1 addition & 1 deletion conf/finetune/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ learning_rate: 1e-6
# How much to clip the gradient (no clipping if null)
gradient_clipping_threshold: 0.3
# Learning rate scheduler type (indexed by completed_steps).
lr_scheduler_type: cosine # could be cosine, constant_with_warmup
lr_scheduler_type: constant # could be cosine, constant_with_warmup
# Number of warmup (completed) steps in the learning rate schedule.
num_warmup_steps: 50
# Number of gradient accumulation steps.
Expand Down
133 changes: 133 additions & 0 deletions conf/mcp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
defaults:
- base
- _self_


actor:
rollout_policy: pipelinerl.domains.mcp.generate_mcp_rollout
system_prompt: Please reason step by step, and put your final answer within \boxed{}.
llm_max_rollouts: 64
task_template: |-
{task}

finetune:
seq_length: 48000
seq_parallel: 4

dataset_loader: pipelinerl.domains.math.load_datasets
train_dataset_names:
- open_reasoner_zero_57k
- open_reasoner_zero_extended_72k
test_dataset_names:
- aime_2024
- amc_2023
- math_500

vllm_config:
use_v1: false
vllm_kwargs:
enable-auto-tool-choice: ""
tool-call-parser: rl_tool
tool-parser-plugin: ${hydra:runtime.cwd}/pipelinerl/rl_tool_parser_plugin.py
max_model_len: 40960

environment:
_target_: pipelinerl.domains.mcp.MCPEnvironmentServer
n_envs: 8
host: "0.0.0.0"
exp_path: ${output_dir}/env_server
mcp_target: tapeagents.mcp.MCPEnvironment
mcp_config_path: ${hydra:runtime.cwd}/conf/mcp/python.json
mcp_tools_whitelist:
- run_python_code
env_call_timeout: 600 # Increased from default 60s to 10 minutes
mcp_read_timeout_seconds: 3000


world:
env_replicas_per_actor: 8

agent_max_loops: 3
agent:
_target_: tapeagents.agent.Agent
name : mcp_agent
max_iterations: 3
store_llm_calls: true
templates:
system_prompt: |
You are an expert AI Agent trained to assist users with complex information processing tasks.
Your role is to understand user queries and respond in a helpful and accurate manner.
Keep your replies concise and direct. Prioritize clarity and avoid over-elaboration.
Do not express emotions or opinions about user questions.
allowed_tools: |
You have access to the following tools:
{tools_description}
thought_format: |
Important! Respond with the plain text, do not include any JSON or code.
Do not output anything besides what I asked in this message.
allowed_steps: |
You have access to the following tools:
{tools_description}
format: >
Output only a single JSON dict.
Do not repeat the last thought again.
If the last action does not change the observation, do not repeat it!
DO NOT OUTPUT ANYTHING BESIDES THE JSON! DO NOT PLACE ANY COMMENTS INSIDE THE JSON.
It will break the system that processes the output.


nodes:
- _target_: tapeagents.nodes.StandardNode
name: plan
system_prompt: ${agent.templates.system_prompt}
guidance: |
Write a concise multi-step plan explaining which steps should be performed to find the answer for the given task.
Be specific about how each step should be performed. Only describe the intended actions here, do not perform them yet.
Consider that next steps may depend on results of previous steps, so include conditional branching using "if" statements where needed.
Start with the title "Plan". Every step should have short name and description.
${agent.templates.thought_format}
steps_prompt: ${agent.templates.allowed_tools}

- _target_: tapeagents.nodes.StandardNode
name: select
system_prompt: ${agent.templates.system_prompt}
trim_obs_except_last_n: 100
guidance: |
Select the next step to do to move forward with the plan. Describe the expected effect of the proposed action.
${agent.templates.thought_format}
steps_prompt: ${agent.templates.allowed_tools}

- _target_: tapeagents.nodes.StandardNode
name: act
system_prompt: ${agent.templates.system_prompt}
trim_obs_except_last_n: 100
guidance: Then produce single function call for the next step. If the answer is ready, call MathAnswer. Put your final answer within \boxed{}.
steps:
- pipelinerl.domains.mcp.steps.MathAnswer
use_known_actions: true
use_function_calls: true

- _target_: tapeagents.nodes.StandardNode
name: summarize
system_prompt: ${agent.templates.system_prompt}
trim_obs_except_last_n: 100
guidance: |
Summarize last observation. If its an image, thoroughly describe it with all details.
Describe the results of the last action and observed changes
Do not hallucinate or make up any information, only describe what you see in the observation.
Do not guess or assume action effects, describe only visible changes.
${agent.templates.thought_format}

- _target_: tapeagents.nodes.StandardNode
name: reflect
system_prompt: ${agent.templates.system_prompt}
trim_obs_except_last_n: 100
guidance: |
1. Evaluate the action's success, explain its effect on current step, overall plan and task solution.
2. If the last action was not successful, describe errors and the possible reasons for failure.
3. Check if the current plan step is finished.
4. If the step is finished, update the following steps of the plan with new information and choose the next step.
${agent.templates.thought_format}
next_node: select

model_path: Qwen/Qwen3-8B
11 changes: 11 additions & 0 deletions conf/mcp/python.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"mcpServers": {
"python_exec": {
"command": "bash",
"args": [
"-c",
"JOB_TAG=${MCP_JOB_TAG:-${JOB_ID:-$HOSTNAME}} && BASE=/home/toolkit/.cache && mkdir -p \"$BASE/mcp_tmp/$JOB_TAG\" \"$BASE/deno_mcp/$JOB_TAG\" \"$BASE/tmp/$JOB_TAG\" && export DENO_DIR=\"$BASE/deno_mcp/$JOB_TAG\" TMPDIR=\"$BASE/tmp/$JOB_TAG\" && /home/toolkit/.deno/bin/deno cache jsr:@pydantic/mcp-run-python >/dev/null 2>&1 || true; DIR=$(mktemp -d -p \"$BASE/mcp_tmp/$JOB_TAG\" mcp_XXXXXXXX) && cd \"$DIR\" && /home/toolkit/.deno/bin/deno run -N -R=node_modules -W=node_modules --node-modules-dir=auto jsr:@pydantic/mcp-run-python stdio; EC=$?; cd /; rm -rf \"$DIR\"; exit $EC"
]
}
}
}
2 changes: 1 addition & 1 deletion pipelinerl/domains/math/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .load_datasets import load_datasets
from .rollouts import generate_math_rollout, RewardTable
from .rollouts import generate_math_rollout, RewardTable, get_reward
from .verifier_api import MathEnvironment, verify_answer, verify_answer_rpc
47 changes: 25 additions & 22 deletions pipelinerl/domains/math/rollouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,28 @@ class RewardTable(BaseModel):
correct_answer_finished: float
buffer_tokens: int = 0 # 0 means no overlong reward shaping

def get_reward(answer_status: str, finished: bool, reward_table: RewardTable) -> float:
match (answer_status, finished):
case ("wrong", False):
return reward_table.wrong_answer_not_finished
case ("wrong", True):
return reward_table.wrong_answer_finished
case ("no_answer", False):
return reward_table.no_answer_not_finished
case ("no_answer", True):
return reward_table.no_answer_finished
case ("unparsable", False):
return reward_table.unparsable_not_finished
case ("unparsable", True):
return reward_table.unparsable_finished
case ("correct", False):
return reward_table.correct_answer_not_finished
case ("correct", True):
return reward_table.correct_answer_finished
case _:
raise ValueError(f"Invalid answer_status/finished combination: {answer_status}/{finished}")


def length_penalty(max_length: int, sequence_length: int, buffer_tokens: int) -> float:
"""
Compute the overlong penalty
Expand All @@ -51,7 +73,7 @@ async def generate_math_rollout(
latency = time.time() - time_start

assert llm_call.output.content is not None
rewards = RewardTable(**dict(cfg.rewards))
reward_table = RewardTable(**dict(cfg.rewards))
discount_factor = cfg.actor.discount_factor

# math_verify is a fast environment, no support for environment replicas for now
Expand All @@ -70,30 +92,11 @@ async def generate_math_rollout(

trace = make_training_text(llm, llm_call)
# Determine reward based on answer status and finished state
match (answer_status, trace.finished):
case ("wrong", False):
reward = rewards.wrong_answer_not_finished
case ("wrong", True):
reward = rewards.wrong_answer_finished
case ("no_answer", False):
reward = rewards.no_answer_not_finished
case ("no_answer", True):
reward = rewards.no_answer_finished
case ("unparsable", False):
reward = rewards.unparsable_not_finished
case ("unparsable", True):
reward = rewards.unparsable_finished
case ("correct", False):
reward = rewards.correct_answer_not_finished
case ("correct", True):
reward = rewards.correct_answer_finished
case _:
raise ValueError(f"Invalid answer_status/finished combination: {answer_status}/{trace.finished}")

reward = get_reward(answer_status, trace.finished, reward_table)
# Apply discount factor based on output length
reward *= discount_factor**llm_call.output_length_tokens
overlong_penalty = 0
if rewards.buffer_tokens > 0:
if reward_table.buffer_tokens > 0:
overlong_penalty = length_penalty(llm.parameters['max_tokens'], llm_call.output_length_tokens, rewards.buffer_tokens)
reward += overlong_penalty
trace.reward = reward
Expand Down
2 changes: 2 additions & 0 deletions pipelinerl/domains/mcp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .rollouts import generate_mcp_rollout
from .env_server import MCPEnvironmentServer
101 changes: 101 additions & 0 deletions pipelinerl/domains/mcp/env_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os
from tapeagents.remote_environment import EnvironmentServer
from omegaconf import OmegaConf
from typing import List
from fastapi import HTTPException
from pydantic import BaseModel
import logging
import asyncio
from concurrent.futures import ProcessPoolExecutor
from functools import partial

from pipelinerl.domains.math.verifier_api import verify_answer

logger = logging.getLogger(__name__)


class EnvironmentServerWithVerifier(EnvironmentServer):
"""Environment server that includes the verify_answer endpoint."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.process_pool = ProcessPoolExecutor(max_workers=4)

def create_app(self):
app = super().create_app()

class VerifyAnswerRequest(BaseModel):
prediction: str
gold: str
strict: bool = True
max_prediction_length: int = 1000

@app.post("/verify_answer")
async def verify_answer_endpoint(request: VerifyAnswerRequest):
try:
# Run verification in the process pool to avoid blocking the main thread
loop = asyncio.get_event_loop()
answer_status = await loop.run_in_executor(
self.process_pool,
partial(
verify_answer,
request.prediction,
request.gold,
request.strict,
request.max_prediction_length
)
)
return {"answer_status": answer_status}
except Exception as e:
logger.exception(f"Error in verify_answer: {e}")
raise HTTPException(status_code=500, detail=f"Error verifying answer: {str(e)}")

return app

def shutdown(self):
super().shutdown()
if hasattr(self, 'process_pool'):
self.process_pool.shutdown(wait=True)


class MCPEnvironmentServer:

def __init__(self,
n_envs: int,
host: str,
mcp_target: str,
mcp_config_path: str,
mcp_tools_whitelist: List[str],
exp_path: str,
env_call_timeout: int = 60,
mcp_read_timeout_seconds: int = 10,
):
# Remote environment server configuration
self.n_envs = n_envs
self.host = host
self.env_call_timeout = env_call_timeout
# Individual web environment configuration
self.mcp_target = mcp_target
self.mcp_config_path = mcp_config_path
self.mcp_tools_whitelist = mcp_tools_whitelist
self.exp_path = exp_path
self.mcp_read_timeout_seconds = mcp_read_timeout_seconds


def launch(self, port: int):
"""
Serve the environment in TapeAgent with verify_answer endpoint.
"""
env_server = EnvironmentServerWithVerifier(
n_envs=self.n_envs,
host=self.host,
port=port,
env_call_timeout=self.env_call_timeout
)
env_server.launch(OmegaConf.create({
"_target_": self.mcp_target,
"config_path": self.mcp_config_path,
"tools_whitelist": self.mcp_tools_whitelist,
"read_timeout_seconds": self.mcp_read_timeout_seconds,
}))

Loading