Skip to content

[22/n] tensor engine, Move Endpoint to its own file, move propagator to endpoint #624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 16 additions & 157 deletions python/monarch/_src/actor/actor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,16 @@

# pyre-unsafe

import asyncio
import collections
import contextvars
import functools
import importlib
import inspect
import itertools
import logging
import random
import traceback

from abc import ABC, abstractmethod

from dataclasses import dataclass
from operator import mul
from traceback import extract_tb, StackSummary
from typing import (
Any,
Expand All @@ -34,13 +29,9 @@
Iterable,
Iterator,
List,
Literal,
NamedTuple,
Optional,
overload,
ParamSpec,
Protocol,
Sequence,
Tuple,
Type,
TYPE_CHECKING,
Expand Down Expand Up @@ -68,9 +59,15 @@
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Shape
from monarch._rust_bindings.monarch_hyperactor.supervision import SupervisionError

from monarch._rust_bindings.monarch_hyperactor.telemetry import enter_span, exit_span
from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator
from monarch._src.actor.endpoint import (
Endpoint,
EndpointProperty,
Extent,
Propagator,
Selection,
)
from monarch._src.actor.future import Future
from monarch._src.actor.pdb_wrapper import PdbWrapper

Expand All @@ -79,6 +76,8 @@
from monarch._src.actor.shape import MeshTrait, NDSlice
from monarch._src.actor.sync_state import fake_sync_state

from monarch._src.actor.tensor_engine_shim import actor_send

if TYPE_CHECKING:
from monarch._src.actor.proc_mesh import ProcMesh

Expand Down Expand Up @@ -144,9 +143,6 @@ def set(debug_context: "DebugContext") -> None:
_load_balancing_seed = random.Random(4)


Selection = Literal["all", "choose"] | int # TODO: replace with real selection objects


# standin class for whatever is the serializable python object we use
# to name an actor mesh. Hacked up today because ActorMesh
# isn't plumbed to non-clients
Expand Down Expand Up @@ -281,122 +277,16 @@ async def stop(self):
await self._actor_mesh.stop()


class Extent(NamedTuple):
labels: Sequence[str]
sizes: Sequence[int]

@property
def nelements(self) -> int:
return functools.reduce(mul, self.sizes, 1)

def __str__(self) -> str:
return str(dict(zip(self.labels, self.sizes)))


class Endpoint(ABC, Generic[P, R]):
@abstractmethod
def _send(
self,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
port: "Optional[Port]" = None,
selection: Selection = "all",
) -> Extent:
"""
Implements sending a message to the endpoint. The return value of the endpoint will
be sent to port if provided. If port is not provided, the return will be dropped,
and any exception will cause the actor to fail.

The return value is the (multi-dimension) size of the actors that were sent a message.
For ActorEndpoints this will be the actor_meshes size. For free-function endpoints,
this will be the size of the currently active proc_mesh.
"""
pass

@abstractmethod
def _port(self, once: bool = False) -> "PortTuple[R]":
pass

def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any:
return r

# the following are all 'adverbs' or different ways to handle the
# return values of this endpoint. Adverbs should only ever take *args, **kwargs
# of the original call. If we want to add syntax sugar for something that needs additional
# arguments, it should be implemented as function indepdendent of endpoint like `send`
# and `Accumulator`
def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
"""
Load balanced sends a message to one chosen actor and awaits a result.

Load balanced RPC-style entrypoint for request/response messaging.
"""
p, r = port(self, once=True)
# pyre-ignore
self._send(args, kwargs, port=p, selection="choose")
return r.recv()

def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
p, r = port(self, once=True)
# pyre-ignore
extent = self._send(args, kwargs, port=p, selection="choose")
if extent.nelements != 1:
raise ValueError(
f"Can only use 'call_one' on a single Actor but this actor has shape {extent}"
)
return r.recv()

def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
p, r = ranked_port(self)
# pyre-ignore
extent = self._send(args, kwargs, port=p)

async def process() -> ValueMesh[R]:
results: List[R] = [None] * extent.nelements # pyre-fixme[9]
for _ in range(extent.nelements):
rank, value = await r.recv()
results[rank] = value
call_shape = Shape(
extent.labels,
NDSlice.new_row_major(extent.sizes),
)
return ValueMesh(call_shape, results)

return Future(impl=process, requires_loop=False)

async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R]:
"""
Broadcasts to all actors and yields their responses as a stream / generator.

This enables processing results from multiple actors incrementally as
they become available. Returns an async generator of response values.
"""
p, r = port(self)
# pyre-ignore
extent = self._send(args, kwargs, port=p)
for _ in range(extent.nelements):
yield await r.recv()

