Skip to content

Refactor parser logic to better allow subclassing #77

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 158 additions & 116 deletions multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ def parse_options_header(header, options=None, unquote=header_unquote):
##############################################################################


# Constants used by the parser
_HEADER_EXPECTED = frozenset(["Content-Disposition", "Content-Type", "Content-Length"])
# Parser states as constants
_PREAMBLE = "PREAMBLE"
_HEADER = "HEADER"
Expand Down Expand Up @@ -314,7 +316,7 @@ def __init__(
will always trigger exceptions, even in non-strict mode.

The various limits are meant as safeguards and exceeding any of those
limit will trigger a :exc:`ParserLimitReached` exception.
limits will trigger a :exc:`ParserLimitReached` exception while parsing.

:param boundary: The multipart boundary as found in the Content-Type header.
:param content_length: Expected input size in bytes, or -1 if unknown.
Expand All @@ -338,11 +340,14 @@ def __init__(

# Internal parser state
self._parsed = 0
self._fieldcount = 0
self._buffer = bytearray()
self._current = MultipartSegment(self)
self._state = _PREAMBLE

self._segment = None
self._segment_count = 0
self._segment_headerlist = []
self._segment_limit = -1

#: True if the parser reached the end of the multipart stream, stopped
#: parsing due to an :attr:`error`, or :meth:`<close>` was called.
self.closed = False
Expand All @@ -359,20 +364,20 @@ def parse(
self, chunk: Union[bytes, bytearray]
) -> Iterator[Union["MultipartSegment", bytearray, None]]:
"""Parse a chunk of data and yield as many result objects as possible
with the data given.
from the data given.

For each multipart segment, the parser will emit a single instance of
:class:`MultipartSegment` with all headers already present, followed by
zero or more non-empty `bytearray` chunks of the segment payload,
followed by a single `None` signaling the end of the current segment.

For each multipart segment, the parser will emit a single instance
of :class:`MultipartSegment` with all headers already present, followed
by zero or more non-empty `bytearray` instances containing parts of the
segment body, followed by a single `None` signaling the end of the
current segment.
The returned iterator will yield results up to the point where more
data is needed or the end of the multipart stream was detected. The
iterator must be fully consumed before feeding more data to the parser.

The returned iterator will stop if more data is required or if the end
of the multipart stream was detected. The iterator must be fully consumed
before parsing the next chunk. End of input can be signaled by parsing
an empty chunk or closing the parser. This is important to verify the
multipart message was parsed completely and the last segment is actually
complete.
End of input can be signaled by parsing an empty chunk or closing the
parser. This is important to detect incomplete multipart streams
where the last segment is still missing data.

Format errors or exceeded limits will trigger :exc:`MultipartError`.
"""
Expand Down Expand Up @@ -423,6 +428,7 @@ def parse(
tail = buffer[next_start - 2 : next_start]

if tail == b"\r\n": # Normal delimiter found
self._on_segment_start()
self._state = _HEADER
offset = next_start
continue
Expand All @@ -449,12 +455,12 @@ def parse(
nl = buffer.find(b"\r\n", offset)

if nl > offset: # Non-empty header line
self._current._add_headerline(buffer[offset:nl])
self._on_segment_headerline(buffer[offset:nl])
offset = nl + 2
continue
elif nl == offset: # Empty header line -> End of header section
self._current._close_headers()
yield self._current
self._segment = self._create_segment(self._segment_headerlist)
yield self._segment
self._state = _BODY
offset += 2
continue
Expand All @@ -481,27 +487,25 @@ def parse(

if tail == b"\r\n" or tail == b"--":
if index > offset:
self._current._update_size(index - offset)
yield buffer[offset:index]
yield self._on_segment_payload(buffer[offset:index])

offset = next_start
self._current._mark_complete()
yield None # End of segment
self._on_segment_complete()
yield None # end of segment

if tail == b"--": # Last delimiter
self._state = _COMPLETE
break
else: # Normal delimiter
self._current = MultipartSegment(self)
self._on_segment_start()
self._state = _HEADER
continue

# Keep enough in buffer to accout for a partial delimiter at
# the end, but emit the rest.
chunk_end = bufferlen - (d_len + 1)
assert chunk_end > offset # Always true
self._current._update_size(chunk_end - offset)
yield buffer[offset:chunk_end]
yield self._on_segment_payload(buffer[offset:chunk_end])
offset = chunk_end
break # wait for more data

Expand All @@ -519,6 +523,101 @@ def parse(
self.close(check_complete=False)
raise

def _on_segment_start(self):
"""Reset internal state to start a new segment"""
self._segment_count += 1
if self._segment_count > self.max_segment_count:
raise ParserLimitReached("Maximum segment count exceeded")

self._segment = None
self._segment_headerlist = []
self._segment_limit = -1

def _on_segment_headerline(self, line: Union[bytes, bytearray]):
"""Parse a raw segment header line, which may be a continuation of a
previous line in non-strict mode."""
assert line and self._segment is None

# Handle header continuation (headers split into multiple lines)
if line[0] in b" \t": # Multi-line header value
if not self._segment_headerlist or self.strict:
raise StrictParserError("Unexpected segment header continuation")
prev = ": ".join(self._segment_headerlist.pop())
line = prev.encode(self.header_charset) + b" " + line.strip()

# Enforce header limits
if len(line) > self.max_header_size:
raise ParserLimitReached("Maximum segment header length exceeded")
if len(self._segment_headerlist) >= self.max_header_count:
raise ParserLimitReached("Maximum segment header count exceeded")

# Decode headers into header name and value
try:
name, col, value = line.decode(self.header_charset).partition(":")
name = name.strip().title()
if not col or not name:
raise ParserError("Malformed segment header")
if name not in _HEADER_EXPECTED:
if " " in name or not name.isascii() or not name.isprintable():
raise ParserError("Invalid segment header name")
value = value.strip()
except UnicodeDecodeError as err:
raise ParserError("Segment header failed to decode", err)

if name == "Content-Length":
if not value.isdecimal():
raise ParserError("Invalid segment Content-Length header value")
content_length = int(value)
if content_length > self.max_segment_size:
raise ParserLimitReached(
"Segment Content-Length larger than maximum segment size"
)
self._segment_limit = content_length

self._segment_headerlist.append((name, value))

def _create_segment(self, headerlist: List[Tuple[str, str]]):
"""Create a :class:`MultipartSegment` from a list of headers and check
for missing or invalid headers.

This implementation is specific 'multipart/form-data' and will reject
segments with missing or invalid `Content-Disposition` headers or header
options.

Subclasses can override this method to support other multipart stream
types (e.g. multipart/byteranges) with different restrictions.
"""
segment = MultipartSegment(headerlist)

if segment.disposition != "form-data":
if segment.disposition is None:
raise ParserError("Missing Content-Disposition segment header")
raise ParserError("Invalid Content-Disposition segment header: Wrong type")
if segment.name is None:
segment.name = ""
if self.strict:
raise StrictParserError(
"Invalid Content-Disposition segment header: Missing name option"
)

return segment

def _on_segment_payload(self, chunk: Union[bytes, bytearray]):
assert self._segment is not None and not self._segment.complete
self._segment.size += len(chunk)
if self._segment.size > self.max_segment_size:
raise ParserLimitReached("Maximum segment size exceeded")
if -1 < self._segment_limit < self._segment.size:
raise ParserError("Segment Content-Length exceeded")
return chunk

def _on_segment_complete(self):
assert self._segment is not None and not self._segment.complete
if self._segment.size < self._segment_limit:
raise ParserError("Segment size does not match Content-Length header")
self._segment.complete = True
return None

def close(self, check_complete=True):
"""
Close this parser if not already closed.
Expand All @@ -538,119 +637,62 @@ def close(self, check_complete=True):


class MultipartSegment:
"""A :class:`MultipartSegment` represents the header section of a single
multipart part and provides convenient access to part headers and other
details (e.g. :attr:`name` and :attr:`filename`). Each segment also tracks
its own content :attr:`size` while the :class:`PushMultipartParser`
processes more data, and is marked as :attr:`complete` as soon as the next
multipart border is found. Segments do not store or buffer any of their
content data, though.
"""Representation of the header section of a single multipart segment.

:class:`MultipartSegment` instances do not store or buffer any payload data,
but the parser will update the payload :attr:`size` property while parsing,
and mark the segment as :attr:`complete` when done.
"""

#: List of headers as name/value pairs with normalized (Title-Case) names.
#: Ordered list of headers as (name, value) pairs. Header names are
#: normalized (Title-Case) and values are stripped of leading or tailing
#: whitespace.
headerlist: List[Tuple[str, str]]
#: The 'name' option of the `Content-Disposition` header. Always a string,
#: but may be empty.
name: str

#: The cleaned up `Content-Disposition` header value without any header
#: options. This will always be 'form-data' in HTTP multipart contexts.
disposition: Optional[str]
#: The 'name' option of the `Content-Disposition` header. For `form-data`
#: this will always be a string, but the string may be empty.
name: Optional[str]
#: The optional 'filename' option of the `Content-Disposition` header.
filename: Optional[str]
#: The cleaned up `Content-Type` segment header, if present. The value is
#: lower-cased and header options (e.g. charset) are removed.

#: The cleaned up `Content-Type` segment header without any header options.
content_type: Optional[str]
#: The 'charset' option of the `Content-Type` header, if present.
#: The optional 'charset' option of the `Content-Type` header.
charset: Optional[str]

#: Segment body size (so far). Will be updated during parsing.
#: Segment body size (so far). Will be updated for each chunk of payload
#: during parsing.
size: int
#: If true, the segment content was fully parsed and the size value is final.
#: True if the parser detected the end of the segment and no more payload
#: chunks are to be expected.
complete: bool

def __init__(self, parser: PushMultipartParser):
"""Private constructor, used by :class:`PushMultipartParser`"""
self._parser = parser
def __init__(self, headerlist: List[Tuple[str, str]]):
"""Private constructor used by :class:`PushMultipartParser`"""

if parser._fieldcount + 1 > parser.max_segment_count:
raise ParserLimitReached("Maximum segment count exceeded")
parser._fieldcount += 1

self.headerlist = []
self.size = 0
self.complete = False

self.name = None # type: ignore
self.headerlist = headerlist
self.disposition = None
self.name = None
self.filename = None
self.content_type = None
self.charset = None
self._clen = -1
self._size_limit = parser.max_segment_size

def _add_headerline(self, line: Union[bytes, bytearray]):
assert line and self.name is None
parser = self._parser

if line[0] in b" \t": # Multi-line header value
if not self.headerlist or parser.strict:
raise StrictParserError("Unexpected segment header continuation")
prev = ": ".join(self.headerlist.pop())
line = prev.encode(parser.header_charset) + b" " + line.strip()

if len(line) > parser.max_header_size:
raise ParserLimitReached("Maximum segment header length exceeded")
if len(self.headerlist) >= parser.max_header_count:
raise ParserLimitReached("Maximum segment header count exceeded")

try:
name, col, value = line.decode(parser.header_charset).partition(":")
name = name.strip()
if not col or not name:
raise ParserError("Malformed segment header")
if " " in name or not name.isascii() or not name.isprintable():
raise ParserError("Invalid segment header name")
except UnicodeDecodeError as err:
raise ParserError("Segment header failed to decode", err)

self.headerlist.append((name.title(), value.strip()))

def _close_headers(self):
assert self.name is None

for h, v in self.headerlist:
if h == "Content-Disposition":
dtype, args = parse_options_header(
v, unquote=content_disposition_unquote
for name, value in headerlist:
if name == "Content-Disposition":
self.disposition, args = parse_options_header(
value, unquote=content_disposition_unquote
)
if dtype != "form-data":
raise ParserError(
"Invalid Content-Disposition segment header: Wrong type"
)
if "name" not in args and self._parser.strict:
raise StrictParserError(
"Invalid Content-Disposition segment header: Missing name option"
)
self.name = args.get("name", "")
self.name = args.get("name")
self.filename = args.get("filename")
elif h == "Content-Type":
self.content_type, args = parse_options_header(v)
elif name == "Content-Type":
self.content_type, args = parse_options_header(value)
self.charset = args.get("charset")
elif h == "Content-Length" and v.isdecimal():
self._clen = int(v)

if self.name is None:
raise ParserError("Missing Content-Disposition segment header")

def _update_size(self, bytecount: int):
assert self.name is not None and not self.complete
self.size += bytecount
if self._clen >= 0 and self.size > self._clen:
raise ParserError("Segment Content-Length exceeded")
if self.size > self._size_limit:
raise ParserLimitReached("Maximum segment size exceeded")

def _mark_complete(self):
assert self.name is not None and not self.complete
if self._clen >= 0 and self.size != self._clen:
raise ParserError("Segment size does not match Content-Length header")
self.complete = True

def header(self, name: str, default=None):
"""Return the value of a header if present, or a default value."""
Expand Down