Skip to content

feat: Add checkpoint forking functionality #253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,19 @@ This project uses the `uv` package manager.
## Releases

- If asked to help with a release, refer to the checklist in CONTRIBUTING.md

## Documentation

- All documentation is in the `docs` directory.
- If you add a new page, be sure to add it to the sidebar in `docs/docs.json`.
- If you move a page, be sure to update the sidebar in `docs/docs.json` and check for any broken links.

### Adding images

- Add images to the `docs/images` directory
- If the image is a png, first convert it to webp using `magick <input.png> <output.webp>`. Do not include the original png in the repo.
- Use the `<Frame>` tag to add images with captions as seen in the page `checkpoint-forking.mdx`.

### Adding notes

- Add notes using the `<Note>` tag as seen in the page `ruler.mdx`
7 changes: 7 additions & 0 deletions docs/docs.json
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@
"fundamentals/ruler"
]
},
{
"group": "Features",
"pages": [
"features/checkpoint-forking",
"features/additional-histories"
]
},
{
"group": "Tutorials",
"pages": ["tutorials/summarizer"]
Expand Down
132 changes: 132 additions & 0 deletions docs/features/checkpoint-forking.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
---
title: Checkpoint Forking
description: Learn how to fork training from existing model checkpoints
---

# Checkpoint Forking

<Frame caption="Run 206 had a catastrophic failure. We fixed it by forking into run 230 before the point of collapse.">
<img
src="/images/forked-run.webp"
alt="Checkpoint forking example"
style={{ maxWidth: "100%", height: "auto" }}
/>
</Frame>

Checkpoint forking allows you to create a new training run that starts from an existing model's checkpoint. This is particularly useful when:

- Training has gone off track and you want to restart from a known good checkpoint
- You want to experiment with different hyperparameters from a specific point
- You need to branch off multiple experiments from the same checkpoint

<Note>
This feature is marked as experimental because we're still refining the API shape. However, the core functionality will remain stable.
</Note>

## Basic Usage

The simplest way to fork a checkpoint is to specify it when creating your model:

```python
import art
from art.local import LocalBackend

async def train():
with LocalBackend() as backend:
# Create a new model that will fork from an existing checkpoint
model = art.TrainableModel(
name="my-model-v2",
project="my-project",
base_model="Qwen/Qwen2.5-14B-Instruct",
)

# Copy the checkpoint from another model
await backend._experimental_fork_checkpoint(
model,
from_model="my-model-v1",
not_after_step=500, # Use checkpoint at or before step 500
verbose=True,
)

# Register and continue training
await model.register(backend)
# ... rest of training code
```

## Forking from S3

If your checkpoints are stored in S3, you can fork directly from there:

```python
await backend._experimental_fork_checkpoint(
model,
from_model="my-model-v1",
from_s3_bucket="my-backup-bucket",
not_after_step=500,
verbose=True,
)
```

## Parameters

### `from_model` (required)
The name of the model to fork from.

### `from_project` (optional)
The project containing the model to fork from. Defaults to the current model's project.

### `from_s3_bucket` (optional)
S3 bucket to pull the checkpoint from. If not provided, will look for the checkpoint locally.

### `not_after_step` (optional)
The maximum step number to use. The function will use the latest checkpoint that is less than or equal to this step. If not provided, uses the latest available checkpoint.

### `verbose` (optional)
Whether to print detailed progress information during the forking process.

## How It Works

1. **Checkpoint Selection**: The system finds the appropriate checkpoint based on your `not_after_step` parameter
2. **S3 Pull** (if needed): If forking from S3, only the specific checkpoint is downloaded, not the entire model history
3. **Checkpoint Copy**: The checkpoint is copied to your new model's directory at the same step number
4. **Training Continuation**: Your model can now continue training from this checkpoint

## Example: Lowering the Learning Rate

Here's a practical example of using checkpoint forking to test a lower learning rate:

```python
# Original model trained with lr=1e-5
base_model = art.TrainableModel(
name="summarizer-base",
project="experiments",
base_model="Qwen/Qwen2.5-14B-Instruct",
)

# Fork at step 1000 to try lower learning rate
low_lr_model = art.TrainableModel(
name="summarizer-low-lr",
project="experiments",
base_model="Qwen/Qwen2.5-14B-Instruct",
)

async def experiment():
with LocalBackend() as backend:
# Fork the model from the base model
await backend._experimental_fork_checkpoint(
low_lr_model,
from_model="summarizer-base",
not_after_step=1000,
verbose=True,
)
await model.register(backend)

# Now train with a lower learning rate
# ... training code with different configs
```

## Notes

- Checkpoints are forked at the same step number they had in the source model
- The `not_after_step` parameter uses `<=` comparison, so specifying 500 will include step 500 if it exists
- Only checkpoint files are copied - training logs and trajectories are not included in the fork
13 changes: 2 additions & 11 deletions docs/fundamentals/ruler.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,13 @@ description: "Learn how to use RULER to automatically reward your agents."

RULER (Relative Universal LLM-Elicited Rewards) is a general-purpose reward function that uses an LLM-as-judge to rank multiple agent trajectories. It requires no labeled data, expert feedback, or hand-crafted reward functions, yet reliably improves agent performance.

<div align="center">
<Frame caption="RULER performance across multiple tasks at launch. In 3 out of 4 tasks, models trained with RULER slightly outperform those trained with hand-crafted reward functions. See the full launch announcement for details.">
<img
src="/images/ruler-results.png"
alt="RULER Performance Results"
style={{ maxWidth: "100%", height: "auto" }}
/>
<p>
<em>
RULER performance across multiple tasks at launch. In 3 out of 4 tasks,
models trained with RULER slightly outperform those trained with
hand-crafted reward functions. See the full{" "}
<a href="https://openpipe.ai/blog/ruler">launch announcement</a> for
details.
</em>
</p>
</div>
</Frame>

## Key Benefits

Expand Down
Binary file added docs/images/forked-run.webp
Binary file not shown.
2 changes: 1 addition & 1 deletion docs/resources/glossary.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ icon: "circle-info"

## Additional Histories

A feature that allows a trajectory to contain multiple separate conversation histories. Used for training agents with non-linear conversation flows, preserving special tokens across turns, or handling sub-agent interactions. See [Additional Histories](/fundamentals/additional-histories) for details.
A feature that allows a trajectory to contain multiple separate conversation histories. Used for training agents with non-linear conversation flows, preserving special tokens across turns, or handling sub-agent interactions. See [Additional Histories](/features/additional-histories) for details.

## Agent

Expand Down
2 changes: 1 addition & 1 deletion docs/resources/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ Here are additional models that we've tested and found to work well with ART:
- [Llama 3.2 3B Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct)
- [Llama 3.3 70B Instruct](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct)
- [Qwen 2.5 72B Instruct](https://huggingface.co/Qwen/Qwen2.5-72B-Instruct)
- Additionally, the [Qwen 3](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) family of models is well supported for single-turn workflows. For multi-turn workflows the Qwen 3 chat template removes the `<think>` tokens from previous turns, which makes training more complicated. It is still possible to use for multi-turn workflows by splitting each turn into a separate message history with our `additional_histories` trajectory parameter (see [Additional Histories](/fundamentals/additional-histories)).
- Additionally, the [Qwen 3](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) family of models is well supported for single-turn workflows. For multi-turn workflows the Qwen 3 chat template removes the `<think>` tokens from previous turns, which makes training more complicated. It is still possible to use for multi-turn workflows by splitting each turn into a separate message history with our `additional_histories` trajectory parameter (see [Additional Histories](/features/additional-histories)).

If you're curious about a model that is not listed above, ask in the Discord [#support](https://discord.com/channels/1359674493949448375/1359674622965973185) channel.
48 changes: 48 additions & 0 deletions examples/art-e/CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
## Adding New Models to the Email Agent

When adding a new model to the email agent project, follow these steps:

1. **Add the model definition to `all_experiments.py`**:

```python
# Model XXX: Description of what makes this model unique
models["XXX"] = models["BASE_MODEL_ID"].model_copy(deep=True)
models["XXX"].name = "email-agent-XXX"
# Add any custom configuration here
```

2. **If you need a new configuration option**:

a. Add it to `ProjectPolicyConfig` in `project_types.py`:

```python
class ProjectPolicyConfig(BaseModel):
# ... existing fields ...
my_new_option: bool = True # Add description and default value
```

b. If it affects training, update `train.py` to pass it to the training function:

```python
await model.train(
groups,
config=art.TrainConfig(learning_rate=model.config.learning_rate),
_config=art.dev.TrainConfig(
# ... existing parameters ...
my_new_option=model.config.my_new_option,
),
)
```

c. If it affects rollouts, update `rollout.py` to use the new option.

3. **Common model variations**:
- Base model change: `models["XXX"].base_model = "new-base-model"`
- Learning rate: `models["XXX"].config.learning_rate = 1e-5`
- Training epochs: `models["XXX"].config.num_epochs = 3`
- Judge model: `models["XXX"].config.group_judge_model = "model-name"`
- Fork from checkpoint:
```python
models["XXX"].config.fork_from_model = "email-agent-YYY"
models["XXX"].config.fork_not_after_step = 90
```
43 changes: 24 additions & 19 deletions examples/art-e/all_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,29 +61,17 @@
models["201"] = models["008"].model_copy(deep=True)
models["201"].name = "email-agent-201"

# Model 202: like 008 but with judge-group rescoring during training
models["202"] = models["008"].model_copy(deep=True)
models["202"].name = "email-agent-202"
# Enable the new flag
models["202"].config.use_judge_group_variant = "v1"

# Model 204: like 202 but with judge-group rescoring variant v2
models["204"] = models["202"].model_copy(deep=True)
models["204"].name = "email-agent-204"
# Enable the v2 flag
models["204"].config.use_judge_group_variant = "v2"

# Model 205: like 204 but using Gemini 2.5 Flash as the judge-group model
models["205"] = models["204"].model_copy(deep=True)
# Model 205: like 201 but using Gemini 2.5 Flash as the judge-group model
models["205"] = models["201"].model_copy(deep=True)
models["205"].name = "email-agent-205"
# Set the judge group model
models["205"].config.group_judge_model = "gemini/gemini-2.5-flash"
models["205"].config.ruler_judge_model = "gemini/gemini-2.5-flash"

# Model 206: like 204 but using Qwen3 32B as the judge-group model
models["206"] = models["204"].model_copy(deep=True)
models["206"].name = "email-agent-206"
# Set the judge group model
models["206"].config.group_judge_model = "openrouter/qwen/qwen3-32b"
models["206"].config.ruler_judge_model = "openrouter/qwen/qwen3-32b"

# Model 207: like 205 but only uses 12 training examples total
models["207"] = models["205"].model_copy(deep=True)
Expand Down Expand Up @@ -133,7 +121,7 @@

models["213"] = models["206"].model_copy(deep=True)
models["213"].name = "email-agent-213"
models["213"].config.group_judge_model = "openai/o3"
models["213"].config.ruler_judge_model = "openai/o3"

models["215"] = models["008"].model_copy(deep=True)
models["215"].name = "email-agent-215"
Expand All @@ -150,7 +138,7 @@
models["218"] = models["206"].model_copy(deep=True)
models["218"].name = "email-agent-218-5"
models["218"].base_model = "Qwen/Qwen3-32B"
models["218"].config.group_judge_model = "base_model"
models["218"].config.ruler_judge_model = "base_model"
models["218"].config.include_qwen3_nothink = True

# Model 219: like 008 but with custom internal config (low max_grad_norm) and high learning rate
Expand Down Expand Up @@ -185,7 +173,7 @@
num_scheduler_steps=1,
)
)
models["222"].config.group_judge_model = "base_model"
models["222"].config.ruler_judge_model = "base_model"
models["222"].config.include_qwen3_nothink = True

