From c2b97e4c9ce65b4348a380292e5064452a74d2f2 Mon Sep 17 00:00:00 2001 From: Kyle Corbitt Date: Tue, 15 Jul 2025 17:41:21 -0700 Subject: [PATCH 1/3] Add adjust_lr function to support learning rate schedules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add adjust_lr function to calculate learning rate with warmup and cooldown phases - Update DatasetBatch to include total_steps field for LR calculations - Support warmup_length and cooldown_length as either int (steps) or float (ratio) - Simplified design: constant LR by default, linear decay via cooldown_length 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- src/art/utils/iterate_dataset.py | 61 ++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/src/art/utils/iterate_dataset.py b/src/art/utils/iterate_dataset.py index dfa0fbf5..71c4b3ab 100644 --- a/src/art/utils/iterate_dataset.py +++ b/src/art/utils/iterate_dataset.py @@ -1,7 +1,7 @@ import math import random from dataclasses import dataclass -from typing import List, Generator, TypeVar, Generic +from typing import List, Generator, TypeVar, Generic, Union from tqdm.auto import tqdm T = TypeVar("T") @@ -15,6 +15,59 @@ class DatasetBatch(Generic[T]): step: int epoch: int epoch_step: int + total_steps: int + + +def adjust_lr( + batch: DatasetBatch, + learning_rate: float, + warmup_length: Union[int, float] = 0, + cooldown_length: Union[int, float] = 0, +) -> float: + """ + Calculate the learning rate for a given batch based on the schedule. + + Args: + batch: The DatasetBatch containing step and total_steps information. + learning_rate: The base learning rate. + warmup_length: Either an int (number of steps) or float (ratio of total steps). Defaults to 0. + cooldown_length: Either an int (number of steps) or float (ratio of total steps). Defaults to 0. + + Returns: + The adjusted learning rate for the current batch. + """ + current_step = batch.step + total_steps = batch.total_steps + + # Convert warmup_length to steps if it's a ratio + if isinstance(warmup_length, float): + warmup_steps = int(warmup_length * total_steps) + else: + warmup_steps = warmup_length + + # Convert cooldown_length to steps if it's a ratio + if isinstance(cooldown_length, float): + cooldown_steps = int(cooldown_length * total_steps) + else: + cooldown_steps = cooldown_length + + # Ensure warmup + cooldown don't exceed total steps + warmup_steps = min(warmup_steps, total_steps) + cooldown_steps = min(cooldown_steps, total_steps - warmup_steps) + + # Warmup phase + if current_step < warmup_steps: + return learning_rate * (current_step + 1) / warmup_steps + + # Cooldown phase + cooldown_start = total_steps - cooldown_steps + if current_step >= cooldown_start and cooldown_steps > 0: + steps_into_cooldown = current_step - cooldown_start + remaining_ratio = 1.0 - (steps_into_cooldown + 1) / cooldown_steps + return learning_rate * remaining_ratio + + # Main phase (between warmup and cooldown) + return learning_rate def iterate_dataset( @@ -82,7 +135,11 @@ def iterate_dataset( batch_indices = indices[i : i + groups_per_step] items = [dataset[idx] for idx in batch_indices] yield DatasetBatch( - items=items, epoch=epoch, step=global_step, epoch_step=epoch_step + items=items, + epoch=epoch, + step=global_step, + epoch_step=epoch_step, + total_steps=total_steps, ) # Update progress bar after yielding From 94668812f5d6112721663136cd39765470537107 Mon Sep 17 00:00:00 2001 From: Kyle Corbitt Date: Tue, 15 Jul 2025 17:54:06 -0700 Subject: [PATCH 2/3] Support negative cooldown_length to specify cooldown start step MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Negative cooldown_length now specifies when cooldown starts (e.g., -20 means start at step 20) - This enables easy linear decay after warmup: warmup_length=20, cooldown_length=-20 - Ensures cooldown always starts after warmup phase completes 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- dev/yes_no_maybe_skypilot.py | 92 +++++++++++++++++++++++++++ examples/art-e/all_experiments.py | 8 +++ examples/art-e/art_e/project_types.py | 1 + examples/art-e/art_e/train.py | 5 +- examples/art-e/uv.lock | 18 ++---- src/art/utils/iterate_dataset.py | 26 +++++--- 6 files changed, 125 insertions(+), 25 deletions(-) create mode 100644 dev/yes_no_maybe_skypilot.py diff --git a/dev/yes_no_maybe_skypilot.py b/dev/yes_no_maybe_skypilot.py new file mode 100644 index 00000000..364c0443 --- /dev/null +++ b/dev/yes_no_maybe_skypilot.py @@ -0,0 +1,92 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "openpipe-art[skypilot]==0.4.3", +# "skypilot[runpod]", +# ] +# /// +import art +from art.skypilot import SkyPilotBackend +from dotenv import load_dotenv +import openai +import asyncio + +load_dotenv() + + +async def main(): + backend = await SkyPilotBackend.initialize_cluster( + cluster_name="kyle-yes-no-maybe", + gpu="H100", + ) + + model = art.TrainableModel( + name="001", + project="yes-no-maybe", + base_model="Qwen/Qwen2.5-7B-Instruct", + _internal_config=art.dev.InternalModelConfig( + _decouple_vllm_and_unsloth=True, + engine_args=art.dev.EngineArgs(gpu_memory_utilization=0.7), + ), + ) + await model.register(backend) + + async def rollout(client: openai.AsyncOpenAI, prompt: str) -> art.Trajectory: + messages: art.Messages = [ + { + "role": "user", + "content": prompt, + } + ] + chat_completion = await client.chat.completions.create( + messages=messages, model=model.name, max_tokens=100, timeout=100 + ) + choice = chat_completion.choices[0] + content = choice.message.content + assert isinstance(content, str) + if content == "yes": + reward = 0.5 + elif content == "no": + reward = 0.75 + elif content == "maybe": + reward = 1.0 + else: + reward = 0.0 + return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward) + + def with_quotes(w): + return f"'{w}'" + + prompts = [ + f"{prefix} with {', '.join([with_quotes(w) if use_quotes else w for w in words]) if len(words) == 3 else f'{words[0]}' + (f' or {words[1]}' if len(words) > 1 else '')}" + for prefix in ["respond", "just respond"] + for use_quotes in [True, False] + for words in [ + ["yes", "no", "maybe"], + ["maybe", "yes", "no"], + ["no", "yes", "maybe"], + ["yes", "maybe", "no"], + ["yes", "no"], + ["maybe", "no"], + ["no", "maybe"], + ["no", "yes"], + ["yes", "no"], + ] + ] + + openai_client = model.openai_client() + for _ in range(await model.get_step(), 1_000): + train_groups = await art.gather_trajectory_groups( + ( + art.TrajectoryGroup(rollout(openai_client, prompt) for _ in range(32)) + for prompt in prompts + ), + ) + await model.train( + train_groups, + config=art.TrainConfig(learning_rate=1e-4), + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/art-e/all_experiments.py b/examples/art-e/all_experiments.py index c00729e8..c7ca4fa0 100644 --- a/examples/art-e/all_experiments.py +++ b/examples/art-e/all_experiments.py @@ -205,3 +205,11 @@ models["225"] = models["224"].model_copy(deep=True) models["225"].name = "email-agent-225" + +models["226"] = models["008"].model_copy(deep=True) +models["226"].name = "email-agent-226" +models["226"].config.precalculate_logprobs = True + +models["227"] = models["008"].model_copy(deep=True) +models["227"].name = "email-agent-227" +models["220"].base_model = "willcb/Qwen3-14B" diff --git a/examples/art-e/art_e/project_types.py b/examples/art-e/art_e/project_types.py index 97dd16be..2c834451 100644 --- a/examples/art-e/art_e/project_types.py +++ b/examples/art-e/art_e/project_types.py @@ -28,3 +28,4 @@ class ProjectPolicyConfig(BaseModel): # choose its own default (e.g., derive from the current time). training_dataset_seed: int | None = None messages_only: bool = False + precalculate_logprobs: bool = False diff --git a/examples/art-e/art_e/train.py b/examples/art-e/art_e/train.py index 4df3089a..41d12225 100644 --- a/examples/art-e/art_e/train.py +++ b/examples/art-e/art_e/train.py @@ -143,9 +143,8 @@ async def judge_after_each( groups, config=art.TrainConfig(learning_rate=model.config.learning_rate), _config=art.dev.TrainConfig( - allow_training_without_logprobs=True - if model.config.messages_only - else False + allow_training_without_logprobs=model.config.messages_only, + precalculate_logprobs=model.config.precalculate_logprobs, ), ) diff --git a/examples/art-e/uv.lock b/examples/art-e/uv.lock index 5a74a901..d43f8807 100644 --- a/examples/art-e/uv.lock +++ b/examples/art-e/uv.lock @@ -3198,7 +3198,7 @@ wheels = [ [[package]] name = "openpipe-art" -version = "0.4.0" +version = "0.4.2" source = { editable = "../../" } dependencies = [ { name = "litellm" }, @@ -3214,7 +3214,6 @@ backend = [ { name = "hf-xet" }, { name = "peft" }, { name = "polars" }, - { name = "semver" }, { name = "setproctitle" }, { name = "setuptools", version = "79.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, { name = "setuptools", version = "80.9.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, @@ -3246,9 +3245,10 @@ requires-dist = [ { name = "peft", marker = "extra == 'backend'", specifier = ">=0.14.0" }, { name = "polars", marker = "extra == 'backend'", specifier = ">=1.26.0" }, { name = "seaborn", marker = "extra == 'plotting'", specifier = ">=0.13.2" }, - { name = "semver", marker = "extra == 'backend'", specifier = ">=3.0.4" }, + { name = "semver", marker = "extra == 'skypilot'", specifier = ">=3.0.4" }, { name = "setproctitle", marker = "extra == 'backend'", specifier = ">=1.3.6" }, { name = "setuptools", marker = "extra == 'backend'", specifier = ">=78.1.0" }, + { name = "skypilot", marker = "extra == 'skypilot'", specifier = "==0.9.3" }, { name = "tblib", marker = "extra == 'backend'", specifier = ">=3.0.0" }, { name = "torch", marker = "extra == 'backend'", specifier = ">=2.7.0" }, { name = "torchao", marker = "extra == 'backend'", specifier = ">=0.9.0" }, @@ -3261,7 +3261,7 @@ requires-dist = [ { name = "wandb", marker = "extra == 'backend'", specifier = ">=0.19.8" }, { name = "weave", marker = "extra == 'backend'", specifier = ">=0.51.51" }, ] -provides-extras = ["plotting", "backend"] +provides-extras = ["plotting", "backend", "skypilot"] [package.metadata.requires-dev] dev = [ @@ -3271,7 +3271,6 @@ dev = [ { name = "ipywidgets", specifier = ">=8.1.5" }, { name = "openpipe", specifier = ">=4.49.0" }, { name = "ruff", specifier = ">=0.12.1" }, - { name = "skypilot", extras = ["cudo", "do", "fluidstack", "gcp", "lambda", "paperspace", "runpod"], specifier = "==0.8.0" }, ] [[package]] @@ -5127,15 +5126,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/83/11/00d3c3dfc25ad54e731d91449895a79e4bf2384dc3ac01809010ba88f6d5/seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987", size = 294914, upload-time = "2024-01-25T13:21:49.598Z" }, ] -[[package]] -name = "semver" -version = "3.0.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/72/d1/d3159231aec234a59dd7d601e9dd9fe96f3afff15efd33c1070019b26132/semver-3.0.4.tar.gz", hash = "sha256:afc7d8c584a5ed0a11033af086e8af226a9c0b206f313e0301f8dd7b6b589602", size = 269730, upload-time = "2025-01-24T13:19:27.617Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a6/24/4d91e05817e92e3a61c8a21e08fd0f390f5301f1c448b137c57c4bc6e543/semver-3.0.4-py3-none-any.whl", hash = "sha256:9c824d87ba7f7ab4a1890799cec8596f15c1241cb473404ea1cb0c55e4b04746", size = 17912, upload-time = "2025-01-24T13:19:24.949Z" }, -] - [[package]] name = "sentencepiece" version = "0.2.0" diff --git a/src/art/utils/iterate_dataset.py b/src/art/utils/iterate_dataset.py index 71c4b3ab..f104dd76 100644 --- a/src/art/utils/iterate_dataset.py +++ b/src/art/utils/iterate_dataset.py @@ -32,6 +32,7 @@ def adjust_lr( learning_rate: The base learning rate. warmup_length: Either an int (number of steps) or float (ratio of total steps). Defaults to 0. cooldown_length: Either an int (number of steps) or float (ratio of total steps). Defaults to 0. + If negative, specifies the step at which cooldown starts (e.g., -20 means cooldown starts at step 20). Returns: The adjusted learning rate for the current batch. @@ -45,22 +46,31 @@ def adjust_lr( else: warmup_steps = warmup_length - # Convert cooldown_length to steps if it's a ratio - if isinstance(cooldown_length, float): - cooldown_steps = int(cooldown_length * total_steps) + # Handle cooldown_length + if cooldown_length < 0: + # Negative value means cooldown starts at that step + cooldown_start = int(-cooldown_length) + cooldown_steps = total_steps - cooldown_start else: - cooldown_steps = cooldown_length - - # Ensure warmup + cooldown don't exceed total steps + # Convert cooldown_length to steps if it's a ratio + if isinstance(cooldown_length, float): + cooldown_steps = int(cooldown_length * total_steps) + else: + cooldown_steps = cooldown_length + cooldown_start = total_steps - cooldown_steps + + # Ensure warmup doesn't exceed total steps warmup_steps = min(warmup_steps, total_steps) - cooldown_steps = min(cooldown_steps, total_steps - warmup_steps) + + # Ensure cooldown_start is after warmup + cooldown_start = max(cooldown_start, warmup_steps) + cooldown_steps = total_steps - cooldown_start # Warmup phase if current_step < warmup_steps: return learning_rate * (current_step + 1) / warmup_steps # Cooldown phase - cooldown_start = total_steps - cooldown_steps if current_step >= cooldown_start and cooldown_steps > 0: steps_into_cooldown = current_step - cooldown_start remaining_ratio = 1.0 - (steps_into_cooldown + 1) / cooldown_steps From 4c4ed547d5ceb0366246071d352cc75a2d150ddc Mon Sep 17 00:00:00 2001 From: Kyle Corbitt Date: Tue, 15 Jul 2025 17:59:20 -0700 Subject: [PATCH 3/3] Add experiment 228 with learning rate warmup and cooldown MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add warmup_length and cooldown_length fields to ProjectPolicyConfig - Update train.py to use adjust_lr function with batch-specific learning rates - Add experiment 228 with 20-step warmup and cooldown starting at step 20 This experiment will test whether warmup/cooldown improves training compared to our baseline constant learning rate approach. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- examples/art-e/all_experiments.py | 7 ++++++- examples/art-e/art_e/project_types.py | 4 +++- examples/art-e/art_e/train.py | 12 ++++++++++-- src/art/utils/__init__.py | 3 ++- 4 files changed, 21 insertions(+), 5 deletions(-) diff --git a/examples/art-e/all_experiments.py b/examples/art-e/all_experiments.py index c7ca4fa0..78f49b92 100644 --- a/examples/art-e/all_experiments.py +++ b/examples/art-e/all_experiments.py @@ -212,4 +212,9 @@ models["227"] = models["008"].model_copy(deep=True) models["227"].name = "email-agent-227" -models["220"].base_model = "willcb/Qwen3-14B" +models["227"].base_model = "willcb/Qwen3-14B" + +models["228"] = models["008"].model_copy(deep=True) +models["228"].name = "email-agent-228" +models["228"].config.warmup_length = 20 +models["228"].config.cooldown_length = -20 diff --git a/examples/art-e/art_e/project_types.py b/examples/art-e/art_e/project_types.py index 2c834451..20183b6b 100644 --- a/examples/art-e/art_e/project_types.py +++ b/examples/art-e/art_e/project_types.py @@ -1,5 +1,5 @@ from pydantic import BaseModel -from typing import Literal +from typing import Literal, Union class ProjectPolicyConfig(BaseModel): @@ -13,6 +13,8 @@ class ProjectPolicyConfig(BaseModel): trajectories_per_group: int = 6 groups_per_step: int = 1 learning_rate: float = 1.2e-5 + warmup_length: Union[int, float] = 0 + cooldown_length: Union[int, float] = 0 eval_steps: int = 30 val_set_size: int = 100 training_dataset_size: int = 4000 diff --git a/examples/art-e/art_e/train.py b/examples/art-e/art_e/train.py index 41d12225..874b3630 100644 --- a/examples/art-e/art_e/train.py +++ b/examples/art-e/art_e/train.py @@ -7,7 +7,7 @@ from art_e.data.query_iterators import load_synthetic_queries from art_e.data.types_enron import SyntheticQuery from art_e.data.local_email_db import generate_database -from art.utils import iterate_dataset +from art.utils import iterate_dataset, adjust_lr from art_e.project_types import ProjectPolicyConfig from art_e.evaluate.benchmark import benchmark_model import os @@ -139,9 +139,17 @@ async def judge_after_each( ) continue # Proceed to next batch/epoch without training. + # Calculate learning rate for this batch + current_lr = adjust_lr( + batch, + learning_rate=model.config.learning_rate, + warmup_length=model.config.warmup_length, + cooldown_length=model.config.cooldown_length, + ) + await model.train( groups, - config=art.TrainConfig(learning_rate=model.config.learning_rate), + config=art.TrainConfig(learning_rate=current_lr), _config=art.dev.TrainConfig( allow_training_without_logprobs=model.config.messages_only, precalculate_logprobs=model.config.precalculate_logprobs, diff --git a/src/art/utils/__init__.py b/src/art/utils/__init__.py index 12bd3995..cb372d7d 100644 --- a/src/art/utils/__init__.py +++ b/src/art/utils/__init__.py @@ -1,7 +1,7 @@ # Import all utilities to maintain the same interface from .format_message import format_message from .retry import retry -from .iterate_dataset import iterate_dataset +from .iterate_dataset import iterate_dataset, adjust_lr from .limit_concurrency import limit_concurrency from .log_http_errors import log_http_errors from .get_model_step import get_model_step @@ -10,6 +10,7 @@ "format_message", "retry", "iterate_dataset", + "adjust_lr", "limit_concurrency", "log_http_errors", "get_model_step",