Skip to content

Commit a3f6a97

Browse files
andrewjcgfacebook-github-bot
authored andcommitted
Add auto-reload support to CodeSyncManager (#568)
Summary: Update the `CodeSyncManager` actor to call `AutoReloadActor` when requested, to trigger hot-reloading of Python code after code sync completes. In this setup, only a single `CodeSyncManager` on an e.g. host will perform a code sync operation, and then will cast to all ranks that share the host (e.g. gpu ranks) to trigger all the reloads. Differential Revision: D78358200
1 parent 06634d6 commit a3f6a97

File tree

8 files changed

+310
-19
lines changed

8 files changed

+310
-19
lines changed

monarch_hyperactor/Cargo.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# @generated by autocargo from //monarch/monarch_hyperactor:[monarch_hyperactor,process_allocator-oss]
1+
# @generated by autocargo from //monarch/monarch_hyperactor:[monarch_hyperactor,process_allocator-oss,test_monarch_hyperactor]
22

33
[package]
44
name = "monarch_hyperactor"
@@ -7,6 +7,10 @@ authors = ["Meta"]
77
edition = "2021"
88
license = "BSD-3-Clause"
99

10+
[[test]]
11+
name = "test_monarch_hyperactor"
12+
path = "tests/lib.rs"
13+
1014
[dependencies]
1115
anyhow = "1.0.98"
1216
async-once-cell = "0.4.2"
@@ -16,7 +20,6 @@ clap = { version = "4.5.38", features = ["derive", "env", "string", "unicode", "
1620
erased-serde = "0.3.27"
1721
fbinit = { version = "0.2.0", git = "https://github.com/facebookexperimental/rust-shed.git", branch = "main" }
1822
futures = { version = "0.3.30", features = ["async-await", "compat"] }
19-
futures-util = { version = "0.3.30", features = ["compat"] }
2023
hyperactor = { version = "0.0.0", path = "../hyperactor" }
2124
hyperactor_mesh = { version = "0.0.0", path = "../hyperactor_mesh" }
2225
hyperactor_multiprocess = { version = "0.0.0", path = "../hyperactor_multiprocess" }

monarch_hyperactor/src/code_sync.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
pub mod auto_reload;
910
pub mod manager;
1011
pub mod rsync;
1112
mod workspace;
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
use std::sync::Arc;
10+
11+
use anyhow::Result;
12+
use async_trait::async_trait;
13+
use hyperactor::Actor;
14+
use hyperactor::Context;
15+
use hyperactor::Handler;
16+
use hyperactor::Named;
17+
use hyperactor::PortRef;
18+
use monarch_types::SerializablePyErr;
19+
use pyo3::prelude::*;
20+
use serde::Deserialize;
21+
use serde::Serialize;
22+
23+
/// Message to trigger module reloading
24+
#[derive(Debug, Clone, Named, Serialize, Deserialize)]
25+
pub struct AutoReloadMessage {
26+
pub result: PortRef<Result<(), String>>,
27+
}
28+
29+
/// Parameters for creating an AutoReloadActor
30+
#[derive(Debug, Clone, Named, Serialize, Deserialize)]
31+
pub struct AutoReloadParams {}
32+
33+
/// Simple Rust Actor that wraps the Python AutoReloader class via pyo3
34+
#[derive(Debug)]
35+
#[hyperactor::export(spawn = true, handlers = [AutoReloadMessage])]
36+
pub struct AutoReloadActor {
37+
state: Result<(Arc<PyObject>, PyObject), SerializablePyErr>,
38+
}
39+
40+
#[async_trait]
41+
impl Actor for AutoReloadActor {
42+
type Params = AutoReloadParams;
43+
44+
async fn new(Self::Params {}: Self::Params) -> Result<Self> {
45+
Ok(Self {
46+
state: tokio::task::spawn_blocking(move || {
47+
Python::with_gil(|py| {
48+
Self::create_state(py).map_err(SerializablePyErr::from_fn(py))
49+
})
50+
})
51+
.await?,
52+
})
53+
}
54+
}
55+
56+
impl AutoReloadActor {
57+
fn create_state(py: Python) -> PyResult<(Arc<PyObject>, PyObject)> {
58+
// Import the Python AutoReloader class
59+
let auto_reload_module = py.import("monarch._src.actor.code_sync.auto_reload")?;
60+
let auto_reloader_class = auto_reload_module.getattr("AutoReloader")?;
61+
62+
let reloader = auto_reloader_class.call0()?;
63+
64+
// Install the audit import hook: SysAuditImportHook.install(reloader.import_callback)
65+
let sys_audit_import_hook_class = auto_reload_module.getattr("SysAuditImportHook")?;
66+
let import_callback = reloader.getattr("import_callback")?;
67+
let hook_guard = sys_audit_import_hook_class.call_method1("install", (import_callback,))?;
68+
69+
Ok((Arc::new(reloader.into()), hook_guard.into()))
70+
}
71+
72+
fn reload(py: Python, py_reloader: &PyObject) -> PyResult<()> {
73+
let reloader = py_reloader.bind(py);
74+
let changed_modules: Vec<String> = reloader.call_method0("reload_changes")?.extract()?;
75+
if !changed_modules.is_empty() {
76+
eprintln!("reloaded modules: {:?}", changed_modules);
77+
}
78+
Ok(())
79+
}
80+
}
81+
82+
#[async_trait]
83+
impl Handler<AutoReloadMessage> for AutoReloadActor {
84+
async fn handle(
85+
&mut self,
86+
cx: &Context<Self>,
87+
AutoReloadMessage { result }: AutoReloadMessage,
88+
) -> Result<()> {
89+
// Call the Python reloader's reload_changes method
90+
let res = async {
91+
let py_reloader: Arc<_> = self.state.as_ref().map_err(Clone::clone)?.0.clone();
92+
tokio::task::spawn_blocking(move || {
93+
Python::with_gil(|py| {
94+
Self::reload(py, py_reloader.as_ref()).map_err(SerializablePyErr::from_fn(py))
95+
})
96+
})
97+
.await??;
98+
anyhow::Ok(())
99+
}
100+
.await;
101+
result.send(cx, res.map_err(|e| format!("{:#?}", e)))?;
102+
Ok(())
103+
}
104+
}

monarch_hyperactor/src/code_sync/manager.rs

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
use std::collections::HashSet;
910
use std::path::PathBuf;
1011

1112
use anyhow::Result;
@@ -46,6 +47,9 @@ use tokio::net::TcpListener;
4647
use tokio::net::TcpStream;
4748

4849
use crate::code_sync::WorkspaceLocation;
50+
use crate::code_sync::auto_reload::AutoReloadActor;
51+
use crate::code_sync::auto_reload::AutoReloadMessage;
52+
use crate::code_sync::auto_reload::AutoReloadParams;
4953
use crate::code_sync::rsync::RsyncActor;
5054
use crate::code_sync::rsync::RsyncDaemon;
5155
use crate::code_sync::rsync::RsyncMessage;
@@ -161,6 +165,7 @@ pub struct CodeSyncManagerParams {}
161165
)]
162166
pub struct CodeSyncManager {
163167
rsync: OnceCell<ActorHandle<RsyncActor>>,
168+
auto_reload: OnceCell<ActorHandle<AutoReloadActor>>,
164169
}
165170

166171
#[async_trait]
@@ -170,6 +175,7 @@ impl Actor for CodeSyncManager {
170175
async fn new(CodeSyncManagerParams {}: Self::Params) -> Result<Self> {
171176
Ok(Self {
172177
rsync: OnceCell::new(),
178+
auto_reload: OnceCell::new(),
173179
})
174180
}
175181
}
@@ -183,6 +189,15 @@ impl CodeSyncManager {
183189
.get_or_try_init(RsyncActor::spawn(cx, RsyncParams {}))
184190
.await
185191
}
192+
193+
async fn get_auto_reload_actor<'a>(
194+
&'a mut self,
195+
cx: &Context<'a, Self>,
196+
) -> Result<&'a ActorHandle<AutoReloadActor>> {
197+
self.auto_reload
198+
.get_or_try_init(AutoReloadActor::spawn(cx, AutoReloadParams {}))
199+
.await
200+
}
186201
}
187202

