Skip to content

Commit f1e1d1f

Browse files
committed
feat: adding constants and timeout on handshake
1 parent 46ba715 commit f1e1d1f

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

src/handshake.rs

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use tokio::io::{split, AsyncReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHa
1414
use tokio::net::TcpStream;
1515
use tokio::sync::mpsc::channel;
1616
use tokio::sync::Mutex;
17+
use tokio::time::{timeout, Duration};
1718
use tokio_stream::wrappers::ReceiverStream;
1819

1920
const UUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
@@ -25,6 +26,10 @@ const HTTP_ACCEPT_RESPONSE: &str = "HTTP/1.1 101 Switching Protocols\r\n\
2526
Sec-WebSocket-Accept: {}\r\n\
2627
\r\n";
2728

29+
const HTTP_METHOD: &str = "GET";
30+
const SEC_WEBSOCKET_KEY: &str = "Sec-WebSocket-Key";
31+
const HOST: &str = "Host";
32+
2833
pub type Result = std::result::Result<WSConnection, Error>;
2934

3035
/// Used for accepting websocket connections as a server.
@@ -148,8 +153,17 @@ async fn parse_handshake(
148153
// connections that sends a lot of data
149154
let mut buffer = vec![0; 1024];
150155

151-
// Read the request into the buffer
152-
let n = buf_reader.read(&mut buffer).await?;
156+
// Adding a timeout to the buffer read, since some attackers may only connect to the TCP
157+
// endpoint, and froze without sending the HTTP handshake.
158+
// Therefore, we need to drop all these cases
159+
let read_result = timeout(
160+
Duration::from_secs(2), buf_reader.read(&mut buffer)).await;
161+
162+
let n = match read_result {
163+
Ok(Ok(size)) => size, // Continue processing the payload
164+
Ok(Err(e)) => Err(e)?, // An error occurred while reading
165+
Err(_e) => Err(_e)?, // Reading from the socket timed out
166+
};
153167

154168
// Parse the HTTP request from the buffer
155169
let mut headers = [EMPTY_HEADER; 16];
@@ -158,23 +172,15 @@ async fn parse_handshake(
158172
req.parse(&buffer[..n])?;
159173

160174
// Validate the WebSocket handshake
161-
if req.method != Some("GET") || req.version != Some(1) {
175+
if req.method != Some(HTTP_METHOD) || req.version != Some(1) {
162176
return Err(Error::InvalidHTTPHandshake);
163177
}
164178

165-
// if req.get_header_value("Connection") != Some(String::from("Upgrade")) {
166-
// return Err(Error::NoConnectionHeaderPresent);
167-
// }
168-
//
169-
// if req.get_header_value("Upgrade") != Some(String::from("websocket")) {
170-
// return Err(Error::NoUpgradeHeaderPresent);
171-
// }
172-
173-
if req.get_header_value("Host").is_none() {
179+
if req.get_header_value(HOST).is_none() {
174180
return Err(Error::NoHostHeaderPresent);
175181
}
176182

177-
let sec_websocket_key = match req.get_header_value("Sec-WebSocket-Key") {
183+
let sec_websocket_key = match req.get_header_value(SEC_WEBSOCKET_KEY) {
178184
Some(key) => key.to_string(),
179185
None => Err(Error::NoSecWebsocketKey)?,
180186
};

0 commit comments

Comments
 (0)