From 6051d0f5efba9cf1adff61bae754684ac4d671e8 Mon Sep 17 00:00:00 2001 From: l45k Date: Tue, 16 Sep 2025 08:34:45 +0200 Subject: [PATCH] feat: nesterov on parameter server Co-Authored-By: Gemini --- crates/messages/src/lib.rs | 10 +- crates/scheduler/src/bin/hypha-scheduler.rs | 4 + crates/worker/src/executor/mod.rs | 6 + .../worker/src/executor/parameter_server.rs | 300 ++++++++++++------ .../hypha-accelerate-driver/src/training.py | 49 ++- 5 files changed, 258 insertions(+), 111 deletions(-) diff --git a/crates/messages/src/lib.rs b/crates/messages/src/lib.rs index e974372..00d6ad0 100644 --- a/crates/messages/src/lib.rs +++ b/crates/messages/src/lib.rs @@ -370,7 +370,11 @@ pub enum Executor { config: DiLoCoConfig, }, #[serde(rename = "parameter-server")] - ParameterServer { updates: Receive, results: Send }, + ParameterServer { + updates: Receive, + results: Send, + optimizer: Optimizer, + }, } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -381,9 +385,9 @@ pub enum Optimizer { betas: Option<[f64; 2]>, epsilon: Option, }, - Sgd { + Nesterov { learning_rate: f64, - momentum: Option, + momentum: f64, }, } diff --git a/crates/scheduler/src/bin/hypha-scheduler.rs b/crates/scheduler/src/bin/hypha-scheduler.rs index 065e4cd..f30311c 100644 --- a/crates/scheduler/src/bin/hypha-scheduler.rs +++ b/crates/scheduler/src/bin/hypha-scheduler.rs @@ -267,6 +267,10 @@ async fn run(config: ConfigWithMetadata) -> Result<()> { executor: Executor::ParameterServer { updates: Receive::peers(worker_ids.clone()), results: Send::peers(worker_ids.clone(), SelectionStrategy::All), + optimizer: Optimizer::Nesterov { + learning_rate: 0.7, + momentum: 0.9, + }, }, }) .await diff --git a/crates/worker/src/executor/mod.rs b/crates/worker/src/executor/mod.rs index 37b0246..eac969f 100644 --- a/crates/worker/src/executor/mod.rs +++ b/crates/worker/src/executor/mod.rs @@ -10,6 +10,8 @@ mod process; pub use parameter_server::ParameterServerExecutor; pub use process::ProcessExecutor; +use crate::executor::parameter_server::TensorOpError; + #[derive(Error, Debug)] pub enum Error { #[error("Bridge error")] @@ -19,6 +21,10 @@ pub enum Error { Io(#[from] std::io::Error), #[error("Unsupported job spec")] UnsupportedJobSpec(), + #[error("Unsupported optimizer")] + UnsupportedOptimizer(), + #[error("Tensor error")] + Tensor(#[from] TensorOpError), } pub trait JobExecutor { diff --git a/crates/worker/src/executor/parameter_server.rs b/crates/worker/src/executor/parameter_server.rs index 3664712..36cfc54 100644 --- a/crates/worker/src/executor/parameter_server.rs +++ b/crates/worker/src/executor/parameter_server.rs @@ -1,9 +1,13 @@ use std::{ - collections::HashMap, fs::Permissions, os::unix::fs::PermissionsExt, path::PathBuf, pin::Pin, + collections::HashMap, + fs::Permissions, + os::unix::fs::PermissionsExt, + path::{Path, PathBuf}, + pin::Pin, }; use candle_core::{ - Device, + Device, Tensor, safetensors::{Load, MmapedSafetensors}, }; use futures_util::StreamExt; @@ -26,6 +30,18 @@ use crate::{ network::Network, }; +#[derive(Debug, thiserror::Error)] +pub enum TensorOpError { + #[error("Candle core error: {0}")] + Candle(#[from] candle_core::Error), + + #[error("Safetensors error: {0}")] + Safetensor(#[from] safetensors::SafeTensorError), + + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), +} + pub struct ParameterServerExecutor { connector: Connector, work_dir_base: PathBuf, @@ -67,11 +83,23 @@ impl JobExecutor for ParameterServerExecutor { let device = Device::Cpu; - let (updates, results) = match job.executor { - hypha_messages::Executor::ParameterServer { updates, results } => (updates, results), + let (updates, results, optimizer) = match job.executor { + hypha_messages::Executor::ParameterServer { + updates, + results, + optimizer, + } => (updates, results, optimizer), _ => return Err(Error::UnsupportedJobSpec()), }; + let (learning_rate, momentum) = match optimizer { + hypha_messages::Optimizer::Nesterov { + learning_rate, + momentum, + } => (learning_rate, momentum), + _ => return Err(Error::UnsupportedOptimizer()), + }; + let connector = self.connector.clone(); let task_tracker = TaskTracker::new(); @@ -153,6 +181,7 @@ impl JobExecutor for ParameterServerExecutor { // NOTE: Prefer continuing with the last good aggregation if averaging fails. // Borrow current aggregation state to avoid moving out of `result_tensor`. + tracing::info!("Received file {:?}", name); let result_tensor_file_name = match current_result_tensor_file_name { None => { // Create a copy of the initial parameters to avoid overwriting mmaped files. @@ -164,96 +193,30 @@ impl JobExecutor for ParameterServerExecutor { temporary_tensor_file_name }, Some(result_tensor_file_name) => { - tokio::task::spawn_blocking({ - let device = device.clone(); let work_dir = work_dir.clone(); - move || { - // Average the new tensor with an existing one - match unsafe { MmapedSafetensors::new(file_name) } { - Ok(new_tensor) => { - if let Err(e) = std::fs::create_dir_all(work_dir.join("avg").join(name.as_str())) { - tracing::warn!(error = ?e, "Failed to create avg directory; keeping previous aggregation"); - result_tensor_file_name.clone() - } else { - let paths = { - let result_tensor = unsafe { MmapedSafetensors::new(result_tensor_file_name.as_path()) }.expect("tensor file is readable"); - - let mut paths = Vec::new(); - for (tensor_name, tensor) in result_tensor.tensors() { - // Load tensors; skip this tensor on failure - let current_tensor = match tensor.load(&device) { - Ok(t) => t, - Err(e) => { - tracing::warn!(error = ?e, tensor = %tensor_name, "Failed to load current tensor; skipping"); - continue; - } - }; - let other_tensor = match new_tensor.load(&tensor_name, &device) { - Ok(t) => t, - Err(e) => { - tracing::warn!(error = ?e, tensor = %tensor_name, "Failed to load new tensor; skipping"); - continue; - } - }; + let tmp_dir_name = work_dir.join("tmp_join"); + let temporary_tensor_file_name = work_dir.join(format!("tmp_{:?}", Uuid::new_v4())); - let avg_tensor = match (current_tensor + other_tensor).and_then(|t| t / 2.) { - Ok(t) => t, - Err(e) => { - tracing::warn!(error = ?e, tensor = %tensor_name, "Failed to average tensors; skipping"); - continue; - } - }; - - let avg_tensor_file_name = work_dir.join("avg").join(name.as_str()).join(&tensor_name); - if let Err(e) = candle_core::safetensors::save( - &HashMap::from([(tensor_name.clone(), avg_tensor)]), - avg_tensor_file_name.as_path(), - ) { - tracing::warn!(error = ?e, tensor = %tensor_name, "Failed to save averaged tensor; skipping"); - continue; - } - paths.push(avg_tensor_file_name); - } - - paths - }; - - if paths.is_empty() { - tracing::warn!("No tensors averaged; keeping previous aggregation"); - } else if let Ok(all_tensors) = unsafe { MmapedSafetensors::multi(&paths) } { - // If at this point we're still mmapping 'result_tensor_file_name', we could have undefined behaviour! - // Make sure that we don't! - match serialize_to_file( - all_tensors.tensors(), - &None, - result_tensor_file_name.as_path(), - ) { - Ok(_) => { - tracing::debug!("Updated averaged tensors"); - } - Err(e) => { - tracing::warn!(error = ?e, "Failed to serialize averaged tensors; keeping previous aggregation"); - } - } - } else { - tracing::warn!("Failed to mmap averaged tensors; keeping previous aggregation"); - } - - if let Err(e) = std::fs::remove_dir_all(work_dir.join("avg").join(name.as_str())) { - tracing::warn!(error = ?e, "Failed to cleanup avg directory"); - } - - result_tensor_file_name.clone() - } - } - Err(e) => { - tracing::warn!(error = ?e, "Failed to mmap new tensor; keeping previous aggregation"); - - // Keep previous aggregation file - result_tensor_file_name.clone() - } + // Average the new tensor with an existing one + let average_op = |a: &Tensor, b: &Tensor|{ + // Compute (a + b) / 2. + (a + b).and_then(|t| t / 2.) + }; + if let Err(e) = apply_tensor_op( + &file_name, + &result_tensor_file_name, + &temporary_tensor_file_name, + &tmp_dir_name, + &device, + average_op, + ) + .await{ + tracing::warn!(error = ?e, "Failed to average results"); } - }}).await.expect("averaging tasks runs to completion") + fs::copy(temporary_tensor_file_name.clone(), result_tensor_file_name.clone()).await.expect("file can be copied"); + fs::remove_file(temporary_tensor_file_name).await.expect("Delete temporary average"); + + result_tensor_file_name.clone() } }; @@ -270,6 +233,9 @@ impl JobExecutor for ParameterServerExecutor { let final_tensor_file_name = work_dir.join("avg-final"); fs::rename(result_file_name.as_path(), final_tensor_file_name.as_path()).await.expect("tensor file can be renamed"); + // Do outer optimization + let gradient_file = nesterov(final_tensor_file_name.clone(), work_dir.clone(), &device, momentum, learning_rate).await.expect("nesterov"); + // These need to be reset before sending out the result! current_result_tensor_file_name = None; current_worker = 0; @@ -281,7 +247,7 @@ impl JobExecutor for ParameterServerExecutor { Ok(item) => { tracing::info!(peer_id = item.meta.name, "Sending parameter server update"); - match fs::File::open(final_tensor_file_name.as_path()).await { + match fs::File::open(gradient_file.as_path()).await { Ok(mut file) => { let mut writer = item.writer.compat_write(); @@ -313,7 +279,8 @@ impl JobExecutor for ParameterServerExecutor { } } - fs::remove_file(final_tensor_file_name.as_path()).await.expect("tensor file can be removed"); + fs::remove_file(gradient_file.as_path()).await.expect("gradient file can be removed"); + fs::remove_file(final_tensor_file_name.as_path()).await.expect("tensor file can be removed") } } }; @@ -336,3 +303,150 @@ impl JobExecutor for ParameterServerExecutor { Ok(ParameterServerExecution { task_tracker }) } } + +/// Applies a binary operation to corresponding tensors from two safetensor files. +/// +/// This function memory-maps two input safetensor files, iterates through the tensors +/// of the first file, finds the corresponding tensor by name in the second file, +/// and applies the provided operation `op` to the pair. The resulting tensors +/// are saved into a `temp_path` and combined into a single result file. Only +/// two tensors will be held in memory at the same time. +/// +/// # Arguments +/// * `file_a_path` - Path to the first safetensor file. +/// * `file_b_path` - Path to the second safetensor file. +/// * `output_path` - Path where the resulting safetensor file will be saved. +/// * `temp_path` - Path to a temporary directory where the intermediate safetensor file will be saved. +/// * `device` - The Candle device to perform computations on (e.g., `Device::Cpu`). +/// * `op` - A closure that takes two tensors and returns a new tensor. +/// +/// # Returns +/// A `Result` indicating success or a `TensorOpError` on failure. +/// +/// # Note +/// Tensors present in the second file but not the first are ignored. If a tensor +/// from the first file is not found in the second, it is skipped with a warning. +/// The `temp_path` will be created and delted by the function. Make sure it doesn't +/// point to an existing directory that contains important data. +/// Also make sure that the tensors in `op` are in the same order as they are passed to the function. +async fn apply_tensor_op( + file_a_path: &Path, + file_b_path: &Path, + output_path: &Path, + temp_path: &Path, + device: &Device, + op: F, +) -> Result<(), TensorOpError> +where + F: Fn(&Tensor, &Tensor) -> Result, +{ + // 1. Open both safetensor files in a memory-mapped way. + // SAFETY: The MmapedSafetensors::new function is unsafe because it assumes + // the underlying file will not be modified while the memory map is active. + let tensors_a = unsafe { candle_core::safetensors::MmapedSafetensors::new(file_a_path)? }; + let tensors_b = unsafe { candle_core::safetensors::MmapedSafetensors::new(file_b_path)? }; + fs::create_dir_all(temp_path).await?; + + let mut result_tensors = Vec::new(); + + // 2. Iterate through each tensor in the first file. + for (name, tensor_view) in tensors_a.tensors() { + // Try to load the corresponding tensor from the second file. + match tensors_b.load(&name, device) { + Ok(tensor_b) => { + // If found, load the tensor from the first file. + let tensor_a = tensor_view.load(device)?; + + // 3. Apply the provided computation function and serialize the result to disk. + let result_tensor = op(&tensor_a, &tensor_b)?; + let result_path = temp_path.join(name.clone()); + candle_core::safetensors::save( + &HashMap::from([(name, result_tensor)]), + result_path.clone(), + )?; + result_tensors.push(result_path); + } + Err(_) => { + // If a tensor from file A doesn't exist in file B, skip it. + tracing::warn!("Tensor '{}' not found in second file, skipping.", name); + continue; + } + } + } + + // 4. Write all result tensors to the new file. + if result_tensors.is_empty() { + tracing::warn!("Warning: No matching tensors found to process."); + } else { + let all_tensors = unsafe { MmapedSafetensors::multi(&result_tensors)? }; + serialize_to_file(all_tensors.tensors(), &None, output_path)?; + } + + fs::remove_dir_all(temp_path).await?; + + Ok(()) +} + +async fn update_momentum( + work_dir: PathBuf, + gradient_file_name: &Path, + device: &Device, + momentum: f64, +) -> Result { + // If we are in the first round, we need to initialize the momentum with the gradient + let momentum_file = work_dir.join("momentum"); + if fs::metadata(momentum_file.clone()).await.is_err() { + fs::copy(gradient_file_name.to_path_buf(), momentum_file.clone()) + .await + .expect("copy gradients to momentum"); + } else { + let tmp_dir_name = work_dir.join("tmp_momentum"); + let momentum_update_file = work_dir.join("momentum_update"); + let momentum_op = |g: &Tensor, m: &Tensor| { + // Calculation: (mu * momentum) / 2.0 + (momentum * m).and_then(|t| t + g) + }; + apply_tensor_op( + gradient_file_name, + &momentum_file, + &momentum_update_file, + &tmp_dir_name, + device, + momentum_op, + ) + .await?; + fs::copy(momentum_update_file, momentum_file.clone()).await?; + } + Ok(momentum_file) +} + +async fn nesterov( + gradient_file: PathBuf, + work_dir: PathBuf, + device: &Device, + momentum: f64, + learning_rate: f64, +) -> Result { + let momentum_file = update_momentum(work_dir.clone(), &gradient_file, device, momentum).await?; + + let tmp_dir_name = work_dir.join("tmp_nesterov"); + let result_gradient_name = work_dir.join("gradient_update"); + + let nesterov_op = |g: &Tensor, m: &Tensor| { + // Compute: learning_rate * ((momentum * m) + g) + (momentum * m) + .and_then(|t| t + g) + .and_then(|t| learning_rate * t) + }; + apply_tensor_op( + &gradient_file, + &momentum_file, + &result_gradient_name, + &tmp_dir_name, + device, + nesterov_op, + ) + .await?; + + Ok(result_gradient_name) +} diff --git a/drivers/hypha-accelerate-driver/src/training.py b/drivers/hypha-accelerate-driver/src/training.py index 6905f92..ee43e8e 100644 --- a/drivers/hypha-accelerate-driver/src/training.py +++ b/drivers/hypha-accelerate-driver/src/training.py @@ -10,7 +10,7 @@ import torch.utils.data from accelerate import Accelerator from safetensors import safe_open -from safetensors.torch import load, load_file, save_model +from safetensors.torch import load, load_file, save_file, save_model from torch.nn import Module from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -189,15 +189,23 @@ def get_scheduler(type: str, args: dict[str, float], optmizer: Optimizer) -> tor raise RuntimeError(f"Learning rate Scheduler {type} not supported") -def merge_models(origin: Module, weight_path: str, alpha: float) -> Module: - # All weights need to be on CPU - origin.to("cpu") - state_dict = origin.state_dict() - with safe_open(weight_path, framework="pt", device="cpu") as b: # type: ignore - for name in b.keys(): # noqa: SIM118 - state_dict[name] += (alpha * (b.get_tensor(name) - state_dict[name])).to(state_dict[name].dtype) - origin.load_state_dict(state_dict) - return origin +def merge_models(old_model: str, weight_path: str) -> dict[str, torch.Tensor]: + state_dict: dict[str, torch.Tensor] = {} + with ( + safe_open(weight_path, framework="pt", device="cpu") as g, # type: ignore[no-untyped-call] + safe_open(old_model, framework="pt", device="cpu") as m, # type: ignore[no-untyped-call] + ): + for name in m.keys(): # noqa: SIM118 + # state_dict[name] += (alpha * (b.get_tensor(name) - state_dict[name])).to(state_dict[name].dtype) + state_dict[name] = m.get_tensor(name) - g.get_tensor(name) + return state_dict + + +def extract_gradients(state_dict: dict[str, torch.Tensor], previous_model_path: str) -> dict[str, torch.Tensor]: + with safe_open(previous_model_path, framework="pt", device="cpu") as p: # type: ignore[no-untyped-call] + for name in p: + state_dict[name] -= p.get_tensor(name).to(state_dict[name].dtype) + return state_dict def dataset_wrapper(dataset: torch.utils.data.DataLoader) -> Iterator[dict[str, torch.Tensor]]: # type: ignore[type-arg] @@ -248,6 +256,10 @@ def main(socket_path: str, work_dir: str, job_json: str) -> None: # noqa: PLR09 job_spec["config"]["batch_size"], ) + # Serialize the model to disk + previous_model_path = os.path.join(work_dir, "0_global_weights.pt") + save_model(model, previous_model_path) + model, optimizer, training_dataloader, scheduler = accelerator.prepare(model, optimizer, data_loader, scheduler) training_data_iter = dataset_wrapper(training_dataloader) run_epoch = get_training_loop(job_spec["config"], training_data_iter) @@ -276,9 +288,11 @@ def main(socket_path: str, work_dir: str, job_json: str) -> None: # noqa: PLR09 rel_path = parameters.get("path") if parameters else latest.get("path") if isinstance(rel_path, str): path = os.path.join(work_dir, rel_path) - base_model = accelerator.unwrap_model(model) - merge_models(base_model, path, 0.9) - model = accelerator.prepare(base_model) + # Load new model with outer gradients + model.load_state_dict(merge_models(previous_model_path, path)) + # over write previous model + save_model(model, previous_model_path) + model = accelerator.prepare(model) print("Weights updated from", rel_path, flush=True) except Exception as e: print(f"pointer handling error: {e}") @@ -293,10 +307,15 @@ def main(socket_path: str, work_dir: str, job_json: str) -> None: # noqa: PLR09 if accelerator.is_main_process and epoch % job_spec["config"]["checkpointing"] == 0: # For testing purposes set to global! - file_name = f"{epoch}_global_weights.pt" + file_name = f"{epoch}_local_gradients.pt" result_path = os.path.join(work_dir, file_name) # Save unwrapped model to avoid accelerator wrappers interfering - save_model(accelerator.unwrap_model(model), result_path) + # + model = accelerator.unwrap_model(model) + # All weights need to be on CPU + model.to("cpu") + + save_file(extract_gradients(model.state_dict(), previous_model_path), result_path) session.send(job_spec["results"], file_name) # Mark that before the next training epoch we must wait for an update