Skip to content

Use tokio::oneshot for sim clock #822

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions hyperactor/src/channel/sim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use super::*;
use crate::channel;
use crate::clock::Clock;
use crate::clock::RealClock;
use crate::clock::SimClock;
use crate::data::Serialized;
use crate::mailbox::MessageEnvelope;
use crate::simnet;
Expand Down Expand Up @@ -129,7 +128,7 @@ pub(crate) struct MessageDeliveryEvent {
src_addr: Option<ChannelAddr>,
dest_addr: ChannelAddr,
data: Serialized,
duration_ms: u64,
duration: tokio::time::Duration,
}

impl MessageDeliveryEvent {
Expand All @@ -139,14 +138,14 @@ impl MessageDeliveryEvent {
src_addr,
dest_addr,
data,
duration_ms: 100,
duration: tokio::time::Duration::from_millis(100),
}
}
}

#[async_trait]
impl Event for MessageDeliveryEvent {
async fn handle(&self) -> Result<(), SimNetError> {
async fn handle(&mut self) -> Result<(), SimNetError> {
// Send the message to the correct receiver.
SENDER
.send(
Expand All @@ -158,8 +157,8 @@ impl Event for MessageDeliveryEvent {
Ok(())
}

fn duration_ms(&self) -> u64 {
self.duration_ms
fn duration(&self) -> tokio::time::Duration {
self.duration
}

fn summary(&self) -> String {
Expand All @@ -178,12 +177,12 @@ impl Event for MessageDeliveryEvent {
src: src_addr.clone(),
dst: self.dest_addr.clone(),
};
self.duration_ms = topology
self.duration = topology
.lock()
.await
.topology
.get(&edge)
.map_or_else(|| 1, |v| v.latency.as_millis() as u64);
.map_or_else(|| tokio::time::Duration::from_millis(1), |v| v.latency);
}
}
}
Expand Down Expand Up @@ -332,7 +331,7 @@ impl<M: RemoteMessage> Tx<M> for SimTx<M> {
self.dst_addr.clone(),
data,
)),
time: SimClock.millis_since_start(RealClock.now()),
time: RealClock.now(),
}),
_ => handle.send_event(Box::new(MessageDeliveryEvent::new(
self.src_addr.clone(),
Expand Down Expand Up @@ -551,7 +550,10 @@ mod tests {
.await
.unwrap();

assert_eq!(SimClock.millis_since_start(RealClock.now()), 0);
assert_eq!(
SimClock.duration_since_start(RealClock.now()),
tokio::time::Duration::ZERO
);
// Fast forward real time to 5 seconds
tokio::time::advance(tokio::time::Duration::from_secs(5)).await;
{
Expand Down
137 changes: 72 additions & 65 deletions hyperactor/src/clock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,17 @@ use std::error::Error;
use std::fmt;
use std::sync::LazyLock;
use std::sync::Mutex;
use std::sync::OnceLock;
use std::time::SystemTime;

use async_trait::async_trait;
use futures::pin_mut;
use hyperactor_telemetry::TelemetryClock;
use serde::Deserialize;
use serde::Serialize;

use crate::Mailbox;
use crate::channel::ChannelAddr;
use crate::data::Named;
use crate::id;
use crate::mailbox::DeliveryError;
use crate::mailbox::MailboxSender;
use crate::mailbox::MessageEnvelope;
use crate::mailbox::Undeliverable;
use crate::mailbox::UndeliverableMailboxSender;
use crate::mailbox::monitored_return_handle;
use crate::simnet::SleepEvent;
use crate::simnet::Event;
use crate::simnet::SimNetError;
use crate::simnet::simnet_handle;

struct SimTime {
Expand Down Expand Up @@ -183,6 +175,45 @@ impl ClockKind {
}
}

#[derive(Debug)]
struct SleepEvent {
done_tx: Option<tokio::sync::oneshot::Sender<()>>,
duration: tokio::time::Duration,
}

impl SleepEvent {
pub(crate) fn new(
done_tx: tokio::sync::oneshot::Sender<()>,
duration: tokio::time::Duration,
) -> Box<Self> {
Box::new(Self {
done_tx: Some(done_tx),
duration,
})
}
}

#[async_trait]
impl Event for SleepEvent {
async fn handle(&mut self) -> Result<(), SimNetError> {
self.done_tx
.take()
.unwrap()
.send(())
.map_err(|_| SimNetError::PanickedTask)?;

Ok(())
}

fn duration(&self) -> tokio::time::Duration {
self.duration
}

fn summary(&self) -> String {
format!("Sleeping for {} ms", self.duration.as_millis())
}
}

/// Clock to be used in simulator runs that allows the simnet to create a scheduled event for.
/// When the wakeup event becomes the next earliest scheduled event, the simnet will advance it's
/// time to the wakeup time and use the transmitter to wake up this green thread
Expand All @@ -192,33 +223,25 @@ pub struct SimClock;
impl Clock for SimClock {
/// Tell the simnet to wake up this green thread after the specified duration has pass on the simnet
async fn sleep(&self, duration: tokio::time::Duration) {
let mailbox = SimClock::mailbox().clone();
let (tx, rx) = mailbox.open_once_port::<()>();
let (tx, rx) = tokio::sync::oneshot::channel::<()>();

simnet_handle()
.unwrap()
.send_event(SleepEvent::new(
tx.bind(),
mailbox,
duration.as_millis() as u64,
))
.send_event(SleepEvent::new(tx, duration))
.unwrap();
rx.recv().await.unwrap();

rx.await.unwrap();
}

async fn non_advancing_sleep(&self, duration: tokio::time::Duration) {
let mailbox = SimClock::mailbox().clone();
let (tx, rx) = mailbox.open_once_port::<()>();
let (tx, rx) = tokio::sync::oneshot::channel::<()>();

simnet_handle()
.unwrap()
.send_nonadvanceable_event(SleepEvent::new(
tx.bind(),
mailbox,
duration.as_millis() as u64,
))
.send_nonadvanceable_event(SleepEvent::new(tx, duration))
.unwrap();
rx.recv().await.unwrap();

rx.await.unwrap();
}

async fn sleep_until(&self, deadline: tokio::time::Instant) {
Expand All @@ -242,23 +265,18 @@ impl Clock for SimClock {
where
F: std::future::Future<Output = T>,
{
let mailbox = SimClock::mailbox().clone();
let (tx, deadline_rx) = mailbox.open_once_port::<()>();
let (tx, deadline_rx) = tokio::sync::oneshot::channel::<()>();

simnet_handle()
.unwrap()
.send_event(SleepEvent::new(
tx.bind(),
mailbox,
duration.as_millis() as u64,
))
.send_event(SleepEvent::new(tx, duration))
.unwrap();

let fut = f;
pin_mut!(fut);

tokio::select! {
_ = deadline_rx.recv() => {
_ = deadline_rx => {
Err(TimeoutError)
}
res = &mut fut => Ok(res)
Expand All @@ -267,37 +285,20 @@ impl Clock for SimClock {
}

impl SimClock {
// TODO (SF, 2025-07-11): Remove this global, thread through a mailbox
// from upstack and handle undeliverable messages properly.
fn mailbox() -> &'static Mailbox {
static SIMCLOCK_MAILBOX: OnceLock<Mailbox> = OnceLock::new();
SIMCLOCK_MAILBOX.get_or_init(|| {
let mailbox = Mailbox::new_detached(id!(proc[0].proc).clone());
let (undeliverable_messages, mut rx) =
mailbox.open_port::<Undeliverable<MessageEnvelope>>();
undeliverable_messages.bind_to(Undeliverable::<MessageEnvelope>::port());
tokio::spawn(async move {
while let Ok(Undeliverable(mut envelope)) = rx.recv().await {
envelope.try_set_error(DeliveryError::BrokenLink(
"message returned to undeliverable port".to_string(),
));
UndeliverableMailboxSender
.post(envelope, /*unused */ monitored_return_handle())
}
});
mailbox
})
}

/// Advance the sumulator's time to the specified instant
pub fn advance_to(&self, millis: u64) {
pub fn advance_to(&self, time: tokio::time::Instant) {
let mut guard = SIM_TIME.now.lock().unwrap();
*guard = SIM_TIME.start + tokio::time::Duration::from_millis(millis);
*guard = time;
}

/// Get the number of milliseconds elapsed since the start of the simulation
pub fn millis_since_start(&self, instant: tokio::time::Instant) -> u64 {
instant.duration_since(SIM_TIME.start).as_millis() as u64
pub fn duration_since_start(&self, instant: tokio::time::Instant) -> tokio::time::Duration {
instant.duration_since(SIM_TIME.start)
}

/// Instant marking the start of the simulation
pub fn start(&self) -> tokio::time::Instant {
SIM_TIME.start.clone()
}
}

Expand Down Expand Up @@ -347,10 +348,16 @@ mod tests {
#[tokio::test]
async fn test_sim_clock_simple() {
let start = SimClock.now();
assert_eq!(SimClock.millis_since_start(start), 0);
SimClock.advance_to(10000);
assert_eq!(
SimClock.duration_since_start(start),
tokio::time::Duration::ZERO
);
SimClock.advance_to(SimClock.start() + tokio::time::Duration::from_millis(10000));
let end = SimClock.now();
assert_eq!(SimClock.millis_since_start(end), 10000);
assert_eq!(
SimClock.duration_since_start(end),
tokio::time::Duration::from_millis(10000)
);
assert_eq!(
end.duration_since(start),
tokio::time::Duration::from_secs(10)
Expand All @@ -360,7 +367,7 @@ mod tests {
#[tokio::test]
async fn test_sim_clock_system_time() {
let start = SimClock.system_time_now();
SimClock.advance_to(10000);
SimClock.advance_to(SimClock.start() + tokio::time::Duration::from_millis(10000));
let end = SimClock.system_time_now();
assert_eq!(
end.duration_since(start).unwrap(),
Expand Down
Loading