Skip to content

Commit 7543316

Browse files
pzhan9facebook-github-bot
authored andcommitted
Migrate to PythonActorMesh and PythonActorMeshRef
Summary: This diff swaps `_ActorMeshRefImpl` with `PythonActorMesh[Ref]`. The swap itself should be straightforward since `PythonActorMesh[Ref]` should be drop-in replacements for `_ActorMeshRefImpl`. Most of the complexity in this diff is from how I tried to add a toggle between them, just in case there is any bugs with `PythonActorMesh[Ref]`, so we can quickly switch back to `_ActorMeshRefImpl`. What I did is: 1. Add wrapper classes `EitherPyActorMesh[Ref]`, whose underlying type can be either `PythonActorMesh[Ref]` or `_ActorMeshRefImpl`; 2. a env var `USE_STANDIN_ACTOR_MESH` is used to which one would be used when instantiating `EitherPyActorMesh[Ref]`. The landing of this diff would mean all Python-side mesh API calls should go through Rust-side's `cast` code path, except several usages of `ActorIdRef`. Differential Revision: D78355743
1 parent 6d41b1f commit 7543316

File tree

2 files changed

+170
-27
lines changed

2 files changed

+170
-27
lines changed

python/monarch/_src/actor/actor_mesh.py

Lines changed: 165 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import inspect
1414
import itertools
1515
import logging
16+
import os
1617
import pickle
1718
import random
1819
import traceback
@@ -56,6 +57,7 @@
5657
MonitoredOncePortReceiver,
5758
MonitoredPortReceiver,
5859
PythonActorMesh,
60+
PythonActorMeshRef,
5961
)
6062
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
6163
Mailbox,
@@ -65,9 +67,9 @@
6567
PortRef,
6668
)
6769
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
70+
from monarch._rust_bindings.monarch_hyperactor.selection import Selection as HySelection
6871
from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Shape
6972
from monarch._rust_bindings.monarch_hyperactor.supervision import SupervisionError
70-
7173
from monarch._rust_bindings.monarch_hyperactor.telemetry import enter_span, exit_span
7274
from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator
7375
from monarch._src.actor.future import Future
@@ -146,6 +148,28 @@ def set(debug_context: "DebugContext") -> None:
146148
Selection = Literal["all", "choose"] | int # TODO: replace with real selection objects
147149

148150

