Skip to content

Langchain Integration #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added api/backend/ai/__init__.py
Empty file.
105 changes: 58 additions & 47 deletions api/backend/ai/ai_router.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,73 @@
# STL
import logging
from collections.abc import Iterable, AsyncGenerator
import asyncio
from typing import AsyncGenerator, List, Dict, Any, cast

# PDM
from ollama import Message
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from openai.types.chat import ChatCompletionMessageParam
from langchain_core.messages import BaseMessage
from langchain_core.runnables import RunnableConfig
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain_core.exceptions import LangChainException

# LOCAL
from api.backend.models import AI as AIRequestModel
from api.backend.ai.clients import (
llama_model,
open_ai_key,
llama_client,
open_ai_model,
openai_client,
llm_instance,
provider_info,
AI_PROVIDER_BACKEND,
convert_to_langchain_messages,
)
from api.backend.ai.schemas import AI
from api.backend.routers.handle_exceptions import handle_exceptions

LOG = logging.getLogger("AI")

ai_router = APIRouter()


async def llama_chat(chat_messages: list[Message]) -> AsyncGenerator[str, None]:
if llama_client and llama_model:
async def langchain_chat(messages: List[BaseMessage]) -> AsyncGenerator[str, None]:
if not llm_instance:
LOG.error("LLM instance not available")
yield "An error occurred: LLM not configured."
return

callback_handler = AsyncIteratorCallbackHandler()
run_config = RunnableConfig(callbacks=[callback_handler])

async def stream_llm_task():
try:
async for part in await llama_client.chat(
model=llama_model, messages=chat_messages, stream=True
):
yield part["message"]["content"]
async for _ in llm_instance.astream(messages, config=run_config):
pass # Callback handler processes the chunks
except LangChainException as e:
LOG.error(f"LangChain error during streaming: {e}")
raise
except Exception as e:
LOG.error(f"Error during chat: {e}")
yield "An error occurred while processing your request."


async def openai_chat(
chat_messages: Iterable[ChatCompletionMessageParam],
) -> AsyncGenerator[str, None]:
if openai_client and not open_ai_model:
LOG.error("OpenAI model is not set")
yield "An error occurred while processing your request."

if not openai_client:
LOG.error("OpenAI client is not set")
yield "An error occurred while processing your request."

if openai_client and open_ai_model:
try:
response = openai_client.chat.completions.create(
model=open_ai_model, messages=chat_messages, stream=True
)
for part in response:
yield part.choices[0].delta.content or ""
except Exception as e:
LOG.error(f"Error during OpenAI chat: {e}")
yield "An error occurred while processing your request."


chat_function = llama_chat if llama_client else openai_chat

LOG.error(f"Unexpected error during LLM streaming: {e}", exc_info=True)
raise
finally:
if not callback_handler.done.is_set():
callback_handler.done.set()

stream_task = asyncio.create_task(stream_llm_task())

try:
async for token in callback_handler.aiter():
yield token
except Exception as e:
LOG.error(f"Error in streaming response: {e}", exc_info=True)
yield f"Streaming error: {str(e)}"
finally:
if not stream_task.done():
stream_task.cancel()
try:
await stream_task
except asyncio.CancelledError:
LOG.debug("Stream task cancelled successfully")
except Exception as e:
LOG.error(f"Error during stream task cleanup: {e}")


chat_function = langchain_chat if llm_instance else None

@ai_router.post("/ai")
@handle_exceptions(logger=LOG)
Expand All @@ -73,4 +80,8 @@ async def ai(c: AI):
@ai_router.get("/ai/check")
@handle_exceptions(logger=LOG)
async def check():
return JSONResponse(content={"ai_enabled": bool(open_ai_key or llama_model)})
return JSONResponse(content={
"ai_system_enabled": bool(llm_instance and provider_info.get("configured", False)),
"configured_backend_provider": AI_PROVIDER_BACKEND,
"active_provider_details": provider_info
})
205 changes: 180 additions & 25 deletions api/backend/ai/clients.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,194 @@
# STL
import os
import logging
from typing import Optional, Dict, Any, List

# PDM
from ollama import AsyncClient
from openai import OpenAI
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import ChatOpenAI
from langchain_community.chat_models import ChatOllama

