diff --git a/conf/rewards/tir_pure_success.yaml b/conf/rewards/tir_pure_success.yaml new file mode 100644 index 00000000..8dfc9396 --- /dev/null +++ b/conf/rewards/tir_pure_success.yaml @@ -0,0 +1,9 @@ +correct_answer: 1.0 +wrong_answer: -0.15 +no_answer: -0.25 +unparsable: -0.25 +execution_failure: -0.05 +successful_code_execution: 0.0 +timeout_penalty: -0.25 +buffer_tokens: 0 +iteration_penalty: -0.005 diff --git a/conf/tir.yaml b/conf/tir.yaml index 523e747a..3fa9ad08 100644 --- a/conf/tir.yaml +++ b/conf/tir.yaml @@ -6,83 +6,92 @@ save_tapes: true llm: parameters: - max_tokens: 3000 + max_tokens: 512 # Reduced from 1024 to help stay within context limits temperature: 0.2 + stop: + - "```output" + +test_llm: + parameters: + max_tokens: 512 # Reduced from 1024 to help stay within context limits + temperature: 0.2 + stop: + - "```output" actor: rollout_policy: pipelinerl.domains.tir.rollouts.generate_tir_rollout - # TIR mode: 'fast' (single candidate + iterative reasoning) or 'sc_tir' (multiple candidates + majority voting) - mode: fast # Default to fast mode - # SC-TIR parameters (only used when mode=sc_tir) - num_candidates: 4 # Width: number of solution candidates to generate - max_reasoning_steps: 8 # Depth: max reasoning iterations per candidate - system_prompt: |- - You are a mathematical problem-solving assistant. Use Python code to solve problems step by step. - - Instructions: - 1. Read the problem carefully - 2. Write Python code to solve it, showing your work - 3. Execute the code and analyze the output - 4. If needed, write more code to refine your solution - 5. Once you have the final answer, present it in \boxed{} format - - Code format: - ```python - # Your Python code here - ``` - - The code will be executed and you'll see the output like: - ```output - # Execution results will appear here - ``` - - You can write multiple code blocks if needed. Use libraries like sympy, numpy, math as needed. + max_reasoning_steps: 6 # Reduced from 8 to encourage concise reasoning + system_prompt: "" + #TODO: rm debug code + llm_max_rollouts: 1 + + # system_prompt: |- + # You are a mathematical problem-solving assistant. Use Python code to solve problems step by step. + + # Instructions: + # 1. Read the problem carefully + # 2. Write Python code to solve it, showing your work + # 3. Execute the code and analyze the output + # 4. If needed, write more code to refine your solution + # 5. Once you have the final answer, present it in \boxed{} format + + # Code format: + # ```python + # # Your Python code here + # ``` + + # The code will be executed and you'll see the output like: + # ```output + # # Execution results will appear here + # ``` + + # You can write multiple code blocks if needed. Use libraries like sympy, numpy, math as needed. - Always end with your final answer in the format: \boxed{answer} - - Examples: - - Problem: What is 2 + 3? - ```python - result = 2 + 3 - print(f"2 + 3 = {result}") - ``` - ```output - 2 + 3 = 5 - ``` - \boxed{5} - - Problem: Solve x^2 - 5x + 6 = 0 - ```python - from sympy import symbols, solve, expand - x = symbols('x') - equation = x**2 - 5*x + 6 - solutions = solve(equation, x) - print(f"Solutions: {solutions}") + # Always end with your final answer in the format: \boxed{answer} + + # Examples: + + # Problem: What is 2 + 3? + # ```python + # result = 2 + 3 + # print(f"2 + 3 = {result}") + # ``` + # ```output + # 2 + 3 = 5 + # ``` + # \boxed{5} + + # Problem: Solve x^2 - 5x + 6 = 0 + # ```python + # from sympy import symbols, solve, expand + # x = symbols('x') + # equation = x**2 - 5*x + 6 + # solutions = solve(equation, x) + # print(f"Solutions: {solutions}") - # Verify by substitution - for sol in solutions: - check = sol**2 - 5*sol + 6 - print(f"x = {sol}: {sol}^2 - 5*{sol} + 6 = {check}") - ``` - ```output - Solutions: [2, 3] - x = 2: 2^2 - 5*2 + 6 = 0 - x = 3: 3^2 - 5*3 + 6 = 0 - ``` - The solutions are x = 2 and x = 3. - \boxed{2, 3} - task_template: |- - {task} + # # Verify by substitution + # for sol in solutions: + # check = sol**2 - 5*sol + 6 + # print(f"x = {sol}: {sol}^2 - 5*{sol} + 6 = {check}") + # ``` + # ```output + # Solutions: [2, 3] + # x = 2: 2^2 - 5*2 + 6 = 0 + # x = 3: 3^2 - 5*3 + 6 = 0 + # ``` + # The solutions are x = 2 and x = 3. + # \boxed{2, 3} + # task_template: |- + # {task} model_path: /mnt/llmd/base_models/AI-MO-NuminaMath-7B-TIR -output_dir: results/tir/${now:%Y-%m-%d}/${now:%H-%M-%S} +output_dir: /mnt/llmd/results/exps/rafa/tir/${now:%Y-%m-%d}/${now:%H-%M-%S} world: env_replicas: 1 -max_loops: 10 # extra iterations +max_loops: 25 # let the agent handle its own termination environment: _target_: pipelinerl.domains.tir.environment.MCPPythonEnvironment @@ -90,13 +99,17 @@ environment: agent: _target_: pipelinerl.domains.tir.agent.TIRMathAgent system_prompt: ${actor.system_prompt} + max_reasoning_pairs: 3 # Number of recent code-execution pairs to keep in context + max_code_chars: 800 # Maximum characters of code per step + max_output_chars: 2000 # Maximum characters of execution output per step dataset_loader: pipelinerl.domains.tir.datasets.load_datasets train_dataset_names: - - math_train + - open_reasoner_zero_57k + - open_reasoner_zero_extended_72k test_dataset_names: - # - math_test - aime_2024 - # - amc_2023 \ No newline at end of file + - amc_2023 + - math_500 \ No newline at end of file diff --git a/conf/tir_sc.yaml b/conf/tir_sc.yaml deleted file mode 100644 index e144fc41..00000000 --- a/conf/tir_sc.yaml +++ /dev/null @@ -1,12 +0,0 @@ -defaults: - - tir - - _self_ - -# Override to use SC-TIR mode with multiple candidates and majority voting -actor: - mode: sc_tir # Enable SC-TIR mode - num_candidates: 4 # Width: number of solution candidates to generate - max_reasoning_steps: 8 # Depth: max reasoning iterations per candidate - -# SC-TIR is slower, so we might want to reduce other settings for faster iteration -output_dir: results/tir_sc/${now:%Y-%m-%d}/${now:%H-%M-%S} \ No newline at end of file diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 8621250c..d1566e32 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -15,14 +15,14 @@ import aiohttp import hydra import uvloop -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pydantic import BaseModel, Field from tapeagents.llms import TrainableLLM from typing import Dict, List import wandb from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb -from pipelinerl.rollouts import RolloutResult +from pipelinerl.rollouts import RolloutResult, BaseMetrics from pipelinerl.shared_memory_array import SharedMemoryQueue from pipelinerl.state import TrainerState from pipelinerl.streams import ( @@ -257,6 +257,8 @@ def rollout_maker_entrypoint( def random_iter(problems: list): + if not problems: + raise ValueError(f"Cannot iterate over empty problems list. No data was loaded.") while True: yield random.sample(problems, 1)[0] @@ -348,6 +350,7 @@ def update_stats(self, rollout_results: List[RolloutResult]): self.latency_list.append(result.latency) self.model_versions_list.append(result.model_version) domain_agnostic_metrics = self.compute_domain_agnostic_metrics(result) + assert isinstance(result.metrics, BaseMetrics), "Metrics should be an instance of BaseMetrics" all_metrics = result.metrics.model_dump() | domain_agnostic_metrics for k, v in all_metrics.items(): if isinstance(v, list): @@ -528,6 +531,24 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict): for agg, sub_stats in calculate_stats(list_of_stats_per_metric_and_dataset).items(): stats[f"{dataset_name}/{metric_name}_{agg}"] = sub_stats + # Add clean dataset-specific pass rates for test evaluation + if not self.is_training and "success" in self.stats: + dataset_pass_rates = self._calculate_dataset_pass_rates() + stats.update(dataset_pass_rates) + + # Log clean pass rates to console for easy viewing + if dataset_pass_rates: + logger.info("Dataset Pass Rates:") + for key, value in dataset_pass_rates.items(): + if key.startswith("pass_rate/"): + dataset_name = key.replace("pass_rate/", "") + logger.info(f" {dataset_name}: {value:.1f}%") + + # Debug: log all metrics being sent to wandb + logger.info("All dataset metrics being logged to wandb:") + for key, value in dataset_pass_rates.items(): + logger.info(f" actor/{key}: {value}") + stats |= ( { f"{split_name}{k}": v @@ -549,7 +570,49 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict): if self.cfg.wandb.use_wandb: wandb.log({f"actor/{k}": v for k, v in stats.items()}) stats_writer.write(stats) - self.init_stats() # Reset stats for the next iteration + + # Only reset stats for training (not test evaluation) + if self.is_training: + self.init_stats() # Reset stats for the next iteration + + def _calculate_dataset_pass_rates(self) -> Dict[str, float]: + """Calculate clean dataset-specific pass rates matching the table format.""" + pass_rates = {} + + # Dataset name mapping for clean display + dataset_name_mapping = { + "gsm8k_test": "GSM8k", + "gsm8k_train": "GSM8k", + "math_test": "MATH", + "math_train": "MATH", + "aime_2024": "AIME 2024", + "aime_2023": "AIME 2023", + "aime_2022": "AIME 2022", + "amc_2023": "AMC 2023", + "amc_2022": "AMC 2022", + } + + success_stats = self.stats.get("success", {}) + + for dataset_name, group_results in success_stats.items(): + # Flatten all success values for this dataset + all_successes = [] + for group_id, success_list in group_results.items(): + all_successes.extend(success_list) + + if all_successes: + # Calculate pass rate as percentage + pass_rate = (sum(all_successes) / len(all_successes)) * 100 + + # Use clean dataset name if available, otherwise use original + clean_name = dataset_name_mapping.get(dataset_name, dataset_name) + pass_rates[f"pass_rate/{clean_name}"] = pass_rate + + # Track total problems attempted for all datasets + pass_rates[f"problems_solved/{clean_name}"] = sum(all_successes) + pass_rates[f"problems_total/{clean_name}"] = len(all_successes) + + return pass_rates def run_actor_loop(cfg: DictConfig): @@ -592,7 +655,7 @@ def run_actor_loop(cfg: DictConfig): base_url=url, model_name=str(actor_model_path), tokenizer_name=str(actor_model_path), - parameters=cfg.llm.parameters, + parameters=OmegaConf.to_container(cfg.llm.parameters, resolve=True), use_cache=False, collect_logprobs=True, observe_llm_calls=False, @@ -604,7 +667,7 @@ def run_actor_loop(cfg: DictConfig): base_url=url, model_name=str(actor_model_path), tokenizer_name=str(actor_model_path), - parameters=cfg.test_llm.parameters, + parameters=OmegaConf.to_container(cfg.test_llm.parameters, resolve=True), use_cache=False, collect_logprobs=True, observe_llm_calls=False, @@ -648,13 +711,22 @@ def run_actor_loop(cfg: DictConfig): if last_regular_eval == -1 else last_regular_eval + cfg.eval_every_n_versions ) - if ( + + # In eval debug mode, run test evaluation immediately and only once + should_run_test_eval = False + if cfg.debug.mode == "eval" and test_dataset and test_loop_run is None and last_regular_eval == -1: + should_run_test_eval = True + logger.info("Eval debug mode: Running test evaluation immediately") + elif ( cfg.eval_every_n_versions and not cfg.debug.mode and trainer_state.propagated_weight_version >= next_regular_eval and test_dataset and test_loop_run is None ): + should_run_test_eval = True + + if should_run_test_eval: logger.info("Create test loop") test_loop_run = test_loop.run( dataset=test_dataset, @@ -672,6 +744,11 @@ def run_actor_loop(cfg: DictConfig): last_regular_eval = current_eval train_loop.is_scheduling_paused = False logger.info("Test loop finished") + + # In eval debug mode, exit after test evaluation completes + if cfg.debug.mode == "eval": + logger.info("Eval debug mode: Test evaluation completed, exiting") + break # 3. Keep running the training loop _ = next(train_loop_run) diff --git a/pipelinerl/domains/math/load_datasets.py b/pipelinerl/domains/math/load_datasets.py index 2602400d..3ba609f8 100644 --- a/pipelinerl/domains/math/load_datasets.py +++ b/pipelinerl/domains/math/load_datasets.py @@ -216,7 +216,7 @@ def load_datasets(dataset_names: List[str] | str | None) -> List[Tuple[str, Dict if "math_train" in dataset_names: # math_dataset = load_math("train") - dataset = load_dataset("hendrycks/competition_math", split="train", trust_remote_code=True) + dataset = load_dataset("hendrycks/competition_math", "default", split="train", trust_remote_code=True) samples = [s for s in process_math(dataset, "math_train") if s is not None] logger.info(f"Loading math train dataset: {len(samples)} samples") datasets += add_ids(samples) @@ -260,7 +260,7 @@ def load_datasets(dataset_names: List[str] | str | None) -> List[Tuple[str, Dict if "math_test" in dataset_names: # math_dataset = load_math("test") - dataset = load_dataset("hendrycks/competition_math", split="test", trust_remote_code=True) + dataset = load_dataset("hendrycks/competition_math", "default", split="test", trust_remote_code=True) samples = [s for s in process_math(dataset, "math_test") if s is not None] logger.info(f"Loading math test dataset: {len(samples)} samples") datasets += add_ids(samples) diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py index 41a61021..8cd8e163 100644 --- a/pipelinerl/domains/math/rollouts.py +++ b/pipelinerl/domains/math/rollouts.py @@ -7,7 +7,7 @@ from pipelinerl.rollouts import RolloutResult, BaseMetrics from pipelinerl.world import Job from tapeagents.core import Prompt -from tapeagents.llms.trainable import TrainableLLM +from tapeagents.llms import TrainableLLM from pipelinerl.async_llm import llm_async_generate, make_training_text from .verifier_api import verify_answer_rpc diff --git a/pipelinerl/domains/tir/agent.py b/pipelinerl/domains/tir/agent.py index f696753c..b131086e 100644 --- a/pipelinerl/domains/tir/agent.py +++ b/pipelinerl/domains/tir/agent.py @@ -1,12 +1,11 @@ -"""TIR Math Agent implementation for Tool Integrated Reasoning.""" - import logging +import math import re from typing import Any, Generator, Union, Literal from pydantic import Field from tapeagents.agent import Agent -from tapeagents.core import Action, Prompt, Step, Tape, Observation, LLMOutputParsingFailureAction, SetNextNode, AgentStep, StopStep +from tapeagents.core import Prompt, Step, Tape, Observation, LLMOutputParsingFailureAction, SetNextNode, StopStep from tapeagents.llms import LLM from tapeagents.nodes import Node from tapeagents.steps import ActionExecutionFailure @@ -34,109 +33,427 @@ class CodeExecutionNode(Node): """Node that generates Python code to solve math problems with iterative reasoning.""" system_prompt: str = Field(default="", description="System prompt for the node") + max_reasoning_pairs: int = Field(default=3, description="Number of recent code-exec pairs to include in prompt") + max_code_chars: int = Field(default=800, description="Maximum characters of code to include per step") + max_output_chars: int = Field(default=2000, description="Maximum characters of execution output per step") + + def _extract_numerical_value(self, text: str): + """Extract numerical value from text using multiple parsing strategies.""" + if not text or not isinstance(text, str): + return None + + text = text.strip() + if not text: + return None + + # option 1: Simple integer + if re.match(r'^[+-]?\d+$', text): + try: + return int(text) + except ValueError: + pass + + # option 2: Simple float + if re.match(r'^[+-]?\d+\.\d+$', text): + try: + return float(text) + except ValueError: + pass + + # option 3: Scientific notation + if re.match(r'^[+-]?\d+(?:\.\d+)?[eE][+-]?\d+$', text): + try: + return float(text) + except ValueError: + pass + + # option 4: Simple fraction + if '/' in text and len(text.split('/')) == 2: + try: + parts = text.split('/') + num = float(parts[0].strip()) + den = float(parts[1].strip()) + if den != 0: + value = num / den + if abs(value - round(value)) < 0.001: + return round(value) + return value + except ValueError: + pass + + # option 5: Try to evaluate simple arithmetic expressions + try: + import ast + import operator + + # Simple arithmetic operators + ops = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.Pow: operator.pow, + ast.USub: operator.neg, + ast.UAdd: operator.pos, + } + + def safe_eval(node): + if isinstance(node, ast.Constant): # Python 3.8+ + return node.value + elif isinstance(node, ast.Num): # Python < 3.8 + return node.n + elif isinstance(node, ast.BinOp): + return ops[type(node.op)](safe_eval(node.left), safe_eval(node.right)) + elif isinstance(node, ast.UnaryOp): + return ops[type(node.op)](safe_eval(node.operand)) + else: + raise ValueError(f"Unsupported operation: {type(node)}") + + if re.match(r'^[0-9+\-*/().\s]+$', text): + tree = ast.parse(text, mode='eval') + result = safe_eval(tree.body) + if isinstance(result, (int, float)) and not (math.isnan(result) or math.isinf(result)): + if abs(result - round(result)) < 0.001: + return round(result) + return result + except Exception: + pass + + # option 6: Try SymPy parsing (if available) + try: + import sympy as sp + + expr = sp.sympify(text) + + if expr.is_number: + result = float(expr.evalf()) + if not (math.isnan(result) or math.isinf(result)): + if abs(result - round(result)) < 0.001: + return round(result) + return result + + elif expr.free_symbols: + substitutions = {} + for symbol in expr.free_symbols: + var_name = str(symbol) + if var_name in ['x', 'y', 'z']: + substitutions[symbol] = 1 + elif var_name in ['t', 'time']: + substitutions[symbol] = 1 + elif var_name in ['n', 'i', 'j', 'k']: + substitutions[symbol] = 1 + + if substitutions: + try: + substituted = expr.subs(substitutions) + if substituted.is_number: + result = float(substituted.evalf()) + if not (math.isnan(result) or math.isinf(result)): + if abs(result - round(result)) < 0.001: + return round(result) + return result + except Exception as e: + logger.warning(f"Error evaluating SymPy expression: {e}") + pass + except Exception as e: + logger.warning(f"Error evaluating SymPy expression: {e}") + pass + + return None def make_prompt(self, agent: Any, tape: Tape) -> Prompt: messages = [] if self.system_prompt: messages.append({"role": "system", "content": self.system_prompt}) - # Build conversation with task and previous code/results task = tape.steps[0] assert isinstance(task, Task), f"Expected a Task, got {task.__class__.__name__}" - conversation_content = task.llm_view() + messages.append( + {"role": "user", "content": task.llm_view()} + ) - # Add previous code execution attempts and results + # Collect all reasoning steps first + all_reasoning_steps = [] for step in tape.steps[1:]: + if isinstance(step, (PythonCodeAction, CodeExecutionResult, ActionExecutionFailure)): + all_reasoning_steps.append(step) + + # Deterministic last-k pairs + max_items = self.max_reasoning_pairs * 2 + reasoning_steps = all_reasoning_steps[-max_items:] if len(all_reasoning_steps) > max_items else all_reasoning_steps + + assistant_output_content = "" + for step in reasoning_steps: if isinstance(step, PythonCodeAction): - conversation_content += f"\n\n```python\n{step.code}\n```output\n" + code = step.code + if len(code) > self.max_code_chars: + head = code[: int(self.max_code_chars * 0.7)] + tail = code[-int(self.max_code_chars * 0.3):] + code = f"{head}\n# ... [code truncated] ...\n{tail}" + assistant_output_content += f"\n\n```python\n{code}\n```" elif isinstance(step, CodeExecutionResult): result = step.result.output.strip() if "\n\nstdout:" in result: result = result.split("\n\nstdout:")[0].strip() - # Clean up result formatting if result.startswith('"') and result.endswith('"'): result = result[1:-1] - conversation_content += f"{result}\n```" + if len(result) > self.max_output_chars: + lines = result.split('\n') + if len(lines) > 20: + kept_lines = lines[:10] + [f"... [{len(lines)-20} lines omitted] ..."] + lines[-10:] + result = '\n'.join(kept_lines) + else: + result = result[: self.max_output_chars] + "... [output truncated]" + assistant_output_content += f"\n```output\n{result}\n```" elif isinstance(step, ActionExecutionFailure): - conversation_content += f"Error: {step.error}\n```" - - messages.append({"role": "user", "content": conversation_content}) + assistant_output_content += f"\n```output\nError: {step.error}\n```" + + if assistant_output_content: + messages.append({"role": "assistant", "content": assistant_output_content}) - # Load tokenizer if needed llm = agent.llms.get("default") if llm and llm.tokenizer is None: llm.load_tokenizer() if llm and llm.tokenizer: - prompt_token_ids = llm.tokenizer.apply_chat_template( - messages, add_special_tokens=True, add_generation_prompt=True - ) + if messages[-1]["role"] == "user": + prompt_token_ids = llm.tokenizer.apply_chat_template( + messages, add_special_tokens=True, add_generation_prompt=True + ) + else: + prompt_token_ids = llm.tokenizer.apply_chat_template( + messages, add_special_tokens=True, add_generation_prompt=False + ) else: prompt_token_ids = None return Prompt(messages=messages, token_ids=prompt_token_ids) def generate_steps(self, agent: Any, tape: Tape, llm_stream) -> Generator[Step, None, None]: - # Parse LLM output for Python code or final answer output_text = llm_stream.get_output().content if not output_text: yield LLMOutputParsingFailureAction(error="Empty LLM output", llm_output=output_text) yield SetNextNode(next_node="code_exec") return - # Check for boxed answer first + # extract Python code and boxed answer + python_code_pattern = r'```python\s*\n(.*?)```' + code_matches = re.findall(python_code_pattern, output_text, re.DOTALL) + boxed_pattern = r'\\boxed\{([^}]+)\}' boxed_match = re.search(boxed_pattern, output_text) - if boxed_match: - value_str = boxed_match.group(1).strip() - try: - value = float(value_str) - except ValueError: - value = value_str - logger.info(f"Found final boxed answer: {value}") - yield AnswerAction(text=f"The answer is {value}", value=value) + + has_execution_results = any(isinstance(step, CodeExecutionResult) for step in tape.steps) + last_action_was_verification = False + if tape.steps: + for step in reversed(tape.steps): + if isinstance(step, PythonCodeAction): + last_action_was_verification = step.name == "verification.py" + break + + # CASE 1: both code and boxed answer present? + if code_matches and boxed_match and not last_action_was_verification: + code = code_matches[-1].strip() + boxed_value = boxed_match.group(1).strip() + logger.info(f"Found complete solution with code and boxed answer: {boxed_value}") + + yield PythonCodeAction(name="verification.py", code=code, input_files=[]) + yield SetNextNode(next_node="code_exec") return - # Look for Python code blocks - python_code_pattern = r'```python\s*\n(.*?)\n```' - code_matches = re.findall(python_code_pattern, output_text, re.DOTALL) + # CASE 1b: executed verification code - extract result or use boxed answer? + elif last_action_was_verification and has_execution_results: + last_result = None + for step in reversed(tape.steps): + if isinstance(step, CodeExecutionResult): + result = step.result.output.strip() + if result.startswith('"') and result.endswith('"'): + result = result[1:-1] + last_result = result + break + + if last_result: + logger.info(f"Last execution result for answer extraction: '{last_result}'") + lines = last_result.strip().split('\n') + for i, line in enumerate(reversed(lines)): + line = line.strip() + logger.info(f"Checking line {i}: '{line}'") + + extracted_value = self._extract_numerical_value(line) + if extracted_value is not None: + logger.info(f"Using execution result as answer: {extracted_value}") + yield AnswerAction(text=f"The answer is {extracted_value}", value=extracted_value) + return + + # fallback: find boxed answer from original complete solution in tape history + original_boxed_answer = None + if hasattr(agent, 'llm_calls') and agent.llm_calls: + for llm_call in reversed(agent.llm_calls): + if hasattr(llm_call, 'response') and llm_call.response: + content = llm_call.response.content + if '```python' in content and '\\boxed{' in content: + boxed_match_history = re.search(r'\\boxed\{([^}]+)\}', content) + if boxed_match_history: + original_boxed_answer = boxed_match_history.group(1).strip() + break + + if original_boxed_answer: + logger.info(f"Falling back to original boxed answer: '{original_boxed_answer}'") + try: + if '/' in original_boxed_answer and len(original_boxed_answer.split('/')) == 2: + parts = original_boxed_answer.split('/') + value = float(parts[0]) / float(parts[1]) + else: + value = float(original_boxed_answer) + except ValueError: + value = original_boxed_answer + yield AnswerAction(text=f"The answer is {value}", value=value) + return + + # something went wrong? + logger.warning("Failed to extract answer from verification step") + yield AnswerAction(text="Unable to determine answer", value=0) + return - if code_matches: - # Take the last code block + # CASE 2: only code present? + elif code_matches: code = code_matches[-1].strip() - logger.info(f"Extracted Python code: {code[:100]}...") - yield PythonCodeAction( - name="math_solution.py", - code=code, - input_files=[] - ) + logger.info(f"Extracted Python code for iteration: {code[:100]}...") + + # why are we still generating code? + reasoning_attempts = len([s for s in tape.steps if isinstance(s, PythonCodeAction)]) + + recent_results = [] + recent_errors = [] + for step in tape.steps[-8:]: # last 8 steps + if isinstance(step, CodeExecutionResult): + result = step.result.output.strip() + if result.startswith('"') and result.endswith('"'): + result = result[1:-1] + recent_results.append(result.lower()) + elif isinstance(step, ActionExecutionFailure): + recent_errors.append(step.error) + + none_outputs = sum(1 for r in recent_results if r in ['none', '', 'null']) + same_outputs = len(recent_results) - len(set(recent_results)) if recent_results else 0 + + if (none_outputs >= 2 and reasoning_attempts >= 3) or \ + (same_outputs >= 2 and reasoning_attempts >= 4) or \ + reasoning_attempts >= 6: + logger.warning(f"Stopping code execution: {reasoning_attempts} attempts, {none_outputs} None outputs, {same_outputs} repeated outputs") + + # look at previous outputs for an answer + all_outputs = [] + for step in tape.steps: + if isinstance(step, CodeExecutionResult): + output = step.result.output.strip() + if output.startswith('"') and output.endswith('"'): + output = output[1:-1] + all_outputs.append(output) + + combined_output = "\n".join(all_outputs) + number_patterns = [ + r'answer[:\s=]+([+-]?\d+(?:\.\d+)?)', + r'result[:\s=]+([+-]?\d+(?:\.\d+)?)', + r'([+-]?\d+(?:\.\d+)?)\s*$', + r'([+-]?\d+(?:\.\d+)?)', + ] + + for pattern in number_patterns: + numbers = re.findall(pattern, combined_output, re.IGNORECASE | re.MULTILINE) + if numbers: + extracted_value = self._extract_numerical_value(numbers[-1]) + if extracted_value is not None: + logger.info(f"Extracted answer from history: {extracted_value}") + yield AnswerAction(text=f"The answer is {extracted_value}", value=extracted_value) + return + + all_numbers = re.findall(r'([+-]?\d+(?:\.\d+)?)', combined_output) + if all_numbers: + for num_str in reversed(all_numbers): # Try from last to first + extracted_value = self._extract_numerical_value(num_str) + if extracted_value is not None: + logger.info(f"Extracted fallback answer: {extracted_value}") + yield AnswerAction(text=f"Best guess answer: {extracted_value}", value=extracted_value) + return + + logger.warning("No numerical answer found in outputs") + yield AnswerAction(text="Unable to determine answer", value="No answer") + return + + yield PythonCodeAction(name="math_solution.py", code=code, input_files=[]) yield SetNextNode(next_node="code_exec") + + # CASE 3: only boxed answer present? + elif boxed_match: + value_str = boxed_match.group(1).strip() + logger.info(f"Found direct boxed answer: {value_str}") + + extracted_value = self._extract_numerical_value(value_str) + if extracted_value is not None: + yield AnswerAction(text=f"The answer is {extracted_value}", value=extracted_value) + return + else: + yield AnswerAction(text=f"The answer is {value_str}", value=value_str) + return + + # CASE 4: neither code nor answer? - keep going else: - # No code or boxed answer - try to extract answer from text - has_execution_results = any(isinstance(step, CodeExecutionResult) for step in tape.steps) - - if has_execution_results: - # Look for answer patterns in the output - answer_patterns = [ - r"(?:answer|result)\s+is\s+([+-]?\d*\.?\d+)", - r"([+-]?\d*\.?\d+)$", - r"(?:final|answer):\s*([+-]?\d*\.?\d+)", + reasoning_attempts = len([s for s in tape.steps if isinstance(s, PythonCodeAction)]) + parse_failures = len([s for s in tape.steps if isinstance(s, LLMOutputParsingFailureAction)]) + + # check for "None" outputs or empty results that indicate unproductive loops + recent_results = [] + for step in tape.steps[-6:]: # last 6 steps + if isinstance(step, CodeExecutionResult): + result = step.result.output.strip() + if result.startswith('"') and result.endswith('"'): + result = result[1:-1] + recent_results.append(result.lower()) + + none_outputs = sum(1 for r in recent_results if r in ['none', '', 'null']) + + should_terminate = ( + (parse_failures >= 2 and reasoning_attempts >= 4) or + (none_outputs >= 2 and reasoning_attempts >= 3) or + (reasoning_attempts >= 8) # hard limit + ) + + if should_terminate: + logger.warning(f"Terminating: {reasoning_attempts} attempts, {parse_failures} parse failures, {none_outputs} None outputs") + # try to extract any numerical answer from the accumulated outputs + all_outputs = [] + for step in tape.steps: + if isinstance(step, CodeExecutionResult): + output = step.result.output.strip() + if output.startswith('"') and output.endswith('"'): + output = output[1:-1] + all_outputs.append(output) + + combined_output = "\n".join(all_outputs) + number_patterns = [ + r'answer[:\s=]+([+-]?\d+(?:\.\d+)?)', + r'result[:\s=]+([+-]?\d+(?:\.\d+)?)', + r'([+-]?\d+(?:\.\d+)?)\s*$', # Number at end + r'([+-]?\d+(?:\.\d+)?)', # Any number ] - for pattern in answer_patterns: - match = re.search(pattern, output_text, re.IGNORECASE) - if match: - try: - value = float(match.group(1)) - logger.info(f"Extracted answer from text: {value}") - yield AnswerAction(text=f"The answer is {value}", value=value) + for pattern in number_patterns: + numbers = re.findall(pattern, combined_output, re.IGNORECASE | re.MULTILINE) + if numbers: + # Try the improved extraction on the last found number + extracted_value = self._extract_numerical_value(numbers[-1]) + if extracted_value is not None: + logger.info(f"Extracting answer from execution history with pattern '{pattern}': {extracted_value}") + yield AnswerAction(text=f"The answer is {extracted_value}", value=extracted_value) return - except ValueError: - continue + + logger.warning("No clear numerical answer found, providing default") + yield AnswerAction(text="Unable to determine answer", value=0) + return - # Continue iterating - yield LLMOutputParsingFailureAction(error="No Python code or clear answer found, continuing", llm_output=output_text) + yield LLMOutputParsingFailureAction(error="No code or answer found", llm_output=output_text) yield SetNextNode(next_node="code_exec") @@ -157,23 +474,32 @@ def generate_steps(self, agent: Any, tape: Tape, llm_stream) -> Generator[Step, class TIRMathAgent(Agent): """TIR (Tool Integrated Reasoning) agent for mathematical problem solving.""" - def __init__(self, system_prompt: str = "", max_iterations: int = 8, **kwargs): - # Create nodes with the system prompt + def __init__(self, system_prompt: str = "", max_iterations: int = 8, + max_reasoning_pairs: int = 3, max_code_chars: int = 800, + max_output_chars: int = 2000, **kwargs): nodes = [ CodeExecutionNode( name="code_exec", - system_prompt=system_prompt + system_prompt=system_prompt, + max_reasoning_pairs=max_reasoning_pairs, + max_code_chars=max_code_chars, + max_output_chars=max_output_chars ), ] super().__init__(nodes=nodes, max_iterations=max_iterations, **kwargs) self.store_llm_calls = True @classmethod - def create(cls, system_prompt: str, llm: LLM, max_prompt_length: int, max_iterations: int = 8): + def create(cls, system_prompt: str, llm: LLM, max_prompt_length: int, max_iterations: int = 8, + max_reasoning_pairs: int = 3, max_code_chars: int = 800, + max_output_chars: int = 2000): agent = cls( system_prompt=system_prompt, llms={"default": llm}, max_iterations=max_iterations, + max_reasoning_pairs=max_reasoning_pairs, + max_code_chars=max_code_chars, + max_output_chars=max_output_chars, ) agent.store_llm_calls = True if agent.llms["default"].tokenizer is None: @@ -186,7 +512,7 @@ def get_steps_description(self) -> str: def extract_result_value(sample: dict) -> dict: """Extract numerical result from dataset sample.""" - # Compatibility wrapper - actual implementation is in datasets.py + # compatibility wrapper - actual implementation is in datasets.py from .datasets import extract_result_value as datasets_extract_result_value return datasets_extract_result_value(sample) @@ -198,7 +524,7 @@ def solve_task(agent: Agent, env, task: dict, tape_file: str = "") -> Tape: import os tmp_tape_file = f"{tape_file}.tmp" if tape_file else None - start_step = Task(task=task["question"]) + start_step = Task(task=task["task"]) tape = TIRMathTape(steps=[start_step], context=None) metadata = task.copy() @@ -218,7 +544,6 @@ def solve_task(agent: Agent, env, task: dict, tape_file: str = "") -> Tape: metadata["solved"] = False if isinstance(tape[-1], AnswerAction): - # Use same verification logic as generate_tir_rollout try: from pipelinerl.domains.math.verifier_api import verify_math predicted_answer = f"\\boxed{{{tape[-1].value}}}" @@ -227,7 +552,6 @@ def solve_task(agent: Agent, env, task: dict, tape_file: str = "") -> Tape: metadata["solved"] = (answer_status == "correct") except Exception as e: logger.warning(f"Math verification failed: {e}") - # Fallback to numerical comparison task_value = task.get("value") tape_value = tape[-1].value if task_value is not None and tape_value is not None: diff --git a/pipelinerl/domains/tir/datasets.py b/pipelinerl/domains/tir/datasets.py index 8c8543e4..e2dc7fc8 100644 --- a/pipelinerl/domains/tir/datasets.py +++ b/pipelinerl/domains/tir/datasets.py @@ -1,5 +1,3 @@ -"""Dataset loading and processing for TIR domain.""" - import logging import re from typing import Dict, Any, List, Union @@ -22,8 +20,70 @@ def _load_gsm8k_dataset(split: str) -> List[Dict[str, Any]]: def _load_math_dataset(split: str) -> List[Dict[str, Any]]: - from pipelinerl.domains.math.load_datasets import load_datasets as math_load_datasets - return math_load_datasets([f"math_{split}"]) + """Load MATH dataset directly""" + from datasets import load_dataset + from pipelinerl.domains.math.load_datasets import add_ids + + dataset = load_dataset("hendrycks/competition_math", "main", split=split, trust_remote_code=True) + samples = [s for s in _process_math_for_tir(dataset, f"math_{split}") if s is not None] + logger.info(f"Loading math {split} dataset for TIR: {len(samples)} samples") + return add_ids(samples) + + +def _process_math_for_tir(dataset, dataset_name): + """Process MATH dataset for TIR domain with proper boxed answer extraction.""" + for item in dataset: + if "correctness_math_verify" in item: + if not any(item["correctness_math_verify"]): + yield None + continue + if "problem" in item: + question = item["problem"] + elif "question" in item: + question = item["question"] + else: + yield None + continue + if "subject" in item and "type" not in item: + item["type"] = item["subject"] + + if "answer" in item: + answer = "\\boxed{" + item["answer"] + "}" + elif "solution" in item: + solution = item["solution"] + answer = _extract_boxed_answer(solution) + else: + yield None + continue + + sample = { + "dataset": dataset_name, + "level": item.get("level", ""), + "type": item.get("type", ""), + "task": question, + "answer": answer, + } + yield sample + + +def _extract_boxed_answer(solution: str) -> str: + """Extract the boxed answer from a solution, fallback to full solution if not found.""" + boxed_start = solution.rfind("\\boxed{") + if boxed_start >= 0: + brace_count = 0 + i = boxed_start + 7 + while i < len(solution): + if solution[i] == '{': + brace_count += 1 + elif solution[i] == '}': + if brace_count == 0: + boxed_content = solution[boxed_start + 7:i] + return f"\\boxed{{{boxed_content}}}" + else: + brace_count -= 1 + i += 1 + + return solution def _load_aime_dataset(year: int) -> List[Dict[str, Any]]: @@ -71,11 +131,65 @@ def add_ids(dataset): return dataset +def _load_openreasoner_dataset(dataset_name: str) -> List[Dict[str, Any]]: + """Load OpenReasoner datasets following the math domain pattern.""" + try: + data_file_urls = { + "open_reasoner_zero_57k": "https://raw.githubusercontent.com/Open-Reasoner-Zero/Open-Reasoner-Zero/refs/heads/main/data/orz_math_57k_collected.json", + "open_reasoner_zero_extended_72k": "https://raw.githubusercontent.com/Open-Reasoner-Zero/Open-Reasoner-Zero/refs/heads/main/data/orz_math_72k_collection_extended.json", + "open_reasoner_zero_hard_13k": "https://raw.githubusercontent.com/Open-Reasoner-Zero/Open-Reasoner-Zero/refs/heads/main/data/orz_math_13k_collection_hard.json", + } + + if dataset_name not in data_file_urls: + logger.error(f"Unknown OpenReasoner dataset: {dataset_name}") + return [] + + # Load the dataset from the JSON file + dataset = load_dataset( + "json", + data_files=data_file_urls[dataset_name], + split="train", + trust_remote_code=True, + ) + + samples = [] + for item in dataset: + # Format: item["0"]["value"] = task, item["1"]["ground_truth"]["value"] = answer + try: + task = item["0"]["value"] + answer_value = item["1"]["ground_truth"]["value"] + + # Ensure answer is in boxed format + if not answer_value.startswith("\\boxed"): + answer = f"\\boxed{{{answer_value}}}" + else: + answer = answer_value + + problem = { + "task": task, + "answer": answer, + "dataset": dataset_name, + "level": "", + "type": "reasoning", + } + samples.append(problem) + + except (KeyError, TypeError) as e: + logger.warning(f"Skipping malformed item in {dataset_name}: {e}") + continue + + logger.info(f"Loaded {dataset_name}: {len(samples)} samples") + return samples + + except Exception as e: + logger.error(f"Failed to load {dataset_name}: {e}") + return [] + + def load_datasets(dataset_names: List[str], **kwargs) -> List[Dict[str, Any]]: """Load datasets for TIR domain.""" all_problems = [] - # Dataset loading map for cleaner logic dataset_loaders = { "gsm8k_train": lambda: _load_gsm8k_dataset("train"), "gsm8k_test": lambda: _load_gsm8k_dataset("test"), @@ -86,23 +200,38 @@ def load_datasets(dataset_names: List[str], **kwargs) -> List[Dict[str, Any]]: "aime_2022": lambda: _load_aime_dataset(2022), "amc_2023": lambda: _load_amc_dataset(2023), "amc_2022": lambda: _load_amc_dataset(2022), + "open_reasoner_zero_57k": lambda: _load_openreasoner_dataset("open_reasoner_zero_57k"), + "open_reasoner_zero_extended_72k": lambda: _load_openreasoner_dataset("open_reasoner_zero_extended_72k"), + "open_reasoner_zero_hard_13k": lambda: _load_openreasoner_dataset("open_reasoner_zero_hard_13k"), } + logger.info(f"Attempting to load datasets: {dataset_names}") + for name in dataset_names: if name in dataset_loaders: - samples = dataset_loaders[name]() - logger.info(f"Loaded {name}: {len(samples)} samples") - - # GSM8K dataset needs IDs - if name.startswith("gsm8k"): - samples = add_ids(samples) - - all_problems.extend(samples) + try: + samples = dataset_loaders[name]() + logger.info(f"Loaded {name}: {len(samples)} samples") + + if not samples: + logger.warning(f"Dataset {name} returned 0 samples!") + + if name.startswith("gsm8k"): + samples = add_ids(samples) + + all_problems.extend(samples) + except Exception as e: + logger.error(f"Failed to load dataset {name}: {e}") + continue else: logger.warning(f"Unknown dataset: {name}") - logger.info(f"Loaded {len(all_problems)} problems from {len(dataset_names)} datasets") + logger.info(f"Total problems loaded: {len(all_problems)}") + + if not all_problems: + raise ValueError(f"No problems loaded from any datasets: {dataset_names}. Check dataset names and network connectivity.") + return all_problems diff --git a/pipelinerl/domains/tir/environment.py b/pipelinerl/domains/tir/environment.py index c71d7c9c..d9c9040b 100644 --- a/pipelinerl/domains/tir/environment.py +++ b/pipelinerl/domains/tir/environment.py @@ -1,15 +1,16 @@ -"""TIR Environment implementations for secure Python code execution.""" - import asyncio import logging import os -from typing import Union +import subprocess +import tempfile +import threading +from contextlib import asynccontextmanager from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from tapeagents.remote_environment import EnvironmentServer -from tapeagents.environment import Environment +from tapeagents.environment import AsyncEnvironment from tapeagents.core import Action from tapeagents.tools.code_executor import PythonCodeAction, CodeExecutionResult from tapeagents.steps import ActionExecutionFailure @@ -18,24 +19,236 @@ logger = logging.getLogger(__name__) -class MCPPythonEnvironment(Environment): - """Environment using MCP Run Python server for secure code execution.""" +def _parse_mcp_result(mcp_output: str) -> tuple[str, bool]: + """Parse MCP output to extract result and determine success.""" + if "error" in mcp_output: + if "" in mcp_output and "" in mcp_output: + start = mcp_output.find("") + len("") + end = mcp_output.find("") + error_msg = mcp_output[start:end].strip() + return f"Error: {error_msg}", False + else: + return "Error: Code execution failed", False + + # Check for tags first (common in MCP responses) + if "" in mcp_output and "" in mcp_output: + start = mcp_output.find("") + len("") + end = mcp_output.find("") + output = mcp_output[start:end].strip() + return output if output else "No output produced", True + + if "" in mcp_output and "" in mcp_output: + start = mcp_output.find("") + len("") + end = mcp_output.find("") + output = mcp_output[start:end].strip() + + if output.startswith("[") and output.endswith("]"): + output = output[1:-1].strip() + + return output if output else "No output produced", True + + elif "" in mcp_output and "" in mcp_output: + start = mcp_output.find("") + len("") + end = mcp_output.find("") + return_value = mcp_output[start:end].strip() + + if return_value.startswith("[") and return_value.endswith("]"): + return_value = return_value[1:-1].strip() + + return return_value, True + + elif "" in mcp_output and "" in mcp_output: + start = mcp_output.find("") + len("") + end = mcp_output.find("") + error_msg = mcp_output[start:end].strip() + + if "Traceback" in error_msg: + lines = error_msg.split('\n') + last_line = lines[-1] if lines else error_msg + return f"Error: {last_line}", False + else: + return f"Error: {error_msg}", False + + else: + clean_output = mcp_output.strip() + return clean_output if clean_output else "No output produced", True + +# Global shared Deno setup to avoid per-environment complexity +_global_deno_setup_lock = threading.Lock() +_global_deno_setup_done = False +_global_deno_work_dir = None + +try: + from filelock import FileLock + _deno_file_lock = FileLock("/tmp/deno_mcp_start.lock") +except ImportError: + logger.warning("filelock not available, using threading lock only") + _deno_file_lock = None + + +@asynccontextmanager +async def _stdio_client_with_stderr(server_params): + """Wrap stdio_client and try to capture any Deno process errors.""" + try: + async with stdio_client(server_params) as pipes: + yield pipes + except Exception as e: + logger.error(f"MCP stdio_client failed: {e}") + try: + result = subprocess.run([ + server_params.command, + *server_params.args + ], + env=server_params.env, + cwd=server_params.cwd, + capture_output=True, + text=True, + timeout=10, + input="" + ) + if result.stderr: + logger.error("⇢ Deno stderr:\n%s\n⇠ end stderr", result.stderr.strip()) + if result.stdout: + logger.error("⇢ Deno stdout:\n%s\n⇠ end stdout", result.stdout.strip()) + logger.error(f"Deno exit code: {result.returncode}") + except Exception as debug_e: + logger.error(f"Failed to debug Deno command: {debug_e}") + raise e + + +def _ensure_global_deno_setup(): + """Ensure Deno and MCP package are set up globally once.""" + global _global_deno_setup_done, _global_deno_work_dir + + with _global_deno_setup_lock: + if _global_deno_setup_done: + return _global_deno_work_dir + + logger.info("Setting up global Deno environment for MCP") + + _global_deno_work_dir = tempfile.mkdtemp(prefix="deno_global_") + + env_vars = { + 'PATH': os.environ.get('PATH', ''), + 'DENO_NO_UPDATE_CHECK': '1', + } + + # install Deno if not found + deno_install_dir = os.environ.get('DENO_INSTALL', os.path.expanduser('~/.deno')) + deno_bin_path = os.path.join(deno_install_dir, 'bin', 'deno') + + if not os.path.exists(deno_bin_path): + logger.info(f"Installing Deno to {deno_install_dir}") + try: + install_cmd = f'curl -fsSL https://deno.land/install.sh | bash -s -- -q -d {deno_install_dir}' + subprocess.run(install_cmd, shell=True, check=True, timeout=120, + env={**env_vars, 'DENO_INSTALL': deno_install_dir}) + logger.info(f"Deno installed successfully to {deno_bin_path}") + except Exception as e: + logger.error(f"Deno installation failed: {e}") + return None + + deno_bin_dir = os.path.join(deno_install_dir, 'bin') + env_vars['PATH'] = f"{deno_bin_dir}:{env_vars['PATH']}" + + os.environ['PATH'] = env_vars['PATH'] + + try: + deno_test = subprocess.run(['deno', '--version'], env=env_vars, capture_output=True, text=True, timeout=10) + logger.info(f"Deno version: {deno_test.stdout.strip()}") + except Exception as e: + logger.error(f"Deno test failed: {e}") + return None + + # Pre-cache MCP package globally + for attempt in range(2): + try: + subprocess.run([ + 'deno', 'cache', + '--node-modules-dir=auto', + 'jsr:@pydantic/mcp-run-python' + ], env=env_vars, cwd=_global_deno_work_dir, check=True, timeout=120) + logger.info("Global MCP package caching completed") + _global_deno_setup_done = True + return _global_deno_work_dir + except Exception as e: + if attempt == 0 and "Failed reading cache entry" in str(e): + logger.warning(f"Deno cache corruption detected, clearing cache and retrying: {e}") + try: + subprocess.run(['deno', 'cache', '--reload', 'jsr:@pydantic/mcp-run-python'], + env=env_vars, cwd=_global_deno_work_dir, timeout=120) + except Exception as clear_e: + logger.error(f"Cache clear failed: {clear_e}") + else: + logger.error(f"Global MCP package caching failed: {e}") + return None + + +class MCPPythonEnvironment(AsyncEnvironment): + """Environment using (Pydantic) MCP Run Python server for sandboxed code execution.""" + + # Class-level lock to serialize Deno session creation within each Python worker + _deno_lock = asyncio.Lock() def __init__(self): super().__init__() - self.server_params = StdioServerParameters( - command='deno', - args=[ - 'run', - '-N', - '-R=node_modules', - '-W=node_modules', - '--node-modules-dir=auto', - 'jsr:@pydantic/mcp-run-python', - 'stdio', - ], - ) - logger.info("MCP Python environment initialized") + + # do Deno setup lazily when needed + self.work_dir = None + + self.env_vars = None + + self.server_params = None + logger.info("MCP Python environment initialized (Deno setup will be done lazily)") + + def _ensure_setup(self): + """Ensure Deno setup is complete (called lazily when needed).""" + if self.work_dir is None: + self.work_dir = _ensure_global_deno_setup() + if not self.work_dir: + raise RuntimeError("Failed to set up global Deno environment") + + deno_install_dir = os.environ.get('DENO_INSTALL', os.path.expanduser('~/.deno')) + deno_bin_dir = os.path.join(deno_install_dir, 'bin') + + current_path = os.environ.get('PATH', '') + if deno_bin_dir not in current_path: + new_path = f"{deno_bin_dir}:{current_path}" + else: + new_path = current_path + + self.env_vars = { + 'PATH': new_path, + 'DENO_NO_UPDATE_CHECK': '1', + } + + self.work_dir = tempfile.mkdtemp(prefix="mcp_env_") + + self.server_params = StdioServerParameters( + command='deno', + args=[ + 'run', + '-A', + '--quiet', + 'jsr:@pydantic/mcp-run-python', + 'stdio', + ], + env=self.env_vars, + cwd=self.work_dir, + ) + logger.info(f"MCP Python environment setup completed with work dir: {self.work_dir}") + + def __del__(self): + """Clean up instance-specific temporary directory.""" + try: + import shutil + if hasattr(self, 'work_dir') and self.work_dir and os.path.exists(self.work_dir): + # don't clean up the global deno work dir + if self.work_dir != _global_deno_work_dir: + shutil.rmtree(self.work_dir) + logger.debug(f"Cleaned up work dir: {self.work_dir}") + except Exception as e: + logger.warning(f"Failed to clean up work dir: {e}") def launch(self, port: int): """Launch the environment as a server.""" @@ -53,6 +266,7 @@ def launch(self, port: int): })) def react(self, tape): + """Synchronous react method for backward compatibility.""" actions = [step for step in tape.steps[-tape.metadata.n_added_steps :] if isinstance(step, Action)] for action in actions: @@ -62,27 +276,23 @@ def react(self, tape): try: logger.info(f"Executing Python code via MCP: {repr(action.code[:100])}...") - # Execute code using MCP - handle async properly try: - loop = asyncio.get_running_loop() - # Run in thread to avoid event loop conflicts + asyncio.get_running_loop() import concurrent.futures - import threading def run_in_thread(): return asyncio.run(self._execute_python_code(action.code)) with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(run_in_thread) - result = future.result(timeout=30) + result = future.result(timeout=90) except RuntimeError: - # No running loop result = asyncio.run(self._execute_python_code(action.code)) - logger.info(f"MCP execution result: {repr(result[:200])}...") + logger.info(f"MCP execution result: {repr(result)}") - output, success = self._parse_mcp_result(result) + output, success = _parse_mcp_result(result) observation = CodeExecutionResult( result=CommandLineCodeResult( @@ -93,6 +303,10 @@ def run_in_thread(): tape = tape.append(observation) + except TimeoutError as e: + logger.warning(f"Code execution timed out: {e}") + tape = tape.append(ActionExecutionFailure(error=f"Timeout: {e}")) + break except Exception as e: logger.error(f"MCP execution failed: {e}") tape = tape.append(ActionExecutionFailure(error=str(e))) @@ -100,65 +314,98 @@ def run_in_thread(): return tape - async def _execute_python_code(self, code: str) -> str: - """Execute Python code using MCP Run Python server""" - async with stdio_client(self.server_params) as (read, write): - async with ClientSession(read, write) as session: - await session.initialize() - result = await session.call_tool('run_python_code', {'python_code': code}) - return result.content[0].text - - def _parse_mcp_result(self, mcp_output: str) -> tuple[str, bool]: - """Parse MCP output to extract result and determine success.""" - # Check for error status - if "error" in mcp_output: - if "" in mcp_output and "" in mcp_output: - start = mcp_output.find("") + len("") - end = mcp_output.find("") - error_msg = mcp_output[start:end].strip() - return f"Error: {error_msg}", False - else: - return "Error: Code execution failed", False - - # Look for successful output in tags - if "" in mcp_output and "" in mcp_output: - start = mcp_output.find("") + len("") - end = mcp_output.find("") - output = mcp_output[start:end].strip() - - # Remove list brackets if present - if output.startswith("[") and output.endswith("]"): - output = output[1:-1].strip() - - return output if output else "No output produced", True + async def areact(self, tape): + """Async react method for use with async_execute_agent.""" + actions = [step for step in tape.steps[-tape.metadata.n_added_steps :] if isinstance(step, Action)] - # Try return_value format - elif "" in mcp_output and "" in mcp_output: - start = mcp_output.find("") + len("") - end = mcp_output.find("") - return_value = mcp_output[start:end].strip() - - # Remove list brackets - if return_value.startswith("[") and return_value.endswith("]"): - return_value = return_value[1:-1].strip() - - return return_value, True + for action in actions: + if not isinstance(action, PythonCodeAction): + continue + + try: + logger.info(f"Executing Python code via MCP: {repr(action.code[:100])}...") + + result = await self._execute_python_code(action.code) + logger.info(f"MCP execution result: {repr(result)}") + + output, success = _parse_mcp_result(result) + + observation = CodeExecutionResult( + result=CommandLineCodeResult( + output=output, + exit_code=0 if success else 1 + ) + ) + + tape = tape.append(observation) + + except TimeoutError as e: + logger.warning(f"Code execution timed out: {e}") + tape = tape.append(ActionExecutionFailure(error=f"Timeout: {e}")) + break + except Exception as e: + logger.error(f"MCP execution failed: {e}") + tape = tape.append(ActionExecutionFailure(error=str(e))) + break + + return tape + + async def _execute_python_code(self, code: str) -> str: + """Execute Python code using MCP Run Python server with retry mechanism and serialization.""" + self._ensure_setup() - # Check for stderr errors - elif "" in mcp_output and "" in mcp_output: - start = mcp_output.find("") + len("") - end = mcp_output.find("") - error_msg = mcp_output[start:end].strip() - - # Clean up Python tracebacks - if "Traceback" in error_msg: - lines = error_msg.split('\n') - last_line = lines[-1] if lines else error_msg - return f"Error: {last_line}", False + async with self._deno_lock: + if _deno_file_lock: + with _deno_file_lock: + return await self._run_mcp_session(code) else: - return f"Error: {error_msg}", False - - else: - # No structured output - return raw - clean_output = mcp_output.strip() - return clean_output if clean_output else "No output produced", True + return await self._run_mcp_session(code) + + async def _run_mcp_session(self, code: str) -> str: + """Run MCP session with retry logic.""" + max_retries = 2 + for attempt in range(max_retries): + try: + logger.debug(f"MCP execution attempt {attempt + 1}/{max_retries}") + + test_result = subprocess.run([ + self.server_params.command, '--version' + ], env=self.server_params.env, capture_output=True, text=True, timeout=5) + if test_result.returncode != 0: + logger.error(f"Deno version check failed: {test_result.stderr}") + raise RuntimeError(f"Deno not working: {test_result.stderr}") + logger.debug(f"Deno version check passed: {test_result.stdout.strip()}") + + async with _stdio_client_with_stderr(self.server_params) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + try: + result = await asyncio.wait_for( + session.call_tool('run_python_code', {'python_code': code}), + timeout=30.0 + ) + return result.content[0].text + except asyncio.TimeoutError: + raise TimeoutError("Code execution timed out after 30 seconds") + except Exception as e: + logger.error(f"MCP tool call failed on attempt {attempt + 1}: {e}") + if attempt == max_retries - 1: + raise e + await asyncio.sleep(0.5) + except Exception as e: + logger.error(f"MCP session setup failed on attempt {attempt + 1}: {e}") + if attempt == max_retries - 1: + logger.error(f"Server params: {self.server_params}") + logger.error(f"Environment vars: {self.env_vars}") + raise RuntimeError(f"MCP execution failed after {max_retries} attempts. This indicates a serious issue with the Deno/MCP setup. Code execution cannot proceed safely without proper sandboxing.") + elif "Failed reading cache entry" in str(e): + logger.warning(f"Deno cache corruption detected during MCP execution, clearing cache: {e}") + try: + subprocess.run([ + 'deno', 'cache', '--reload', 'jsr:@pydantic/mcp-run-python' + ], env=self.server_params.env, cwd=self.server_params.cwd, timeout=30) + logger.info("Cache cleared successfully") + except Exception as clear_e: + logger.error(f"Failed to clear Deno cache: {clear_e}") + backoff_time = 0.5 + attempt * 0.5 + await asyncio.sleep(backoff_time) diff --git a/pipelinerl/domains/tir/old/evaluate.py b/pipelinerl/domains/tir/old/evaluate.py deleted file mode 100644 index 951ef994..00000000 --- a/pipelinerl/domains/tir/old/evaluate.py +++ /dev/null @@ -1,455 +0,0 @@ -"""Evaluation script for TIR (Tool Integrated Reasoning) domain.""" - -import logging -import time -from typing import Optional - -import numpy as np -from datasets import load_dataset -from termcolor import colored -from tqdm import tqdm -import wandb - -from tapeagents.llms import TrainableLLM - -from .agent import TIRMathAgent, extract_result_value, solve_task, AnswerAction -from .environment import TIRMathEnvironment, MCPPythonEnvironment -from .prompts import PromptRegistry - -# Set logging level -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def evaluate_tir_model( - num_samples: int = 100, - temperature: float = 0.2, - dataset_name: str = "gsm8k", - wandb_project: Optional[str] = None, - wandb_run_name: Optional[str] = None, - log_failures: bool = True, - use_mcp: bool = True, - prompt_type: str = "default", - model_path: str = "/mnt/llmd/base_models/AI-MO-NuminaMath-7B-TIR", - base_url: str = "http://localhost:8080" -): - """ - Evaluate the TIR model's performance on math problems. - - Args: - num_samples: Number of samples to evaluate - temperature: Sampling temperature for the model - dataset_name: Dataset to evaluate on (gsm8k, math, etc.) - wandb_project: W&B project name (if None, no logging to wandb) - wandb_run_name: W&B run name - log_failures: Whether to log detailed failure analysis to wandb - use_mcp: Whether to use MCP Python execution environment - prompt_type: Type of prompt to use (default, advanced) - model_path: Path to the model - base_url: Base URL for the LLM server - - Returns: - Dictionary with accuracy and detailed results. - """ - # Initialize wandb if project specified - if wandb_project: - wandb.init( - project=wandb_project, - name=wandb_run_name or f"tir-{dataset_name}-eval-t{temperature}-n{num_samples}", - config={ - "dataset": dataset_name.upper(), - "num_samples": num_samples, - "temperature": temperature, - "model": model_path.split("/")[-1] if "/" in model_path else model_path, - "framework": "TapeAgents-TIR", - "execution_env": "MCP-Python" if use_mcp else "Container-Python", - "prompt_type": prompt_type, - "max_iterations": 1, - } - ) - - # Load dataset - if dataset_name.lower() == "gsm8k": - dataset = load_dataset("openai/gsm8k", "main", split="test") - elif dataset_name.lower() == "math": - dataset = load_dataset("hendrycks/competition_math", "main", split="test") - else: - raise ValueError(f"Unsupported dataset: {dataset_name}") - - samples = [s for s in dataset][:num_samples] - logger.info(f"Evaluating on {len(samples)} samples from {dataset_name.upper()} test set using TIR") - - # Initialize model and agent - llm = TrainableLLM( - base_url=base_url, - model_name=model_path, - tokenizer_name=model_path, - parameters=dict( - temperature=temperature, - max_tokens=512 - ), - ) - - # Get the appropriate system prompt - system_prompt = PromptRegistry.get_prompt(prompt_type) - - agent = TIRMathAgent.create( - llm=llm, - max_prompt_length=1024, - system_prompt=system_prompt - ) - - # Use TIR environment - if use_mcp: - env = MCPPythonEnvironment() - logger.info("Using MCP Python environment for TIR") - else: - env = TIRMathEnvironment(use_mcp=False) - logger.info("Using container/restricted Python environment for TIR") - - # Track results - results = [] - correct = 0 - errors = 0 - - # Track detailed metrics for wandb - no_answer_count = 0 - wrong_answer_count = 0 - execution_errors = 0 - - # Start timing - start_time = time.time() - sample_times = [] - - for i, sample in enumerate(tqdm(samples)): - try: - # Time individual sample processing - sample_start_time = time.time() - - # Prepare sample with expected value - sample = extract_result_value(sample) - - # Use the solve_task function from TIR agent - tape = solve_task(agent, env, sample) - - # Check if the task was solved correctly - is_solved = tape.metadata.result.get("solved", False) - - # Count steps in the tape - num_steps = len(tape.steps) - - # Extract the answer from the last AnswerAction in the tape - answer_step = None - for step in reversed(tape.steps): - if isinstance(step, AnswerAction): - answer_step = step - break - - if answer_step is None: - logger.warning(colored(f"No answer found for sample {i}", "yellow")) - no_answer_count += 1 - errors += 1 - results.append({ - "question": sample["question"], - "expected": sample["value"], - "predicted": None, - "correct": False, - "solved": False, - "error": "No answer produced", - "sample_id": i, - "num_steps": num_steps - }) - - # Log no-answer cases immediately to wandb for debugging - if wandb_project: - wandb.log({"no_answer_case": i + 1, "no_answer_total": no_answer_count}) - continue - - # Compare results - predicted_value = answer_step.value - expected_value = sample["value"] - - # Check if values match (with small tolerance for floating point) - if predicted_value is not None and expected_value is not None: - is_correct = abs(float(predicted_value) - float(expected_value)) < 1e-6 - else: - is_correct = False - - if is_correct: - correct += 1 - logger.debug(colored(f"Correct answer for sample {i}", "green")) - else: - wrong_answer_count += 1 - logger.debug(colored(f"Wrong answer for sample {i}. Expected {expected_value}, got {predicted_value}", "red")) - - # Record timing for this sample - sample_end_time = time.time() - sample_duration = sample_end_time - sample_start_time - sample_times.append(sample_duration) - - results.append({ - "question": sample["question"], - "expected": expected_value, - "predicted": predicted_value, - "correct": is_correct, - "solved": is_solved, - "error": None, - "sample_id": i, - "processing_time": sample_duration, - "num_steps": num_steps - }) - - # Log progress to wandb every 10 samples - if wandb_project and (i + 1) % 10 == 0: - elapsed_time = time.time() - start_time - avg_time_per_sample = elapsed_time / (i + 1) - current_step_counts = [r["num_steps"] for r in results[:i+1]] - current_avg_steps = np.mean(current_step_counts) if current_step_counts else 0 - wandb.log({ - "samples_processed": i + 1, - "current_accuracy": correct / (i + 1), - "current_no_answer_rate": no_answer_count / (i + 1), - "current_error_rate": errors / (i + 1), - "elapsed_time_minutes": elapsed_time / 60, - "avg_time_per_sample_seconds": avg_time_per_sample, - "estimated_total_time_minutes": (avg_time_per_sample * len(samples)) / 60, - "current_avg_steps_per_sample": current_avg_steps - }) - - except Exception as e: - logger.error(colored(f"Error processing sample {i}: {str(e)}", "red")) - execution_errors += 1 - errors += 1 - - # Record timing even for failed samples - sample_end_time = time.time() - sample_duration = sample_end_time - sample_start_time - sample_times.append(sample_duration) - - # Set num_steps to 0 for failed samples (no tape created) - num_steps = 0 - - results.append({ - "question": sample["question"], - "expected": sample.get("value", None), - "predicted": None, - "correct": False, - "solved": False, - "error": str(e), - "sample_id": i, - "processing_time": sample_duration, - "num_steps": num_steps - }) - - # Calculate metrics and timing - total_time = time.time() - start_time - avg_time_per_sample = np.mean(sample_times) if sample_times else 0 - median_time_per_sample = np.median(sample_times) if sample_times else 0 - - # Calculate step statistics - step_counts = [r["num_steps"] for r in results] - avg_steps_per_sample = np.mean(step_counts) if step_counts else 0 - median_steps_per_sample = np.median(step_counts) if step_counts else 0 - max_steps = max(step_counts) if step_counts else 0 - min_steps = min(step_counts) if step_counts else 0 - - accuracy = correct / len(samples) if len(samples) > 0 else 0 - error_rate = errors / len(samples) if len(samples) > 0 else 0 - solved_rate = sum(1 for r in results if r["solved"]) / len(samples) if len(samples) > 0 else 0 - no_answer_rate = no_answer_count / len(samples) if len(samples) > 0 else 0 - wrong_answer_rate = wrong_answer_count / len(samples) if len(samples) > 0 else 0 - execution_error_rate = execution_errors / len(samples) if len(samples) > 0 else 0 - - logger.info(f"\n{dataset_name.upper()} TIR Evaluation Results:") - logger.info(f"Total samples: {len(samples)}") - logger.info(colored(f"Accuracy: {accuracy:.2%} ({correct}/{len(samples)})", "green")) - logger.info(colored(f"Solved rate: {solved_rate:.2%}", "blue")) - logger.info(f"Correct: {correct}") - logger.info(colored(f"Errors: {errors} ({error_rate:.2%})", "red")) - logger.info(f"Token usage: {llm.token_count if hasattr(llm, 'token_count') else 'N/A'}") - logger.info(colored(f"Total time: {total_time:.1f}s ({total_time/60:.1f} minutes)", "cyan")) - logger.info(colored(f"Average time per sample: {avg_time_per_sample:.2f}s", "cyan")) - logger.info(colored(f"Median time per sample: {median_time_per_sample:.2f}s", "cyan")) - logger.info(colored(f"Average steps per sample: {avg_steps_per_sample:.1f}", "magenta")) - logger.info(colored(f"Median steps per sample: {median_steps_per_sample:.1f}", "magenta")) - logger.info(colored(f"Step range: {min_steps}-{max_steps}", "magenta")) - - # Log final results to wandb - if wandb_project: - final_metrics = { - "final_accuracy": accuracy, - "final_solved_rate": solved_rate, - "final_error_rate": error_rate, - "no_answer_rate": no_answer_rate, - "wrong_answer_rate": wrong_answer_rate, - "execution_error_rate": execution_error_rate, - "total_samples": len(samples), - "correct_answers": correct, - "no_answer_cases": no_answer_count, - "wrong_answers": wrong_answer_count, - "execution_errors": execution_errors, - "token_usage": getattr(llm, 'token_count', None), - "total_time_seconds": total_time, - "total_time_minutes": total_time / 60, - "avg_time_per_sample_seconds": avg_time_per_sample, - "median_time_per_sample_seconds": median_time_per_sample, - "samples_per_minute": len(samples) / (total_time / 60) if total_time > 0 else 0, - "avg_steps_per_sample": avg_steps_per_sample, - "median_steps_per_sample": median_steps_per_sample, - "max_steps": max_steps, - "min_steps": min_steps - } - wandb.log(final_metrics) - - # Log failure examples as a table - if log_failures: - failures = [r for r in results if not r["correct"]] - failure_data = [] - for failure in failures[:20]: # Log first 20 failures - failure_data.append([ - failure["sample_id"], - failure["question"][:150] + "..." if len(failure["question"]) > 150 else failure["question"], - failure["expected"], - failure["predicted"], - failure["error"] or "Wrong calculation" - ]) - - if failure_data: - failure_table = wandb.Table( - columns=["Sample ID", "Question", "Expected", "Predicted", "Error Type"], - data=failure_data - ) - wandb.log({"failure_examples": failure_table}) - - # Create accuracy over time chart - progress_data = [] - running_correct = 0 - for i, result in enumerate(results): - if result["correct"]: - running_correct += 1 - progress_data.append([i + 1, running_correct / (i + 1)]) - - progress_table = wandb.Table( - columns=["Sample", "Accuracy"], - data=progress_data - ) - wandb.log({ - "accuracy_over_time": wandb.plot.line( - progress_table, "Sample", "Accuracy", - title=f"{dataset_name.upper()} TIR Accuracy Over Time" - ) - }) - - wandb.finish() - - return { - "accuracy": accuracy, - "solved_rate": solved_rate, - "error_rate": error_rate, - "no_answer_rate": no_answer_rate, - "wrong_answer_rate": wrong_answer_rate, - "execution_error_rate": execution_error_rate, - "total_samples": len(samples), - "correct": correct, - "errors": errors, - "total_time_seconds": total_time, - "total_time_minutes": total_time / 60, - "avg_time_per_sample": avg_time_per_sample, - "median_time_per_sample": median_time_per_sample, - "samples_per_minute": len(samples) / (total_time / 60) if total_time > 0 else 0, - "avg_steps_per_sample": avg_steps_per_sample, - "median_steps_per_sample": median_steps_per_sample, - "max_steps": max_steps, - "min_steps": min_steps, - "detailed_results": results - } - - -def analyze_failures(results: dict, num_examples: int = 5): - """Analyze and print example failures for debugging.""" - failures = [r for r in results["detailed_results"] if not r["correct"]] - - print(f"\nTIR Failure Analysis ({len(failures)} total failures):") - print("=" * 80) - - # Group failures by type - no_answer = [f for f in failures if f["predicted"] is None] - wrong_answer = [f for f in failures if f["predicted"] is not None] - - print(f"No answer produced: {len(no_answer)}") - print(f"Wrong answer: {len(wrong_answer)}") - - print(f"\nExample Failures (showing first {num_examples}):") - for i, failure in enumerate(failures[:num_examples]): - print(f"\nFailure {i+1}:") - print(f"Sample ID: {failure.get('sample_id', 'N/A')}") - print(f"Question: {failure['question']}") - print(f"Expected: {failure['expected']}") - print(f"Predicted: {failure['predicted']}") - print(f"Solved: {failure['solved']}") - if failure['error']: - print(f"Error: {failure['error']}") - print("-" * 80) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="Run TIR evaluation on math datasets") - parser.add_argument("--num-samples", type=int, default=200, help="Number of samples to evaluate") - parser.add_argument("--temperature", type=float, default=0.2, help="Sampling temperature") - parser.add_argument("--dataset", type=str, default="gsm8k", choices=["gsm8k", "math"], help="Dataset to evaluate on") - parser.add_argument("--wandb-project", type=str, help="Wandb project name") - parser.add_argument("--wandb-run-name", type=str, help="Wandb run name") - parser.add_argument("--no-wandb", action="store_true", help="Disable wandb logging") - parser.add_argument("--no-mcp", action="store_true", help="Don't use MCP Python execution") - parser.add_argument("--prompt-type", type=str, default="default", choices=["default", "advanced"], help="Prompt type to use") - parser.add_argument("--model-path", type=str, default="/mnt/llmd/base_models/AI-MO-NuminaMath-7B-TIR", help="Model path") - parser.add_argument("--base-url", type=str, default="http://localhost:8080", help="Base URL for LLM server") - - args = parser.parse_args() - - # Set up wandb configuration - wandb_project = None if args.no_wandb else (args.wandb_project or f"tir-{args.dataset}-eval") - - logger.info(f"Starting TIR evaluation:") - logger.info(f" Dataset: {args.dataset.upper()}") - logger.info(f" Number of samples: {args.num_samples}") - logger.info(f" Temperature: {args.temperature}") - logger.info(f" Prompt type: {args.prompt_type}") - logger.info(f" Use MCP: {not args.no_mcp}") - logger.info(f" Wandb project: {wandb_project or 'None (disabled)'}") - - # Run TIR evaluation - results = evaluate_tir_model( - num_samples=args.num_samples, - temperature=args.temperature, - dataset_name=args.dataset, - wandb_project=wandb_project, - wandb_run_name=args.wandb_run_name, - log_failures=True, - use_mcp=not args.no_mcp, - prompt_type=args.prompt_type, - model_path=args.model_path, - base_url=args.base_url - ) - - # Print summary - print(f"\n{'='*80}") - print(f"TIR {args.dataset.upper()} EXPERIMENT SUMMARY") - print(f"{'='*80}") - print(f"Accuracy: {results['accuracy']:.2%} ({results['correct']}/{results['total_samples']})") - print(f"No answer rate: {results['no_answer_rate']:.2%}") - print(f"Wrong answer rate: {results['wrong_answer_rate']:.2%}") - print(f"Execution error rate: {results['execution_error_rate']:.2%}") - print(f"Total time: {results['total_time_minutes']:.1f} minutes") - print(f"Average time per sample: {results['avg_time_per_sample']:.2f}s") - print(f"Throughput: {results['samples_per_minute']:.1f} samples/minute") - print(f"Average steps per sample: {results['avg_steps_per_sample']:.1f}") - print(f"Step range: {results['min_steps']}-{results['max_steps']}") - - if wandb_project: - print(f"\nResults logged to wandb project: {wandb_project}") - - # Analyze failures - analyze_failures(results) \ No newline at end of file diff --git a/pipelinerl/domains/tir/rollouts.py b/pipelinerl/domains/tir/rollouts.py index 4150c208..f6fb51c0 100644 --- a/pipelinerl/domains/tir/rollouts.py +++ b/pipelinerl/domains/tir/rollouts.py @@ -4,109 +4,159 @@ import time import json import os -from typing import Dict, Any, List, Union +from typing import Any, List from collections import Counter import aiohttp from omegaconf import DictConfig from tapeagents.llms import TrainableLLM -from pipelinerl.rollouts import RolloutResult +from pipelinerl.rollouts import RolloutResult, BaseMetrics +from pydantic import BaseModel +from tapeagents.steps import ActionExecutionFailure +from tapeagents.tools.code_executor import CodeExecutionResult logger = logging.getLogger(__name__) -# Cache environments globally +# cache environments globally to avoid recreating them _cached_environments = {} + +class TIRMetrics(BaseMetrics): + """TIR-specific metrics extending the base metrics.""" + overflow: int = 0 + timeout: int = 0 + prompt_tokens: int = 0 + output_tokens: int = 0 + reached_answer_action: int = 0 + + +class TIRRewardTable(BaseModel): + """Reward table for TIR domain""" + correct_answer: float + wrong_answer: float + no_answer: float + unparsable: float + execution_failure: float + successful_code_execution: float + timeout_penalty: float + buffer_tokens: int + iteration_penalty: float + + +def length_penalty(max_length: int, sequence_length: int, buffer_tokens: int) -> float: + """Compute the overlong penalty""" + if sequence_length > (max_length - buffer_tokens) and sequence_length <= max_length: + return ((max_length - buffer_tokens) - sequence_length) / buffer_tokens + return 0. + + async def generate_tir_rollout(cfg: DictConfig, llm: TrainableLLM, problem: dict, session: aiohttp.ClientSession) -> RolloutResult: - """Generate a rollout for TIR domain with fast or sc_tir modes.""" + """Generate a rollout for TIR domain with iterative reasoning.""" from pipelinerl.async_llm import make_training_text - from tapeagents.orchestrator import main_loop + from tapeagents.orchestrator import async_main_loop from .agent import Task, TIRMathTape, AnswerAction, TIRMathAgent from .environment import MCPPythonEnvironment time_start = time.time() + # Create or reuse environment env_key = str(cfg.environment) if env_key not in _cached_environments: _cached_environments[env_key] = MCPPythonEnvironment() logger.info("Created new cached MCP environment") environment = _cached_environments[env_key] - mode = getattr(cfg.actor, 'mode', 'fast') - num_candidates = getattr(cfg.actor, 'num_candidates', 4) if mode == 'sc_tir' else 1 max_reasoning_steps = getattr(cfg.actor, 'max_reasoning_steps', 8) + logger.info(f"Running TIR with max {max_reasoning_steps} reasoning steps") - logger.info(f"Running {mode} mode with {num_candidates} candidates, max {max_reasoning_steps} steps") + # Create agent + agent = TIRMathAgent( + system_prompt=cfg.actor.system_prompt, + max_iterations=max_reasoning_steps + ) + agent.llms = {"default": llm} - all_final_tapes = [] - all_llm_calls = [] - all_training_samples = [] - candidate_answers = [] + # Debug: Check what tokenizer is being used + if hasattr(llm, 'tokenizer') and llm.tokenizer: + logger.info(f"LLM using tokenizer: {llm.tokenizer.__class__.__name__} from {llm.tokenizer.name_or_path}") + logger.info(f"LLM tokenizer vocab size: {llm.tokenizer.vocab_size}") + else: + logger.warning("LLM has no tokenizer loaded") - for candidate_idx in range(num_candidates): - logger.info(f"Generating candidate {candidate_idx + 1}/{num_candidates}") - - agent = TIRMathAgent( - system_prompt=cfg.actor.system_prompt, - max_iterations=max_reasoning_steps + # Use task template if provided + task_template = getattr(cfg.actor, 'task_template', '{task}') + task_step = Task(task=problem["task"], template=task_template) + start_tape = TIRMathTape(steps=[task_step], context=None) + + # Run agent-environment interaction + final_tape = None + + async for event in async_main_loop(agent, start_tape, environment, session, cfg.max_loops): + if event.agent_tape: + final_tape = event.agent_tape + elif event.env_tape: + final_tape = event.env_tape + + if final_tape is None: + logger.warning("Failed to generate tape") + metrics = TIRMetrics( + reward=0.0, + success=False, + no_error=False, + no_answer=True, + ) + return RolloutResult( + training_texts=[], + metrics=metrics, + latency=time.time() - time_start, + dataset_name=problem.get("dataset", "unknown"), ) - agent.llms = {"default": llm} - - task_step = Task(task=problem["task"]) - start_tape = TIRMathTape(steps=[task_step], context=None) - - # agent-environment interaction - final_tape = None - for event in main_loop(agent, start_tape, environment, cfg.max_loops): - if event.agent_tape: - final_tape = event.agent_tape - elif event.env_tape: - final_tape = event.env_tape - - if final_tape is not None: - all_final_tapes.append(final_tape) - - answer_step = None - for step in reversed(final_tape.steps): - if isinstance(step, AnswerAction): - answer_step = step - break - - if answer_step is not None: - candidate_answers.append(answer_step.value) - logger.info(f"Candidate {candidate_idx + 1} answer: {answer_step.value}") - else: - candidate_answers.append(None) - logger.warning(f"Candidate {candidate_idx + 1} produced no answer") - - candidate_llm_calls = [] - candidate_samples = [] - - for step in final_tape.steps: - if step.metadata and step.metadata.other: - llm_call_data = step.metadata.other.get("llm_call") - if llm_call_data: - training_text = make_training_text(llm, llm_call_data) - candidate_samples.append(training_text) - candidate_llm_calls.append(llm_call_data) - - if not candidate_llm_calls: - _, candidate_llm_calls = agent.reuse(final_tape) - candidate_samples = [agent.make_training_text(llm_call) for llm_call in candidate_llm_calls] - - all_llm_calls.extend(candidate_llm_calls) - all_training_samples.extend(candidate_samples) - else: - candidate_answers.append(None) - logger.warning(f"Candidate {candidate_idx + 1} failed") - # majority voting or single answer - if mode == 'sc_tir': - final_answer = apply_majority_voting(candidate_answers) - logger.info(f"Candidates: {candidate_answers} -> Majority: {final_answer}") - else: - final_answer = candidate_answers[0] if candidate_answers else None - logger.info(f"Fast mode answer: {final_answer}") + final_answer = None + reached_answer_action = False + if final_tape and isinstance(final_tape.steps[-1], AnswerAction): + final_answer = final_tape.steps[-1].value + reached_answer_action = True + + predicted_answer = final_answer if final_answer is not None else "No answer" + ground_truth = problem.get("answer", "Unknown") + logger.info(f"Problem: {problem.get('id', 'unknown')} | Predicted: {predicted_answer} | Ground truth: {ground_truth} | Reached AnswerAction: {reached_answer_action}") + training_samples = [] + + llm_calls = [] + + if getattr(agent, "llm_calls", None): + llm_calls.extend(agent.llm_calls) + + if final_tape and not llm_calls: + for step in final_tape.steps: + if ( + step.metadata + and step.metadata.other + and "llm_call" in step.metadata.other + ): + llm_calls.append(step.metadata.other["llm_call"]) + + for llm_call in llm_calls: + if isinstance(llm_call, dict): + from tapeagents.core import LLMCall + llm_call = LLMCall(**llm_call) + + training_sample = make_training_text(llm, llm_call) + training_samples.append(training_sample) + + if reached_answer_action and training_samples: + for text in training_samples: + text.finished = True + + if not llm_calls: + logger.warning( + "No LLM calls were captured for this rollout; no training samples will be produced. " + "Check that `agent.store_llm_calls=True` and that the orchestrator " + "is not stripping metadata." + ) + + # save debug info if requested if getattr(cfg, 'save_tapes', False): debug_dir = os.path.join(cfg.output_dir, "debug_tapes") os.makedirs(debug_dir, exist_ok=True) @@ -114,12 +164,8 @@ async def generate_tir_rollout(cfg: DictConfig, llm: TrainableLLM, problem: dict debug_file = os.path.join(debug_dir, f"problem_{problem.get('id', 'unknown')}.json") debug_data = { "problem": problem, - "mode": mode, - "num_candidates": num_candidates, - "candidate_answers": candidate_answers, - "majority_answer": final_answer, - "num_tapes": len(all_final_tapes), - "total_llm_calls": len(all_llm_calls), + "answer": final_answer, + "num_llm_calls": len(llm_calls), "target_answer": problem.get("answer", ""), } @@ -144,54 +190,104 @@ async def generate_tir_rollout(cfg: DictConfig, llm: TrainableLLM, problem: dict else: answer_status = "unparsable" - # rewards - reward = 1.0 if success else 0.0 - for sample in all_training_samples: - sample.reward = reward + if training_samples: + import numpy as np + avg_prompt_tokens = np.mean([t.prompt_tokens for t in training_samples]) + avg_output_tokens = np.mean([t.output_tokens for t in training_samples]) + logger.info( + f"šŸŽ“ Generated {len(training_samples)} training samples | " + f"avg prompt={avg_prompt_tokens:.0f} tokens | " + f"avg output={avg_output_tokens:.0f} tokens | " + f"answer_status={answer_status} | " + f"reached_answer={reached_answer_action}" + ) + + rewards = TIRRewardTable(**dict(cfg.get('rewards', {}))) + + num_llm_calls = len(llm_calls) + iteration_penalty_total = num_llm_calls * rewards.iteration_penalty + + if training_samples: + if answer_status == "correct": + base_reward = rewards.correct_answer + elif answer_status == "wrong": + base_reward = rewards.wrong_answer + elif answer_status == "no_answer": + base_reward = rewards.no_answer + elif answer_status == "unparsable": + base_reward = rewards.unparsable + else: + base_reward = rewards.wrong_answer # fallback + + has_execution_errors = any( + isinstance(step, ActionExecutionFailure) for step in final_tape.steps + ) + if has_execution_errors: + base_reward += rewards.execution_failure + + successful_executions = sum(1 for step in final_tape.steps if isinstance(step, CodeExecutionResult)) + base_reward += successful_executions * rewards.successful_code_execution + + base_reward += iteration_penalty_total + + for sample in training_samples: + sample.reward = base_reward + + if cfg.actor.discount_factor and llm_calls: + total_output_tokens = sum(llm_call.output_length_tokens for llm_call in llm_calls) + discount_multiplier = cfg.actor.discount_factor ** total_output_tokens + + overlong_penalty = 0 + if rewards.buffer_tokens > 0: + max_tokens = getattr(cfg.actor, 'max_tokens', 4096) + overlong_penalty = length_penalty(max_tokens, total_output_tokens, rewards.buffer_tokens) + + for sample in training_samples: + sample.reward *= discount_multiplier + sample.reward += overlong_penalty + + avg_reward = sum(sample.reward for sample in training_samples) / len(training_samples) + else: + avg_reward = sum(sample.reward for sample in training_samples) / len(training_samples) if training_samples else 0.0 - # discount factor - if cfg.actor.discount_factor and all_llm_calls: - total_output_tokens = sum(llm_call.output_length_tokens for llm_call in all_llm_calls) - reward *= cfg.actor.discount_factor ** total_output_tokens - for sample in all_training_samples: - sample.reward = reward + if training_samples: + reward_values = [sample.reward for sample in training_samples] + logger.info( + f"šŸ† Final rewards: mean={np.mean(reward_values):.3f} | " + f"min={np.min(reward_values):.3f} | max={np.max(reward_values):.3f} | " + f"base_reward={base_reward:.3f} | " + f"iter_penalty={iteration_penalty_total:.3f} | " + f"llm_calls={num_llm_calls} | " + f"discount_factor={cfg.actor.discount_factor if cfg.actor.discount_factor else 'N/A'}" + ) has_errors = any( - any(1 for s in tape.steps if hasattr(s, 'error') and s.error) - for tape in all_final_tapes + any(1 for s in final_tape.steps if hasattr(s, 'error') and s.error) + for s in [final_tape] ) - valid_answers = [ans for ans in candidate_answers if ans is not None] - if mode == 'sc_tir' and len(valid_answers) > 1: - answer_counts = Counter(valid_answers) - most_common_count = answer_counts.most_common(1)[0][1] if answer_counts else 0 - agreement_rate = most_common_count / len(valid_answers) - else: - agreement_rate = 1.0 if valid_answers else 0.0 - - metrics = { - "reward": reward, - "success": 1 if success else 0, - "no_error": 1 if not has_errors else 0, - "no_answer": 1 if answer_status == "no_answer" else 0, - "overflow": 0, # TODO: detect max_loops - "prompt_tokens": sum(llm_call.prompt_length_tokens for llm_call in all_llm_calls) if all_llm_calls else 0, - "output_tokens": sum(llm_call.output_length_tokens for llm_call in all_llm_calls) if all_llm_calls else 0, - "mode": mode, - "num_candidates": num_candidates, - "candidates_with_answers": len(valid_answers), - "agreement_rate": agreement_rate, - "majority_answer": final_answer, - "candidate_answers": candidate_answers, - } + metrics = TIRMetrics( + reward=avg_reward, + success=success, + no_error=not has_errors, + no_answer=(answer_status == "no_answer"), + overflow=0, + timeout=0, + prompt_tokens=sum(llm_call.prompt_length_tokens for llm_call in llm_calls) if llm_calls else 0, + output_tokens=sum(llm_call.output_length_tokens for llm_call in llm_calls) if llm_calls else 0, + reached_answer_action=int(reached_answer_action), + ) + # guard against assertion error when we have insufficient samples or inconsistent text + if len(training_samples) >= 2 and training_samples[0].text not in training_samples[1].text: + logger.debug("Rollout consistency check failed; continuing without assertion") return RolloutResult( - training_texts=all_training_samples, + training_texts=training_samples, metrics=metrics, latency=time.time() - time_start, dataset_name=problem.get("dataset", "unknown"), - prompt_tokens=[llm_call.prompt_length_tokens for llm_call in all_llm_calls] if all_llm_calls else [], - output_tokens=[llm_call.output_length_tokens for llm_call in all_llm_calls] if all_llm_calls else [], + prompt_tokens=[llm_call.prompt_length_tokens for llm_call in llm_calls] if llm_calls else [], + output_tokens=[llm_call.output_length_tokens for llm_call in llm_calls] if llm_calls else [], ) @@ -202,7 +298,6 @@ def apply_majority_voting(candidate_answers: List[Any]) -> Any: if not valid_answers: return None - # normalise answers normalized_answers = [] for ans in valid_answers: if isinstance(ans, (int, float)): diff --git a/pipelinerl/launch.py b/pipelinerl/launch.py index 3e864171..d2beae86 100644 --- a/pipelinerl/launch.py +++ b/pipelinerl/launch.py @@ -561,6 +561,8 @@ def main(cfg: DictConfig): processes.extend(launch_jobs(cfg, world_map, ["finetune"])) elif cfg.debug.mode == "actor": processes.extend(launch_jobs(cfg, world_map, ["actor", "environment", "actor_llm"])) + elif cfg.debug.mode == "eval": + processes.extend(launch_jobs(cfg, world_map, ["actor", "environment", "actor_llm"])) elif cfg.debug.mode == "preprocessor": processes.extend(launch_jobs(cfg, world_map, ["preprocessor", "preprocessor_llm"])) elif cfg.debug.mode == "actor+preprocessor": diff --git a/pipelinerl/preprocess.py b/pipelinerl/preprocess.py index 544bd87a..d4db1107 100644 --- a/pipelinerl/preprocess.py +++ b/pipelinerl/preprocess.py @@ -333,7 +333,27 @@ def run_preprocessing_loop( else: wandb_run = None - tokenizer = load_tokenizer(cfg.finetune.config_name) + # Use same model selection logic as actor: prefer finetuned model if available + finetune_model_path = exp_root_dir / "finetune" / "current" + if finetune_model_path.exists(): + tokenizer_path = str(finetune_model_path) + logger.info(f"Using finetuned model tokenizer from: {tokenizer_path}") + else: + tokenizer_path = cfg.finetune.config_name + logger.info(f"Using base model tokenizer from config: {tokenizer_path}") + + logger.info(f"Loading tokenizer from: {tokenizer_path}") + tokenizer = load_tokenizer(tokenizer_path) + logger.info(f"Loaded tokenizer: {tokenizer.__class__.__name__} from {tokenizer.name_or_path}") + logger.info(f"Tokenizer vocab size: {tokenizer.vocab_size}") + + # Check if Qwen special tokens are in vocab + qwen_tokens = {"<|im_start|>": 151644, "<|im_end|>": 151645, "<|endoftext|>": 151643} + for token_name, token_id in qwen_tokens.items(): + if token_id in tokenizer.get_vocab().values(): + logger.info(f"Found Qwen token {token_name} (id: {token_id}) in tokenizer vocab") + else: + logger.warning(f"Qwen token {token_name} (id: {token_id}) NOT found in tokenizer vocab") llm_urls = str(cfg.me.llm_urls).split("+") if cfg.me.llm_urls else [] if llm_urls: diff --git a/pipelinerl/utils.py b/pipelinerl/utils.py index 2b0a252c..db54ef0b 100644 --- a/pipelinerl/utils.py +++ b/pipelinerl/utils.py @@ -236,6 +236,15 @@ def calculate_stats(stats: List | Dict[Any, Any]) -> Dict[str, float]: # stats is a dict of list stats = dict_to_list(stats) + if not stats: + # Handle empty stats gracefully + return { + "max": 0.0, + "min": 0.0, + "var": 0.0, + "mean": 0.0, + } + if not isinstance(stats, list): raise TypeError(f"Expected stats to be a list, got {type(stats)}")