Skip to content

Spawn signal handler out of loop into background task #579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 68 additions & 59 deletions hyperactor_mesh/src/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

use std::sync::Arc;
use std::time::Duration;

use futures::StreamExt;
Expand All @@ -20,9 +21,11 @@ use hyperactor::channel::Tx;
use hyperactor::clock::Clock;
use hyperactor::clock::RealClock;
use hyperactor::mailbox::MailboxServer;
use hyperactor::proc::Proc;
use serde::Deserialize;
use serde::Serialize;
use signal_hook::consts::signal::SIGTERM;
use tokio::sync::Mutex;

use crate::proc_mesh::mesh_agent::MeshAgent;

Expand Down Expand Up @@ -126,6 +129,43 @@ async fn exit_if_missed_heartbeat(bootstrap_index: usize, bootstrap_addr: Channe
pub async fn bootstrap() -> anyhow::Error {
pub async fn go() -> Result<(), anyhow::Error> {
let mut signals = signal_hook_tokio::Signals::new([SIGTERM])?;
let procs = Arc::new(Mutex::new(Vec::<Proc>::new()));
{
let procs = procs.clone();
tokio::spawn(async move {
if let Some(signal) = signals.next().await {
if let Ok(signal) = nix::sys::signal::Signal::try_from(signal) {
{
for proc_to_stop in procs.lock().await.iter_mut() {
if let Err(err) = proc_to_stop
.destroy_and_wait(Duration::from_millis(10), None)
.await
{
tracing::error!(
"error while stopping proc {}: {}",
proc_to_stop.proc_id(),
err
);
}
}
}
// SAFETY: We're setting the handle to SigDfl (default system behaviour)
if let Err(err) = unsafe {
nix::sys::signal::signal(signal, nix::sys::signal::SigHandler::SigDfl)
} {
tracing::error!(
"failed to reset {:?} to default signal handler: {}",
signal,
err
);
}
if let Err(err) = nix::sys::signal::raise(signal) {
tracing::error!("failed to raise {:?}: {}", signal, err);
}
}
}
});
}

let bootstrap_addr: ChannelAddr = std::env::var(BOOTSTRAP_ADDR_ENV)
.map_err(|err| anyhow::anyhow!("read `{}`: {}", BOOTSTRAP_ADDR_ENV, err))?
Expand All @@ -145,58 +185,31 @@ pub async fn bootstrap() -> anyhow::Error {

tokio::spawn(exit_if_missed_heartbeat(bootstrap_index, bootstrap_addr));

let mut procs = Vec::new();

loop {
let _ = hyperactor::tracing::info_span!("wait_for_next_message_from_mesh_agent");
tokio::select! {
msg = rx.recv() => {
match msg? {
Allocator2Process::StartProc(proc_id, listen_transport) => {
let (proc, mesh_agent) = MeshAgent::bootstrap(proc_id.clone()).await?;
let (proc_addr, proc_rx) =
channel::serve(ChannelAddr::any(listen_transport)).await?;
// Undeliverable messages get forwarded to the mesh agent.
let handle = proc.clone().serve(proc_rx, mesh_agent.port());
drop(handle); // linter appeasement; it is safe to drop this future
tx.send(Process2Allocator(
bootstrap_index,
Process2AllocatorMessage::StartedProc(
proc_id.clone(),
mesh_agent.bind(),
proc_addr,
),
))
.await?;
procs.push(proc);
}
Allocator2Process::StopAndExit(code) => {
tracing::info!("stopping procs with code {code}");
for mut proc_to_stop in procs {
if let Err(err) = proc_to_stop
.destroy_and_wait(Duration::from_millis(10), None)
.await
{
tracing::error!(
"error while stopping proc {}: {}",
proc_to_stop.proc_id(),
err
);
}
}
tracing::info!("exiting with {code}");
std::process::exit(code);
}
Allocator2Process::Exit(code) => {
tracing::info!("exiting with {code}");
std::process::exit(code);
}
}
match rx.recv().await? {
Allocator2Process::StartProc(proc_id, listen_transport) => {
let (proc, mesh_agent) = MeshAgent::bootstrap(proc_id.clone()).await?;
let (proc_addr, proc_rx) =
channel::serve(ChannelAddr::any(listen_transport)).await?;
// Undeliverable messages get forwarded to the mesh agent.
let handle = proc.clone().serve(proc_rx, mesh_agent.port());
drop(handle); // linter appeasement; it is safe to drop this future
tx.send(Process2Allocator(
bootstrap_index,
Process2AllocatorMessage::StartedProc(
proc_id.clone(),
mesh_agent.bind(),
proc_addr,
),
))
.await?;
procs.lock().await.push(proc);
}
signal = signals.next() => {
if signal.is_some_and(|sig| sig == SIGTERM) {
tracing::info!("received SIGTERM, stopping procs");
for mut proc_to_stop in procs {
Allocator2Process::StopAndExit(code) => {
tracing::info!("stopping procs with code {code}");
{
for proc_to_stop in procs.lock().await.iter_mut() {
if let Err(err) = proc_to_stop
.destroy_and_wait(Duration::from_millis(10), None)
.await
Expand All @@ -208,17 +221,13 @@ pub async fn bootstrap() -> anyhow::Error {
);
}
}
// SAFETY: We're setting the handle to SigDfl (defautl system behaviour)
if let Err(err) = unsafe {
nix::sys::signal::signal(nix::sys::signal::SIGTERM, nix::sys::signal::SigHandler::SigDfl)
} {
tracing::error!("failed to signal SIGTERM: {}", err);
}
if let Err(err) = nix::sys::signal::raise(nix::sys::signal::SIGTERM) {
tracing::error!("failed to raise SIGTERM: {}", err);
}
std::process::exit(128 + SIGTERM);
}
tracing::info!("exiting with {code}");
std::process::exit(code);
}
Allocator2Process::Exit(code) => {
tracing::info!("exiting with {code}");
std::process::exit(code);
}
}
}
Expand Down
Loading