diff --git a/CLAUDE.md b/CLAUDE.md index 6c068179..d7f193a0 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -13,3 +13,19 @@ This project uses the `uv` package manager. ## Releases - If asked to help with a release, refer to the checklist in CONTRIBUTING.md + +## Documentation + +- All documentation is in the `docs` directory. +- If you add a new page, be sure to add it to the sidebar in `docs/docs.json`. +- If you move a page, be sure to update the sidebar in `docs/docs.json` and check for any broken links. + +### Adding images + +- Add images to the `docs/images` directory +- If the image is a png, first convert it to webp using `magick `. Do not include the original png in the repo. +- Use the `` tag to add images with captions as seen in the page `checkpoint-forking.mdx`. + +### Adding notes + +- Add notes using the `` tag as seen in the page `ruler.mdx` diff --git a/docs/docs.json b/docs/docs.json index 13de20fb..850dd480 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -60,6 +60,13 @@ "fundamentals/ruler" ] }, + { + "group": "Features", + "pages": [ + "features/checkpoint-forking", + "features/additional-histories" + ] + }, { "group": "Tutorials", "pages": ["tutorials/summarizer"] diff --git a/docs/fundamentals/additional-histories.mdx b/docs/features/additional-histories.mdx similarity index 100% rename from docs/fundamentals/additional-histories.mdx rename to docs/features/additional-histories.mdx diff --git a/docs/features/checkpoint-forking.mdx b/docs/features/checkpoint-forking.mdx new file mode 100644 index 00000000..0e28e789 --- /dev/null +++ b/docs/features/checkpoint-forking.mdx @@ -0,0 +1,132 @@ +--- +title: Checkpoint Forking +description: Learn how to fork training from existing model checkpoints +--- + +# Checkpoint Forking + + + Checkpoint forking example + + +Checkpoint forking allows you to create a new training run that starts from an existing model's checkpoint. This is particularly useful when: + +- Training has gone off track and you want to restart from a known good checkpoint +- You want to experiment with different hyperparameters from a specific point +- You need to branch off multiple experiments from the same checkpoint + + +This feature is marked as experimental because we're still refining the API shape. However, the core functionality will remain stable. + + +## Basic Usage + +The simplest way to fork a checkpoint is to specify it when creating your model: + +```python +import art +from art.local import LocalBackend + +async def train(): + with LocalBackend() as backend: + # Create a new model that will fork from an existing checkpoint + model = art.TrainableModel( + name="my-model-v2", + project="my-project", + base_model="Qwen/Qwen2.5-14B-Instruct", + ) + + # Copy the checkpoint from another model + await backend._experimental_fork_checkpoint( + model, + from_model="my-model-v1", + not_after_step=500, # Use checkpoint at or before step 500 + verbose=True, + ) + + # Register and continue training + await model.register(backend) + # ... rest of training code +``` + +## Forking from S3 + +If your checkpoints are stored in S3, you can fork directly from there: + +```python +await backend._experimental_fork_checkpoint( + model, + from_model="my-model-v1", + from_s3_bucket="my-backup-bucket", + not_after_step=500, + verbose=True, +) +``` + +## Parameters + +### `from_model` (required) +The name of the model to fork from. + +### `from_project` (optional) +The project containing the model to fork from. Defaults to the current model's project. + +### `from_s3_bucket` (optional) +S3 bucket to pull the checkpoint from. If not provided, will look for the checkpoint locally. + +### `not_after_step` (optional) +The maximum step number to use. The function will use the latest checkpoint that is less than or equal to this step. If not provided, uses the latest available checkpoint. + +### `verbose` (optional) +Whether to print detailed progress information during the forking process. + +## How It Works + +1. **Checkpoint Selection**: The system finds the appropriate checkpoint based on your `not_after_step` parameter +2. **S3 Pull** (if needed): If forking from S3, only the specific checkpoint is downloaded, not the entire model history +3. **Checkpoint Copy**: The checkpoint is copied to your new model's directory at the same step number +4. **Training Continuation**: Your model can now continue training from this checkpoint + +## Example: Lowering the Learning Rate + +Here's a practical example of using checkpoint forking to test a lower learning rate: + +```python +# Original model trained with lr=1e-5 +base_model = art.TrainableModel( + name="summarizer-base", + project="experiments", + base_model="Qwen/Qwen2.5-14B-Instruct", +) + +# Fork at step 1000 to try lower learning rate +low_lr_model = art.TrainableModel( + name="summarizer-low-lr", + project="experiments", + base_model="Qwen/Qwen2.5-14B-Instruct", +) + +async def experiment(): + with LocalBackend() as backend: + # Fork the model from the base model + await backend._experimental_fork_checkpoint( + low_lr_model, + from_model="summarizer-base", + not_after_step=1000, + verbose=True, + ) + await model.register(backend) + + # Now train with a lower learning rate + # ... training code with different configs +``` + +## Notes + +- Checkpoints are forked at the same step number they had in the source model +- The `not_after_step` parameter uses `<=` comparison, so specifying 500 will include step 500 if it exists +- Only checkpoint files are copied - training logs and trajectories are not included in the fork \ No newline at end of file diff --git a/docs/fundamentals/ruler.mdx b/docs/fundamentals/ruler.mdx index 81cd3566..1fd22132 100644 --- a/docs/fundamentals/ruler.mdx +++ b/docs/fundamentals/ruler.mdx @@ -8,22 +8,13 @@ description: "Learn how to use RULER to automatically reward your agents." RULER (Relative Universal LLM-Elicited Rewards) is a general-purpose reward function that uses an LLM-as-judge to rank multiple agent trajectories. It requires no labeled data, expert feedback, or hand-crafted reward functions, yet reliably improves agent performance. -
+ RULER Performance Results -

- - RULER performance across multiple tasks at launch. In 3 out of 4 tasks, - models trained with RULER slightly outperform those trained with - hand-crafted reward functions. See the full{" "} - launch announcement for - details. - -

