Skip to content

Add @endpoint(explicit_response_port=True) #634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions monarch_hyperactor/src/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,22 @@ pub enum UnflattenArg {
PyObject,
}

#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.actor")]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum MethodSpecifier {
/// Call method 'name', send its return value to the response port.
ReturnsResponse { name: String },
/// Call method 'name', send the response port as the first argument.
ExplicitPort { name: String },
/// Construct the object
Init {},
}

#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.actor")]
#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq)]
pub enum PythonMessageKind {
CallMethod {
name: String,
name: MethodSpecifier,
response_port: Option<EitherPortRef>,
},
Result {
Expand All @@ -202,7 +213,7 @@ pub enum PythonMessageKind {
},
Uninit {},
CallMethodIndirect {
name: String,
name: MethodSpecifier,
local_state_broker: (String, usize),
id: usize,
// specify whether the argument to unflatten the local mailbox,
Expand Down Expand Up @@ -230,7 +241,7 @@ pub struct PythonMessage {
}

struct ResolvedCallMethod {
method: String,
method: MethodSpecifier,
bytes: Vec<u8>,
local_state: PyObject,
/// Implements PortProtocol
Expand Down Expand Up @@ -862,6 +873,7 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
hyperactor_mod.add_class::<PythonActorHandle>()?;
hyperactor_mod.add_class::<PythonMessage>()?;
hyperactor_mod.add_class::<PythonMessageKind>()?;
hyperactor_mod.add_class::<MethodSpecifier>()?;
hyperactor_mod.add_class::<UnflattenArg>()?;
hyperactor_mod.add_class::<PanicFlag>()?;
hyperactor_mod.add_class::<PyPythonTask>()?;
Expand Down Expand Up @@ -892,7 +904,9 @@ mod tests {
);
let message = PythonMessage {
kind: PythonMessageKind::CallMethod {
name: "test".to_string(),
name: MethodSpecifier::ReturnsResponse {
name: "test".to_string(),
},
response_port: Some(EitherPortRef::Unbounded(port_ref.clone().into())),
},
message: vec![1, 2, 3],
Expand All @@ -913,7 +927,9 @@ mod tests {

let no_port_message = PythonMessage {
kind: PythonMessageKind::CallMethod {
name: "test".to_string(),
name: MethodSpecifier::ReturnsResponse {
name: "test".to_string(),
},
response_port: None,
},
..message
Expand Down
64 changes: 52 additions & 12 deletions python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,18 @@
import abc
from enum import Enum

from typing import Any, final, Iterable, List, Optional, Protocol, Tuple, Type
from typing import (
Any,
final,
Generic,
Iterable,
List,
Optional,
Protocol,
Tuple,
Type,
TypeVar,
)

from monarch._rust_bindings.monarch_hyperactor.mailbox import (
Mailbox,
Expand Down Expand Up @@ -126,30 +137,57 @@ class Exception(PythonMessageKind):

class CallMethod(PythonMessageKind):
def __init__(
self, name: str, response_port: PortRef | OncePortRef | None
self, name: MethodSpecifier, response_port: PortRef | OncePortRef | None
) -> None: ...
@property
def name(self) -> str: ...
def name(self) -> MethodSpecifier: ...
@property
def response_port(self) -> PortRef | OncePortRef | None: ...

class MethodSpecifier:
@classmethod
@property
def ReturnsResponse(cls) -> "Type[ReturnsResponse]": ...
@classmethod
@property
def ExplicitPort(cls) -> "Type[ExplicitPort]": ...
@classmethod
@property
def Init(cls) -> "Type[Init]": ...

class ReturnsResponse(MethodSpecifier):
def __init__(self, name: str) -> None: ...
@property
def name(self) -> str: ...

class ExplicitPort(MethodSpecifier):
def __init__(self, name: str) -> None: ...
@property
def name(self) -> str: ...

class Init(MethodSpecifier):
pass

class UnflattenArg(Enum):
Mailbox = 0
PyObject = 1

class CallMethodIndirect(PythonMessageKind):
def __init__(
self,
name: str,
name: MethodSpecifier,
broker_id: Tuple[str, int],
id: int,
unflatten_args: List[UnflattenArg],
) -> None: ...

class Init(PythonMessageKind):
def __init__(self, response_port: PortRef | OncePortRef | None) -> None: ...
@property
def response_port(self) -> PortRef | OncePortRef | None: ...
def name(self) -> MethodSpecifier: ...
@property
def broker_id(self) -> Tuple[str, int]: ...
@property
def id(self) -> int: ...
@property
def unflatten_args(self) -> List[UnflattenArg]: ...

class Uninit(PythonMessageKind):
pass
Expand Down Expand Up @@ -219,8 +257,10 @@ class PanicFlag:
"""
...

class PortProtocol(Protocol):
def send(self, obj: Any) -> None: ...
R = TypeVar("R")

class PortProtocol(Generic[R], Protocol):
def send(self, obj: R) -> None: ...
def exception(self, obj: Any) -> None: ...

class Actor(Protocol):
Expand All @@ -229,9 +269,9 @@ class Actor(Protocol):
mailbox: Mailbox,
rank: int,
shape: Shape,
method: str,
method: MethodSpecifier,
message: bytes,
panic_flag: PanicFlag,
local_state: Iterable[Any],
response_port: PortProtocol,
response_port: PortProtocol[Any],
) -> None: ...
95 changes: 74 additions & 21 deletions python/monarch/_src/actor/actor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
Iterable,
Iterator,
List,
Literal,
NamedTuple,
Optional,
overload,
ParamSpec,
Tuple,
Type,
Expand All @@ -39,6 +41,7 @@
)

from monarch._rust_bindings.monarch_hyperactor.actor import (
MethodSpecifier,
PanicFlag,
PythonMessage,
PythonMessageKind,
Expand Down Expand Up @@ -282,16 +285,18 @@ class ActorEndpoint(Endpoint[P, R]):
def __init__(
self,
actor_mesh_ref: _ActorMeshRefImpl,
name: str,
name: MethodSpecifier,
impl: Callable[Concatenate[Any, P], Awaitable[R]],
mailbox: Mailbox,
propagator: Propagator = None,
propagator: Propagator,
explicit_response_port: bool,
) -> None:
super().__init__(propagator)
self._actor_mesh = actor_mesh_ref
self._name = name
self._signature: inspect.Signature = inspect.signature(impl)
self._mailbox = mailbox
self._explicit_response_port = explicit_response_port

def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any:
mesh = self._actor_mesh._actor_mesh
Expand All @@ -300,6 +305,12 @@ def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any:
def _call_name(self) -> Any:
return self._name

def _check_arguments(self, args, kwargs):
if self._explicit_response_port:
self._signature.bind(None, None, *args, **kwargs)
else:
self._signature.bind(None, *args, **kwargs)

def _send(
self,
args: Tuple[Any, ...],
Expand All @@ -312,7 +323,7 @@ def _send(

This sends the message to all actors but does not wait for any result.
"""
self._signature.bind(None, *args, **kwargs)
self._check_arguments(args, kwargs)
objects, bytes = flatten((args, kwargs), _is_ref_or_mailbox)
if all(not hasattr(obj, "__monarch_ref__") for obj in objects):
message = PythonMessage(
Expand All @@ -336,23 +347,50 @@ def _port(self, once: bool = False) -> "PortTuple[R]":
return PortTuple(p, PortReceiver(self._mailbox, self._supervise(r._receiver)))

def _rref(self, args, kwargs):
self._signature.bind(None, *args, **kwargs)
self._check_arguments(args, kwargs)
refs, bytes = flatten((args, kwargs), _is_ref_or_mailbox)

return actor_rref(self, bytes, refs)


@overload
def as_endpoint(
not_an_endpoint: Callable[P, R],
*,
propagate: Propagator = None,
explicit_response_port: Literal[False] = False,
) -> Endpoint[P, R]: ...


@overload
def as_endpoint(
not_an_endpoint: Callable[Concatenate["PortProtocol[R]", P], None],
*,
propagate: Propagator = None,
explicit_response_port: Literal[True],
) -> Endpoint[P, R]: ...


def as_endpoint(
not_an_endpoint: Callable[P, R], *, propagate: Propagator = None
) -> Endpoint[P, R]:
not_an_endpoint: Any,
*,
propagate: Propagator = None,
explicit_response_port: bool = False,
):
if not isinstance(not_an_endpoint, NotAnEndpoint):
raise ValueError("expected an method of a spawned actor")
kind = (
MethodSpecifier.ExplicitPort
if explicit_response_port
else MethodSpecifier.ReturnsResponse
)
return ActorEndpoint(
not_an_endpoint._ref._actor_mesh_ref,
not_an_endpoint._name,
kind(not_an_endpoint._name),
getattr(not_an_endpoint._ref, not_an_endpoint._name),
not_an_endpoint._ref._mailbox,
propagate,
explicit_response_port,
)


Expand Down Expand Up @@ -598,7 +636,7 @@ async def handle(
mailbox: Mailbox,
rank: int,
shape: Shape,
method: str,
method_spec: MethodSpecifier,
message: bytes,
panic_flag: PanicFlag,
local_state: Iterable[Any],
Expand All @@ -616,17 +654,23 @@ async def handle(

args, kwargs = unflatten(message, local_state)

if method == "__init__":
Class, *args = args
try:
self.instance = Class(*args, **kwargs)
except Exception as e:
self._saved_error = ActorError(
e, f"Remote actor {Class}.__init__ call failed."
)
raise e
port.send(None)
return None
match method_spec:
case MethodSpecifier.Init():
Class, *args = args
try:
self.instance = Class(*args, **kwargs)
except Exception as e:
self._saved_error = ActorError(
e, f"Remote actor {Class}.__init__ call failed."
)
raise e
port.send(None)
return None
case MethodSpecifier.ReturnsResponse(name=method):
pass
case MethodSpecifier.ExplicitPort(name=method):
args = (port, *args)
port = DroppingPort()

if self.instance is None:
# This could happen because of the following reasons. Both
Expand Down Expand Up @@ -775,15 +819,22 @@ def __init__(
for attr_name in dir(self._class):
attr_value = getattr(self._class, attr_name, None)
if isinstance(attr_value, EndpointProperty):
# Convert string method name to appropriate MethodSpecifier
kind = (
MethodSpecifier.ExplicitPort
if attr_value._explicit_response_port
else MethodSpecifier.ReturnsResponse
)
setattr(
self,
attr_name,
ActorEndpoint(
self._actor_mesh_ref,
attr_name,
kind(attr_name),
attr_value._method,
self._mailbox,
attr_value._propagator,
attr_value._explicit_response_port,
),
)

Expand All @@ -802,9 +853,11 @@ async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None:

ep = ActorEndpoint(
self._actor_mesh_ref,
"__init__",
MethodSpecifier.Init(),
null_func,
self._mailbox,
None,
False,
)
send(ep, (self._class, *args), kwargs)

Expand Down
Loading
Loading