Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ set -ex
# use envs as local overwrites for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
export WANDB_PROJECT=${WANDB_PROJECT:-"titan-dion-8b"}
NGPU=${NGPU:-"8"}
export LOG_RANK=${LOG_RANK:-0}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b_muon.toml"}

TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}

Expand Down
105 changes: 93 additions & 12 deletions torchtitan/components/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,23 +143,104 @@ def __init__(self, log_dir: str, job_config: JobConfig, tag: str | None = None):
# Create logging directory
os.makedirs(log_dir, exist_ok=True)

self.wandb.init(
project=os.getenv("WANDB_PROJECT", "torchtitan"),
dir=log_dir,
config=job_config.to_dict(),
)
logger.info("WandB logging enabled")
# Prepare wandb initialization parameters
wandb_config = {
"project": os.getenv("WANDB_PROJECT", "torchtitan"),
"dir": log_dir,
"config": job_config.to_dict(),
}

# Add optional parameters if environment variables are set
if os.getenv("WANDB_ENTITY"):
wandb_config["entity"] = os.getenv("WANDB_ENTITY")

if os.getenv("WANDB_NAME"):
wandb_config["name"] = os.getenv("WANDB_NAME")

if os.getenv("WANDB_TAGS"):
# Split comma-separated tags
tags = [tag.strip() for tag in os.getenv("WANDB_TAGS").split(",")]
wandb_config["tags"] = tags

if os.getenv("WANDB_NOTES"):
wandb_config["notes"] = os.getenv("WANDB_NOTES")

# Log the configuration being used
logger.info(f"Initializing WandB with config: {wandb_config}")

# Check if wandb is properly authenticated
try:
api_key = self.wandb.api.api_key
if api_key:
logger.info("WandB API key found - authentication OK")
else:
logger.warning(
"No WandB API key found - you may need to run 'wandb login'"
)
except Exception as e:
logger.warning(f"Could not check WandB authentication: {e}")

# Initialize wandb with enhanced configuration
try:
run = self.wandb.init(**wandb_config)

if run is not None:
logger.info(f"WandB logging enabled successfully!")
logger.info(f"Run URL: {run.url}")
logger.info(f"Project: {run.project}")
logger.info(f"Entity: {run.entity}")
logger.info(f"Run name: {run.name}")
logger.info(f"Run ID: {run.id}")
if hasattr(run, "tags") and run.tags:
logger.info(f"Tags: {run.tags}")
else:
logger.warning(
"WandB initialization returned None - logging may not work properly"
)
except Exception as e:
error_msg = str(e)
logger.error(f"Failed to initialize WandB: {error_msg}")

# Provide specific guidance for common authentication errors
if "401" in error_msg or "user is not logged in" in error_msg:
logger.error("Authentication error detected. This usually means:")
logger.error("1. Your WandB login session has expired")
logger.error("2. You need to re-authenticate with WandB")
logger.error("Please run: wandb login --relogin")
logger.error(
"Or set WANDB_MODE=offline to run without uploading to WandB"
)
elif "403" in error_msg or "permission" in error_msg.lower():
logger.error("Permission error detected. Please check:")
logger.error("1. Your WandB entity/team name is correct")
logger.error("2. You have access to the specified project")
logger.error("3. Your API key has the necessary permissions")
elif "network" in error_msg.lower() or "connection" in error_msg.lower():
logger.error("Network error detected. You can:")
logger.error("1. Check your internet connection")
logger.error("2. Set WANDB_MODE=offline to run without uploading")

logger.error("WandB logging will be disabled for this run")
# Fall back to a no-op logger
self.wandb = None

def log(self, metrics: dict[str, Any], step: int) -> None:
wandb_metrics = {
(k if self.tag is None else f"{self.tag}/{k}"): v
for k, v in metrics.items()
}
self.wandb.log(wandb_metrics, step=step)
if self.wandb is not None:
wandb_metrics = {
(k if self.tag is None else f"{self.tag}/{k}"): v
for k, v in metrics.items()
}
self.wandb.log(wandb_metrics, step=step)

def close(self) -> None:
if self.wandb.run is not None:
if (
self.wandb is not None
and hasattr(self.wandb, "run")
and self.wandb.run is not None
):
logger.info("Finishing WandB run...")
self.wandb.finish()
logger.info("WandB run finished successfully")


