Skip to content

Add adjust_lr function for learning rate schedules #247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions dev/yes_no_maybe_skypilot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "openpipe-art[skypilot]==0.4.3",
# "skypilot[runpod]",
# ]
# ///
import art
from art.skypilot import SkyPilotBackend
from dotenv import load_dotenv
import openai
import asyncio

load_dotenv()


async def main():
backend = await SkyPilotBackend.initialize_cluster(
cluster_name="kyle-yes-no-maybe",
gpu="H100",
)

model = art.TrainableModel(
name="001",
project="yes-no-maybe",
base_model="Qwen/Qwen2.5-7B-Instruct",
_internal_config=art.dev.InternalModelConfig(
_decouple_vllm_and_unsloth=True,
engine_args=art.dev.EngineArgs(gpu_memory_utilization=0.7),
),
)
await model.register(backend)

async def rollout(client: openai.AsyncOpenAI, prompt: str) -> art.Trajectory:
messages: art.Messages = [
{
"role": "user",
"content": prompt,
}
]
chat_completion = await client.chat.completions.create(
messages=messages, model=model.name, max_tokens=100, timeout=100
)
choice = chat_completion.choices[0]
content = choice.message.content
assert isinstance(content, str)
if content == "yes":
reward = 0.5
elif content == "no":
reward = 0.75
elif content == "maybe":
reward = 1.0
else:
reward = 0.0
return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward)

def with_quotes(w):
return f"'{w}'"

prompts = [
f"{prefix} with {', '.join([with_quotes(w) if use_quotes else w for w in words]) if len(words) == 3 else f'{words[0]}' + (f' or {words[1]}' if len(words) > 1 else '')}"
for prefix in ["respond", "just respond"]
for use_quotes in [True, False]
for words in [
["yes", "no", "maybe"],
["maybe", "yes", "no"],
["no", "yes", "maybe"],
["yes", "maybe", "no"],
["yes", "no"],
["maybe", "no"],
["no", "maybe"],
["no", "yes"],
["yes", "no"],
]
]

openai_client = model.openai_client()
for _ in range(await model.get_step(), 1_000):
train_groups = await art.gather_trajectory_groups(
(
art.TrajectoryGroup(rollout(openai_client, prompt) for _ in range(32))
for prompt in prompts
),
)
await model.train(
train_groups,
config=art.TrainConfig(learning_rate=1e-4),
)


if __name__ == "__main__":
asyncio.run(main())
13 changes: 13 additions & 0 deletions examples/art-e/all_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,16 @@

models["225"] = models["224"].model_copy(deep=True)
models["225"].name = "email-agent-225"

models["226"] = models["008"].model_copy(deep=True)
models["226"].name = "email-agent-226"
models["226"].config.precalculate_logprobs = True

models["227"] = models["008"].model_copy(deep=True)
models["227"].name = "email-agent-227"
models["227"].base_model = "willcb/Qwen3-14B"

models["228"] = models["008"].model_copy(deep=True)
models["228"].name = "email-agent-228"
models["228"].config.warmup_length = 20
models["228"].config.cooldown_length = -20
5 changes: 4 additions & 1 deletion examples/art-e/art_e/project_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydantic import BaseModel
from typing import Literal
from typing import Literal, Union


class ProjectPolicyConfig(BaseModel):
Expand All @@ -13,6 +13,8 @@ class ProjectPolicyConfig(BaseModel):
trajectories_per_group: int = 6
groups_per_step: int = 1
learning_rate: float = 1.2e-5
warmup_length: Union[int, float] = 0
cooldown_length: Union[int, float] = 0
eval_steps: int = 30
val_set_size: int = 100
training_dataset_size: int = 4000
Expand All @@ -28,3 +30,4 @@ class ProjectPolicyConfig(BaseModel):
# choose its own default (e.g., derive from the current time).
training_dataset_seed: int | None = None
messages_only: bool = False
precalculate_logprobs: bool = False
17 changes: 12 additions & 5 deletions examples/art-e/art_e/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from art_e.data.query_iterators import load_synthetic_queries
from art_e.data.types_enron import SyntheticQuery
from art_e.data.local_email_db import generate_database
from art.utils import iterate_dataset
from art.utils import iterate_dataset, adjust_lr
from art_e.project_types import ProjectPolicyConfig
from art_e.evaluate.benchmark import benchmark_model
import os
Expand Down Expand Up @@ -139,13 +139,20 @@ async def judge_after_each(
)
continue # Proceed to next batch/epoch without training.