# Load environment variables
open_ai_key = os.getenv("OPENAI_KEY")
open_ai_model = os.getenv("OPENAI_MODEL")
llama_url = os.getenv("OLLAMA_URL")
llama_model = os.getenv("OLLAMA_MODEL")
LOG = logging.getLogger(__name__)

# Initialize clients
openai_client = OpenAI(api_key=open_ai_key) if open_ai_key else None
llama_client = AsyncClient(host=llama_url) if llama_url else None
# Environment variables
AI_PROVIDER_BACKEND = os.getenv("AI_PROVIDER_BACKEND")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_MODEL_NAME = os.getenv("OPENAI_MODEL_NAME")
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL")
OLLAMA_MODEL_NAME = os.getenv("OLLAMA_MODEL_NAME")
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
OPENROUTER_MODEL_NAME = os.getenv("OPENROUTER_MODEL_NAME")

# Global state
llm_instance: Optional[BaseChatModel] = None
provider_info: Dict[str, Any] = {
"name": "None",
"model": None,
"configured": False,
"error": None
}

async def ask_open_ai(prompt: str) -> str:
if not openai_client:
raise ValueError("OpenAI client not initialized")

response = openai_client.chat.completions.create(
model=open_ai_model or "gpt-4.1-mini",
messages=[{"role": "user", "content": prompt}],
)
class ChatOpenRouter(ChatOpenAI):
"""Custom OpenRouter client extending ChatOpenAI."""

def __init__(self, openai_api_key: Optional[str] = None, **kwargs):
api_key = openai_api_key or os.environ.get("OPENROUTER_API_KEY")
super().__init__(
base_url="https://openrouter.ai/api/v1",
openai_api_key=api_key,
**kwargs
)

return response.choices[0].message.content or ""

def create_openai_provider() -> tuple[Optional[BaseChatModel], Dict[str, Any]]:
"""Create OpenAI provider instance."""
if not OPENAI_API_KEY or not OPENAI_MODEL_NAME:
error_msg = "OpenAI API key or model name not provided."
return None, {"name": "OpenAI", "configured": False, "error": error_msg}

try:
llm = ChatOpenAI(
model=OPENAI_MODEL_NAME,
api_key=OPENAI_API_KEY,
streaming=True,
temperature=0.7,
)
info = {"name": "OpenAI", "model": OPENAI_MODEL_NAME, "configured": True}
LOG.info(f"Initialized OpenAI provider. Model: {OPENAI_MODEL_NAME}")
return llm, info
except Exception as e:
error_msg = f"Failed to initialize OpenAI provider: {e}"
LOG.error(error_msg)
return None, {"name": "OpenAI", "configured": False, "error": error_msg}

async def ask_ollama(prompt: str) -> str:
if not llama_client:
raise ValueError("Ollama client not initialized")

response = await llama_client.chat(
model=llama_model or "", messages=[{"role": "user", "content": prompt}]
)
def create_ollama_provider() -> tuple[Optional[BaseChatModel], Dict[str, Any]]:
if not OLLAMA_BASE_URL or not OLLAMA_MODEL_NAME:
error_msg = "Ollama base URL or model name not provided."
return None, {"name": "Ollama", "configured": False, "error": error_msg}

try:
llm = ChatOllama(
base_url=OLLAMA_BASE_URL,
model=OLLAMA_MODEL_NAME,
temperature=0.7,
)
info = {"name": "Ollama", "model": OLLAMA_MODEL_NAME, "configured": True}
LOG.info(f"Initialized Ollama provider. Model: {OLLAMA_MODEL_NAME}, URL: {OLLAMA_BASE_URL}")
return llm, info
except Exception as e:
error_msg = f"Failed to initialize Ollama provider: {e}"
LOG.error(error_msg)
return None, {"name": "Ollama", "configured": False, "error": error_msg}

return response.message.content or ""

def create_openrouter_provider() -> tuple[Optional[BaseChatModel], Dict[str, Any]]:
if not OPENROUTER_API_KEY or not OPENROUTER_MODEL_NAME:
error_msg = "OpenRouter API key or model name not provided."
return None, {"name": "OpenRouter", "configured": False, "error": error_msg}

