diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 08943a8312..84cbd16504 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -77,15 +77,43 @@ def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema: if schema.get('type') == 'object': return schema elif schema.get('$ref') is not None: - maybe_result = schema.get('$defs', {}).get(schema['$ref'][8:]) # This removes the initial "#/$defs/". - - if "'$ref': '#/$defs/" in str(maybe_result): - return schema # We can't remove the $defs because the schema contains other references - return maybe_result + ref = schema['$ref'] + if ref.startswith('#/$defs/'): + ref_name = ref[8:] # Remove "#/$defs/" prefix + defs = schema.get('$defs', {}) + if ref_name in defs: + resolved = defs[ref_name] + # Check if the resolved schema contains nested references. + # This is necessary because if we inline a schema that itself contains + # $ref references, those references won't be resolvable without the $defs. + # The old code used fragile string matching; this uses proper recursive checking. + if _contains_ref(resolved): + # Keep the $defs because they're needed for nested references + return schema + return resolved + # For non-local refs or unresolvable refs, return the schema as-is + return schema else: raise UserError('Schema must be an object') +def _contains_ref(obj: Any) -> bool: + """Recursively check if an object contains any $ref keys.""" + if isinstance(obj, dict): + if '$ref' in obj: + return True + for v in obj.values(): # pyright: ignore[reportUnknownVariableType] + if isinstance(v, (dict, list)) and _contains_ref(v): + return True + return False + elif isinstance(obj, list): + for item in obj: # pyright: ignore[reportUnknownVariableType] + if isinstance(item, (dict, list)) and _contains_ref(item): + return True + return False + return False + + T = TypeVar('T') diff --git a/pydantic_ai_slim/pydantic_ai/output.py b/pydantic_ai_slim/pydantic_ai/output.py index d61eb61748..eecea9dc3c 100644 --- a/pydantic_ai_slim/pydantic_ai/output.py +++ b/pydantic_ai_slim/pydantic_ai/output.py @@ -10,6 +10,7 @@ from typing_extensions import TypeAliasType, TypeVar from . import _utils +from .exceptions import UserError from .messages import ToolCallPart from .tools import RunContext, ToolDefinition @@ -304,6 +305,19 @@ def StructuredDict( """ json_schema = _utils.check_object_json_schema(json_schema) + # If the schema contains $defs, inline them to avoid issues with pydantic's + # JSON schema generator (Issue #2466) + if '$defs' in json_schema: + from .profiles import InlineDefsJsonSchemaTransformer + + try: + transformer = InlineDefsJsonSchemaTransformer(json_schema) + json_schema = transformer.walk() + except UserError: + # If the transformer can't resolve refs (e.g., missing definitions), + # keep the original schema unchanged + pass + if name: json_schema['title'] = name diff --git a/tests/test_agent.py b/tests/test_agent.py index d063b10117..4c5e7d4274 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -46,7 +46,11 @@ ) from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.output import DeferredToolCalls, StructuredDict, ToolOutput +from pydantic_ai.output import ( + DeferredToolCalls, + StructuredDict, + ToolOutput, +) from pydantic_ai.profiles import ModelProfile from pydantic_ai.result import Usage from pydantic_ai.tools import ToolDefinition @@ -1362,6 +1366,64 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ) +def test_output_type_structured_dict_nested(): + """Test StructuredDict with nested JSON schemas using $ref - Issue #2466.""" + # Schema with nested $ref that pydantic's generator can't resolve + CarDict = StructuredDict( + { + '$defs': { + 'Tire': { + 'type': 'object', + 'properties': {'brand': {'type': 'string'}, 'size': {'type': 'integer'}}, + 'required': ['brand', 'size'], + } + }, + 'type': 'object', + 'properties': { + 'make': {'type': 'string'}, + 'model': {'type': 'string'}, + 'tires': {'type': 'array', 'items': {'$ref': '#/$defs/Tire'}}, + }, + 'required': ['make', 'model', 'tires'], + }, + name='Car', + description='A car with tires', + ) + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + # Verify the output tool schema has been properly transformed + # The $refs should be inlined by InlineDefsJsonSchemaTransformer + output_tool = info.output_tools[0] + assert output_tool.parameters_json_schema is not None + schema = output_tool.parameters_json_schema + + # Check that the Tire definition has been inlined in the tires array items + assert 'properties' in schema + assert 'tires' in schema['properties'] + tires_schema = schema['properties']['tires'] + assert tires_schema['type'] == 'array' + + # The $ref should have been resolved to the actual Tire schema + items_schema = tires_schema['items'] + assert '$ref' not in items_schema # Should be inlined, not a ref + assert items_schema['type'] == 'object' + assert 'properties' in items_schema + assert 'brand' in items_schema['properties'] + assert 'size' in items_schema['properties'] + + args_json = '{"make": "Toyota", "model": "Camry", "tires": [{"brand": "Michelin", "size": 17}]}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=CarDict) + + result = agent.run_sync('Generate a car') + + expected = {'make': 'Toyota', 'model': 'Camry', 'tires': [{'brand': 'Michelin', 'size': 17}]} + assert result.output == expected + + def test_default_structured_output_mode(): def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: return ModelResponse(parts=[TextPart(content='hello')]) # pragma: no cover