Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 55 additions & 58 deletions src/bin/sadmin/client_daemon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ use tokio::{
},
process::ChildStdin,
select,
sync::{
mpsc::{UnboundedReceiver, UnboundedSender},
Notify,
},
sync::mpsc::{UnboundedReceiver, UnboundedSender},
time::timeout,
};
use tokio_rustls::{client::TlsStream, rustls, TlsConnector};
Expand All @@ -52,6 +49,7 @@ use crate::{
connection::Config,
persist_daemon,
service_control::DaemonControlMessage,
state::State,
tokio_passfd::{self},
};
use sdnotify::SdNotify;
Expand Down Expand Up @@ -103,16 +101,20 @@ pub struct ClientDaemon {
pub type PersistMessageSender =
tokio::sync::oneshot::Sender<(persist_daemon::Message, Option<OwnedFd>)>;

#[derive(PartialEq, Eq)]
enum ConnectionState {
Good,
Bad,
}

pub struct Client {
connector: TlsConnector,
pub config: Config,
command_tasks: Mutex<HashMap<u64, Arc<dyn TaskBase>>>,
send_failure_notify: Notify,
sender_clear: Notify,
new_send_notify: Notify,
sender: tokio::sync::Mutex<
Option<WriteHalf<tokio_rustls::client::TlsStream<tokio::net::TcpStream>>>,
>,
connection_state: State<ConnectionState>,
script_stdin: Mutex<HashMap<u64, UnboundedSender<DataMessage>>>,
persist_responses: Mutex<HashMap<u64, PersistMessageSender>>,
persist_idc: AtomicU64,
Expand Down Expand Up @@ -141,34 +143,46 @@ impl Client {
message.push(30);
loop {
let mut s = self.sender.lock().await;
if let Some(v) = s.deref_mut() {
let write_all = write_all_and_flush(v, &message);
let sender_clear = self.sender_clear.notified();
let sleep = tokio::time::sleep(Duration::from_secs(40));
tokio::select!(
val = write_all => {
if let Err(e) = val {
// The send errored out, notify the recv half so we can try to initiate a new connection
error!("Failed sending message to backend: {}", e);
self.send_failure_notify.notify_one();
*s = None
}
break
}
_ = sender_clear => {},
_ = sleep => {
// The send timeouted, notify the recv half so we can try to initiate a new connection
error!("Timout sending message to server");
self.send_failure_notify.notify_one();
*s = None
}
);
if *self.connection_state.get() != ConnectionState::Good {
std::mem::drop(s);
// We do not currently have a send socket so lets wait for one
info!("We do not currently have a send socket so lets wait for one");
self.connection_state
.wait(|s| s == &ConnectionState::Good)
.await;
continue;
}
// We do not currently have a send socket so lets wait for one
info!("We do not currently have a send socket so lets wait for one");
std::mem::drop(s);
self.new_send_notify.notified().await;
};
let Some(v) = s.deref_mut() else {
std::mem::drop(s);
// We do not currently have a send socket so lets wait for one
error!("Logic error do not currently have a send socket so lets wait for one");
self.connection_state.set(ConnectionState::Bad);
self.connection_state
.wait(|s| s == &ConnectionState::Good)
.await;
continue;
};
let write_all = write_all_and_flush(v, &message);
let disconnected = self.connection_state.wait(|v| v != &ConnectionState::Good);
let sleep = tokio::time::sleep(Duration::from_secs(40));
tokio::select!(
val = write_all => {
if let Err(e) = val {
// The send errored out, notify the recv half so we can try to initiate a new connection
error!("Failed sending message to backend: {}", e);
self.connection_state.set(ConnectionState::Bad);
}
break
}
_ = disconnected => {
// We are disconnected, wait for reconnect
},
_ = sleep => {
// The send timeouted, notify the recv half so we can try to initiate a new connection
error!("Timout sending message to server");
self.connection_state.set(ConnectionState::Bad);
}
);
}
}

Expand Down Expand Up @@ -940,8 +954,9 @@ impl Client {
auth_message.push(30);
write_all_and_flush(&mut write, &auth_message).await?;

*self.sender.lock().await = Some(write);
self.new_send_notify.notify_one();
let mut l = self.sender.lock().await;
*l = Some(write);
self.connection_state.set(ConnectionState::Good);
Ok(read)
}

Expand Down Expand Up @@ -1005,7 +1020,7 @@ impl Client {
}
let mut start = buffer.len();
let read = read.read_buf(&mut buffer);
let send_failure = self.send_failure_notify.notified();
let disconnect = self.connection_state.wait(|v| v != &ConnectionState::Good);
let sleep = tokio::time::sleep(Duration::from_secs(120));
run_token.set_location(file!(), line!());
tokio::select! {
Expand All @@ -1022,7 +1037,7 @@ impl Client {
}
}
}
_ = send_failure => {
_ = disconnect => {
break
}
_ = sleep => {
Expand Down Expand Up @@ -1059,23 +1074,7 @@ impl Client {
break;
}
}
info!("Trying to take sender for disconnect");
run_token.set_location(file!(), line!());
{
let f = async {
loop {
self.sender_clear.notify_waiters();
self.sender_clear.notify_one();
tokio::time::sleep(Duration::from_millis(1)).await
}
};
tokio::select! {
mut l = self.sender.lock() => {
let _sender = l.take();
}
() = f => {panic!()}
}
}
self.connection_state.set(ConnectionState::Bad);
run_token.set_location(file!(), line!());
info!("Took sender for disconnect");
if let Some(notifier) = &notifier {
Expand Down Expand Up @@ -1632,9 +1631,7 @@ pub async fn client_daemon(config: Config, args: ClientDaemon) -> Result<()> {
config,
db,
command_tasks: Default::default(),
send_failure_notify: Default::default(),
sender_clear: Default::default(),
new_send_notify: Default::default(),
connection_state: State::new(ConnectionState::Bad),
sender: Default::default(),
script_stdin: Default::default(),
persist_responses: Default::default(),
Expand Down
2 changes: 2 additions & 0 deletions src/bin/sadmin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ mod run;
mod service_control;
mod service_deploy;
#[cfg(feature = "daemon")]
mod state;
#[cfg(feature = "daemon")]
mod tokio_passfd;
mod upgrade;

Expand Down
83 changes: 83 additions & 0 deletions src/bin/sadmin/state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use std::{
future::Future,
ops::Deref,
sync::{Mutex, MutexGuard},
task::Waker,
};

struct StateContent<T> {
state: T,
waiters: Vec<Waker>,
}

/// Store some state T, that can be mutated
/// It is possible to wait for the value to get into a specific state
pub struct State<T> {
content: std::sync::Mutex<StateContent<T>>,
}
pub struct StateValue<'a, T>(MutexGuard<'a, StateContent<T>>);

impl<T> Deref for StateValue<'_, T> {
type Target = T;

fn deref(&self) -> &Self::Target {
&self.0.state
}
}

pub struct StateWaiter<'a, T, P: Fn(&T) -> bool> {
state: &'a State<T>,
p: P,
}

impl<'a, T, P: Fn(&T) -> bool> Future for StateWaiter<'a, T, P> {
type Output = StateValue<'a, T>;

fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let mut content = self.state.content.lock().unwrap();
if (self.p)(&content.state) {
std::task::Poll::Ready(StateValue(content))
} else {
content.waiters.push(cx.waker().clone());
std::task::Poll::Pending
}
}
}

unsafe impl<T, P: Fn(&T) -> bool> Send for StateWaiter<'_, T, P> {}

impl<T: Eq> State<T> {
pub fn new(v: T) -> Self {
State {
content: Mutex::new(StateContent {
state: v,
waiters: Vec::new(),
}),
}
}

/// Update the value to v, notify any waiters where v fulfills the predicate
pub fn set(&self, v: T) {
let mut inner = self.content.lock().unwrap();
if inner.state == v {
return;
}
for w in std::mem::take(&mut inner.waiters) {
w.wake();
}
inner.state = v;
}

/// Get the current value, return a wrapper of the mutex lock
pub fn get(&self) -> StateValue<'_, T> {
StateValue(self.content.lock().unwrap())
}

/// Return future waiting for predicate to full some predicate
pub fn wait<P: Fn(&T) -> bool>(&self, p: P) -> StateWaiter<'_, T, P> {
StateWaiter { state: self, p }
}
}
Loading