diff --git a/Cargo.lock b/Cargo.lock index ac895cb7e..42e224734 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -101,6 +101,12 @@ version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "async-broadcast" version = "0.7.2" @@ -3093,13 +3099,16 @@ dependencies = [ name = "stackable-webhook" version = "0.3.1" dependencies = [ + "arc-swap", "axum", + "clap", "futures-util", "hyper", "hyper-util", "k8s-openapi", "kube", "opentelemetry", + "rand 0.9.1", "serde_json", "snafu 0.8.6", "stackable-certs", @@ -3111,6 +3120,7 @@ dependencies = [ "tower-http", "tracing", "tracing-opentelemetry", + "x509-cert", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 1bc0bc32f..e50a67c40 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ repository = "https://github.com/stackabletech/operator-rs" [workspace.dependencies] product-config = { git = "https://github.com/stackabletech/product-config.git", tag = "0.7.0" } +arc-swap = "1.7" axum = { version = "0.8.1", features = ["http2"] } chrono = { version = "0.4.38", default-features = false } clap = { version = "4.5.17", features = ["derive", "cargo", "env"] } diff --git a/crates/stackable-certs/src/ca/consts.rs b/crates/stackable-certs/src/ca/consts.rs index 125a63a05..bcd080cd4 100644 --- a/crates/stackable-certs/src/ca/consts.rs +++ b/crates/stackable-certs/src/ca/consts.rs @@ -1,6 +1,6 @@ use stackable_operator::time::Duration; -/// The default CA validity time span of one hour (3600 seconds). +/// The default CA validity time span pub const DEFAULT_CA_VALIDITY: Duration = Duration::from_hours_unchecked(1); /// The root CA subject name containing only the common name. diff --git a/crates/stackable-certs/src/ca/mod.rs b/crates/stackable-certs/src/ca/mod.rs index b2e464b45..f9c2a2f26 100644 --- a/crates/stackable-certs/src/ca/mod.rs +++ b/crates/stackable-certs/src/ca/mod.rs @@ -38,7 +38,7 @@ pub enum Error { #[snafu(display("failed to generate RSA signing key"))] GenerateRsaSigningKey { source: rsa::Error }, - #[snafu(display("failed to generate ECDSA signign key"))] + #[snafu(display("failed to generate ECDSA signing key"))] GenerateEcdsaSigningKey { source: ecdsa::Error }, #[snafu(display("failed to parse {subject:?} as subject"))] diff --git a/crates/stackable-operator/CHANGELOG.md b/crates/stackable-operator/CHANGELOG.md index 709d7d816..92da2a1a3 100644 --- a/crates/stackable-operator/CHANGELOG.md +++ b/crates/stackable-operator/CHANGELOG.md @@ -34,7 +34,12 @@ All notable changes to this project will be documented in this file. ### Changed - Update `kube` to `1.1.0` ([#1049]). -- BREAKING: Return type for `ListenerOperatorVolumeSourceBuilder::new()` is no onger a `Result` ([#1058]). +- BREAKING: Return type for `ListenerOperatorVolumeSourceBuilder::new()` is no longer a `Result` ([#1058]). +- BREAKING: Require two new CLI arguments: `--operator-namespace` and `-operator-service-name`. + These are required, so that the operator knows what Service it needs to enter as CRD conversion webhook ([#1066]). +- BREAKING: The `ProductOperatorRun` used for CLI arguments has some field renamed for consistency ([#1066]): + - `telemetry_arguments` -> `telemetry` + - `cluster_info_opts` -> `cluster_info` ### Fixed @@ -50,6 +55,7 @@ All notable changes to this project will be documented in this file. [#1058]: https://github.com/stackabletech/operator-rs/pull/1058 [#1060]: https://github.com/stackabletech/operator-rs/pull/1060 [#1064]: https://github.com/stackabletech/operator-rs/pull/1064 +[#1066]: https://github.com/stackabletech/operator-rs/pull/1066 [#1068]: https://github.com/stackabletech/operator-rs/pull/1068 [#1071]: https://github.com/stackabletech/operator-rs/pull/1071 diff --git a/crates/stackable-operator/src/cli.rs b/crates/stackable-operator/src/cli.rs index d9dafeb07..b19fb3693 100644 --- a/crates/stackable-operator/src/cli.rs +++ b/crates/stackable-operator/src/cli.rs @@ -116,7 +116,7 @@ use product_config::ProductConfigManager; use snafu::{ResultExt, Snafu}; use stackable_telemetry::tracing::TelemetryOptions; -use crate::{namespace::WatchNamespace, utils::cluster_info::KubernetesClusterInfoOpts}; +use crate::{namespace::WatchNamespace, utils::cluster_info::KubernetesClusterInfoOptions}; pub const AUTHOR: &str = "Stackable GmbH - info@stackable.tech"; @@ -163,10 +163,10 @@ pub enum Command { /// Can be embedded into an extended argument set: /// /// ```rust -/// # use stackable_operator::cli::{Command, ProductOperatorRun, ProductConfigPath}; +/// # use stackable_operator::cli::{Command, OperatorEnvironmentOptions, ProductOperatorRun, ProductConfigPath}; +/// # use stackable_operator::{namespace::WatchNamespace, utils::cluster_info::KubernetesClusterInfoOptions}; +/// # use stackable_telemetry::tracing::TelemetryOptions; /// use clap::Parser; -/// use stackable_operator::{namespace::WatchNamespace, utils::cluster_info::KubernetesClusterInfoOpts}; -/// use stackable_telemetry::tracing::TelemetryOptions; /// /// #[derive(clap::Parser, Debug, PartialEq, Eq)] /// struct Run { @@ -176,17 +176,36 @@ pub enum Command { /// common: ProductOperatorRun, /// } /// -/// let opts = Command::::parse_from(["foobar-operator", "run", "--name", "foo", "--product-config", "bar", "--watch-namespace", "foobar", "--kubernetes-node-name", "baz"]); +/// let opts = Command::::parse_from([ +/// "foobar-operator", +/// "run", +/// "--name", +/// "foo", +/// "--product-config", +/// "bar", +/// "--watch-namespace", +/// "foobar", +/// "--operator-namespace", +/// "stackable-operators", +/// "--operator-service-name", +/// "foo-operator", +/// "--kubernetes-node-name", +/// "baz", +/// ]); /// assert_eq!(opts, Command::Run(Run { /// name: "foo".to_string(), /// common: ProductOperatorRun { /// product_config: ProductConfigPath::from("bar".as_ref()), /// watch_namespace: WatchNamespace::One("foobar".to_string()), -/// telemetry_arguments: TelemetryOptions::default(), -/// cluster_info_opts: KubernetesClusterInfoOpts { +/// telemetry: TelemetryOptions::default(), +/// cluster_info: KubernetesClusterInfoOptions { /// kubernetes_cluster_domain: None, /// kubernetes_node_name: "baz".to_string(), /// }, +/// operator_environment: OperatorEnvironmentOptions { +/// operator_namespace: "stackable-operators".to_string(), +/// operator_service_name: "foo-operator".to_string(), +/// }, /// }, /// })); /// ``` @@ -220,10 +239,13 @@ pub struct ProductOperatorRun { pub watch_namespace: WatchNamespace, #[command(flatten)] - pub telemetry_arguments: TelemetryOptions, + pub operator_environment: OperatorEnvironmentOptions, + + #[command(flatten)] + pub telemetry: TelemetryOptions, #[command(flatten)] - pub cluster_info_opts: KubernetesClusterInfoOpts, + pub cluster_info: KubernetesClusterInfoOptions, } /// A path to a [`ProductConfigManager`] spec file @@ -281,9 +303,25 @@ impl ProductConfigPath { } } +#[derive(clap::Parser, Debug, PartialEq, Eq)] +pub struct OperatorEnvironmentOptions { + /// The namespace the operator is running in, usually `stackable-operators`. + /// + /// Note that when running the operator on Kubernetes we recommend to use the + /// [downward API](https://kubernetes.io/docs/concepts/workloads/pods/downward-api/) + /// to let Kubernetes mount the namespace as the `OPERATOR_NAMESPACE` env variable. + #[arg(long, env)] + pub operator_namespace: String, + + /// The name of the service the operator is reachable at, usually + /// something like `-operator`. + #[arg(long, env)] + pub operator_service_name: String, +} + #[cfg(test)] mod tests { - use std::{env, fs::File}; + use std::fs::File; use clap::Parser; use rstest::*; @@ -294,7 +332,6 @@ mod tests { const USER_PROVIDED_PATH: &str = "user_provided_path_properties.yaml"; const DEPLOY_FILE_PATH: &str = "deploy_config_spec_properties.yaml"; const DEFAULT_FILE_PATH: &str = "default_file_path_properties.yaml"; - const WATCH_NAMESPACE: &str = "WATCH_NAMESPACE"; #[test] fn verify_cli() { @@ -381,9 +418,6 @@ mod tests { #[test] fn product_operator_run_watch_namespace() { - // clean env var to not interfere if already set - unsafe { env::remove_var(WATCH_NAMESPACE) }; - // cli with namespace let opts = ProductOperatorRun::parse_from([ "run", @@ -391,6 +425,10 @@ mod tests { "bar", "--watch-namespace", "foo", + "--operator-namespace", + "stackable-operators", + "--operator-service-name", + "foo-operator", "--kubernetes-node-name", "baz", ]); @@ -399,11 +437,15 @@ mod tests { ProductOperatorRun { product_config: ProductConfigPath::from("bar".as_ref()), watch_namespace: WatchNamespace::One("foo".to_string()), - cluster_info_opts: KubernetesClusterInfoOpts { + cluster_info: KubernetesClusterInfoOptions { kubernetes_cluster_domain: None, kubernetes_node_name: "baz".to_string() }, - telemetry_arguments: Default::default(), + telemetry: Default::default(), + operator_environment: OperatorEnvironmentOptions { + operator_namespace: "stackable-operators".to_string(), + operator_service_name: "foo-operator".to_string(), + } } ); @@ -412,6 +454,10 @@ mod tests { "run", "--product-config", "bar", + "--operator-namespace", + "stackable-operators", + "--operator-service-name", + "foo-operator", "--kubernetes-node-name", "baz", ]); @@ -420,33 +466,15 @@ mod tests { ProductOperatorRun { product_config: ProductConfigPath::from("bar".as_ref()), watch_namespace: WatchNamespace::All, - cluster_info_opts: KubernetesClusterInfoOpts { - kubernetes_cluster_domain: None, - kubernetes_node_name: "baz".to_string() - }, - telemetry_arguments: Default::default(), - } - ); - - // env with namespace - unsafe { env::set_var(WATCH_NAMESPACE, "foo") }; - let opts = ProductOperatorRun::parse_from([ - "run", - "--product-config", - "bar", - "--kubernetes-node-name", - "baz", - ]); - assert_eq!( - opts, - ProductOperatorRun { - product_config: ProductConfigPath::from("bar".as_ref()), - watch_namespace: WatchNamespace::One("foo".to_string()), - cluster_info_opts: KubernetesClusterInfoOpts { + cluster_info: KubernetesClusterInfoOptions { kubernetes_cluster_domain: None, kubernetes_node_name: "baz".to_string() }, - telemetry_arguments: Default::default(), + telemetry: Default::default(), + operator_environment: OperatorEnvironmentOptions { + operator_namespace: "stackable-operators".to_string(), + operator_service_name: "foo-operator".to_string(), + } } ); } diff --git a/crates/stackable-operator/src/client.rs b/crates/stackable-operator/src/client.rs index 5d493866e..f79a1eb91 100644 --- a/crates/stackable-operator/src/client.rs +++ b/crates/stackable-operator/src/client.rs @@ -21,7 +21,7 @@ use tracing::trace; use crate::{ kvp::LabelSelectorExt, - utils::cluster_info::{KubernetesClusterInfo, KubernetesClusterInfoOpts}, + utils::cluster_info::{KubernetesClusterInfo, KubernetesClusterInfoOptions}, }; pub type Result = std::result::Result; @@ -529,13 +529,13 @@ impl Client { /// use k8s_openapi::api::core::v1::Pod; /// use stackable_operator::{ /// client::{Client, initialize_operator}, - /// utils::cluster_info::KubernetesClusterInfoOpts, + /// utils::cluster_info::KubernetesClusterInfoOptions, /// }; /// /// #[tokio::main] /// async fn main() { - /// let cluster_info_opts = KubernetesClusterInfoOpts::parse(); - /// let client = initialize_operator(None, &cluster_info_opts) + /// let cluster_info_options = KubernetesClusterInfoOptions::parse(); + /// let client = initialize_operator(None, &cluster_info_options) /// .await /// .expect("Unable to construct client."); /// let watcher_config: watcher::Config = @@ -652,7 +652,7 @@ where pub async fn initialize_operator( field_manager: Option, - cluster_info_opts: &KubernetesClusterInfoOpts, + cluster_info_opts: &KubernetesClusterInfoOptions, ) -> Result { let kubeconfig: Config = kube::Config::infer() .await @@ -687,10 +687,10 @@ mod tests { }; use tokio::time::error::Elapsed; - use crate::utils::cluster_info::KubernetesClusterInfoOpts; + use crate::utils::cluster_info::KubernetesClusterInfoOptions; - async fn test_cluster_info_opts() -> KubernetesClusterInfoOpts { - KubernetesClusterInfoOpts { + async fn test_cluster_info_opts() -> KubernetesClusterInfoOptions { + KubernetesClusterInfoOptions { // We have to hard-code a made-up cluster domain, // since kubernetes_node_name (probably) won't be a valid Node that we can query. kubernetes_cluster_domain: Some( diff --git a/crates/stackable-operator/src/utils/cluster_info.rs b/crates/stackable-operator/src/utils/cluster_info.rs index 56c718f9e..0cc92e9e9 100644 --- a/crates/stackable-operator/src/utils/cluster_info.rs +++ b/crates/stackable-operator/src/utils/cluster_info.rs @@ -16,13 +16,17 @@ pub struct KubernetesClusterInfo { } #[derive(clap::Parser, Debug, PartialEq, Eq)] -pub struct KubernetesClusterInfoOpts { +pub struct KubernetesClusterInfoOptions { /// Kubernetes cluster domain, usually this is `cluster.local`. // We are not using a default value here, as we query the cluster if it is not specified. #[arg(long, env)] pub kubernetes_cluster_domain: Option, /// Name of the Kubernetes Node that the operator is running on. + /// + /// Note that when running the operator on Kubernetes we recommend to use the + /// [downward API](https://kubernetes.io/docs/concepts/workloads/pods/downward-api/) + /// to let Kubernetes mount the namespace as the `KUBERNETES_NODE_NAME` env variable. #[arg(long, env)] pub kubernetes_node_name: String, } @@ -30,10 +34,10 @@ pub struct KubernetesClusterInfoOpts { impl KubernetesClusterInfo { pub async fn new( client: &Client, - cluster_info_opts: &KubernetesClusterInfoOpts, + cluster_info_opts: &KubernetesClusterInfoOptions, ) -> Result { let cluster_domain = match cluster_info_opts { - KubernetesClusterInfoOpts { + KubernetesClusterInfoOptions { kubernetes_cluster_domain: Some(cluster_domain), .. } => { @@ -41,7 +45,7 @@ impl KubernetesClusterInfo { cluster_domain.clone() } - KubernetesClusterInfoOpts { + KubernetesClusterInfoOptions { kubernetes_node_name: node_name, .. } => { diff --git a/crates/stackable-telemetry/src/instrumentation/axum/mod.rs b/crates/stackable-telemetry/src/instrumentation/axum/mod.rs index b72450c99..14e7928cd 100644 --- a/crates/stackable-telemetry/src/instrumentation/axum/mod.rs +++ b/crates/stackable-telemetry/src/instrumentation/axum/mod.rs @@ -73,22 +73,6 @@ const OTEL_TRACE_ID_TO: &str = "opentelemetry.trace_id.to"; /// # let _: Router = router; /// ``` /// -/// ### Example with Webhook -/// -/// The usage is even simpler when combined with the `stackable_webhook` crate. -/// The webhook server has built-in support to automatically emit HTTP spans on -/// every incoming request. -/// -/// ``` -/// use stackable_webhook::{WebhookServer, Options}; -/// use axum::Router; -/// -/// let router = Router::new(); -/// let server = WebhookServer::new(router, Options::default()); -/// -/// # let _: WebhookServer = server; -/// ``` -/// /// This layer is implemented based on [this][1] official Tower guide. /// /// [1]: https://github.com/tower-rs/tower/blob/master/guides/building-a-middleware-from-scratch.md diff --git a/crates/stackable-webhook/CHANGELOG.md b/crates/stackable-webhook/CHANGELOG.md index a7354362f..c34b32519 100644 --- a/crates/stackable-webhook/CHANGELOG.md +++ b/crates/stackable-webhook/CHANGELOG.md @@ -4,6 +4,15 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +### Added + +- BREAKING: Re-write the `ConversionWebhookServer`. + It can now do CRD conversions, handle multiple CRDs and takes care of reconciling the CRDs ([#1066]). +- BREAKING: The `TlsServer` can now handle certificate rotation. + To achieve this, a new `CertificateResolver` was added. + Also, `TlsServer::new` now returns an additional `mpsc::Receiver`, so that the caller + can get notified about certificate rotations happening ([#1066]). + ### Fixed - Don't pull in the `aws-lc-rs` crate, as this currently fails to build in `make run-dev` ([#1043]). @@ -18,10 +27,11 @@ All notable changes to this project will be documented in this file. [#1043]: https://github.com/stackabletech/operator-rs/pull/1043 [#1045]: https://github.com/stackabletech/operator-rs/pull/1045 +[#1066]: https://github.com/stackabletech/operator-rs/pull/1066 ## [0.3.1] - 2024-07-10 -## Changed +### Changed - Remove instrumentation of long running functions, add more granular instrumentation of futures. Adjust span and event levels ([#811]). - Bump rust-toolchain to 1.79.0 ([#822]). diff --git a/crates/stackable-webhook/Cargo.toml b/crates/stackable-webhook/Cargo.toml index da553f188..ece1480da 100644 --- a/crates/stackable-webhook/Cargo.toml +++ b/crates/stackable-webhook/Cargo.toml @@ -11,6 +11,7 @@ stackable-certs = { path = "../stackable-certs", features = ["rustls"] } stackable-telemetry = { path = "../stackable-telemetry" } stackable-operator = { path = "../stackable-operator" } +arc-swap.workspace = true axum.workspace = true futures-util.workspace = true hyper-util.workspace = true @@ -18,6 +19,7 @@ hyper.workspace = true k8s-openapi.workspace = true kube.workspace = true opentelemetry.workspace = true +rand.workspace = true serde_json.workspace = true snafu.workspace = true tokio-rustls.workspace = true @@ -26,3 +28,7 @@ tower-http.workspace = true tower.workspace = true tracing.workspace = true tracing-opentelemetry.workspace = true +x509-cert.workspace = true + +[dev-dependencies] +clap.workspace = true diff --git a/crates/stackable-webhook/src/lib.rs b/crates/stackable-webhook/src/lib.rs index 186f19e12..c37c423f2 100644 --- a/crates/stackable-webhook/src/lib.rs +++ b/crates/stackable-webhook/src/lib.rs @@ -1,6 +1,6 @@ //! Utility types and functions to easily create ready-to-use webhook servers //! which can handle different tasks, for example CRD conversions. All webhook -//! servers use HTTPS by defaultThis library is fully compatible with the +//! servers use HTTPS by default. This library is fully compatible with the //! [`tracing`] crate and emits debug level tracing data. //! //! Most users will only use the top-level exported generic [`WebhookServer`] @@ -11,24 +11,31 @@ //! use stackable_webhook::{WebhookServer, Options}; //! use axum::Router; //! +//! # async fn test() { //! let router = Router::new(); -//! let server = WebhookServer::new(router, Options::default()); +//! let (server, cert_rx) = WebhookServer::new(router, Options::default()) +//! .await +//! .expect("failed to create WebhookServer"); +//! # } //! ``` //! //! For some usages, complete end-to-end [`WebhookServer`] implementations -//! exist. One such implementation is the [`ConversionWebhookServer`][1]. The -//! only required parameters are a conversion handler function and [`Options`]. +//! exist. One such implementation is the [`ConversionWebhookServer`][1]. //! //! This library additionally also exposes lower-level structs and functions to -//! enable complete controll over these details if needed. +//! enable complete control over these details if needed. //! //! [1]: crate::servers::ConversionWebhookServer use axum::{Router, routing::get}; use futures_util::{FutureExt as _, pin_mut, select}; use snafu::{ResultExt, Snafu}; use stackable_telemetry::AxumTraceLayer; -use tokio::signal::unix::{SignalKind, signal}; +use tokio::{ + signal::unix::{SignalKind, signal}, + sync::mpsc, +}; use tower::ServiceBuilder; +use x509_cert::Certificate; // use tower_http::trace::TraceLayer; use crate::tls::TlsServer; @@ -41,10 +48,6 @@ pub mod tls; // Selected re-exports pub use crate::options::Options; -/// A result type alias with the library-level [`Error`] type as teh default -/// error type. -pub type Result = std::result::Result; - /// A generic webhook handler receiving a request and sending back a response. /// /// This trait is not intended to be implemented by external crates and this @@ -56,25 +59,16 @@ pub trait WebhookHandler { fn call(self, req: Req) -> Res; } -/// A generic webhook handler receiving a request and state and sending back -/// a response. -/// -/// This trait is not intended to be implemented by external crates and this -/// library provides various ready-to-use implementations for it. One such an -/// implementation is part of the [`ConversionWebhookServer`][1]. -/// -/// [1]: crate::servers::ConversionWebhookServer -pub trait StatefulWebhookHandler { - fn call(self, req: Req, state: S) -> Res; -} +/// A result type alias with the [`WebhookError`] type as the default error type. +pub type Result = std::result::Result; #[derive(Debug, Snafu)] -pub enum Error { +pub enum WebhookError { #[snafu(display("failed to create TLS server"))] - CreateTlsServer { source: tls::Error }, + CreateTlsServer { source: tls::TlsServerError }, #[snafu(display("failed to run TLS server"))] - RunTlsServer { source: tls::Error }, + RunTlsServer { source: tls::TlsServerError }, } /// A ready-to-use webhook server. @@ -88,8 +82,7 @@ pub enum Error { /// /// [1]: crate::servers::ConversionWebhookServer pub struct WebhookServer { - options: Options, - router: Router, + tls_server: TlsServer, } impl WebhookServer { @@ -109,8 +102,12 @@ impl WebhookServer { /// use stackable_webhook::{WebhookServer, Options}; /// use axum::Router; /// + /// # async fn test() { /// let router = Router::new(); - /// let server = WebhookServer::new(router, Options::default()); + /// let (server, cert_rx) = WebhookServer::new(router, Options::default()) + /// .await + /// .expect("failed to create WebhookServer"); + /// # } /// ``` /// /// ### Example with Custom Options @@ -119,16 +116,53 @@ impl WebhookServer { /// use stackable_webhook::{WebhookServer, Options}; /// use axum::Router; /// + /// # async fn test() { /// let options = Options::builder() /// .bind_address([127, 0, 0, 1], 8080) + /// .add_subject_alterative_dns_name("my-san-entry") /// .build(); /// /// let router = Router::new(); - /// let server = WebhookServer::new(router, options); + /// let (server, cert_rx) = WebhookServer::new(router, options) + /// .await + /// .expect("failed to create WebhookServer"); + /// # } /// ``` - pub fn new(router: Router, options: Options) -> Self { + pub async fn new( + router: Router, + options: Options, + ) -> Result<(Self, mpsc::Receiver)> { tracing::trace!("create new webhook server"); - Self { options, router } + + // TODO (@Techassi): Make opt-in configurable from the outside + // Create an OpenTelemetry tracing layer + tracing::trace!("create tracing service (layer)"); + let trace_layer = AxumTraceLayer::new().with_opt_in(); + + // Use a service builder to provide multiple layers at once. Recommended + // by the Axum project. + // + // See https://docs.rs/axum/latest/axum/middleware/index.html#applying-multiple-middleware + // TODO (@NickLarsenNZ): rename this server_builder and keep it specific to tracing, since it's placement in the chain is important + let service_builder = ServiceBuilder::new().layer(trace_layer); + + // Create the root router and merge the provided router into it. + tracing::debug!("create core router and merge provided router"); + let router = router + .layer(service_builder) + // The health route is below the AxumTraceLayer so as not to be instrumented + .route("/health", get(|| async { "ok" })); + + tracing::debug!("create TLS server"); + let (tls_server, cert_rx) = TlsServer::new( + options.socket_addr, + router, + options.subject_alterative_dns_names, + ) + .await + .context(CreateTlsServerSnafu)?; + + Ok((Self { tls_server }, cert_rx)) } /// Runs the Webhook server and sets up signal handlers for shutting down. @@ -170,33 +204,6 @@ impl WebhookServer { async fn run_server(self) -> Result<()> { tracing::debug!("run webhook server"); - // TODO (@Techassi): Make opt-in configurable from the outside - // Create an OpenTelemetry tracing layer - tracing::trace!("create tracing service (layer)"); - let trace_layer = AxumTraceLayer::new().with_opt_in(); - - // Use a service builder to provide multiple layers at once. Recommended - // by the Axum project. - // - // See https://docs.rs/axum/latest/axum/middleware/index.html#applying-multiple-middleware - // TODO (@NickLarsenNZ): rename this server_builder and keep it specific to tracing, since it's placement in the chain is important - let service_builder = ServiceBuilder::new().layer(trace_layer); - - // Create the root router and merge the provided router into it. - tracing::debug!("create core router and merge provided router"); - let router = self - .router - .layer(service_builder) - // The health route is below the AxumTraceLayer so as not to be instrumented - .route("/health", get(|| async { "ok" })); - - // Create server for TLS termination - tracing::debug!("create TLS server"); - let tls_server = TlsServer::new(self.options.socket_addr, router) - .await - .context(CreateTlsServerSnafu)?; - - tracing::info!("running TLS server"); - tls_server.run().await.context(RunTlsServerSnafu) + self.tls_server.run().await.context(RunTlsServerSnafu) } } diff --git a/crates/stackable-webhook/src/options.rs b/crates/stackable-webhook/src/options.rs index 99a01133e..666395ea0 100644 --- a/crates/stackable-webhook/src/options.rs +++ b/crates/stackable-webhook/src/options.rs @@ -41,6 +41,10 @@ pub struct Options { /// The default HTTPS socket address the [`TcpListener`][tokio::net::TcpListener] /// binds to. pub socket_addr: SocketAddr, + + /// The subject alterative DNS names that should be added to the certificates generated for this + /// webhook. + pub subject_alterative_dns_names: Vec, } impl Default for Options { @@ -66,6 +70,7 @@ impl Options { #[derive(Debug, Default)] pub struct OptionsBuilder { socket_addr: Option, + subject_alterative_dns_names: Vec, } impl OptionsBuilder { @@ -91,11 +96,32 @@ impl OptionsBuilder { self } + /// Sets the subject alterative DNS names that should be added to the certificates generated for + /// this webhook. + pub fn subject_alterative_dns_names( + mut self, + subject_alterative_dns_name: Vec, + ) -> Self { + self.subject_alterative_dns_names = subject_alterative_dns_name; + self + } + + /// Adds the (subject alterative DNS name to the list of names. + pub fn add_subject_alterative_dns_name( + mut self, + subject_alterative_dns_name: impl Into, + ) -> Self { + self.subject_alterative_dns_names + .push(subject_alterative_dns_name.into()); + self + } + /// Builds the final [`Options`] by using default values for any not /// explicitly set option. pub fn build(self) -> Options { Options { socket_addr: self.socket_addr.unwrap_or(DEFAULT_SOCKET_ADDRESS), + subject_alterative_dns_names: self.subject_alterative_dns_names, } } } diff --git a/crates/stackable-webhook/src/servers/conversion.rs b/crates/stackable-webhook/src/servers/conversion.rs index 9b1ff197b..03fc3b23c 100644 --- a/crates/stackable-webhook/src/servers/conversion.rs +++ b/crates/stackable-webhook/src/servers/conversion.rs @@ -1,14 +1,61 @@ use std::fmt::Debug; -use axum::{Json, Router, extract::State, routing::post}; +use axum::{Json, Router, routing::post}; +use k8s_openapi::{ + ByteString, + apiextensions_apiserver::pkg::apis::apiextensions::v1::{ + CustomResourceConversion, CustomResourceDefinition, ServiceReference, WebhookClientConfig, + WebhookConversion, + }, +}; // Re-export this type because users of the conversion webhook server require // this type to write the handler function. Instead of importing this type from // kube directly, consumers can use this type instead. This also eliminates // keeping the kube dependency version in sync between here and the operator. pub use kube::core::conversion::ConversionReview; +use kube::{ + Api, Client, ResourceExt, + api::{Patch, PatchParams}, +}; +use snafu::{OptionExt, ResultExt, Snafu}; +use stackable_operator::cli::OperatorEnvironmentOptions; +use tokio::{sync::mpsc, try_join}; use tracing::instrument; +use x509_cert::{ + Certificate, + der::{EncodePem, pem::LineEnding}, +}; -use crate::{StatefulWebhookHandler, WebhookHandler, WebhookServer, options::Options}; +use crate::{ + WebhookError, WebhookHandler, WebhookServer, constants::DEFAULT_HTTPS_PORT, options::Options, +}; + +#[derive(Debug, Snafu)] +pub enum ConversionWebhookError { + #[snafu(display("failed to create webhook server"))] + CreateWebhookServer { source: WebhookError }, + + #[snafu(display("failed to run webhook server"))] + RunWebhookServer { source: WebhookError }, + + #[snafu(display("failed to receive certificate from channel"))] + ReceiverCertificateFromChannel, + + #[snafu(display("failed to convert CA certificate into PEM format"))] + ConvertCaToPem { source: x509_cert::der::Error }, + + #[snafu(display("failed to reconcile CRDs"))] + ReconcileCRDs { + #[snafu(source(from(ConversionWebhookError, Box::new)))] + source: Box, + }, + + #[snafu(display("failed to update CRD {crd_name:?}"))] + UpdateCRD { + source: stackable_operator::kube::Error, + crd_name: String, + }, +} impl WebhookHandler for F where @@ -19,141 +66,246 @@ where } } -impl StatefulWebhookHandler for F -where - F: FnOnce(ConversionReview, S) -> ConversionReview, -{ - fn call(self, req: ConversionReview, state: S) -> ConversionReview { - self(req, state) - } -} - /// A ready-to-use CRD conversion webhook server. /// -/// See [`ConversionWebhookServer::new()`] and [`ConversionWebhookServer::new_with_state()`] -/// for usage examples. +/// See [`ConversionWebhookServer::new()`] for usage examples. pub struct ConversionWebhookServer { - options: Options, - router: Router, + server: WebhookServer, + cert_rx: mpsc::Receiver, + client: Client, + field_manager: String, + crds: Vec, + operator_environment: OperatorEnvironmentOptions, } impl ConversionWebhookServer { - /// Creates a new conversion webhook server **without** state which expects - /// POST requests being made to the `/convert` endpoint. + /// Creates a new conversion webhook server, which expects POST requests being made to the + /// `/convert/{crd name}` endpoint. + /// + /// You need to provide two things for every CRD passed in via the `crds_and_handlers` argument: /// - /// Each request is handled by the provided `handler` function. Any function - /// with the signature `(ConversionReview) -> ConversionReview` can be - /// provided. The [`ConversionReview`] type can be imported via a re-export at - /// [`crate::servers::ConversionReview`]. + /// 1. The CRD + /// 2. A conversion function to convert between CRD versions. Typically you would use the + /// the auto-generated `try_convert` function on CRD spec definition structs for this. + /// + /// The [`ConversionWebhookServer`] takes care of reconciling the CRDs into the Kubernetes + /// cluster and takes care of adding itself as conversion webhook. This includes TLS + /// certificates and CA bundles. /// /// # Example /// - /// ``` + /// ```no_run + /// use clap::Parser; /// use stackable_webhook::{ /// servers::{ConversionReview, ConversionWebhookServer}, /// Options /// }; + /// use stackable_operator::cli::OperatorEnvironmentOptions; + /// use stackable_operator::kube::Client; + /// use stackable_operator::crd::s3::{S3Connection, S3ConnectionVersion}; + /// + /// # async fn test() { + /// let crds_and_handlers = [ + /// ( + /// S3Connection::merged_crd(S3ConnectionVersion::V1Alpha1) + /// .expect("failed to merge S3Connection CRD"), + /// S3Connection::try_convert as fn(_) -> _, + /// ), + /// ]; + /// + /// const OPERATOR_NAME: &str = "PRODUCT_OPERATOR"; + /// let client = Client::try_default().await.expect("failed to create Kubernetes client"); + /// let operator_environment = OperatorEnvironmentOptions::parse(); /// /// // Construct the conversion webhook server - /// let server = ConversionWebhookServer::new(handler, Options::default()); + /// let conversion_webhook = ConversionWebhookServer::new( + /// crds_and_handlers, + /// stackable_webhook::Options::default(), + /// client, + /// OPERATOR_NAME, + /// operator_environment, + /// ) + /// .await + /// .expect("failed to create ConversionWebhookServer"); /// - /// // Define the handler function - /// fn handler(req: ConversionReview) -> ConversionReview { - /// // In here we can do the CRD conversion - /// req - /// } + /// conversion_webhook.run().await.expect("failed to run ConversionWebhookServer"); + /// # } /// ``` - #[instrument(name = "create_conversion_webhook_server", skip(handler))] - pub fn new(handler: H, options: Options) -> Self + #[instrument( + name = "create_conversion_webhook_server", + skip(crds_and_handlers, client) + )] + pub async fn new( + crds_and_handlers: impl IntoIterator, + mut options: Options, + client: Client, + field_manager: impl Into + Debug, + operator_environment: OperatorEnvironmentOptions, + ) -> Result where H: WebhookHandler + Clone + Send + Sync + 'static, { tracing::debug!("create new conversion webhook server"); + let field_manager: String = field_manager.into(); - let handler_fn = |Json(review): Json| async { - let review = handler.call(review); - Json(review) - }; + let mut router = Router::new(); + let mut crds = Vec::new(); + for (crd, handler) in crds_and_handlers { + let crd_name = crd.name_any(); + let handler_fn = |Json(review): Json| async { + let review = handler.call(review); + Json(review) + }; - let router = Router::new().route("/convert", post(handler_fn)); - Self { router, options } - } + let route = format!("/convert/{crd_name}"); + router = router.route(&route, post(handler_fn)); + crds.push(crd); + } - /// Creates a new conversion webhook server **with** state which expects - /// POST requests being made to the `/convert` endpoint. - /// - /// Each request is handled by the provided `handler` function. Any function - /// with the signature `(ConversionReview, S) -> ConversionReview` can be - /// provided. The [`ConversionReview`] type can be imported via a re-export at - /// [`crate::servers::ConversionReview`]. - /// - /// It is recommended to wrap the state in an [`Arc`][std::sync::Arc] if it - /// needs to be mutable, see - /// . - /// - /// # Example - /// - /// ``` - /// use std::sync::Arc; - /// - /// use stackable_webhook::{ - /// servers::{ConversionReview, ConversionWebhookServer}, - /// Options - /// }; - /// - /// #[derive(Debug, Clone)] - /// struct State {} - /// - /// let shared_state = Arc::new(State {}); - /// let server = ConversionWebhookServer::new_with_state( - /// handler, - /// shared_state, - /// Options::default(), - /// ); - /// - /// // Define the handler function - /// fn handler(req: ConversionReview, state: Arc) -> ConversionReview { - /// // In here we can do the CRD conversion - /// req - /// } - /// ``` - #[instrument(name = "create_conversion_webhook_server_with_state", skip(handler))] - pub fn new_with_state(handler: H, state: S, options: Options) -> Self - where - H: StatefulWebhookHandler - + Clone - + Send - + Sync - + 'static, - S: Clone + Debug + Send + Sync + 'static, - { - tracing::debug!("create new conversion webhook server with state"); - - // NOTE (@Techassi): Initially, after adding the state extractor, the - // compiler kept throwing a trait error at me stating that the closure - // below doesn't implement the Handler trait from Axum. This had nothing - // to do with the state itself, but rather the order of extractors. All - // body consuming extractors, like the JSON extractor need to come last - // in the handler. - // https://docs.rs/axum/latest/axum/extract/index.html#the-order-of-extractors - let handler_fn = |State(state): State, Json(review): Json| async { - let review = handler.call(review, state); - Json(review) - }; - - let router = Router::new() - .route("/convert", post(handler_fn)) - .with_state(state); - - Self { router, options } + // This is how Kubernetes calls us, so it decides about the naming. + // AFAIK we can not influence this, so this is the only SAN entry needed. + let subject_alterative_dns_name = format!( + "{service_name}.{operator_namespace}.svc", + service_name = operator_environment.operator_service_name, + operator_namespace = operator_environment.operator_namespace, + ); + options + .subject_alterative_dns_names + .push(subject_alterative_dns_name); + + let (server, mut cert_rx) = WebhookServer::new(router, options) + .await + .context(CreateWebhookServerSnafu)?; + + // We block the ConversionWebhookServer creation until the certificates have been generated. + // This way we + // 1. Are able to apply the CRDs before we start the actual controllers relying on them + // 2. Avoid updating them shortly after as cert have been generated. Doing so would cause + // unnecessary "too old resource version" errors in the controllers as the CRD was updated. + let current_cert = cert_rx + .recv() + .await + .context(ReceiverCertificateFromChannelSnafu)?; + Self::reconcile_crds( + &client, + &field_manager, + &crds, + &operator_environment, + ¤t_cert, + ) + .await + .context(ReconcileCRDsSnafu)?; + + Ok(Self { + server, + cert_rx, + client, + field_manager, + crds, + operator_environment, + }) } - /// Starts the conversion webhook server by starting the underlying - /// [`WebhookServer`]. - pub async fn run(self) -> Result<(), crate::Error> { + pub async fn run(self) -> Result<(), ConversionWebhookError> { tracing::info!("starting conversion webhook server"); - let server = WebhookServer::new(self.router, self.options); - server.run().await + let Self { + server, + cert_rx, + client, + field_manager, + crds, + operator_environment, + } = self; + + try_join!( + Self::run_webhook_server(server), + Self::run_crd_reconciliation_loop( + cert_rx, + &client, + &field_manager, + &crds, + &operator_environment + ), + )?; + + Ok(()) + } + + async fn run_webhook_server(server: WebhookServer) -> Result<(), ConversionWebhookError> { + server.run().await.context(RunWebhookServerSnafu) + } + + async fn run_crd_reconciliation_loop( + mut cert_rx: mpsc::Receiver, + client: &Client, + field_manager: &str, + crds: &[CustomResourceDefinition], + operator_environment: &OperatorEnvironmentOptions, + ) -> Result<(), ConversionWebhookError> { + while let Some(current_cert) = cert_rx.recv().await { + Self::reconcile_crds( + client, + field_manager, + crds, + operator_environment, + ¤t_cert, + ) + .await + .context(ReconcileCRDsSnafu)?; + } + Ok(()) + } + + #[instrument(skip_all)] + async fn reconcile_crds( + client: &Client, + field_manager: &str, + crds: &[CustomResourceDefinition], + operator_environment: &OperatorEnvironmentOptions, + current_cert: &Certificate, + ) -> Result<(), ConversionWebhookError> { + tracing::info!( + crds = ?crds.iter().map(CustomResourceDefinition::name_any).collect::>(), + "Reconciling CRDs" + ); + let ca_bundle = current_cert + .to_pem(LineEnding::LF) + .context(ConvertCaToPemSnafu)?; + + let crd_api: Api = Api::all(client.clone()); + for mut crd in crds.iter().cloned() { + let crd_name = crd.name_any(); + + crd.spec.conversion = Some(CustomResourceConversion { + strategy: "Webhook".to_string(), + webhook: Some(WebhookConversion { + // conversionReviewVersions indicates what ConversionReview versions are understood/preferred by the webhook. + // The first version in the list understood by the API server is sent to the webhook. + // The webhook must respond with a ConversionReview object in the same version it received. + conversion_review_versions: vec!["v1".to_string()], + client_config: Some(WebhookClientConfig { + service: Some(ServiceReference { + name: operator_environment.operator_service_name.clone(), + namespace: operator_environment.operator_namespace.clone(), + path: Some(format!("/convert/{crd_name}")), + port: Some(DEFAULT_HTTPS_PORT.into()), + }), + ca_bundle: Some(ByteString(ca_bundle.as_bytes().to_vec())), + url: None, + }), + }), + }); + + let patch = Patch::Apply(&crd); + let patch_params = PatchParams::apply(field_manager); + crd_api + .patch(&crd_name, &patch_params, &patch) + .await + .with_context(|_| UpdateCRDSnafu { + crd_name: crd_name.to_string(), + })?; + } + Ok(()) } } diff --git a/crates/stackable-webhook/src/servers/mod.rs b/crates/stackable-webhook/src/servers/mod.rs index b242df779..2d87b6cad 100644 --- a/crates/stackable-webhook/src/servers/mod.rs +++ b/crates/stackable-webhook/src/servers/mod.rs @@ -2,4 +2,5 @@ //! purposes. mod conversion; -pub use conversion::*; +pub use conversion::{ConversionWebhookError, ConversionWebhookServer}; +pub use kube::core::conversion::ConversionReview; diff --git a/crates/stackable-webhook/src/tls.rs b/crates/stackable-webhook/src/tls.rs deleted file mode 100644 index 2aad52ee4..000000000 --- a/crates/stackable-webhook/src/tls.rs +++ /dev/null @@ -1,283 +0,0 @@ -//! This module contains structs and functions to easily create a TLS termination -//! server, which can be used in combination with an Axum [`Router`]. -use std::{net::SocketAddr, sync::Arc}; - -use axum::{Router, extract::Request}; -use futures_util::pin_mut; -use hyper::{body::Incoming, service::service_fn}; -use hyper_util::rt::{TokioExecutor, TokioIo}; -use opentelemetry::trace::{FutureExt, SpanKind}; -use snafu::{ResultExt, Snafu}; -use stackable_certs::{ - CertificatePairError, - ca::{CertificateAuthority, DEFAULT_CA_VALIDITY}, - keys::rsa, -}; -use tokio::net::TcpListener; -use tokio_rustls::{ - TlsAcceptor, - rustls::{ - ServerConfig, - crypto::ring::default_provider, - version::{TLS12, TLS13}, - }, -}; -use tower::{Service, ServiceExt}; -use tracing::{Instrument, Span, field::Empty, instrument}; -use tracing_opentelemetry::OpenTelemetrySpanExt; - -pub type Result = std::result::Result; - -#[derive(Debug, Snafu)] -pub enum Error { - #[snafu(display("failed to construct TLS server config, bad certificate/key"))] - InvalidTlsPrivateKey { source: tokio_rustls::rustls::Error }, - - #[snafu(display("failed to create TCP listener by binding to socket address {socket_addr:?}"))] - BindTcpListener { - source: std::io::Error, - socket_addr: SocketAddr, - }, - - #[snafu(display("failed to create CA to generate and sign webhook leaf certificate"))] - CreateCertificateAuthority { source: stackable_certs::ca::Error }, - - #[snafu(display("failed to generate webhook leaf certificate"))] - GenerateLeafCertificate { source: stackable_certs::ca::Error }, - - #[snafu(display("failed to encode leaf certificate as DER"))] - EncodeCertificateDer { - source: CertificatePairError, - }, - - #[snafu(display("failed to encode private key as DER"))] - EncodePrivateKeyDer { - source: CertificatePairError, - }, - - #[snafu(display("failed to set safe TLS protocol versions"))] - SetSafeTlsProtocolVersions { source: tokio_rustls::rustls::Error }, - - #[snafu(display("failed to run task in blocking thread"))] - TokioSpawnBlocking { source: tokio::task::JoinError }, -} - -/// Custom implementation of [`std::cmp::PartialEq`] because some inner types -/// don't implement it. -/// -/// Note that this implementation is restritced to testing because there are -/// variants that use [`stackable_certs::ca::Error`] which only implements -/// [`PartialEq`] for tests. -#[cfg(test)] -impl PartialEq for Error { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - ( - Self::BindTcpListener { - source: lhs_source, - socket_addr: lhs_socket_addr, - }, - Self::BindTcpListener { - source: rhs_source, - socket_addr: rhs_socket_addr, - }, - ) => lhs_socket_addr == rhs_socket_addr && lhs_source.kind() == rhs_source.kind(), - (lhs, rhs) => lhs == rhs, - } - } -} - -/// A server which terminates TLS connections and allows clients to commnunicate -/// via HTTPS with the underlying HTTP router. -pub struct TlsServer { - config: Arc, - socket_addr: SocketAddr, - router: Router, -} - -impl TlsServer { - #[instrument(name = "create_tls_server", skip(router))] - pub async fn new(socket_addr: SocketAddr, router: Router) -> Result { - // NOTE(@NickLarsenNZ): This code is not async, and does take some - // non-negligable amount of time to complete (moreso in debug ). - // We run this in a thread reserved for blocking code so that the Tokio - // executor is able to make progress on other futures instead of being - // blocked. - // See https://docs.rs/tokio/latest/tokio/task/fn.spawn_blocking.html - let task = tokio::task::spawn_blocking(move || { - let mut certificate_authority = - CertificateAuthority::new_rsa().context(CreateCertificateAuthoritySnafu)?; - - let leaf_certificate = certificate_authority - .generate_rsa_leaf_certificate("Leaf", "webhook", [], DEFAULT_CA_VALIDITY) - .context(GenerateLeafCertificateSnafu)?; - - let certificate_der = leaf_certificate - .certificate_der() - .context(EncodeCertificateDerSnafu)?; - - let private_key_der = leaf_certificate - .private_key_der() - .context(EncodePrivateKeyDerSnafu)?; - - let tls_provider = default_provider(); - let mut config = ServerConfig::builder_with_provider(tls_provider.into()) - .with_protocol_versions(&[&TLS12, &TLS13]) - .context(SetSafeTlsProtocolVersionsSnafu)? - .with_no_client_auth() - .with_single_cert(vec![certificate_der], private_key_der) - .context(InvalidTlsPrivateKeySnafu)?; - - config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - let config = Arc::new(config); - - Ok(Self { - socket_addr, - config, - router, - }) - }) - .await - .context(TokioSpawnBlockingSnafu)??; - - Ok(task) - } - - /// Runs the TLS server by listening for incoming TCP connections on the - /// bound socket address. It only accepts TLS connections. Internally each - /// TLS stream get handled by a Hyper service, which in turn is an Axum - /// router. - pub async fn run(self) -> Result<()> { - let tls_acceptor = TlsAcceptor::from(self.config); - let tcp_listener = - TcpListener::bind(self.socket_addr) - .await - .context(BindTcpListenerSnafu { - socket_addr: self.socket_addr, - })?; - - // To be able to extract the connect info from incoming requests, it is - // required to turn the router into a Tower service which is capable of - // doing that. Calling `into_make_service_with_connect_info` returns a - // new struct `IntoMakeServiceWithConnectInfo` which implements the - // Tower Service trait. This service is called after the TCP connection - // has been accepted. - // - // Inspired by: - // - https://github.com/tokio-rs/axum/discussions/2397 - // - https://github.com/tokio-rs/axum/blob/b02ce307371a973039018a13fa012af14775948c/examples/serve-with-hyper/src/main.rs#L98 - - let mut router = self - .router - .into_make_service_with_connect_info::(); - - pin_mut!(tcp_listener); - loop { - let tls_acceptor = tls_acceptor.clone(); - - // Wait for new tcp connection - let (tcp_stream, remote_addr) = match tcp_listener.accept().await { - Ok((stream, addr)) => (stream, addr), - Err(err) => { - tracing::trace!(%err, "failed to accept incoming TCP connection"); - continue; - } - }; - - // Here, the connect info is extracted by calling Tower's Service - // trait function on `IntoMakeServiceWithConnectInfo` - let tower_service = router.call(remote_addr).await.unwrap(); - - let span = tracing::debug_span!("accept tcp connection"); - tokio::spawn( - async move { - let span = tracing::trace_span!( - "accept tls connection", - "otel.kind" = ?SpanKind::Server, - "otel.status_code" = Empty, - "otel.status_message" = Empty, - "client.address" = remote_addr.ip().to_string(), - "client.port" = remote_addr.port() as i64, - "server.address" = Empty, - "server.port" = Empty, - "network.peer.address" = remote_addr.ip().to_string(), - "network.peer.port" = remote_addr.port() as i64, - "network.local.address" = Empty, - "network.local.port" = Empty, - "network.transport" = "tcp", - "network.type" = self.socket_addr.semantic_convention_network_type(), - ); - - if let Ok(local_addr) = tcp_stream.local_addr() { - let addr = &local_addr.ip().to_string(); - let port = local_addr.port(); - span.record("server.address", addr) - .record("server.port", port as i64) - .record("network.local.address", addr) - .record("network.local.port", port as i64); - } - - // Wait for tls handshake to happen - let tls_stream = match tls_acceptor - .accept(tcp_stream) - .instrument(span.clone()) - .await - { - Ok(tls_stream) => tls_stream, - Err(err) => { - span.record("otel.status_code", "Error") - .record("otel.status_message", err.to_string()); - tracing::trace!(%remote_addr, "error during tls handshake connection"); - return; - } - }; - - // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio. - // `TokioIo` converts between them. - let tls_stream = TokioIo::new(tls_stream); - - // Hyper also has its own `Service` trait and doesn't use tower. We can use - // `hyper::service::service_fn` to create a hyper `Service` that calls our app through - // `tower::Service::call`. - let hyper_service = service_fn(move |request: Request| { - // This carries the current context with the trace id so that the TraceLayer can use that as a parent - let otel_context = Span::current().context(); - // We need to clone here, because oneshot consumes self - tower_service - .clone() - .oneshot(request) - .with_context(otel_context) - }); - - let span = tracing::debug_span!("serve connection"); - hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) - .serve_connection_with_upgrades(tls_stream, hyper_service) - .instrument(span.clone()) - .await - .unwrap_or_else(|err| { - span.record("otel.status_code", "Error") - .record("otel.status_message", err.to_string()); - tracing::warn!(%err, %remote_addr, "failed to serve connection"); - }) - } - .instrument(span), - ); - } - } -} - -pub trait SocketAddrExt { - fn semantic_convention_network_type(&self) -> &'static str; -} - -impl SocketAddrExt for SocketAddr { - fn semantic_convention_network_type(&self) -> &'static str { - match self { - SocketAddr::V4(_) => "ipv4", - SocketAddr::V6(_) => "ipv6", - } - } -} - -// TODO (@NickLarsenNZ): impl record_error(err: impl Error) for Span as a shortcut to set otel.status_* fields -// TODO (@NickLarsenNZ): wrap tracing::span macros to automatically add otel fields diff --git a/crates/stackable-webhook/src/tls/cert_resolver.rs b/crates/stackable-webhook/src/tls/cert_resolver.rs new file mode 100644 index 000000000..e8417f9bf --- /dev/null +++ b/crates/stackable-webhook/src/tls/cert_resolver.rs @@ -0,0 +1,156 @@ +use std::sync::Arc; + +use arc_swap::ArcSwap; +use snafu::{ResultExt, Snafu}; +use stackable_certs::{CertificatePairError, ca::CertificateAuthority, keys::ecdsa}; +use tokio::sync::mpsc; +use tokio_rustls::rustls::{ + crypto::ring::default_provider, server::ResolvesServerCert, sign::CertifiedKey, +}; +use x509_cert::Certificate; + +use super::{WEBHOOK_CA_LIFETIME, WEBHOOK_CERTIFICATE_LIFETIME}; + +type Result = std::result::Result; + +#[derive(Debug, Snafu)] +pub enum CertificateResolverError { + #[snafu(display("failed send certificate to channel"))] + SendCertificateToChannel, + + #[snafu(display("failed to generate ECDSA signing key"))] + GenerateEcdsaSigningKey { source: ecdsa::Error }, + + #[snafu(display("failed to generate new certificate"))] + GenerateNewCertificate { + #[snafu(source(from(CertificateResolverError, Box::new)))] + source: Box, + }, + + #[snafu(display("failed to create CA to generate and sign webhook leaf certificate"))] + CreateCertificateAuthority { source: stackable_certs::ca::Error }, + + #[snafu(display("failed to generate webhook leaf certificate"))] + GenerateLeafCertificate { source: stackable_certs::ca::Error }, + + #[snafu(display("failed to encode leaf certificate as DER"))] + EncodeCertificateDer { + source: CertificatePairError, + }, + + #[snafu(display("failed to encode private key as DER"))] + EncodePrivateKeyDer { + source: CertificatePairError, + }, + + #[snafu(display("failed to decode CertifiedKey from DER"))] + DecodeCertifiedKeyFromDer { source: tokio_rustls::rustls::Error }, + + #[snafu(display("failed to run task in blocking thread"))] + TokioSpawnBlocking { source: tokio::task::JoinError }, +} + +/// This struct serves as [`ResolvesServerCert`] to always hand out the current certificate for TLS +/// client connections. +/// +/// It offers the [`Self::rotate_certificate`] function to create a fresh certificate and basically +/// hot-reload the certificate in the running webhook. +#[derive(Debug)] +pub struct CertificateResolver { + /// Using a [`ArcSwap`] (over e.g. [`tokio::sync::RwLock`]), so that we can easily + /// (and performant) bridge between async write and sync write. + current_certified_key: ArcSwap, + subject_alterative_dns_names: Arc>, + + cert_tx: mpsc::Sender, +} + +impl CertificateResolver { + pub async fn new( + subject_alterative_dns_names: Vec, + cert_tx: mpsc::Sender, + ) -> Result { + let subject_alterative_dns_names = Arc::new(subject_alterative_dns_names); + let (cert, certified_key) = Self::generate_new_cert(subject_alterative_dns_names.clone()) + .await + .context(GenerateNewCertificateSnafu)?; + + cert_tx + .send(cert) + .await + .map_err(|_err| CertificateResolverError::SendCertificateToChannel)?; + + Ok(Self { + subject_alterative_dns_names, + current_certified_key: ArcSwap::new(certified_key), + cert_tx, + }) + } + + pub async fn rotate_certificate(&self) -> Result<()> { + let (cert, certified_key) = + Self::generate_new_cert(self.subject_alterative_dns_names.clone()) + .await + .context(GenerateNewCertificateSnafu)?; + + // TODO: Sign the new cert somehow with the old cert. See https://github.com/stackabletech/decisions/issues/56 + + self.cert_tx + .send(cert) + .await + .map_err(|_err| CertificateResolverError::SendCertificateToChannel)?; + + self.current_certified_key.store(certified_key); + + Ok(()) + } + + /// FIXME: This should *not* construct a CA cert and cert, but only a cert! + /// This needs some changes in stackable-certs though. + /// See [the relevant decision](https://github.com/stackabletech/decisions/issues/56) + async fn generate_new_cert( + subject_alterative_dns_names: Arc>, + ) -> Result<(Certificate, Arc)> { + // The certificate generations can take a while, so we use `spawn_blocking` + tokio::task::spawn_blocking(move || { + let tls_provider = default_provider(); + + let ca_key = ecdsa::SigningKey::new().context(GenerateEcdsaSigningKeySnafu)?; + let mut ca = + CertificateAuthority::new_with(ca_key, rand::random::(), WEBHOOK_CA_LIFETIME) + .context(CreateCertificateAuthoritySnafu)?; + + let certificate = ca + .generate_ecdsa_leaf_certificate( + "Leaf", + "webhook", + subject_alterative_dns_names.iter().map(|san| san.as_str()), + WEBHOOK_CERTIFICATE_LIFETIME, + ) + .context(GenerateLeafCertificateSnafu)?; + + let certificate_der = certificate + .certificate_der() + .context(EncodeCertificateDerSnafu)?; + let private_key_der = certificate + .private_key_der() + .context(EncodePrivateKeyDerSnafu)?; + let certificate_key = + CertifiedKey::from_der(vec![certificate_der], private_key_der, &tls_provider) + .context(DecodeCertifiedKeyFromDerSnafu)?; + + Ok((certificate.certificate().clone(), Arc::new(certificate_key))) + }) + .await + .context(TokioSpawnBlockingSnafu)? + } +} + +impl ResolvesServerCert for CertificateResolver { + fn resolve( + &self, + _client_hello: tokio_rustls::rustls::server::ClientHello<'_>, + ) -> Option> { + Some(self.current_certified_key.load().clone()) + } +} diff --git a/crates/stackable-webhook/src/tls/mod.rs b/crates/stackable-webhook/src/tls/mod.rs new file mode 100644 index 000000000..d1cb10b95 --- /dev/null +++ b/crates/stackable-webhook/src/tls/mod.rs @@ -0,0 +1,277 @@ +//! This module contains structs and functions to easily create a TLS termination +//! server, which can be used in combination with an Axum [`Router`]. +use std::{convert::Infallible, net::SocketAddr, sync::Arc}; + +use axum::{ + Router, + extract::{ConnectInfo, Request}, + middleware::AddExtension, +}; +use hyper::{body::Incoming, service::service_fn}; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use opentelemetry::trace::{FutureExt, SpanKind}; +use snafu::{ResultExt, Snafu}; +use stackable_operator::time::Duration; +use tokio::{ + net::{TcpListener, TcpStream}, + sync::mpsc, +}; +use tokio_rustls::{ + TlsAcceptor, + rustls::{ + ServerConfig, + crypto::ring::default_provider, + version::{TLS12, TLS13}, + }, +}; +use tower::{Service, ServiceExt}; +use tracing::{Instrument, Span, field::Empty, instrument}; +use tracing_opentelemetry::OpenTelemetrySpanExt; +use x509_cert::Certificate; + +mod cert_resolver; + +pub use cert_resolver::{CertificateResolver, CertificateResolverError}; + +pub const WEBHOOK_CA_LIFETIME: Duration = Duration::from_minutes_unchecked(3); +pub const WEBHOOK_CERTIFICATE_LIFETIME: Duration = Duration::from_minutes_unchecked(2); +pub const WEBHOOK_CERTIFICATE_ROTATION_INTERVAL: Duration = Duration::from_minutes_unchecked(1); + +pub type Result = std::result::Result; + +#[derive(Debug, Snafu)] +pub enum TlsServerError { + #[snafu(display("failed to create certificate resolver"))] + CreateCertificateResolver { source: CertificateResolverError }, + + #[snafu(display("failed to create TCP listener by binding to socket address {socket_addr:?}"))] + BindTcpListener { + source: std::io::Error, + socket_addr: SocketAddr, + }, + + #[snafu(display("failed to rotate certificate"))] + RotateCertificate { source: CertificateResolverError }, + + #[snafu(display("failed to set safe TLS protocol versions"))] + SetSafeTlsProtocolVersions { source: tokio_rustls::rustls::Error }, +} + +/// A server which terminates TLS connections and allows clients to communicate +/// via HTTPS with the underlying HTTP router. +/// +/// It also rotates the generated certificates as needed. +pub struct TlsServer { + config: ServerConfig, + cert_resolver: Arc, + + socket_addr: SocketAddr, + router: Router, +} + +impl TlsServer { + /// Create a new [`TlsServer`]. + /// + /// This create a [`CertificateResolver`] with the provided `subject_alterative_dns_names`, + /// which takes care of the certificate rotation. Afterwards it create the [`ServerConfig`], + /// which let's the [`CertificateResolver`] provide the needed certificates. + #[instrument(name = "create_tls_server", skip(router))] + pub async fn new( + socket_addr: SocketAddr, + router: Router, + subject_alterative_dns_names: Vec, + ) -> Result<(Self, mpsc::Receiver)> { + let (cert_tx, cert_rx) = mpsc::channel(1); + + let cert_resolver = CertificateResolver::new(subject_alterative_dns_names, cert_tx) + .await + .context(CreateCertificateResolverSnafu)?; + let cert_resolver = Arc::new(cert_resolver); + + let tls_provider = default_provider(); + let mut config = ServerConfig::builder_with_provider(tls_provider.into()) + .with_protocol_versions(&[&TLS12, &TLS13]) + .context(SetSafeTlsProtocolVersionsSnafu)? + .with_no_client_auth() + .with_cert_resolver(cert_resolver.clone()); + config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + + let tls_server = Self { + config, + cert_resolver, + socket_addr, + router, + }; + + Ok((tls_server, cert_rx)) + } + + /// Runs the TLS server by listening for incoming TCP connections on the + /// bound socket address. It only accepts TLS connections. Internally each + /// TLS stream get handled by a Hyper service, which in turn is an Axum + /// router. + /// + /// It also starts a background task to rotate the certificate as needed. + pub async fn run(self) -> Result<()> { + let start = tokio::time::Instant::now() + *WEBHOOK_CERTIFICATE_ROTATION_INTERVAL; + let mut interval = tokio::time::interval_at(start, *WEBHOOK_CERTIFICATE_ROTATION_INTERVAL); + + let tls_acceptor = TlsAcceptor::from(Arc::new(self.config)); + let tcp_listener = + TcpListener::bind(self.socket_addr) + .await + .context(BindTcpListenerSnafu { + socket_addr: self.socket_addr, + })?; + + // To be able to extract the connect info from incoming requests, it is + // required to turn the router into a Tower service which is capable of + // doing that. Calling `into_make_service_with_connect_info` returns a + // new struct `IntoMakeServiceWithConnectInfo` which implements the + // Tower Service trait. This service is called after the TCP connection + // has been accepted. + // + // Inspired by: + // - https://github.com/tokio-rs/axum/discussions/2397 + // - https://github.com/tokio-rs/axum/blob/b02ce307371a973039018a13fa012af14775948c/examples/serve-with-hyper/src/main.rs#L98 + + let mut router = self + .router + .into_make_service_with_connect_info::(); + + loop { + let tls_acceptor = tls_acceptor.clone(); + + // Wait for either a new TCP connection or the certificate rotation interval tick + tokio::select! { + // We opt for a biased execution of arms to make sure we always check if the + // certificate needs rotation based on the interval. This ensures, we always use + // a valid certificate for the TLS connection. + biased; + + // This is cancellation-safe. If this branch is cancelled, the tick is NOT consumed. + // As such, we will not miss rotating the certificate. + _ = interval.tick() => { + self.cert_resolver + .rotate_certificate() + .await + .context(RotateCertificateSnafu)? + } + + // This is cancellation-safe. If cancelled, no new connections are accepted. + tcp_connection = tcp_listener.accept() => { + let (tcp_stream, remote_addr) = match tcp_connection { + Ok((stream, addr)) => (stream, addr), + Err(err) => { + tracing::trace!(%err, "failed to accept incoming TCP connection"); + continue; + } + }; + + // Here, the connect info is extracted by calling Tower's Service + // trait function on `IntoMakeServiceWithConnectInfo` + let tower_service: Result<_, Infallible> = router.call(remote_addr).await; + let tower_service = tower_service.expect("Infallible error can never happen"); + + let span = tracing::debug_span!("accept tcp connection"); + tokio::spawn(async move { + Self::handle_request(tcp_stream, remote_addr, tls_acceptor, tower_service, self.socket_addr) + }.instrument(span)); + } + }; + } + } + + async fn handle_request( + tcp_stream: TcpStream, + remote_addr: SocketAddr, + tls_acceptor: TlsAcceptor, + tower_service: AddExtension>, + socket_addr: SocketAddr, + ) { + let span = tracing::trace_span!( + "accept tls connection", + "otel.kind" = ?SpanKind::Server, + "otel.status_code" = Empty, + "otel.status_message" = Empty, + "client.address" = remote_addr.ip().to_string(), + "client.port" = remote_addr.port() as i64, + "server.address" = Empty, + "server.port" = Empty, + "network.peer.address" = remote_addr.ip().to_string(), + "network.peer.port" = remote_addr.port() as i64, + "network.local.address" = Empty, + "network.local.port" = Empty, + "network.transport" = "tcp", + "network.type" = socket_addr.semantic_convention_network_type(), + ); + + if let Ok(local_addr) = tcp_stream.local_addr() { + let addr = &local_addr.ip().to_string(); + let port = local_addr.port(); + span.record("server.address", addr) + .record("server.port", port as i64) + .record("network.local.address", addr) + .record("network.local.port", port as i64); + } + + // Wait for tls handshake to happen + let tls_stream = match tls_acceptor + .accept(tcp_stream) + .instrument(span.clone()) + .await + { + Ok(tls_stream) => tls_stream, + Err(err) => { + span.record("otel.status_code", "Error") + .record("otel.status_message", err.to_string()); + tracing::trace!(%remote_addr, "error during tls handshake connection"); + return; + } + }; + + // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio. + // `TokioIo` converts between them. + let tls_stream = TokioIo::new(tls_stream); + + // Hyper also has its own `Service` trait and doesn't use tower. We can use + // `hyper::service::service_fn` to create a hyper `Service` that calls our app through + // `tower::Service::call`. + let hyper_service = service_fn(move |request: Request| { + // This carries the current context with the trace id so that the TraceLayer can use that as a parent + let otel_context = Span::current().context(); + // We need to clone here, because oneshot consumes self + tower_service + .clone() + .oneshot(request) + .with_context(otel_context) + }); + + let span = tracing::debug_span!("serve connection"); + hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(tls_stream, hyper_service) + .instrument(span.clone()) + .await + .unwrap_or_else(|err| { + span.record("otel.status_code", "Error") + .record("otel.status_message", err.to_string()); + tracing::warn!(%err, %remote_addr, "failed to serve connection"); + }) + } +} + +pub trait SocketAddrExt { + fn semantic_convention_network_type(&self) -> &'static str; +} + +impl SocketAddrExt for SocketAddr { + fn semantic_convention_network_type(&self) -> &'static str { + match self { + SocketAddr::V4(_) => "ipv4", + SocketAddr::V6(_) => "ipv6", + } + } +} + +// TODO (@NickLarsenNZ): impl record_error(err: impl Error) for Span as a shortcut to set otel.status_* fields +// TODO (@NickLarsenNZ): wrap tracing::span macros to automatically add otel fields