Skip to content

Add auto-reload support to CodeSyncManager #568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions monarch_hyperactor/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @generated by autocargo from //monarch/monarch_hyperactor:[monarch_hyperactor,process_allocator-oss]
# @generated by autocargo from //monarch/monarch_hyperactor:[monarch_hyperactor,process_allocator-oss,test_monarch_hyperactor]

[package]
name = "monarch_hyperactor"
Expand All @@ -7,6 +7,10 @@ authors = ["Meta"]
edition = "2021"
license = "BSD-3-Clause"

[[test]]
name = "test_monarch_hyperactor"
path = "tests/lib.rs"

[dependencies]
anyhow = "1.0.98"
async-once-cell = "0.4.2"
Expand All @@ -16,7 +20,6 @@ clap = { version = "4.5.38", features = ["derive", "env", "string", "unicode", "
erased-serde = "0.3.27"
fbinit = { version = "0.2.0", git = "https://github.com/facebookexperimental/rust-shed.git", branch = "main" }
futures = { version = "0.3.30", features = ["async-await", "compat"] }
futures-util = { version = "0.3.30", features = ["compat"] }
hyperactor = { version = "0.0.0", path = "../hyperactor" }
hyperactor_mesh = { version = "0.0.0", path = "../hyperactor_mesh" }
hyperactor_multiprocess = { version = "0.0.0", path = "../hyperactor_multiprocess" }
Expand Down
1 change: 1 addition & 0 deletions monarch_hyperactor/src/code_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

pub mod auto_reload;
pub mod manager;
pub mod rsync;
mod workspace;
Expand Down
104 changes: 104 additions & 0 deletions monarch_hyperactor/src/code_sync/auto_reload.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

use std::sync::Arc;

use anyhow::Result;
use async_trait::async_trait;
use hyperactor::Actor;
use hyperactor::Context;
use hyperactor::Handler;
use hyperactor::Named;
use hyperactor::PortRef;
use monarch_types::SerializablePyErr;
use pyo3::prelude::*;
use serde::Deserialize;
use serde::Serialize;

/// Message to trigger module reloading
#[derive(Debug, Clone, Named, Serialize, Deserialize)]
pub struct AutoReloadMessage {
pub result: PortRef<Result<(), String>>,
}

/// Parameters for creating an AutoReloadActor
#[derive(Debug, Clone, Named, Serialize, Deserialize)]
pub struct AutoReloadParams {}

/// Simple Rust Actor that wraps the Python AutoReloader class via pyo3
#[derive(Debug)]
#[hyperactor::export(spawn = true, handlers = [AutoReloadMessage])]
pub struct AutoReloadActor {
state: Result<(Arc<PyObject>, PyObject), SerializablePyErr>,
}

#[async_trait]
impl Actor for AutoReloadActor {
type Params = AutoReloadParams;

async fn new(Self::Params {}: Self::Params) -> Result<Self> {
Ok(Self {
state: tokio::task::spawn_blocking(move || {
Python::with_gil(|py| {
Self::create_state(py).map_err(SerializablePyErr::from_fn(py))
})
})
.await?,
})
}
}

impl AutoReloadActor {
fn create_state(py: Python) -> PyResult<(Arc<PyObject>, PyObject)> {
// Import the Python AutoReloader class
let auto_reload_module = py.import("monarch._src.actor.code_sync.auto_reload")?;
let auto_reloader_class = auto_reload_module.getattr("AutoReloader")?;

let reloader = auto_reloader_class.call0()?;

// Install the audit import hook: SysAuditImportHook.install(reloader.import_callback)
let sys_audit_import_hook_class = auto_reload_module.getattr("SysAuditImportHook")?;
let import_callback = reloader.getattr("import_callback")?;
let hook_guard = sys_audit_import_hook_class.call_method1("install", (import_callback,))?;

Ok((Arc::new(reloader.into()), hook_guard.into()))
}

fn reload(py: Python, py_reloader: &PyObject) -> PyResult<()> {
let reloader = py_reloader.bind(py);
let changed_modules: Vec<String> = reloader.call_method0("reload_changes")?.extract()?;
if !changed_modules.is_empty() {
eprintln!("reloaded modules: {:?}", changed_modules);
}
Ok(())
}
}

#[async_trait]
impl Handler<AutoReloadMessage> for AutoReloadActor {
async fn handle(
&mut self,
cx: &Context<Self>,
AutoReloadMessage { result }: AutoReloadMessage,
) -> Result<()> {
// Call the Python reloader's reload_changes method
let res = async {
let py_reloader: Arc<_> = self.state.as_ref().map_err(Clone::clone)?.0.clone();
tokio::task::spawn_blocking(move || {
Python::with_gil(|py| {
Self::reload(py, py_reloader.as_ref()).map_err(SerializablePyErr::from_fn(py))
})
})
.await??;
anyhow::Ok(())
}
.await;
result.send(cx, res.map_err(|e| format!("{:#?}", e)))?;
Ok(())
}
}
48 changes: 40 additions & 8 deletions monarch_hyperactor/src/code_sync/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

use std::collections::HashSet;
use std::path::PathBuf;

use anyhow::Result;
Expand Down Expand Up @@ -46,6 +47,9 @@ use tokio::net::TcpListener;
use tokio::net::TcpStream;

use crate::code_sync::WorkspaceLocation;
use crate::code_sync::auto_reload::AutoReloadActor;
use crate::code_sync::auto_reload::AutoReloadMessage;
use crate::code_sync::auto_reload::AutoReloadParams;
use crate::code_sync::rsync::RsyncActor;
use crate::code_sync::rsync::RsyncDaemon;
use crate::code_sync::rsync::RsyncMessage;
Expand Down Expand Up @@ -161,6 +165,7 @@ pub struct CodeSyncManagerParams {}
)]
pub struct CodeSyncManager {
rsync: OnceCell<ActorHandle<RsyncActor>>,
auto_reload: OnceCell<ActorHandle<AutoReloadActor>>,
}

