Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dspy/clients/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from typing import Any

import cloudpickle
import orjson
import pydantic
import ujson
from cachetools import LRUCache
from diskcache import FanoutCache

Expand Down Expand Up @@ -93,7 +93,7 @@ def transform_value(value):
return value

params = {k: transform_value(v) for k, v in request.items() if k not in ignored_args_for_cache_key}
return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest()
return sha256(orjson.dumps(params, option=orjson.OPT_SORT_KEYS)).hexdigest()

def get(self, request: dict[str, Any], ignored_args_for_cache_key: list[str] | None = None) -> Any:
try:
Expand Down
9 changes: 4 additions & 5 deletions dspy/clients/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import time
from typing import TYPE_CHECKING, Any

import orjson
import requests
import ujson

from dspy.clients.provider import Provider, TrainingJob
from dspy.clients.utils_finetune import TrainDataFormat, get_finetune_directory
Expand Down Expand Up @@ -265,8 +265,7 @@ def _get_workspace_client() -> "WorkspaceClient":
from databricks.sdk import WorkspaceClient
except ImportError:
raise ImportError(
"To use Databricks finetuning, please install the databricks-sdk package via "
"`pip install databricks-sdk`."
"To use Databricks finetuning, please install the databricks-sdk package via `pip install databricks-sdk`."
)
return WorkspaceClient()

Expand Down Expand Up @@ -311,14 +310,14 @@ def _save_data_to_local_file(train_data: list[dict[str, Any]], data_format: Trai
finetune_dir = get_finetune_directory()
file_path = os.path.join(finetune_dir, file_name)
file_path = os.path.abspath(file_path)
with open(file_path, "w") as f:
with open(file_path, "wb") as f:
for item in train_data:
if data_format == TrainDataFormat.CHAT:
_validate_chat_data(item)
elif data_format == TrainDataFormat.COMPLETION:
_validate_completion_data(item)

f.write(ujson.dumps(item) + "\n")
f.write(orjson.dumps(item) + b"\n")
return file_path


Expand Down
10 changes: 5 additions & 5 deletions dspy/clients/utils_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from typing import Any, Literal, TypedDict

import ujson
import orjson

import dspy
from dspy.adapters.base import Adapter
Expand Down Expand Up @@ -58,9 +58,9 @@ def get_finetune_directory() -> str:


def write_lines(file_path, data):
with open(file_path, "w") as f:
with open(file_path, "wb") as f:
for item in data:
f.write(ujson.dumps(item) + "\n")
f.write(orjson.dumps(item) + b"\n")


def save_data(
Expand All @@ -75,9 +75,9 @@ def save_data(
finetune_dir = get_finetune_directory()
file_path = os.path.join(finetune_dir, file_name)
file_path = os.path.abspath(file_path)
with open(file_path, "w") as f:
with open(file_path, "wb") as f:
for item in data:
f.write(ujson.dumps(item) + "\n")
f.write(orjson.dumps(item) + b"\n")
return file_path


Expand Down
7 changes: 5 additions & 2 deletions dspy/predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def reset(self):
self.train = []
self.demos = []

def dump_state(self):
def dump_state(self, json_mode=True):
state_keys = ["traces", "train"]
state = {k: getattr(self, k) for k in state_keys}

Expand All @@ -42,7 +42,10 @@ def dump_state(self):
# FIXME: Saving BaseModels as strings in examples doesn't matter because you never re-access as an object
demo[field] = serialize_object(demo[field])

state["demos"].append(demo)
if isinstance(demo, dict) or not json_mode:
state["demos"].append(demo)
else:
state["demos"].append(demo.toDict())

state["signature"] = self.signature.dump_state()
state["lm"] = self.lm.dump_state() if self.lm else None
Expand Down
7 changes: 3 additions & 4 deletions dspy/predict/refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import textwrap
from typing import Callable

import ujson
import orjson

import dspy
from dspy.adapters.utils import get_field_description_string
Expand Down Expand Up @@ -158,10 +158,9 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs):
}

advise_kwargs = dict(**modules, **trajectory, **reward, module_names=module_names)
# advise_kwargs = {k: ujson.dumps(recursive_mask(v), indent=2) for k, v in advise_kwargs.items()}
# only dumps if it's a list or dict
advise_kwargs = {
k: v if isinstance(v, str) else ujson.dumps(recursive_mask(v), indent=2)
k: v if isinstance(v, str) else orjson.dumps(recursive_mask(v), option=orjson.OPT_INDENT_2).decode()
for k, v in advise_kwargs.items()
}
advice = dspy.Predict(OfferFeedback)(**advise_kwargs).advice
Expand Down Expand Up @@ -200,7 +199,7 @@ def inspect_modules(program):
def recursive_mask(o):
# If the object is already serializable, return it.
try:
ujson.dumps(o)
orjson.dumps(o)
return o
except TypeError:
pass
Expand Down
32 changes: 17 additions & 15 deletions dspy/primitives/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pathlib import Path

import cloudpickle
import ujson
import orjson

from dspy.utils.saving import get_dependency_versions

Expand Down Expand Up @@ -153,8 +153,8 @@ def reset_copy(self):

return new_instance

def dump_state(self):
return {name: param.dump_state() for name, param in self.named_parameters()}
def dump_state(self, json_mode=True):
return {name: param.dump_state(json_mode=json_mode) for name, param in self.named_parameters()}

def load_state(self, state):
for name, param in self.named_parameters():
Expand All @@ -169,10 +169,10 @@ def save(self, path, save_program=False, modules_to_serialize=None):
- `save_program=True`: Save the whole module to a directory via cloudpickle, which contains both the state and
architecture of the model.

If `save_program=True` and `modules_to_serialize` are provided, it will register those modules for serialization
with cloudpickle's `register_pickle_by_value`. This causes cloudpickle to serialize the module by value rather
than by reference, ensuring the module is fully preserved along with the saved program. This is useful
when you have custom modules that need to be serialized alongside your program. If None, then no modules
If `save_program=True` and `modules_to_serialize` are provided, it will register those modules for serialization
with cloudpickle's `register_pickle_by_value`. This causes cloudpickle to serialize the module by value rather
than by reference, ensuring the module is fully preserved along with the saved program. This is useful
when you have custom modules that need to be serialized alongside your program. If None, then no modules
will be registered for serialization.

We also save the dependency versions, so that the loaded model can check if there is a version mismatch on
Expand Down Expand Up @@ -215,24 +215,26 @@ def save(self, path, save_program=False, modules_to_serialize=None):
f"Saving failed with error: {e}. Please remove the non-picklable attributes from your DSPy program, "
"or consider using state-only saving by setting `save_program=False`."
)
with open(path / "metadata.json", "w", encoding="utf-8") as f:
ujson.dump(metadata, f, indent=2, ensure_ascii=False)
with open(path / "metadata.json", "wb") as f:
f.write(orjson.dumps(metadata, option=orjson.OPT_INDENT_2 | orjson.OPT_APPEND_NEWLINE))

