diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index bcf0d7f13..2c9ad5317 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -6,21 +6,16 @@ # pyre-unsafe -import asyncio import collections import contextvars import functools -import importlib import inspect import itertools import logging import random import traceback -from abc import ABC, abstractmethod - from dataclasses import dataclass -from operator import mul from traceback import extract_tb, StackSummary from typing import ( Any, @@ -34,13 +29,9 @@ Iterable, Iterator, List, - Literal, NamedTuple, Optional, - overload, ParamSpec, - Protocol, - Sequence, Tuple, Type, TYPE_CHECKING, @@ -68,9 +59,15 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Shape from monarch._rust_bindings.monarch_hyperactor.supervision import SupervisionError - from monarch._rust_bindings.monarch_hyperactor.telemetry import enter_span, exit_span from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator +from monarch._src.actor.endpoint import ( + Endpoint, + EndpointProperty, + Extent, + Propagator, + Selection, +) from monarch._src.actor.future import Future from monarch._src.actor.pdb_wrapper import PdbWrapper @@ -79,6 +76,8 @@ from monarch._src.actor.shape import MeshTrait, NDSlice from monarch._src.actor.sync_state import fake_sync_state +from monarch._src.actor.tensor_engine_shim import actor_send + if TYPE_CHECKING: from monarch._src.actor.proc_mesh import ProcMesh @@ -144,9 +143,6 @@ def set(debug_context: "DebugContext") -> None: _load_balancing_seed = random.Random(4) -Selection = Literal["all", "choose"] | int # TODO: replace with real selection objects - - # standin class for whatever is the serializable python object we use # to name an actor mesh. Hacked up today because ActorMesh # isn't plumbed to non-clients @@ -281,114 +277,6 @@ async def stop(self): await self._actor_mesh.stop() -class Extent(NamedTuple): - labels: Sequence[str] - sizes: Sequence[int] - - @property - def nelements(self) -> int: - return functools.reduce(mul, self.sizes, 1) - - def __str__(self) -> str: - return str(dict(zip(self.labels, self.sizes))) - - -class Endpoint(ABC, Generic[P, R]): - @abstractmethod - def _send( - self, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - port: "Optional[Port]" = None, - selection: Selection = "all", - ) -> Extent: - """ - Implements sending a message to the endpoint. The return value of the endpoint will - be sent to port if provided. If port is not provided, the return will be dropped, - and any exception will cause the actor to fail. - - The return value is the (multi-dimension) size of the actors that were sent a message. - For ActorEndpoints this will be the actor_meshes size. For free-function endpoints, - this will be the size of the currently active proc_mesh. - """ - pass - - @abstractmethod - def _port(self, once: bool = False) -> "PortTuple[R]": - pass - - def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any: - return r - - # the following are all 'adverbs' or different ways to handle the - # return values of this endpoint. Adverbs should only ever take *args, **kwargs - # of the original call. If we want to add syntax sugar for something that needs additional - # arguments, it should be implemented as function indepdendent of endpoint like `send` - # and `Accumulator` - def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: - """ - Load balanced sends a message to one chosen actor and awaits a result. - - Load balanced RPC-style entrypoint for request/response messaging. - """ - p, r = port(self, once=True) - # pyre-ignore - self._send(args, kwargs, port=p, selection="choose") - return r.recv() - - def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: - p, r = port(self, once=True) - # pyre-ignore - extent = self._send(args, kwargs, port=p, selection="choose") - if extent.nelements != 1: - raise ValueError( - f"Can only use 'call_one' on a single Actor but this actor has shape {extent}" - ) - return r.recv() - - def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]": - p, r = ranked_port(self) - # pyre-ignore - extent = self._send(args, kwargs, port=p) - - async def process() -> ValueMesh[R]: - results: List[R] = [None] * extent.nelements # pyre-fixme[9] - for _ in range(extent.nelements): - rank, value = await r.recv() - results[rank] = value - call_shape = Shape( - extent.labels, - NDSlice.new_row_major(extent.sizes), - ) - return ValueMesh(call_shape, results) - - return Future(impl=process, requires_loop=False) - - async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R]: - """ - Broadcasts to all actors and yields their responses as a stream / generator. - - This enables processing results from multiple actors incrementally as - they become available. Returns an async generator of response values. - """ - p, r = port(self) - # pyre-ignore - extent = self._send(args, kwargs, port=p) - for _ in range(extent.nelements): - yield await r.recv() - - def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None: - """ - Fire-and-forget broadcast to all actors without waiting for actors to - acknowledge receipt. - - In other words, the return of this method does not guarrantee the - delivery of the message. - """ - # pyre-ignore - send(self, args, kwargs) - - class ActorEndpoint(Endpoint[P, R]): def __init__( self, @@ -396,7 +284,9 @@ def __init__( name: str, impl: Callable[Concatenate[Any, P], Awaitable[R]], mailbox: Mailbox, + propagator: Propagator = None, ) -> None: + super().__init__(propagator) self._actor_mesh = actor_mesh_ref self._name = name self._signature: inspect.Signature = inspect.signature(impl) @@ -406,6 +296,9 @@ def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any: mesh = self._actor_mesh._actor_mesh return r if mesh is None else mesh.supervise(r) + def _call_name(self) -> Any: + return self._name + def _send( self, args: Tuple[Any, ...], @@ -430,9 +323,7 @@ def _send( ) self._actor_mesh.cast(message, selection) else: - importlib.import_module("monarch." + "mesh_controller").actor_send( - self, bytes, refs, port, selection - ) + actor_send(self, bytes, refs, port, selection) shape = self._actor_mesh._shape return Extent(shape.labels, shape.ndslice.sizes) @@ -521,39 +412,6 @@ def send( endpoint._send(args, kwargs, port, selection) -class EndpointProperty(Generic[P, R]): - @overload - def __init__(self, method: Callable[Concatenate[Any, P], Awaitable[R]]) -> None: ... - - @overload - def __init__(self, method: Callable[Concatenate[Any, P], R]) -> None: ... - - def __init__(self, method: Any) -> None: - self._method = method - - def __get__(self, instance, owner) -> Endpoint[P, R]: - # this is a total lie, but we have to actually - # recognize this was defined as an endpoint, - # and also lookup the method - return cast(Endpoint[P, R], self) - - -@overload -def endpoint( - method: Callable[Concatenate[Any, P], Awaitable[R]], -) -> EndpointProperty[P, R]: ... - - -@overload -def endpoint( - method: Callable[Concatenate[Any, P], R], -) -> EndpointProperty[P, R]: ... - - -def endpoint(method): - return EndpointProperty(method) - - class Port(Generic[R]): def __init__( self, @@ -919,6 +777,7 @@ def __getattr__(self, name: str) -> Any: name, attr._method, self._mailbox, + propagator=attr._propagator, ) # Cache it for future use setattr(self, name, endpoint) diff --git a/python/monarch/_src/actor/code_sync/auto_reload.py b/python/monarch/_src/actor/code_sync/auto_reload.py index 7fb8e3017..d2e7dcc33 100644 --- a/python/monarch/_src/actor/code_sync/auto_reload.py +++ b/python/monarch/_src/actor/code_sync/auto_reload.py @@ -17,7 +17,8 @@ from types import ModuleType from typing import Dict, List, Optional, Tuple -from monarch._src.actor.actor_mesh import Actor, endpoint +from monarch._src.actor.actor_mesh import Actor +from monarch._src.actor.endpoint import endpoint class SysAuditHookGuard(contextlib.AbstractContextManager): diff --git a/python/monarch/_src/actor/debugger.py b/python/monarch/_src/actor/debugger.py index cb5386b62..c7b07739b 100644 --- a/python/monarch/_src/actor/debugger.py +++ b/python/monarch/_src/actor/debugger.py @@ -20,9 +20,9 @@ Actor, ActorMeshRef, DebugContext, - endpoint, MonarchContext, ) +from monarch._src.actor.endpoint import endpoint from monarch._src.actor.pdb_wrapper import DebuggerWrite, PdbWrapper from monarch._src.actor.sync_state import fake_sync_state diff --git a/python/monarch/_src/actor/endpoint.py b/python/monarch/_src/actor/endpoint.py new file mode 100644 index 000000000..2028f9d0c --- /dev/null +++ b/python/monarch/_src/actor/endpoint.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import functools +from abc import ABC, abstractmethod +from operator import mul +from typing import ( + Any, + AsyncGenerator, + Awaitable, + Callable, + cast, + Concatenate, + Dict, + Generic, + List, + Literal, + Optional, + overload, + ParamSpec, + Sequence, + Tuple, + TYPE_CHECKING, + TypeVar, +) + +from monarch._src.actor.future import Future +from monarch._src.actor.tensor_engine_shim import _cached_propagation, fake_call + +if TYPE_CHECKING: + from monarch._src.actor.actor_mesh import ( + HyPortReceiver, + OncePortReceiver, + Port, + PortTuple, + ValueMesh, + ) + +P = ParamSpec("P") +R = TypeVar("R") + +Selection = Literal["all", "choose"] | int + + +class Extent: + def __init__(self, labels: Sequence[str], sizes: Sequence[int]) -> None: + self.labels = labels + self.sizes = sizes + + @property + def nelements(self) -> int: + return functools.reduce(mul, self.sizes, 1) + + def __str__(self) -> str: + return str(dict(zip(self.labels, self.sizes))) + + +Propagator = Any + + +class Endpoint(ABC, Generic[P, R]): + def __init__(self, propagator: Propagator) -> None: + self._propagator_arg = propagator + self._cache: Optional[dict] = None + + @abstractmethod + def _send( + self, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + port: "Optional[Port]" = None, + selection: Selection = "all", + ) -> Extent: + """ + Implements sending a message to the endpoint. The return value of the endpoint will + be sent to port if provided. If port is not provided, the return will be dropped, + and any exception will cause the actor to fail. + + The return value is the (multi-dimension) size of the actors that were sent a message. + For ActorEndpoints this will be the actor_meshes size. For free-function endpoints, + this will be the size of the currently active proc_mesh. + """ + pass + + @abstractmethod + def _port(self, once: bool = False) -> "PortTuple[R]": + pass + + @abstractmethod + def _call_name(self) -> Any: + """ + Something to use in InputChecker to represent calling this thingy. + """ + pass + + def _supervise(self, r: "HyPortReceiver | OncePortReceiver") -> Any: + return r + + # the following are all 'adverbs' or different ways to handle the + # return values of this endpoint. Adverbs should only ever take *args, **kwargs + # of the original call. If we want to add syntax sugar for something that needs additional + # arguments, it should be implemented as function indepdendent of endpoint like `send` + # and `Accumulator` + def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: + """ + Load balanced sends a message to one chosen actor and awaits a result. + + Load balanced RPC-style entrypoint for request/response messaging. + """ + from monarch._src.actor.actor_mesh import port + + p, r = port(self, once=True) + # pyre-ignore + self._send(args, kwargs, port=p, selection="choose") + return r.recv() + + def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: + from monarch._src.actor.actor_mesh import port + + p, r = port(self, once=True) + # pyre-ignore + extent = self._send(args, kwargs, port=p, selection="choose") + if extent.nelements != 1: + raise ValueError( + f"Can only use 'call_one' on a single Actor but this actor has shape {extent}" + ) + return r.recv() + + def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]": + from monarch._src.actor.actor_mesh import ranked_port, ValueMesh + + p, r = ranked_port(self) + # pyre-ignore + extent = self._send(args, kwargs, port=p) + + async def process() -> "ValueMesh[R]": + from monarch._rust_bindings.monarch_hyperactor.shape import Shape + from monarch._src.actor.shape import NDSlice + + results: List[R] = [None] * extent.nelements # pyre-fixme[9] + for _ in range(extent.nelements): + rank, value = await r.recv() + results[rank] = value + call_shape = Shape( + extent.labels, + NDSlice.new_row_major(extent.sizes), + ) + return ValueMesh(call_shape, results) + + return Future(impl=process, requires_loop=False) + + async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R]: + """ + Broadcasts to all actors and yields their responses as a stream / generator. + + This enables processing results from multiple actors incrementally as + they become available. Returns an async generator of response values. + """ + from monarch._src.actor.actor_mesh import port + + p, r = port(self) + # pyre-ignore + extent = self._send(args, kwargs, port=p) + for _ in range(extent.nelements): + yield await r.recv() + + def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None: + """ + Fire-and-forget broadcast to all actors without waiting for actors to + acknowledge receipt. + + In other words, the return of this method does not guarrantee the + delivery of the message. + """ + from monarch._src.actor.actor_mesh import send + + # pyre-ignore + send(self, args, kwargs) + + def _propagate(self, args, kwargs, fake_args, fake_kwargs): + if self._propagator_arg is None or self._propagator_arg == "cached": + if self._cache is None: + self._cache = {} + return _cached_propagation(self._cache, self._resolvable, args, kwargs) + elif self._propagator_arg == "inspect": + return None + elif self._propagator_arg == "mocked": + raise NotImplementedError("mocked propagation") + else: + return fake_call(self._propagator_arg, *fake_args, **fake_kwargs) + + def _fetch_propagate(self, args, kwargs, fake_args, fake_kwargs): + if self._propagator_arg is None: + return # no propgator provided, so we just assume no mutations + return self._propagate(args, kwargs, fake_args, fake_kwargs) + + def _pipe_propagate(self, args, kwargs, fake_args, fake_kwargs): + if not callable(self._propagator_arg): + raise ValueError("Must specify explicit callable for pipe") + return self._propagate(args, kwargs, fake_args, fake_kwargs) + + +class EndpointProperty(Generic[P, R]): + @overload + def __init__( + self, + method: Callable[Concatenate[Any, P], Awaitable[R]], + propagator: Propagator, + ) -> None: ... + + @overload + def __init__( + self, method: Callable[Concatenate[Any, P], R], propagator: Propagator + ) -> None: ... + + def __init__(self, method: Any, propagator: Propagator) -> None: + self._method = method + self._propagator = propagator + + def __get__(self, instance, owner) -> Endpoint[P, R]: + # this is a total lie, but we have to actually + # recognize this was defined as an endpoint, + # and also lookup the method + return cast(Endpoint[P, R], self) + + +# This can't just be Callable because otherwise we are not +# allowed to use type arguments in the return value. +class EndpointIfy: + @overload + def __call__(self, function: Callable[P, Awaitable[R]]) -> Endpoint[P, R]: ... + @overload + def __call__(self, function: Callable[P, R]) -> Endpoint[P, R]: ... + + def __call__(self, function: Any): + pass + + +@overload +def endpoint( + method: Callable[Concatenate[Any, P], Awaitable[R]], + *, + propagate: Propagator = None, +) -> EndpointProperty[P, R]: ... + + +@overload +def endpoint( + method: Callable[Concatenate[Any, P], R], + *, + propagate: Propagator = None, +) -> EndpointProperty[P, R]: ... + + +@overload +def endpoint( + *, + propagate: Propagator = None, +) -> EndpointIfy: ... + + +def endpoint(method=None, *, propagate=None): + if method is None: + return functools.partial(endpoint, propagate=propagate) + return EndpointProperty(method, propagator=propagate) diff --git a/python/monarch/_src/actor/tensor_engine_shim.py b/python/monarch/_src/actor/tensor_engine_shim.py new file mode 100644 index 000000000..85c63ca07 --- /dev/null +++ b/python/monarch/_src/actor/tensor_engine_shim.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +from functools import partial +from typing import Any, Optional, Sequence, TYPE_CHECKING + +""" +This file provides a type annoated shim for using tensor engine functions +from within the actor module which only optionally includes the tensor engine. + +Each function that is needed should have a @shim entry below which gives the name, +module, and type of the function. Each function is resolved dynamically the first +time it is used. +""" + +if TYPE_CHECKING: + from monarch._src.actor.actor_mesh import ActorEndpoint, Port, Selection + from monarch._src.actor.endpoint import Endpoint + + +def shim(fn=None, *, module=None): + if fn is None: + return partial(shim, module=module) + + impl = None + name = fn.__name__ + + def wrap(*args, **kwargs): + nonlocal impl + if impl is None: + impl = getattr(importlib.import_module(module), name) + return impl(*args, **kwargs) + + return wrap + + +@shim(module="monarch.mesh_controller") +def actor_send( + endpoint: "ActorEndpoint", + args_kwargs_tuple: bytes, + refs: "Sequence[Any]", + port: "Optional[Port[Any]]", + selection: "Selection", +) -> None: ... + + +@shim(module="monarch.common.remote") +def _cached_propagation(_cache, rfunction: "Endpoint", args, kwargs) -> Any: ... + + +@shim(module="monarch.common.fake") +def fake_call(fn, *args, **kwargs): ... diff --git a/python/monarch/actor/__init__.py b/python/monarch/actor/__init__.py index 9f0ee2b62..9c85fc9dd 100644 --- a/python/monarch/actor/__init__.py +++ b/python/monarch/actor/__init__.py @@ -15,13 +15,13 @@ current_actor_name, current_rank, current_size, - endpoint, MonarchContext, Point, port, send, ValueMesh, ) +from monarch._src.actor.endpoint import endpoint from monarch._src.actor.future import Future from monarch._src.actor.proc_mesh import ( debug_client, diff --git a/python/monarch/common/remote.py b/python/monarch/common/remote.py index 1c19a360b..befb644c5 100644 --- a/python/monarch/common/remote.py +++ b/python/monarch/common/remote.py @@ -30,7 +30,8 @@ import torch from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox from monarch._rust_bindings.monarch_hyperactor.shape import Shape -from monarch._src.actor.actor_mesh import Extent, Port, PortTuple, Selection +from monarch._src.actor.actor_mesh import Port, PortTuple +from monarch._src.actor.endpoint import Extent, Selection from monarch.common import _coalescing, device_mesh, stream from monarch.common.future import Future as OldFuture @@ -38,7 +39,7 @@ if TYPE_CHECKING: from monarch.common.client import Client -from monarch._src.actor.actor_mesh import Endpoint +from monarch._src.actor.endpoint import Endpoint from monarch.common.device_mesh import RemoteProcessGroup from monarch.common.fake import fake_call @@ -70,9 +71,11 @@ class Remote(Generic[P, R], Endpoint[P, R]): def __init__(self, impl: Any, propagator_arg: Propagator): + super().__init__(propagator_arg) self._remote_impl = impl - self._propagator_arg = propagator_arg - self._cache: Optional[dict] = None + + def _call_name(self) -> Any: + return self._remote_impl def _send( self, @@ -154,28 +157,6 @@ def _resolvable(self): def _maybe_resolvable(self): return None if self._remote_impl is None else self._resolvable - def _propagate(self, args, kwargs, fake_args, fake_kwargs): - if self._propagator_arg is None or self._propagator_arg == "cached": - if self._cache is None: - self._cache = {} - return _cached_propagation(self._cache, self._resolvable, args, kwargs) - elif self._propagator_arg == "inspect": - return None - elif self._propagator_arg == "mocked": - raise NotImplementedError("mocked propagation") - else: - return fake_call(self._propagator_arg, *fake_args, **fake_kwargs) - - def _fetch_propagate(self, args, kwargs, fake_args, fake_kwargs): - if self._propagator_arg is None: - return # no propgator provided, so we just assume no mutations - return self._propagate(args, kwargs, fake_args, fake_kwargs) - - def _pipe_propagate(self, args, kwargs, fake_args, fake_kwargs): - if not callable(self._propagator_arg): - raise ValueError("Must specify explicit callable for pipe") - return self._propagate(args, kwargs, fake_args, fake_kwargs) - def rref(self, *args: P.args, **kwargs: P.kwargs) -> R: return dtensor_dispatch( self._resolvable, @@ -230,7 +211,7 @@ def remote(function: Any = None, *, propagate: Propagator = None) -> Any: def call_on_shard_and_fetch( - remote: Remote[P, R], *args, shard: Dict[str, int] | None = None, **kwargs + remote: Endpoint[P, R], *args, shard: Dict[str, int] | None = None, **kwargs ) -> OldFuture[R]: # We have to flatten the tensors twice: first to discover # which mesh we are working on to shard it, and then again when doing the @@ -238,13 +219,13 @@ def call_on_shard_and_fetch( # implicit inference of the mesh from the tensors. dtensors, unflatten = flatten((args, kwargs), lambda x: isinstance(x, torch.Tensor)) with InputChecker.from_flat_args( - remote._remote_impl, dtensors, unflatten + remote._call_name(), dtensors, unflatten ) as checker: checker.check_mesh_stream_local(device_mesh._active, stream._active) if not hasattr(checker.mesh.client, "_mesh_controller"): return _old_call_on_shard_and_fetch( - remote, + cast("Remote[P, R]", remote), *args, shard=shard, **kwargs, @@ -371,8 +352,9 @@ def _mock_pgs(x): _hit = 0 -def _cached_propagation(_cache, rfunction, args, kwargs): +def _cached_propagation(_cache, rfunction: Endpoint, args, kwargs): tensors, shape_key = hashable_tensor_flatten(args, kwargs) + # pyre-ignore inputs_group = TensorGroup([t._fake for t in tensors]) requires_grads = tuple(t.requires_grad for t in tensors) key = (shape_key, inputs_group.pattern, requires_grads) diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index 61f3101db..7334ded62 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -39,7 +39,8 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension ActorId, ) -from monarch._src.actor.actor_mesh import ActorEndpoint, Port, PortTuple, Selection +from monarch._src.actor.actor_mesh import ActorEndpoint, Port, PortTuple +from monarch._src.actor.endpoint import Selection from monarch._src.actor.shape import NDSlice from monarch.common import device_mesh, messages, stream from monarch.common.controller_api import TController diff --git a/python/tests/test_debugger.py b/python/tests/test_debugger.py index a4750d70e..b1071ad36 100644 --- a/python/tests/test_debugger.py +++ b/python/tests/test_debugger.py @@ -17,7 +17,7 @@ import torch -from monarch._src.actor.actor_mesh import Actor, endpoint, MonarchContext +from monarch._src.actor.actor_mesh import Actor, MonarchContext from monarch._src.actor.debugger import ( Attach, Cast, @@ -29,6 +29,7 @@ ListCommand, Quit, ) +from monarch._src.actor.endpoint import endpoint from monarch._src.actor.proc_mesh import proc_mesh