From 11a217ed92e8ee2c4cb7c9b97879f46c0cf9e445 Mon Sep 17 00:00:00 2001 From: James Sun Date: Mon, 18 Aug 2025 23:43:29 -0700 Subject: [PATCH 1/2] sync flush of proc mesh Summary: Provide sync flush so it is guaranteed all the flushed logs on the remote procs will be streamed back and flushed on client's stdout/stderr. Differential Revision: D80051803 --- hyperactor_mesh/src/logging.rs | 246 +++++++++++++++--- monarch_extension/src/logging.rs | 54 ++++ .../monarch_extension/logging.pyi | 1 + python/tests/python_actor_test_binary.py | 7 +- python/tests/test_python_actors.py | 51 +++- 5 files changed, 314 insertions(+), 45 deletions(-) diff --git a/hyperactor_mesh/src/logging.rs b/hyperactor_mesh/src/logging.rs index 496f58144..28b7275f1 100644 --- a/hyperactor_mesh/src/logging.rs +++ b/hyperactor_mesh/src/logging.rs @@ -22,12 +22,15 @@ use chrono::DateTime; use chrono::Local; use hyperactor::Actor; use hyperactor::ActorRef; +use hyperactor::Bind; use hyperactor::Context; use hyperactor::HandleClient; use hyperactor::Handler; use hyperactor::Instance; use hyperactor::Named; +use hyperactor::OncePortRef; use hyperactor::RefClient; +use hyperactor::Unbind; use hyperactor::channel; use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelRx; @@ -39,9 +42,6 @@ use hyperactor::channel::TxStatus; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; use hyperactor::data::Serialized; -use hyperactor::message::Bind; -use hyperactor::message::Bindings; -use hyperactor::message::Unbind; use hyperactor_telemetry::env; use hyperactor_telemetry::log_file_path; use serde::Deserialize; @@ -235,6 +235,24 @@ impl fmt::Display for Aggregator { } } +/// Messages that can be sent to the LogClientActor remotely. +#[derive( + Debug, + Clone, + Serialize, + Deserialize, + Named, + Handler, + HandleClient, + RefClient, + Bind, + Unbind +)] +pub enum LogFlushMessage { + /// Flush the log + ForceSyncFlush { version: u64 }, +} + /// Messages that can be sent to the LogClientActor remotely. #[derive( Debug, @@ -260,7 +278,11 @@ pub enum LogMessage { }, /// Flush the log - Flush {}, + Flush { + /// Indicate if the current flush is synced or non-synced. + /// If synced, a version number is available. Otherwise, none. + sync_version: Option, + }, } /// Messages that can be sent to the LogClient locally. @@ -279,6 +301,16 @@ pub enum LogClientMessage { /// The time window in seconds to aggregate logs. If None, aggregation is disabled. aggregate_window_sec: Option, }, + + /// Synchronously flush all the logs from all the procs. This is for client to call. + StartSyncFlush { + /// Expect these many procs to ack the flush message. + expected_procs: usize, + /// Return once we have received the acks from all the procs + reply: OncePortRef<()>, + /// Return to the caller the current flush version + version: OncePortRef, + }, } /// Trait for sending logs @@ -352,7 +384,7 @@ impl LogSender for LocalLogSender { // send will make sure message is delivered if TxStatus::Active == *self.status.borrow() { // Do not use tx.send, it will block the allocator as the child process state is unknown. - self.tx.post(LogMessage::Flush {}); + self.tx.post(LogMessage::Flush { sync_version: None }); } else { tracing::debug!( "log sender {} is not active, skip sending flush message", @@ -558,7 +590,9 @@ impl Named, Handler, HandleClient, - RefClient + RefClient, + Bind, + Unbind )] pub enum LogForwardMessage { /// Receive the log from the parent process and forward ti to the client. @@ -568,18 +602,6 @@ pub enum LogForwardMessage { SetMode { stream_to_client: bool }, } -impl Bind for LogForwardMessage { - fn bind(&mut self, _bindings: &mut Bindings) -> anyhow::Result<()> { - Ok(()) - } -} - -impl Unbind for LogForwardMessage { - fn unbind(&self, _bindings: &mut Bindings) -> anyhow::Result<()> { - Ok(()) - } -} - /// A log forwarder that receives the log from its parent process and forward it back to the client #[derive(Debug)] #[hyperactor::export( @@ -647,17 +669,32 @@ impl Actor for LogForwardActor { #[hyperactor::forward(LogForwardMessage)] impl LogForwardMessageHandler for LogForwardActor { async fn forward(&mut self, ctx: &Context) -> Result<(), anyhow::Error> { - if let Ok(LogMessage::Log { - hostname, - pid, - output_target, - payload, - }) = self.rx.recv().await - { - if self.stream_to_client { - self.logging_client_ref - .log(ctx, hostname, pid, output_target, payload) - .await?; + match self.rx.recv().await { + Ok(LogMessage::Flush { sync_version }) => { + match sync_version { + None => { + // no need to do anything. The previous messages have already been sent to the client. + // Client will flush based on its own frequency. + } + version => { + self.logging_client_ref.flush(ctx, version).await?; + } + } + } + Ok(LogMessage::Log { + hostname, + pid, + output_target, + payload, + }) => { + if self.stream_to_client { + self.logging_client_ref + .log(ctx, hostname, pid, output_target, payload) + .await?; + } + } + Err(e) => { + return Err(e.into()); } } @@ -696,6 +733,60 @@ fn deserialize_message_lines( anyhow::bail!("Failed to deserialize message as either String or Vec") } +/// An actor that send flush message to the log forwarder actor. +/// The reason we need an extra actor instead of reusing the log forwarder actor +/// is because the log forwarder can be blocked on the rx.recv() that listens on the new log lines. +/// Thus, we need to create anew channel as a tx to send the flush message to the log forwarder +/// So we do not get into a deadlock. +#[derive(Debug)] +#[hyperactor::export( + spawn = true, + handlers = [LogFlushMessage {cast = true}], +)] +pub struct LogFlushActor { + tx: ChannelTx, +} + +#[async_trait] +impl Actor for LogFlushActor { + type Params = (); + + async fn new(_: ()) -> Result { + let log_channel: ChannelAddr = match std::env::var(BOOTSTRAP_LOG_CHANNEL) { + Ok(channel) => channel.parse()?, + Err(err) => { + tracing::debug!( + "log forwarder actor failed to read env var {}: {}", + BOOTSTRAP_LOG_CHANNEL, + err + ); + // TODO: this should error out; it can only happen with local proc; we need to fix it. + ChannelAddr::any(ChannelTransport::Unix) + } + }; + let tx = channel::dial::(log_channel)?; + + Ok(Self { tx }) + } +} + +#[async_trait] +#[hyperactor::forward(LogFlushMessage)] +impl LogFlushMessageHandler for LogFlushActor { + async fn force_sync_flush( + &mut self, + _cx: &Context, + version: u64, + ) -> Result<(), anyhow::Error> { + self.tx + .send(LogMessage::Flush { + sync_version: Some(version), + }) + .await + .map_err(anyhow::Error::from) + } +} + /// A client to receive logs from remote processes #[derive(Debug)] #[hyperactor::export( @@ -707,6 +798,11 @@ pub struct LogClientActor { aggregators: HashMap, last_flush_time: SystemTime, next_flush_deadline: Option, + + // For flush sync barrier + current_flush_version: u64, + current_flush_port: Option>, + current_unflushed_procs: usize, } impl LogClientActor { @@ -736,6 +832,12 @@ impl LogClientActor { OutputTarget::Stderr => eprintln!("{}", message), } } + + fn flush_internal(&mut self) { + self.print_aggregators(); + self.last_flush_time = RealClock.system_time_now(); + self.next_flush_deadline = None; + } } #[async_trait] @@ -754,6 +856,9 @@ impl Actor for LogClientActor { aggregators, last_flush_time: RealClock.system_time_now(), next_flush_deadline: None, + current_flush_version: 0, + current_flush_port: None, + current_unflushed_procs: 0, }) } } @@ -805,20 +910,26 @@ impl LogMessageHandler for LogClientActor { let new_deadline = self.last_flush_time + Duration::from_secs(window); let now = RealClock.system_time_now(); if new_deadline <= now { - self.flush(cx).await?; + self.flush_internal(); } else { let delay = new_deadline.duration_since(now)?; match self.next_flush_deadline { None => { self.next_flush_deadline = Some(new_deadline); - cx.self_message_with_delay(LogMessage::Flush {}, delay)?; + cx.self_message_with_delay( + LogMessage::Flush { sync_version: None }, + delay, + )?; } Some(deadline) => { // Some early log lines have alrady triggered the flush. if new_deadline < deadline { // This can happen if the user has adjusted the aggregation window. self.next_flush_deadline = Some(new_deadline); - cx.self_message_with_delay(LogMessage::Flush {}, delay)?; + cx.self_message_with_delay( + LogMessage::Flush { sync_version: None }, + delay, + )?; } } } @@ -829,10 +940,45 @@ impl LogMessageHandler for LogClientActor { Ok(()) } - async fn flush(&mut self, _cx: &Context) -> Result<(), anyhow::Error> { - self.print_aggregators(); - self.last_flush_time = RealClock.system_time_now(); - self.next_flush_deadline = None; + async fn flush( + &mut self, + cx: &Context, + sync_version: Option, + ) -> Result<(), anyhow::Error> { + match sync_version { + None => { + self.flush_internal(); + } + Some(version) => { + if version != self.current_flush_version { + tracing::error!( + "found mismatched flush versions: got {}, expect {}; this can happen if some previous flush didn't finish fully", + version, + self.current_flush_version + ); + return Ok(()); + } + + if self.current_unflushed_procs == 0 || self.current_flush_port.is_none() { + // This is a serious issue; it's better to error out. + anyhow::bail!("found no ongoing flush request"); + } + self.current_unflushed_procs -= 1; + + tracing::debug!( + "ack sync flush: version {}; remaining procs: {}", + self.current_flush_version, + self.current_unflushed_procs + ); + + if self.current_unflushed_procs == 0 { + self.flush_internal(); + let reply = self.current_flush_port.take().unwrap(); + self.current_flush_port = None; + reply.send(cx, ()).map_err(anyhow::Error::from)?; + } + } + } Ok(()) } @@ -853,6 +999,34 @@ impl LogClientMessageHandler for LogClientActor { self.aggregate_window_sec = aggregate_window_sec; Ok(()) } + + async fn start_sync_flush( + &mut self, + cx: &Context, + expected_procs_flushed: usize, + reply: OncePortRef<()>, + version: OncePortRef, + ) -> Result<(), anyhow::Error> { + if self.current_unflushed_procs > 0 || self.current_flush_port.is_some() { + tracing::warn!( + "found unfinished ongoing flush: version {}; {} unflushed procs", + self.current_flush_version, + self.current_unflushed_procs, + ); + } + + self.current_flush_version += 1; + tracing::debug!( + "start sync flush with version {}", + self.current_flush_version + ); + self.current_flush_port = Some(reply.clone()); + self.current_unflushed_procs = expected_procs_flushed; + version + .send(cx, self.current_flush_version) + .map_err(anyhow::Error::from)?; + Ok(()) + } } #[cfg(test)] diff --git a/monarch_extension/src/logging.rs b/monarch_extension/src/logging.rs index 9ff8b208b..a155471e0 100644 --- a/monarch_extension/src/logging.rs +++ b/monarch_extension/src/logging.rs @@ -13,6 +13,8 @@ use hyperactor_mesh::RootActorMesh; use hyperactor_mesh::actor_mesh::ActorMesh; use hyperactor_mesh::logging::LogClientActor; use hyperactor_mesh::logging::LogClientMessage; +use hyperactor_mesh::logging::LogFlushActor; +use hyperactor_mesh::logging::LogFlushMessage; use hyperactor_mesh::logging::LogForwardActor; use hyperactor_mesh::logging::LogForwardMessage; use hyperactor_mesh::selection::Selection; @@ -33,11 +35,49 @@ use pyo3::types::PyModule; pub struct LoggingMeshClient { // handles remote process log forwarding; no python runtime forwarder_mesh: SharedCell>, + // because forwarder mesh keeps listening to the new coming logs, + // the flush mesh is a way to unblock it from busy waiting a log and do sync flush. + flush_mesh: SharedCell>, // handles python logger; has python runtime logger_mesh: SharedCell>, client_actor: ActorHandle, } +impl LoggingMeshClient { + async fn flush_internal( + client_actor: ActorHandle, + flush_mesh: SharedCell>, + ) -> Result<(), anyhow::Error> { + let flush_inner_mesh = flush_mesh.borrow().map_err(anyhow::Error::msg)?; + let (reply_tx, reply_rx) = flush_inner_mesh.proc_mesh().client().open_once_port::<()>(); + let (version_tx, version_rx) = flush_inner_mesh + .proc_mesh() + .client() + .open_once_port::(); + + // First initialize a sync flush. + client_actor.send(LogClientMessage::StartSyncFlush { + expected_procs: flush_inner_mesh.proc_mesh().shape().slice().len(), + reply: reply_tx.bind(), + version: version_tx.bind(), + })?; + + let version = version_rx.recv().await?; + + // Then ask all the flushers to ask the log forwarders to sync flush + flush_inner_mesh.cast( + flush_inner_mesh.proc_mesh().client(), + Selection::True, + LogFlushMessage::ForceSyncFlush { version }, + )?; + + // Finally the forwarder will send sync point back to the client, flush, and return. + reply_rx.recv().await?; + + Ok(()) + } +} + #[pymethods] impl LoggingMeshClient { #[staticmethod] @@ -47,9 +87,11 @@ impl LoggingMeshClient { let client_actor = proc_mesh.client_proc().spawn("log_client", ()).await?; let client_actor_ref = client_actor.bind(); let forwarder_mesh = proc_mesh.spawn("log_forwarder", &client_actor_ref).await?; + let flush_mesh = proc_mesh.spawn("log_flusher", &()).await?; let logger_mesh = proc_mesh.spawn("logger", &()).await?; Ok(Self { forwarder_mesh, + flush_mesh, logger_mesh, client_actor, }) @@ -97,6 +139,18 @@ impl LoggingMeshClient { Ok(()) } + + // A sync flush mechanism for the client make sure all the stdout/stderr are streamed back and flushed. + fn flush(&self) -> PyResult { + let flush_mesh = self.flush_mesh.clone(); + let client_actor = self.client_actor.clone(); + + PyPythonTask::new(async move { + Self::flush_internal(client_actor, flush_mesh) + .await + .map_err(|e| PyErr::new::(e.to_string())) + }) + } } impl Drop for LoggingMeshClient { diff --git a/python/monarch/_rust_bindings/monarch_extension/logging.pyi b/python/monarch/_rust_bindings/monarch_extension/logging.pyi index 5d6f11960..fa3d732af 100644 --- a/python/monarch/_rust_bindings/monarch_extension/logging.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/logging.pyi @@ -21,3 +21,4 @@ class LoggingMeshClient: def set_mode( self, stream_to_client: bool, aggregate_window_sec: int | None, level: int ) -> None: ... + def flush(self) -> PythonTask[None]: ... diff --git a/python/tests/python_actor_test_binary.py b/python/tests/python_actor_test_binary.py index 12a10b0f5..9cff72087 100644 --- a/python/tests/python_actor_test_binary.py +++ b/python/tests/python_actor_test_binary.py @@ -10,6 +10,7 @@ import logging import click +from monarch._src.actor.future import Future from monarch.actor import Actor, endpoint, proc_mesh @@ -40,8 +41,10 @@ async def _flush_logs() -> None: for _ in range(5): await am.print.call("has print streaming") - # TODO: will soon be removed by D80051803 - await asyncio.sleep(2) + # TODO: remove this completely once we hook the flush logic upon dropping device_mesh + log_mesh = pm._logging_mesh_client + assert log_mesh is not None + Future(coro=log_mesh.flush().spawn().task()).get() @main.command("flush-logs") diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 269409a99..71835058e 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -27,6 +27,7 @@ from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask from monarch._src.actor.actor_mesh import ActorMesh, Channel, Port +from monarch._src.actor.future import Future from monarch.actor import ( Accumulator, @@ -548,8 +549,10 @@ async def test_actor_log_streaming() -> None: await am.print.call("has print streaming too") await am.log.call("has log streaming as level matched") - # Give it some time to reflect and aggregate - await asyncio.sleep(1) + # TODO: remove this completely once we hook the flush logic upon dropping device_mesh + log_mesh = pm._logging_mesh_client + assert log_mesh is not None + Future(coro=log_mesh.flush().spawn().task()).get() # Flush all outputs stdout_file.flush() @@ -664,7 +667,11 @@ async def test_logging_option_defaults() -> None: for _ in range(5): await am.print.call("print streaming") await am.log.call("log streaming") - await asyncio.sleep(4) + + # TODO: remove this completely once we hook the flush logic upon dropping device_mesh + log_mesh = pm._logging_mesh_client + assert log_mesh is not None + Future(coro=log_mesh.flush().spawn().task()).get() # Flush all outputs stdout_file.flush() @@ -790,8 +797,10 @@ async def test_flush_on_disable_aggregation() -> None: for _ in range(5): await am.print.call("single log line") - # Wait a bit to ensure flush completes - await asyncio.sleep(1) + # TODO: remove this completely once we hook the flush logic upon dropping device_mesh + log_mesh = pm._logging_mesh_client + assert log_mesh is not None + Future(coro=log_mesh.flush().spawn().task()).get() # Flush all outputs stdout_file.flush() @@ -835,6 +844,32 @@ async def test_flush_on_disable_aggregation() -> None: pass +@pytest.mark.timeout(120) +async def test_multiple_ongoing_flushes_no_deadlock() -> None: + """ + The goal is to make sure when a user sends multiple sync flushes, we are not deadlocked. + Because now a flush call is purely sync, it is very easy to get into a deadlock. + So we assert the last flush call will not get into such a state. + """ + pm = await proc_mesh(gpus=4) + am = await pm.spawn("printer", Printer) + + # Generate some logs that will be aggregated but not flushed immediately + for _ in range(10): + await am.print.call("aggregated log line") + + log_mesh = pm._logging_mesh_client + assert log_mesh is not None + futures = [] + for _ in range(5): + # FIXME: the order of futures doesn't necessarily mean the order of flushes due to the async nature. + await asyncio.sleep(0.1) + futures.append(Future(coro=log_mesh.flush().spawn().task())) + + # The last flush should not block + futures[-1].get() + + @pytest.mark.timeout(60) async def test_adjust_aggregation_window() -> None: """Test that the flush deadline is updated when the aggregation window is adjusted. @@ -875,8 +910,10 @@ async def test_adjust_aggregation_window() -> None: for _ in range(3): await am.print.call("second batch of logs") - # Wait just enough time for the shorter window to trigger a flush - await asyncio.sleep(1) + # TODO: remove this completely once we hook the flush logic upon dropping device_mesh + log_mesh = pm._logging_mesh_client + assert log_mesh is not None + Future(coro=log_mesh.flush().spawn().task()).get() # Flush all outputs stdout_file.flush() From 177389215f3e88110087588552944b080dbd36a5 Mon Sep 17 00:00:00 2001 From: James Sun Date: Tue, 19 Aug 2025 00:11:03 -0700 Subject: [PATCH 2/2] disallow logging option for local procs (#825) Summary: Pull Request resolved: https://github.com/meta-pytorch/monarch/pull/825 there is no way to tail local procs; simply disallow it; I don't quite like the new interface introduced. RFC is needed. Differential Revision: D80063700 --- hyperactor_mesh/src/logging.rs | 42 +++++--------------------- python/monarch/_src/actor/allocator.py | 26 +++++++++++++++- python/monarch/_src/actor/proc_mesh.py | 38 +++++++++++++++++------ python/tests/test_python_actors.py | 9 ++++++ 4 files changed, 70 insertions(+), 45 deletions(-) diff --git a/hyperactor_mesh/src/logging.rs b/hyperactor_mesh/src/logging.rs index 28b7275f1..91535ccac 100644 --- a/hyperactor_mesh/src/logging.rs +++ b/hyperactor_mesh/src/logging.rs @@ -34,7 +34,6 @@ use hyperactor::Unbind; use hyperactor::channel; use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelRx; -use hyperactor::channel::ChannelTransport; use hyperactor::channel::ChannelTx; use hyperactor::channel::Rx; use hyperactor::channel::Tx; @@ -550,6 +549,10 @@ impl // Since LogSender::send takes &self, we don't need to clone it if let Err(e) = this.log_sender.send(output_target, data_to_send) { tracing::error!("error sending log: {}", e); + return Poll::Ready(Err(io::Error::other(format!( + "error sending write message: {}", + e + )))); } // Return success with the full buffer size Poll::Ready(Ok(buf.len())) @@ -621,15 +624,7 @@ impl Actor for LogForwardActor { async fn new(logging_client_ref: Self::Params) -> Result { let log_channel: ChannelAddr = match std::env::var(BOOTSTRAP_LOG_CHANNEL) { Ok(channel) => channel.parse()?, - Err(err) => { - tracing::debug!( - "log forwarder actor failed to read env var {}: {}", - BOOTSTRAP_LOG_CHANNEL, - err - ); - // TODO: an empty channel to serve - ChannelAddr::any(ChannelTransport::Unix) - } + Err(err) => return Err(err.into()), }; tracing::info!( "log forwarder {} serve at {}", @@ -637,21 +632,7 @@ impl Actor for LogForwardActor { log_channel ); - let rx = match channel::serve(log_channel.clone()).await { - Ok((_, rx)) => rx, - Err(err) => { - // This can happen if we are not spanwed on a separate process like local. - // For local mesh, log streaming anyway is not needed. - tracing::error!( - "log forwarder actor failed to bootstrap on given channel {}: {}", - log_channel, - err - ); - channel::serve(ChannelAddr::any(ChannelTransport::Unix)) - .await? - .1 - } - }; + let (_, rx) = channel::serve(log_channel.clone()).await?; Ok(Self { rx, logging_client_ref, @@ -754,15 +735,7 @@ impl Actor for LogFlushActor { async fn new(_: ()) -> Result { let log_channel: ChannelAddr = match std::env::var(BOOTSTRAP_LOG_CHANNEL) { Ok(channel) => channel.parse()?, - Err(err) => { - tracing::debug!( - "log forwarder actor failed to read env var {}: {}", - BOOTSTRAP_LOG_CHANNEL, - err - ); - // TODO: this should error out; it can only happen with local proc; we need to fix it. - ChannelAddr::any(ChannelTransport::Unix) - } + Err(err) => return Err(err.into()), }; let tx = channel::dial::(log_channel)?; @@ -1036,6 +1009,7 @@ mod tests { use hyperactor::channel; use hyperactor::channel::ChannelAddr; + use hyperactor::channel::ChannelTransport; use hyperactor::channel::ChannelTx; use hyperactor::channel::Tx; use hyperactor::id; diff --git a/python/monarch/_src/actor/allocator.py b/python/monarch/_src/actor/allocator.py index 43415990b..bae5f3a00 100644 --- a/python/monarch/_src/actor/allocator.py +++ b/python/monarch/_src/actor/allocator.py @@ -33,6 +33,7 @@ class AllocHandle(DeprecatedNotAFuture): _hy_alloc: "Shared[Alloc]" _extent: Dict[str, int] + _fork_processes: bool @property def initialized(self) -> Future[Literal[True]]: @@ -48,6 +49,10 @@ async def task() -> Literal[True]: return Future(coro=task()) + @property + def fork_processes(self) -> bool: + return self._fork_processes + class AllocateMixin(abc.ABC): @abc.abstractmethod @@ -63,7 +68,14 @@ def allocate(self, spec: AllocSpec) -> "AllocHandle": Returns: - A future that will be fulfilled when the requested allocation is fulfilled. """ - return AllocHandle(self.allocate_nonblocking(spec).spawn(), spec.extent) + return AllocHandle( + self.allocate_nonblocking(spec).spawn(), + spec.extent, + self._fork_processes(), + ) + + @abc.abstractmethod + def _fork_processes(self) -> bool: ... @final @@ -72,6 +84,9 @@ class ProcessAllocator(ProcessAllocatorBase, AllocateMixin): An allocator that allocates by spawning local processes. """ + def _fork_processes(self) -> bool: + return True + @final class LocalAllocator(LocalAllocatorBase, AllocateMixin): @@ -79,6 +94,9 @@ class LocalAllocator(LocalAllocatorBase, AllocateMixin): An allocator that allocates by spawning actors into the current process. """ + def _fork_processes(self) -> bool: + return False + @final class SimAllocator(SimAllocatorBase, AllocateMixin): @@ -86,6 +104,9 @@ class SimAllocator(SimAllocatorBase, AllocateMixin): An allocator that allocates by spawning actors into the current process using simulated channels for transport """ + def _fork_processes(self) -> bool: + return False + class RemoteAllocInitializer(abc.ABC): """Subclass-able Python interface for `hyperactor_mesh::alloc::remoteprocess:RemoteProcessAllocInitializer`. @@ -219,3 +240,6 @@ class RemoteAllocator(RemoteAllocatorBase, AllocateMixin): An allocator that allocates by spawning actors on a remote host. The remote host must be running hyperactor's remote-process-allocator. """ + + def _fork_processes(self) -> bool: + return True diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index 6aee0a973..9ae2316af 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -179,6 +179,7 @@ def __init__( self, hy_proc_mesh: "Shared[HyProcMesh]", shape: Shape, + _fork_processes: bool, _device_mesh: Optional["DeviceMesh"] = None, ) -> None: self._proc_mesh = hy_proc_mesh @@ -193,6 +194,7 @@ def __init__( self._code_sync_client: Optional[CodeSyncMeshClient] = None self._logging_mesh_client: Optional[LoggingMeshClient] = None self._maybe_device_mesh: Optional["DeviceMesh"] = _device_mesh + self._fork_processes = _fork_processes self._stopped = False self._controller_controller: Optional["_ControllerController"] = None @@ -225,7 +227,12 @@ def _new_with_shape(self, shape: Shape) -> "ProcMesh": if self._maybe_device_mesh is None else self._device_mesh._new_with_shape(shape) ) - pm = ProcMesh(self._proc_mesh, shape, _device_mesh=device_mesh) + pm = ProcMesh( + self._proc_mesh, + shape, + _fork_processes=self._fork_processes, + _device_mesh=device_mesh, + ) pm._slice = True return pm @@ -300,7 +307,12 @@ async def task() -> HyProcMesh: hy_proc_mesh = PythonTask.from_coroutine(task()).spawn() - pm = ProcMesh(hy_proc_mesh, shape) + fork_processes: bool = alloc.fork_processes + pm = ProcMesh( + hy_proc_mesh, + shape, + _fork_processes=fork_processes, + ) async def task( pm: "ProcMesh", @@ -309,14 +321,15 @@ async def task( ) -> HyProcMesh: hy_proc_mesh = await hy_proc_mesh_task - pm._logging_mesh_client = await LoggingMeshClient.spawn( - proc_mesh=hy_proc_mesh - ) - pm._logging_mesh_client.set_mode( - stream_to_client=True, - aggregate_window_sec=3, - level=logging.INFO, - ) + if fork_processes: + pm._logging_mesh_client = await LoggingMeshClient.spawn( + proc_mesh=hy_proc_mesh + ) + pm._logging_mesh_client.set_mode( + stream_to_client=True, + aggregate_window_sec=3, + level=logging.INFO, + ) if setup_actor is not None: await setup_actor.setup.call() @@ -483,6 +496,11 @@ async def logging_option( Returns: None """ + if not self._fork_processes: + raise RuntimeError( + "Logging option is only available for allocators that fork processes. Allocators like LocalAllocator are not supported." + ) + if level < 0 or level > 255: raise ValueError("Invalid logging level: {}".format(level)) await self.initialized diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 71835058e..91d096827 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -1101,3 +1101,12 @@ def test_mesh_len(): proc_mesh = local_proc_mesh(gpus=12).get() s = proc_mesh.spawn("sync_actor", SyncActor).get() assert 12 == len(s) + + +async def test_logging_option_on_local_procs() -> None: + proc_mesh = local_proc_mesh(gpus=1) + with pytest.raises( + RuntimeError, + match="Logging option is only available for allocators that fork processes", + ): + await proc_mesh.logging_option(stream_to_client=True)