From 705f0076540df32a9c0517095293e52386fc7c2c Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Fri, 27 Jun 2025 09:30:12 -0700 Subject: [PATCH 1/3] Adding tests for guardrail samples --- tests/contrib/test_openai.py | 341 +++++++++++++++++++++++++++++++++++ 1 file changed, 341 insertions(+) diff --git a/tests/contrib/test_openai.py b/tests/contrib/test_openai.py index 295a2a855..e7cf6c93c 100644 --- a/tests/contrib/test_openai.py +++ b/tests/contrib/test_openai.py @@ -6,6 +6,7 @@ from typing import Any, Optional, Union import pytest +from pydantic import ConfigDict, Field from temporalio import activity, workflow from temporalio.client import Client, WorkflowFailureError, WorkflowHandle @@ -46,6 +47,11 @@ function_tool, handoff, trace, + input_guardrail, + GuardrailFunctionOutput, + InputGuardrailTripwireTriggered, + output_guardrail, + OutputGuardrailTripwireTriggered, ) from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX from agents.items import ( @@ -1146,3 +1152,338 @@ async def test_customer_service_workflow(client: Client): .activity_task_completed_event_attributes.result.payloads[0] .data.decode() ) + + +guardrail_response_index: int = 0 + + +class InputGuardrailModel(OpenAIResponsesModel): + __test__ = False + responses: list[ModelResponse] = [ + ModelResponse( + output=[ + ResponseOutputMessage( + id="", + content=[ + ResponseOutputText( + text="The capital of California is Sacramento.", + annotations=[], + type="output_text", + ) + ], + role="assistant", + status="completed", + type="message", + ) + ], + usage=Usage(), + response_id=None, + ), + ModelResponse( + output=[ + ResponseOutputMessage( + id="", + content=[ + ResponseOutputText( + text="x=3", + annotations=[], + type="output_text", + ) + ], + role="assistant", + status="completed", + type="message", + ) + ], + usage=Usage(), + response_id=None, + ), + ] + guardrail_responses = [ + ModelResponse( + output=[ + ResponseOutputMessage( + id="", + content=[ + ResponseOutputText( + text='{"is_math_homework":false,"reasoning":"The question asked is about the capital of California, which is a geography-related query, not math."}', + annotations=[], + type="output_text", + ) + ], + role="assistant", + status="completed", + type="message", + ) + ], + usage=Usage(), + response_id=None, + ), + ModelResponse( + output=[ + ResponseOutputMessage( + id="", + content=[ + ResponseOutputText( + text='{"is_math_homework":true,"reasoning":"The question involves solving an equation for a variable, which is a typical math homework problem."}', + annotations=[], + type="output_text", + ) + ], + role="assistant", + status="completed", + type="message", + ) + ], + usage=Usage(), + response_id=None, + ), + ] + + def __init__( + self, + model: str, + openai_client: AsyncOpenAI, + ) -> None: + global response_index + response_index = 0 + global guardrail_response_index + guardrail_response_index = 0 + super().__init__(model, openai_client) + + async def get_response( + self, + system_instructions: Union[str, None], + input: Union[str, list[TResponseInputItem]], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: Union[AgentOutputSchemaBase, None], + handoffs: list[Handoff], + tracing: ModelTracing, + previous_response_id: Union[str, None], + prompt: Union[ResponsePromptParam, None] = None, + ) -> ModelResponse: + if ( + system_instructions + == "Check if the user is asking you to do their math homework." + ): + global guardrail_response_index + response = self.guardrail_responses[guardrail_response_index] + guardrail_response_index += 1 + return response + else: + global response_index + response = self.responses[response_index] + response_index += 1 + return response + + +### 1. An agent-based guardrail that is triggered if the user is asking to do math homework +class MathHomeworkOutput(BaseModel): + reasoning: str + is_math_homework: bool + model_config = ConfigDict(extra="forbid") + + +guardrail_agent = Agent( + name="Guardrail check", + instructions="Check if the user is asking you to do their math homework.", + output_type=MathHomeworkOutput, +) + + +@input_guardrail +async def math_guardrail( + context: RunContextWrapper[None], + agent: Agent, + input: str | list[TResponseInputItem], +) -> GuardrailFunctionOutput: + """This is an input guardrail function, which happens to call an agent to check if the input + is a math homework question. + """ + result = await Runner.run(guardrail_agent, input, context=context.context) + final_output = result.final_output_as(MathHomeworkOutput) + + return GuardrailFunctionOutput( + output_info=final_output, + tripwire_triggered=final_output.is_math_homework, + ) + + +@workflow.defn +class InputGuardrailWorkflow: + @workflow.run + async def run(self, messages: list[str]) -> list[str]: + agent = Agent( + name="Customer support agent", + instructions="You are a customer support agent. You help customers with their questions.", + input_guardrails=[math_guardrail], + ) + + input_data: list[TResponseInputItem] = [] + results: list[str] = [] + + for user_input in messages: + input_data.append( + { + "role": "user", + "content": user_input, + } + ) + + try: + result = await Runner.run(agent, input_data) + results.append(result.final_output) + # If the guardrail didn't trigger, we use the result as the input for the next run + input_data = result.to_input_list() + except InputGuardrailTripwireTriggered: + # If the guardrail triggered, we instead add a refusal message to the input + message = "Sorry, I can't help you with your math homework." + results.append(message) + input_data.append( + { + "role": "assistant", + "content": message, + } + ) + return results + + +async def test_input_guardrail(client: Client): + new_config = client.config() + new_config["data_converter"] = open_ai_data_converter + client = Client(**new_config) + + model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=10)) + with set_open_ai_agent_temporal_overrides(model_params): + model_activity = ModelActivity( + TestProvider( + InputGuardrailModel( # type: ignore + "", openai_client=AsyncOpenAI(api_key="Fake key") + ) + ) + ) + async with new_worker( + client, + InputGuardrailWorkflow, + activities=[model_activity.invoke_model_activity], + interceptors=[OpenAIAgentsTracingInterceptor()], + ) as worker: + workflow_handle = await client.start_workflow( + InputGuardrailWorkflow.run, + [ + "What's the capital of California?", + "Can you help me solve for x: 2x + 5 = 11", + ], + id=f"input-guardrail-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + result = await workflow_handle.result() + assert len(result) == 2 + assert result[0] == "The capital of California is Sacramento." + assert result[1] == "Sorry, I can't help you with your math homework." + + +class OutputGuardrailModel(TestModel): + responses = [ + ModelResponse( + output=[ + ResponseOutputMessage( + id="", + content=[ + ResponseOutputText( + text='{"reasoning":"The phone number\'s area code (650) is associated with a region. However, the exact location is not definitive, but it\'s commonly linked to the San Francisco Peninsula in California, including cities like San Mateo, Palo Alto, and parts of Silicon Valley. It\'s important to note that area codes don\'t always guarantee a specific location due to mobile number portability.","response":"The area code 650 is typically associated with California, particularly the San Francisco Peninsula, including cities like Palo Alto and San Mateo.","user_name":null}', + annotations=[], + type="output_text", + ) + ], + role="assistant", + status="completed", + type="message", + ) + ], + usage=Usage(), + response_id=None, + ) + ] + + +# The agent's output type +class MessageOutput(BaseModel): + reasoning: str = Field( + description="Thoughts on how to respond to the user's message" + ) + response: str = Field(description="The response to the user's message") + user_name: str | None = Field( + description="The name of the user who sent the message, if known" + ) + model_config = ConfigDict(extra="forbid") + + +@output_guardrail +async def sensitive_data_check( + context: RunContextWrapper, agent: Agent, output: MessageOutput +) -> GuardrailFunctionOutput: + phone_number_in_response = "650" in output.response + phone_number_in_reasoning = "650" in output.reasoning + + return GuardrailFunctionOutput( + output_info={ + "phone_number_in_response": phone_number_in_response, + "phone_number_in_reasoning": phone_number_in_reasoning, + }, + tripwire_triggered=phone_number_in_response or phone_number_in_reasoning, + ) + + +output_guardrail_agent = Agent( + name="Assistant", + instructions="You are a helpful assistant.", + output_type=MessageOutput, + output_guardrails=[sensitive_data_check], +) + + +@workflow.defn +class OutputGuardrailWorkflow: + @workflow.run + async def run(self) -> bool: + try: + await Runner.run( + output_guardrail_agent, + "My phone number is 650-123-4567. Where do you think I live?", + ) + return True + except OutputGuardrailTripwireTriggered: + return False + + +async def test_output_guardrail(client: Client): + new_config = client.config() + new_config["data_converter"] = open_ai_data_converter + client = Client(**new_config) + + model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=10)) + with set_open_ai_agent_temporal_overrides(model_params): + model_activity = ModelActivity( + TestProvider( + OutputGuardrailModel( # type: ignore + "", openai_client=AsyncOpenAI(api_key="Fake key") + ) + ) + ) + async with new_worker( + client, + OutputGuardrailWorkflow, + activities=[model_activity.invoke_model_activity], + interceptors=[OpenAIAgentsTracingInterceptor()], + ) as worker: + workflow_handle = await client.start_workflow( + OutputGuardrailWorkflow.run, + id=f"output-guardrail-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + result = await workflow_handle.result() + assert not result From e5489e49fe8cd2d203f4755934a4d0eee3b1f8c7 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Fri, 27 Jun 2025 12:06:31 -0700 Subject: [PATCH 2/3] Fix lint --- tests/contrib/test_openai.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/contrib/test_openai.py b/tests/contrib/test_openai.py index e7cf6c93c..40152f811 100644 --- a/tests/contrib/test_openai.py +++ b/tests/contrib/test_openai.py @@ -30,7 +30,9 @@ from agents import ( Agent, AgentOutputSchemaBase, + GuardrailFunctionOutput, Handoff, + InputGuardrailTripwireTriggered, ItemHelpers, MessageOutputItem, Model, @@ -39,6 +41,7 @@ ModelSettings, ModelTracing, OpenAIResponsesModel, + OutputGuardrailTripwireTriggered, RunContextWrapper, Runner, Tool, @@ -46,12 +49,9 @@ Usage, function_tool, handoff, - trace, input_guardrail, - GuardrailFunctionOutput, - InputGuardrailTripwireTriggered, output_guardrail, - OutputGuardrailTripwireTriggered, + trace, ) from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX from agents.items import ( @@ -1285,7 +1285,7 @@ class MathHomeworkOutput(BaseModel): model_config = ConfigDict(extra="forbid") -guardrail_agent = Agent( +guardrail_agent: Agent = Agent( name="Guardrail check", instructions="Check if the user is asking you to do their math homework.", output_type=MathHomeworkOutput, From dd2b0467bc8df3e5a936211a826b39a33b047383 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Fri, 27 Jun 2025 12:17:24 -0700 Subject: [PATCH 3/3] Remove | syntax --- tests/contrib/test_openai.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/contrib/test_openai.py b/tests/contrib/test_openai.py index 40152f811..4e1452ce5 100644 --- a/tests/contrib/test_openai.py +++ b/tests/contrib/test_openai.py @@ -1296,7 +1296,7 @@ class MathHomeworkOutput(BaseModel): async def math_guardrail( context: RunContextWrapper[None], agent: Agent, - input: str | list[TResponseInputItem], + input: Union[str, list[TResponseInputItem]], ) -> GuardrailFunctionOutput: """This is an input guardrail function, which happens to call an agent to check if the input is a math homework question. @@ -1415,7 +1415,7 @@ class MessageOutput(BaseModel): description="Thoughts on how to respond to the user's message" ) response: str = Field(description="The response to the user's message") - user_name: str | None = Field( + user_name: Optional[str] = Field( description="The name of the user who sent the message, if known" ) model_config = ConfigDict(extra="forbid")