Skip to content

[Draft][develop] Upgrade connexion to 2.15.0rc3, upgrade Werkzeug to 3.1.3 to address CVE #6932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions cli/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@ aws-cdk.core~=1.164
aws_cdk.aws-cloudwatch~=1.164
aws_cdk.aws-lambda~=1.164
boto3>=1.16.14
connexion~=2.13.0
flask>=2.2.5,<2.3
jinja2~=3.0
jmespath~=0.10
jsii==1.85.0
marshmallow~=3.10
PyYAML>=5.3.1,!=5.4
tabulate>=0.8.8,<=0.8.10
werkzeug~=2.0
connexion~=2.15.0rc3
werkzeug~=3.0
flask~=3.0
requests
jsonschema
inflection
packaging~=25.0
10 changes: 7 additions & 3 deletions cli/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,15 @@ def readme():
"aws-cdk.aws-ssm~=" + CDK_VERSION,
"aws-cdk.aws-sqs~=" + CDK_VERSION,
"aws-cdk.aws-cloudformation~=" + CDK_VERSION,
"werkzeug~=2.0",
"connexion~=2.13.0",
"flask>=2.2.5,<2.3",
"connexion~=2.15.0rc3",
"jmespath~=0.10",
"jsii==1.85.0",
"werkzeug~=3.0",
"flask~=3.0",
"requests",
"jsonschema",
"inflection",
"packaging~=25.0",
]

LAMBDA_REQUIRES = [
Expand Down
71 changes: 46 additions & 25 deletions cli/src/pcluster/api/awslambda/serverless_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import json
import os
import sys
from urllib.parse import unquote, unquote_plus, urlencode

from werkzeug.datastructures import Headers, MultiDict, iter_multi_items
from werkzeug.datastructures import Headers, iter_multi_items
from werkzeug.http import HTTP_STATUS_CODES
from werkzeug.urls import url_encode, url_unquote, url_unquote_plus
from werkzeug.wrappers import Response

# List of MIME types that should not be base64 encoded. MIME types within `text/*`
Expand Down Expand Up @@ -95,8 +95,8 @@ def encode_query_string(event):
if not params:
params = ""
if is_alb_event(event):
params = MultiDict((url_unquote_plus(k), url_unquote_plus(v)) for k, v in iter_multi_items(params))
return url_encode(params)
params = [(unquote_plus(k), unquote_plus(v)) for k, v in iter_multi_items(params)]
return urlencode(params, doseq=True)


def get_script_name(headers, request_context):
Expand All @@ -108,7 +108,7 @@ def get_script_name(headers, request_context):
"1",
]

if headers.get("Host", "").endswith(".amazonaws.com") and not strip_stage_path:
if "amazonaws.com" in headers.get("Host", "") and not strip_stage_path:
script_name = "/{}".format(request_context.get("stage", ""))
else:
script_name = ""
Expand Down Expand Up @@ -138,7 +138,7 @@ def setup_environ_items(environ, headers):
def generate_response(response, event):
returndict = {"statusCode": response.status_code}

if "multiValueHeaders" in event:
if "multiValueHeaders" in event and event["multiValueHeaders"]:
returndict["multiValueHeaders"] = group_headers(response.headers)
else:
returndict["headers"] = split_headers(response.headers)
Expand All @@ -164,12 +164,27 @@ def generate_response(response, event):
return returndict


def strip_express_gateway_query_params(path):
"""Contrary to regular AWS lambda HTTP events, Express Gateway
(https://github.com/ExpressGateway/express-gateway-plugin-lambda)
adds query parameters to the path, which we need to strip.
"""
if "?" in path:
path = path.split("?")[0]
return path


def handle_request(app, event, context):
if event.get("source") in ["aws.events", "serverless-plugin-warmup"]:
print("Lambda warming event received, skipping handler")
return {}

if event.get("version") is None and event.get("isBase64Encoded") is None and not is_alb_event(event):
if (
event.get("version") is None
and event.get("isBase64Encoded") is None
and event.get("requestPath") is not None
and not is_alb_event(event)
):
return handle_lambda_integration(app, event, context)

if event.get("version") == "2.0":
Expand All @@ -179,7 +194,7 @@ def handle_request(app, event, context):


def handle_payload_v1(app, event, context):
if "multiValueHeaders" in event:
if "multiValueHeaders" in event and event["multiValueHeaders"]:
headers = Headers(event["multiValueHeaders"])
else:
headers = Headers(event["headers"])
Expand All @@ -189,35 +204,35 @@ def handle_payload_v1(app, event, context):
# If a user is using a custom domain on API Gateway, they may have a base
# path in their URL. This allows us to strip it out via an optional
# environment variable.
path_info = event["path"]
path_info = strip_express_gateway_query_params(event["path"])
base_path = os.environ.get("API_GATEWAY_BASE_PATH")
if base_path:
script_name = "/" + base_path

