diff --git a/ddtrace/appsec/_handlers.py b/ddtrace/appsec/_handlers.py index c88b2323e0c..11dbfcbe088 100644 --- a/ddtrace/appsec/_handlers.py +++ b/ddtrace/appsec/_handlers.py @@ -1,14 +1,22 @@ import io import json +from typing import Any +from typing import Dict +from typing import Optional import xmltodict +from ddtrace._trace.span import Span +from ddtrace.appsec._asm_request_context import _call_waf +from ddtrace.appsec._asm_request_context import _call_waf_first from ddtrace.appsec._asm_request_context import get_blocked from ddtrace.appsec._constants import SPAN_DATA_NAMES +from ddtrace.appsec._http_utils import extract_cookies_from_headers +from ddtrace.appsec._http_utils import normalize_headers +from ddtrace.appsec._http_utils import parse_http_body from ddtrace.contrib import trace_utils from ddtrace.contrib.internal.trace_utils_base import _get_request_header_user_agent from ddtrace.contrib.internal.trace_utils_base import _set_url_tag -from ddtrace.ext import SpanTypes from ddtrace.ext import http from ddtrace.internal import core from ddtrace.internal.constants import RESPONSE_HEADERS @@ -53,7 +61,7 @@ def _on_set_http_meta( response_headers, response_cookies, ): - if asm_config._asm_enabled and span.span_type == SpanTypes.WEB: + if asm_config._asm_enabled and span.span_type in asm_config._asm_http_span_types: # avoid circular import from ddtrace.appsec._asm_request_context import set_waf_address @@ -77,6 +85,74 @@ def _on_set_http_meta( set_waf_address(k, v) +# AWS Lambda +def _on_lambda_start_request( + span: Span, + request_headers: Dict[str, str], + request_ip: Optional[str], + body: Optional[str], + is_body_base64: bool, + raw_uri: str, + route: str, + method: str, + parsed_query: Dict[str, Any], +): + if not (asm_config._asm_enabled and span.span_type in asm_config._asm_http_span_types): + return + + headers = normalize_headers(request_headers) + request_body = parse_http_body(headers, body, is_body_base64) + request_cookies = extract_cookies_from_headers(headers) + + _on_set_http_meta( + span, + request_ip, + raw_uri, + route, + method, + headers, + request_cookies, + parsed_query, + None, + request_body, + None, + None, + None, + ) + + _call_waf_first(("aws_lambda",)) + + +def _on_lambda_start_response( + span: Span, + status_code: str, + response_headers: Dict[str, str], +): + if not (asm_config._asm_enabled and span.span_type in asm_config._asm_http_span_types): + return + + waf_headers = normalize_headers(response_headers) + response_cookies = extract_cookies_from_headers(waf_headers) + + _on_set_http_meta( + span, + None, + None, + None, + None, + None, + None, + None, + None, + None, + status_code, + waf_headers, + response_cookies, + ) + + _call_waf(("aws_lambda",)) + + # ASGI @@ -307,6 +383,9 @@ def listen(): core.on("asgi.request.parse.body", _on_asgi_request_parse_body, "await_receive_and_body") + core.on("aws_lambda.start_request", _on_lambda_start_request) + core.on("aws_lambda.start_response", _on_lambda_start_response) + core.on("grpc.server.response.message", _on_grpc_server_response) core.on("grpc.server.data", _on_grpc_server_data) diff --git a/ddtrace/appsec/_http_utils.py b/ddtrace/appsec/_http_utils.py new file mode 100644 index 00000000000..15d7814942b --- /dev/null +++ b/ddtrace/appsec/_http_utils.py @@ -0,0 +1,81 @@ +import base64 +from http.cookies import SimpleCookie +import json +from typing import Any +from typing import Dict +from typing import Optional +from typing import Union +from urllib.parse import parse_qs + +import xmltodict + +from ddtrace.internal.utils import http as http_utils + + +def normalize_headers( + request_headers: Dict[str, str], +) -> Dict[str, Optional[str]]: + """Normalize headers according to the WAF expectations. + + The WAF expects headers to be lowercased and empty values to be None. + """ + headers: Dict[str, Optional[str]] = {} + for key, value in request_headers.items(): + normalized_key = http_utils.normalize_header_name(key) + if value: + headers[normalized_key] = str(value).strip() + else: + headers[normalized_key] = None + return headers + + +def parse_http_body( + normalized_headers: Dict[str, Optional[str]], + body: Optional[str], + is_body_base64: bool, +) -> Union[str, Dict[str, Any], None]: + """Parse a request body based on the content-type header.""" + if body is None: + return None + if is_body_base64: + try: + body = base64.b64decode(body).decode() + except (ValueError, TypeError): + return None + + try: + content_type = normalized_headers.get("content-type") + if not content_type: + return None + + if content_type in ("application/json", "application/vnd.api+json", "text/json"): + return json.loads(body) + elif content_type in ("application/x-url-encoded", "application/x-www-form-urlencoded"): + return parse_qs(body) + elif content_type in ("application/xml", "text/xml"): + return xmltodict.parse(body) + elif content_type.startswith("multipart/form-data"): + return http_utils.parse_form_multipart(body, normalized_headers) + elif content_type == "text/plain": + return body + else: + return None + + except Exception: + return None + + +def extract_cookies_from_headers( + normalized_headers: Dict[str, Optional[str]], +) -> Optional[Dict[str, str]]: + """Extract cookies from the WAF headers.""" + cookie_names = {"cookie", "set-cookie"} + for name in cookie_names: + if name in normalized_headers: + cookie = SimpleCookie() + header = normalized_headers[name] + del normalized_headers[name] + if header: + cookie.load(header) + return {k: v.value for k, v in cookie.items()} + return None diff --git a/ddtrace/appsec/_processor.py b/ddtrace/appsec/_processor.py index 34472c4dfb6..4e5d029e6ef 100644 --- a/ddtrace/appsec/_processor.py +++ b/ddtrace/appsec/_processor.py @@ -35,7 +35,6 @@ from ddtrace.appsec._utils import DDWaf_result from ddtrace.constants import _ORIGIN_KEY from ddtrace.constants import _RUNTIME_FAMILY -from ddtrace.ext import SpanTypes from ddtrace.internal._unpatched import unpatched_open as open # noqa: A004 from ddtrace.internal.logger import get_logger from ddtrace.internal.rate_limiter import RateLimiter @@ -232,9 +231,6 @@ def _waf_action( be retrieved from the `core`. This can be used when you don't want to store the value in the `core` before checking the `WAF`. """ - if span.span_type not in (SpanTypes.WEB, SpanTypes.HTTP, SpanTypes.GRPC): - return None - if _asm_request_context.get_blocked(): # We still must run the waf if we need to extract schemas for API SECURITY if not custom_data or not custom_data.get("PROCESSOR_SETTINGS", {}).get("extract-schema", False): @@ -365,7 +361,7 @@ def _is_needed(self, address: str) -> bool: return address in self._addresses_to_keep def on_span_finish(self, span: Span) -> None: - if span.span_type in {SpanTypes.WEB, SpanTypes.GRPC}: + if span.span_type in asm_config._asm_processed_span_types: _asm_request_context.call_waf_callback_no_instrumentation() self._ddwaf._at_request_end() _asm_request_context.end_context(span) diff --git a/ddtrace/settings/asm.py b/ddtrace/settings/asm.py index 085c3ec4f9e..73809bef89b 100644 --- a/ddtrace/settings/asm.py +++ b/ddtrace/settings/asm.py @@ -66,6 +66,7 @@ class ASMConfig(DDConfig): if _asm_static_rule_file == "": _asm_static_rule_file = None _asm_processed_span_types = {SpanTypes.WEB, SpanTypes.GRPC} + _asm_http_span_types = {SpanTypes.WEB} _iast_enabled = tracer_config._from_endpoint.get("iast_enabled", DDConfig.var(bool, IAST.ENV, default=False)) _iast_request_sampling = DDConfig.var(float, IAST.ENV_REQUEST_SAMPLING, default=30.0) _iast_debug = DDConfig.var(bool, IAST.ENV_DEBUG, default=False, private=True) @@ -230,6 +231,7 @@ def __init__(self): if in_aws_lambda(): self._asm_processed_span_types.add(SpanTypes.SERVERLESS) + self._asm_http_span_types.add(SpanTypes.SERVERLESS) # As a first step, only Threat Management in monitoring mode should be enabled in AWS Lambda tracer_config._remote_config_enabled = False diff --git a/tests/appsec/appsec/test_appsec_http_utils.py b/tests/appsec/appsec/test_appsec_http_utils.py new file mode 100644 index 00000000000..f08098bc6a6 --- /dev/null +++ b/tests/appsec/appsec/test_appsec_http_utils.py @@ -0,0 +1,118 @@ +import pytest + +from ddtrace.appsec import _http_utils + + +@pytest.mark.parametrize( + "input_headers, expected", + [ + ({"Host": "Example.COM"}, {"host": "Example.COM"}), + ( + {"X-Custom-None": "", "Content-Type": "application/json", "X-Custom-Spacing ": " trim spaces "}, + {"x-custom-none": None, "content-type": "application/json", "x-custom-spacing": "trim spaces"}, + ), + ], +) +def test_normalize_headers(input_headers, expected): + result = _http_utils.normalize_headers(input_headers) + assert result == expected + + +@pytest.mark.parametrize( + "headers, body, is_body_base64, expected_output", + [ + # Body is None + ({}, None, False, None), + # Base64 encoded body - text/plain + ( + {"content-type": "text/plain"}, + "dGV4dCBib2R5", + True, + "text body", + ), + # Base64 encoded body - application/json + ( + {"content-type": "application/json"}, + "eyJrZXkiOiAidmFsdWUifQ==", + True, + {"key": "value"}, + ), + # Base64 decoding failure - text/plain + ( + {"content-type": "text/plain"}, + "invalid_base64_string", + True, + None, + ), + # JSON content types + ({"content-type": "application/json"}, '{"key": "value"}', False, {"key": "value"}), + ({"content-type": "application/vnd.api+json"}, '{"key": "value"}', False, {"key": "value"}), + ({"content-type": "text/json"}, '{"key": "value"}', False, {"key": "value"}), + # Form urlencoded + ( + {"content-type": "application/x-www-form-urlencoded"}, + "key=value&key2=value2", + False, + {"key": ["value"], "key2": ["value2"]}, + ), + # XML content types + ({"content-type": "application/xml"}, "value", False, {"root": {"key": "value"}}), + ({"content-type": "text/xml"}, "value", False, {"root": {"key": "value"}}), + # Text plain + ({"content-type": "text/plain"}, "simple text body", False, "simple text body"), + # Unsupported content type + ({"content-type": "application/octet-stream"}, "binary data", False, None), + # No content type provided + ({}, "some body", False, None), + # Invalid JSON + ({"content-type": "application/json"}, "not a valid json string", False, None), + # Invalid XML + ({"content-type": "application/xml"}, "value", False, None), + # Multipart form data + ( + {"content-type": "multipart/form-data; boundary=boundary"}, + ( + "--boundary\r\n" + 'Content-Disposition: form-data; name="formPart"\r\n' + "content-type: application/x-www-form-urlencoded\r\n" + "\r\n" + "key=value\r\n" + "--boundary--" + ), + False, + {"formPart": {"key": ["value"]}}, # Mocked return value for parse_form_multipart + ), + # Invalid base64 encoded body (decoding fails) + ( + {"content-type": "application/xml"}, + "invalid_base64_and_invalid_xml", + True, + None, + ), + ], +) +def test_parse_http_body(headers, body, is_body_base64, expected_output, mocker): + result = _http_utils.parse_http_body(headers, body, is_body_base64) + assert result == expected_output + + +@pytest.mark.parametrize( + "input_headers, expected", + [ + ( + {"cookie": "sessionid=abc123; csrftoken=xyz789"}, + {"sessionid": "abc123", "csrftoken": "xyz789"}, + ), + ( + {"set-cookie": "sessionid=abc123; Path=/; HttpOnly"}, + {"sessionid": "abc123"}, + ), + ({"cookie": ""}, {}), + ({"cookie": None}, {}), + ({"set-cookie": None}, {}), + ], +) +# Tests for extract_cookies_from_headers +def test_extract_cookies_from_headers(input_headers, expected): + result = _http_utils.extract_cookies_from_headers(input_headers) + assert result == expected