From ac8eb221f75b537f368a06be0f6f6a64ad9115d9 Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Wed, 16 Jul 2025 10:39:30 -0700 Subject: [PATCH 1/2] Add def slice to PythonActorMesh and PythonActorMeshRef (#551) Summary: This diffs adds `def slice` method to both `PythonActorMesh` and `PythonActorMeshRef`. With this method, we can: 1. slice a `PythonActorMesh` object into a `PythonActorMeshRef`; 1. slice a `PythonActorMeshRef` into another `PythonActorMeshRef`. Tests are added to demo that we can cast to the sliced mesh ref. Reviewed By: shayne-fletcher Differential Revision: D78292490 --- hyperactor_mesh/src/actor_mesh.rs | 79 +++++++++++------ hyperactor_mesh/src/reference.rs | 74 +++++++++++----- monarch_hyperactor/src/actor_mesh.rs | 76 +++++++++++++++++ .../monarch_hyperactor/actor_mesh.pyi | 24 +++++- python/tests/_monarch/test_actor_mesh.py | 85 +++++++++++++++++-- 5 files changed, 284 insertions(+), 54 deletions(-) diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index 0ba8f4956..de54ae755 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -36,6 +36,7 @@ use ndslice::Range; use ndslice::Selection; use ndslice::Shape; use ndslice::ShapeError; +use ndslice::SliceError; use ndslice::selection; use ndslice::selection::EvalOpts; use ndslice::selection::ReifyView; @@ -95,6 +96,47 @@ where Ok(()) } +#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`. +pub(crate) fn cast_to_sliced_mesh( + caps: &impl cap::CanSend, + actor_mesh_id: ActorMeshId, + sender: &ActorId, + comm_actor_ref: &ActorRef, + sel_of_sliced: &Selection, + message: M, + sliced_shape: &Shape, + base_shape: &Shape, +) -> Result<(), CastError> +where + A: RemoteActor + RemoteHandles>, + M: Castable + RemoteMessage, +{ + let base_slice = base_shape.slice(); + + // Casting to `*`? + let sel_of_base = if selection::normalize(sel_of_sliced) == normal::NormalizedSelection::True { + // Reify this view into base. + base_slice.reify_view(sliced_shape.slice())? + } else { + // No, fall back on `of_ranks`. + let ranks = sel_of_sliced + .eval(&EvalOpts::strict(), sliced_shape.slice())? + .collect::>(); + Selection::of_ranks(base_slice, &ranks)? + }; + + // Cast. + actor_mesh_cast::( + caps, + actor_mesh_id, + base_shape, + sender, + comm_actor_ref, + sel_of_base, + message, + ) +} + /// A mesh of actors, all of which reside on the same [`ProcMesh`]. pub trait ActorMesh: Mesh { /// The type of actor in the mesh. @@ -350,31 +392,15 @@ impl ActorMesh for SlicedActorMesh<'_, A> { Self::Actor: RemoteHandles>, M: Castable + RemoteMessage, { - let base_shape = self.0.shape(); - let base_slice = base_shape.slice(); - - // Casting to `*`? - let selection = if selection::normalize(&sel) == normal::NormalizedSelection::True { - // Reify this view into base. - base_slice.reify_view(self.shape().slice()).unwrap() - } else { - // No, fall back on `of_ranks`. - let ranks = sel - .eval(&EvalOpts::strict(), self.shape().slice()) - .unwrap() - .collect::>(); - Selection::of_ranks(base_slice, &ranks).unwrap() - }; - - // Cast. - actor_mesh_cast::( - self.proc_mesh().client(), // send capability - self.id(), // actor mesh id (destination mesh) - base_shape, // actor mesh shape - self.proc_mesh().client().actor_id(), // sender - self.proc_mesh().comm_actor(), // comm actor - selection, // the selected actors - message, // the message + cast_to_sliced_mesh::( + /*caps=*/ self.proc_mesh().client(), + /*actor_mesh_id=*/ self.id(), + /*sender=*/ self.proc_mesh().client().actor_id(), + /*comm_actor_ref*/ self.proc_mesh().comm_actor(), + /*sel_of_sliced=*/ &sel, + /*message=*/ message, + /*sliced_shape=*/ self.shape(), + /*base_shape=*/ self.0.shape(), ) } } @@ -394,6 +420,9 @@ pub enum CastError { #[error(transparent)] ShapeError(#[from] ShapeError), + #[error(transparent)] + SliceError(#[from] SliceError), + #[error(transparent)] SerializationError(#[from] bincode::Error), diff --git a/hyperactor_mesh/src/reference.rs b/hyperactor_mesh/src/reference.rs index 625b50ce0..d97efae74 100644 --- a/hyperactor_mesh/src/reference.rs +++ b/hyperactor_mesh/src/reference.rs @@ -19,14 +19,17 @@ use hyperactor::actor::RemoteActor; use hyperactor::cap; use hyperactor::message::Castable; use hyperactor::message::IndexedErasedUnbound; +use ndslice::Range; use ndslice::Selection; use ndslice::Shape; +use ndslice::ShapeError; use serde::Deserialize; use serde::Serialize; use crate::CommActor; use crate::actor_mesh::CastError; use crate::actor_mesh::actor_mesh_cast; +use crate::actor_mesh::cast_to_sliced_mesh; #[macro_export] macro_rules! mesh_id { @@ -71,10 +74,15 @@ pub struct ProcMeshId(pub String); pub struct ActorMeshId(pub ProcMeshId, pub String); /// Types references to Actor Meshes. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, PartialEq)] pub struct ActorMeshRef { pub(crate) mesh_id: ActorMeshId, - shape: Shape, + /// The shape of the root mesh. + root: Shape, + /// If some, it mean this mesh ref points to a sliced mesh, and this field + /// is this sliced mesh's shape. If None, it means this mesh ref points to + /// the root mesh. + sliced: Option, /// The reference to the comm actor of the underlying Proc Mesh. comm_actor_ref: ActorRef, phantom: PhantomData, @@ -87,12 +95,13 @@ impl ActorMeshRef { /// line argument) is a valid reference. pub(crate) fn attest( mesh_id: ActorMeshId, - shape: Shape, + root: Shape, comm_actor_ref: ActorRef, ) -> Self { Self { mesh_id, - shape, + root, + sliced: None, comm_actor_ref, phantom: PhantomData, } @@ -105,7 +114,10 @@ impl ActorMeshRef { /// Shape of the Actor Mesh. pub fn shape(&self) -> &Shape { - &self.shape + match &self.sliced { + Some(s) => s, + None => &self.root, + } } /// Cast an [`M`]-typed message to the ranks selected by `sel` @@ -121,15 +133,38 @@ impl ActorMeshRef { A: RemoteHandles + RemoteHandles>, M: Castable + RemoteMessage, { - actor_mesh_cast::( - caps, - self.mesh_id.clone(), - self.shape(), - caps.mailbox().actor_id(), - &self.comm_actor_ref, - selection, - message, - ) + match &self.sliced { + Some(sliced_shape) => cast_to_sliced_mesh::( + caps, + self.mesh_id.clone(), + caps.mailbox().actor_id(), + &self.comm_actor_ref, + &selection, + message, + sliced_shape, + &self.root, + ), + None => actor_mesh_cast::( + caps, + self.mesh_id.clone(), + &self.root, + caps.mailbox().actor_id(), + &self.comm_actor_ref, + selection, + message, + ), + } + } + + pub fn select>(&self, label: &str, range: R) -> Result { + let sliced = self.shape().select(label, range)?; + Ok(Self { + mesh_id: self.mesh_id.clone(), + root: self.root.clone(), + sliced: Some(sliced), + comm_actor_ref: self.comm_actor_ref.clone(), + phantom: PhantomData, + }) } } @@ -137,21 +172,14 @@ impl Clone for ActorMeshRef { fn clone(&self) -> Self { Self { mesh_id: self.mesh_id.clone(), - shape: self.shape.clone(), + root: self.root.clone(), + sliced: self.sliced.clone(), comm_actor_ref: self.comm_actor_ref.clone(), phantom: PhantomData, } } } -impl PartialEq for ActorMeshRef { - fn eq(&self, other: &Self) -> bool { - self.mesh_id == other.mesh_id && self.shape == other.shape - } -} - -impl Eq for ActorMeshRef {} - #[cfg(test)] mod tests { use async_trait::async_trait; diff --git a/monarch_hyperactor/src/actor_mesh.rs b/monarch_hyperactor/src/actor_mesh.rs index 65ceabfa8..7326e1ac2 100644 --- a/monarch_hyperactor/src/actor_mesh.rs +++ b/monarch_hyperactor/src/actor_mesh.rs @@ -27,6 +27,8 @@ use pyo3::exceptions::PyRuntimeError; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyBytes; +use pyo3::types::PyDict; +use pyo3::types::PySlice; use serde::Deserialize; use serde::Serialize; use tokio::sync::Mutex; @@ -178,6 +180,11 @@ impl PythonActorMesh { Ok(monitor_instance.into_py(py)) } + #[pyo3(signature = (**kwargs))] + fn slice(&self, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult { + self.bind()?.slice(kwargs) + } + #[getter] pub fn client(&self) -> PyMailbox { self.client.clone() @@ -222,6 +229,75 @@ impl PythonActorMeshRef { Ok(()) } + #[pyo3(signature = (**kwargs))] + fn slice(&self, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult { + // When the input type is `int`, convert it into `ndslice::Range`. + fn convert_int(index: isize) -> PyResult { + if index < 0 { + return Err(PyException::new_err(format!( + "does not support negative index in selection: {}", + index + ))); + } + Ok(ndslice::Range::from(index as usize)) + } + + // When the input type is `slice`, convert it into `ndslice::Range`. + fn convert_py_slice<'py>(s: &Bound<'py, PySlice>) -> PyResult { + fn get_attr<'py>(s: &Bound<'py, PySlice>, attr: &str) -> PyResult> { + let v = s.getattr(attr)?.extract::>()?; + if v.is_some() && v.unwrap() < 0 { + return Err(PyException::new_err(format!( + "does not support negative {} in slice: {}", + attr, + v.unwrap(), + ))); + } + Ok(v) + } + + let start = get_attr(s, "start")?.unwrap_or(0); + let stop: Option = get_attr(s, "stop")?; + let step = get_attr(s, "step")?.unwrap_or(1); + Ok(ndslice::Range( + start as usize, + stop.map(|s| s as usize), + step as usize, + )) + } + + if kwargs.is_none() || kwargs.unwrap().is_empty() { + return Err(PyException::new_err("selection cannot be empty")); + } + + let mut sliced = self.inner.clone(); + + for entry in kwargs.unwrap().items() { + let label = entry.get_item(0)?.str()?; + let label_str = label.to_str()?; + + let value = entry.get_item(1)?; + + let range = if let Ok(index) = value.extract::() { + convert_int(index)? + } else if let Ok(s) = value.downcast::() { + convert_py_slice(s)? + } else { + return Err(PyException::new_err( + "selection only supports type int or slice", + )); + }; + sliced = sliced.select(label_str, range).map_err(|err| { + PyException::new_err(format!( + "failed to select label {}; error is: {}", + label_str, err + )) + })?; + } + + Ok(Self { inner: sliced }) + } + #[getter] fn shape(&self) -> PyShape { PyShape::from(self.inner.shape().clone()) diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi index 18f3eb0c4..72faf2330 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi @@ -6,7 +6,6 @@ # pyre-strict -from collections.abc import Mapping from typing import AsyncIterator, final from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage @@ -18,6 +17,7 @@ from monarch._rust_bindings.monarch_hyperactor.mailbox import ( from monarch._rust_bindings.monarch_hyperactor.proc import ActorId from monarch._rust_bindings.monarch_hyperactor.selection import Selection from monarch._rust_bindings.monarch_hyperactor.shape import Shape +from typing_extensions import Self @final class PythonActorMeshRef: @@ -31,6 +31,12 @@ class PythonActorMeshRef: """Cast a message to the selected actors in the mesh.""" ... + def slice(self, **kwargs: int | slice[int | None, int | None, int | None]) -> Self: + """ + See PythonActorMeshRef.slice for documentation. + """ + ... + @property def shape(self) -> Shape: """ @@ -53,6 +59,22 @@ class PythonActorMesh: """ Cast a message to the selected actors in the mesh. """ + ... + + def slice( + self, **kwargs: int | slice[int | None, int | None, int | None] + ) -> PythonActorMeshRef: + """ + Slice the mesh into a new mesh ref with the given selection. The reason + it returns a mesh ref, rather than the mesh object itself, is because + sliced mesh is a view of the original mesh, and does not own the mesh's + resources. + + Arguments: + - `kwargs`: argument name is the label, and argument value is how to + slice the mesh along the dimension of that label. + """ + ... def get_supervision_event(self) -> ActorSupervisionEvent | None: """ diff --git a/python/tests/_monarch/test_actor_mesh.py b/python/tests/_monarch/test_actor_mesh.py index 874cf3c83..eeccc5c29 100644 --- a/python/tests/_monarch/test_actor_mesh.py +++ b/python/tests/_monarch/test_actor_mesh.py @@ -6,6 +6,7 @@ # pyre-unsafe +import asyncio import pickle from typing import Any, List @@ -34,7 +35,7 @@ async def allocate() -> ProcMesh: - spec = AllocSpec(AllocConstraints(), replica=2, gpus=3, hosts=8) + spec = AllocSpec(AllocConstraints(), replicas=3, hosts=8, gpus=8) allocator = monarch.LocalAllocator() alloc = await allocator.allocate(spec) proc_mesh = await ProcMesh.allocate_nonblocking(alloc) @@ -108,15 +109,17 @@ async def verify_cast( assert rank is not None rcv_ranks.append(rank) rcv_ranks.sort() - for i in cast_ranks: - assert rcv_ranks[i] == i + assert rcv_ranks == cast_ranks + # verify no more messages are received + with pytest.raises(asyncio.exceptions.TimeoutError): + await asyncio.wait_for(receiver.recv(), timeout=1) @pytest.mark.timeout(30) async def test_cast_handle() -> None: proc_mesh = await allocate() actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor) - await verify_cast(actor_mesh, proc_mesh.client, list(range(2 * 3 * 8))) + await verify_cast(actor_mesh, proc_mesh.client, list(range(3 * 8 * 8))) @pytest.mark.timeout(30) @@ -124,4 +127,76 @@ async def test_cast_ref() -> None: proc_mesh = await allocate() actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor) actor_mesh_ref = actor_mesh.bind() - await verify_cast(actor_mesh_ref, proc_mesh.client, list(range(2 * 3 * 8))) + await verify_cast(actor_mesh_ref, proc_mesh.client, list(range(3 * 8 * 8))) + + +async def verify_slice( + actor_mesh: PythonActorMesh | PythonActorMeshRef, + mailbox: Mailbox, +) -> None: + sliced_mesh = actor_mesh.slice( + gpus=slice(2, 8, 2), + replicas=slice(None, 2), + hosts=slice(3, 7), + ) + sliced_shape = sliced_mesh.shape + # fmt: off + # turn off formatting to make the following list more readable + replica_0_ranks = [ + # gpus=2,4,6 + 24 + 2, 24 + 4, 24 + 6, # hosts=3 + 32 + 2, 32 + 4, 32 + 6, # hosts=4 + 40 + 2, 40 + 4, 40 + 6, # hosts=5 + 48 + 2, 48 + 4, 48 + 6, # hosts=6 + ] + # fmt: on + replica_1_ranks = [rank + 64 for rank in replica_0_ranks] + assert ( + sliced_shape.ranks() == replica_0_ranks + replica_1_ranks + ), f"left is {sliced_shape.ranks()}" + await verify_cast(sliced_mesh, mailbox, sliced_shape.ranks()) + + assert sliced_shape.labels == ["replicas", "hosts", "gpus"] + assert sliced_shape.ndslice.sizes == [2, 4, 3] + # When slicing a sliced mesh, the user treats this sliced mesh as a + # continuous mesh, and calculates the dimensions based on that assumption, + # without considering the original mesh. + # + # e.g, the following slicing operation selects index 0 and 2 of the hosts + # dimension on the sliced mesh. But corresponding index on the original + # mesh is 3 and 5. + sliced_again = sliced_mesh.slice( + replicas=1, + hosts=slice(None, None, 2), + gpus=slice(1, 3), + ) + again_shape = sliced_again.shape + assert again_shape.labels == ["replicas", "hosts", "gpus"] + assert again_shape.ndslice.sizes == [1, 2, 2] + # fmt: off + # turn off formatting to make the following list more readable + selected_ranks = [ + rank + 64 for rank in + [ + # gpus=4,6 + 24 + 4, 24 + 6, # hosts=3 + 40 + 4, 40 + 6, # hosts=5 + ] + ] + # fmt: on + assert again_shape.ranks() == selected_ranks, f"left is {sliced_shape.ranks()}" + + +@pytest.mark.timeout(30) +async def test_slice_actor_mesh_handle() -> None: + proc_mesh = await allocate() + actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor) + await verify_slice(actor_mesh, proc_mesh.client) + + +@pytest.mark.timeout(30) +async def test_slice_actor_mesh_ref() -> None: + proc_mesh = await allocate() + actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor) + actor_mesh_ref = actor_mesh.bind() + await verify_slice(actor_mesh_ref, proc_mesh.client) From cca87c62f6616befa044338a03a4b0e1d2f1df7f Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Wed, 16 Jul 2025 10:39:30 -0700 Subject: [PATCH 2/2] Push MeshTrait implementation to _ActorMeshRefImpl Summary: `_ActorMeshRefImpl` is the standin class for `PythonActorMesh` and `PythonActorMeshRef`. This diff stack is working on replacing `_ActorMeshRefImpl` with `PythonActorMesh` and `PythonActorMeshRef`. Compared to `PythonActorMesh` and `PythonActorMeshRef`, one method `_ActorMeshRefImpl` is missing is the `def slice` method. Lacking of this method blocks to do a drop-in replacement. This diff pushes the `MeshTrait` implementation to `_ActorMeshRefImpl`. In this way, `_ActorMeshRefImpl` will have `the slice` method from `MeshTrait`. Differential Revision: D78300586 --- python/monarch/_src/actor/actor_mesh.py | 37 ++++++++++++++++--------- python/monarch/rdma.py | 2 ++ 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index e9d4d69eb..0d2236ebd 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -147,7 +147,7 @@ def set(debug_context: "DebugContext") -> None: # standin class for whatever is the serializable python object we use # to name an actor mesh. Hacked up today because ActorMesh # isn't plumbed to non-clients -class _ActorMeshRefImpl: +class _ActorMeshRefImpl(MeshTrait): def __init__( self, mailbox: Mailbox, @@ -181,12 +181,17 @@ def from_hyperactor_mesh( def from_actor_id(mailbox: Mailbox, actor_id: ActorId) -> "_ActorMeshRefImpl": return _ActorMeshRefImpl(mailbox, None, None, singleton_shape, [actor_id]) - @staticmethod - def from_actor_ref_with_shape( - ref: "_ActorMeshRefImpl", shape: Shape - ) -> "_ActorMeshRefImpl": + @property + def _ndslice(self) -> NDSlice: + return self._shape.ndslice + + @property + def _labels(self) -> Iterable[str]: + return self._shape.labels + + def _new_with_shape(self, shape: Shape) -> "_ActorMeshRefImpl": return _ActorMeshRefImpl( - ref._mailbox, None, None, shape, ref._please_replace_me_actor_ids + self._mailbox, None, None, shape, self._please_replace_me_actor_ids ) def __getstate__( @@ -877,7 +882,7 @@ def _labels(self) -> Tuple[str, ...]: "actor implementations are not meshes, but we can't convince the typechecker of it..." ) - def _new_with_shape(self, shape: Shape) -> "ActorMeshRef": + def _new_with_shape(self, shape: Shape) -> "Actor": raise NotImplementedError( "actor implementations are not meshes, but we can't convince the typechecker of it..." ) @@ -958,19 +963,25 @@ def __reduce_ex__( @property def _ndslice(self) -> NDSlice: - return self._actor_mesh_ref._shape.ndslice + raise NotImplementedError( + "should not be called because def slice is overridden" + ) @property def _labels(self) -> Iterable[str]: - return self._actor_mesh_ref._shape.labels + raise NotImplementedError( + "should not be called because def slice is overridden" + ) def _new_with_shape(self, shape: Shape) -> "ActorMeshRef": - return ActorMeshRef( - self._class, - _ActorMeshRefImpl.from_actor_ref_with_shape(self._actor_mesh_ref, shape), - self._mailbox, + raise NotImplementedError( + "should not be called because def slice is overridden" ) + def slice(self, **kwargs) -> "ActorMeshRef": + sliced = self._actor_mesh_ref.slice(**kwargs) + return ActorMeshRef(self._class, sliced, self._mailbox) + def __repr__(self) -> str: return f"ActorMeshRef(class={self._class}, shape={self._actor_mesh_ref._shape})" diff --git a/python/monarch/rdma.py b/python/monarch/rdma.py index b9cc771a0..0ba3fcd90 100644 --- a/python/monarch/rdma.py +++ b/python/monarch/rdma.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import ctypes from dataclasses import dataclass