diff --git a/hyperactor_mesh/Cargo.toml b/hyperactor_mesh/Cargo.toml index 1d680ad1..8c86bc6d 100644 --- a/hyperactor_mesh/Cargo.toml +++ b/hyperactor_mesh/Cargo.toml @@ -56,8 +56,6 @@ rand = { version = "0.8", features = ["small_rng"] } serde = { version = "1.0.185", features = ["derive", "rc"] } serde_bytes = "0.11" serde_json = { version = "1.0.140", features = ["alloc", "float_roundtrip", "unbounded_depth"] } -signal-hook = "0.3" -signal-hook-tokio = { version = "0.3", features = ["futures-v0_3"] } thiserror = "2.0.12" tokio = { version = "1.45.0", features = ["full", "test-util", "tracing"] } tokio-stream = { version = "0.1.17", features = ["fs", "io-util", "net", "signal", "sync", "time"] } diff --git a/hyperactor_mesh/src/bootstrap.rs b/hyperactor_mesh/src/bootstrap.rs index 899deebb..428d6ec7 100644 --- a/hyperactor_mesh/src/bootstrap.rs +++ b/hyperactor_mesh/src/bootstrap.rs @@ -6,9 +6,9 @@ * LICENSE file in the root directory of this source tree. */ +use std::sync::Arc; use std::time::Duration; -use futures::StreamExt; use hyperactor::ActorRef; use hyperactor::Named; use hyperactor::ProcId; @@ -20,9 +20,10 @@ use hyperactor::channel::Tx; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; use hyperactor::mailbox::MailboxServer; +use hyperactor::proc::Proc; use serde::Deserialize; use serde::Serialize; -use signal_hook::consts::signal::SIGTERM; +use tokio::sync::Mutex; use crate::proc_mesh::mesh_agent::MeshAgent; @@ -125,7 +126,22 @@ async fn exit_if_missed_heartbeat(bootstrap_index: usize, bootstrap_addr: Channe /// Use [`bootstrap_or_die`] to implement this behavior directly. pub async fn bootstrap() -> anyhow::Error { pub async fn go() -> Result<(), anyhow::Error> { - let mut signals = signal_hook_tokio::Signals::new([SIGTERM])?; + let procs = Arc::new(Mutex::new(Vec::::new())); + let procs_for_cleanup = procs.clone(); + let _cleanup_guard = hyperactor::register_signal_cleanup_scoped(Box::pin(async move { + for proc_to_stop in procs_for_cleanup.lock().await.iter_mut() { + if let Err(err) = proc_to_stop + .destroy_and_wait(Duration::from_millis(10), None) + .await + { + tracing::error!( + "error while stopping proc {}: {}", + proc_to_stop.proc_id(), + err + ); + } + } + })); let bootstrap_addr: ChannelAddr = std::env::var(BOOTSTRAP_ADDR_ENV) .map_err(|err| anyhow::anyhow!("read `{}`: {}", BOOTSTRAP_ADDR_ENV, err))? @@ -145,58 +161,31 @@ pub async fn bootstrap() -> anyhow::Error { tokio::spawn(exit_if_missed_heartbeat(bootstrap_index, bootstrap_addr)); - let mut procs = Vec::new(); - loop { let _ = hyperactor::tracing::info_span!("wait_for_next_message_from_mesh_agent"); - tokio::select! { - msg = rx.recv() => { - match msg? { - Allocator2Process::StartProc(proc_id, listen_transport) => { - let (proc, mesh_agent) = MeshAgent::bootstrap(proc_id.clone()).await?; - let (proc_addr, proc_rx) = - channel::serve(ChannelAddr::any(listen_transport)).await?; - // Undeliverable messages get forwarded to the mesh agent. - let handle = proc.clone().serve(proc_rx, mesh_agent.port()); - drop(handle); // linter appeasement; it is safe to drop this future - tx.send(Process2Allocator( - bootstrap_index, - Process2AllocatorMessage::StartedProc( - proc_id.clone(), - mesh_agent.bind(), - proc_addr, - ), - )) - .await?; - procs.push(proc); - } - Allocator2Process::StopAndExit(code) => { - tracing::info!("stopping procs with code {code}"); - for mut proc_to_stop in procs { - if let Err(err) = proc_to_stop - .destroy_and_wait(Duration::from_millis(10), None) - .await - { - tracing::error!( - "error while stopping proc {}: {}", - proc_to_stop.proc_id(), - err - ); - } - } - tracing::info!("exiting with {code}"); - std::process::exit(code); - } - Allocator2Process::Exit(code) => { - tracing::info!("exiting with {code}"); - std::process::exit(code); - } - } + match rx.recv().await? { + Allocator2Process::StartProc(proc_id, listen_transport) => { + let (proc, mesh_agent) = MeshAgent::bootstrap(proc_id.clone()).await?; + let (proc_addr, proc_rx) = + channel::serve(ChannelAddr::any(listen_transport)).await?; + // Undeliverable messages get forwarded to the mesh agent. + let handle = proc.clone().serve(proc_rx, mesh_agent.port()); + drop(handle); // linter appeasement; it is safe to drop this future + tx.send(Process2Allocator( + bootstrap_index, + Process2AllocatorMessage::StartedProc( + proc_id.clone(), + mesh_agent.bind(), + proc_addr, + ), + )) + .await?; + procs.lock().await.push(proc); } - signal = signals.next() => { - if signal.is_some_and(|sig| sig == SIGTERM) { - tracing::info!("received SIGTERM, stopping procs"); - for mut proc_to_stop in procs { + Allocator2Process::StopAndExit(code) => { + tracing::info!("stopping procs with code {code}"); + { + for proc_to_stop in procs.lock().await.iter_mut() { if let Err(err) = proc_to_stop .destroy_and_wait(Duration::from_millis(10), None) .await @@ -208,17 +197,13 @@ pub async fn bootstrap() -> anyhow::Error { ); } } - // SAFETY: We're setting the handle to SigDfl (defautl system behaviour) - if let Err(err) = unsafe { - nix::sys::signal::signal(nix::sys::signal::SIGTERM, nix::sys::signal::SigHandler::SigDfl) - } { - tracing::error!("failed to signal SIGTERM: {}", err); - } - if let Err(err) = nix::sys::signal::raise(nix::sys::signal::SIGTERM) { - tracing::error!("failed to raise SIGTERM: {}", err); - } - std::process::exit(128 + SIGTERM); } + tracing::info!("exiting with {code}"); + std::process::exit(code); + } + Allocator2Process::Exit(code) => { + tracing::info!("exiting with {code}"); + std::process::exit(code); } } }