diff --git a/temporalio/contrib/openai_agents/temporal_tools.py b/temporalio/contrib/openai_agents/temporal_tools.py index 81f582df8..e2ba8ed39 100644 --- a/temporalio/contrib/openai_agents/temporal_tools.py +++ b/temporalio/contrib/openai_agents/temporal_tools.py @@ -1,11 +1,12 @@ """Support for using Temporal activities as OpenAI agents tools.""" +import json from datetime import timedelta from typing import Any, Callable, Optional from temporalio import activity, workflow from temporalio.common import Priority, RetryPolicy -from temporalio.exceptions import ApplicationError +from temporalio.exceptions import ApplicationError, TemporalError from temporalio.workflow import ActivityCancellationType, VersioningIntent, unsafe with unsafe.imports_passed_through(): @@ -13,6 +14,10 @@ from agents.function_schema import function_schema +class ToolSerializationError(TemporalError): + """Error that occurs when a tool output could not be serialized.""" + + def activity_as_tool( fn: Callable, *, @@ -69,32 +74,40 @@ def activity_as_tool( "Bare function without tool and activity decorators is not supported", "invalid_tool", ) + schema = function_schema(fn) async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any: try: - return str( - await workflow.execute_activity( - fn, - input, - task_queue=task_queue, - schedule_to_close_timeout=schedule_to_close_timeout, - schedule_to_start_timeout=schedule_to_start_timeout, - start_to_close_timeout=start_to_close_timeout, - heartbeat_timeout=heartbeat_timeout, - retry_policy=retry_policy, - cancellation_type=cancellation_type, - activity_id=activity_id, - versioning_intent=versioning_intent, - summary=summary, - priority=priority, - ) - ) - except Exception: + json_data = json.loads(input) + except Exception as e: raise ApplicationError( + f"Invalid JSON input for tool {schema.name}: {input}" + ) from e + + # Activities don't support keyword only arguments, so we can ignore the kwargs_dict return + args, _ = schema.to_call_args(schema.params_pydantic_model(**json_data)) + result = await workflow.execute_activity( + fn, + args=args, + task_queue=task_queue, + schedule_to_close_timeout=schedule_to_close_timeout, + schedule_to_start_timeout=schedule_to_start_timeout, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + retry_policy=retry_policy, + cancellation_type=cancellation_type, + activity_id=activity_id, + versioning_intent=versioning_intent, + summary=summary, + priority=priority, + ) + try: + return str(result) + except Exception as e: + raise ToolSerializationError( "You must return a string representation of the tool output, or something we can call str() on" - ) + ) from e - schema = function_schema(fn) return FunctionTool( name=schema.name, description=schema.description or "", diff --git a/tests/contrib/test_openai.py b/tests/contrib/test_openai.py index 323b05896..a3e83e53a 100644 --- a/tests/contrib/test_openai.py +++ b/tests/contrib/test_openai.py @@ -219,6 +219,29 @@ async def get_weather(city: str) -> Weather: return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.") +@activity.defn +async def get_weather_country(city: str, country: str) -> Weather: + """ + Get the weather for a given city in a country. + """ + return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.") + + +@dataclass +class WeatherInput: + city: str + + +@activity.defn +async def get_weather_object(input: WeatherInput) -> Weather: + """ + Get the weather for a given city. + """ + return Weather( + city=input.city, temperature_range="14-20C", conditions="Sunny with wind." + ) + + class TestWeatherModel(TestModel): responses = [ ModelResponse( @@ -235,6 +258,34 @@ class TestWeatherModel(TestModel): usage=Usage(), response_id=None, ), + ModelResponse( + output=[ + ResponseFunctionToolCall( + arguments='{"input":{"city":"Tokyo"}}', + call_id="call", + name="get_weather_object", + type="function_call", + id="id", + status="completed", + ) + ], + usage=Usage(), + response_id=None, + ), + ModelResponse( + output=[ + ResponseFunctionToolCall( + arguments='{"city":"Tokyo","country":"Japan"}', + call_id="call", + name="get_weather_country", + type="function_call", + id="id", + status="completed", + ) + ], + usage=Usage(), + response_id=None, + ), ModelResponse( output=[ ResponseOutputMessage( @@ -267,7 +318,13 @@ async def run(self, question: str) -> str: tools=[ activity_as_tool( get_weather, start_to_close_timeout=timedelta(seconds=10) - ) + ), + activity_as_tool( + get_weather_object, start_to_close_timeout=timedelta(seconds=10) + ), + activity_as_tool( + get_weather_country, start_to_close_timeout=timedelta(seconds=10) + ), ], ) # type: Agent result = await Runner.run(starting_agent=agent, input=question) @@ -291,7 +348,12 @@ async def test_tool_workflow(client: Client): async with new_worker( client, ToolsWorkflow, - activities=[model_activity.invoke_model_activity, get_weather], + activities=[ + model_activity.invoke_model_activity, + get_weather, + get_weather_object, + get_weather_country, + ], interceptors=[OpenAIAgentsTracingInterceptor()], ) as worker: workflow_handle = await client.start_workflow( @@ -309,7 +371,7 @@ async def test_tool_workflow(client: Client): if e.HasField("activity_task_completed_event_attributes"): events.append(e) - assert len(events) == 3 + assert len(events) == 7 assert ( "function_call" in events[0] @@ -323,11 +385,35 @@ async def test_tool_workflow(client: Client): .data.decode() ) assert ( - "Test weather result" + "function_call" in events[2] .activity_task_completed_event_attributes.result.payloads[0] .data.decode() ) + assert ( + "Sunny with wind" + in events[3] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "function_call" + in events[4] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Sunny with wind" + in events[5] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Test weather result" + in events[6] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) class TestPlannerModel(OpenAIResponsesModel):