Skip to content

Commit c9d2505

Browse files
actor: move more port receiver supervision to rust (#578)
Summary: Pull Request resolved: #578 this diff moves supervision logic from python into rust, aligning with the goal of eliminating complex supervision wiring in python. the essential change is that: ``` class ActorEndpoint(...): def _port(self, once: bool = False) -> "PortTuple[R]": monitor = ( None if self._actor_mesh._actor_mesh is None else self._actor_mesh._actor_mesh.monitor() ) return PortTuple.create(self._mailbox, monitor, once) ``` becomes: ``` class ActorEndpoint(...): def _port(self, once: bool = False) -> PortTuple[R]: p, r = PortTuple.create(self._mailbox, once) return PortTuple(p, PortReceiver(self._mailbox, self._supervise(r._receiver))) ``` `_supervise(...)` dispatches to new Rust helpers: ```python mesh.supervise_port(...) and mesh.supervise_once_port(...) ``` which wrap the receivers with supervision logic (including selection between message arrival and supervision events), completely eliminating the need for python-side constructs like `ActorMeshMonitor`. most of the python complexity introduced in D77434080 is removed. the only meaningful addition is `_supervise(...)`, a small overrideable hook that defaults to a no-op and cleanly delegates to rust when supervision is desired. - the creation and wiring of the monitor stream is now fully in rust. - the responsibility of wrapping receivers with supervision is now fully in rust. - python no longer constructs or passes supervision monitors; rust now owns the full wiring, and python receives already-wrapped receivers with supervision behavior embedded this is a strict improvement: lower complexity, cleaner override points and supervision is entirely managed in rust. Differential Revision: D78528860
1 parent 9bba53b commit c9d2505

File tree

6 files changed

+111
-149
lines changed

6 files changed

+111
-149
lines changed

monarch_hyperactor/src/actor_mesh.rs

Lines changed: 58 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,32 @@ impl PythonActorMesh {
176176
.map(PyActorId::from))
177177
}
178178

179-
// Start monitoring the actor mesh by subscribing to its supervision events. For each supervision
180-
// event, it is consumed by PythonActorMesh first, then gets sent to the monitor for user to consume.
181-
fn monitor<'py>(&self, py: Python<'py>) -> PyResult<PyObject> {
182-
let receiver = self.user_monitor_sender.subscribe();
183-
let monitor_instance = PyActorMeshMonitor {
184-
receiver: SharedCell::from(Mutex::new(receiver)),
179+
fn supervise_port<'py>(
180+
&self,
181+
py: Python<'py>,
182+
receiver: &PythonPortReceiver,
183+
) -> PyResult<PyObject> {
184+
let rx = MonitoredPythonPortReceiver {
185+
inner: receiver.inner(),
186+
monitor: PyActorMeshMonitor {
187+
receiver: SharedCell::from(Mutex::new(self.user_monitor_sender.subscribe())),
188+
},
189+
};
190+
rx.into_py_any(py)
191+
}
192+
193+
fn supervise_once_port<'py>(
194+
&self,
195+
py: Python<'py>,
196+
receiver: &PythonOncePortReceiver,
197+
) -> PyResult<PyObject> {
198+
let rx = MonitoredPythonOncePortReceiver {
199+
inner: receiver.inner(),
200+
monitor: PyActorMeshMonitor {
201+
receiver: SharedCell::from(Mutex::new(self.user_monitor_sender.subscribe())),
202+
},
185203
};
186-
Ok(monitor_instance.into_py(py))
204+
rx.into_py_any(py)
187205
}
188206

189207
#[pyo3(signature = (**kwargs))]
@@ -335,29 +353,43 @@ impl Drop for PythonActorMesh {
335353
}
336354
}
337355

356+
// `PyActorMeshMonitor` is not accessed directly from Python. It is
357+
// marked with `#[pyclass]` so it can be used as a field inside
358+
// `MonitoredPythonPortReceiver`.
338359
#[pyclass(
339360
name = "ActorMeshMonitor",
340361
module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
341362
)]
342-
pub struct PyActorMeshMonitor {
363+
struct PyActorMeshMonitor {
343364
receiver: SharedCell<Mutex<tokio::sync::broadcast::Receiver<Option<ActorSupervisionEvent>>>>,
344365
}
345366

346367
#[pymethods]
347368
impl PyActorMeshMonitor {
348-
fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
349-
slf
350-
}
351-
352-
pub fn __anext__(&self, py: Python<'_>) -> PyResult<PyObject> {
353-
let receiver = self.receiver.clone();
354-
Ok(pyo3_async_runtimes::tokio::future_into_py(py, get_next(receiver))?.into())
369+
fn __repr__(&self) -> &'static str {
370+
"<ActorMeshMonitor>"
355371
}
356372
}
357373