151+
def to_hy_sel(selection: Selection, shape: Shape) -> HySelection:
152+
if selection == "choose":
153+
dim = len(shape.labels)
154+
assert dim > 0
155+
query = ",".join(["?"] * dim)
156+
return HySelection.from_string(f"{query}")
157+
elif selection == "all":
158+
return HySelection.from_string("*")
159+
else:
160+
raise ValueError(f"invalid selection: {selection}")
161+
162+
163+
# A temporary gate used by the PythonActorMesh/PythonActorMeshRef migration.
164+
# We can use this gate to quickly roll back to using _ActorMeshRefImpl, if we
165+
# encounter any issues with the migration.
166+
#
167+
# This should be removed once we confirm PythonActorMesh/PythonActorMeshRef is
168+
# working correctly in production.
169+
def _use_standin_mesh() -> bool:
170+
return bool(os.getenv("USE_STANDIN_ACTOR_MESH", default=False))
171+
172+
149173
# standin class for whatever is the serializable python object we use
150174
# to name an actor mesh. Hacked up today because ActorMesh
151175
# isn't plumbed to non-clients
@@ -158,6 +182,10 @@ def __init__(
158182
shape: Shape,
159183
actor_ids: List[ActorId],
160184
) -> None:
185+
if not _use_standin_mesh():
186+
raise ValueError(
187+
"ActorMeshRefImpl should only be used when USE_STANDIN_ACTOR_MESH is set"
188+
)
161189
self._mailbox = mailbox
162190
self._actor_mesh = hy_actor_mesh
163191
# actor meshes do not have a way to look this up at the moment,
@@ -296,8 +324,8 @@ def __init__(
296324

297325
def cast(
298326
self,
299-
message: PythonMessage,
300327
selection: Selection,
328+
message: PythonMessage,
301329
) -> None:
302330
self._mailbox.post(self._actor_id, message)
303331

@@ -309,6 +337,110 @@ def monitor(self) -> Optional[ActorMeshMonitor]:
309337
return None
310338

311339

340+
# A temporary wrapper used by the PythonActorMesh/PythonActorMeshRef migration.
341+
# This wrapper is used to enable switching between PythonActorMeshRef and
342+
# _ActorMeshRefImpl through the `USE_STANDIN_ACTOR_MESH` env var.
343+
#
344+
# This should be removed once we confirm PythonActorMesh/PythonActorMeshRef is
345+
# working correctly in production.
346+
class EitherPyActorMeshRef:
347+
def __init__(self, inner: PythonActorMeshRef | _ActorMeshRefImpl) -> None:
348+
if _use_standin_mesh():
349+
assert isinstance(
350+
inner, _ActorMeshRefImpl
351+
), "expect _ActorMeshRefImpl because env var USE_STANDIN_ACTOR_MESH is set"
352+
else:
353+
assert isinstance(
354+
inner, PythonActorMeshRef
355+
), "expect PythonActorMeshRef because env var USE_STANDIN_ACTOR_MESH is not set"
356+
self._inner: PythonActorMeshRef | _ActorMeshRefImpl = inner
357+
358+
def cast(
359+
self, mailbox: Mailbox, selection: Selection, message: PythonMessage
360+
) -> None:
361+
inner = self._inner
362+
if isinstance(inner, _ActorMeshRefImpl):
363+
inner.cast(message, selection)
364+
elif isinstance(inner, PythonActorMeshRef):
365+
inner.cast(mailbox, to_hy_sel(selection, self.shape), message)
366+
else:
367+
raise ValueError(f"unsupported mesh type: {inner.__class__.__name__}")
368+
369+
def slice(self, **kwargs) -> "EitherPyActorMeshRef":
370+
return EitherPyActorMeshRef(self._inner.slice(**kwargs))
371+
372+
@property
373+
def shape(self) -> Shape:
374+
return self._inner.shape
375+
376+
def monitor(self) -> Optional[ActorMeshMonitor]:
377+
return None
378+
379+
380+
# A temporary wrapper used by the PythonActorMesh/PythonActorMesh migration.
381+
# This wrapper is used to enable switching between PythonActorMesh and
382+
# _ActorMeshRefImpl through the `USE_STANDIN_ACTOR_MESH` env var.
383+
#
384+
# This should be removed once we confirm PythonActorMesh/PythonActorMeshRef is
385+
# working correctly in production.
386+
class EitherPyActorMesh:
387+
def __init__(
388+
self, actor_mesh: PythonActorMesh, mailbox: Mailbox, proc_mesh: "ProcMesh"
389+
) -> None:
390+
if _use_standin_mesh():
391+
inner = _ActorMeshRefImpl.from_hyperactor_mesh(
392+
mailbox, actor_mesh, proc_mesh
393+
)
394+
else:
395+
inner = actor_mesh
396+
self._inner: PythonActorMesh | _ActorMeshRefImpl = inner
397+
self._proc_mesh = proc_mesh
398+
399+
def bind(self) -> "EitherPyActorMeshRef":
400+
inner = self._inner
401+
if isinstance(inner, PythonActorMesh):
402+
return EitherPyActorMeshRef(inner.bind())
403+
elif isinstance(inner, _ActorMeshRefImpl):
404+
return EitherPyActorMeshRef(inner)
405+
else:
406+
raise ValueError(f"unsupported mesh type: {inner.__class__.__name__}")
407+
408+
def cast(self, selection: Selection, message: PythonMessage) -> None:
409+
inner = self._inner
410+
if isinstance(inner, _ActorMeshRefImpl):
411+
inner.cast(message, selection)
412+
elif isinstance(inner, PythonActorMesh):
413+
inner.cast(to_hy_sel(selection, self.shape), message)
414+
else:
415+
raise ValueError(f"unsupported mesh type: {inner.__class__.__name__}")
416+
417+
def slice(self, **kwargs) -> "EitherPyActorMeshRef":
418+
return EitherPyActorMeshRef(self._inner.slice(**kwargs))
419+
420+
@property
421+
def shape(self) -> Shape:
422+
return self._inner.shape
423+
424+
def monitor(self) -> Optional[ActorMeshMonitor]:
425+
return self._inner.monitor()
426+
427+
@property
428+
def proc_mesh(self) -> "ProcMesh":
429+
return self._proc_mesh
430+
431+
@property
432+
def name_pid(self) -> Tuple[str, int]:
433+
inner = self._inner
434+
if isinstance(inner, _ActorMeshRefImpl):
435+
return inner._name_pid
436+
elif isinstance(inner, PythonActorMesh):
437+
actor_id0 = inner.get(0)
438+
assert actor_id0 is not None
439+
return actor_id0.actor_name, actor_id0.pid
440+
else:
441+
raise ValueError(f"unsupported mesh type: {inner.__class__.__name__}")
442+
443+
312444
class Extent(NamedTuple):
313445
labels: Sequence[str]
314446
sizes: Sequence[int]
@@ -377,26 +509,28 @@ def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
377509
extent = self._send(args, kwargs, port=p)
378510

379511
async def process() -> ValueMesh[R]:
380-
results: List[R] = [None] * extent.nelements # pyre-fixme[9]
512+
results: Dict[int, R] = dict()
381513
for _ in range(extent.nelements):
382514
rank, value = await r.recv()
383515
results[rank] = value
384516
call_shape = Shape(
385517
extent.labels,
386518
NDSlice.new_row_major(extent.sizes),
387519
)
388-
return ValueMesh(call_shape, results)
520+
sorted_values = [results[rank] for rank in sorted(results)]
521+
return ValueMesh(call_shape, sorted_values)
389522

390523
def process_blocking() -> ValueMesh[R]:
391-
results: List[R] = [None] * extent.nelements # pyre-fixme[9]
524+
results: Dict[int, R] = dict()
392525
for _ in range(extent.nelements):
393526
rank, value = r.recv().get()
394527
results[rank] = value
395528
call_shape = Shape(
396529
extent.labels,
397530
NDSlice.new_row_major(extent.sizes),
398531
)
399-
return ValueMesh(call_shape, results)
532+
sorted_values = [results[rank] for rank in sorted(results)]
533+
return ValueMesh(call_shape, sorted_values)
400534

401535
return Future(process, process_blocking)
402536

@@ -428,7 +562,7 @@ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
428562
class ActorEndpoint(Endpoint[P, R]):
429563
def __init__(
430564
self,
431-
actor_mesh_ref: _ActorMeshRefImpl | ActorIdFakeMesh,
565+
actor_mesh_ref: EitherPyActorMesh | EitherPyActorMeshRef | ActorIdFakeMesh,
432566
name: str,
433567
impl: Callable[Concatenate[Any, P], Awaitable[R]],
434568
mailbox: Mailbox,
@@ -469,7 +603,15 @@ def _send(
469603
),
470604
bytes,
471605
)
472-
self._actor_mesh.cast(message, selection)
606+
mesh = self._actor_mesh
607+
if isinstance(mesh, EitherPyActorMeshRef):
608+
mesh.cast(self._mailbox, selection, message)
609+
elif isinstance(mesh, EitherPyActorMesh) or isinstance(
610+
mesh, ActorIdFakeMesh
611+
):
612+
mesh.cast(selection, message)
613+
else:
614+
raise ValueError(f"unsupported mesh type: {mesh.__class__.__name__}")
473615
else:
474616
importlib.import_module("monarch." + "mesh_controller").actor_send(
475617
self, self._name, bytes, refs, port
@@ -931,12 +1073,14 @@ class _ActorMeshTrait(MeshTrait):
9311073
def __init__(
9321074
self,
9331075
Class: Type[T],
934-
actor_mesh_ref: _ActorMeshRefImpl | ActorIdFakeMesh,
1076+
actor_mesh_ref: EitherPyActorMesh | EitherPyActorMeshRef | ActorIdFakeMesh,
9351077
mailbox: Mailbox,
9361078
) -> None:
9371079
self.__name__: str = Class.__name__
9381080
self._class: Type[T] = Class
939-
self._actor_mesh_ref: _ActorMeshRefImpl | ActorIdFakeMesh = actor_mesh_ref
1081+
self._actor_mesh_ref: (
1082+
EitherPyActorMesh | EitherPyActorMeshRef | ActorIdFakeMesh
1083+
) = actor_mesh_ref
9401084
self._mailbox: Mailbox = mailbox
9411085
for attr_name in dir(self._class):
9421086
attr_value = getattr(self._class, attr_name, None)
@@ -1003,20 +1147,22 @@ class ActorMeshHandle(_ActorMeshTrait, Generic[T]):
10031147
def __init__(
10041148
self,
10051149
Class: Type[T],
1006-
actor_mesh: _ActorMeshRefImpl,
1150+
actor_mesh: PythonActorMesh,
10071151
mailbox: Mailbox,
1152+
proc_mesh: "ProcMesh",
10081153
) -> None:
1009-
super().__init__(Class, actor_mesh, mailbox)
1154+
wrapper = EitherPyActorMesh(actor_mesh, mailbox, proc_mesh)
1155+
super().__init__(Class, wrapper, mailbox)
10101156

1011-
def _inner(self) -> "_ActorMeshRefImpl":
1157+
def _inner(self) -> "EitherPyActorMesh":
10121158
mesh = self._actor_mesh_ref
10131159
assert isinstance(
1014-
mesh, _ActorMeshRefImpl
1160+
mesh, EitherPyActorMesh
10151161
), f"mesh type is {mesh.__class__.__name__}"
10161162
return mesh
10171163

10181164
def bind(self) -> "ActorMeshRef[T]":
1019-
return ActorMeshRef(self._class, self._inner(), self._mailbox)
1165+
return ActorMeshRef(self._class, self._inner().bind(), self._mailbox)
10201166

10211167
def _create(
10221168
self,
@@ -1056,22 +1202,22 @@ def proc_mesh(self) -> "Optional[ProcMesh]":
10561202

10571203
@property
10581204
def name_pid(self) -> Tuple[str, int]:
1059-
return self._inner()._name_pid
1205+
return self._inner().name_pid
10601206

10611207

10621208
class ActorMeshRef(_ActorMeshTrait, Generic[T]):
10631209
def __init__(
10641210
self,
10651211
Class: Type[T],
1066-
actor_mesh: _ActorMeshRefImpl,
1212+
actor_mesh: EitherPyActorMeshRef,
10671213
mailbox: Mailbox,
10681214
) -> None:
10691215
super().__init__(Class, actor_mesh, mailbox)
10701216

1071-
def _inner(self) -> "_ActorMeshRefImpl":
1217+
def _inner(self) -> "EitherPyActorMeshRef":
10721218
mesh = self._actor_mesh_ref
10731219
assert isinstance(
1074-
mesh, _ActorMeshRefImpl
1220+
mesh, EitherPyActorMeshRef
10751221
), f"mesh type is {mesh.__class__.__name__}"
10761222
return mesh
10771223

python/monarch/_src/actor/proc_mesh.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,7 @@
3636
ProcMeshMonitor,
3737
)
3838
from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice
39-
from monarch._src.actor.actor_mesh import (
40-
_Actor,
41-
_ActorMeshRefImpl,
42-
Actor,
43-
ActorMeshHandle,
44-
)
39+
from monarch._src.actor.actor_mesh import _Actor, Actor, ActorMeshHandle
4540
from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator, SimAllocator
4641
from monarch._src.actor.code_sync import RsyncMeshClient, WorkspaceLocation
4742
from monarch._src.actor.code_sync.auto_reload import AutoReloadActor
@@ -180,8 +175,9 @@ def _spawn_blocking(
180175
actor_mesh = self._proc_mesh.spawn_blocking(name, _Actor)
181176
service = ActorMeshHandle(
182177
Class,
183-
_ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh, self),
178+
actor_mesh,
184179
self._mailbox,
180+
self._proc_mesh,
185181
)
186182
# useful to have this separate, because eventually we can reconstitute ActorMeshRef objects across pickling by
187183
# doing `ActorMeshRef(Class, actor_handle)` but not calling _create.
@@ -205,8 +201,9 @@ async def _spawn_nonblocking(
205201
actor_mesh = await self._proc_mesh.spawn_nonblocking(name, _Actor)
206202
service = ActorMeshHandle(
207203
Class,
208-
_ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh, self),
204+
actor_mesh,
209205
self._mailbox,
206+
self._proc_mesh,
210207
)
211208
# useful to have this separate, because eventually we can reconstitute ActorMeshRef objects across pickling by
212209
# doing `ActorMeshRef(Class, actor_handle)` but not calling _create.

0 commit comments

Comments
 (0)