models["223"] = models["206"].model_copy(deep=True)
Expand All @@ -205,3 +193,20 @@

models["225"] = models["224"].model_copy(deep=True)
models["225"].name = "email-agent-225"

# Model 229: Fork from 224 not after step 1381
models["229"] = models["224"].model_copy(deep=True)
models["229"].name = "email-agent-229"
models["229"].config.fork_from_model = "email-agent-224"
models["229"].config.fork_not_after_step = 1381

# Model 230: Fork from 206 not after step 90
models["230"] = models["206"].model_copy(deep=True)
models["230"].name = "email-agent-230"
models["230"].config.fork_from_model = "email-agent-206"
models["230"].config.fork_not_after_step = 90

# Model 231: Like 206 but with scale_rewards=False
models["231"] = models["206"].model_copy(deep=True)
models["231"].name = "email-agent-231"
models["231"].config.scale_rewards = False
17 changes: 10 additions & 7 deletions examples/art-e/art_e/project_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from pydantic import BaseModel
from typing import Literal


class ProjectPolicyConfig(BaseModel):
Expand All @@ -17,14 +16,18 @@ class ProjectPolicyConfig(BaseModel):
val_set_size: int = 100
training_dataset_size: int = 4000
num_epochs: int = 4
use_judge_group_variant: Literal["v1"] | Literal["v2"] | None = (
None # e.g., "v1", "v2"; None disables judge-group rescoring
)
# Model name to use for judge-group rescoring (LLM-as-a-judge). Defaults to
# OpenAI's o3 model. You can override this per-training run.
group_judge_model: str = "openai/o3"
# Model name to use for RULER rescoring (LLM-as-a-judge). Defaults to
ruler_judge_model: str | None = None
minimum_reward_std_dev: float = 0.0
# Random seed to control which subset of the training data is sampled. When None, the sampler can
# choose its own default (e.g., derive from the current time).
training_dataset_seed: int | None = None
messages_only: bool = False

# Fork configuration
fork_from_model: str | None = None
fork_from_project: str | None = None
fork_not_after_step: int | None = None

# Training configuration
scale_rewards: bool = True # Whether to scale rewards during training
Loading