diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 0198591302..df8e6635b5 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -1,3 +1,4 @@ +from copy import deepcopy import logging import os import re @@ -300,15 +301,49 @@ async def async_stream_completion(): return async_stream_completion +def _modify_ollama_request(request): + """Draft version to make it work on my machine / on my test""" + messages = request["messages"] + request["messages"] = [] + for message in messages: + new_message = dict( + role=message["role"], + ) + content = message["content"] + if isinstance(content, str): + new_message["content"] = content + request["messages"].append(new_message) + elif isinstance(content, list): + contents = message["content"] + new_message["content"] = [] + for part in contents: + if part["type"] == "text": + new_message["content"].append(part["text"]) + elif part["type"] == "image_url": + if "images" not in new_message: + new_message["images"] = [] + new_message["images"].append(part["image_url"]["url"].split(",", 1)[1]) + new_message["content"] = " ".join(new_message["content"]) + request["messages"].append(new_message) + else: + NotImplementedError("Such content type is not supported") + return request + + def litellm_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): cache = cache or {"no-cache": True, "no-store": True} stream_completion = _get_stream_completion_fn(request, cache, sync=True) + if request["model"].split("/")[0] == "ollama_chat": + modified_request = _modify_ollama_request(deepcopy(request)) + else: + modified_request = request + if stream_completion is None: return litellm.completion( cache=cache, num_retries=num_retries, retry_strategy="exponential_backoff_retry", - **request, + **modified_request, ) return stream_completion()