From 10767fc32a324b403609413c37490e2ca59ea9bc Mon Sep 17 00:00:00 2001 From: Marcel Hellkamp Date: Sat, 21 Sep 2024 13:01:54 +0200 Subject: [PATCH 1/2] docs: Fix examples in readme --- README.rst | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/README.rst b/README.rst index f773b2f..600f85d 100644 --- a/README.rst +++ b/README.rst @@ -81,12 +81,12 @@ arrive, instead of waiting for the entire request to be parsed: def wsgi(environ, start_response): assert environ["REQUEST_METHOD"] == "POST" - ctype, copts = mp.parse_options_header(environ.get("CONTENT_TYPE", "")) + ctype, copts = parse_options_header(environ.get("CONTENT_TYPE", "")) boundary = copts.get("boundary") charset = copts.get("charset", "utf8") assert ctype == "multipart/form-data" - parser = mp.MultipartParser(environ["wsgi.input"], boundary, charset) + parser = MultipartParser(environ["wsgi.input"], boundary, charset) for part in parser: if part.filename: print(f"{part.name}: File upload ({part.size} bytes)") @@ -104,20 +104,20 @@ the other parsers in this library: .. code-block:: python - from multipart import PushMultipartParser + from multipart import PushMultipartParser, MultipartSegment async def process_multipart(reader: asyncio.StreamReader, boundary: str): with PushMultipartParser(boundary) as parser: while not parser.closed: chunk = await reader.read(1024*46) for event in parser.parse(chunk): - if isinstance(event, list): - print("== Start of segment") - for header, value in event: + if isinstance(event, MultipartSegment): + print(f"== Start of segment: {event.name}") + for header, value in event.headerlist: print(f"{header}: {value}") - elif isinstance(event, bytearray): + elif event: print(f"[{len(event)} bytes of data]") - elif event is None: + else: print("== End of segment") From 8af424637de47c12476fc9afb3de92ce9466b71a Mon Sep 17 00:00:00 2001 From: Marcel Hellkamp Date: Sat, 21 Sep 2024 13:51:46 +0200 Subject: [PATCH 2/2] feat: Added some more typing annotations --- multipart.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/multipart.py b/multipart.py index 47e04e3..822d758 100644 --- a/multipart.py +++ b/multipart.py @@ -17,10 +17,9 @@ import re from io import BytesIO -from typing import Iterator, Union, Optional, Tuple, List +from typing import Iterator, Union, Optional, Tuple, List, MutableMapping, TypeVar from urllib.parse import parse_qs from wsgiref.headers import Headers -from collections.abc import MutableMapping as DictMixin import tempfile @@ -29,8 +28,10 @@ ############################################################################## # Some of these were copied from bottle: https://bottlepy.org +_V = TypeVar("V") +_D = TypeVar("D") -class MultiDict(DictMixin): +class MultiDict(MutableMapping[str, _V]): """ A dict that stores multiple values per key. Most dict methods return the last value by default. There are special methods to get all values. """ @@ -50,7 +51,7 @@ def __init__(self, *args, **kwargs): def __len__(self): return len(self.dict) - def __iter__(self): + def __iter__(self) -> Iterator[_V]: return iter(self.dict) def __contains__(self, key): @@ -65,10 +66,10 @@ def __str__(self): def __repr__(self): return repr(self.dict) - def keys(self): + def keys(self) -> Iterator[str]: return self.dict.keys() - def __getitem__(self, key): + def __getitem__(self, key) -> _V: return self.get(key, KeyError, -1) def __setitem__(self, key, value): @@ -80,16 +81,16 @@ def append(self, key, value): def replace(self, key, value): self.dict[key] = [value] - def getall(self, key): + def getall(self, key) -> List[_V]: return self.dict.get(key) or [] - def get(self, key, default=None, index=-1): + def get(self, key, default:_D=None, index=-1) -> Union[_V,_D]: if key not in self.dict and default != KeyError: return [default][index] return self.dict[key][index] - def iterallitems(self): + def iterallitems(self) -> Iterator[Tuple[str, _V]]: """ Yield (key, value) keys, but for all values. """ for key, values in self.dict.items(): for value in values: @@ -585,7 +586,7 @@ def __init__( self._done = [] self._part_iter = None - def __iter__(self): + def __iter__(self) -> Iterator["MultipartPart"]: """Iterate over the parts of the multipart message.""" if not self._part_iter: self._part_iter = self._iterparse() @@ -601,7 +602,7 @@ def parts(self): """Returns a list with all parts of the multipart message.""" return list(self) - def get(self, name, default=None): + def get(self, name, default: _D = None): """Return the first part with that name or a default value.""" for part in self: if name == part.name: @@ -737,7 +738,9 @@ def close(self): ############################################################################## -def parse_form_data(environ, charset="utf8", strict=False, **kwargs): +def parse_form_data( + environ, charset="utf8", strict=False, **kwargs +) -> Tuple[MultiDict[str], MultiDict[MultipartPart]]: """ Parses both types of form data (multipart and url-encoded) from a WSGI environment and returns a (forms, files) tuple. Both are instances of :class:`MultiDict` and may contain multiple values per key.