diff --git a/monarch_hyperactor/Cargo.toml b/monarch_hyperactor/Cargo.toml index 820cde23..76bd92eb 100644 --- a/monarch_hyperactor/Cargo.toml +++ b/monarch_hyperactor/Cargo.toml @@ -1,4 +1,4 @@ -# @generated by autocargo from //monarch/monarch_hyperactor:[monarch_hyperactor,process_allocator-oss] +# @generated by autocargo from //monarch/monarch_hyperactor:[monarch_hyperactor,process_allocator-oss,test_monarch_hyperactor] [package] name = "monarch_hyperactor" @@ -7,6 +7,10 @@ authors = ["Meta"] edition = "2021" license = "BSD-3-Clause" +[[test]] +name = "test_monarch_hyperactor" +path = "tests/lib.rs" + [dependencies] anyhow = "1.0.98" async-once-cell = "0.4.2" @@ -16,7 +20,6 @@ clap = { version = "4.5.38", features = ["derive", "env", "string", "unicode", " erased-serde = "0.3.27" fbinit = { version = "0.2.0", git = "https://github.com/facebookexperimental/rust-shed.git", branch = "main" } futures = { version = "0.3.30", features = ["async-await", "compat"] } -futures-util = { version = "0.3.30", features = ["compat"] } hyperactor = { version = "0.0.0", path = "../hyperactor" } hyperactor_mesh = { version = "0.0.0", path = "../hyperactor_mesh" } hyperactor_multiprocess = { version = "0.0.0", path = "../hyperactor_multiprocess" } diff --git a/monarch_hyperactor/src/code_sync.rs b/monarch_hyperactor/src/code_sync.rs index f57ee3ee..7e318b67 100644 --- a/monarch_hyperactor/src/code_sync.rs +++ b/monarch_hyperactor/src/code_sync.rs @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +pub mod auto_reload; pub mod manager; pub mod rsync; mod workspace; diff --git a/monarch_hyperactor/src/code_sync/auto_reload.rs b/monarch_hyperactor/src/code_sync/auto_reload.rs new file mode 100644 index 00000000..bc8e3453 --- /dev/null +++ b/monarch_hyperactor/src/code_sync/auto_reload.rs @@ -0,0 +1,104 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use std::sync::Arc; + +use anyhow::Result; +use async_trait::async_trait; +use hyperactor::Actor; +use hyperactor::Context; +use hyperactor::Handler; +use hyperactor::Named; +use hyperactor::PortRef; +use monarch_types::SerializablePyErr; +use pyo3::prelude::*; +use serde::Deserialize; +use serde::Serialize; + +/// Message to trigger module reloading +#[derive(Debug, Clone, Named, Serialize, Deserialize)] +pub struct AutoReloadMessage { + pub result: PortRef>, +} + +/// Parameters for creating an AutoReloadActor +#[derive(Debug, Clone, Named, Serialize, Deserialize)] +pub struct AutoReloadParams {} + +/// Simple Rust Actor that wraps the Python AutoReloader class via pyo3 +#[derive(Debug)] +#[hyperactor::export(spawn = true, handlers = [AutoReloadMessage])] +pub struct AutoReloadActor { + state: Result<(Arc, PyObject), SerializablePyErr>, +} + +#[async_trait] +impl Actor for AutoReloadActor { + type Params = AutoReloadParams; + + async fn new(Self::Params {}: Self::Params) -> Result { + Ok(Self { + state: tokio::task::spawn_blocking(move || { + Python::with_gil(|py| { + Self::create_state(py).map_err(SerializablePyErr::from_fn(py)) + }) + }) + .await?, + }) + } +} + +impl AutoReloadActor { + fn create_state(py: Python) -> PyResult<(Arc, PyObject)> { + // Import the Python AutoReloader class + let auto_reload_module = py.import("monarch._src.actor.code_sync.auto_reload")?; + let auto_reloader_class = auto_reload_module.getattr("AutoReloader")?; + + let reloader = auto_reloader_class.call0()?; + + // Install the audit import hook: SysAuditImportHook.install(reloader.import_callback) + let sys_audit_import_hook_class = auto_reload_module.getattr("SysAuditImportHook")?; + let import_callback = reloader.getattr("import_callback")?; + let hook_guard = sys_audit_import_hook_class.call_method1("install", (import_callback,))?; + + Ok((Arc::new(reloader.into()), hook_guard.into())) + } + + fn reload(py: Python, py_reloader: &PyObject) -> PyResult<()> { + let reloader = py_reloader.bind(py); + let changed_modules: Vec = reloader.call_method0("reload_changes")?.extract()?; + if !changed_modules.is_empty() { + eprintln!("reloaded modules: {:?}", changed_modules); + } + Ok(()) + } +} + +#[async_trait] +impl Handler for AutoReloadActor { + async fn handle( + &mut self, + cx: &Context, + AutoReloadMessage { result }: AutoReloadMessage, + ) -> Result<()> { + // Call the Python reloader's reload_changes method + let res = async { + let py_reloader: Arc<_> = self.state.as_ref().map_err(Clone::clone)?.0.clone(); + tokio::task::spawn_blocking(move || { + Python::with_gil(|py| { + Self::reload(py, py_reloader.as_ref()).map_err(SerializablePyErr::from_fn(py)) + }) + }) + .await??; + anyhow::Ok(()) + } + .await; + result.send(cx, res.map_err(|e| format!("{:#?}", e)))?; + Ok(()) + } +} diff --git a/monarch_hyperactor/src/code_sync/manager.rs b/monarch_hyperactor/src/code_sync/manager.rs index d3ea97dc..b137fe11 100644 --- a/monarch_hyperactor/src/code_sync/manager.rs +++ b/monarch_hyperactor/src/code_sync/manager.rs @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +use std::collections::HashSet; use std::path::PathBuf; use anyhow::Result; @@ -46,6 +47,9 @@ use tokio::net::TcpListener; use tokio::net::TcpStream; use crate::code_sync::WorkspaceLocation; +use crate::code_sync::auto_reload::AutoReloadActor; +use crate::code_sync::auto_reload::AutoReloadMessage; +use crate::code_sync::auto_reload::AutoReloadParams; use crate::code_sync::rsync::RsyncActor; use crate::code_sync::rsync::RsyncDaemon; use crate::code_sync::rsync::RsyncMessage; @@ -161,6 +165,7 @@ pub struct CodeSyncManagerParams {} )] pub struct CodeSyncManager { rsync: OnceCell>, + auto_reload: OnceCell>, } #[async_trait] @@ -170,6 +175,7 @@ impl Actor for CodeSyncManager { async fn new(CodeSyncManagerParams {}: Self::Params) -> Result { Ok(Self { rsync: OnceCell::new(), + auto_reload: OnceCell::new(), }) } } @@ -183,6 +189,15 @@ impl CodeSyncManager { .get_or_try_init(RsyncActor::spawn(cx, RsyncParams {})) .await } + + async fn get_auto_reload_actor<'a>( + &'a mut self, + cx: &Context<'a, Self>, + ) -> Result<&'a ActorHandle> { + self.auto_reload + .get_or_try_init(AutoReloadActor::spawn(cx, AutoReloadParams {})) + .await + } } #[async_trait] @@ -216,12 +231,22 @@ impl CodeSyncMessageHandler for CodeSyncManager { if let Some(workspace_shape) = reload { let mesh = workspace_shape.downstream_mesh(cx.self_id())?; let (tx, rx) = cx.open_port::>(); - mesh.cast(cx, sel!(*), CodeSyncMessage::Reload { result: tx.bind() })?; - let _: Vec<()> = rx - .take(mesh.shape().slice().len()) - .map(|res| res?.map_err(anyhow::Error::msg)) - .try_collect() - .await?; + let tx = tx.bind(); + mesh.cast( + cx, + // We make sure to exclude the current rank from the sync, as this actor will + // be blocked here waiting for results. We just manually call `reload` to run + // concurrently below. + sel!(*).without(mesh.shape().slice(), &HashSet::from([cx.self_id().rank()]))?, + CodeSyncMessage::Reload { result: tx.clone() }, + )?; + let _: ((), Vec<()>) = try_join!( + // Run reload for this rank. + self.reload(cx, tx), + rx.take(mesh.shape().slice().len()) + .map(|res| res?.map_err(anyhow::Error::msg)) + .try_collect(), + )?; } anyhow::Ok(()) @@ -236,8 +261,15 @@ impl CodeSyncMessageHandler for CodeSyncManager { cx: &Context, result: PortRef>, ) -> Result<()> { - // TODO(agallagher): Add reload. - let res = async move { anyhow::Ok(()) }.await; + let res = async move { + let (tx, mut rx) = cx.open_port(); + self.get_auto_reload_actor(cx) + .await? + .send(AutoReloadMessage { result: tx.bind() })?; + rx.recv().await?.map_err(anyhow::Error::msg)?; + anyhow::Ok(()) + } + .await; result.send(cx, res.map_err(|e| format!("{:#?}", e)))?; Ok(()) } diff --git a/monarch_hyperactor/tests/code_sync/auto_reload.rs b/monarch_hyperactor/tests/code_sync/auto_reload.rs new file mode 100644 index 00000000..2851e045 --- /dev/null +++ b/monarch_hyperactor/tests/code_sync/auto_reload.rs @@ -0,0 +1,141 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use anyhow::Result; +use anyhow::anyhow; +use hyperactor_mesh::actor_mesh::ActorMesh; +use hyperactor_mesh::alloc::AllocSpec; +use hyperactor_mesh::alloc::Allocator; +use hyperactor_mesh::alloc::local::LocalAllocator; +use hyperactor_mesh::mesh::Mesh; +use hyperactor_mesh::proc_mesh::ProcMesh; +use monarch_hyperactor::code_sync::auto_reload::AutoReloadActor; +use monarch_hyperactor::code_sync::auto_reload::AutoReloadMessage; +use monarch_hyperactor::code_sync::auto_reload::AutoReloadParams; +use ndslice::shape; +use pyo3::ffi::c_str; +use pyo3::prelude::*; +use tempfile::TempDir; +use tokio::fs; + +#[tokio::test] +async fn test_auto_reload_actor() -> Result<()> { + pyo3::prepare_freethreaded_python(); + Python::with_gil(|py| py.run(c_str!("import monarch._rust_bindings"), None, None))?; + + // Create a temporary directory for Python files + let temp_dir = TempDir::new()?; + let py_file_path = temp_dir.path().join("test_module.py"); + + // Create initial Python file content + let initial_content = r#" +# Test module for auto-reload +def get_value(): + return "initial_value" + +CONSTANT = "initial_constant" +"#; + fs::write(&py_file_path, initial_content).await?; + + // Set up a single AutoReloadActor + let alloc = LocalAllocator + .allocate(AllocSpec { + shape: shape! { replica = 1 }, + constraints: Default::default(), + }) + .await?; + + let proc_mesh = ProcMesh::allocate(alloc).await?; + let params = AutoReloadParams {}; + let actor_mesh = proc_mesh + .spawn::("auto_reload_test", ¶ms) + .await?; + + // Get a reference to the single actor + let actor_ref = actor_mesh + .get(0) + .ok_or_else(|| anyhow!("No actor at index 0"))?; + let mailbox = actor_mesh.proc_mesh().client(); + + // First, we need to import the module to get it tracked by the AutoReloader + // We'll do this by running Python code that imports our test module + let temp_path = temp_dir.path().to_path_buf(); + let import_result = tokio::task::spawn_blocking({ + move || { + Python::with_gil(|py| -> PyResult { + // Add the temp directory to Python path + let sys = py.import("sys")?; + let path = sys.getattr("path")?; + let path_list = path.downcast::()?; + path_list.insert(0, temp_path.to_string_lossy())?; + + // Import the test module + let test_module = py.import("test_module")?; + let get_value_func = test_module.getattr("get_value")?; + let initial_value: String = get_value_func.call0()?.extract()?; + + Ok(initial_value) + }) + } + }) + .await??; + + // Verify we got the initial value + assert_eq!(import_result, "initial_value"); + println!("Initial import successful, got: {}", import_result); + + // Now modify the Python file + let modified_content = r#" +# Test module for auto-reload (MODIFIED) +def get_value(): + return "modified_value" + +CONSTANT = "modified_constant" +"#; + fs::write(&py_file_path, modified_content).await?; + println!("Modified Python file"); + + // Send AutoReloadMessage to trigger reload + let (result_tx, mut result_rx) = mailbox.open_port::>(); + actor_ref.send( + &mailbox, + AutoReloadMessage { + result: result_tx.bind(), + }, + )?; + + // Wait for reload to complete + let reload_result = result_rx.recv().await?; + reload_result.map_err(|e| anyhow!("Reload failed: {}", e))?; + println!("Auto-reload completed successfully"); + + // Now import the module again and verify the changes were propagated + let final_result = tokio::task::spawn_blocking({ + move || { + Python::with_gil(|py| -> PyResult { + // Re-import the test module (it should be reloaded now) + let test_module = py.import("test_module")?; + let get_value_func = test_module.getattr("get_value")?; + let final_value: String = get_value_func.call0()?.extract()?; + + Ok(final_value) + }) + } + }) + .await??; + + // Verify that the changes were propagated + assert_eq!(final_result, "modified_value"); + println!("Final import successful, got: {}", final_result); + + // Verify that the module was actually reloaded by checking if we get the new value + assert_ne!(import_result, final_result); + println!("Auto-reload test completed successfully - module was reloaded!"); + + Ok(()) +} diff --git a/monarch_hyperactor/tests/code_sync/mod.rs b/monarch_hyperactor/tests/code_sync/mod.rs new file mode 100644 index 00000000..986e0468 --- /dev/null +++ b/monarch_hyperactor/tests/code_sync/mod.rs @@ -0,0 +1,9 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +mod auto_reload; diff --git a/monarch_hyperactor/tests/lib.rs b/monarch_hyperactor/tests/lib.rs new file mode 100644 index 00000000..b9bb2b86 --- /dev/null +++ b/monarch_hyperactor/tests/lib.rs @@ -0,0 +1,9 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +mod code_sync; diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index 600930bc..c1561f80 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -104,7 +104,6 @@ def __init__( self._debug_manager: Optional[DebugManager] = None self._mailbox: Mailbox = self._proc_mesh.client self._code_sync_client: Optional[CodeSyncMeshClient] = None - self._auto_reload_actor: Optional[AutoReloadActor] = None self._logging_mesh_client: Optional[LoggingMeshClient] = None self._maybe_device_mesh: Optional["DeviceMesh"] = _device_mesh self._stopped = False @@ -247,11 +246,6 @@ async def sync_workspace(self, auto_reload: bool = False) -> None: self._code_sync_client = CodeSyncMeshClient.spawn_blocking( proc_mesh=self._proc_mesh, ) - # TODO(agallagher): Merge this into the `CodeSyncMeshClient` actor. - self._auto_reload_actor = self._spawn_blocking( - "auto_reload", - AutoReloadActor, - ) # TODO(agallagher): We need some way to configure and pass this # in -- right now we're assuming the `gpu` dimension, which isn't # correct. @@ -266,10 +260,8 @@ async def sync_workspace(self, auto_reload: bool = False) -> None: location=WorkspaceLocation.FromEnvVar("WORKSPACE_DIR"), shape=WorkspaceShape.shared("gpus"), ), + auto_reload=auto_reload, ) - if auto_reload: - assert self._auto_reload_actor is not None - await self._auto_reload_actor.reload.call() async def logging_option(self, stream_to_client: bool = False) -> None: """