Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions base_client/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
T = typing.TypeVar("T")


@dataclasses.dataclass
@dataclasses.dataclass(kw_only=True, slots=True)
class BaseClient:
client: httpx.AsyncClient
retryer: circuit_breaker_box.Retrier[httpx.Response] | None = None
retrier: circuit_breaker_box.Retrier[httpx.Response] | None = None

async def send(self, *, request: httpx.Request) -> httpx.Response:
if self.retryer:
return await self.retryer.retry(self._process_request, request.url.host, request=request)
if self.retrier:
return await self.retrier.retry(self._process_request, request.url.host, request=request)
return await self._process_request(request)

def prepare_request( # noqa: PLR0913
Expand Down
2 changes: 1 addition & 1 deletion examples/example_client_with_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def main() -> None:
)
client = SomeSpecificClient(
client=httpx.AsyncClient(base_url=SOME_HOST, timeout=httpx.Timeout(1)),
retryer=retrier_with_circuit_breaker,
retrier=retrier_with_circuit_breaker,
)
answer = await client.some_method(params={})
logger.debug(answer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def main() -> None:
)
client = SomeSpecificClient(
client=httpx.AsyncClient(base_url=SOME_HOST, timeout=httpx.Timeout(1)),
retryer=retrier_with_circuit_breaker,
retrier=retrier_with_circuit_breaker,
)
answer = await client.some_method(params={})
logger.debug(answer)
Expand Down
6 changes: 3 additions & 3 deletions examples/example_client_with_retry_circuit_breaker_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


@dataclasses.dataclass
class TestRedisConnection(aioredis.Redis): # type: ignore[type-arg]
class FakeRedisConnection(aioredis.Redis): # type: ignore[type-arg]
async def incr(self, host: str | bytes, amount: int = 1) -> int:
logger.debug("host: %s, amount: %d{amount}", host, amount)
return amount
Expand Down Expand Up @@ -62,7 +62,7 @@ async def main() -> None:
circuit_breaker = circuit_breaker_box.CircuitBreakerRedis(
reset_timeout_in_seconds=RESET_TIMEOUT_IN_SECONDS,
max_failure_count=CIRCUIT_BREAKER_MAX_FAILURE_COUNT,
redis_connection=TestRedisConnection(),
redis_connection=FakeRedisConnection(),
)
retrier_with_circuit_breaker = circuit_breaker_box.Retrier[httpx.Response](
circuit_breaker=circuit_breaker,
Expand All @@ -72,7 +72,7 @@ async def main() -> None:
)
client = SomeSpecificClient(
client=httpx.AsyncClient(base_url=SOME_HOST, timeout=httpx.Timeout(1)),
retryer=retrier_with_circuit_breaker,
retrier=retrier_with_circuit_breaker,
)
answer = await client.some_method(params={"foo": "bar"})
logger.debug(answer.model_dump())
Expand Down
22 changes: 11 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import tenacity

import base_client
from examples.example_client_with_retry_circuit_breaker_redis import TestRedisConnection
from examples.example_client_with_retry_circuit_breaker_redis import FakeRedisConnection


TEST_BASE_URL = "http://example.com/"


class TestClient(base_client.BaseClient):
class FakeClient(base_client.BaseClient):
async def fetch_async(self, request: httpx.Request) -> httpx.Response:
return await self.send(request=request)

Expand All @@ -22,9 +22,9 @@ async def fetch_async(self, request: httpx.Request) -> httpx.Response:


@pytest.fixture(name="test_client_with_circuit_breaker_redis")
def test_client_with_circuit_breaker_redis() -> TestClient:
def test_client_with_circuit_breaker_redis() -> FakeClient:
circuit_breaker = circuit_breaker_box.CircuitBreakerRedis(
redis_connection=TestRedisConnection(),
redis_connection=FakeRedisConnection(),
reset_timeout_in_seconds=RESET_TIMEOUT_IN_SECONDS,
max_failure_count=CLIENT_MAX_FAILURE_COUNT,
)
Expand All @@ -34,14 +34,14 @@ def test_client_with_circuit_breaker_redis() -> TestClient:
retry_cause=tenacity.retry_if_exception_type((httpx.RequestError, base_client.errors.HttpStatusError)),
wait_strategy=tenacity.wait_none(),
)
return TestClient(
return FakeClient(
client=httpx.AsyncClient(base_url=TEST_BASE_URL, timeout=httpx.Timeout(1)),
retryer=retrier_with_circuit_breaker,
retrier=retrier_with_circuit_breaker,
)


@pytest.fixture(name="test_client_with_circuit_breaker_in_memory")
def test_client_with_circuit_breaker_in_memory() -> TestClient:
def test_client_with_circuit_breaker_in_memory() -> FakeClient:
circuit_breaker = circuit_breaker_box.CircuitBreakerInMemory(
reset_timeout_in_seconds=RESET_TIMEOUT_IN_SECONDS,
max_cache_size=MAX_CACHE_SIZE,
Expand All @@ -53,12 +53,12 @@ def test_client_with_circuit_breaker_in_memory() -> TestClient:
retry_cause=tenacity.retry_if_exception_type((httpx.RequestError, base_client.errors.HttpStatusError)),
wait_strategy=tenacity.wait_none(),
)
return TestClient(
return FakeClient(
client=httpx.AsyncClient(base_url=TEST_BASE_URL, timeout=httpx.Timeout(1)),
retryer=retrier_with_circuit_breaker,
retrier=retrier_with_circuit_breaker,
)


@pytest.fixture(name="test_client")
def fixture_test_client() -> TestClient:
return TestClient(client=httpx.AsyncClient(base_url=TEST_BASE_URL, timeout=httpx.Timeout(1)))
def fixture_test_client() -> FakeClient:
return FakeClient(client=httpx.AsyncClient(base_url=TEST_BASE_URL, timeout=httpx.Timeout(1)))
14 changes: 7 additions & 7 deletions tests/test_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import base_client
from base_client.response import response_to_model
from tests.conftest import TEST_BASE_URL, TestClient
from tests.conftest import TEST_BASE_URL, FakeClient


@respx.mock
Expand Down Expand Up @@ -35,15 +35,15 @@
),
],
)
async def test_client_async(test_client: TestClient, expected_request: httpx.Request) -> None:
async def test_client_async(test_client: FakeClient, expected_request: httpx.Request) -> None:
mocked_route = respx.get(expected_request.url).mock(return_value=httpx.Response(status_code=httpx.codes.OK))
response = await test_client.fetch_async(expected_request)
assert mocked_route.called
assert response.status_code == httpx.codes.OK


