Skip to content

Lift the send-to-actor logic from _ActorMeshRefImpl to ActorIdRef #555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 115 additions & 33 deletions python/monarch/_src/actor/actor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def set(debug_context: "DebugContext") -> None:
# standin class for whatever is the serializable python object we use
# to name an actor mesh. Hacked up today because ActorMesh
# isn't plumbed to non-clients
class _ActorMeshRefImpl:
class _ActorMeshRefImpl(MeshTrait):
def __init__(
self,
mailbox: Mailbox,
Expand Down Expand Up @@ -177,16 +177,24 @@ def from_hyperactor_mesh(
[cast(ActorId, hy_actor_mesh.get(i)) for i in range(len(shape))],
)

@staticmethod
def from_actor_id(mailbox: Mailbox, actor_id: ActorId) -> "_ActorMeshRefImpl":
return _ActorMeshRefImpl(mailbox, None, None, singleton_shape, [actor_id])
def monitor(self) -> Optional[ActorMeshMonitor]:
return self._actor_mesh.monitor() if self._actor_mesh is not None else None

@staticmethod
def from_actor_ref_with_shape(
ref: "_ActorMeshRefImpl", shape: Shape
) -> "_ActorMeshRefImpl":
@property
def shape(self) -> Shape:
return self._shape

@property
def _ndslice(self) -> NDSlice:
return self._shape.ndslice

@property
def _labels(self) -> Iterable[str]:
return self._shape.labels

def _new_with_shape(self, shape: Shape) -> "_ActorMeshRefImpl":
return _ActorMeshRefImpl(
ref._mailbox, None, None, shape, ref._please_replace_me_actor_ids
self._mailbox, None, None, shape, self._please_replace_me_actor_ids
)

def __getstate__(
Expand Down Expand Up @@ -270,6 +278,35 @@ def _name_pid(self):
return actor_id0.actor_name, actor_id0.pid


class ActorIdFakeMesh:
"""
Fake mesh that represents a single actor. This is used to allow sending a
message to a single actor through the mesh API.
"""

def __init__(
self,
actor_id: ActorId,
mailbox: Mailbox,
) -> None:
self._actor_id = actor_id
self._mailbox = mailbox

def cast(
self,
message: PythonMessage,
selection: Selection,
) -> None:
self._mailbox.post(self._actor_id, message)

@property
def shape(self) -> Shape:
return singleton_shape

def monitor(self) -> Optional[ActorMeshMonitor]:
return None


class Extent(NamedTuple):
labels: Sequence[str]
sizes: Sequence[int]
Expand Down Expand Up @@ -389,7 +426,7 @@ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
class ActorEndpoint(Endpoint[P, R]):
def __init__(
self,
actor_mesh_ref: _ActorMeshRefImpl,
actor_mesh_ref: _ActorMeshRefImpl | ActorIdFakeMesh,
name: str,
impl: Callable[Concatenate[Any, P], Awaitable[R]],
mailbox: Mailbox,
Expand Down Expand Up @@ -426,15 +463,11 @@ def _send(
importlib.import_module("monarch." + "mesh_controller").actor_send(
self, bytes, refs, port, selection
)
shape = self._actor_mesh._shape
shape = self._actor_mesh.shape
return Extent(shape.labels, shape.ndslice.sizes)

def _port(self, once: bool = False) -> "PortTuple[R]":
monitor = (
None
if self._actor_mesh._actor_mesh is None
else self._actor_mesh._actor_mesh.monitor()
)
monitor = self._actor_mesh.monitor()
return PortTuple.create(self._mailbox, monitor, once)


Expand Down Expand Up @@ -874,19 +907,22 @@ def _labels(self) -> Tuple[str, ...]:
"actor implementations are not meshes, but we can't convince the typechecker of it..."
)

def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
def _new_with_shape(self, shape: Shape) -> "Actor":
raise NotImplementedError(
"actor implementations are not meshes, but we can't convince the typechecker of it..."
)


class ActorMeshRef(MeshTrait):
class _ActorMeshTrait(MeshTrait):
def __init__(
self, Class: Type[T], actor_mesh_ref: _ActorMeshRefImpl, mailbox: Mailbox
self,
Class: Type[T],
actor_mesh_ref: _ActorMeshRefImpl | ActorIdFakeMesh,
mailbox: Mailbox,
) -> None:
self.__name__: str = Class.__name__
self._class: Type[T] = Class
self._actor_mesh_ref: _ActorMeshRefImpl = actor_mesh_ref
self._actor_mesh_ref: _ActorMeshRefImpl | ActorIdFakeMesh = actor_mesh_ref
self._mailbox: Mailbox = mailbox
for attr_name in dir(self._class):
attr_value = getattr(self._class, attr_name, None)
Expand Down Expand Up @@ -928,6 +964,40 @@ def __getattr__(self, name: str) -> Any:
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)

@property
def _ndslice(self) -> NDSlice:
raise NotImplementedError(
"should not be called because def slice is overridden"
)

@property
def _labels(self) -> Iterable[str]:
raise NotImplementedError(
"should not be called because def slice is overridden"
)

def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
raise NotImplementedError(
"should not be called because def slice is overridden"
)


class ActorMeshRef(_ActorMeshTrait, Generic[T]):
def __init__(
self,
Class: Type[T],
actor_mesh: _ActorMeshRefImpl,
mailbox: Mailbox,
) -> None:
super().__init__(Class, actor_mesh, mailbox)

def _inner(self) -> "_ActorMeshRefImpl":
mesh = self._actor_mesh_ref
assert isinstance(
mesh, _ActorMeshRefImpl
), f"mesh type is {mesh.__class__.__name__}"
return mesh

def _create(
self,
args: Iterable[Any],
Expand All @@ -953,23 +1023,35 @@ def __reduce_ex__(
self._mailbox,
)

@property
def _ndslice(self) -> NDSlice:
return self._actor_mesh_ref._shape.ndslice
def slice(self, **kwargs) -> "ActorMeshRef":
sliced = self._inner().slice(**kwargs)
return ActorMeshRef(self._class, sliced, self._mailbox)

@property
def _labels(self) -> Iterable[str]:
return self._actor_mesh_ref._shape.labels
def __repr__(self) -> str:
return f"ActorMeshRef(class={self._class}, shape={self._inner().shape})"

def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
return ActorMeshRef(
self._class,
_ActorMeshRefImpl.from_actor_ref_with_shape(self._actor_mesh_ref, shape),
self._mailbox,
)

class ActorIdRef(_ActorMeshTrait, Generic[T]):
def __init__(
self,
Class: Type[T],
actor_id: ActorId,
mailbox: Mailbox,
) -> None:
super().__init__(Class, ActorIdFakeMesh(actor_id, mailbox), mailbox)

def _inner(self) -> "ActorId":
mesh = self._actor_mesh_ref
assert isinstance(
mesh, ActorIdFakeMesh
), f"mesh type is {mesh.__class__.__name__}"
return mesh._actor_id

def slice(self, **kwargs) -> "ActorIdRef":
raise NotImplementedError("ActorIdRef does not support slicing")

def __repr__(self) -> str:
return f"ActorMeshRef(class={self._class}, shape={self._actor_mesh_ref._shape})"
return f"ActorIdRef(class={self._class}, actor_id={self._inner()})"


class ActorError(Exception):
Expand Down
12 changes: 3 additions & 9 deletions python/monarch/_src/actor/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@

from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
from monarch._src.actor.actor_mesh import (
_ActorMeshRefImpl,
Actor,
ActorMeshRef,
ActorIdRef,
DebugContext,
endpoint,
MonarchContext,
Expand Down Expand Up @@ -511,14 +510,9 @@ def ref() -> "DebugManager":
ctx = MonarchContext.get()
return cast(
DebugManager,
ActorMeshRef(
ActorIdRef(
DebugManager,
_ActorMeshRefImpl.from_actor_id(
ctx.mailbox,
ActorId.from_string(
f"{ctx.proc_id}.{_DEBUG_MANAGER_ACTOR_NAME}[0]"
),
),
ActorId.from_string(f"{ctx.proc_id}.{_DEBUG_MANAGER_ACTOR_NAME}[0]"),
ctx.mailbox,
),
)
Expand Down
17 changes: 5 additions & 12 deletions python/monarch/rdma.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import ctypes

from dataclasses import dataclass
Expand All @@ -12,13 +14,7 @@
import torch

from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
from monarch._src.actor.actor_mesh import (
_ActorMeshRefImpl,
Actor,
ActorMeshRef,
endpoint,
MonarchContext,
)
from monarch._src.actor.actor_mesh import Actor, ActorIdRef, endpoint, MonarchContext


@dataclass
Expand Down Expand Up @@ -54,12 +50,9 @@ def on_proc(proc_id: str) -> "RDMAManager":
ctx = MonarchContext.get()
return cast(
RDMAManager,
ActorMeshRef(
ActorIdRef(
RDMAManager,
_ActorMeshRefImpl.from_actor_id(
ctx.mailbox,
ActorId.from_string(f"{proc_id}.rdma_manager[0]"),
),
ActorId.from_string(f"{proc_id}.rdma_manager[0]"),
ctx.mailbox,
),
)
Expand Down
Loading