Skip to content

Commit af86e40

Browse files
committed
client-daemon: Fix more race conditions in upstream connection handeling
1 parent e2645b2 commit af86e40

File tree

2 files changed

+57
-58
lines changed

2 files changed

+57
-58
lines changed

src/bin/sadmin/client_daemon.rs

Lines changed: 55 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,7 @@ use tokio::{
3030
},
3131
process::ChildStdin,
3232
select,
33-
sync::{
34-
mpsc::{UnboundedReceiver, UnboundedSender},
35-
Notify,
36-
},
33+
sync::mpsc::{UnboundedReceiver, UnboundedSender},
3734
time::timeout,
3835
};
3936
use tokio_rustls::{client::TlsStream, rustls, TlsConnector};
@@ -52,6 +49,7 @@ use crate::{
5249
connection::Config,
5350
persist_daemon,
5451
service_control::DaemonControlMessage,
52+
state::State,
5553
tokio_passfd::{self},
5654
};
5755
use sdnotify::SdNotify;
@@ -103,16 +101,20 @@ pub struct ClientDaemon {
103101
pub type PersistMessageSender =
104102
tokio::sync::oneshot::Sender<(persist_daemon::Message, Option<OwnedFd>)>;
105103

104+
#[derive(PartialEq, Eq)]
105+
enum ConnectionState {
106+
Good,
107+
Bad,
108+
}
109+
106110
pub struct Client {
107111
connector: TlsConnector,
108112
pub config: Config,
109113
command_tasks: Mutex<HashMap<u64, Arc<dyn TaskBase>>>,
110-
send_failure_notify: Notify,
111-
sender_clear: Notify,
112-
new_send_notify: Notify,
113114
sender: tokio::sync::Mutex<
114115
Option<WriteHalf<tokio_rustls::client::TlsStream<tokio::net::TcpStream>>>,
115116
>,
117+
connection_state: State<ConnectionState>,
116118
script_stdin: Mutex<HashMap<u64, UnboundedSender<DataMessage>>>,
117119
persist_responses: Mutex<HashMap<u64, PersistMessageSender>>,
118120
persist_idc: AtomicU64,
@@ -141,34 +143,46 @@ impl Client {
141143
message.push(30);
142144
loop {
143145
let mut s = self.sender.lock().await;
144-
if let Some(v) = s.deref_mut() {
145-
let write_all = write_all_and_flush(v, &message);
146-
let sender_clear = self.sender_clear.notified();
147-
let sleep = tokio::time::sleep(Duration::from_secs(40));
148-
tokio::select!(
149-
val = write_all => {
150-
if let Err(e) = val {
151-
// The send errored out, notify the recv half so we can try to initiate a new connection
152-
error!("Failed sending message to backend: {}", e);
153-
self.send_failure_notify.notify_one();
154-
*s = None
155-
}
156-
break
157-
}
158-
_ = sender_clear => {},
159-
_ = sleep => {
160-
// The send timeouted, notify the recv half so we can try to initiate a new connection
161-
error!("Timout sending message to server");
162-
self.send_failure_notify.notify_one();
163-
*s = None
164-
}
165-
);
146+
if *self.connection_state.get() != ConnectionState::Good {
147+
std::mem::drop(s);
148+
// We do not currently have a send socket so lets wait for one
149+
info!("We do not currently have a send socket so lets wait for one");
150+
self.connection_state
151+
.wait(|s| s == &ConnectionState::Good)
152+
.await;
166153
continue;
167-
}
168-
// We do not currently have a send socket so lets wait for one
169-
info!("We do not currently have a send socket so lets wait for one");
170-
std::mem::drop(s);
171-
self.new_send_notify.notified().await;
154+
};
155+
let Some(v) = s.deref_mut() else {
156+
std::mem::drop(s);
157+
// We do not currently have a send socket so lets wait for one
158+
error!("Logic error do not currently have a send socket so lets wait for one");
159+
self.connection_state.set(ConnectionState::Bad);
160+
self.connection_state
161+
.wait(|s| s == &ConnectionState::Good)
162+
.await;
163+
continue;
164+
};
165+
let write_all = write_all_and_flush(v, &message);
166+
let disconnected = self.connection_state.wait(|v| v != &ConnectionState::Good);
167+
let sleep = tokio::time::sleep(Duration::from_secs(40));
168+
tokio::select!(
169+
val = write_all => {
170+
if let Err(e) = val {
171+
// The send errored out, notify the recv half so we can try to initiate a new connection
172+
error!("Failed sending message to backend: {}", e);
173+
self.connection_state.set(ConnectionState::Bad);
174+
}
175+
break
176+
}
177+
_ = disconnected => {
178+
// We are disconnected, wait for reconnect
179+
},
180+
_ = sleep => {
181+
// The send timeouted, notify the recv half so we can try to initiate a new connection
182+
error!("Timout sending message to server");
183+
self.connection_state.set(ConnectionState::Bad);
184+
}
185+
);
172186
}
173187
}
174188

@@ -940,8 +954,9 @@ impl Client {
940954
auth_message.push(30);
941955
write_all_and_flush(&mut write, &auth_message).await?;
942956

943-
*self.sender.lock().await = Some(write);
944-
self.new_send_notify.notify_one();
957+
let mut l = self.sender.lock().await;
958+
*l = Some(write);
959+
self.connection_state.set(ConnectionState::Good);
945960
Ok(read)
946961
}
947962

@@ -1005,7 +1020,7 @@ impl Client {
10051020
}
10061021
let mut start = buffer.len();
10071022
let read = read.read_buf(&mut buffer);
1008-
let send_failure = self.send_failure_notify.notified();
1023+
let disconnect = self.connection_state.wait(|v| v != &ConnectionState::Good);
10091024
let sleep = tokio::time::sleep(Duration::from_secs(120));
10101025
run_token.set_location(file!(), line!());
10111026
tokio::select! {
@@ -1022,7 +1037,7 @@ impl Client {
10221037
}
10231038
}
10241039
}
1025-
_ = send_failure => {
1040+
_ = disconnect => {
10261041
break
10271042
}
10281043
_ = sleep => {
@@ -1059,23 +1074,7 @@ impl Client {
10591074
break;
10601075
}
10611076
}
1062-
info!("Trying to take sender for disconnect");
1063-
run_token.set_location(file!(), line!());
1064-
{
1065-
let f = async {
1066-
loop {
1067-
self.sender_clear.notify_waiters();
1068-
self.sender_clear.notify_one();
1069-
tokio::time::sleep(Duration::from_millis(1)).await
1070-
}
1071-
};
1072-
tokio::select! {
1073-
mut l = self.sender.lock() => {
1074-
let _sender = l.take();
1075-
}
1076-
() = f => {panic!()}
1077-
}
1078-
}
1077+
self.connection_state.set(ConnectionState::Bad);
10791078
run_token.set_location(file!(), line!());
10801079
info!("Took sender for disconnect");
10811080
if let Some(notifier) = &notifier {
@@ -1632,9 +1631,7 @@ pub async fn client_daemon(config: Config, args: ClientDaemon) -> Result<()> {
16321631
config,
16331632
db,
16341633
command_tasks: Default::default(),
1635-
send_failure_notify: Default::default(),
1636-
sender_clear: Default::default(),
1637-
new_send_notify: Default::default(),
1634+
connection_state: State::new(ConnectionState::Bad),
16381635
sender: Default::default(),
16391636
script_stdin: Default::default(),
16401637
persist_responses: Default::default(),

src/bin/sadmin/main.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ mod run;
3232
mod service_control;
3333
mod service_deploy;
3434
#[cfg(feature = "daemon")]
35+
mod state;
36+
#[cfg(feature = "daemon")]
3537
mod tokio_passfd;
3638
mod upgrade;
3739

0 commit comments

Comments
 (0)