return

state = self.dump_state()
state["metadata"] = metadata
if path.suffix == ".json":
state = self.dump_state()
state["metadata"] = metadata
try:
with open(path, "w", encoding="utf-8") as f:
f.write(ujson.dumps(state, indent=2 , ensure_ascii=False))
with open(path, "wb") as f:
f.write(orjson.dumps(state, option=orjson.OPT_INDENT_2 | orjson.OPT_APPEND_NEWLINE))
except Exception as e:
raise RuntimeError(
f"Failed to save state to {path} with error: {e}. Your DSPy program may contain non "
"json-serializable objects, please consider saving the state in .pkl by using `path` ending "
"with `.pkl`, or saving the whole program by setting `save_program=True`."
)
elif path.suffix == ".pkl":
state = self.dump_state(json_mode=False)
state["metadata"] = metadata
with open(path, "wb") as f:
cloudpickle.dump(state, f)
else:
Expand All @@ -248,8 +250,8 @@ def load(self, path):
path = Path(path)

if path.suffix == ".json":
with open(path, encoding="utf-8") as f:
state = ujson.loads(f.read())
with open(path, "rb") as f:
state = orjson.loads(f.read())
elif path.suffix == ".pkl":
with open(path, "rb") as f:
state = cloudpickle.load(f)
Expand Down
16 changes: 15 additions & 1 deletion dspy/primitives/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,18 @@ def without(self, *keys):
return copied

def toDict(self): # noqa: N802
return self._store.copy()
def convert_to_serializable(value):
if hasattr(value, "toDict"):
return value.toDict()
elif isinstance(value, list):
return [convert_to_serializable(item) for item in value]
elif isinstance(value, dict):
return {k: convert_to_serializable(v) for k, v in value.items()}
else:
return value

serializable_store = {}
for k, v in self._store.items():
serializable_store[k] = convert_to_serializable(v)

return serializable_store
6 changes: 3 additions & 3 deletions dspy/streaming/streamify.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator

import litellm
import ujson
import orjson
from anyio import create_memory_object_stream, create_task_group
from anyio.streams.memory import MemoryObjectSendStream
from litellm import ModelResponseStream
Expand Down Expand Up @@ -261,10 +261,10 @@ async def streaming_response(streamer: AsyncGenerator) -> AsyncGenerator:
async for value in streamer:
if isinstance(value, Prediction):
data = {"prediction": dict(value.items(include_dspy=False))}
yield f"data: {ujson.dumps(data)}\n\n"
yield f"data: {orjson.dumps(data).decode()}\n\n"
elif isinstance(value, litellm.ModelResponseStream):
data = {"chunk": value.json()}
yield f"data: {ujson.dumps(data)}\n\n"
yield f"data: {orjson.dumps(data).decode()}\n\n"
elif isinstance(value, str) and value.startswith("data:"):
# The chunk value is an OpenAI-compatible streaming chunk value,
# e.g. "data: {"finish_reason": "stop", "index": 0, "is_finished": True, ...}",
Expand Down
8 changes: 4 additions & 4 deletions dspy/teleprompt/simba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import textwrap
from typing import Callable

import ujson
import orjson

