-
Notifications
You must be signed in to change notification settings - Fork 85
adding resample benchmark objects #226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
from abc import abstractmethod | ||
from dataclasses import dataclass, field | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
import benchmarks | ||
import numpy as np | ||
import pandas as pd | ||
from bgym import Benchmark, EnvArgs, HighLevelActionSetArgs | ||
from browsergym.experiments.benchmark.base import BenchmarkBackend | ||
from browsergym.experiments.benchmark.utils import make_env_args_list_from_repeat_tasks | ||
from dataclasses_json import DataClassJsonMixin, config | ||
from torch import threshold | ||
|
||
from agentlab.analyze.inspect_results import load_result_df | ||
from agentlab.experiments.study import Study | ||
|
||
|
||
@dataclass | ||
class ResampleBenchmark(Benchmark): | ||
exp_dir: Path = None | ||
name: str = None | ||
high_level_action_set_args: HighLevelActionSetArgs = None | ||
is_multi_tab: bool = None | ||
supports_parallel_seeds: bool = None | ||
env_args_list: list[EnvArgs] = None | ||
backends: list[BenchmarkBackend] = None | ||
task_metadata: Optional[pd.DataFrame] = field( | ||
default_factory=lambda: None, | ||
metadata=config( | ||
encoder=lambda df: df.to_dict(orient="records") if df is not None else None, | ||
decoder=lambda items: pd.DataFrame(items) if items is not None else None, | ||
), | ||
) | ||
|
||
def __post_init__(self): | ||
assert self.exp_dir is not None | ||
study = Study.load(self.exp_dir) | ||
benchmark = study.benchmark | ||
|
||
self.name = f"resample-{benchmark.name}" | ||
self.high_level_action_set_args = benchmark.high_level_action_set_args | ||
self.is_multi_tab = benchmark.is_multi_tab | ||
self.supports_parallel_seeds = benchmark.supports_parallel_seeds | ||
self.backends = benchmark.backends | ||
# we discard the task_metadata to create new ones in post_init | ||
|
||
values = self.evaluate(study, benchmark.env_args_list) | ||
selected_env_args = self.select(values, benchmark.env_args_list) | ||
|
||
if len(selected_env_args) == 0: | ||
raise ValueError("No env_args selected, lower restrictions") | ||
|
||
self.env_args_list = selected_env_args | ||
|
||
super().__post_init__() | ||
|
||
@abstractmethod | ||
def evaluate(self, study, env_args_list): | ||
pass | ||
|
||
@abstractmethod | ||
def select(self, values, env_args_list): | ||
pass | ||
|
||
|
||
@dataclass | ||
class AllTasksBenchmark(ResampleBenchmark): | ||
def evaluate(self, study, env_args_list): | ||
return [0] * len(env_args_list) | ||
|
||
def select(self, values, env_args_list): | ||
return env_args_list | ||
|
||
|
||
@dataclass | ||
class HighVarianceBenchmark(ResampleBenchmark): | ||
threshold: float = 0.2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Threshold Value Validation Missing
Tell me moreWhat is the issue?The hardcoded threshold value of 0.2 for variance selection lacks validation to ensure it's a reasonable value. Why this mattersInvalid threshold values (negative or extremely high) could lead to unintended task filtering behavior. Suggested change ∙ Feature PreviewAdd threshold validation in post_init: def __post_init__(self):
if not 0 <= self.threshold <= 1:
raise ValueError(f"Threshold must be between 0 and 1, got {self.threshold}")
super().__post_init__() 💬 Chat with Korbit by mentioning @korbit-ai. |
||
|
||
def evaluate(self, study: Study, env_args_list): | ||
result_df = load_result_df(study.dir) | ||
return dict(result_df.groupby("env.task_name")["cum_reward"].std()) | ||
Comment on lines
+80
to
+82
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing Task Name Error Handling
Tell me moreWhat is the issue?The evaluate method in HighVarianceBenchmark may fail if a task_name in env_args_list doesn't exist in the result_df from the study. Why this mattersThis will cause a KeyError when trying to access non-existent task names in the select method, potentially crashing the benchmark creation. Suggested change ∙ Feature PreviewAdd error handling to safely handle missing task names: def evaluate(self, study: Study, env_args_list):
result_df = load_result_df(study.dir)
std_dict = dict(result_df.groupby("env.task_name")["cum_reward"].std())
# Return 0 variance for missing tasks to exclude them
return {task.task_name: std_dict.get(task.task_name, 0) for task in env_args_list} 💬 Chat with Korbit by mentioning @korbit-ai.
Comment on lines
+81
to
+82
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inefficient DataFrame Loading and Processing
Tell me moreWhat is the issue?Loading and processing the entire result DataFrame for each evaluation is inefficient, especially when dealing with large datasets. Why this mattersThis approach requires loading the complete dataset into memory and performing groupby operations each time evaluate() is called, which can be memory-intensive and slow for large experiment results. Suggested change ∙ Feature PreviewCache the processed results or pass pre-computed statistics to avoid reloading and recomputing. Consider implementing as: def __init__(self, exp_dir: Path, threshold: float = 0.2):
self._cached_stats = None
super().__init__(exp_dir=exp_dir, threshold=threshold)
def _compute_stats(self, study: Study):
if self._cached_stats is None:
result_df = load_result_df(study.dir)
self._cached_stats = dict(result_df.groupby("env.task_name")["cum_reward"].std())
return self._cached_stats
def evaluate(self, study: Study, env_args_list):
return self._compute_stats(study) 💬 Chat with Korbit by mentioning @korbit-ai. |
||
|
||
def select(self, values, env_args_list): | ||
selected_env_args = [] | ||
for env_args in env_args_list: | ||
if values[env_args.task_name] > self.threshold: | ||
selected_env_args.append(env_args) | ||
return selected_env_args | ||
|
||
|
||
@dataclass | ||
class StochasticHighVarianceBenchmark(ResampleBenchmark): | ||
regulation_threshold: float = 0.1 | ||
total_seeds = 600 | ||
min_seeds = 2 | ||
random_seed = 42 | ||
|
||
def evaluate(self, study: Study, env_args_list): | ||
result_df = load_result_df(study.dir) | ||
var = result_df.groupby("env.task_name")["cum_reward"].var() | ||
probs = dict((var + self.regulation_threshold) / (var + self.regulation_threshold).sum()) | ||
return probs | ||
|
||
def select(self, values, env_args_list: list[EnvArgs]): | ||
selected_env_args = [] | ||
max_steps = env_args_list[0].max_steps | ||
for task_name, p in values.items(): | ||
# ceil to avoid missing seeds | ||
n_seeds = np.random.RandomState(self.random_seed).poisson(p * self.total_seeds) | ||
n_seeds = max(n_seeds, self.min_seeds) | ||
for seed in np.random.RandomState(self.random_seed).randint(0, 2**32, n_seeds): | ||
selected_env_args.append( | ||
EnvArgs( | ||
task_name=task_name, | ||
task_seed=int(seed), | ||
max_steps=max_steps, | ||
headless=True, | ||
record_video=False, | ||
wait_for_user_message=False, | ||
viewport=None, | ||
slow_mo=None, | ||
storage_state=None, | ||
task_kwargs=None, | ||
) | ||
) | ||
return selected_env_args | ||
Comment on lines
+93
to
+127
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
|
||
if __name__ == "__main__": | ||
exp_dir = Path( | ||
"/Users/t.lesellierdechezell/agentlab_results/2025-03-04_14-43-48_genericagent-gpt-4o-mini-2024-07-18-on-miniwob" | ||
) | ||
benchmark = StochasticHighVarianceBenchmark(exp_dir=exp_dir) | ||
print(benchmark.env_args_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simplify this