Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 58 additions & 48 deletions dspy/retrievers/databricks_rm.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,16 @@
import json
import os
from dataclasses import dataclass
from importlib.util import find_spec
from typing import Any

import requests

import dspy
from dspy.primitives.prediction import Prediction
from dspy.retrievers.retrieve import Document

_databricks_sdk_installed = find_spec("databricks.sdk") is not None


@dataclass
class Document:
page_content: str
metadata: dict[str, Any]
type: str

def to_dict(self) -> dict[str, Any]:
return {
"page_content": self.page_content,
"metadata": self.metadata,
"type": self.type,
}


class DatabricksRM(dspy.Retrieve):
"""
A retriever module that uses a Databricks Mosaic AI Vector Search Index to return the top-k
Expand Down Expand Up @@ -129,19 +114,30 @@ def __init__(
compatible with the Databricks Mosaic Agent Framework.
"""
super().__init__(k=k)
self.databricks_token = databricks_token if databricks_token is not None else os.environ.get("DATABRICKS_TOKEN")
self.databricks_token = (
databricks_token
if databricks_token is not None
else os.environ.get("DATABRICKS_TOKEN")
)
self.databricks_endpoint = (
databricks_endpoint if databricks_endpoint is not None else os.environ.get("DATABRICKS_HOST")
databricks_endpoint
if databricks_endpoint is not None
else os.environ.get("DATABRICKS_HOST")
)
self.databricks_client_id = (
databricks_client_id if databricks_client_id is not None else os.environ.get("DATABRICKS_CLIENT_ID")
databricks_client_id
if databricks_client_id is not None
else os.environ.get("DATABRICKS_CLIENT_ID")
)
self.databricks_client_secret = (
databricks_client_secret
if databricks_client_secret is not None
else os.environ.get("DATABRICKS_CLIENT_SECRET")
)
if not _databricks_sdk_installed and (self.databricks_token, self.databricks_endpoint).count(None) > 0:
if (
not _databricks_sdk_installed
and (self.databricks_token, self.databricks_endpoint).count(None) > 0
):
raise ValueError(
"To retrieve documents with Databricks Vector Search, you must install the"
" databricks-sdk Python library, supply the databricks_token and"
Expand Down Expand Up @@ -196,12 +192,23 @@ def _get_extra_columns(self, item: dict[str, Any]) -> dict[str, Any]:
extra_columns = {
k: v
for k, v in item.items()
if k not in [self.docs_id_column_name, self.text_column_name, self.docs_uri_column_name]
if k
not in [
self.docs_id_column_name,
self.text_column_name,
self.docs_uri_column_name,
]
}
if self.docs_id_column_name == "metadata":
extra_columns = {
**extra_columns,
**{"metadata": {k: v for k, v in json.loads(item["metadata"]).items() if k != "document_id"}},
**{
"metadata": {
k: v
for k, v in json.loads(item["metadata"]).items()
if k != "document_id"
}
},
}
return extra_columns

Expand All @@ -210,7 +217,7 @@ def forward(
query: str | list[float],
query_type: str = "ANN",
filters_json: str | None = None,
) -> dspy.Prediction | list[dict[str, Any]]:
) -> list[Document]:
"""
Retrieve documents from a Databricks Mosaic AI Vector Search Index that are relevant to the
specified query.
Expand Down Expand Up @@ -286,7 +293,9 @@ def forward(
)

if self.text_column_name not in col_names:
raise Exception(f"text_column_name: '{self.text_column_name}' is not in the index columns: \n {col_names}")
raise Exception(
f"text_column_name: '{self.text_column_name}' is not in the index columns: \n {col_names}"
)

# Extracting the results
items = []
Expand All @@ -300,27 +309,22 @@ def forward(
# Sorting results by score in descending order
sorted_docs = sorted(items, key=lambda x: x["score"], reverse=True)[: self.k]

if self.use_with_databricks_agent_framework:
return [
Document(
page_content=doc[self.text_column_name],
metadata={
"doc_id": self._extract_doc_ids(doc),
"doc_uri": doc[self.docs_uri_column_name] if self.docs_uri_column_name else None,
}
| self._get_extra_columns(doc),
type="Document",
).to_dict()
for doc in sorted_docs
]
else:
# Returning the prediction
return Prediction(
docs=[doc[self.text_column_name] for doc in sorted_docs],
doc_ids=[self._extract_doc_ids(doc) for doc in sorted_docs],
doc_uris=[doc[self.docs_uri_column_name] for doc in sorted_docs] if self.docs_uri_column_name else None,
extra_columns=[self._get_extra_columns(item) for item in sorted_docs],
)
return [
Document(
page_content=doc[self.text_column_name],
metadata={
"doc_id": self._extract_doc_ids(doc),
"doc_uri": (
doc[self.docs_uri_column_name]
if self.docs_uri_column_name
else None
),
}
| self._get_extra_columns(doc),
type="Document",
).to_dict()
for doc in sorted_docs
]

@staticmethod
def _query_via_databricks_sdk(
Expand Down Expand Up @@ -365,15 +369,19 @@ def _query_via_databricks_sdk(
from databricks.sdk import WorkspaceClient

if (query_text, query_vector).count(None) != 1:
raise ValueError("Exactly one of query_text or query_vector must be specified.")
raise ValueError(
"Exactly one of query_text or query_vector must be specified."
)

if databricks_client_secret and databricks_client_id:
# Use client ID and secret for authentication if they are provided
databricks_client = WorkspaceClient(
client_id=databricks_client_id,
client_secret=databricks_client_secret,
)
print("Creating Databricks workspace client using service principal authentication.")
print(
"Creating Databricks workspace client using service principal authentication."
)

else:
# Fallback for token-based authentication
Expand Down Expand Up @@ -424,7 +432,9 @@ def _query_via_requests(
dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query.
"""
if (query_text, query_vector).count(None) != 1:
raise ValueError("Exactly one of query_text or query_vector must be specified.")
raise ValueError(
"Exactly one of query_text or query_vector must be specified."
)