358374
impl PyActorMeshMonitor {
359375
pub async fn next(&self) -> PyResult<PyObject> {
360-
get_next(self.receiver.clone()).await
376+
let receiver = self.receiver.clone();
377+
let receiver = receiver
378+
.borrow()
379+
.expect("`Actor mesh receiver` is shutdown");
380+
let mut receiver = receiver.lock().await;
381+
let event = receiver.recv().await.unwrap();
382+
let supervision_event = match event {
383+
None => PyActorSupervisionEvent {
384+
// Dummy actor as place holder to indicate the whole mesh is stopped
385+
// TODO(albertli): remove this when pushing all supervision logic to rust.
386+
actor_id: id!(default[0].actor[0]).into(),
387+
actor_status: "actor mesh is stopped due to proc mesh shutdown".to_string(),
388+
},
389+
Some(event) => PyActorSupervisionEvent::from(event.clone()),
390+
};
391+
392+
Python::with_gil(|py| supervision_event.into_py_any(py))
361393
}
362394
}
363395

@@ -369,49 +401,21 @@ impl Clone for PyActorMeshMonitor {
369401
}
370402
}
371403

372-
async fn get_next(
373-
receiver: SharedCell<Mutex<tokio::sync::broadcast::Receiver<Option<ActorSupervisionEvent>>>>,
374-
) -> PyResult<PyObject> {
375-
let receiver = receiver.clone();
376-
377-
let receiver = receiver
378-
.borrow()
379-
.expect("`Actor mesh receiver` is shutdown");
380-
let mut receiver = receiver.lock().await;
381-
let event = receiver.recv().await.unwrap();
382-
383-
let supervision_event = match event {
384-
None => PyActorSupervisionEvent {
385-
// Dummy actor as place holder to indicate the whole mesh is stopped
386-
// TODO(albertli): remove this when pushing all supervision logic to rust.
387-
actor_id: id!(default[0].actor[0]).into(),
388-
actor_status: "actor mesh is stopped due to proc mesh shutdown".to_string(),
389-
},
390-
Some(event) => PyActorSupervisionEvent::from(event.clone()),
391-
};
392-
393-
Python::with_gil(|py| supervision_event.into_py_any(py))
394-
}
395-
396-
// TODO(albertli): this is temporary remove this when pushing all supervision logic to rust.
404+
// Values of this (private) type can only be created by calling
405+
// `PythonActorMesh::supervise_port()`.
397406
#[pyclass(
398407
name = "MonitoredPortReceiver",
399408
module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
400409
)]
401-
pub(super) struct MonitoredPythonPortReceiver {
410+
struct MonitoredPythonPortReceiver {
402411
inner: Arc<tokio::sync::Mutex<PortReceiver<PythonMessage>>>,
403412
monitor: PyActorMeshMonitor,
404413
}
405414

406415
#[pymethods]
407416
impl MonitoredPythonPortReceiver {
408-
#[new]
409-
fn new(receiver: &PythonPortReceiver, monitor: &PyActorMeshMonitor) -> Self {
410-
let inner = receiver.inner();
411-
MonitoredPythonPortReceiver {
412-
inner,
413-
monitor: monitor.clone(),
414-
}
417+
fn __repr__(&self) -> &'static str {
418+
"<MonitoredPortReceiver>"
415419
}
416420

417421
fn recv<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
@@ -447,24 +451,21 @@ impl MonitoredPythonPortReceiver {
447451
}
448452
}
449453

454+
// Values of this (private) type can only be created by calling
455+
// `PythonActorMesh::supervise_once_port()`.
450456
#[pyclass(
451457
name = "MonitoredOncePortReceiver",
452458
module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
453459
)]
454-
pub(super) struct MonitoredPythonOncePortReceiver {
460+
struct MonitoredPythonOncePortReceiver {
455461
inner: Arc<std::sync::Mutex<Option<OncePortReceiver<PythonMessage>>>>,
456462
monitor: PyActorMeshMonitor,
457463
}
458464

459465
#[pymethods]
460466
impl MonitoredPythonOncePortReceiver {
461-
#[new]
462-
fn new(receiver: &PythonOncePortReceiver, monitor: &PyActorMeshMonitor) -> Self {
463-
let inner = receiver.inner();
464-
MonitoredPythonOncePortReceiver {
465-
inner,
466-
monitor: monitor.clone(),
467-
}
467+
fn __repr__(&self) -> &'static str {
468+
"<MonitoredOncePortReceiver>"
468469
}
469470

470471
fn recv<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {

python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,15 @@ class PythonActorMesh:
8888
"""
8989
...
9090

91-
# TODO(albertli): remove this when pushing all supervision logic to Rust
92-
def monitor(self) -> ActorMeshMonitor:
91+
def supervise_port(self, r: PortReceiver) -> MonitoredPortReceiver:
9392
"""
94-
Returns a supervision monitor for this mesh.
93+
Return a monitored port receiver.
94+
"""
95+
...
96+
97+
def supervise_once_port(self, r: OncePortReceiver) -> MonitoredOncePortReceiver:
98+
"""
99+
Return a monitored once port receiver.
95100
"""
96101
...
97102

@@ -113,31 +118,11 @@ class PythonActorMesh:
113118
"""
114119
...
115120

116-
@final
117-
class ActorMeshMonitor:
118-
def __aiter__(self) -> AsyncIterator["ActorSupervisionEvent"]:
119-
"""
120-
Returns an async iterator for this monitor.
121-
"""
122-
...
123-
124-
async def __anext__(self) -> "ActorSupervisionEvent":
125-
"""
126-
Returns the next proc event in the proc mesh.
127-
"""
128-
...
129-
130121
@final
131122
class MonitoredPortReceiver:
123+
"""A monitored receiver to which PythonMessages are sent. Values
124+
of this type cannot be constructed directly in Python.
132125
"""
133-
A monitored receiver to which PythonMessages are sent.
134-
"""
135-
136-
def __init__(self, receiver: PortReceiver, monitor: ActorMeshMonitor) -> None:
137-
"""
138-
Create a new monitored receiver from a PortReceiver.
139-
"""
140-
...
141126

142127
async def recv(self) -> PythonMessage:
143128
"""Receive a PythonMessage from the port's sender."""
@@ -148,15 +133,9 @@ class MonitoredPortReceiver:
148133

149134
@final
150135
class MonitoredOncePortReceiver:
136+
"""A monitored once receiver to which PythonMessages are sent.
137+
Values of this type cannot be constructed directly in Python.
151138
"""
152-
A variant of monitored PortReceiver that can only receive a single message.
153-
"""
154-
155-
def __init__(self, receiver: OncePortReceiver, monitor: ActorMeshMonitor) -> None:
156-
"""
157-
Create a new monitored receiver from a PortReceiver.
158-
"""
159-
...
160139

161140
async def recv(self) -> PythonMessage:
162141
"""Receive a single PythonMessage from the port's sender."""

python/monarch/_src/actor/actor_mesh.py

Lines changed: 37 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
Optional,
3939
overload,
4040
ParamSpec,
41+
Protocol,
42+
runtime_checkable,
4143
Sequence,
4244
Tuple,
4345
Type,
@@ -50,12 +52,7 @@
5052
PythonMessage,
5153
PythonMessageKind,
5254
)
53-
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import (
54-
ActorMeshMonitor,
55-
MonitoredOncePortReceiver,
56-
MonitoredPortReceiver,
57-
PythonActorMesh,
58-
)
55+
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
5956
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
6057
Mailbox,
6158
OncePortReceiver,
@@ -306,6 +303,9 @@ def _send(
306303
def _port(self, once: bool = False) -> "PortTuple[R]":
307304
pass
308305

306+
def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any:
307+
return r
308+
309309
# the following are all 'adverbs' or different ways to handle the
310310
# return values of this endpoint. Adverbs should only ever take *args, **kwargs
311311
# of the original call. If we want to add syntax sugar for something that needs additional
@@ -399,6 +399,16 @@ def __init__(
399399
self._signature: inspect.Signature = inspect.signature(impl)
400400
self._mailbox = mailbox
401401

402+
def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any:
403+
mesh = self._actor_mesh._actor_mesh
404+
if mesh is None:
405+
return r
406+
return (
407+
mesh.supervise_once_port(r)
408+
if isinstance(r, OncePortReceiver)
409+
else mesh.supervise_port(r)
410+
)
411+
402412
def _send(
403413
self,
404414
args: Tuple[Any, ...],
@@ -430,12 +440,11 @@ def _send(
430440
return Extent(shape.labels, shape.ndslice.sizes)
431441

432442
def _port(self, once: bool = False) -> "PortTuple[R]":
433-
monitor = (
434-
None
435-
if self._actor_mesh._actor_mesh is None
436-
else self._actor_mesh._actor_mesh.monitor()
437-
)
438-
return PortTuple.create(self._mailbox, monitor, once)
443+
p, r = PortTuple.create(self._mailbox, once)
444+
assert isinstance(
445+
r._receiver, (HyPortReceiver, OncePortReceiver)
446+
), "unexpected receiver type"
447+
return PortTuple(p, PortReceiver(self._mailbox, self._supervise(r._receiver)))
439448

440449

441450
class Accumulator(Generic[P, R, A]):
@@ -589,21 +598,11 @@ class PortTuple(NamedTuple, Generic[R]):
589598
receiver: "PortReceiver[R]"
590599

591600
@staticmethod
592-
def create(
593-
mailbox: Mailbox, monitor: Optional[ActorMeshMonitor], once: bool = False
594-
) -> "PortTuple[Any]":
601+
def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]":
595602
handle, receiver = mailbox.open_once_port() if once else mailbox.open_port()
596603
port_ref = handle.bind()
597-
if monitor is not None:
598-
receiver = (
599-
MonitoredOncePortReceiver(receiver, monitor)
600-
if isinstance(receiver, OncePortReceiver)
601-
else MonitoredPortReceiver(receiver, monitor)
602-
)
603-
604604
return PortTuple(
605-
Port(port_ref, mailbox, rank=None),
606-
PortReceiver(mailbox, receiver),
605+
Port(port_ref, mailbox, rank=None), PortReceiver(mailbox, receiver)
607606
)
608607
else:
609608

@@ -612,21 +611,11 @@ class PortTuple(NamedTuple):
612611
receiver: "PortReceiver[Any]"
613612

614613
@staticmethod
615-
def create(
616-
mailbox: Mailbox, monitor: Optional[ActorMeshMonitor], once: bool = False
617-
) -> "PortTuple[Any]":
614+
def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]":
618615
handle, receiver = mailbox.open_once_port() if once else mailbox.open_port()
619616
port_ref = handle.bind()
620-
if monitor is not None:
621-
receiver = (
622-
MonitoredOncePortReceiver(receiver, monitor)
623-
if isinstance(receiver, OncePortReceiver)
624-
else MonitoredPortReceiver(receiver, monitor)
625-
)
626-
627617
return PortTuple(
628-
Port(port_ref, mailbox, rank=None),
629-
PortReceiver(mailbox, receiver),
618+
Port(port_ref, mailbox, rank=None), PortReceiver(mailbox, receiver)
630619
)
631620

632621

@@ -644,22 +633,19 @@ def ranked_port(
644633
return p, RankedPortReceiver[R](receiver._mailbox, receiver._receiver)
645634

646635

636+
R = TypeVar("R")
637+
638+
639+
@runtime_checkable
640+
class ReceiverLike(Protocol[R]):
641+
def blocking_recv(self) -> R: ...
642+
async def recv(self) -> R: ...
643+
644+
647645
class PortReceiver(Generic[R]):
648-
def __init__(
649-
self,
650-
mailbox: Mailbox,
651-
receiver: MonitoredPortReceiver
652-
| MonitoredOncePortReceiver
653-
| HyPortReceiver
654-
| OncePortReceiver,
655-
) -> None:
656-
self._mailbox: Mailbox = mailbox
657-
self._receiver: (
658-
MonitoredPortReceiver
659-
| MonitoredOncePortReceiver
660-
| HyPortReceiver
661-
| OncePortReceiver
662-
) = receiver
646+
def __init__(self, mailbox: Mailbox, receiver: ReceiverLike) -> None:
647+
self._mailbox = mailbox
648+
self._receiver: ReceiverLike = receiver
663649

664650
async def _recv(self) -> R:
665651
return self._process(await self._receiver.recv())

0 commit comments

Comments
 (0)