From 0327e15373edc16b8a5b18ae84bf6c24b6a2e32c Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Wed, 27 Aug 2025 18:29:39 +0900 Subject: [PATCH 01/18] Add Anthropic citation support --- dspy/__init__.py | 2 +- dspy/adapters/__init__.py | 4 +- dspy/adapters/base.py | 21 ++- dspy/adapters/types/__init__.py | 4 +- dspy/adapters/types/citation.py | 164 +++++++++++++++++++++++ dspy/adapters/types/document.py | 107 +++++++++++++++ dspy/clients/base_lm.py | 68 ++++++++++ tests/adapters/test_citation.py | 225 ++++++++++++++++++++++++++++++++ tests/adapters/test_document.py | 55 ++++++++ 9 files changed, 646 insertions(+), 4 deletions(-) create mode 100644 dspy/adapters/types/citation.py create mode 100644 dspy/adapters/types/document.py create mode 100644 tests/adapters/test_citation.py create mode 100644 tests/adapters/test_document.py diff --git a/dspy/__init__.py b/dspy/__init__.py index ea4c75a862..8b21238c2f 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -6,7 +6,7 @@ from dspy.evaluate import Evaluate # isort: skip from dspy.clients import * # isort: skip -from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, TwoStepAdapter, Image, Audio, History, Type, Tool, ToolCalls, Code # isort: skip +from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, TwoStepAdapter, Image, Audio, History, Type, Tool, ToolCalls, Code, Citations, Document # isort: skip from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging from dspy.utils.asyncify import asyncify from dspy.utils.syncify import syncify diff --git a/dspy/adapters/__init__.py b/dspy/adapters/__init__.py index 1dea6da47a..79edd1ec53 100644 --- a/dspy/adapters/__init__.py +++ b/dspy/adapters/__init__.py @@ -2,7 +2,7 @@ from dspy.adapters.chat_adapter import ChatAdapter from dspy.adapters.json_adapter import JSONAdapter from dspy.adapters.two_step_adapter import TwoStepAdapter -from dspy.adapters.types import Audio, Code, History, Image, Tool, ToolCalls, Type +from dspy.adapters.types import Audio, Citations, Code, Document, History, Image, Tool, ToolCalls, Type from dspy.adapters.xml_adapter import XMLAdapter __all__ = [ @@ -13,6 +13,8 @@ "Image", "Audio", "Code", + "Citations", + "Document", "JSONAdapter", "XMLAdapter", "TwoStepAdapter", diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index eaea79563c..cf8cfc0d8c 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -4,7 +4,7 @@ import json_repair import litellm -from dspy.adapters.types import History +from dspy.adapters.types import Citations, History from dspy.adapters.types.base_type import split_message_content_for_custom_types from dspy.adapters.types.tool import Tool, ToolCalls from dspy.signatures.signature import Signature @@ -74,16 +74,19 @@ def _call_postprocess( values = [] tool_call_output_field_name = self._get_tool_call_output_field_name(original_signature) + citation_output_field_names = self._get_citation_output_field_names(original_signature) for output in outputs: output_logprobs = None tool_calls = None + citations = None text = output if isinstance(output, dict): text = output["text"] output_logprobs = output.get("logprobs") tool_calls = output.get("tool_calls") + citations = output.get("citations") if text: value = self.parse(processed_signature, text) @@ -106,6 +109,14 @@ def _call_postprocess( ] value[tool_call_output_field_name] = ToolCalls.from_dict_list(tool_calls) + if citations and citation_output_field_names: + # Convert citation dictionaries to Citations object using from_dict_list + citations_obj = Citations.from_dict_list(citations) + + # Assign to all citation fields found + for field_name in citation_output_field_names: + value[field_name] = citations_obj + if output_logprobs: value["logprobs"] = output_logprobs @@ -390,6 +401,14 @@ def _get_tool_call_output_field_name(self, signature: type[Signature]) -> bool: return name return None + def _get_citation_output_field_names(self, signature: type[Signature]) -> list[str]: + """Find all Citations output fields in the signature.""" + citation_fields = [] + for name, field in signature.output_fields.items(): + if field.annotation == Citations: + citation_fields.append(name) + return citation_fields + def format_conversation_history( self, signature: type[Signature], diff --git a/dspy/adapters/types/__init__.py b/dspy/adapters/types/__init__.py index 11b9faee1b..e7f28faadb 100644 --- a/dspy/adapters/types/__init__.py +++ b/dspy/adapters/types/__init__.py @@ -1,8 +1,10 @@ from dspy.adapters.types.audio import Audio from dspy.adapters.types.base_type import Type +from dspy.adapters.types.citation import Citations from dspy.adapters.types.code import Code +from dspy.adapters.types.document import Document from dspy.adapters.types.history import History from dspy.adapters.types.image import Image from dspy.adapters.types.tool import Tool, ToolCalls -__all__ = ["History", "Image", "Audio", "Type", "Tool", "ToolCalls", "Code"] +__all__ = ["History", "Image", "Audio", "Type", "Tool", "ToolCalls", "Code", "Citations", "Document"] diff --git a/dspy/adapters/types/citation.py b/dspy/adapters/types/citation.py new file mode 100644 index 0000000000..cd62e82ea4 --- /dev/null +++ b/dspy/adapters/types/citation.py @@ -0,0 +1,164 @@ +from typing import Any + +import pydantic + +from dspy.adapters.types.base_type import Type + + +class Citations(Type): + """Citations extracted from an LM response with source references. + + This type represents citations returned by language models that support + citation extraction, particularly Anthropic's Citations API through LiteLLM. + Citations include the quoted text and source information. + + Example: + ```python + import dspy + from dspy.signatures import Signature + + class AnswerWithSources(Signature): + '''Answer questions using provided documents with citations.''' + documents: list[dspy.Document] = dspy.InputField() + question: str = dspy.InputField() + answer: str = dspy.OutputField() + citations: dspy.Citations = dspy.OutputField() + + # Create documents to provide as sources + docs = [ + dspy.Document( + content="The Earth orbits the Sun in an elliptical path.", + title="Basic Astronomy Facts" + ), + dspy.Document( + content="Water boils at 100°C at standard atmospheric pressure.", + title="Physics Fundamentals", + metadata={"author": "Dr. Smith", "year": 2023} + ) + ] + + # Use with a model that supports citations like Claude + lm = dspy.LM("anthropic/claude-3-5-sonnet-20241022") + predictor = dspy.Predict(AnswerWithSources, lm=lm) + result = predictor(documents=docs, question="What temperature does water boil?") + + for citation in result.citations.citations: + print(citation.format()) + ``` + """ + + class Citation(Type): + """Individual citation with text and source information.""" + text: str + source: str | dict[str, Any] | None = None + start: int | None = None + end: int | None = None + + def format(self) -> str: + """Format citation as a readable string. + + Returns: + A formatted citation string showing the quoted text and source. + """ + source_info = "" + if self.source: + if isinstance(self.source, str): + source_info = f" ({self.source})" + elif isinstance(self.source, dict): + title = self.source.get("title", "") + url = self.source.get("url", "") + if title and url: + source_info = f" ({title}: {url})" + elif title: + source_info = f" ({title})" + elif url: + source_info = f" ({url})" + + return f'"{self.text}"{source_info}' + + citations: list[Citation] + + @classmethod + def from_dict_list(cls, citations_dicts: list[dict[str, Any]]) -> "Citations": + """Convert a list of dictionaries to a Citations instance. + + Args: + citations_dicts: A list of dictionaries, where each dictionary should have 'text' key + and optionally 'source', 'start', and 'end' keys. + + Returns: + A Citations instance. + + Example: + ```python + citations_dict = [ + {"text": "The sky is blue", "source": "Weather Guide"}, + {"text": "Water boils at 100°C", "source": {"title": "Physics Book", "url": "http://example.com"}} + ] + citations = Citations.from_dict_list(citations_dict) + ``` + """ + citations = [cls.Citation(**item) for item in citations_dicts] + return cls(citations=citations) + + @classmethod + def description(cls) -> str: + """Description of the citations type for use in prompts.""" + return ( + "Citations with quoted text and source references. " + "Include the exact text being cited and information about its source." + ) + + def format(self) -> list[dict[str, Any]]: + """Format citations as a list of dictionaries.""" + return [ + { + "text": citation.text, + "source": citation.source, + "start": citation.start, + "end": citation.end, + } + for citation in self.citations + ] + + @pydantic.model_validator(mode="before") + @classmethod + def validate_input(cls, data: Any): + if isinstance(data, cls): + return data + + # Handle case where data is a list of dicts with citation info + if isinstance(data, list) and all( + isinstance(item, dict) and "text" in item for item in data + ): + return {"citations": [cls.Citation(**item) for item in data]} + + # Handle case where data is a dict + elif isinstance(data, dict): + if "citations" in data: + # Handle case where data is a dict with "citations" key + citations_data = data["citations"] + if isinstance(citations_data, list): + return { + "citations": [ + cls.Citation(**item) if isinstance(item, dict) else item + for item in citations_data + ] + } + elif "text" in data: + # Handle case where data is a single citation dict + return {"citations": [cls.Citation(**data)]} + + raise ValueError(f"Received invalid value for `dspy.Citations`: {data}") + + def __iter__(self): + """Allow iteration over citations.""" + return iter(self.citations) + + def __len__(self): + """Return the number of citations.""" + return len(self.citations) + + def __getitem__(self, index): + """Allow indexing into citations.""" + return self.citations[index] diff --git a/dspy/adapters/types/document.py b/dspy/adapters/types/document.py new file mode 100644 index 0000000000..57fa257a9b --- /dev/null +++ b/dspy/adapters/types/document.py @@ -0,0 +1,107 @@ +from typing import Any, Literal + +import pydantic + +from dspy.adapters.types.base_type import Type + + +class Document(Type): + """A document type for providing content that can be cited by language models. + + This type represents documents that can be passed to language models for citation-enabled + responses, particularly useful with Anthropic's Citations API. Documents include the content + and metadata that helps the LM understand and reference the source material. + + Attributes: + data: The text content of the document + title: Optional title for the document (used in citations) + media_type: MIME type of the document content (defaults to "text/plain") + metadata: Optional additional metadata about the document + context: Optional context information about the document + + Example: + ```python + import dspy + from dspy.signatures import Signature + + class AnswerWithSources(Signature): + '''Answer questions using provided documents with citations.''' + documents: list[dspy.Document] = dspy.InputField() + question: str = dspy.InputField() + answer: str = dspy.OutputField() + citations: dspy.Citations = dspy.OutputField() + + # Create documents + docs = [ + dspy.Document( + data="The Earth orbits the Sun in an elliptical path.", + title="Basic Astronomy Facts" + ), + dspy.Document( + data="Water boils at 100°C at standard atmospheric pressure.", + title="Physics Fundamentals", + metadata={"author": "Dr. Smith", "year": 2023} + ) + ] + + # Use with a citation-supporting model + lm = dspy.LM("anthropic/claude-opus-4-1-20250805") + predictor = dspy.Predict(AnswerWithSources) + result = predictor(documents=docs, question="What temperature does water boil?", lm=lm) + print(result.citations) + ``` + """ + + data: str + title: str | None = None + media_type: Literal["text/plain", "application/pdf"] = "text/plain" + metadata: dict[str, Any] | None = None + context: str | None = None + + def format(self) -> list[dict[str, Any]]: + """Format document for LM consumption. + + Returns: + A list containing the document block in the format expected by citation-enabled language models. + """ + return { + "type": "document", + "source": { + "type": "text", + "media_type": self.media_type, + "data": self.data + }, + "citations": {"enabled": True}, + "title": self.title + } + + + + @classmethod + def description(cls) -> str: + """Description of the document type for use in prompts.""" + return ( + "A document containing text content that can be referenced and cited. " + "Include the full text content and optionally a title for proper referencing." + ) + + @pydantic.model_validator(mode="before") + @classmethod + def validate_input(cls, data: Any): + if isinstance(data, cls): + return data + + # Handle case where data is just a string (data only) + if isinstance(data, str): + return {"data": data} + + # Handle case where data is a dict + elif isinstance(data, dict): + return data + + raise ValueError(f"Received invalid value for `dspy.Document`: {data}") + + def __str__(self) -> str: + """String representation showing title and content length.""" + title_part = f"'{self.title}': " if self.title else "" + return f"Document({title_part}{len(self.data)} chars)" diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 6f86da5632..97c6d84bbd 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -183,6 +183,12 @@ def _process_completion(self, response, merged_kwargs): output["logprobs"] = c.logprobs if hasattr(c, "logprobs") else c["logprobs"] if hasattr(c, "message") and getattr(c.message, "tool_calls", None): output["tool_calls"] = c.message.tool_calls + + # Extract citations from LiteLLM response if available + citations = self._extract_citations_from_response(response, c) + if citations: + output["citations"] = citations + outputs.append(output) if all(len(output) == 1 for output in outputs): @@ -191,6 +197,68 @@ def _process_completion(self, response, merged_kwargs): return outputs + def _extract_citations_from_response(self, response, choice): + """Extract citations from LiteLLM response if available. + + Args: + response: The LiteLLM response object + choice: The choice object from response.choices + + Returns: + List of citation dictionaries or None if no citations found + """ + try: + # Check for citations in LiteLLM provider_specific_fields + if hasattr(response, "choices") and hasattr(choice, "message"): + message = choice.message + # Check for citations in provider_specific_fields (Anthropic format) + if hasattr(message, "provider_specific_fields") and message.provider_specific_fields: + provider_fields = message.provider_specific_fields + if isinstance(provider_fields, dict) and "citations" in provider_fields: + citations_data = provider_fields["citations"] + if isinstance(citations_data, list): + citations = [] + for citation_data in citations_data: + citation_dict = { + "text": citation_data.get("quote", ""), + "source": citation_data.get("source"), + "start": citation_data.get("start"), + "end": citation_data.get("end"), + } + citations.append(citation_dict) + return citations + + # Check for citations directly in the message (fallback) + if hasattr(message, "citations") and message.citations: + citations_data = message.citations + if isinstance(citations_data, list): + citations = [] + for citation_data in citations_data: + if hasattr(citation_data, "quote"): + citation_dict = { + "text": citation_data.quote, + "source": getattr(citation_data, "source", None), + "start": getattr(citation_data, "start", None), + "end": getattr(citation_data, "end", None), + } + elif isinstance(citation_data, dict): + citation_dict = { + "text": citation_data.get("quote", citation_data.get("text", "")), + "source": citation_data.get("source"), + "start": citation_data.get("start"), + "end": citation_data.get("end"), + } + else: + continue + citations.append(citation_dict) + return citations + + except Exception: + # If citation extraction fails, just continue without citations + pass + + return None + def _process_response(self, response): """Process the response of OpenAI Response API and extract outputs. diff --git a/tests/adapters/test_citation.py b/tests/adapters/test_citation.py new file mode 100644 index 0000000000..838abe40e6 --- /dev/null +++ b/tests/adapters/test_citation.py @@ -0,0 +1,225 @@ +from unittest.mock import MagicMock + +import dspy +from dspy.adapters.types import Citations +from dspy.signatures.signature import Signature + + +class CitationSignature(Signature): + """Test signature with citations.""" + question: str = dspy.InputField() + answer: str = dspy.OutputField() + citations: Citations = dspy.OutputField() + + +def test_citation_type(): + """Test the individual Citation type.""" + citation = Citations.Citation( + text="The sky is blue", + source="Weather Guide", + start=10, + end=25 + ) + + assert citation.text == "The sky is blue" + assert citation.source == "Weather Guide" + assert citation.start == 10 + assert citation.end == 25 + + # Test formatting + formatted = citation.format() + assert '"The sky is blue"' in formatted + assert "Weather Guide" in formatted + + +def test_citation_type_with_dict_source(): + """Test Citation with dict source.""" + citation = Citations.Citation( + text="The sky is blue", + source={"title": "Weather Guide", "url": "http://example.com"} + ) + + formatted = citation.format() + assert '"The sky is blue"' in formatted + assert "Weather Guide" in formatted + assert "http://example.com" in formatted + + +def test_citations_container_type(): + """Test the Citations container type.""" + citations_data = [ + {"text": "The sky is blue", "source": "Weather Guide"}, + {"text": "Water boils at 100°C", "source": "Physics Book"} + ] + + citations = Citations.from_dict_list(citations_data) + + assert len(citations) == 2 + assert citations[0].text == "The sky is blue" + assert citations[1].text == "Water boils at 100°C" + + # Test iteration + citation_texts = [c.text for c in citations] + assert "The sky is blue" in citation_texts + assert "Water boils at 100°C" in citation_texts + + +def test_citations_description(): + """Test Citations description method.""" + desc = Citations.description() + assert "citation" in desc.lower() + assert "source" in desc.lower() + + +def test_citations_format(): + """Test Citations format method.""" + citations_data = [ + {"text": "The sky is blue", "source": "Weather Guide", "start": 10, "end": 25} + ] + citations = Citations.from_dict_list(citations_data) + + formatted = citations.format() + assert isinstance(formatted, list) + assert len(formatted) == 1 + assert formatted[0]["text"] == "The sky is blue" + assert formatted[0]["source"] == "Weather Guide" + + +def test_citations_validation(): + """Test Citations validation with different input formats.""" + # Test with from_dict_list + citations1 = Citations.from_dict_list([{"text": "Hello", "source": "World"}]) + assert len(citations1) == 1 + + # Test direct construction with citations list + citations2 = Citations(citations=[Citations.Citation(text="Hello", source="World")]) + assert len(citations2) == 1 + + +def test_citation_field_detection(): + """Test that Citations fields are properly detected by adapter.""" + from dspy.adapters.chat_adapter import ChatAdapter + + adapter = ChatAdapter() + + # Test citation field detection + citation_fields = adapter._get_citation_output_field_names(CitationSignature) + assert "citations" in citation_fields + + +def test_citation_extraction_from_lm_response(): + """Test citation extraction from mock LM response.""" + from dspy.clients.base_lm import BaseLM + + # Create a mock response with citations + mock_response = MagicMock() + mock_choice = MagicMock() + mock_message = MagicMock() + + # Mock provider_specific_fields with citations (Anthropic format) + mock_message.provider_specific_fields = { + "citations": [ + { + "quote": "The sky is blue", + "source": "Weather Guide", + "start": 10, + "end": 25 + } + ] + } + + mock_choice.message = mock_message + mock_response.choices = [mock_choice] + + # Create BaseLM instance and test citation extraction + lm = BaseLM(model="test") + citations = lm._extract_citations_from_response(mock_response, mock_choice) + + assert citations is not None + assert len(citations) == 1 + assert citations[0]["text"] == "The sky is blue" + assert citations[0]["source"] == "Weather Guide" + assert citations[0]["start"] == 10 + assert citations[0]["end"] == 25 + + +def test_citations_postprocessing(): + """Test that citations are properly processed in adapter postprocessing.""" + from dspy.adapters.chat_adapter import ChatAdapter + + adapter = ChatAdapter() + + # Mock outputs with citations - need valid parsed text in ChatAdapter format + outputs = [{ + "text": "[[ ## answer ## ]]\nThe answer is blue.\n\n[[ ## citations ## ]]\n[]", + "citations": [ + { + "text": "The sky is blue", + "source": "Weather Guide", + "start": 10, + "end": 25 + } + ] + }] + + # Process with citation signature + result = adapter._call_postprocess( + CitationSignature, + CitationSignature, + outputs + ) + + # Should have Citations object in the result + assert len(result) == 1 + assert "citations" in result[0] + assert isinstance(result[0]["citations"], Citations) + assert len(result[0]["citations"]) == 1 + assert result[0]["citations"][0].text == "The sky is blue" + + +def test_citations_without_citations(): + """Test that processing works when no citations are present.""" + from dspy.adapters.chat_adapter import ChatAdapter + + adapter = ChatAdapter() + + # Mock outputs without citations + outputs = [{ + "text": "[[ ## answer ## ]]\nThe answer is blue.\n\n[[ ## citations ## ]]\n[]" + }] + + # Process with citation signature + result = adapter._call_postprocess( + CitationSignature, + CitationSignature, + outputs + ) + + # Should still work, with None for citations (since no citations in LM response) + assert len(result) == 1 + # Note: When no citations are in LM response, citations field should be None + # but the field gets set to empty list from parsing. Let's verify it's empty. + assert result[0]["citations"] is None or ( + isinstance(result[0]["citations"], Citations) and len(result[0]["citations"]) == 0 + ) + + +def test_citation_imports(): + """Test that Citations can be imported from main dspy module.""" + assert hasattr(dspy, "Citations") + assert dspy.Citations is Citations + + # Individual Citation should only be accessible via Citations.Citation + assert not hasattr(dspy, "Citation") + + +def test_citation_access_pattern(): + """Test that Citation class is accessible as Citations.Citation (like ToolCalls pattern).""" + # Test that we can create Citation objects via Citations.Citation + citation = Citations.Citation(text="Hello", source="World") + assert citation.text == "Hello" + assert citation.source == "World" + + # Test that Citations.Citation is the correct class + assert hasattr(Citations, "Citation") + assert isinstance(citation, Citations.Citation) diff --git a/tests/adapters/test_document.py b/tests/adapters/test_document.py new file mode 100644 index 0000000000..9dbb5bfaad --- /dev/null +++ b/tests/adapters/test_document.py @@ -0,0 +1,55 @@ +import pydantic +import pytest + +import dspy + + +def test_document_validate_input(): + # Create a `dspy.Document` instance with valid data. + doc = dspy.Document(data="The Earth orbits the Sun.") + assert doc.data == "The Earth orbits the Sun." + + with pytest.raises(pydantic.ValidationError): + # Try to create a `dspy.Document` instance with invalid type. + dspy.Document(data=123) + + +def test_document_in_nested_type(): + class Wrapper(pydantic.BaseModel): + document: dspy.Document + + doc = dspy.Document(data="Hello, world!") + wrapper = Wrapper(document=doc) + assert wrapper.document.data == "Hello, world!" + + +def test_document_with_all_fields(): + doc = dspy.Document( + data="Water boils at 100°C at standard pressure.", + title="Physics Facts", + media_type="application/pdf", + metadata={"author": "Dr. Smith", "year": 2023}, + context="Laboratory conditions" + ) + assert doc.data == "Water boils at 100°C at standard pressure." + assert doc.title == "Physics Facts" + assert doc.media_type == "application/pdf" + assert doc.metadata == {"author": "Dr. Smith", "year": 2023} + assert doc.context == "Laboratory conditions" + + +def test_document_format(): + doc = dspy.Document( + data="The sky is blue.", + title="Color Facts", + media_type="text/plain" + ) + + formatted = doc.format() + + assert formatted["type"] == "document" + assert formatted["source"]["type"] == "text" + assert formatted["source"]["media_type"] == "text/plain" + assert formatted["source"]["data"] == "The sky is blue." + assert formatted["title"] == "Color Facts" + assert formatted["citations"]["enabled"] is True From 8d7d050ffe360e3c6858198fdace33f15def4d7b Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Wed, 27 Aug 2025 18:35:49 +0900 Subject: [PATCH 02/18] simplify --- dspy/adapters/base.py | 2 -- dspy/clients/base_lm.py | 27 +-------------------------- tests/adapters/test_document.py | 4 ++-- 3 files changed, 3 insertions(+), 30 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index cf8cfc0d8c..1b43f21338 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -110,10 +110,8 @@ def _call_postprocess( value[tool_call_output_field_name] = ToolCalls.from_dict_list(tool_calls) if citations and citation_output_field_names: - # Convert citation dictionaries to Citations object using from_dict_list citations_obj = Citations.from_dict_list(citations) - # Assign to all citation fields found for field_name in citation_output_field_names: value[field_name] = citations_obj diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 97c6d84bbd..667d5fcdf8 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -199,6 +199,7 @@ def _process_completion(self, response, merged_kwargs): def _extract_citations_from_response(self, response, choice): """Extract citations from LiteLLM response if available. + Reference: https://docs.litellm.ai/docs/providers/anthropic#beta-citations-api Args: response: The LiteLLM response object @@ -227,32 +228,6 @@ def _extract_citations_from_response(self, response, choice): } citations.append(citation_dict) return citations - - # Check for citations directly in the message (fallback) - if hasattr(message, "citations") and message.citations: - citations_data = message.citations - if isinstance(citations_data, list): - citations = [] - for citation_data in citations_data: - if hasattr(citation_data, "quote"): - citation_dict = { - "text": citation_data.quote, - "source": getattr(citation_data, "source", None), - "start": getattr(citation_data, "start", None), - "end": getattr(citation_data, "end", None), - } - elif isinstance(citation_data, dict): - citation_dict = { - "text": citation_data.get("quote", citation_data.get("text", "")), - "source": citation_data.get("source"), - "start": citation_data.get("start"), - "end": citation_data.get("end"), - } - else: - continue - citations.append(citation_dict) - return citations - except Exception: # If citation extraction fails, just continue without citations pass diff --git a/tests/adapters/test_document.py b/tests/adapters/test_document.py index 9dbb5bfaad..dd7d5f2d3f 100644 --- a/tests/adapters/test_document.py +++ b/tests/adapters/test_document.py @@ -44,9 +44,9 @@ def test_document_format(): title="Color Facts", media_type="text/plain" ) - + formatted = doc.format() - + assert formatted["type"] == "document" assert formatted["source"]["type"] == "text" assert formatted["source"]["media_type"] == "text/plain" From 2bb2f02aa4c2be4e702e10f098d87c5c2217b870 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Wed, 27 Aug 2025 18:55:36 +0900 Subject: [PATCH 03/18] fix field names --- dspy/adapters/base.py | 17 +- dspy/adapters/types/citation.py | 72 ++++---- dspy/clients/base_lm.py | 10 +- tests/adapters/test_citation.py | 314 +++++++++++++++----------------- 4 files changed, 191 insertions(+), 222 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 1b43f21338..6fa30a5854 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -74,7 +74,7 @@ def _call_postprocess( values = [] tool_call_output_field_name = self._get_tool_call_output_field_name(original_signature) - citation_output_field_names = self._get_citation_output_field_names(original_signature) + citation_output_field_name = self._get_citation_output_field_name(original_signature) for output in outputs: output_logprobs = None @@ -109,11 +109,9 @@ def _call_postprocess( ] value[tool_call_output_field_name] = ToolCalls.from_dict_list(tool_calls) - if citations and citation_output_field_names: + if citations and citation_output_field_name: citations_obj = Citations.from_dict_list(citations) - - for field_name in citation_output_field_names: - value[field_name] = citations_obj + value[citation_output_field_name] = citations_obj if output_logprobs: value["logprobs"] = output_logprobs @@ -399,13 +397,12 @@ def _get_tool_call_output_field_name(self, signature: type[Signature]) -> bool: return name return None - def _get_citation_output_field_names(self, signature: type[Signature]) -> list[str]: - """Find all Citations output fields in the signature.""" - citation_fields = [] + def _get_citation_output_field_name(self, signature: type[Signature]) -> str | None: + """Find the Citations output field in the signature.""" for name, field in signature.output_fields.items(): if field.annotation == Citations: - citation_fields.append(name) - return citation_fields + return name + return None def format_conversation_history( self, diff --git a/dspy/adapters/types/citation.py b/dspy/adapters/types/citation.py index cd62e82ea4..6d8fa87bdb 100644 --- a/dspy/adapters/types/citation.py +++ b/dspy/adapters/types/citation.py @@ -48,33 +48,32 @@ class AnswerWithSources(Signature): """ class Citation(Type): - """Individual citation with text and source information.""" - text: str - source: str | dict[str, Any] | None = None - start: int | None = None - end: int | None = None + """Individual citation with character location information.""" + type: str = "char_location" + cited_text: str + document_index: int + document_title: str | None = None + start_char_index: int + end_char_index: int - def format(self) -> str: - """Format citation as a readable string. + def format(self) -> dict[str, Any]: + """Format citation as dictionary for LM consumption. Returns: - A formatted citation string showing the quoted text and source. + A dictionary in the format expected by citation APIs. """ - source_info = "" - if self.source: - if isinstance(self.source, str): - source_info = f" ({self.source})" - elif isinstance(self.source, dict): - title = self.source.get("title", "") - url = self.source.get("url", "") - if title and url: - source_info = f" ({title}: {url})" - elif title: - source_info = f" ({title})" - elif url: - source_info = f" ({url})" - - return f'"{self.text}"{source_info}' + citation_dict = { + "type": self.type, + "cited_text": self.cited_text, + "document_index": self.document_index, + "start_char_index": self.start_char_index, + "end_char_index": self.end_char_index + } + + if self.document_title: + citation_dict["document_title"] = self.document_title + + return citation_dict citations: list[Citation] @@ -83,8 +82,8 @@ def from_dict_list(cls, citations_dicts: list[dict[str, Any]]) -> "Citations": """Convert a list of dictionaries to a Citations instance. Args: - citations_dicts: A list of dictionaries, where each dictionary should have 'text' key - and optionally 'source', 'start', and 'end' keys. + citations_dicts: A list of dictionaries, where each dictionary should have 'cited_text' key + and 'document_index', 'start_char_index', 'end_char_index' keys. Returns: A Citations instance. @@ -92,8 +91,13 @@ def from_dict_list(cls, citations_dicts: list[dict[str, Any]]) -> "Citations": Example: ```python citations_dict = [ - {"text": "The sky is blue", "source": "Weather Guide"}, - {"text": "Water boils at 100°C", "source": {"title": "Physics Book", "url": "http://example.com"}} + { + "cited_text": "The sky is blue", + "document_index": 0, + "document_title": "Weather Guide", + "start_char_index": 0, + "end_char_index": 15 + } ] citations = Citations.from_dict_list(citations_dict) ``` @@ -111,15 +115,7 @@ def description(cls) -> str: def format(self) -> list[dict[str, Any]]: """Format citations as a list of dictionaries.""" - return [ - { - "text": citation.text, - "source": citation.source, - "start": citation.start, - "end": citation.end, - } - for citation in self.citations - ] + return [citation.format() for citation in self.citations] @pydantic.model_validator(mode="before") @classmethod @@ -129,7 +125,7 @@ def validate_input(cls, data: Any): # Handle case where data is a list of dicts with citation info if isinstance(data, list) and all( - isinstance(item, dict) and "text" in item for item in data + isinstance(item, dict) and "cited_text" in item for item in data ): return {"citations": [cls.Citation(**item) for item in data]} @@ -145,7 +141,7 @@ def validate_input(cls, data: Any): for item in citations_data ] } - elif "text" in data: + elif "cited_text" in data: # Handle case where data is a single citation dict return {"citations": [cls.Citation(**data)]} diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 667d5fcdf8..513f980273 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -221,10 +221,12 @@ def _extract_citations_from_response(self, response, choice): citations = [] for citation_data in citations_data: citation_dict = { - "text": citation_data.get("quote", ""), - "source": citation_data.get("source"), - "start": citation_data.get("start"), - "end": citation_data.get("end"), + "type": citation_data.get("type", ""), + "cited_text": citation_data.get("cited_text", ""), + "document_index": citation_data.get("document_index", 0), + "document_title": citation_data.get("document_title"), + "start_char_index": citation_data.get("start_char_index", 0), + "end_char_index": citation_data.get("end_char_index", 0), } citations.append(citation_dict) return citations diff --git a/tests/adapters/test_citation.py b/tests/adapters/test_citation.py index 838abe40e6..1d7bbedc7d 100644 --- a/tests/adapters/test_citation.py +++ b/tests/adapters/test_citation.py @@ -1,151 +1,131 @@ -from unittest.mock import MagicMock +import pydantic +import pytest import dspy -from dspy.adapters.types import Citations -from dspy.signatures.signature import Signature -class CitationSignature(Signature): - """Test signature with citations.""" - question: str = dspy.InputField() - answer: str = dspy.OutputField() - citations: Citations = dspy.OutputField() - - -def test_citation_type(): - """Test the individual Citation type.""" - citation = Citations.Citation( - text="The sky is blue", - source="Weather Guide", - start=10, - end=25 +def test_citation_validate_input(): + # Create a `dspy.Citations.Citation` instance with valid data. + citation = dspy.Citations.Citation( + cited_text="The Earth orbits the Sun.", + document_index=0, + start_char_index=0, + end_char_index=23 ) - - assert citation.text == "The sky is blue" - assert citation.source == "Weather Guide" - assert citation.start == 10 - assert citation.end == 25 - - # Test formatting - formatted = citation.format() - assert '"The sky is blue"' in formatted - assert "Weather Guide" in formatted - - -def test_citation_type_with_dict_source(): - """Test Citation with dict source.""" - citation = Citations.Citation( - text="The sky is blue", - source={"title": "Weather Guide", "url": "http://example.com"} + assert citation.cited_text == "The Earth orbits the Sun." + assert citation.document_index == 0 + assert citation.start_char_index == 0 + assert citation.end_char_index == 23 + assert citation.type == "char_location" + + with pytest.raises(pydantic.ValidationError): + # Try to create a `dspy.Citations.Citation` instance with missing required field. + dspy.Citations.Citation(cited_text="text") + + +def test_citations_in_nested_type(): + class Wrapper(pydantic.BaseModel): + citations: dspy.Citations + + citation = dspy.Citations.Citation( + cited_text="Hello, world!", + document_index=0, + start_char_index=0, + end_char_index=13 + ) + citations = dspy.Citations(citations=[citation]) + wrapper = Wrapper(citations=citations) + assert wrapper.citations.citations[0].cited_text == "Hello, world!" + + +def test_citation_with_all_fields(): + citation = dspy.Citations.Citation( + cited_text="Water boils at 100°C.", + document_index=1, + document_title="Physics Facts", + start_char_index=10, + end_char_index=31 + ) + assert citation.cited_text == "Water boils at 100°C." + assert citation.document_index == 1 + assert citation.document_title == "Physics Facts" + assert citation.start_char_index == 10 + assert citation.end_char_index == 31 + + +def test_citation_format(): + citation = dspy.Citations.Citation( + cited_text="The sky is blue.", + document_index=0, + document_title="Weather Guide", + start_char_index=5, + end_char_index=21 ) formatted = citation.format() - assert '"The sky is blue"' in formatted - assert "Weather Guide" in formatted - assert "http://example.com" in formatted - -def test_citations_container_type(): - """Test the Citations container type.""" - citations_data = [ - {"text": "The sky is blue", "source": "Weather Guide"}, - {"text": "Water boils at 100°C", "source": "Physics Book"} - ] - - citations = Citations.from_dict_list(citations_data) - - assert len(citations) == 2 - assert citations[0].text == "The sky is blue" - assert citations[1].text == "Water boils at 100°C" - - # Test iteration - citation_texts = [c.text for c in citations] - assert "The sky is blue" in citation_texts - assert "Water boils at 100°C" in citation_texts - - -def test_citations_description(): - """Test Citations description method.""" - desc = Citations.description() - assert "citation" in desc.lower() - assert "source" in desc.lower() + assert formatted["type"] == "char_location" + assert formatted["cited_text"] == "The sky is blue." + assert formatted["document_index"] == 0 + assert formatted["document_title"] == "Weather Guide" + assert formatted["start_char_index"] == 5 + assert formatted["end_char_index"] == 21 def test_citations_format(): - """Test Citations format method.""" - citations_data = [ - {"text": "The sky is blue", "source": "Weather Guide", "start": 10, "end": 25} - ] - citations = Citations.from_dict_list(citations_data) + citations = dspy.Citations(citations=[ + dspy.Citations.Citation( + cited_text="First citation", + document_index=0, + start_char_index=0, + end_char_index=14 + ), + dspy.Citations.Citation( + cited_text="Second citation", + document_index=1, + document_title="Source", + start_char_index=20, + end_char_index=35 + ) + ]) formatted = citations.format() - assert isinstance(formatted, list) - assert len(formatted) == 1 - assert formatted[0]["text"] == "The sky is blue" - assert formatted[0]["source"] == "Weather Guide" - - -def test_citations_validation(): - """Test Citations validation with different input formats.""" - # Test with from_dict_list - citations1 = Citations.from_dict_list([{"text": "Hello", "source": "World"}]) - assert len(citations1) == 1 - - # Test direct construction with citations list - citations2 = Citations(citations=[Citations.Citation(text="Hello", source="World")]) - assert len(citations2) == 1 + assert isinstance(formatted, list) + assert len(formatted) == 2 + assert formatted[0]["cited_text"] == "First citation" + assert formatted[1]["cited_text"] == "Second citation" + assert formatted[1]["document_title"] == "Source" -def test_citation_field_detection(): - """Test that Citations fields are properly detected by adapter.""" - from dspy.adapters.chat_adapter import ChatAdapter - - adapter = ChatAdapter() - - # Test citation field detection - citation_fields = adapter._get_citation_output_field_names(CitationSignature) - assert "citations" in citation_fields - - -def test_citation_extraction_from_lm_response(): - """Test citation extraction from mock LM response.""" - from dspy.clients.base_lm import BaseLM - - # Create a mock response with citations - mock_response = MagicMock() - mock_choice = MagicMock() - mock_message = MagicMock() - - # Mock provider_specific_fields with citations (Anthropic format) - mock_message.provider_specific_fields = { - "citations": [ - { - "quote": "The sky is blue", - "source": "Weather Guide", - "start": 10, - "end": 25 - } - ] - } - mock_choice.message = mock_message - mock_response.choices = [mock_choice] +def test_citations_from_dict_list(): + citations_data = [ + { + "cited_text": "The sky is blue", + "document_index": 0, + "document_title": "Weather Guide", + "start_char_index": 0, + "end_char_index": 15 + } + ] - # Create BaseLM instance and test citation extraction - lm = BaseLM(model="test") - citations = lm._extract_citations_from_response(mock_response, mock_choice) + citations = dspy.Citations.from_dict_list(citations_data) - assert citations is not None - assert len(citations) == 1 - assert citations[0]["text"] == "The sky is blue" - assert citations[0]["source"] == "Weather Guide" - assert citations[0]["start"] == 10 - assert citations[0]["end"] == 25 + assert len(citations.citations) == 1 + assert citations.citations[0].cited_text == "The sky is blue" + assert citations.citations[0].document_title == "Weather Guide" def test_citations_postprocessing(): """Test that citations are properly processed in adapter postprocessing.""" from dspy.adapters.chat_adapter import ChatAdapter + from dspy.signatures.signature import Signature + + class CitationSignature(Signature): + """Test signature with citations.""" + question: str = dspy.InputField() + answer: str = dspy.OutputField() + citations: dspy.Citations = dspy.OutputField() adapter = ChatAdapter() @@ -154,10 +134,11 @@ def test_citations_postprocessing(): "text": "[[ ## answer ## ]]\nThe answer is blue.\n\n[[ ## citations ## ]]\n[]", "citations": [ { - "text": "The sky is blue", - "source": "Weather Guide", - "start": 10, - "end": 25 + "cited_text": "The sky is blue", + "document_index": 0, + "document_title": "Weather Guide", + "start_char_index": 10, + "end_char_index": 25 } ] }] @@ -172,54 +153,47 @@ def test_citations_postprocessing(): # Should have Citations object in the result assert len(result) == 1 assert "citations" in result[0] - assert isinstance(result[0]["citations"], Citations) + assert isinstance(result[0]["citations"], dspy.Citations) assert len(result[0]["citations"]) == 1 - assert result[0]["citations"][0].text == "The sky is blue" - - -def test_citations_without_citations(): - """Test that processing works when no citations are present.""" - from dspy.adapters.chat_adapter import ChatAdapter + assert result[0]["citations"][0].cited_text == "The sky is blue" - adapter = ChatAdapter() - - # Mock outputs without citations - outputs = [{ - "text": "[[ ## answer ## ]]\nThe answer is blue.\n\n[[ ## citations ## ]]\n[]" - }] - - # Process with citation signature - result = adapter._call_postprocess( - CitationSignature, - CitationSignature, - outputs - ) - # Should still work, with None for citations (since no citations in LM response) - assert len(result) == 1 - # Note: When no citations are in LM response, citations field should be None - # but the field gets set to empty list from parsing. Let's verify it's empty. - assert result[0]["citations"] is None or ( - isinstance(result[0]["citations"], Citations) and len(result[0]["citations"]) == 0 - ) +def test_citation_extraction_from_lm_response(): + """Test citation extraction from mock LM response.""" + from unittest.mock import MagicMock + from dspy.clients.base_lm import BaseLM -def test_citation_imports(): - """Test that Citations can be imported from main dspy module.""" - assert hasattr(dspy, "Citations") - assert dspy.Citations is Citations + # Create a mock response with citations in new LiteLLM format + mock_response = MagicMock() + mock_choice = MagicMock() + mock_message = MagicMock() - # Individual Citation should only be accessible via Citations.Citation - assert not hasattr(dspy, "Citation") + # Mock provider_specific_fields with citations (Anthropic format) + mock_message.provider_specific_fields = { + "citations": [ + { + "type": "char_location", + "cited_text": "The sky is blue", + "document_index": 0, + "document_title": "Weather Guide", + "start_char_index": 10, + "end_char_index": 25 + } + ] + } + mock_choice.message = mock_message + mock_response.choices = [mock_choice] -def test_citation_access_pattern(): - """Test that Citation class is accessible as Citations.Citation (like ToolCalls pattern).""" - # Test that we can create Citation objects via Citations.Citation - citation = Citations.Citation(text="Hello", source="World") - assert citation.text == "Hello" - assert citation.source == "World" + # Create BaseLM instance and test citation extraction + lm = BaseLM(model="test") + citations = lm._extract_citations_from_response(mock_response, mock_choice) - # Test that Citations.Citation is the correct class - assert hasattr(Citations, "Citation") - assert isinstance(citation, Citations.Citation) + assert citations is not None + assert len(citations) == 1 + assert citations[0]["cited_text"] == "The sky is blue" + assert citations[0]["document_index"] == 0 + assert citations[0]["document_title"] == "Weather Guide" + assert citations[0]["start_char_index"] == 10 + assert citations[0]["end_char_index"] == 25 From 8ea14f5d909dba5d5616ed6fb99c317241a22728 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Wed, 27 Aug 2025 18:57:25 +0900 Subject: [PATCH 04/18] use recent model --- dspy/adapters/types/citation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dspy/adapters/types/citation.py b/dspy/adapters/types/citation.py index 6d8fa87bdb..7b81ba890d 100644 --- a/dspy/adapters/types/citation.py +++ b/dspy/adapters/types/citation.py @@ -38,7 +38,7 @@ class AnswerWithSources(Signature): ] # Use with a model that supports citations like Claude - lm = dspy.LM("anthropic/claude-3-5-sonnet-20241022") + lm = dspy.LM("anthropic/claude-opus-4-1-20250805") predictor = dspy.Predict(AnswerWithSources, lm=lm) result = predictor(documents=docs, question="What temperature does water boil?") From 322ff59ca5bc0a2b377076520724c24b2c53338c Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Wed, 27 Aug 2025 19:05:15 +0900 Subject: [PATCH 05/18] address comments --- dspy/adapters/types/citation.py | 4 ++-- dspy/adapters/types/document.py | 13 ++++++++++--- dspy/clients/base_lm.py | 2 +- tests/adapters/test_document.py | 16 ++++++++++------ 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/dspy/adapters/types/citation.py b/dspy/adapters/types/citation.py index 7b81ba890d..c635d53d76 100644 --- a/dspy/adapters/types/citation.py +++ b/dspy/adapters/types/citation.py @@ -27,11 +27,11 @@ class AnswerWithSources(Signature): # Create documents to provide as sources docs = [ dspy.Document( - content="The Earth orbits the Sun in an elliptical path.", + data="The Earth orbits the Sun in an elliptical path.", title="Basic Astronomy Facts" ), dspy.Document( - content="Water boils at 100°C at standard atmospheric pressure.", + data="Water boils at 100°C at standard atmospheric pressure.", title="Physics Fundamentals", metadata={"author": "Dr. Smith", "year": 2023} ) diff --git a/dspy/adapters/types/document.py b/dspy/adapters/types/document.py index 57fa257a9b..73601d10d8 100644 --- a/dspy/adapters/types/document.py +++ b/dspy/adapters/types/document.py @@ -64,17 +64,24 @@ def format(self) -> list[dict[str, Any]]: Returns: A list containing the document block in the format expected by citation-enabled language models. """ - return { + document_block = { "type": "document", "source": { "type": "text", "media_type": self.media_type, "data": self.data }, - "citations": {"enabled": True}, - "title": self.title + "citations": {"enabled": True} } + if self.title: + document_block["title"] = self.title + + if self.context: + document_block["context"] = self.context + + return [document_block] + @classmethod diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 513f980273..48467d2be2 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -221,7 +221,7 @@ def _extract_citations_from_response(self, response, choice): citations = [] for citation_data in citations_data: citation_dict = { - "type": citation_data.get("type", ""), + "type": citation_data.get("type", "char_location"), "cited_text": citation_data.get("cited_text", ""), "document_index": citation_data.get("document_index", 0), "document_title": citation_data.get("document_title"), diff --git a/tests/adapters/test_document.py b/tests/adapters/test_document.py index dd7d5f2d3f..a4ee91d973 100644 --- a/tests/adapters/test_document.py +++ b/tests/adapters/test_document.py @@ -47,9 +47,13 @@ def test_document_format(): formatted = doc.format() - assert formatted["type"] == "document" - assert formatted["source"]["type"] == "text" - assert formatted["source"]["media_type"] == "text/plain" - assert formatted["source"]["data"] == "The sky is blue." - assert formatted["title"] == "Color Facts" - assert formatted["citations"]["enabled"] is True + assert isinstance(formatted, list) + assert len(formatted) == 1 + + doc_block = formatted[0] + assert doc_block["type"] == "document" + assert doc_block["source"]["type"] == "text" + assert doc_block["source"]["media_type"] == "text/plain" + assert doc_block["source"]["data"] == "The sky is blue." + assert doc_block["title"] == "Color Facts" + assert doc_block["citations"]["enabled"] is True From 23648c1238a54da9d84506409845fc3d18521a08 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Thu, 28 Aug 2025 12:03:10 +0900 Subject: [PATCH 06/18] remove metadata --- dspy/adapters/types/document.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/dspy/adapters/types/document.py b/dspy/adapters/types/document.py index 73601d10d8..046822be5b 100644 --- a/dspy/adapters/types/document.py +++ b/dspy/adapters/types/document.py @@ -16,7 +16,6 @@ class Document(Type): data: The text content of the document title: Optional title for the document (used in citations) media_type: MIME type of the document content (defaults to "text/plain") - metadata: Optional additional metadata about the document context: Optional context information about the document Example: @@ -40,7 +39,6 @@ class AnswerWithSources(Signature): dspy.Document( data="Water boils at 100°C at standard atmospheric pressure.", title="Physics Fundamentals", - metadata={"author": "Dr. Smith", "year": 2023} ) ] @@ -55,7 +53,6 @@ class AnswerWithSources(Signature): data: str title: str | None = None media_type: Literal["text/plain", "application/pdf"] = "text/plain" - metadata: dict[str, Any] | None = None context: str | None = None def format(self) -> list[dict[str, Any]]: From a14a90c45268665e61ad5a3af05d8ce20d066e0c Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Thu, 28 Aug 2025 12:09:34 +0900 Subject: [PATCH 07/18] fix test --- tests/adapters/test_document.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/adapters/test_document.py b/tests/adapters/test_document.py index a4ee91d973..a840e45f6a 100644 --- a/tests/adapters/test_document.py +++ b/tests/adapters/test_document.py @@ -28,13 +28,11 @@ def test_document_with_all_fields(): data="Water boils at 100°C at standard pressure.", title="Physics Facts", media_type="application/pdf", - metadata={"author": "Dr. Smith", "year": 2023}, context="Laboratory conditions" ) assert doc.data == "Water boils at 100°C at standard pressure." assert doc.title == "Physics Facts" assert doc.media_type == "application/pdf" - assert doc.metadata == {"author": "Dr. Smith", "year": 2023} assert doc.context == "Laboratory conditions" From 8f047dd35b96d921fc5013da6e5e906a0b05dbb9 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Thu, 4 Sep 2025 01:02:53 -0700 Subject: [PATCH 08/18] support citation streaming --- dspy/streaming/streaming_listener.py | 22 ++++++- tests/streaming/test_streaming.py | 86 ++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 1 deletion(-) diff --git a/dspy/streaming/streaming_listener.py b/dspy/streaming/streaming_listener.py index 59168a7b40..5dd7563ba3 100644 --- a/dspy/streaming/streaming_listener.py +++ b/dspy/streaming/streaming_listener.py @@ -7,6 +7,7 @@ from dspy.adapters.chat_adapter import ChatAdapter from dspy.adapters.json_adapter import JSONAdapter +from dspy.adapters.types.citation import Citations from dspy.adapters.xml_adapter import XMLAdapter from dspy.dsp.utils.settings import settings from dspy.streaming.messages import StreamResponse @@ -101,6 +102,20 @@ def receive(self, chunk: ModelResponseStream): except Exception: return + # Handle anthropic citations. see https://docs.litellm.ai/docs/providers/anthropic#beta-citations-api + try: + if self._is_citation_type(): + if chunk_citation := chunk.choices[0].delta.provider_specific_fields.get("citation", None): + return StreamResponse( + self.predict_name, + self.signature_field_name, + Citations.from_dict_list([chunk_citation]), + is_last_chunk=False, + ) + except Exception: + pass + + if chunk_message and start_identifier in chunk_message: # If the cache is hit, the chunk_message could be the full response. When it happens we can # directly end the stream listening. In some models like gemini, each stream chunk can be multiple @@ -203,6 +218,11 @@ def flush(self) -> str: f"{', '.join([a.__name__ for a in ADAPTER_SUPPORT_STREAMING])}" ) + def _is_citation_type(self) -> bool: + """Check if the signature field is a citations field.""" + from dspy.predict import Predict + return isinstance(self.predict, Predict) and getattr(self.predict.signature.output_fields.get(self.signature_field_name, None), "annotation", None) == Citations + def find_predictor_for_stream_listeners(program: "Module", stream_listeners: list[StreamListener]): """Find the predictor for each stream listener. @@ -230,7 +250,7 @@ def find_predictor_for_stream_listeners(program: "Module", stream_listeners: lis "predictor to use for streaming. Please specify the predictor to listen to." ) - if field_info.annotation is not str: + if field_info.annotation not in [str, Citations]: raise ValueError( f"Stream listener can only be applied to string output field, but your field {field_name} is of " f"type {field_info.annotation}." diff --git a/tests/streaming/test_streaming.py b/tests/streaming/test_streaming.py index 2fbd7abd60..d0f10db6ea 100644 --- a/tests/streaming/test_streaming.py +++ b/tests/streaming/test_streaming.py @@ -874,3 +874,89 @@ async def send_to_stream(): assert isinstance(all_chunks[0], CustomChunk) assert isinstance(all_chunks[1], dspy.Prediction) + + +@pytest.mark.anyio +async def test_streaming_with_citations(): + class AnswerWithSources(dspy.Signature): + """Answer questions using provided documents with citations.""" + documents: list[dspy.Document] = dspy.InputField() + question: str = dspy.InputField() + answer: str = dspy.OutputField() + citations: dspy.Citations = dspy.OutputField() + + class MyProgram(dspy.Module): + def __init__(self): + super().__init__() + self.predict = dspy.Predict(AnswerWithSources) + + def forward(self, documents, question, **kwargs): + return self.predict(documents=documents, question=question, **kwargs) + + async def citation_stream(*args, **kwargs): + # Stream chunks with citation data in provider_specific_fields + yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="[[ ##"))]) + yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" answer"))]) + yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" ## ]]\n\n"))]) + yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="Water"))]) + yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" boils"))]) + yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" at"))]) + yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" 100°C"))]) + yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="."))]) + yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="\n\n"))]) + yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="[[ ##"))]) + yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" citations"))]) + yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" ## ]]\n\n"))]) + yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content='[{"type": "char_location", "cited_text": "Water boils at 100°C", "document_index": 0, "document_title": "Physics Facts", "start_char_index": 0, "end_char_index": 19}]'))]) + yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta( + content="", + provider_specific_fields={ + "citation": { + "type": "char_location", + "cited_text": "Water boils at 100°C", + "document_index": 0, + "document_title": "Physics Facts", + "start_char_index": 0, + "end_char_index": 19 + } + } + ))]) + yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="\n\n"))]) + yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="[[ ##"))]) + yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" completed"))]) + yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" ## ]]"))]) + + # Mock the final response choice to include provider_specific_fields with citations + with mock.patch("litellm.acompletion", return_value=citation_stream()): + program = dspy.streamify( + MyProgram(), + stream_listeners=[ + dspy.streaming.StreamListener(signature_field_name="citations"), + ], + ) + + # Create test documents + docs = [dspy.Document(data="Water boils at 100°C at standard pressure.", title="Physics Facts")] + + with dspy.context(lm=dspy.LM("anthropic/claude-3-5-sonnet-20241022", cache=False)): + output = program(documents=docs, question="What temperature does water boil?") + citation_chunks = [] + final_prediction = None + async for value in output: + if isinstance(value, dspy.streaming.StreamResponse) and value.signature_field_name == "citations": + citation_chunks.append(value) + elif isinstance(value, dspy.Prediction): + final_prediction = value + + # Test that we received citation chunks from streaming + assert len(citation_chunks) > 0 + citation_chunk = citation_chunks[0] + assert isinstance(citation_chunk.chunk, dspy.Citations) + assert len(citation_chunk.chunk) == 1 + assert citation_chunk.chunk[0].cited_text == "Water boils at 100°C" + assert citation_chunk.chunk[0].document_title == "Physics Facts" + + # Test that prediction contains the expected fields + assert final_prediction is not None + assert hasattr(final_prediction, "answer") + assert hasattr(final_prediction, "citations") From 6e4777d1de5b111bbd9ddaec56a3969ee2ef0470 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Thu, 4 Sep 2025 01:09:26 -0700 Subject: [PATCH 09/18] add supported text --- dspy/adapters/types/citation.py | 7 +++++- dspy/clients/base_lm.py | 9 ++++---- tests/adapters/test_citation.py | 40 ++++++++++++++++++--------------- 3 files changed, 33 insertions(+), 23 deletions(-) diff --git a/dspy/adapters/types/citation.py b/dspy/adapters/types/citation.py index c635d53d76..42c459bf2c 100644 --- a/dspy/adapters/types/citation.py +++ b/dspy/adapters/types/citation.py @@ -55,6 +55,7 @@ class Citation(Type): document_title: str | None = None start_char_index: int end_char_index: int + supported_text: str | None = None def format(self) -> dict[str, Any]: """Format citation as dictionary for LM consumption. @@ -73,6 +74,9 @@ def format(self) -> dict[str, Any]: if self.document_title: citation_dict["document_title"] = self.document_title + if self.supported_text: + citation_dict["supported_text"] = self.supported_text + return citation_dict citations: list[Citation] @@ -96,7 +100,8 @@ def from_dict_list(cls, citations_dicts: list[dict[str, Any]]) -> "Citations": "document_index": 0, "document_title": "Weather Guide", "start_char_index": 0, - "end_char_index": 15 + "end_char_index": 15, + "supported_text": "The sky was blue yesterday." } ] citations = Citations.from_dict_list(citations_dict) diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 48467d2be2..83c4b10894 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -222,11 +222,12 @@ def _extract_citations_from_response(self, response, choice): for citation_data in citations_data: citation_dict = { "type": citation_data.get("type", "char_location"), - "cited_text": citation_data.get("cited_text", ""), - "document_index": citation_data.get("document_index", 0), + "cited_text": citation_data.get("cited_text"), + "document_index": citation_data.get("document_index"), "document_title": citation_data.get("document_title"), - "start_char_index": citation_data.get("start_char_index", 0), - "end_char_index": citation_data.get("end_char_index", 0), + "start_char_index": citation_data.get("start_char_index"), + "end_char_index": citation_data.get("end_char_index"), + "supported_text": citation_data.get("supported_text"), } citations.append(citation_dict) return citations diff --git a/tests/adapters/test_citation.py b/tests/adapters/test_citation.py index 1d7bbedc7d..abbd97929b 100644 --- a/tests/adapters/test_citation.py +++ b/tests/adapters/test_citation.py @@ -5,21 +5,21 @@ def test_citation_validate_input(): - # Create a `dspy.Citations.Citation` instance with valid data. citation = dspy.Citations.Citation( cited_text="The Earth orbits the Sun.", document_index=0, start_char_index=0, - end_char_index=23 + end_char_index=23, + supported_text="The Earth orbits the Sun." ) assert citation.cited_text == "The Earth orbits the Sun." assert citation.document_index == 0 assert citation.start_char_index == 0 assert citation.end_char_index == 23 assert citation.type == "char_location" + assert citation.supported_text == "The Earth orbits the Sun." with pytest.raises(pydantic.ValidationError): - # Try to create a `dspy.Citations.Citation` instance with missing required field. dspy.Citations.Citation(cited_text="text") @@ -31,7 +31,8 @@ class Wrapper(pydantic.BaseModel): cited_text="Hello, world!", document_index=0, start_char_index=0, - end_char_index=13 + end_char_index=13, + supported_text="Hello, world!" ) citations = dspy.Citations(citations=[citation]) wrapper = Wrapper(citations=citations) @@ -44,13 +45,15 @@ def test_citation_with_all_fields(): document_index=1, document_title="Physics Facts", start_char_index=10, - end_char_index=31 + end_char_index=31, + supported_text="Water boils at 100°C." ) assert citation.cited_text == "Water boils at 100°C." assert citation.document_index == 1 assert citation.document_title == "Physics Facts" assert citation.start_char_index == 10 assert citation.end_char_index == 31 + assert citation.supported_text == "Water boils at 100°C." def test_citation_format(): @@ -59,7 +62,8 @@ def test_citation_format(): document_index=0, document_title="Weather Guide", start_char_index=5, - end_char_index=21 + end_char_index=21, + supported_text="The sky is blue." ) formatted = citation.format() @@ -70,6 +74,7 @@ def test_citation_format(): assert formatted["document_title"] == "Weather Guide" assert formatted["start_char_index"] == 5 assert formatted["end_char_index"] == 21 + assert formatted["supported_text"] == "The sky is blue." def test_citations_format(): @@ -78,14 +83,16 @@ def test_citations_format(): cited_text="First citation", document_index=0, start_char_index=0, - end_char_index=14 + end_char_index=14, + supported_text="First citation" ), dspy.Citations.Citation( cited_text="Second citation", document_index=1, document_title="Source", start_char_index=20, - end_char_index=35 + end_char_index=35, + supported_text="Second citation" ) ]) @@ -105,7 +112,8 @@ def test_citations_from_dict_list(): "document_index": 0, "document_title": "Weather Guide", "start_char_index": 0, - "end_char_index": 15 + "end_char_index": 15, + "supported_text": "The sky was blue yesterday." } ] @@ -129,7 +137,6 @@ class CitationSignature(Signature): adapter = ChatAdapter() - # Mock outputs with citations - need valid parsed text in ChatAdapter format outputs = [{ "text": "[[ ## answer ## ]]\nThe answer is blue.\n\n[[ ## citations ## ]]\n[]", "citations": [ @@ -138,19 +145,18 @@ class CitationSignature(Signature): "document_index": 0, "document_title": "Weather Guide", "start_char_index": 10, - "end_char_index": 25 + "end_char_index": 25, + "supported_text": "The sky is blue" } ] }] - # Process with citation signature result = adapter._call_postprocess( CitationSignature, CitationSignature, outputs ) - # Should have Citations object in the result assert len(result) == 1 assert "citations" in result[0] assert isinstance(result[0]["citations"], dspy.Citations) @@ -164,12 +170,9 @@ def test_citation_extraction_from_lm_response(): from dspy.clients.base_lm import BaseLM - # Create a mock response with citations in new LiteLLM format mock_response = MagicMock() mock_choice = MagicMock() mock_message = MagicMock() - - # Mock provider_specific_fields with citations (Anthropic format) mock_message.provider_specific_fields = { "citations": [ { @@ -178,7 +181,8 @@ def test_citation_extraction_from_lm_response(): "document_index": 0, "document_title": "Weather Guide", "start_char_index": 10, - "end_char_index": 25 + "end_char_index": 25, + "supported_text": "The sky is blue" } ] } @@ -186,7 +190,6 @@ def test_citation_extraction_from_lm_response(): mock_choice.message = mock_message mock_response.choices = [mock_choice] - # Create BaseLM instance and test citation extraction lm = BaseLM(model="test") citations = lm._extract_citations_from_response(mock_response, mock_choice) @@ -197,3 +200,4 @@ def test_citation_extraction_from_lm_response(): assert citations[0]["document_title"] == "Weather Guide" assert citations[0]["start_char_index"] == 10 assert citations[0]["end_char_index"] == 25 + assert citations[0]["supported_text"] == "The sky is blue" From 1d14fdbed30a05e78924e907388ca4aefa6118ff Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Thu, 4 Sep 2025 22:17:11 -0700 Subject: [PATCH 10/18] simplify --- dspy/clients/base_lm.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 83c4b10894..3fd61003e7 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -212,25 +212,12 @@ def _extract_citations_from_response(self, response, choice): # Check for citations in LiteLLM provider_specific_fields if hasattr(response, "choices") and hasattr(choice, "message"): message = choice.message - # Check for citations in provider_specific_fields (Anthropic format) if hasattr(message, "provider_specific_fields") and message.provider_specific_fields: provider_fields = message.provider_specific_fields if isinstance(provider_fields, dict) and "citations" in provider_fields: citations_data = provider_fields["citations"] if isinstance(citations_data, list): - citations = [] - for citation_data in citations_data: - citation_dict = { - "type": citation_data.get("type", "char_location"), - "cited_text": citation_data.get("cited_text"), - "document_index": citation_data.get("document_index"), - "document_title": citation_data.get("document_title"), - "start_char_index": citation_data.get("start_char_index"), - "end_char_index": citation_data.get("end_char_index"), - "supported_text": citation_data.get("supported_text"), - } - citations.append(citation_dict) - return citations + return citations_data except Exception: # If citation extraction fails, just continue without citations pass From 4343ebffd3a96b322aaa1e12acde36607ad9890d Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Thu, 4 Sep 2025 22:21:25 -0700 Subject: [PATCH 11/18] error message --- dspy/streaming/streaming_listener.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dspy/streaming/streaming_listener.py b/dspy/streaming/streaming_listener.py index 5dd7563ba3..5e539a1f20 100644 --- a/dspy/streaming/streaming_listener.py +++ b/dspy/streaming/streaming_listener.py @@ -252,7 +252,7 @@ def find_predictor_for_stream_listeners(program: "Module", stream_listeners: lis if field_info.annotation not in [str, Citations]: raise ValueError( - f"Stream listener can only be applied to string output field, but your field {field_name} is of " + f"Stream listener can only be applied to string or Citationsoutput field, but your field {field_name} is of " f"type {field_info.annotation}." ) From bcb3b81d1bf3067cf20f198d8b37b23932a57364 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Fri, 5 Sep 2025 00:01:41 -0700 Subject: [PATCH 12/18] comment --- dspy/adapters/base.py | 4 ++++ dspy/clients/base_lm.py | 7 +++--- dspy/streaming/streaming_listener.py | 1 - tests/adapters/test_citation.py | 34 ++++++++++------------------ 4 files changed, 19 insertions(+), 27 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 6fa30a5854..c3ed40c513 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -63,6 +63,10 @@ def _call_preprocess( return signature_for_native_function_calling + citation_output_field_name = self._get_citation_output_field_name(signature) + if citation_output_field_name: + signature = signature.delete(citation_output_field_name) + return signature def _call_postprocess( diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 3fd61003e7..961d72e6f3 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -185,7 +185,7 @@ def _process_completion(self, response, merged_kwargs): output["tool_calls"] = c.message.tool_calls # Extract citations from LiteLLM response if available - citations = self._extract_citations_from_response(response, c) + citations = self._extract_citations_from_response(c) if citations: output["citations"] = citations @@ -197,12 +197,11 @@ def _process_completion(self, response, merged_kwargs): return outputs - def _extract_citations_from_response(self, response, choice): + def _extract_citations_from_response(self, choice): """Extract citations from LiteLLM response if available. Reference: https://docs.litellm.ai/docs/providers/anthropic#beta-citations-api Args: - response: The LiteLLM response object choice: The choice object from response.choices Returns: @@ -210,7 +209,7 @@ def _extract_citations_from_response(self, response, choice): """ try: # Check for citations in LiteLLM provider_specific_fields - if hasattr(response, "choices") and hasattr(choice, "message"): + if hasattr(choice, "message"): message = choice.message if hasattr(message, "provider_specific_fields") and message.provider_specific_fields: provider_fields = message.provider_specific_fields diff --git a/dspy/streaming/streaming_listener.py b/dspy/streaming/streaming_listener.py index 5e539a1f20..e802c0d2f2 100644 --- a/dspy/streaming/streaming_listener.py +++ b/dspy/streaming/streaming_listener.py @@ -115,7 +115,6 @@ def receive(self, chunk: ModelResponseStream): except Exception: pass - if chunk_message and start_identifier in chunk_message: # If the cache is hit, the chunk_message could be the full response. When it happens we can # directly end the stream listening. In some models like gemini, each stream chunk can be multiple diff --git a/tests/adapters/test_citation.py b/tests/adapters/test_citation.py index abbd97929b..4003dc9702 100644 --- a/tests/adapters/test_citation.py +++ b/tests/adapters/test_citation.py @@ -125,7 +125,6 @@ def test_citations_from_dict_list(): def test_citations_postprocessing(): - """Test that citations are properly processed in adapter postprocessing.""" from dspy.adapters.chat_adapter import ChatAdapter from dspy.signatures.signature import Signature @@ -165,33 +164,24 @@ class CitationSignature(Signature): def test_citation_extraction_from_lm_response(): - """Test citation extraction from mock LM response.""" from unittest.mock import MagicMock from dspy.clients.base_lm import BaseLM - mock_response = MagicMock() - mock_choice = MagicMock() - mock_message = MagicMock() - mock_message.provider_specific_fields = { - "citations": [ - { - "type": "char_location", - "cited_text": "The sky is blue", - "document_index": 0, - "document_title": "Weather Guide", - "start_char_index": 10, - "end_char_index": 25, - "supported_text": "The sky is blue" - } - ] - } - - mock_choice.message = mock_message - mock_response.choices = [mock_choice] + mock_choice = MagicMock(message=MagicMock(provider_specific_fields={"citations": [ + { + "type": "char_location", + "cited_text": "The sky is blue", + "document_index": 0, + "document_title": "Weather Guide", + "start_char_index": 10, + "end_char_index": 25, + "supported_text": "The sky is blue" + } + ]})) lm = BaseLM(model="test") - citations = lm._extract_citations_from_response(mock_response, mock_choice) + citations = lm._extract_citations_from_response(mock_choice) assert citations is not None assert len(citations) == 1 From e1f13528300c7736408f3cd919398f85eb99bb7a Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Fri, 5 Sep 2025 00:23:31 -0700 Subject: [PATCH 13/18] fix nest --- dspy/clients/base_lm.py | 2 +- tests/adapters/test_citation.py | 10 ++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 961d72e6f3..422cfc42d8 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -216,7 +216,7 @@ def _extract_citations_from_response(self, choice): if isinstance(provider_fields, dict) and "citations" in provider_fields: citations_data = provider_fields["citations"] if isinstance(citations_data, list): - return citations_data + return [citation for citations in citations_data for citation in citations] except Exception: # If citation extraction fails, just continue without citations pass diff --git a/tests/adapters/test_citation.py b/tests/adapters/test_citation.py index 4003dc9702..17c575ddf5 100644 --- a/tests/adapters/test_citation.py +++ b/tests/adapters/test_citation.py @@ -151,7 +151,7 @@ class CitationSignature(Signature): }] result = adapter._call_postprocess( - CitationSignature, + CitationSignature.delete("citations"), CitationSignature, outputs ) @@ -166,9 +166,7 @@ class CitationSignature(Signature): def test_citation_extraction_from_lm_response(): from unittest.mock import MagicMock - from dspy.clients.base_lm import BaseLM - - mock_choice = MagicMock(message=MagicMock(provider_specific_fields={"citations": [ + mock_choice = MagicMock(message=MagicMock(provider_specific_fields={"citations": [[ { "type": "char_location", "cited_text": "The sky is blue", @@ -178,9 +176,9 @@ def test_citation_extraction_from_lm_response(): "end_char_index": 25, "supported_text": "The sky is blue" } - ]})) + ]]})) - lm = BaseLM(model="test") + lm = dspy.LM(model="test") citations = lm._extract_citations_from_response(mock_choice) assert citations is not None From ed1fe45e7b2cbb71583bc0a1004a0d497bcc9951 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Fri, 5 Sep 2025 13:16:38 -0700 Subject: [PATCH 14/18] improve test --- tests/streaming/test_streaming.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/streaming/test_streaming.py b/tests/streaming/test_streaming.py index d0f10db6ea..1af84dcb82 100644 --- a/tests/streaming/test_streaming.py +++ b/tests/streaming/test_streaming.py @@ -904,9 +904,6 @@ async def citation_stream(*args, **kwargs): yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" 100°C"))]) yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="."))]) yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="\n\n"))]) - yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="[[ ##"))]) - yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" citations"))]) - yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" ## ]]\n\n"))]) yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content='[{"type": "char_location", "cited_text": "Water boils at 100°C", "document_index": 0, "document_title": "Physics Facts", "start_char_index": 0, "end_char_index": 19}]'))]) yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta( content="", From cc24355c0e833b4c822173c4778717871fdcdc17 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Fri, 5 Sep 2025 13:19:26 -0700 Subject: [PATCH 15/18] simplify --- dspy/clients/base_lm.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 422cfc42d8..1aac679931 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -209,14 +209,9 @@ def _extract_citations_from_response(self, choice): """ try: # Check for citations in LiteLLM provider_specific_fields - if hasattr(choice, "message"): - message = choice.message - if hasattr(message, "provider_specific_fields") and message.provider_specific_fields: - provider_fields = message.provider_specific_fields - if isinstance(provider_fields, dict) and "citations" in provider_fields: - citations_data = provider_fields["citations"] - if isinstance(citations_data, list): - return [citation for citations in citations_data for citation in citations] + citations_data = choice.message.provider_specific_fields.get("citations") + if isinstance(citations_data, list): + return [citation for citations in citations_data for citation in citations] except Exception: # If citation extraction fails, just continue without citations pass From bf4e53acf454011603d969a9b715aab9e9ab5ca8 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Fri, 5 Sep 2025 14:11:46 -0700 Subject: [PATCH 16/18] comment --- dspy/streaming/streaming_listener.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dspy/streaming/streaming_listener.py b/dspy/streaming/streaming_listener.py index e802c0d2f2..811125dcb6 100644 --- a/dspy/streaming/streaming_listener.py +++ b/dspy/streaming/streaming_listener.py @@ -251,7 +251,7 @@ def find_predictor_for_stream_listeners(program: "Module", stream_listeners: lis if field_info.annotation not in [str, Citations]: raise ValueError( - f"Stream listener can only be applied to string or Citationsoutput field, but your field {field_name} is of " + f"Stream listener can only be applied to string or Citations output field, but your field {field_name} is of " f"type {field_info.annotation}." ) From dbb93bb255229ffa500fb030cf40baa9408e4a2c Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Tue, 9 Sep 2025 08:21:31 +0900 Subject: [PATCH 17/18] move citations and document into experimental --- dspy/__init__.py | 2 +- dspy/adapters/__init__.py | 4 +--- dspy/adapters/base.py | 3 ++- dspy/adapters/types/__init__.py | 4 +--- dspy/adapters/types/citation.py | 13 ++++++++----- dspy/adapters/types/document.py | 13 ++++++++----- dspy/clients/base_lm.py | 6 ++---- dspy/experimental/__init__.py | 7 +++++++ dspy/streaming/streaming_listener.py | 4 ++-- tests/adapters/test_citation.py | 27 ++++++++++++++------------- tests/adapters/test_document.py | 18 +++++++++--------- tests/streaming/test_streaming.py | 9 +++++---- 12 files changed, 60 insertions(+), 50 deletions(-) create mode 100644 dspy/experimental/__init__.py diff --git a/dspy/__init__.py b/dspy/__init__.py index 8b21238c2f..ea4c75a862 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -6,7 +6,7 @@ from dspy.evaluate import Evaluate # isort: skip from dspy.clients import * # isort: skip -from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, TwoStepAdapter, Image, Audio, History, Type, Tool, ToolCalls, Code, Citations, Document # isort: skip +from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, TwoStepAdapter, Image, Audio, History, Type, Tool, ToolCalls, Code # isort: skip from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging from dspy.utils.asyncify import asyncify from dspy.utils.syncify import syncify diff --git a/dspy/adapters/__init__.py b/dspy/adapters/__init__.py index 79edd1ec53..1dea6da47a 100644 --- a/dspy/adapters/__init__.py +++ b/dspy/adapters/__init__.py @@ -2,7 +2,7 @@ from dspy.adapters.chat_adapter import ChatAdapter from dspy.adapters.json_adapter import JSONAdapter from dspy.adapters.two_step_adapter import TwoStepAdapter -from dspy.adapters.types import Audio, Citations, Code, Document, History, Image, Tool, ToolCalls, Type +from dspy.adapters.types import Audio, Code, History, Image, Tool, ToolCalls, Type from dspy.adapters.xml_adapter import XMLAdapter __all__ = [ @@ -13,8 +13,6 @@ "Image", "Audio", "Code", - "Citations", - "Document", "JSONAdapter", "XMLAdapter", "TwoStepAdapter", diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index c3ed40c513..f714b338bc 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -4,9 +4,10 @@ import json_repair import litellm -from dspy.adapters.types import Citations, History +from dspy.adapters.types import History from dspy.adapters.types.base_type import split_message_content_for_custom_types from dspy.adapters.types.tool import Tool, ToolCalls +from dspy.experimental import Citations from dspy.signatures.signature import Signature from dspy.utils.callback import BaseCallback, with_callbacks diff --git a/dspy/adapters/types/__init__.py b/dspy/adapters/types/__init__.py index e7f28faadb..11b9faee1b 100644 --- a/dspy/adapters/types/__init__.py +++ b/dspy/adapters/types/__init__.py @@ -1,10 +1,8 @@ from dspy.adapters.types.audio import Audio from dspy.adapters.types.base_type import Type -from dspy.adapters.types.citation import Citations from dspy.adapters.types.code import Code -from dspy.adapters.types.document import Document from dspy.adapters.types.history import History from dspy.adapters.types.image import Image from dspy.adapters.types.tool import Tool, ToolCalls -__all__ = ["History", "Image", "Audio", "Type", "Tool", "ToolCalls", "Code", "Citations", "Document"] +__all__ = ["History", "Image", "Audio", "Type", "Tool", "ToolCalls", "Code"] diff --git a/dspy/adapters/types/citation.py b/dspy/adapters/types/citation.py index 42c459bf2c..b3c613c070 100644 --- a/dspy/adapters/types/citation.py +++ b/dspy/adapters/types/citation.py @@ -3,8 +3,10 @@ import pydantic from dspy.adapters.types.base_type import Type +from dspy.utils.annotation import experimental +@experimental(version="3.0.4") class Citations(Type): """Citations extracted from an LM response with source references. @@ -16,21 +18,22 @@ class Citations(Type): ```python import dspy from dspy.signatures import Signature + from dspy.experimental import Citations, Document class AnswerWithSources(Signature): '''Answer questions using provided documents with citations.''' - documents: list[dspy.Document] = dspy.InputField() + documents: list[Document] = dspy.InputField() question: str = dspy.InputField() answer: str = dspy.OutputField() - citations: dspy.Citations = dspy.OutputField() + citations: Citations = dspy.OutputField() # Create documents to provide as sources docs = [ - dspy.Document( + Document( data="The Earth orbits the Sun in an elliptical path.", title="Basic Astronomy Facts" ), - dspy.Document( + Document( data="Water boils at 100°C at standard atmospheric pressure.", title="Physics Fundamentals", metadata={"author": "Dr. Smith", "year": 2023} @@ -150,7 +153,7 @@ def validate_input(cls, data: Any): # Handle case where data is a single citation dict return {"citations": [cls.Citation(**data)]} - raise ValueError(f"Received invalid value for `dspy.Citations`: {data}") + raise ValueError(f"Received invalid value for `Citations`: {data}") def __iter__(self): """Allow iteration over citations.""" diff --git a/dspy/adapters/types/document.py b/dspy/adapters/types/document.py index 046822be5b..cd492a6ffb 100644 --- a/dspy/adapters/types/document.py +++ b/dspy/adapters/types/document.py @@ -3,8 +3,10 @@ import pydantic from dspy.adapters.types.base_type import Type +from dspy.utils.annotation import experimental +@experimental(version="3.0.4") class Document(Type): """A document type for providing content that can be cited by language models. @@ -22,21 +24,22 @@ class Document(Type): ```python import dspy from dspy.signatures import Signature + from dspy.experimental import Document, Citations class AnswerWithSources(Signature): '''Answer questions using provided documents with citations.''' - documents: list[dspy.Document] = dspy.InputField() + documents: list[Document] = dspy.InputField() question: str = dspy.InputField() answer: str = dspy.OutputField() - citations: dspy.Citations = dspy.OutputField() + citations: Citations = dspy.OutputField() # Create documents docs = [ - dspy.Document( + Document( data="The Earth orbits the Sun in an elliptical path.", title="Basic Astronomy Facts" ), - dspy.Document( + Document( data="Water boils at 100°C at standard atmospheric pressure.", title="Physics Fundamentals", ) @@ -103,7 +106,7 @@ def validate_input(cls, data: Any): elif isinstance(data, dict): return data - raise ValueError(f"Received invalid value for `dspy.Document`: {data}") + raise ValueError(f"Received invalid value for `Document`: {data}") def __str__(self) -> str: """String representation showing title and content length.""" diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 1aac679931..ad241e8d5d 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -205,7 +205,7 @@ def _extract_citations_from_response(self, choice): choice: The choice object from response.choices Returns: - List of citation dictionaries or None if no citations found + A list of citation dictionaries or None if no citations found """ try: # Check for citations in LiteLLM provider_specific_fields @@ -214,9 +214,7 @@ def _extract_citations_from_response(self, choice): return [citation for citations in citations_data for citation in citations] except Exception: # If citation extraction fails, just continue without citations - pass - - return None + return None def _process_response(self, response): """Process the response of OpenAI Response API and extract outputs. diff --git a/dspy/experimental/__init__.py b/dspy/experimental/__init__.py new file mode 100644 index 0000000000..651c7d97f2 --- /dev/null +++ b/dspy/experimental/__init__.py @@ -0,0 +1,7 @@ +from dspy.adapters.types.citation import Citations +from dspy.adapters.types.document import Document + +__all__ = [ + "Citations", + "Document", +] diff --git a/dspy/streaming/streaming_listener.py b/dspy/streaming/streaming_listener.py index 811125dcb6..e307593513 100644 --- a/dspy/streaming/streaming_listener.py +++ b/dspy/streaming/streaming_listener.py @@ -104,7 +104,7 @@ def receive(self, chunk: ModelResponseStream): # Handle anthropic citations. see https://docs.litellm.ai/docs/providers/anthropic#beta-citations-api try: - if self._is_citation_type(): + if self._signature_field_is_citation_type(): if chunk_citation := chunk.choices[0].delta.provider_specific_fields.get("citation", None): return StreamResponse( self.predict_name, @@ -217,7 +217,7 @@ def flush(self) -> str: f"{', '.join([a.__name__ for a in ADAPTER_SUPPORT_STREAMING])}" ) - def _is_citation_type(self) -> bool: + def _signature_field_is_citation_type(self) -> bool: """Check if the signature field is a citations field.""" from dspy.predict import Predict return isinstance(self.predict, Predict) and getattr(self.predict.signature.output_fields.get(self.signature_field_name, None), "annotation", None) == Citations diff --git a/tests/adapters/test_citation.py b/tests/adapters/test_citation.py index 17c575ddf5..8d2881b074 100644 --- a/tests/adapters/test_citation.py +++ b/tests/adapters/test_citation.py @@ -2,10 +2,11 @@ import pytest import dspy +from dspy.experimental import Citations def test_citation_validate_input(): - citation = dspy.Citations.Citation( + citation = Citations.Citation( cited_text="The Earth orbits the Sun.", document_index=0, start_char_index=0, @@ -20,27 +21,27 @@ def test_citation_validate_input(): assert citation.supported_text == "The Earth orbits the Sun." with pytest.raises(pydantic.ValidationError): - dspy.Citations.Citation(cited_text="text") + Citations.Citation(cited_text="text") def test_citations_in_nested_type(): class Wrapper(pydantic.BaseModel): - citations: dspy.Citations + citations: Citations - citation = dspy.Citations.Citation( + citation = Citations.Citation( cited_text="Hello, world!", document_index=0, start_char_index=0, end_char_index=13, supported_text="Hello, world!" ) - citations = dspy.Citations(citations=[citation]) + citations = Citations(citations=[citation]) wrapper = Wrapper(citations=citations) assert wrapper.citations.citations[0].cited_text == "Hello, world!" def test_citation_with_all_fields(): - citation = dspy.Citations.Citation( + citation = Citations.Citation( cited_text="Water boils at 100°C.", document_index=1, document_title="Physics Facts", @@ -57,7 +58,7 @@ def test_citation_with_all_fields(): def test_citation_format(): - citation = dspy.Citations.Citation( + citation = Citations.Citation( cited_text="The sky is blue.", document_index=0, document_title="Weather Guide", @@ -78,15 +79,15 @@ def test_citation_format(): def test_citations_format(): - citations = dspy.Citations(citations=[ - dspy.Citations.Citation( + citations = Citations(citations=[ + Citations.Citation( cited_text="First citation", document_index=0, start_char_index=0, end_char_index=14, supported_text="First citation" ), - dspy.Citations.Citation( + Citations.Citation( cited_text="Second citation", document_index=1, document_title="Source", @@ -117,7 +118,7 @@ def test_citations_from_dict_list(): } ] - citations = dspy.Citations.from_dict_list(citations_data) + citations = Citations.from_dict_list(citations_data) assert len(citations.citations) == 1 assert citations.citations[0].cited_text == "The sky is blue" @@ -132,7 +133,7 @@ class CitationSignature(Signature): """Test signature with citations.""" question: str = dspy.InputField() answer: str = dspy.OutputField() - citations: dspy.Citations = dspy.OutputField() + citations: Citations = dspy.OutputField() adapter = ChatAdapter() @@ -158,7 +159,7 @@ class CitationSignature(Signature): assert len(result) == 1 assert "citations" in result[0] - assert isinstance(result[0]["citations"], dspy.Citations) + assert isinstance(result[0]["citations"], Citations) assert len(result[0]["citations"]) == 1 assert result[0]["citations"][0].cited_text == "The sky is blue" diff --git a/tests/adapters/test_document.py b/tests/adapters/test_document.py index a840e45f6a..98c7144407 100644 --- a/tests/adapters/test_document.py +++ b/tests/adapters/test_document.py @@ -1,30 +1,30 @@ import pydantic import pytest -import dspy +from dspy.experimental import Document def test_document_validate_input(): - # Create a `dspy.Document` instance with valid data. - doc = dspy.Document(data="The Earth orbits the Sun.") + # Create a `Document` instance with valid data. + doc = Document(data="The Earth orbits the Sun.") assert doc.data == "The Earth orbits the Sun." with pytest.raises(pydantic.ValidationError): - # Try to create a `dspy.Document` instance with invalid type. - dspy.Document(data=123) + # Try to create a `Document` instance with invalid type. + Document(data=123) def test_document_in_nested_type(): class Wrapper(pydantic.BaseModel): - document: dspy.Document + document: Document - doc = dspy.Document(data="Hello, world!") + doc = Document(data="Hello, world!") wrapper = Wrapper(document=doc) assert wrapper.document.data == "Hello, world!" def test_document_with_all_fields(): - doc = dspy.Document( + doc = Document( data="Water boils at 100°C at standard pressure.", title="Physics Facts", media_type="application/pdf", @@ -37,7 +37,7 @@ def test_document_with_all_fields(): def test_document_format(): - doc = dspy.Document( + doc = Document( data="The sky is blue.", title="Color Facts", media_type="text/plain" diff --git a/tests/streaming/test_streaming.py b/tests/streaming/test_streaming.py index 1af84dcb82..134320d61f 100644 --- a/tests/streaming/test_streaming.py +++ b/tests/streaming/test_streaming.py @@ -9,6 +9,7 @@ from litellm.types.utils import Delta, ModelResponseStream, StreamingChoices import dspy +from dspy.experimental import Citations, Document from dspy.streaming import StatusMessage, StatusMessageProvider, streaming_response @@ -880,10 +881,10 @@ async def send_to_stream(): async def test_streaming_with_citations(): class AnswerWithSources(dspy.Signature): """Answer questions using provided documents with citations.""" - documents: list[dspy.Document] = dspy.InputField() + documents: list[Document] = dspy.InputField() question: str = dspy.InputField() answer: str = dspy.OutputField() - citations: dspy.Citations = dspy.OutputField() + citations: Citations = dspy.OutputField() class MyProgram(dspy.Module): def __init__(self): @@ -933,7 +934,7 @@ async def citation_stream(*args, **kwargs): ) # Create test documents - docs = [dspy.Document(data="Water boils at 100°C at standard pressure.", title="Physics Facts")] + docs = [Document(data="Water boils at 100°C at standard pressure.", title="Physics Facts")] with dspy.context(lm=dspy.LM("anthropic/claude-3-5-sonnet-20241022", cache=False)): output = program(documents=docs, question="What temperature does water boil?") @@ -948,7 +949,7 @@ async def citation_stream(*args, **kwargs): # Test that we received citation chunks from streaming assert len(citation_chunks) > 0 citation_chunk = citation_chunks[0] - assert isinstance(citation_chunk.chunk, dspy.Citations) + assert isinstance(citation_chunk.chunk, Citations) assert len(citation_chunk.chunk) == 1 assert citation_chunk.chunk[0].cited_text == "Water boils at 100°C" assert citation_chunk.chunk[0].document_title == "Physics Facts" From 2b452c3e6eaf919c5770e4b86fb9872d4b0eebf4 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Tue, 9 Sep 2025 08:22:22 +0900 Subject: [PATCH 18/18] comment --- dspy/clients/base_lm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index ad241e8d5d..68fc43cf75 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -213,7 +213,6 @@ def _extract_citations_from_response(self, choice): if isinstance(citations_data, list): return [citation for citations in citations_data for citation in citations] except Exception: - # If citation extraction fails, just continue without citations return None def _process_response(self, response):