Skip to content

Add timeout for connection from client to remote process allocator #583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 41 additions & 13 deletions hyperactor_mesh/src/alloc/remoteprocess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ pub struct RemoteProcessAllocator {
cancel_token: CancellationToken,
}

async fn conditional_sleeper<F: futures::Future<Output = ()>>(t: Option<F>) {
match t {
Some(timer) => timer.await,
None => futures::future::pending().await,
}
}

impl RemoteProcessAllocator {
/// Create a new allocator. It will not start until start() is called.
pub fn new() -> Arc<Self> {
Expand All @@ -125,19 +132,27 @@ impl RemoteProcessAllocator {
/// 4. Allocator sends Done message to bootstrap_addr when Alloc is done.
///
/// At any point, client can send Stop message to serve_addr to stop the allocator.
pub async fn start(&self, cmd: Command, serve_addr: ChannelAddr) -> Result<(), anyhow::Error> {
pub async fn start(
&self,
cmd: Command,
serve_addr: ChannelAddr,
timeout: Option<Duration>,
) -> Result<(), anyhow::Error> {
let process_allocator = ProcessAllocator::new(cmd);
self.start_with_allocator(serve_addr, process_allocator)
self.start_with_allocator(serve_addr, process_allocator, timeout)
.await
}

/// Start a remote process allocator with given allocator listening for
/// RemoteProcessAllocatorMessage on serve_addr.
/// If timeout is Some, the allocator will exit if no client connects within
/// that timeout, and no child allocation is running.
/// Used for testing.
pub async fn start_with_allocator<A: Allocator + Send + Sync + 'static>(
&self,
serve_addr: ChannelAddr,
mut process_allocator: A,
timeout: Option<Duration>,
) -> Result<(), anyhow::Error>
where
<A as Allocator>::Alloc: Send,
Expand Down Expand Up @@ -166,6 +181,9 @@ impl RemoteProcessAllocator {

let mut active_allocation: Option<ActiveAllocation> = None;
loop {
// Refresh each loop iteration so the timer updates whenever a message
// is received.
let sleep = conditional_sleeper(timeout.map(|t| RealClock.sleep(t)));
tokio::select! {
msg = rx.recv() => {
match msg {
Expand Down Expand Up @@ -218,6 +236,16 @@ impl RemoteProcessAllocator {

break;
}
_ = sleep => {
// If there are any active allocations, reset the timeout.
if active_allocation.is_some() {
continue;
}
// Else, exit the loop as a client hasn't connected in a reasonable
// amount of time.
tracing::warn!("timeout elapsed without any allocations, exiting");
break;
}
}
}

Expand Down Expand Up @@ -1143,7 +1171,7 @@ mod test {
let remote_allocator = remote_allocator.clone();
async move {
remote_allocator
.start_with_allocator(serve_addr, allocator)
.start_with_allocator(serve_addr, allocator, None)
.await
}
});
Expand Down Expand Up @@ -1280,7 +1308,7 @@ mod test {
let remote_allocator = remote_allocator.clone();
async move {
remote_allocator
.start_with_allocator(serve_addr, allocator)
.start_with_allocator(serve_addr, allocator, None)
.await
}
});
Expand Down Expand Up @@ -1376,7 +1404,7 @@ mod test {
let remote_allocator = remote_allocator.clone();
async move {
remote_allocator
.start_with_allocator(serve_addr, allocator)
.start_with_allocator(serve_addr, allocator, None)
.await
}
});
Expand Down Expand Up @@ -1483,7 +1511,7 @@ mod test {
let remote_allocator = remote_allocator.clone();
async move {
remote_allocator
.start_with_allocator(serve_addr, allocator)
.start_with_allocator(serve_addr, allocator, None)
.await
}
});
Expand Down Expand Up @@ -1566,7 +1594,7 @@ mod test {
let remote_allocator = remote_allocator.clone();
async move {
remote_allocator
.start_with_allocator(serve_addr, allocator)
.start_with_allocator(serve_addr, allocator, None)
.await
}
});
Expand Down Expand Up @@ -1640,14 +1668,14 @@ mod test_alloc {
let task1_allocator_handle = tokio::spawn(async move {
tracing::info!("spawning task1");
task1_allocator_copy
.start(task1_cmd, task1_addr)
.start(task1_cmd, task1_addr, None)
.await
.unwrap();
});
let task2_allocator_copy = task2_allocator.clone();
let task2_allocator_handle = tokio::spawn(async move {
task2_allocator_copy
.start(task2_cmd, task2_addr)
.start(task2_cmd, task2_addr, None)
.await
.unwrap();
});
Expand Down Expand Up @@ -1763,15 +1791,15 @@ mod test_alloc {
let task1_allocator_handle = tokio::spawn(async move {
tracing::info!("spawning task1");
task1_allocator_copy
.start(task1_cmd, task1_addr)
.start(task1_cmd, task1_addr, None)
.await
.unwrap();
tracing::info!("task1 terminated");
});
let task2_allocator_copy = task2_allocator.clone();
let task2_allocator_handle = tokio::spawn(async move {
task2_allocator_copy
.start(task2_cmd, task2_addr)
.start(task2_cmd, task2_addr, None)
.await
.unwrap();
tracing::info!("task2 terminated");
Expand Down Expand Up @@ -1884,14 +1912,14 @@ mod test_alloc {
let task1_allocator_handle = tokio::spawn(async move {
tracing::info!("spawning task1");
task1_allocator_copy
.start(task1_cmd, task1_addr)
.start(task1_cmd, task1_addr, None)
.await
.unwrap();
});
let task2_allocator_copy = task2_allocator.clone();
let task2_allocator_handle = tokio::spawn(async move {
task2_allocator_copy
.start(task2_cmd, task2_addr)
.start(task2_cmd, task2_addr, None)
.await
.unwrap();
});
Expand Down
Loading
Loading