@@ -30,10 +30,7 @@ use tokio::{
30
30
} ,
31
31
process:: ChildStdin ,
32
32
select,
33
- sync:: {
34
- mpsc:: { UnboundedReceiver , UnboundedSender } ,
35
- Notify ,
36
- } ,
33
+ sync:: mpsc:: { UnboundedReceiver , UnboundedSender } ,
37
34
time:: timeout,
38
35
} ;
39
36
use tokio_rustls:: { client:: TlsStream , rustls, TlsConnector } ;
@@ -52,6 +49,7 @@ use crate::{
52
49
connection:: Config ,
53
50
persist_daemon,
54
51
service_control:: DaemonControlMessage ,
52
+ state:: State ,
55
53
tokio_passfd:: { self } ,
56
54
} ;
57
55
use sdnotify:: SdNotify ;
@@ -103,16 +101,20 @@ pub struct ClientDaemon {
103
101
pub type PersistMessageSender =
104
102
tokio:: sync:: oneshot:: Sender < ( persist_daemon:: Message , Option < OwnedFd > ) > ;
105
103
104
+ #[ derive( PartialEq , Eq ) ]
105
+ enum ConnectionState {
106
+ Good ,
107
+ Bad ,
108
+ }
109
+
106
110
pub struct Client {
107
111
connector : TlsConnector ,
108
112
pub config : Config ,
109
113
command_tasks : Mutex < HashMap < u64 , Arc < dyn TaskBase > > > ,
110
- send_failure_notify : Notify ,
111
- sender_clear : Notify ,
112
- new_send_notify : Notify ,
113
114
sender : tokio:: sync:: Mutex <
114
115
Option < WriteHalf < tokio_rustls:: client:: TlsStream < tokio:: net:: TcpStream > > > ,
115
116
> ,
117
+ connection_state : State < ConnectionState > ,
116
118
script_stdin : Mutex < HashMap < u64 , UnboundedSender < DataMessage > > > ,
117
119
persist_responses : Mutex < HashMap < u64 , PersistMessageSender > > ,
118
120
persist_idc : AtomicU64 ,
@@ -141,34 +143,46 @@ impl Client {
141
143
message. push ( 30 ) ;
142
144
loop {
143
145
let mut s = self . sender . lock ( ) . await ;
144
- if let Some ( v) = s. deref_mut ( ) {
145
- let write_all = write_all_and_flush ( v, & message) ;
146
- let sender_clear = self . sender_clear . notified ( ) ;
147
- let sleep = tokio:: time:: sleep ( Duration :: from_secs ( 40 ) ) ;
148
- tokio:: select!(
149
- val = write_all => {
150
- if let Err ( e) = val {
151
- // The send errored out, notify the recv half so we can try to initiate a new connection
152
- error!( "Failed sending message to backend: {}" , e) ;
153
- self . send_failure_notify. notify_one( ) ;
154
- * s = None
155
- }
156
- break
157
- }
158
- _ = sender_clear => { } ,
159
- _ = sleep => {
160
- // The send timeouted, notify the recv half so we can try to initiate a new connection
161
- error!( "Timout sending message to server" ) ;
162
- self . send_failure_notify. notify_one( ) ;
163
- * s = None
164
- }
165
- ) ;
146
+ if * self . connection_state . get ( ) != ConnectionState :: Good {
147
+ std:: mem:: drop ( s) ;
148
+ // We do not currently have a send socket so lets wait for one
149
+ info ! ( "We do not currently have a send socket so lets wait for one" ) ;
150
+ self . connection_state
151
+ . wait ( |s| s == & ConnectionState :: Good )
152
+ . await ;
166
153
continue ;
167
- }
168
- // We do not currently have a send socket so lets wait for one
169
- info ! ( "We do not currently have a send socket so lets wait for one" ) ;
170
- std:: mem:: drop ( s) ;
171
- self . new_send_notify . notified ( ) . await ;
154
+ } ;
155
+ let Some ( v) = s. deref_mut ( ) else {
156
+ std:: mem:: drop ( s) ;
157
+ // We do not currently have a send socket so lets wait for one
158
+ error ! ( "Logic error do not currently have a send socket so lets wait for one" ) ;
159
+ self . connection_state . set ( ConnectionState :: Bad ) ;
160
+ self . connection_state
161
+ . wait ( |s| s == & ConnectionState :: Good )
162
+ . await ;
163
+ continue ;
164
+ } ;
165
+ let write_all = write_all_and_flush ( v, & message) ;
166
+ let disconnected = self . connection_state . wait ( |v| v != & ConnectionState :: Good ) ;
167
+ let sleep = tokio:: time:: sleep ( Duration :: from_secs ( 40 ) ) ;
168
+ tokio:: select!(
169
+ val = write_all => {
170
+ if let Err ( e) = val {
171
+ // The send errored out, notify the recv half so we can try to initiate a new connection
172
+ error!( "Failed sending message to backend: {}" , e) ;
173
+ self . connection_state. set( ConnectionState :: Bad ) ;
174
+ }
175
+ break
176
+ }
177
+ _ = disconnected => {
178
+ // We are disconnected, wait for reconnect
179
+ } ,
180
+ _ = sleep => {
181
+ // The send timeouted, notify the recv half so we can try to initiate a new connection
182
+ error!( "Timout sending message to server" ) ;
183
+ self . connection_state. set( ConnectionState :: Bad ) ;
184
+ }
185
+ ) ;
172
186
}
173
187
}
174
188
@@ -940,8 +954,9 @@ impl Client {
940
954
auth_message. push ( 30 ) ;
941
955
write_all_and_flush ( & mut write, & auth_message) . await ?;
942
956
943
- * self . sender . lock ( ) . await = Some ( write) ;
944
- self . new_send_notify . notify_one ( ) ;
957
+ let mut l = self . sender . lock ( ) . await ;
958
+ * l = Some ( write) ;
959
+ self . connection_state . set ( ConnectionState :: Good ) ;
945
960
Ok ( read)
946
961
}
947
962
@@ -1005,7 +1020,7 @@ impl Client {
1005
1020
}
1006
1021
let mut start = buffer. len ( ) ;
1007
1022
let read = read. read_buf ( & mut buffer) ;
1008
- let send_failure = self . send_failure_notify . notified ( ) ;
1023
+ let disconnect = self . connection_state . wait ( |v| v != & ConnectionState :: Good ) ;
1009
1024
let sleep = tokio:: time:: sleep ( Duration :: from_secs ( 120 ) ) ;
1010
1025
run_token. set_location ( file ! ( ) , line ! ( ) ) ;
1011
1026
tokio:: select! {
@@ -1022,7 +1037,7 @@ impl Client {
1022
1037
}
1023
1038
}
1024
1039
}
1025
- _ = send_failure => {
1040
+ _ = disconnect => {
1026
1041
break
1027
1042
}
1028
1043
_ = sleep => {
@@ -1059,23 +1074,7 @@ impl Client {
1059
1074
break ;
1060
1075
}
1061
1076
}
1062
- info ! ( "Trying to take sender for disconnect" ) ;
1063
- run_token. set_location ( file ! ( ) , line ! ( ) ) ;
1064
- {
1065
- let f = async {
1066
- loop {
1067
- self . sender_clear . notify_waiters ( ) ;
1068
- self . sender_clear . notify_one ( ) ;
1069
- tokio:: time:: sleep ( Duration :: from_millis ( 1 ) ) . await
1070
- }
1071
- } ;
1072
- tokio:: select! {
1073
- mut l = self . sender. lock( ) => {
1074
- let _sender = l. take( ) ;
1075
- }
1076
- ( ) = f => { panic!( ) }
1077
- }
1078
- }
1077
+ self . connection_state . set ( ConnectionState :: Bad ) ;
1079
1078
run_token. set_location ( file ! ( ) , line ! ( ) ) ;
1080
1079
info ! ( "Took sender for disconnect" ) ;
1081
1080
if let Some ( notifier) = & notifier {
@@ -1632,9 +1631,7 @@ pub async fn client_daemon(config: Config, args: ClientDaemon) -> Result<()> {
1632
1631
config,
1633
1632
db,
1634
1633
command_tasks : Default :: default ( ) ,
1635
- send_failure_notify : Default :: default ( ) ,
1636
- sender_clear : Default :: default ( ) ,
1637
- new_send_notify : Default :: default ( ) ,
1634
+ connection_state : State :: new ( ConnectionState :: Bad ) ,
1638
1635
sender : Default :: default ( ) ,
1639
1636
script_stdin : Default :: default ( ) ,
1640
1637
persist_responses : Default :: default ( ) ,
0 commit comments