# Calculate learning rate for this batch
current_lr = adjust_lr(
batch,
learning_rate=model.config.learning_rate,
warmup_length=model.config.warmup_length,
cooldown_length=model.config.cooldown_length,
)

await model.train(
groups,
config=art.TrainConfig(learning_rate=model.config.learning_rate),
config=art.TrainConfig(learning_rate=current_lr),
_config=art.dev.TrainConfig(
allow_training_without_logprobs=True
if model.config.messages_only
else False
allow_training_without_logprobs=model.config.messages_only,
precalculate_logprobs=model.config.precalculate_logprobs,
),
)

Expand Down
18 changes: 4 additions & 14 deletions examples/art-e/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion src/art/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Import all utilities to maintain the same interface
from .format_message import format_message
from .retry import retry
from .iterate_dataset import iterate_dataset
from .iterate_dataset import iterate_dataset, adjust_lr
from .limit_concurrency import limit_concurrency
from .log_http_errors import log_http_errors
from .get_model_step import get_model_step
Expand All @@ -10,6 +10,7 @@
"format_message",
"retry",
"iterate_dataset",
"adjust_lr",
"limit_concurrency",
"log_http_errors",
"get_model_step",
Expand Down
71 changes: 69 additions & 2 deletions src/art/utils/iterate_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import random
from dataclasses import dataclass
from typing import List, Generator, TypeVar, Generic
from typing import List, Generator, TypeVar, Generic, Union
from tqdm.auto import tqdm

T = TypeVar("T")
Expand All @@ -15,6 +15,69 @@ class DatasetBatch(Generic[T]):
step: int
epoch: int
epoch_step: int
total_steps: int


def adjust_lr(
batch: DatasetBatch,
learning_rate: float,
warmup_length: Union[int, float] = 0,
cooldown_length: Union[int, float] = 0,
) -> float:
"""
Calculate the learning rate for a given batch based on the schedule.

Args:
batch: The DatasetBatch containing step and total_steps information.
learning_rate: The base learning rate.
warmup_length: Either an int (number of steps) or float (ratio of total steps). Defaults to 0.
cooldown_length: Either an int (number of steps) or float (ratio of total steps). Defaults to 0.
If negative, specifies the step at which cooldown starts (e.g., -20 means cooldown starts at step 20).

Returns:
The adjusted learning rate for the current batch.
"""
current_step = batch.step
total_steps = batch.total_steps

# Convert warmup_length to steps if it's a ratio
if isinstance(warmup_length, float):
warmup_steps = int(warmup_length * total_steps)
else:
warmup_steps = warmup_length

# Handle cooldown_length
if cooldown_length < 0:
# Negative value means cooldown starts at that step
cooldown_start = int(-cooldown_length)
cooldown_steps = total_steps - cooldown_start
else:
# Convert cooldown_length to steps if it's a ratio
if isinstance(cooldown_length, float):
cooldown_steps = int(cooldown_length * total_steps)
else:
cooldown_steps = cooldown_length
cooldown_start = total_steps - cooldown_steps

# Ensure warmup doesn't exceed total steps
warmup_steps = min(warmup_steps, total_steps)

# Ensure cooldown_start is after warmup
cooldown_start = max(cooldown_start, warmup_steps)
cooldown_steps = total_steps - cooldown_start

# Warmup phase
if current_step < warmup_steps:
return learning_rate * (current_step + 1) / warmup_steps

# Cooldown phase
if current_step >= cooldown_start and cooldown_steps > 0:
steps_into_cooldown = current_step - cooldown_start
remaining_ratio = 1.0 - (steps_into_cooldown + 1) / cooldown_steps
return learning_rate * remaining_ratio

# Main phase (between warmup and cooldown)
return learning_rate


def iterate_dataset(
Expand Down Expand Up @@ -82,7 +145,11 @@ def iterate_dataset(
batch_indices = indices[i : i + groups_per_step]
items = [dataset[idx] for idx in batch_indices]
yield DatasetBatch(
items=items, epoch=epoch, step=global_step, epoch_step=epoch_step
items=items,
epoch=epoch,
step=global_step,
epoch_step=epoch_step,
total_steps=total_steps,
)

# Update progress bar after yielding
Expand Down