diff --git a/ddtrace/appsec/_handlers.py b/ddtrace/appsec/_handlers.py
index c88b2323e0c..11dbfcbe088 100644
--- a/ddtrace/appsec/_handlers.py
+++ b/ddtrace/appsec/_handlers.py
@@ -1,14 +1,22 @@
import io
import json
+from typing import Any
+from typing import Dict
+from typing import Optional
import xmltodict
+from ddtrace._trace.span import Span
+from ddtrace.appsec._asm_request_context import _call_waf
+from ddtrace.appsec._asm_request_context import _call_waf_first
from ddtrace.appsec._asm_request_context import get_blocked
from ddtrace.appsec._constants import SPAN_DATA_NAMES
+from ddtrace.appsec._http_utils import extract_cookies_from_headers
+from ddtrace.appsec._http_utils import normalize_headers
+from ddtrace.appsec._http_utils import parse_http_body
from ddtrace.contrib import trace_utils
from ddtrace.contrib.internal.trace_utils_base import _get_request_header_user_agent
from ddtrace.contrib.internal.trace_utils_base import _set_url_tag
-from ddtrace.ext import SpanTypes
from ddtrace.ext import http
from ddtrace.internal import core
from ddtrace.internal.constants import RESPONSE_HEADERS
@@ -53,7 +61,7 @@ def _on_set_http_meta(
response_headers,
response_cookies,
):
- if asm_config._asm_enabled and span.span_type == SpanTypes.WEB:
+ if asm_config._asm_enabled and span.span_type in asm_config._asm_http_span_types:
# avoid circular import
from ddtrace.appsec._asm_request_context import set_waf_address
@@ -77,6 +85,74 @@ def _on_set_http_meta(
set_waf_address(k, v)
+# AWS Lambda
+def _on_lambda_start_request(
+ span: Span,
+ request_headers: Dict[str, str],
+ request_ip: Optional[str],
+ body: Optional[str],
+ is_body_base64: bool,
+ raw_uri: str,
+ route: str,
+ method: str,
+ parsed_query: Dict[str, Any],
+):
+ if not (asm_config._asm_enabled and span.span_type in asm_config._asm_http_span_types):
+ return
+
+ headers = normalize_headers(request_headers)
+ request_body = parse_http_body(headers, body, is_body_base64)
+ request_cookies = extract_cookies_from_headers(headers)
+
+ _on_set_http_meta(
+ span,
+ request_ip,
+ raw_uri,
+ route,
+ method,
+ headers,
+ request_cookies,
+ parsed_query,
+ None,
+ request_body,
+ None,
+ None,
+ None,
+ )
+
+ _call_waf_first(("aws_lambda",))
+
+
+def _on_lambda_start_response(
+ span: Span,
+ status_code: str,
+ response_headers: Dict[str, str],
+):
+ if not (asm_config._asm_enabled and span.span_type in asm_config._asm_http_span_types):
+ return
+
+ waf_headers = normalize_headers(response_headers)
+ response_cookies = extract_cookies_from_headers(waf_headers)
+
+ _on_set_http_meta(
+ span,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ status_code,
+ waf_headers,
+ response_cookies,
+ )
+
+ _call_waf(("aws_lambda",))
+
+
# ASGI
@@ -307,6 +383,9 @@ def listen():
core.on("asgi.request.parse.body", _on_asgi_request_parse_body, "await_receive_and_body")
+ core.on("aws_lambda.start_request", _on_lambda_start_request)
+ core.on("aws_lambda.start_response", _on_lambda_start_response)
+
core.on("grpc.server.response.message", _on_grpc_server_response)
core.on("grpc.server.data", _on_grpc_server_data)
diff --git a/ddtrace/appsec/_http_utils.py b/ddtrace/appsec/_http_utils.py
new file mode 100644
index 00000000000..15d7814942b
--- /dev/null
+++ b/ddtrace/appsec/_http_utils.py
@@ -0,0 +1,81 @@
+import base64
+from http.cookies import SimpleCookie
+import json
+from typing import Any
+from typing import Dict
+from typing import Optional
+from typing import Union
+from urllib.parse import parse_qs
+
+import xmltodict
+
+from ddtrace.internal.utils import http as http_utils
+
+
+def normalize_headers(
+ request_headers: Dict[str, str],
+) -> Dict[str, Optional[str]]:
+ """Normalize headers according to the WAF expectations.
+
+ The WAF expects headers to be lowercased and empty values to be None.
+ """
+ headers: Dict[str, Optional[str]] = {}
+ for key, value in request_headers.items():
+ normalized_key = http_utils.normalize_header_name(key)
+ if value:
+ headers[normalized_key] = str(value).strip()
+ else:
+ headers[normalized_key] = None
+ return headers
+
+
+def parse_http_body(
+ normalized_headers: Dict[str, Optional[str]],
+ body: Optional[str],
+ is_body_base64: bool,
+) -> Union[str, Dict[str, Any], None]:
+ """Parse a request body based on the content-type header."""
+ if body is None:
+ return None
+ if is_body_base64:
+ try:
+ body = base64.b64decode(body).decode()
+ except (ValueError, TypeError):
+ return None
+
+ try:
+ content_type = normalized_headers.get("content-type")
+ if not content_type:
+ return None
+
+ if content_type in ("application/json", "application/vnd.api+json", "text/json"):
+ return json.loads(body)
+ elif content_type in ("application/x-url-encoded", "application/x-www-form-urlencoded"):
+ return parse_qs(body)
+ elif content_type in ("application/xml", "text/xml"):
+ return xmltodict.parse(body)
+ elif content_type.startswith("multipart/form-data"):
+ return http_utils.parse_form_multipart(body, normalized_headers)
+ elif content_type == "text/plain":
+ return body
+ else:
+ return None
+
+ except Exception:
+ return None
+
+
+def extract_cookies_from_headers(
+ normalized_headers: Dict[str, Optional[str]],
+) -> Optional[Dict[str, str]]:
+ """Extract cookies from the WAF headers."""
+ cookie_names = {"cookie", "set-cookie"}
+ for name in cookie_names:
+ if name in normalized_headers:
+ cookie = SimpleCookie()
+ header = normalized_headers[name]
+ del normalized_headers[name]
+ if header:
+ cookie.load(header)
+ return {k: v.value for k, v in cookie.items()}
+ return None
diff --git a/ddtrace/appsec/_processor.py b/ddtrace/appsec/_processor.py
index 34472c4dfb6..4e5d029e6ef 100644
--- a/ddtrace/appsec/_processor.py
+++ b/ddtrace/appsec/_processor.py
@@ -35,7 +35,6 @@
from ddtrace.appsec._utils import DDWaf_result
from ddtrace.constants import _ORIGIN_KEY
from ddtrace.constants import _RUNTIME_FAMILY
-from ddtrace.ext import SpanTypes
from ddtrace.internal._unpatched import unpatched_open as open # noqa: A004
from ddtrace.internal.logger import get_logger
from ddtrace.internal.rate_limiter import RateLimiter
@@ -232,9 +231,6 @@ def _waf_action(
be retrieved from the `core`. This can be used when you don't want to store
the value in the `core` before checking the `WAF`.
"""
- if span.span_type not in (SpanTypes.WEB, SpanTypes.HTTP, SpanTypes.GRPC):
- return None
-
if _asm_request_context.get_blocked():
# We still must run the waf if we need to extract schemas for API SECURITY
if not custom_data or not custom_data.get("PROCESSOR_SETTINGS", {}).get("extract-schema", False):
@@ -365,7 +361,7 @@ def _is_needed(self, address: str) -> bool:
return address in self._addresses_to_keep
def on_span_finish(self, span: Span) -> None:
- if span.span_type in {SpanTypes.WEB, SpanTypes.GRPC}:
+ if span.span_type in asm_config._asm_processed_span_types:
_asm_request_context.call_waf_callback_no_instrumentation()
self._ddwaf._at_request_end()
_asm_request_context.end_context(span)
diff --git a/ddtrace/settings/asm.py b/ddtrace/settings/asm.py
index 085c3ec4f9e..73809bef89b 100644
--- a/ddtrace/settings/asm.py
+++ b/ddtrace/settings/asm.py
@@ -66,6 +66,7 @@ class ASMConfig(DDConfig):
if _asm_static_rule_file == "":
_asm_static_rule_file = None
_asm_processed_span_types = {SpanTypes.WEB, SpanTypes.GRPC}
+ _asm_http_span_types = {SpanTypes.WEB}
_iast_enabled = tracer_config._from_endpoint.get("iast_enabled", DDConfig.var(bool, IAST.ENV, default=False))
_iast_request_sampling = DDConfig.var(float, IAST.ENV_REQUEST_SAMPLING, default=30.0)
_iast_debug = DDConfig.var(bool, IAST.ENV_DEBUG, default=False, private=True)
@@ -230,6 +231,7 @@ def __init__(self):
if in_aws_lambda():
self._asm_processed_span_types.add(SpanTypes.SERVERLESS)
+ self._asm_http_span_types.add(SpanTypes.SERVERLESS)
# As a first step, only Threat Management in monitoring mode should be enabled in AWS Lambda
tracer_config._remote_config_enabled = False
diff --git a/tests/appsec/appsec/test_appsec_http_utils.py b/tests/appsec/appsec/test_appsec_http_utils.py
new file mode 100644
index 00000000000..f08098bc6a6
--- /dev/null
+++ b/tests/appsec/appsec/test_appsec_http_utils.py
@@ -0,0 +1,118 @@
+import pytest
+
+from ddtrace.appsec import _http_utils
+
+
+@pytest.mark.parametrize(
+ "input_headers, expected",
+ [
+ ({"Host": "Example.COM"}, {"host": "Example.COM"}),
+ (
+ {"X-Custom-None": "", "Content-Type": "application/json", "X-Custom-Spacing ": " trim spaces "},
+ {"x-custom-none": None, "content-type": "application/json", "x-custom-spacing": "trim spaces"},
+ ),
+ ],
+)
+def test_normalize_headers(input_headers, expected):
+ result = _http_utils.normalize_headers(input_headers)
+ assert result == expected
+
+
+@pytest.mark.parametrize(
+ "headers, body, is_body_base64, expected_output",
+ [
+ # Body is None
+ ({}, None, False, None),
+ # Base64 encoded body - text/plain
+ (
+ {"content-type": "text/plain"},
+ "dGV4dCBib2R5",
+ True,
+ "text body",
+ ),
+ # Base64 encoded body - application/json
+ (
+ {"content-type": "application/json"},
+ "eyJrZXkiOiAidmFsdWUifQ==",
+ True,
+ {"key": "value"},
+ ),
+ # Base64 decoding failure - text/plain
+ (
+ {"content-type": "text/plain"},
+ "invalid_base64_string",
+ True,
+ None,
+ ),
+ # JSON content types
+ ({"content-type": "application/json"}, '{"key": "value"}', False, {"key": "value"}),
+ ({"content-type": "application/vnd.api+json"}, '{"key": "value"}', False, {"key": "value"}),
+ ({"content-type": "text/json"}, '{"key": "value"}', False, {"key": "value"}),
+ # Form urlencoded
+ (
+ {"content-type": "application/x-www-form-urlencoded"},
+ "key=value&key2=value2",
+ False,
+ {"key": ["value"], "key2": ["value2"]},
+ ),
+ # XML content types
+ ({"content-type": "application/xml"}, "value", False, {"root": {"key": "value"}}),
+ ({"content-type": "text/xml"}, "value", False, {"root": {"key": "value"}}),
+ # Text plain
+ ({"content-type": "text/plain"}, "simple text body", False, "simple text body"),
+ # Unsupported content type
+ ({"content-type": "application/octet-stream"}, "binary data", False, None),
+ # No content type provided
+ ({}, "some body", False, None),
+ # Invalid JSON
+ ({"content-type": "application/json"}, "not a valid json string", False, None),
+ # Invalid XML
+ ({"content-type": "application/xml"}, "value", False, None),
+ # Multipart form data
+ (
+ {"content-type": "multipart/form-data; boundary=boundary"},
+ (
+ "--boundary\r\n"
+ 'Content-Disposition: form-data; name="formPart"\r\n'
+ "content-type: application/x-www-form-urlencoded\r\n"
+ "\r\n"
+ "key=value\r\n"
+ "--boundary--"
+ ),
+ False,
+ {"formPart": {"key": ["value"]}}, # Mocked return value for parse_form_multipart
+ ),
+ # Invalid base64 encoded body (decoding fails)
+ (
+ {"content-type": "application/xml"},
+ "invalid_base64_and_invalid_xml",
+ True,
+ None,
+ ),
+ ],
+)
+def test_parse_http_body(headers, body, is_body_base64, expected_output, mocker):
+ result = _http_utils.parse_http_body(headers, body, is_body_base64)
+ assert result == expected_output
+
+
+@pytest.mark.parametrize(
+ "input_headers, expected",
+ [
+ (
+ {"cookie": "sessionid=abc123; csrftoken=xyz789"},
+ {"sessionid": "abc123", "csrftoken": "xyz789"},
+ ),
+ (
+ {"set-cookie": "sessionid=abc123; Path=/; HttpOnly"},
+ {"sessionid": "abc123"},
+ ),
+ ({"cookie": ""}, {}),
+ ({"cookie": None}, {}),
+ ({"set-cookie": None}, {}),
+ ],
+)
+# Tests for extract_cookies_from_headers
+def test_extract_cookies_from_headers(input_headers, expected):
+ result = _http_utils.extract_cookies_from_headers(input_headers)
+ assert result == expected