Skip to content

Mattj/augusta #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions src/cookbook/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
from dataclasses import dataclass
from typing import Dict, List, Optional

from olmo_core.data import (
DataMix,
NumpyDataLoaderConfig,
NumpyDatasetConfig,
NumpyDatasetType,
TokenizerConfig,
from cookbook.aliases import SourceInstance, WandbConfig
from cookbook.data.dataset import MixtureBuilder
from cookbook.model.config import (
MODEL_TO_LR_MAP,
DefaultOptimizerProperties,
ModelTrainConfig,
SupportedTokenizers,
WrappedTransformerConfig,
)
from cookbook.model.evaluators import DownstreamEvaluators
from cookbook.model.schedulers import WSD
from olmo_core.data import DataMix, NumpyDataLoaderConfig, NumpyDatasetConfig, NumpyDatasetType, TokenizerConfig
from olmo_core.data.types import NumpyDatasetDType
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride, Scheduler
Expand All @@ -27,18 +32,6 @@
WandBCallback,
)

from cookbook.aliases import SourceInstance, WandbConfig
from cookbook.data.dataset import MixtureBuilder
from cookbook.model.config import (
MODEL_TO_LR_MAP,
DefaultOptimizerProperties,
ModelTrainConfig,
SupportedTokenizers,
WrappedTransformerConfig,
)
from cookbook.model.evaluators import DownstreamEvaluators
from cookbook.model.schedulers import WSD

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -189,8 +182,15 @@ def __init__(
if any(substring in cluster for substring in ["jupiter", "saturn"]) and weka:
self.root_dir = f"/weka/oe-training-default/ai2-llm"
logger.info(f"Using Weka bucket as root dir: {self.root_dir}")
self.checkpoint_dir = f"{self.root_dir}/checkpoints/{self.beaker_user.lower()}/{self.run_name}"

elif "augusta" in cluster:
try:
assert not weka
except AssertionError as e:
logger.info("Can't be on Augusta and weka!")
raise e
self.data_dir = self.root_dir = "gs://ai2-llm"
Comment on lines +188 to +191
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also fail before submitting the job for our own sanity.

Maybe here: https://github.com/allenai/olmo-cookbook/blob/main/src/cookbook/utils/config.py#L165

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


self.checkpoint_dir = f"{self.root_dir}/checkpoints/{self.beaker_user.lower()}/{self.run_name}"
self.dataset_cache = f"{self.root_dir}/{self.beaker_user.lower()}/{self.run_name}/dataset-cache"

def get_tokenizer_config(self, tokenizer) -> TokenizerConfig:
Expand Down
21 changes: 21 additions & 0 deletions src/cookbook/recipes/love2code/train-SAMPLE-augusta.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: "train-SAMPLE-augusta"
description: "Just trying out cookbook + augusta"
budget: "ai2/oe-training"
workspace: "ai2/learn2code"
nodes: 1
gpus: 8
preemptible: true
max_tokens: 113_184_153_600 # 5xC multiplier
sequence_length: 2048
seed: 1337
model: "olmo2_1B"
tokenizer: "dolma2"
priority: high
cluster: ai2/augusta-google-1
weka: false
dataset:
sources:
- name: sample
target_ratio: 1.0
paths:
- gs://ai2-llm/preprocessed/love2code/python_only/python_tokens/part-000-00000.npy
20 changes: 6 additions & 14 deletions src/cookbook/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,13 @@
from typing import List, Tuple, cast

import yaml
from olmo_core.launch.beaker import (
BeakerEnvSecret,
BeakerLaunchConfig,
BeakerWekaBucket,
)
from olmo_core.train.callbacks import ConfigSaverCallback, WandBCallback
from olmo_core.utils import get_default_device, seed_all

from cookbook.aliases import (
ExperimentConfig,
ExperimentGroup,
ExperimentInstance,
SourceConfig,
SourceInstance,
)
from cookbook.aliases import ExperimentConfig, ExperimentGroup, ExperimentInstance, SourceConfig, SourceInstance
from cookbook.model.builder import TransformerConfigBuilder
from cookbook.utils.data import normalize_source_paths
from olmo_core.launch.beaker import BeakerEnvSecret, BeakerLaunchConfig, BeakerWekaBucket
from olmo_core.train.callbacks import ConfigSaverCallback, WandBCallback
from olmo_core.utils import get_default_device, seed_all

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -190,6 +180,8 @@ def mk_launch_configs(group: ExperimentGroup, beaker_user: str) -> list[BeakerLa
BeakerEnvSecret(name="AWS_CREDENTIALS", secret=f"{beaker_user}_AWS_CREDENTIALS"),
BeakerEnvSecret(name="R2_ENDPOINT_URL", secret="R2_ENDPOINT_URL"),
BeakerEnvSecret(name="WEKA_ENDPOINT_URL", secret="WEKA_ENDPOINT_URL"),
BeakerEnvSecret(name="GS_INTEROP_KEY", secret="GS_INTEROP_KEY"),
BeakerEnvSecret(name="GS_INTEROP_SECRET", secret="GS_INTEROP_SECRET"),
],
setup_steps=[
'git clone "$REPO_URL"',
Expand Down
19 changes: 10 additions & 9 deletions src/cookbook/utils/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import concurrent.futures
import hashlib
import json
import logging
import os
import pathlib
Expand All @@ -7,22 +9,18 @@
from urllib.parse import urlparse

import s3fs
from tqdm import tqdm

from cookbook.aliases import SourceConfig
from olmo_core.aliases import PathOrStr
from olmo_core.data.types import NumpyDatasetDType
from olmo_core.io import get_file_size, is_url, normalize_path
from olmo_core.utils import OLMoEnvironmentError
from tqdm import tqdm

logger = logging.getLogger(__name__)
logging.getLogger("botocore").setLevel(logging.WARNING)


import hashlib
import json

from cookbook.aliases import SourceConfig


def _bytes_to_tokens(num_bytes: int, dtype: NumpyDatasetDType) -> int:
"""
Convert bytes to tokens based on the dtype.
Expand Down Expand Up @@ -65,9 +63,12 @@ def get_token_counts_and_ratios(
parsed = urlparse(path)
if parsed.scheme == "s3":
continue
if parsed.scheme == "weka":
elif parsed.scheme == "weka":
client_kwargs["endpoint_url"] = os.environ.get("WEKA_ENDPOINT_URL")

elif parsed.scheme == "gs":
client_kwargs["endpoint_url"] = "https://storage.googleapis.com"
client_kwargs["key"] = os.environ.get("GS_INTEROP_KEY")
client_kwargs["secret"] = os.environ.get("GS_INTEROP_SECRET")
fs = s3fs.S3FileSystem(client_kwargs={**client_kwargs})

with concurrent.futures.ThreadPoolExecutor(max_workers=64) as executor:
Expand Down