diff --git a/hyperactor/src/channel/sim.rs b/hyperactor/src/channel/sim.rs index 41a41487e..1ec2393f3 100644 --- a/hyperactor/src/channel/sim.rs +++ b/hyperactor/src/channel/sim.rs @@ -22,7 +22,6 @@ use super::*; use crate::channel; use crate::clock::Clock; use crate::clock::RealClock; -use crate::clock::SimClock; use crate::data::Serialized; use crate::mailbox::MessageEnvelope; use crate::simnet; @@ -129,7 +128,7 @@ pub(crate) struct MessageDeliveryEvent { src_addr: Option, dest_addr: ChannelAddr, data: Serialized, - duration_ms: u64, + duration: tokio::time::Duration, } impl MessageDeliveryEvent { @@ -139,14 +138,14 @@ impl MessageDeliveryEvent { src_addr, dest_addr, data, - duration_ms: 100, + duration: tokio::time::Duration::from_millis(100), } } } #[async_trait] impl Event for MessageDeliveryEvent { - async fn handle(&self) -> Result<(), SimNetError> { + async fn handle(&mut self) -> Result<(), SimNetError> { // Send the message to the correct receiver. SENDER .send( @@ -158,8 +157,8 @@ impl Event for MessageDeliveryEvent { Ok(()) } - fn duration_ms(&self) -> u64 { - self.duration_ms + fn duration(&self) -> tokio::time::Duration { + self.duration } fn summary(&self) -> String { @@ -178,12 +177,12 @@ impl Event for MessageDeliveryEvent { src: src_addr.clone(), dst: self.dest_addr.clone(), }; - self.duration_ms = topology + self.duration = topology .lock() .await .topology .get(&edge) - .map_or_else(|| 1, |v| v.latency.as_millis() as u64); + .map_or_else(|| tokio::time::Duration::from_millis(1), |v| v.latency); } } } @@ -332,7 +331,7 @@ impl Tx for SimTx { self.dst_addr.clone(), data, )), - time: SimClock.millis_since_start(RealClock.now()), + time: RealClock.now(), }), _ => handle.send_event(Box::new(MessageDeliveryEvent::new( self.src_addr.clone(), @@ -551,7 +550,10 @@ mod tests { .await .unwrap(); - assert_eq!(SimClock.millis_since_start(RealClock.now()), 0); + assert_eq!( + SimClock.duration_since_start(RealClock.now()), + tokio::time::Duration::ZERO + ); // Fast forward real time to 5 seconds tokio::time::advance(tokio::time::Duration::from_secs(5)).await; { diff --git a/hyperactor/src/clock.rs b/hyperactor/src/clock.rs index 84c572ecc..b05152e63 100644 --- a/hyperactor/src/clock.rs +++ b/hyperactor/src/clock.rs @@ -12,25 +12,17 @@ use std::error::Error; use std::fmt; use std::sync::LazyLock; use std::sync::Mutex; -use std::sync::OnceLock; use std::time::SystemTime; +use async_trait::async_trait; use futures::pin_mut; use hyperactor_telemetry::TelemetryClock; use serde::Deserialize; use serde::Serialize; -use crate::Mailbox; use crate::channel::ChannelAddr; -use crate::data::Named; -use crate::id; -use crate::mailbox::DeliveryError; -use crate::mailbox::MailboxSender; -use crate::mailbox::MessageEnvelope; -use crate::mailbox::Undeliverable; -use crate::mailbox::UndeliverableMailboxSender; -use crate::mailbox::monitored_return_handle; -use crate::simnet::SleepEvent; +use crate::simnet::Event; +use crate::simnet::SimNetError; use crate::simnet::simnet_handle; struct SimTime { @@ -183,6 +175,45 @@ impl ClockKind { } } +#[derive(Debug)] +struct SleepEvent { + done_tx: Option>, + duration: tokio::time::Duration, +} + +impl SleepEvent { + pub(crate) fn new( + done_tx: tokio::sync::oneshot::Sender<()>, + duration: tokio::time::Duration, + ) -> Box { + Box::new(Self { + done_tx: Some(done_tx), + duration, + }) + } +} + +#[async_trait] +impl Event for SleepEvent { + async fn handle(&mut self) -> Result<(), SimNetError> { + self.done_tx + .take() + .unwrap() + .send(()) + .map_err(|_| SimNetError::PanickedTask)?; + + Ok(()) + } + + fn duration(&self) -> tokio::time::Duration { + self.duration + } + + fn summary(&self) -> String { + format!("Sleeping for {} ms", self.duration.as_millis()) + } +} + /// Clock to be used in simulator runs that allows the simnet to create a scheduled event for. /// When the wakeup event becomes the next earliest scheduled event, the simnet will advance it's /// time to the wakeup time and use the transmitter to wake up this green thread @@ -192,33 +223,25 @@ pub struct SimClock; impl Clock for SimClock { /// Tell the simnet to wake up this green thread after the specified duration has pass on the simnet async fn sleep(&self, duration: tokio::time::Duration) { - let mailbox = SimClock::mailbox().clone(); - let (tx, rx) = mailbox.open_once_port::<()>(); + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); simnet_handle() .unwrap() - .send_event(SleepEvent::new( - tx.bind(), - mailbox, - duration.as_millis() as u64, - )) + .send_event(SleepEvent::new(tx, duration)) .unwrap(); - rx.recv().await.unwrap(); + + rx.await.unwrap(); } async fn non_advancing_sleep(&self, duration: tokio::time::Duration) { - let mailbox = SimClock::mailbox().clone(); - let (tx, rx) = mailbox.open_once_port::<()>(); + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); simnet_handle() .unwrap() - .send_nonadvanceable_event(SleepEvent::new( - tx.bind(), - mailbox, - duration.as_millis() as u64, - )) + .send_nonadvanceable_event(SleepEvent::new(tx, duration)) .unwrap(); - rx.recv().await.unwrap(); + + rx.await.unwrap(); } async fn sleep_until(&self, deadline: tokio::time::Instant) { @@ -242,23 +265,18 @@ impl Clock for SimClock { where F: std::future::Future, { - let mailbox = SimClock::mailbox().clone(); - let (tx, deadline_rx) = mailbox.open_once_port::<()>(); + let (tx, deadline_rx) = tokio::sync::oneshot::channel::<()>(); simnet_handle() .unwrap() - .send_event(SleepEvent::new( - tx.bind(), - mailbox, - duration.as_millis() as u64, - )) + .send_event(SleepEvent::new(tx, duration)) .unwrap(); let fut = f; pin_mut!(fut); tokio::select! { - _ = deadline_rx.recv() => { + _ = deadline_rx => { Err(TimeoutError) } res = &mut fut => Ok(res) @@ -267,37 +285,20 @@ impl Clock for SimClock { } impl SimClock { - // TODO (SF, 2025-07-11): Remove this global, thread through a mailbox - // from upstack and handle undeliverable messages properly. - fn mailbox() -> &'static Mailbox { - static SIMCLOCK_MAILBOX: OnceLock = OnceLock::new(); - SIMCLOCK_MAILBOX.get_or_init(|| { - let mailbox = Mailbox::new_detached(id!(proc[0].proc).clone()); - let (undeliverable_messages, mut rx) = - mailbox.open_port::>(); - undeliverable_messages.bind_to(Undeliverable::::port()); - tokio::spawn(async move { - while let Ok(Undeliverable(mut envelope)) = rx.recv().await { - envelope.try_set_error(DeliveryError::BrokenLink( - "message returned to undeliverable port".to_string(), - )); - UndeliverableMailboxSender - .post(envelope, /*unused */ monitored_return_handle()) - } - }); - mailbox - }) - } - /// Advance the sumulator's time to the specified instant - pub fn advance_to(&self, millis: u64) { + pub fn advance_to(&self, time: tokio::time::Instant) { let mut guard = SIM_TIME.now.lock().unwrap(); - *guard = SIM_TIME.start + tokio::time::Duration::from_millis(millis); + *guard = time; } /// Get the number of milliseconds elapsed since the start of the simulation - pub fn millis_since_start(&self, instant: tokio::time::Instant) -> u64 { - instant.duration_since(SIM_TIME.start).as_millis() as u64 + pub fn duration_since_start(&self, instant: tokio::time::Instant) -> tokio::time::Duration { + instant.duration_since(SIM_TIME.start) + } + + /// Instant marking the start of the simulation + pub fn start(&self) -> tokio::time::Instant { + SIM_TIME.start.clone() } } @@ -347,10 +348,16 @@ mod tests { #[tokio::test] async fn test_sim_clock_simple() { let start = SimClock.now(); - assert_eq!(SimClock.millis_since_start(start), 0); - SimClock.advance_to(10000); + assert_eq!( + SimClock.duration_since_start(start), + tokio::time::Duration::ZERO + ); + SimClock.advance_to(SimClock.start() + tokio::time::Duration::from_millis(10000)); let end = SimClock.now(); - assert_eq!(SimClock.millis_since_start(end), 10000); + assert_eq!( + SimClock.duration_since_start(end), + tokio::time::Duration::from_millis(10000) + ); assert_eq!( end.duration_since(start), tokio::time::Duration::from_secs(10) @@ -360,7 +367,7 @@ mod tests { #[tokio::test] async fn test_sim_clock_system_time() { let start = SimClock.system_time_now(); - SimClock.advance_to(10000); + SimClock.advance_to(SimClock.start() + tokio::time::Duration::from_millis(10000)); let end = SimClock.system_time_now(); assert_eq!( end.duration_since(start).unwrap(), diff --git a/hyperactor/src/simnet.rs b/hyperactor/src/simnet.rs index def1fc9f5..19db9691a 100644 --- a/hyperactor/src/simnet.rs +++ b/hyperactor/src/simnet.rs @@ -70,7 +70,7 @@ const OPERATIONAL_MESSAGE_BUFFER_SIZE: usize = 8; pub trait Address: Hash + Debug + Eq + PartialEq + Ord + PartialOrd + Clone {} impl Address for A {} -type SimulatorTimeInstant = u64; +type SimulatorTimeInstant = tokio::time::Instant; /// The unit of execution for the simulator. /// Using handle(), simnet can schedule executions in the network. @@ -88,18 +88,18 @@ pub trait Event: Send + Sync + Debug { /// For a proc spawn, it will be creating the proc object and instantiating it. /// For any event that manipulates the network (like adding/removing nodes etc.) /// implement handle_network(). - async fn handle(&self) -> Result<(), SimNetError>; + async fn handle(&mut self) -> Result<(), SimNetError>; /// This is the method that will be called when the simulator fires the event /// Unless you need to make changes to the network, you do not have to implement this. /// Only implement handle() method for all non-simnet requirements. - async fn handle_network(&self, _phantom: &SimNet) -> Result<(), SimNetError> { + async fn handle_network(&mut self, _phantom: &SimNet) -> Result<(), SimNetError> { self.handle().await } /// The latency of the event. This could be network latency, induced latency (sleep), or /// GPU work latency. - fn duration_ms(&self) -> u64; + fn duration(&self) -> tokio::time::Duration; /// Read the simnet config and update self accordingly. async fn read_simnet_config(&mut self, _topology: &Arc>) {} @@ -117,17 +117,17 @@ struct NodeJoinEvent { #[async_trait] impl Event for NodeJoinEvent { - async fn handle(&self) -> Result<(), SimNetError> { + async fn handle(&mut self) -> Result<(), SimNetError> { Ok(()) } - async fn handle_network(&self, simnet: &SimNet) -> Result<(), SimNetError> { + async fn handle_network(&mut self, simnet: &SimNet) -> Result<(), SimNetError> { simnet.bind(self.channel_addr.clone()).await; self.handle().await } - fn duration_ms(&self) -> u64 { - 0 + fn duration(&self) -> tokio::time::Duration { + tokio::time::Duration::ZERO } fn summary(&self) -> String { @@ -135,46 +135,6 @@ impl Event for NodeJoinEvent { } } -#[derive(Debug)] -pub(crate) struct SleepEvent { - done_tx: OncePortRef<()>, - mailbox: Mailbox, - duration_ms: u64, -} - -impl SleepEvent { - pub(crate) fn new(done_tx: OncePortRef<()>, mailbox: Mailbox, duration_ms: u64) -> Box { - Box::new(Self { - done_tx, - mailbox, - duration_ms, - }) - } -} - -#[async_trait] -impl Event for SleepEvent { - async fn handle(&self) -> Result<(), SimNetError> { - Ok(()) - } - - async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> { - self.done_tx - .clone() - .send(&self.mailbox, ()) - .map_err(|_err| SimNetError::Closed("TODO".to_string()))?; - Ok(()) - } - - fn duration_ms(&self) -> u64 { - self.duration_ms - } - - fn summary(&self) -> String { - format!("Sleeping for {} ms", self.duration_ms) - } -} - #[derive(Debug)] /// A pytorch operation pub struct TorchOpEvent { @@ -188,11 +148,11 @@ pub struct TorchOpEvent { #[async_trait] impl Event for TorchOpEvent { - async fn handle(&self) -> Result<(), SimNetError> { + async fn handle(&mut self) -> Result<(), SimNetError> { Ok(()) } - async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> { + async fn handle_network(&mut self, _simnet: &SimNet) -> Result<(), SimNetError> { self.done_tx .clone() .send(&self.mailbox, ()) @@ -200,8 +160,8 @@ impl Event for TorchOpEvent { Ok(()) } - fn duration_ms(&self) -> u64 { - 100 + fn duration(&self) -> tokio::time::Duration { + tokio::time::Duration::from_millis(100) } fn summary(&self) -> String { @@ -304,6 +264,10 @@ pub enum SimNetError { /// SimnetHandle being accessed without starting simnet #[error("simnet not started")] NotStarted, + + /// A task has panicked. + #[error("panicked task")] + PanickedTask, } struct State { @@ -561,22 +525,20 @@ impl SimNet { // Get latency event.read_simnet_config(&self.config).await; ScheduledEvent { - time: SimClock.millis_since_start( - SimClock.now() + tokio::time::Duration::from_millis(event.duration_ms()), - ), + time: SimClock.now() + event.duration(), event, } } /// Schedule the event into the network. fn schedule_event(&mut self, scheduled_event: ScheduledEvent, advanceable: bool) { - let start_at = SimClock.millis_since_start(SimClock.now()); + let start_at = SimClock.now(); let end_at = scheduled_event.time; self.records.push(SimulatorEventRecord { summary: scheduled_event.event.summary(), - start_at, - end_at, + start_at: SimClock.duration_since_start(start_at).as_millis() as u64, + end_at: SimClock.duration_since_start(end_at).as_millis() as u64, }); if advanceable { @@ -604,7 +566,7 @@ impl SimNet { ) -> Vec { // The simulated number of milliseconds the training script // has spent waiting for the backend to resolve a future - let mut training_script_waiting_time: u64 = 0; + let mut training_script_waiting_time = tokio::time::Duration::from_millis(0); // Duration elapsed while only non_advanceable_events has events let mut debounce_timer: Option = None; 'outer: loop { @@ -638,9 +600,7 @@ impl SimNet { .scheduled_events .first_key_value() .is_some_and(|(time, _)| { - *time - > SimClock.millis_since_start(RealClock.now()) - + training_script_waiting_time + *time > RealClock.now() + training_script_waiting_time }) { tokio::task::yield_now().await; @@ -705,12 +665,11 @@ impl SimNet { continue; }; if training_script_state_rx.borrow().is_waiting() { - let advanced_time = - scheduled_time - SimClock.millis_since_start(SimClock.now()); + let advanced_time = scheduled_time - SimClock.now(); training_script_waiting_time += advanced_time; } SimClock.advance_to(scheduled_time); - for scheduled_event in scheduled_events { + for mut scheduled_event in scheduled_events { self.pending_event_count .fetch_sub(1, std::sync::atomic::Ordering::SeqCst); if scheduled_event.event.handle_network(self).await.is_err() { @@ -749,9 +708,9 @@ pub struct SimulatorEventRecord { /// Event dependent summary for user pub summary: String, /// The time at which the message delivery was started. - pub start_at: SimulatorTimeInstant, + pub start_at: u64, /// The time at which the message was delivered to the receiver. - pub end_at: SimulatorTimeInstant, + pub end_at: u64, } /// A configuration for the network topology. @@ -805,13 +764,13 @@ mod tests { src_addr: SimAddr, dest_addr: SimAddr, data: Serialized, - duration_ms: u64, + duration: tokio::time::Duration, dispatcher: Option, } #[async_trait] impl Event for MessageDeliveryEvent { - async fn handle(&self) -> Result<(), simnet::SimNetError> { + async fn handle(&mut self) -> Result<(), simnet::SimNetError> { if let Some(dispatcher) = &self.dispatcher { dispatcher .send( @@ -823,8 +782,8 @@ mod tests { } Ok(()) } - fn duration_ms(&self) -> u64 { - self.duration_ms + fn duration(&self) -> tokio::time::Duration { + self.duration } fn summary(&self) -> String { @@ -840,12 +799,12 @@ mod tests { src: self.src_addr.addr().clone(), dst: self.dest_addr.addr().clone(), }; - self.duration_ms = config + self.duration = config .lock() .await .topology .get(&edge) - .map_or_else(|| 1, |v| v.latency.as_millis() as u64); + .map_or_else(|| tokio::time::Duration::from_millis(1), |v| v.latency); } } @@ -860,7 +819,7 @@ mod tests { src_addr, dest_addr, data, - duration_ms: 1, + duration: tokio::time::Duration::from_millis(1), dispatcher, } } @@ -1132,12 +1091,18 @@ mod tests { start(); let start = SimClock.now(); - assert_eq!(SimClock.millis_since_start(start), 0); + assert_eq!( + SimClock.duration_since_start(start), + tokio::time::Duration::ZERO + ); SimClock.sleep(tokio::time::Duration::from_secs(10)).await; let end = SimClock.now(); - assert_eq!(SimClock.millis_since_start(end), 10000); + assert_eq!( + SimClock.duration_since_start(end), + tokio::time::Duration::from_secs(10) + ); } #[tokio::test]