try:
llm = ChatOpenRouter(
model=OPENROUTER_MODEL_NAME,
openai_api_key=OPENROUTER_API_KEY,
streaming=True,
temperature=0.7,
)
info = {"name": "OpenRouter", "model": OPENROUTER_MODEL_NAME, "configured": True}
LOG.info(f"Initialized OpenRouter provider. Model: {OPENROUTER_MODEL_NAME}")
return llm, info
except Exception as e:
error_msg = f"Failed to initialize OpenRouter provider: {e}"
LOG.error(error_msg)
return None, {"name": "OpenRouter", "configured": False, "error": error_msg}


def initialize_ai_provider() -> None:
global llm_instance, provider_info

if not AI_PROVIDER_BACKEND:
provider_info.update({"configured": False, "error": "No AI provider specified"})
return

LOG.info(f"Initializing AI provider: {AI_PROVIDER_BACKEND}")

provider_factories = {
"openai": create_openai_provider,
"ollama": create_ollama_provider,
"openrouter": create_openrouter_provider,
}

factory = provider_factories.get(AI_PROVIDER_BACKEND)
if not factory:
error_msg = f"Unsupported AI provider: {AI_PROVIDER_BACKEND}"
LOG.error(error_msg)
provider_info.update({"configured": False, "error": error_msg})
return

try:
llm_instance, provider_info = factory()
except ImportError as e:
error_msg = f"Missing dependencies for {AI_PROVIDER_BACKEND}: {e}"
LOG.error(error_msg)
provider_info.update({"configured": False, "error": error_msg})
except Exception as e:
error_msg = f"Unexpected error initializing {AI_PROVIDER_BACKEND}: {e}"
LOG.error(error_msg, exc_info=True)
provider_info.update({"configured": False, "error": error_msg})


def convert_to_langchain_messages(messages: List[Dict[str, Any]]) -> List[BaseMessage]:
lc_messages: List[BaseMessage] = []

for msg_dict in messages:
role = str(msg_dict.get("role", "user")).lower()
content = str(msg_dict.get("content", ""))

if role == "user":
lc_messages.append(HumanMessage(content=content))
elif role in ("assistant", "ai"):
lc_messages.append(AIMessage(content=content))
elif role == "system":
lc_messages.append(SystemMessage(content=content))
else:
LOG.warning(f"Unknown message role '{role}', treating as user message")
lc_messages.append(HumanMessage(content=content))

return lc_messages


async def ask_llm(prompt: str) -> str:
"""Simple non-streaming LLM query (similar to your old ask_open_ai/ask_ollama)."""
if not llm_instance:
raise ValueError("LLM client not initialized")

try:
messages = [HumanMessage(content=prompt)]
response = await llm_instance.ainvoke(messages)

content = response.content
if not content:
return ""

if not isinstance(content, list):
return str(content)

text_parts: List[str] = []
for item in content:
if not isinstance(item, dict):
text_parts.append(str(item))
else:
text_value = item.get("text")
if text_value is not None:
text_parts.append(str(text_value))

return " ".join(text_parts) if text_parts else ""

except Exception as e:
LOG.error(f"Error in LLM query: {e}")
raise


initialize_ai_provider()
7 changes: 7 additions & 0 deletions docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ services:
dockerfile: docker/api/Dockerfile
environment:
- LOG_LEVEL=INFO
- AI_PROVIDER_BACKEND= # openai | ollama | openrouter | (leave empty to disable ai)
- OPENAI_API_KEY=''
- OPENAI_MODEL_NAME='gpt4o'
- OLLAMA_BASE_URL=http://localhost:11434
- OLLAMA_MODEL_NAME=llama2
- OPENROUTER_API_KEY=''
- OPENROUTER_MODEL_NAME='openai/gpt-4'
volumes:
- "$PWD/api:/project/app/api"
ports:
Expand Down
8 changes: 7 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@ services:
image: jpyles0524/scraperr_api:latest
environment:
- LOG_LEVEL=INFO
- OPENAI_KEY=${OPENAI_KEY}
- AI_PROVIDER_BACKEND= # openai | ollama | openrouter | (leave empty to disable ai)
- OPENAI_API_KEY=''
- OPENAI_MODEL_NAME='gpt4o'
- OLLAMA_BASE_URL=http://localhost:11434
- OLLAMA_MODEL_NAME=llama2
- OPENROUTER_API_KEY=''
- OPENROUTER_MODEL_NAME='openai/gpt-4'
container_name: scraperr_api
ports:
- 8000:8000
Expand Down
Loading
Loading