#[async_trait]
Expand All @@ -170,6 +175,7 @@ impl Actor for CodeSyncManager {
async fn new(CodeSyncManagerParams {}: Self::Params) -> Result<Self> {
Ok(Self {
rsync: OnceCell::new(),
auto_reload: OnceCell::new(),
})
}
}
Expand All @@ -183,6 +189,15 @@ impl CodeSyncManager {
.get_or_try_init(RsyncActor::spawn(cx, RsyncParams {}))
.await
}

async fn get_auto_reload_actor<'a>(
&'a mut self,
cx: &Context<'a, Self>,
) -> Result<&'a ActorHandle<AutoReloadActor>> {
self.auto_reload
.get_or_try_init(AutoReloadActor::spawn(cx, AutoReloadParams {}))
.await
}
}

#[async_trait]
Expand Down Expand Up @@ -216,12 +231,22 @@ impl CodeSyncMessageHandler for CodeSyncManager {
if let Some(workspace_shape) = reload {
let mesh = workspace_shape.downstream_mesh(cx.self_id())?;
let (tx, rx) = cx.open_port::<Result<(), String>>();
mesh.cast(cx, sel!(*), CodeSyncMessage::Reload { result: tx.bind() })?;
let _: Vec<()> = rx
.take(mesh.shape().slice().len())
.map(|res| res?.map_err(anyhow::Error::msg))
.try_collect()
.await?;
let tx = tx.bind();
mesh.cast(
cx,
// We make sure to exclude the current rank from the sync, as this actor will
// be blocked here waiting for results. We just manually call `reload` to run
// concurrently below.
sel!(*).without(mesh.shape().slice(), &HashSet::from([cx.self_id().rank()]))?,
CodeSyncMessage::Reload { result: tx.clone() },
)?;
let _: ((), Vec<()>) = try_join!(
// Run reload for this rank.
self.reload(cx, tx),
rx.take(mesh.shape().slice().len())
.map(|res| res?.map_err(anyhow::Error::msg))
.try_collect(),
)?;
}

anyhow::Ok(())
Expand All @@ -236,8 +261,15 @@ impl CodeSyncMessageHandler for CodeSyncManager {
cx: &Context<Self>,
result: PortRef<Result<(), String>>,
) -> Result<()> {
// TODO(agallagher): Add reload.
let res = async move { anyhow::Ok(()) }.await;
let res = async move {
let (tx, mut rx) = cx.open_port();
self.get_auto_reload_actor(cx)
.await?
.send(AutoReloadMessage { result: tx.bind() })?;
rx.recv().await?.map_err(anyhow::Error::msg)?;
anyhow::Ok(())
}
.await;
result.send(cx, res.map_err(|e| format!("{:#?}", e)))?;
Ok(())
}
Expand Down
141 changes: 141 additions & 0 deletions monarch_hyperactor/tests/code_sync/auto_reload.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

use anyhow::Result;
use anyhow::anyhow;
use hyperactor_mesh::actor_mesh::ActorMesh;
use hyperactor_mesh::alloc::AllocSpec;
use hyperactor_mesh::alloc::Allocator;
use hyperactor_mesh::alloc::local::LocalAllocator;
use hyperactor_mesh::mesh::Mesh;
use hyperactor_mesh::proc_mesh::ProcMesh;
use monarch_hyperactor::code_sync::auto_reload::AutoReloadActor;
use monarch_hyperactor::code_sync::auto_reload::AutoReloadMessage;
use monarch_hyperactor::code_sync::auto_reload::AutoReloadParams;
use ndslice::shape;
use pyo3::ffi::c_str;
use pyo3::prelude::*;
use tempfile::TempDir;
use tokio::fs;

#[tokio::test]
async fn test_auto_reload_actor() -> Result<()> {
pyo3::prepare_freethreaded_python();
Python::with_gil(|py| py.run(c_str!("import monarch._rust_bindings"), None, None))?;

// Create a temporary directory for Python files
let temp_dir = TempDir::new()?;
let py_file_path = temp_dir.path().join("test_module.py");

// Create initial Python file content
let initial_content = r#"
# Test module for auto-reload
def get_value():
return "initial_value"

CONSTANT = "initial_constant"
"#;
fs::write(&py_file_path, initial_content).await?;

// Set up a single AutoReloadActor
let alloc = LocalAllocator
.allocate(AllocSpec {
shape: shape! { replica = 1 },
constraints: Default::default(),
})
.await?;

let proc_mesh = ProcMesh::allocate(alloc).await?;
let params = AutoReloadParams {};
let actor_mesh = proc_mesh
.spawn::<AutoReloadActor>("auto_reload_test", &params)
.await?;

// Get a reference to the single actor
let actor_ref = actor_mesh
.get(0)
.ok_or_else(|| anyhow!("No actor at index 0"))?;
let mailbox = actor_mesh.proc_mesh().client();

// First, we need to import the module to get it tracked by the AutoReloader
// We'll do this by running Python code that imports our test module
let temp_path = temp_dir.path().to_path_buf();
let import_result = tokio::task::spawn_blocking({
move || {
Python::with_gil(|py| -> PyResult<String> {
// Add the temp directory to Python path
let sys = py.import("sys")?;
let path = sys.getattr("path")?;
let path_list = path.downcast::<pyo3::types::PyList>()?;
path_list.insert(0, temp_path.to_string_lossy())?;

// Import the test module
let test_module = py.import("test_module")?;
let get_value_func = test_module.getattr("get_value")?;
let initial_value: String = get_value_func.call0()?.extract()?;

Ok(initial_value)
})
}
})
.await??;

// Verify we got the initial value
assert_eq!(import_result, "initial_value");
println!("Initial import successful, got: {}", import_result);

// Now modify the Python file
let modified_content = r#"
# Test module for auto-reload (MODIFIED)
def get_value():
return "modified_value"

CONSTANT = "modified_constant"
"#;
fs::write(&py_file_path, modified_content).await?;
println!("Modified Python file");

// Send AutoReloadMessage to trigger reload
let (result_tx, mut result_rx) = mailbox.open_port::<Result<(), String>>();
actor_ref.send(
&mailbox,
AutoReloadMessage {
result: result_tx.bind(),
},
)?;

// Wait for reload to complete
let reload_result = result_rx.recv().await?;
reload_result.map_err(|e| anyhow!("Reload failed: {}", e))?;
println!("Auto-reload completed successfully");

// Now import the module again and verify the changes were propagated
let final_result = tokio::task::spawn_blocking({
move || {
Python::with_gil(|py| -> PyResult<String> {
// Re-import the test module (it should be reloaded now)
let test_module = py.import("test_module")?;
let get_value_func = test_module.getattr("get_value")?;
let final_value: String = get_value_func.call0()?.extract()?;

Ok(final_value)
})
}
})
.await??;

// Verify that the changes were propagated
assert_eq!(final_result, "modified_value");
println!("Final import successful, got: {}", final_result);

// Verify that the module was actually reloaded by checking if we get the new value
assert_ne!(import_result, final_result);
println!("Auto-reload test completed successfully - module was reloaded!");

Ok(())
}
9 changes: 9 additions & 0 deletions monarch_hyperactor/tests/code_sync/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

mod auto_reload;
9 changes: 9 additions & 0 deletions monarch_hyperactor/tests/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

mod code_sync;
Loading
Loading