11use std:: collections:: { BinaryHeap , HashMap } ;
2- use std:: future:: Future as _;
32use std:: sync:: Arc ;
4- use std:: time:: Duration ;
5- use tokio:: sync:: { RwLock , mpsc} ;
3+ use tokio:: sync:: { mpsc, Notify , RwLock } ;
64use pin_project:: pin_project;
5+ use tracing:: debug;
76use std:: {
87 pin:: Pin ,
98 task:: { Context , Poll } ,
109} ;
1110use anyhow:: { Result , anyhow} ;
12- use tokio:: time:: { Instant , Sleep } ;
11+ use tokio:: time:: { sleep_until , Instant } ;
1312use std:: sync:: atomic:: { AtomicU64 , AtomicU8 , Ordering } ;
1413use std:: cmp:: Ordering as CmpOrdering ;
1514use futures:: stream:: Stream ;
@@ -120,6 +119,7 @@ pub(crate) struct EventScheduler {
120119 event_map : Arc < RwLock < HashMap < EventType , Arc < Event > > > > ,
121120 incoming_events : Arc < SegQueue < TimestampedEvent > > ,
122121 next_event_id : AtomicU64 ,
122+ new_event_notify : Arc < Notify > ,
123123}
124124
125125impl EventScheduler {
@@ -128,6 +128,7 @@ impl EventScheduler {
128128 incoming_events : Arc :: new ( SegQueue :: new ( ) ) ,
129129 event_map : Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ,
130130 next_event_id : AtomicU64 :: new ( 0 ) ,
131+ new_event_notify : Arc :: new ( Notify :: new ( ) ) ,
131132 }
132133 }
133134 /// Schedules a new event with a specified type and deadline.
@@ -154,40 +155,43 @@ impl EventScheduler {
154155 // Keep track of event in a map for interception / cancelling purposes
155156 self . event_map . write ( ) . await . insert ( event_type, event. clone ( ) ) ;
156157
158+ // Notify event stream of new event.
159+ self . new_event_notify . notify_one ( ) ;
160+
157161 Ok ( ( receiver, sender) )
158162 }
159163
160164 /// Intercepts a specified event, changing its state to Intercepted.
161- pub ( crate ) async fn intercept_event ( & self , event_type : & EventType ) -> bool {
165+ pub ( crate ) async fn intercept_event ( & self , event_type : & EventType ) -> Result < bool > {
162166 if let Some ( event) = self . event_map . read ( ) . await . get ( event_type) {
163- event. set_state ( EventState :: Intercepted )
167+ event. sender . send ( EventState :: Intercepted ) . await ?;
168+ Ok ( event. set_state ( EventState :: Intercepted ) )
164169 } else {
165- false
170+ Ok ( false )
166171 }
167172 }
168173 /// Cancels a specified event before it reaches deadline, this can be use for a different case
169174 /// where `intercept_event` is not ideal.
170- pub ( crate ) async fn cancel_event ( & self , event_type : & EventType ) -> bool {
175+ pub ( crate ) async fn cancel_event ( & self , event_type : & EventType ) -> Result < bool > {
171176 if let Some ( event) = self . event_map . read ( ) . await . get ( event_type) {
172- event. set_state ( EventState :: Cancelled )
177+ event. sender . send ( EventState :: Cancelled ) . await ?;
178+ Ok ( event. set_state ( EventState :: Cancelled ) )
173179 } else {
174- false
180+ Ok ( false )
175181 }
176182 }
177183
178184 // Remove event from event mapping.
179- async fn remove_event ( & self , event_type : & EventType ) -> Option < Arc < Event > > {
185+ pub ( crate ) async fn remove_event ( & self , event : & Event ) -> Option < Arc < Event > > {
180186 let mut event_map = self . event_map . write ( ) . await ;
181- event_map. remove ( event_type)
187+ event_map. remove ( & event . event_type )
182188 }
183189}
184190
185191#[ pin_project]
186192pub ( crate ) struct EventStream {
187193 scheduler : Arc < EventScheduler > ,
188194 lobby : BinaryHeap < TimestampedEvent > ,
189- #[ pin]
190- sleep : Pin < Box < Sleep > > ,
191195}
192196
193197impl EventStream {
@@ -196,7 +200,6 @@ impl EventStream {
196200 EventStream {
197201 scheduler,
198202 lobby : BinaryHeap :: new ( ) ,
199- sleep : Box :: pin ( tokio:: time:: sleep ( Duration :: from_secs ( 0 ) ) ) ,
200203 }
201204 }
202205 // Move event stream from events(`SegQueue`) from incoming to lobby for processing
@@ -217,17 +220,9 @@ impl EventStream {
217220 }
218221 moved_events
219222 }
220- // Peek into the lobby and calculate time for the next deadline.
221- fn time_until_next_event ( & self ) -> Option < Duration > {
222- self . lobby . peek ( ) . map ( |event| {
223- let now = Instant :: now ( ) ;
224- if event. deadline > now {
225- event. deadline - now
226- } else {
227- // do it now.
228- Duration :: from_secs ( 0 )
229- }
230- } )
223+ // Peek into the lobby and picks the deadline for the next event.
224+ fn deadline_of_next_event ( & self ) -> Option < Instant > {
225+ self . lobby . peek ( ) . map ( |event| event. deadline )
231226 }
232227}
233228
@@ -236,7 +231,6 @@ impl Stream for EventStream {
236231
237232 fn poll_next ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
238233 let moved_events = self . move_incoming_events_to_lobby ( ) ;
239-
240234 // Process events from the lobby(`BinaryHeap`).
241235 while let Some ( TimestampedEvent { deadline, event } ) = self . lobby . pop ( ) {
242236 // Double-check the deadline in case time has passed since moving events
@@ -256,59 +250,104 @@ impl Stream for EventStream {
256250 }
257251 }
258252
259- if let Some ( next_deadline) = self . time_until_next_event ( ) {
260- let deadline = Instant :: now ( ) + next_deadline;
261- // TODO: we shouldn't be setting a new pin.
262- let new_sleep = Box :: pin ( tokio:: time:: sleep_until ( deadline) ) ;
263- self . as_mut ( ) . project ( ) . sleep . as_mut ( ) . set ( new_sleep) ;
264- }
265253
266- // If we moved events but couldn't process any, we need to try again soon
267- if moved_events {
268- cx. waker ( ) . wake_by_ref ( ) ;
254+ let waker = cx. waker ( ) . clone ( ) ;
255+ let notify = self . scheduler . new_event_notify . clone ( ) ;
256+ if let Some ( next_deadline) = self . deadline_of_next_event ( ) {
257+ tokio:: spawn ( async move {
258+ tokio:: select! {
259+ _ = sleep_until( next_deadline) => { } ,
260+ _ = notify. notified( ) => { } ,
261+ }
262+ waker. wake ( ) ;
263+ } ) ;
269264 return Poll :: Pending ;
270265 }
271266
272- // Check if the sleep future is ready
273- match self . project ( ) . sleep . poll ( cx ) {
274- Poll :: Ready ( _ ) => {
275- cx . waker ( ) . wake_by_ref ( ) ;
276- Poll :: Pending
277- }
278- Poll :: Pending => Poll :: Pending ,
279- }
267+ // Schedule a task that will tell the runtime to check back on me
268+ // when i receive a new event notification.
269+ tokio :: spawn ( async move {
270+ notify . notified ( ) . await ;
271+ waker . wake ( ) ;
272+ } ) ;
273+
274+ return Poll :: Pending ;
280275 }
281276}
282277
283278
284279#[ cfg( test) ]
285280mod tests {
281+ use std:: pin:: pin;
286282 use super :: * ;
287- use tokio:: time:: { sleep , timeout } ;
283+ use tokio:: time:: { timeout , Duration } ;
288284
289285 #[ tokio:: test]
290286 async fn test_intercept_event ( ) {
291- let manager = EventScheduler :: new ( ) ;
287+ let scheduler = EventScheduler :: new ( ) ;
292288 let now = Instant :: now ( ) ;
293289 let event_type = EventType :: SuspectTimeout { node : "node" . to_string ( ) } ;
294290
295- let ( mut receiver, _) = manager . schedule_event ( event_type. clone ( ) , now + Duration :: from_secs ( 1 ) ) . await . unwrap ( ) ;
291+ let ( mut receiver, _) = scheduler . schedule_event ( event_type. clone ( ) , now + Duration :: from_secs ( 1 ) ) . await . unwrap ( ) ;
296292
297- assert ! ( manager . intercept_event( & event_type) . await ) ;
293+ assert ! ( scheduler . intercept_event( & event_type) . await . expect ( "should intercept" ) ) ;
298294
299295 let result = timeout ( Duration :: from_millis ( 100 ) , receiver. recv ( ) ) . await ;
300296 assert ! ( result. is_ok( ) ) ;
301297 assert_eq ! ( result. unwrap( ) , Some ( EventState :: Intercepted ) ) ;
302298 }
299+ #[ tokio:: test]
300+ async fn test_event_stream ( ) {
301+ let scheduler = Arc :: new ( EventScheduler :: new ( ) ) ;
302+ let event_stream = EventStream :: new ( scheduler. clone ( ) ) ;
303+ let mut stream = pin ! ( event_stream) ;
304+
305+ // Schedule multiple events
306+ let now = Instant :: now ( ) ;
307+ let event_type1 = EventType :: SuspectTimeout { node : "node1" . to_string ( ) } ;
308+ let event_type2 = EventType :: Ack { sequence_number : 1 } ;
309+ let event_type3 = EventType :: SuspectTimeout { node : "node2" . to_string ( ) } ;
310+
311+ scheduler. schedule_event ( event_type1. clone ( ) , now + Duration :: from_millis ( 100 ) ) . await . unwrap ( ) ;
312+ scheduler. schedule_event ( event_type2. clone ( ) , now + Duration :: from_millis ( 200 ) ) . await . unwrap ( ) ;
313+ scheduler. schedule_event ( event_type3. clone ( ) , now + Duration :: from_millis ( 300 ) ) . await . unwrap ( ) ;
314+
315+ // Use a timeout to prevent infinite waiting
316+ let timeout_duration = Duration :: from_secs ( 1 ) ;
317+
318+ // Aggregate events
319+ let mut received_events = Vec :: new ( ) ;
320+ while let Ok ( Some ( event) ) = timeout ( timeout_duration, futures:: StreamExt :: next ( & mut stream) ) . await {
321+ received_events. push ( event. event_type . clone ( ) ) ;
322+ if received_events. len ( ) == 3 {
323+ break ;
324+ }
325+ }
326+
327+ // Assert received events
328+ assert_eq ! ( received_events. len( ) , 3 , "Expected to receive 3 events" ) ;
329+ assert_eq ! ( received_events[ 0 ] , event_type1) ;
330+ assert_eq ! ( received_events[ 1 ] , event_type2) ;
331+ assert_eq ! ( received_events[ 2 ] , event_type3) ;
332+
333+ // Check event states
334+ for event_type in & [ event_type1, event_type2, event_type3] {
335+ if let Some ( event) = scheduler. event_map . read ( ) . await . get ( event_type) {
336+ assert_eq ! ( event. get_state( ) , EventState :: ReachedDeadline ) ;
337+ } else {
338+ panic ! ( "Event {:?} not found in the map" , event_type) ;
339+ }
340+ }
341+ }
303342
304343 #[ tokio:: test]
305344 async fn test_duplicate_event_type ( ) {
306- let manager = EventScheduler :: new ( ) ;
345+ let scheduler = EventScheduler :: new ( ) ;
307346 let now = Instant :: now ( ) ;
308347 let event_type = EventType :: SuspectTimeout { node : "node" . to_string ( ) } ;
309348
310- manager . schedule_event ( event_type. clone ( ) , now) . await . unwrap ( ) ;
311- let result = manager . schedule_event ( event_type. clone ( ) , now + Duration :: from_secs ( 1 ) ) . await ;
349+ scheduler . schedule_event ( event_type. clone ( ) , now) . await . unwrap ( ) ;
350+ let result = scheduler . schedule_event ( event_type. clone ( ) , now + Duration :: from_secs ( 1 ) ) . await ;
312351
313352 assert ! ( result. is_err( ) ) ;
314353 }
0 commit comments