@respx.mock
async def test_client_request_404(test_client: TestClient) -> None:
async def test_client_request_404(test_client: FakeClient) -> None:
mocked_route = respx.get(TEST_BASE_URL).mock(return_value=httpx.Response(status_code=httpx.codes.NOT_FOUND))
response = await test_client.fetch_async(test_client.prepare_request(method="GET", url=TEST_BASE_URL))
assert mocked_route.called
Expand Down Expand Up @@ -75,7 +75,7 @@ async def test_client_request_404(test_client: TestClient) -> None:
(httpx.TooManyRedirects("TooManyRedirects message"), httpx.TooManyRedirects),
],
)
async def test_retries(side_effect: type[Exception], expected_raise: type[Exception], test_client: TestClient) -> None:
async def test_retries(side_effect: type[Exception], expected_raise: type[Exception], test_client: FakeClient) -> None:
mocked_route = respx.get(TEST_BASE_URL).mock(side_effect=side_effect)
with pytest.raises(expected_raise):
await test_client.fetch_async(test_client.prepare_request(method="GET", url=TEST_BASE_URL))
Expand All @@ -98,7 +98,7 @@ async def test_retries(side_effect: type[Exception], expected_raise: type[Except
],
)
async def test_wont_retry(
side_effect: type[Exception], expected_raise: type[Exception], test_client: TestClient
side_effect: type[Exception], expected_raise: type[Exception], test_client: FakeClient
) -> None:
mocked_route = respx.get(TEST_BASE_URL).mock(side_effect=side_effect)

Expand All @@ -115,7 +115,7 @@ async def test_wont_retry(
(599, base_client.HttpServerError),
],
)
async def test_validate_response(status_code: int, side_effect: type[Exception], test_client: TestClient) -> None:
async def test_validate_response(status_code: int, side_effect: type[Exception], test_client: FakeClient) -> None:
response = httpx.Response(
status_code=status_code,
content=b"",
Expand Down Expand Up @@ -145,6 +145,6 @@ class TestModel(pydantic.BaseModel):
(httpx.URL(TEST_BASE_URL + "?1=2"), [("3", "4")], TEST_BASE_URL + "?1=2&3=4"),
],
)
async def test_prepare_request(url: str, params: dict[str, str], expected_url: str, test_client: TestClient) -> None:
async def test_prepare_request(url: str, params: dict[str, str], expected_url: str, test_client: FakeClient) -> None:
request = test_client.prepare_request(method="GET", url=url, params=params)
assert request.url == expected_url
10 changes: 5 additions & 5 deletions tests/test_circuit_breaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import pytest
import respx

from examples.example_client_with_retry_circuit_breaker_redis import TestRedisConnection
from tests.conftest import CLIENT_MAX_FAILURE_COUNT, TEST_BASE_URL, TestClient
from examples.example_client_with_retry_circuit_breaker_redis import FakeRedisConnection
from tests.conftest import CLIENT_MAX_FAILURE_COUNT, TEST_BASE_URL, FakeClient


@pytest.mark.parametrize(
Expand Down Expand Up @@ -35,7 +35,7 @@
)
@respx.mock
async def test_circuit_breaker_in_memory(
test_client_with_circuit_breaker_in_memory: TestClient, side_effect: Exception
test_client_with_circuit_breaker_in_memory: FakeClient, side_effect: Exception
) -> None:
mocked_route = respx.get(TEST_BASE_URL).mock(side_effect=side_effect)

Expand Down Expand Up @@ -64,13 +64,13 @@ async def test_circuit_breaker_redis(
side_effect: type[Exception],
expected_raise: type[Exception],
errors_by_host_in_redis: int,
test_client_with_circuit_breaker_redis: TestClient,
test_client_with_circuit_breaker_redis: FakeClient,
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def mock_return(*args: typing.Any, **kwargs: typing.Any) -> int: # noqa: ARG001, ANN401
return errors_by_host_in_redis

monkeypatch.setattr(TestRedisConnection, "get", mock_return)
monkeypatch.setattr(FakeRedisConnection, "get", mock_return)

respx.get(TEST_BASE_URL).mock(side_effect=side_effect)
with pytest.raises(expected_raise):
Expand Down