import dspy
from dspy.adapters.utils import get_field_description_string
Expand Down Expand Up @@ -120,7 +120,7 @@ def append_a_rule(bucket, system, **kwargs):
"module_names": module_names,
}

kwargs = {k: v if isinstance(v, str) else ujson.dumps(recursive_mask(v), indent=2)
kwargs = {k: v if isinstance(v, str) else orjson.dumps(recursive_mask(v), option=orjson.OPT_INDENT_2).decode()
for k, v in kwargs.items()}
advice = dspy.Predict(OfferFeedback)(**kwargs).module_advice

Expand Down Expand Up @@ -194,9 +194,9 @@ def inspect_modules(program):
def recursive_mask(o):
# If the object is already serializable, return it.
try:
ujson.dumps(o)
orjson.dumps(o)
return o
except TypeError:
except (TypeError, orjson.JSONEncodeError):
pass

# If it's a dictionary, apply recursively to its values.
Expand Down
4 changes: 2 additions & 2 deletions dspy/utils/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING

import cloudpickle
import ujson
import orjson

if TYPE_CHECKING:
from dspy.primitives.module import Module
Expand Down Expand Up @@ -40,7 +40,7 @@ def load(path: str) -> "Module":
raise FileNotFoundError(f"The path '{path}' does not exist.")

with open(path / "metadata.json") as f:
metadata = ujson.load(f)
metadata = orjson.loads(f.read())

dependency_versions = get_dependency_versions()
saved_dependency_versions = metadata["dependency_versions"]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [
"joblib~=1.3",
"openai>=0.28.1",
"regex>=2023.10.3",
"ujson>=5.8.0",
"orjson>=3.9.0",
"tqdm>=4.66.1",
"requests>=2.31.0",
"optuna>=3.4.0",
Expand Down
18 changes: 9 additions & 9 deletions tests/predict/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from datetime import datetime
from unittest.mock import patch

import orjson
import pydantic
import pytest
import ujson
from litellm import ModelResponse
from pydantic import BaseModel, HttpUrl

Expand Down Expand Up @@ -105,8 +105,8 @@ class TranslateToEnglish(dspy.Signature):
assert len(dumped_state["demos"]) == len(original_instance.demos)
assert dumped_state["demos"][0]["content"] == original_instance.demos[0].content

saved_state = ujson.dumps(dumped_state)
loaded_state = ujson.loads(saved_state)
saved_state = orjson.dumps(dumped_state).decode()
loaded_state = orjson.loads(saved_state)

new_instance = Predict(TranslateToEnglish)
new_instance.load_state(loaded_state)
Expand Down Expand Up @@ -147,8 +147,8 @@ class InventorySignature(dspy.Signature):
assert dumped_state["demos"][0]["items"][0] == {"name": "apple", "quantity": 5}

# Test serialization/deserialization
saved_state = ujson.dumps(dumped_state)
loaded_state = ujson.loads(saved_state)
saved_state = orjson.dumps(dumped_state).decode()
loaded_state = orjson.loads(saved_state)

# Test load_state
new_instance = Predict(InventorySignature)
Expand Down Expand Up @@ -711,17 +711,17 @@ class TestSignature(dspy.Signature):
assert serialized["url"] == "https://www.example.com/"
assert serialized["created_at"] == "2021-01-01T12:00:00"

json_str = ujson.dumps(serialized)
reloaded = ujson.loads(json_str)
json_str = orjson.dumps(serialized).decode()
reloaded = orjson.loads(json_str)
assert reloaded == serialized

predictor = Predict(TestSignature)
demo = {"website_info": website_info, "summary": "This is a test website."}
predictor.demos = [demo]

state = predictor.dump_state()
json_str = ujson.dumps(state)
reloaded_state = ujson.loads(json_str)
json_str = orjson.dumps(state).decode()
reloaded_state = orjson.loads(json_str)

demo_data = reloaded_state["demos"][0]
assert demo_data["website_info"]["url"] == "https://www.example.com/"
Expand Down
12 changes: 11 additions & 1 deletion tests/primitives/test_base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,16 @@ def test_save_and_load_with_json(tmp_path):
model = dspy.ChainOfThought(dspy.Signature("q -> a"))
model.predict.signature = model.predict.signature.with_instructions("You are a helpful assistant.")
model.predict.demos = [
dspy.Example(q="What is the capital of France?", a="Paris", reasoning="n/a").with_inputs("q", "a")
dspy.Example(q="What is the capital of France?", a="Paris", reasoning="n/a").with_inputs("q"),
# Nested example
dspy.Example(
q=[
dspy.Example(q="What is the capital of France?"),
dspy.Example(q="What is actually the capital of France?"),
],
a="Paris",
reasoning="n/a",
).with_inputs("q"),
]
save_path = tmp_path / "model.json"
model.save(save_path)
Expand All @@ -71,6 +80,7 @@ def test_save_and_load_with_json(tmp_path):

assert str(new_model.predict.signature) == str(model.predict.signature)
assert new_model.predict.demos[0] == model.predict.demos[0].toDict()
assert new_model.predict.demos[1] == model.predict.demos[1].toDict()


@pytest.mark.extra
Expand Down
Loading