diff --git a/dspy/retrievers/databricks_rm.py b/dspy/retrievers/databricks_rm.py index dc4966bf2a..8f4e86af83 100644 --- a/dspy/retrievers/databricks_rm.py +++ b/dspy/retrievers/databricks_rm.py @@ -1,31 +1,16 @@ import json import os -from dataclasses import dataclass from importlib.util import find_spec from typing import Any import requests import dspy -from dspy.primitives.prediction import Prediction +from dspy.retrievers.retrieve import Document _databricks_sdk_installed = find_spec("databricks.sdk") is not None -@dataclass -class Document: - page_content: str - metadata: dict[str, Any] - type: str - - def to_dict(self) -> dict[str, Any]: - return { - "page_content": self.page_content, - "metadata": self.metadata, - "type": self.type, - } - - class DatabricksRM(dspy.Retrieve): """ A retriever module that uses a Databricks Mosaic AI Vector Search Index to return the top-k @@ -129,19 +114,30 @@ def __init__( compatible with the Databricks Mosaic Agent Framework. """ super().__init__(k=k) - self.databricks_token = databricks_token if databricks_token is not None else os.environ.get("DATABRICKS_TOKEN") + self.databricks_token = ( + databricks_token + if databricks_token is not None + else os.environ.get("DATABRICKS_TOKEN") + ) self.databricks_endpoint = ( - databricks_endpoint if databricks_endpoint is not None else os.environ.get("DATABRICKS_HOST") + databricks_endpoint + if databricks_endpoint is not None + else os.environ.get("DATABRICKS_HOST") ) self.databricks_client_id = ( - databricks_client_id if databricks_client_id is not None else os.environ.get("DATABRICKS_CLIENT_ID") + databricks_client_id + if databricks_client_id is not None + else os.environ.get("DATABRICKS_CLIENT_ID") ) self.databricks_client_secret = ( databricks_client_secret if databricks_client_secret is not None else os.environ.get("DATABRICKS_CLIENT_SECRET") ) - if not _databricks_sdk_installed and (self.databricks_token, self.databricks_endpoint).count(None) > 0: + if ( + not _databricks_sdk_installed + and (self.databricks_token, self.databricks_endpoint).count(None) > 0 + ): raise ValueError( "To retrieve documents with Databricks Vector Search, you must install the" " databricks-sdk Python library, supply the databricks_token and" @@ -196,12 +192,23 @@ def _get_extra_columns(self, item: dict[str, Any]) -> dict[str, Any]: extra_columns = { k: v for k, v in item.items() - if k not in [self.docs_id_column_name, self.text_column_name, self.docs_uri_column_name] + if k + not in [ + self.docs_id_column_name, + self.text_column_name, + self.docs_uri_column_name, + ] } if self.docs_id_column_name == "metadata": extra_columns = { **extra_columns, - **{"metadata": {k: v for k, v in json.loads(item["metadata"]).items() if k != "document_id"}}, + **{ + "metadata": { + k: v + for k, v in json.loads(item["metadata"]).items() + if k != "document_id" + } + }, } return extra_columns @@ -210,7 +217,7 @@ def forward( query: str | list[float], query_type: str = "ANN", filters_json: str | None = None, - ) -> dspy.Prediction | list[dict[str, Any]]: + ) -> list[Document]: """ Retrieve documents from a Databricks Mosaic AI Vector Search Index that are relevant to the specified query. @@ -286,7 +293,9 @@ def forward( ) if self.text_column_name not in col_names: - raise Exception(f"text_column_name: '{self.text_column_name}' is not in the index columns: \n {col_names}") + raise Exception( + f"text_column_name: '{self.text_column_name}' is not in the index columns: \n {col_names}" + ) # Extracting the results items = [] @@ -300,27 +309,22 @@ def forward( # Sorting results by score in descending order sorted_docs = sorted(items, key=lambda x: x["score"], reverse=True)[: self.k] - if self.use_with_databricks_agent_framework: - return [ - Document( - page_content=doc[self.text_column_name], - metadata={ - "doc_id": self._extract_doc_ids(doc), - "doc_uri": doc[self.docs_uri_column_name] if self.docs_uri_column_name else None, - } - | self._get_extra_columns(doc), - type="Document", - ).to_dict() - for doc in sorted_docs - ] - else: - # Returning the prediction - return Prediction( - docs=[doc[self.text_column_name] for doc in sorted_docs], - doc_ids=[self._extract_doc_ids(doc) for doc in sorted_docs], - doc_uris=[doc[self.docs_uri_column_name] for doc in sorted_docs] if self.docs_uri_column_name else None, - extra_columns=[self._get_extra_columns(item) for item in sorted_docs], - ) + return [ + Document( + page_content=doc[self.text_column_name], + metadata={ + "doc_id": self._extract_doc_ids(doc), + "doc_uri": ( + doc[self.docs_uri_column_name] + if self.docs_uri_column_name + else None + ), + } + | self._get_extra_columns(doc), + type="Document", + ).to_dict() + for doc in sorted_docs + ] @staticmethod def _query_via_databricks_sdk( @@ -365,7 +369,9 @@ def _query_via_databricks_sdk( from databricks.sdk import WorkspaceClient if (query_text, query_vector).count(None) != 1: - raise ValueError("Exactly one of query_text or query_vector must be specified.") + raise ValueError( + "Exactly one of query_text or query_vector must be specified." + ) if databricks_client_secret and databricks_client_id: # Use client ID and secret for authentication if they are provided @@ -373,7 +379,9 @@ def _query_via_databricks_sdk( client_id=databricks_client_id, client_secret=databricks_client_secret, ) - print("Creating Databricks workspace client using service principal authentication.") + print( + "Creating Databricks workspace client using service principal authentication." + ) else: # Fallback for token-based authentication @@ -424,7 +432,9 @@ def _query_via_requests( dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query. """ if (query_text, query_vector).count(None) != 1: - raise ValueError("Exactly one of query_text or query_vector must be specified.") + raise ValueError( + "Exactly one of query_text or query_vector must be specified." + ) headers = { "Authorization": f"Bearer {databricks_token}", diff --git a/dspy/retrievers/retrieve.py b/dspy/retrievers/retrieve.py index ef69d7a4bb..4bd75fa9f1 100644 --- a/dspy/retrievers/retrieve.py +++ b/dspy/retrievers/retrieve.py @@ -1,10 +1,26 @@ import random +from dataclasses import dataclass +from typing import Any from dspy.predict.parameter import Parameter from dspy.primitives.prediction import Prediction from dspy.utils.callback import with_callbacks +@dataclass +class Document: + page_content: str + metadata: dict[str, Any] + type: str + + def to_dict(self) -> dict[str, Any]: + return { + "page_content": self.page_content, + "metadata": self.metadata, + "type": self.type, + } + + def single_query_passage(passages): passages_dict = {key: [] for key in list(passages[0].keys())} for docs in passages: @@ -37,15 +53,16 @@ def load_state(self, state): setattr(self, name, value) @with_callbacks - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> Prediction | list[Document]: return self.forward(*args, **kwargs) def forward( self, query: str, k: int | None = None, + return_documents: bool = False, **kwargs, - ) -> list[str] | Prediction | list[Prediction]: + ) -> Prediction | list[Document]: k = k if k is not None else self.k import dspy @@ -56,12 +73,33 @@ def forward( passages = dspy.settings.rm(query, k=k, **kwargs) from collections.abc import Iterable + if not isinstance(passages, Iterable): # it's not an iterable yet; make it one. # TODO: we should unify the type signatures of dspy.Retriever passages = [passages] - passages = [psg.long_text for psg in passages] + + docs: list[Document] = [] + for psg in passages: + if isinstance(psg, Document): + docs.append(psg) + elif isinstance(psg, dict): + page_content = psg.get("page_content", psg.get("long_text", "")) + # support text OR long_text + metadata = psg.get("metadata", {}) + _type = psg.get("type", "Document") + docs.append( + Document(page_content=page_content, metadata=metadata, type=_type) + ) + elif isinstance(psg, str): + docs.append(Document(page_content=psg, metadata={}, type="Document")) + + if return_documents: + return docs + + passages = [psg.page_content for psg in docs] return Prediction(passages=passages) + # TODO: Consider doing Prediction.from_completions with the individual sets of passages (per query) too. diff --git a/dspy/retrievers/weaviate_rm.py b/dspy/retrievers/weaviate_rm.py index 381919254a..2d5215f634 100644 --- a/dspy/retrievers/weaviate_rm.py +++ b/dspy/retrievers/weaviate_rm.py @@ -1,7 +1,4 @@ - -import dspy -from dspy.dsp.utils import dotdict -from dspy.primitives.prediction import Prediction +from dspy.retrievers.retrieve import Document, Retrieve try: from uuid import uuid4 @@ -14,7 +11,7 @@ ) from err -class WeaviateRM(dspy.Retrieve): +class WeaviateRM(Retrieve): """A retrieval module that uses Weaviate to return the top passages for a given query. Assumes that a Weaviate collection has been created and populated with the following payload: @@ -56,7 +53,9 @@ def __init__( ): self._weaviate_collection_name = weaviate_collection_name self._weaviate_client = weaviate_client - self._weaviate_collection = self._weaviate_client.collections.get(self._weaviate_collection_name) + self._weaviate_collection = self._weaviate_client.collections.get( + self._weaviate_collection_name + ) self._weaviate_collection_text_key = weaviate_collection_text_key self._tenant_id = tenant_id @@ -70,7 +69,9 @@ def __init__( super().__init__(k=k) - def forward(self, query_or_queries: str | list[str], k: int | None = None, **kwargs) -> Prediction: + def forward( + self, query_or_queries: str | list[str], k: int | None = None, *args, **kwargs + ) -> list[Document]: """Search with Weaviate for self.k top passages for query or queries. Args: @@ -82,57 +83,78 @@ def forward(self, query_or_queries: str | list[str], k: int | None = None, **kwa dspy.Prediction: An object containing the retrieved passages. """ k = k if k is not None else self.k - queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries + queries = ( + [query_or_queries] + if isinstance(query_or_queries, str) + else query_or_queries + ) queries = [q for q in queries if q] passages, parsed_results = [], [] tenant = kwargs.pop("tenant_id", self._tenant_id) for query in queries: if self._client_type == "WeaviateClient": if tenant: - results = self._weaviate_collection.query.with_tenant(tenant).hybrid(query=query, limit=k, **kwargs) + results = self._weaviate_collection.query.with_tenant( + tenant + ).hybrid(query=query, limit=k, **kwargs) else: - results = self._weaviate_collection.query.hybrid(query=query, limit=k, **kwargs) + results = self._weaviate_collection.query.hybrid( + query=query, limit=k, **kwargs + ) - parsed_results = [result.properties[self._weaviate_collection_text_key] for result in results.objects] + parsed_results = [ + result.properties[self._weaviate_collection_text_key] + for result in results.objects + ] elif self._client_type == "Client": q = self._weaviate_client.query.get( - self._weaviate_collection_name, [self._weaviate_collection_text_key] - ) + self._weaviate_collection_name, [self._weaviate_collection_text_key] + ) if tenant: q = q.with_tenant(tenant) results = q.with_hybrid(query=query).with_limit(k).do() results = results["data"]["Get"][self._weaviate_collection_name] - parsed_results = [result[self._weaviate_collection_text_key] for result in results] + parsed_results = [ + result[self._weaviate_collection_text_key] for result in results + ] - passages.extend(dotdict({"long_text": d}) for d in parsed_results) + passages.extend(parsed_results) - return passages + return [ + Document(page_content=text, metadata={}, type="Document") + for text in passages + ] def get_objects(self, num_samples: int, fields: list[str]) -> list[dict]: """Get objects from Weaviate using the cursor API.""" if self._client_type == "WeaviateClient": objects = [] counter = 0 - for item in self._weaviate_collection.iterator(): # TODO: add tenancy scoping + for ( + item + ) in self._weaviate_collection.iterator(): # TODO: add tenancy scoping if counter >= num_samples: break new_object = {} for key in item.properties.keys(): if key in fields: - new_object[key] = item.properties[key] + new_object[key] = item.properties[key] objects.append(new_object) counter += 1 return objects else: - raise ValueError("`get_objects` is not supported for the v3 Weaviate Python client, please upgrade to v4.") + raise ValueError( + "`get_objects` is not supported for the v3 Weaviate Python client, please upgrade to v4." + ) def insert(self, new_object_properties: dict): if self._client_type == "WeaviateClient": self._weaviate_collection.data.insert( - properties=new_object_properties, - uuid=get_valid_uuid(uuid4()) - ) # TODO: add tenancy scoping + properties=new_object_properties, uuid=get_valid_uuid(uuid4()) + ) # TODO: add tenancy scoping else: - raise AttributeError("`insert` is not supported for the v3 Weaviate Python client, please upgrade to v4.") + raise AttributeError( + "`insert` is not supported for the v3 Weaviate Python client, please upgrade to v4." + ) diff --git a/uv.lock b/uv.lock index 2fd5aa3d6e..4dd0b4834f 100644 --- a/uv.lock +++ b/uv.lock @@ -664,7 +664,7 @@ wheels = [ [[package]] name = "dspy" -version = "3.0.0" +version = "3.0.1" source = { editable = "." } dependencies = [ { name = "anyio" }, @@ -740,7 +740,7 @@ requires-dist = [ { name = "datamodel-code-generator", marker = "extra == 'dev'", specifier = ">=0.26.3" }, { name = "datasets", marker = "extra == 'test-extras'", specifier = ">=2.14.6" }, { name = "diskcache", specifier = ">=5.6.0" }, - { name = "gepa", specifier = "==0.0.2" }, + { name = "gepa", extras = ["dspy"], specifier = "==0.0.4" }, { name = "joblib", specifier = "~=1.3" }, { name = "json-repair", specifier = ">=0.30.0" }, { name = "langchain-core", marker = "extra == 'langchain'" }, @@ -956,11 +956,11 @@ wheels = [ [[package]] name = "gepa" -version = "0.0.2" +version = "0.0.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/74/8b/2aba83c1a85d1260151c7b0ec57c4b34901ae7a74f27eba4e7532c2f80c8/gepa-0.0.2.tar.gz", hash = "sha256:0fa8ca333e4a69eec68087a68cdd651c9e4b8df0c147009f42a5563f1b9702dc", size = 32253, upload-time = "2025-08-12T02:26:49.112Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0c/0d/aa6065d7d59b3f10ff6818d527dada5a7179ac5643b666b6b6b71d11dab4/gepa-0.0.4.tar.gz", hash = "sha256:b3e020124c7d8a80c07595aca3b73647ec9151203d7166915ad62492b8459bd6", size = 32957, upload-time = "2025-08-14T05:08:36.792Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/58/9c/a1111dfc06e1f4b3adac51670b76b45b8385ed510f35bca6c002ca25eef8/gepa-0.0.2-py3-none-any.whl", hash = "sha256:9894ceefd53d46a6811cf8737b698c99f626777a773307dafc5f8e0038007e53", size = 34191, upload-time = "2025-08-12T02:26:47.625Z" }, + { url = "https://files.pythonhosted.org/packages/ce/c0/836c79f05113c96155e8de1bb8bf3631a9e7b3b75238c592d39460141ea8/gepa-0.0.4-py3-none-any.whl", hash = "sha256:53d275490d644855e90adf4eba1e3ace5c414c76ba0c0f22760b99a0e43984f9", size = 35191, upload-time = "2025-08-14T05:08:35.558Z" }, ] [[package]]