if path_info.startswith(script_name):
path_info = path_info[len(script_name) :] # noqa: E203

body = event["body"] or ""
body = event.get("body") or ""
body = get_body_bytes(event, body)

environ = {
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": headers.get("Content-Type", ""),
"PATH_INFO": url_unquote(path_info),
"PATH_INFO": unquote(path_info),
"QUERY_STRING": encode_query_string(event),
"REMOTE_ADDR": event.get("requestContext", {}).get("identity", {}).get("sourceIp", ""),
"REMOTE_USER": event.get("requestContext", {}).get("authorizer", {}).get("principalId", ""),
"REMOTE_USER": (event.get("requestContext", {}).get("authorizer") or {}).get("principalId", ""),
"REQUEST_METHOD": event.get("httpMethod", {}),
"SCRIPT_NAME": script_name,
"SERVER_NAME": headers.get("Host", "lambda"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "80"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "443"),
"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.errors": sys.stderr,
"wsgi.input": io.BytesIO(body),
"wsgi.multiprocess": False,
"wsgi.multithread": False,
"wsgi.run_once": False,
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "http"),
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "https"),
"wsgi.version": (1, 0),
"serverless.authorizer": event.get("requestContext", {}).get("authorizer"),
"serverless.event": event,
Expand All @@ -237,31 +252,37 @@ def handle_payload_v2(app, event, context):

script_name = get_script_name(headers, event.get("requestContext", {}))

path_info = event["rawPath"]
path_info = strip_express_gateway_query_params(event["rawPath"])
base_path = os.environ.get("API_GATEWAY_BASE_PATH")
if base_path:
script_name = "/" + base_path

if path_info.startswith(script_name):
path_info = path_info[len(script_name) :] # noqa: E203

body = event.get("body", "")
body = get_body_bytes(event, body)

headers["Cookie"] = "; ".join(event.get("cookies", []))

environ = {
"CONTENT_LENGTH": str(len(body)),
"CONTENT_LENGTH": str(len(body or "")),
"CONTENT_TYPE": headers.get("Content-Type", ""),
"PATH_INFO": url_unquote(path_info),
"PATH_INFO": unquote(path_info),
"QUERY_STRING": event.get("rawQueryString", ""),
"REMOTE_ADDR": event.get("requestContext", {}).get("http", {}).get("sourceIp", ""),
"REMOTE_USER": event.get("requestContext", {}).get("authorizer", {}).get("principalId", ""),
"REQUEST_METHOD": event.get("requestContext", {}).get("http", {}).get("method", ""),
"SCRIPT_NAME": script_name,
"SERVER_NAME": headers.get("Host", "lambda"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "80"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "443"),
"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.errors": sys.stderr,
"wsgi.input": io.BytesIO(body),
"wsgi.multiprocess": False,
"wsgi.multithread": False,
"wsgi.run_once": False,
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "http"),
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "https"),
"wsgi.version": (1, 0),
"serverless.authorizer": event.get("requestContext", {}).get("authorizer"),
"serverless.event": event,
Expand All @@ -282,7 +303,7 @@ def handle_lambda_integration(app, event, context):

script_name = get_script_name(headers, event)

path_info = event["requestPath"]
path_info = strip_express_gateway_query_params(event["requestPath"])

for key, value in event.get("path", {}).items():
path_info = path_info.replace("{%s}" % key, value)
Expand All @@ -293,23 +314,23 @@ def handle_lambda_integration(app, event, context):
body = get_body_bytes(event, body)

environ = {
"CONTENT_LENGTH": str(len(body)),
"CONTENT_LENGTH": str(len(body or "")),
"CONTENT_TYPE": headers.get("Content-Type", ""),
"PATH_INFO": url_unquote(path_info),
"QUERY_STRING": url_encode(event.get("query", {})),
"PATH_INFO": unquote(path_info),
"QUERY_STRING": urlencode(event.get("query", {}), doseq=True),
"REMOTE_ADDR": event.get("identity", {}).get("sourceIp", ""),
"REMOTE_USER": event.get("principalId", ""),
"REQUEST_METHOD": event.get("method", ""),
"SCRIPT_NAME": script_name,
"SERVER_NAME": headers.get("Host", "lambda"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "80"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "443"),
"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.errors": sys.stderr,
"wsgi.input": io.BytesIO(body),
"wsgi.multiprocess": False,
"wsgi.multithread": False,
"wsgi.run_once": False,
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "http"),
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "https"),
"wsgi.version": (1, 0),
"serverless.authorizer": event.get("enhancedAuthContext"),
"serverless.event": event,
Expand Down
28 changes: 25 additions & 3 deletions cli/src/pcluster/api/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
# Generated by OpenAPI Generator (python-flask)

