diff --git a/Cargo.lock b/Cargo.lock index cfe9093..1d2d49c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2009,6 +2009,8 @@ dependencies = [ "safetensors", "serde", "serde_json", + "sha2", + "tempfile", "thiserror 2.0.16", "tokio", "tokio-util", diff --git a/Cargo.toml b/Cargo.toml index f1b088b..7622cd3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,6 @@ ciborium = "0.2.2" clap = { version = "4.5.31", features = ["derive"] } documented = "0.9.2" figment = { version = "0.10", features = ["toml", "env"] } -futures-rustls = { version = "0.26.0", default-features = false } futures-util = "0.3" http-body-util = "0.1.3" hypha-config = { path = "crates/config" } @@ -87,8 +86,6 @@ tracing = "0.1.41" tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } utoipa = { version = "5.4.0", features = ["axum_extras", "url", "uuid"] } uuid = { version = "1.17.0", features = ["serde", "v4"] } -webpki = { version = "0.103", package = "rustls-webpki", features = ["std"] } -x509-parser = "0.16" [profile.profiling] inherits = "release" diff --git a/crates/messages/src/lib.rs b/crates/messages/src/lib.rs index e974372..00d6ad0 100644 --- a/crates/messages/src/lib.rs +++ b/crates/messages/src/lib.rs @@ -370,7 +370,11 @@ pub enum Executor { config: DiLoCoConfig, }, #[serde(rename = "parameter-server")] - ParameterServer { updates: Receive, results: Send }, + ParameterServer { + updates: Receive, + results: Send, + optimizer: Optimizer, + }, } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -381,9 +385,9 @@ pub enum Optimizer { betas: Option<[f64; 2]>, epsilon: Option, }, - Sgd { + Nesterov { learning_rate: f64, - momentum: Option, + momentum: f64, }, } diff --git a/crates/scheduler/src/bin/hypha-scheduler.rs b/crates/scheduler/src/bin/hypha-scheduler.rs index 065e4cd..f30311c 100644 --- a/crates/scheduler/src/bin/hypha-scheduler.rs +++ b/crates/scheduler/src/bin/hypha-scheduler.rs @@ -267,6 +267,10 @@ async fn run(config: ConfigWithMetadata) -> Result<()> { executor: Executor::ParameterServer { updates: Receive::peers(worker_ids.clone()), results: Send::peers(worker_ids.clone(), SelectionStrategy::All), + optimizer: Optimizer::Nesterov { + learning_rate: 0.7, + momentum: 0.9, + }, }, }) .await diff --git a/crates/worker/Cargo.toml b/crates/worker/Cargo.toml index 9fd82e5..db5acaf 100644 --- a/crates/worker/Cargo.toml +++ b/crates/worker/Cargo.toml @@ -27,6 +27,8 @@ reqwest = { version = "0.12", default-features = false, features = ["stream", "j safetensors.workspace = true serde.workspace = true serde_json.workspace = true +sha2.workspace = true +tempfile.workspace = true thiserror.workspace = true tokio.workspace = true tokio-util.workspace = true diff --git a/crates/worker/src/executor/mod.rs b/crates/worker/src/executor/mod.rs index 37b0246..eac969f 100644 --- a/crates/worker/src/executor/mod.rs +++ b/crates/worker/src/executor/mod.rs @@ -10,6 +10,8 @@ mod process; pub use parameter_server::ParameterServerExecutor; pub use process::ProcessExecutor; +use crate::executor::parameter_server::TensorOpError; + #[derive(Error, Debug)] pub enum Error { #[error("Bridge error")] @@ -19,6 +21,10 @@ pub enum Error { Io(#[from] std::io::Error), #[error("Unsupported job spec")] UnsupportedJobSpec(), + #[error("Unsupported optimizer")] + UnsupportedOptimizer(), + #[error("Tensor error")] + Tensor(#[from] TensorOpError), } pub trait JobExecutor { diff --git a/crates/worker/src/executor/parameter_server.rs b/crates/worker/src/executor/parameter_server.rs index 3664712..5c22959 100644 --- a/crates/worker/src/executor/parameter_server.rs +++ b/crates/worker/src/executor/parameter_server.rs @@ -1,13 +1,18 @@ use std::{ - collections::HashMap, fs::Permissions, os::unix::fs::PermissionsExt, path::PathBuf, pin::Pin, + collections::HashMap, + fs::Permissions, + os::unix::fs::PermissionsExt, + path::{Path, PathBuf}, + pin::Pin, }; use candle_core::{ - Device, + Device, Tensor, safetensors::{Load, MmapedSafetensors}, }; use futures_util::StreamExt; use safetensors::serialize_to_file; +use sha2::{Digest, Sha256}; use tokio::{ fs, io::{self, AsyncWriteExt}, @@ -26,6 +31,18 @@ use crate::{ network::Network, }; +#[derive(Debug, thiserror::Error)] +pub enum TensorOpError { + #[error("Candle core error: {0}")] + Candle(#[from] candle_core::Error), + + #[error("Safetensors error: {0}")] + Safetensor(#[from] safetensors::SafeTensorError), + + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), +} + pub struct ParameterServerExecutor { connector: Connector, work_dir_base: PathBuf, @@ -67,11 +84,23 @@ impl JobExecutor for ParameterServerExecutor { let device = Device::Cpu; - let (updates, results) = match job.executor { - hypha_messages::Executor::ParameterServer { updates, results } => (updates, results), + let (updates, results, optimizer) = match job.executor { + hypha_messages::Executor::ParameterServer { + updates, + results, + optimizer, + } => (updates, results, optimizer), _ => return Err(Error::UnsupportedJobSpec()), }; + let (learning_rate, momentum) = match optimizer { + hypha_messages::Optimizer::Nesterov { + learning_rate, + momentum, + } => (learning_rate, momentum), + _ => return Err(Error::UnsupportedOptimizer()), + }; + let connector = self.connector.clone(); let task_tracker = TaskTracker::new(); @@ -105,9 +134,11 @@ impl JobExecutor for ParameterServerExecutor { match item { Ok(item) => { let name = item.meta.name.clone(); - tracing::info!(peer_id = ?name, "Received parameter server update (start)"); let mut reader = item.reader.compat(); - let file_name = work_dir.join(name.clone()); + // Don't use the name from the meta data as it could be an arbitrary path. + let hex_digest = format!("{:X}",Sha256::digest(item.meta.name)); + tracing::info!(peer_id = ?name, file_name = ?hex_digest, "Received parameter server update (start)"); + let file_name = work_dir.join(hex_digest); match fs::File::create(&file_name).await { Ok(mut file) => { @@ -153,111 +184,39 @@ impl JobExecutor for ParameterServerExecutor { // NOTE: Prefer continuing with the last good aggregation if averaging fails. // Borrow current aggregation state to avoid moving out of `result_tensor`. - let result_tensor_file_name = match current_result_tensor_file_name { + tracing::info!("Received file {:?}", name); + current_result_tensor_file_name = match current_result_tensor_file_name { None => { - // Create a copy of the initial parameters to avoid overwriting mmaped files. - // In edge-cases we might receive updates while still sending out results. - // If the name of the result file would be the same as one of incoming updates, - // we would overwrite an mmapped file. - let temporary_tensor_file_name = work_dir.join("avg-temporary"); - fs::copy(file_name.as_path(), temporary_tensor_file_name.as_path()).await.expect("file can be copied"); - temporary_tensor_file_name + // Just use the first file as the current name, so no copy required. + Some(file_name.to_path_buf()) }, Some(result_tensor_file_name) => { - tokio::task::spawn_blocking({ - let device = device.clone(); let work_dir = work_dir.clone(); - move || { - // Average the new tensor with an existing one - match unsafe { MmapedSafetensors::new(file_name) } { - Ok(new_tensor) => { - if let Err(e) = std::fs::create_dir_all(work_dir.join("avg").join(name.as_str())) { - tracing::warn!(error = ?e, "Failed to create avg directory; keeping previous aggregation"); - result_tensor_file_name.clone() - } else { - let paths = { - let result_tensor = unsafe { MmapedSafetensors::new(result_tensor_file_name.as_path()) }.expect("tensor file is readable"); - - let mut paths = Vec::new(); - for (tensor_name, tensor) in result_tensor.tensors() { - // Load tensors; skip this tensor on failure - let current_tensor = match tensor.load(&device) { - Ok(t) => t, - Err(e) => { - tracing::warn!(error = ?e, tensor = %tensor_name, "Failed to load current tensor; skipping"); - continue; - } - }; - let other_tensor = match new_tensor.load(&tensor_name, &device) { - Ok(t) => t, - Err(e) => { - tracing::warn!(error = ?e, tensor = %tensor_name, "Failed to load new tensor; skipping"); - continue; - } - }; - - let avg_tensor = match (current_tensor + other_tensor).and_then(|t| t / 2.) { - Ok(t) => t, - Err(e) => { - tracing::warn!(error = ?e, tensor = %tensor_name, "Failed to average tensors; skipping"); - continue; - } - }; - - let avg_tensor_file_name = work_dir.join("avg").join(name.as_str()).join(&tensor_name); - if let Err(e) = candle_core::safetensors::save( - &HashMap::from([(tensor_name.clone(), avg_tensor)]), - avg_tensor_file_name.as_path(), - ) { - tracing::warn!(error = ?e, tensor = %tensor_name, "Failed to save averaged tensor; skipping"); - continue; - } - paths.push(avg_tensor_file_name); - } - - paths - }; - - if paths.is_empty() { - tracing::warn!("No tensors averaged; keeping previous aggregation"); - } else if let Ok(all_tensors) = unsafe { MmapedSafetensors::multi(&paths) } { - // If at this point we're still mmapping 'result_tensor_file_name', we could have undefined behaviour! - // Make sure that we don't! - match serialize_to_file( - all_tensors.tensors(), - &None, - result_tensor_file_name.as_path(), - ) { - Ok(_) => { - tracing::debug!("Updated averaged tensors"); - } - Err(e) => { - tracing::warn!(error = ?e, "Failed to serialize averaged tensors; keeping previous aggregation"); - } - } - } else { - tracing::warn!("Failed to mmap averaged tensors; keeping previous aggregation"); - } - - if let Err(e) = std::fs::remove_dir_all(work_dir.join("avg").join(name.as_str())) { - tracing::warn!(error = ?e, "Failed to cleanup avg directory"); - } - - result_tensor_file_name.clone() - } - } - Err(e) => { - tracing::warn!(error = ?e, "Failed to mmap new tensor; keeping previous aggregation"); + let tmp_dir_name = work_dir.join("tmp_join"); + let resulting_tensor_file_name = work_dir.join(format!("joined_{:?}", Uuid::new_v4())); - // Keep previous aggregation file - result_tensor_file_name.clone() - } + // TODO: This isn't correct. We need a weighting with the number of samples processed by + // the worker. Until we have this information lets assume we traing with two workers. + // Average the new tensor with an existing one + let average_op = |a: &Tensor, b: &Tensor|{ + // Compute (a + b) / 2. + (a + b).and_then(|t| t / 2.) + }; + if let Err(e) = apply_tensor_op( + &file_name, + &result_tensor_file_name, + &resulting_tensor_file_name, + &tmp_dir_name, + &device, + average_op, + ) + .await{ + tracing::warn!(error = ?e, "Failed to average results"); } - }}).await.expect("averaging tasks runs to completion") + Some(resulting_tensor_file_name) } }; - current_result_tensor_file_name = Some(result_tensor_file_name); current_worker += 1; // We assume that each worker sends their parameters, then waits to receive updates. @@ -270,6 +229,9 @@ impl JobExecutor for ParameterServerExecutor { let final_tensor_file_name = work_dir.join("avg-final"); fs::rename(result_file_name.as_path(), final_tensor_file_name.as_path()).await.expect("tensor file can be renamed"); + // Do outer optimization + let gradient_file = nesterov(final_tensor_file_name.clone(), work_dir.clone(), &device, momentum, learning_rate).await.expect("nesterov"); + // These need to be reset before sending out the result! current_result_tensor_file_name = None; current_worker = 0; @@ -281,7 +243,7 @@ impl JobExecutor for ParameterServerExecutor { Ok(item) => { tracing::info!(peer_id = item.meta.name, "Sending parameter server update"); - match fs::File::open(final_tensor_file_name.as_path()).await { + match fs::File::open(gradient_file.as_path()).await { Ok(mut file) => { let mut writer = item.writer.compat_write(); @@ -313,7 +275,8 @@ impl JobExecutor for ParameterServerExecutor { } } - fs::remove_file(final_tensor_file_name.as_path()).await.expect("tensor file can be removed"); + fs::remove_file(gradient_file.as_path()).await.expect("gradient file can be removed"); + fs::remove_file(final_tensor_file_name.as_path()).await.expect("tensor file can be removed") } } }; @@ -336,3 +299,224 @@ impl JobExecutor for ParameterServerExecutor { Ok(ParameterServerExecution { task_tracker }) } } + +/// Applies a binary operation to corresponding tensors from two safetensor files. +/// +/// This function memory-maps two input safetensor files, iterates through the tensors +/// of the first file, finds the corresponding tensor by name in the second file, +/// and applies the provided operation `op` to the pair. The resulting tensors +/// are saved into a `temp_path` and combined into a single result file. Only +/// two tensors will be held in memory at the same time. +/// +/// # Arguments +/// * `file_a_path` - Path to the first safetensor file. +/// * `file_b_path` - Path to the second safetensor file. +/// * `output_path` - Path where the resulting safetensor file will be saved. +/// * `temp_path` - Path to a temporary directory where the intermediate safetensor file will be saved. +/// * `device` - The Candle device to perform computations on (e.g., `Device::Cpu`). +/// * `op` - A closure that takes two tensors and returns a new tensor. +/// +/// # Returns +/// A `Result` indicating success or a `TensorOpError` on failure. +/// +/// # Note +/// Tensors present in the second file but not the first are ignored. If a tensor +/// from the first file is not found in the second, it is skipped with a warning. +/// The `temp_path` will be created and delted by the function. Make sure it doesn't +/// point to an existing directory that contains important data. +/// Also make sure that the tensors in `op` are in the same order as they are passed to the function. +async fn apply_tensor_op( + file_a_path: &Path, + file_b_path: &Path, + output_path: &Path, + temp_path: &Path, + device: &Device, + op: F, +) -> Result<(), TensorOpError> +where + F: Fn(&Tensor, &Tensor) -> Result, +{ + // 1. Open both safetensor files in a memory-mapped way. + // SAFETY: The MmapedSafetensors::new function is unsafe because it assumes + // the underlying file will not be modified while the memory map is active. + let tensors_a = unsafe { candle_core::safetensors::MmapedSafetensors::new(file_a_path)? }; + let tensors_b = unsafe { candle_core::safetensors::MmapedSafetensors::new(file_b_path)? }; + + let mut result_tensors = Vec::new(); + + // 2. Iterate through each tensor in the first file. + for (name, tensor_view) in tensors_a.tensors() { + // Try to load the corresponding tensor from the second file. + match tensors_b.load(&name, device) { + Ok(tensor_b) => { + // If found, load the tensor from the first file. + let tensor_a = tensor_view.load(device)?; + + // 3. Apply the provided computation function and serialize the result to disk. + let result_tensor = op(&tensor_a, &tensor_b)?; + let result_path = temp_path.join(format!("{:X}", Sha256::digest(name.clone()))); + candle_core::safetensors::save( + &HashMap::from([(name, result_tensor)]), + result_path.clone(), + )?; + result_tensors.push(result_path); + } + Err(_) => { + // If a tensor from file A doesn't exist in file B, skip it. + tracing::warn!("Tensor '{}' not found in second file, skipping.", name); + continue; + } + } + } + + // 4. Write all result tensors to the new file. + if result_tensors.is_empty() { + tracing::warn!("Warning: No matching tensors found to process."); + } else { + let all_tensors = unsafe { MmapedSafetensors::multi(&result_tensors)? }; + serialize_to_file(all_tensors.tensors(), &None, output_path)?; + } + + Ok(()) +} + +async fn update_momentum( + work_dir: PathBuf, + gradient_file_name: &Path, + device: &Device, + momentum: f64, +) -> Result { + // If we are in the first round, we need to initialize the momentum with the gradient + let momentum_file = work_dir.join("momentum"); + if fs::metadata(momentum_file.clone()).await.is_err() { + fs::copy(gradient_file_name.to_path_buf(), momentum_file.clone()) + .await + .expect("copy gradients to momentum"); + } else { + let momentum_update_file = work_dir.join("momentum_update"); + let momentum_op = |g: &Tensor, m: &Tensor| { + // Calculation: (mu * momentum) / 2.0 + (momentum * m).and_then(|t| t + g) + }; + apply_tensor_op( + gradient_file_name, + &momentum_file, + &momentum_update_file, + &work_dir, + device, + momentum_op, + ) + .await?; + fs::copy(momentum_update_file, momentum_file.clone()).await?; + } + Ok(momentum_file) +} + +async fn nesterov( + gradient_file: PathBuf, + work_dir: PathBuf, + device: &Device, + momentum: f64, + learning_rate: f64, +) -> Result { + let momentum_file = update_momentum(work_dir.clone(), &gradient_file, device, momentum).await?; + + let result_gradient_name = work_dir.join("gradient_update"); + + let nesterov_op = |g: &Tensor, m: &Tensor| { + // Compute: learning_rate * ((momentum * m) + g) + (momentum * m) + .and_then(|t| t + g) + .and_then(|t| learning_rate * t) + }; + apply_tensor_op( + &gradient_file, + &momentum_file, + &result_gradient_name, + &work_dir, + device, + nesterov_op, + ) + .await?; + + Ok(result_gradient_name) +} + +#[cfg(test)] +mod tests { + use super::*; + + // Belows test is based on the following python code + // import torch + // param = [torch.Tensor([1,1,1,1,1])] + // optim = torch.optim.SGD(param, lr = 0.1, momentum=.7, nesterov=True) + // param[0].grad = torch.Tensor([.5, .5, .5, .5, .5]) + // optim.step() + // print(1-param[0]) + // optim.zero_grad() + // param[0].grad = torch.Tensor([.1, .2, .3, .4, .5]) + // optim.step() + // print(0.915-param[0]) + // tensor([0.0850, 0.0850, 0.0850, 0.0850, 0.0850]) + // tensor([0.0415, 0.0585, 0.0755, 0.0925, 0.1095]) + #[tokio::test] + async fn test_nesterov() { + use tempfile::TempDir; + // Create a tmp dir for the test + let tmp_dir = TempDir::new().unwrap(); + + let device = Device::Cpu; + let gradient_tensor = + candle_core::Tensor::from_vec(vec![0.5, 0.5, 0.5, 0.5, 0.5], 5, &device).unwrap(); + let gradient_file_name = tmp_dir.path().join("gradient_file"); + safetensors::serialize_to_file( + vec![("gradient", &gradient_tensor)], + &None, + &gradient_file_name.clone(), + ) + .unwrap(); + + let result = nesterov( + gradient_file_name.clone(), + tmp_dir.path().to_path_buf(), + &device, + 0.7, + 0.1, + ) + .await + .unwrap(); + let update = candle_core::safetensors::load(result, &device).unwrap(); + assert_eq!( + update.get("gradient").unwrap().to_vec1::().unwrap(), + vec![0.085, 0.085, 0.085, 0.085, 0.085] + ); + + let gradient_tensor = + candle_core::Tensor::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5], 5, &device).unwrap(); + safetensors::serialize_to_file( + vec![("gradient", &gradient_tensor)], + &None, + &gradient_file_name, + ) + .unwrap(); + let result = nesterov( + gradient_file_name, + tmp_dir.path().to_path_buf(), + &device, + 0.7, + 0.1, + ) + .await + .unwrap(); + let update = candle_core::safetensors::load(result, &device).unwrap(); + let difference = update + .get("gradient") + .unwrap() + .to_vec1::() + .unwrap() + .into_iter() + .zip(vec![0.0415, 0.0585, 0.0755, 0.0925, 0.1095]) + .fold(0f64, |acc, (a, b)| acc + (a - b).abs()); + assert!(difference < 0.000001) + } +} diff --git a/drivers/hypha-accelerate-driver/src/training.py b/drivers/hypha-accelerate-driver/src/training.py index 6905f92..cf9f01e 100644 --- a/drivers/hypha-accelerate-driver/src/training.py +++ b/drivers/hypha-accelerate-driver/src/training.py @@ -10,7 +10,7 @@ import torch.utils.data from accelerate import Accelerator from safetensors import safe_open -from safetensors.torch import load, load_file, save_model +from safetensors.torch import load, load_file, save_file, save_model from torch.nn import Module from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -189,15 +189,23 @@ def get_scheduler(type: str, args: dict[str, float], optmizer: Optimizer) -> tor raise RuntimeError(f"Learning rate Scheduler {type} not supported") -def merge_models(origin: Module, weight_path: str, alpha: float) -> Module: - # All weights need to be on CPU - origin.to("cpu") - state_dict = origin.state_dict() - with safe_open(weight_path, framework="pt", device="cpu") as b: # type: ignore - for name in b.keys(): # noqa: SIM118 - state_dict[name] += (alpha * (b.get_tensor(name) - state_dict[name])).to(state_dict[name].dtype) - origin.load_state_dict(state_dict) - return origin +def merge_models(old_model: str, weight_path: str) -> dict[str, torch.Tensor]: + state_dict: dict[str, torch.Tensor] = {} + with ( + safe_open(weight_path, framework="pt", device="cpu") as g, # type: ignore[no-untyped-call] + safe_open(old_model, framework="pt", device="cpu") as m, # type: ignore[no-untyped-call] + ): + for name in m.keys(): # noqa: SIM118 + # state_dict[name] += (alpha * (b.get_tensor(name) - state_dict[name])).to(state_dict[name].dtype) + state_dict[name] = m.get_tensor(name) - g.get_tensor(name) + return state_dict + + +def extract_gradients(state_dict: dict[str, torch.Tensor], previous_model_path: str) -> dict[str, torch.Tensor]: + with safe_open(previous_model_path, framework="pt", device="cpu") as p: # type: ignore[no-untyped-call] + for name in p.keys(): # noqa: SIM118 + state_dict[name] -= p.get_tensor(name).to(state_dict[name].dtype) + return state_dict def dataset_wrapper(dataset: torch.utils.data.DataLoader) -> Iterator[dict[str, torch.Tensor]]: # type: ignore[type-arg] @@ -248,6 +256,10 @@ def main(socket_path: str, work_dir: str, job_json: str) -> None: # noqa: PLR09 job_spec["config"]["batch_size"], ) + # Serialize the model to disk + previous_model_path = os.path.join(work_dir, "0_global_weights.pt") + save_model(model, previous_model_path) + model, optimizer, training_dataloader, scheduler = accelerator.prepare(model, optimizer, data_loader, scheduler) training_data_iter = dataset_wrapper(training_dataloader) run_epoch = get_training_loop(job_spec["config"], training_data_iter) @@ -276,9 +288,11 @@ def main(socket_path: str, work_dir: str, job_json: str) -> None: # noqa: PLR09 rel_path = parameters.get("path") if parameters else latest.get("path") if isinstance(rel_path, str): path = os.path.join(work_dir, rel_path) - base_model = accelerator.unwrap_model(model) - merge_models(base_model, path, 0.9) - model = accelerator.prepare(base_model) + # Load new model with outer gradients + model.load_state_dict(merge_models(previous_model_path, path)) + # over write previous model + save_model(model, previous_model_path) + model = accelerator.prepare(model) print("Weights updated from", rel_path, flush=True) except Exception as e: print(f"pointer handling error: {e}") @@ -293,10 +307,15 @@ def main(socket_path: str, work_dir: str, job_json: str) -> None: # noqa: PLR09 if accelerator.is_main_process and epoch % job_spec["config"]["checkpointing"] == 0: # For testing purposes set to global! - file_name = f"{epoch}_global_weights.pt" + file_name = f"{epoch}_local_gradients.pt" result_path = os.path.join(work_dir, file_name) # Save unwrapped model to avoid accelerator wrappers interfering - save_model(accelerator.unwrap_model(model), result_path) + # + model = accelerator.unwrap_model(model) + # All weights need to be on CPU + model.to("cpu") + + save_file(extract_gradients(model.state_dict(), previous_model_path), result_path) session.send(job_spec["results"], file_name) # Mark that before the next training epoch we must wait for an update