Skip to content

Fixes #2513: Add support for Free-Form Function Calling and Context Free Grammar constraints over tools #2572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ToolOutput,
_OutputSpecItem, # type: ignore[reportPrivateUsage]
)
from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition
from .tools import FunctionTextFormat, GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition
from .toolsets.abstract import AbstractToolset, ToolsetTool

if TYPE_CHECKING:
Expand Down Expand Up @@ -591,6 +591,7 @@ class OutputObjectDefinition:
name: str | None = None
description: str | None = None
strict: bool | None = None
text_format: Literal['text'] | FunctionTextFormat | None = None


@dataclass(init=False)
Expand Down Expand Up @@ -621,6 +622,7 @@ def __init__(
name: str | None = None,
description: str | None = None,
strict: bool | None = None,
text_format: Literal['text'] | FunctionTextFormat | None = None,
):
if inspect.isfunction(output) or inspect.ismethod(output):
self._function_schema = _function_schema.function_schema(output, GenerateToolJsonSchema)
Expand Down Expand Up @@ -663,6 +665,7 @@ def __init__(
description=description,
json_schema=json_schema,
strict=strict,
text_format=text_format,
)

async def process(
Expand Down Expand Up @@ -920,19 +923,23 @@ def build(
name = None
description = None
strict = None
text_format = None
if isinstance(output, ToolOutput):
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
name = output.name
description = output.description
strict = output.strict
text_format = output.text_format

output = output.output

description = description or default_description
if strict is None:
strict = default_strict

processor = ObjectOutputProcessor(output=output, description=description, strict=strict)
processor = ObjectOutputProcessor(
output=output, description=description, strict=strict, text_format=text_format
)
object_def = processor.object_def

if name is None:
Expand All @@ -957,6 +964,7 @@ def build(
description=description,
parameters_json_schema=object_def.json_schema,
strict=object_def.strict,
text_format=object_def.text_format,
outer_typed_dict_key=processor.outer_typed_dict_key,
kind='output',
)
Expand Down
13 changes: 12 additions & 1 deletion pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any, Callable, ClassVar, cast, overload
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, cast, overload

from opentelemetry.trace import NoOpTracer, use_span
from pydantic.json_schema import GenerateJsonSchema
Expand Down Expand Up @@ -39,6 +39,7 @@
from ..tools import (
AgentDepsT,
DocstringFormat,
FunctionTextFormat,
GenerateToolJsonSchema,
RunContext,
Tool,
Expand Down Expand Up @@ -963,6 +964,7 @@ def tool(
require_parameter_descriptions: bool = False,
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
strict: bool | None = None,
text_format: Literal['text'] | FunctionTextFormat | None = None,
) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...

def tool(
Expand All @@ -977,6 +979,7 @@ def tool(
require_parameter_descriptions: bool = False,
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
strict: bool | None = None,
text_format: Literal['text'] | FunctionTextFormat | None = None,
) -> Any:
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.

Expand Down Expand Up @@ -1021,6 +1024,8 @@ async def spam(ctx: RunContext[str], y: float) -> float:
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
text_format: Used to invoke the function using free-form function calling (only affects OpenAI).
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
"""

def tool_decorator(
Expand All @@ -1037,6 +1042,7 @@ def tool_decorator(
require_parameter_descriptions,
schema_generator,
strict,
text_format,
)
return func_

Expand All @@ -1057,6 +1063,7 @@ def tool_plain(
require_parameter_descriptions: bool = False,
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
strict: bool | None = None,
text_format: Literal['text'] | FunctionTextFormat | None = None,
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...

def tool_plain(
Expand All @@ -1071,6 +1078,7 @@ def tool_plain(
require_parameter_descriptions: bool = False,
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
strict: bool | None = None,
text_format: Literal['text'] | FunctionTextFormat | None = None,
) -> Any:
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.

Expand Down Expand Up @@ -1115,6 +1123,8 @@ async def spam(ctx: RunContext[str]) -> float:
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
text_format: Used to invoke the function using free-form function calling (only affects OpenAI).
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
"""

def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
Expand All @@ -1129,6 +1139,7 @@ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams
require_parameter_descriptions,
schema_generator,
strict,
text_format,
)
return func_

Expand Down
65 changes: 57 additions & 8 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from openai.types.responses.response_input_param import FunctionCallOutput, Message
from openai.types.shared import ReasoningEffort
from openai.types.shared_params import Reasoning
from openai.types.shared_params.custom_tool_input_format import CustomToolInputFormat
except ImportError as _import_error:
raise ImportError(
'Please install `openai` to use the OpenAI model, '
Expand Down Expand Up @@ -762,7 +763,7 @@ async def request(
response = await self._responses_create(
messages, False, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters
)
return self._process_response(response)
return self._process_response(response, model_request_parameters)

@asynccontextmanager
async def request_stream(
Expand All @@ -779,7 +780,11 @@ async def request_stream(
async with response:
yield await self._process_streamed_response(response, model_request_parameters)

def _process_response(self, response: responses.Response) -> ModelResponse:
def _process_response(
self,
response: responses.Response,
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
"""Process a non-streamed response, and prepare a message to return."""
timestamp = number_to_datetime(response.created_at)
items: list[ModelResponsePart] = []
Expand All @@ -795,6 +800,16 @@ def _process_response(self, response: responses.Response) -> ModelResponse:
items.append(TextPart(content.text))
elif item.type == 'function_call':
items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id))
elif item.type == 'custom_tool_call':
if item.name not in model_request_parameters.tool_defs:
raise UnexpectedModelBehavior(f'Unknown tool called: {item.name}')
tool = model_request_parameters.tool_defs[item.name]
argument_name = tool.single_string_argument_name
if argument_name is None:
raise UnexpectedModelBehavior(
f'Custom tool call made to function {item.name} which has unexpected arguments'
)
items.append(ToolCallPart(item.name, {argument_name: item.input}, tool_call_id=item.call_id))
return ModelResponse(
items,
usage=_map_usage(response),
Expand Down Expand Up @@ -893,11 +908,14 @@ async def _responses_create(
try:
extra_headers = model_settings.get('extra_headers', {})
extra_headers.setdefault('User-Agent', get_user_agent())
parallel_tool_calls = self._get_parallel_tool_calling(
model_settings=model_settings, model_request_parameters=model_request_parameters
)
return await self.client.responses.create(
input=openai_messages,
model=self._model_name,
instructions=instructions,
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
parallel_tool_calls=parallel_tool_calls,
tools=tools or NOT_GIVEN,
tool_choice=tool_choice or NOT_GIVEN,
max_output_tokens=model_settings.get('max_tokens', NOT_GIVEN),
Expand Down Expand Up @@ -937,7 +955,18 @@ def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reason
return NOT_GIVEN
return Reasoning(effort=reasoning_effort, summary=reasoning_summary)

def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.FunctionToolParam]:
def _get_parallel_tool_calling(
self, model_settings: OpenAIResponsesModelSettings, model_request_parameters: ModelRequestParameters
) -> bool | NotGiven:
if any(tool_definition.text_format for tool_definition in model_request_parameters.tool_defs.values()):
return False
if any(tool_definition.text_format for tool_definition in model_request_parameters.output_tools):
return False
return model_settings.get('parallel_tool_calls', NOT_GIVEN)

def _get_tools(
self, model_request_parameters: ModelRequestParameters
) -> list[responses.FunctionToolParam | responses.CustomToolParam]:
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]

def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.ToolParam]:
Expand All @@ -960,15 +989,35 @@ def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -
)
return tools

def _map_tool_definition(self, f: ToolDefinition) -> responses.FunctionToolParam:
def _map_tool_definition(self, f: ToolDefinition) -> responses.FunctionToolParam | responses.CustomToolParam:
model_profile = OpenAIModelProfile.from_profile(self.profile)
if f.text_format:
if not model_profile.openai_supports_freeform_function_calling:
raise UserError(
f'`{f.name}` is uses free-form function calling but {self._model_name} does not support free form function calling.'
)
if not f.only_takes_string_argument:
raise UserError(
f'`{f.name}` is set as a free-form function but does not take a single string argument.'
)
if f.text_format == 'text':
format: CustomToolInputFormat = {'type': 'text'}
else:
format = {'type': 'grammar', 'syntax': f.text_format.syntax, 'definition': f.text_format.grammar}
tool_param: responses.CustomToolParam = {
'name': f.name,
'type': 'custom',
'description': f.description or '',
'format': format,
}
return tool_param

return {
'name': f.name,
'parameters': f.parameters_json_schema,
'type': 'function',
'description': f.description,
'strict': bool(
f.strict and OpenAIModelProfile.from_profile(self.profile).openai_supports_strict_tool_definition
),
'strict': bool(f.strict and model_profile.openai_supports_strict_tool_definition),
}

async def _map_messages(
Expand Down
6 changes: 5 additions & 1 deletion pydantic_ai_slim/pydantic_ai/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from . import _utils
from .messages import ToolCallPart
from .tools import RunContext, ToolDefinition
from .tools import FunctionTextFormat, RunContext, ToolDefinition

__all__ = (
# classes
Expand Down Expand Up @@ -112,6 +112,8 @@ class Vehicle(BaseModel):
"""The maximum number of retries for the tool."""
strict: bool | None
"""Whether to use strict mode for the tool."""
text_format: Literal['text'] | FunctionTextFormat | None = None
"""Whether to invoke the function with free-form function calling for tool calls."""

def __init__(
self,
Expand All @@ -121,12 +123,14 @@ def __init__(
description: str | None = None,
max_retries: int | None = None,
strict: bool | None = None,
text_format: Literal['text'] | FunctionTextFormat | None = None,
):
self.output = type_
self.name = name
self.description = description
self.max_retries = max_retries
self.strict = strict
self.text_format = text_format


@dataclass(init=False)
Expand Down
7 changes: 7 additions & 0 deletions pydantic_ai_slim/pydantic_ai/profiles/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,16 @@ class OpenAIModelProfile(ModelProfile):
openai_system_prompt_role: OpenAISystemPromptRole | None = None
"""The role to use for the system prompt message. If not provided, defaults to `'system'`."""

# GPT-5 introduced support for directly calling a function with a string.
openai_supports_freeform_function_calling: bool = False
"""Whether the provider accepts the value ``type='custom'`` for tools in the
request payload."""


def openai_model_profile(model_name: str) -> ModelProfile:
"""Get the model profile for an OpenAI model."""
is_reasoning_model = model_name.startswith('o') or model_name.startswith('gpt-5')
is_freeform_function_calling_model = model_name.startswith('gpt-5')
# Structured Outputs (output mode 'native') is only supported with the gpt-4o-mini, gpt-4o-mini-2024-07-18, and gpt-4o-2024-08-06 model snapshots and later.
# We leave it in here for all models because the `default_structured_output_mode` is `'tool'`, so `native` is only used
# when the user specifically uses the `NativeOutput` marker, so an error from the API is acceptable.
Expand All @@ -50,6 +56,7 @@ def openai_model_profile(model_name: str) -> ModelProfile:
supports_json_schema_output=True,
supports_json_object_output=True,
openai_supports_sampling_settings=not is_reasoning_model,
openai_supports_freeform_function_calling=is_freeform_function_calling_model,
openai_system_prompt_role=openai_system_prompt_role,
)

Expand Down
Loading
Loading