@@ -14,6 +14,7 @@ use tokio::io::{split, AsyncReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHa
1414use tokio:: net:: TcpStream ;
1515use tokio:: sync:: mpsc:: channel;
1616use tokio:: sync:: Mutex ;
17+ use tokio:: time:: { timeout, Duration } ;
1718use tokio_stream:: wrappers:: ReceiverStream ;
1819
1920const UUID : & str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" ;
@@ -25,6 +26,10 @@ const HTTP_ACCEPT_RESPONSE: &str = "HTTP/1.1 101 Switching Protocols\r\n\
2526 Sec-WebSocket-Accept: {}\r \n \
2627 \r \n ";
2728
29+ const HTTP_METHOD : & str = "GET" ;
30+ const SEC_WEBSOCKET_KEY : & str = "Sec-WebSocket-Key" ;
31+ const HOST : & str = "Host" ;
32+
2833pub type Result = std:: result:: Result < WSConnection , Error > ;
2934
3035/// Used for accepting websocket connections as a server.
@@ -148,8 +153,17 @@ async fn parse_handshake(
148153 // connections that sends a lot of data
149154 let mut buffer = vec ! [ 0 ; 1024 ] ;
150155
151- // Read the request into the buffer
152- let n = buf_reader. read ( & mut buffer) . await ?;
156+ // Adding a timeout to the buffer read, since some attackers may only connect to the TCP
157+ // endpoint, and froze without sending the HTTP handshake.
158+ // Therefore, we need to drop all these cases
159+ let read_result = timeout (
160+ Duration :: from_secs ( 2 ) , buf_reader. read ( & mut buffer) ) . await ;
161+
162+ let n = match read_result {
163+ Ok ( Ok ( size) ) => size, // Continue processing the payload
164+ Ok ( Err ( e) ) => Err ( e) ?, // An error occurred while reading
165+ Err ( _e) => Err ( _e) ?, // Reading from the socket timed out
166+ } ;
153167
154168 // Parse the HTTP request from the buffer
155169 let mut headers = [ EMPTY_HEADER ; 16 ] ;
@@ -158,23 +172,15 @@ async fn parse_handshake(
158172 req. parse ( & buffer[ ..n] ) ?;
159173
160174 // Validate the WebSocket handshake
161- if req. method != Some ( "GET" ) || req. version != Some ( 1 ) {
175+ if req. method != Some ( HTTP_METHOD ) || req. version != Some ( 1 ) {
162176 return Err ( Error :: InvalidHTTPHandshake ) ;
163177 }
164178
165- // if req.get_header_value("Connection") != Some(String::from("Upgrade")) {
166- // return Err(Error::NoConnectionHeaderPresent);
167- // }
168- //
169- // if req.get_header_value("Upgrade") != Some(String::from("websocket")) {
170- // return Err(Error::NoUpgradeHeaderPresent);
171- // }
172-
173- if req. get_header_value ( "Host" ) . is_none ( ) {
179+ if req. get_header_value ( HOST ) . is_none ( ) {
174180 return Err ( Error :: NoHostHeaderPresent ) ;
175181 }
176182
177- let sec_websocket_key = match req. get_header_value ( "Sec-WebSocket-Key" ) {
183+ let sec_websocket_key = match req. get_header_value ( SEC_WEBSOCKET_KEY ) {
178184 Some ( key) => key. to_string ( ) ,
179185 None => Err ( Error :: NoSecWebsocketKey ) ?,
180186 } ;
0 commit comments