diff --git a/conf/rewards/tir_pure_success.yaml b/conf/rewards/tir_pure_success.yaml
new file mode 100644
index 00000000..8dfc9396
--- /dev/null
+++ b/conf/rewards/tir_pure_success.yaml
@@ -0,0 +1,9 @@
+correct_answer: 1.0
+wrong_answer: -0.15
+no_answer: -0.25
+unparsable: -0.25
+execution_failure: -0.05
+successful_code_execution: 0.0
+timeout_penalty: -0.25
+buffer_tokens: 0
+iteration_penalty: -0.005
diff --git a/conf/tir.yaml b/conf/tir.yaml
index 523e747a..3fa9ad08 100644
--- a/conf/tir.yaml
+++ b/conf/tir.yaml
@@ -6,83 +6,92 @@ save_tapes: true
llm:
parameters:
- max_tokens: 3000
+ max_tokens: 512 # Reduced from 1024 to help stay within context limits
temperature: 0.2
+ stop:
+ - "```output"
+
+test_llm:
+ parameters:
+ max_tokens: 512 # Reduced from 1024 to help stay within context limits
+ temperature: 0.2
+ stop:
+ - "```output"
actor:
rollout_policy: pipelinerl.domains.tir.rollouts.generate_tir_rollout
- # TIR mode: 'fast' (single candidate + iterative reasoning) or 'sc_tir' (multiple candidates + majority voting)
- mode: fast # Default to fast mode
- # SC-TIR parameters (only used when mode=sc_tir)
- num_candidates: 4 # Width: number of solution candidates to generate
- max_reasoning_steps: 8 # Depth: max reasoning iterations per candidate
- system_prompt: |-
- You are a mathematical problem-solving assistant. Use Python code to solve problems step by step.
-
- Instructions:
- 1. Read the problem carefully
- 2. Write Python code to solve it, showing your work
- 3. Execute the code and analyze the output
- 4. If needed, write more code to refine your solution
- 5. Once you have the final answer, present it in \boxed{} format
-
- Code format:
- ```python
- # Your Python code here
- ```
-
- The code will be executed and you'll see the output like:
- ```output
- # Execution results will appear here
- ```
-
- You can write multiple code blocks if needed. Use libraries like sympy, numpy, math as needed.
+ max_reasoning_steps: 6 # Reduced from 8 to encourage concise reasoning
+ system_prompt: ""
+ #TODO: rm debug code
+ llm_max_rollouts: 1
+
+ # system_prompt: |-
+ # You are a mathematical problem-solving assistant. Use Python code to solve problems step by step.
+
+ # Instructions:
+ # 1. Read the problem carefully
+ # 2. Write Python code to solve it, showing your work
+ # 3. Execute the code and analyze the output
+ # 4. If needed, write more code to refine your solution
+ # 5. Once you have the final answer, present it in \boxed{} format
+
+ # Code format:
+ # ```python
+ # # Your Python code here
+ # ```
+
+ # The code will be executed and you'll see the output like:
+ # ```output
+ # # Execution results will appear here
+ # ```
+
+ # You can write multiple code blocks if needed. Use libraries like sympy, numpy, math as needed.
- Always end with your final answer in the format: \boxed{answer}
-
- Examples:
-
- Problem: What is 2 + 3?
- ```python
- result = 2 + 3
- print(f"2 + 3 = {result}")
- ```
- ```output
- 2 + 3 = 5
- ```
- \boxed{5}
-
- Problem: Solve x^2 - 5x + 6 = 0
- ```python
- from sympy import symbols, solve, expand
- x = symbols('x')
- equation = x**2 - 5*x + 6
- solutions = solve(equation, x)
- print(f"Solutions: {solutions}")
+ # Always end with your final answer in the format: \boxed{answer}
+
+ # Examples:
+
+ # Problem: What is 2 + 3?
+ # ```python
+ # result = 2 + 3
+ # print(f"2 + 3 = {result}")
+ # ```
+ # ```output
+ # 2 + 3 = 5
+ # ```
+ # \boxed{5}
+
+ # Problem: Solve x^2 - 5x + 6 = 0
+ # ```python
+ # from sympy import symbols, solve, expand
+ # x = symbols('x')
+ # equation = x**2 - 5*x + 6
+ # solutions = solve(equation, x)
+ # print(f"Solutions: {solutions}")
- # Verify by substitution
- for sol in solutions:
- check = sol**2 - 5*sol + 6
- print(f"x = {sol}: {sol}^2 - 5*{sol} + 6 = {check}")
- ```
- ```output
- Solutions: [2, 3]
- x = 2: 2^2 - 5*2 + 6 = 0
- x = 3: 3^2 - 5*3 + 6 = 0
- ```
- The solutions are x = 2 and x = 3.
- \boxed{2, 3}
- task_template: |-
- {task}
+ # # Verify by substitution
+ # for sol in solutions:
+ # check = sol**2 - 5*sol + 6
+ # print(f"x = {sol}: {sol}^2 - 5*{sol} + 6 = {check}")
+ # ```
+ # ```output
+ # Solutions: [2, 3]
+ # x = 2: 2^2 - 5*2 + 6 = 0
+ # x = 3: 3^2 - 5*3 + 6 = 0
+ # ```
+ # The solutions are x = 2 and x = 3.
+ # \boxed{2, 3}
+ # task_template: |-
+ # {task}
model_path: /mnt/llmd/base_models/AI-MO-NuminaMath-7B-TIR
-output_dir: results/tir/${now:%Y-%m-%d}/${now:%H-%M-%S}
+output_dir: /mnt/llmd/results/exps/rafa/tir/${now:%Y-%m-%d}/${now:%H-%M-%S}
world:
env_replicas: 1
-max_loops: 10 # extra iterations
+max_loops: 25 # let the agent handle its own termination
environment:
_target_: pipelinerl.domains.tir.environment.MCPPythonEnvironment
@@ -90,13 +99,17 @@ environment:
agent:
_target_: pipelinerl.domains.tir.agent.TIRMathAgent
system_prompt: ${actor.system_prompt}
+ max_reasoning_pairs: 3 # Number of recent code-execution pairs to keep in context
+ max_code_chars: 800 # Maximum characters of code per step
+ max_output_chars: 2000 # Maximum characters of execution output per step
dataset_loader: pipelinerl.domains.tir.datasets.load_datasets
train_dataset_names:
- - math_train
+ - open_reasoner_zero_57k
+ - open_reasoner_zero_extended_72k
test_dataset_names:
- # - math_test
- aime_2024
- # - amc_2023
\ No newline at end of file
+ - amc_2023
+ - math_500
\ No newline at end of file
diff --git a/conf/tir_sc.yaml b/conf/tir_sc.yaml
deleted file mode 100644
index e144fc41..00000000
--- a/conf/tir_sc.yaml
+++ /dev/null
@@ -1,12 +0,0 @@
-defaults:
- - tir
- - _self_
-
-# Override to use SC-TIR mode with multiple candidates and majority voting
-actor:
- mode: sc_tir # Enable SC-TIR mode
- num_candidates: 4 # Width: number of solution candidates to generate
- max_reasoning_steps: 8 # Depth: max reasoning iterations per candidate
-
-# SC-TIR is slower, so we might want to reduce other settings for faster iteration
-output_dir: results/tir_sc/${now:%Y-%m-%d}/${now:%H-%M-%S}
\ No newline at end of file
diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py
index 8621250c..d1566e32 100644
--- a/pipelinerl/actor.py
+++ b/pipelinerl/actor.py
@@ -15,14 +15,14 @@
import aiohttp
import hydra
import uvloop
-from omegaconf import DictConfig
+from omegaconf import DictConfig, OmegaConf
from pydantic import BaseModel, Field
from tapeagents.llms import TrainableLLM
from typing import Dict, List
import wandb
from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb
-from pipelinerl.rollouts import RolloutResult
+from pipelinerl.rollouts import RolloutResult, BaseMetrics
from pipelinerl.shared_memory_array import SharedMemoryQueue
from pipelinerl.state import TrainerState
from pipelinerl.streams import (
@@ -257,6 +257,8 @@ def rollout_maker_entrypoint(
def random_iter(problems: list):
+ if not problems:
+ raise ValueError(f"Cannot iterate over empty problems list. No data was loaded.")
while True:
yield random.sample(problems, 1)[0]
@@ -348,6 +350,7 @@ def update_stats(self, rollout_results: List[RolloutResult]):
self.latency_list.append(result.latency)
self.model_versions_list.append(result.model_version)
domain_agnostic_metrics = self.compute_domain_agnostic_metrics(result)
+ assert isinstance(result.metrics, BaseMetrics), "Metrics should be an instance of BaseMetrics"
all_metrics = result.metrics.model_dump() | domain_agnostic_metrics
for k, v in all_metrics.items():
if isinstance(v, list):
@@ -528,6 +531,24 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict):
for agg, sub_stats in calculate_stats(list_of_stats_per_metric_and_dataset).items():
stats[f"{dataset_name}/{metric_name}_{agg}"] = sub_stats
+ # Add clean dataset-specific pass rates for test evaluation
+ if not self.is_training and "success" in self.stats:
+ dataset_pass_rates = self._calculate_dataset_pass_rates()
+ stats.update(dataset_pass_rates)
+
+ # Log clean pass rates to console for easy viewing
+ if dataset_pass_rates:
+ logger.info("Dataset Pass Rates:")
+ for key, value in dataset_pass_rates.items():
+ if key.startswith("pass_rate/"):
+ dataset_name = key.replace("pass_rate/", "")
+ logger.info(f" {dataset_name}: {value:.1f}%")
+
+ # Debug: log all metrics being sent to wandb
+ logger.info("All dataset metrics being logged to wandb:")
+ for key, value in dataset_pass_rates.items():
+ logger.info(f" actor/{key}: {value}")
+
stats |= (
{
f"{split_name}{k}": v
@@ -549,7 +570,49 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict):
if self.cfg.wandb.use_wandb:
wandb.log({f"actor/{k}": v for k, v in stats.items()})
stats_writer.write(stats)
- self.init_stats() # Reset stats for the next iteration
+
+ # Only reset stats for training (not test evaluation)
+ if self.is_training:
+ self.init_stats() # Reset stats for the next iteration
+
+ def _calculate_dataset_pass_rates(self) -> Dict[str, float]:
+ """Calculate clean dataset-specific pass rates matching the table format."""
+ pass_rates = {}
+
+ # Dataset name mapping for clean display
+ dataset_name_mapping = {
+ "gsm8k_test": "GSM8k",
+ "gsm8k_train": "GSM8k",
+ "math_test": "MATH",
+ "math_train": "MATH",
+ "aime_2024": "AIME 2024",
+ "aime_2023": "AIME 2023",
+ "aime_2022": "AIME 2022",
+ "amc_2023": "AMC 2023",
+ "amc_2022": "AMC 2022",
+ }
+
+ success_stats = self.stats.get("success", {})
+
+ for dataset_name, group_results in success_stats.items():
+ # Flatten all success values for this dataset
+ all_successes = []
+ for group_id, success_list in group_results.items():
+ all_successes.extend(success_list)
+
+ if all_successes:
+ # Calculate pass rate as percentage
+ pass_rate = (sum(all_successes) / len(all_successes)) * 100
+
+ # Use clean dataset name if available, otherwise use original
+ clean_name = dataset_name_mapping.get(dataset_name, dataset_name)
+ pass_rates[f"pass_rate/{clean_name}"] = pass_rate
+
+ # Track total problems attempted for all datasets
+ pass_rates[f"problems_solved/{clean_name}"] = sum(all_successes)
+ pass_rates[f"problems_total/{clean_name}"] = len(all_successes)
+
+ return pass_rates
def run_actor_loop(cfg: DictConfig):
@@ -592,7 +655,7 @@ def run_actor_loop(cfg: DictConfig):
base_url=url,
model_name=str(actor_model_path),
tokenizer_name=str(actor_model_path),
- parameters=cfg.llm.parameters,
+ parameters=OmegaConf.to_container(cfg.llm.parameters, resolve=True),
use_cache=False,
collect_logprobs=True,
observe_llm_calls=False,
@@ -604,7 +667,7 @@ def run_actor_loop(cfg: DictConfig):
base_url=url,
model_name=str(actor_model_path),
tokenizer_name=str(actor_model_path),
- parameters=cfg.test_llm.parameters,
+ parameters=OmegaConf.to_container(cfg.test_llm.parameters, resolve=True),
use_cache=False,
collect_logprobs=True,
observe_llm_calls=False,
@@ -648,13 +711,22 @@ def run_actor_loop(cfg: DictConfig):
if last_regular_eval == -1
else last_regular_eval + cfg.eval_every_n_versions
)
- if (
+
+ # In eval debug mode, run test evaluation immediately and only once
+ should_run_test_eval = False
+ if cfg.debug.mode == "eval" and test_dataset and test_loop_run is None and last_regular_eval == -1:
+ should_run_test_eval = True
+ logger.info("Eval debug mode: Running test evaluation immediately")
+ elif (
cfg.eval_every_n_versions
and not cfg.debug.mode
and trainer_state.propagated_weight_version >= next_regular_eval
and test_dataset
and test_loop_run is None
):
+ should_run_test_eval = True
+
+ if should_run_test_eval:
logger.info("Create test loop")
test_loop_run = test_loop.run(
dataset=test_dataset,
@@ -672,6 +744,11 @@ def run_actor_loop(cfg: DictConfig):
last_regular_eval = current_eval
train_loop.is_scheduling_paused = False
logger.info("Test loop finished")
+
+ # In eval debug mode, exit after test evaluation completes
+ if cfg.debug.mode == "eval":
+ logger.info("Eval debug mode: Test evaluation completed, exiting")
+ break
# 3. Keep running the training loop
_ = next(train_loop_run)
diff --git a/pipelinerl/domains/math/load_datasets.py b/pipelinerl/domains/math/load_datasets.py
index 2602400d..3ba609f8 100644
--- a/pipelinerl/domains/math/load_datasets.py
+++ b/pipelinerl/domains/math/load_datasets.py
@@ -216,7 +216,7 @@ def load_datasets(dataset_names: List[str] | str | None) -> List[Tuple[str, Dict
if "math_train" in dataset_names:
# math_dataset = load_math("train")
- dataset = load_dataset("hendrycks/competition_math", split="train", trust_remote_code=True)
+ dataset = load_dataset("hendrycks/competition_math", "default", split="train", trust_remote_code=True)
samples = [s for s in process_math(dataset, "math_train") if s is not None]
logger.info(f"Loading math train dataset: {len(samples)} samples")
datasets += add_ids(samples)
@@ -260,7 +260,7 @@ def load_datasets(dataset_names: List[str] | str | None) -> List[Tuple[str, Dict
if "math_test" in dataset_names:
# math_dataset = load_math("test")
- dataset = load_dataset("hendrycks/competition_math", split="test", trust_remote_code=True)
+ dataset = load_dataset("hendrycks/competition_math", "default", split="test", trust_remote_code=True)
samples = [s for s in process_math(dataset, "math_test") if s is not None]
logger.info(f"Loading math test dataset: {len(samples)} samples")
datasets += add_ids(samples)
diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py
index 41a61021..8cd8e163 100644
--- a/pipelinerl/domains/math/rollouts.py
+++ b/pipelinerl/domains/math/rollouts.py
@@ -7,7 +7,7 @@
from pipelinerl.rollouts import RolloutResult, BaseMetrics
from pipelinerl.world import Job
from tapeagents.core import Prompt
-from tapeagents.llms.trainable import TrainableLLM
+from tapeagents.llms import TrainableLLM
from pipelinerl.async_llm import llm_async_generate, make_training_text
from .verifier_api import verify_answer_rpc
diff --git a/pipelinerl/domains/tir/agent.py b/pipelinerl/domains/tir/agent.py
index f696753c..b131086e 100644
--- a/pipelinerl/domains/tir/agent.py
+++ b/pipelinerl/domains/tir/agent.py
@@ -1,12 +1,11 @@
-"""TIR Math Agent implementation for Tool Integrated Reasoning."""
-
import logging
+import math
import re
from typing import Any, Generator, Union, Literal
from pydantic import Field
from tapeagents.agent import Agent
-from tapeagents.core import Action, Prompt, Step, Tape, Observation, LLMOutputParsingFailureAction, SetNextNode, AgentStep, StopStep
+from tapeagents.core import Prompt, Step, Tape, Observation, LLMOutputParsingFailureAction, SetNextNode, StopStep
from tapeagents.llms import LLM
from tapeagents.nodes import Node
from tapeagents.steps import ActionExecutionFailure
@@ -34,109 +33,427 @@ class CodeExecutionNode(Node):
"""Node that generates Python code to solve math problems with iterative reasoning."""
system_prompt: str = Field(default="", description="System prompt for the node")
+ max_reasoning_pairs: int = Field(default=3, description="Number of recent code-exec pairs to include in prompt")
+ max_code_chars: int = Field(default=800, description="Maximum characters of code to include per step")
+ max_output_chars: int = Field(default=2000, description="Maximum characters of execution output per step")
+
+ def _extract_numerical_value(self, text: str):
+ """Extract numerical value from text using multiple parsing strategies."""
+ if not text or not isinstance(text, str):
+ return None
+
+ text = text.strip()
+ if not text:
+ return None
+
+ # option 1: Simple integer
+ if re.match(r'^[+-]?\d+$', text):
+ try:
+ return int(text)
+ except ValueError:
+ pass
+
+ # option 2: Simple float
+ if re.match(r'^[+-]?\d+\.\d+$', text):
+ try:
+ return float(text)
+ except ValueError:
+ pass
+
+ # option 3: Scientific notation
+ if re.match(r'^[+-]?\d+(?:\.\d+)?[eE][+-]?\d+$', text):
+ try:
+ return float(text)
+ except ValueError:
+ pass
+
+ # option 4: Simple fraction
+ if '/' in text and len(text.split('/')) == 2:
+ try:
+ parts = text.split('/')
+ num = float(parts[0].strip())
+ den = float(parts[1].strip())
+ if den != 0:
+ value = num / den
+ if abs(value - round(value)) < 0.001:
+ return round(value)
+ return value
+ except ValueError:
+ pass
+
+ # option 5: Try to evaluate simple arithmetic expressions
+ try:
+ import ast
+ import operator
+
+ # Simple arithmetic operators
+ ops = {
+ ast.Add: operator.add,
+ ast.Sub: operator.sub,
+ ast.Mult: operator.mul,
+ ast.Div: operator.truediv,
+ ast.Pow: operator.pow,
+ ast.USub: operator.neg,
+ ast.UAdd: operator.pos,
+ }
+
+ def safe_eval(node):
+ if isinstance(node, ast.Constant): # Python 3.8+
+ return node.value
+ elif isinstance(node, ast.Num): # Python < 3.8
+ return node.n
+ elif isinstance(node, ast.BinOp):
+ return ops[type(node.op)](safe_eval(node.left), safe_eval(node.right))
+ elif isinstance(node, ast.UnaryOp):
+ return ops[type(node.op)](safe_eval(node.operand))
+ else:
+ raise ValueError(f"Unsupported operation: {type(node)}")
+
+ if re.match(r'^[0-9+\-*/().\s]+$', text):
+ tree = ast.parse(text, mode='eval')
+ result = safe_eval(tree.body)
+ if isinstance(result, (int, float)) and not (math.isnan(result) or math.isinf(result)):
+ if abs(result - round(result)) < 0.001:
+ return round(result)
+ return result
+ except Exception:
+ pass
+
+ # option 6: Try SymPy parsing (if available)
+ try:
+ import sympy as sp
+
+ expr = sp.sympify(text)
+
+ if expr.is_number:
+ result = float(expr.evalf())
+ if not (math.isnan(result) or math.isinf(result)):
+ if abs(result - round(result)) < 0.001:
+ return round(result)
+ return result
+
+ elif expr.free_symbols:
+ substitutions = {}
+ for symbol in expr.free_symbols:
+ var_name = str(symbol)
+ if var_name in ['x', 'y', 'z']:
+ substitutions[symbol] = 1
+ elif var_name in ['t', 'time']:
+ substitutions[symbol] = 1
+ elif var_name in ['n', 'i', 'j', 'k']:
+ substitutions[symbol] = 1
+
+ if substitutions:
+ try:
+ substituted = expr.subs(substitutions)
+ if substituted.is_number:
+ result = float(substituted.evalf())
+ if not (math.isnan(result) or math.isinf(result)):
+ if abs(result - round(result)) < 0.001:
+ return round(result)
+ return result
+ except Exception as e:
+ logger.warning(f"Error evaluating SymPy expression: {e}")
+ pass
+ except Exception as e:
+ logger.warning(f"Error evaluating SymPy expression: {e}")
+ pass
+
+ return None
def make_prompt(self, agent: Any, tape: Tape) -> Prompt:
messages = []
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
- # Build conversation with task and previous code/results
task = tape.steps[0]
assert isinstance(task, Task), f"Expected a Task, got {task.__class__.__name__}"
- conversation_content = task.llm_view()
+ messages.append(
+ {"role": "user", "content": task.llm_view()}
+ )
- # Add previous code execution attempts and results
+ # Collect all reasoning steps first
+ all_reasoning_steps = []
for step in tape.steps[1:]:
+ if isinstance(step, (PythonCodeAction, CodeExecutionResult, ActionExecutionFailure)):
+ all_reasoning_steps.append(step)
+
+ # Deterministic last-k pairs
+ max_items = self.max_reasoning_pairs * 2
+ reasoning_steps = all_reasoning_steps[-max_items:] if len(all_reasoning_steps) > max_items else all_reasoning_steps
+
+ assistant_output_content = ""
+ for step in reasoning_steps:
if isinstance(step, PythonCodeAction):
- conversation_content += f"\n\n```python\n{step.code}\n```output\n"
+ code = step.code
+ if len(code) > self.max_code_chars:
+ head = code[: int(self.max_code_chars * 0.7)]
+ tail = code[-int(self.max_code_chars * 0.3):]
+ code = f"{head}\n# ... [code truncated] ...\n{tail}"
+ assistant_output_content += f"\n\n```python\n{code}\n```"
elif isinstance(step, CodeExecutionResult):
result = step.result.output.strip()
if "\n\nstdout:" in result:
result = result.split("\n\nstdout:")[0].strip()
- # Clean up result formatting
if result.startswith('"') and result.endswith('"'):
result = result[1:-1]
- conversation_content += f"{result}\n```"
+ if len(result) > self.max_output_chars:
+ lines = result.split('\n')
+ if len(lines) > 20:
+ kept_lines = lines[:10] + [f"... [{len(lines)-20} lines omitted] ..."] + lines[-10:]
+ result = '\n'.join(kept_lines)
+ else:
+ result = result[: self.max_output_chars] + "... [output truncated]"
+ assistant_output_content += f"\n```output\n{result}\n```"
elif isinstance(step, ActionExecutionFailure):
- conversation_content += f"Error: {step.error}\n```"
-
- messages.append({"role": "user", "content": conversation_content})
+ assistant_output_content += f"\n```output\nError: {step.error}\n```"
+
+ if assistant_output_content:
+ messages.append({"role": "assistant", "content": assistant_output_content})
- # Load tokenizer if needed
llm = agent.llms.get("default")
if llm and llm.tokenizer is None:
llm.load_tokenizer()
if llm and llm.tokenizer:
- prompt_token_ids = llm.tokenizer.apply_chat_template(
- messages, add_special_tokens=True, add_generation_prompt=True
- )
+ if messages[-1]["role"] == "user":
+ prompt_token_ids = llm.tokenizer.apply_chat_template(
+ messages, add_special_tokens=True, add_generation_prompt=True
+ )
+ else:
+ prompt_token_ids = llm.tokenizer.apply_chat_template(
+ messages, add_special_tokens=True, add_generation_prompt=False
+ )
else:
prompt_token_ids = None
return Prompt(messages=messages, token_ids=prompt_token_ids)
def generate_steps(self, agent: Any, tape: Tape, llm_stream) -> Generator[Step, None, None]:
- # Parse LLM output for Python code or final answer
output_text = llm_stream.get_output().content
if not output_text:
yield LLMOutputParsingFailureAction(error="Empty LLM output", llm_output=output_text)
yield SetNextNode(next_node="code_exec")
return
- # Check for boxed answer first
+ # extract Python code and boxed answer
+ python_code_pattern = r'```python\s*\n(.*?)```'
+ code_matches = re.findall(python_code_pattern, output_text, re.DOTALL)
+
boxed_pattern = r'\\boxed\{([^}]+)\}'
boxed_match = re.search(boxed_pattern, output_text)
- if boxed_match:
- value_str = boxed_match.group(1).strip()
- try:
- value = float(value_str)
- except ValueError:
- value = value_str
- logger.info(f"Found final boxed answer: {value}")
- yield AnswerAction(text=f"The answer is {value}", value=value)
+
+ has_execution_results = any(isinstance(step, CodeExecutionResult) for step in tape.steps)
+ last_action_was_verification = False
+ if tape.steps:
+ for step in reversed(tape.steps):
+ if isinstance(step, PythonCodeAction):
+ last_action_was_verification = step.name == "verification.py"
+ break
+
+ # CASE 1: both code and boxed answer present?
+ if code_matches and boxed_match and not last_action_was_verification:
+ code = code_matches[-1].strip()
+ boxed_value = boxed_match.group(1).strip()
+ logger.info(f"Found complete solution with code and boxed answer: {boxed_value}")
+
+ yield PythonCodeAction(name="verification.py", code=code, input_files=[])
+ yield SetNextNode(next_node="code_exec")
return
- # Look for Python code blocks
- python_code_pattern = r'```python\s*\n(.*?)\n```'
- code_matches = re.findall(python_code_pattern, output_text, re.DOTALL)
+ # CASE 1b: executed verification code - extract result or use boxed answer?
+ elif last_action_was_verification and has_execution_results:
+ last_result = None
+ for step in reversed(tape.steps):
+ if isinstance(step, CodeExecutionResult):
+ result = step.result.output.strip()
+ if result.startswith('"') and result.endswith('"'):
+ result = result[1:-1]
+ last_result = result
+ break
+
+ if last_result:
+ logger.info(f"Last execution result for answer extraction: '{last_result}'")
+ lines = last_result.strip().split('\n')
+ for i, line in enumerate(reversed(lines)):
+ line = line.strip()
+ logger.info(f"Checking line {i}: '{line}'")
+
+ extracted_value = self._extract_numerical_value(line)
+ if extracted_value is not None:
+ logger.info(f"Using execution result as answer: {extracted_value}")
+ yield AnswerAction(text=f"The answer is {extracted_value}", value=extracted_value)
+ return
+
+ # fallback: find boxed answer from original complete solution in tape history
+ original_boxed_answer = None
+ if hasattr(agent, 'llm_calls') and agent.llm_calls:
+ for llm_call in reversed(agent.llm_calls):
+ if hasattr(llm_call, 'response') and llm_call.response:
+ content = llm_call.response.content
+ if '```python' in content and '\\boxed{' in content:
+ boxed_match_history = re.search(r'\\boxed\{([^}]+)\}', content)
+ if boxed_match_history:
+ original_boxed_answer = boxed_match_history.group(1).strip()
+ break
+
+ if original_boxed_answer:
+ logger.info(f"Falling back to original boxed answer: '{original_boxed_answer}'")
+ try:
+ if '/' in original_boxed_answer and len(original_boxed_answer.split('/')) == 2:
+ parts = original_boxed_answer.split('/')
+ value = float(parts[0]) / float(parts[1])
+ else:
+ value = float(original_boxed_answer)
+ except ValueError:
+ value = original_boxed_answer
+ yield AnswerAction(text=f"The answer is {value}", value=value)
+ return
+
+ # something went wrong?
+ logger.warning("Failed to extract answer from verification step")
+ yield AnswerAction(text="Unable to determine answer", value=0)
+ return
- if code_matches:
- # Take the last code block
+ # CASE 2: only code present?
+ elif code_matches:
code = code_matches[-1].strip()
- logger.info(f"Extracted Python code: {code[:100]}...")
- yield PythonCodeAction(
- name="math_solution.py",
- code=code,
- input_files=[]
- )
+ logger.info(f"Extracted Python code for iteration: {code[:100]}...")
+
+ # why are we still generating code?
+ reasoning_attempts = len([s for s in tape.steps if isinstance(s, PythonCodeAction)])
+
+ recent_results = []
+ recent_errors = []
+ for step in tape.steps[-8:]: # last 8 steps
+ if isinstance(step, CodeExecutionResult):
+ result = step.result.output.strip()
+ if result.startswith('"') and result.endswith('"'):
+ result = result[1:-1]
+ recent_results.append(result.lower())
+ elif isinstance(step, ActionExecutionFailure):
+ recent_errors.append(step.error)
+
+ none_outputs = sum(1 for r in recent_results if r in ['none', '', 'null'])
+ same_outputs = len(recent_results) - len(set(recent_results)) if recent_results else 0
+
+ if (none_outputs >= 2 and reasoning_attempts >= 3) or \
+ (same_outputs >= 2 and reasoning_attempts >= 4) or \
+ reasoning_attempts >= 6:
+ logger.warning(f"Stopping code execution: {reasoning_attempts} attempts, {none_outputs} None outputs, {same_outputs} repeated outputs")
+
+ # look at previous outputs for an answer
+ all_outputs = []
+ for step in tape.steps:
+ if isinstance(step, CodeExecutionResult):
+ output = step.result.output.strip()
+ if output.startswith('"') and output.endswith('"'):
+ output = output[1:-1]
+ all_outputs.append(output)
+
+ combined_output = "\n".join(all_outputs)
+ number_patterns = [
+ r'answer[:\s=]+([+-]?\d+(?:\.\d+)?)',
+ r'result[:\s=]+([+-]?\d+(?:\.\d+)?)',
+ r'([+-]?\d+(?:\.\d+)?)\s*$',
+ r'([+-]?\d+(?:\.\d+)?)',
+ ]
+
+ for pattern in number_patterns:
+ numbers = re.findall(pattern, combined_output, re.IGNORECASE | re.MULTILINE)
+ if numbers:
+ extracted_value = self._extract_numerical_value(numbers[-1])
+ if extracted_value is not None:
+ logger.info(f"Extracted answer from history: {extracted_value}")
+ yield AnswerAction(text=f"The answer is {extracted_value}", value=extracted_value)
+ return
+
+ all_numbers = re.findall(r'([+-]?\d+(?:\.\d+)?)', combined_output)
+ if all_numbers:
+ for num_str in reversed(all_numbers): # Try from last to first
+ extracted_value = self._extract_numerical_value(num_str)
+ if extracted_value is not None:
+ logger.info(f"Extracted fallback answer: {extracted_value}")
+ yield AnswerAction(text=f"Best guess answer: {extracted_value}", value=extracted_value)
+ return
+
+ logger.warning("No numerical answer found in outputs")
+ yield AnswerAction(text="Unable to determine answer", value="No answer")
+ return
+
+ yield PythonCodeAction(name="math_solution.py", code=code, input_files=[])
yield SetNextNode(next_node="code_exec")
+
+ # CASE 3: only boxed answer present?
+ elif boxed_match:
+ value_str = boxed_match.group(1).strip()
+ logger.info(f"Found direct boxed answer: {value_str}")
+
+ extracted_value = self._extract_numerical_value(value_str)
+ if extracted_value is not None:
+ yield AnswerAction(text=f"The answer is {extracted_value}", value=extracted_value)
+ return
+ else:
+ yield AnswerAction(text=f"The answer is {value_str}", value=value_str)
+ return
+
+ # CASE 4: neither code nor answer? - keep going
else:
- # No code or boxed answer - try to extract answer from text
- has_execution_results = any(isinstance(step, CodeExecutionResult) for step in tape.steps)
-
- if has_execution_results:
- # Look for answer patterns in the output
- answer_patterns = [
- r"(?:answer|result)\s+is\s+([+-]?\d*\.?\d+)",
- r"([+-]?\d*\.?\d+)$",
- r"(?:final|answer):\s*([+-]?\d*\.?\d+)",
+ reasoning_attempts = len([s for s in tape.steps if isinstance(s, PythonCodeAction)])
+ parse_failures = len([s for s in tape.steps if isinstance(s, LLMOutputParsingFailureAction)])
+
+ # check for "None" outputs or empty results that indicate unproductive loops
+ recent_results = []
+ for step in tape.steps[-6:]: # last 6 steps
+ if isinstance(step, CodeExecutionResult):
+ result = step.result.output.strip()
+ if result.startswith('"') and result.endswith('"'):
+ result = result[1:-1]
+ recent_results.append(result.lower())
+
+ none_outputs = sum(1 for r in recent_results if r in ['none', '', 'null'])
+
+ should_terminate = (
+ (parse_failures >= 2 and reasoning_attempts >= 4) or
+ (none_outputs >= 2 and reasoning_attempts >= 3) or
+ (reasoning_attempts >= 8) # hard limit
+ )
+
+ if should_terminate:
+ logger.warning(f"Terminating: {reasoning_attempts} attempts, {parse_failures} parse failures, {none_outputs} None outputs")
+ # try to extract any numerical answer from the accumulated outputs
+ all_outputs = []
+ for step in tape.steps:
+ if isinstance(step, CodeExecutionResult):
+ output = step.result.output.strip()
+ if output.startswith('"') and output.endswith('"'):
+ output = output[1:-1]
+ all_outputs.append(output)
+
+ combined_output = "\n".join(all_outputs)
+ number_patterns = [
+ r'answer[:\s=]+([+-]?\d+(?:\.\d+)?)',
+ r'result[:\s=]+([+-]?\d+(?:\.\d+)?)',
+ r'([+-]?\d+(?:\.\d+)?)\s*$', # Number at end
+ r'([+-]?\d+(?:\.\d+)?)', # Any number
]
- for pattern in answer_patterns:
- match = re.search(pattern, output_text, re.IGNORECASE)
- if match:
- try:
- value = float(match.group(1))
- logger.info(f"Extracted answer from text: {value}")
- yield AnswerAction(text=f"The answer is {value}", value=value)
+ for pattern in number_patterns:
+ numbers = re.findall(pattern, combined_output, re.IGNORECASE | re.MULTILINE)
+ if numbers:
+ # Try the improved extraction on the last found number
+ extracted_value = self._extract_numerical_value(numbers[-1])
+ if extracted_value is not None:
+ logger.info(f"Extracting answer from execution history with pattern '{pattern}': {extracted_value}")
+ yield AnswerAction(text=f"The answer is {extracted_value}", value=extracted_value)
return
- except ValueError:
- continue
+
+ logger.warning("No clear numerical answer found, providing default")
+ yield AnswerAction(text="Unable to determine answer", value=0)
+ return
- # Continue iterating
- yield LLMOutputParsingFailureAction(error="No Python code or clear answer found, continuing", llm_output=output_text)
+ yield LLMOutputParsingFailureAction(error="No code or answer found", llm_output=output_text)
yield SetNextNode(next_node="code_exec")
@@ -157,23 +474,32 @@ def generate_steps(self, agent: Any, tape: Tape, llm_stream) -> Generator[Step,
class TIRMathAgent(Agent):
"""TIR (Tool Integrated Reasoning) agent for mathematical problem solving."""
- def __init__(self, system_prompt: str = "", max_iterations: int = 8, **kwargs):
- # Create nodes with the system prompt
+ def __init__(self, system_prompt: str = "", max_iterations: int = 8,
+ max_reasoning_pairs: int = 3, max_code_chars: int = 800,
+ max_output_chars: int = 2000, **kwargs):
nodes = [
CodeExecutionNode(
name="code_exec",
- system_prompt=system_prompt
+ system_prompt=system_prompt,
+ max_reasoning_pairs=max_reasoning_pairs,
+ max_code_chars=max_code_chars,
+ max_output_chars=max_output_chars
),
]
super().__init__(nodes=nodes, max_iterations=max_iterations, **kwargs)
self.store_llm_calls = True
@classmethod
- def create(cls, system_prompt: str, llm: LLM, max_prompt_length: int, max_iterations: int = 8):
+ def create(cls, system_prompt: str, llm: LLM, max_prompt_length: int, max_iterations: int = 8,
+ max_reasoning_pairs: int = 3, max_code_chars: int = 800,
+ max_output_chars: int = 2000):
agent = cls(
system_prompt=system_prompt,
llms={"default": llm},
max_iterations=max_iterations,
+ max_reasoning_pairs=max_reasoning_pairs,
+ max_code_chars=max_code_chars,
+ max_output_chars=max_output_chars,
)
agent.store_llm_calls = True
if agent.llms["default"].tokenizer is None:
@@ -186,7 +512,7 @@ def get_steps_description(self) -> str:
def extract_result_value(sample: dict) -> dict:
"""Extract numerical result from dataset sample."""
- # Compatibility wrapper - actual implementation is in datasets.py
+ # compatibility wrapper - actual implementation is in datasets.py
from .datasets import extract_result_value as datasets_extract_result_value
return datasets_extract_result_value(sample)
@@ -198,7 +524,7 @@ def solve_task(agent: Agent, env, task: dict, tape_file: str = "") -> Tape:
import os
tmp_tape_file = f"{tape_file}.tmp" if tape_file else None
- start_step = Task(task=task["question"])
+ start_step = Task(task=task["task"])
tape = TIRMathTape(steps=[start_step], context=None)
metadata = task.copy()
@@ -218,7 +544,6 @@ def solve_task(agent: Agent, env, task: dict, tape_file: str = "") -> Tape:
metadata["solved"] = False
if isinstance(tape[-1], AnswerAction):
- # Use same verification logic as generate_tir_rollout
try:
from pipelinerl.domains.math.verifier_api import verify_math
predicted_answer = f"\\boxed{{{tape[-1].value}}}"
@@ -227,7 +552,6 @@ def solve_task(agent: Agent, env, task: dict, tape_file: str = "") -> Tape:
metadata["solved"] = (answer_status == "correct")
except Exception as e:
logger.warning(f"Math verification failed: {e}")
- # Fallback to numerical comparison
task_value = task.get("value")
tape_value = tape[-1].value
if task_value is not None and tape_value is not None:
diff --git a/pipelinerl/domains/tir/datasets.py b/pipelinerl/domains/tir/datasets.py
index 8c8543e4..e2dc7fc8 100644
--- a/pipelinerl/domains/tir/datasets.py
+++ b/pipelinerl/domains/tir/datasets.py
@@ -1,5 +1,3 @@
-"""Dataset loading and processing for TIR domain."""
-
import logging
import re
from typing import Dict, Any, List, Union
@@ -22,8 +20,70 @@ def _load_gsm8k_dataset(split: str) -> List[Dict[str, Any]]:
def _load_math_dataset(split: str) -> List[Dict[str, Any]]:
- from pipelinerl.domains.math.load_datasets import load_datasets as math_load_datasets
- return math_load_datasets([f"math_{split}"])
+ """Load MATH dataset directly"""
+ from datasets import load_dataset
+ from pipelinerl.domains.math.load_datasets import add_ids
+
+ dataset = load_dataset("hendrycks/competition_math", "main", split=split, trust_remote_code=True)
+ samples = [s for s in _process_math_for_tir(dataset, f"math_{split}") if s is not None]
+ logger.info(f"Loading math {split} dataset for TIR: {len(samples)} samples")
+ return add_ids(samples)
+
+
+def _process_math_for_tir(dataset, dataset_name):
+ """Process MATH dataset for TIR domain with proper boxed answer extraction."""
+ for item in dataset:
+ if "correctness_math_verify" in item:
+ if not any(item["correctness_math_verify"]):
+ yield None
+ continue
+ if "problem" in item:
+ question = item["problem"]
+ elif "question" in item:
+ question = item["question"]
+ else:
+ yield None
+ continue
+ if "subject" in item and "type" not in item:
+ item["type"] = item["subject"]
+
+ if "answer" in item:
+ answer = "\\boxed{" + item["answer"] + "}"
+ elif "solution" in item:
+ solution = item["solution"]
+ answer = _extract_boxed_answer(solution)
+ else:
+ yield None
+ continue
+
+ sample = {
+ "dataset": dataset_name,
+ "level": item.get("level", ""),
+ "type": item.get("type", ""),
+ "task": question,
+ "answer": answer,
+ }
+ yield sample
+
+
+def _extract_boxed_answer(solution: str) -> str:
+ """Extract the boxed answer from a solution, fallback to full solution if not found."""
+ boxed_start = solution.rfind("\\boxed{")
+ if boxed_start >= 0:
+ brace_count = 0
+ i = boxed_start + 7
+ while i < len(solution):
+ if solution[i] == '{':
+ brace_count += 1
+ elif solution[i] == '}':
+ if brace_count == 0:
+ boxed_content = solution[boxed_start + 7:i]
+ return f"\\boxed{{{boxed_content}}}"
+ else:
+ brace_count -= 1
+ i += 1
+
+ return solution
def _load_aime_dataset(year: int) -> List[Dict[str, Any]]:
@@ -71,11 +131,65 @@ def add_ids(dataset):
return dataset
+def _load_openreasoner_dataset(dataset_name: str) -> List[Dict[str, Any]]:
+ """Load OpenReasoner datasets following the math domain pattern."""
+ try:
+ data_file_urls = {
+ "open_reasoner_zero_57k": "https://raw.githubusercontent.com/Open-Reasoner-Zero/Open-Reasoner-Zero/refs/heads/main/data/orz_math_57k_collected.json",
+ "open_reasoner_zero_extended_72k": "https://raw.githubusercontent.com/Open-Reasoner-Zero/Open-Reasoner-Zero/refs/heads/main/data/orz_math_72k_collection_extended.json",
+ "open_reasoner_zero_hard_13k": "https://raw.githubusercontent.com/Open-Reasoner-Zero/Open-Reasoner-Zero/refs/heads/main/data/orz_math_13k_collection_hard.json",
+ }
+
+ if dataset_name not in data_file_urls:
+ logger.error(f"Unknown OpenReasoner dataset: {dataset_name}")
+ return []
+
+ # Load the dataset from the JSON file
+ dataset = load_dataset(
+ "json",
+ data_files=data_file_urls[dataset_name],
+ split="train",
+ trust_remote_code=True,
+ )
+
+ samples = []
+ for item in dataset:
+ # Format: item["0"]["value"] = task, item["1"]["ground_truth"]["value"] = answer
+ try:
+ task = item["0"]["value"]
+ answer_value = item["1"]["ground_truth"]["value"]
+
+ # Ensure answer is in boxed format
+ if not answer_value.startswith("\\boxed"):
+ answer = f"\\boxed{{{answer_value}}}"
+ else:
+ answer = answer_value
+
+ problem = {
+ "task": task,
+ "answer": answer,
+ "dataset": dataset_name,
+ "level": "",
+ "type": "reasoning",
+ }
+ samples.append(problem)
+
+ except (KeyError, TypeError) as e:
+ logger.warning(f"Skipping malformed item in {dataset_name}: {e}")
+ continue
+
+ logger.info(f"Loaded {dataset_name}: {len(samples)} samples")
+ return samples
+
+ except Exception as e:
+ logger.error(f"Failed to load {dataset_name}: {e}")
+ return []
+
+
def load_datasets(dataset_names: List[str], **kwargs) -> List[Dict[str, Any]]:
"""Load datasets for TIR domain."""
all_problems = []
- # Dataset loading map for cleaner logic
dataset_loaders = {
"gsm8k_train": lambda: _load_gsm8k_dataset("train"),
"gsm8k_test": lambda: _load_gsm8k_dataset("test"),
@@ -86,23 +200,38 @@ def load_datasets(dataset_names: List[str], **kwargs) -> List[Dict[str, Any]]:
"aime_2022": lambda: _load_aime_dataset(2022),
"amc_2023": lambda: _load_amc_dataset(2023),
"amc_2022": lambda: _load_amc_dataset(2022),
+ "open_reasoner_zero_57k": lambda: _load_openreasoner_dataset("open_reasoner_zero_57k"),
+ "open_reasoner_zero_extended_72k": lambda: _load_openreasoner_dataset("open_reasoner_zero_extended_72k"),
+ "open_reasoner_zero_hard_13k": lambda: _load_openreasoner_dataset("open_reasoner_zero_hard_13k"),
}
+ logger.info(f"Attempting to load datasets: {dataset_names}")
+
for name in dataset_names:
if name in dataset_loaders:
- samples = dataset_loaders[name]()
- logger.info(f"Loaded {name}: {len(samples)} samples")
-
- # GSM8K dataset needs IDs
- if name.startswith("gsm8k"):
- samples = add_ids(samples)
-
- all_problems.extend(samples)
+ try:
+ samples = dataset_loaders[name]()
+ logger.info(f"Loaded {name}: {len(samples)} samples")
+
+ if not samples:
+ logger.warning(f"Dataset {name} returned 0 samples!")
+
+ if name.startswith("gsm8k"):
+ samples = add_ids(samples)
+
+ all_problems.extend(samples)
+ except Exception as e:
+ logger.error(f"Failed to load dataset {name}: {e}")
+ continue
else:
logger.warning(f"Unknown dataset: {name}")
- logger.info(f"Loaded {len(all_problems)} problems from {len(dataset_names)} datasets")
+ logger.info(f"Total problems loaded: {len(all_problems)}")
+
+ if not all_problems:
+ raise ValueError(f"No problems loaded from any datasets: {dataset_names}. Check dataset names and network connectivity.")
+
return all_problems
diff --git a/pipelinerl/domains/tir/environment.py b/pipelinerl/domains/tir/environment.py
index c71d7c9c..d9c9040b 100644
--- a/pipelinerl/domains/tir/environment.py
+++ b/pipelinerl/domains/tir/environment.py
@@ -1,15 +1,16 @@
-"""TIR Environment implementations for secure Python code execution."""
-
import asyncio
import logging
import os
-from typing import Union
+import subprocess
+import tempfile
+import threading
+from contextlib import asynccontextmanager
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from tapeagents.remote_environment import EnvironmentServer
-from tapeagents.environment import Environment
+from tapeagents.environment import AsyncEnvironment
from tapeagents.core import Action
from tapeagents.tools.code_executor import PythonCodeAction, CodeExecutionResult
from tapeagents.steps import ActionExecutionFailure
@@ -18,24 +19,236 @@
logger = logging.getLogger(__name__)
-class MCPPythonEnvironment(Environment):
- """Environment using MCP Run Python server for secure code execution."""
+def _parse_mcp_result(mcp_output: str) -> tuple[str, bool]:
+ """Parse MCP output to extract result and determine success."""
+ if "error" in mcp_output:
+ if "" in mcp_output and "" in mcp_output:
+ start = mcp_output.find("") + len("")
+ end = mcp_output.find("")
+ error_msg = mcp_output[start:end].strip()
+ return f"Error: {error_msg}", False
+ else:
+ return "Error: Code execution failed", False
+
+ # Check for