diff --git a/docs/evals.md b/docs/evals.md index 8d3091ad5..e1bd5a814 100644 --- a/docs/evals.md +++ b/docs/evals.md @@ -341,8 +341,8 @@ async def double_number(input_value: int) -> int: # Run evaluation with unlimited concurrency t0 = time.time() report_default = dataset.evaluate_sync(double_number) -print(f'Evaluation took less than 0.5s: {time.time() - t0 < 0.5}') -#> Evaluation took less than 0.5s: True +print(f'Evaluation took less than 1s: {time.time() - t0 < 1}') +#> Evaluation took less than 1s: True report_default.print(include_input=True, include_output=True, include_durations=False) # (1)! """ diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index b65e319b4..723ec690f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -37,12 +37,13 @@ ) from ..profiles import ModelProfileSpec from ..providers import Provider, infer_provider +from ..providers.anthropic import AsyncAnthropicClient from ..settings import ModelSettings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent try: - from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream + from anthropic import NOT_GIVEN, APIStatusError, AsyncStream from anthropic.types.beta import ( BetaBase64PDFBlockParam, BetaBase64PDFSourceParam, @@ -134,7 +135,7 @@ class AnthropicModel(Model): Apart from `__init__`, all methods are private or match those of the base class. """ - client: AsyncAnthropic = field(repr=False) + client: AsyncAnthropicClient = field(repr=False) _model_name: AnthropicModelName = field(repr=False) _system: str = field(default='anthropic', repr=False) @@ -143,7 +144,7 @@ def __init__( self, model_name: AnthropicModelName, *, - provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic', + provider: Literal['anthropic'] | Provider[AsyncAnthropicClient] = 'anthropic', profile: ModelProfileSpec | None = None, settings: ModelSettings | None = None, ): @@ -153,7 +154,7 @@ def __init__( model_name: The name of the Anthropic model to use. List of model names available [here](https://docs.anthropic.com/en/docs/about-claude/models). provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an - instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used. + instance of `Provider[ASYNC_ANTHROPIC_CLIENT]`. If not provided, the other parameters will be used. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. settings: Default model settings for this model instance. """ diff --git a/pydantic_ai_slim/pydantic_ai/providers/anthropic.py b/pydantic_ai_slim/pydantic_ai/providers/anthropic.py index 20bc3255e..b596c4d7e 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/providers/anthropic.py @@ -1,9 +1,10 @@ from __future__ import annotations as _annotations import os -from typing import overload +from typing import Union, overload import httpx +from typing_extensions import TypeAlias from pydantic_ai.exceptions import UserError from pydantic_ai.models import cached_async_http_client @@ -12,15 +13,18 @@ from pydantic_ai.providers import Provider try: - from anthropic import AsyncAnthropic -except ImportError as _import_error: # pragma: no cover + from anthropic import AsyncAnthropic, AsyncAnthropicBedrock +except ImportError as _import_error: raise ImportError( 'Please install the `anthropic` package to use the Anthropic provider, ' 'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`' ) from _import_error -class AnthropicProvider(Provider[AsyncAnthropic]): +AsyncAnthropicClient: TypeAlias = Union[AsyncAnthropic, AsyncAnthropicBedrock] + + +class AnthropicProvider(Provider[AsyncAnthropicClient]): """Provider for Anthropic API.""" @property @@ -32,14 +36,14 @@ def base_url(self) -> str: return str(self._client.base_url) @property - def client(self) -> AsyncAnthropic: + def client(self) -> AsyncAnthropicClient: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: return anthropic_model_profile(model_name) @overload - def __init__(self, *, anthropic_client: AsyncAnthropic | None = None) -> None: ... + def __init__(self, *, anthropic_client: AsyncAnthropicClient | None = None) -> None: ... @overload def __init__(self, *, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None: ... @@ -48,7 +52,7 @@ def __init__( self, *, api_key: str | None = None, - anthropic_client: AsyncAnthropic | None = None, + anthropic_client: AsyncAnthropicClient | None = None, http_client: httpx.AsyncClient | None = None, ) -> None: """Create a new Anthropic provider. @@ -71,7 +75,6 @@ def __init__( 'Set the `ANTHROPIC_API_KEY` environment variable or pass it via `AnthropicProvider(api_key=...)`' 'to use the Anthropic provider.' ) - if http_client is not None: self._client = AsyncAnthropic(api_key=api_key, http_client=http_client) else: diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 2f24b95f1..786292db8 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -95,6 +95,7 @@ def test_init(): m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key='foobar')) + assert isinstance(m.client, AsyncAnthropic) assert m.client.api_key == 'foobar' assert m.model_name == 'claude-3-5-haiku-latest' assert m.system == 'anthropic' diff --git a/tests/providers/test_anthropic.py b/tests/providers/test_anthropic.py index 44f47554b..79871a36c 100644 --- a/tests/providers/test_anthropic.py +++ b/tests/providers/test_anthropic.py @@ -1,14 +1,11 @@ from __future__ import annotations as _annotations -import httpx import pytest -from pydantic_ai.exceptions import UserError - -from ..conftest import TestEnv, try_import +from ..conftest import try_import with try_import() as imports_successful: - from anthropic import AsyncAnthropic + from anthropic import AsyncAnthropic, AsyncAnthropicBedrock from pydantic_ai.providers.anthropic import AnthropicProvider @@ -24,24 +21,19 @@ def test_anthropic_provider(): assert provider.client.api_key == 'api-key' -def test_anthropic_provider_need_api_key(env: TestEnv) -> None: - env.remove('ANTHROPIC_API_KEY') - with pytest.raises(UserError, match=r'.*ANTHROPIC_API_KEY.*'): - AnthropicProvider() - - -def test_anthropic_provider_pass_http_client() -> None: - http_client = httpx.AsyncClient() - provider = AnthropicProvider(http_client=http_client, api_key='api-key') - assert isinstance(provider.client, AsyncAnthropic) - # Verify the http_client is being used by the AsyncAnthropic client - assert provider.client._client == http_client # type: ignore[reportPrivateUsage] - - def test_anthropic_provider_pass_anthropic_client() -> None: anthropic_client = AsyncAnthropic(api_key='api-key') provider = AnthropicProvider(anthropic_client=anthropic_client) assert provider.client == anthropic_client + bedrock_client = AsyncAnthropicBedrock( + aws_secret_key='aws-secret-key', + aws_access_key='aws-access-key', + aws_region='us-west-2', + aws_profile='default', + aws_session_token='aws-session-token', + ) + provider = AnthropicProvider(anthropic_client=bedrock_client) + assert provider.client == bedrock_client def test_anthropic_provider_with_env_base_url(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/tests/providers/test_provider_names.py b/tests/providers/test_provider_names.py index e09ba2a11..99b6b59b5 100644 --- a/tests/providers/test_provider_names.py +++ b/tests/providers/test_provider_names.py @@ -16,7 +16,6 @@ from pydantic_ai.providers.anthropic import AnthropicProvider from pydantic_ai.providers.azure import AzureProvider - from pydantic_ai.providers.cohere import CohereProvider from pydantic_ai.providers.deepseek import DeepSeekProvider from pydantic_ai.providers.fireworks import FireworksProvider from pydantic_ai.providers.github import GitHubProvider @@ -35,7 +34,6 @@ test_infer_provider_params = [ ('anthropic', AnthropicProvider, 'ANTHROPIC_API_KEY'), - ('cohere', CohereProvider, 'CO_API_KEY'), ('deepseek', DeepSeekProvider, 'DEEPSEEK_API_KEY'), ('openrouter', OpenRouterProvider, 'OPENROUTER_API_KEY'), ('vercel', VercelProvider, 'VERCEL_AI_GATEWAY_API_KEY'),