diff --git a/tests/contrib/test_openai.py b/tests/contrib/test_openai.py index 033557d6..323b0589 100644 --- a/tests/contrib/test_openai.py +++ b/tests/contrib/test_openai.py @@ -6,6 +6,7 @@ from typing import Any, Optional, Union, no_type_check import pytest +from pydantic import ConfigDict, Field from temporalio import activity, workflow from temporalio.client import Client, WorkflowFailureError, WorkflowHandle @@ -29,7 +30,9 @@ from agents import ( Agent, AgentOutputSchemaBase, + GuardrailFunctionOutput, Handoff, + InputGuardrailTripwireTriggered, ItemHelpers, MessageOutputItem, Model, @@ -38,6 +41,7 @@ ModelSettings, ModelTracing, OpenAIResponsesModel, + OutputGuardrailTripwireTriggered, RunContextWrapper, Runner, Tool, @@ -45,6 +49,8 @@ Usage, function_tool, handoff, + input_guardrail, + output_guardrail, trace, ) from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX @@ -1151,3 +1157,338 @@ async def test_customer_service_workflow(client: Client): .activity_task_completed_event_attributes.result.payloads[0] .data.decode() ) + + +guardrail_response_index: int = 0 + + +class InputGuardrailModel(OpenAIResponsesModel): + __test__ = False + responses: list[ModelResponse] = [ + ModelResponse( + output=[ + ResponseOutputMessage( + id="", + content=[ + ResponseOutputText( + text="The capital of California is Sacramento.", + annotations=[], + type="output_text", + ) + ], + role="assistant", + status="completed", + type="message", + ) + ], + usage=Usage(), + response_id=None, + ), + ModelResponse( + output=[ + ResponseOutputMessage( + id="", + content=[ + ResponseOutputText( + text="x=3", + annotations=[], + type="output_text", + ) + ], + role="assistant", + status="completed", + type="message", + ) + ], + usage=Usage(), + response_id=None, + ), + ] + guardrail_responses = [ + ModelResponse( + output=[ + ResponseOutputMessage( + id="", + content=[ + ResponseOutputText( + text='{"is_math_homework":false,"reasoning":"The question asked is about the capital of California, which is a geography-related query, not math."}', + annotations=[], + type="output_text", + ) + ], + role="assistant", + status="completed", + type="message", + ) + ], + usage=Usage(), + response_id=None, + ), + ModelResponse( + output=[ + ResponseOutputMessage( + id="", + content=[ + ResponseOutputText( + text='{"is_math_homework":true,"reasoning":"The question involves solving an equation for a variable, which is a typical math homework problem."}', + annotations=[], + type="output_text", + ) + ], + role="assistant", + status="completed", + type="message", + ) + ], + usage=Usage(), + response_id=None, + ), + ] + + def __init__( + self, + model: str, + openai_client: AsyncOpenAI, + ) -> None: + global response_index + response_index = 0 + global guardrail_response_index + guardrail_response_index = 0 + super().__init__(model, openai_client) + + async def get_response( + self, + system_instructions: Union[str, None], + input: Union[str, list[TResponseInputItem]], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: Union[AgentOutputSchemaBase, None], + handoffs: list[Handoff], + tracing: ModelTracing, + previous_response_id: Union[str, None], + prompt: Union[ResponsePromptParam, None] = None, + ) -> ModelResponse: + if ( + system_instructions + == "Check if the user is asking you to do their math homework." + ): + global guardrail_response_index + response = self.guardrail_responses[guardrail_response_index] + guardrail_response_index += 1 + return response + else: + global response_index + response = self.responses[response_index] + response_index += 1 + return response + + +### 1. An agent-based guardrail that is triggered if the user is asking to do math homework +class MathHomeworkOutput(BaseModel): + reasoning: str + is_math_homework: bool + model_config = ConfigDict(extra="forbid") + + +guardrail_agent: Agent = Agent( + name="Guardrail check", + instructions="Check if the user is asking you to do their math homework.", + output_type=MathHomeworkOutput, +) + + +@input_guardrail +async def math_guardrail( + context: RunContextWrapper[None], + agent: Agent, + input: Union[str, list[TResponseInputItem]], +) -> GuardrailFunctionOutput: + """This is an input guardrail function, which happens to call an agent to check if the input + is a math homework question. + """ + result = await Runner.run(guardrail_agent, input, context=context.context) + final_output = result.final_output_as(MathHomeworkOutput) + + return GuardrailFunctionOutput( + output_info=final_output, + tripwire_triggered=final_output.is_math_homework, + ) + + +@workflow.defn +class InputGuardrailWorkflow: + @workflow.run + async def run(self, messages: list[str]) -> list[str]: + agent = Agent( + name="Customer support agent", + instructions="You are a customer support agent. You help customers with their questions.", + input_guardrails=[math_guardrail], + ) + + input_data: list[TResponseInputItem] = [] + results: list[str] = [] + + for user_input in messages: + input_data.append( + { + "role": "user", + "content": user_input, + } + ) + + try: + result = await Runner.run(agent, input_data) + results.append(result.final_output) + # If the guardrail didn't trigger, we use the result as the input for the next run + input_data = result.to_input_list() + except InputGuardrailTripwireTriggered: + # If the guardrail triggered, we instead add a refusal message to the input + message = "Sorry, I can't help you with your math homework." + results.append(message) + input_data.append( + { + "role": "assistant", + "content": message, + } + ) + return results + + +async def test_input_guardrail(client: Client): + new_config = client.config() + new_config["data_converter"] = open_ai_data_converter + client = Client(**new_config) + + model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=10)) + with set_open_ai_agent_temporal_overrides(model_params): + model_activity = ModelActivity( + TestProvider( + InputGuardrailModel( # type: ignore + "", openai_client=AsyncOpenAI(api_key="Fake key") + ) + ) + ) + async with new_worker( + client, + InputGuardrailWorkflow, + activities=[model_activity.invoke_model_activity], + interceptors=[OpenAIAgentsTracingInterceptor()], + ) as worker: + workflow_handle = await client.start_workflow( + InputGuardrailWorkflow.run, + [ + "What's the capital of California?", + "Can you help me solve for x: 2x + 5 = 11", + ], + id=f"input-guardrail-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + result = await workflow_handle.result() + assert len(result) == 2 + assert result[0] == "The capital of California is Sacramento." + assert result[1] == "Sorry, I can't help you with your math homework." + + +class OutputGuardrailModel(TestModel): + responses = [ + ModelResponse( + output=[ + ResponseOutputMessage( + id="", + content=[ + ResponseOutputText( + text='{"reasoning":"The phone number\'s area code (650) is associated with a region. However, the exact location is not definitive, but it\'s commonly linked to the San Francisco Peninsula in California, including cities like San Mateo, Palo Alto, and parts of Silicon Valley. It\'s important to note that area codes don\'t always guarantee a specific location due to mobile number portability.","response":"The area code 650 is typically associated with California, particularly the San Francisco Peninsula, including cities like Palo Alto and San Mateo.","user_name":null}', + annotations=[], + type="output_text", + ) + ], + role="assistant", + status="completed", + type="message", + ) + ], + usage=Usage(), + response_id=None, + ) + ] + + +# The agent's output type +class MessageOutput(BaseModel): + reasoning: str = Field( + description="Thoughts on how to respond to the user's message" + ) + response: str = Field(description="The response to the user's message") + user_name: Optional[str] = Field( + description="The name of the user who sent the message, if known" + ) + model_config = ConfigDict(extra="forbid") + + +@output_guardrail +async def sensitive_data_check( + context: RunContextWrapper, agent: Agent, output: MessageOutput +) -> GuardrailFunctionOutput: + phone_number_in_response = "650" in output.response + phone_number_in_reasoning = "650" in output.reasoning + + return GuardrailFunctionOutput( + output_info={ + "phone_number_in_response": phone_number_in_response, + "phone_number_in_reasoning": phone_number_in_reasoning, + }, + tripwire_triggered=phone_number_in_response or phone_number_in_reasoning, + ) + + +output_guardrail_agent = Agent( + name="Assistant", + instructions="You are a helpful assistant.", + output_type=MessageOutput, + output_guardrails=[sensitive_data_check], +) + + +@workflow.defn +class OutputGuardrailWorkflow: + @workflow.run + async def run(self) -> bool: + try: + await Runner.run( + output_guardrail_agent, + "My phone number is 650-123-4567. Where do you think I live?", + ) + return True + except OutputGuardrailTripwireTriggered: + return False + + +async def test_output_guardrail(client: Client): + new_config = client.config() + new_config["data_converter"] = open_ai_data_converter + client = Client(**new_config) + + model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=10)) + with set_open_ai_agent_temporal_overrides(model_params): + model_activity = ModelActivity( + TestProvider( + OutputGuardrailModel( # type: ignore + "", openai_client=AsyncOpenAI(api_key="Fake key") + ) + ) + ) + async with new_worker( + client, + OutputGuardrailWorkflow, + activities=[model_activity.invoke_model_activity], + interceptors=[OpenAIAgentsTracingInterceptor()], + ) as worker: + workflow_handle = await client.start_workflow( + OutputGuardrailWorkflow.run, + id=f"output-guardrail-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + result = await workflow_handle.result() + assert not result