diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 407cd00781b..7f1c9b7d87b 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -815,7 +815,7 @@ def _openapi_operation_parameters( from aws_lambda_powertools.event_handler.openapi.compat import ( get_schema_from_model_field, ) - from aws_lambda_powertools.event_handler.openapi.params import Param + from aws_lambda_powertools.event_handler.openapi.params import Form, Header, Param, Query parameters = [] parameter: dict[str, Any] = {} @@ -826,32 +826,78 @@ def _openapi_operation_parameters( if not field_info.include_in_schema: continue - param_schema = get_schema_from_model_field( - field=param, - model_name_map=model_name_map, - field_mapping=field_mapping, - ) + # Check if this is a Pydantic model that should be expanded + from pydantic import BaseModel - parameter = { - "name": param.alias, - "in": field_info.in_.value, - "required": param.required, - "schema": param_schema, - } + from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass + + if isinstance(field_info, (Query, Header, Form)) and lenient_issubclass(field_info.annotation, BaseModel): + # Expand Pydantic model into individual parameters + model_class = field_info.annotation - if field_info.description: - parameter["description"] = field_info.description + for field_name, field_def in model_class.model_fields.items(): + # Create individual parameter for each model field + param_name = field_def.alias or field_name - if field_info.openapi_examples: - parameter["examples"] = field_info.openapi_examples + # Convert snake_case to kebab-case for headers (HTTP convention) + if isinstance(field_info, Header): + param_name = param_name.replace("_", "-") - if field_info.deprecated: - parameter["deprecated"] = field_info.deprecated + individual_param = { + "name": param_name, + "in": field_info.in_.value, + "required": field_def.is_required() + if hasattr(field_def, "is_required") + else field_def.default is ..., + "schema": Route._get_basic_type_schema(field_def.annotation), + } + + if field_def.description: + individual_param["description"] = field_def.description + + parameters.append(individual_param) + else: + # Regular parameter processing + param_schema = get_schema_from_model_field( + field=param, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) - parameters.append(parameter) + parameter = { + "name": param.alias, + "in": field_info.in_.value, + "required": param.required, + "schema": param_schema, + } + + if field_info.description: + parameter["description"] = field_info.description + + if field_info.openapi_examples: + parameter["examples"] = field_info.openapi_examples + + if field_info.deprecated: + parameter["deprecated"] = field_info.deprecated + + parameters.append(parameter) return parameters + @staticmethod + def _get_basic_type_schema(param_type: type) -> dict[str, str]: + """ + Get basic OpenAPI schema for simple types + """ + if isinstance(int, param_type): + return {"type": "integer"} + elif isinstance(float, param_type): + return {"type": "number"} + elif isinstance(bool, param_type): + return {"type": "boolean"} + else: + return {"type": "string"} + @staticmethod def _openapi_operation_return( *, diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 6a276de20fb..b8baef145ec 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence from urllib.parse import parse_qs -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler from aws_lambda_powertools.event_handler.openapi.compat import ( @@ -19,7 +19,7 @@ from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, ResponseValidationError -from aws_lambda_powertools.event_handler.openapi.params import Param +from aws_lambda_powertools.event_handler.openapi.params import Header, Param, Query if TYPE_CHECKING: from aws_lambda_powertools.event_handler import Response @@ -69,8 +69,8 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> route.dependant.query_params, ) - # Process query values - query_values, query_errors = _request_params_to_args( + # Process query values (with Pydantic model support) + query_values, query_errors = _request_params_to_args_with_pydantic_support( route.dependant.query_params, query_string, ) @@ -81,8 +81,8 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> route.dependant.header_params, ) - # Process header values - header_values, header_errors = _request_params_to_args( + # Process header values (with Pydantic model support) + header_values, header_errors = _request_params_to_args_with_pydantic_support( route.dependant.header_params, headers, ) @@ -311,6 +311,58 @@ def _prepare_response_content( return res # pragma: no cover +def _request_params_to_args_with_pydantic_support( + required_params: Sequence[ModelField], + received_params: Mapping[str, Any], +) -> tuple[dict[str, Any], list[Any]]: + """ + Convert request params to a dictionary of values with Pydantic model support. + """ + values = {} + errors = [] + + for field in required_params: + field_info = field.field_info + + # Check if this is a Pydantic model in Query/Header + from pydantic import BaseModel + + from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass + + if isinstance(field_info, (Query, Header)) and lenient_issubclass(field_info.annotation, BaseModel): + # Handle Pydantic model - use the same approach as _request_body_to_args + loc = (field_info.in_.value, field.alias) + + # Get the raw data for the Pydantic model + value = received_params.get(field.alias) + + if value is None: + if field.required: + errors.append(get_missing_field_error(loc)) + else: + values[field.name] = deepcopy(field.default) + continue + + else: + # Regular parameter processing (existing logic) + if not isinstance(field_info, Param): + raise AssertionError(f"Expected Param field_info, got {field_info}") + + value = received_params.get(field.alias) + loc = (field_info.in_.value, field.alias) + + if value is None: + if field.required: + errors.append(get_missing_field_error(loc=loc)) + else: + values[field.name] = deepcopy(field.default) + continue + + # Use _validate_field like _request_body_to_args does + values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors) + return values, errors + + def _request_params_to_args( required_params: Sequence[ModelField], received_params: Mapping[str, Any], @@ -439,7 +491,7 @@ def _normalize_multi_query_string_with_param( params: Sequence[ModelField], ) -> dict[str, Any]: """ - Extract and normalize resolved_query_string_parameters + Extract and normalize resolved_query_string_parameters with Pydantic model support Parameters ---------- @@ -453,19 +505,41 @@ def _normalize_multi_query_string_with_param( A dictionary containing the processed multi_query_string_parameters. """ resolved_query_string: dict[str, Any] = query_string - for param in filter(is_scalar_field, params): - try: - # if the target parameter is a scalar, we keep the first value of the query string - # regardless if there are more in the payload - resolved_query_string[param.alias] = query_string[param.alias][0] - except KeyError: - pass + + for param in params: + # Handle scalar fields (existing logic) + if is_scalar_field(param): + try: + resolved_query_string[param.alias] = query_string[param.alias][0] + except KeyError: + pass + # Handle Pydantic models + elif isinstance(param.field_info, Query) and hasattr(param.field_info, "annotation"): + from pydantic import BaseModel + + from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass + + if lenient_issubclass(param.field_info.annotation, BaseModel): + model_class = param.field_info.annotation + model_data = {} + + # Collect all fields for the Pydantic model + for field_name, field_def in model_class.model_fields.items(): + field_alias = field_def.alias or field_name + try: + model_data[field_alias] = query_string[field_alias][0] + except KeyError: + pass + + # Store the collected data under the param alias + resolved_query_string[param.alias] = model_data + return resolved_query_string def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any], params: Sequence[ModelField]): """ - Extract and normalize resolved_headers_field + Extract and normalize resolved_headers_field with Pydantic model support Parameters ---------- @@ -479,12 +553,43 @@ def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any], A dictionary containing the processed headers. """ if headers: - for param in filter(is_scalar_field, params): - try: - if len(headers[param.alias]) == 1: - # if the target parameter is a scalar and the list contains only 1 element - # we keep the first value of the headers regardless if there are more in the payload - headers[param.alias] = headers[param.alias][0] - except KeyError: - pass + for param in params: + # Handle scalar fields (existing logic) + if is_scalar_field(param): + try: + if len(headers[param.alias]) == 1: + headers[param.alias] = headers[param.alias][0] + except KeyError: + pass + # Handle Pydantic models + elif isinstance(param.field_info, Header) and hasattr(param.field_info, "annotation"): + from pydantic import BaseModel + + from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass + + if lenient_issubclass(param.field_info.annotation, BaseModel): + model_class = param.field_info.annotation + model_data = {} + + # Collect all fields for the Pydantic model + for field_name, field_def in model_class.model_fields.items(): + field_alias = field_def.alias or field_name + + # Convert snake_case to kebab-case for headers (HTTP convention) + header_key = field_alias.replace("_", "-") + + try: + header_value = headers[header_key] + if isinstance(header_value, list): + if len(header_value) == 1: + model_data[field_alias] = header_value[0] + else: + model_data[field_alias] = header_value + else: + model_data[field_alias] = header_value + except KeyError: + pass + + # Store the collected data under the param alias + headers[param.alias] = model_data return headers diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index 98a8740a74f..f302300e77e 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -277,6 +277,9 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: return False elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field): return False + elif isinstance(param_field.field_info, (Query, Header)): + # Allow Pydantic models in Query, Header, and Form parameters when explicitly annotated + return False else: if not isinstance(param_field.field_info, Body): raise AssertionError(f"Param: {param_field.name} can only be a request body, use Body()") @@ -307,6 +310,97 @@ def get_flat_params(dependant: Dependant) -> list[ModelField]: ) +def expand_pydantic_model_for_openapi(param_field: ModelField) -> list[ModelField]: + """ + Expands a Pydantic model parameter into individual fields for OpenAPI schema generation. + + Parameters + ---------- + param_field: ModelField + The field containing a Pydantic model + + Returns + ------- + list[ModelField] + List of individual ModelField objects for each field in the Pydantic model + """ + from pydantic import BaseModel + + from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass + + # Check if this is a Pydantic model in Query, Header, or Form + if not ( + isinstance(param_field.field_info, (Query, Header, Form)) + and lenient_issubclass(param_field.field_info.annotation, BaseModel) + ): + return [param_field] + + # Get the Pydantic model class + model_class = param_field.field_info.annotation + field_info_template = param_field.field_info + + expanded_fields = [] + + # Create individual fields for each field in the Pydantic model + for field_name, field_def in model_class.model_fields.items(): + # Create a new field_info for each model field + individual_field_info = type(field_info_template)( + default=field_def.default if field_def.default is not ... else None, + annotation=field_def.annotation, + alias=field_def.alias or field_name, + title=field_def.title, + description=field_def.description, + ) + + # Create the ModelField using the internal function + from aws_lambda_powertools.event_handler.openapi.params import _create_model_field + + individual_field = _create_model_field( + field_info=individual_field_info, + type_annotation=field_def.annotation, + param_name=field_name, + is_path_param=False, + ) + + if individual_field: + expanded_fields.append(individual_field) + + return expanded_fields + + +def get_flat_params_with_pydantic_expansion(dependant: Dependant) -> list[ModelField]: + """ + Get a list of all parameters from a Dependant object, expanding Pydantic models into individual fields. + This is used specifically for OpenAPI schema generation. + + Parameters + ---------- + dependant : Dependant + The Dependant object containing the parameters. + + Returns + ------- + list[ModelField] + A list of ModelField objects with Pydantic models expanded into individual fields. + """ + flat_dependant = get_flat_dependant(dependant) + all_params = ( + flat_dependant.path_params + + flat_dependant.query_params + + flat_dependant.header_params + + flat_dependant.cookie_params + ) + + expanded_params = [] + + for param in all_params: + # Expand Pydantic models into individual fields + expanded_fields = expand_pydantic_model_for_openapi(param) + expanded_params.extend(expanded_fields) + + return expanded_params + + def get_body_field(*, dependant: Dependant, name: str) -> ModelField | None: """ Get the Body field for a given Dependant object. diff --git a/tests/functional/event_handler/_pydantic/test_openapi_params.py b/tests/functional/event_handler/_pydantic/test_openapi_params.py index 19b5287d66a..ccc3b5e4856 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_params.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_params.py @@ -25,6 +25,145 @@ JSON_CONTENT_TYPE = "application/json" +def test_openapi_pydantic_query_params(): + """Test that Pydantic models in Query parameters are expanded into individual fields in OpenAPI schema""" + app = APIGatewayRestResolver() + + class QueryParams(BaseModel): + limit: int = Field(default=10, ge=1, le=100, description="Number of items to return") + offset: int = Field(default=0, ge=0, description="Number of items to skip") + search: Optional[str] = Field(default=None, description="Search term") + + @app.get("/search") + def search_handler(params: Annotated[QueryParams, Query()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + + # Check that the path exists + assert "/search" in schema.paths + path = schema.paths["/search"] + assert path.get is not None + + # Check that parameters are expanded + get_operation = path.get + assert get_operation.parameters is not None + assert len(get_operation.parameters) == 3 + + # Check individual parameters + param_names = [param.name for param in get_operation.parameters] + assert "limit" in param_names + assert "offset" in param_names + assert "search" in param_names + + # Check parameter details + for param in get_operation.parameters: + assert param.in_ == ParameterInType.query + if param.name == "limit": + assert param.required is False # Has default value + assert param.description == "Number of items to return" + assert param.schema_.type == "integer" + elif param.name == "offset": + assert param.required is False # Has default value + assert param.description == "Number of items to skip" + assert param.schema_.type == "integer" + elif param.name == "search": + assert param.required is False # Optional field + assert param.description == "Search term" + assert param.schema_.type == "string" + + +def test_openapi_pydantic_header_params(): + """Test that Pydantic models in Header parameters are expanded into individual fields in OpenAPI schema""" + app = APIGatewayRestResolver() + + class HeaderParams(BaseModel): + authorization: str = Field(description="Authorization token") + user_agent: str = Field(default="PowerTools/1.0", description="User agent") + accept_language: Optional[str] = Field(default=None, alias="accept-language", description="Language preference") + + @app.get("/protected") + def protected_handler(headers: Annotated[HeaderParams, Header()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + + # Check that the path exists + assert "/protected" in schema.paths + path = schema.paths["/protected"] + assert path.get is not None + + # Check that parameters are expanded + get_operation = path.get + assert get_operation.parameters is not None + assert len(get_operation.parameters) == 3 + + # Check individual parameters + param_names = [param.name for param in get_operation.parameters] + assert "authorization" in param_names + assert "user_agent" in param_names + assert "accept-language" in param_names # Should use alias + + # Check parameter details + for param in get_operation.parameters: + assert param.in_ == ParameterInType.header + if param.name == "authorization": + assert param.required is True # No default value + assert param.description == "Authorization token" + assert param.schema_.type == "string" + elif param.name == "user_agent": + assert param.required is False # Has default value + assert param.description == "User agent" + assert param.schema_.type == "string" + elif param.name == "accept-language": + assert param.required is False # Optional field + assert param.description == "Language preference" + assert param.schema_.type == "string" + + +def test_openapi_pydantic_mixed_params(): + """Test that mixed Pydantic models (Query + Header) work together""" + app = APIGatewayRestResolver() + + class QueryParams(BaseModel): + q: str = Field(description="Search query") + limit: int = Field(default=10, description="Number of results") + + class HeaderParams(BaseModel): + authorization: str = Field(description="Bearer token") + + @app.get("/mixed") + def mixed_handler(query: Annotated[QueryParams, Query()], headers: Annotated[HeaderParams, Header()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + + # Check that the path exists + assert "/mixed" in schema.paths + path = schema.paths["/mixed"] + assert path.get is not None + + # Check that all parameters are expanded + get_operation = path.get + assert get_operation.parameters is not None + assert len(get_operation.parameters) == 3 # 2 query + 1 header + + # Check parameter types + query_params = [p for p in get_operation.parameters if p.in_ == ParameterInType.query] + header_params = [p for p in get_operation.parameters if p.in_ == ParameterInType.header] + + assert len(query_params) == 2 + assert len(header_params) == 1 + + # Check specific parameters + query_names = [p.name for p in query_params] + assert "q" in query_names + assert "limit" in query_names + + header_names = [p.name for p in header_params] + assert "authorization" in header_names + + def test_openapi_no_params(): app = APIGatewayRestResolver() @@ -776,3 +915,132 @@ def form_edge_cases( assert "required_field" in component_schema.required assert "optional_field" not in component_schema.required # Optional assert "field_with_default" not in component_schema.required # Has default + + +def test_openapi_pydantic_query_with_constraints(): + """Test that Pydantic field constraints are preserved in OpenAPI schema""" + app = APIGatewayRestResolver() + + class QueryParams(BaseModel): + limit: int = Field(ge=1, le=100, description="Number of items") + name: str = Field(min_length=1, max_length=50, description="Name filter") + + @app.get("/items") + def get_items(params: Annotated[QueryParams, Query()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + path = schema.paths["/items"] + get_operation = path.get + + # Find the limit parameter + limit_param = next(p for p in get_operation.parameters if p.name == "limit") + assert limit_param.schema_.type == "integer" + assert limit_param.description == "Number of items" + + # Find the name parameter + name_param = next(p for p in get_operation.parameters if p.name == "name") + assert name_param.schema_.type == "string" + assert name_param.description == "Name filter" + + +def test_openapi_pydantic_header_with_alias(): + """Test that Pydantic field aliases work correctly in Header parameters""" + app = APIGatewayRestResolver() + + class HeaderParams(BaseModel): + content_type: str = Field(alias="content-type", description="Content type") + user_agent: str = Field(alias="user-agent", description="User agent") + + @app.get("/test") + def test_handler(headers: Annotated[HeaderParams, Header()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + path = schema.paths["/test"] + get_operation = path.get + + # Check that aliases are used as parameter names + param_names = [param.name for param in get_operation.parameters] + assert "content-type" in param_names + assert "user-agent" in param_names + assert "content_type" not in param_names # Original field name should not be used + assert "user_agent" not in param_names + + +def test_openapi_pydantic_required_vs_optional(): + """Test that required vs optional fields are correctly identified""" + app = APIGatewayRestResolver() + + class QueryParams(BaseModel): + required_field: str = Field(description="Required field") + optional_with_default: str = Field(default="default", description="Optional with default") + optional_nullable: Optional[str] = Field(default=None, description="Optional nullable") + + @app.get("/test") + def test_handler(params: Annotated[QueryParams, Query()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + path = schema.paths["/test"] + get_operation = path.get + + for param in get_operation.parameters: + if param.name == "required_field": + assert param.required is True + elif param.name == "optional_with_default": + assert param.required is False + elif param.name == "optional_nullable": + assert param.required is False + + +def test_openapi_pydantic_backward_compatibility(): + """Test that existing Body parameter behavior is unchanged""" + app = APIGatewayRestResolver() + + class BodyModel(BaseModel): + name: str = Field(description="Name") + age: int = Field(description="Age") + + @app.post("/users") + def create_user(user: BodyModel): # No annotation - should work as Body + return {"message": "success"} + + schema = app.get_openapi_schema() + path = schema.paths["/users"] + post_operation = path.post + + # Should have no parameters (body is handled separately) + assert post_operation.parameters is None or len(post_operation.parameters) == 0 + + # Should have request body + assert post_operation.requestBody is not None + assert "application/json" in post_operation.requestBody.content + + +def test_openapi_pydantic_complex_types(): + """Test that complex types are handled correctly""" + app = APIGatewayRestResolver() + + class QueryParams(BaseModel): + string_field: str = Field(description="String field") + int_field: int = Field(description="Integer field") + float_field: float = Field(description="Float field") + bool_field: bool = Field(description="Boolean field") + + @app.get("/complex") + def complex_handler(params: Annotated[QueryParams, Query()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + path = schema.paths["/complex"] + get_operation = path.get + + type_mapping = {} + for param in get_operation.parameters: + type_mapping[param.name] = param.schema_.type + + assert type_mapping["string_field"] == "string" + assert type_mapping["int_field"] == "integer" + assert type_mapping["float_field"] == "number" + assert type_mapping["bool_field"] == "boolean" diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 1fd919b7b71..1dbf2ce1e6b 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -6,7 +6,7 @@ from typing import List, Optional, Tuple import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Field from typing_extensions import Annotated from aws_lambda_powertools.event_handler import ( @@ -47,6 +47,282 @@ def handler(user_id: int): assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) +def test_validate_pydantic_query_params(gw_event): + """Test that Pydantic models in Query parameters are validated correctly""" + + app = APIGatewayRestResolver(enable_validation=True) + del gw_event["multiValueQueryStringParameters"] + + class QueryParams(BaseModel): + limit: int = Field(default=10, ge=1, le=100, description="Number of items") + search: Optional[str] = Field(default=None, description="Search term") + + @app.get("/search") + def search_handler(params: Annotated[QueryParams, Query()]): + return { + "limit": params.limit, + "search": params.search, + } + + # Test valid request + gw_event["path"] = "/search" + gw_event["queryStringParameters"] = {"limit": "25", "search": "python"} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["limit"] == 25 + assert body["search"] == "python" + + # Test with default values + gw_event["queryStringParameters"] = {} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["limit"] == 10 # Default value + assert body["search"] is None # Default value + + # Test validation error (limit too high) + gw_event["queryStringParameters"] = {"limit": "150"} + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("limit" in str(error) for error in body["detail"]) + + +def test_validate_pydantic_query_params_detailed_errors(gw_event): + """Test that Pydantic validation errors include detailed field-level information""" + app = APIGatewayRestResolver(enable_validation=True) + + del gw_event["multiValueHeaders"] + del gw_event["multiValueQueryStringParameters"] + + class QueryParams(BaseModel): + full_name: str = Field(..., min_length=5, description="Full name with minimum 5 characters") + age: int = Field(..., ge=18, le=100, description="Age between 18 and 100") + + @app.get("/query-model") + def query_model(params: Annotated[QueryParams, Query()]): + return {"full_name": params.full_name, "age": params.age} + + # Test validation error with detailed field information + gw_event["path"] = "/query-model" + gw_event["queryStringParameters"] = {"full_name": "Jo", "age": "15"} # Both invalid + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + + # Check that we get detailed field-level errors + errors = body["detail"] + + # Should have errors for both fields + full_name_error = next((e for e in errors if "full_name" in e["loc"]), None) + age_error = next((e for e in errors if "age" in e["loc"]), None) + + assert full_name_error is not None, "Should have error for full_name field" + assert age_error is not None, "Should have error for age field" + + # Check error details for full_name + assert full_name_error["loc"] == ["query", "params", "full_name"] + assert full_name_error["type"] == "string_too_short" + + # Check error details for age + assert age_error["loc"] == ["query", "params", "age"] + assert age_error["type"] == "greater_than_equal" + + +def test_validate_pydantic_header_params(gw_event): + """Test that Pydantic models in Header parameters are validated correctly""" + app = APIGatewayRestResolver(enable_validation=True) + + del gw_event["multiValueHeaders"] + + class HeaderParams(BaseModel): + authorization: str = Field(description="Authorization token") + user_agent: str = Field(default="PowerTools/1.0", description="User agent") + + @app.get("/protected") + def protected_handler(my_headers: Annotated[HeaderParams, Header()]): + return { + "authorization": my_headers.authorization, + "user_agent": my_headers.user_agent, + } + + # Test valid request + gw_event["path"] = "/protected" + gw_event["headers"] = {"authorization": "Bearer token123", "user-agent": "TestClient/1.0"} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["authorization"] == "Bearer token123" + assert body["user_agent"] == "TestClient/1.0" + + # Test with default value + gw_event["headers"] = {"authorization": "Bearer token123"} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["authorization"] == "Bearer token123" + assert body["user_agent"] == "PowerTools/1.0" # Default value + + # Test missing required header + gw_event["headers"] = {} + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("my-headers" in str(error) for error in body["detail"]) + + +def test_validate_pydantic_header_snake_case_to_kebab_case_schema(gw_event): + """Test that snake_case header fields are converted to kebab-case in OpenAPI schema and validation""" + app = APIGatewayRestResolver(enable_validation=True) + app.enable_swagger() + + del gw_event["multiValueHeaders"] + + class HeaderParams(BaseModel): + correlation_id: str = Field(description="Correlation ID header") + user_agent: str = Field(default="PowerTools/1.0", description="User agent header") + + @app.get("/kebab-headers") + def kebab_handler(my_headers: Annotated[HeaderParams, Header()]): + return { + "correlation_id": my_headers.correlation_id, + "user_agent": my_headers.user_agent, + } + + # Test that OpenAPI schema uses kebab-case for headers + openapi_schema = app.get_openapi_schema() + operation = openapi_schema.paths["/kebab-headers"].get + parameters = operation.parameters + + # Find the correlation_id parameter + correlation_param = next((p for p in parameters if p.name == "correlation-id"), None) + assert correlation_param is not None, "Should have correlation-id parameter in kebab-case" + + # Find the user_agent parameter + user_agent_param = next((p for p in parameters if p.name == "user-agent"), None) + assert user_agent_param is not None, "Should have user-agent parameter in kebab-case" + + # Test validation with kebab-case headers + gw_event["path"] = "/kebab-headers" + gw_event["headers"] = {"correlation-id": "test-123", "user-agent": "TestClient/1.0"} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["correlation_id"] == "test-123" + assert body["user_agent"] == "TestClient/1.0" + + +def test_validate_pydantic_mixed_params(gw_event): + """Test that mixed Pydantic models (Query + Header) are validated correctly""" + app = APIGatewayRestResolver(enable_validation=True) + + del gw_event["multiValueHeaders"] + del gw_event["multiValueQueryStringParameters"] + + class QueryParams(BaseModel): + q: str = Field(description="Search query") + limit: int = Field(default=10, description="Number of results") + + class HeaderParams(BaseModel): + authorization: str = Field(description="Bearer token") + + @app.get("/mixed") + def mixed_handler(query: Annotated[QueryParams, Query()], headers: Annotated[HeaderParams, Header()]): + return { + "query": {"q": query.q, "limit": query.limit}, + "headers": {"authorization": headers.authorization}, + } + + # Test valid request + gw_event["path"] = "/mixed" + gw_event["queryStringParameters"] = {"q": "python", "limit": "25"} + gw_event["headers"] = {"authorization": "Bearer token123"} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["query"]["q"] == "python" + assert body["query"]["limit"] == 25 + assert body["headers"]["authorization"] == "Bearer token123" + + # Test missing required query parameter + gw_event["queryStringParameters"] = {"limit": "25"} # Missing 'q' + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("q" in str(error) for error in body["detail"]) + + # Test missing required header + gw_event["queryStringParameters"] = {"q": "python"} + gw_event["headers"] = {} # Missing 'authorization' + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("headers" in str(error) for error in body["detail"]) + + +def test_validate_pydantic_with_alias(gw_event): + """Test that Pydantic models with field aliases work correctly""" + app = APIGatewayRestResolver(enable_validation=True) + + del gw_event["multiValueHeaders"] + del gw_event["multiValueQueryStringParameters"] + + class HeaderParams(BaseModel): + accept_language: str = Field(alias="accept-language", description="Language preference") + + @app.get("/alias") + def alias_handler(headers: Annotated[HeaderParams, Header()]): + return {"accept_language": headers.accept_language} + + # Test with alias in request + gw_event["path"] = "/alias" + gw_event["headers"] = {"accept-language": "en-US"} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["accept_language"] == "en-US" + + # Test missing aliased field + gw_event["headers"] = {} + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("headers" in str(error) for error in body["detail"]) + + def test_validate_scalars_with_default(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -1983,3 +2259,131 @@ def get_user(user_id: int) -> UserModel: assert response_body["name"] == "User123" assert response_body["age"] == 143 assert response_body["email"] == "user123@example.com" + + +def test_validate_pydantic_query_params_with_config_dict_and_validators(gw_event): + """Test that Pydantic models with ConfigDict, aliases, and validators work correctly""" + from typing import Any + + from pydantic import UUID4, AfterValidator, Base64UrlStr, ConfigDict, StringConstraints, alias_generators + + del gw_event["multiValueHeaders"] + del gw_event["multiValueQueryStringParameters"] + + app = APIGatewayRestResolver(enable_validation=True) + app.enable_swagger(path="/swagger") + + def _validate_powertools(value: str) -> str: + if not value.startswith("Powertools"): + raise ValueError("Full name must start with 'Powertools'") + return value + + class QuerySimple(BaseModel): + full_name: Annotated[str, StringConstraints(min_length=5), AfterValidator(_validate_powertools)] + next_token: Base64UrlStr + search_id: str + + @app.get("/query-model-simple") + def query_model(params: Annotated[QuerySimple, Query()]) -> dict[str, Any]: + return { + "fullName": params.full_name, + "nextToken": params.next_token, + "searchId": params.search_id, + } + + class QueryAdvanced(BaseModel): + full_name: Annotated[str, StringConstraints(min_length=5)] + next_token: str + search_id: Annotated[str, Field(alias="id")] # Using str instead of UUID4 for simpler testing + + model_config = ConfigDict( + alias_generator=alias_generators.to_camel, + validate_by_alias=True, + serialize_by_alias=True, + ) + + @app.get("/query-model-advanced") + def query_model_advanced(params: Annotated[QueryAdvanced, Query()]) -> dict[str, Any]: + return params.model_dump() + + # Test QuerySimple with validators + gw_event["path"] = "/query-model-simple" + gw_event["queryStringParameters"] = { + "full_name": "Powertools Lambda", + "next_token": "dGVzdA==", # base64url encoded "test" + "search_id": "search-123", + } + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["fullName"] == "Powertools Lambda" + assert body["nextToken"] == "test" + assert body["searchId"] == "search-123" + + # Test QuerySimple validation error (name doesn't start with "Powertools") + gw_event["queryStringParameters"] = { + "full_name": "Lambda Powertools", # Wrong order + "next_token": "dGVzdA==", + "search_id": "search-123", + } + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + errors = body["detail"] + + # Test QueryAdvanced with ConfigDict and alias_generator + gw_event["path"] = "/query-model-advanced" + gw_event["queryStringParameters"] = { + "fullName": "Advanced Test", # camelCase from alias_generator + "nextToken": "dGVzdA==", # camelCase from alias_generator + "id": "search-456", # explicit alias + } + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + # Should return with camelCase keys due to serialize_by_alias=True + assert body["fullName"] == "Advanced Test" + assert body["nextToken"] == "dGVzdA==" + assert body["id"] == "search-456" + + # Test QueryAdvanced with snake_case field names (should also work due to populate_by_name behavior) + gw_event["queryStringParameters"] = { + "full_name": "Snake Case Test", # snake_case field name + "next_token": "dGVzdA==", # snake_case field name + "id": "search-789", # explicit alias + } + + gw_event["path"] = "/query-model-advanced" + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["fullName"] == "Snake Case Test" + assert body["nextToken"] == "token789" + assert body["id"] == "search-789" + + # Test QueryAdvanced validation error (full_name too short) + gw_event["queryStringParameters"] = { + "fullName": "Bad", # Too short (min_length=5) + "nextToken": "dGVzdA==", + "id": "search-456", + } + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + errors = body["detail"] + + # Should have validation error for full_name with proper location + full_name_error = next((e for e in errors if "full_name" in e["loc"] or "fullName" in e["loc"]), None) + assert full_name_error is not None + assert full_name_error["type"] == "string_too_short"