Skip to content

Support for more activity tool inputs #923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 34 additions & 21 deletions temporalio/contrib/openai_agents/temporal_tools.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
"""Support for using Temporal activities as OpenAI agents tools."""

import json
from datetime import timedelta
from typing import Any, Callable, Optional

from temporalio import activity, workflow
from temporalio.common import Priority, RetryPolicy
from temporalio.exceptions import ApplicationError
from temporalio.exceptions import ApplicationError, TemporalError
from temporalio.workflow import ActivityCancellationType, VersioningIntent, unsafe

with unsafe.imports_passed_through():
from agents import FunctionTool, RunContextWrapper, Tool
from agents.function_schema import function_schema


class ToolSerializationError(TemporalError):
"""Error that occurs when a tool output could not be serialized."""


def activity_as_tool(
fn: Callable,
*,
Expand Down Expand Up @@ -69,32 +74,40 @@ def activity_as_tool(
"Bare function without tool and activity decorators is not supported",
"invalid_tool",
)
schema = function_schema(fn)

async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
try:
return str(
await workflow.execute_activity(
fn,
input,
task_queue=task_queue,
schedule_to_close_timeout=schedule_to_close_timeout,
schedule_to_start_timeout=schedule_to_start_timeout,
start_to_close_timeout=start_to_close_timeout,
heartbeat_timeout=heartbeat_timeout,
retry_policy=retry_policy,
cancellation_type=cancellation_type,
activity_id=activity_id,
versioning_intent=versioning_intent,
summary=summary,
priority=priority,
)
)
except Exception:
json_data = json.loads(input)
except Exception as e:
raise ApplicationError(
f"Invalid JSON input for tool {schema.name}: {input}"
) from e

# Activities don't support keyword only arguments, so we can ignore the kwargs_dict return
args, _ = schema.to_call_args(schema.params_pydantic_model(**json_data))
result = await workflow.execute_activity(
fn,
args=args,
task_queue=task_queue,
schedule_to_close_timeout=schedule_to_close_timeout,
schedule_to_start_timeout=schedule_to_start_timeout,
start_to_close_timeout=start_to_close_timeout,
heartbeat_timeout=heartbeat_timeout,
retry_policy=retry_policy,
cancellation_type=cancellation_type,
activity_id=activity_id,
versioning_intent=versioning_intent,
summary=summary,
priority=priority,
)
try:
return str(result)
except Exception as e:
raise ToolSerializationError(
"You must return a string representation of the tool output, or something we can call str() on"
)
) from e

schema = function_schema(fn)
return FunctionTool(
name=schema.name,
description=schema.description or "",
Expand Down
94 changes: 90 additions & 4 deletions tests/contrib/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,29 @@ async def get_weather(city: str) -> Weather:
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")


@activity.defn
async def get_weather_country(city: str, country: str) -> Weather:
"""
Get the weather for a given city in a country.
"""
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")


@dataclass
class WeatherInput:
city: str


@activity.defn
async def get_weather_object(input: WeatherInput) -> Weather:
"""
Get the weather for a given city.
"""
return Weather(
city=input.city, temperature_range="14-20C", conditions="Sunny with wind."
)


class TestWeatherModel(TestModel):
responses = [
ModelResponse(
Expand All @@ -235,6 +258,34 @@ class TestWeatherModel(TestModel):
usage=Usage(),
response_id=None,
),
ModelResponse(
output=[
ResponseFunctionToolCall(
arguments='{"input":{"city":"Tokyo"}}',
call_id="call",
name="get_weather_object",
type="function_call",
id="id",
status="completed",
)
],
usage=Usage(),
response_id=None,
),
ModelResponse(
output=[
ResponseFunctionToolCall(
arguments='{"city":"Tokyo","country":"Japan"}',
call_id="call",
name="get_weather_country",
type="function_call",
id="id",
status="completed",
)
],
usage=Usage(),
response_id=None,
),
ModelResponse(
output=[
ResponseOutputMessage(
Expand Down Expand Up @@ -267,7 +318,13 @@ async def run(self, question: str) -> str:
tools=[
activity_as_tool(
get_weather, start_to_close_timeout=timedelta(seconds=10)
)
),
activity_as_tool(
get_weather_object, start_to_close_timeout=timedelta(seconds=10)
),
activity_as_tool(
get_weather_country, start_to_close_timeout=timedelta(seconds=10)
),
],
) # type: Agent
result = await Runner.run(starting_agent=agent, input=question)
Expand All @@ -291,7 +348,12 @@ async def test_tool_workflow(client: Client):
async with new_worker(
client,
ToolsWorkflow,
activities=[model_activity.invoke_model_activity, get_weather],
activities=[
model_activity.invoke_model_activity,
get_weather,
get_weather_object,
get_weather_country,
],
interceptors=[OpenAIAgentsTracingInterceptor()],
) as worker:
workflow_handle = await client.start_workflow(
Expand All @@ -309,7 +371,7 @@ async def test_tool_workflow(client: Client):
if e.HasField("activity_task_completed_event_attributes"):
events.append(e)

assert len(events) == 3
assert len(events) == 7
assert (
"function_call"
in events[0]
Expand All @@ -323,11 +385,35 @@ async def test_tool_workflow(client: Client):
.data.decode()
)
assert (
"Test weather result"
"function_call"
in events[2]
.activity_task_completed_event_attributes.result.payloads[0]
.data.decode()
)
assert (
"Sunny with wind"
in events[3]
.activity_task_completed_event_attributes.result.payloads[0]
.data.decode()
)
assert (
"function_call"
in events[4]
.activity_task_completed_event_attributes.result.payloads[0]
.data.decode()
)
assert (
"Sunny with wind"
in events[5]
.activity_task_completed_event_attributes.result.payloads[0]
.data.decode()
)
assert (
"Test weather result"
in events[6]
.activity_task_completed_event_attributes.result.payloads[0]
.data.decode()
)


class TestPlannerModel(OpenAIResponsesModel):
Expand Down