Skip to content

DRAFT: Add middleware to expose raw headers in contextvar #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/north_mcp_python_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware

from .auth import AuthContextMiddleware, NorthAuthBackend, on_auth_error
from .auth import (
AuthContextMiddleware,
HeadersContextMiddleware,
NorthAuthBackend,
on_auth_error,
)


def is_debug_mode() -> bool:
Expand Down Expand Up @@ -66,5 +71,6 @@ def _add_middleware(self, app: Starlette) -> None:
on_error=on_auth_error,
),
Middleware(AuthContextMiddleware, debug=self._debug),
Middleware(HeadersContextMiddleware, debug=self._debug),
]
app.user_middleware.extend(middleware)
39 changes: 39 additions & 0 deletions src/north_mcp_python_sdk/auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import contextvars
import logging
from typing import Any

import jwt
from pydantic import BaseModel, Field, ValidationError
Expand Down Expand Up @@ -35,6 +36,10 @@ def __init__(
"north_auth_context", default=None
)

headers_context_var = contextvars.ContextVar[dict[str, Any] | None](
"north_headers_context", default=None
)


def on_auth_error(request: HTTPConnection, exc: AuthenticationError) -> JSONResponse:
return JSONResponse({"error": str(exc)}, status_code=401)
Expand All @@ -48,6 +53,14 @@ def get_authenticated_user() -> AuthenticatedNorthUser:
return user


def get_raw_headers() -> dict[str, Any]:
headers = headers_context_var.get()
if not headers:
raise Exception("headers not found in context")

return headers


class AuthContextMiddleware:
"""
Middleware that extracts the authenticated user from the request
Expand Down Expand Up @@ -83,6 +96,32 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
auth_context_var.reset(token)


class HeadersContextMiddleware:
"""
Middleware that sets the request headers in a contextvar for easy access
throughout the request lifecycle.
"""

def __init__(self, app: ASGIApp, debug: bool = False):
self.app = app
self.debug = debug
self.logger = logging.getLogger("NorthMCP.HeadersContext")
if debug:
self.logger.setLevel(logging.DEBUG)

async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] == "lifespan":
return await self.app(scope, receive, send)

headers = dict(scope.get("headers", {}))
self.logger.debug("Setting request headers in context: %s", headers)
token = headers_context_var.set(headers)
try:
await self.app(scope, receive, send)
finally:
headers_context_var.reset(token)


class NorthAuthBackend(AuthenticationBackend):
"""
Authentication backend that validates Bearer tokens.
Expand Down