From 6041587534b2ff44c3e8d4520e292b1bad30dc88 Mon Sep 17 00:00:00 2001 From: Benji Pelletier Date: Thu, 14 Aug 2025 11:49:42 -0700 Subject: [PATCH 1/2] Add sqlite tracing recorder and sql assertions (#779) Summary: New tracer subscriber to be used for testing (e.g., script or simulator) 1. New logging layer for use in tests that writes all log messages to a series of sqlite tables 2. Add capability to do sql based assertions for script tests or simulation tests 3. New trace level logging events on actor lifecycle events Next diffs will: * Get this working for our PAFT simulator tests so we can easily assert * Support custom columns Reviewed By: eliothedeman Differential Revision: D73512355 --- hyperactor/src/mailbox.rs | 2 +- hyperactor_telemetry/Cargo.toml | 2 + hyperactor_telemetry/src/lib.rs | 1 + hyperactor_telemetry/src/sqlite.rs | 367 +++++++++++++++++++++++++++++ 4 files changed, 371 insertions(+), 1 deletion(-) create mode 100644 hyperactor_telemetry/src/sqlite.rs diff --git a/hyperactor/src/mailbox.rs b/hyperactor/src/mailbox.rs index 07ed7d624..b1de225cb 100644 --- a/hyperactor/src/mailbox.rs +++ b/hyperactor/src/mailbox.rs @@ -1043,7 +1043,7 @@ impl MailboxSender for MailboxClient { return_handle: PortHandle>, ) { // tracing::trace!(name = "post", "posting message to {}", envelope.dest); - tracing::event!(target:"message", tracing::Level::DEBUG, "crc"=envelope.data.crc(), "size"=envelope.data.len(), "sender"= %envelope.sender, "dest" = %envelope.dest.0, "port"= envelope.dest.1, "message_type" = envelope.data.typename().unwrap_or("unknown"), "send_message"); + tracing::event!(target:"messages", tracing::Level::DEBUG, "crc"=envelope.data.crc(), "size"=envelope.data.len(), "sender"= %envelope.sender, "dest" = %envelope.dest.0, "port"= envelope.dest.1, "message_type" = envelope.data.typename().unwrap_or("unknown"), "send_message"); if let Err(mpsc::error::SendError((envelope, return_handle))) = self.buffer.send((envelope, return_handle)) diff --git a/hyperactor_telemetry/Cargo.toml b/hyperactor_telemetry/Cargo.toml index ce7199a4d..ccb08276c 100644 --- a/hyperactor_telemetry/Cargo.toml +++ b/hyperactor_telemetry/Cargo.toml @@ -20,9 +20,11 @@ lazy_static = "1.5" opentelemetry = "0.29" opentelemetry_sdk = { version = "0.29.0", features = ["rt-tokio"] } rand = { version = "0.8", features = ["small_rng"] } +rusqlite = { version = "0.36.0", features = ["backup", "blob", "bundled", "column_decltype", "functions", "limits", "modern_sqlite", "serde_json"] } scuba = { version = "0.1.0", git = "https://github.com/facebookexperimental/rust-shed.git", branch = "main", optional = true } serde = { version = "1.0.219", features = ["derive", "rc"] } serde_json = { version = "1.0.140", features = ["alloc", "float_roundtrip", "raw_value", "unbounded_depth"] } +serde_rusqlite = "0.39.3" tokio = { version = "1.46.1", features = ["full", "test-util", "tracing"] } tracing = { version = "0.1.41", features = ["attributes", "valuable"] } tracing-appender = "0.2.3" diff --git a/hyperactor_telemetry/src/lib.rs b/hyperactor_telemetry/src/lib.rs index 415ff17e2..e2c19879a 100644 --- a/hyperactor_telemetry/src/lib.rs +++ b/hyperactor_telemetry/src/lib.rs @@ -33,6 +33,7 @@ mod otel; mod pool; pub mod recorder; mod spool; +pub mod sqlite; use std::io::IsTerminal; use std::io::Write; use std::str::FromStr; diff --git a/hyperactor_telemetry/src/sqlite.rs b/hyperactor_telemetry/src/sqlite.rs new file mode 100644 index 000000000..0f02fa148 --- /dev/null +++ b/hyperactor_telemetry/src/sqlite.rs @@ -0,0 +1,367 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::Mutex; + +use anyhow::Result; +use anyhow::anyhow; +use lazy_static::lazy_static; +use rusqlite::Connection; +use rusqlite::functions::FunctionFlags; +use serde::Serialize; +use serde_json::Value as JValue; +use serde_rusqlite::*; +use tracing::Event; +use tracing::Subscriber; +use tracing::level_filters::LevelFilter; +use tracing_subscriber::Layer; +use tracing_subscriber::filter::Targets; +use tracing_subscriber::prelude::*; + +pub trait TableDef { + fn name(&self) -> &'static str; + fn columns(&self) -> &'static [&'static str]; + fn create_table_stmt(&self) -> String { + let name = self.name(); + let columns = self + .columns() + .iter() + .map(|col| format!("{col} TEXT ")) + .collect::>() + .join(","); + format!("create table if not exists {name} (seq INTEGER primary key, {columns})") + } + fn insert_stmt(&self) -> String { + let name = self.name(); + let columns = self.columns().join(", "); + let params = self + .columns() + .iter() + .map(|c| format!(":{c}")) + .collect::>() + .join(", "); + format!("insert into {name} ({columns}) values ({params})") + } +} + +impl TableDef for (&'static str, &'static [&'static str]) { + fn name(&self) -> &'static str { + self.0 + } + + fn columns(&self) -> &'static [&'static str] { + self.1 + } +} + +#[derive(Clone, Debug)] +pub struct Table { + pub columns: &'static [&'static str], + pub create_table_stmt: String, + pub insert_stmt: String, +} + +impl From<(&'static str, &'static [&'static str])> for Table { + fn from(value: (&'static str, &'static [&'static str])) -> Self { + Self { + columns: value.columns(), + create_table_stmt: value.create_table_stmt(), + insert_stmt: value.insert_stmt(), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TableName { + ActorLifecycle, + Messages, + LogEvents, +} + +impl TableName { + pub const ACTOR_LIFECYCLE_STR: &'static str = "actor_lifecycle"; + pub const MESSAGES_STR: &'static str = "messages"; + pub const LOG_EVENTS_STR: &'static str = "log_events"; + + pub fn as_str(&self) -> &'static str { + match self { + TableName::ActorLifecycle => Self::ACTOR_LIFECYCLE_STR, + TableName::Messages => Self::MESSAGES_STR, + TableName::LogEvents => Self::LOG_EVENTS_STR, + } + } + + pub fn get_table(&self) -> &'static Table { + match self { + TableName::ActorLifecycle => &ACTOR_LIFECYCLE, + TableName::Messages => &MESSAGES, + TableName::LogEvents => &LOG_EVENTS, + } + } +} + +lazy_static! { + static ref ACTOR_LIFECYCLE: Table = ( + TableName::ActorLifecycle.as_str(), + [ + "actor_id", + "actor", + "name", + "supervised_actor", + "actor_status", + "module_path", + "line", + "file", + ] + .as_slice() + ) + .into(); + static ref MESSAGES: Table = ( + TableName::Messages.as_str(), + [ + "span_id", + "time_us", + "src", + "dest", + "payload", + "module_path", + "line", + "file", + ] + .as_slice() + ) + .into(); + static ref LOG_EVENTS: Table = ( + TableName::LogEvents.as_str(), + [ + "span_id", + "time_us", + "name", + "message", + "actor_id", + "level", + "line", + "file", + "module_path", + ] + .as_slice() + ) + .into(); + static ref ALL_TABLES: Vec = vec![ + ACTOR_LIFECYCLE.clone(), + MESSAGES.clone(), + LOG_EVENTS.clone() + ]; +} + +pub struct SqliteLayer { + conn: Arc>, +} +use tracing::field::Visit; + +#[derive(Debug, Clone, Default, Serialize)] +struct SqlVisitor(HashMap); + +impl Visit for SqlVisitor { + fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) { + self.0.insert( + field.name().to_string(), + JValue::String(format!("{:?}", value)), + ); + } + + fn record_str(&mut self, field: &tracing::field::Field, value: &str) { + self.0 + .insert(field.name().to_string(), JValue::String(value.to_string())); + } + + fn record_i64(&mut self, field: &tracing::field::Field, value: i64) { + self.0 + .insert(field.name().to_string(), JValue::Number(value.into())); + } + + fn record_f64(&mut self, field: &tracing::field::Field, value: f64) { + let n = serde_json::Number::from_f64(value).unwrap(); + self.0.insert(field.name().to_string(), JValue::Number(n)); + } + + fn record_u64(&mut self, field: &tracing::field::Field, value: u64) { + self.0 + .insert(field.name().to_string(), JValue::Number(value.into())); + } + + fn record_bool(&mut self, field: &tracing::field::Field, value: bool) { + self.0.insert(field.name().to_string(), JValue::Bool(value)); + } +} + +macro_rules! insert_event { + ($table:expr, $conn:ident, $event:ident) => { + let mut v: SqlVisitor = Default::default(); + $event.record(&mut v); + let meta = $event.metadata(); + v.0.insert( + "module_path".to_string(), + meta.module_path().map(String::from).into(), + ); + v.0.insert("line".to_string(), meta.line().into()); + v.0.insert("file".to_string(), meta.file().map(String::from).into()); + $conn.prepare_cached(&$table.insert_stmt)?.execute( + serde_rusqlite::to_params_named_with_fields(v, $table.columns)? + .to_slice() + .as_slice(), + )?; + }; +} + +impl SqliteLayer { + pub fn new() -> Result { + let conn = Connection::open_in_memory()?; + + for table in ALL_TABLES.iter() { + conn.execute(&table.create_table_stmt, [])?; + } + conn.create_scalar_function( + "assert", + 2, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + move |ctx| { + let condition: bool = ctx.get(0)?; + let message: String = ctx.get(1)?; + + if !condition { + return Err(rusqlite::Error::UserFunctionError( + anyhow!("assertion failed:{condition} {message}",).into(), + )); + } + + Ok(condition) + }, + )?; + + Ok(Self { + conn: Arc::new(Mutex::new(conn)), + }) + } + + fn insert_event(&self, event: &Event<'_>) -> Result<()> { + let conn = self.conn.lock().unwrap(); + match (event.metadata().target(), event.metadata().name()) { + (TableName::MESSAGES_STR, _) => { + insert_event!(TableName::Messages.get_table(), conn, event); + } + (TableName::ACTOR_LIFECYCLE_STR, _) => { + insert_event!(TableName::ActorLifecycle.get_table(), conn, event); + } + _ => { + insert_event!(TableName::LogEvents.get_table(), conn, event); + } + } + Ok(()) + } + + pub fn connection(&self) -> Arc> { + self.conn.clone() + } +} + +impl Layer for SqliteLayer { + fn on_event(&self, event: &Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>) { + self.insert_event(event).unwrap(); + } +} + +#[allow(dead_code)] +fn print_table(conn: &Connection, table_name: TableName) -> Result<()> { + let table_name_str = table_name.as_str(); + + // Get column names + let mut stmt = conn.prepare(&format!("PRAGMA table_info({})", table_name_str))?; + let column_info = stmt.query_map([], |row| { + row.get::<_, String>(1) // Column name is at index 1 + })?; + + let columns: Vec = column_info.collect::, _>>()?; + + // Print header + println!("=== {} ===", table_name_str.to_uppercase()); + println!("{}", columns.join(" | ")); + println!("{}", "-".repeat(columns.len() * 10)); + + // Print rows + let mut stmt = conn.prepare(&format!("SELECT * FROM {}", table_name_str))?; + let rows = stmt.query_map([], |row| { + let mut values = Vec::new(); + for (i, column) in columns.iter().enumerate() { + // Handle different column types properly + let value = if i == 0 && *column == "seq" { + // First column is always the INTEGER seq column + match row.get::<_, Option>(i)? { + Some(v) => v.to_string(), + None => "NULL".to_string(), + } + } else { + // All other columns are TEXT + match row.get::<_, Option>(i)? { + Some(v) => v, + None => "NULL".to_string(), + } + }; + values.push(value); + } + Ok(values.join(" | ")) + })?; + + for row in rows { + println!("{}", row?); + } + println!(); + Ok(()) +} + +pub fn with_tracing_db() -> Arc> { + let layer = SqliteLayer::new().unwrap(); + let conn = layer.connection(); + + let layer = layer.with_filter( + Targets::new() + .with_default(LevelFilter::TRACE) + .with_targets(vec![ + ("tokio", LevelFilter::OFF), + ("opentelemetry", LevelFilter::OFF), + ("runtime", LevelFilter::OFF), + ]), + ); + tracing_subscriber::registry().with(layer).init(); + conn +} + +#[cfg(test)] +mod tests { + use tracing::info; + + use super::*; + + #[test] + fn test_sqlite_layer() -> Result<()> { + let conn = with_tracing_db(); + + info!(target:"messages", test_field = "test_value", "Test msg"); + info!(target:"log_events", test_field = "test_value", "Test event"); + + let count: i64 = + conn.lock() + .unwrap() + .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?; + print_table(&conn.lock().unwrap(), TableName::LogEvents)?; + assert!(count > 0); + Ok(()) + } +} From 8936a0e7f5874e99cd213a54a86cdcb656dc4248 Mon Sep 17 00:00:00 2001 From: Benji Pelletier Date: Thu, 14 Aug 2025 11:49:42 -0700 Subject: [PATCH 2/2] Add tracing sqlite db file support so we can connect from python (#807) Summary: POC of Rust-create sqlite db accessable from native python sqlite connection. This allows us to write python tests that test Monarch and assert user (or BE) events using sql queries. * Uses reloadable layer to inject SqliteLayer into tracing registry on demand. * Exposes `with_tracing_db_file` to python to create and get the DB file name so we can connect to Reviewed By: eliothedeman, pablorfb-meta Differential Revision: D79761474 --- hyperactor_telemetry/src/lib.rs | 4 + hyperactor_telemetry/src/sqlite.rs | 206 ++++++++++++++++-- monarch_hyperactor/src/telemetry.rs | 68 +++++- .../monarch_hyperactor/telemetry.pyi | 57 ++++- 4 files changed, 310 insertions(+), 25 deletions(-) diff --git a/hyperactor_telemetry/src/lib.rs b/hyperactor_telemetry/src/lib.rs index e2c19879a..cd8415434 100644 --- a/hyperactor_telemetry/src/lib.rs +++ b/hyperactor_telemetry/src/lib.rs @@ -64,6 +64,7 @@ use tracing_subscriber::fmt::format::Writer; use tracing_subscriber::registry::LookupSpan; use crate::recorder::Recorder; +use crate::sqlite::get_reloadable_sqlite_layer; pub trait TelemetryClock { fn now(&self) -> tokio::time::Instant; @@ -563,6 +564,8 @@ pub fn initialize_logging_with_log_prefix( .with_target("opentelemetry", LevelFilter::OFF), // otel has some log span under debug that we don't care about ); + let sqlite_layer = get_reloadable_sqlite_layer().unwrap(); + use tracing_subscriber::Registry; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; @@ -574,6 +577,7 @@ pub fn initialize_logging_with_log_prefix( std::env::var(env_var).unwrap_or_default() != "1" } if let Err(err) = Registry::default() + .with(sqlite_layer) .with(if is_layer_enabled(DISABLE_OTEL_TRACING) { Some(otel::tracing_layer()) } else { diff --git a/hyperactor_telemetry/src/sqlite.rs b/hyperactor_telemetry/src/sqlite.rs index 0f02fa148..d46904ff7 100644 --- a/hyperactor_telemetry/src/sqlite.rs +++ b/hyperactor_telemetry/src/sqlite.rs @@ -7,6 +7,8 @@ */ use std::collections::HashMap; +use std::fs; +use std::path::PathBuf; use std::sync::Arc; use std::sync::Mutex; @@ -20,11 +22,19 @@ use serde_json::Value as JValue; use serde_rusqlite::*; use tracing::Event; use tracing::Subscriber; -use tracing::level_filters::LevelFilter; use tracing_subscriber::Layer; -use tracing_subscriber::filter::Targets; +use tracing_subscriber::Registry; use tracing_subscriber::prelude::*; +use tracing_subscriber::reload; +pub type SqliteReloadHandle = reload::Handle, Registry>; + +lazy_static! { + // Reload handle allows us to include a no-op layer during init, but load + // the layer dynamically during tests. + static ref RELOAD_HANDLE: Mutex> = + Mutex::new(None); +} pub trait TableDef { fn name(&self) -> &'static str; fn columns(&self) -> &'static [&'static str]; @@ -224,7 +234,15 @@ macro_rules! insert_event { impl SqliteLayer { pub fn new() -> Result { let conn = Connection::open_in_memory()?; + Self::setup_connection(conn) + } + + pub fn new_with_file(db_path: &str) -> Result { + let conn = Connection::open(db_path)?; + Self::setup_connection(conn) + } + fn setup_connection(conn: Connection) -> Result { for table in ALL_TABLES.iter() { conn.execute(&table.create_table_stmt, [])?; } @@ -326,21 +344,89 @@ fn print_table(conn: &Connection, table_name: TableName) -> Result<()> { Ok(()) } -pub fn with_tracing_db() -> Arc> { - let layer = SqliteLayer::new().unwrap(); - let conn = layer.connection(); - - let layer = layer.with_filter( - Targets::new() - .with_default(LevelFilter::TRACE) - .with_targets(vec![ - ("tokio", LevelFilter::OFF), - ("opentelemetry", LevelFilter::OFF), - ("runtime", LevelFilter::OFF), - ]), - ); - tracing_subscriber::registry().with(layer).init(); - conn +fn init_tracing_subscriber(layer: SqliteLayer) { + let handle = RELOAD_HANDLE.lock().unwrap(); + if let Some(reload_handle) = handle.as_ref() { + let _ = reload_handle.reload(layer); + } else { + tracing_subscriber::registry().with(layer).init(); + } +} + +// === API === + +// Creates a new reload handler and no-op layer for initialization +pub fn get_reloadable_sqlite_layer() -> Result, Registry>> { + let (layer, reload_handle) = reload::Layer::new(None); + let mut handle = RELOAD_HANDLE.lock().unwrap(); + *handle = Some(reload_handle); + Ok(layer) +} + +/// RAII guard for SQLite tracing database +pub struct SqliteTracing { + db_path: Option, + connection: Arc>, +} + +impl SqliteTracing { + /// Create a new SqliteTracing with a temporary file + pub fn new() -> Result { + let temp_dir = std::env::temp_dir(); + let file_name = format!("hyperactor_trace_{}.db", std::process::id()); + let db_path = temp_dir.join(file_name); + + let db_path_str = db_path.to_string_lossy(); + let layer = SqliteLayer::new_with_file(&db_path_str)?; + let connection = layer.connection(); + + init_tracing_subscriber(layer); + + Ok(Self { + db_path: Some(db_path), + connection, + }) + } + + /// Create a new SqliteTracing with in-memory database + pub fn new_in_memory() -> Result { + let layer = SqliteLayer::new()?; + let connection = layer.connection(); + + init_tracing_subscriber(layer); + + Ok(Self { + db_path: None, + connection, + }) + } + + /// Get the path to the temporary database file (None for in-memory) + pub fn db_path(&self) -> Option<&PathBuf> { + self.db_path.as_ref() + } + + /// Get a reference to the database connection + pub fn connection(&self) -> Arc> { + self.connection.clone() + } +} + +impl Drop for SqliteTracing { + fn drop(&mut self) { + // Reset the layer to None + let handle = RELOAD_HANDLE.lock().unwrap(); + if let Some(reload_handle) = handle.as_ref() { + let _ = reload_handle.reload(None); + } + + // Delete the temporary file if it exists + if let Some(db_path) = &self.db_path { + if db_path.exists() { + let _ = fs::remove_file(db_path); + } + } + } } #[cfg(test)] @@ -350,8 +436,9 @@ mod tests { use super::*; #[test] - fn test_sqlite_layer() -> Result<()> { - let conn = with_tracing_db(); + fn test_sqlite_tracing_with_file() -> Result<()> { + let tracing = SqliteTracing::new()?; + let conn = tracing.connection(); info!(target:"messages", test_field = "test_value", "Test msg"); info!(target:"log_events", test_field = "test_value", "Test event"); @@ -362,6 +449,87 @@ mod tests { .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?; print_table(&conn.lock().unwrap(), TableName::LogEvents)?; assert!(count > 0); + + // Verify we have a file path + assert!(tracing.db_path().is_some()); + let db_path = tracing.db_path().unwrap(); + assert!(db_path.exists()); + + Ok(()) + } + + #[test] + fn test_sqlite_tracing_in_memory() -> Result<()> { + let tracing = SqliteTracing::new_in_memory()?; + let conn = tracing.connection(); + + info!(target:"messages", test_field = "test_value", "Test event in memory"); + + let count: i64 = + conn.lock() + .unwrap() + .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?; + print_table(&conn.lock().unwrap(), TableName::Messages)?; + assert!(count > 0); + + // Verify we don't have a file path for in-memory + assert!(tracing.db_path().is_none()); + + Ok(()) + } + + #[test] + fn test_sqlite_tracing_cleanup() -> Result<()> { + let db_path = { + let tracing = SqliteTracing::new()?; + let conn = tracing.connection(); + + info!(target:"log_events", test_field = "cleanup_test", "Test cleanup event"); + + let count: i64 = + conn.lock() + .unwrap() + .query_row("SELECT COUNT(*) FROM log_events", [], |row| row.get(0))?; + assert!(count > 0); + + tracing.db_path().unwrap().clone() + }; // tracing goes out of scope here, triggering Drop + + // File should be cleaned up after Drop + assert!(!db_path.exists()); + + Ok(()) + } + + #[test] + fn test_sqlite_tracing_different_targets() -> Result<()> { + let tracing = SqliteTracing::new_in_memory()?; + let conn = tracing.connection(); + + // Test different event targets + info!(target:"messages", src = "actor1", dest = "actor2", payload = "test_message", "Message event"); + info!(target:"actor_lifecycle", actor_id = "123", actor = "TestActor", name = "test", "Lifecycle event"); + info!(target:"log_events", test_field = "general_event", "General event"); + + // Check that events went to the right tables + let message_count: i64 = + conn.lock() + .unwrap() + .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?; + assert_eq!(message_count, 1); + + let lifecycle_count: i64 = + conn.lock() + .unwrap() + .query_row("SELECT COUNT(*) FROM actor_lifecycle", [], |row| row.get(0))?; + assert_eq!(lifecycle_count, 1); + + let events_count: i64 = + conn.lock() + .unwrap() + .query_row("SELECT COUNT(*) FROM log_events", [], |row| row.get(0))?; + assert_eq!(events_count, 1); + Ok(()) } } diff --git a/monarch_hyperactor/src/telemetry.rs b/monarch_hyperactor/src/telemetry.rs index 3979e439e..1ccf8d6b6 100644 --- a/monarch_hyperactor/src/telemetry.rs +++ b/monarch_hyperactor/src/telemetry.rs @@ -13,6 +13,7 @@ use std::cell::Cell; use hyperactor::clock::ClockKind; use hyperactor::clock::RealClock; use hyperactor::clock::SimClock; +use hyperactor_telemetry::sqlite::SqliteTracing; use hyperactor_telemetry::swap_telemetry_clock; use opentelemetry::global; use opentelemetry::metrics; @@ -65,7 +66,6 @@ pub fn forward_to_tracing(py: Python, record: PyObject) -> PyResult<()> { let file = record.getattr(py, "filename")?; let file: &str = file.extract(py)?; let level: i32 = record.getattr(py, "levelno")?.extract(py)?; - // Map level number to level name match level { 40 | 50 => { @@ -82,6 +82,7 @@ pub fn forward_to_tracing(py: Python, record: PyObject) -> PyResult<()> { match traceback { Some(traceback) => { tracing::error!( + target:"log_events", file = file, lineno = lineno, stacktrace = traceback, @@ -93,10 +94,10 @@ pub fn forward_to_tracing(py: Python, record: PyObject) -> PyResult<()> { } } } - 30 => tracing::warn!(file = file, lineno = lineno, message), - 20 => tracing::info!(file = file, lineno = lineno, message), - 10 => tracing::debug!(file = file, lineno = lineno, message), - _ => tracing::info!(file = file, lineno = lineno, message), + 30 => tracing::warn!(target:"log_events", file = file, lineno = lineno, message), + 20 => tracing::info!(target:"log_events", file = file, lineno = lineno, message), + 10 => tracing::debug!(target:"log_events", file = file, lineno = lineno, message), + _ => tracing::info!(target:"log_events", file = file, lineno = lineno, message), } Ok(()) } @@ -215,6 +216,62 @@ impl PySpan { } } +#[pyclass( + subclass, + module = "monarch._rust_bindings.monarch_hyperactor.telemetry" +)] +struct PySqliteTracing { + guard: Option, +} + +#[pymethods] +impl PySqliteTracing { + #[new] + #[pyo3(signature = (in_memory = false))] + fn new(in_memory: bool) -> PyResult { + let guard = if in_memory { + SqliteTracing::new_in_memory() + } else { + SqliteTracing::new() + }; + + match guard { + Ok(guard) => Ok(Self { guard: Some(guard) }), + Err(e) => Err(PyErr::new::(format!( + "Failed to create SQLite tracing guard: {}", + e + ))), + } + } + + fn db_path(&self) -> PyResult> { + match &self.guard { + Some(guard) => Ok(guard.db_path().map(|p| p.to_string_lossy().to_string())), + None => Err(PyErr::new::( + "Guard has been closed", + )), + } + } + + fn __enter__(slf: PyRefMut<'_, Self>) -> PyResult> { + Ok(slf) + } + + fn __exit__( + &mut self, + _exc_type: Option, + _exc_value: Option, + _traceback: Option, + ) -> PyResult { + self.guard = None; + Ok(false) // Don't suppress exceptions + } + + fn close(&mut self) { + self.guard = None; + } +} + use pyo3::Bound; use pyo3::types::PyModule; @@ -267,5 +324,6 @@ pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> { module.add_class::()?; module.add_class::()?; module.add_class::()?; + module.add_class::()?; Ok(()) } diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/telemetry.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/telemetry.pyi index 59030beab..878258df3 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/telemetry.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/telemetry.pyi @@ -74,7 +74,6 @@ def get_current_span_id() -> int: def use_real_clock() -> None: """ Convenience function to switch to real-time clock. - This switches the telemetry system to use real system time. """ ... @@ -160,3 +159,59 @@ class PyUpDownCounter: - value (int): The value to add to the counter (can be positive or negative). """ ... + +class PySqliteTracing: + def __init__(self, in_memory: bool = False) -> None: + """ + Create a new PySqliteTracing. + + This creates an RAII guard that sets up SQLite tracing collection. + When used as a context manager, it will automatically clean up when exiting. + + Args: + - in_memory (bool): If True, uses an in-memory database. If False, creates a temporary file. + """ + ... + + def db_path(self) -> str | None: + """ + Get the path to the database file. + + Returns: + - str | None: The path to the database file, or None if using in-memory database. + + Raises: + - RuntimeError: If the guard has been closed. + """ + ... + + def close(self) -> None: + """ + Manually close the guard and clean up resources. + + After calling this method, the guard cannot be used anymore. + """ + ... + + def __enter__(self) -> "PySqliteTracing": + """ + Enter the context manager. + + Returns: + - PySqliteTracing: Self for use in the with statement. + """ + ... + + def __exit__(self, exc_type: object, exc_value: object, traceback: object) -> bool: + """ + Exit the context manager and clean up resources. + + Args: + - exc_type: Exception type (if any) + - exc_value: Exception value (if any) + - traceback: Exception traceback (if any) + + Returns: + - bool: False (does not suppress exceptions) + """ + ...