diff --git a/requirements-test.txt b/requirements-test.txt deleted file mode 100644 index 847062e..0000000 --- a/requirements-test.txt +++ /dev/null @@ -1,3 +0,0 @@ -pytest -pytest-cov -responses diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c839394 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +aiohttp>=3.8.1 \ No newline at end of file diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 0000000..5ff2ee2 --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,4 @@ +pytest>=6.2.5 +pytest-asyncio>=0.16.0 +pytest-cov>=3.0.0 +responses>=0.16.0 \ No newline at end of file diff --git a/tests/server/cert/cert.pem b/tests/server/cert/cert.pem new file mode 100644 index 0000000..eb7569e --- /dev/null +++ b/tests/server/cert/cert.pem @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDazCCAlOgAwIBAgIUZwtHML3lPY/qQUYGRSsQfhznp/AwDQYJKoZIhvcNAQEL +BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMTEyMjgxNjIyNDBaFw0yMjEy +MjgxNjIyNDBaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw +HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQDBZVwOlVFYTBAAFRPdxjz+ao7s+d0Z648zrUY/E2hQ +0aNW+bmQcKk3/c/vm6Ai7tI8Jr2kMnkgL3/m1m9ntxhSrlny1tmoGpiZQT+TLUsT +oDm4GTe6tNnSmO3+oPfYy7NAL5GCix5lp/JGFqz9d91ZYawkp35BRMUrFgyf7Nl5 +M3HvaWFy6oii6PFb6V+yMDiedVLGc4BdKm/9itphxeFIK0RvvTXmuYyUqQvcGI04 +WwlZW/2tTHrdHEA/6Z5GRU5JyaqH4IIinAZojzfqgEuFubwCJb6TAqmke0Q6JCjC +cNNHZclm2zqi+LG+k+Nh8I2ji4N9YT548ypOPN8ytqYJAgMBAAGjUzBRMB0GA1Ud +DgQWBBQfmg8fmIjBUfvDce+dFBdd1h5f2DAfBgNVHSMEGDAWgBQfmg8fmIjBUfvD +ce+dFBdd1h5f2DAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAc +fZKr4TDQiCt8+bP629rQHiNUQMFOCCSV0+jEmJmgRqwY45d1XKS9qlh8D/DowNk3 +wnrQviBNVZY1VwuWNCUGDYNKR+Io2dnnaYT9Z1GE3glnV+y3lx3g9qcpNxIhHClv +mg1HILCpbv5+GNSM4Y+Ds1GjvnnkkkVEnEzGjqfrdor9jfatpb9zfzykF/fibRfD +HCSGMCn2qbWH6kX25fa1CsN0L04wqhnWAyXfdJREDPmMpO6thPUBKmov6uqImjoj +SCyywPoOMYo7eoYTjzV6gJ9u41aIWLRrsWjo294nx04hvDcPTs7D3Mqy9OWWjwoT +90LDY52y0NZ/IfcprJo7 +-----END CERTIFICATE----- diff --git a/tests/server/cert/key.pem b/tests/server/cert/key.pem new file mode 100644 index 0000000..cb1f101 --- /dev/null +++ b/tests/server/cert/key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDBZVwOlVFYTBAA +FRPdxjz+ao7s+d0Z648zrUY/E2hQ0aNW+bmQcKk3/c/vm6Ai7tI8Jr2kMnkgL3/m +1m9ntxhSrlny1tmoGpiZQT+TLUsToDm4GTe6tNnSmO3+oPfYy7NAL5GCix5lp/JG +Fqz9d91ZYawkp35BRMUrFgyf7Nl5M3HvaWFy6oii6PFb6V+yMDiedVLGc4BdKm/9 +itphxeFIK0RvvTXmuYyUqQvcGI04WwlZW/2tTHrdHEA/6Z5GRU5JyaqH4IIinAZo +jzfqgEuFubwCJb6TAqmke0Q6JCjCcNNHZclm2zqi+LG+k+Nh8I2ji4N9YT548ypO +PN8ytqYJAgMBAAECggEAGcnNS7iHf1GtNIWa/5Cmo3pMEreCzykFEGwDmPeaimRZ +9ogYQXV0ax3yM046PU5BRHoaAaIlWrBayso/UsIsgSH0ppgVr//T52O2+YYpo7VS +3QSn1MK25Qk1eyW1xvfqsB1nttOcOwv5F3WAnc/0+S04Ci83e7aS4BrxlgW8Phft +bEK9AHB2ugDbge/043gOnMBWboo/sNgZV42k79nHOpSrSgNDM3cGjYRoCzlwmtI6 +BDridbZ307a+sEusLz+1mF2BGqcBIYgBEG7T4cVRfkMPtcAgwNS58R0Bj94TbvLt +eTuVelqdsDHGc3vtjtqrYtr6Ower47r/4XNAvkFEjQKBgQD2XWYZr8Ktc+QC9DI7 +4mo/xwF0dLF7CA/HrXZwkm9Nf0Nd/Maef7IbctJQubKIyTlSj+fxddiNTL5wkUgG +BFW6b6cMC4tCXCfGqubLwxIUo4MdygVXoquVHAY8bv+exZhqTNeP/aIrcJ5moefN +5SArFTNiY4jlo5uVZjPXn0q3QwKBgQDI9aMKvN+GRqvzVSEhbSRiVbYJxBKDjrPF +h3IcQRhTRNoExxRRi1jR7s9VHUIFISf1WHCooQBJ86gAR5gAZmFyoWtQHqXQWsP4 +QDHrTTbAz9vsu4tH8m9uHmKdMkwiucfnuqJqkXDVdr2Uv9uY9xwS1O6cIqpbmsHm +n5ZeCWrawwKBgCHjkSLhaX8gnPHHE43nRERHpKyXTL6myjzmYI91pTfc1LB+D/hH +ioF6FvIhySxFucvgncA6PLKbJusnIOgq+nvt1eWzRNG5CYOriJno2HjcUTHs0zVN +3Bpjw3vWrPTzK1ccAN7+vasKD9AAX3mUFgu3G91h4bfs/H7dky4K7GUvAoGBAJUP +vxRj7NlI6prf9mc48dgPA6xSx/jVjPtj8HyMvGJnm+AXWzbxSbzOivPzc9kiMuWF +6Grsoa45EdDDSjhhuL6yhUs0sIHQEbS+yUhkSczTYapDopiHd6gS2csIV/kaHPIC +Oh8aKrvsC8ueVGEuSqCdWTBvdjXkoRdUINE34w7JAoGAC4K5KPrnVoM2aAcGdrxM +5+rZxWJK+sl55Ui4vyaal7Z7brRnpk+vYiImkAe9aRaI/Bggtw+j3jJgmXfcNQf6 +2haVu1kFIRn3OVmYI+SziVCwWTUQVSK3iTocb4uo6LeATMuxbJMQp4nFkXWno7xy +WujVUfhuhvF0gm7tkH2wAVU= +-----END PRIVATE KEY----- diff --git a/tests/server/mockserver.py b/tests/server/mockserver.py new file mode 100644 index 0000000..685e7d8 --- /dev/null +++ b/tests/server/mockserver.py @@ -0,0 +1,52 @@ +import json +import ssl +from http.server import BaseHTTPRequestHandler, HTTPServer +from pathlib import Path +from threading import Thread +from urllib.parse import urljoin + +CERT_PATH = (Path(__file__).resolve().parents[0] / "cert").resolve() + + +class MockServer: + def __enter__(self): + # Using HTTPs to test SSL + self.httpd = HTTPServer(("127.0.0.1", 4443), _RequestHandler) + self.httpd.socket = ssl.wrap_socket(self.httpd.socket, + keyfile=(CERT_PATH / "key.pem").resolve(), + certfile=(CERT_PATH / "cert.pem").resolve(), + server_side=True) + self.address, self.port = self.httpd.server_address + self.thread = Thread(target=self.httpd.serve_forever) + self.thread.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.httpd.shutdown() + self.thread.join() + + def urljoin(self, url: str) -> str: + return urljoin("http://{}:{}".format(self.address, self.port), url) + + +class _RequestHandler(BaseHTTPRequestHandler): + def _send_response(self, status: int, content: str, content_type: str): + self.send_response(status) + self.send_header("Content-type", content_type) + self.end_headers() + self.wfile.write(content.encode("utf-8")) + + def do_POST(self): # NOQA + content_length = int(self.headers["Content-Length"]) + try: + post_data = json.loads(self.rfile.read(content_length).decode("utf-8")) # NOQA + except (AttributeError, TypeError, ValueError, KeyError) as er: + self._send_response(400, str(er), "text/html") + return + else: + self._send_response( + 200, + # Mirror data + json.dumps(post_data), + "application/json" + ) diff --git a/tests/test_ssl.py b/tests/test_ssl.py new file mode 100644 index 0000000..253de93 --- /dev/null +++ b/tests/test_ssl.py @@ -0,0 +1,80 @@ +import ssl +from typing import Optional, Union + +import pytest +from aiohttp import ClientConnectorCertificateError, TCPConnector + +from tests.server.mockserver import MockServer +from zyte_api.aio.client import create_session, AsyncClient + + +@pytest.mark.parametrize( + "verify_ssl, expected_ssl_mode", + [ + (None, None), + (False, False), + (True, None) + ], +) +def test_verify_ssl(verify_ssl: Optional[bool], expected_ssl_mode: Optional[bool]): + session = create_session(verify_ssl=verify_ssl) + assert session.connector._ssl == expected_ssl_mode # type: ignore + + +@pytest.mark.parametrize( + "ssl_mode", + [ + None, + False, + True, + ssl.SSLContext() + ], +) +def test_connector_ssl(ssl_mode: Optional[Union[bool, ssl.SSLContext]]): + connector = TCPConnector(ssl=ssl_mode) + session = create_session(connector=connector) + assert session.connector._ssl == ssl_mode # type: ignore + + +@pytest.mark.parametrize( + "verify_ssl, ssl_mode, error_message", + [ + (False, None, r"Provided `verify_ssl` argument \(False\) conflicts " + r"with `connector` argument \(connector\._ssl=None\)"), + (False, ssl.SSLContext(), r"Provided `verify_ssl` argument \(False\) conflicts " + r"with `connector` argument \(connector\._ssl=\)"), + (None, ssl.SSLContext(), r"Provided `verify_ssl` argument \(None\) conflicts " + r"with `connector` argument \(connector\._ssl=\)") + ] +) +def test_verify_connector_conflict(verify_ssl: Optional[bool], + ssl_mode: Optional[Union[bool, ssl.SSLContext]], + error_message: str): + connector = TCPConnector(ssl=ssl_mode) + with pytest.raises(ValueError, match=error_message): + create_session(connector=connector, verify_ssl=verify_ssl) + + +@pytest.mark.asyncio +async def test_disabled_ssl_verification(): + with MockServer(): + session = create_session(verify_ssl=False) + data = {'url': 'https://example.com', 'browserHtml': True} + client = AsyncClient(api_url="https://127.0.0.1:4443/", api_key="TEST") + resp = await client.request_raw(data, + handle_retries=False, + session=session) + # Check mirorred data + assert resp == data + + +@pytest.mark.asyncio +async def test_enabled_ssl_verification(): + with MockServer(): + session = create_session() + data = {'url': 'https://example.com', 'browserHtml': True} + client = AsyncClient(api_url="https://127.0.0.1:4443/", api_key="TEST") + with pytest.raises((ClientConnectorCertificateError, ssl.SSLError)): # NOQA + await client.request_raw(data, + handle_retries=False, + session=session) diff --git a/tox.ini b/tox.ini index 7770a3d..bb749ae 100644 --- a/tox.ini +++ b/tox.ini @@ -3,7 +3,8 @@ envlist = py36,py37,py38,py39,py310,mypy,docs [testenv] deps = - -rrequirements-test.txt + -rrequirements.txt + -rtests/requirements.txt commands = py.test \ diff --git a/zyte_api/aio/client.py b/zyte_api/aio/client.py index e82c832..4d14e61 100644 --- a/zyte_api/aio/client.py +++ b/zyte_api/aio/client.py @@ -3,6 +3,7 @@ """ import asyncio +import logging import time from functools import partial from typing import Optional, Iterator, List @@ -18,18 +19,27 @@ from ..stats import AggStats, ResponseStats from ..utils import user_agent - # 120 seconds is probably too long, but we are concerned about the case with # many concurrent requests and some processing logic running in the same reactor, # thus, saturating the CPU. This will make timeouts more likely. AIO_API_TIMEOUT = aiohttp.ClientTimeout(total=API_TIMEOUT + 120) +logger = logging.getLogger(__name__) + -def create_session(connection_pool_size=100, **kwargs) -> aiohttp.ClientSession: +def create_session(connection_pool_size: int = 100, **kwargs) -> aiohttp.ClientSession: """ Create a session with parameters suited for Zyte API """ kwargs.setdefault('timeout', AIO_API_TIMEOUT) + verify_ssl_provided = "verify_ssl" in kwargs + + ssl = _set_ssl_mode(kwargs.pop("verify_ssl", None)) if "connector" not in kwargs: - kwargs["connector"] = TCPConnector(limit=connection_pool_size) + kwargs["connector"] = TCPConnector(limit=connection_pool_size, ssl=ssl) + else: + # If verify_ssl was explicitly provided, but differs from ssl mode provided to connector + if verify_ssl_provided and kwargs["connector"]._ssl != ssl: # NOQA + raise ValueError(f"Provided `verify_ssl` argument ({ssl}) " + f"conflicts with `connector` argument (connector._ssl={kwargs['connector']._ssl})") # NOQA return aiohttp.ClientSession(**kwargs) @@ -43,15 +53,22 @@ def _post_func(session): return session.post +def _set_ssl_mode(verify_ssl: Optional[bool]) -> Optional[bool]: + # ssl certificate check could be either + # None (default, enabled), False (disabled), or custom SSL context (not relevant) + # https://docs.aiohttp.org/en/stable/client_reference.html#tcpconnector + return False if verify_ssl is False else None + + class AsyncClient: def __init__(self, *, - api_key=None, - api_url=API_URL, - n_conn=15, + api_key: Optional[str] = None, + api_url: str = API_URL, + n_conn: int = 15, ): - self.api_key = get_apikey(api_key) - self.api_url = api_url - self.n_conn = n_conn + self.api_key: str = get_apikey(api_key) + self.api_url: str = api_url + self.n_conn: int = n_conn self.agg_stats = AggStats() async def request_raw(self, query: dict, *, @@ -76,7 +93,7 @@ async def request(): url=self.api_url + endpoint, json=query, auth=auth, - headers=headers, + headers=headers ) try: @@ -144,7 +161,7 @@ def request_parallel_as_completed(self, async def _request(query): async with sem: return await self.request_raw(query, - endpoint=endpoint, - session=session) + endpoint=endpoint, + session=session) return asyncio.as_completed([_request(query) for query in queries])