From e8d2a6e9c76a4d0d7975d0c18b23a2f5b353579b Mon Sep 17 00:00:00 2001 From: Tyson Nottingham Date: Thu, 11 Apr 2024 10:54:44 -0700 Subject: [PATCH] Resize headers on partial parse The main benefit of this is that it enables you to get the parsed headers from partial results when using uninitialized headers. Prior to this, on a partial result, the headers in a request or response would be replaced with the original slice passed when creating the request or response (typically an empty slice when using uninitialized headers), rather than the initialized portion of the uninitialized headers passed into the parse function call. It also makes handling partial parse results more convenient in general since you can iterate over the headers in the request or response without having to manually determine the last element. This is technically a breaking change, though it seems unlikely that any client would depend on the previous behavior in a way that the change would break their code. --- src/lib.rs | 76 +++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 66 insertions(+), 10 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3232852..2d27674 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -547,7 +547,7 @@ impl<'h, 'b> Request<'h, 'b> { newline!(bytes); let len = orig_len - bytes.len(); - let headers_len = complete!(parse_headers_iter_uninit( + let parse_headers_status = parse_headers_iter_uninit( &mut headers, &mut bytes, &HeaderParserConfig { @@ -556,11 +556,14 @@ impl<'h, 'b> Request<'h, 'b> { allow_space_before_first_header_name: config.allow_space_before_first_header_name, ignore_invalid_headers: config.ignore_invalid_headers_in_requests }, - )); + )?; /* SAFETY: see `parse_headers_iter_uninit` guarantees */ self.headers = unsafe { assume_init_slice(headers) }; - Ok(Status::Complete(len + headers_len)) + match parse_headers_status { + Status::Complete(headers_len) => Ok(Status::Complete(len + headers_len)), + Status::Partial => Ok(Status::Partial), + } } /// Try to parse a buffer of bytes into the Request, @@ -583,7 +586,7 @@ impl<'h, 'b> Request<'h, 'b> { let headers: *mut [Header<'_>] = headers; let headers = headers as *mut [MaybeUninit>]; match self.parse_with_config_and_uninit_headers(buf, config, &mut *headers) { - Ok(Status::Complete(idx)) => Ok(Status::Complete(idx)), + Ok(status) => Ok(status), other => { // put the original headers back self.headers = &mut *(headers as *mut [Header<'_>]); @@ -687,7 +690,7 @@ impl<'h, 'b> Response<'h, 'b> { let headers: *mut [Header<'_>] = headers; let headers = headers as *mut [MaybeUninit>]; match self.parse_with_config_and_uninit_headers(buf, config, &mut *headers) { - Ok(Status::Complete(idx)) => Ok(Status::Complete(idx)), + Ok(status) => Ok(status), other => { // put the original headers back self.headers = &mut *(headers as *mut [Header<'_>]); @@ -745,7 +748,7 @@ impl<'h, 'b> Response<'h, 'b> { let len = orig_len - bytes.len(); - let headers_len = complete!(parse_headers_iter_uninit( + let parse_headers_status = parse_headers_iter_uninit( &mut headers, &mut bytes, &HeaderParserConfig { @@ -754,10 +757,14 @@ impl<'h, 'b> Response<'h, 'b> { allow_space_before_first_header_name: config.allow_space_before_first_header_name, ignore_invalid_headers: config.ignore_invalid_headers_in_responses } - )); + )?; /* SAFETY: see `parse_headers_iter_uninit` guarantees */ self.headers = unsafe { assume_init_slice(headers) }; - Ok(Status::Complete(len + headers_len)) + + match parse_headers_status { + Status::Complete(headers_len) => Ok(Status::Complete(len + headers_len)), + Status::Partial => Ok(Status::Partial), + } } } @@ -1386,6 +1393,7 @@ pub fn parse_chunk_size(buf: &[u8]) #[cfg(test)] mod tests { + use core::mem::MaybeUninit; use super::{Request, Response, Status, EMPTY_HEADER, parse_chunk_size}; const NUM_OF_HEADERS: usize = 4; @@ -1559,7 +1567,7 @@ mod tests { assert_eq!(req.method.unwrap(), "GET"); assert_eq!(req.path.unwrap(), "/"); assert_eq!(req.version.unwrap(), 1); - assert_eq!(req.headers.len(), NUM_OF_HEADERS); // doesn't slice since not Complete + assert_eq!(req.headers.len(), 1); assert_eq!(req.headers[0].name, "Host"); assert_eq!(req.headers[0].value, b"yolo"); } @@ -1760,7 +1768,7 @@ mod tests { assert_eq!(res.version.unwrap(), 1); assert_eq!(res.code.unwrap(), 200); assert_eq!(res.reason.unwrap(), "OK"); - assert_eq!(res.headers.len(), NUM_OF_HEADERS); // doesn't slice since not Complete + assert_eq!(res.headers.len(), 1); assert_eq!(res.headers[0].name, "Server"); assert_eq!(res.headers[0].value, b"yolo"); } @@ -2591,4 +2599,52 @@ mod tests { assert_eq!(response.headers[0].name, "Space-Before-Header"); assert_eq!(response.headers[0].value, &b"hello there"[..]); } + + #[test] + fn test_request_partial_with_uninit_headers() { + const REQUEST: &[u8] = b"GET / HTTP/1.1\r\nFoo: bar\r\nBaz: quux\r\n"; + + let mut headers = unsafe { + MaybeUninit::<[MaybeUninit>; 4]>::uninit().assume_init() + }; + + let mut request = Request::new(&mut []); + + let result = crate::ParserConfig::default() + .parse_request_with_uninit_headers(&mut request, REQUEST, &mut headers); + + assert_eq!(result, Ok(Status::Partial)); + assert_eq!(request.method.unwrap(), "GET"); + assert_eq!(request.path.unwrap(), "/"); + assert_eq!(request.version.unwrap(), 1); + assert_eq!(request.headers.len(), 2); + assert_eq!(request.headers[0].name, "Foo"); + assert_eq!(request.headers[0].value, &b"bar"[..]); + assert_eq!(request.headers[1].name, "Baz"); + assert_eq!(request.headers[1].value, &b"quux"[..]); + } + + #[test] + fn test_response_partial_with_uninit_headers() { + const RESPONSE: &[u8] = b"HTTP/1.1 200 OK\r\nFoo: bar\r\nBaz: quux\r\n"; + + let mut headers = unsafe { + MaybeUninit::<[MaybeUninit>; 4]>::uninit().assume_init() + }; + + let mut response = Response::new(&mut []); + + let result = crate::ParserConfig::default() + .parse_response_with_uninit_headers(&mut response, RESPONSE, &mut headers); + + assert_eq!(result, Ok(Status::Partial)); + assert_eq!(response.version.unwrap(), 1); + assert_eq!(response.code.unwrap(), 200); + assert_eq!(response.reason.unwrap(), "OK"); + assert_eq!(response.headers.len(), 2); + assert_eq!(response.headers[0].name, "Foo"); + assert_eq!(response.headers[0].value, &b"bar"[..]); + assert_eq!(response.headers[1].name, "Baz"); + assert_eq!(response.headers[1].value, &b"quux"[..]); + } }