Skip to content

Fix PredictModel to return correct output fields #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 31 additions & 24 deletions cognify/frontends/dspy/connector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dspy
from dspy.adapters.chat_adapter import ChatAdapter, prepare_instructions
from cognify.llm import Model, StructuredModel, Input, OutputFormat
from cognify.llm import Model, StructuredModel, Input, OutputFormat, OutputLabel
from cognify.llm.model import LMConfig
from pydantic import BaseModel, create_model
from typing import Any, Dict, Type
Expand Down Expand Up @@ -41,48 +41,62 @@ def cognify_predictor(

if not isinstance(dspy_predictor, dspy.Predict):
warnings.warn(
"Original module is not a `Predict`. This may result in lossy translation",
"Original module is NOT a `dspy.Predict`. This may result in lossy translation",
UserWarning,
)

if isinstance(dspy_predictor, dspy.Retrieve):
warnings.warn(
"Original module is a `Retrieve`. This will be ignored", UserWarning
"Original module is a `dspy.Retrieve`. This will be ignored", UserWarning
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this retrive mean? why ignore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dspy.Retrieve is their retriever. They still consider retrieval to be a "module", and since we don't optimize retrieval, I just ignore it. This means whenever it gets called in the actual workflow, we will call the original retrieve module.

)
self.ignore_module = True
return None

# initialize cog lm
system_prompt = prepare_instructions(dspy_predictor.signature)
input_names = list(dspy_predictor.signature.input_fields.keys())
input_variables = [Input(name=input_name) for input_name in input_names]

output_fields = dspy_predictor.signature.output_fields
if "reasoning" in output_fields:
del output_fields["reasoning"]
# stripping the reasoning field may crash their workflow, so we warn users instead
warnings.warn(
"Original module contained reasoning. This will be stripped. Add reasoning as a cog instead",
f"Cognify performs its own reasoning prompt optimization automatically. Consider using `dspy.Predict` for module '{name}' instead of `dspy.ChainOfThought`",
UserWarning,
)
output_fields_for_schema = {k: v.annotation for k, v in output_fields.items()}
self.output_schema = generate_pydantic_model(
"OutputData", output_fields_for_schema
)
system_prompt = prepare_instructions(dspy_predictor.signature)

# lm config
lm_client: dspy.LM = dspy.settings.get("lm", None)

assert lm_client, "Expected lm to be configured in dspy"
lm_config = LMConfig(model=lm_client.model, kwargs=lm_client.kwargs)

# always treat as structured to provide compatiblity with forward function
return StructuredModel(
# treat as cognify.Model, allow dspy to handle output parsing
return Model(
agent_name=name,
system_prompt=system_prompt,
input_variables=input_variables,
output_format=OutputFormat(schema=self.output_schema),
lm_config=lm_config,
output=OutputLabel("llm_output"),
lm_config=lm_config
)

def construct_messages(self, inputs):
messages = None
if self.predictor:
messages: APICompatibleMessage = self.chat_adapter.format(
self.predictor.signature, self.predictor.demos, inputs
)
return messages

def parse_output(self, result):
values = []

# from dspy chat adapter __call__
value = self.chat_adapter.parse(self.predictor.signature, result, _parse_values=True)
assert set(value.keys()) == set(self.predictor.signature.output_fields.keys()), f"Expected {self.predictor.signature.output_fields.keys()} but got {value.keys()}"
values.append(value)

return values

def forward(self, **kwargs):
assert (
Expand All @@ -95,19 +109,12 @@ def forward(self, **kwargs):
inputs: Dict[str, str] = {
k.name: kwargs[k.name] for k in self.cog_lm.input_variables
}
messages = None
if self.predictor:
messages: APICompatibleMessage = self.chat_adapter.format(
self.predictor.signature, self.predictor.demos, inputs
)
messages = self.construct_messages(inputs)
result = self.cog_lm(
messages, inputs
) # kwargs have already been set when initializing cog_lm
kwargs: dict = result.model_dump()
for k,v in kwargs.items():
if not v:
raise ValueError(f"{self.cog_lm.name} did not generate a value for field `{k}`, consider using a larger model for structured output")
return dspy.Prediction(**kwargs)
completions = self.parse_output(result)
return dspy.Prediction.from_completions(completions, signature=self.predictor.signature)


def as_predict(cog_lm: Model) -> PredictModel:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/user_guide/tutorials/interface/dspy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ In DSPy, the :code:`dspy.Predict` class is the primary abstraction for obtaining

.. tip::

DSPy also contains other, more detailed modules that don't follow the behavior of :code:`dspy.Predict` (e.g., :code:`dspy.ChainOfThought`). In Cognify, we view Chain-of-Thought prompting (and other similar techniques) as possible optimizations to apply to an LLM call on the fly instead of as pre-defined templates. Hence, during the translation process we will strip the "reasoning" step out of the predictor definition and leave it to the optimizer.
DSPy also contains other, more detailed modules that don't follow the behavior of :code:`dspy.Predict` (e.g., :code:`dspy.ChainOfThought`). In Cognify, we view Chain-of-Thought prompting (and other similar techniques) as possible optimizations to apply to an LLM call on the fly instead of as pre-defined templates.

By default, Cognify will translate **all** predictors into valid optimization targets. For more fine-grained control over which predictors should be targeted for optimization, you can manually wrap your predictor with our :code:`cognify.PredictModel` class like so:

Expand Down