diff --git a/docs/docs/cheatsheet.md b/docs/docs/cheatsheet.md index 34f8a83e1d..3a34060895 100644 --- a/docs/docs/cheatsheet.md +++ b/docs/docs/cheatsheet.md @@ -6,6 +6,8 @@ sidebar_position: 999 This page will contain snippets for frequent usage patterns. +DSPy supports global retry defaults for LM calls via settings so you can centralize transient error handling without adding extra retry layers. Configure once at startup with `dspy.configure(default_num_retries=5, retry_strategy="exponential_backoff_retry")`, override per instance where needed with `dspy.LM(model="openai/gpt-4o-mini", num_retries=1)`, or scope temporary settings to a block using `with dspy.context(default_num_retries=8, retry_strategy="exponential_backoff_retry"):`. The strategy is forwarded to LiteLLM (default remains `"exponential_backoff_retry"`), while per-instance `num_retries` takes precedence. + ## DSPy Programs ### dspy.Signature diff --git a/docs/docs/tutorials/global_retries/index.ipynb b/docs/docs/tutorials/global_retries/index.ipynb new file mode 100644 index 0000000000..de5481573f --- /dev/null +++ b/docs/docs/tutorials/global_retries/index.ipynb @@ -0,0 +1,64 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7437ca87", + "metadata": {}, + "source": [ + "# Setting global retry defaults" + ] + }, + { + "cell_type": "markdown", + "id": "33342724", + "metadata": {}, + "source": [ + "\n", + "Global retry defaults let you configure transient error handling once and keep behavior consistent across modules. Configure with `dspy.configure(default_num_retries=4, retry_strategy=\"exponential_backoff_retry\")`, override per instance with `dspy.LM(..., num_retries=1)`, or scope a block using `with dspy.context(default_num_retries=6):`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1477680", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "\n", + "import dspy\n", + "\n", + "dspy.configure(default_num_retries=4, retry_strategy=\"exponential_backoff_retry\")\n", + "\n", + "lm_default = dspy.LM(model=\"openai/gpt-4o-mini\", cache=False, num_retries=None)\n", + "print(lm_default.dump_state()[\"num_retries\"]) # 4 retries by default\n", + "\n", + "with dspy.context(default_num_retries=6):\n", + " lm_scoped = dspy.LM(model=\"openai/gpt-4o-mini\", cache=False, num_retries=None)\n", + " print(lm_scoped.dump_state()[\"num_retries\"]) # 6 retries in this scope\n", + "\n", + "lm_override = dspy.LM(model=\"openai/gpt-4o-mini\", cache=False, num_retries=1)\n", + "print(lm_override.dump_state()[\"num_retries\"]) # 1 retry for this instance" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + }, + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index abc97a845b..52f39043f2 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -35,7 +35,7 @@ def __init__( cache: bool = True, cache_in_memory: bool = True, callbacks: list[BaseCallback] | None = None, - num_retries: int = 3, + num_retries: int | None = None, provider: Provider | None = None, finetuning_model: str | None = None, launch_kwargs: dict[str, Any] | None = None, @@ -70,7 +70,9 @@ def __init__( self.provider = provider or self.infer_provider() self.callbacks = callbacks or [] self.history = [] - self.num_retries = num_retries + self.num_retries = ( + num_retries if num_retries is not None else dspy.settings.get("default_num_retries", 3) + ) self.finetuning_model = finetuning_model self.launch_kwargs = launch_kwargs or {} self.train_kwargs = train_kwargs or {} @@ -304,10 +306,11 @@ def litellm_completion(request: dict[str, Any], num_retries: int, cache: dict[st cache = cache or {"no-cache": True, "no-store": True} stream_completion = _get_stream_completion_fn(request, cache, sync=True) if stream_completion is None: + strategy = dspy.settings.get("retry_strategy", "exponential_backoff_retry") return litellm.completion( cache=cache, num_retries=num_retries, - retry_strategy="exponential_backoff_retry", + retry_strategy=strategy, **request, ) @@ -328,6 +331,7 @@ def litellm_text_completion(request: dict[str, Any], num_retries: int, cache: di # Build the prompt from the messages. prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"]) + strategy = dspy.settings.get("retry_strategy", "exponential_backoff_retry") return litellm.text_completion( cache=cache, model=f"text-completion-openai/{model}", @@ -335,7 +339,7 @@ def litellm_text_completion(request: dict[str, Any], num_retries: int, cache: di api_base=api_base, prompt=prompt, num_retries=num_retries, - retry_strategy="exponential_backoff_retry", + retry_strategy=strategy, **request, ) @@ -344,10 +348,11 @@ async def alitellm_completion(request: dict[str, Any], num_retries: int, cache: cache = cache or {"no-cache": True, "no-store": True} stream_completion = _get_stream_completion_fn(request, cache, sync=False) if stream_completion is None: + strategy = dspy.settings.get("retry_strategy", "exponential_backoff_retry") return await litellm.acompletion( cache=cache, num_retries=num_retries, - retry_strategy="exponential_backoff_retry", + retry_strategy=strategy, **request, ) @@ -366,6 +371,7 @@ async def alitellm_text_completion(request: dict[str, Any], num_retries: int, ca # Build the prompt from the messages. prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"]) + strategy = dspy.settings.get("retry_strategy", "exponential_backoff_retry") return await litellm.atext_completion( cache=cache, model=f"text-completion-openai/{model}", @@ -373,6 +379,6 @@ async def alitellm_text_completion(request: dict[str, Any], num_retries: int, ca api_base=api_base, prompt=prompt, num_retries=num_retries, - retry_strategy="exponential_backoff_retry", + retry_strategy=strategy, **request, ) diff --git a/dspy/dsp/utils/settings.py b/dspy/dsp/utils/settings.py index f5319a9f9f..561b810c5b 100644 --- a/dspy/dsp/utils/settings.py +++ b/dspy/dsp/utils/settings.py @@ -28,6 +28,8 @@ allow_tool_async_sync_conversion=False, max_history_size=10000, max_trace_size=10000, + default_num_retries=3, + retry_strategy="exponential_backoff_retry", ) # Global base configuration and owner tracking diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py index 8aa4c89dc9..ca1e4d3f09 100644 --- a/tests/clients/test_lm.py +++ b/tests/clients/test_lm.py @@ -171,6 +171,47 @@ def test_retry_number_set_correctly(): assert mock_completion.call_args.kwargs["num_retries"] == 3 +def test_global_default_num_retries_when_none(): + from litellm.utils import Choices, Message, ModelResponse + + with dspy.context(default_num_retries=4): + lm = dspy.LM(model="openai/dspy-test-model", cache=False, num_retries=None) + with mock.patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse( + choices=[Choices(message=Message(content="answer"))], model="openai/dspy-test-model" + ) + lm("query") + + assert mock_completion.call_args.kwargs["num_retries"] == 4 + + +def test_instance_num_retries_overrides_global(): + from litellm.utils import Choices, Message, ModelResponse + + with dspy.context(default_num_retries=2): + lm = dspy.LM(model="openai/dspy-test-model", cache=False, num_retries=7) + with mock.patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse( + choices=[Choices(message=Message(content="answer"))], model="openai/dspy-test-model" + ) + lm("query") + + assert mock_completion.call_args.kwargs["num_retries"] == 7 + + +def test_retry_strategy_from_settings(): + from litellm.utils import Choices, Message, ModelResponse + + with dspy.context(retry_strategy="custom_strategy"): + lm = dspy.LM(model="openai/dspy-test-model", cache=False) + with mock.patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse( + choices=[Choices(message=Message(content="answer"))], model="openai/dspy-test-model" + ) + lm("query") + + assert mock_completion.call_args.kwargs["retry_strategy"] == "custom_strategy" + def test_retry_made_on_system_errors(): retry_tracking = [0] # Using a list to track retries