diff --git a/Cargo.lock b/Cargo.lock index dac2c8de..8c7d6942 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1485,9 +1485,11 @@ dependencies = [ "base64 0.22.1", "bimap", "blst", + "bytes", "cipher 0.4.4", "ctr 0.9.2", "derive_more 2.0.1", + "docker-image", "eth2_keystore", "ethereum_serde_utils", "ethereum_ssz 0.8.3", @@ -1497,9 +1499,11 @@ dependencies = [ "pbkdf2 0.12.2", "rand 0.9.0", "reqwest", + "scopeguard", "serde", "serde_json", "serde_yaml", + "serial_test", "sha2 0.10.8", "ssz_types", "thiserror 2.0.12", @@ -1589,9 +1593,11 @@ dependencies = [ "axum 0.8.1", "cb-common", "cb-pbs", + "cb-signer", "eyre", "reqwest", "serde_json", + "tempfile", "tokio", "tracing", "tracing-subscriber", @@ -2158,6 +2164,16 @@ dependencies = [ "serde_yaml", ] +[[package]] +name = "docker-image" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ed901b8f2157bafce6e96f39217f7b1a4af32d84266d251ed7c22ce001f0b" +dependencies = [ + "lazy_static", + "regex", +] + [[package]] name = "doctest-file" version = "1.0.0" @@ -4331,6 +4347,15 @@ dependencies = [ "cipher 0.3.0", ] +[[package]] +name = "scc" +version = "2.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22b2d775fb28f245817589471dd49c5edf64237f4a19d10ce9a92ff4651a27f4" +dependencies = [ + "sdd", +] + [[package]] name = "schannel" version = "0.1.27" @@ -4358,6 +4383,12 @@ dependencies = [ "sha2 0.9.9", ] +[[package]] +name = "sdd" +version = "3.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "584e070911c7017da6cb2eb0788d09f43d789029b5877d3e5ecc8acf86ceee21" + [[package]] name = "sec1" version = "0.7.3" @@ -4562,6 +4593,31 @@ dependencies = [ "serde", ] +[[package]] +name = "serial_test" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b258109f244e1d6891bf1053a55d63a5cd4f8f4c30cf9a1280989f80e7a1fa9" +dependencies = [ + "futures", + "log", + "once_cell", + "parking_lot", + "scc", + "serial_test_derive", +] + +[[package]] +name = "serial_test_derive" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d69265a08751de7844521fd15003ae0a888e035773ba05695c5c759a6f89eef" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "sha1" version = "0.10.6" @@ -4863,9 +4919,9 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tempfile" -version = "3.19.0" +version = "3.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "488960f40a3fd53d72c2a29a58722561dee8afdd175bd88e3db4677d7b2ba600" +checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" dependencies = [ "fastrand", "getrandom 0.3.1", diff --git a/Cargo.toml b/Cargo.toml index 283caf0d..7f21eb0d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ base64 = "0.22.1" bimap = { version = "0.6.3", features = ["serde"] } blsful = "2.5" blst = "0.3.11" +bytes = "1.10.1" cb-cli = { path = "crates/cli" } cb-common = { path = "crates/common" } cb-metrics = { path = "crates/metrics" } @@ -36,6 +37,7 @@ color-eyre = "0.6.3" ctr = "0.9.2" derive_more = { version = "2.0.1", features = ["deref", "display", "from", "into"] } docker-compose-types = "0.16.0" +docker-image = "0.2.1" eth2_keystore = { git = "https://github.com/sigp/lighthouse", rev = "8d058e4040b765a96aa4968f4167af7571292be2" } ethereum_serde_utils = "0.7.0" ethereum_ssz = "0.8" @@ -52,11 +54,14 @@ prometheus = "0.13.4" prost = "0.13.4" rand = { version = "0.9", features = ["os_rng"] } reqwest = { version = "0.12.4", features = ["json", "stream"] } +scopeguard = "1.2.0" serde = { version = "1.0.202", features = ["derive"] } serde_json = "1.0.117" serde_yaml = "0.9.33" +serial_test = "3.2.0" sha2 = "0.10.8" ssz_types = "0.10" +tempfile = "3.20.0" thiserror = "2.0.12" tokio = { version = "1.37.0", features = ["full"] } toml = "0.8.13" diff --git a/config.example.toml b/config.example.toml index d32dfbf9..f95bd255 100644 --- a/config.example.toml +++ b/config.example.toml @@ -55,6 +55,9 @@ extra_validation_enabled = false # Execution Layer RPC url to use for extra validation # OPTIONAL rpc_url = "https://ethereum-holesky-rpc.publicnode.com" +# Timeout for any HTTP requests sent from the PBS module to other services, in seconds +# OPTIONAL, DEFAULT: 10 +http_timeout_seconds = 10 # The PBS module needs one or more [[relays]] as defined below. [[relays]] diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index df78b046..abc50c73 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -13,9 +13,11 @@ axum.workspace = true base64.workspace = true bimap.workspace = true blst.workspace = true +bytes.workspace = true cipher.workspace = true ctr.workspace = true derive_more.workspace = true +docker-image.workspace = true eth2_keystore.workspace = true ethereum_serde_utils.workspace = true ethereum_ssz.workspace = true @@ -41,3 +43,5 @@ tree_hash_derive.workspace = true unicode-normalization.workspace = true url.workspace = true jsonwebtoken.workspace = true +serial_test.workspace = true +scopeguard.workspace = true diff --git a/crates/common/src/config/constants.rs b/crates/common/src/config/constants.rs index 3f93ce27..309eb15e 100644 --- a/crates/common/src/config/constants.rs +++ b/crates/common/src/config/constants.rs @@ -35,6 +35,11 @@ pub const SIGNER_MODULE_NAME: &str = "signer"; /// Where the signer module should open the server pub const SIGNER_ENDPOINT_ENV: &str = "CB_SIGNER_ENDPOINT"; +// JWT authentication settings +pub const SIGNER_JWT_AUTH_FAIL_LIMIT_ENV: &str = "CB_SIGNER_JWT_AUTH_FAIL_LIMIT"; +pub const SIGNER_JWT_AUTH_FAIL_TIMEOUT_SECONDS_ENV: &str = + "CB_SIGNER_JWT_AUTH_FAIL_TIMEOUT_SECONDS"; + /// Comma separated list module_id=jwt_secret pub const JWTS_ENV: &str = "CB_JWTS"; @@ -67,6 +72,15 @@ pub const PROXY_DIR_KEYS_DEFAULT: &str = "/proxy_keys"; pub const PROXY_DIR_SECRETS_ENV: &str = "CB_PROXY_SECRETS_DIR"; pub const PROXY_DIR_SECRETS_DEFAULT: &str = "/proxy_secrets"; +////////////////////////// MUXER ////////////////////////// + +/// Timeout for HTTP requests, in seconds +pub const HTTP_TIMEOUT_SECONDS_ENV: &str = "CB_HTTP_TIMEOUT_SECONDS"; +pub const HTTP_TIMEOUT_SECONDS_DEFAULT: u64 = 10; + +/// Max content length for Muxer HTTP responses, in bytes +pub const MUXER_HTTP_MAX_LENGTH: u64 = 1024 * 1024 * 1024 * 10; // 10 MiB + ///////////////////////// MODULES ///////////////////////// /// The unique ID of the module @@ -81,3 +95,6 @@ pub const SIGNER_URL_ENV: &str = "CB_SIGNER_URL"; /// Events modules /// Where to receive builder events pub const BUILDER_PORT_ENV: &str = "CB_BUILDER_PORT"; + +///////////////////////// TESTING CONSTANTS ///////////////////////// +pub const CB_TEST_HTTP_DISABLE_CONTENT_LENGTH_ENV: &str = "CB_TEST_HTTP_DISABLE_CONTENT_LENGTH"; diff --git a/crates/common/src/config/mod.rs b/crates/common/src/config/mod.rs index 75fd3c9d..b782999b 100644 --- a/crates/common/src/config/mod.rs +++ b/crates/common/src/config/mod.rs @@ -41,6 +41,9 @@ impl CommitBoostConfig { /// Validate config pub async fn validate(&self) -> Result<()> { self.pbs.pbs_config.validate(self.chain).await?; + if let Some(signer) = &self.signer { + signer.validate().await?; + } Ok(()) } diff --git a/crates/common/src/config/mux.rs b/crates/common/src/config/mux.rs index 28a7f9e1..e8a7851b 100644 --- a/crates/common/src/config/mux.rs +++ b/crates/common/src/config/mux.rs @@ -2,6 +2,7 @@ use std::{ collections::{HashMap, HashSet}, path::{Path, PathBuf}, sync::Arc, + time::Duration, }; use alloy::{ @@ -16,7 +17,11 @@ use tracing::{debug, info}; use url::Url; use super::{load_optional_env_var, PbsConfig, RelayConfig, MUX_PATH_ENV}; -use crate::{pbs::RelayClient, types::Chain}; +use crate::{ + config::{safe_read_http_response, HTTP_TIMEOUT_SECONDS_ENV}, + pbs::RelayClient, + types::Chain, +}; #[derive(Debug, Deserialize, Serialize)] pub struct PbsMuxes { @@ -38,13 +43,19 @@ impl PbsMuxes { chain: Chain, default_pbs: &PbsConfig, ) -> eyre::Result> { + let http_timeout = match load_optional_env_var(HTTP_TIMEOUT_SECONDS_ENV) { + Some(timeout_str) => Duration::from_secs(timeout_str.parse::()?), + None => Duration::from_secs(default_pbs.http_timeout_seconds), + }; + let mut muxes = self.muxes; for mux in muxes.iter_mut() { ensure!(!mux.relays.is_empty(), "mux config {} must have at least one relay", mux.id); if let Some(loader) = &mux.loader { - let extra_keys = loader.load(&mux.id, chain, default_pbs.rpc_url.clone()).await?; + let extra_keys = + loader.load(&mux.id, chain, default_pbs.rpc_url.clone(), http_timeout).await?; mux.validator_pubkeys.extend(extra_keys); } @@ -163,6 +174,7 @@ impl MuxKeysLoader { mux_id: &str, chain: Chain, rpc_url: Option, + http_timeout: Duration, ) -> eyre::Result> { match self { Self::File(config_path) => { @@ -175,11 +187,15 @@ impl MuxKeysLoader { } Self::HTTP { url } => { - let client = reqwest::Client::new(); + let url = Url::parse(url).wrap_err("failed to parse mux keys URL")?; + if url.scheme() != "https" { + bail!("mux keys URL must use HTTPS"); + } + let client = reqwest::ClientBuilder::new().timeout(http_timeout).build()?; let response = client.get(url).send().await?; - let pubkeys = response.text().await?; + let pubkeys = safe_read_http_response(response).await?; serde_json::from_str(&pubkeys) - .wrap_err("failed to fetch mux keys from http endpoint") + .wrap_err("failed to fetch mux keys from HTTP endpoint") } Self::Registry { registry, node_operator_id } => match registry { @@ -190,7 +206,9 @@ impl MuxKeysLoader { fetch_lido_registry_keys(rpc_url, chain, U256::from(*node_operator_id)).await } - NORegistry::SSV => fetch_ssv_pubkeys(chain, U256::from(*node_operator_id)).await, + NORegistry::SSV => { + fetch_ssv_pubkeys(chain, U256::from(*node_operator_id), http_timeout).await + } }, } } @@ -290,6 +308,7 @@ async fn fetch_lido_registry_keys( async fn fetch_ssv_pubkeys( chain: Chain, node_operator_id: U256, + http_timeout: Duration, ) -> eyre::Result> { const MAX_PER_PAGE: usize = 100; @@ -300,22 +319,16 @@ async fn fetch_ssv_pubkeys( _ => bail!("SSV network is not supported for chain: {chain:?}"), }; - let client = reqwest::Client::new(); let mut pubkeys: Vec = vec![]; let mut page = 1; loop { - let response = client - .get(format!( - "https://api.ssv.network/api/v4/{}/validators/in_operator/{}?perPage={}&page={}", - chain_name, node_operator_id, MAX_PER_PAGE, page - )) - .send() - .await - .map_err(|e| eyre::eyre!("Error sending request to SSV network API: {e}"))? - .json::() - .await?; + let url = format!( + "https://api.ssv.network/api/v4/{}/validators/in_operator/{}?perPage={}&page={}", + chain_name, node_operator_id, MAX_PER_PAGE, page + ); + let response = fetch_ssv_pubkeys_from_url(&url, http_timeout).await?; pubkeys.extend(response.validators.iter().map(|v| v.pubkey).collect::>()); page += 1; @@ -336,6 +349,25 @@ async fn fetch_ssv_pubkeys( Ok(pubkeys) } +async fn fetch_ssv_pubkeys_from_url( + url: &str, + http_timeout: Duration, +) -> eyre::Result { + let client = reqwest::ClientBuilder::new().timeout(http_timeout).build()?; + let response = client.get(url).send().await.map_err(|e| { + if e.is_timeout() { + eyre::eyre!("Request to SSV network API timed out: {e}") + } else { + eyre::eyre!("Error sending request to SSV network API: {e}") + } + })?; + + // Parse the response as JSON + let body_string = safe_read_http_response(response).await?; + serde_json::from_slice::(body_string.as_bytes()) + .wrap_err("failed to parse SSV response") +} + #[derive(Deserialize)] struct SSVResponse { validators: Vec, @@ -355,10 +387,22 @@ struct SSVPagination { #[cfg(test)] mod tests { - use alloy::{primitives::U256, providers::ProviderBuilder}; + use std::{env, net::SocketAddr}; + + use alloy::{hex::FromHex, primitives::U256, providers::ProviderBuilder}; + use axum::{response::Response, routing::get}; + use scopeguard::defer; + use serial_test::serial; + use tokio::{net::TcpListener, task::JoinHandle}; use url::Url; use super::*; + use crate::config::{ + CB_TEST_HTTP_DISABLE_CONTENT_LENGTH_ENV, HTTP_TIMEOUT_SECONDS_DEFAULT, + MUXER_HTTP_MAX_LENGTH, + }; + + const TEST_HTTP_TIMEOUT: u64 = 2; #[tokio::test] async fn test_lido_registry_address() -> eyre::Result<()> { @@ -393,14 +437,174 @@ mod tests { } #[tokio::test] + /// Tests that a successful SSV network fetch is handled and parsed properly async fn test_ssv_network_fetch() -> eyre::Result<()> { - let chain = Chain::Holesky; - let node_operator_id = U256::from(200); + // Start the mock server + let port = 30100; + let _server_handle = create_mock_server(port).await?; + let url = format!("http://localhost:{port}/ssv"); + let response = + fetch_ssv_pubkeys_from_url(&url, Duration::from_secs(HTTP_TIMEOUT_SECONDS_DEFAULT)) + .await?; + + // Make sure the response is correct + // NOTE: requires that ssv_data.json dpesn't change + assert_eq!(response.validators.len(), 3); + let expected_pubkeys = [ + BlsPublicKey::from_hex( + "0x967ba17a3e7f82a25aa5350ec34d6923e28ad8237b5a41efe2c5e325240d74d87a015bf04634f21900963539c8229b2a", + )?, + BlsPublicKey::from_hex( + "0xac769e8cec802e8ffee34de3253be8f438a0c17ee84bdff0b6730280d24b5ecb77ebc9c985281b41ee3bda8663b6658c", + )?, + BlsPublicKey::from_hex( + "0x8c866a5a05f3d45c49b457e29365259021a509c5daa82e124f9701a960ee87b8902e87175315ab638a3d8b1115b23639", + )?, + ]; + for (i, validator) in response.validators.iter().enumerate() { + assert_eq!(validator.pubkey, expected_pubkeys[i]); + } + + // Clean up the server handle + _server_handle.abort(); + + Ok(()) + } + + #[tokio::test] + #[serial] + /// Tests that the SSV network fetch is handled properly when the response's + /// body is too large + async fn test_ssv_network_fetch_big_data() -> eyre::Result<()> { + // Start the mock server + let port = 30101; + env::remove_var(CB_TEST_HTTP_DISABLE_CONTENT_LENGTH_ENV); + let _server_handle = create_mock_server(port).await?; + let url = format!("http://localhost:{port}/big_data"); + let response = fetch_ssv_pubkeys_from_url(&url, Duration::from_secs(120)).await; + + // The response should fail due to content length being too big + assert!(response.is_err(), "Expected error due to big content length, but got success"); + if let Err(e) = response { + assert!( + e.to_string().contains("content length") && + e.to_string().contains("exceeds the maximum allowed length"), + "Expected content length error, got: {e}", + ); + } + + // Clean up the server handle + _server_handle.abort(); - let pubkeys = fetch_ssv_pubkeys(chain, node_operator_id).await?; + Ok(()) + } + + #[tokio::test] + /// Tests that the SSV network fetch is handled properly when the request + /// times out + async fn test_ssv_network_fetch_timeout() -> eyre::Result<()> { + // Start the mock server + let port = 30102; + let _server_handle = create_mock_server(port).await?; + let url = format!("http://localhost:{port}/timeout"); + let response = + fetch_ssv_pubkeys_from_url(&url, Duration::from_secs(TEST_HTTP_TIMEOUT)).await; + + // The response should fail due to timeout + assert!(response.is_err(), "Expected timeout error, but got success"); + if let Err(e) = response { + assert!(e.to_string().contains("timed out"), "Expected timeout error, got: {}", e); + } + + // Clean up the server handle + _server_handle.abort(); + + Ok(()) + } + + #[tokio::test] + #[serial] + /// Tests that the SSV network fetch is handled properly when the response's + /// content-length header is missing + async fn test_ssv_network_fetch_big_data_without_content_length() -> eyre::Result<()> { + // Start the mock server + let port = 30103; + env::set_var(CB_TEST_HTTP_DISABLE_CONTENT_LENGTH_ENV, "1"); + defer! { env::remove_var(CB_TEST_HTTP_DISABLE_CONTENT_LENGTH_ENV); } + let _server_handle = create_mock_server(port).await?; + let url = format!("http://localhost:{port}/big_data"); + let response = fetch_ssv_pubkeys_from_url(&url, Duration::from_secs(120)).await; + + // The response should fail due to timeout + assert!(response.is_err(), "Expected error due to body size, but got success"); + if let Err(e) = response { + assert!( + e.to_string().contains("Response body exceeds the maximum allowed length "), + "Expected content length error, got: {e}", + ); + } - assert_eq!(pubkeys.len(), 3); + // Clean up the server handle + _server_handle.abort(); Ok(()) } + + /// Creates a simple mock server to simulate the SSV API endpoint under + /// various conditions for testing + async fn create_mock_server(port: u16) -> Result, axum::Error> { + let router = axum::Router::new() + .route("/ssv", get(handle_ssv)) + .route("/big_data", get(handle_big_data)) + .route("/timeout", get(handle_timeout)) + .into_make_service(); + + let address = SocketAddr::from(([127, 0, 0, 1], port)); + let listener = TcpListener::bind(address).await.map_err(axum::Error::new)?; + let server = axum::serve(listener, router).with_graceful_shutdown(async { + tokio::signal::ctrl_c().await.expect("Failed to listen for shutdown signal"); + }); + let result = Ok(tokio::spawn(async move { + if let Err(e) = server.await { + eprintln!("Server error: {}", e); + } + })); + info!("Mock server started on http://localhost:{port}/"); + result + } + + /// Sends the good SSV JSON data to the client + async fn handle_ssv() -> Response { + // Read the JSON data + let data = include_str!("../../../../tests/data/ssv_valid.json"); + + // Create a valid response + Response::builder() + .status(200) + .header("Content-Type", "application/json") + .body(data.into()) + .unwrap() + } + + /// Sends a response with a large body but no content length + async fn handle_big_data() -> Response { + // Create a response with a large body but no content length + let body = "f".repeat(2 * MUXER_HTTP_MAX_LENGTH as usize); + Response::builder() + .status(200) + .header("Content-Type", "application/text") + .body(body.into()) + .unwrap() + } + + /// Simulates a timeout by sleeping for a long time + async fn handle_timeout() -> Response { + // Sleep for a long time to simulate a timeout + tokio::time::sleep(std::time::Duration::from_secs(2 * TEST_HTTP_TIMEOUT)).await; + Response::builder() + .status(200) + .header("Content-Type", "application/text") + .body("Timeout response".into()) + .unwrap() + } } diff --git a/crates/common/src/config/pbs.rs b/crates/common/src/config/pbs.rs index 6c993716..363e9d99 100644 --- a/crates/common/src/config/pbs.rs +++ b/crates/common/src/config/pbs.rs @@ -17,7 +17,7 @@ use url::Url; use super::{ constants::PBS_IMAGE_DEFAULT, load_optional_env_var, CommitBoostConfig, RuntimeMuxConfig, - PBS_ENDPOINT_ENV, + HTTP_TIMEOUT_SECONDS_DEFAULT, PBS_ENDPOINT_ENV, }; use crate::{ commit::client::SignerClient, @@ -122,6 +122,9 @@ pub struct PbsConfig { pub extra_validation_enabled: bool, /// Execution Layer RPC url to use for extra validation pub rpc_url: Option, + /// Timeout for HTTP requests in seconds + #[serde(default = "default_u64::")] + pub http_timeout_seconds: u64, } impl PbsConfig { diff --git a/crates/common/src/config/signer.rs b/crates/common/src/config/signer.rs index e5ed6c22..7e5fbd58 100644 --- a/crates/common/src/config/signer.rs +++ b/crates/common/src/config/signer.rs @@ -4,20 +4,25 @@ use std::{ path::PathBuf, }; -use eyre::{bail, OptionExt, Result}; +use docker_image::DockerImage; +use eyre::{bail, ensure, OptionExt, Result}; use serde::{Deserialize, Serialize}; use tonic::transport::{Certificate, Identity}; use url::Url; use super::{ load_jwt_secrets, load_optional_env_var, utils::load_env_var, CommitBoostConfig, - SIGNER_ENDPOINT_ENV, SIGNER_IMAGE_DEFAULT, + SIGNER_ENDPOINT_ENV, SIGNER_IMAGE_DEFAULT, SIGNER_JWT_AUTH_FAIL_LIMIT_ENV, + SIGNER_JWT_AUTH_FAIL_TIMEOUT_SECONDS_ENV, }; use crate::{ config::{DIRK_CA_CERT_ENV, DIRK_CERT_ENV, DIRK_DIR_SECRETS_ENV, DIRK_KEY_ENV}, - signer::{ProxyStore, SignerLoader, DEFAULT_SIGNER_PORT}, + signer::{ + ProxyStore, SignerLoader, DEFAULT_JWT_AUTH_FAIL_LIMIT, + DEFAULT_JWT_AUTH_FAIL_TIMEOUT_SECONDS, DEFAULT_SIGNER_PORT, + }, types::{Chain, ModuleId}, - utils::{default_host, default_u16}, + utils::{default_host, default_u16, default_u32}, }; #[derive(Debug, Serialize, Deserialize, Clone)] @@ -32,11 +37,39 @@ pub struct SignerConfig { /// Docker image of the module #[serde(default = "default_signer")] pub docker_image: String, + + /// Number of JWT auth failures before rate limiting an endpoint + /// If set to 0, no rate limiting will be applied + #[serde(default = "default_u32::")] + pub jwt_auth_fail_limit: u32, + + /// Duration in seconds to rate limit an endpoint after the JWT auth failure + /// limit has been reached + #[serde(default = "default_u32::")] + pub jwt_auth_fail_timeout_seconds: u32, + /// Inner type-specific configuration #[serde(flatten)] pub inner: SignerType, } +impl SignerConfig { + /// Validate the signer config + pub async fn validate(&self) -> Result<()> { + // Port must be positive + ensure!(self.port > 0, "Port must be positive"); + + // The Docker tag must parse + ensure!(!self.docker_image.is_empty(), "Docker image is empty"); + ensure!( + DockerImage::parse(&self.docker_image).is_ok(), + format!("Invalid Docker image: {}", self.docker_image) + ); + + Ok(()) + } +} + fn default_signer() -> String { SIGNER_IMAGE_DEFAULT.to_string() } @@ -100,6 +133,8 @@ pub struct StartSignerConfig { pub store: Option, pub endpoint: SocketAddr, pub jwts: HashMap, + pub jwt_auth_fail_limit: u32, + pub jwt_auth_fail_timeout_seconds: u32, pub dirk: Option, } @@ -119,12 +154,31 @@ impl StartSignerConfig { SocketAddr::from((signer_config.host, signer_config.port)) }; + // Load the JWT auth fail limit the same way + let jwt_auth_fail_limit = + if let Some(limit) = load_optional_env_var(SIGNER_JWT_AUTH_FAIL_LIMIT_ENV) { + limit.parse()? + } else { + signer_config.jwt_auth_fail_limit + }; + + // Load the JWT auth fail timeout the same way + let jwt_auth_fail_timeout_seconds = if let Some(timeout) = + load_optional_env_var(SIGNER_JWT_AUTH_FAIL_TIMEOUT_SECONDS_ENV) + { + timeout.parse()? + } else { + signer_config.jwt_auth_fail_timeout_seconds + }; + match signer_config.inner { SignerType::Local { loader, store, .. } => Ok(StartSignerConfig { chain: config.chain, loader: Some(loader), endpoint, jwts, + jwt_auth_fail_limit, + jwt_auth_fail_timeout_seconds, store, dirk: None, }), @@ -153,6 +207,8 @@ impl StartSignerConfig { chain: config.chain, endpoint, jwts, + jwt_auth_fail_limit, + jwt_auth_fail_timeout_seconds, loader: None, store, dirk: Some(DirkConfig { diff --git a/crates/common/src/config/utils.rs b/crates/common/src/config/utils.rs index 67c367c5..dd391e17 100644 --- a/crates/common/src/config/utils.rs +++ b/crates/common/src/config/utils.rs @@ -1,10 +1,14 @@ -use std::{collections::HashMap, path::Path}; +use std::{collections::HashMap, env, path::Path}; +use bytes::{BufMut, BytesMut}; use eyre::{bail, Context, Result}; use serde::de::DeserializeOwned; use super::JWTS_ENV; -use crate::types::ModuleId; +use crate::{ + config::{CB_TEST_HTTP_DISABLE_CONTENT_LENGTH_ENV, MUXER_HTTP_MAX_LENGTH}, + types::ModuleId, +}; pub fn load_env_var(env: &str) -> Result { std::env::var(env).wrap_err(format!("{env} is not set")) @@ -30,6 +34,46 @@ pub fn load_jwt_secrets() -> Result> { decode_string_to_map(&jwt_secrets) } +/// Reads an HTTP response safely, erroring out if it failed or if the body is +/// too large. +pub async fn safe_read_http_response(mut response: reqwest::Response) -> Result { + // Get the content length from the response headers + let mut content_length = response.content_length(); + if env::var(CB_TEST_HTTP_DISABLE_CONTENT_LENGTH_ENV).is_ok() { + content_length = None; + } + + // Break if content length is provided but it's too big + if let Some(length) = content_length { + if length > MUXER_HTTP_MAX_LENGTH { + bail!("Response content length ({length}) exceeds the maximum allowed length ({MUXER_HTTP_MAX_LENGTH} bytes)"); + } + } + + // Make sure the response is a 200 + if response.status() != reqwest::StatusCode::OK { + bail!("Request failed with status: {}", response.status()); + } + + // Read the response to a buffer in chunks + let mut buffer = BytesMut::with_capacity(1024); + while let Some(chunk) = response.chunk().await? { + if buffer.len() > MUXER_HTTP_MAX_LENGTH as usize { + bail!( + "Response body exceeds the maximum allowed length ({MUXER_HTTP_MAX_LENGTH} bytes)" + ); + } + buffer.put(chunk); + } + + // Convert the buffer to a string + let bytes = buffer.freeze(); + match std::str::from_utf8(&bytes) { + Ok(s) => Ok(s.to_string()), + Err(e) => bail!("Failed to decode response body as UTF-8: {e}"), + } +} + fn decode_string_to_map(raw: &str) -> Result> { // trim the string and split for comma raw.trim() diff --git a/crates/common/src/pbs/event.rs b/crates/common/src/pbs/event.rs index 015de714..266fb68c 100644 --- a/crates/common/src/pbs/event.rs +++ b/crates/common/src/pbs/event.rs @@ -1,4 +1,4 @@ -use std::net::SocketAddr; +use std::{net::SocketAddr, time::Duration}; use alloy::{primitives::B256, rpc::types::beacon::relay::ValidatorRegistration}; use async_trait::async_trait; @@ -8,7 +8,7 @@ use axum::{ routing::post, Json, }; -use eyre::bail; +use eyre::{bail, Result}; use reqwest::StatusCode; use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; @@ -19,7 +19,10 @@ use super::{ GetHeaderParams, GetHeaderResponse, SignedBlindedBeaconBlock, SubmitBlindedBlockResponse, }; use crate::{ - config::{load_optional_env_var, BUILDER_URLS_ENV}, + config::{ + load_optional_env_var, BUILDER_URLS_ENV, HTTP_TIMEOUT_SECONDS_DEFAULT, + HTTP_TIMEOUT_SECONDS_ENV, + }, pbs::BUILDER_EVENTS_PATH, }; @@ -48,11 +51,24 @@ pub struct BuilderEventPublisher { } impl BuilderEventPublisher { - pub fn new(endpoints: Vec) -> Self { - Self { client: reqwest::Client::new(), endpoints } + pub fn new(endpoints: Vec, http_timeout: Duration) -> Result { + for endpoint in &endpoints { + if endpoint.scheme() != "https" { + bail!("BuilderEventPublisher endpoints must use HTTPS (endpoint {endpoint} is invalid)"); + } + } + Ok(Self { + client: reqwest::ClientBuilder::new().timeout(http_timeout).build().unwrap(), + endpoints, + }) } - pub fn new_from_env() -> eyre::Result> { + pub fn new_from_env() -> Result> { + let http_timeout = match load_optional_env_var(HTTP_TIMEOUT_SECONDS_ENV) { + Some(timeout_str) => Duration::from_secs(timeout_str.parse::()?), + None => Duration::from_secs(HTTP_TIMEOUT_SECONDS_DEFAULT), + }; + load_optional_env_var(BUILDER_URLS_ENV) .map(|joined| { let endpoints = joined @@ -62,9 +78,9 @@ impl BuilderEventPublisher { let url = base.trim().parse::()?.join(BUILDER_EVENTS_PATH)?; Ok(url) }) - .collect::>>()?; + .collect::>>()?; - Ok(Self::new(endpoints)) + Self::new(endpoints, http_timeout) }) .transpose() } diff --git a/crates/common/src/signer/constants.rs b/crates/common/src/signer/constants.rs index aa834f91..45e3ce23 100644 --- a/crates/common/src/signer/constants.rs +++ b/crates/common/src/signer/constants.rs @@ -1 +1,6 @@ pub const DEFAULT_SIGNER_PORT: u16 = 20000; + +// Rate limit signer API requests for 5 minutes after the endpoint has 3 JWT +// auth failures +pub const DEFAULT_JWT_AUTH_FAIL_LIMIT: u32 = 3; +pub const DEFAULT_JWT_AUTH_FAIL_TIMEOUT_SECONDS: u32 = 5 * 60; diff --git a/crates/common/src/utils.rs b/crates/common/src/utils.rs index 37119580..a1dcb7cb 100644 --- a/crates/common/src/utils.rs +++ b/crates/common/src/utils.rs @@ -137,6 +137,10 @@ pub const fn default_u64() -> u64 { U } +pub const fn default_u32() -> u32 { + U +} + pub const fn default_u16() -> u16 { U } diff --git a/crates/signer/src/error.rs b/crates/signer/src/error.rs index 477e9e42..a2a113f3 100644 --- a/crates/signer/src/error.rs +++ b/crates/signer/src/error.rs @@ -27,6 +27,9 @@ pub enum SignerModuleError { #[error("internal error: {0}")] Internal(String), + + #[error("rate limited for {0} more seconds")] + RateLimited(f64), } impl IntoResponse for SignerModuleError { @@ -45,6 +48,9 @@ impl IntoResponse for SignerModuleError { (StatusCode::INTERNAL_SERVER_ERROR, "internal error".to_string()) } SignerModuleError::SignerError(err) => (StatusCode::BAD_REQUEST, err.to_string()), + SignerModuleError::RateLimited(duration) => { + (StatusCode::TOO_MANY_REQUESTS, format!("rate limited for {duration:?}")) + } } .into_response() } diff --git a/crates/signer/src/service.rs b/crates/signer/src/service.rs index a965f057..3ca1d5ac 100644 --- a/crates/signer/src/service.rs +++ b/crates/signer/src/service.rs @@ -1,7 +1,12 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::HashMap, + net::SocketAddr, + sync::Arc, + time::{Duration, Instant}, +}; use axum::{ - extract::{Request, State}, + extract::{ConnectInfo, Request, State}, http::StatusCode, middleware::{self, Next}, response::{IntoResponse, Response}, @@ -41,13 +46,30 @@ use crate::{ /// Implements the Signer API and provides a service for signing requests pub struct SigningService; +// Tracker for a peer's JWT failures +struct JwtAuthFailureInfo { + // Number of auth failures since the first failure was tracked + failure_count: u32, + + // Time of the last auth failure + last_failure: Instant, +} + #[derive(Clone)] struct SigningState { /// Manager handling different signing methods manager: Arc>, + /// Map of modules ids to JWT secrets. This also acts as registry of all /// modules running jwts: Arc>, + + /// Map of JWT failures per peer + jwt_auth_failures: Arc>>, + + // JWT auth failure settings + jwt_auth_fail_limit: u32, + jwt_auth_fail_timeout: Duration, } impl SigningService { @@ -62,12 +84,31 @@ impl SigningService { let state = SigningState { manager: Arc::new(RwLock::new(start_manager(config.clone()).await?)), jwts: config.jwts.into(), + jwt_auth_failures: Arc::new(RwLock::new(HashMap::new())), + jwt_auth_fail_limit: config.jwt_auth_fail_limit, + jwt_auth_fail_timeout: Duration::from_secs(config.jwt_auth_fail_timeout_seconds as u64), }; - let loaded_consensus = state.manager.read().await.available_consensus_signers(); - let loaded_proxies = state.manager.read().await.available_proxy_signers(); + // Get the signer counts + let loaded_consensus: usize; + let loaded_proxies: usize; + { + let manager = state.manager.read().await; + loaded_consensus = manager.available_consensus_signers(); + loaded_proxies = manager.available_proxy_signers(); + } - info!(version = COMMIT_BOOST_VERSION, commit_hash = COMMIT_BOOST_COMMIT, modules =? module_ids, endpoint =? config.endpoint, loaded_consensus, loaded_proxies, "Starting signing service"); + info!( + version = COMMIT_BOOST_VERSION, + commit_hash = COMMIT_BOOST_COMMIT, + modules =? module_ids, + endpoint =? config.endpoint, + loaded_consensus, + loaded_proxies, + jwt_auth_fail_limit =? state.jwt_auth_fail_limit, + jwt_auth_fail_timeout =? state.jwt_auth_fail_timeout, + "Starting signing service" + ); SigningService::init_metrics(config.chain)?; @@ -79,7 +120,8 @@ impl SigningService { .route(RELOAD_PATH, post(handle_reload)) .with_state(state.clone()) .route_layer(middleware::from_fn(log_request)) - .route(STATUS_PATH, get(handle_status)); + .route(STATUS_PATH, get(handle_status)) + .into_make_service_with_connect_info::(); let listener = TcpListener::bind(config.endpoint).await?; @@ -95,9 +137,76 @@ impl SigningService { async fn jwt_auth( State(state): State, TypedHeader(auth): TypedHeader>, + addr: ConnectInfo, mut req: Request, next: Next, ) -> Result { + // Check if the request needs to be rate limited + let client_ip = addr.ip().to_string(); + check_jwt_rate_limit(&state, &client_ip).await?; + + // Process JWT authorization + match check_jwt_auth(&auth, &state).await { + Ok(module_id) => { + req.extensions_mut().insert(module_id); + Ok(next.run(req).await) + } + Err(SignerModuleError::Unauthorized) => { + let mut failures = state.jwt_auth_failures.write().await; + let failure_info = failures + .entry(client_ip) + .or_insert(JwtAuthFailureInfo { failure_count: 0, last_failure: Instant::now() }); + failure_info.failure_count += 1; + failure_info.last_failure = Instant::now(); + Err(SignerModuleError::Unauthorized) + } + Err(err) => Err(err), + } +} + +/// Checks if the incoming request needs to be rate limited due to previous JWT +/// authentication failures +async fn check_jwt_rate_limit( + state: &SigningState, + client_ip: &String, +) -> Result<(), SignerModuleError> { + let mut failures = state.jwt_auth_failures.write().await; + + // Ignore clients that don't have any failures + if let Some(failure_info) = failures.get(client_ip) { + // If the last failure was more than the timeout ago, remove this entry so it's + // eligible again + let elapsed = failure_info.last_failure.elapsed(); + if elapsed > state.jwt_auth_fail_timeout { + debug!("Removing {client_ip} from JWT auth failure list"); + failures.remove(client_ip); + return Ok(()); + } + + // If the failure threshold hasn't been met yet, don't rate limit + if failure_info.failure_count < state.jwt_auth_fail_limit { + debug!( + "Client {client_ip} has {}/{} JWT auth failures, no rate limit applied", + failure_info.failure_count, state.jwt_auth_fail_limit + ); + return Ok(()); + } + + // Rate limit the request + let remaining = state.jwt_auth_fail_timeout - elapsed; + warn!("Client {client_ip} is rate limited for {remaining:?} more seconds due to JWT auth failures"); + return Err(SignerModuleError::RateLimited(remaining.as_secs_f64())); + } + + debug!("Client {client_ip} has no JWT auth failures, no rate limit applied"); + Ok(()) +} + +/// Checks if a request can successfully authenticate with the JWT secret +async fn check_jwt_auth( + auth: &Authorization, + state: &SigningState, +) -> Result { let jwt: Jwt = auth.token().to_string().into(); // We first need to decode it to get the module id and then validate it @@ -116,10 +225,7 @@ async fn jwt_auth( error!("Unauthorized request. Invalid JWT: {e}"); SignerModuleError::Unauthorized })?; - - req.extensions_mut().insert(module_id); - - Ok(next.run(req).await) + Ok(module_id) } /// Requests logging middleware layer diff --git a/tests/Cargo.toml b/tests/Cargo.toml index ce273ae7..f1b5c9d9 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -9,9 +9,11 @@ alloy.workspace = true axum.workspace = true cb-common.workspace = true cb-pbs.workspace = true +cb-signer.workspace = true eyre.workspace = true reqwest.workspace = true serde_json.workspace = true +tempfile.workspace = true tokio.workspace = true tracing.workspace = true tracing-subscriber.workspace = true diff --git a/tests/data/ssv_valid.json b/tests/data/ssv_valid.json new file mode 100644 index 00000000..e19b13e6 --- /dev/null +++ b/tests/data/ssv_valid.json @@ -0,0 +1,99 @@ +{ + "validators": [ + { + "id": 554991, + "public_key": "967ba17a3e7f82a25aa5350ec34d6923e28ad8237b5a41efe2c5e325240d74d87a015bf04634f21900963539c8229b2a", + "cluster": "0xf7c1283eb0c0f76b5fa84c7541d8d4d27751b4083a5e8dcb8ac9e72bb7f559b8", + "owner_address": "0xB2EE025B1d129c61E77223bAb42fc65b29B16243", + "status": "Inactive", + "is_valid": true, + "is_deleted": false, + "is_public_key_valid": true, + "is_shares_valid": true, + "is_operators_valid": true, + "operators": [ + 16, + 27, + 86, + 90, + 200, + 204, + 214 + ], + "validator_info": { + "index": 1476217, + "status": "withdrawal_possible", + "activation_epoch": 4950, + "effective_balance": 32000000000 + }, + "version": "v4", + "network": "holesky" + }, + { + "id": 554992, + "public_key": "ac769e8cec802e8ffee34de3253be8f438a0c17ee84bdff0b6730280d24b5ecb77ebc9c985281b41ee3bda8663b6658c", + "cluster": "0xf7c1283eb0c0f76b5fa84c7541d8d4d27751b4083a5e8dcb8ac9e72bb7f559b8", + "owner_address": "0xB2EE025B1d129c61E77223bAb42fc65b29B16243", + "status": "Inactive", + "is_valid": true, + "is_deleted": false, + "is_public_key_valid": true, + "is_shares_valid": true, + "is_operators_valid": true, + "operators": [ + 16, + 27, + 86, + 90, + 200, + 204, + 214 + ], + "validator_info": { + "index": 1476218, + "status": "withdrawal_possible", + "activation_epoch": 4950, + "effective_balance": 32000000000 + }, + "version": "v4", + "network": "holesky" + }, + { + "id": 554994, + "public_key": "8c866a5a05f3d45c49b457e29365259021a509c5daa82e124f9701a960ee87b8902e87175315ab638a3d8b1115b23639", + "cluster": "0xf7c1283eb0c0f76b5fa84c7541d8d4d27751b4083a5e8dcb8ac9e72bb7f559b8", + "owner_address": "0xB2EE025B1d129c61E77223bAb42fc65b29B16243", + "status": "Inactive", + "is_valid": true, + "is_deleted": false, + "is_public_key_valid": true, + "is_shares_valid": true, + "is_operators_valid": true, + "operators": [ + 16, + 27, + 86, + 90, + 200, + 204, + 214 + ], + "validator_info": { + "index": 1476222, + "status": "withdrawal_possible", + "activation_epoch": 4950, + "effective_balance": 32000000000 + }, + "version": "v4", + "network": "holesky" + } + ], + "pagination": { + "total": 3, + "pages": 1, + "per_page": 10, + "page": 1, + "current_first": 554991, + "current_last": 554994 + } +} \ No newline at end of file diff --git a/tests/src/utils.rs b/tests/src/utils.rs index f2ae9157..b412efe8 100644 --- a/tests/src/utils.rs +++ b/tests/src/utils.rs @@ -1,13 +1,22 @@ use std::{ + collections::HashMap, net::{Ipv4Addr, SocketAddr}, sync::{Arc, Once}, }; use alloy::{primitives::U256, rpc::types::beacon::BlsPublicKey}; use cb_common::{ - config::{PbsConfig, PbsModuleConfig, RelayConfig}, + config::{ + PbsConfig, PbsModuleConfig, RelayConfig, SignerConfig, SignerType, StartSignerConfig, + SIGNER_IMAGE_DEFAULT, + }, pbs::{RelayClient, RelayEntry}, - types::Chain, + signer::{ + SignerLoader, DEFAULT_JWT_AUTH_FAIL_LIMIT, DEFAULT_JWT_AUTH_FAIL_TIMEOUT_SECONDS, + DEFAULT_SIGNER_PORT, + }, + types::{Chain, ModuleId}, + utils::default_host, }; use eyre::Result; @@ -72,6 +81,7 @@ pub fn get_pbs_static_config(port: u16) -> PbsConfig { late_in_slot_time_ms: u64::MAX, extra_validation_enabled: false, rpc_url: None, + http_timeout_seconds: 10, } } @@ -91,3 +101,34 @@ pub fn to_pbs_config( muxes: None, } } + +pub fn get_signer_config(loader: SignerLoader) -> SignerConfig { + SignerConfig { + host: default_host(), + port: DEFAULT_SIGNER_PORT, + docker_image: SIGNER_IMAGE_DEFAULT.to_string(), + jwt_auth_fail_limit: DEFAULT_JWT_AUTH_FAIL_LIMIT, + jwt_auth_fail_timeout_seconds: DEFAULT_JWT_AUTH_FAIL_TIMEOUT_SECONDS, + inner: SignerType::Local { loader, store: None }, + } +} + +pub fn get_start_signer_config( + signer_config: SignerConfig, + chain: Chain, + jwts: HashMap, +) -> StartSignerConfig { + match signer_config.inner { + SignerType::Local { loader, .. } => StartSignerConfig { + chain, + loader: Some(loader), + store: None, + endpoint: SocketAddr::new(signer_config.host.into(), signer_config.port), + jwts, + jwt_auth_fail_limit: signer_config.jwt_auth_fail_limit, + jwt_auth_fail_timeout_seconds: signer_config.jwt_auth_fail_timeout_seconds, + dirk: None, + }, + _ => panic!("Only local signers are supported in tests"), + } +} diff --git a/tests/tests/config.rs b/tests/tests/config.rs index dafd96d9..f6f31d96 100644 --- a/tests/tests/config.rs +++ b/tests/tests/config.rs @@ -37,11 +37,11 @@ async fn test_load_pbs_happy() -> Result<()> { // Docker and general settings assert_eq!(config.pbs.docker_image, "ghcr.io/commit-boost/pbs:latest"); - assert_eq!(config.pbs.with_signer, false); + assert!(!config.pbs.with_signer); assert_eq!(config.pbs.pbs_config.host, "127.0.0.1".parse::().unwrap()); assert_eq!(config.pbs.pbs_config.port, 18550); - assert_eq!(config.pbs.pbs_config.relay_check, true); - assert_eq!(config.pbs.pbs_config.wait_all_registrations, true); + assert!(config.pbs.pbs_config.relay_check); + assert!(config.pbs.pbs_config.wait_all_registrations); // Timeouts assert_eq!(config.pbs.pbs_config.timeout_get_header_ms, 950); @@ -49,12 +49,12 @@ async fn test_load_pbs_happy() -> Result<()> { assert_eq!(config.pbs.pbs_config.timeout_register_validator_ms, 3000); // Bid settings and validation - assert_eq!(config.pbs.pbs_config.skip_sigverify, false); + assert!(!config.pbs.pbs_config.skip_sigverify); dbg!(&config.pbs.pbs_config.min_bid_wei); dbg!(&U256::from(0.5)); assert_eq!(config.pbs.pbs_config.min_bid_wei, U256::from((0.5 * WEI_PER_ETH as f64) as u64)); assert_eq!(config.pbs.pbs_config.late_in_slot_time_ms, 2000); - assert_eq!(config.pbs.pbs_config.extra_validation_enabled, false); + assert!(!config.pbs.pbs_config.extra_validation_enabled); assert_eq!( config.pbs.pbs_config.rpc_url, Some("https://ethereum-holesky-rpc.publicnode.com".parse::().unwrap()) @@ -64,7 +64,7 @@ async fn test_load_pbs_happy() -> Result<()> { let relay = &config.relays[0]; assert_eq!(relay.id, Some("example-relay".to_string())); assert_eq!(relay.entry.url, "http://0xa1cec75a3f0661e99299274182938151e8433c61a19222347ea1313d839229cb4ce4e3e5aa2bdeb71c8fcf1b084963c2@abc.xyz".parse::().unwrap()); - assert_eq!(relay.enable_timing_games, false); + assert!(!relay.enable_timing_games); assert_eq!(relay.target_first_request_ms, Some(200)); assert_eq!(relay.frequency_get_header_ms, Some(300)); diff --git a/tests/tests/pbs_get_header.rs b/tests/tests/pbs_get_header.rs index 422a71a3..747d460c 100644 --- a/tests/tests/pbs_get_header.rs +++ b/tests/tests/pbs_get_header.rs @@ -23,7 +23,7 @@ use tree_hash::TreeHash; async fn test_get_header() -> Result<()> { setup_test_env(); let signer = random_secret(); - let pubkey: BlsPublicKey = blst_pubkey_to_alloy(&signer.sk_to_pk()).into(); + let pubkey: BlsPublicKey = blst_pubkey_to_alloy(&signer.sk_to_pk()); let chain = Chain::Holesky; let pbs_port = 3200; diff --git a/tests/tests/signer_jwt_auth.rs b/tests/tests/signer_jwt_auth.rs new file mode 100644 index 00000000..90a0365f --- /dev/null +++ b/tests/tests/signer_jwt_auth.rs @@ -0,0 +1,146 @@ +use std::{collections::HashMap, time::Duration}; + +use alloy::{hex, primitives::FixedBytes}; +use cb_common::{ + commit::{constants::GET_PUBKEYS_PATH, request::GetPubkeysResponse}, + config::StartSignerConfig, + signer::{SignerLoader, ValidatorKeysFormat}, + types::{Chain, ModuleId}, + utils::create_jwt, +}; +use cb_signer::service::SigningService; +use cb_tests::utils::{get_signer_config, get_start_signer_config, setup_test_env}; +use eyre::Result; +use reqwest::{Response, StatusCode}; +use tracing::info; + +const JWT_MODULE: &str = "test-module"; +const JWT_SECRET: &str = "test-jwt-secret"; + +#[tokio::test] +async fn test_signer_jwt_auth_success() -> Result<()> { + setup_test_env(); + let module_id = ModuleId(JWT_MODULE.to_string()); + let start_config = start_server(20100).await?; + + // Run a pubkeys request + let jwt = create_jwt(&module_id, JWT_SECRET)?; + let client = reqwest::Client::new(); + let url = format!("http://{}{}", start_config.endpoint, GET_PUBKEYS_PATH); + let response = client.get(&url).bearer_auth(&jwt).send().await?; + + // Verify the expected pubkeys are returned + verify_pubkeys(response).await?; + + Ok(()) +} + +#[tokio::test] +async fn test_signer_jwt_auth_fail() -> Result<()> { + setup_test_env(); + let module_id = ModuleId(JWT_MODULE.to_string()); + let start_config = start_server(20200).await?; + + // Run a pubkeys request - this should fail due to invalid JWT + let jwt = create_jwt(&module_id, "incorrect secret")?; + let client = reqwest::Client::new(); + let url = format!("http://{}{}", start_config.endpoint, GET_PUBKEYS_PATH); + let response = client.get(&url).bearer_auth(&jwt).send().await?; + assert!(response.status() == StatusCode::UNAUTHORIZED); + info!( + "Server returned expected error code {} for invalid JWT: {}", + response.status(), + response.text().await.unwrap_or_else(|_| "No response body".to_string()) + ); + Ok(()) +} + +#[tokio::test] +async fn test_signer_jwt_rate_limit() -> Result<()> { + setup_test_env(); + let module_id = ModuleId(JWT_MODULE.to_string()); + let start_config = start_server(20300).await?; + + // Run as many pubkeys requests as the fail limit + let jwt = create_jwt(&module_id, "incorrect secret")?; + let client = reqwest::Client::new(); + let url = format!("http://{}{}", start_config.endpoint, GET_PUBKEYS_PATH); + for _ in 0..start_config.jwt_auth_fail_limit { + let response = client.get(&url).bearer_auth(&jwt).send().await?; + assert!(response.status() == StatusCode::UNAUTHORIZED); + } + + // Run another request - this should fail due to rate limiting now + let jwt = create_jwt(&module_id, JWT_SECRET)?; + let response = client.get(&url).bearer_auth(&jwt).send().await?; + assert!(response.status() == StatusCode::TOO_MANY_REQUESTS); + + // Wait for the rate limit timeout + tokio::time::sleep(Duration::from_secs(start_config.jwt_auth_fail_timeout_seconds as u64)) + .await; + + // Now the next request should succeed + let response = client.get(&url).bearer_auth(&jwt).send().await?; + verify_pubkeys(response).await?; + + Ok(()) +} + +// Starts the signer moduler server on a separate task and returns its +// configuration +async fn start_server(port: u16) -> Result { + setup_test_env(); + let chain = Chain::Hoodi; + + // Mock JWT secrets + let module_id = ModuleId(JWT_MODULE.to_string()); + let mut jwts = HashMap::new(); + jwts.insert(module_id.clone(), JWT_SECRET.to_string()); + + // Create a signer config + let loader = SignerLoader::ValidatorsDir { + keys_path: "data/keystores/keys".into(), + secrets_path: "data/keystores/secrets".into(), + format: ValidatorKeysFormat::Lighthouse, + }; + let mut config = get_signer_config(loader); + config.port = port; + config.jwt_auth_fail_limit = 3; // Set a low fail limit for testing + config.jwt_auth_fail_timeout_seconds = 3; // Set a short timeout for testing + let start_config = get_start_signer_config(config, chain, jwts); + + // Run the Signer + let server_handle = tokio::spawn(SigningService::run(start_config.clone())); + + // Make sure the server is running + tokio::time::sleep(Duration::from_millis(100)).await; + if server_handle.is_finished() { + return Err(eyre::eyre!( + "Signer service failed to start: {}", + server_handle.await.unwrap_err() + )); + } + Ok(start_config) +} + +// Verifies that the pubkeys returned by the server match the pubkeys in the +// test data +async fn verify_pubkeys(response: Response) -> Result<()> { + // Verify the expected pubkeys are returned + assert!(response.status() == StatusCode::OK); + let pubkey_json = response.json::().await?; + assert_eq!(pubkey_json.keys.len(), 2); + let expected_pubkeys = vec![ + FixedBytes::new(hex!("883827193f7627cd04e621e1e8d56498362a52b2a30c9a1c72036eb935c4278dee23d38a24d2f7dda62689886f0c39f4")), + FixedBytes::new(hex!("b3a22e4a673ac7a153ab5b3c17a4dbef55f7e47210b20c0cbb0e66df5b36bb49ef808577610b034172e955d2312a61b9")), + ]; + for expected in expected_pubkeys { + assert!( + pubkey_json.keys.iter().any(|k| k.consensus == expected), + "Expected pubkey not found: {:?}", + expected + ); + info!("Server returned expected pubkey: {:?}", expected); + } + Ok(()) +}