Skip to content

feat(event_handler): add File parameter support for multipart/form-data uploads in OpenAPI utility #7132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 196 additions & 16 deletions aws_lambda_powertools/event_handler/middlewares/openapi_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import dataclasses
import json
import logging
import re
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, Union
from urllib.parse import parse_qs

from pydantic import BaseModel
Expand All @@ -19,7 +20,7 @@
from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, ResponseValidationError
from aws_lambda_powertools.event_handler.openapi.params import Param
from aws_lambda_powertools.event_handler.openapi.params import Param, UploadFile

if TYPE_CHECKING:
from aws_lambda_powertools.event_handler import Response
Expand All @@ -35,6 +36,7 @@
CONTENT_DISPOSITION_NAME_PARAM = "name="
APPLICATION_JSON_CONTENT_TYPE = "application/json"
APPLICATION_FORM_CONTENT_TYPE = "application/x-www-form-urlencoded"
MULTIPART_FORM_CONTENT_TYPE = "multipart/form-data"


class OpenAPIRequestValidationMiddleware(BaseMiddlewareHandler):
Expand Down Expand Up @@ -125,8 +127,12 @@ def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]:
elif content_type.startswith(APPLICATION_FORM_CONTENT_TYPE):
return self._parse_form_data(app)

# Handle multipart form data
elif content_type.startswith(MULTIPART_FORM_CONTENT_TYPE):
return self._parse_multipart_data(app, content_type)

else:
raise NotImplementedError("Only JSON body or Form() are supported")
raise NotImplementedError(f"Content type '{content_type}' is not supported")

def _parse_json_data(self, app: EventHandlerInstance) -> dict[str, Any]:
"""Parse JSON data from the request body."""
Expand Down Expand Up @@ -169,6 +175,142 @@ def _parse_form_data(self, app: EventHandlerInstance) -> dict[str, Any]:
],
) from e

def _parse_multipart_data(self, app: EventHandlerInstance, content_type: str) -> dict[str, Any]:
"""Parse multipart/form-data."""
try:
decoded_bytes = self._decode_request_body(app)
boundary_bytes = self._extract_boundary_bytes(content_type)
return self._parse_multipart_sections(decoded_bytes, boundary_bytes)

except Exception as e:
raise RequestValidationError(
[
{
"type": "multipart_invalid",
"loc": ("body",),
"msg": "Invalid multipart form data",
"input": {},
"ctx": {"error": str(e)},
},
],
) from e

def _decode_request_body(self, app: EventHandlerInstance) -> bytes:
"""Decode the request body, handling base64 encoding if necessary."""
import base64

body = app.current_event.body or ""

if app.current_event.is_base64_encoded:
try:
return base64.b64decode(body)
except Exception:
# If decoding fails, use body as-is
return body.encode("utf-8") if isinstance(body, str) else body
else:
return body.encode("utf-8") if isinstance(body, str) else body

def _extract_boundary_bytes(self, content_type: str) -> bytes:
"""Extract and return the boundary bytes from the content type header."""
boundary_match = re.search(r"boundary=([^;,\s]+)", content_type)

if not boundary_match:
# Handle WebKit browsers that may use different boundary formats
webkit_match = re.search(r"WebKitFormBoundary([a-zA-Z0-9]+)", content_type)
if webkit_match:
boundary = "WebKitFormBoundary" + webkit_match.group(1)
else:
raise ValueError("No boundary found in multipart content-type")
else:
boundary = boundary_match.group(1).strip('"')

return ("--" + boundary).encode("utf-8")

def _parse_multipart_sections(self, decoded_bytes: bytes, boundary_bytes: bytes) -> dict[str, Any]:
"""Parse individual multipart sections from the decoded body."""
parsed_data: dict[str, Any] = {}

if not decoded_bytes:
return parsed_data

sections = decoded_bytes.split(boundary_bytes)

for section in sections[1:-1]: # Skip first empty and last closing parts
if not section.strip():
continue

field_name, content = self._parse_multipart_section(section)
if field_name:
parsed_data[field_name] = content

return parsed_data

def _parse_multipart_section(self, section: bytes) -> tuple[str | None, bytes | str | UploadFile]:
"""Parse a single multipart section to extract field name and content."""
headers_part, content = self._split_section_headers_and_content(section)

if headers_part is None:
return None, b""

# Extract field name from Content-Disposition header
name_match = re.search(r'name="([^"]+)"', headers_part)
if not name_match:
return None, b""

field_name = name_match.group(1)

# Check if it's a file field and process accordingly
if "filename=" in headers_part:
# It's a file - extract metadata and create UploadFile
filename_match = re.search(r'filename="([^"]*)"', headers_part)
filename = filename_match.group(1) if filename_match else None

# Extract Content-Type if present
content_type_match = re.search(r"Content-Type:\s*([^\r\n]+)", headers_part, re.IGNORECASE)
content_type = content_type_match.group(1).strip() if content_type_match else None

# Parse all headers from the section
headers = {}
for line_raw in headers_part.split("\n"):
line = line_raw.strip()
if ":" in line and not line.startswith("Content-Disposition"):
key, value = line.split(":", 1)
headers[key.strip()] = value.strip()

# Create UploadFile instance with metadata
upload_file = UploadFile(
file=content,
filename=filename,
content_type=content_type,
headers=headers,
)
return field_name, upload_file
else:
# It's a regular form field - decode as string
return field_name, self._decode_form_field_content(content)

def _split_section_headers_and_content(self, section: bytes) -> tuple[str | None, bytes]:
"""Split a multipart section into headers and content parts."""
header_end = section.find(b"\r\n\r\n")
if header_end == -1:
header_end = section.find(b"\n\n")
if header_end == -1:
return None, b""
content = section[header_end + 2 :].strip()
else:
content = section[header_end + 4 :].strip()