188203
#[async_trait]
@@ -216,12 +231,22 @@ impl CodeSyncMessageHandler for CodeSyncManager {
216231
if let Some(workspace_shape) = reload {
217232
let mesh = workspace_shape.downstream_mesh(cx.self_id())?;
218233
let (tx, rx) = cx.open_port::<Result<(), String>>();
219-
mesh.cast(cx, sel!(*), CodeSyncMessage::Reload { result: tx.bind() })?;
220-
let _: Vec<()> = rx
221-
.take(mesh.shape().slice().len())
222-
.map(|res| res?.map_err(anyhow::Error::msg))
223-
.try_collect()
224-
.await?;
234+
let tx = tx.bind();
235+
mesh.cast(
236+
cx,
237+
// We make sure to exclude the current rank from the sync, as this actor will
238+
// be blocked here waiting for results. We just manually call `reload` to run
239+
// concurrently below.
240+
sel!(*).without(mesh.shape().slice(), &HashSet::from([cx.self_id().rank()]))?,
241+
CodeSyncMessage::Reload { result: tx.clone() },
242+
)?;
243+
let _: ((), Vec<()>) = try_join!(
244+
// Run reload for this rank.
245+
self.reload(cx, tx),
246+
rx.take(mesh.shape().slice().len())
247+
.map(|res| res?.map_err(anyhow::Error::msg))
248+
.try_collect(),
249+
)?;
225250
}
226251

227252
anyhow::Ok(())
@@ -236,8 +261,15 @@ impl CodeSyncMessageHandler for CodeSyncManager {
236261
cx: &Context<Self>,
237262
result: PortRef<Result<(), String>>,
238263
) -> Result<()> {
239-
// TODO(agallagher): Add reload.
240-
let res = async move { anyhow::Ok(()) }.await;
264+
let res = async move {
265+
let (tx, mut rx) = cx.open_port();
266+
self.get_auto_reload_actor(cx)
267+
.await?
268+
.send(AutoReloadMessage { result: tx.bind() })?;
269+
rx.recv().await?.map_err(anyhow::Error::msg)?;
270+
anyhow::Ok(())
271+
}
272+
.await;
241273
result.send(cx, res.map_err(|e| format!("{:#?}", e)))?;
242274
Ok(())
243275
}
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
use anyhow::Result;
10+
use anyhow::anyhow;
11+
use hyperactor_mesh::actor_mesh::ActorMesh;
12+
use hyperactor_mesh::alloc::AllocSpec;
13+
use hyperactor_mesh::alloc::Allocator;
14+
use hyperactor_mesh::alloc::local::LocalAllocator;
15+
use hyperactor_mesh::mesh::Mesh;
16+
use hyperactor_mesh::proc_mesh::ProcMesh;
17+
use monarch_hyperactor::code_sync::auto_reload::AutoReloadActor;
18+
use monarch_hyperactor::code_sync::auto_reload::AutoReloadMessage;
19+
use monarch_hyperactor::code_sync::auto_reload::AutoReloadParams;
20+
use ndslice::shape;
21+
use pyo3::ffi::c_str;
22+
use pyo3::prelude::*;
23+
use tempfile::TempDir;
24+
use tokio::fs;
25+
26+
#[tokio::test]
27+
async fn test_auto_reload_actor() -> Result<()> {
28+
pyo3::prepare_freethreaded_python();
29+
Python::with_gil(|py| py.run(c_str!("import monarch._rust_bindings"), None, None))?;
30+
31+
// Create a temporary directory for Python files
32+
let temp_dir = TempDir::new()?;
33+
let py_file_path = temp_dir.path().join("test_module.py");
34+
35+
// Create initial Python file content
36+
let initial_content = r#"
37+
# Test module for auto-reload
38+
def get_value():
39+
return "initial_value"
40+
41+
CONSTANT = "initial_constant"
42+
"#;
43+
fs::write(&py_file_path, initial_content).await?;
44+
45+
// Set up a single AutoReloadActor
46+
let alloc = LocalAllocator
47+
.allocate(AllocSpec {
48+
shape: shape! { replica = 1 },
49+
constraints: Default::default(),
50+
})
51+
.await?;
52+
53+
let proc_mesh = ProcMesh::allocate(alloc).await?;
54+
let params = AutoReloadParams {};
55+
let actor_mesh = proc_mesh
56+
.spawn::<AutoReloadActor>("auto_reload_test", &params)
57+
.await?;
58+
59+
// Get a reference to the single actor
60+
let actor_ref = actor_mesh
61+
.get(0)
62+
.ok_or_else(|| anyhow!("No actor at index 0"))?;
63+
let mailbox = actor_mesh.proc_mesh().client();
64+
65+
// First, we need to import the module to get it tracked by the AutoReloader
66+
// We'll do this by running Python code that imports our test module
67+
let temp_path = temp_dir.path().to_path_buf();
68+
let import_result = tokio::task::spawn_blocking({
69+
move || {
70+
Python::with_gil(|py| -> PyResult<String> {
71+
// Add the temp directory to Python path
72+
let sys = py.import("sys")?;
73+
let path = sys.getattr("path")?;
74+
let path_list = path.downcast::<pyo3::types::PyList>()?;
75+
path_list.insert(0, temp_path.to_string_lossy())?;
76+
77+
// Import the test module
78+
let test_module = py.import("test_module")?;
79+
let get_value_func = test_module.getattr("get_value")?;
80+
let initial_value: String = get_value_func.call0()?.extract()?;
81+
82+
Ok(initial_value)
83+
})
84+
}
85+
})
86+
.await??;
87+
88+
// Verify we got the initial value
89+
assert_eq!(import_result, "initial_value");
90+
println!("Initial import successful, got: {}", import_result);
91+
92+
// Now modify the Python file
93+
let modified_content = r#"
94+
# Test module for auto-reload (MODIFIED)
95+
def get_value():
96+
return "modified_value"
97+
98+
CONSTANT = "modified_constant"
99+
"#;
100+
fs::write(&py_file_path, modified_content).await?;
101+
println!("Modified Python file");
102+
103+
// Send AutoReloadMessage to trigger reload
104+
let (result_tx, mut result_rx) = mailbox.open_port::<Result<(), String>>();
105+
actor_ref.send(
106+
&mailbox,
107+
AutoReloadMessage {
108+
result: result_tx.bind(),
109+
},
110+
)?;
111+
112+
// Wait for reload to complete
113+
let reload_result = result_rx.recv().await?;
114+
reload_result.map_err(|e| anyhow!("Reload failed: {}", e))?;
115+
println!("Auto-reload completed successfully");
116+
117+
// Now import the module again and verify the changes were propagated
118+
let final_result = tokio::task::spawn_blocking({
119+
move || {
120+
Python::with_gil(|py| -> PyResult<String> {
121+
// Re-import the test module (it should be reloaded now)
122+
let test_module = py.import("test_module")?;
123+
let get_value_func = test_module.getattr("get_value")?;
124+
let final_value: String = get_value_func.call0()?.extract()?;
125+
126+
Ok(final_value)
127+
})
128+
}
129+
})
130+
.await??;
131+
132+
// Verify that the changes were propagated
133+
assert_eq!(final_result, "modified_value");
134+
println!("Final import successful, got: {}", final_result);
135+
136+
// Verify that the module was actually reloaded by checking if we get the new value
137+
assert_ne!(import_result, final_result);
138+
println!("Auto-reload test completed successfully - module was reloaded!");
139+
140+
Ok(())
141+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
mod auto_reload;

monarch_hyperactor/tests/lib.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
mod code_sync;

0 commit comments

Comments
 (0)