Skip to content

Adding tests for guardrail samples #929

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
341 changes: 341 additions & 0 deletions tests/contrib/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Optional, Union, no_type_check

import pytest
from pydantic import ConfigDict, Field

from temporalio import activity, workflow
from temporalio.client import Client, WorkflowFailureError, WorkflowHandle
Expand All @@ -29,7 +30,9 @@
from agents import (
Agent,
AgentOutputSchemaBase,
GuardrailFunctionOutput,
Handoff,
InputGuardrailTripwireTriggered,
ItemHelpers,
MessageOutputItem,
Model,
Expand All @@ -38,13 +41,16 @@
ModelSettings,
ModelTracing,
OpenAIResponsesModel,
OutputGuardrailTripwireTriggered,
RunContextWrapper,
Runner,
Tool,
TResponseInputItem,
Usage,
function_tool,
handoff,
input_guardrail,
output_guardrail,
trace,
)
from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX
Expand Down Expand Up @@ -1151,3 +1157,338 @@ async def test_customer_service_workflow(client: Client):
.activity_task_completed_event_attributes.result.payloads[0]
.data.decode()
)


guardrail_response_index: int = 0


class InputGuardrailModel(OpenAIResponsesModel):
__test__ = False
responses: list[ModelResponse] = [
ModelResponse(
output=[
ResponseOutputMessage(
id="",
content=[
ResponseOutputText(
text="The capital of California is Sacramento.",
annotations=[],
type="output_text",
)
],
role="assistant",
status="completed",
type="message",
)
],
usage=Usage(),
response_id=None,
),
ModelResponse(
output=[
ResponseOutputMessage(
id="",
content=[
ResponseOutputText(
text="x=3",
annotations=[],
type="output_text",
)
],
role="assistant",
status="completed",
type="message",
)
],
usage=Usage(),
response_id=None,
),
]
guardrail_responses = [
ModelResponse(
output=[
ResponseOutputMessage(
id="",
content=[
ResponseOutputText(
text='{"is_math_homework":false,"reasoning":"The question asked is about the capital of California, which is a geography-related query, not math."}',
annotations=[],
type="output_text",
)
],
role="assistant",
status="completed",
type="message",
)
],
usage=Usage(),
response_id=None,
),
ModelResponse(
output=[
ResponseOutputMessage(
id="",
content=[
ResponseOutputText(
text='{"is_math_homework":true,"reasoning":"The question involves solving an equation for a variable, which is a typical math homework problem."}',
annotations=[],
type="output_text",
)
],
role="assistant",
status="completed",
type="message",
)
],
usage=Usage(),
response_id=None,
),
]

def __init__(
self,
model: str,
openai_client: AsyncOpenAI,
) -> None:
global response_index
response_index = 0
global guardrail_response_index
guardrail_response_index = 0
super().__init__(model, openai_client)

async def get_response(
self,
system_instructions: Union[str, None],
input: Union[str, list[TResponseInputItem]],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: Union[AgentOutputSchemaBase, None],
handoffs: list[Handoff],
tracing: ModelTracing,
previous_response_id: Union[str, None],
prompt: Union[ResponsePromptParam, None] = None,
) -> ModelResponse:
if (
system_instructions
== "Check if the user is asking you to do their math homework."
):
global guardrail_response_index
response = self.guardrail_responses[guardrail_response_index]
guardrail_response_index += 1
return response
else:
global response_index
response = self.responses[response_index]
response_index += 1
return response


### 1. An agent-based guardrail that is triggered if the user is asking to do math homework
class MathHomeworkOutput(BaseModel):
reasoning: str
is_math_homework: bool
model_config = ConfigDict(extra="forbid")


guardrail_agent: Agent = Agent(
name="Guardrail check",
instructions="Check if the user is asking you to do their math homework.",
output_type=MathHomeworkOutput,
)


@input_guardrail
async def math_guardrail(
context: RunContextWrapper[None],
agent: Agent,
input: Union[str, list[TResponseInputItem]],
) -> GuardrailFunctionOutput:
"""This is an input guardrail function, which happens to call an agent to check if the input
is a math homework question.
"""
result = await Runner.run(guardrail_agent, input, context=context.context)
final_output = result.final_output_as(MathHomeworkOutput)

return GuardrailFunctionOutput(
output_info=final_output,
tripwire_triggered=final_output.is_math_homework,
)


@workflow.defn
class InputGuardrailWorkflow:
@workflow.run
async def run(self, messages: list[str]) -> list[str]:
agent = Agent(
name="Customer support agent",
instructions="You are a customer support agent. You help customers with their questions.",
input_guardrails=[math_guardrail],
)