headers_part = section[:header_end].decode("utf-8", errors="ignore")
return headers_part, content

def _decode_form_field_content(self, content: bytes) -> str | bytes:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm I guess this method returns the decode content of the file and while this is nice, I think developers must also have access to filename, headers, content-type to reconstruct the file in the Lambda..

I'm talking about something like FastAPI is doing with UploadFile class - https://fastapi.tiangolo.com/reference/uploadfile/#fastapi.UploadFile.file.

Can you investigate this, please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will look in to it and see how we can have that part of the implementation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leandrodamascena! Great feedback! I've implemented exactly what you requested - a FastAPI-inspired UploadFile class that provides developers with complete access to filename, headers, content-type, and all metadata needed to reconstruct files in Lambda functions.

UploadFile Response:
{
"filename": "important-document.pdf",
"content_type": "application/pdf",
"size": 52,
"headers": {
"Content-Type": "application/pdf",
"X-Upload-ID": "12345",
"X-File-Hash": "abc123def456"
},
"content_preview": "PDF file content with metadata for reconstruction.",
"can_reconstruct_file": true
}

Backward Compatibility Response:
{
"message": "Existing code works!",
"size": 27
}

"""Decode form field content as string, falling back to bytes if decoding fails."""
try:
return content.decode("utf-8")
except UnicodeDecodeError:
# If can't decode as text, keep as bytes
return content


class OpenAPIResponseValidationMiddleware(BaseMiddlewareHandler):
"""
Expand Down Expand Up @@ -346,6 +488,47 @@ def _request_params_to_args(
return values, errors


def _get_field_location(field: ModelField, field_alias_omitted: bool) -> tuple[str, ...]:
"""Get the location tuple for a field based on whether alias is omitted."""
if field_alias_omitted:
return ("body",)
return ("body", field.alias)


def _get_field_value(received_body: dict[str, Any] | None, field: ModelField) -> Any | None:
"""Extract field value from received body, returning None if not found or on error."""
if received_body is None:
return None

try:
return received_body.get(field.alias)
except AttributeError:
return None


def _resolve_field_type(field_type: type) -> type:
"""Resolve the actual field type, handling Union types by returning the first non-None type."""
from typing import get_args, get_origin

if get_origin(field_type) is Union:
union_args = get_args(field_type)
non_none_types = [arg for arg in union_args if arg is not type(None)]
if non_none_types:
return non_none_types[0]
return field_type


def _convert_value_type(value: Any, field_type: type) -> Any:
"""Convert value between UploadFile and bytes for type compatibility."""
if isinstance(value, UploadFile) and field_type is bytes:
# Convert UploadFile to bytes for backward compatibility
return value.file
elif isinstance(value, bytes) and field_type == UploadFile:
# Convert bytes to UploadFile if that's what's expected
return UploadFile(file=value)
return value


def _request_body_to_args(
required_params: list[ModelField],
received_body: dict[str, Any] | None,
Expand All @@ -363,32 +546,29 @@ def _request_body_to_args(
)

for field in required_params:
# This sets the location to:
# { "user": { object } } if field.alias == user
# { { object } if field_alias is omitted
loc: tuple[str, ...] = ("body", field.alias)
if field_alias_omitted:
loc = ("body",)

value: Any | None = None
loc = _get_field_location(field, field_alias_omitted)
value = _get_field_value(received_body, field)

# Now that we know what to look for, try to get the value from the received body
if received_body is not None:
# Handle AttributeError from _get_field_value
if received_body is not None and value is None:
try:
value = received_body.get(field.alias)
# Double-check with direct access to distinguish None value from AttributeError
received_body.get(field.alias)
except AttributeError:
errors.append(get_missing_field_error(loc))
continue

# Determine if the field is required
# Handle missing values
if value is None:
if field.required:
errors.append(get_missing_field_error(loc))
else:
values[field.name] = deepcopy(field.default)
continue

# MAINTENANCE: Handle byte and file fields
# Handle type conversions for UploadFile/bytes compatibility
field_type = _resolve_field_type(field.type_)
value = _convert_value_type(value, field_type)

# Finally, validate the value
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
Expand Down
20 changes: 15 additions & 5 deletions aws_lambda_powertools/event_handler/openapi/dependant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from aws_lambda_powertools.event_handler.openapi.params import (
Body,
Dependant,
File,
Form,
Header,
Param,
ParamTypes,
Query,
_File,
analyze_param,
create_response_field,
get_flat_dependant,
Expand Down Expand Up @@ -367,13 +367,23 @@ def get_body_field_info(
if not required:
body_field_info_kwargs["default"] = None

if any(isinstance(f.field_info, _File) for f in flat_dependant.body_params):
# MAINTENANCE: body_field_info: type[Body] = _File
raise NotImplementedError("_File fields are not supported in request bodies")
elif any(isinstance(f.field_info, Form) for f in flat_dependant.body_params):
# Check for File parameters
has_file_params = any(isinstance(f.field_info, File) for f in flat_dependant.body_params)
# Check for Form parameters
has_form_params = any(isinstance(f.field_info, Form) for f in flat_dependant.body_params)

if has_file_params:
# File parameters use multipart/form-data
body_field_info = Body
body_field_info_kwargs["media_type"] = "multipart/form-data"
body_field_info_kwargs["embed"] = True
elif has_form_params:
# Form parameters use application/x-www-form-urlencoded
body_field_info = Body
body_field_info_kwargs["media_type"] = "application/x-www-form-urlencoded"
body_field_info_kwargs["embed"] = True
else:
# Regular JSON body parameters
body_field_info = Body

body_param_media_types = [
Expand Down
Loading