def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
"""
Fire-and-forget broadcast to all actors without waiting for actors to
acknowledge receipt.

In other words, the return of this method does not guarrantee the
delivery of the message.
"""
# pyre-ignore
send(self, args, kwargs)


class ActorEndpoint(Endpoint[P, R]):
def __init__(
self,
actor_mesh_ref: _ActorMeshRefImpl,
name: str,
impl: Callable[Concatenate[Any, P], Awaitable[R]],
mailbox: Mailbox,
propagator: Propagator = None,
) -> None:
super().__init__(propagator)
self._actor_mesh = actor_mesh_ref
self._name = name
self._signature: inspect.Signature = inspect.signature(impl)
Expand All @@ -406,6 +296,9 @@ def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any:
mesh = self._actor_mesh._actor_mesh
return r if mesh is None else mesh.supervise(r)

def _call_name(self) -> Any:
return self._name

def _send(
self,
args: Tuple[Any, ...],
Expand All @@ -430,9 +323,7 @@ def _send(
)
self._actor_mesh.cast(message, selection)
else:
importlib.import_module("monarch." + "mesh_controller").actor_send(
self, bytes, refs, port, selection
)
actor_send(self, bytes, refs, port, selection)
shape = self._actor_mesh._shape
return Extent(shape.labels, shape.ndslice.sizes)

Expand Down Expand Up @@ -521,39 +412,6 @@ def send(
endpoint._send(args, kwargs, port, selection)


class EndpointProperty(Generic[P, R]):
@overload
def __init__(self, method: Callable[Concatenate[Any, P], Awaitable[R]]) -> None: ...

@overload
def __init__(self, method: Callable[Concatenate[Any, P], R]) -> None: ...

def __init__(self, method: Any) -> None:
self._method = method

def __get__(self, instance, owner) -> Endpoint[P, R]:
# this is a total lie, but we have to actually
# recognize this was defined as an endpoint,
# and also lookup the method
return cast(Endpoint[P, R], self)


@overload
def endpoint(
method: Callable[Concatenate[Any, P], Awaitable[R]],
) -> EndpointProperty[P, R]: ...


@overload
def endpoint(
method: Callable[Concatenate[Any, P], R],
) -> EndpointProperty[P, R]: ...


def endpoint(method):
return EndpointProperty(method)


class Port(Generic[R]):
def __init__(
self,
Expand Down Expand Up @@ -919,6 +777,7 @@ def __getattr__(self, name: str) -> Any:
name,
attr._method,
self._mailbox,
propagator=attr._propagator,
)
# Cache it for future use
setattr(self, name, endpoint)
Expand Down
3 changes: 2 additions & 1 deletion python/monarch/_src/actor/code_sync/auto_reload.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from types import ModuleType
from typing import Dict, List, Optional, Tuple

from monarch._src.actor.actor_mesh import Actor, endpoint
from monarch._src.actor.actor_mesh import Actor
from monarch._src.actor.endpoint import endpoint


class SysAuditHookGuard(contextlib.AbstractContextManager):
Expand Down
2 changes: 1 addition & 1 deletion python/monarch/_src/actor/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
Actor,
ActorMeshRef,
DebugContext,
endpoint,
MonarchContext,
)
from monarch._src.actor.endpoint import endpoint
from monarch._src.actor.pdb_wrapper import DebuggerWrite, PdbWrapper
from monarch._src.actor.sync_state import fake_sync_state

Expand Down
Loading
Loading