diff --git a/hyperactor/src/mailbox.rs b/hyperactor/src/mailbox.rs index 07ed7d624..b1de225cb 100644 --- a/hyperactor/src/mailbox.rs +++ b/hyperactor/src/mailbox.rs @@ -1043,7 +1043,7 @@ impl MailboxSender for MailboxClient { return_handle: PortHandle>, ) { // tracing::trace!(name = "post", "posting message to {}", envelope.dest); - tracing::event!(target:"message", tracing::Level::DEBUG, "crc"=envelope.data.crc(), "size"=envelope.data.len(), "sender"= %envelope.sender, "dest" = %envelope.dest.0, "port"= envelope.dest.1, "message_type" = envelope.data.typename().unwrap_or("unknown"), "send_message"); + tracing::event!(target:"messages", tracing::Level::DEBUG, "crc"=envelope.data.crc(), "size"=envelope.data.len(), "sender"= %envelope.sender, "dest" = %envelope.dest.0, "port"= envelope.dest.1, "message_type" = envelope.data.typename().unwrap_or("unknown"), "send_message"); if let Err(mpsc::error::SendError((envelope, return_handle))) = self.buffer.send((envelope, return_handle)) diff --git a/hyperactor_telemetry/Cargo.toml b/hyperactor_telemetry/Cargo.toml index ce7199a4d..ccb08276c 100644 --- a/hyperactor_telemetry/Cargo.toml +++ b/hyperactor_telemetry/Cargo.toml @@ -20,9 +20,11 @@ lazy_static = "1.5" opentelemetry = "0.29" opentelemetry_sdk = { version = "0.29.0", features = ["rt-tokio"] } rand = { version = "0.8", features = ["small_rng"] } +rusqlite = { version = "0.36.0", features = ["backup", "blob", "bundled", "column_decltype", "functions", "limits", "modern_sqlite", "serde_json"] } scuba = { version = "0.1.0", git = "https://github.com/facebookexperimental/rust-shed.git", branch = "main", optional = true } serde = { version = "1.0.219", features = ["derive", "rc"] } serde_json = { version = "1.0.140", features = ["alloc", "float_roundtrip", "raw_value", "unbounded_depth"] } +serde_rusqlite = "0.39.3" tokio = { version = "1.46.1", features = ["full", "test-util", "tracing"] } tracing = { version = "0.1.41", features = ["attributes", "valuable"] } tracing-appender = "0.2.3" diff --git a/hyperactor_telemetry/src/lib.rs b/hyperactor_telemetry/src/lib.rs index 415ff17e2..cd8415434 100644 --- a/hyperactor_telemetry/src/lib.rs +++ b/hyperactor_telemetry/src/lib.rs @@ -33,6 +33,7 @@ mod otel; mod pool; pub mod recorder; mod spool; +pub mod sqlite; use std::io::IsTerminal; use std::io::Write; use std::str::FromStr; @@ -63,6 +64,7 @@ use tracing_subscriber::fmt::format::Writer; use tracing_subscriber::registry::LookupSpan; use crate::recorder::Recorder; +use crate::sqlite::get_reloadable_sqlite_layer; pub trait TelemetryClock { fn now(&self) -> tokio::time::Instant; @@ -562,6 +564,8 @@ pub fn initialize_logging_with_log_prefix( .with_target("opentelemetry", LevelFilter::OFF), // otel has some log span under debug that we don't care about ); + let sqlite_layer = get_reloadable_sqlite_layer().unwrap(); + use tracing_subscriber::Registry; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; @@ -573,6 +577,7 @@ pub fn initialize_logging_with_log_prefix( std::env::var(env_var).unwrap_or_default() != "1" } if let Err(err) = Registry::default() + .with(sqlite_layer) .with(if is_layer_enabled(DISABLE_OTEL_TRACING) { Some(otel::tracing_layer()) } else { diff --git a/hyperactor_telemetry/src/sqlite.rs b/hyperactor_telemetry/src/sqlite.rs new file mode 100644 index 000000000..d46904ff7 --- /dev/null +++ b/hyperactor_telemetry/src/sqlite.rs @@ -0,0 +1,535 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use std::collections::HashMap; +use std::fs; +use std::path::PathBuf; +use std::sync::Arc; +use std::sync::Mutex; + +use anyhow::Result; +use anyhow::anyhow; +use lazy_static::lazy_static; +use rusqlite::Connection; +use rusqlite::functions::FunctionFlags; +use serde::Serialize; +use serde_json::Value as JValue; +use serde_rusqlite::*; +use tracing::Event; +use tracing::Subscriber; +use tracing_subscriber::Layer; +use tracing_subscriber::Registry; +use tracing_subscriber::prelude::*; +use tracing_subscriber::reload; + +pub type SqliteReloadHandle = reload::Handle, Registry>; + +lazy_static! { + // Reload handle allows us to include a no-op layer during init, but load + // the layer dynamically during tests. + static ref RELOAD_HANDLE: Mutex> = + Mutex::new(None); +} +pub trait TableDef { + fn name(&self) -> &'static str; + fn columns(&self) -> &'static [&'static str]; + fn create_table_stmt(&self) -> String { + let name = self.name(); + let columns = self + .columns() + .iter() + .map(|col| format!("{col} TEXT ")) + .collect::>() + .join(","); + format!("create table if not exists {name} (seq INTEGER primary key, {columns})") + } + fn insert_stmt(&self) -> String { + let name = self.name(); + let columns = self.columns().join(", "); + let params = self + .columns() + .iter() + .map(|c| format!(":{c}")) + .collect::>() + .join(", "); + format!("insert into {name} ({columns}) values ({params})") + } +} + +impl TableDef for (&'static str, &'static [&'static str]) { + fn name(&self) -> &'static str { + self.0 + } + + fn columns(&self) -> &'static [&'static str] { + self.1 + } +} + +#[derive(Clone, Debug)] +pub struct Table { + pub columns: &'static [&'static str], + pub create_table_stmt: String, + pub insert_stmt: String, +} + +impl From<(&'static str, &'static [&'static str])> for Table { + fn from(value: (&'static str, &'static [&'static str])) -> Self { + Self { + columns: value.columns(), + create_table_stmt: value.create_table_stmt(), + insert_stmt: value.insert_stmt(), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TableName { + ActorLifecycle, + Messages, + LogEvents, +} + +impl TableName { + pub const ACTOR_LIFECYCLE_STR: &'static str = "actor_lifecycle"; + pub const MESSAGES_STR: &'static str = "messages"; + pub const LOG_EVENTS_STR: &'static str = "log_events"; + + pub fn as_str(&self) -> &'static str { + match self { + TableName::ActorLifecycle => Self::ACTOR_LIFECYCLE_STR, + TableName::Messages => Self::MESSAGES_STR, + TableName::LogEvents => Self::LOG_EVENTS_STR, + } + } + + pub fn get_table(&self) -> &'static Table { + match self { + TableName::ActorLifecycle => &ACTOR_LIFECYCLE, + TableName::Messages => &MESSAGES, + TableName::LogEvents => &LOG_EVENTS, + } + } +} + +lazy_static! { + static ref ACTOR_LIFECYCLE: Table = ( + TableName::ActorLifecycle.as_str(), + [ + "actor_id", + "actor", + "name", + "supervised_actor", + "actor_status", + "module_path", + "line", + "file", + ] + .as_slice() + ) + .into(); + static ref MESSAGES: Table = ( + TableName::Messages.as_str(), + [ + "span_id", + "time_us", + "src", + "dest", + "payload", + "module_path", + "line", + "file", + ] + .as_slice() + ) + .into(); + static ref LOG_EVENTS: Table = ( + TableName::LogEvents.as_str(), + [ + "span_id", + "time_us", + "name", + "message", + "actor_id", + "level", + "line", + "file", + "module_path", + ] + .as_slice() + ) + .into(); + static ref ALL_TABLES: Vec = vec![ + ACTOR_LIFECYCLE.clone(), + MESSAGES.clone(), + LOG_EVENTS.clone() + ]; +} + +pub struct SqliteLayer { + conn: Arc>, +} +use tracing::field::Visit; + +#[derive(Debug, Clone, Default, Serialize)] +struct SqlVisitor(HashMap); + +impl Visit for SqlVisitor { + fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) { + self.0.insert( + field.name().to_string(), + JValue::String(format!("{:?}", value)), + ); + } + + fn record_str(&mut self, field: &tracing::field::Field, value: &str) { + self.0 + .insert(field.name().to_string(), JValue::String(value.to_string())); + } + + fn record_i64(&mut self, field: &tracing::field::Field, value: i64) { + self.0 + .insert(field.name().to_string(), JValue::Number(value.into())); + } + + fn record_f64(&mut self, field: &tracing::field::Field, value: f64) { + let n = serde_json::Number::from_f64(value).unwrap(); + self.0.insert(field.name().to_string(), JValue::Number(n)); + } + + fn record_u64(&mut self, field: &tracing::field::Field, value: u64) { + self.0 + .insert(field.name().to_string(), JValue::Number(value.into())); + } + + fn record_bool(&mut self, field: &tracing::field::Field, value: bool) { + self.0.insert(field.name().to_string(), JValue::Bool(value)); + } +} + +macro_rules! insert_event { + ($table:expr, $conn:ident, $event:ident) => { + let mut v: SqlVisitor = Default::default(); + $event.record(&mut v); + let meta = $event.metadata(); + v.0.insert( + "module_path".to_string(), + meta.module_path().map(String::from).into(), + ); + v.0.insert("line".to_string(), meta.line().into()); + v.0.insert("file".to_string(), meta.file().map(String::from).into()); + $conn.prepare_cached(&$table.insert_stmt)?.execute( + serde_rusqlite::to_params_named_with_fields(v, $table.columns)? + .to_slice() + .as_slice(), + )?; + }; +} + +impl SqliteLayer { + pub fn new() -> Result { + let conn = Connection::open_in_memory()?; + Self::setup_connection(conn) + } + + pub fn new_with_file(db_path: &str) -> Result { + let conn = Connection::open(db_path)?; + Self::setup_connection(conn) + } + + fn setup_connection(conn: Connection) -> Result { + for table in ALL_TABLES.iter() { + conn.execute(&table.create_table_stmt, [])?; + } + conn.create_scalar_function( + "assert", + 2, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + move |ctx| { + let condition: bool = ctx.get(0)?; + let message: String = ctx.get(1)?; + + if !condition { + return Err(rusqlite::Error::UserFunctionError( + anyhow!("assertion failed:{condition} {message}",).into(), + )); + } + + Ok(condition) + }, + )?; + + Ok(Self { + conn: Arc::new(Mutex::new(conn)), + }) + } + + fn insert_event(&self, event: &Event<'_>) -> Result<()> { + let conn = self.conn.lock().unwrap(); + match (event.metadata().target(), event.metadata().name()) { + (TableName::MESSAGES_STR, _) => { + insert_event!(TableName::Messages.get_table(), conn, event); + } + (TableName::ACTOR_LIFECYCLE_STR, _) => { + insert_event!(TableName::ActorLifecycle.get_table(), conn, event); + } + _ => { + insert_event!(TableName::LogEvents.get_table(), conn, event); + } + } + Ok(()) + } + + pub fn connection(&self) -> Arc> { + self.conn.clone() + } +} + +impl Layer for SqliteLayer { + fn on_event(&self, event: &Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>) { + self.insert_event(event).unwrap(); + } +} + +#[allow(dead_code)] +fn print_table(conn: &Connection, table_name: TableName) -> Result<()> { + let table_name_str = table_name.as_str(); + + // Get column names + let mut stmt = conn.prepare(&format!("PRAGMA table_info({})", table_name_str))?; + let column_info = stmt.query_map([], |row| { + row.get::<_, String>(1) // Column name is at index 1 + })?; + + let columns: Vec = column_info.collect::, _>>()?; + + // Print header + println!("=== {} ===", table_name_str.to_uppercase()); + println!("{}", columns.join(" | ")); + println!("{}", "-".repeat(columns.len() * 10)); + + // Print rows + let mut stmt = conn.prepare(&format!("SELECT * FROM {}", table_name_str))?; + let rows = stmt.query_map([], |row| { + let mut values = Vec::new(); + for (i, column) in columns.iter().enumerate() { + // Handle different column types properly + let value = if i == 0 && *column == "seq" { + // First column is always the INTEGER seq column + match row.get::<_, Option>(i)? { + Some(v) => v.to_string(), + None => "NULL".to_string(), + } + } else { + // All other columns are TEXT + match row.get::<_, Option>(i)? { + Some(v) => v, + None => "NULL".to_string(), + } + }; + values.push(value); + } + Ok(values.join(" | ")) + })?; + + for row in rows { + println!("{}", row?); + } + println!(); + Ok(()) +} + +fn init_tracing_subscriber(layer: SqliteLayer) { + let handle = RELOAD_HANDLE.lock().unwrap(); + if let Some(reload_handle) = handle.as_ref() { + let _ = reload_handle.reload(layer); + } else { + tracing_subscriber::registry().with(layer).init(); + } +} + +// === API === + +// Creates a new reload handler and no-op layer for initialization +pub fn get_reloadable_sqlite_layer() -> Result, Registry>> { + let (layer, reload_handle) = reload::Layer::new(None); + let mut handle = RELOAD_HANDLE.lock().unwrap(); + *handle = Some(reload_handle); + Ok(layer) +} + +/// RAII guard for SQLite tracing database +pub struct SqliteTracing { + db_path: Option, + connection: Arc>, +} + +impl SqliteTracing { + /// Create a new SqliteTracing with a temporary file + pub fn new() -> Result { + let temp_dir = std::env::temp_dir(); + let file_name = format!("hyperactor_trace_{}.db", std::process::id()); + let db_path = temp_dir.join(file_name); + + let db_path_str = db_path.to_string_lossy(); + let layer = SqliteLayer::new_with_file(&db_path_str)?; + let connection = layer.connection(); + + init_tracing_subscriber(layer); + + Ok(Self { + db_path: Some(db_path), + connection, + }) + } + + /// Create a new SqliteTracing with in-memory database + pub fn new_in_memory() -> Result { + let layer = SqliteLayer::new()?; + let connection = layer.connection(); + + init_tracing_subscriber(layer); + + Ok(Self { + db_path: None, + connection, + }) + } + + /// Get the path to the temporary database file (None for in-memory) + pub fn db_path(&self) -> Option<&PathBuf> { + self.db_path.as_ref() + } + + /// Get a reference to the database connection + pub fn connection(&self) -> Arc> { + self.connection.clone() + } +} + +impl Drop for SqliteTracing { + fn drop(&mut self) { + // Reset the layer to None + let handle = RELOAD_HANDLE.lock().unwrap(); + if let Some(reload_handle) = handle.as_ref() { + let _ = reload_handle.reload(None); + } + + // Delete the temporary file if it exists + if let Some(db_path) = &self.db_path { + if db_path.exists() { + let _ = fs::remove_file(db_path); + } + } + } +} + +#[cfg(test)] +mod tests { + use tracing::info; + + use super::*; + + #[test] + fn test_sqlite_tracing_with_file() -> Result<()> { + let tracing = SqliteTracing::new()?; + let conn = tracing.connection(); + + info!(target:"messages", test_field = "test_value", "Test msg"); + info!(target:"log_events", test_field = "test_value", "Test event"); + + let count: i64 = + conn.lock() + .unwrap() + .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?; + print_table(&conn.lock().unwrap(), TableName::LogEvents)?; + assert!(count > 0); + + // Verify we have a file path + assert!(tracing.db_path().is_some()); + let db_path = tracing.db_path().unwrap(); + assert!(db_path.exists()); + + Ok(()) + } + + #[test] + fn test_sqlite_tracing_in_memory() -> Result<()> { + let tracing = SqliteTracing::new_in_memory()?; + let conn = tracing.connection(); + + info!(target:"messages", test_field = "test_value", "Test event in memory"); + + let count: i64 = + conn.lock() + .unwrap() + .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?; + print_table(&conn.lock().unwrap(), TableName::Messages)?; + assert!(count > 0); + + // Verify we don't have a file path for in-memory + assert!(tracing.db_path().is_none()); + + Ok(()) + } + + #[test] + fn test_sqlite_tracing_cleanup() -> Result<()> { + let db_path = { + let tracing = SqliteTracing::new()?; + let conn = tracing.connection(); + + info!(target:"log_events", test_field = "cleanup_test", "Test cleanup event"); + + let count: i64 = + conn.lock() + .unwrap() + .query_row("SELECT COUNT(*) FROM log_events", [], |row| row.get(0))?; + assert!(count > 0); + + tracing.db_path().unwrap().clone() + }; // tracing goes out of scope here, triggering Drop + + // File should be cleaned up after Drop + assert!(!db_path.exists()); + + Ok(()) + } + + #[test] + fn test_sqlite_tracing_different_targets() -> Result<()> { + let tracing = SqliteTracing::new_in_memory()?; + let conn = tracing.connection(); + + // Test different event targets + info!(target:"messages", src = "actor1", dest = "actor2", payload = "test_message", "Message event"); + info!(target:"actor_lifecycle", actor_id = "123", actor = "TestActor", name = "test", "Lifecycle event"); + info!(target:"log_events", test_field = "general_event", "General event"); + + // Check that events went to the right tables + let message_count: i64 = + conn.lock() + .unwrap() + .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?; + assert_eq!(message_count, 1); + + let lifecycle_count: i64 = + conn.lock() + .unwrap() + .query_row("SELECT COUNT(*) FROM actor_lifecycle", [], |row| row.get(0))?; + assert_eq!(lifecycle_count, 1); + + let events_count: i64 = + conn.lock() + .unwrap() + .query_row("SELECT COUNT(*) FROM log_events", [], |row| row.get(0))?; + assert_eq!(events_count, 1); + + Ok(()) + } +} diff --git a/monarch_hyperactor/src/telemetry.rs b/monarch_hyperactor/src/telemetry.rs index 3979e439e..1ccf8d6b6 100644 --- a/monarch_hyperactor/src/telemetry.rs +++ b/monarch_hyperactor/src/telemetry.rs @@ -13,6 +13,7 @@ use std::cell::Cell; use hyperactor::clock::ClockKind; use hyperactor::clock::RealClock; use hyperactor::clock::SimClock; +use hyperactor_telemetry::sqlite::SqliteTracing; use hyperactor_telemetry::swap_telemetry_clock; use opentelemetry::global; use opentelemetry::metrics; @@ -65,7 +66,6 @@ pub fn forward_to_tracing(py: Python, record: PyObject) -> PyResult<()> { let file = record.getattr(py, "filename")?; let file: &str = file.extract(py)?; let level: i32 = record.getattr(py, "levelno")?.extract(py)?; - // Map level number to level name match level { 40 | 50 => { @@ -82,6 +82,7 @@ pub fn forward_to_tracing(py: Python, record: PyObject) -> PyResult<()> { match traceback { Some(traceback) => { tracing::error!( + target:"log_events", file = file, lineno = lineno, stacktrace = traceback, @@ -93,10 +94,10 @@ pub fn forward_to_tracing(py: Python, record: PyObject) -> PyResult<()> { } } } - 30 => tracing::warn!(file = file, lineno = lineno, message), - 20 => tracing::info!(file = file, lineno = lineno, message), - 10 => tracing::debug!(file = file, lineno = lineno, message), - _ => tracing::info!(file = file, lineno = lineno, message), + 30 => tracing::warn!(target:"log_events", file = file, lineno = lineno, message), + 20 => tracing::info!(target:"log_events", file = file, lineno = lineno, message), + 10 => tracing::debug!(target:"log_events", file = file, lineno = lineno, message), + _ => tracing::info!(target:"log_events", file = file, lineno = lineno, message), } Ok(()) } @@ -215,6 +216,62 @@ impl PySpan { } } +#[pyclass( + subclass, + module = "monarch._rust_bindings.monarch_hyperactor.telemetry" +)] +struct PySqliteTracing { + guard: Option, +} + +#[pymethods] +impl PySqliteTracing { + #[new] + #[pyo3(signature = (in_memory = false))] + fn new(in_memory: bool) -> PyResult { + let guard = if in_memory { + SqliteTracing::new_in_memory() + } else { + SqliteTracing::new() + }; + + match guard { + Ok(guard) => Ok(Self { guard: Some(guard) }), + Err(e) => Err(PyErr::new::(format!( + "Failed to create SQLite tracing guard: {}", + e + ))), + } + } + + fn db_path(&self) -> PyResult> { + match &self.guard { + Some(guard) => Ok(guard.db_path().map(|p| p.to_string_lossy().to_string())), + None => Err(PyErr::new::( + "Guard has been closed", + )), + } + } + + fn __enter__(slf: PyRefMut<'_, Self>) -> PyResult> { + Ok(slf) + } + + fn __exit__( + &mut self, + _exc_type: Option, + _exc_value: Option, + _traceback: Option, + ) -> PyResult { + self.guard = None; + Ok(false) // Don't suppress exceptions + } + + fn close(&mut self) { + self.guard = None; + } +} + use pyo3::Bound; use pyo3::types::PyModule; @@ -267,5 +324,6 @@ pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> { module.add_class::()?; module.add_class::()?; module.add_class::()?; + module.add_class::()?; Ok(()) } diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/telemetry.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/telemetry.pyi index 59030beab..878258df3 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/telemetry.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/telemetry.pyi @@ -74,7 +74,6 @@ def get_current_span_id() -> int: def use_real_clock() -> None: """ Convenience function to switch to real-time clock. - This switches the telemetry system to use real system time. """ ... @@ -160,3 +159,59 @@ class PyUpDownCounter: - value (int): The value to add to the counter (can be positive or negative). """ ... + +class PySqliteTracing: + def __init__(self, in_memory: bool = False) -> None: + """ + Create a new PySqliteTracing. + + This creates an RAII guard that sets up SQLite tracing collection. + When used as a context manager, it will automatically clean up when exiting. + + Args: + - in_memory (bool): If True, uses an in-memory database. If False, creates a temporary file. + """ + ... + + def db_path(self) -> str | None: + """ + Get the path to the database file. + + Returns: + - str | None: The path to the database file, or None if using in-memory database. + + Raises: + - RuntimeError: If the guard has been closed. + """ + ... + + def close(self) -> None: + """ + Manually close the guard and clean up resources. + + After calling this method, the guard cannot be used anymore. + """ + ... + + def __enter__(self) -> "PySqliteTracing": + """ + Enter the context manager. + + Returns: + - PySqliteTracing: Self for use in the with statement. + """ + ... + + def __exit__(self, exc_type: object, exc_value: object, traceback: object) -> bool: + """ + Exit the context manager and clean up resources. + + Args: + - exc_type: Exception type (if any) + - exc_value: Exception value (if any) + - traceback: Exception traceback (if any) + + Returns: + - bool: False (does not suppress exceptions) + """ + ...