From 2bc4b4369389833c216f6b3c9421fc36612a534d Mon Sep 17 00:00:00 2001 From: zdevito Date: Thu, 24 Jul 2025 09:21:19 -0700 Subject: [PATCH 1/2] Add @endpoint(explicit_response_port=True) This option exposes an option that rust actors have always had to get the response port as an argument instead of using the return value of the function as the response. Having this feature allows synchronous actors to defer responding to a message and process other messages, making them as expressive as asynchronous actors which accomplished this by awaiting. Enjoy the giant @overload copypasta to appease the python typecheckers. Differential Revision: [D78901486](https://our.internmc.facebook.com/intern/diff/D78901486/) [ghstack-poisoned] --- monarch_hyperactor/src/actor.rs | 18 +++- .../monarch_hyperactor/actor.pyi | 64 ++++++++++--- python/monarch/_src/actor/actor_mesh.py | 95 +++++++++++++++---- python/monarch/_src/actor/endpoint.py | 68 ++++++++++++- python/monarch/mesh_controller.py | 93 +++--------------- python/tests/_monarch/test_actor.py | 6 +- python/tests/_monarch/test_mailbox.py | 18 +++- python/tests/test_python_actors.py | 12 +++ 8 files changed, 249 insertions(+), 125 deletions(-) diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index d8cd381d4..9c2a01440 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -187,11 +187,22 @@ pub enum UnflattenArg { PyObject, } +#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.actor")] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum MethodSpecifier { + /// Call method 'name', send its return value to the response port. + ReturnsResponse { name: String }, + /// Call method 'name', send the response port as the first argument. + ExplicitPort { name: String }, + /// Construct the object + Init {}, +} + #[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.actor")] #[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq)] pub enum PythonMessageKind { CallMethod { - name: String, + name: MethodSpecifier, response_port: Option, }, Result { @@ -202,7 +213,7 @@ pub enum PythonMessageKind { }, Uninit {}, CallMethodIndirect { - name: String, + name: MethodSpecifier, local_state_broker: (String, usize), id: usize, // specify whether the argument to unflatten the local mailbox, @@ -230,7 +241,7 @@ pub struct PythonMessage { } struct ResolvedCallMethod { - method: String, + method: MethodSpecifier, bytes: Vec, local_state: PyObject, /// Implements PortProtocol @@ -843,6 +854,7 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; + hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi index 8827da0a1..1fc4607b2 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi @@ -9,7 +9,18 @@ import abc from enum import Enum -from typing import Any, final, Iterable, List, Optional, Protocol, Tuple, Type +from typing import ( + Any, + final, + Generic, + Iterable, + List, + Optional, + Protocol, + Tuple, + Type, + TypeVar, +) from monarch._rust_bindings.monarch_hyperactor.mailbox import ( Mailbox, @@ -126,13 +137,37 @@ class Exception(PythonMessageKind): class CallMethod(PythonMessageKind): def __init__( - self, name: str, response_port: PortRef | OncePortRef | None + self, name: MethodSpecifier, response_port: PortRef | OncePortRef | None ) -> None: ... @property - def name(self) -> str: ... + def name(self) -> MethodSpecifier: ... @property def response_port(self) -> PortRef | OncePortRef | None: ... +class MethodSpecifier: + @classmethod + @property + def ReturnsResponse(cls) -> "Type[ReturnsResponse]": ... + @classmethod + @property + def ExplicitPort(cls) -> "Type[ExplicitPort]": ... + @classmethod + @property + def Init(cls) -> "Type[Init]": ... + +class ReturnsResponse(MethodSpecifier): + def __init__(self, name: str) -> None: ... + @property + def name(self) -> str: ... + +class ExplicitPort(MethodSpecifier): + def __init__(self, name: str) -> None: ... + @property + def name(self) -> str: ... + +class Init(MethodSpecifier): + pass + class UnflattenArg(Enum): Mailbox = 0 PyObject = 1 @@ -140,16 +175,19 @@ class UnflattenArg(Enum): class CallMethodIndirect(PythonMessageKind): def __init__( self, - name: str, + name: MethodSpecifier, broker_id: Tuple[str, int], id: int, unflatten_args: List[UnflattenArg], ) -> None: ... - -class Init(PythonMessageKind): - def __init__(self, response_port: PortRef | OncePortRef | None) -> None: ... @property - def response_port(self) -> PortRef | OncePortRef | None: ... + def name(self) -> MethodSpecifier: ... + @property + def broker_id(self) -> Tuple[str, int]: ... + @property + def id(self) -> int: ... + @property + def unflatten_args(self) -> List[UnflattenArg]: ... class Uninit(PythonMessageKind): pass @@ -219,8 +257,10 @@ class PanicFlag: """ ... -class PortProtocol(Protocol): - def send(self, obj: Any) -> None: ... +R = TypeVar("R") + +class PortProtocol(Generic[R], Protocol): + def send(self, obj: R) -> None: ... def exception(self, obj: Any) -> None: ... class Actor(Protocol): @@ -229,9 +269,9 @@ class Actor(Protocol): mailbox: Mailbox, rank: int, shape: Shape, - method: str, + method: MethodSpecifier, message: bytes, panic_flag: PanicFlag, local_state: Iterable[Any], - response_port: PortProtocol, + response_port: PortProtocol[Any], ) -> None: ... diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 855c1532f..b07ce7dbf 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -29,8 +29,10 @@ Iterable, Iterator, List, + Literal, NamedTuple, Optional, + overload, ParamSpec, Tuple, Type, @@ -39,6 +41,7 @@ ) from monarch._rust_bindings.monarch_hyperactor.actor import ( + MethodSpecifier, PanicFlag, PythonMessage, PythonMessageKind, @@ -282,16 +285,18 @@ class ActorEndpoint(Endpoint[P, R]): def __init__( self, actor_mesh_ref: _ActorMeshRefImpl, - name: str, + name: MethodSpecifier, impl: Callable[Concatenate[Any, P], Awaitable[R]], mailbox: Mailbox, - propagator: Propagator = None, + propagator: Propagator, + explicit_response_port: bool, ) -> None: super().__init__(propagator) self._actor_mesh = actor_mesh_ref self._name = name self._signature: inspect.Signature = inspect.signature(impl) self._mailbox = mailbox + self._explicit_response_port = explicit_response_port def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any: mesh = self._actor_mesh._actor_mesh @@ -300,6 +305,12 @@ def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any: def _call_name(self) -> Any: return self._name + def _check_arguments(self, args, kwargs): + if self._explicit_response_port: + self._signature.bind(None, None, *args, **kwargs) + else: + self._signature.bind(None, *args, **kwargs) + def _send( self, args: Tuple[Any, ...], @@ -312,7 +323,7 @@ def _send( This sends the message to all actors but does not wait for any result. """ - self._signature.bind(None, *args, **kwargs) + self._check_arguments(args, kwargs) objects, bytes = flatten((args, kwargs), _is_ref_or_mailbox) if all(not hasattr(obj, "__monarch_ref__") for obj in objects): message = PythonMessage( @@ -336,23 +347,50 @@ def _port(self, once: bool = False) -> "PortTuple[R]": return PortTuple(p, PortReceiver(self._mailbox, self._supervise(r._receiver))) def _rref(self, args, kwargs): - self._signature.bind(None, *args, **kwargs) + self._check_arguments(args, kwargs) refs, bytes = flatten((args, kwargs), _is_ref_or_mailbox) return actor_rref(self, bytes, refs) +@overload +def as_endpoint( + not_an_endpoint: Callable[P, R], + *, + propagate: Propagator = None, + explicit_response_port: Literal[False] = False, +) -> Endpoint[P, R]: ... + + +@overload +def as_endpoint( + not_an_endpoint: Callable[Concatenate["PortProtocol[R]", P], None], + *, + propagate: Propagator = None, + explicit_response_port: Literal[True], +) -> Endpoint[P, R]: ... + + def as_endpoint( - not_an_endpoint: Callable[P, R], *, propagate: Propagator = None -) -> Endpoint[P, R]: + not_an_endpoint: Any, + *, + propagate: Propagator = None, + explicit_response_port: bool = False, +): if not isinstance(not_an_endpoint, NotAnEndpoint): raise ValueError("expected an method of a spawned actor") + kind = ( + MethodSpecifier.ExplicitPort + if explicit_response_port + else MethodSpecifier.ReturnsResponse + ) return ActorEndpoint( not_an_endpoint._ref._actor_mesh_ref, - not_an_endpoint._name, + kind(not_an_endpoint._name), getattr(not_an_endpoint._ref, not_an_endpoint._name), not_an_endpoint._ref._mailbox, propagate, + explicit_response_port, ) @@ -598,7 +636,7 @@ async def handle( mailbox: Mailbox, rank: int, shape: Shape, - method: str, + method_spec: MethodSpecifier, message: bytes, panic_flag: PanicFlag, local_state: Iterable[Any], @@ -616,17 +654,23 @@ async def handle( args, kwargs = unflatten(message, local_state) - if method == "__init__": - Class, *args = args - try: - self.instance = Class(*args, **kwargs) - except Exception as e: - self._saved_error = ActorError( - e, f"Remote actor {Class}.__init__ call failed." - ) - raise e - port.send(None) - return None + match method_spec: + case MethodSpecifier.Init(): + Class, *args = args + try: + self.instance = Class(*args, **kwargs) + except Exception as e: + self._saved_error = ActorError( + e, f"Remote actor {Class}.__init__ call failed." + ) + raise e + port.send(None) + return None + case MethodSpecifier.ReturnsResponse(name=method): + pass + case MethodSpecifier.ExplicitPort(name=method): + args = (port, *args) + port = DroppingPort() if self.instance is None: # This could happen because of the following reasons. Both @@ -775,15 +819,22 @@ def __init__( for attr_name in dir(self._class): attr_value = getattr(self._class, attr_name, None) if isinstance(attr_value, EndpointProperty): + # Convert string method name to appropriate MethodSpecifier + kind = ( + MethodSpecifier.ExplicitPort + if attr_value._explicit_response_port + else MethodSpecifier.ReturnsResponse + ) setattr( self, attr_name, ActorEndpoint( self._actor_mesh_ref, - attr_name, + kind(attr_name), attr_value._method, self._mailbox, attr_value._propagator, + attr_value._explicit_response_port, ), ) @@ -802,9 +853,11 @@ async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None: ep = ActorEndpoint( self._actor_mesh_ref, - "__init__", + MethodSpecifier.Init(), null_func, self._mailbox, + None, + False, ) send(ep, (self._class, *args), kwargs) diff --git a/python/monarch/_src/actor/endpoint.py b/python/monarch/_src/actor/endpoint.py index 411647e18..bb7f5feee 100644 --- a/python/monarch/_src/actor/endpoint.py +++ b/python/monarch/_src/actor/endpoint.py @@ -223,16 +223,23 @@ def __init__( self, method: Callable[Concatenate[Any, P], Awaitable[R]], propagator: Propagator, + explicit_response_port: bool, ) -> None: ... @overload def __init__( - self, method: Callable[Concatenate[Any, P], R], propagator: Propagator + self, + method: Callable[Concatenate[Any, P], R], + propagator: Propagator, + explicit_response_port: bool, ) -> None: ... - def __init__(self, method: Any, propagator: Propagator) -> None: + def __init__( + self, method: Any, propagator: Propagator, explicit_response_port: bool + ) -> None: self._method = method self._propagator = propagator + self._explicit_response_port = explicit_response_port def __get__(self, instance, owner) -> Endpoint[P, R]: # this is a total lie, but we have to actually @@ -274,11 +281,28 @@ def __call__(self, function: Any): pass +class PortedEndpointIfy: + @overload + def __call__( + self, + function: Callable[Concatenate[Any, "Port[R]", P], Awaitable[None]], + ) -> Endpoint[P, R]: ... + + @overload + def __call__( + self, function: Callable[Concatenate[Any, "Port[R]", P], None] + ) -> Endpoint[P, R]: ... + + def __call__(self, function: Any): + pass + + @overload def endpoint( method: Callable[Concatenate[Any, P], Awaitable[R]], *, propagate: Propagator = None, + explicit_response_port: Literal[False] = False, ) -> EndpointProperty[P, R]: ... @@ -287,6 +311,7 @@ def endpoint( method: Callable[Concatenate[Any, P], R], *, propagate: Propagator = None, + explicit_response_port: Literal[False] = False, ) -> EndpointProperty[P, R]: ... @@ -294,10 +319,43 @@ def endpoint( def endpoint( *, propagate: Propagator = None, + explicit_response_port: Literal[False] = False, ) -> EndpointIfy: ... -def endpoint(method=None, *, propagate=None): +@overload +def endpoint( + method: Callable[Concatenate[Any, "Port[R]", P], Awaitable[None]], + *, + propagate: Propagator = None, + explicit_response_port: Literal[True], +) -> EndpointProperty[P, R]: ... + + +@overload +def endpoint( + method: Callable[Concatenate[Any, "Port[R]", P], None], + *, + propagate: Propagator = None, + explicit_response_port: Literal[True], +) -> EndpointProperty[P, R]: ... + + +@overload +def endpoint( + *, + propagate: Propagator = None, + explicit_response_port: Literal[True], +) -> PortedEndpointIfy: ... + + +def endpoint(method=None, *, propagate=None, explicit_response_port: bool = False): if method is None: - return functools.partial(endpoint, propagate=propagate) - return EndpointProperty(method, propagator=propagate) + return functools.partial( + endpoint, + propagate=propagate, + explicit_response_port=explicit_response_port, + ) + return EndpointProperty( + method, propagator=propagate, explicit_response_port=explicit_response_port + ) diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index 8b520c5d1..61f3101db 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -30,7 +30,6 @@ WorldState, ) from monarch._rust_bindings.monarch_extension.mesh_controller import _Controller -from monarch._rust_bindings.monarch_extension.tensor_worker import Ref from monarch._rust_bindings.monarch_hyperactor.actor import ( PythonMessage, PythonMessageKind, @@ -40,17 +39,14 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension ActorId, ) -from monarch._src.actor.actor_mesh import ActorEndpoint, Port, PortTuple -from monarch._src.actor.endpoint import Selection +from monarch._src.actor.actor_mesh import ActorEndpoint, Port, PortTuple, Selection from monarch._src.actor.shape import NDSlice from monarch.common import device_mesh, messages, stream from monarch.common.controller_api import TController -from monarch.common.function import ResolvableFunction from monarch.common.invocation import Seq from monarch.common.messages import Referenceable, SendResultOfActorCall from monarch.common.stream import StreamRef -from monarch.common.tensor import dtensor_check, InputChecker, Tensor -from monarch.common.tree import flatten +from monarch.common.tensor import InputChecker, Tensor from monarch.tensor_worker_main import _set_trace if TYPE_CHECKING: @@ -268,29 +264,6 @@ def __str__(self): return "" -def _cast_call_method_indirect( - endpoint: ActorEndpoint, - selection: Selection, - client: MeshClient, - seq: Seq, - args_kwargs_tuple: bytes, - refs: Sequence[Any], -) -> Tuple[str, int]: - unflatten_args = [ - UnflattenArg.PyObject if isinstance(ref, Tensor) else UnflattenArg.Mailbox - for ref in refs - ] - broker_id: Tuple[str, int] = client._mesh_controller.broker_id - actor_msg = PythonMessage( - PythonMessageKind.CallMethodIndirect( - endpoint._name, broker_id, seq, unflatten_args - ), - args_kwargs_tuple, - ) - endpoint._actor_mesh.cast(actor_msg, selection) - return broker_id - - def actor_send( endpoint: ActorEndpoint, args_kwargs_tuple: bytes, @@ -298,6 +271,10 @@ def actor_send( port: Optional[Port[Any]], selection: Selection, ): + unflatten_args = [ + UnflattenArg.PyObject if isinstance(ref, Tensor) else UnflattenArg.Mailbox + for ref in refs + ] tensors = [ref for ref in refs if isinstance(ref, Tensor)] # we have some monarch references, we need to ensure their # proc_mesh matches that of the tensors we sent to it @@ -306,7 +283,7 @@ def actor_send( if hasattr(t, "stream"): chosen_stream = t.stream break - with InputChecker(tensors, lambda x: f"actor_call({x})") as checker: + with InputChecker(refs, lambda x: f"actor_call({x})") as checker: checker.check_mesh_stream_local(device_mesh._active, chosen_stream) # TODO: move propagators into Endpoint abstraction and run the propagator to get the # mutates @@ -322,6 +299,8 @@ def actor_send( client = cast(MeshClient, checker.mesh.client) + broker_id: Tuple[str, int] = client._mesh_controller.broker_id + stream_ref = chosen_stream._to_ref(client) fut = (port, checker.mesh._ndslice) if port is not None else None @@ -336,9 +315,13 @@ def actor_send( # The message to the generic actor tells it to first wait on the broker to get the local arguments # from the stream, then it will run the actor method, and send the result to response port. - broker_id = _cast_call_method_indirect( - endpoint, selection, client, ident, args_kwargs_tuple, refs + actor_msg = PythonMessage( + PythonMessageKind.CallMethodIndirect( + endpoint._name, broker_id, ident, unflatten_args + ), + args_kwargs_tuple, ) + endpoint._actor_mesh.cast(actor_msg, selection) worker_msg = SendResultOfActorCall(ident, broker_id, tensors, [], stream_ref) client.send(checker.mesh._ndslice, worker_msg) # we have to ask for status updates @@ -346,49 +329,3 @@ def actor_send( # enough work to count this future as finished, # and all potential errors have been reported client._request_status() - - -def actor_rref(endpoint, args_kwargs_tuple: bytes, refs: Sequence[Any]): - chosen_stream = stream._active - fake_result, dtensors, mutates, mesh = dtensor_check( - endpoint._propagate, - cast(ResolvableFunction, endpoint._name), - refs, - {}, - device_mesh._active, - chosen_stream, - ) - assert mesh is not None - - fake_result_dtensors, unflatten_result = flatten( - fake_result, lambda x: isinstance(x, torch.Tensor) - ) - result_dtensors = tuple( - Tensor(fake, mesh, chosen_stream) for fake in fake_result_dtensors - ) - seq = mesh.client.new_node(result_dtensors + mutates, dtensors) - assert all(t.ref is not None for t in result_dtensors) - assert all(t.ref is not None for t in mutates) - result = result_msg = unflatten_result(result_dtensors) - if len(result_dtensors) == 0: - result_msg = None - - broker_id = _cast_call_method_indirect( - endpoint, "all", mesh.client, seq, args_kwargs_tuple, refs - ) - # note the device mesh has to be defined regardles so the remote functions - # can invoke mesh.rank("...") - - mesh.define_remotely() - - mesh._send( - messages.CallActorMethod( - seq, - result_msg, - broker_id, - refs, - cast("List[Ref]", mutates), - stream._active._to_ref(mesh.client), - ) - ) - return result diff --git a/python/tests/_monarch/test_actor.py b/python/tests/_monarch/test_actor.py index 31c06519f..c913f4833 100644 --- a/python/tests/_monarch/test_actor.py +++ b/python/tests/_monarch/test_actor.py @@ -9,6 +9,7 @@ import time from monarch._rust_bindings.monarch_hyperactor.actor import ( + MethodSpecifier, PythonMessage, PythonMessageKind, ) @@ -22,6 +23,9 @@ def test_python_message() -> None: payload: str = "a" * 2**30 # 1gb blob: bytes = payload.encode("utf-8") t = time.time() - PythonMessage(PythonMessageKind.CallMethod(method, None), blob) + PythonMessage( + PythonMessageKind.CallMethod(MethodSpecifier.ReturnsResponse(method), None), + blob, + ) t_spent = time.time() - t assert t_spent < 1 diff --git a/python/tests/_monarch/test_mailbox.py b/python/tests/_monarch/test_mailbox.py index 7c82af883..c4f565178 100644 --- a/python/tests/_monarch/test_mailbox.py +++ b/python/tests/_monarch/test_mailbox.py @@ -13,6 +13,7 @@ import monarch from monarch._rust_bindings.monarch_hyperactor.actor import ( + MethodSpecifier, PanicFlag, PythonMessage, PythonMessageKind, @@ -75,7 +76,9 @@ def __call__(self, state: PythonMessage, update: PythonMessage) -> PythonMessage @property def initial_state(self) -> PythonMessage: return PythonMessage( - PythonMessageKind.CallMethod(" @Accumulator.initial_state", None), + PythonMessageKind.CallMethod( + MethodSpecifier.ReturnsResponse(" @Accumulator.initial_state"), None + ), pickle.dumps(self._initial_state), ) @@ -108,7 +111,9 @@ def post_message(value: int) -> None: port_ref.send( mailbox, PythonMessage( - PythonMessageKind.CallMethod("test_accumulator", None), + PythonMessageKind.CallMethod( + MethodSpecifier.ReturnsResponse("test_accumulator"), None + ), pickle.dumps(value), ), ) @@ -135,11 +140,11 @@ async def handle( mailbox: Mailbox, rank: int, shape: Shape, - method: str, + method: MethodSpecifier, message: bytes, panic_flag: PanicFlag, local_state: Iterable[Any], - response_port: "PortProtocol", + response_port: "PortProtocol[Any]", ) -> None: response_port.send(pickle.loads(message)) for i in range(100): @@ -164,7 +169,10 @@ def my_reduce(state: str, update: str) -> str: actor_mesh.cast( Selection.from_string("*"), PythonMessage( - PythonMessageKind.CallMethod("echo", port_ref), pickle.dumps("start") + PythonMessageKind.CallMethod( + MethodSpecifier.ReturnsResponse("echo"), port_ref + ), + pickle.dumps("start"), ), ) diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 5e69caaca..a2bccb59d 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -586,3 +586,15 @@ async def test_actor_mesh_stop(self) -> None: await am_2.print.call("hello 3") await am_2.log.call("hello 4") + + +class PortedActor(Actor): + @endpoint(explicit_response_port=True) + def add(self, port: "Port[int]", b: int) -> None: + port.send(3 + b) + + +def test_ported_actor(): + proc_mesh = local_proc_mesh(gpus=1).get() + a = proc_mesh.spawn("port_actor", PortedActor).get() + assert 5 == a.add.call_one(2).get() From b8e90a97d55e379ff396285f629e91f14fc43798 Mon Sep 17 00:00:00 2001 From: zdevito Date: Thu, 24 Jul 2025 12:18:18 -0700 Subject: [PATCH 2/2] Update on "Add @endpoint(explicit_response_port=True)" This option exposes an option that rust actors have always had to get the response port as an argument instead of using the return value of the function as the response. Having this feature allows synchronous actors to defer responding to a message and process other messages, making them as expressive as asynchronous actors which accomplished this by awaiting. Enjoy the giant overload copypasta to appease the python typecheckers. Differential Revision: [D78901486](https://our.internmc.facebook.com/intern/diff/D78901486/) [ghstack-poisoned] --- monarch_hyperactor/src/actor.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index 9c2a01440..b81b3d649 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -885,7 +885,9 @@ mod tests { ); let message = PythonMessage { kind: PythonMessageKind::CallMethod { - name: "test".to_string(), + name: MethodSpecifier::ReturnsResponse { + name: "test".to_string(), + }, response_port: Some(EitherPortRef::Unbounded(port_ref.clone().into())), }, message: vec![1, 2, 3], @@ -906,7 +908,9 @@ mod tests { let no_port_message = PythonMessage { kind: PythonMessageKind::CallMethod { - name: "test".to_string(), + name: MethodSpecifier::ReturnsResponse { + name: "test".to_string(), + }, response_port: None, }, ..message