Skip to content

Commit 98f7c5b

Browse files
: actor: port receiver supervision (#578)
Summary: this diff moves supervision logic from python into rust, aligning with the goal of eliminating complex supervision wiring in python. the essential change is that: ``` class ActorEndpoint(...): def _port(self, once: bool = False) -> "PortTuple[R]": monitor = ( None if self._actor_mesh._actor_mesh is None else self._actor_mesh._actor_mesh.monitor() ) return PortTuple.create(self._mailbox, monitor, once) ``` becomes: ``` class ActorEndpoint(...): def _port(self, once: bool = False) -> PortTuple[R]: p, r = PortTuple.create(self._mailbox, once) return PortTuple(p, PortReceiver(self._mailbox, self._supervise(r._receiver))) ``` `_supervise(...)` dispatches to new Rust helpers: ```python mesh.supervise_port(...) and mesh.supervise_once_port(...) ``` which wrap the receivers with supervision logic (including selection between message arrival and supervision events), completely eliminating the need for python-side constructs like `ActorMeshMonitor`. most of the python complexity introduced in D77434080 is removed. the only meaningful addition is `_supervise(...)`, a small overrideable hook that defaults to a no-op and cleanly delegates to rust when supervision is desired. this is a strict improvement: lower complexity, cleaner override points and supervision is entirely managed in rust. Differential Revision: D78528860
1 parent c463ecc commit 98f7c5b

File tree

5 files changed

+88
-126
lines changed

5 files changed

+88
-126
lines changed

monarch_hyperactor/src/actor_mesh.rs

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -175,14 +175,32 @@ impl PythonActorMesh {
175175
.map(PyActorId::from))
176176
}
177177

178-
// Start monitoring the actor mesh by subscribing to its supervision events. For each supervision
179-
// event, it is consumed by PythonActorMesh first, then gets sent to the monitor for user to consume.
180-
fn monitor<'py>(&self, py: Python<'py>) -> PyResult<PyObject> {
181-
let receiver = self.user_monitor_sender.subscribe();
182-
let monitor_instance = PyActorMeshMonitor {
183-
receiver: SharedCell::from(Mutex::new(receiver)),
178+
fn supervise_port<'py>(
179+
&self,
180+
py: Python<'py>,
181+
receiver: &PythonPortReceiver,
182+
) -> PyResult<PyObject> {
183+
let rx = MonitoredPythonPortReceiver {
184+
inner: receiver.inner(),
185+
monitor: PyActorMeshMonitor {
186+
receiver: SharedCell::from(Mutex::new(self.user_monitor_sender.subscribe())),
187+
},
188+
};
189+
Ok(rx.into_py(py))
190+
}
191+
192+
fn supervise_once_port<'py>(
193+
&self,
194+
py: Python<'py>,
195+
receiver: &PythonOncePortReceiver,
196+
) -> PyResult<PyObject> {
197+
let rx = MonitoredPythonOncePortReceiver {
198+
inner: receiver.inner(),
199+
monitor: PyActorMeshMonitor {
200+
receiver: SharedCell::from(Mutex::new(self.user_monitor_sender.subscribe())),
201+
},
184202
};
185-
Ok(monitor_instance.into_py(py))
203+
Ok(rx.into_py(py))
186204
}
187205

188206
#[pyo3(signature = (**kwargs))]
@@ -334,27 +352,22 @@ impl Drop for PythonActorMesh {
334352
}
335353
}
336354