import datetime
from json import JSONEncoder

import six
from connexion.apps.flask_app import FlaskJSONEncoder
from flask.json.provider import DefaultJSONProvider

from pcluster.api.models.base_model_ import Model
from pcluster.utils import to_iso_timestr


class JSONEncoder(FlaskJSONEncoder):
class JSONEncoderForCli(JSONEncoder):
"""Make the model objects JSON serializable."""

include_nulls = False
Expand All @@ -35,4 +36,25 @@ def default(self, obj): # pylint: disable=arguments-renamed
return dikt
elif isinstance(obj, datetime.date):
return to_iso_timestr(obj)
return FlaskJSONEncoder.default(self, obj)
return JSONEncoder.default(self, obj)


class JSONEncoder(DefaultJSONProvider):
"""Make the model objects JSON serializable."""

include_nulls = False

def default(self, obj): # pylint: disable=arguments-renamed
"""Override the base method to add support for model objects serialization."""
if isinstance(obj, Model):
dikt = {}
for attr, _ in six.iteritems(obj.openapi_types):
value = getattr(obj, attr)
if value is None and not self.include_nulls:
continue
attr = obj.attribute_map[attr]
dikt[attr] = value
return dikt
elif isinstance(obj, datetime.date):
return to_iso_timestr(obj)
return super().default(obj)
2 changes: 1 addition & 1 deletion cli/src/pcluster/api/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and
# limitations under the License.

from connexion import ProblemException
from connexion.exceptions import ProblemException
from werkzeug.exceptions import HTTPException

from pcluster.api.models import (
Expand Down
9 changes: 5 additions & 4 deletions cli/src/pcluster/api/flask_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import functools
import logging

import connexion
from connexion import ProblemException
from connexion.apps.flask_app import FlaskApp
from connexion.decorators.validation import ParameterValidator
from connexion.exceptions import ProblemException
from flask import Response, jsonify, request
from werkzeug.exceptions import HTTPException

Expand Down Expand Up @@ -74,9 +74,10 @@ def __init__(self, swagger_ui: bool = False, validate_responses=False):
assert_valid_node_js()
options = {"swagger_ui": swagger_ui}

self.app = connexion.FlaskApp(__name__, specification_dir="openapi/", skip_error_handlers=True)
self.app = FlaskApp(__name__, specification_dir="openapi/", skip_error_handlers=True)
self.flask_app = self.app.app
self.flask_app.json_encoder = encoder.JSONEncoder
self.flask_app.json_provider_class = encoder.JSONEncoder
self.flask_app.json = encoder.JSONEncoder(self.flask_app)
self.app.add_api(
"openapi.yaml",
arguments={"title": "ParallelCluster"},
Expand Down
4 changes: 2 additions & 2 deletions cli/src/pcluster/cli/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,15 @@ def _run_operation(model, args, extra_args):
except Exception as e:
# format exception messages in the same manner as the api
message = pcluster.api.errors.exception_message(e)
error_encoded = encoder.JSONEncoder().encode(message)
error_encoded = encoder.JSONEncoderForCli().encode(message)
raise APIOperationException(json.loads(error_encoded))
else:
try:
return args.func(args, extra_args)
except pcluster.api.errors.ParallelClusterApiException as e:
# Format exception messages in the same manner as the api
message = pcluster.api.errors.exception_message(e)
error_encoded = encoder.JSONEncoder().encode(message)
error_encoded = encoder.JSONEncoderForCli().encode(message)
raise APIOperationException(json.loads(error_encoded))
except Exception as e:
raise e
Expand Down
4 changes: 2 additions & 2 deletions cli/src/pcluster/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def call(func_str, *args, **kwargs):
if isinstance(ret, tuple):
ret, status_code = ret
if status_code >= 400:
data = json.loads(encoder.JSONEncoder().encode(ret))
data = json.loads(encoder.JSONEncoderForCli().encode(ret))
raise APIOperationException(data)
data = json.loads(encoder.JSONEncoder().encode(ret))
data = json.loads(encoder.JSONEncoderForCli().encode(ret))
return jmespath.search(query, data) if query else data
2 changes: 1 addition & 1 deletion cli/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ def read_text(path):


def wire_translate(data):
return json.loads(encoder.JSONEncoder().encode(data))
return json.loads(encoder.JSONEncoderForCli().encode(data))
Loading