Skip to content

Allow proper type on AnthropicProvider when using Bedrock #2490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions docs/evals.md
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,8 @@ async def double_number(input_value: int) -> int:
# Run evaluation with unlimited concurrency
t0 = time.time()
report_default = dataset.evaluate_sync(double_number)
print(f'Evaluation took less than 0.5s: {time.time() - t0 < 0.5}')
#> Evaluation took less than 0.5s: True
print(f'Evaluation took less than 1s: {time.time() - t0 < 1}')
#> Evaluation took less than 1s: True

report_default.print(include_input=True, include_output=True, include_durations=False) # (1)!
"""
Expand Down
9 changes: 5 additions & 4 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@
)
from ..profiles import ModelProfileSpec
from ..providers import Provider, infer_provider
from ..providers.anthropic import AsyncAnthropicClient
from ..settings import ModelSettings
from ..tools import ToolDefinition
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent

try:
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
from anthropic import NOT_GIVEN, APIStatusError, AsyncStream
from anthropic.types.beta import (
BetaBase64PDFBlockParam,
BetaBase64PDFSourceParam,
Expand Down Expand Up @@ -134,7 +135,7 @@ class AnthropicModel(Model):
Apart from `__init__`, all methods are private or match those of the base class.
"""

client: AsyncAnthropic = field(repr=False)
client: AsyncAnthropicClient = field(repr=False)

_model_name: AnthropicModelName = field(repr=False)
_system: str = field(default='anthropic', repr=False)
Expand All @@ -143,7 +144,7 @@ def __init__(
self,
model_name: AnthropicModelName,
*,
provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic',
provider: Literal['anthropic'] | Provider[AsyncAnthropicClient] = 'anthropic',
profile: ModelProfileSpec | None = None,
settings: ModelSettings | None = None,
):
Expand All @@ -153,7 +154,7 @@ def __init__(
model_name: The name of the Anthropic model to use. List of model names available
[here](https://docs.anthropic.com/en/docs/about-claude/models).
provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an
instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used.
instance of `Provider[ASYNC_ANTHROPIC_CLIENT]`. If not provided, the other parameters will be used.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
instance of `Provider[ASYNC_ANTHROPIC_CLIENT]`. If not provided, the other parameters will be used.
instance of `Provider[AsyncAnthropicClient]`. If not provided, the other parameters will be used.

profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
settings: Default model settings for this model instance.
"""
Expand Down
19 changes: 11 additions & 8 deletions pydantic_ai_slim/pydantic_ai/providers/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations as _annotations

import os
from typing import overload
from typing import Union, overload

import httpx
from typing_extensions import TypeAlias

from pydantic_ai.exceptions import UserError
from pydantic_ai.models import cached_async_http_client
Expand All @@ -12,15 +13,18 @@
from pydantic_ai.providers import Provider

try:
from anthropic import AsyncAnthropic
except ImportError as _import_error: # pragma: no cover
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
except ImportError as _import_error:
raise ImportError(
'Please install the `anthropic` package to use the Anthropic provider, '
'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`'
) from _import_error


class AnthropicProvider(Provider[AsyncAnthropic]):
AsyncAnthropicClient: TypeAlias = Union[AsyncAnthropic, AsyncAnthropicBedrock]


class AnthropicProvider(Provider[AsyncAnthropicClient]):
"""Provider for Anthropic API."""

@property
Expand All @@ -32,14 +36,14 @@ def base_url(self) -> str:
return str(self._client.base_url)

@property
def client(self) -> AsyncAnthropic:
def client(self) -> AsyncAnthropicClient:
return self._client

def model_profile(self, model_name: str) -> ModelProfile | None:
return anthropic_model_profile(model_name)

@overload
def __init__(self, *, anthropic_client: AsyncAnthropic | None = None) -> None: ...
def __init__(self, *, anthropic_client: AsyncAnthropicClient | None = None) -> None: ...

@overload
def __init__(self, *, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None: ...
Expand All @@ -48,7 +52,7 @@ def __init__(
self,
*,
api_key: str | None = None,
anthropic_client: AsyncAnthropic | None = None,
anthropic_client: AsyncAnthropicClient | None = None,
http_client: httpx.AsyncClient | None = None,
) -> None:
"""Create a new Anthropic provider.
Expand All @@ -71,7 +75,6 @@ def __init__(
'Set the `ANTHROPIC_API_KEY` environment variable or pass it via `AnthropicProvider(api_key=...)`'
'to use the Anthropic provider.'
)

if http_client is not None:
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client)
else:
Expand Down
1 change: 1 addition & 0 deletions tests/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@

def test_init():
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key='foobar'))
assert isinstance(m.client, AsyncAnthropic)
assert m.client.api_key == 'foobar'
assert m.model_name == 'claude-3-5-haiku-latest'
assert m.system == 'anthropic'
Expand Down
30 changes: 11 additions & 19 deletions tests/providers/test_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from __future__ import annotations as _annotations

import httpx
import pytest

from pydantic_ai.exceptions import UserError

from ..conftest import TestEnv, try_import
from ..conftest import try_import

with try_import() as imports_successful:
from anthropic import AsyncAnthropic
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock

from pydantic_ai.providers.anthropic import AnthropicProvider

Expand All @@ -24,24 +21,19 @@ def test_anthropic_provider():
assert provider.client.api_key == 'api-key'


def test_anthropic_provider_need_api_key(env: TestEnv) -> None:
env.remove('ANTHROPIC_API_KEY')
with pytest.raises(UserError, match=r'.*ANTHROPIC_API_KEY.*'):
AnthropicProvider()


def test_anthropic_provider_pass_http_client() -> None:
http_client = httpx.AsyncClient()
provider = AnthropicProvider(http_client=http_client, api_key='api-key')
assert isinstance(provider.client, AsyncAnthropic)
# Verify the http_client is being used by the AsyncAnthropic client
assert provider.client._client == http_client # type: ignore[reportPrivateUsage]


def test_anthropic_provider_pass_anthropic_client() -> None:
anthropic_client = AsyncAnthropic(api_key='api-key')
provider = AnthropicProvider(anthropic_client=anthropic_client)
assert provider.client == anthropic_client
bedrock_client = AsyncAnthropicBedrock(
aws_secret_key='aws-secret-key',
aws_access_key='aws-access-key',
aws_region='us-west-2',
aws_profile='default',
aws_session_token='aws-session-token',
)
provider = AnthropicProvider(anthropic_client=bedrock_client)
assert provider.client == bedrock_client


def test_anthropic_provider_with_env_base_url(monkeypatch: pytest.MonkeyPatch) -> None:
Expand Down
2 changes: 0 additions & 2 deletions tests/providers/test_provider_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from pydantic_ai.providers.anthropic import AnthropicProvider
from pydantic_ai.providers.azure import AzureProvider
from pydantic_ai.providers.cohere import CohereProvider
from pydantic_ai.providers.deepseek import DeepSeekProvider
from pydantic_ai.providers.fireworks import FireworksProvider
from pydantic_ai.providers.github import GitHubProvider
Expand All @@ -35,7 +34,6 @@

test_infer_provider_params = [
('anthropic', AnthropicProvider, 'ANTHROPIC_API_KEY'),
('cohere', CohereProvider, 'CO_API_KEY'),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove this?

('deepseek', DeepSeekProvider, 'DEEPSEEK_API_KEY'),
('openrouter', OpenRouterProvider, 'OPENROUTER_API_KEY'),
('vercel', VercelProvider, 'VERCEL_AI_GATEWAY_API_KEY'),
Expand Down
Loading