def ensure_pp_loss_visible(
Expand Down
187 changes: 183 additions & 4 deletions torchtitan/components/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,49 @@
from torch.optim import Optimizer

from torchtitan.components.ft import FTManager, has_torchft
from torchtitan.config import Optimizer as OptimizerConfig
from torchtitan.config import Optimizer as OptimizerConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims

# Dion optimizer availability will be checked lazily when needed
DION_AVAILABLE = None
MUON_AVAILABLE = None


def _check_dion_availability():
"""Lazy check for Dion optimizer availability."""
global DION_AVAILABLE
if DION_AVAILABLE is None:
try:
from torchtitan.experiments.dion_optimizer.dion import (
Dion,
DionMixedPrecisionConfig,
)
from torchtitan.experiments.dion_optimizer.titan_dion import (
DionOptimizersContainer,
)

DION_AVAILABLE = True
except ImportError:
DION_AVAILABLE = False
return DION_AVAILABLE


def _check_muon_availability():
"""Lazy check for Muon optimizer availability."""
global MUON_AVAILABLE
if MUON_AVAILABLE is None:
try:
from torchtitan.experiments.dion_optimizer.muon import Muon
from torchtitan.experiments.dion_optimizer.titan_muon import (
MuonOptimizersContainer,
)

MUON_AVAILABLE = True
except ImportError:
MUON_AVAILABLE = False
return MUON_AVAILABLE


__all__ = [
"OptimizersContainer",
"build_optimizers",
Expand Down Expand Up @@ -249,8 +289,8 @@ def build_optimizers(
This function creates a ``OptimizersContainer`` for the given model parts.
``optimizer_config`` should define the correct optimizer name and parameters.
This function currently supports creating ``OptimizersContainer`` and
``OptimizersInBackwardContainer``.
This function currently supports creating ``OptimizersContainer``,
``OptimizersInBackwardContainer``, and ``DionOptimizersContainer``.
**Note**
Users who want to customize the optimizer behavior can create their own
Expand All @@ -264,6 +304,146 @@ def build_optimizers(
parallel_dims (ParallelDims): Parallel dimensions for the model.
"""
optim_in_bwd = optimizer_config.early_step_in_backward
name = optimizer_config.name

# Handle Dion optimizer
if name == "Dion":
if not _check_dion_availability():
raise ImportError(
"Dion optimizer is not available. Please ensure the dion optimizer files are present in "
"torchtitan/experiments/dion_optimizer/"
)

if optim_in_bwd:
raise NotImplementedError(
"Dion optimizer does not support early step in backward."
)

if ft_manager and ft_manager.enabled:
raise NotImplementedError(
"TorchFT is not yet supported with Dion optimizer."
)

# Import the DionOptimizerConfig and DionOptimizersContainer from titan_dion
from torchtitan.experiments.dion_optimizer.titan_dion import (
DionOptimizerConfig,
DionOptimizersContainer,
)

# Create DionOptimizerConfig from optimizer_config
dion_config = DionOptimizerConfig(
name="dion",
lr=optimizer_config.lr,
weight_decay=optimizer_config.weight_decay,
mu=optimizer_config.mu,
betas=(optimizer_config.beta1, optimizer_config.beta2),
epsilon=optimizer_config.eps,
rank_fraction=optimizer_config.rank_fraction,
rank_multiple_of=optimizer_config.rank_multiple_of,
power_iters=optimizer_config.power_iters,
qr_method=optimizer_config.qr_method,
cqr_warmup_steps=optimizer_config.cqr_warmup_steps,
rcqr_oversample=optimizer_config.rcqr_oversample,
algorithm=optimizer_config.algorithm,
replicate_mesh_grad_sync=optimizer_config.replicate_mesh_grad_sync,
# Parameter-specific optimizer selection
scalar_optimizer=getattr(optimizer_config, "scalar_optimizer", "adamw"),
embedding_optimizer=getattr(
optimizer_config, "embedding_optimizer", "adamw"
),
head_optimizer=getattr(optimizer_config, "head_optimizer", "adamw"),
routing_optimizer=getattr(optimizer_config, "routing_optimizer", "adamw"),
expert_optimizer=getattr(optimizer_config, "expert_optimizer", None),
# Additional optimizer options
head_lr_scaling=getattr(optimizer_config, "head_lr_scaling", True),
# Learning rate scaling factors
scalar_lr_factor=getattr(optimizer_config, "scalar_lr_factor", 1.0),
embedding_lr_factor=getattr(optimizer_config, "embedding_lr_factor", 1.0),
head_lr_factor=getattr(optimizer_config, "head_lr_factor", 1.0),
routing_lr_factor=getattr(optimizer_config, "routing_lr_factor", 1.0),
expert_lr_factor=getattr(optimizer_config, "expert_lr_factor", 1.0),
)

# Set mixed precision dtypes if specified
if optimizer_config.momentum_dtype:
dion_config.momentum_dtype = TORCH_DTYPE_MAP[
optimizer_config.momentum_dtype
]
if optimizer_config.Q_dtype:
dion_config.Q_dtype = TORCH_DTYPE_MAP[optimizer_config.Q_dtype]
if optimizer_config.variance_dtype:
dion_config.variance_dtype = TORCH_DTYPE_MAP[
optimizer_config.variance_dtype
]

return DionOptimizersContainer(
model_parts=model_parts,
dion_config=dion_config,
parallel_dims=parallel_dims,
)

# Handle Muon optimizer
if name == "Muon":
if not _check_muon_availability():
raise ImportError(
"Muon optimizer is not available. Please ensure the muon optimizer files are present in "
"torchtitan/experiments/dion_optimizer/"
)

if optim_in_bwd:
raise NotImplementedError(
"Muon optimizer does not support early step in backward."
)

if ft_manager and ft_manager.enabled:
raise NotImplementedError(
"TorchFT is not yet supported with Muon optimizer."
)

# Import the MuonOptimizerConfig and MuonOptimizersContainer from titan_muon
from torchtitan.experiments.dion_optimizer.titan_muon import (
MuonOptimizerConfig,
MuonOptimizersContainer,
)

# Create MuonOptimizerConfig from optimizer_config
muon_config = MuonOptimizerConfig(
name="muon",
lr=optimizer_config.lr,
weight_decay=optimizer_config.weight_decay,
mu=optimizer_config.mu,
betas=(optimizer_config.beta1, optimizer_config.beta2),
epsilon=optimizer_config.eps,
nesterov=getattr(optimizer_config, "nesterov", False),
adjust_lr=getattr(optimizer_config, "adjust_lr", "spectral_norm"),
flatten=getattr(optimizer_config, "flatten", False),
use_triton=getattr(optimizer_config, "use_triton", False),
algorithm=optimizer_config.algorithm,
# Parameter-specific optimizer selection
scalar_optimizer=getattr(optimizer_config, "scalar_optimizer", "adamw"),
embedding_optimizer=getattr(
optimizer_config, "embedding_optimizer", "adamw"
),
head_optimizer=getattr(optimizer_config, "head_optimizer", "adamw"),
routing_optimizer=getattr(optimizer_config, "routing_optimizer", "adamw"),
expert_optimizer=getattr(optimizer_config, "expert_optimizer", None),
# Additional optimizer options
head_lr_scaling=getattr(optimizer_config, "head_lr_scaling", True),
# Learning rate scaling factors
scalar_lr_factor=getattr(optimizer_config, "scalar_lr_factor", 1.0),
embedding_lr_factor=getattr(optimizer_config, "embedding_lr_factor", 1.0),
head_lr_factor=getattr(optimizer_config, "head_lr_factor", 1.0),
routing_lr_factor=getattr(optimizer_config, "routing_lr_factor", 1.0),
expert_lr_factor=getattr(optimizer_config, "expert_lr_factor", 1.0),
)

return MuonOptimizersContainer(
model_parts=model_parts,
muon_config=muon_config,
parallel_dims=parallel_dims,
)

# Handle standard optimizers (Adam, AdamW)
if optim_in_bwd:
if parallel_dims.ep_enabled:
raise NotImplementedError(
Expand All @@ -278,7 +458,6 @@ def build_optimizers(
"TorchFT is not supported with optimizers in backward."
)

name = optimizer_config.name
lr = optimizer_config.lr
beta1 = optimizer_config.beta1
beta2 = optimizer_config.beta2
Expand Down
15 changes: 15 additions & 0 deletions torchtitan/components/simple_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Import Dion optimizer components
import torch

try:
from torchtitan.experiments.dion_optimizer.dion import (
Dion,
DionMixedPrecisionConfig,
)
from torchtitan.experiments.dion_optimizer.titan_dion import DionOptimizersContainer

DION_AVAILABLE = True
print("✓ Dion optimizer components imported")
except ImportError:
DION_AVAILABLE = False
print("✗ Dion optimizer components not available")
Loading