headers = {
"Authorization": f"Bearer {databricks_token}",
Expand Down
44 changes: 41 additions & 3 deletions dspy/retrievers/retrieve.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
import random
from dataclasses import dataclass
from typing import Any

from dspy.predict.parameter import Parameter
from dspy.primitives.prediction import Prediction
from dspy.utils.callback import with_callbacks


@dataclass
class Document:
page_content: str
metadata: dict[str, Any]
type: str

def to_dict(self) -> dict[str, Any]:
return {
"page_content": self.page_content,
"metadata": self.metadata,
"type": self.type,
}


def single_query_passage(passages):
passages_dict = {key: [] for key in list(passages[0].keys())}
for docs in passages:
Expand Down Expand Up @@ -37,15 +53,16 @@ def load_state(self, state):
setattr(self, name, value)

@with_callbacks
def __call__(self, *args, **kwargs):
def __call__(self, *args, **kwargs) -> Prediction | list[Document]:
return self.forward(*args, **kwargs)

def forward(
self,
query: str,
k: int | None = None,
return_documents: bool = False,
**kwargs,
) -> list[str] | Prediction | list[Prediction]:
) -> Prediction | list[Document]:
k = k if k is not None else self.k

import dspy
Expand All @@ -56,12 +73,33 @@ def forward(
passages = dspy.settings.rm(query, k=k, **kwargs)

from collections.abc import Iterable

if not isinstance(passages, Iterable):
# it's not an iterable yet; make it one.
# TODO: we should unify the type signatures of dspy.Retriever
passages = [passages]
passages = [psg.long_text for psg in passages]

docs: list[Document] = []
for psg in passages:
if isinstance(psg, Document):
docs.append(psg)
elif isinstance(psg, dict):
page_content = psg.get("page_content", psg.get("long_text", ""))
# support text OR long_text
metadata = psg.get("metadata", {})
_type = psg.get("type", "Document")
docs.append(
Document(page_content=page_content, metadata=metadata, type=_type)
)
elif isinstance(psg, str):
docs.append(Document(page_content=psg, metadata={}, type="Document"))

if return_documents:
return docs

passages = [psg.page_content for psg in docs]

return Prediction(passages=passages)


# TODO: Consider doing Prediction.from_completions with the individual sets of passages (per query) too.
Loading