-
+ ## Key Benefits diff --git a/docs/images/forked-run.webp b/docs/images/forked-run.webp new file mode 100644 index 00000000..c5b306b0 Binary files /dev/null and b/docs/images/forked-run.webp differ diff --git a/docs/resources/glossary.mdx b/docs/resources/glossary.mdx index 6825a9f9..4a08b5fb 100644 --- a/docs/resources/glossary.mdx +++ b/docs/resources/glossary.mdx @@ -6,7 +6,7 @@ icon: "circle-info" ## Additional Histories -A feature that allows a trajectory to contain multiple separate conversation histories. Used for training agents with non-linear conversation flows, preserving special tokens across turns, or handling sub-agent interactions. See [Additional Histories](/fundamentals/additional-histories) for details. +A feature that allows a trajectory to contain multiple separate conversation histories. Used for training agents with non-linear conversation flows, preserving special tokens across turns, or handling sub-agent interactions. See [Additional Histories](/features/additional-histories) for details. ## Agent diff --git a/docs/resources/models.mdx b/docs/resources/models.mdx index a9b62219..faf70167 100644 --- a/docs/resources/models.mdx +++ b/docs/resources/models.mdx @@ -25,6 +25,6 @@ Here are additional models that we've tested and found to work well with ART: - [Llama 3.2 3B Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) - [Llama 3.3 70B Instruct](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct) - [Qwen 2.5 72B Instruct](https://huggingface.co/Qwen/Qwen2.5-72B-Instruct) -- Additionally, the [Qwen 3](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) family of models is well supported for single-turn workflows. For multi-turn workflows the Qwen 3 chat template removes the `` tokens from previous turns, which makes training more complicated. It is still possible to use for multi-turn workflows by splitting each turn into a separate message history with our `additional_histories` trajectory parameter (see [Additional Histories](/fundamentals/additional-histories)). +- Additionally, the [Qwen 3](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) family of models is well supported for single-turn workflows. For multi-turn workflows the Qwen 3 chat template removes the `` tokens from previous turns, which makes training more complicated. It is still possible to use for multi-turn workflows by splitting each turn into a separate message history with our `additional_histories` trajectory parameter (see [Additional Histories](/features/additional-histories)). If you're curious about a model that is not listed above, ask in the Discord [#support](https://discord.com/channels/1359674493949448375/1359674622965973185) channel. diff --git a/examples/art-e/CLAUDE.md b/examples/art-e/CLAUDE.md new file mode 100644 index 00000000..e0f35b66 --- /dev/null +++ b/examples/art-e/CLAUDE.md @@ -0,0 +1,48 @@ +## Adding New Models to the Email Agent + +When adding a new model to the email agent project, follow these steps: + +1. **Add the model definition to `all_experiments.py`**: + + ```python + # Model XXX: Description of what makes this model unique + models["XXX"] = models["BASE_MODEL_ID"].model_copy(deep=True) + models["XXX"].name = "email-agent-XXX" + # Add any custom configuration here + ``` + +2. **If you need a new configuration option**: + + a. Add it to `ProjectPolicyConfig` in `project_types.py`: + + ```python + class ProjectPolicyConfig(BaseModel): + # ... existing fields ... + my_new_option: bool = True # Add description and default value + ``` + + b. If it affects training, update `train.py` to pass it to the training function: + + ```python + await model.train( + groups, + config=art.TrainConfig(learning_rate=model.config.learning_rate), + _config=art.dev.TrainConfig( + # ... existing parameters ... + my_new_option=model.config.my_new_option, + ), + ) + ``` + + c. If it affects rollouts, update `rollout.py` to use the new option. + +3. **Common model variations**: + - Base model change: `models["XXX"].base_model = "new-base-model"` + - Learning rate: `models["XXX"].config.learning_rate = 1e-5` + - Training epochs: `models["XXX"].config.num_epochs = 3` + - Judge model: `models["XXX"].config.group_judge_model = "model-name"` + - Fork from checkpoint: + ```python + models["XXX"].config.fork_from_model = "email-agent-YYY" + models["XXX"].config.fork_not_after_step = 90 + ``` diff --git a/examples/art-e/all_experiments.py b/examples/art-e/all_experiments.py index c00729e8..dd89e43d 100644 --- a/examples/art-e/all_experiments.py +++ b/examples/art-e/all_experiments.py @@ -61,29 +61,17 @@ models["201"] = models["008"].model_copy(deep=True) models["201"].name = "email-agent-201" -# Model 202: like 008 but with judge-group rescoring during training -models["202"] = models["008"].model_copy(deep=True) -models["202"].name = "email-agent-202" -# Enable the new flag -models["202"].config.use_judge_group_variant = "v1" - -# Model 204: like 202 but with judge-group rescoring variant v2 -models["204"] = models["202"].model_copy(deep=True) -models["204"].name = "email-agent-204" -# Enable the v2 flag -models["204"].config.use_judge_group_variant = "v2" - -# Model 205: like 204 but using Gemini 2.5 Flash as the judge-group model -models["205"] = models["204"].model_copy(deep=True) +# Model 205: like 201 but using Gemini 2.5 Flash as the judge-group model +models["205"] = models["201"].model_copy(deep=True) models["205"].name = "email-agent-205" # Set the judge group model -models["205"].config.group_judge_model = "gemini/gemini-2.5-flash" +models["205"].config.ruler_judge_model = "gemini/gemini-2.5-flash" # Model 206: like 204 but using Qwen3 32B as the judge-group model models["206"] = models["204"].model_copy(deep=True) models["206"].name = "email-agent-206" # Set the judge group model -models["206"].config.group_judge_model = "openrouter/qwen/qwen3-32b" +models["206"].config.ruler_judge_model = "openrouter/qwen/qwen3-32b" # Model 207: like 205 but only uses 12 training examples total models["207"] = models["205"].model_copy(deep=True) @@ -133,7 +121,7 @@ models["213"] = models["206"].model_copy(deep=True) models["213"].name = "email-agent-213" -models["213"].config.group_judge_model = "openai/o3" +models["213"].config.ruler_judge_model = "openai/o3" models["215"] = models["008"].model_copy(deep=True) models["215"].name = "email-agent-215" @@ -150,7 +138,7 @@ models["218"] = models["206"].model_copy(deep=True) models["218"].name = "email-agent-218-5" models["218"].base_model = "Qwen/Qwen3-32B" -models["218"].config.group_judge_model = "base_model" +models["218"].config.ruler_judge_model = "base_model" models["218"].config.include_qwen3_nothink = True # Model 219: like 008 but with custom internal config (low max_grad_norm) and high learning rate @@ -185,7 +173,7 @@ num_scheduler_steps=1, ) ) -models["222"].config.group_judge_model = "base_model" +models["222"].config.ruler_judge_model = "base_model" models["222"].config.include_qwen3_nothink = True models["223"] = models["206"].model_copy(deep=True) @@ -205,3 +193,20 @@ models["225"] = models["224"].model_copy(deep=True) models["225"].name = "email-agent-225" + +# Model 229: Fork from 224 not after step 1381 +models["229"] = models["224"].model_copy(deep=True) +models["229"].name = "email-agent-229" +models["229"].config.fork_from_model = "email-agent-224" +models["229"].config.fork_not_after_step = 1381 + +# Model 230: Fork from 206 not after step 90 +models["230"] = models["206"].model_copy(deep=True) +models["230"].name = "email-agent-230" +models["230"].config.fork_from_model = "email-agent-206" +models["230"].config.fork_not_after_step = 90 + +# Model 231: Like 206 but with scale_rewards=False +models["231"] = models["206"].model_copy(deep=True) +models["231"].name = "email-agent-231" +models["231"].config.scale_rewards = False diff --git a/examples/art-e/art_e/project_types.py b/examples/art-e/art_e/project_types.py index 97dd16be..b4419ae4 100644 --- a/examples/art-e/art_e/project_types.py +++ b/examples/art-e/art_e/project_types.py @@ -1,5 +1,4 @@ from pydantic import BaseModel -from typing import Literal class ProjectPolicyConfig(BaseModel): @@ -17,14 +16,18 @@ class ProjectPolicyConfig(BaseModel): val_set_size: int = 100 training_dataset_size: int = 4000 num_epochs: int = 4 - use_judge_group_variant: Literal["v1"] | Literal["v2"] | None = ( - None # e.g., "v1", "v2"; None disables judge-group rescoring - ) - # Model name to use for judge-group rescoring (LLM-as-a-judge). Defaults to - # OpenAI's o3 model. You can override this per-training run. - group_judge_model: str = "openai/o3" + # Model name to use for RULER rescoring (LLM-as-a-judge). Defaults to + ruler_judge_model: str | None = None minimum_reward_std_dev: float = 0.0 # Random seed to control which subset of the training data is sampled. When None, the sampler can # choose its own default (e.g., derive from the current time). training_dataset_seed: int | None = None messages_only: bool = False + + # Fork configuration + fork_from_model: str | None = None + fork_from_project: str | None = None + fork_not_after_step: int | None = None + + # Training configuration + scale_rewards: bool = True # Whether to scale rewards during training diff --git a/examples/art-e/art_e/rollout.py b/examples/art-e/art_e/rollout.py index 1ee30f90..52d7ac7f 100644 --- a/examples/art-e/art_e/rollout.py +++ b/examples/art-e/art_e/rollout.py @@ -316,25 +316,31 @@ async def return_final_answer(answer: str, sources: list[str]): choice = llm_response.choices[0] assert isinstance(choice, Choices) - # Our rollout is only set up to handle one tool call at a time, so just ignore any parallel tool calls. - if choice.message.tool_calls is not None and len(choice.message.tool_calls) > 1: - choice.message.tool_calls = choice.message.tool_calls[:1] - traj.messages_and_choices.append(convert_litellm_choice_to_openai(choice)) # type: ignore if choice.message.tool_calls is None: rubric.bad_tool_call_name = True + traj.messages_and_choices.append( + { + "role": "user", + "content": "You did not call any tools. This is not allowed.", + } + ) break for tool_call in choice.message.tool_calls: - if tool_call is None: - rubric.bad_tool_call_args = True - break try: tool_args = json.loads(tool_call.function.arguments) assert isinstance(tool_args, dict) except Exception: rubric.bad_tool_call_args = True + traj.messages_and_choices.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": f"Error parsing arguments for {tool_call.function.name}. Cannot continue.", + } + ) break for tool_fn in tools: @@ -354,6 +360,13 @@ async def return_final_answer(answer: str, sources: list[str]): traj.logs.append( f"Invalid args for {tool_call.function.name}: {e}" ) + traj.messages_and_choices.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": f"Invalid args for {tool_call.function.name}: {e}", + } + ) break break else: diff --git a/examples/art-e/art_e/train.py b/examples/art-e/art_e/train.py index 4df3089a..2254925c 100644 --- a/examples/art-e/art_e/train.py +++ b/examples/art-e/art_e/train.py @@ -22,12 +22,28 @@ async def train(model: art.TrainableModel[ProjectPolicyConfig]): generate_database() with LocalBackend() as backend: - print(f"Pulling from S3 bucket: `{os.environ['BACKUP_BUCKET']}`") + print( + f"Pulling latest checkpoint from S3 bucket: `{os.environ['BACKUP_BUCKET']}`" + ) await backend._experimental_pull_from_s3( model, s3_bucket=os.environ["BACKUP_BUCKET"], verbose=True, + only_step="latest", # Only pull the latest checkpoint + exclude=["trajectories"], # Exclude trajectories to save space/time ) + + # Handle fork configuration if specified + if model.config.fork_from_model: + print(f"Forking from model: {model.config.fork_from_model}") + await backend._experimental_fork_checkpoint( + model, + from_model=model.config.fork_from_model, + from_s3_bucket=os.environ["BACKUP_BUCKET"], + not_after_step=model.config.fork_not_after_step, + verbose=True, + ) + await model.register(backend) print("Loading training data...") @@ -78,12 +94,12 @@ async def judge_after_each( If no judge is configured, simply return the group as-is. """ - if model.config.group_judge_model is None: + if model.config.ruler_judge_model is None: return group return await ruler_score_group( group, - model.config.group_judge_model, + model.config.ruler_judge_model, swallow_exceptions=True, ) @@ -143,9 +159,8 @@ async def judge_after_each( groups, config=art.TrainConfig(learning_rate=model.config.learning_rate), _config=art.dev.TrainConfig( - allow_training_without_logprobs=True - if model.config.messages_only - else False + allow_training_without_logprobs=model.config.messages_only, + scale_rewards=model.config.scale_rewards, ), ) diff --git a/examples/art-e/pyproject.toml b/examples/art-e/pyproject.toml index e358979e..f8802c0c 100644 --- a/examples/art-e/pyproject.toml +++ b/examples/art-e/pyproject.toml @@ -51,4 +51,7 @@ art-e = { workspace = true } openpipe-art = { path = "../../", editable = true } [dependency-groups] -dev = ["art-e"] +dev = [ + "art-e", + "ruff>=0.12.3", +] diff --git a/examples/art-e/uv.lock b/examples/art-e/uv.lock index 5a74a901..76cbb22e 100644 --- a/examples/art-e/uv.lock +++ b/examples/art-e/uv.lock @@ -311,6 +311,7 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "art-e" }, + { name = "ruff" }, ] [package.metadata] @@ -343,7 +344,10 @@ requires-dist = [ ] [package.metadata.requires-dev] -dev = [{ name = "art-e", editable = "." }] +dev = [ + { name = "art-e", editable = "." }, + { name = "ruff", specifier = ">=0.12.3" }, +] [[package]] name = "astor" @@ -3198,7 +3202,7 @@ wheels = [ [[package]] name = "openpipe-art" -version = "0.4.0" +version = "0.4.2" source = { editable = "../../" } dependencies = [ { name = "litellm" }, @@ -3214,7 +3218,6 @@ backend = [ { name = "hf-xet" }, { name = "peft" }, { name = "polars" }, - { name = "semver" }, { name = "setproctitle" }, { name = "setuptools", version = "79.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, { name = "setuptools", version = "80.9.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, @@ -3246,9 +3249,10 @@ requires-dist = [ { name = "peft", marker = "extra == 'backend'", specifier = ">=0.14.0" }, { name = "polars", marker = "extra == 'backend'", specifier = ">=1.26.0" }, { name = "seaborn", marker = "extra == 'plotting'", specifier = ">=0.13.2" }, - { name = "semver", marker = "extra == 'backend'", specifier = ">=3.0.4" }, + { name = "semver", marker = "extra == 'skypilot'", specifier = ">=3.0.4" }, { name = "setproctitle", marker = "extra == 'backend'", specifier = ">=1.3.6" }, { name = "setuptools", marker = "extra == 'backend'", specifier = ">=78.1.0" }, + { name = "skypilot", extras = ["cudo", "do", "fluidstack", "gcp", "lambda", "paperspace", "runpod"], marker = "extra == 'skypilot'", specifier = "==0.9.3" }, { name = "tblib", marker = "extra == 'backend'", specifier = ">=3.0.0" }, { name = "torch", marker = "extra == 'backend'", specifier = ">=2.7.0" }, { name = "torchao", marker = "extra == 'backend'", specifier = ">=0.9.0" }, @@ -3261,7 +3265,7 @@ requires-dist = [ { name = "wandb", marker = "extra == 'backend'", specifier = ">=0.19.8" }, { name = "weave", marker = "extra == 'backend'", specifier = ">=0.51.51" }, ] -provides-extras = ["plotting", "backend"] +provides-extras = ["plotting", "backend", "skypilot"] [package.metadata.requires-dev] dev = [ @@ -3271,7 +3275,6 @@ dev = [ { name = "ipywidgets", specifier = ">=8.1.5" }, { name = "openpipe", specifier = ">=4.49.0" }, { name = "ruff", specifier = ">=0.12.1" }, - { name = "skypilot", extras = ["cudo", "do", "fluidstack", "gcp", "lambda", "paperspace", "runpod"], specifier = "==0.8.0" }, ] [[package]] @@ -4935,6 +4938,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/93/0c0f002031f18b53af7a6166103c02b9c0667be528944137cc954ec921b3/rsa-4.7.2-py3-none-any.whl", hash = "sha256:78f9a9bf4e7be0c5ded4583326e7461e3a3c5aae24073648b4bdfa797d78c9d2", size = 34505, upload-time = "2021-02-24T10:55:03.55Z" }, ] +[[package]] +name = "ruff" +version = "0.12.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/2a/43955b530c49684d3c38fcda18c43caf91e99204c2a065552528e0552d4f/ruff-0.12.3.tar.gz", hash = "sha256:f1b5a4b6668fd7b7ea3697d8d98857390b40c1320a63a178eee6be0899ea2d77", size = 4459341, upload-time = "2025-07-11T13:21:16.086Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/fd/b44c5115539de0d598d75232a1cc7201430b6891808df111b8b0506aae43/ruff-0.12.3-py3-none-linux_armv6l.whl", hash = "sha256:47552138f7206454eaf0c4fe827e546e9ddac62c2a3d2585ca54d29a890137a2", size = 10430499, upload-time = "2025-07-11T13:20:26.321Z" }, + { url = "https://files.pythonhosted.org/packages/43/c5/9eba4f337970d7f639a37077be067e4ec80a2ad359e4cc6c5b56805cbc66/ruff-0.12.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:0a9153b000c6fe169bb307f5bd1b691221c4286c133407b8827c406a55282041", size = 11213413, upload-time = "2025-07-11T13:20:30.017Z" }, + { url = "https://files.pythonhosted.org/packages/e2/2c/fac3016236cf1fe0bdc8e5de4f24c76ce53c6dd9b5f350d902549b7719b2/ruff-0.12.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa6b24600cf3b750e48ddb6057e901dd5b9aa426e316addb2a1af185a7509882", size = 10586941, upload-time = "2025-07-11T13:20:33.046Z" }, + { url = "https://files.pythonhosted.org/packages/c5/0f/41fec224e9dfa49a139f0b402ad6f5d53696ba1800e0f77b279d55210ca9/ruff-0.12.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2506961bf6ead54887ba3562604d69cb430f59b42133d36976421bc8bd45901", size = 10783001, upload-time = "2025-07-11T13:20:35.534Z" }, + { url = "https://files.pythonhosted.org/packages/0d/ca/dd64a9ce56d9ed6cad109606ac014860b1c217c883e93bf61536400ba107/ruff-0.12.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c4faaff1f90cea9d3033cbbcdf1acf5d7fb11d8180758feb31337391691f3df0", size = 10269641, upload-time = "2025-07-11T13:20:38.459Z" }, + { url = "https://files.pythonhosted.org/packages/63/5c/2be545034c6bd5ce5bb740ced3e7014d7916f4c445974be11d2a406d5088/ruff-0.12.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40dced4a79d7c264389de1c59467d5d5cefd79e7e06d1dfa2c75497b5269a5a6", size = 11875059, upload-time = "2025-07-11T13:20:41.517Z" }, + { url = "https://files.pythonhosted.org/packages/8e/d4/a74ef1e801ceb5855e9527dae105eaff136afcb9cc4d2056d44feb0e4792/ruff-0.12.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0262d50ba2767ed0fe212aa7e62112a1dcbfd46b858c5bf7bbd11f326998bafc", size = 12658890, upload-time = "2025-07-11T13:20:44.442Z" }, + { url = "https://files.pythonhosted.org/packages/13/c8/1057916416de02e6d7c9bcd550868a49b72df94e3cca0aeb77457dcd9644/ruff-0.12.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12371aec33e1a3758597c5c631bae9a5286f3c963bdfb4d17acdd2d395406687", size = 12232008, upload-time = "2025-07-11T13:20:47.374Z" }, + { url = "https://files.pythonhosted.org/packages/f5/59/4f7c130cc25220392051fadfe15f63ed70001487eca21d1796db46cbcc04/ruff-0.12.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:560f13b6baa49785665276c963edc363f8ad4b4fc910a883e2625bdb14a83a9e", size = 11499096, upload-time = "2025-07-11T13:20:50.348Z" }, + { url = "https://files.pythonhosted.org/packages/d4/01/a0ad24a5d2ed6be03a312e30d32d4e3904bfdbc1cdbe63c47be9d0e82c79/ruff-0.12.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:023040a3499f6f974ae9091bcdd0385dd9e9eb4942f231c23c57708147b06311", size = 11688307, upload-time = "2025-07-11T13:20:52.945Z" }, + { url = "https://files.pythonhosted.org/packages/93/72/08f9e826085b1f57c9a0226e48acb27643ff19b61516a34c6cab9d6ff3fa/ruff-0.12.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:883d844967bffff5ab28bba1a4d246c1a1b2933f48cb9840f3fdc5111c603b07", size = 10661020, upload-time = "2025-07-11T13:20:55.799Z" }, + { url = "https://files.pythonhosted.org/packages/80/a0/68da1250d12893466c78e54b4a0ff381370a33d848804bb51279367fc688/ruff-0.12.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2120d3aa855ff385e0e562fdee14d564c9675edbe41625c87eeab744a7830d12", size = 10246300, upload-time = "2025-07-11T13:20:58.222Z" }, + { url = "https://files.pythonhosted.org/packages/6a/22/5f0093d556403e04b6fd0984fc0fb32fbb6f6ce116828fd54306a946f444/ruff-0.12.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6b16647cbb470eaf4750d27dddc6ebf7758b918887b56d39e9c22cce2049082b", size = 11263119, upload-time = "2025-07-11T13:21:01.503Z" }, + { url = "https://files.pythonhosted.org/packages/92/c9/f4c0b69bdaffb9968ba40dd5fa7df354ae0c73d01f988601d8fac0c639b1/ruff-0.12.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e1417051edb436230023575b149e8ff843a324557fe0a265863b7602df86722f", size = 11746990, upload-time = "2025-07-11T13:21:04.524Z" }, + { url = "https://files.pythonhosted.org/packages/fe/84/7cc7bd73924ee6be4724be0db5414a4a2ed82d06b30827342315a1be9e9c/ruff-0.12.3-py3-none-win32.whl", hash = "sha256:dfd45e6e926deb6409d0616078a666ebce93e55e07f0fb0228d4b2608b2c248d", size = 10589263, upload-time = "2025-07-11T13:21:07.148Z" }, + { url = "https://files.pythonhosted.org/packages/07/87/c070f5f027bd81f3efee7d14cb4d84067ecf67a3a8efb43aadfc72aa79a6/ruff-0.12.3-py3-none-win_amd64.whl", hash = "sha256:a946cf1e7ba3209bdef039eb97647f1c77f6f540e5845ec9c114d3af8df873e7", size = 11695072, upload-time = "2025-07-11T13:21:11.004Z" }, + { url = "https://files.pythonhosted.org/packages/e0/30/f3eaf6563c637b6e66238ed6535f6775480db973c836336e4122161986fc/ruff-0.12.3-py3-none-win_arm64.whl", hash = "sha256:5f9c7c9c8f84c2d7f27e93674d27136fbf489720251544c4da7fb3d742e011b1", size = 10805855, upload-time = "2025-07-11T13:21:13.547Z" }, +] + [[package]] name = "runpod" version = "1.7.12" @@ -5127,15 +5155,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/83/11/00d3c3dfc25ad54e731d91449895a79e4bf2384dc3ac01809010ba88f6d5/seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987", size = 294914, upload-time = "2024-01-25T13:21:49.598Z" }, ] -[[package]] -name = "semver" -version = "3.0.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/72/d1/d3159231aec234a59dd7d601e9dd9fe96f3afff15efd33c1070019b26132/semver-3.0.4.tar.gz", hash = "sha256:afc7d8c584a5ed0a11033af086e8af226a9c0b206f313e0301f8dd7b6b589602", size = 269730, upload-time = "2025-01-24T13:19:27.617Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a6/24/4d91e05817e92e3a61c8a21e08fd0f390f5301f1c448b137c57c4bc6e543/semver-3.0.4-py3-none-any.whl", hash = "sha256:9c824d87ba7f7ab4a1890799cec8596f15c1241cb473404ea1cb0c55e4b04746", size = 17912, upload-time = "2025-01-24T13:19:24.949Z" }, -] - [[package]] name = "sentencepiece" version = "0.2.0" diff --git a/scripts/migrate-s3-checkpoints.py b/scripts/migrate-s3-checkpoints.py new file mode 100755 index 00000000..7348c63c --- /dev/null +++ b/scripts/migrate-s3-checkpoints.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +""" +Script to migrate model checkpoints in S3 from old to new structure. + +Old structure: s3://bucket/prefix/project/models/model_name/0001/ +New structure: s3://bucket/prefix/project/models/model_name/checkpoints/0001/ + +Usage: + python scripts/migrate-s3-checkpoints.py --project myproject --model mymodel + python scripts/migrate-s3-checkpoints.py --project myproject --model mymodel --dry-run + python scripts/migrate-s3-checkpoints.py --project myproject --model mymodel --bucket custom-bucket --prefix custom-prefix +""" + +import argparse +import asyncio +import sys +from pathlib import Path + +# Add the src directory to the Python path +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from art.utils.s3_checkpoint_utils import migrate_s3_checkpoints_to_new_structure + + +async def main(): + parser = argparse.ArgumentParser( + description="Migrate model checkpoints in S3 from old to new structure" + ) + parser.add_argument( + "--project", + required=True, + help="Project name", + ) + parser.add_argument( + "--model", + required=True, + help="Model name", + ) + parser.add_argument( + "--bucket", + help="S3 bucket name (defaults to BACKUP_BUCKET env var)", + ) + parser.add_argument( + "--prefix", + help="S3 prefix", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Only show what would be done without making changes", + ) + + args = parser.parse_args() + + await migrate_s3_checkpoints_to_new_structure( + model_name=args.model, + project=args.project, + s3_bucket=args.bucket, + prefix=args.prefix, + dry_run=args.dry_run, + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/art/backend.py b/src/art/backend.py index 27c87eae..78555d07 100644 --- a/src/art/backend.py +++ b/src/art/backend.py @@ -1,7 +1,7 @@ import httpx import json from tqdm import auto as tqdm -from typing import AsyncIterator, TYPE_CHECKING +from typing import AsyncIterator, TYPE_CHECKING, Literal from art.utils import log_http_errors from art.utils.deploy_model import LoRADeploymentJob, LoRADeploymentProvider @@ -138,8 +138,14 @@ async def _experimental_pull_from_s3( prefix: str | None = None, verbose: bool = False, delete: bool = False, + only_step: int | Literal["latest"] | None = None, ) -> None: - """Download the model directory from S3 into file system where the LocalBackend is running. Right now this can be used to pull trajectory logs for processing or model checkpoints.""" + """Download the model directory from S3 into file system where the LocalBackend is running. Right now this can be used to pull trajectory logs for processing or model checkpoints. + + Args: + only_step: If specified, only pull this specific step. Can be an int for a specific step, + or "latest" to pull only the latest checkpoint. If None, pulls all steps. + """ response = await self._client.post( "/_experimental_pull_from_s3", json={ @@ -148,6 +154,7 @@ async def _experimental_pull_from_s3( "prefix": prefix, "verbose": verbose, "delete": delete, + "only_step": only_step, }, timeout=600, ) @@ -177,6 +184,45 @@ async def _experimental_push_to_s3( ) response.raise_for_status() + @log_http_errors + async def _experimental_fork_checkpoint( + self, + model: "Model", + from_model: str, + from_project: str | None = None, + from_s3_bucket: str | None = None, + not_after_step: int | None = None, + verbose: bool = False, + prefix: str | None = None, + ) -> None: + """Fork a checkpoint from another model to initialize this model. + + Args: + model: The model to fork to. + from_model: The name of the model to fork from. + from_project: The project of the model to fork from. Defaults to model.project. + from_s3_bucket: Optional S3 bucket to pull the checkpoint from. If provided, + will pull from S3 first. Otherwise, will fork from local disk. + not_after_step: Optional step number. If provided, will copy the last saved + checkpoint that is <= this step. Otherwise, copies the latest checkpoint. + verbose: Whether to print verbose output. + prefix: Optional S3 prefix for the bucket. + """ + response = await self._client.post( + "/_experimental_fork_checkpoint", + json={ + "model": model.model_dump(), + "from_model": from_model, + "from_project": from_project, + "from_s3_bucket": from_s3_bucket, + "not_after_step": not_after_step, + "verbose": verbose, + "prefix": prefix, + }, + timeout=600, + ) + response.raise_for_status() + @log_http_errors async def _experimental_deploy( self, diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 5757798d..345c071f 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -10,6 +10,8 @@ from art.utils.output_dirs import ( get_default_art_path, get_model_dir, + get_output_dir_from_model_properties, + get_step_checkpoint_dir, get_trajectories_split_dir, ) from art.utils.trajectory_logging import serialize_trajectory_groups @@ -22,7 +24,7 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer from transformers.tokenization_utils_base import PreTrainedTokenizerBase from tqdm import auto as tqdm -from typing import AsyncIterator, cast +from typing import AsyncIterator, cast, Literal import wandb from wandb.sdk.wandb_run import Run import weave @@ -455,18 +457,55 @@ async def _experimental_pull_from_s3( verbose: bool = False, delete: bool = False, exclude: list[ExcludableOption] | None = None, + latest_only: bool = False, + only_step: int | Literal["latest"] | None = None, ) -> None: """Download the model directory from S3 into local Backend storage. Right now this can be used to pull trajectory logs for processing or model checkpoints. Args: model: The model to pull from S3. - step: A specific step to pull from S3. If None, all steps will be pulled. + step: DEPRECATED. Use only_step instead. s3_bucket: The S3 bucket to pull from. If None, the default bucket will be used. prefix: The prefix to pull from S3. If None, the model name will be used. verbose: Whether to print verbose output. delete: Whether to delete the local model directory. exclude: List of directories to exclude from sync. Valid options: "checkpoints", "logs", "trajectories". + latest_only: DEPRECATED. Use only_step="latest" instead. + only_step: If specified, only pull this specific step. Can be an int for a specific step, + or "latest" to pull only the latest checkpoint. If None, pulls all steps. """ + # Handle backward compatibility and new only_step parameter + if only_step is None and latest_only: + only_step = "latest" + + # Handle the only_step parameter + if only_step is not None and step is None: + if only_step == "latest": + from art.utils.s3_checkpoint_utils import ( + get_latest_checkpoint_step_from_s3, + ) + + latest_step = await get_latest_checkpoint_step_from_s3( + model_name=model.name, + project=model.project, + s3_bucket=s3_bucket, + prefix=prefix, + ) + + if latest_step is not None: + step = latest_step + if verbose: + print(f"Found latest checkpoint at step {step}") + else: + if verbose: + print("No checkpoints found in S3") + return + else: + # only_step is an int + step = only_step + if verbose: + print(f"Pulling specific checkpoint at step {step}") + await pull_model_from_s3( model_name=model.name, project=model.project, @@ -498,6 +537,181 @@ async def _experimental_push_to_s3( art_path=self._path, ) + async def _experimental_fork_checkpoint( + self, + model: Model, + from_model: str, + from_project: str | None = None, + from_s3_bucket: str | None = None, + not_after_step: int | None = None, + verbose: bool = False, + prefix: str | None = None, + ) -> None: + """Fork a checkpoint from another model to initialize this model. + + Args: + model: The model to fork to. + from_model: The name of the model to fork from. + from_project: The project of the model to fork from. Defaults to model.project. + from_s3_bucket: Optional S3 bucket to pull the checkpoint from. If provided, + will pull from S3 first. Otherwise, will fork from local disk. + not_after_step: Optional step number. If provided, will copy the last saved + checkpoint that is <= this step. Otherwise, copies the latest checkpoint. + verbose: Whether to print verbose output. + prefix: Optional S3 prefix for the bucket. + """ + # Default from_project to model.project if not provided + from_project = from_project or model.project + + # Get source and destination directories + source_model_dir = get_output_dir_from_model_properties( + project=from_project, + name=from_model, + art_path=self._path, + ) + dest_model_dir = get_output_dir_from_model_properties( + project=model.project, + name=model.name, + art_path=self._path, + ) + + # If S3 bucket is provided, pull from S3 first + if from_s3_bucket is not None: + if verbose: + print( + f"DEBUG: Fork checkpoint - from_s3_bucket={from_s3_bucket}, not_after_step={not_after_step}" + ) + + # Determine which checkpoint to pull + if not_after_step is None: + # Pull only the latest checkpoint + if verbose: + print( + f"Pulling latest checkpoint for model {from_model} from S3 bucket {from_s3_bucket}..." + ) + await self._experimental_pull_from_s3( + Model(name=from_model, project=from_project), + s3_bucket=from_s3_bucket, + verbose=verbose, + exclude=["logs", "trajectories"], + only_step="latest", + ) + else: + # Find the right checkpoint not after the specified step + from art.utils.s3_checkpoint_utils import ( + get_checkpoint_step_not_after_from_s3, + ) + + if verbose: + print( + f"Finding checkpoint not after step {not_after_step} for model {from_model} in S3..." + ) + + # Find which step to pull + target_step = await get_checkpoint_step_not_after_from_s3( + model_name=from_model, + project=from_project, + not_after_step=not_after_step, + s3_bucket=from_s3_bucket, + prefix=prefix, + ) + + if target_step is None: + raise ValueError( + f"No checkpoints found not after step {not_after_step} for model {from_model} in S3" + ) + + if verbose: + print( + f"Found checkpoint at step {target_step}, pulling only this checkpoint..." + ) + + # Pull only the specific checkpoint we need + await pull_model_from_s3( + model_name=from_model, + project=from_project, + step=target_step, + s3_bucket=from_s3_bucket, + verbose=verbose, + art_path=self._path, + exclude=["logs", "trajectories"], # Only need checkpoints + ) + + # Find the checkpoint to fork + checkpoint_base_dir = os.path.join(source_model_dir, "checkpoints") + if not os.path.exists(checkpoint_base_dir): + raise FileNotFoundError( + f"No checkpoints found for model {from_model} in project {from_project}" + ) + + if verbose: + print(f"DEBUG: Checkpoint base dir: {checkpoint_base_dir}") + print( + f"DEBUG: Contents: {os.listdir(checkpoint_base_dir) if os.path.exists(checkpoint_base_dir) else 'Does not exist'}" + ) + + # Get all available checkpoint steps + available_steps = sorted( + int(d) + for d in os.listdir(checkpoint_base_dir) + if os.path.isdir(os.path.join(checkpoint_base_dir, d)) and d.isdigit() + ) + + if not available_steps: + raise FileNotFoundError( + f"No checkpoint directories found for model {from_model}" + ) + + # Determine which step to use + if not_after_step is None: + # Use the latest checkpoint + selected_step = available_steps[-1] + else: + # Find the last checkpoint not after the specified step + valid_steps = [s for s in available_steps if s <= not_after_step] + if not valid_steps: + raise ValueError( + f"No checkpoints found not after step {not_after_step}. " + f"Available steps: {available_steps}" + ) + selected_step = valid_steps[-1] + + # Create destination checkpoint directory + dest_checkpoint_dir = get_step_checkpoint_dir(dest_model_dir, selected_step) + os.makedirs(os.path.dirname(dest_checkpoint_dir), exist_ok=True) + + # Copy the checkpoint + source_checkpoint_dir = os.path.join( + checkpoint_base_dir, f"{selected_step:04d}" + ) + if verbose: + print( + f"Copying checkpoint from {source_checkpoint_dir} to {dest_checkpoint_dir}" + ) + print(f"DEBUG: Source dir exists: {os.path.exists(source_checkpoint_dir)}") + if os.path.exists(source_checkpoint_dir): + print( + f"DEBUG: Source dir contents: {os.listdir(source_checkpoint_dir)}" + ) + print( + f"DEBUG: Source dir is empty: {len(os.listdir(source_checkpoint_dir)) == 0}" + ) + + import shutil + + # Remove destination if it already exists (empty directory from previous attempts) + if os.path.exists(dest_checkpoint_dir): + if verbose: + print("DEBUG: Destination already exists, removing it first") + shutil.rmtree(dest_checkpoint_dir) + + shutil.copytree(source_checkpoint_dir, dest_checkpoint_dir) + + if verbose: + print( + f"Successfully forked checkpoint from {from_model} (step {selected_step}) to {model.name}" + ) + async def _experimental_deploy( self, deploy_to: LoRADeploymentProvider, diff --git a/src/art/model.py b/src/art/model.py index 4c95e50f..37655fe9 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -308,7 +308,7 @@ async def register( self, _openai_client_config ) - # Populate the new top-level inference fields so that the rest of the + # Populate the top-level inference fields so that the rest of the # code (and any user code) can create an OpenAI client immediately. self.inference_base_url = base_url self.inference_api_key = api_key diff --git a/src/art/utils/s3.py b/src/art/utils/s3.py index 543ecfbc..3f57996e 100644 --- a/src/art/utils/s3.py +++ b/src/art/utils/s3.py @@ -39,7 +39,8 @@ def build_s3_path( prefix_part = f"{prefix.strip('/')}/" if prefix else "" path = f"s3://{s3_bucket}/{prefix_part}{project}/models/{model_name}" if step is not None: - path += f"/{step:04d}" + # Use the new checkpoint structure in S3 + path += f"/checkpoints/{step:04d}" return path @@ -191,12 +192,12 @@ async def pull_model_from_s3( art_path=art_path, ) os.makedirs(local_model_dir, exist_ok=True) - # When pulling a specific step, we need to handle the old S3 structure + # Use the new checkpoint structure if step is not None: - # First, try to pull to the old structure location since that's what S3 has - old_step_dir = os.path.join(local_model_dir, f"{step:04d}") - os.makedirs(old_step_dir, exist_ok=True) - local_dir = old_step_dir + # Pull directly to the new checkpoint structure + checkpoint_dir = get_step_checkpoint_dir(local_model_dir, step) + os.makedirs(checkpoint_dir, exist_ok=True) + local_dir = checkpoint_dir else: local_dir = local_model_dir @@ -208,26 +209,13 @@ async def pull_model_from_s3( prefix=prefix, ) await ensure_bucket_exists(s3_bucket) + if verbose: + print(f"DEBUG: S3 sync from {s3_path} to {local_dir}") await s3_sync(s3_path, local_dir, verbose=verbose, delete=delete, exclude=exclude) - - # After pulling, migrate to new structure if needed - if step is not None: - # Check if we need to migrate this specific step - old_step_dir = os.path.join(local_model_dir, f"{step:04d}") - new_step_dir = get_step_checkpoint_dir(local_model_dir, step) - - if os.path.exists(old_step_dir) and not os.path.exists(new_step_dir): - # The checkpoint exists in old structure, migrate it - print(f"Migrating pulled checkpoint {step:04d} to new structure...") - os.makedirs(os.path.dirname(new_step_dir), exist_ok=True) - import shutil - - shutil.move(old_step_dir, new_step_dir) - else: - # If pulling all steps, run the full migration - from ..local.checkpoints import migrate_checkpoints_to_new_structure - - migrate_checkpoints_to_new_structure(local_model_dir) + if verbose: + print( + f"DEBUG: After sync, local_dir contents: {os.listdir(local_dir) if os.path.exists(local_dir) else 'Does not exist'}" + ) return local_model_dir diff --git a/src/art/utils/s3_checkpoint_utils.py b/src/art/utils/s3_checkpoint_utils.py new file mode 100644 index 00000000..c67e1b40 --- /dev/null +++ b/src/art/utils/s3_checkpoint_utils.py @@ -0,0 +1,244 @@ +"""Utilities for working with S3 checkpoints.""" + +import asyncio +from asyncio.subprocess import PIPE + + +async def get_latest_checkpoint_step_from_s3( + model_name: str, + project: str, + s3_bucket: str | None = None, + prefix: str | None = None, +) -> int | None: + """ + Get the latest checkpoint step number from S3 without downloading files. + + Returns: + The latest step number, or None if no checkpoints exist. + """ + from .s3 import build_s3_path + + s3_path = build_s3_path( + model_name=model_name, + project=project, + s3_bucket=s3_bucket, + prefix=prefix, + ) + + # List checkpoint directories in S3 + cmd = ["aws", "s3", "ls", f"{s3_path}/checkpoints/"] + + process = await asyncio.create_subprocess_exec(*cmd, stdout=PIPE, stderr=PIPE) + stdout, stderr = await process.communicate() + + if process.returncode != 0: + # No checkpoints found or error + return None + + # Parse output to find checkpoint directories + lines = stdout.decode().strip().split("\n") + checkpoint_steps = [] + + for line in lines: + if line.strip(): + # AWS S3 ls output format: "PRE 0001/" + parts = line.split() + if len(parts) >= 2 and parts[0] == "PRE": + dirname = parts[1].rstrip("/") + if dirname.isdigit(): + checkpoint_steps.append(int(dirname)) + + return max(checkpoint_steps) if checkpoint_steps else None + + +async def get_checkpoint_step_not_after_from_s3( + model_name: str, + project: str, + not_after_step: int, + s3_bucket: str | None = None, + prefix: str | None = None, +) -> int | None: + """ + Get the latest checkpoint step number that is not after the specified step from S3. + + Args: + not_after_step: Find the latest checkpoint <= this step. + + Returns: + The step number, or None if no suitable checkpoint exists. + """ + from .s3 import build_s3_path + + s3_path = build_s3_path( + model_name=model_name, + project=project, + s3_bucket=s3_bucket, + prefix=prefix, + ) + + # List checkpoint directories in S3 + cmd = ["aws", "s3", "ls", f"{s3_path}/checkpoints/"] + + process = await asyncio.create_subprocess_exec(*cmd, stdout=PIPE, stderr=PIPE) + stdout, stderr = await process.communicate() + + if process.returncode != 0: + # No checkpoints found or error + return None + + # Parse output to find checkpoint directories + lines = stdout.decode().strip().split("\n") + valid_steps = [] + + for line in lines: + if line.strip(): + # AWS S3 ls output format: "PRE 0001/" + parts = line.split() + if len(parts) >= 2 and parts[0] == "PRE": + dirname = parts[1].rstrip("/") + if dirname.isdigit(): + step = int(dirname) + if step <= not_after_step: + valid_steps.append(step) + + return max(valid_steps) if valid_steps else None + + +async def migrate_s3_checkpoints_to_new_structure( + model_name: str, + project: str, + s3_bucket: str | None = None, + prefix: str | None = None, + dry_run: bool = False, +) -> None: + """ + Migrate existing checkpoints in S3 from the old structure to the new structure. + + Old: s3://bucket/prefix/project/models/model_name/0001/ + New: s3://bucket/prefix/project/models/model_name/checkpoints/0001/ + + Args: + model_name: The name of the model to migrate. + project: The project name. + s3_bucket: The S3 bucket. If None, uses BACKUP_BUCKET env var. + prefix: Optional prefix for the S3 path. + dry_run: If True, only print what would be done without making changes. + """ + import os + from .s3 import build_s3_path + + if s3_bucket is None: + s3_bucket = os.environ.get("BACKUP_BUCKET") + if not s3_bucket: + raise ValueError( + "BACKUP_BUCKET environment variable not set and no bucket provided" + ) + + s3_path = build_s3_path( + model_name=model_name, + project=project, + s3_bucket=s3_bucket, + prefix=prefix, + ) + + print(f"Checking for checkpoints to migrate in {s3_path}") + + # List all directories in the model path + cmd = ["aws", "s3", "ls", f"{s3_path}/"] + process = await asyncio.create_subprocess_exec(*cmd, stdout=PIPE, stderr=PIPE) + stdout, stderr = await process.communicate() + + if process.returncode != 0: + print(f"Error listing S3 path: {stderr.decode()}") + return + + # Parse output to find checkpoint directories + lines = stdout.decode().strip().split("\n") + checkpoint_dirs = [] + + for line in lines: + if line.strip(): + # AWS S3 ls output format: "PRE 0001/" + parts = line.split() + if len(parts) >= 2 and parts[0] == "PRE": + dirname = parts[1].rstrip("/") + # Check if it's a 4-digit checkpoint directory (old format) + if dirname.isdigit() and len(dirname) == 4: + checkpoint_dirs.append(dirname) + + if not checkpoint_dirs: + print("No checkpoints found in old format to migrate") + return + + print( + f"Found {len(checkpoint_dirs)} checkpoint(s) to migrate: {', '.join(checkpoint_dirs)}" + ) + + if dry_run: + print("DRY RUN: Would migrate the following checkpoints:") + for checkpoint in checkpoint_dirs: + print(f" {s3_path}/{checkpoint}/ -> {s3_path}/checkpoints/{checkpoint}/") + return + + # Perform migration + migrated_count = 0 + for checkpoint in checkpoint_dirs: + old_path = f"{s3_path}/{checkpoint}/" + new_path = f"{s3_path}/checkpoints/{checkpoint}/" + + print(f"Migrating checkpoint {checkpoint}...") + + # Check if already exists in new location + check_cmd = ["aws", "s3", "ls", new_path] + check_process = await asyncio.create_subprocess_exec( + *check_cmd, stdout=PIPE, stderr=PIPE + ) + check_stdout, _ = await check_process.communicate() + + if check_process.returncode == 0 and check_stdout.decode().strip(): + print(f" Checkpoint {checkpoint} already exists in new location, skipping") + continue + + # Copy checkpoint to new location (using sync to preserve structure) + sync_cmd = ["aws", "s3", "sync", old_path, new_path] + sync_process = await asyncio.create_subprocess_exec( + *sync_cmd, stdout=PIPE, stderr=PIPE + ) + _, sync_stderr = await sync_process.communicate() + + if sync_process.returncode != 0: + print(f" Error copying checkpoint {checkpoint}: {sync_stderr.decode()}") + continue + + # Verify copy was successful by checking if files exist in new location + verify_cmd = ["aws", "s3", "ls", new_path, "--recursive"] + verify_process = await asyncio.create_subprocess_exec( + *verify_cmd, stdout=PIPE, stderr=PIPE + ) + verify_stdout, _ = await verify_process.communicate() + + if verify_process.returncode != 0 or not verify_stdout.decode().strip(): + print( + f" Error: Checkpoint {checkpoint} not found in new location after copy" + ) + continue + + # Remove old checkpoint directory + rm_cmd = ["aws", "s3", "rm", old_path, "--recursive"] + rm_process = await asyncio.create_subprocess_exec( + *rm_cmd, stdout=PIPE, stderr=PIPE + ) + _, rm_stderr = await rm_process.communicate() + + if rm_process.returncode != 0: + print( + f" Warning: Failed to remove old checkpoint {checkpoint}: {rm_stderr.decode()}" + ) + print( + " Checkpoint was successfully copied to new location but old files remain" + ) + else: + print(f" Successfully migrated checkpoint {checkpoint}") + migrated_count += 1 + + print(f"\nMigration complete. Successfully migrated {migrated_count} checkpoint(s)")