input_data: list[TResponseInputItem] = []
results: list[str] = []

for user_input in messages:
input_data.append(
{
"role": "user",
"content": user_input,
}
)

try:
result = await Runner.run(agent, input_data)
results.append(result.final_output)
# If the guardrail didn't trigger, we use the result as the input for the next run
input_data = result.to_input_list()
except InputGuardrailTripwireTriggered:
# If the guardrail triggered, we instead add a refusal message to the input
message = "Sorry, I can't help you with your math homework."
results.append(message)
input_data.append(
{
"role": "assistant",
"content": message,
}
)
return results


async def test_input_guardrail(client: Client):
new_config = client.config()
new_config["data_converter"] = open_ai_data_converter
client = Client(**new_config)

model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=10))
with set_open_ai_agent_temporal_overrides(model_params):
model_activity = ModelActivity(
TestProvider(
InputGuardrailModel( # type: ignore
"", openai_client=AsyncOpenAI(api_key="Fake key")
)
)
)
async with new_worker(
client,
InputGuardrailWorkflow,
activities=[model_activity.invoke_model_activity],
interceptors=[OpenAIAgentsTracingInterceptor()],
) as worker:
workflow_handle = await client.start_workflow(
InputGuardrailWorkflow.run,
[
"What's the capital of California?",
"Can you help me solve for x: 2x + 5 = 11",
],
id=f"input-guardrail-{uuid.uuid4()}",
task_queue=worker.task_queue,
execution_timeout=timedelta(seconds=10),
)
result = await workflow_handle.result()
assert len(result) == 2
assert result[0] == "The capital of California is Sacramento."
assert result[1] == "Sorry, I can't help you with your math homework."


class OutputGuardrailModel(TestModel):
responses = [
ModelResponse(
output=[
ResponseOutputMessage(
id="",
content=[
ResponseOutputText(
text='{"reasoning":"The phone number\'s area code (650) is associated with a region. However, the exact location is not definitive, but it\'s commonly linked to the San Francisco Peninsula in California, including cities like San Mateo, Palo Alto, and parts of Silicon Valley. It\'s important to note that area codes don\'t always guarantee a specific location due to mobile number portability.","response":"The area code 650 is typically associated with California, particularly the San Francisco Peninsula, including cities like Palo Alto and San Mateo.","user_name":null}',
annotations=[],
type="output_text",
)
],
role="assistant",
status="completed",
type="message",
)
],
usage=Usage(),
response_id=None,
)
]


# The agent's output type
class MessageOutput(BaseModel):
reasoning: str = Field(
description="Thoughts on how to respond to the user's message"
)
response: str = Field(description="The response to the user's message")
user_name: Optional[str] = Field(
description="The name of the user who sent the message, if known"
)
model_config = ConfigDict(extra="forbid")


@output_guardrail
async def sensitive_data_check(
context: RunContextWrapper, agent: Agent, output: MessageOutput
) -> GuardrailFunctionOutput:
phone_number_in_response = "650" in output.response
phone_number_in_reasoning = "650" in output.reasoning

return GuardrailFunctionOutput(
output_info={
"phone_number_in_response": phone_number_in_response,
"phone_number_in_reasoning": phone_number_in_reasoning,
},
tripwire_triggered=phone_number_in_response or phone_number_in_reasoning,
)


output_guardrail_agent = Agent(
name="Assistant",
instructions="You are a helpful assistant.",
output_type=MessageOutput,
output_guardrails=[sensitive_data_check],
)


@workflow.defn
class OutputGuardrailWorkflow:
@workflow.run
async def run(self) -> bool:
try:
await Runner.run(
output_guardrail_agent,
"My phone number is 650-123-4567. Where do you think I live?",
)
return True
except OutputGuardrailTripwireTriggered:
return False


async def test_output_guardrail(client: Client):
new_config = client.config()
new_config["data_converter"] = open_ai_data_converter
client = Client(**new_config)

model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=10))
with set_open_ai_agent_temporal_overrides(model_params):
model_activity = ModelActivity(
TestProvider(
OutputGuardrailModel( # type: ignore
"", openai_client=AsyncOpenAI(api_key="Fake key")
)
)
)
async with new_worker(
client,
OutputGuardrailWorkflow,
activities=[model_activity.invoke_model_activity],
interceptors=[OpenAIAgentsTracingInterceptor()],
) as worker:
workflow_handle = await client.start_workflow(
OutputGuardrailWorkflow.run,
id=f"output-guardrail-{uuid.uuid4()}",
task_queue=worker.task_queue,
execution_timeout=timedelta(seconds=10),
)
result = await workflow_handle.result()
assert not result