diff --git a/README.md b/README.md index 5eacd1dec..03ab3c0ce 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,7 @@ informal introduction to the features and their implementation. - [Heartbeating and Cancellation](#heartbeating-and-cancellation) - [Worker Shutdown](#worker-shutdown) - [Testing](#testing-1) + - [Nexus](#nexus) - [Workflow Replay](#workflow-replay) - [Observability](#observability) - [Metrics](#metrics) @@ -1308,6 +1309,7 @@ affect calls activity code might make to functions on the `temporalio.activity` * `cancel()` can be invoked to simulate a cancellation of the activity * `worker_shutdown()` can be invoked to simulate a worker shutdown during execution of the activity + ### Workflow Replay Given a workflow's history, it can be replayed locally to check for things like non-determinism errors. For example, diff --git a/pyproject.toml b/pyproject.toml index 391ea0abc..072ce19c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ keywords = [ "workflow", ] dependencies = [ + "nexus-rpc", "protobuf>=3.20,<6", "python-dateutil>=2.8.2,<3 ; python_version < '3.11'", "types-protobuf>=3.20", @@ -44,7 +45,7 @@ dev = [ "psutil>=5.9.3,<6", "pydocstyle>=6.3.0,<7", "pydoctor>=24.11.1,<25", - "pyright==1.1.377", + "pyright==1.1.400", "pytest~=7.4", "pytest-asyncio>=0.21,<0.22", "pytest-timeout~=2.2", @@ -53,6 +54,8 @@ dev = [ "twine>=4.0.1,<5", "ruff>=0.5.0,<0.6", "maturin>=1.8.2", + "pytest-cov>=6.1.1", + "httpx>=0.28.1", "pytest-pretty>=1.3.0", ] @@ -162,6 +165,7 @@ exclude = [ "tests/worker/workflow_sandbox/testmodules/proto", "temporalio/bridge/worker.py", "temporalio/contrib/opentelemetry.py", + "temporalio/contrib/pydantic.py", "temporalio/converter.py", "temporalio/testing/_workflow.py", "temporalio/worker/_activity.py", @@ -173,6 +177,10 @@ exclude = [ "tests/api/test_grpc_stub.py", "tests/conftest.py", "tests/contrib/test_opentelemetry.py", + "tests/contrib/pydantic/models.py", + "tests/contrib/pydantic/models_2.py", + "tests/contrib/pydantic/test_pydantic.py", + "tests/contrib/pydantic/workflows.py", "tests/test_converter.py", "tests/test_service.py", "tests/test_workflow.py", @@ -208,3 +216,6 @@ exclude = [ [tool.uv] # Prevent uv commands from building the package by default package = false + +[tool.uv.sources] +nexus-rpc = { git = "https://github.com/nexus-rpc/sdk-python" } diff --git a/temporalio/bridge/src/worker.rs b/temporalio/bridge/src/worker.rs index 9dfca82c9..130389259 100644 --- a/temporalio/bridge/src/worker.rs +++ b/temporalio/bridge/src/worker.rs @@ -20,7 +20,7 @@ use temporal_sdk_core_api::worker::{ }; use temporal_sdk_core_api::Worker; use temporal_sdk_core_protos::coresdk::workflow_completion::WorkflowActivationCompletion; -use temporal_sdk_core_protos::coresdk::{ActivityHeartbeat, ActivityTaskCompletion}; +use temporal_sdk_core_protos::coresdk::{ActivityHeartbeat, ActivityTaskCompletion, nexus::NexusTaskCompletion}; use temporal_sdk_core_protos::temporal::api::history::v1::History; use tokio::sync::mpsc::{channel, Sender}; use tokio_stream::wrappers::ReceiverStream; @@ -60,6 +60,7 @@ pub struct WorkerConfig { graceful_shutdown_period_millis: u64, nondeterminism_as_workflow_fail: bool, nondeterminism_as_workflow_fail_for_types: HashSet, + nexus_task_poller_behavior: PollerBehavior, } #[derive(FromPyObject)] @@ -565,6 +566,18 @@ impl WorkerRef { }) } + fn poll_nexus_task<'p>(&self, py: Python<'p>) -> PyResult> { + let worker = self.worker.as_ref().unwrap().clone(); + self.runtime.future_into_py(py, async move { + let bytes = match worker.poll_nexus_task().await { + Ok(task) => task.encode_to_vec(), + Err(PollError::ShutDown) => return Err(PollShutdownError::new_err(())), + Err(err) => return Err(PyRuntimeError::new_err(format!("Poll failure: {}", err))), + }; + Ok(bytes) + }) + } + fn complete_workflow_activation<'p>( &self, py: Python<'p>, @@ -599,6 +612,22 @@ impl WorkerRef { }) } + fn complete_nexus_task<'p>(&self, + py: Python<'p>, + proto: &Bound<'_, PyBytes>, +) -> PyResult> { + let worker = self.worker.as_ref().unwrap().clone(); + let completion = NexusTaskCompletion::decode(proto.as_bytes()) + .map_err(|err| PyValueError::new_err(format!("Invalid proto: {}", err)))?; + self.runtime.future_into_py(py, async move { + worker + .complete_nexus_task(completion) + .await + .context("Completion failure") + .map_err(Into::into) + }) + } + fn record_activity_heartbeat(&self, proto: &Bound<'_, PyBytes>) -> PyResult<()> { enter_sync!(self.runtime); let heartbeat = ActivityHeartbeat::decode(proto.as_bytes()) @@ -696,6 +725,7 @@ fn convert_worker_config( }) .collect::>>(), ) + .nexus_task_poller_behavior(conf.nexus_task_poller_behavior) .build() .map_err(|err| PyValueError::new_err(format!("Invalid worker config: {err}"))) } diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 74cf55bfd..e97563bf1 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -26,6 +26,7 @@ import temporalio.bridge.client import temporalio.bridge.proto import temporalio.bridge.proto.activity_task +import temporalio.bridge.proto.nexus import temporalio.bridge.proto.workflow_activation import temporalio.bridge.proto.workflow_completion import temporalio.bridge.runtime @@ -35,7 +36,7 @@ from temporalio.bridge.temporal_sdk_bridge import ( CustomSlotSupplier as BridgeCustomSlotSupplier, ) -from temporalio.bridge.temporal_sdk_bridge import PollShutdownError +from temporalio.bridge.temporal_sdk_bridge import PollShutdownError # type: ignore @dataclass @@ -60,6 +61,7 @@ class WorkerConfig: graceful_shutdown_period_millis: int nondeterminism_as_workflow_fail: bool nondeterminism_as_workflow_fail_for_types: Set[str] + nexus_task_poller_behavior: PollerBehavior @dataclass @@ -216,6 +218,14 @@ async def poll_activity_task( await self._ref.poll_activity_task() ) + async def poll_nexus_task( + self, + ) -> temporalio.bridge.proto.nexus.NexusTask: + """Poll for a nexus task.""" + return temporalio.bridge.proto.nexus.NexusTask.FromString( + await self._ref.poll_nexus_task() + ) + async def complete_workflow_activation( self, comp: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion, @@ -229,6 +239,12 @@ async def complete_activity_task( """Complete an activity task.""" await self._ref.complete_activity_task(comp.SerializeToString()) + async def complete_nexus_task( + self, comp: temporalio.bridge.proto.nexus.NexusTaskCompletion + ) -> None: + """Complete a nexus task.""" + await self._ref.complete_nexus_task(comp.SerializeToString()) + def record_activity_heartbeat( self, comp: temporalio.bridge.proto.ActivityHeartbeat ) -> None: diff --git a/temporalio/client.py b/temporalio/client.py index f46297eb9..2d091626a 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -464,9 +464,17 @@ async def start_workflow( rpc_metadata: Mapping[str, str] = {}, rpc_timeout: Optional[timedelta] = None, request_eager_start: bool = False, - stack_level: int = 2, priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: Optional[temporalio.common.VersioningOverride] = None, + # The following options should not be considered part of the public API. They + # are deliberately not exposed in overloads, and are not subject to any + # backwards compatibility guarantees. + nexus_completion_callbacks: Sequence[NexusCompletionCallback] = [], + workflow_event_links: Sequence[ + temporalio.api.common.v1.Link.WorkflowEvent + ] = [], + request_id: Optional[str] = None, + stack_level: int = 2, ) -> WorkflowHandle[Any, Any]: """Start a workflow and return its handle. @@ -529,7 +537,6 @@ async def start_workflow( name, result_type_from_type_hint = ( temporalio.workflow._Definition.get_name_and_result_type(workflow) ) - return await self._impl.start_workflow( StartWorkflowInput( workflow=name, @@ -557,6 +564,9 @@ async def start_workflow( rpc_timeout=rpc_timeout, request_eager_start=request_eager_start, priority=priority, + nexus_completion_callbacks=nexus_completion_callbacks, + workflow_event_links=workflow_event_links, + request_id=request_id, ) ) @@ -5193,6 +5203,10 @@ class StartWorkflowInput: rpc_timeout: Optional[timedelta] request_eager_start: bool priority: temporalio.common.Priority + # The following options are experimental and unstable. + nexus_completion_callbacks: Sequence[NexusCompletionCallback] + workflow_event_links: Sequence[temporalio.api.common.v1.Link.WorkflowEvent] + request_id: Optional[str] versioning_override: Optional[temporalio.common.VersioningOverride] = None @@ -5807,8 +5821,26 @@ async def _build_start_workflow_execution_request( self, input: StartWorkflowInput ) -> temporalio.api.workflowservice.v1.StartWorkflowExecutionRequest: req = temporalio.api.workflowservice.v1.StartWorkflowExecutionRequest() - req.request_eager_execution = input.request_eager_start await self._populate_start_workflow_execution_request(req, input) + # _populate_start_workflow_execution_request is used for both StartWorkflowInput + # and UpdateWithStartStartWorkflowInput. UpdateWithStartStartWorkflowInput does + # not have the following two fields so they are handled here. + req.request_eager_execution = input.request_eager_start + if input.request_id: + req.request_id = input.request_id + + req.completion_callbacks.extend( + temporalio.api.common.v1.Callback( + nexus=temporalio.api.common.v1.Callback.Nexus( + url=callback.url, header=callback.header + ) + ) + for callback in input.nexus_completion_callbacks + ) + req.links.extend( + temporalio.api.common.v1.Link(workflow_event=link) + for link in input.workflow_event_links + ) return req async def _build_signal_with_start_workflow_execution_request( @@ -7231,6 +7263,21 @@ def api_key(self, value: Optional[str]) -> None: self.service_client.update_api_key(value) +@dataclass(frozen=True) +class NexusCompletionCallback: + """Nexus callback to attach to events such as workflow completion. + + .. warning:: + This option is experimental and unstable. + """ + + url: str + """Callback URL.""" + + header: Mapping[str, str] + """Header to attach to callback request.""" + + async def _encode_user_metadata( converter: temporalio.converter.DataConverter, summary: Optional[Union[str, temporalio.api.common.v1.Payload]], diff --git a/temporalio/common.py b/temporalio/common.py index 3349f70e9..b9b088e86 100644 --- a/temporalio/common.py +++ b/temporalio/common.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime, timedelta -from enum import Enum, IntEnum +from enum import IntEnum from typing import ( Any, Callable, diff --git a/temporalio/converter.py b/temporalio/converter.py index 6a6d0e12b..43dbe305b 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -16,6 +16,7 @@ from datetime import datetime from enum import IntEnum from itertools import zip_longest +from logging import getLogger from typing import ( Any, Awaitable, @@ -40,6 +41,7 @@ import google.protobuf.json_format import google.protobuf.message import google.protobuf.symbol_database +import nexusrpc import typing_extensions import temporalio.api.common.v1 @@ -60,6 +62,8 @@ if sys.version_info >= (3, 10): from types import UnionType +logger = getLogger(__name__) + class PayloadConverter(ABC): """Base payload converter to/from multiple payloads/values.""" @@ -911,6 +915,12 @@ def _error_to_failure( failure.child_workflow_execution_failure_info.retry_state = ( temporalio.api.enums.v1.RetryState.ValueType(error.retry_state or 0) ) + # TODO(nexus-prerelease): test coverage for this + elif isinstance(error, temporalio.exceptions.NexusOperationError): + failure.nexus_operation_execution_failure_info.SetInParent() + failure.nexus_operation_execution_failure_info.operation_token = ( + error.operation_token + ) def from_failure( self, @@ -1006,6 +1016,33 @@ def from_failure( if child_info.retry_state else None, ) + elif failure.HasField("nexus_handler_failure_info"): + nexus_handler_failure_info = failure.nexus_handler_failure_info + try: + _type = nexusrpc.HandlerErrorType[nexus_handler_failure_info.type] + except KeyError: + logger.warning( + f"Unknown Nexus HandlerErrorType: {nexus_handler_failure_info.type}" + ) + _type = nexusrpc.HandlerErrorType.INTERNAL + return nexusrpc.HandlerError( + failure.message or "Nexus handler error", + type=_type, + retryable={ + temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE: True, + temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE: False, + }.get(nexus_handler_failure_info.retry_behavior), + ) + elif failure.HasField("nexus_operation_execution_failure_info"): + nexus_op_failure_info = failure.nexus_operation_execution_failure_info + err = temporalio.exceptions.NexusOperationError( + failure.message or "Nexus operation error", + scheduled_event_id=nexus_op_failure_info.scheduled_event_id, + endpoint=nexus_op_failure_info.endpoint, + service=nexus_op_failure_info.service, + operation=nexus_op_failure_info.operation, + operation_token=nexus_op_failure_info.operation_token, + ) else: err = temporalio.exceptions.FailureError(failure.message or "Failure error") err._failure = failure diff --git a/temporalio/exceptions.py b/temporalio/exceptions.py index f045b36a0..0a1cd9a1d 100644 --- a/temporalio/exceptions.py +++ b/temporalio/exceptions.py @@ -362,6 +362,53 @@ def retry_state(self) -> Optional[RetryState]: return self._retry_state +class NexusOperationError(FailureError): + """Error raised on Nexus operation failure.""" + + def __init__( + self, + message: str, + *, + scheduled_event_id: int, + endpoint: str, + service: str, + operation: str, + operation_token: str, + ): + """Initialize a Nexus operation error.""" + super().__init__(message) + self._scheduled_event_id = scheduled_event_id + self._endpoint = endpoint + self._service = service + self._operation = operation + self._operation_token = operation_token + + @property + def scheduled_event_id(self) -> int: + """The NexusOperationScheduled event ID for the failed operation.""" + return self._scheduled_event_id + + @property + def endpoint(self) -> str: + """The endpoint name for the failed operation.""" + return self._endpoint + + @property + def service(self) -> str: + """The service name for the failed operation.""" + return self._service + + @property + def operation(self) -> str: + """The name of the failed operation.""" + return self._operation + + @property + def operation_token(self) -> str: + """The operation token returned by the failed operation.""" + return self._operation_token + + def is_cancelled_exception(exception: BaseException) -> bool: """Check whether the given exception is considered a cancellation exception according to Temporal. diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py new file mode 100644 index 000000000..217e37565 --- /dev/null +++ b/temporalio/nexus/__init__.py @@ -0,0 +1,17 @@ +from ._decorators import workflow_run_operation as workflow_run_operation +from ._operation_context import Info as Info +from ._operation_context import LoggerAdapter as LoggerAdapter +from ._operation_context import ( + WorkflowRunOperationContext as WorkflowRunOperationContext, +) +from ._operation_context import ( + _TemporalCancelOperationContext as _TemporalCancelOperationContext, +) +from ._operation_context import ( + _TemporalStartOperationContext as _TemporalStartOperationContext, +) +from ._operation_context import client as client +from ._operation_context import in_operation as in_operation +from ._operation_context import info as info +from ._operation_context import logger as logger +from ._token import WorkflowHandle as WorkflowHandle diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py new file mode 100644 index 000000000..b1a30f93c --- /dev/null +++ b/temporalio/nexus/_decorators.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from typing import ( + Awaitable, + Callable, + Optional, + TypeVar, + Union, + overload, +) + +import nexusrpc +from nexusrpc import InputT, OutputT +from nexusrpc.handler import ( + OperationHandler, + StartOperationContext, +) + +from temporalio.nexus._operation_context import ( + WorkflowRunOperationContext, +) +from temporalio.nexus._operation_handlers import ( + WorkflowRunOperationHandler, +) +from temporalio.nexus._token import ( + WorkflowHandle, +) +from temporalio.nexus._util import ( + get_callable_name, + get_workflow_run_start_method_input_and_output_type_annotations, + set_operation_factory, +) + +ServiceHandlerT = TypeVar("ServiceHandlerT") + + +@overload +def workflow_run_operation( + start: Callable[ + [ServiceHandlerT, WorkflowRunOperationContext, InputT], + Awaitable[WorkflowHandle[OutputT]], + ], +) -> Callable[ + [ServiceHandlerT, WorkflowRunOperationContext, InputT], + Awaitable[WorkflowHandle[OutputT]], +]: ... + + +@overload +def workflow_run_operation( + *, + name: Optional[str] = None, +) -> Callable[ + [ + Callable[ + [ServiceHandlerT, WorkflowRunOperationContext, InputT], + Awaitable[WorkflowHandle[OutputT]], + ] + ], + Callable[ + [ServiceHandlerT, WorkflowRunOperationContext, InputT], + Awaitable[WorkflowHandle[OutputT]], + ], +]: ... + + +def workflow_run_operation( + start: Optional[ + Callable[ + [ServiceHandlerT, WorkflowRunOperationContext, InputT], + Awaitable[WorkflowHandle[OutputT]], + ] + ] = None, + *, + name: Optional[str] = None, +) -> Union[ + Callable[ + [ServiceHandlerT, WorkflowRunOperationContext, InputT], + Awaitable[WorkflowHandle[OutputT]], + ], + Callable[ + [ + Callable[ + [ServiceHandlerT, WorkflowRunOperationContext, InputT], + Awaitable[WorkflowHandle[OutputT]], + ] + ], + Callable[ + [ServiceHandlerT, WorkflowRunOperationContext, InputT], + Awaitable[WorkflowHandle[OutputT]], + ], + ], +]: + """ + Decorator marking a method as the start method for a workflow-backed operation. + """ + + def decorator( + start: Callable[ + [ServiceHandlerT, WorkflowRunOperationContext, InputT], + Awaitable[WorkflowHandle[OutputT]], + ], + ) -> Callable[ + [ServiceHandlerT, WorkflowRunOperationContext, InputT], + Awaitable[WorkflowHandle[OutputT]], + ]: + ( + input_type, + output_type, + ) = get_workflow_run_start_method_input_and_output_type_annotations(start) + + def operation_handler_factory( + self: ServiceHandlerT, + ) -> OperationHandler[InputT, OutputT]: + async def _start( + ctx: StartOperationContext, input: InputT + ) -> WorkflowHandle[OutputT]: + return await start( + self, + WorkflowRunOperationContext.from_start_operation_context(ctx), + input, + ) + + _start.__doc__ = start.__doc__ + return WorkflowRunOperationHandler(_start, input_type, output_type) + + method_name = get_callable_name(start) + nexusrpc.set_operation_definition( + operation_handler_factory, + nexusrpc.Operation( + name=name or method_name, + method_name=method_name, + input_type=input_type, + output_type=output_type, + ), + ) + + set_operation_factory(start, operation_handler_factory) + return start + + if start is None: + return decorator + + return decorator(start) diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py new file mode 100644 index 000000000..47425ebbb --- /dev/null +++ b/temporalio/nexus/_operation_context.py @@ -0,0 +1,583 @@ +from __future__ import annotations + +import dataclasses +import logging +import re +import urllib.parse +from contextvars import ContextVar +from dataclasses import dataclass +from datetime import timedelta +from typing import ( + Any, + Awaitable, + Callable, + Mapping, + MutableMapping, + Optional, + Sequence, + Type, + Union, + overload, +) + +import nexusrpc.handler +from nexusrpc.handler import CancelOperationContext, StartOperationContext +from typing_extensions import Concatenate + +import temporalio.api.common.v1 +import temporalio.api.enums.v1 +import temporalio.client +import temporalio.common +from temporalio.nexus._token import WorkflowHandle +from temporalio.types import ( + MethodAsyncNoParam, + MethodAsyncSingleParam, + MultiParamSpec, + ParamType, + ReturnType, + SelfType, +) + +# The Temporal Nexus worker always builds a nexusrpc StartOperationContext or +# CancelOperationContext and passes it as the first parameter to the nexusrpc operation +# handler. In addition, it sets one of the following context vars. + +_temporal_start_operation_context: ContextVar[_TemporalStartOperationContext] = ( + ContextVar("temporal-start-operation-context") +) + +_temporal_cancel_operation_context: ContextVar[_TemporalCancelOperationContext] = ( + ContextVar("temporal-cancel-operation-context") +) + + +@dataclass(frozen=True) +class Info: + """Information about the running Nexus operation. + + Retrieved inside a Nexus operation handler via :py:func:`info`. + """ + + task_queue: str + """The task queue of the worker handling this Nexus operation.""" + + +def in_operation() -> bool: + """ + Whether the current code is inside a Nexus operation. + """ + return _try_temporal_context() is not None + + +def info() -> Info: + """ + Get the current Nexus operation information. + """ + return _temporal_context().info() + + +def client() -> temporalio.client.Client: + """ + Get the Temporal client used by the worker handling the current Nexus operation. + """ + return _temporal_context().client + + +def _temporal_context() -> ( + Union[_TemporalStartOperationContext, _TemporalCancelOperationContext] +): + ctx = _try_temporal_context() + if ctx is None: + raise RuntimeError("Not in Nexus operation context.") + return ctx + + +def _try_temporal_context() -> ( + Optional[Union[_TemporalStartOperationContext, _TemporalCancelOperationContext]] +): + start_ctx = _temporal_start_operation_context.get(None) + cancel_ctx = _temporal_cancel_operation_context.get(None) + if start_ctx and cancel_ctx: + raise RuntimeError("Cannot be in both start and cancel operation contexts.") + return start_ctx or cancel_ctx + + +@dataclass +class _TemporalStartOperationContext: + """ + Context for a Nexus start operation being handled by a Temporal Nexus Worker. + """ + + nexus_context: StartOperationContext + """Nexus-specific start operation context.""" + + info: Callable[[], Info] + """Temporal information about the running Nexus operation.""" + + client: temporalio.client.Client + """The Temporal client in use by the worker handling this Nexus operation.""" + + @classmethod + def get(cls) -> _TemporalStartOperationContext: + ctx = _temporal_start_operation_context.get(None) + if ctx is None: + raise RuntimeError("Not in Nexus operation context.") + return ctx + + def set(self) -> None: + _temporal_start_operation_context.set(self) + + def _get_completion_callbacks( + self, + ) -> list[temporalio.client.NexusCompletionCallback]: + ctx = self.nexus_context + return ( + [ + # TODO(nexus-prerelease): For WorkflowRunOperation, when it handles the Nexus + # request, it needs to copy the links to the callback in + # StartWorkflowRequest.CompletionCallbacks and to StartWorkflowRequest.Links + # (for backwards compatibility). PR reference in Go SDK: + # https://github.com/temporalio/sdk-go/pull/1945 + temporalio.client.NexusCompletionCallback( + url=ctx.callback_url, + header=ctx.callback_headers, + ) + ] + if ctx.callback_url + else [] + ) + + def _get_workflow_event_links( + self, + ) -> list[temporalio.api.common.v1.Link.WorkflowEvent]: + event_links = [] + for inbound_link in self.nexus_context.inbound_links: + if link := _nexus_link_to_workflow_event(inbound_link): + event_links.append(link) + return event_links + + def _add_outbound_links( + self, workflow_handle: temporalio.client.WorkflowHandle[Any, Any] + ): + try: + link = _workflow_event_to_nexus_link( + _workflow_handle_to_workflow_execution_started_event_link( + workflow_handle + ) + ) + except Exception as e: + logger.warning( + f"Failed to create WorkflowExecutionStarted event link for workflow {id}: {e}" + ) + else: + self.nexus_context.outbound_links.append( + # TODO(nexus-prerelease): Before, WorkflowRunOperation was generating an EventReference + # link to send back to the caller. Now, it checks if the server returned + # the link in the StartWorkflowExecutionResponse, and if so, send the link + # from the response to the caller. Fallback to generating the link for + # backwards compatibility. PR reference in Go SDK: + # https://github.com/temporalio/sdk-go/pull/1934 + link + ) + return workflow_handle + + +@dataclass(frozen=True) +class WorkflowRunOperationContext(StartOperationContext): + _temporal_context: Optional[_TemporalStartOperationContext] = None + + @property + def temporal_context(self) -> _TemporalStartOperationContext: + if not self._temporal_context: + raise RuntimeError("Temporal context not set") + return self._temporal_context + + @classmethod + def from_start_operation_context( + cls, ctx: StartOperationContext + ) -> WorkflowRunOperationContext: + return cls( + _temporal_context=_TemporalStartOperationContext.get(), + **{f.name: getattr(ctx, f.name) for f in dataclasses.fields(ctx)}, + ) + + # Overload for no-param workflow + @overload + async def start_workflow( + self, + workflow: MethodAsyncNoParam[SelfType, ReturnType], + *, + id: str, + task_queue: Optional[str] = None, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + start_signal: Optional[str] = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: Optional[temporalio.common.VersioningOverride] = None, + ) -> WorkflowHandle[ReturnType]: ... + + # Overload for single-param workflow + @overload + async def start_workflow( + self, + workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: Optional[str] = None, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + start_signal: Optional[str] = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: Optional[temporalio.common.VersioningOverride] = None, + ) -> WorkflowHandle[ReturnType]: ... + + # Overload for multi-param workflow + @overload + async def start_workflow( + self, + workflow: Callable[ + Concatenate[SelfType, MultiParamSpec], Awaitable[ReturnType] + ], + *, + args: Sequence[Any], + id: str, + task_queue: Optional[str] = None, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + start_signal: Optional[str] = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: Optional[temporalio.common.VersioningOverride] = None, + ) -> WorkflowHandle[ReturnType]: ... + + # Overload for string-name workflow + @overload + async def start_workflow( + self, + workflow: str, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: Optional[str] = None, + result_type: Optional[Type[ReturnType]] = None, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + start_signal: Optional[str] = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: Optional[temporalio.common.VersioningOverride] = None, + ) -> WorkflowHandle[ReturnType]: ... + + async def start_workflow( + self, + workflow: Union[str, Callable[..., Awaitable[ReturnType]]], + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: Optional[str] = None, + result_type: Optional[Type] = None, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + start_signal: Optional[str] = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: Optional[temporalio.common.VersioningOverride] = None, + ) -> WorkflowHandle[ReturnType]: + """Start a workflow that will deliver the result of the Nexus operation. + + The workflow will be started in the same namespace as the Nexus worker, using + the same client as the worker. If task queue is not specified, the worker's task + queue will be used. + + See :py:meth:`temporalio.client.Client.start_workflow` for all arguments. + + The return value is :py:class:`temporalio.nexus.WorkflowHandle`. + + The workflow will be started as usual, with the following modifications: + + - On workflow completion, Temporal server will deliver the workflow result to + the Nexus operation caller, using the callback from the Nexus operation start + request. + + - The request ID from the Nexus operation start request will be used as the + request ID for the start workflow request. + + - Inbound links to the caller that were submitted in the Nexus start operation + request will be attached to the started workflow and, outbound links to the + started workflow will be added to the Nexus start operation response. If the + Nexus caller is itself a workflow, this means that the workflow in the caller + namespace web UI will contain links to the started workflow, and vice versa. + """ + # TODO(nexus-preview): When sdk-python supports on_conflict_options, Typescript does this: + # if (workflowOptions.workflowIdConflictPolicy === 'USE_EXISTING') { + # internalOptions.onConflictOptions = { + # attachLinks: true, + # attachCompletionCallbacks: true, + # attachRequestId: true, + # }; + # } + + # We must pass nexus_completion_callbacks, workflow_event_links, and request_id, + # but these are deliberately not exposed in overloads, hence the type-check + # violation. + wf_handle = await self.temporal_context.client.start_workflow( # type: ignore + workflow=workflow, + arg=arg, + args=args, + id=id, + task_queue=task_queue or self.temporal_context.info().task_queue, + result_type=result_type, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + start_signal=start_signal, + start_signal_args=start_signal_args, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + request_eager_start=request_eager_start, + priority=priority, + versioning_override=versioning_override, + nexus_completion_callbacks=self.temporal_context._get_completion_callbacks(), + workflow_event_links=self.temporal_context._get_workflow_event_links(), + request_id=self.temporal_context.nexus_context.request_id, + ) + + self.temporal_context._add_outbound_links(wf_handle) + + return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle) + + +@dataclass +class _TemporalCancelOperationContext: + """ + Context for a Nexus cancel operation being handled by a Temporal Nexus Worker. + """ + + nexus_context: CancelOperationContext + """Nexus-specific cancel operation context.""" + + info: Callable[[], Info] + """Temporal information about the running Nexus cancel operation.""" + + client: temporalio.client.Client + """The Temporal client in use by the worker handling the current Nexus operation.""" + + @classmethod + def get(cls) -> _TemporalCancelOperationContext: + ctx = _temporal_cancel_operation_context.get(None) + if ctx is None: + raise RuntimeError("Not in Nexus cancel operation context.") + return ctx + + def set(self) -> None: + _temporal_cancel_operation_context.set(self) + + +def _workflow_handle_to_workflow_execution_started_event_link( + handle: temporalio.client.WorkflowHandle[Any, Any], +) -> temporalio.api.common.v1.Link.WorkflowEvent: + if handle.first_execution_run_id is None: + raise ValueError( + f"Workflow handle {handle} has no first execution run ID. " + "Cannot create WorkflowExecutionStarted event link." + ) + return temporalio.api.common.v1.Link.WorkflowEvent( + namespace=handle._client.namespace, + workflow_id=handle.id, + run_id=handle.first_execution_run_id, + event_ref=temporalio.api.common.v1.Link.WorkflowEvent.EventReference( + # TODO(nexus-prerelease): confirm that it is correct not to use event_id. + # Should the proto say explicitly that it's optional or how it behaves when it's missing? + event_type=temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED + ), + # TODO(nexus-prerelease): RequestIdReference? + ) + + +def _workflow_event_to_nexus_link( + workflow_event: temporalio.api.common.v1.Link.WorkflowEvent, +) -> nexusrpc.Link: + scheme = "temporal" + namespace = urllib.parse.quote(workflow_event.namespace) + workflow_id = urllib.parse.quote(workflow_event.workflow_id) + run_id = urllib.parse.quote(workflow_event.run_id) + path = f"/namespaces/{namespace}/workflows/{workflow_id}/{run_id}/history" + query_params = urllib.parse.urlencode( + { + "eventType": temporalio.api.enums.v1.EventType.Name( + workflow_event.event_ref.event_type + ), + "referenceType": "EventReference", + } + ) + return nexusrpc.Link( + url=urllib.parse.urlunparse((scheme, "", path, "", query_params, "")), + type=workflow_event.DESCRIPTOR.full_name, + ) + + +_LINK_URL_PATH_REGEX = re.compile( + r"^/namespaces/(?P[^/]+)/workflows/(?P[^/]+)/(?P[^/]+)/history$" +) + + +def _nexus_link_to_workflow_event( + link: nexusrpc.Link, +) -> Optional[temporalio.api.common.v1.Link.WorkflowEvent]: + url = urllib.parse.urlparse(link.url) + match = _LINK_URL_PATH_REGEX.match(url.path) + if not match: + logger.warning( + f"Invalid Nexus link: {link}. Expected path to match {_LINK_URL_PATH_REGEX.pattern}" + ) + return None + try: + query_params = urllib.parse.parse_qs(url.query) + [reference_type] = query_params.get("referenceType", []) + if reference_type != "EventReference": + raise ValueError( + f"Expected Nexus link URL query parameter referenceType to be EventReference but got: {reference_type}" + ) + [event_type_name] = query_params.get("eventType", []) + event_ref = temporalio.api.common.v1.Link.WorkflowEvent.EventReference( + # TODO(nexus-prerelease): confirm that it is correct not to use event_id. + # Should the proto say explicitly that it's optional or how it behaves when it's missing? + event_type=temporalio.api.enums.v1.EventType.Value(event_type_name) + ) + except ValueError as err: + logger.warning( + f"Failed to parse event type from Nexus link URL query parameters: {link} ({err})" + ) + event_ref = None + + groups = match.groupdict() + return temporalio.api.common.v1.Link.WorkflowEvent( + namespace=urllib.parse.unquote(groups["namespace"]), + workflow_id=urllib.parse.unquote(groups["workflow_id"]), + run_id=urllib.parse.unquote(groups["run_id"]), + event_ref=event_ref, + ) + + +class LoggerAdapter(logging.LoggerAdapter): + def __init__(self, logger: logging.Logger, extra: Optional[Mapping[str, Any]]): + super().__init__(logger, extra or {}) + + def process( + self, msg: Any, kwargs: MutableMapping[str, Any] + ) -> tuple[Any, MutableMapping[str, Any]]: + extra = dict(self.extra or {}) + if tctx := _try_temporal_context(): + extra["service"] = tctx.nexus_context.service + extra["operation"] = tctx.nexus_context.operation + extra["task_queue"] = tctx.info().task_queue + kwargs["extra"] = extra | kwargs.get("extra", {}) + return msg, kwargs + + +logger = LoggerAdapter(logging.getLogger("temporalio.nexus"), None) +"""Logger that emits additional data describing the current Nexus operation.""" diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py new file mode 100644 index 000000000..ecc286719 --- /dev/null +++ b/temporalio/nexus/_operation_handlers.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from typing import ( + Any, + Awaitable, + Callable, + Optional, + Type, +) + +from nexusrpc import ( + HandlerError, + HandlerErrorType, + InputT, + OperationInfo, + OutputT, +) +from nexusrpc.handler import ( + CancelOperationContext, + FetchOperationInfoContext, + FetchOperationResultContext, + OperationHandler, + StartOperationContext, + StartOperationResultAsync, +) + +from temporalio import client +from temporalio.nexus._operation_context import ( + _temporal_cancel_operation_context, +) +from temporalio.nexus._token import WorkflowHandle + +from ._util import ( + is_async_callable, +) + + +class WorkflowRunOperationHandler(OperationHandler[InputT, OutputT]): + """ + Operation handler for Nexus operations that start a workflow. + + Use this class to create an operation handler that starts a workflow by passing your + ``start`` method to the constructor. Your ``start`` method must use + :py:func:`temporalio.nexus.WorkflowRunOperationContext.start_workflow` to start the + workflow. + """ + + def __init__( + self, + start: Callable[ + [StartOperationContext, InputT], + Awaitable[WorkflowHandle[OutputT]], + ], + input_type: Optional[Type[InputT]], + output_type: Optional[Type[OutputT]], + ) -> None: + if not is_async_callable(start): + raise RuntimeError( + f"{start} is not an `async def` method. " + "WorkflowRunOperationHandler must be initialized with an " + "`async def` start method." + ) + self._start = start + if start.__doc__: + self.start.__func__.__doc__ = start.__doc__ + self._input_type = input_type + self._output_type = output_type + + async def start( + self, ctx: StartOperationContext, input: InputT + ) -> StartOperationResultAsync: + """ + Start the operation, by starting a workflow and completing asynchronously. + """ + handle = await self._start(ctx, input) + if not isinstance(handle, WorkflowHandle): + if isinstance(handle, client.WorkflowHandle): + raise RuntimeError( + f"Expected {handle} to be a nexus.WorkflowHandle, but got a client.WorkflowHandle. " + f"You must use WorkflowRunOperationContext.start_workflow " + "to start a workflow that will deliver the result of the Nexus operation, " + "not client.Client.start_workflow." + ) + raise RuntimeError( + f"Expected {handle} to be a nexus.WorkflowHandle, but got {type(handle)}. " + ) + return StartOperationResultAsync(handle.to_token()) + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + """Cancel the operation, by cancelling the workflow.""" + await _cancel_workflow(token) + + async def fetch_info( + self, ctx: FetchOperationInfoContext, token: str + ) -> OperationInfo: + raise NotImplementedError( + "Temporal Nexus operation handlers do not support fetching operation info." + ) + + async def fetch_result( + self, ctx: FetchOperationResultContext, token: str + ) -> OutputT: + raise NotImplementedError( + "Temporal Nexus operation handlers do not support fetching the operation result." + ) + # An implementation is provided for future reference: + # TODO: honor `wait` param and Request-Timeout header + # try: + # nexus_handle = WorkflowHandle[OutputT].from_token(token) + # except Exception as err: + # raise HandlerError( + # "Failed to decode operation token as workflow operation token. " + # "Fetching result for non-workflow operations is not supported.", + # type=HandlerErrorType.NOT_FOUND, + # ) from err + # ctx = _temporal_fetch_operation_context.get() + # try: + # client_handle = nexus_handle.to_workflow_handle( + # ctx.client, result_type=self._output_type + # ) + # except Exception as err: + # raise HandlerError( + # "Failed to construct workflow handle from workflow operation token", + # type=HandlerErrorType.NOT_FOUND, + # ) from err + # return await client_handle.result() + + +async def _cancel_workflow( + token: str, + **kwargs: Any, +) -> None: + """ + Cancel a workflow that is backing a Nexus operation. + + This function is used by the Nexus worker to cancel a workflow that is backing a + Nexus operation, i.e. started by a + :py:func:`temporalio.nexus.workflow_run_operation`-decorated method. + + Args: + token: The token of the workflow to cancel. kwargs: Additional keyword arguments + to pass to the workflow cancel method. + """ + try: + nexus_workflow_handle = WorkflowHandle[Any].from_token(token) + except Exception as err: + raise HandlerError( + "Failed to decode operation token as a workflow operation token. " + "Canceling non-workflow operations is not supported.", + type=HandlerErrorType.NOT_FOUND, + ) from err + + ctx = _temporal_cancel_operation_context.get() + try: + client_workflow_handle = nexus_workflow_handle._to_client_workflow_handle( + ctx.client + ) + except Exception as err: + raise HandlerError( + "Failed to construct workflow handle from workflow operation token", + type=HandlerErrorType.NOT_FOUND, + ) from err + await client_workflow_handle.cancel(**kwargs) diff --git a/temporalio/nexus/_token.py b/temporalio/nexus/_token.py new file mode 100644 index 000000000..480a404b1 --- /dev/null +++ b/temporalio/nexus/_token.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import base64 +import json +from dataclasses import dataclass +from typing import Any, Generic, Literal, Optional, Type + +from nexusrpc import OutputT + +from temporalio import client + +OperationTokenType = Literal[1] +OPERATION_TOKEN_TYPE_WORKFLOW: OperationTokenType = 1 + + +@dataclass(frozen=True) +class WorkflowHandle(Generic[OutputT]): + """A handle to a workflow that is backing a Nexus operation.""" + + namespace: str + workflow_id: str + # Version of the token. Treated as v1 if missing. This field is not included in the + # serialized token; it's only used to reject newer token versions on load. + version: Optional[int] = None + + def _to_client_workflow_handle( + self, client: client.Client, result_type: Optional[Type[OutputT]] = None + ) -> client.WorkflowHandle[Any, OutputT]: + """Create a :py:class:`temporalio.client.WorkflowHandle` from the token.""" + if client.namespace != self.namespace: + raise ValueError( + f"Client namespace {client.namespace} does not match " + f"operation token namespace {self.namespace}" + ) + return client.get_workflow_handle(self.workflow_id, result_type=result_type) + + # TODO(nexus-preview): The return type here should be dictated by the input workflow + # handle type. + @classmethod + def _unsafe_from_client_workflow_handle( + cls, workflow_handle: client.WorkflowHandle[Any, OutputT] + ) -> WorkflowHandle[OutputT]: + """Create a :py:class:`WorkflowHandle` from a :py:class:`temporalio.client.WorkflowHandle`. + + This is a private method not intended to be used by users. It does not check + that the supplied client.WorkflowHandle references a workflow that has been + instrumented to supply the result of a Nexus operation. + """ + return cls( + namespace=workflow_handle._client.namespace, + workflow_id=workflow_handle.id, + ) + + def to_token(self) -> str: + return _base64url_encode_no_padding( + json.dumps( + { + "t": OPERATION_TOKEN_TYPE_WORKFLOW, + "ns": self.namespace, + "wid": self.workflow_id, + }, + separators=(",", ":"), + ).encode("utf-8") + ) + + @classmethod + def from_token(cls, token: str) -> WorkflowHandle[OutputT]: + """Decodes and validates a token from its base64url-encoded string representation.""" + if not token: + raise TypeError("invalid workflow token: token is empty") + try: + decoded_bytes = _base64url_decode_no_padding(token) + except Exception as err: + raise TypeError("failed to decode token as base64url") from err + try: + workflow_operation_token = json.loads(decoded_bytes.decode("utf-8")) + except Exception as err: + raise TypeError("failed to unmarshal workflow operation token") from err + + if not isinstance(workflow_operation_token, dict): + raise TypeError( + f"invalid workflow token: expected dict, got {type(workflow_operation_token)}" + ) + + token_type = workflow_operation_token.get("t") + if token_type != OPERATION_TOKEN_TYPE_WORKFLOW: + raise TypeError( + f"invalid workflow token type: {token_type}, expected: {OPERATION_TOKEN_TYPE_WORKFLOW}" + ) + + version = workflow_operation_token.get("v") + if version is not None and version != 0: + raise TypeError( + "invalid workflow token: 'v' field, if present, must be 0 or null/absent" + ) + + workflow_id = workflow_operation_token.get("wid") + if not workflow_id or not isinstance(workflow_id, str): + raise TypeError( + "invalid workflow token: missing, empty, or non-string workflow ID (wid)" + ) + + namespace = workflow_operation_token.get("ns") + if namespace is None or not isinstance(namespace, str): + # Allow empty string for ns, but it must be present and a string + raise TypeError( + "invalid workflow token: missing or non-string namespace (ns)" + ) + + return cls( + namespace=namespace, + workflow_id=workflow_id, + version=version, + ) + + +def _base64url_encode_no_padding(data: bytes) -> str: + return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=") + + +_base64_url_alphabet = set( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_-" +) + + +def _base64url_decode_no_padding(s: str) -> bytes: + if invalid_chars := set(s) - _base64_url_alphabet: + raise ValueError( + f"invalid base64URL encoded string: contains invalid characters: {invalid_chars}" + ) + padding = "=" * (-len(s) % 4) + return base64.urlsafe_b64decode(s + padding) diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py new file mode 100644 index 000000000..c0a1b8464 --- /dev/null +++ b/temporalio/nexus/_util.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import functools +import inspect +import typing +import warnings +from typing import ( + Any, + Awaitable, + Callable, + Optional, + Type, + TypeVar, + Union, +) + +import nexusrpc +from nexusrpc import ( + InputT, + OutputT, +) + +from temporalio.nexus._operation_context import WorkflowRunOperationContext + +from ._token import ( + WorkflowHandle as WorkflowHandle, +) + +ServiceHandlerT = TypeVar("ServiceHandlerT") + + +def get_workflow_run_start_method_input_and_output_type_annotations( + start: Callable[ + [ServiceHandlerT, WorkflowRunOperationContext, InputT], + Awaitable[WorkflowHandle[OutputT]], + ], +) -> tuple[ + Optional[Type[InputT]], + Optional[Type[OutputT]], +]: + """Return operation input and output types. + + `start` must be a type-annotated start method that returns a + :py:class:`temporalio.nexus.WorkflowHandle`. + """ + + input_type, output_type = _get_start_method_input_and_output_type_annotations(start) + origin_type = typing.get_origin(output_type) + if not origin_type: + output_type = None + elif not issubclass(origin_type, WorkflowHandle): + warnings.warn( + f"Expected return type of {start.__name__} to be a subclass of WorkflowHandle, " + f"but is {output_type}" + ) + output_type = None + + if output_type: + args = typing.get_args(output_type) + if len(args) != 1: + suffix = f": {args}" if args else "" + warnings.warn( + f"Expected return type {output_type} of {start.__name__} to have exactly one type parameter, " + f"but has {len(args)}{suffix}." + ) + output_type = None + else: + [output_type] = args + return input_type, output_type + + +def _get_start_method_input_and_output_type_annotations( + start: Callable[ + [ServiceHandlerT, WorkflowRunOperationContext, InputT], + Union[OutputT, Awaitable[OutputT]], + ], +) -> tuple[ + Optional[Type[InputT]], + Optional[Type[OutputT]], +]: + """Return operation input and output types. + + `start` must be a type-annotated start method that returns a synchronous result. + """ + try: + type_annotations = typing.get_type_hints(start) + except TypeError: + warnings.warn( + f"Expected decorated start method {start} to have type annotations" + ) + return None, None + output_type = type_annotations.pop("return", None) + + if len(type_annotations) != 2: + suffix = f": {type_annotations}" if type_annotations else "" + warnings.warn( + f"Expected decorated start method {start} to have exactly 2 " + f"type-annotated parameters (ctx and input), but it has {len(type_annotations)}" + f"{suffix}." + ) + input_type = None + else: + ctx_type, input_type = type_annotations.values() + if not issubclass(ctx_type, WorkflowRunOperationContext): + warnings.warn( + f"Expected first parameter of {start} to be an instance of " + f"WorkflowRunOperationContext, but is {ctx_type}." + ) + input_type = None + + return input_type, output_type + + +def get_callable_name(fn: Callable[..., Any]) -> str: + method_name = getattr(fn, "__name__", None) + if not method_name and callable(fn) and hasattr(fn, "__call__"): + method_name = fn.__class__.__name__ + if not method_name: + raise TypeError( + f"Could not determine callable name: " + f"expected {fn} to be a function or callable instance." + ) + return method_name + + +# TODO(nexus-preview) Copied from nexusrpc +def get_operation_factory( + obj: Any, +) -> tuple[ + Optional[Callable[[Any], Any]], + Optional[nexusrpc.Operation[Any, Any]], +]: + """Return the :py:class:`Operation` for the object along with the factory function. + + ``obj`` should be a decorated operation start method. + """ + op_defn = nexusrpc.get_operation_definition(obj) + if op_defn: + factory = obj + else: + if factory := getattr(obj, "__nexus_operation_factory__", None): + op_defn = nexusrpc.get_operation_definition(factory) + if not isinstance(op_defn, nexusrpc.Operation): + return None, None + return factory, op_defn + + +# TODO(nexus-preview) Copied from nexusrpc +def set_operation_factory( + obj: Any, + operation_factory: Callable[[Any], Any], +) -> None: + """Set the :py:class:`OperationHandler` factory for this object. + + ``obj`` should be an operation start method. + """ + setattr(obj, "__nexus_operation_factory__", operation_factory) + + +# Copied from https://github.com/modelcontextprotocol/python-sdk +# +# Copyright (c) 2024 Anthropic, PBC. +# +# This file is licensed under the MIT License. +def is_async_callable(obj: Any) -> bool: + """ + Return True if `obj` is an async callable. + + Supports partials of async callable class instances. + """ + while isinstance(obj, functools.partial): + obj = obj.func + + return inspect.iscoroutinefunction(obj) or ( + callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) + ) diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index fe18d1f18..9bc373022 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -201,8 +201,9 @@ async def drain_poll_queue(self) -> None: # Only call this after run()/drain_poll_queue() have returned. This will not # raise an exception. - # TODO(dan): based on the comment above it looks like the intention may have been to use - # return_exceptions=True + # TODO(nexus-preview): based on the comment above it looks like the intention may have been to use + # return_exceptions=True. Change this for nexus and activity and change call sites to consume entire + # stream and then raise first exception async def wait_all_completed(self) -> None: running_tasks = [v.task for v in self._running_activities.values() if v.task] if running_tasks: diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index a3146200e..692721ad1 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -3,15 +3,14 @@ from __future__ import annotations import concurrent.futures -from dataclasses import dataclass +from collections.abc import Callable, Mapping, MutableMapping +from dataclasses import dataclass, field from datetime import timedelta from typing import ( Any, Awaitable, - Callable, + Generic, List, - Mapping, - MutableMapping, NoReturn, Optional, Sequence, @@ -19,9 +18,14 @@ Union, ) +import nexusrpc.handler +from nexusrpc import InputT, OutputT + import temporalio.activity import temporalio.api.common.v1 import temporalio.common +import temporalio.nexus +import temporalio.nexus._util import temporalio.workflow from temporalio.workflow import VersioningIntent @@ -285,6 +289,52 @@ class StartChildWorkflowInput: ret_type: Optional[Type] +@dataclass +class StartNexusOperationInput(Generic[InputT, OutputT]): + """Input for :py:meth:`WorkflowOutboundInterceptor.start_nexus_operation`.""" + + endpoint: str + service: str + operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]] + input: InputT + schedule_to_close_timeout: Optional[timedelta] + headers: Optional[Mapping[str, str]] + output_type: Optional[Type[OutputT]] = None + + _operation_name: str = field(init=False, repr=False) + _input_type: Optional[Type[InputT]] = field(init=False, repr=False) + + def __post_init__(self) -> None: + if isinstance(self.operation, nexusrpc.Operation): + self._operation_name = self.operation.name + self._input_type = self.operation.input_type + self.output_type = self.operation.output_type + elif isinstance(self.operation, str): + self._operation_name = self.operation + self._input_type = None + elif callable(self.operation): + _, op = temporalio.nexus._util.get_operation_factory(self.operation) + if isinstance(op, nexusrpc.Operation): + self._operation_name = op.name + self._input_type = op.input_type + self.output_type = op.output_type + else: + raise ValueError( + f"Operation callable is not a Nexus operation: {self.operation}" + ) + else: + raise ValueError(f"Operation is not a Nexus operation: {self.operation}") + + @property + def operation_name(self) -> str: + return self._operation_name + + # TODO(nexus-prerelease) contravariant type in output + @property + def input_type(self) -> Optional[Type[InputT]]: + return self._input_type + + @dataclass class StartLocalActivityInput: """Input for :py:meth:`WorkflowOutboundInterceptor.start_local_activity`.""" @@ -409,3 +459,9 @@ def start_local_activity( and :py:func:`temporalio.workflow.execute_local_activity` call. """ return self.next.start_local_activity(input) + + async def start_nexus_operation( + self, input: StartNexusOperationInput[InputT, OutputT] + ) -> temporalio.workflow.NexusOperationHandle[OutputT]: + """Called for every :py:func:`temporalio.workflow.start_nexus_operation` call.""" + return await self.next.start_nexus_operation(input) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py new file mode 100644 index 000000000..1adb67b3e --- /dev/null +++ b/temporalio/worker/_nexus.py @@ -0,0 +1,467 @@ +"""Nexus worker""" + +from __future__ import annotations + +import asyncio +import concurrent.futures +import json +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Mapping, + NoReturn, + Optional, + Sequence, + Type, +) + +import google.protobuf.json_format +import nexusrpc.handler +from nexusrpc import LazyValue +from nexusrpc.handler import CancelOperationContext, Handler, StartOperationContext + +import temporalio.api.common.v1 +import temporalio.api.enums.v1 +import temporalio.api.failure.v1 +import temporalio.api.nexus.v1 +import temporalio.bridge.proto.nexus +import temporalio.bridge.worker +import temporalio.client +import temporalio.common +import temporalio.converter +import temporalio.nexus +from temporalio.exceptions import ApplicationError +from temporalio.nexus import ( + Info, + _TemporalCancelOperationContext, + _TemporalStartOperationContext, + logger, +) +from temporalio.service import RPCError, RPCStatusCode + +from ._interceptor import Interceptor + + +class _NexusWorker: + def __init__( + self, + *, + bridge_worker: Callable[[], temporalio.bridge.worker.Worker], + client: temporalio.client.Client, + task_queue: str, + service_handlers: Sequence[Any], + data_converter: temporalio.converter.DataConverter, + interceptors: Sequence[Interceptor], + metric_meter: temporalio.common.MetricMeter, + executor: Optional[concurrent.futures.Executor], + ) -> None: + # TODO: make it possible to query task queue of bridge worker instead of passing + # unused task_queue into _NexusWorker, _ActivityWorker, etc? + self._bridge_worker = bridge_worker + self._client = client + self._task_queue = task_queue + self._handler = Handler(service_handlers, executor) + self._data_converter = data_converter + # TODO(nexus-preview): interceptors + self._interceptors = interceptors + # TODO(nexus-preview): metric_meter + self._metric_meter = metric_meter + self._running_tasks: dict[bytes, asyncio.Task[Any]] = {} + self._fail_worker_exception_queue: asyncio.Queue[Exception] = asyncio.Queue() + + async def run(self) -> None: + """ + Continually poll for Nexus tasks and dispatch to handlers. + """ + + async def raise_from_exception_queue() -> NoReturn: + raise await self._fail_worker_exception_queue.get() + + exception_task = asyncio.create_task(raise_from_exception_queue()) + + while True: + try: + poll_task = asyncio.create_task(self._bridge_worker().poll_nexus_task()) + await asyncio.wait( + [poll_task, exception_task], return_when=asyncio.FIRST_COMPLETED + ) + if exception_task.done(): + poll_task.cancel() + await exception_task + nexus_task = await poll_task + + if nexus_task.HasField("task"): + task = nexus_task.task + if task.request.HasField("start_operation"): + self._running_tasks[task.task_token] = asyncio.create_task( + self._handle_start_operation_task( + task.task_token, + task.request.start_operation, + dict(task.request.header), + ) + ) + elif task.request.HasField("cancel_operation"): + self._running_tasks[task.task_token] = asyncio.create_task( + self._handle_cancel_operation_task( + task.task_token, + task.request.cancel_operation, + dict(task.request.header), + ) + ) + else: + raise NotImplementedError( + f"Invalid Nexus task request: {task.request}" + ) + elif nexus_task.HasField("cancel_task"): + if running_task := self._running_tasks.get( + nexus_task.cancel_task.task_token + ): + # TODO(nexus-prerelease): when do we remove the entry from _running_operations? + running_task.cancel() + else: + logger.debug( + f"Received cancel_task but no running task exists for " + f"task token: {nexus_task.cancel_task.task_token.decode()}" + ) + else: + raise NotImplementedError(f"Invalid Nexus task: {nexus_task}") + + except temporalio.bridge.worker.PollShutdownError: + exception_task.cancel() + return + + except Exception as err: + raise RuntimeError("Nexus worker failed") from err + + # Only call this if run() raised an error + async def drain_poll_queue(self) -> None: + while True: + try: + # Take all tasks and say we can't handle them + task = await self._bridge_worker().poll_nexus_task() + completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task.task.task_token + ) + completion.error.failure.message = "Worker shutting down" + await self._bridge_worker().complete_nexus_task(completion) + except temporalio.bridge.worker.PollShutdownError: + return + + # Only call this after run()/drain_poll_queue() have returned. This will not + # raise an exception. + async def wait_all_completed(self) -> None: + await asyncio.gather(*self._running_tasks.values(), return_exceptions=True) + + # TODO(nexus-preview): stack trace pruning. See sdk-typescript NexusHandler.execute + # "Any call up to this function and including this one will be trimmed out of stack traces."" + + async def _handle_cancel_operation_task( + self, + task_token: bytes, + request: temporalio.api.nexus.v1.CancelOperationRequest, + headers: Mapping[str, str], + ) -> None: + """ + Handle a cancel operation task. + + Attempt to execute the user cancel_operation method. Handle errors and send the + task completion. + """ + # TODO(nexus-prerelease): headers + ctx = CancelOperationContext( + service=request.service, + operation=request.operation, + headers=headers, + ) + _TemporalCancelOperationContext( + info=lambda: Info(task_queue=self._task_queue), + nexus_context=ctx, + client=self._client, + ).set() + try: + await self._handler.cancel_operation(ctx, request.operation_token) + except BaseException as err: + logger.warning("Failed to execute Nexus cancel operation method") + completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task_token, + error=await self._handler_error_to_proto( + _exception_to_handler_error(err) + ), + ) + else: + # TODO(nexus-preview): ack_cancel completions? + completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task_token, + completed=temporalio.api.nexus.v1.Response( + cancel_operation=temporalio.api.nexus.v1.CancelOperationResponse() + ), + ) + try: + await self._bridge_worker().complete_nexus_task(completion) + except Exception: + logger.exception("Failed to send Nexus task completion") + + async def _handle_start_operation_task( + self, + task_token: bytes, + start_request: temporalio.api.nexus.v1.StartOperationRequest, + headers: Mapping[str, str], + ) -> None: + """ + Handle a start operation task. + + Attempt to execute the user start_operation method and invoke the data converter + on the result. Handle errors and send the task completion. + """ + + try: + start_response = await self._start_operation(start_request, headers) + except BaseException as err: + logger.warning("Failed to execute Nexus start operation method") + handler_err = _exception_to_handler_error(err) + completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task_token, + error=await self._handler_error_to_proto(handler_err), + ) + if isinstance(err, concurrent.futures.BrokenExecutor): + self._fail_worker_exception_queue.put_nowait(err) + else: + completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task_token, + completed=temporalio.api.nexus.v1.Response( + start_operation=start_response + ), + ) + + try: + await self._bridge_worker().complete_nexus_task(completion) + except Exception: + logger.exception("Failed to send Nexus task completion") + finally: + try: + del self._running_tasks[task_token] + except KeyError: + logger.exception("Failed to remove completed Nexus operation") + + async def _start_operation( + self, + start_request: temporalio.api.nexus.v1.StartOperationRequest, + headers: Mapping[str, str], + ) -> temporalio.api.nexus.v1.StartOperationResponse: + """ + Invoke the Nexus handler's start_operation method and construct the StartOperationResponse. + + OperationError is handled by this function, since it results in a StartOperationResponse. + + All other exceptions are handled by a caller of this function. + """ + ctx = StartOperationContext( + service=start_request.service, + operation=start_request.operation, + headers=headers, + request_id=start_request.request_id, + callback_url=start_request.callback, + inbound_links=[ + nexusrpc.Link(url=link.url, type=link.type) + for link in start_request.links + ], + callback_headers=dict(start_request.callback_header), + ) + _TemporalStartOperationContext( + nexus_context=ctx, + client=self._client, + info=lambda: Info(task_queue=self._task_queue), + ).set() + input = LazyValue( + serializer=_DummyPayloadSerializer( + data_converter=self._data_converter, + payload=start_request.payload, + ), + headers={}, + stream=None, + ) + try: + result = await self._handler.start_operation(ctx, input) + links = [ + temporalio.api.nexus.v1.Link(url=link.url, type=link.type) + for link in ctx.outbound_links + ] + if isinstance(result, nexusrpc.handler.StartOperationResultAsync): + return temporalio.api.nexus.v1.StartOperationResponse( + async_success=temporalio.api.nexus.v1.StartOperationResponse.Async( + operation_token=result.token, + links=links, + ) + ) + elif isinstance(result, nexusrpc.handler.StartOperationResultSync): + [payload] = await self._data_converter.encode([result.value]) + return temporalio.api.nexus.v1.StartOperationResponse( + sync_success=temporalio.api.nexus.v1.StartOperationResponse.Sync( + payload=payload, + links=links, + ) + ) + else: + raise _exception_to_handler_error( + TypeError( + "Operation start method must return either " + "nexusrpc.handler.StartOperationResultSync or " + "nexusrpc.handler.StartOperationResultAsync." + ) + ) + except nexusrpc.OperationError as err: + return temporalio.api.nexus.v1.StartOperationResponse( + operation_error=await self._operation_error_to_proto(err), + ) + + async def _exception_to_failure_proto( + self, + err: BaseException, + ) -> temporalio.api.nexus.v1.Failure: + try: + api_failure = temporalio.api.failure.v1.Failure() + await self._data_converter.encode_failure(err, api_failure) + _api_failure = google.protobuf.json_format.MessageToDict(api_failure) + return temporalio.api.nexus.v1.Failure( + message=_api_failure.pop("message", ""), + metadata={"type": "temporal.api.failure.v1.Failure"}, + details=json.dumps(_api_failure).encode("utf-8"), + ) + except BaseException as err: + return temporalio.api.nexus.v1.Failure( + message=f"{err.__class__.__name__}: {err}", + metadata={"type": "temporal.api.failure.v1.Failure"}, + ) + + async def _operation_error_to_proto( + self, + err: nexusrpc.OperationError, + ) -> temporalio.api.nexus.v1.UnsuccessfulOperationError: + return temporalio.api.nexus.v1.UnsuccessfulOperationError( + operation_state=err.state.value, + failure=await self._exception_to_failure_proto(err), + ) + + async def _handler_error_to_proto( + self, err: nexusrpc.HandlerError + ) -> temporalio.api.nexus.v1.HandlerError: + return temporalio.api.nexus.v1.HandlerError( + error_type=err.type.value, + failure=await self._exception_to_failure_proto(err), + retry_behavior=( + temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE + if err.retryable + else temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE + ), + ) + + +@dataclass +class _DummyPayloadSerializer: + data_converter: temporalio.converter.DataConverter + payload: temporalio.api.common.v1.Payload + + async def serialize(self, value: Any) -> nexusrpc.Content: + raise NotImplementedError( + "The serialize method of the Serializer is not used by handlers" + ) + + async def deserialize( + self, + content: nexusrpc.Content, + as_type: Optional[Type[Any]] = None, + ) -> Any: + try: + [input] = await self.data_converter.decode( + [self.payload], + type_hints=[as_type] if as_type else None, + ) + return input + except Exception as err: + raise nexusrpc.HandlerError( + "Data converter failed to decode Nexus operation input", + type=nexusrpc.HandlerErrorType.BAD_REQUEST, + retryable=False, + ) from err + + +# TODO(nexus-prerelease): tests for this function +def _exception_to_handler_error(err: BaseException) -> nexusrpc.HandlerError: + # Based on sdk-typescript's convertKnownErrors: + # https://github.com/temporalio/sdk-typescript/blob/nexus/packages/worker/src/nexus.ts + if isinstance(err, nexusrpc.HandlerError): + return err + elif isinstance(err, ApplicationError): + handler_err = nexusrpc.HandlerError( + # TODO(nexus-prerelease): what should message be? + err.message, + type=nexusrpc.HandlerErrorType.INTERNAL, + retryable=not err.non_retryable, + ) + elif isinstance(err, RPCError): + if err.status == RPCStatusCode.INVALID_ARGUMENT: + handler_err = nexusrpc.HandlerError( + err.message, + type=nexusrpc.HandlerErrorType.BAD_REQUEST, + ) + elif err.status in [ + RPCStatusCode.ALREADY_EXISTS, + RPCStatusCode.FAILED_PRECONDITION, + RPCStatusCode.OUT_OF_RANGE, + ]: + handler_err = nexusrpc.HandlerError( + err.message, + type=nexusrpc.HandlerErrorType.INTERNAL, + retryable=False, + ) + elif err.status in [RPCStatusCode.ABORTED, RPCStatusCode.UNAVAILABLE]: + handler_err = nexusrpc.HandlerError( + err.message, + type=nexusrpc.HandlerErrorType.UNAVAILABLE, + ) + elif err.status in [ + RPCStatusCode.CANCELLED, + RPCStatusCode.DATA_LOSS, + RPCStatusCode.INTERNAL, + RPCStatusCode.UNKNOWN, + RPCStatusCode.UNAUTHENTICATED, + RPCStatusCode.PERMISSION_DENIED, + ]: + # Note that UNAUTHENTICATED and PERMISSION_DENIED have Nexus error types but + # we convert to internal because this is not a client auth error and happens + # when the handler fails to auth with Temporal and should be considered + # retryable. + handler_err = nexusrpc.HandlerError( + err.message, type=nexusrpc.HandlerErrorType.INTERNAL + ) + elif err.status == RPCStatusCode.NOT_FOUND: + handler_err = nexusrpc.HandlerError( + err.message, type=nexusrpc.HandlerErrorType.NOT_FOUND + ) + elif err.status == RPCStatusCode.RESOURCE_EXHAUSTED: + handler_err = nexusrpc.HandlerError( + err.message, + type=nexusrpc.HandlerErrorType.RESOURCE_EXHAUSTED, + ) + elif err.status == RPCStatusCode.UNIMPLEMENTED: + handler_err = nexusrpc.HandlerError( + err.message, + type=nexusrpc.HandlerErrorType.NOT_IMPLEMENTED, + ) + elif err.status == RPCStatusCode.DEADLINE_EXCEEDED: + handler_err = nexusrpc.HandlerError( + err.message, + type=nexusrpc.HandlerErrorType.UPSTREAM_TIMEOUT, + ) + else: + handler_err = nexusrpc.HandlerError( + f"Unhandled RPC error status: {err.status}", + type=nexusrpc.HandlerErrorType.INTERNAL, + ) + else: + handler_err = nexusrpc.HandlerError( + str(err), type=nexusrpc.HandlerErrorType.INTERNAL + ) + handler_err.__cause__ = err + return handler_err diff --git a/temporalio/worker/_replayer.py b/temporalio/worker/_replayer.py index 7600118d6..c016495c7 100644 --- a/temporalio/worker/_replayer.py +++ b/temporalio/worker/_replayer.py @@ -261,6 +261,9 @@ def on_eviction_hook( activity_task_poller_behavior=temporalio.bridge.worker.PollerBehaviorSimpleMaximum( 1 ), + nexus_task_poller_behavior=temporalio.bridge.worker.PollerBehaviorSimpleMaximum( + 1 + ), ), ) # Start worker diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 4793f675e..188d80080 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -41,6 +41,7 @@ from ._activity import SharedStateManager, _ActivityWorker from ._interceptor import Interceptor +from ._nexus import _NexusWorker from ._tuning import WorkerTuner from ._workflow import _WorkflowWorker from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner @@ -106,9 +107,11 @@ def __init__( *, task_queue: str, activities: Sequence[Callable] = [], + nexus_service_handlers: Sequence[Any] = [], workflows: Sequence[Type] = [], activity_executor: Optional[concurrent.futures.Executor] = None, workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None, + nexus_task_executor: Optional[concurrent.futures.Executor] = None, workflow_runner: WorkflowRunner = SandboxedWorkflowRunner(), unsandboxed_workflow_runner: WorkflowRunner = UnsandboxedWorkflowRunner(), interceptors: Sequence[Interceptor] = [], @@ -143,6 +146,9 @@ def __init__( activity_task_poller_behavior: PollerBehavior = PollerBehaviorSimpleMaximum( maximum=5 ), + nexus_task_poller_behavior: PollerBehavior = PollerBehaviorSimpleMaximum( + maximum=5 + ), ) -> None: """Create a worker to process workflows and/or activities. @@ -153,10 +159,12 @@ def __init__( client's underlying service client. This client cannot be "lazy". task_queue: Required task queue for this worker. - activities: Set of activity callables decorated with + activities: Activity callables decorated with :py:func:`@activity.defn`. Activities may be async functions or non-async functions. - workflows: Set of workflow classes decorated with + nexus_service_handlers: Instances of Nexus service handler classes + decorated with :py:func:`@nexusrpc.handler.service_handler`. + workflows: Workflow classes decorated with :py:func:`@workflow.defn`. activity_executor: Concurrent executor to use for non-async activities. This is required if any activities are non-async. @@ -175,6 +183,10 @@ def __init__( otherwise. The default one will be properly shutdown, but if one is provided, the caller is responsible for shutting it down after the worker is shut down. + nexus_operation_executor: Executor to use for non-async + Nexus operations. This is required if any operation start methods + are non-`async def`. :py:class:`concurrent.futures.ThreadPoolExecutor` + is recommended. workflow_runner: Runner for workflows. unsandboxed_workflow_runner: Runner for workflows that opt-out of sandboxing. @@ -195,9 +207,9 @@ def __init__( tasks that will ever be given to this worker at one time. Mutually exclusive with ``tuner``. Must be set to at least two if ``max_cached_workflows`` is nonzero. max_concurrent_activities: Maximum number of activity tasks that - will ever be given to this worker concurrently. Mutually exclusive with ``tuner``. + will ever be given to the activity worker concurrently. Mutually exclusive with ``tuner``. max_concurrent_local_activities: Maximum number of local activity - tasks that will ever be given to this worker concurrently. Mutually exclusive with ``tuner``. + tasks that will ever be given to the activity worker concurrently. Mutually exclusive with ``tuner``. tuner: Provide a custom :py:class:`WorkerTuner`. Mutually exclusive with the ``max_concurrent_workflow_tasks``, ``max_concurrent_activities``, and ``max_concurrent_local_activities`` arguments. @@ -296,9 +308,14 @@ def __init__( Defaults to a 5-poller maximum. activity_task_poller_behavior: Specify the behavior of activity task polling. Defaults to a 5-poller maximum. + nexus_task_poller_behavior: Specify the behavior of Nexus task polling. + Defaults to a 5-poller maximum. """ - if not activities and not workflows: - raise ValueError("At least one activity or workflow must be specified") + # TODO(nexus-preview): max_concurrent_nexus_tasks / tuner support + if not (activities or nexus_service_handlers or workflows): + raise ValueError( + "At least one activity, Nexus service, or workflow must be specified" + ) if use_worker_versioning and not build_id: raise ValueError( "build_id must be specified when use_worker_versioning is True" @@ -327,6 +344,7 @@ def __init__( workflows=workflows, activity_executor=activity_executor, workflow_task_executor=workflow_task_executor, + nexus_task_executor=nexus_task_executor, workflow_runner=workflow_runner, unsandboxed_workflow_runner=unsandboxed_workflow_runner, interceptors=interceptors, @@ -361,7 +379,6 @@ def __init__( self._async_context_run_task: Optional[asyncio.Task] = None self._async_context_run_exception: Optional[BaseException] = None - # Create activity and workflow worker self._activity_worker: Optional[_ActivityWorker] = None self._runtime = ( bridge_client.config.runtime or temporalio.runtime.Runtime.default() @@ -393,6 +410,18 @@ def __init__( interceptors=interceptors, metric_meter=self._runtime.metric_meter, ) + self._nexus_worker: Optional[_NexusWorker] = None + if nexus_service_handlers: + self._nexus_worker = _NexusWorker( + bridge_worker=lambda: self._bridge_worker, + client=client, + task_queue=task_queue, + service_handlers=nexus_service_handlers, + data_converter=client_config["data_converter"], + interceptors=interceptors, + metric_meter=self._runtime.metric_meter, + executor=nexus_task_executor, + ) self._workflow_worker: Optional[_WorkflowWorker] = None if workflows: should_enforce_versioning_behavior = ( @@ -524,6 +553,7 @@ def check_activity(activity): versioning_strategy=versioning_strategy, workflow_task_poller_behavior=workflow_task_poller_behavior._to_bridge(), activity_task_poller_behavior=activity_task_poller_behavior._to_bridge(), + nexus_task_poller_behavior=nexus_task_poller_behavior._to_bridge(), ), ) @@ -617,21 +647,30 @@ async def raise_on_shutdown(): except asyncio.CancelledError: pass - tasks: List[asyncio.Task] = [asyncio.create_task(raise_on_shutdown())] + tasks: dict[ + Union[None, _ActivityWorker, _WorkflowWorker, _NexusWorker], asyncio.Task + ] = {None: asyncio.create_task(raise_on_shutdown())} # Create tasks for workers if self._activity_worker: - tasks.append(asyncio.create_task(self._activity_worker.run())) + tasks[self._activity_worker] = asyncio.create_task( + self._activity_worker.run() + ) if self._workflow_worker: - tasks.append(asyncio.create_task(self._workflow_worker.run())) + tasks[self._workflow_worker] = asyncio.create_task( + self._workflow_worker.run() + ) + if self._nexus_worker: + tasks[self._nexus_worker] = asyncio.create_task(self._nexus_worker.run()) # Wait for either worker or shutdown requested - wait_task = asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) + wait_task = asyncio.wait(tasks.values(), return_when=asyncio.FIRST_EXCEPTION) try: await asyncio.shield(wait_task) - # If any of the last two tasks failed, we want to re-raise that as - # the exception - exception = next((t.exception() for t in tasks[1:] if t.done()), None) + # If any of the worker tasks failed, re-raise that as the exception + exception = next( + (t.exception() for w, t in tasks.items() if w and t.done()), None + ) if exception: logger.error("Worker failed, shutting down", exc_info=exception) if self._config["on_fatal_error"]: @@ -646,7 +685,7 @@ async def raise_on_shutdown(): exception = user_cancel_err # Cancel the shutdown task (safe if already done) - tasks[0].cancel() + tasks[None].cancel() graceful_timeout = self._config["graceful_shutdown_timeout"] logger.info( f"Beginning worker shutdown, will wait {graceful_timeout} before cancelling activities" @@ -655,18 +694,10 @@ async def raise_on_shutdown(): # Initiate core worker shutdown self._bridge_worker.initiate_shutdown() - # If any worker task had an exception, replace that task with a queue - # drain (task at index 1 can be activity or workflow worker, task at - # index 2 must be workflow worker if present) - if tasks[1].done() and tasks[1].exception(): - if self._activity_worker: - tasks[1] = asyncio.create_task(self._activity_worker.drain_poll_queue()) - else: - assert self._workflow_worker - tasks[1] = asyncio.create_task(self._workflow_worker.drain_poll_queue()) - if len(tasks) > 2 and tasks[2].done() and tasks[2].exception(): - assert self._workflow_worker - tasks[2] = asyncio.create_task(self._workflow_worker.drain_poll_queue()) + # If any worker task had an exception, replace that task with a queue drain + for worker, task in tasks.items(): + if worker and task.done() and task.exception(): + tasks[worker] = asyncio.create_task(worker.drain_poll_queue()) # Notify shutdown occurring if self._activity_worker: @@ -675,20 +706,23 @@ async def raise_on_shutdown(): self._workflow_worker.notify_shutdown() # Wait for all tasks to complete (i.e. for poller loops to stop) - await asyncio.wait(tasks) + await asyncio.wait(tasks.values()) # Sometimes both workers throw an exception and since we only take the # first, Python may complain with "Task exception was never retrieved" # if we don't get the others. Therefore we call cancel on each task # which suppresses this. - for task in tasks: + for task in tasks.values(): task.cancel() - # If there's an activity worker, we have to let all activity completions - # finish. We cannot guarantee that because poll shutdown completed - # (which means activities completed) that they got flushed to the - # server. + # Let all activity / nexus operations completions finish. We cannot guarantee that + # because poll shutdown completed (which means activities/operations completed) + # that they got flushed to the server. if self._activity_worker: await self._activity_worker.wait_all_completed() + if self._nexus_worker: + await self._nexus_worker.wait_all_completed() + + # TODO(nexus-preview): check that we do all appropriate things for nexus worker that we do for activity worker # Do final shutdown try: @@ -770,6 +804,7 @@ class WorkerConfig(TypedDict, total=False): workflows: Sequence[Type] activity_executor: Optional[concurrent.futures.Executor] workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] + nexus_task_executor: Optional[concurrent.futures.Executor] workflow_runner: WorkflowRunner unsandboxed_workflow_runner: WorkflowRunner interceptors: Sequence[Interceptor] diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 528b42197..a7fe73d67 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -43,6 +43,8 @@ cast, ) +import nexusrpc.handler +from nexusrpc import InputT, OutputT from typing_extensions import Self, TypeAlias, TypedDict import temporalio.activity @@ -58,6 +60,7 @@ import temporalio.common import temporalio.converter import temporalio.exceptions +import temporalio.nexus import temporalio.workflow from temporalio.service import __version__ @@ -72,6 +75,7 @@ StartActivityInput, StartChildWorkflowInput, StartLocalActivityInput, + StartNexusOperationInput, WorkflowInboundInterceptor, WorkflowOutboundInterceptor, ) @@ -228,6 +232,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None: self._pending_timers: Dict[int, _TimerHandle] = {} self._pending_activities: Dict[int, _ActivityHandle] = {} self._pending_child_workflows: Dict[int, _ChildWorkflowHandle] = {} + self._pending_nexus_operations: Dict[int, _NexusOperationHandle] = {} self._pending_external_signals: Dict[int, asyncio.Future] = {} self._pending_external_cancels: Dict[int, asyncio.Future] = {} # Keyed by type @@ -507,6 +512,10 @@ def _apply( self._apply_resolve_child_workflow_execution_start( job.resolve_child_workflow_execution_start ) + elif job.HasField("resolve_nexus_operation_start"): + self._apply_resolve_nexus_operation_start(job.resolve_nexus_operation_start) + elif job.HasField("resolve_nexus_operation"): + self._apply_resolve_nexus_operation(job.resolve_nexus_operation) elif job.HasField("resolve_request_cancel_external_workflow"): self._apply_resolve_request_cancel_external_workflow( job.resolve_request_cancel_external_workflow @@ -770,7 +779,6 @@ def _apply_resolve_child_workflow_execution( self, job: temporalio.bridge.proto.workflow_activation.ResolveChildWorkflowExecution, ) -> None: - # No matter the result, we know we want to pop handle = self._pending_child_workflows.pop(job.seq, None) if not handle: raise RuntimeError( @@ -839,6 +847,72 @@ def _apply_resolve_child_workflow_execution_start( else: raise RuntimeError("Child workflow start did not have a known status") + def _apply_resolve_nexus_operation_start( + self, + job: temporalio.bridge.proto.workflow_activation.ResolveNexusOperationStart, + ) -> None: + handle = self._pending_nexus_operations.get(job.seq) + if not handle: + raise RuntimeError( + f"Failed to find nexus operation handle for job sequence number {job.seq}" + ) + if job.HasField("operation_token"): + # The Nexus operation started asynchronously. A `ResolveNexusOperation` job + # will follow in a future activation. + handle._resolve_start_success(job.operation_token) + elif job.HasField("started_sync"): + # The Nexus operation 'started' in the sense that it's already resolved. A + # `ResolveNexusOperation` job will be in the same activation. + handle._resolve_start_success(None) + elif job.HasField("cancelled_before_start"): + # The operation was cancelled before it was ever sent to server (same WFT). + # Note that core will still send a `ResolveNexusOperation` job in the same + # activation, so there does not need to be an exceptional case for this in + # lang. + # TODO(nexus-preview): confirm appropriate to take no action here + pass + else: + raise ValueError(f"Unknown Nexus operation start status: {job}") + + def _apply_resolve_nexus_operation( + self, + job: temporalio.bridge.proto.workflow_activation.ResolveNexusOperation, + ) -> None: + handle = self._pending_nexus_operations.get(job.seq) + if not handle: + raise RuntimeError( + f"Failed to find nexus operation handle for job sequence number {job.seq}" + ) + + # Handle the four oneof variants of NexusOperationResult + result = job.result + if result.HasField("completed"): + [output] = self._convert_payloads( + [result.completed], + [handle._input.output_type] if handle._input.output_type else None, + ) + handle._resolve_success(output) + elif result.HasField("failed"): + handle._resolve_failure( + self._failure_converter.from_failure( + result.failed, self._payload_converter + ) + ) + elif result.HasField("cancelled"): + handle._resolve_failure( + self._failure_converter.from_failure( + result.cancelled, self._payload_converter + ) + ) + elif result.HasField("timed_out"): + handle._resolve_failure( + self._failure_converter.from_failure( + result.timed_out, self._payload_converter + ) + ) + else: + raise RuntimeError("Nexus operation did not have a result") + def _apply_resolve_request_cancel_external_workflow( self, job: temporalio.bridge.proto.workflow_activation.ResolveRequestCancelExternalWorkflow, @@ -1299,6 +1373,7 @@ def workflow_start_activity( ) ) + # workflow_start_child_workflow ret_type async def workflow_start_child_workflow( self, workflow: Any, @@ -1333,7 +1408,7 @@ async def workflow_start_child_workflow( if isinstance(workflow, str): name = workflow elif callable(workflow): - defn = temporalio.workflow._Definition.must_from_run_fn(workflow) + defn = temporalio.workflow._Definition.must_from_run_fn(workflow) # pyright: ignore if not defn.name: raise TypeError("Cannot invoke dynamic workflow explicitly") name = defn.name @@ -1418,6 +1493,29 @@ def workflow_start_local_activity( ) ) + async def workflow_start_nexus_operation( + self, + endpoint: str, + service: str, + operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]], + input: Any, + output_type: Optional[Type[OutputT]] = None, + schedule_to_close_timeout: Optional[timedelta] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> temporalio.workflow.NexusOperationHandle[OutputT]: + # start_nexus_operation + return await self._outbound.start_nexus_operation( + StartNexusOperationInput( + endpoint=endpoint, + service=service, + operation=operation, + input=input, + output_type=output_type, + schedule_to_close_timeout=schedule_to_close_timeout, + headers=headers, + ) + ) + def workflow_time_ns(self) -> int: return self._time_ns @@ -1722,6 +1820,47 @@ async def run_child() -> Any: except asyncio.CancelledError: apply_child_cancel_error() + async def _outbound_start_nexus_operation( + self, input: StartNexusOperationInput[Any, OutputT] + ) -> _NexusOperationHandle[OutputT]: + # A Nexus operation handle contains two futures: self._start_fut is resolved as a + # result of the Nexus operation starting (activation job: + # resolve_nexus_operation_start), and self._result_fut is resolved as a result of + # the Nexus operation completing (activation job: resolve_nexus_operation). The + # handle itself corresponds to an asyncio.Task which waits on self.result_fut, + # handling CancelledError by emitting a RequestCancelNexusOperation command. We do + # not return the handle until we receive resolve_nexus_operation_start, like + # ChildWorkflowHandle and unlike ActivityHandle. Note that a Nexus operation may + # complete synchronously (in which case both jobs will be sent in the same + # activation, and start will be resolved without an operation token), or + # asynchronously (in which case start they may be sent in separate activations, + # and start will be resolved with an operation token). See comments in + # tests/worker/test_nexus.py for worked examples of the evolution of the resulting + # handle state machine in the sync and async Nexus response cases. + handle: _NexusOperationHandle[OutputT] + + async def operation_handle_fn() -> OutputT: + while True: + try: + return cast(OutputT, await asyncio.shield(handle._result_fut)) + except asyncio.CancelledError: + cancel_command = self._add_command() + handle._apply_cancel_command(cancel_command) + + handle = _NexusOperationHandle( + self, self._next_seq("nexus_operation"), input, operation_handle_fn() + ) + handle._apply_schedule_command() + self._pending_nexus_operations[handle._seq] = handle + + while True: + try: + await asyncio.shield(handle._start_fut) + return handle + except asyncio.CancelledError: + cancel_command = self._add_command() + handle._apply_cancel_command(cancel_command) + #### Miscellaneous helpers #### # These are in alphabetical order. @@ -2458,6 +2597,11 @@ async def start_child_workflow( ) -> temporalio.workflow.ChildWorkflowHandle[Any, Any]: return await self._instance._outbound_start_child_workflow(input) + async def start_nexus_operation( + self, input: StartNexusOperationInput[Any, OutputT] + ) -> _NexusOperationHandle[OutputT]: + return await self._instance._outbound_start_nexus_operation(input) + def start_local_activity( self, input: StartLocalActivityInput ) -> temporalio.workflow.ActivityHandle[Any]: @@ -2844,6 +2988,94 @@ async def cancel(self) -> None: await self._instance._cancel_external_workflow(command) +# TODO(nexus-preview): are we sure we don't want to inherit from asyncio.Task as +# ActivityHandle and ChildWorkflowHandle do? I worry that we should provide .done(), +# .result(), .exception() etc for consistency. +class _NexusOperationHandle(temporalio.workflow.NexusOperationHandle[OutputT]): + def __init__( + self, + instance: _WorkflowInstanceImpl, + seq: int, + input: StartNexusOperationInput[Any, OutputT], + fn: Coroutine[Any, Any, OutputT], + ): + self._instance = instance + self._seq = seq + self._input = input + self._task = asyncio.Task(fn) + self._start_fut: asyncio.Future[Optional[str]] = instance.create_future() + self._result_fut: asyncio.Future[Optional[OutputT]] = instance.create_future() + + @property + def operation_token(self) -> Optional[str]: + # TODO(nexus-preview): How should this behave? + # Java has a separate class that only exists if the operation token exists: + # https://github.com/temporalio/sdk-java/blob/master/temporal-sdk/src/main/java/io/temporal/internal/sync/NexusOperationExecutionImpl.java#L26 + # And Go similar: + # https://github.com/temporalio/sdk-go/blob/master/internal/workflow.go#L2770-L2771 + try: + return self._start_fut.result() + except BaseException: + return None + + async def result(self) -> OutputT: + return await self._task + + def __await__(self) -> Generator[Any, Any, OutputT]: + return self._task.__await__() + + def __repr__(self) -> str: + return ( + f"{self._start_fut} " + f"{self._result_fut} " + f"Task[{self._task._state}] fut_waiter = {self._task._fut_waiter}) ({self._task._must_cancel})" # type: ignore + ) + + def cancel(self) -> bool: + return self._task.cancel() + + def _resolve_start_success(self, operation_token: Optional[str]) -> None: + # We intentionally let this error if already done + self._start_fut.set_result(operation_token) + + def _resolve_success(self, result: Any) -> None: + # We intentionally let this error if already done + self._result_fut.set_result(result) + + def _resolve_failure(self, err: BaseException) -> None: + if self._start_fut.done(): + # We intentionally let this error if already done + self._result_fut.set_exception(err) + else: + self._start_fut.set_exception(err) + # Set null result to avoid warning about unhandled future + self._result_fut.set_result(None) + + def _apply_schedule_command(self) -> None: + command = self._instance._add_command() + v = command.schedule_nexus_operation + v.seq = self._seq + v.endpoint = self._input.endpoint + v.service = self._input.service + v.operation = self._input.operation_name + v.input.CopyFrom( + self._instance._payload_converter.to_payload(self._input.input) + ) + if self._input.schedule_to_close_timeout is not None: + v.schedule_to_close_timeout.FromTimedelta( + self._input.schedule_to_close_timeout + ) + if self._input.headers: + for key, val in self._input.headers.items(): + v.nexus_header[key] = val + + def _apply_cancel_command( + self, + command: temporalio.bridge.proto.workflow_commands.WorkflowCommand, + ) -> None: + command.request_cancel_nexus_operation.seq = self._seq + + class _ContinueAsNewError(temporalio.workflow.ContinueAsNewError): def __init__( self, instance: _WorkflowInstanceImpl, input: ContinueAsNewInput diff --git a/temporalio/worker/workflow_sandbox/_restrictions.py b/temporalio/worker/workflow_sandbox/_restrictions.py index fdc126809..32f7ba012 100644 --- a/temporalio/worker/workflow_sandbox/_restrictions.py +++ b/temporalio/worker/workflow_sandbox/_restrictions.py @@ -471,6 +471,7 @@ def with_child_unrestricted(self, *child_path: str) -> SandboxMatcher: # https://wrapt.readthedocs.io/en/latest/issues.html#using-issubclass-on-abstract-classes "asyncio", "abc", + "nexusrpc", "temporalio", # Due to pkg_resources use of base classes caused by the ABC issue # above, and Otel's use of pkg_resources, we pass it through diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 409c8d690..34cc4a55c 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -4,7 +4,6 @@ import asyncio import contextvars -import dataclasses import inspect import logging import threading @@ -23,6 +22,7 @@ Awaitable, Callable, Dict, + Generator, Generic, Iterable, Iterator, @@ -40,6 +40,9 @@ overload, ) +import nexusrpc +import nexusrpc.handler +from nexusrpc import InputT, OutputT from typing_extensions import ( Concatenate, Literal, @@ -54,6 +57,7 @@ import temporalio.common import temporalio.converter import temporalio.exceptions +import temporalio.nexus import temporalio.workflow from .types import ( @@ -846,6 +850,18 @@ def workflow_start_local_activity( activity_id: Optional[str], ) -> ActivityHandle[Any]: ... + @abstractmethod + async def workflow_start_nexus_operation( + self, + endpoint: str, + service: str, + operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]], + input: Any, + output_type: Optional[Type[OutputT]] = None, + schedule_to_close_timeout: Optional[timedelta] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> NexusOperationHandle[OutputT]: ... + @abstractmethod def workflow_time_ns(self) -> int: ... @@ -1967,7 +1983,7 @@ class _AsyncioTask(asyncio.Task[AnyType]): pass else: - + # TODO: inherited classes should be other way around? class _AsyncioTask(Generic[AnyType], asyncio.Task): pass @@ -4368,6 +4384,21 @@ async def execute_child_workflow( return await handle +class NexusOperationHandle(Generic[OutputT]): + def cancel(self) -> bool: + """ + Request cancellation of the operation. + """ + raise NotImplementedError + + def __await__(self) -> Generator[Any, Any, OutputT]: + raise NotImplementedError + + @property + def operation_token(self) -> Optional[str]: + raise NotImplementedError + + class ExternalWorkflowHandle(Generic[SelfType]): """Handle for interacting with an external workflow. @@ -5074,3 +5105,151 @@ def _to_proto(self) -> temporalio.bridge.proto.common.VersioningIntent.ValueType elif self == VersioningIntent.DEFAULT: return temporalio.bridge.proto.common.VersioningIntent.DEFAULT return temporalio.bridge.proto.common.VersioningIntent.UNSPECIFIED + + +# Nexus + +ServiceT = TypeVar("ServiceT") + + +class NexusClient(ABC, Generic[ServiceT]): + """ + A client for invoking Nexus operations. + + example: + ```python + nexus_client = workflow.create_nexus_client( + endpoint=my_nexus_endpoint, + service=MyService, + ) + handle = await nexus_client.start_operation( + operation=MyService.my_operation, + input=MyOperationInput(value="hello"), + schedule_to_close_timeout=timedelta(seconds=10), + ) + result = await handle.result() + ``` + """ + + # TODO(nexus-prerelease): overloads: no-input, ret type + # TODO(nexus-prerelease): should it be an error to use a reference to a method on a class other than that supplied? + @abstractmethod + async def start_operation( + self, + operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]], + input: InputT, + *, + output_type: Optional[Type[OutputT]] = None, + schedule_to_close_timeout: Optional[timedelta] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> NexusOperationHandle[OutputT]: + """Start a Nexus operation and return its handle. + + Args: + operation: The Nexus operation. + input: The Nexus operation input. + output_type: The Nexus operation output type. + schedule_to_close_timeout: Timeout for the entire operation attempt. + headers: Headers to send with the Nexus HTTP request. + + Returns: + A handle to the Nexus operation. The result can be obtained as + ```python + await handle.result() + ``` + """ + ... + + # TODO(nexus-prerelease): overloads: no-input, ret type + @abstractmethod + async def execute_operation( + self, + operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]], + input: InputT, + *, + output_type: Optional[Type[OutputT]] = None, + schedule_to_close_timeout: Optional[timedelta] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> OutputT: ... + + +class _NexusClient(NexusClient[ServiceT]): + def __init__( + self, + *, + endpoint: str, + service: Union[Type[ServiceT], str], + ) -> None: + """Create a Nexus client. + + Args: + service: The Nexus service. + endpoint: The Nexus endpoint. + """ + # If service is not a str, then it must be a service interface or implementation + # class. + if isinstance(service, str): + self.service_name = service + elif service_defn := nexusrpc.get_service_definition(service): + self.service_name = service_defn.name + else: + raise ValueError( + f"`service` may be a name (str), or a class decorated with either " + f"@nexusrpc.handler.service_handler or @nexusrpc.service. " + f"Invalid service type: {type(service)}" + ) + self.endpoint = endpoint + + # TODO(nexus-prerelease): overloads: no-input, ret type + # TODO(nexus-prerelease): should it be an error to use a reference to a method on a class other than that supplied? + async def start_operation( + self, + operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]], + input: InputT, + *, + output_type: Optional[Type[OutputT]] = None, + schedule_to_close_timeout: Optional[timedelta] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> temporalio.workflow.NexusOperationHandle[OutputT]: + return ( + await temporalio.workflow._Runtime.current().workflow_start_nexus_operation( + endpoint=self.endpoint, + service=self.service_name, + operation=operation, + input=input, + output_type=output_type, + schedule_to_close_timeout=schedule_to_close_timeout, + headers=headers, + ) + ) + + # TODO(nexus-prerelease): overloads: no-input, ret type + async def execute_operation( + self, + operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]], + input: InputT, + *, + output_type: Optional[Type[OutputT]] = None, + schedule_to_close_timeout: Optional[timedelta] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> OutputT: + handle = await self.start_operation( + operation, + input, + output_type=output_type, + schedule_to_close_timeout=schedule_to_close_timeout, + headers=headers, + ) + return await handle + + +def create_nexus_client( + endpoint: str, service: Union[Type[ServiceT], str] +) -> NexusClient[ServiceT]: + """Create a Nexus client. + + Args: + endpoint: The Nexus endpoint. + service: The Nexus service. + """ + return _NexusClient(endpoint=endpoint, service=service) diff --git a/tests/conftest.py b/tests/conftest.py index 37b1fe89c..7d9f0157d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -99,6 +99,7 @@ def env_type(request: pytest.FixtureRequest) -> str: @pytest_asyncio.fixture(scope="session") async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: if env_type == "local": + http_port = 7243 env = await WorkflowEnvironment.start_local( dev_server_extra_args=[ "--dynamic-config-value", @@ -117,13 +118,18 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: "system.enableDeploymentVersions=true", "--dynamic-config-value", "frontend.activityAPIsEnabled=true", + "--http-port", + str(http_port), ], dev_server_download_version=DEV_SERVER_DOWNLOAD_VERSION, ) + # TODO(nexus-preview): expose this in a more principled way + env._http_port = http_port # type: ignore elif env_type == "time-skipping": env = await WorkflowEnvironment.start_time_skipping() else: env = WorkflowEnvironment.from_client(await Client.connect(env_type)) + yield env await env.shutdown() diff --git a/tests/helpers/nexus.py b/tests/helpers/nexus.py new file mode 100644 index 000000000..4452944da --- /dev/null +++ b/tests/helpers/nexus.py @@ -0,0 +1,119 @@ +import dataclasses +from dataclasses import dataclass +from typing import Any, Mapping, Optional + +import temporalio.api +import temporalio.api.nexus.v1 +import temporalio.api.operatorservice.v1 +import temporalio.workflow +from temporalio.client import Client +from temporalio.converter import FailureConverter, PayloadConverter + +with temporalio.workflow.unsafe.imports_passed_through(): + import httpx + from google.protobuf import json_format + + +def make_nexus_endpoint_name(task_queue: str) -> str: + # Create endpoints for different task queues without name collisions. + return f"nexus-endpoint-{task_queue}" + + +# TODO(nexus-preview): How do we recommend that users create endpoints in their own tests? +# See https://github.com/temporalio/sdk-typescript/pull/1708/files?show-viewed-files=true&file-filters%5B%5D=&w=0#r2082549085 +async def create_nexus_endpoint( + task_queue: str, client: Client +) -> temporalio.api.operatorservice.v1.CreateNexusEndpointResponse: + name = make_nexus_endpoint_name(task_queue) + return await client.operator_service.create_nexus_endpoint( + temporalio.api.operatorservice.v1.CreateNexusEndpointRequest( + spec=temporalio.api.nexus.v1.EndpointSpec( + name=name, + target=temporalio.api.nexus.v1.EndpointTarget( + worker=temporalio.api.nexus.v1.EndpointTarget.Worker( + namespace=client.namespace, + task_queue=task_queue, + ) + ), + ) + ) + ) + + +@dataclass +class ServiceClient: + server_address: str # E.g. http://127.0.0.1:7243 + endpoint: str + service: str + + async def start_operation( + self, + operation: str, + body: Optional[dict[str, Any]] = None, + headers: Mapping[str, str] = {}, + ) -> httpx.Response: + """ + Start a Nexus operation. + """ + # TODO(nexus-preview): Support callback URL as query param + async with httpx.AsyncClient() as http_client: + return await http_client.post( + f"{self.server_address}/nexus/endpoints/{self.endpoint}/services/{self.service}/{operation}", + json=body, + headers=headers, + ) + + async def cancel_operation( + self, + operation: str, + token: str, + ) -> httpx.Response: + async with httpx.AsyncClient() as http_client: + return await http_client.post( + f"{self.server_address}/nexus/endpoints/{self.endpoint}/services/{self.service}/{operation}/cancel", + # Token can also be sent as "Nexus-Operation-Token" header + params={"token": token}, + ) + + +def dataclass_as_dict(dataclass: Any) -> dict[str, Any]: + """ + Return a shallow dict of the dataclass's fields. + + dataclasses.as_dict goes too far (attempts to pickle values) + """ + return { + field.name: getattr(dataclass, field.name) + for field in dataclasses.fields(dataclass) + } + + +@dataclass +class Failure: + """A Nexus Failure object, with details parsed into an exception. + + https://github.com/nexus-rpc/api/blob/main/SPEC.md#failure + """ + + message: str = "" + metadata: Optional[dict[str, str]] = None + details: Optional[dict[str, Any]] = None + + exception_from_details: Optional[BaseException] = dataclasses.field( + init=False, default=None + ) + + def __post_init__(self) -> None: + if self.metadata and (error_type := self.metadata.get("type")): + self.exception_from_details = self._instantiate_exception( + error_type, self.details + ) + + def _instantiate_exception( + self, error_type: str, details: Optional[dict[str, Any]] + ) -> BaseException: + proto = { + "temporal.api.failure.v1.Failure": temporalio.api.failure.v1.Failure, + }[error_type]() + json_format.ParseDict(self.details, proto, ignore_unknown_fields=True) + return FailureConverter.default.from_failure(proto, PayloadConverter.default) diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py new file mode 100644 index 000000000..26f94c122 --- /dev/null +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -0,0 +1,169 @@ +import uuid + +import httpx +import nexusrpc.handler +import pytest +from nexusrpc.handler import sync_operation + +from temporalio import nexus, workflow +from temporalio.client import Client +from temporalio.nexus._util import get_operation_factory +from temporalio.worker import Worker +from tests.helpers.nexus import create_nexus_endpoint + +HTTP_PORT = 7243 + + +@workflow.defn +class MyWorkflow: + @workflow.run + async def run(self, input: int) -> int: + return input + 1 + + +@nexusrpc.service +class MyService: + increment: nexusrpc.Operation[int, int] + + +class MyIncrementOperationHandler(nexusrpc.handler.OperationHandler[int, int]): + async def start( + self, + ctx: nexusrpc.handler.StartOperationContext, + input: int, + ) -> nexusrpc.handler.StartOperationResultAsync: + wrctx = nexus.WorkflowRunOperationContext.from_start_operation_context(ctx) + wf_handle = await wrctx.start_workflow( + MyWorkflow.run, input, id=str(uuid.uuid4()) + ) + return nexusrpc.handler.StartOperationResultAsync(token=wf_handle.to_token()) + + async def cancel( + self, + ctx: nexusrpc.handler.CancelOperationContext, + token: str, + ) -> None: + raise NotImplementedError + + async def fetch_info( + self, + ctx: nexusrpc.handler.FetchOperationInfoContext, + token: str, + ) -> nexusrpc.OperationInfo: + raise NotImplementedError + + async def fetch_result( + self, + ctx: nexusrpc.handler.FetchOperationResultContext, + token: str, + ) -> int: + raise NotImplementedError + + +@nexusrpc.handler.service_handler +class MyServiceHandlerWithWorkflowRunOperation: + @nexusrpc.handler._decorators.operation_handler + def increment(self) -> nexusrpc.handler.OperationHandler[int, int]: + return MyIncrementOperationHandler() + + +async def test_run_nexus_service_from_programmatically_created_service_handler( + client: Client, +): + task_queue = str(uuid.uuid4()) + + service_handler = nexusrpc.handler._core.ServiceHandler( + service=nexusrpc.ServiceDefinition( + name="MyService", + operations={ + "increment": nexusrpc.Operation[int, int]( + name="increment", + method_name="increment", + input_type=int, + output_type=int, + ), + }, + ), + operation_handlers={ + "increment": MyIncrementOperationHandler(), + }, + ) + + service_name = service_handler.service.name + + endpoint = (await create_nexus_endpoint(task_queue, client)).endpoint.id + async with Worker( + client, + task_queue=task_queue, + nexus_service_handlers=[service_handler], + ): + async with httpx.AsyncClient() as http_client: + response = await http_client.post( + f"http://127.0.0.1:{HTTP_PORT}/nexus/endpoints/{endpoint}/services/{service_name}/increment", + json=1, + ) + assert response.status_code == 201 + + +def make_incrementer_user_service_definition_and_service_handler_classes( + op_names: list[str], +) -> tuple[type, type]: + # + # service contract + # + + ops = {name: nexusrpc.Operation[int, int] for name in op_names} + service_cls: type = nexusrpc.service(type("ServiceContract", (), ops)) + + # + # service handler + # + @sync_operation + async def _increment_op( + self, + ctx: nexusrpc.handler.StartOperationContext, + input: int, + ) -> int: + return input + 1 + + op_handler_factories = {} + for name in op_names: + op_handler_factory, _ = get_operation_factory(_increment_op) + assert op_handler_factory + op_handler_factories[name] = op_handler_factory + + handler_cls: type = nexusrpc.handler.service_handler(service=service_cls)( + type("ServiceImpl", (), op_handler_factories) + ) + + return service_cls, handler_cls + + +@pytest.mark.skip( + reason="Dynamic creation of service contract using type() is not supported" +) +async def test_dynamic_creation_of_user_handler_classes(client: Client): + task_queue = str(uuid.uuid4()) + + service_cls, handler_cls = ( + make_incrementer_user_service_definition_and_service_handler_classes( + ["increment"] + ) + ) + + assert (service_defn := nexusrpc.get_service_definition(service_cls)) + service_name = service_defn.name + + endpoint = (await create_nexus_endpoint(task_queue, client)).endpoint.id + async with Worker( + client, + task_queue=task_queue, + nexus_service_handlers=[handler_cls()], + ): + async with httpx.AsyncClient() as http_client: + response = await http_client.post( + f"http://127.0.0.1:{HTTP_PORT}/nexus/endpoints/{endpoint}/services/{service_name}/increment", + json=1, + ) + assert response.status_code == 200 + assert response.json() == 2 diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py new file mode 100644 index 000000000..75227d745 --- /dev/null +++ b/tests/nexus/test_handler.py @@ -0,0 +1,1180 @@ +""" +See https://github.com/nexus-rpc/api/blob/main/SPEC.md + +This file contains test coverage for Nexus StartOperation and CancelOperation +operations issued by a caller directly via HTTP. + +The response to StartOperation may indicate a protocol-level failure (400 +BAD_REQUEST, 520 UPSTREAM_TIMEOUT, etc). In this case the body is a valid +Failure object. + + +(https://github.com/nexus-rpc/api/blob/main/SPEC.md#predefined-handler-errors) + +""" + +import asyncio +import concurrent.futures +import json +import logging +import pprint +import uuid +from concurrent.futures.thread import ThreadPoolExecutor +from dataclasses import dataclass +from types import MappingProxyType +from typing import Any, Callable, Mapping, Optional, Type, Union + +import httpx +import nexusrpc +import pytest +from nexusrpc import ( + HandlerError, + HandlerErrorType, + OperationError, + OperationErrorState, + OperationInfo, +) +from nexusrpc.handler import ( + CancelOperationContext, + FetchOperationInfoContext, + FetchOperationResultContext, + OperationHandler, + StartOperationContext, + StartOperationResultSync, + service_handler, + sync_operation, +) +from nexusrpc.handler._decorators import operation_handler +from nexusrpc.syncio.handler import sync_operation as syncio_sync_operation + +from temporalio import nexus, workflow +from temporalio.client import Client +from temporalio.common import WorkflowIDReusePolicy +from temporalio.exceptions import ApplicationError +from temporalio.nexus import ( + WorkflowRunOperationContext, + workflow_run_operation, +) +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from tests.helpers.nexus import ( + Failure, + ServiceClient, + create_nexus_endpoint, + dataclass_as_dict, +) + +HTTP_PORT = 7243 + + +@dataclass +class Input: + value: str + + +@dataclass +class Output: + value: str + + +@dataclass +class NonSerializableOutput: + callable: Callable[[], Any] = lambda: None + + +# TODO(nexus-prelease): Test attaching multiple callers to the same operation. +# TODO(nexus-preview): type check nexus implementation under mypy +# TODO(nexus-preview): test malformed inbound_links and outbound_links + +# TODO(nexus-prerelease): 2025-07-02T23:29:20.000489Z WARN temporal_sdk_core::worker::nexus: Nexus task not found on completion. This may happen if the operation has already been cancelled but completed anyway. details=Status { code: NotFound, message: "Nexus task not found or already expired", details: b"\x08\x05\x12'Nexus task not found or already expired\x1aB\n@type.googleapis.com/temporal.api.errordetails.v1.NotFoundFailure", metadata: MetadataMap { headers: {"content-type": "application/grpc"} }, source: None } + + +@nexusrpc.service +class MyService: + echo: nexusrpc.Operation[Input, Output] + echo_renamed: nexusrpc.Operation[Input, Output] = nexusrpc.Operation( + name="echo-renamed" + ) + hang: nexusrpc.Operation[Input, Output] + log: nexusrpc.Operation[Input, Output] + workflow_run_operation_happy_path: nexusrpc.Operation[Input, Output] + sync_operation_with_non_async_def: nexusrpc.Operation[Input, Output] + operation_returning_unwrapped_result_at_runtime_error: nexusrpc.Operation[ + Input, Output + ] + non_retryable_application_error: nexusrpc.Operation[Input, Output] + retryable_application_error: nexusrpc.Operation[Input, Output] + check_operation_timeout_header: nexusrpc.Operation[Input, Output] + workflow_run_op_link_test: nexusrpc.Operation[Input, Output] + handler_error_internal: nexusrpc.Operation[Input, Output] + operation_error_failed: nexusrpc.Operation[Input, Output] + idempotency_check: nexusrpc.Operation[None, Output] + non_serializable_output: nexusrpc.Operation[Input, NonSerializableOutput] + + +@workflow.defn +class MyWorkflow: + @workflow.run + async def run(self, input: Input) -> Output: + return Output(value=f"from workflow: {input.value}") + + +@workflow.defn +class WorkflowWithoutTypeAnnotations: + @workflow.run + async def run(self, input): # type: ignore + return Output(value=f"from workflow without type annotations: {input}") + + +@workflow.defn +class MyLinkTestWorkflow: + @workflow.run + async def run(self, input: Input) -> Output: + return Output(value=f"from link test workflow: {input.value}") + + +# The service_handler decorator is applied by the test +class MyServiceHandler: + @sync_operation + async def echo(self, ctx: StartOperationContext, input: Input) -> Output: + assert ctx.headers["test-header-key"] == "test-header-value" + ctx.outbound_links.extend(ctx.inbound_links) + assert nexus.in_operation() + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) + + # The name override is present in the service definition. But the test below submits + # the same operation name in the request whether using a service definition or now. + # The name override here is necessary when the test is not using the service + # definition. It should be permitted when the service definition is in effect, as + # long as the name override is the same as that in the service definition. + @sync_operation(name="echo-renamed") + async def echo_renamed(self, ctx: StartOperationContext, input: Input) -> Output: + return await self.echo(ctx, input) + + @sync_operation + async def hang(self, ctx: StartOperationContext, input: Input) -> Output: + await asyncio.Future() + return Output(value="won't reach here") + + @sync_operation + async def non_retryable_application_error( + self, ctx: StartOperationContext, input: Input + ) -> Output: + raise ApplicationError( + "non-retryable application error", + "details arg", + # TODO(nexus-preview): what values of `type` should be tested? + type="TestFailureType", + non_retryable=True, + ) + + @sync_operation + async def retryable_application_error( + self, ctx: StartOperationContext, input: Input + ) -> Output: + raise ApplicationError( + "retryable application error", + "details arg", + type="TestFailureType", + non_retryable=False, + ) + + @sync_operation + async def handler_error_internal( + self, ctx: StartOperationContext, input: Input + ) -> Output: + raise HandlerError( + message="deliberate internal handler error", + type=HandlerErrorType.INTERNAL, + retryable=False, + ) from RuntimeError("cause message") + + @sync_operation + async def operation_error_failed( + self, ctx: StartOperationContext, input: Input + ) -> Output: + raise OperationError( + message="deliberate operation error", + state=OperationErrorState.FAILED, + ) + + @sync_operation + async def check_operation_timeout_header( + self, ctx: StartOperationContext, input: Input + ) -> Output: + assert "operation-timeout" in ctx.headers + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) + + @sync_operation + async def log(self, ctx: StartOperationContext, input: Input) -> Output: + nexus.logger.info( + "Logging from start method", extra={"input_value": input.value} + ) + return Output(value=f"logged: {input.value}") + + @workflow_run_operation + async def workflow_run_operation_happy_path( + self, ctx: WorkflowRunOperationContext, input: Input + ) -> nexus.WorkflowHandle[Output]: + assert nexus.in_operation() + return await ctx.start_workflow( + MyWorkflow.run, + input, + id=str(uuid.uuid4()), + id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, + ) + + @sync_operation + async def sync_operation_with_non_async_def( + self, ctx: StartOperationContext, input: Input + ) -> Output: + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) + + @workflow_run_operation + async def workflow_run_op_link_test( + self, ctx: WorkflowRunOperationContext, input: Input + ) -> nexus.WorkflowHandle[Output]: + assert any( + link.url == "http://inbound-link/" for link in ctx.inbound_links + ), "Inbound link not found" + assert ctx.request_id == "test-request-id-123", "Request ID mismatch" + ctx.outbound_links.extend(ctx.inbound_links) + + return await ctx.start_workflow( + MyLinkTestWorkflow.run, + input, + id=str(uuid.uuid4()), + ) + + class OperationHandlerReturningUnwrappedResult(OperationHandler[Input, Output]): + async def start( + self, + ctx: StartOperationContext, + input: Input, + # This return type is a type error, but VSCode doesn't flag it unless + # "python.analysis.typeCheckingMode" is set to "strict" + ) -> Output: + # Invalid: start method must wrap result as StartOperationResultSync + # or StartOperationResultAsync + return Output(value="unwrapped result error") + + async def fetch_info( + self, ctx: FetchOperationInfoContext, token: str + ) -> OperationInfo: + raise NotImplementedError + + async def fetch_result( + self, ctx: FetchOperationResultContext, token: str + ) -> Output: + raise NotImplementedError + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + raise NotImplementedError + + @operation_handler + def operation_returning_unwrapped_result_at_runtime_error( + self, + ) -> OperationHandler[Input, Output]: + return MyServiceHandler.OperationHandlerReturningUnwrappedResult() + + @sync_operation + async def idempotency_check( + self, ctx: StartOperationContext, input: None + ) -> Output: + return Output(value=f"request_id: {ctx.request_id}") + + @sync_operation + async def non_serializable_output( + self, ctx: StartOperationContext, input: Input + ) -> NonSerializableOutput: + return NonSerializableOutput() + + +# Immutable dicts that can be used as dataclass field defaults + +SUCCESSFUL_RESPONSE_HEADERS = MappingProxyType( + { + "content-type": "application/json", + } +) + +UNSUCCESSFUL_RESPONSE_HEADERS = MappingProxyType( + { + "content-type": "application/json", + "temporal-nexus-failure-source": "worker", + } +) + + +@dataclass +class SuccessfulResponse: + status_code: int + body_json: Optional[Union[dict[str, Any], Callable[[dict[str, Any]], bool]]] = None + headers: Mapping[str, str] = SUCCESSFUL_RESPONSE_HEADERS + + +@dataclass +class UnsuccessfulResponse: + status_code: int + # Expected value of Nexus-Request-Retryable header + retryable_header: Optional[bool] + failure_message: Union[str, Callable[[str], bool]] + # Is the Nexus Failure expected to have the details field populated? + failure_details: bool = True + # Expected value of inverse of non_retryable attribute of exception. + retryable_exception: bool = True + body_json: Optional[Callable[[dict[str, Any]], bool]] = None + headers: Mapping[str, str] = UNSUCCESSFUL_RESPONSE_HEADERS + + +class _TestCase: + operation: str + service_defn: str = "MyService" + input: Input = Input("") + headers: dict[str, str] = {} + expected: SuccessfulResponse + expected_without_service_definition: Optional[SuccessfulResponse] = None + skip = "" + + @classmethod + def check_response( + cls, + response: httpx.Response, + with_service_definition: bool, + ) -> None: + assert response.status_code == cls.expected.status_code, ( + f"expected status code {cls.expected.status_code} " + f"but got {response.status_code} for response content" + f"{pprint.pformat(response.content.decode())}" + ) + if not with_service_definition and cls.expected_without_service_definition: + expected = cls.expected_without_service_definition + else: + expected = cls.expected + if expected.body_json is not None: + body = response.json() + assert isinstance(body, dict) + if isinstance(expected.body_json, dict): + assert body == expected.body_json + else: + assert expected.body_json(body) + assert response.headers.items() >= cls.expected.headers.items() + + +class _FailureTestCase(_TestCase): + expected: UnsuccessfulResponse # type: ignore[assignment] + + @classmethod + def check_response( + cls, response: httpx.Response, with_service_definition: bool + ) -> None: + super().check_response(response, with_service_definition) + failure = Failure(**response.json()) + + if isinstance(cls.expected.failure_message, str): + assert failure.message == cls.expected.failure_message + else: + assert cls.expected.failure_message(failure.message) + + # retryability assertions + if ( + retryable_header := response.headers.get("nexus-request-retryable") + ) is not None: + assert json.loads(retryable_header) == cls.expected.retryable_header + else: + assert cls.expected.retryable_header is None + + if cls.expected.failure_details: + assert ( + failure.exception_from_details is not None + ), "Expected exception details, but found none." + assert isinstance(failure.exception_from_details, ApplicationError) + + exception_from_failure_details = failure.exception_from_details + if ( + exception_from_failure_details.type == "HandlerError" + and exception_from_failure_details.__cause__ + ): + cause = exception_from_failure_details.__cause__ + assert isinstance(cause, ApplicationError) + exception_from_failure_details = cause + + assert exception_from_failure_details.non_retryable == ( + not cls.expected.retryable_exception + ) + + +class SyncHandlerHappyPath(_TestCase): + operation = "echo" + input = Input("hello") + # TODO(nexus-prerelease): why is application/json randomly scattered around these tests? + headers = { + "Content-Type": "application/json", + "Test-Header-Key": "test-header-value", + "Nexus-Link": '; type="test"', + } + expected = SuccessfulResponse( + status_code=200, + body_json={"value": "from start method on MyServiceHandler: hello"}, + ) + # TODO(nexus-prerelease): headers should be lower-cased + assert ( + headers.get("Nexus-Link") == '; type="test"' + ), "Nexus-Link header not echoed correctly." + + +class SyncHandlerHappyPathRenamed(SyncHandlerHappyPath): + operation = "echo-renamed" + + +class SyncHandlerHappyPathNonAsyncDef(_TestCase): + operation = "sync_operation_with_non_async_def" + input = Input("hello") + expected = SuccessfulResponse( + status_code=200, + body_json={"value": "from start method on MyServiceHandler: hello"}, + ) + + +class AsyncHandlerHappyPath(_TestCase): + operation = "workflow_run_operation_happy_path" + input = Input("hello") + headers = {"Operation-Timeout": "777s"} + expected = SuccessfulResponse( + status_code=201, + ) + + +class WorkflowRunOpLinkTestHappyPath(_TestCase): + # TODO(nexus-prerelease): fix this test + skip = "Yields invalid link" + operation = "workflow_run_op_link_test" + input = Input("link-test-input") + headers = { + "Nexus-Link": '; type="test"', + "Nexus-Request-Id": "test-request-id-123", + } + expected = SuccessfulResponse( + status_code=201, + ) + + @classmethod + def check_response( + cls, response: httpx.Response, with_service_definition: bool + ) -> None: + super().check_response(response, with_service_definition) + nexus_link = response.headers.get("nexus-link") + assert nexus_link is not None, "nexus-link header not found in response" + assert nexus_link.startswith( + " None: + super().check_response(response, with_service_definition) + failure = Failure(**response.json()) + assert failure.exception_from_details + assert isinstance(failure.exception_from_details, ApplicationError) + err = failure.exception_from_details.__cause__ + assert isinstance(err, ApplicationError) + assert err.type == "TestFailureType" + assert err.details == ("details arg",) + + +class NonRetryableApplicationError(_ApplicationErrorTestCase): + operation = "non_retryable_application_error" + expected = UnsuccessfulResponse( + status_code=500, + retryable_header=False, + retryable_exception=False, + failure_message="non-retryable application error", + ) + + +class RetryableApplicationError(_ApplicationErrorTestCase): + operation = "retryable_application_error" + expected = UnsuccessfulResponse( + status_code=500, + retryable_header=True, + failure_message="retryable application error", + ) + + +class HandlerErrorInternal(_FailureTestCase): + operation = "handler_error_internal" + expected = UnsuccessfulResponse( + status_code=500, + # TODO(nexus-prerelease): check this assertion + retryable_header=False, + failure_message="deliberate internal handler error", + ) + + +class OperationErrorFailed(_FailureTestCase): + operation = "operation_error_failed" + expected = UnsuccessfulResponse( + status_code=424, + # TODO(nexus-prerelease): check that OperationError should not set retryable header + retryable_header=None, + failure_message="deliberate operation error", + headers=UNSUCCESSFUL_RESPONSE_HEADERS | {"nexus-operation-state": "failed"}, + ) + + +class UnknownService(_FailureTestCase): + service_defn = "NonExistentService" + operation = "" + expected = UnsuccessfulResponse( + status_code=404, + retryable_header=False, + failure_message="No handler for service 'NonExistentService'.", + ) + + +class UnknownOperation(_FailureTestCase): + operation = "NonExistentOperation" + expected = UnsuccessfulResponse( + status_code=404, + retryable_header=False, + failure_message=lambda s: s.startswith( + "Nexus service definition 'MyService' has no operation 'NonExistentOperation'." + ), + ) + + +class NonSerializableOutputFailure(_FailureTestCase): + operation = "non_serializable_output" + expected = UnsuccessfulResponse( + status_code=500, + retryable_header=False, + failure_message="Object of type function is not JSON serializable", + ) + + +@pytest.mark.parametrize( + "test_case", + [ + SyncHandlerHappyPath, + SyncHandlerHappyPathRenamed, + SyncHandlerHappyPathNonAsyncDef, + AsyncHandlerHappyPath, + WorkflowRunOpLinkTestHappyPath, + ], +) +@pytest.mark.parametrize("with_service_definition", [True, False]) +async def test_start_operation_happy_path( + test_case: Type[_TestCase], + with_service_definition: bool, + env: WorkflowEnvironment, +): + if with_service_definition: + await _test_start_operation_with_service_definition(test_case, env) + else: + await _test_start_operation_without_service_definition(test_case, env) + + +@pytest.mark.parametrize( + "test_case", + [ + OperationHandlerReturningUnwrappedResultError, + UpstreamTimeoutViaRequestTimeout, + OperationTimeoutHeader, + BadRequest, + HandlerErrorInternal, + UnknownService, + UnknownOperation, + NonSerializableOutputFailure, + ], +) +async def test_start_operation_protocol_level_failures( + test_case: Type[_TestCase], env: WorkflowEnvironment +): + await _test_start_operation_with_service_definition(test_case, env) + + +@pytest.mark.parametrize( + "test_case", + [ + NonRetryableApplicationError, + RetryableApplicationError, + OperationErrorFailed, + ], +) +async def test_start_operation_operation_failures( + test_case: Type[_TestCase], env: WorkflowEnvironment +): + await _test_start_operation_with_service_definition(test_case, env) + + +async def _test_start_operation_with_service_definition( + test_case: Type[_TestCase], + env: WorkflowEnvironment, +): + if test_case.skip: + pytest.skip(test_case.skip) + task_queue = str(uuid.uuid4()) + endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + service_client = ServiceClient( + server_address=server_address(env), + endpoint=endpoint, + service=(test_case.service_defn), + ) + + with pytest.WarningsRecorder() as warnings: + decorator = service_handler(service=MyService) + user_service_handler = decorator(MyServiceHandler)() + + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[user_service_handler], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ): + response = await service_client.start_operation( + test_case.operation, + dataclass_as_dict(test_case.input), + test_case.headers, + ) + test_case.check_response(response, with_service_definition=True) + + assert not any(warnings), [w.message for w in warnings] + + +async def _test_start_operation_without_service_definition( + test_case: Type[_TestCase], + env: WorkflowEnvironment, +): + if test_case.skip: + pytest.skip(test_case.skip) + task_queue = str(uuid.uuid4()) + endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + service_client = ServiceClient( + server_address=server_address(env), + endpoint=endpoint, + service=MyServiceHandler.__name__, + ) + + with pytest.WarningsRecorder() as warnings: + decorator = service_handler + user_service_handler = decorator(MyServiceHandler)() + + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[user_service_handler], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ): + response = await service_client.start_operation( + test_case.operation, + dataclass_as_dict(test_case.input), + test_case.headers, + ) + test_case.check_response(response, with_service_definition=False) + + assert not any(warnings), [w.message for w in warnings] + + +@nexusrpc.service +class MyServiceWithOperationsWithoutTypeAnnotations: + workflow_run_operation_without_type_annotations: nexusrpc.Operation[Input, Output] + sync_operation_without_type_annotations: nexusrpc.Operation[Input, Output] + + +class MyServiceHandlerWithOperationsWithoutTypeAnnotations: + @sync_operation + async def sync_operation_without_type_annotations(self, ctx, input): + # Despite the lack of type annotations, the input type from the op definition in + # the service definition is used to deserialize the input. + return Output( + value=f"from start method on {self.__class__.__name__} without type annotations: {input}" + ) + + @workflow_run_operation + async def workflow_run_operation_without_type_annotations(self, ctx, input): + return await ctx.start_workflow( + WorkflowWithoutTypeAnnotations.run, + input, + id=str(uuid.uuid4()), + ) + + +class SyncHandlerHappyPathWithoutTypeAnnotations(_TestCase): + operation = "sync_operation_without_type_annotations" + input = Input("hello") + expected = SuccessfulResponse( + status_code=200, + body_json={ + "value": "from start method on MyServiceHandlerWithOperationsWithoutTypeAnnotations without type annotations: Input(value='hello')" + }, + ) + + +class AsyncHandlerHappyPathWithoutTypeAnnotations(_TestCase): + operation = "workflow_run_operation_without_type_annotations" + input = Input("hello") + expected = SuccessfulResponse( + status_code=201, + ) + + +# Attempting to use the service_handler decorator on a class containing an operation +# without type annotations is a validation error (test coverage in nexusrpc) +@pytest.mark.parametrize( + "test_case", + [ + SyncHandlerHappyPathWithoutTypeAnnotations, + AsyncHandlerHappyPathWithoutTypeAnnotations, + ], +) +async def test_start_operation_without_type_annotations( + test_case: Type[_TestCase], env: WorkflowEnvironment +): + if test_case.skip: + pytest.skip(test_case.skip) + task_queue = str(uuid.uuid4()) + endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + service_client = ServiceClient( + server_address=server_address(env), + endpoint=endpoint, + service=MyServiceWithOperationsWithoutTypeAnnotations.__name__, + ) + + with pytest.WarningsRecorder() as warnings: + decorator = service_handler( + service=MyServiceWithOperationsWithoutTypeAnnotations + ) + user_service_handler = decorator( + MyServiceHandlerWithOperationsWithoutTypeAnnotations + )() + + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[user_service_handler], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ): + response = await service_client.start_operation( + test_case.operation, + dataclass_as_dict(test_case.input), + test_case.headers, + ) + test_case.check_response(response, with_service_definition=True) + + assert not any(warnings), [w.message for w in warnings] + + +def test_operation_without_type_annotations_without_service_definition_raises_validation_error(): + with pytest.raises( + ValueError, + match=r"has no input type.+has no output type", + ): + service_handler(MyServiceHandlerWithOperationsWithoutTypeAnnotations) + + +async def test_logger_uses_operation_context(env: WorkflowEnvironment, caplog: Any): + task_queue = str(uuid.uuid4()) + service_name = MyService.__name__ + operation_name = "log" + resp = await create_nexus_endpoint(task_queue, env.client) + endpoint = resp.endpoint.id + service_client = ServiceClient( + server_address=server_address(env), + endpoint=endpoint, + service=service_name, + ) + caplog.set_level(logging.INFO) + + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[MyServiceHandler()], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ): + response = await service_client.start_operation( + operation_name, + dataclass_as_dict(Input("test_log")), + { + "Content-Type": "application/json", + "Test-Log-Header": "test-log-header-value", + }, + ) + assert response.is_success + response.raise_for_status() + output_json = response.json() + assert output_json == {"value": "logged: test_log"} + + record = next( + ( + record + for record in caplog.records + if record.name == "temporalio.nexus" + and record.getMessage() == "Logging from start method" + ), + None, + ) + assert record is not None, "Expected log message not found" + assert record.levelname == "INFO" + assert getattr(record, "input_value", None) == "test_log" + assert getattr(record, "service", None) == service_name + assert getattr(record, "operation", None) == operation_name + + +class _InstantiationCase: + executor: bool + handler: Callable[[], Any] + exception: Optional[Type[Exception]] + match: Optional[str] + + +@nexusrpc.service +class EchoService: + echo: nexusrpc.Operation[Input, Output] + + +@service_handler(service=EchoService) +class SyncStartHandler: + @syncio_sync_operation + def echo(self, ctx: StartOperationContext, input: Input) -> Output: + assert ctx.headers["test-header-key"] == "test-header-value" + ctx.outbound_links.extend(ctx.inbound_links) + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) + + +@service_handler(service=EchoService) +class DefaultCancelHandler: + @sync_operation + async def echo(self, ctx: StartOperationContext, input: Input) -> Output: + return Output( + value=f"from start method on {self.__class__.__name__}: {input.value}" + ) + + +@service_handler(service=EchoService) +class SyncCancelHandler: + class SyncCancel(OperationHandler[Input, Output]): + async def start( + self, + ctx: StartOperationContext, + input: Input, + # This return type is a type error, but VSCode doesn't flag it unless + # "python.analysis.typeCheckingMode" is set to "strict" + ) -> StartOperationResultSync[Output]: + # Invalid: start method must wrap result as StartOperationResultSync + # or StartOperationResultAsync + return StartOperationResultSync(Output(value="Hello")) # type: ignore + + def cancel(self, ctx: CancelOperationContext, token: str) -> None: + return None # type: ignore + + def fetch_info( + self, ctx: FetchOperationInfoContext, token: str + ) -> OperationInfo: + raise NotImplementedError + + def fetch_result(self, ctx: FetchOperationResultContext, token: str) -> Output: + raise NotImplementedError + + @operation_handler + def echo(self) -> OperationHandler[Input, Output]: + return SyncCancelHandler.SyncCancel() + + +class SyncHandlerNoExecutor(_InstantiationCase): + handler = SyncStartHandler + executor = False + exception = RuntimeError + match = "Use nexusrpc.syncio.handler.Handler instead" + + +class DefaultCancel(_InstantiationCase): + handler = DefaultCancelHandler + executor = False + exception = None + + +class SyncCancel(_InstantiationCase): + handler = SyncCancelHandler + executor = False + exception = RuntimeError + match = "Use nexusrpc.syncio.handler.Handler instead" + + +@pytest.mark.parametrize( + "test_case", + [SyncHandlerNoExecutor, DefaultCancel, SyncCancel], +) +async def test_handler_instantiation( + test_case: Type[_InstantiationCase], client: Client +): + task_queue = str(uuid.uuid4()) + + if test_case.exception is not None: + with pytest.raises(test_case.exception, match=test_case.match): + Worker( + client, + task_queue=task_queue, + nexus_service_handlers=[test_case.handler()], + nexus_task_executor=ThreadPoolExecutor() + if test_case.executor + else None, + ) + else: + Worker( + client, + task_queue=task_queue, + nexus_service_handlers=[test_case.handler()], + nexus_task_executor=ThreadPoolExecutor() if test_case.executor else None, + ) + + +async def test_cancel_operation_with_invalid_token(env: WorkflowEnvironment): + """Verify that canceling an operation with an invalid token fails correctly.""" + task_queue = str(uuid.uuid4()) + endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + service_client = ServiceClient( + server_address=server_address(env), + endpoint=endpoint, + service=MyService.__name__, + ) + + decorator = service_handler(service=MyService) + user_service_handler = decorator(MyServiceHandler)() + + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[user_service_handler], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ): + cancel_response = await service_client.cancel_operation( + "workflow_run_operation_happy_path", + token="this-is-not-a-valid-token", + ) + assert cancel_response.status_code == 404 + failure = Failure(**cancel_response.json()) + assert "failed to decode operation token" in failure.message.lower() + + +async def test_request_id_is_received_by_sync_operation( + env: WorkflowEnvironment, +): + task_queue = str(uuid.uuid4()) + endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + service_client = ServiceClient( + server_address=server_address(env), + endpoint=endpoint, + service=MyService.__name__, + ) + + decorator = service_handler(service=MyService) + user_service_handler = decorator(MyServiceHandler)() + + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[user_service_handler], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ): + request_id = str(uuid.uuid4()) + resp = await service_client.start_operation( + "idempotency_check", None, {"Nexus-Request-Id": request_id} + ) + assert resp.status_code == 200 + assert resp.json() == {"value": f"request_id: {request_id}"} + + +@workflow.defn +class EchoWorkflow: + @workflow.run + async def run(self, input: Input) -> Output: + return Output(value=input.value) + + +@service_handler +class ServiceHandlerForRequestIdTest: + @workflow_run_operation + async def operation_backed_by_a_workflow( + self, ctx: WorkflowRunOperationContext, input: Input + ) -> nexus.WorkflowHandle[Output]: + return await ctx.start_workflow( + EchoWorkflow.run, + input, + id=input.value, + id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, + ) + + @workflow_run_operation + async def operation_that_executes_a_workflow_before_starting_the_backing_workflow( + self, ctx: WorkflowRunOperationContext, input: Input + ) -> nexus.WorkflowHandle[Output]: + await nexus.client().start_workflow( + EchoWorkflow.run, + input, + id=input.value, + task_queue=nexus.info().task_queue, + ) + # This should fail. It will not fail if the Nexus request ID was incorrectly + # propagated to both StartWorkflow requests. + return await ctx.start_workflow( + EchoWorkflow.run, + input, + id=input.value, + id_reuse_policy=WorkflowIDReusePolicy.REJECT_DUPLICATE, + ) + + +async def test_request_id_becomes_start_workflow_request_id(env: WorkflowEnvironment): + # We send two Nexus requests that would start a workflow with the same workflow ID, + # using reuse_policy=REJECT_DUPLICATE. This would fail if they used different + # request IDs. However, when we use the same request ID, it does not fail, + # demonstrating that the Nexus Start Operation request ID has become the + # StartWorkflow request ID. + task_queue = str(uuid.uuid4()) + endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + service_client = ServiceClient( + server_address=server_address(env), + endpoint=endpoint, + service=ServiceHandlerForRequestIdTest.__name__, + ) + + async def start_two_workflows_with_conflicting_workflow_ids( + request_ids: tuple[tuple[str, int, str], tuple[str, int, str]], + ): + workflow_id = str(uuid.uuid4()) + for request_id, status_code, error_message in request_ids: + resp = await service_client.start_operation( + "operation_backed_by_a_workflow", + dataclass_as_dict(Input(workflow_id)), + {"Nexus-Request-Id": request_id}, + ) + assert resp.status_code == status_code, ( + f"expected status code {status_code} " + f"but got {resp.status_code} for response content " + f"{pprint.pformat(resp.content.decode())}" + ) + if not error_message: + assert status_code == 201 + op_info = resp.json() + assert op_info["token"] + assert op_info["state"] == nexusrpc.OperationState.RUNNING.value + else: + assert status_code >= 400 + failure = Failure(**resp.json()) + assert failure.message == error_message + + async def start_two_workflows_in_a_single_operation( + request_id: str, status_code: int, error_message: str + ): + resp = await service_client.start_operation( + "operation_that_executes_a_workflow_before_starting_the_backing_workflow", + dataclass_as_dict(Input("test-workflow-id")), + {"Nexus-Request-Id": request_id}, + ) + assert resp.status_code == status_code + if error_message: + failure = Failure(**resp.json()) + assert failure.message == error_message + + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[ServiceHandlerForRequestIdTest()], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ): + request_id_1, request_id_2 = str(uuid.uuid4()), str(uuid.uuid4()) + # Reusing the same request ID does not fail + await start_two_workflows_with_conflicting_workflow_ids( + ((request_id_1, 201, ""), (request_id_1, 201, "")) + ) + # Using a different request ID does fail + # TODO(nexus-prerelease) I think that this should be a 409 per the spec. Go and + # Java are not doing that. + await start_two_workflows_with_conflicting_workflow_ids( + ( + (request_id_1, 201, ""), + (request_id_2, 500, "Workflow execution already started"), + ) + ) + # Two workflows started in the same operation should fail, since the Nexus + # request ID should be propagated to the backing workflow only. + await start_two_workflows_in_a_single_operation( + request_id_1, 500, "Workflow execution already started" + ) + + +def server_address(env: WorkflowEnvironment) -> str: + http_port = getattr(env, "_http_port", 7243) + return f"http://127.0.0.1:{http_port}" diff --git a/tests/nexus/test_handler_async_operation.py b/tests/nexus/test_handler_async_operation.py new file mode 100644 index 000000000..82280f5bd --- /dev/null +++ b/tests/nexus/test_handler_async_operation.py @@ -0,0 +1,261 @@ +""" +Test that the Nexus SDK can be used to define an operation that responds asynchronously. +""" + +from __future__ import annotations + +import asyncio +import concurrent.futures +import dataclasses +import uuid +from collections.abc import Coroutine +from dataclasses import dataclass, field +from typing import Any, Type, Union + +import nexusrpc +import nexusrpc.handler +import pytest +from nexusrpc import OperationInfo +from nexusrpc.handler import ( + CancelOperationContext, + FetchOperationInfoContext, + FetchOperationResultContext, + OperationHandler, + StartOperationContext, + StartOperationResultAsync, + service_handler, +) +from nexusrpc.handler._decorators import operation_handler + +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from tests.helpers.nexus import ServiceClient, create_nexus_endpoint + + +@dataclass +class Input: + value: str + + +@dataclass +class Output: + value: str + + +@dataclass +class AsyncOperationWithAsyncDefs(OperationHandler[Input, Output]): + executor: TaskExecutor + + async def start( + self, ctx: StartOperationContext, input: Input + ) -> StartOperationResultAsync: + async def task() -> Output: + await asyncio.sleep(0.1) + return Output("Hello from async operation!") + + task_id = str(uuid.uuid4()) + await self.executor.add_task(task_id, task()) + return StartOperationResultAsync(token=task_id) + + async def fetch_info( + self, ctx: FetchOperationInfoContext, token: str + ) -> OperationInfo: + # status = self.executor.get_task_status(task_id=token) + # return OperationInfo(token=token, status=status) + raise NotImplementedError( + "Not possible to test this currently since the server's Nexus implementation does not support fetch_info" + ) + + async def fetch_result( + self, ctx: FetchOperationResultContext, token: str + ) -> Output: + # return await self.executor.get_task_result(task_id=token) + raise NotImplementedError( + "Not possible to test this currently since the server's Nexus implementation does not support fetch_result" + ) + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + self.executor.request_cancel_task(task_id=token) + + +@dataclass +class AsyncOperationWithNonAsyncDefs(OperationHandler[Input, Output]): + executor: TaskExecutor + + def start( + self, ctx: StartOperationContext, input: Input + ) -> StartOperationResultAsync: + async def task() -> Output: + await asyncio.sleep(0.1) + return Output("Hello from async operation!") + + task_id = str(uuid.uuid4()) + self.executor.add_task_sync(task_id, task()) + return StartOperationResultAsync(token=task_id) + + def fetch_info(self, ctx: FetchOperationInfoContext, token: str) -> OperationInfo: + # status = self.executor.get_task_status(task_id=token) + # return OperationInfo(token=token, status=status) + raise NotImplementedError( + "Not possible to test this currently since the server's Nexus implementation does not support fetch_info" + ) + + def fetch_result(self, ctx: FetchOperationResultContext, token: str) -> Output: + # return self.executor.get_task_result_sync(task_id=token) + raise NotImplementedError( + "Not possible to test this currently since the server's Nexus implementation does not support fetch_result" + ) + + def cancel(self, ctx: CancelOperationContext, token: str) -> None: + self.executor.request_cancel_task(task_id=token) + + +@dataclass +@service_handler +class MyServiceHandlerWithAsyncDefs: + executor: TaskExecutor + + @operation_handler + def async_operation(self) -> OperationHandler[Input, Output]: + return AsyncOperationWithAsyncDefs(self.executor) + + +@dataclass +@service_handler +class MyServiceHandlerWithNonAsyncDefs: + executor: TaskExecutor + + @operation_handler + def async_operation(self) -> OperationHandler[Input, Output]: + return AsyncOperationWithNonAsyncDefs(self.executor) + + +@pytest.mark.parametrize( + "service_handler_cls", + [ + MyServiceHandlerWithAsyncDefs, + MyServiceHandlerWithNonAsyncDefs, + ], +) +async def test_async_operation_lifecycle( + env: WorkflowEnvironment, + service_handler_cls: Union[ + Type[MyServiceHandlerWithAsyncDefs], + Type[MyServiceHandlerWithNonAsyncDefs], + ], +): + task_executor = await TaskExecutor.connect() + task_queue = str(uuid.uuid4()) + endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + service_client = ServiceClient( + f"http://127.0.0.1:{env._http_port}", # type: ignore + endpoint, + service_handler_cls.__name__, + ) + + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[service_handler_cls(task_executor)], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ): + start_response = await service_client.start_operation( + "async_operation", + body=dataclass_as_dict(Input(value="Hello from test")), + ) + assert start_response.status_code == 201 + assert start_response.json()["token"] + assert start_response.json()["state"] == "running" + + # Cancel it + cancel_response = await service_client.cancel_operation( + "async_operation", + token=start_response.json()["token"], + ) + assert cancel_response.status_code == 202 + + # get_info and get_result not implemented by server + + +@dataclass +class TaskExecutor: + """ + This class represents the task execution platform being used by the team operating the + Nexus operation. + """ + + event_loop: asyncio.AbstractEventLoop + tasks: dict[str, asyncio.Task[Any]] = field(default_factory=dict) + + @classmethod + async def connect(cls) -> TaskExecutor: + return cls(event_loop=asyncio.get_running_loop()) + + async def add_task(self, task_id: str, coro: Coroutine[Any, Any, Any]) -> None: + """ + Add a task to the task execution platform. + """ + if task_id in self.tasks: + raise RuntimeError(f"Task with id {task_id} already exists") + + # This function is async def because in reality this step will often write to + # durable storage. + self.tasks[task_id] = asyncio.create_task(coro) + + def add_task_sync(self, task_id: str, coro: Coroutine[Any, Any, Any]) -> None: + """ + Add a task to the task execution platform from a sync context. + """ + asyncio.run_coroutine_threadsafe( + self.add_task(task_id, coro), self.event_loop + ).result() + + def get_task_status(self, task_id: str) -> nexusrpc.OperationState: + task = self.tasks[task_id] + if not task.done(): + return nexusrpc.OperationState.RUNNING + elif task.cancelled(): + return nexusrpc.OperationState.CANCELED + elif task.exception(): + return nexusrpc.OperationState.FAILED + else: + return nexusrpc.OperationState.SUCCEEDED + + async def get_task_result(self, task_id: str) -> Any: + """ + Get the result of a task from the task execution platform. + """ + task = self.tasks.get(task_id) + if not task: + raise RuntimeError(f"Task not found with id {task_id}") + return await task + + def get_task_result_sync(self, task_id: str) -> Any: + """ + Get the result of a task from the task execution platform from a sync context. + """ + return asyncio.run_coroutine_threadsafe( + self.get_task_result(task_id), self.event_loop + ).result() + + def request_cancel_task(self, task_id: str) -> None: + """ + Request cancellation of a task on the task execution platform. + """ + task = self.tasks.get(task_id) + if not task: + raise RuntimeError(f"Task not found with id {task_id}") + task.cancel() + # Not implemented: cancellation confirmation, deletion on cancellation + + +def dataclass_as_dict(dataclass: Any) -> dict[str, Any]: + """ + Return a shallow dict of the dataclass's fields. + + dataclasses.as_dict goes too far (attempts to pickle values) + """ + return { + field.name: getattr(dataclass, field.name) + for field in dataclasses.fields(dataclass) + } diff --git a/tests/nexus/test_handler_interface_implementation.py b/tests/nexus/test_handler_interface_implementation.py new file mode 100644 index 000000000..be98ff6d6 --- /dev/null +++ b/tests/nexus/test_handler_interface_implementation.py @@ -0,0 +1,62 @@ +from typing import Any, Optional, Type + +import nexusrpc +import nexusrpc.handler +import pytest +from nexusrpc.handler import StartOperationContext, sync_operation + +from temporalio import nexus +from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation + +HTTP_PORT = 7243 + + +class _InterfaceImplementationTestCase: + Interface: Type[Any] + Impl: Type[Any] + error_message: Optional[str] + + +class ValidImpl(_InterfaceImplementationTestCase): + @nexusrpc.service + class Interface: + op: nexusrpc.Operation[None, None] + + class Impl: + @sync_operation + async def op(self, ctx: StartOperationContext, input: None) -> None: ... + + error_message = None + + +class ValidWorkflowRunImpl(_InterfaceImplementationTestCase): + @nexusrpc.service + class Interface: + op: nexusrpc.Operation[str, int] + + class Impl: + @workflow_run_operation + async def op( + self, ctx: WorkflowRunOperationContext, input: str + ) -> nexus.WorkflowHandle[int]: ... + + error_message = None + + +@pytest.mark.parametrize( + "test_case", + [ + ValidImpl, + ValidWorkflowRunImpl, + ], +) +def test_service_decorator_enforces_interface_conformance( + test_case: Type[_InterfaceImplementationTestCase], +): + if test_case.error_message: + with pytest.raises(Exception) as ei: + nexusrpc.handler.service_handler(test_case.Interface)(test_case.Impl) + err = ei.value + assert test_case.error_message in str(err) + else: + nexusrpc.handler.service_handler(service=test_case.Interface)(test_case.Impl) diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py new file mode 100644 index 000000000..ce124a8b0 --- /dev/null +++ b/tests/nexus/test_handler_operation_definitions.py @@ -0,0 +1,100 @@ +""" +Test that operation_handler decorator results in operation definitions with the correct name +and input/output types. +""" + +from dataclasses import dataclass +from typing import Any, Type + +import nexusrpc.handler +import pytest + +from temporalio import nexus +from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation +from temporalio.nexus._util import get_operation_factory + + +@dataclass +class Input: + pass + + +@dataclass +class Output: + pass + + +@dataclass +class _TestCase: + Service: Type[Any] + expected_operations: dict[str, nexusrpc.Operation] + + +class NotCalled(_TestCase): + @nexusrpc.handler.service_handler + class Service: + @workflow_run_operation + async def my_workflow_run_operation_handler( + self, ctx: WorkflowRunOperationContext, input: Input + ) -> nexus.WorkflowHandle[Output]: ... + + expected_operations = { + "my_workflow_run_operation_handler": nexusrpc.Operation( + name="my_workflow_run_operation_handler", + method_name="my_workflow_run_operation_handler", + input_type=Input, + output_type=Output, + ), + } + + +class CalledWithoutArgs(_TestCase): + @nexusrpc.handler.service_handler + class Service: + @workflow_run_operation + async def my_workflow_run_operation_handler( + self, ctx: WorkflowRunOperationContext, input: Input + ) -> nexus.WorkflowHandle[Output]: ... + + expected_operations = NotCalled.expected_operations + + +class CalledWithNameOverride(_TestCase): + @nexusrpc.handler.service_handler + class Service: + @workflow_run_operation(name="operation-name") + async def workflow_run_operation_with_name_override( + self, ctx: WorkflowRunOperationContext, input: Input + ) -> nexus.WorkflowHandle[Output]: ... + + expected_operations = { + "workflow_run_operation_with_name_override": nexusrpc.Operation( + name="operation-name", + method_name="workflow_run_operation_with_name_override", + input_type=Input, + output_type=Output, + ), + } + + +@pytest.mark.parametrize( + "test_case", + [ + NotCalled, + CalledWithoutArgs, + CalledWithNameOverride, + ], +) +@pytest.mark.asyncio +async def test_collected_operation_names( + test_case: Type[_TestCase], +): + service_defn = nexusrpc.get_service_definition(test_case.Service) + assert isinstance(service_defn, nexusrpc.ServiceDefinition) + assert service_defn.name == "Service" + for method_name, expected_op in test_case.expected_operations.items(): + _, actual_op = get_operation_factory(getattr(test_case.Service, method_name)) + assert isinstance(actual_op, nexusrpc.Operation) + assert actual_op.name == expected_op.name + assert actual_op.input_type == expected_op.input_type + assert actual_op.output_type == expected_op.output_type diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py new file mode 100644 index 000000000..cded6daf7 --- /dev/null +++ b/tests/nexus/test_workflow_caller.py @@ -0,0 +1,1760 @@ +import asyncio +import uuid +from dataclasses import dataclass +from datetime import timedelta +from enum import IntEnum +from itertools import zip_longest +from typing import Any, Awaitable, Callable, Literal, Union + +import nexusrpc +import nexusrpc.handler +import pytest +from nexusrpc.handler import ( + CancelOperationContext, + FetchOperationInfoContext, + FetchOperationResultContext, + OperationHandler, + StartOperationContext, + StartOperationResultAsync, + StartOperationResultSync, + service_handler, + sync_operation, +) +from nexusrpc.handler._decorators import operation_handler + +import temporalio.api +import temporalio.api.common +import temporalio.api.common.v1 +import temporalio.api.enums.v1 +import temporalio.api.nexus +import temporalio.api.nexus.v1 +import temporalio.api.operatorservice +import temporalio.api.operatorservice.v1 +import temporalio.exceptions +import temporalio.nexus._operation_handlers +from temporalio import nexus, workflow +from temporalio.client import ( + Client, + WithStartWorkflowOperation, + WorkflowExecutionStatus, + WorkflowFailureError, + WorkflowHandle, +) +from temporalio.common import WorkflowIDConflictPolicy +from temporalio.exceptions import ( + ApplicationError, + CancelledError, + NexusOperationError, + TimeoutError, +) +from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation +from temporalio.service import RPCError, RPCStatusCode +from temporalio.worker import Worker +from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name + +# TODO(nexus-prerelease): test availability of Temporal client etc in async context set by worker +# TODO(nexus-preview): test worker shutdown, wait_all_completed, drain etc + +# ----------------------------------------------------------------------------- +# Test definition +# + + +class CallerReference(IntEnum): + IMPL_WITHOUT_INTERFACE = 0 + IMPL_WITH_INTERFACE = 1 + INTERFACE = 2 + + +class OpDefinitionType(IntEnum): + SHORTHAND = 0 + LONGHAND = 1 + + +@dataclass +class SyncResponse: + op_definition_type: OpDefinitionType + use_async_def: bool + exception_in_operation_start: bool + + +@dataclass +class AsyncResponse: + operation_workflow_id: str + block_forever_waiting_for_cancellation: bool + op_definition_type: OpDefinitionType + exception_in_operation_start: bool + + +# The order of the two types in this union is critical since the data converter matches +# eagerly, ignoring unknown fields, and so would identify an AsyncResponse as a +# SyncResponse if SyncResponse came first. +ResponseType = Union[AsyncResponse, SyncResponse] + +# ----------------------------------------------------------------------------- +# Service interface +# + + +@dataclass +class OpInput: + response_type: ResponseType + headers: dict[str, str] + caller_reference: CallerReference + + +@dataclass +class OpOutput: + value: str + + +@dataclass +class HandlerWfInput: + op_input: OpInput + + +@dataclass +class HandlerWfOutput: + value: str + + +@nexusrpc.service +class ServiceInterface: + sync_or_async_operation: nexusrpc.Operation[OpInput, OpOutput] + sync_operation: nexusrpc.Operation[OpInput, OpOutput] + async_operation: nexusrpc.Operation[OpInput, HandlerWfOutput] + + +# ----------------------------------------------------------------------------- +# Service implementation +# + + +@workflow.defn +class HandlerWorkflow: + @workflow.run + async def run( + self, + input: HandlerWfInput, + ) -> HandlerWfOutput: + assert isinstance(input.op_input.response_type, AsyncResponse) + if input.op_input.response_type.block_forever_waiting_for_cancellation: + await asyncio.Future() + return HandlerWfOutput( + value="workflow result", + ) + + +# TODO(nexus-prerelease): check type-checking passing in CI + + +class SyncOrAsyncOperation(OperationHandler[OpInput, OpOutput]): + async def start( # type: ignore[override] + self, ctx: StartOperationContext, input: OpInput + ) -> Union[ + StartOperationResultSync[OpOutput], + StartOperationResultAsync, + ]: + if input.response_type.exception_in_operation_start: + # TODO(nexus-prerelease): don't think RPCError should be used here + raise RPCError( + "RPCError INVALID_ARGUMENT in Nexus operation", + RPCStatusCode.INVALID_ARGUMENT, + b"", + ) + if isinstance(input.response_type, SyncResponse): + return StartOperationResultSync(value=OpOutput(value="sync response")) + elif isinstance(input.response_type, AsyncResponse): + # TODO(nexus-preview): what do we want the DX to be for a user who is + # starting a Nexus backing workflow from a custom start method? (They may + # need to do this in order to customize the cancel method). + tctx = WorkflowRunOperationContext.from_start_operation_context(ctx) + handle = await tctx.start_workflow( + HandlerWorkflow.run, + HandlerWfInput(op_input=input), + id=input.response_type.operation_workflow_id, + ) + return StartOperationResultAsync(handle.to_token()) + else: + raise TypeError + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + return await temporalio.nexus._operation_handlers._cancel_workflow(token) + + async def fetch_info( + self, ctx: FetchOperationInfoContext, token: str + ) -> nexusrpc.OperationInfo: + raise NotImplementedError + + async def fetch_result( + self, ctx: FetchOperationResultContext, token: str + ) -> OpOutput: + raise NotImplementedError + + +@service_handler(service=ServiceInterface) +class ServiceImpl: + @operation_handler + def sync_or_async_operation( + self, + ) -> OperationHandler[OpInput, OpOutput]: + return SyncOrAsyncOperation() + + @sync_operation + async def sync_operation( + self, ctx: StartOperationContext, input: OpInput + ) -> OpOutput: + assert isinstance(input.response_type, SyncResponse) + if input.response_type.exception_in_operation_start: + raise RPCError( + "RPCError INVALID_ARGUMENT in Nexus operation", + RPCStatusCode.INVALID_ARGUMENT, + b"", + ) + return OpOutput(value="sync response") + + @workflow_run_operation + async def async_operation( + self, ctx: WorkflowRunOperationContext, input: OpInput + ) -> nexus.WorkflowHandle[HandlerWfOutput]: + assert isinstance(input.response_type, AsyncResponse) + if input.response_type.exception_in_operation_start: + raise RPCError( + "RPCError INVALID_ARGUMENT in Nexus operation", + RPCStatusCode.INVALID_ARGUMENT, + b"", + ) + return await ctx.start_workflow( + HandlerWorkflow.run, + HandlerWfInput(op_input=input), + id=input.response_type.operation_workflow_id, + ) + + +# ----------------------------------------------------------------------------- +# Caller workflow +# + + +@dataclass +class CallerWfInput: + op_input: OpInput + + +@dataclass +class CallerWfOutput: + op_output: OpOutput + + +@workflow.defn +class CallerWorkflow: + """ + A workflow that executes a Nexus operation, specifying whether it should return + synchronously or asynchronously. + """ + + @workflow.init + def __init__( + self, + input: CallerWfInput, + request_cancel: bool, + task_queue: str, + ) -> None: + self.nexus_client = workflow.create_nexus_client( + endpoint=make_nexus_endpoint_name(task_queue), + service={ + CallerReference.IMPL_WITH_INTERFACE: ServiceImpl, + CallerReference.INTERFACE: ServiceInterface, + }[input.op_input.caller_reference], + ) + self._nexus_operation_started = False + self._proceed = False + + @workflow.run + async def run( + self, + input: CallerWfInput, + request_cancel: bool, + task_queue: str, + ) -> CallerWfOutput: + op_input = input.op_input + op_handle = await self.nexus_client.start_operation( + self._get_operation(op_input), + op_input, + headers=op_input.headers, + ) + self._nexus_operation_started = True + if not input.op_input.response_type.exception_in_operation_start: + if isinstance(input.op_input.response_type, SyncResponse): + assert ( + op_handle.operation_token is None + ), "operation_token should be absent after a sync response" + else: + assert ( + op_handle.operation_token + ), "operation_token should be present after an async response" + + if request_cancel: + # Even for SyncResponse, the op_handle future is not done at this point; that + # transition doesn't happen until the handle is awaited. + assert op_handle.cancel() + op_output = await op_handle + return CallerWfOutput(op_output=OpOutput(value=op_output.value)) + + @workflow.update + async def wait_nexus_operation_started(self) -> None: + await workflow.wait_condition(lambda: self._nexus_operation_started) + + @staticmethod + def _get_operation( + op_input: OpInput, + ) -> Union[ + nexusrpc.Operation[OpInput, OpOutput], + Callable[[Any], OperationHandler[OpInput, OpOutput]], + ]: + return { # type: ignore[return-value] + ( + SyncResponse, + OpDefinitionType.SHORTHAND, + CallerReference.IMPL_WITH_INTERFACE, + True, + ): ServiceImpl.sync_operation, + ( + SyncResponse, + OpDefinitionType.SHORTHAND, + CallerReference.INTERFACE, + True, + ): ServiceInterface.sync_operation, + ( + SyncResponse, + OpDefinitionType.LONGHAND, + CallerReference.IMPL_WITH_INTERFACE, + True, + ): ServiceImpl.sync_or_async_operation, + ( + SyncResponse, + OpDefinitionType.LONGHAND, + CallerReference.INTERFACE, + True, + ): ServiceInterface.sync_or_async_operation, + ( + AsyncResponse, + OpDefinitionType.SHORTHAND, + CallerReference.IMPL_WITH_INTERFACE, + True, + ): ServiceImpl.async_operation, + ( + AsyncResponse, + OpDefinitionType.SHORTHAND, + CallerReference.INTERFACE, + True, + ): ServiceInterface.async_operation, + ( + AsyncResponse, + OpDefinitionType.LONGHAND, + CallerReference.IMPL_WITH_INTERFACE, + True, + ): ServiceImpl.sync_or_async_operation, + ( + AsyncResponse, + OpDefinitionType.LONGHAND, + CallerReference.INTERFACE, + True, + ): ServiceInterface.sync_or_async_operation, + }[ + {True: SyncResponse, False: AsyncResponse}[ + isinstance(op_input.response_type, SyncResponse) + ], + op_input.response_type.op_definition_type, + op_input.caller_reference, + ( + op_input.response_type.use_async_def + if isinstance(op_input.response_type, SyncResponse) + else True + ), + ] + + +@workflow.defn +class UntypedCallerWorkflow: + @workflow.init + def __init__( + self, input: CallerWfInput, request_cancel: bool, task_queue: str + ) -> None: + # TODO(nexus-preview): untyped caller cannot reference name of implementation. I think this is as it should be. + service_name = "ServiceInterface" + self.nexus_client: workflow.NexusClient[Any] = workflow.create_nexus_client( + endpoint=make_nexus_endpoint_name(task_queue), + service=service_name, + ) + + @workflow.run + async def run( + self, input: CallerWfInput, request_cancel: bool, task_queue: str + ) -> CallerWfOutput: + op_input = input.op_input + if op_input.response_type.op_definition_type == OpDefinitionType.LONGHAND: + op_name = "sync_or_async_operation" + elif isinstance(op_input.response_type, AsyncResponse): + op_name = "async_operation" + elif isinstance(op_input.response_type, SyncResponse): + op_name = "sync_operation" + else: + raise TypeError + + arbitrary_condition = isinstance(op_input.response_type, SyncResponse) + + if arbitrary_condition: + op_handle = await self.nexus_client.start_operation( + op_name, + op_input, + headers=op_input.headers, + output_type=OpOutput, + ) + op_output = await op_handle + else: + op_output = await self.nexus_client.execute_operation( + op_name, + op_input, + headers=op_input.headers, + output_type=OpOutput, + ) + return CallerWfOutput(op_output=OpOutput(value=op_output.value)) + + +# ----------------------------------------------------------------------------- +# Tests +# + + +# TODO(nexus-preview): cross-namespace tests +# TODO(nexus-preview): nexus endpoint pytest fixture? +# TODO(nexus-prerelease): test headers +@pytest.mark.parametrize("exception_in_operation_start", [False, True]) +@pytest.mark.parametrize("request_cancel", [False, True]) +@pytest.mark.parametrize( + "op_definition_type", [OpDefinitionType.SHORTHAND, OpDefinitionType.LONGHAND] +) +@pytest.mark.parametrize( + "caller_reference", + [CallerReference.IMPL_WITH_INTERFACE, CallerReference.INTERFACE], +) +async def test_sync_response( + client: Client, + exception_in_operation_start: bool, + request_cancel: bool, + op_definition_type: OpDefinitionType, + caller_reference: CallerReference, +): + task_queue = str(uuid.uuid4()) + async with Worker( + client, + nexus_service_handlers=[ServiceImpl()], + workflows=[CallerWorkflow, HandlerWorkflow], + task_queue=task_queue, + workflow_failure_exception_types=[Exception], + ): + await create_nexus_endpoint(task_queue, client) + caller_wf_handle = await client.start_workflow( + CallerWorkflow.run, + args=[ + CallerWfInput( + op_input=OpInput( + response_type=SyncResponse( + op_definition_type=op_definition_type, + use_async_def=True, + exception_in_operation_start=exception_in_operation_start, + ), + headers={"header-key": "header-value"}, + caller_reference=caller_reference, + ), + ), + request_cancel, + task_queue, + ], + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + + # TODO(nexus-prerelease): check bidi links for sync operation + + # The operation result is returned even when request_cancel=True, because the + # response was synchronous and it could not be cancelled. See explanation below. + if exception_in_operation_start: + with pytest.raises(WorkflowFailureError) as ei: + await caller_wf_handle.result() + e = ei.value + assert isinstance(e, WorkflowFailureError) + assert isinstance(e.__cause__, NexusOperationError) + assert isinstance(e.__cause__.__cause__, nexusrpc.HandlerError) + # ID of first command + assert e.__cause__.scheduled_event_id == 5 + assert e.__cause__.endpoint == make_nexus_endpoint_name(task_queue) + assert e.__cause__.service == "ServiceInterface" + assert ( + e.__cause__.operation == "sync_operation" + if op_definition_type == OpDefinitionType.SHORTHAND + else "sync_or_async_operation" + ) + else: + result = await caller_wf_handle.result() + assert result.op_output.value == "sync response" + + +@pytest.mark.parametrize("exception_in_operation_start", [False, True]) +@pytest.mark.parametrize("request_cancel", [False, True]) +@pytest.mark.parametrize( + "op_definition_type", [OpDefinitionType.SHORTHAND, OpDefinitionType.LONGHAND] +) +@pytest.mark.parametrize( + "caller_reference", + [CallerReference.IMPL_WITH_INTERFACE, CallerReference.INTERFACE], +) +async def test_async_response( + client: Client, + exception_in_operation_start: bool, + request_cancel: bool, + op_definition_type: OpDefinitionType, + caller_reference: CallerReference, +): + task_queue = str(uuid.uuid4()) + async with Worker( + client, + nexus_service_handlers=[ServiceImpl()], + workflows=[CallerWorkflow, HandlerWorkflow], + task_queue=task_queue, + workflow_failure_exception_types=[Exception], + ): + caller_wf_handle, handler_wf_handle = await _start_wf_and_nexus_op( + client, + task_queue, + exception_in_operation_start, + request_cancel, + op_definition_type, + caller_reference, + ) + if exception_in_operation_start: + with pytest.raises(WorkflowFailureError) as ei: + await caller_wf_handle.result() + e = ei.value + assert isinstance(e, WorkflowFailureError) + assert isinstance(e.__cause__, NexusOperationError) + assert isinstance(e.__cause__.__cause__, nexusrpc.HandlerError) + # ID of first command after update accepted + assert e.__cause__.scheduled_event_id == 6 + assert e.__cause__.endpoint == make_nexus_endpoint_name(task_queue) + assert e.__cause__.service == "ServiceInterface" + assert ( + e.__cause__.operation == "async_operation" + if op_definition_type == OpDefinitionType.SHORTHAND + else "sync_or_async_operation" + ) + return + + # TODO(nexus-prerelease): race here? How do we know it hasn't been canceled already? + handler_wf_info = await handler_wf_handle.describe() + assert handler_wf_info.status in [ + WorkflowExecutionStatus.RUNNING, + WorkflowExecutionStatus.COMPLETED, + ] + await assert_caller_workflow_has_link_to_handler_workflow( + caller_wf_handle, handler_wf_handle, handler_wf_info.run_id + ) + await assert_handler_workflow_has_link_to_caller_workflow( + caller_wf_handle, handler_wf_handle + ) + + if request_cancel: + # The operation response was asynchronous and so request_cancel is honored. See + # explanation below. + with pytest.raises(WorkflowFailureError) as ei: + await caller_wf_handle.result() + e = ei.value + assert isinstance(e, WorkflowFailureError) + assert isinstance(e.__cause__, NexusOperationError) + assert isinstance(e.__cause__.__cause__, CancelledError) + # ID of first command after update accepted + assert e.__cause__.scheduled_event_id == 6 + assert e.__cause__.endpoint == make_nexus_endpoint_name(task_queue) + assert e.__cause__.service == "ServiceInterface" + assert ( + e.__cause__.operation == "async_operation" + if op_definition_type == OpDefinitionType.SHORTHAND + else "sync_or_async_operation" + ) + assert nexus.WorkflowHandle.from_token( + e.__cause__.operation_token + ) == nexus.WorkflowHandle( + namespace=handler_wf_handle._client.namespace, + workflow_id=handler_wf_handle.id, + ) + # Check that the handler workflow was canceled + handler_wf_info = await handler_wf_handle.describe() + assert handler_wf_info.status == WorkflowExecutionStatus.CANCELED + else: + handler_wf_info = await handler_wf_handle.describe() + assert handler_wf_info.status == WorkflowExecutionStatus.COMPLETED + result = await caller_wf_handle.result() + assert result.op_output.value == "workflow result" + + +async def _start_wf_and_nexus_op( + client: Client, + task_queue: str, + exception_in_operation_start: bool, + request_cancel: bool, + op_definition_type: OpDefinitionType, + caller_reference: CallerReference, +) -> tuple[ + WorkflowHandle[CallerWorkflow, CallerWfOutput], + WorkflowHandle[HandlerWorkflow, HandlerWfOutput], +]: + """ + Start the caller workflow and wait until the Nexus operation has started. + """ + await create_nexus_endpoint(task_queue, client) + operation_workflow_id = str(uuid.uuid4()) + + # Start the caller workflow and wait until it confirms the Nexus operation has started. + block_forever_waiting_for_cancellation = request_cancel + start_op = WithStartWorkflowOperation( + CallerWorkflow.run, + args=[ + CallerWfInput( + op_input=OpInput( + response_type=AsyncResponse( + operation_workflow_id, + block_forever_waiting_for_cancellation, + op_definition_type, + exception_in_operation_start=exception_in_operation_start, + ), + headers={"header-key": "header-value"}, + caller_reference=caller_reference, + ), + ), + request_cancel, + task_queue, + ], + id=str(uuid.uuid4()), + task_queue=task_queue, + id_conflict_policy=WorkflowIDConflictPolicy.FAIL, + ) + + await client.execute_update_with_start_workflow( + CallerWorkflow.wait_nexus_operation_started, + start_workflow_operation=start_op, + ) + caller_wf_handle = await start_op.workflow_handle() + + # check that the operation-backing workflow now exists, and that (a) the handler + # workflow accepted the link to the calling Nexus event, and that (b) the caller + # workflow NexusOperationStarted event received in return a link to the + # operation-backing workflow. + handler_wf_handle: WorkflowHandle[HandlerWorkflow, HandlerWfOutput] = ( + client.get_workflow_handle(operation_workflow_id) + ) + return caller_wf_handle, handler_wf_handle + + +@pytest.mark.parametrize("exception_in_operation_start", [False, True]) +@pytest.mark.parametrize( + "op_definition_type", [OpDefinitionType.SHORTHAND, OpDefinitionType.LONGHAND] +) +@pytest.mark.parametrize( + "caller_reference", + [CallerReference.IMPL_WITH_INTERFACE, CallerReference.INTERFACE], +) +@pytest.mark.parametrize("response_type", [SyncResponse, AsyncResponse]) +async def test_untyped_caller( + client: Client, + exception_in_operation_start: bool, + op_definition_type: OpDefinitionType, + caller_reference: CallerReference, + response_type: ResponseType, +): + task_queue = str(uuid.uuid4()) + async with Worker( + client, + workflows=[UntypedCallerWorkflow, HandlerWorkflow], + nexus_service_handlers=[ServiceImpl()], + task_queue=task_queue, + workflow_failure_exception_types=[Exception], + ): + if response_type == SyncResponse: + response_type = SyncResponse( + op_definition_type=op_definition_type, + use_async_def=True, + exception_in_operation_start=exception_in_operation_start, + ) + else: + response_type = AsyncResponse( + operation_workflow_id=str(uuid.uuid4()), + block_forever_waiting_for_cancellation=False, + op_definition_type=op_definition_type, + exception_in_operation_start=exception_in_operation_start, + ) + await create_nexus_endpoint(task_queue, client) + caller_wf_handle = await client.start_workflow( + UntypedCallerWorkflow.run, + args=[ + CallerWfInput( + op_input=OpInput( + response_type=response_type, + headers={}, + caller_reference=caller_reference, + ), + ), + False, + task_queue, + ], + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + if exception_in_operation_start: + with pytest.raises(WorkflowFailureError) as ei: + await caller_wf_handle.result() + e = ei.value + assert isinstance(e, WorkflowFailureError) + assert isinstance(e.__cause__, NexusOperationError) + assert isinstance(e.__cause__.__cause__, nexusrpc.HandlerError) + else: + result = await caller_wf_handle.result() + assert result.op_output.value == ( + "sync response" + if isinstance(response_type, SyncResponse) + else "workflow result" + ) + + +# +# Test routing of workflow calls +# + + +@dataclass +class ServiceClassNameOutput: + name: str + + +# TODO(nexus-prerelease): async and non-async cancel methods + + +@nexusrpc.service +class ServiceInterfaceWithoutNameOverride: + op: nexusrpc.Operation[None, ServiceClassNameOutput] + + +@nexusrpc.service(name="service-interface-🌈") +class ServiceInterfaceWithNameOverride: + op: nexusrpc.Operation[None, ServiceClassNameOutput] + + +@service_handler +class ServiceImplInterfaceWithNeitherInterfaceNorNameOverride: + @sync_operation + async def op( + self, ctx: StartOperationContext, input: None + ) -> ServiceClassNameOutput: + return ServiceClassNameOutput(self.__class__.__name__) + + +@service_handler(service=ServiceInterfaceWithoutNameOverride) +class ServiceImplInterfaceWithoutNameOverride: + @sync_operation + async def op( + self, ctx: StartOperationContext, input: None + ) -> ServiceClassNameOutput: + return ServiceClassNameOutput(self.__class__.__name__) + + +@service_handler(service=ServiceInterfaceWithNameOverride) +class ServiceImplInterfaceWithNameOverride: + @sync_operation + async def op( + self, ctx: StartOperationContext, input: None + ) -> ServiceClassNameOutput: + return ServiceClassNameOutput(self.__class__.__name__) + + +@service_handler(name="service-impl-🌈") +class ServiceImplWithNameOverride: + @sync_operation + async def op( + self, ctx: StartOperationContext, input: None + ) -> ServiceClassNameOutput: + return ServiceClassNameOutput(self.__class__.__name__) + + +class NameOverride(IntEnum): + NO = 0 + YES = 1 + + +@workflow.defn +class ServiceInterfaceAndImplCallerWorkflow: + @workflow.run + async def run( + self, + caller_reference: CallerReference, + name_override: NameOverride, + task_queue: str, + ) -> ServiceClassNameOutput: + C, N = CallerReference, NameOverride + service_cls: type + if (caller_reference, name_override) == (C.INTERFACE, N.YES): + service_cls = ServiceInterfaceWithNameOverride + elif (caller_reference, name_override) == (C.INTERFACE, N.NO): + service_cls = ServiceInterfaceWithoutNameOverride + elif (caller_reference, name_override) == (C.IMPL_WITH_INTERFACE, N.YES): + service_cls = ServiceImplWithNameOverride + elif (caller_reference, name_override) == (C.IMPL_WITH_INTERFACE, N.NO): + service_cls = ServiceImplInterfaceWithoutNameOverride + elif (caller_reference, name_override) == (C.IMPL_WITHOUT_INTERFACE, N.NO): + service_cls = ServiceImplInterfaceWithNameOverride + service_cls = ServiceImplInterfaceWithNeitherInterfaceNorNameOverride + else: + raise ValueError( + f"Invalid combination of caller_reference ({caller_reference}) and name_override ({name_override})" + ) + + nexus_client = workflow.create_nexus_client( + endpoint=make_nexus_endpoint_name(task_queue), + service=service_cls, + ) + + return await nexus_client.execute_operation(service_cls.op, None) # type: ignore + + +# TODO(nexus-prerelease): check missing decorator behavior + + +async def test_service_interface_and_implementation_names(client: Client): + # Note that: + # - The caller can specify the service & operation via a reference to either the + # interface or implementation class. + # - An interface class may optionally override its name. + # - An implementation class may either override its name or specify an interface that + # it is implementing, but not both. + # - On registering a service implementation with a worker, the name by which the + # service is addressed in requests is the interface name if the implementation + # supplies one, or else the name override made by the impl class, or else the impl + # class name. + # + # This test checks that the request is routed to the expected service under a variety + # of scenarios related to the above considerations. + task_queue = str(uuid.uuid4()) + async with Worker( + client, + nexus_service_handlers=[ + ServiceImplWithNameOverride(), + ServiceImplInterfaceWithNameOverride(), + ServiceImplInterfaceWithoutNameOverride(), + ServiceImplInterfaceWithNeitherInterfaceNorNameOverride(), + ], + workflows=[ServiceInterfaceAndImplCallerWorkflow], + task_queue=task_queue, + workflow_failure_exception_types=[Exception], + ): + await create_nexus_endpoint(task_queue, client) + assert await client.execute_workflow( + ServiceInterfaceAndImplCallerWorkflow.run, + args=(CallerReference.INTERFACE, NameOverride.YES, task_queue), + id=str(uuid.uuid4()), + task_queue=task_queue, + ) == ServiceClassNameOutput("ServiceImplInterfaceWithNameOverride") + assert await client.execute_workflow( + ServiceInterfaceAndImplCallerWorkflow.run, + args=(CallerReference.INTERFACE, NameOverride.NO, task_queue), + id=str(uuid.uuid4()), + task_queue=task_queue, + ) == ServiceClassNameOutput("ServiceImplInterfaceWithoutNameOverride") + assert await client.execute_workflow( + ServiceInterfaceAndImplCallerWorkflow.run, + args=( + CallerReference.IMPL_WITH_INTERFACE, + NameOverride.YES, + task_queue, + ), + id=str(uuid.uuid4()), + task_queue=task_queue, + ) == ServiceClassNameOutput("ServiceImplWithNameOverride") + assert await client.execute_workflow( + ServiceInterfaceAndImplCallerWorkflow.run, + args=( + CallerReference.IMPL_WITH_INTERFACE, + NameOverride.NO, + task_queue, + ), + id=str(uuid.uuid4()), + task_queue=task_queue, + ) == ServiceClassNameOutput("ServiceImplInterfaceWithoutNameOverride") + assert await client.execute_workflow( + ServiceInterfaceAndImplCallerWorkflow.run, + args=( + CallerReference.IMPL_WITHOUT_INTERFACE, + NameOverride.NO, + task_queue, + ), + id=str(uuid.uuid4()), + task_queue=task_queue, + ) == ServiceClassNameOutput( + "ServiceImplInterfaceWithNeitherInterfaceNorNameOverride" + ) + + +@nexusrpc.service +class ServiceWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow: + my_workflow_run_operation: nexusrpc.Operation[None, None] + my_manual_async_operation: nexusrpc.Operation[None, None] + + +@workflow.defn +class EchoWorkflow: + @workflow.run + async def run(self, input: str) -> str: + return input + + +@service_handler +class ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow: + @workflow_run_operation + async def my_workflow_run_operation( + self, ctx: WorkflowRunOperationContext, input: None + ) -> nexus.WorkflowHandle[str]: + result_1 = await nexus.client().execute_workflow( + EchoWorkflow.run, + "result-1", + id=str(uuid.uuid4()), + task_queue=nexus.info().task_queue, + ) + # In case result_1 is incorrectly being delivered to the caller as the operation + # result, give time for that incorrect behavior to occur. + await asyncio.sleep(0.5) + return await ctx.start_workflow( + EchoWorkflow.run, + f"{result_1}-result-2", + id=str(uuid.uuid4()), + ) + + +@workflow.defn +class WorkflowCallingNexusOperationThatExecutesWorkflowBeforeStartingBackingWorkflow: + @workflow.run + async def run(self, input: str, task_queue: str) -> str: + nexus_client = workflow.create_nexus_client( + endpoint=make_nexus_endpoint_name(task_queue), + service=ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow, + ) + return await nexus_client.execute_operation( + ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow.my_workflow_run_operation, + None, + ) + + +async def test_workflow_run_operation_can_execute_workflow_before_starting_backing_workflow( + client: Client, +): + task_queue = str(uuid.uuid4()) + async with Worker( + client, + workflows=[ + EchoWorkflow, + WorkflowCallingNexusOperationThatExecutesWorkflowBeforeStartingBackingWorkflow, + ], + nexus_service_handlers=[ + ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow(), + ], + task_queue=task_queue, + ): + await create_nexus_endpoint(task_queue, client) + result = await client.execute_workflow( + WorkflowCallingNexusOperationThatExecutesWorkflowBeforeStartingBackingWorkflow.run, + args=("result-1", task_queue), + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + assert result == "result-1-result-2" + + +# TODO(nexus-prerelease): test invalid service interface implementations +# TODO(nexus-prerelease): test caller passing output_type + + +async def assert_caller_workflow_has_link_to_handler_workflow( + caller_wf_handle: WorkflowHandle, + handler_wf_handle: WorkflowHandle, + handler_wf_run_id: str, +): + caller_history = await caller_wf_handle.fetch_history() + op_started_event = next( + e + for e in caller_history.events + if ( + e.event_type + == temporalio.api.enums.v1.EventType.EVENT_TYPE_NEXUS_OPERATION_STARTED + ) + ) + if not len(op_started_event.links) == 1: + pytest.fail( + f"Expected 1 link on NexusOperationStarted event, got {len(op_started_event.links)}" + ) + [link] = op_started_event.links + assert link.workflow_event.namespace == handler_wf_handle._client.namespace + assert link.workflow_event.workflow_id == handler_wf_handle.id + assert link.workflow_event.run_id + assert link.workflow_event.run_id == handler_wf_run_id + assert ( + link.workflow_event.event_ref.event_type + == temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED + ) + + +async def assert_handler_workflow_has_link_to_caller_workflow( + caller_wf_handle: WorkflowHandle, + handler_wf_handle: WorkflowHandle, +): + handler_history = await handler_wf_handle.fetch_history() + wf_started_event = next( + e + for e in handler_history.events + if ( + e.event_type + == temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED + ) + ) + if not len(wf_started_event.links) == 1: + pytest.fail( + f"Expected 1 link on WorkflowExecutionStarted event, got {len(wf_started_event.links)}" + ) + [link] = wf_started_event.links + assert link.workflow_event.namespace == caller_wf_handle._client.namespace + assert link.workflow_event.workflow_id == caller_wf_handle.id + assert link.workflow_event.run_id + assert link.workflow_event.run_id == caller_wf_handle.first_execution_run_id + assert ( + link.workflow_event.event_ref.event_type + == temporalio.api.enums.v1.EventType.EVENT_TYPE_NEXUS_OPERATION_SCHEDULED + ) + + +# When request_cancel is True, the NexusOperationHandle in the workflow evolves +# through the following states: +# start_fut result_fut handle_task w/ fut_waiter (task._must_cancel) +# +# Case 1: Sync Nexus operation response w/ cancellation of NexusOperationHandle +# ----------------------------------------------------------------------------- +# >>>>>>>>>>>> WFT 1 +# after await start : Future_7856[FINISHED] Future_7984[FINISHED] Task[PENDING] fut_waiter = Future_8240[FINISHED]) (False) +# before op_handle.cancel : Future_7856[FINISHED] Future_7984[FINISHED] Task[PENDING] fut_waiter = Future_8240[FINISHED]) (False) +# Future_8240[FINISHED].cancel() -> False # no state transition; fut_waiter is already finished +# cancel returned : True +# before await op_handle : Future_7856[FINISHED] Future_7984[FINISHED] Task[PENDING] fut_waiter = Future_8240[FINISHED]) (True) +# --> Despite cancel having been requested, this await on the nexus op handle does not +# raise CancelledError, because the task's underlying fut_waiter is already finished. +# after await op_handle : Future_7856[FINISHED] Future_7984[FINISHED] Task[FINISHED] fut_waiter = None) (False) +# +# +# Case 2: Async Nexus operation response w/ cancellation of NexusOperationHandle +# ------------------------------------------------------------------------------ +# >>>>>>>>>>>> WFT 1 +# after await start : Future_7568[FINISHED] Future_7696[PENDING] Task[PENDING] fut_waiter = Future_7952[PENDING]) (False) +# >>>>>>>>>>>> WFT 2 +# >>>>>>>>>>>> WFT 3 +# after await proceed : Future_7568[FINISHED] Future_7696[PENDING] Task[PENDING] fut_waiter = Future_7952[PENDING]) (False) +# before op_handle.cancel : Future_7568[FINISHED] Future_7696[PENDING] Task[PENDING] fut_waiter = Future_7952[PENDING]) (False) +# Future_7952[PENDING].cancel() -> True # transition to cancelled state; fut_waiter was not finished +# cancel returned : True +# before await op_handle : Future_7568[FINISHED] Future_7696[PENDING] Task[PENDING] fut_waiter = Future_7952[CANCELLED]) (False) +# --> This await on the nexus op handle raises CancelledError, because the task's underlying fut_waiter is cancelled. +# +# Thus in the sync case, although the caller workflow attempted to cancel the +# NexusOperationHandle, this did not result in a CancelledError when the handle was +# awaited, because both resolve_nexus_operation_start and resolve_nexus_operation jobs +# were sent in the same activation and hence the task's fut_waiter was already finished. +# +# But in the async case, at the time that we cancel the NexusOperationHandle, only the +# resolve_nexus_operation_start job had been sent; the result_fut was unresolved. Thus +# when the handle was awaited, CancelledError was raised. +# +# To create output like that above, set the following __repr__s: +# asyncio.Future: +# def __repr__(self): +# return f"{self.__class__.__name__}_{str(id(self))[-4:]}[{self._state}]" +# _NexusOperationHandle: +# def __repr__(self) -> str: +# return ( +# f"{self._start_fut} " +# f"{self._result_fut} " +# f"Task[{self._task._state}] fut_waiter = {self._task._fut_waiter}) ({self._task._must_cancel})" +# ) + + +# Handler + +# @OperationImpl +# public OperationHandler testError() { +# return OperationHandler.sync( +# (ctx, details, input) -> { +# switch (input.getAction()) { +# case RAISE_APPLICATION_ERROR: +# throw ApplicationFailure.newNonRetryableFailure( +# "application error 1", "APPLICATION_ERROR"); +# case RAISE_CUSTOM_ERROR: +# throw new MyCustomException("Custom error 1"); +# case RAISE_CUSTOM_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: +# // ** THIS DOESN'T WORK **: CHAINED CUSTOM EXCEPTIONS DON'T SERIALIZE +# MyCustomException customError = new MyCustomException("Custom error 1"); +# customError.initCause(new MyCustomException("Custom error 2")); +# throw customError; +# case RAISE_APPLICATION_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: +# throw ApplicationFailure.newNonRetryableFailureWithCause( +# "application error 1", +# "APPLICATION_ERROR", +# new MyCustomException("Custom error 2")); +# case RAISE_NEXUS_HANDLER_ERROR: +# throw new HandlerException(HandlerException.ErrorType.NOT_FOUND, "Handler error 1"); +# case RAISE_NEXUS_HANDLER_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: +# // ** THIS DOESN'T WORK ** +# // Can't overwrite cause with +# // io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException: Custom error +# // 2 +# HandlerException handlerErr = +# new HandlerException(HandlerException.ErrorType.NOT_FOUND, "Handler error 1"); +# handlerErr.initCause(new MyCustomException("Custom error 2")); +# throw handlerErr; +# case RAISE_NEXUS_OPERATION_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: +# throw OperationException.failure( +# ApplicationFailure.newNonRetryableFailureWithCause( +# "application error 1", +# "APPLICATION_ERROR", +# new MyCustomException("Custom error 2"))); +# } +# return new NexusService.ErrorTestOutput("Unreachable"); +# }); +# } + +# 🌈 RAISE_APPLICATION_ERROR: +# io.temporal.failure.NexusOperationFailure(message="Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=", scheduledEventId=scheduledEventId, operationToken="operationToken") +# io.nexusrpc.handler.HandlerException(message="handler error: message='application error 1', type='my-application-error-type', nonRetryable=true", type="INTERNAL", nonRetryable=true) +# io.temporal.failure.ApplicationFailure(message="application error 1", type="my-application-error-type", nonRetryable=true) + + +# 🌈 RAISE_CUSTOM_ERROR: +# io.temporal.failure.NexusOperationFailure(message="Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=", scheduledEventId=scheduledEventId, operationToken="operationToken") +# io.temporal.failure.TimeoutFailure(message="message='operation timed out', timeoutType=TIMEOUT_TYPE_SCHEDULE_TO_CLOSE") + + +# 🌈 RAISE_APPLICATION_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: +# io.temporal.failure.NexusOperationFailure(message="Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=", scheduledEventId=scheduledEventId, operationToken="operationToken") +# io.nexusrpc.handler.HandlerException(message="handler error: message='application error 1', type='my-application-error-type', nonRetryable=true", type="INTERNAL", nonRetryable=true) +# io.temporal.failure.ApplicationFailure(message="application error 1", type="my-application-error-type", nonRetryable=true) +# io.temporal.failure.ApplicationFailure(message="Custom error 2", type="io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException", nonRetryable=false) + + +# 🌈 RAISE_NEXUS_HANDLER_ERROR: +# io.temporal.failure.NexusOperationFailure(message="Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=", scheduledEventId=scheduledEventId, operationToken="operationToken") +# io.nexusrpc.handler.HandlerException(message="handler error: message='Handler error 1', type='java.lang.RuntimeException', nonRetryable=false", type="NOT_FOUND", nonRetryable=true) +# io.temporal.failure.ApplicationFailure(message="Handler error 1", type="java.lang.RuntimeException", nonRetryable=false) + + +# 🌈 RAISE_NEXUS_OPERATION_ERROR_WITH_CAUSE_OF_CUSTOM_ERROR: +# io.temporal.failure.NexusOperationFailure(message="Nexus Operation with operation='testErrorservice='NexusService' endpoint='my-nexus-endpoint-name' failed: 'nexus operation completed unsuccessfully'. scheduledEventId=5, operationToken=", scheduledEventId=scheduledEventId, operationToken="operationToken") +# io.temporal.failure.ApplicationFailure(message="application error 1", type="my-application-error-type", nonRetryable=true) +# io.temporal.failure.ApplicationFailure(message="Custom error 2", type="io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException", nonRetryable=false) + +ActionInSyncOp = Literal[ + "application_error_non_retryable", + "custom_error", + "custom_error_from_custom_error", + "application_error_non_retryable_from_custom_error", + "nexus_handler_error_not_found", + "nexus_handler_error_not_found_from_custom_error", + "nexus_operation_error_from_application_error_non_retryable_from_custom_error", +] + + +@dataclass +class ErrorConversionTestCase: + name: ActionInSyncOp + java_behavior: list[tuple[type[Exception], dict[str, Any]]] + + @staticmethod + def parse_exception( + exception: BaseException, + ) -> tuple[type[BaseException], dict[str, Any]]: + if isinstance(exception, NexusOperationError): + return NexusOperationError, {} + return type(exception), { + "message": getattr(exception, "message", None), + "type": getattr(exception, "type", None), + "non_retryable": getattr(exception, "non_retryable", None), + } + + +error_conversion_test_cases: list[ErrorConversionTestCase] = [] + + +# application_error_non_retryable: +error_conversion_test_cases.append( + ErrorConversionTestCase( + name="application_error_non_retryable", + java_behavior=[ + (NexusOperationError, {}), + ( + nexusrpc.HandlerError, + { + "message": "handler error: message='application error 1', type='my-application-error-type', nonRetryable=true", + "type": "INTERNAL", + "non_retryable": True, + }, + ), + ( + ApplicationError, + { + "message": "application error 1", + "type": "my-application-error-type", + "non_retryable": True, + }, + ), + ], + ) +) + +# custom_error: +error_conversion_test_cases.append( + ErrorConversionTestCase( + name="custom_error", + java_behavior=[], # [Not possible] + ) +) + + +# custom_error_from_custom_error: +error_conversion_test_cases.append( + ErrorConversionTestCase( + name="custom_error_from_custom_error", + java_behavior=[], # [Not possible] + ) +) + + +# application_error_non_retryable_from_custom_error: +error_conversion_test_cases.append( + ErrorConversionTestCase( + name="application_error_non_retryable_from_custom_error", + java_behavior=[ + (NexusOperationError, {}), + ( + nexusrpc.HandlerError, + { + "message": "handler error: message='application error 1', type='my-application-error-type', nonRetryable=true", + "type": "INTERNAL", + "non_retryable": True, + }, + ), + ( + ApplicationError, + { + "message": "application error 1", + "type": "my-application-error-type", + "non_retryable": True, + }, + ), + ( + ApplicationError, + { + "message": "Custom error 2", + "type": "io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException", + "non_retryable": False, + }, + ), + ], + ) +) + +# nexus_handler_error_not_found: +error_conversion_test_cases.append( + ErrorConversionTestCase( + name="nexus_handler_error_not_found", + java_behavior=[ + (NexusOperationError, {}), + ( + nexusrpc.HandlerError, + { + "message": "handler error: message='Handler error 1', type='java.lang.RuntimeException', nonRetryable=false", + "type": "NOT_FOUND", + "non_retryable": True, + }, + ), + ( + ApplicationError, + { + "message": "Handler error 1", + "type": "java.lang.RuntimeException", + "non_retryable": False, + }, + ), + ], + ) +) + +# nexus_handler_error_not_found_from_custom_error: +error_conversion_test_cases.append( + ErrorConversionTestCase( + name="nexus_handler_error_not_found_from_custom_error", + java_behavior=[], # [Not possible] + ) +) + + +# nexus_operation_error_from_application_error_non_retryable_from_custom_error: +error_conversion_test_cases.append( + ErrorConversionTestCase( + name="nexus_operation_error_from_application_error_non_retryable_from_custom_error", + java_behavior=[ + (NexusOperationError, {}), + ( + ApplicationError, + { + "message": "application error 1", + "type": "my-application-error-type", + "non_retryable": True, + }, + ), + ( + ApplicationError, + { + "message": "Custom error 2", + "type": "io.temporal.samples.nexus.handler.NexusServiceImpl$MyCustomException", + "non_retryable": False, + }, + ), + ], + ) +) + + +class CustomError(Exception): + pass + + +@dataclass +class ErrorTestInput: + task_queue: str + action_in_sync_op: ActionInSyncOp + + +@nexusrpc.handler.service_handler +class ErrorTestService: + @sync_operation + async def op(self, ctx: StartOperationContext, input: ErrorTestInput) -> None: + if input.action_in_sync_op == "application_error_non_retryable": + raise ApplicationError("application error in nexus op", non_retryable=True) + elif input.action_in_sync_op == "custom_error": + raise CustomError("custom error in nexus op") + elif input.action_in_sync_op == "custom_error_from_custom_error": + raise CustomError("custom error 1 in nexus op") from CustomError( + "custom error 2 in nexus op" + ) + elif ( + input.action_in_sync_op + == "application_error_non_retryable_from_custom_error" + ): + raise ApplicationError( + "application error in nexus op", non_retryable=True + ) from CustomError("custom error in nexus op") + elif input.action_in_sync_op == "nexus_handler_error_not_found": + raise nexusrpc.HandlerError( + "test", + type=nexusrpc.HandlerErrorType.NOT_FOUND, + ) + elif ( + input.action_in_sync_op == "nexus_handler_error_not_found_from_custom_error" + ): + raise nexusrpc.HandlerError( + "test", + type=nexusrpc.HandlerErrorType.NOT_FOUND, + ) from CustomError("custom error in nexus op") + elif ( + input.action_in_sync_op + == "nexus_operation_error_from_application_error_non_retryable_from_custom_error" + ): + try: + raise ApplicationError( + "application error in nexus op", non_retryable=True + ) from CustomError("custom error in nexus op") + except ApplicationError as err: + raise nexusrpc.OperationError( + "operation error in nexus op", + state=nexusrpc.OperationErrorState.FAILED, + ) from err + else: + raise NotImplementedError( + f"Unhandled action_in_sync_op: {input.action_in_sync_op}" + ) + + +# Caller + + +@workflow.defn(sandboxed=False) +class ErrorTestCallerWorkflow: + @workflow.init + def __init__(self, input: ErrorTestInput): + self.nexus_client = workflow.create_nexus_client( + endpoint=make_nexus_endpoint_name(input.task_queue), + service=ErrorTestService, + ) + self.test_cases = {t.name: t for t in error_conversion_test_cases} + + @workflow.run + async def run(self, input: ErrorTestInput) -> None: + try: + await self.nexus_client.execute_operation( + # TODO(nexus-prerelease): why wasn't this a type error? + # ErrorTestService.op, ErrorTestCallerWfInput() + ErrorTestService.op, + # TODO(nexus-prerelease): why wasn't this a type error? + # None + input, + ) + except BaseException as err: + errs = [err] + while err.__cause__: + errs.append(err.__cause__) + err = err.__cause__ + actual = [ErrorConversionTestCase.parse_exception(err) for err in errs] + results = list( + zip_longest( + self.test_cases[input.action_in_sync_op].java_behavior, + actual, + fillvalue=None, + ) + ) + print(f""" + +{input.action_in_sync_op} +{'-' * 80} +""") + for java_behavior, actual in results: # type: ignore[assignment] + print(f"Java: {java_behavior}") + print(f"Python: {actual}") + print() + print("-" * 80) + return None + + assert False, "Unreachable" + + +@pytest.mark.parametrize( + "action_in_sync_op", + [ + "application_error_non_retryable", + "custom_error", + "custom_error_from_custom_error", + "application_error_non_retryable_from_custom_error", + "nexus_handler_error_not_found", + "nexus_handler_error_not_found_from_custom_error", + "nexus_operation_error_from_application_error_non_retryable_from_custom_error", + ], +) +async def test_errors_raised_by_nexus_operation( + client: Client, action_in_sync_op: ActionInSyncOp +): + task_queue = str(uuid.uuid4()) + async with Worker( + client, + nexus_service_handlers=[ErrorTestService()], + workflows=[ErrorTestCallerWorkflow], + task_queue=task_queue, + ): + await create_nexus_endpoint(task_queue, client) + await client.execute_workflow( + ErrorTestCallerWorkflow.run, + ErrorTestInput( + task_queue=task_queue, + action_in_sync_op=action_in_sync_op, + ), + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + + +# Start timeout test +@service_handler +class StartTimeoutTestService: + @sync_operation + async def op_handler_that_never_returns( + self, ctx: StartOperationContext, input: None + ) -> None: + await asyncio.Future() + + +@workflow.defn +class StartTimeoutTestCallerWorkflow: + @workflow.init + def __init__(self): + self.nexus_client = workflow.create_nexus_client( + endpoint=make_nexus_endpoint_name(workflow.info().task_queue), + service=StartTimeoutTestService, + ) + + @workflow.run + async def run(self) -> None: + await self.nexus_client.execute_operation( + StartTimeoutTestService.op_handler_that_never_returns, + None, + schedule_to_close_timeout=timedelta(seconds=0.1), + ) + + +async def test_error_raised_by_timeout_of_nexus_start_operation(client: Client): + task_queue = str(uuid.uuid4()) + async with Worker( + client, + nexus_service_handlers=[StartTimeoutTestService()], + workflows=[StartTimeoutTestCallerWorkflow], + task_queue=task_queue, + ): + await create_nexus_endpoint(task_queue, client) + try: + await client.execute_workflow( + StartTimeoutTestCallerWorkflow.run, + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + except Exception as err: + assert isinstance(err, WorkflowFailureError) + assert isinstance(err.__cause__, NexusOperationError) + assert isinstance(err.__cause__.__cause__, TimeoutError) + else: + pytest.fail("Expected exception due to timeout of nexus start operation") + + +# Cancellation timeout test + + +class OperationWithCancelMethodThatNeverReturns(OperationHandler[None, None]): + async def start( + self, ctx: StartOperationContext, input: None + ) -> StartOperationResultAsync: + return StartOperationResultAsync("fake-token") + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + await asyncio.Future() + + async def fetch_info( + self, ctx: FetchOperationInfoContext, token: str + ) -> nexusrpc.OperationInfo: + raise NotImplementedError("Not implemented") + + async def fetch_result(self, ctx: FetchOperationResultContext, token: str) -> None: + raise NotImplementedError("Not implemented") + + +@service_handler +class CancellationTimeoutTestService: + @nexusrpc.handler._decorators.operation_handler + def op_with_cancel_method_that_never_returns( + self, + ) -> OperationHandler[None, None]: + return OperationWithCancelMethodThatNeverReturns() + + +@workflow.defn +class CancellationTimeoutTestCallerWorkflow: + @workflow.init + def __init__(self): + self.nexus_client = workflow.create_nexus_client( + endpoint=make_nexus_endpoint_name(workflow.info().task_queue), + service=CancellationTimeoutTestService, + ) + + @workflow.run + async def run(self) -> None: + op_handle = await self.nexus_client.start_operation( + CancellationTimeoutTestService.op_with_cancel_method_that_never_returns, + None, + schedule_to_close_timeout=timedelta(seconds=0.1), + ) + op_handle.cancel() + await op_handle + + +async def test_error_raised_by_timeout_of_nexus_cancel_operation(client: Client): + pytest.skip("TODO(nexus-prerelease): finish writing this test") + task_queue = str(uuid.uuid4()) + async with Worker( + client, + nexus_service_handlers=[CancellationTimeoutTestService()], + workflows=[CancellationTimeoutTestCallerWorkflow], + task_queue=task_queue, + ): + await create_nexus_endpoint(task_queue, client) + try: + await client.execute_workflow( + CancellationTimeoutTestCallerWorkflow.run, + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + except Exception as err: + assert isinstance(err, WorkflowFailureError) + assert isinstance(err.__cause__, NexusOperationError) + assert isinstance(err.__cause__.__cause__, TimeoutError) + else: + pytest.fail("Expected exception due to timeout of nexus cancel operation") + + +# Test overloads + + +@dataclass +class OverloadTestValue: + value: int + + +@workflow.defn +class OverloadTestHandlerWorkflow: + @workflow.run + async def run(self, input: OverloadTestValue) -> OverloadTestValue: + return OverloadTestValue(value=input.value * 2) + + +@workflow.defn +class OverloadTestHandlerWorkflowNoParam: + @workflow.run + async def run(self) -> OverloadTestValue: + return OverloadTestValue(value=0) + + +@nexusrpc.handler.service_handler +class OverloadTestServiceHandler: + @workflow_run_operation + async def no_param( + self, + ctx: WorkflowRunOperationContext, + _: OverloadTestValue, + ) -> nexus.WorkflowHandle[OverloadTestValue]: + return await ctx.start_workflow( + OverloadTestHandlerWorkflowNoParam.run, + id=str(uuid.uuid4()), + ) + + @workflow_run_operation + async def single_param( + self, ctx: WorkflowRunOperationContext, input: OverloadTestValue + ) -> nexus.WorkflowHandle[OverloadTestValue]: + return await ctx.start_workflow( + OverloadTestHandlerWorkflow.run, + input, + id=str(uuid.uuid4()), + ) + + @workflow_run_operation + async def multi_param( + self, ctx: WorkflowRunOperationContext, input: OverloadTestValue + ) -> nexus.WorkflowHandle[OverloadTestValue]: + return await ctx.start_workflow( + OverloadTestHandlerWorkflow.run, + args=[input], + id=str(uuid.uuid4()), + ) + + @workflow_run_operation + async def by_name( + self, ctx: WorkflowRunOperationContext, input: OverloadTestValue + ) -> nexus.WorkflowHandle[OverloadTestValue]: + return await ctx.start_workflow( + "OverloadTestHandlerWorkflow", + input, + id=str(uuid.uuid4()), + result_type=OverloadTestValue, + ) + + @workflow_run_operation + async def by_name_multi_param( + self, ctx: WorkflowRunOperationContext, input: OverloadTestValue + ) -> nexus.WorkflowHandle[OverloadTestValue]: + return await ctx.start_workflow( + "OverloadTestHandlerWorkflow", + args=[input], + id=str(uuid.uuid4()), + ) + + +@dataclass +class OverloadTestInput: + op: Callable[ + [Any, WorkflowRunOperationContext, Any], + Awaitable[temporalio.nexus.WorkflowHandle[Any]], + ] + input: Any + output: Any + + +@workflow.defn +class OverloadTestCallerWorkflow: + @workflow.run + async def run(self, op: str, input: OverloadTestValue) -> OverloadTestValue: + nexus_client = workflow.create_nexus_client( + endpoint=make_nexus_endpoint_name(workflow.info().task_queue), + service=OverloadTestServiceHandler, + ) + if op == "no_param": + return await nexus_client.execute_operation( + OverloadTestServiceHandler.no_param, input + ) + elif op == "single_param": + return await nexus_client.execute_operation( + OverloadTestServiceHandler.single_param, input + ) + elif op == "multi_param": + return await nexus_client.execute_operation( + OverloadTestServiceHandler.multi_param, input + ) + elif op == "by_name": + return await nexus_client.execute_operation( + OverloadTestServiceHandler.by_name, input + ) + elif op == "by_name_multi_param": + return await nexus_client.execute_operation( + OverloadTestServiceHandler.by_name_multi_param, input + ) + else: + raise ValueError(f"Unknown op: {op}") + + +@pytest.mark.parametrize( + "op", + [ + "no_param", + "single_param", + "multi_param", + "by_name", + "by_name_multi_param", + ], +) +async def test_workflow_run_operation_overloads(client: Client, op: str): + task_queue = str(uuid.uuid4()) + async with Worker( + client, + task_queue=task_queue, + workflows=[ + OverloadTestCallerWorkflow, + OverloadTestHandlerWorkflow, + OverloadTestHandlerWorkflowNoParam, + ], + nexus_service_handlers=[OverloadTestServiceHandler()], + ): + await create_nexus_endpoint(task_queue, client) + res = await client.execute_workflow( + OverloadTestCallerWorkflow.run, + args=[op, OverloadTestValue(value=2)], + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + assert res == ( + OverloadTestValue(value=4) + if op != "no_param" + else OverloadTestValue(value=0) + ) diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py new file mode 100644 index 000000000..217316412 --- /dev/null +++ b/tests/nexus/test_workflow_run_operation.py @@ -0,0 +1,121 @@ +import re +import uuid +from dataclasses import dataclass +from typing import Any, Type + +import nexusrpc +import pytest +from nexusrpc import Operation, service +from nexusrpc.handler import ( + OperationHandler, + StartOperationContext, + StartOperationResultAsync, + service_handler, +) +from nexusrpc.handler._decorators import operation_handler + +from temporalio import workflow +from temporalio.nexus import WorkflowRunOperationContext +from temporalio.nexus._operation_handlers import WorkflowRunOperationHandler +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from tests.helpers.nexus import ( + Failure, + ServiceClient, + create_nexus_endpoint, + dataclass_as_dict, +) + +HTTP_PORT = 7243 + + +@dataclass +class Input: + value: str + + +@workflow.defn +class EchoWorkflow: + @workflow.run + async def run(self, input: str) -> str: + return input + + +class MyOperation(WorkflowRunOperationHandler): + def __init__(self): + pass + + async def start( + self, ctx: StartOperationContext, input: Input + ) -> StartOperationResultAsync: + tctx = WorkflowRunOperationContext.from_start_operation_context(ctx) + handle = await tctx.start_workflow( + EchoWorkflow.run, + input.value, + id=str(uuid.uuid4()), + ) + return StartOperationResultAsync(handle.to_token()) + + +@service_handler +class SubclassingHappyPath: + @operation_handler + def op(self) -> OperationHandler[Input, str]: + return MyOperation() + + +@service +class Service: + op: Operation[Input, str] + + +@service_handler(service=Service) +class SubclassingNoInputOutputTypeAnnotationsWithServiceDefinition: + # Despite the lack of annotations on the service impl, the service definition + # provides the type needed to deserialize the input into Input so that input.value + # succeeds. + @operation_handler + def op(self) -> OperationHandler: + return MyOperation() + + +@pytest.mark.parametrize( + "service_handler_cls", + [ + SubclassingHappyPath, + SubclassingNoInputOutputTypeAnnotationsWithServiceDefinition, + ], +) +async def test_workflow_run_operation( + env: WorkflowEnvironment, + service_handler_cls: Type[Any], +): + task_queue = str(uuid.uuid4()) + endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + assert (service_defn := nexusrpc.get_service_definition(service_handler_cls)) + service_client = ServiceClient( + server_address=server_address(env), + endpoint=endpoint, + service=service_defn.name, + ) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[service_handler_cls()], + ): + resp = await service_client.start_operation( + "op", + dataclass_as_dict(Input(value="test")), + ) + if hasattr(service_handler_cls, "__expected__error__"): + status_code, message = service_handler_cls.__expected__error__ + assert resp.status_code == status_code + failure = Failure(**resp.json()) + assert re.search(message, failure.message) + else: + assert resp.status_code == 201 + + +def server_address(env: WorkflowEnvironment) -> str: + http_port = getattr(env, "_http_port", 7243) + return f"http://127.0.0.1:{http_port}" diff --git a/uv.lock b/uv.lock index f753830c7..08dd46baf 100644 --- a/uv.lock +++ b/uv.lock @@ -287,6 +287,85 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b8/40/c199d095151addf69efdb4b9ca3a4f20f70e20508d6222bffb9b76f58573/constantly-23.10.4-py3-none-any.whl", hash = "sha256:3fd9b4d1c3dc1ec9757f3c52aef7e53ad9323dbe39f51dfd4c43853b68dfa3f9", size = 13547, upload-time = "2023-10-28T23:18:23.038Z" }, ] +[[package]] +name = "coverage" +version = "7.9.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/e0/98670a80884f64578f0c22cd70c5e81a6e07b08167721c7487b4d70a7ca0/coverage-7.9.1.tar.gz", hash = "sha256:6cf43c78c4282708a28e466316935ec7489a9c487518a77fa68f716c67909cec", size = 813650, upload-time = "2025-06-13T13:02:28.627Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/78/1c1c5ec58f16817c09cbacb39783c3655d54a221b6552f47ff5ac9297603/coverage-7.9.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cc94d7c5e8423920787c33d811c0be67b7be83c705f001f7180c7b186dcf10ca", size = 212028, upload-time = "2025-06-13T13:00:29.293Z" }, + { url = "https://files.pythonhosted.org/packages/98/db/e91b9076f3a888e3b4ad7972ea3842297a52cc52e73fd1e529856e473510/coverage-7.9.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:16aa0830d0c08a2c40c264cef801db8bc4fc0e1892782e45bcacbd5889270509", size = 212420, upload-time = "2025-06-13T13:00:34.027Z" }, + { url = "https://files.pythonhosted.org/packages/0e/d0/2b3733412954576b0aea0a16c3b6b8fbe95eb975d8bfa10b07359ead4252/coverage-7.9.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf95981b126f23db63e9dbe4cf65bd71f9a6305696fa5e2262693bc4e2183f5b", size = 241529, upload-time = "2025-06-13T13:00:35.786Z" }, + { url = "https://files.pythonhosted.org/packages/b3/00/5e2e5ae2e750a872226a68e984d4d3f3563cb01d1afb449a17aa819bc2c4/coverage-7.9.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f05031cf21699785cd47cb7485f67df619e7bcdae38e0fde40d23d3d0210d3c3", size = 239403, upload-time = "2025-06-13T13:00:37.399Z" }, + { url = "https://files.pythonhosted.org/packages/37/3b/a2c27736035156b0a7c20683afe7df498480c0dfdf503b8c878a21b6d7fb/coverage-7.9.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb4fbcab8764dc072cb651a4bcda4d11fb5658a1d8d68842a862a6610bd8cfa3", size = 240548, upload-time = "2025-06-13T13:00:39.647Z" }, + { url = "https://files.pythonhosted.org/packages/98/f5/13d5fc074c3c0e0dc80422d9535814abf190f1254d7c3451590dc4f8b18c/coverage-7.9.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:0f16649a7330ec307942ed27d06ee7e7a38417144620bb3d6e9a18ded8a2d3e5", size = 240459, upload-time = "2025-06-13T13:00:40.934Z" }, + { url = "https://files.pythonhosted.org/packages/36/24/24b9676ea06102df824c4a56ffd13dc9da7904478db519efa877d16527d5/coverage-7.9.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:cea0a27a89e6432705fffc178064503508e3c0184b4f061700e771a09de58187", size = 239128, upload-time = "2025-06-13T13:00:42.343Z" }, + { url = "https://files.pythonhosted.org/packages/be/05/242b7a7d491b369ac5fee7908a6e5ba42b3030450f3ad62c645b40c23e0e/coverage-7.9.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e980b53a959fa53b6f05343afbd1e6f44a23ed6c23c4b4c56c6662bbb40c82ce", size = 239402, upload-time = "2025-06-13T13:00:43.634Z" }, + { url = "https://files.pythonhosted.org/packages/73/e0/4de7f87192fa65c9c8fbaeb75507e124f82396b71de1797da5602898be32/coverage-7.9.1-cp310-cp310-win32.whl", hash = "sha256:70760b4c5560be6ca70d11f8988ee6542b003f982b32f83d5ac0b72476607b70", size = 214518, upload-time = "2025-06-13T13:00:45.622Z" }, + { url = "https://files.pythonhosted.org/packages/d5/ab/5e4e2fe458907d2a65fab62c773671cfc5ac704f1e7a9ddd91996f66e3c2/coverage-7.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:a66e8f628b71f78c0e0342003d53b53101ba4e00ea8dabb799d9dba0abbbcebe", size = 215436, upload-time = "2025-06-13T13:00:47.245Z" }, + { url = "https://files.pythonhosted.org/packages/60/34/fa69372a07d0903a78ac103422ad34db72281c9fc625eba94ac1185da66f/coverage-7.9.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:95c765060e65c692da2d2f51a9499c5e9f5cf5453aeaf1420e3fc847cc060582", size = 212146, upload-time = "2025-06-13T13:00:48.496Z" }, + { url = "https://files.pythonhosted.org/packages/27/f0/da1894915d2767f093f081c42afeba18e760f12fdd7a2f4acbe00564d767/coverage-7.9.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ba383dc6afd5ec5b7a0d0c23d38895db0e15bcba7fb0fa8901f245267ac30d86", size = 212536, upload-time = "2025-06-13T13:00:51.535Z" }, + { url = "https://files.pythonhosted.org/packages/10/d5/3fc33b06e41e390f88eef111226a24e4504d216ab8e5d1a7089aa5a3c87a/coverage-7.9.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37ae0383f13cbdcf1e5e7014489b0d71cc0106458878ccde52e8a12ced4298ed", size = 245092, upload-time = "2025-06-13T13:00:52.883Z" }, + { url = "https://files.pythonhosted.org/packages/0a/39/7aa901c14977aba637b78e95800edf77f29f5a380d29768c5b66f258305b/coverage-7.9.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:69aa417a030bf11ec46149636314c24c8d60fadb12fc0ee8f10fda0d918c879d", size = 242806, upload-time = "2025-06-13T13:00:54.571Z" }, + { url = "https://files.pythonhosted.org/packages/43/fc/30e5cfeaf560b1fc1989227adedc11019ce4bb7cce59d65db34fe0c2d963/coverage-7.9.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a4be2a28656afe279b34d4f91c3e26eccf2f85500d4a4ff0b1f8b54bf807338", size = 244610, upload-time = "2025-06-13T13:00:56.932Z" }, + { url = "https://files.pythonhosted.org/packages/bf/15/cca62b13f39650bc87b2b92bb03bce7f0e79dd0bf2c7529e9fc7393e4d60/coverage-7.9.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:382e7ddd5289f140259b610e5f5c58f713d025cb2f66d0eb17e68d0a94278875", size = 244257, upload-time = "2025-06-13T13:00:58.545Z" }, + { url = "https://files.pythonhosted.org/packages/cd/1a/c0f2abe92c29e1464dbd0ff9d56cb6c88ae2b9e21becdb38bea31fcb2f6c/coverage-7.9.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e5532482344186c543c37bfad0ee6069e8ae4fc38d073b8bc836fc8f03c9e250", size = 242309, upload-time = "2025-06-13T13:00:59.836Z" }, + { url = "https://files.pythonhosted.org/packages/57/8d/c6fd70848bd9bf88fa90df2af5636589a8126d2170f3aade21ed53f2b67a/coverage-7.9.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a39d18b3f50cc121d0ce3838d32d58bd1d15dab89c910358ebefc3665712256c", size = 242898, upload-time = "2025-06-13T13:01:02.506Z" }, + { url = "https://files.pythonhosted.org/packages/c2/9e/6ca46c7bff4675f09a66fe2797cd1ad6a24f14c9c7c3b3ebe0470a6e30b8/coverage-7.9.1-cp311-cp311-win32.whl", hash = "sha256:dd24bd8d77c98557880def750782df77ab2b6885a18483dc8588792247174b32", size = 214561, upload-time = "2025-06-13T13:01:04.012Z" }, + { url = "https://files.pythonhosted.org/packages/a1/30/166978c6302010742dabcdc425fa0f938fa5a800908e39aff37a7a876a13/coverage-7.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:6b55ad10a35a21b8015eabddc9ba31eb590f54adc9cd39bcf09ff5349fd52125", size = 215493, upload-time = "2025-06-13T13:01:05.702Z" }, + { url = "https://files.pythonhosted.org/packages/60/07/a6d2342cd80a5be9f0eeab115bc5ebb3917b4a64c2953534273cf9bc7ae6/coverage-7.9.1-cp311-cp311-win_arm64.whl", hash = "sha256:6ad935f0016be24c0e97fc8c40c465f9c4b85cbbe6eac48934c0dc4d2568321e", size = 213869, upload-time = "2025-06-13T13:01:09.345Z" }, + { url = "https://files.pythonhosted.org/packages/68/d9/7f66eb0a8f2fce222de7bdc2046ec41cb31fe33fb55a330037833fb88afc/coverage-7.9.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a8de12b4b87c20de895f10567639c0797b621b22897b0af3ce4b4e204a743626", size = 212336, upload-time = "2025-06-13T13:01:10.909Z" }, + { url = "https://files.pythonhosted.org/packages/20/20/e07cb920ef3addf20f052ee3d54906e57407b6aeee3227a9c91eea38a665/coverage-7.9.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5add197315a054e92cee1b5f686a2bcba60c4c3e66ee3de77ace6c867bdee7cb", size = 212571, upload-time = "2025-06-13T13:01:12.518Z" }, + { url = "https://files.pythonhosted.org/packages/78/f8/96f155de7e9e248ca9c8ff1a40a521d944ba48bec65352da9be2463745bf/coverage-7.9.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:600a1d4106fe66f41e5d0136dfbc68fe7200a5cbe85610ddf094f8f22e1b0300", size = 246377, upload-time = "2025-06-13T13:01:14.87Z" }, + { url = "https://files.pythonhosted.org/packages/3e/cf/1d783bd05b7bca5c10ded5f946068909372e94615a4416afadfe3f63492d/coverage-7.9.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a876e4c3e5a2a1715a6608906aa5a2e0475b9c0f68343c2ada98110512ab1d8", size = 243394, upload-time = "2025-06-13T13:01:16.23Z" }, + { url = "https://files.pythonhosted.org/packages/02/dd/e7b20afd35b0a1abea09fb3998e1abc9f9bd953bee548f235aebd2b11401/coverage-7.9.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81f34346dd63010453922c8e628a52ea2d2ccd73cb2487f7700ac531b247c8a5", size = 245586, upload-time = "2025-06-13T13:01:17.532Z" }, + { url = "https://files.pythonhosted.org/packages/4e/38/b30b0006fea9d617d1cb8e43b1bc9a96af11eff42b87eb8c716cf4d37469/coverage-7.9.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:888f8eee13f2377ce86d44f338968eedec3291876b0b8a7289247ba52cb984cd", size = 245396, upload-time = "2025-06-13T13:01:19.164Z" }, + { url = "https://files.pythonhosted.org/packages/31/e4/4d8ec1dc826e16791f3daf1b50943e8e7e1eb70e8efa7abb03936ff48418/coverage-7.9.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9969ef1e69b8c8e1e70d591f91bbc37fc9a3621e447525d1602801a24ceda898", size = 243577, upload-time = "2025-06-13T13:01:22.433Z" }, + { url = "https://files.pythonhosted.org/packages/25/f4/b0e96c5c38e6e40ef465c4bc7f138863e2909c00e54a331da335faf0d81a/coverage-7.9.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:60c458224331ee3f1a5b472773e4a085cc27a86a0b48205409d364272d67140d", size = 244809, upload-time = "2025-06-13T13:01:24.143Z" }, + { url = "https://files.pythonhosted.org/packages/8a/65/27e0a1fa5e2e5079bdca4521be2f5dabf516f94e29a0defed35ac2382eb2/coverage-7.9.1-cp312-cp312-win32.whl", hash = "sha256:5f646a99a8c2b3ff4c6a6e081f78fad0dde275cd59f8f49dc4eab2e394332e74", size = 214724, upload-time = "2025-06-13T13:01:25.435Z" }, + { url = "https://files.pythonhosted.org/packages/9b/a8/d5b128633fd1a5e0401a4160d02fa15986209a9e47717174f99dc2f7166d/coverage-7.9.1-cp312-cp312-win_amd64.whl", hash = "sha256:30f445f85c353090b83e552dcbbdad3ec84c7967e108c3ae54556ca69955563e", size = 215535, upload-time = "2025-06-13T13:01:27.861Z" }, + { url = "https://files.pythonhosted.org/packages/a3/37/84bba9d2afabc3611f3e4325ee2c6a47cd449b580d4a606b240ce5a6f9bf/coverage-7.9.1-cp312-cp312-win_arm64.whl", hash = "sha256:af41da5dca398d3474129c58cb2b106a5d93bbb196be0d307ac82311ca234342", size = 213904, upload-time = "2025-06-13T13:01:29.202Z" }, + { url = "https://files.pythonhosted.org/packages/d0/a7/a027970c991ca90f24e968999f7d509332daf6b8c3533d68633930aaebac/coverage-7.9.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:31324f18d5969feef7344a932c32428a2d1a3e50b15a6404e97cba1cc9b2c631", size = 212358, upload-time = "2025-06-13T13:01:30.909Z" }, + { url = "https://files.pythonhosted.org/packages/f2/48/6aaed3651ae83b231556750280682528fea8ac7f1232834573472d83e459/coverage-7.9.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0c804506d624e8a20fb3108764c52e0eef664e29d21692afa375e0dd98dc384f", size = 212620, upload-time = "2025-06-13T13:01:32.256Z" }, + { url = "https://files.pythonhosted.org/packages/6c/2a/f4b613f3b44d8b9f144847c89151992b2b6b79cbc506dee89ad0c35f209d/coverage-7.9.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ef64c27bc40189f36fcc50c3fb8f16ccda73b6a0b80d9bd6e6ce4cffcd810bbd", size = 245788, upload-time = "2025-06-13T13:01:33.948Z" }, + { url = "https://files.pythonhosted.org/packages/04/d2/de4fdc03af5e4e035ef420ed26a703c6ad3d7a07aff2e959eb84e3b19ca8/coverage-7.9.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d4fe2348cc6ec372e25adec0219ee2334a68d2f5222e0cba9c0d613394e12d86", size = 243001, upload-time = "2025-06-13T13:01:35.285Z" }, + { url = "https://files.pythonhosted.org/packages/f5/e8/eed18aa5583b0423ab7f04e34659e51101135c41cd1dcb33ac1d7013a6d6/coverage-7.9.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:34ed2186fe52fcc24d4561041979a0dec69adae7bce2ae8d1c49eace13e55c43", size = 244985, upload-time = "2025-06-13T13:01:36.712Z" }, + { url = "https://files.pythonhosted.org/packages/17/f8/ae9e5cce8885728c934eaa58ebfa8281d488ef2afa81c3dbc8ee9e6d80db/coverage-7.9.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:25308bd3d00d5eedd5ae7d4357161f4df743e3c0240fa773ee1b0f75e6c7c0f1", size = 245152, upload-time = "2025-06-13T13:01:39.303Z" }, + { url = "https://files.pythonhosted.org/packages/5a/c8/272c01ae792bb3af9b30fac14d71d63371db227980682836ec388e2c57c0/coverage-7.9.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:73e9439310f65d55a5a1e0564b48e34f5369bee943d72c88378f2d576f5a5751", size = 243123, upload-time = "2025-06-13T13:01:40.727Z" }, + { url = "https://files.pythonhosted.org/packages/8c/d0/2819a1e3086143c094ab446e3bdf07138527a7b88cb235c488e78150ba7a/coverage-7.9.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:37ab6be0859141b53aa89412a82454b482c81cf750de4f29223d52268a86de67", size = 244506, upload-time = "2025-06-13T13:01:42.184Z" }, + { url = "https://files.pythonhosted.org/packages/8b/4e/9f6117b89152df7b6112f65c7a4ed1f2f5ec8e60c4be8f351d91e7acc848/coverage-7.9.1-cp313-cp313-win32.whl", hash = "sha256:64bdd969456e2d02a8b08aa047a92d269c7ac1f47e0c977675d550c9a0863643", size = 214766, upload-time = "2025-06-13T13:01:44.482Z" }, + { url = "https://files.pythonhosted.org/packages/27/0f/4b59f7c93b52c2c4ce7387c5a4e135e49891bb3b7408dcc98fe44033bbe0/coverage-7.9.1-cp313-cp313-win_amd64.whl", hash = "sha256:be9e3f68ca9edb897c2184ad0eee815c635565dbe7a0e7e814dc1f7cbab92c0a", size = 215568, upload-time = "2025-06-13T13:01:45.772Z" }, + { url = "https://files.pythonhosted.org/packages/09/1e/9679826336f8c67b9c39a359352882b24a8a7aee48d4c9cad08d38d7510f/coverage-7.9.1-cp313-cp313-win_arm64.whl", hash = "sha256:1c503289ffef1d5105d91bbb4d62cbe4b14bec4d13ca225f9c73cde9bb46207d", size = 213939, upload-time = "2025-06-13T13:01:47.087Z" }, + { url = "https://files.pythonhosted.org/packages/bb/5b/5c6b4e7a407359a2e3b27bf9c8a7b658127975def62077d441b93a30dbe8/coverage-7.9.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0b3496922cb5f4215bf5caaef4cf12364a26b0be82e9ed6d050f3352cf2d7ef0", size = 213079, upload-time = "2025-06-13T13:01:48.554Z" }, + { url = "https://files.pythonhosted.org/packages/a2/22/1e2e07279fd2fd97ae26c01cc2186e2258850e9ec125ae87184225662e89/coverage-7.9.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:9565c3ab1c93310569ec0d86b017f128f027cab0b622b7af288696d7ed43a16d", size = 213299, upload-time = "2025-06-13T13:01:49.997Z" }, + { url = "https://files.pythonhosted.org/packages/14/c0/4c5125a4b69d66b8c85986d3321520f628756cf524af810baab0790c7647/coverage-7.9.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2241ad5dbf79ae1d9c08fe52b36d03ca122fb9ac6bca0f34439e99f8327ac89f", size = 256535, upload-time = "2025-06-13T13:01:51.314Z" }, + { url = "https://files.pythonhosted.org/packages/81/8b/e36a04889dda9960be4263e95e777e7b46f1bb4fc32202612c130a20c4da/coverage-7.9.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bb5838701ca68b10ebc0937dbd0eb81974bac54447c55cd58dea5bca8451029", size = 252756, upload-time = "2025-06-13T13:01:54.403Z" }, + { url = "https://files.pythonhosted.org/packages/98/82/be04eff8083a09a4622ecd0e1f31a2c563dbea3ed848069e7b0445043a70/coverage-7.9.1-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b30a25f814591a8c0c5372c11ac8967f669b97444c47fd794926e175c4047ece", size = 254912, upload-time = "2025-06-13T13:01:56.769Z" }, + { url = "https://files.pythonhosted.org/packages/0f/25/c26610a2c7f018508a5ab958e5b3202d900422cf7cdca7670b6b8ca4e8df/coverage-7.9.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:2d04b16a6062516df97969f1ae7efd0de9c31eb6ebdceaa0d213b21c0ca1a683", size = 256144, upload-time = "2025-06-13T13:01:58.19Z" }, + { url = "https://files.pythonhosted.org/packages/c5/8b/fb9425c4684066c79e863f1e6e7ecebb49e3a64d9f7f7860ef1688c56f4a/coverage-7.9.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:7931b9e249edefb07cd6ae10c702788546341d5fe44db5b6108a25da4dca513f", size = 254257, upload-time = "2025-06-13T13:01:59.645Z" }, + { url = "https://files.pythonhosted.org/packages/93/df/27b882f54157fc1131e0e215b0da3b8d608d9b8ef79a045280118a8f98fe/coverage-7.9.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:52e92b01041151bf607ee858e5a56c62d4b70f4dac85b8c8cb7fb8a351ab2c10", size = 255094, upload-time = "2025-06-13T13:02:01.37Z" }, + { url = "https://files.pythonhosted.org/packages/41/5f/cad1c3dbed8b3ee9e16fa832afe365b4e3eeab1fb6edb65ebbf745eabc92/coverage-7.9.1-cp313-cp313t-win32.whl", hash = "sha256:684e2110ed84fd1ca5f40e89aa44adf1729dc85444004111aa01866507adf363", size = 215437, upload-time = "2025-06-13T13:02:02.905Z" }, + { url = "https://files.pythonhosted.org/packages/99/4d/fad293bf081c0e43331ca745ff63673badc20afea2104b431cdd8c278b4c/coverage-7.9.1-cp313-cp313t-win_amd64.whl", hash = "sha256:437c576979e4db840539674e68c84b3cda82bc824dd138d56bead1435f1cb5d7", size = 216605, upload-time = "2025-06-13T13:02:05.638Z" }, + { url = "https://files.pythonhosted.org/packages/1f/56/4ee027d5965fc7fc126d7ec1187529cc30cc7d740846e1ecb5e92d31b224/coverage-7.9.1-cp313-cp313t-win_arm64.whl", hash = "sha256:18a0912944d70aaf5f399e350445738a1a20b50fbea788f640751c2ed9208b6c", size = 214392, upload-time = "2025-06-13T13:02:07.642Z" }, + { url = "https://files.pythonhosted.org/packages/a5/d6/c41dd9b02bf16ec001aaf1cbef665537606899a3db1094e78f5ae17540ca/coverage-7.9.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f424507f57878e424d9a95dc4ead3fbdd72fd201e404e861e465f28ea469951", size = 212029, upload-time = "2025-06-13T13:02:09.058Z" }, + { url = "https://files.pythonhosted.org/packages/f8/c0/40420d81d731f84c3916dcdf0506b3e6c6570817bff2576b83f780914ae6/coverage-7.9.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:535fde4001b2783ac80865d90e7cc7798b6b126f4cd8a8c54acfe76804e54e58", size = 212407, upload-time = "2025-06-13T13:02:11.151Z" }, + { url = "https://files.pythonhosted.org/packages/9b/87/f0db7d62d0e09f14d6d2f6ae8c7274a2f09edf74895a34b412a0601e375a/coverage-7.9.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:02532fd3290bb8fa6bec876520842428e2a6ed6c27014eca81b031c2d30e3f71", size = 241160, upload-time = "2025-06-13T13:02:12.864Z" }, + { url = "https://files.pythonhosted.org/packages/a9/b7/3337c064f058a5d7696c4867159651a5b5fb01a5202bcf37362f0c51400e/coverage-7.9.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:56f5eb308b17bca3bbff810f55ee26d51926d9f89ba92707ee41d3c061257e55", size = 239027, upload-time = "2025-06-13T13:02:14.294Z" }, + { url = "https://files.pythonhosted.org/packages/7e/a9/5898a283f66d1bd413c32c2e0e05408196fd4f37e206e2b06c6e0c626e0e/coverage-7.9.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfa447506c1a52271f1b0de3f42ea0fa14676052549095e378d5bff1c505ff7b", size = 240145, upload-time = "2025-06-13T13:02:15.745Z" }, + { url = "https://files.pythonhosted.org/packages/e0/33/d96e3350078a3c423c549cb5b2ba970de24c5257954d3e4066e2b2152d30/coverage-7.9.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:9ca8e220006966b4a7b68e8984a6aee645a0384b0769e829ba60281fe61ec4f7", size = 239871, upload-time = "2025-06-13T13:02:17.344Z" }, + { url = "https://files.pythonhosted.org/packages/1d/6e/6fb946072455f71a820cac144d49d11747a0f1a21038060a68d2d0200499/coverage-7.9.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:49f1d0788ba5b7ba65933f3a18864117c6506619f5ca80326b478f72acf3f385", size = 238122, upload-time = "2025-06-13T13:02:18.849Z" }, + { url = "https://files.pythonhosted.org/packages/e4/5c/bc43f25c8586840ce25a796a8111acf6a2b5f0909ba89a10d41ccff3920d/coverage-7.9.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:68cd53aec6f45b8e4724c0950ce86eacb775c6be01ce6e3669fe4f3a21e768ed", size = 239058, upload-time = "2025-06-13T13:02:21.423Z" }, + { url = "https://files.pythonhosted.org/packages/11/d8/ce2007418dd7fd00ff8c8b898bb150bb4bac2d6a86df05d7b88a07ff595f/coverage-7.9.1-cp39-cp39-win32.whl", hash = "sha256:95335095b6c7b1cc14c3f3f17d5452ce677e8490d101698562b2ffcacc304c8d", size = 214532, upload-time = "2025-06-13T13:02:22.857Z" }, + { url = "https://files.pythonhosted.org/packages/20/21/334e76fa246e92e6d69cab217f7c8a70ae0cc8f01438bd0544103f29528e/coverage-7.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:e1b5191d1648acc439b24721caab2fd0c86679d8549ed2c84d5a7ec1bedcc244", size = 215439, upload-time = "2025-06-13T13:02:24.268Z" }, + { url = "https://files.pythonhosted.org/packages/3e/e5/c723545c3fd3204ebde3b4cc4b927dce709d3b6dc577754bb57f63ca4a4a/coverage-7.9.1-pp39.pp310.pp311-none-any.whl", hash = "sha256:db0f04118d1db74db6c9e1cb1898532c7dcc220f1d2718f058601f7c3f499514", size = 204009, upload-time = "2025-06-13T13:02:25.787Z" }, + { url = "https://files.pythonhosted.org/packages/08/b8/7ddd1e8ba9701dea08ce22029917140e6f66a859427406579fd8d0ca7274/coverage-7.9.1-py3-none-any.whl", hash = "sha256:66b974b145aa189516b6bf2d8423e888b742517d37872f6ee4c5be0073bd9a3c", size = 204000, upload-time = "2025-06-13T13:02:27.173Z" }, +] + +[package.optional-dependencies] +toml = [ + { name = "tomli", marker = "python_full_version <= '3.11'" }, +] + [[package]] name = "cryptography" version = "45.0.4" @@ -962,6 +1041,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e8/73/d6b999782ae22f16971cc05378b3b33f6a89ede3b9619e8366aa23484bca/mypy_protobuf-3.6.0-py3-none-any.whl", hash = "sha256:56176e4d569070e7350ea620262478b49b7efceba4103d468448f1d21492fd6c", size = 16434, upload-time = "2024-04-01T20:24:40.583Z" }, ] +[[package]] +name = "nexus-rpc" +version = "1.1.0" +source = { git = "https://github.com/nexus-rpc/sdk-python#94a1267cb5baabf2d3609aedb7f6cf81587be6df" } +dependencies = [ + { name = "typing-extensions" }, +] + [[package]] name = "nh3" version = "0.2.21" @@ -1336,14 +1423,15 @@ wheels = [ [[package]] name = "pyright" -version = "1.1.377" +version = "1.1.400" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodeenv" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/49/f0/25b0db363d6888164adb7c828b877bbf2c30936955fb9513922ae03e70e4/pyright-1.1.377.tar.gz", hash = "sha256:aabc30fedce0ded34baa0c49b24f10e68f4bfc8f68ae7f3d175c4b0f256b4fcf", size = 17484, upload-time = "2024-08-21T02:25:15.74Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/cb/c306618a02d0ee8aed5fb8d0fe0ecfed0dbf075f71468f03a30b5f4e1fe0/pyright-1.1.400.tar.gz", hash = "sha256:b8a3ba40481aa47ba08ffb3228e821d22f7d391f83609211335858bf05686bdb", size = 3846546, upload-time = "2025-04-24T12:55:18.907Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/34/c9/89c40c4de44fe9463e77dddd0c4e2d2dd7a93e8ddc6858dfe7d5f75d263d/pyright-1.1.377-py3-none-any.whl", hash = "sha256:af0dd2b6b636c383a6569a083f8c5a8748ae4dcde5df7914b3f3f267e14dd162", size = 18223, upload-time = "2024-08-21T02:25:14.585Z" }, + { url = "https://files.pythonhosted.org/packages/c8/a5/5d285e4932cf149c90e3c425610c5efaea005475d5f96f1bfdb452956c62/pyright-1.1.400-py3-none-any.whl", hash = "sha256:c80d04f98b5a4358ad3a35e241dbf2a408eee33a40779df365644f8054d2517e", size = 5563460, upload-time = "2025-04-24T12:55:17.002Z" }, ] [[package]] @@ -1375,6 +1463,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9c/ce/1e4b53c213dce25d6e8b163697fbce2d43799d76fa08eea6ad270451c370/pytest_asyncio-0.21.2-py3-none-any.whl", hash = "sha256:ab664c88bb7998f711d8039cacd4884da6430886ae8bbd4eded552ed2004f16b", size = 13368, upload-time = "2024-04-29T13:23:23.126Z" }, ] +[[package]] +name = "pytest-cov" +version = "6.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "coverage", extra = ["toml"] }, + { name = "pluggy" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/18/99/668cade231f434aaa59bbfbf49469068d2ddd945000621d3d165d2e7dd7b/pytest_cov-6.2.1.tar.gz", hash = "sha256:25cc6cc0a5358204b8108ecedc51a9b57b34cc6b8c967cc2c01a4e00d8a67da2", size = 69432, upload-time = "2025-06-12T10:47:47.684Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/16/4ea354101abb1287856baa4af2732be351c7bee728065aed451b678153fd/pytest_cov-6.2.1-py3-none-any.whl", hash = "sha256:f5bc4c23f42f1cdd23c70b1dab1bbaef4fc505ba950d53e0081d0730dd7e86d5", size = 24644, upload-time = "2025-06-12T10:47:45.932Z" }, +] + [[package]] name = "pytest-pretty" version = "1.3.0" @@ -1606,6 +1708,7 @@ name = "temporalio" version = "1.13.0" source = { virtual = "." } dependencies = [ + { name = "nexus-rpc" }, { name = "protobuf" }, { name = "python-dateutil", marker = "python_full_version < '3.11'" }, { name = "types-protobuf" }, @@ -1632,6 +1735,7 @@ pydantic = [ dev = [ { name = "cibuildwheel" }, { name = "grpcio-tools" }, + { name = "httpx" }, { name = "maturin" }, { name = "mypy" }, { name = "mypy-protobuf" }, @@ -1641,6 +1745,7 @@ dev = [ { name = "pyright" }, { name = "pytest" }, { name = "pytest-asyncio" }, + { name = "pytest-cov" }, { name = "pytest-pretty" }, { name = "pytest-timeout" }, { name = "ruff" }, @@ -1652,6 +1757,7 @@ dev = [ requires-dist = [ { name = "eval-type-backport", marker = "python_full_version < '3.10' and extra == 'openai-agents'", specifier = ">=0.2.2" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.48.2,<2" }, + { name = "nexus-rpc", git = "https://github.com/nexus-rpc/sdk-python" }, { name = "openai-agents", marker = "extra == 'openai-agents'", specifier = ">=0.0.19,<0.1" }, { name = "opentelemetry-api", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, { name = "opentelemetry-sdk", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, @@ -1667,15 +1773,17 @@ provides-extras = ["grpc", "opentelemetry", "pydantic", "openai-agents"] dev = [ { name = "cibuildwheel", specifier = ">=2.22.0,<3" }, { name = "grpcio-tools", specifier = ">=1.48.2,<2" }, + { name = "httpx", specifier = ">=0.28.1" }, { name = "maturin", specifier = ">=1.8.2" }, { name = "mypy", specifier = "==1.4.1" }, { name = "mypy-protobuf", specifier = ">=3.3.0,<4" }, { name = "psutil", specifier = ">=5.9.3,<6" }, { name = "pydocstyle", specifier = ">=6.3.0,<7" }, { name = "pydoctor", specifier = ">=24.11.1,<25" }, - { name = "pyright", specifier = "==1.1.377" }, + { name = "pyright", specifier = "==1.1.400" }, { name = "pytest", specifier = "~=7.4" }, { name = "pytest-asyncio", specifier = ">=0.21,<0.22" }, + { name = "pytest-cov", specifier = ">=6.1.1" }, { name = "pytest-pretty", specifier = ">=1.3.0" }, { name = "pytest-timeout", specifier = "~=2.2" }, { name = "ruff", specifier = ">=0.5.0,<0.6" },