355+
// `PyActorMeshMonitor` is not accessed directly from Python. It is
356+
// marked with `#[pyclass]` so it can be used as a field inside
357+
// `MonitoredPythonPortReceiver`.
337358
#[pyclass(
338359
name = "ActorMeshMonitor",
339360
module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
340361
)]
341-
pub struct PyActorMeshMonitor {
362+
struct PyActorMeshMonitor {
342363
receiver: SharedCell<Mutex<tokio::sync::broadcast::Receiver<Option<ActorSupervisionEvent>>>>,
343364
}
344365

345-
#[pymethods]
346366
impl PyActorMeshMonitor {
347-
fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
348-
slf
367+
fn __repr__(&self) -> &'static str {
368+
"<ActorMeshMonitor>"
349369
}
350370

351-
pub fn __anext__(&self, py: Python<'_>) -> PyResult<PyObject> {
352-
let receiver = self.receiver.clone();
353-
Ok(pyo3_async_runtimes::tokio::future_into_py(py, get_next(receiver))?.into())
354-
}
355-
}
356-
357-
impl PyActorMeshMonitor {
358371
pub async fn next(&self) -> PyResult<PyObject> {
359372
get_next(self.receiver.clone()).await
360373
}
@@ -392,25 +405,21 @@ async fn get_next(
392405
Ok(Python::with_gil(|py| supervision_event.into_py(py)))
393406
}
394407

395-
// TODO(albertli): this is temporary remove this when pushing all supervision logic to rust.
408+
// Values of this (private) type can only be created by calling
409+
// `PythonActorMesh::supervise_port()`.
396410
#[pyclass(
397411
name = "MonitoredPortReceiver",
398412
module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
399413
)]
400-
pub(super) struct MonitoredPythonPortReceiver {
414+
struct MonitoredPythonPortReceiver {
401415
inner: Arc<tokio::sync::Mutex<PortReceiver<PythonMessage>>>,
402416
monitor: PyActorMeshMonitor,
403417
}
404418

405419
#[pymethods]
406420
impl MonitoredPythonPortReceiver {
407-
#[new]
408-
fn new(receiver: &PythonPortReceiver, monitor: &PyActorMeshMonitor) -> Self {
409-
let inner = receiver.inner();
410-
MonitoredPythonPortReceiver {
411-
inner,
412-
monitor: monitor.clone(),
413-
}
421+
fn __repr__(&self) -> &'static str {
422+
"<MonitoredPortReceiver>"
414423
}
415424

416425
fn recv<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
@@ -446,24 +455,21 @@ impl MonitoredPythonPortReceiver {
446455
}
447456
}
448457

458+
// Values of this (private) type can only be created by calling
459+
// `PythonActorMesh::supervise_once_port()`.
449460
#[pyclass(
450461
name = "MonitoredOncePortReceiver",
451462
module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
452463
)]
453-
pub(super) struct MonitoredPythonOncePortReceiver {
464+
struct MonitoredPythonOncePortReceiver {
454465
inner: Arc<std::sync::Mutex<Option<OncePortReceiver<PythonMessage>>>>,
455466
monitor: PyActorMeshMonitor,
456467
}
457468

458469
#[pymethods]
459470
impl MonitoredPythonOncePortReceiver {
460-
#[new]
461-
fn new(receiver: &PythonOncePortReceiver, monitor: &PyActorMeshMonitor) -> Self {
462-
let inner = receiver.inner();
463-
MonitoredPythonOncePortReceiver {
464-
inner,
465-
monitor: monitor.clone(),
466-
}
471+
fn __repr__(&self) -> &'static str {
472+
"<MonitoredOncePortReceiver>"
467473
}
468474

469475
fn recv<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {

python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,15 @@ class PythonActorMesh:
8888
"""
8989
...
9090

91-
# TODO(albertli): remove this when pushing all supervision logic to Rust
92-
def monitor(self) -> ActorMeshMonitor:
91+
def supervise_port(self, r: PortReceiver) -> MonitoredPortReceiver:
9392
"""
94-
Returns a supervision monitor for this mesh.
93+
Return a monitored port receiver.
94+
"""
95+
...
96+
97+
def supervise_once_port(self, r: OncePortReceiver) -> MonitoredOncePortReceiver:
98+
"""
99+
Return a monitored once port receiver.
95100
"""
96101
...
97102

@@ -113,31 +118,11 @@ class PythonActorMesh:
113118
"""
114119
...
115120

116-
@final
117-
class ActorMeshMonitor:
118-
def __aiter__(self) -> AsyncIterator["ActorSupervisionEvent"]:
119-
"""
120-
Returns an async iterator for this monitor.
121-
"""
122-
...
123-
124-
async def __anext__(self) -> "ActorSupervisionEvent":
125-
"""
126-
Returns the next proc event in the proc mesh.
127-
"""
128-
...
129-
130121
@final
131122
class MonitoredPortReceiver:
123+
"""A monitored receiver to which PythonMessages are sent. Values
124+
of this type cannot be constructed directly in Python.
132125
"""
133-
A monitored receiver to which PythonMessages are sent.
134-
"""
135-
136-
def __init__(self, receiver: PortReceiver, monitor: ActorMeshMonitor) -> None:
137-
"""
138-
Create a new monitored receiver from a PortReceiver.
139-
"""
140-
...
141126

142127
async def recv(self) -> PythonMessage:
143128
"""Receive a PythonMessage from the port's sender."""
@@ -148,15 +133,9 @@ class MonitoredPortReceiver:
148133

149134
@final
150135
class MonitoredOncePortReceiver:
136+
"""A monitored once receiver to which PythonMessages are sent.
137+
Values of this type cannot be constructed directly in Python.
151138
"""
152-
A variant of monitored PortReceiver that can only receive a single message.
153-
"""
154-
155-
def __init__(self, receiver: OncePortReceiver, monitor: ActorMeshMonitor) -> None:
156-
"""
157-
Create a new monitored receiver from a PortReceiver.
158-
"""
159-
...
160139

161140
async def recv(self) -> PythonMessage:
162141
"""Receive a single PythonMessage from the port's sender."""

python/monarch/_src/actor/actor_mesh.py

Lines changed: 32 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
Optional,
3939
overload,
4040
ParamSpec,
41+
Protocol,
42+
runtime_checkable,
4143
Sequence,
4244
Tuple,
4345
Type,
@@ -50,12 +52,7 @@
5052
PythonMessage,
5153
PythonMessageKind,
5254
)
53-
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import (
54-
ActorMeshMonitor,
55-
MonitoredOncePortReceiver,
56-
MonitoredPortReceiver,
57-
PythonActorMesh,
58-
)
55+
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
5956
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
6057
Mailbox,
6158
OncePortReceiver,
@@ -306,6 +303,9 @@ def _send(
306303
def _port(self, once: bool = False) -> "PortTuple[R]":
307304
pass
308305

306+
def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any:
307+
return r
308+
309309
# the following are all 'adverbs' or different ways to handle the
310310
# return values of this endpoint. Adverbs should only ever take *args, **kwargs
311311
# of the original call. If we want to add syntax sugar for something that needs additional
@@ -399,6 +399,14 @@ def __init__(
399399
self._signature: inspect.Signature = inspect.signature(impl)
400400
self._mailbox = mailbox
401401

402+
def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any:
403+
mesh = self._actor_mesh._actor_mesh
404+
return (
405+
mesh.supervise_once_port(r)
406+
if isinstance(r, OncePortReceiver)
407+
else mesh.supervise_port(r)
408+
)
409+
402410
def _send(
403411
self,
404412
args: Tuple[Any, ...],
@@ -430,12 +438,8 @@ def _send(
430438
return Extent(shape.labels, shape.ndslice.sizes)
431439

432440
def _port(self, once: bool = False) -> "PortTuple[R]":
433-
monitor = (
434-
None
435-
if self._actor_mesh._actor_mesh is None
436-
else self._actor_mesh._actor_mesh.monitor()
437-
)
438-
return PortTuple.create(self._mailbox, monitor, once)
441+
p, r = PortTuple.create(self._mailbox, once)
442+
return PortTuple(p, PortReceiver(self._mailbox, self._supervise(r._receiver)))
439443

440444

441445
class Accumulator(Generic[P, R, A]):
@@ -589,21 +593,11 @@ class PortTuple(NamedTuple, Generic[R]):
589593
receiver: "PortReceiver[R]"
590594

591595
@staticmethod
592-
def create(
593-
mailbox: Mailbox, monitor: Optional[ActorMeshMonitor], once: bool = False
594-
) -> "PortTuple[Any]":
596+
def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]":
595597
handle, receiver = mailbox.open_once_port() if once else mailbox.open_port()
596598
port_ref = handle.bind()
597-
if monitor is not None:
598-
receiver = (
599-
MonitoredOncePortReceiver(receiver, monitor)
600-
if isinstance(receiver, OncePortReceiver)
601-
else MonitoredPortReceiver(receiver, monitor)
602-
)
603-
604599
return PortTuple(
605-
Port(port_ref, mailbox, rank=None),
606-
PortReceiver(mailbox, receiver),
600+
Port(port_ref, mailbox, rank=None), PortReceiver(mailbox, receiver)
607601
)
608602
else:
609603

@@ -612,21 +606,11 @@ class PortTuple(NamedTuple):
612606
receiver: "PortReceiver[Any]"
613607

614608
@staticmethod
615-
def create(
616-
mailbox: Mailbox, monitor: Optional[ActorMeshMonitor], once: bool = False
617-
) -> "PortTuple[Any]":
609+
def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]":
618610
handle, receiver = mailbox.open_once_port() if once else mailbox.open_port()
619611
port_ref = handle.bind()
620-
if monitor is not None:
621-
receiver = (
622-
MonitoredOncePortReceiver(receiver, monitor)
623-
if isinstance(receiver, OncePortReceiver)
624-
else MonitoredPortReceiver(receiver, monitor)
625-
)
626-
627612
return PortTuple(
628-
Port(port_ref, mailbox, rank=None),
629-
PortReceiver(mailbox, receiver),
613+
Port(port_ref, mailbox, rank=None), PortReceiver(mailbox, receiver)
630614
)
631615

632616

@@ -644,22 +628,19 @@ def ranked_port(
644628
return p, RankedPortReceiver[R](receiver._mailbox, receiver._receiver)
645629

646630

631+
R = TypeVar("R")
632+
633+
634+
@runtime_checkable
635+
class ReceiverLike(Protocol[R]):
636+
def blocking_recv(self) -> R: ...
637+
async def recv(self) -> R: ...
638+
639+
647640
class PortReceiver(Generic[R]):
648-
def __init__(
649-
self,
650-
mailbox: Mailbox,
651-
receiver: MonitoredPortReceiver
652-
| MonitoredOncePortReceiver
653-
| HyPortReceiver
654-
| OncePortReceiver,
655-
) -> None:
656-
self._mailbox: Mailbox = mailbox
657-
self._receiver: (
658-
MonitoredPortReceiver
659-
| MonitoredOncePortReceiver
660-
| HyPortReceiver
661-
| OncePortReceiver
662-
) = receiver
641+
def __init__(self, mailbox: Mailbox, receiver: ReceiverLike) -> None:
642+
self._mailbox = mailbox
643+
self._receiver: ReceiverLike = receiver
663644

664645
async def _recv(self) -> R:
665646
return self._process(await self._receiver.recv())

python/monarch/common/remote.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _port(self, once: bool = False) -> "PortTuple[R]":
144144
"Cannot create raw port objects with an old-style tensor engine controller."
145145
)
146146
mailbox: Mailbox = mesh_controller._mailbox
147-
return PortTuple.create(mailbox, None, once)
147+
return PortTuple.create(mailbox, once)
148148

149149
@property
150150
def _resolvable(self):

python/monarch/mesh_controller.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,7 @@ def fetch(
149149
defs: Tuple["Tensor", ...],
150150
uses: Tuple["Tensor", ...],
151151
) -> "OldFuture": # the OldFuture is a lie
152-
sender, receiver = PortTuple.create(
153-
self._mesh_controller._mailbox, None, once=True
154-
)
152+
sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True)
155153

156154
ident = self.new_node(defs, uses, cast("OldFuture", sender))
157155
process = mesh._process(shard)
@@ -187,9 +185,7 @@ def shutdown(
187185
atexit.unregister(self._atexit)
188186
self._shutdown = True
189187

190-
sender, receiver = PortTuple.create(
191-
self._mesh_controller._mailbox, None, once=True
192-
)
188+
sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True)
193189
assert sender._port_ref is not None
194190
self._mesh_controller.sync_at_exit(sender._port_ref.port_id)
195191
receiver.recv().get(timeout=60)

0 commit comments

Comments
 (0)