diff --git a/app.py b/app.py index c2dbcb2..da0d1ae 100644 --- a/app.py +++ b/app.py @@ -19,7 +19,7 @@ from slowapi.util import get_remote_address from prometheus_fastapi_instrumentator import Instrumentator from PIL import Image -from utils.data_types import APIKey, Prompt, TextPrompt, TextToImage, ImageToImage, UserSigninInfo, ValidatorInfo, ChatCompletion +from utils.data_types import APIKey, Prompt, TextPrompt, TextToImage, ImageToImage, UserSigninInfo, ValidatorInfo, ChatCompletion, MultimodalPrompt from utils.db_client import MongoDBHandler from fastapi.middleware.cors import CORSMiddleware from transformers import AutoTokenizer @@ -52,6 +52,8 @@ def pil_image_to_base64(image: Image.Image, format="JPEG") -> str: MONGOHOST = os.getenv("MONGOHOST", "localhost") MONGOPORT = os.getenv("MONGOPORT", 27017) +SUPPORTED_MULTIMODAL_MODELS = ["Pixtral_12b"] + # Define a list of allowed origins (domains) allowed_origins = [ "http://localhost:3000", # Change this to the domain you want to allow @@ -586,20 +588,47 @@ async def chat_completions(self, request: Request, data: ChatCompletion): model_list = self.dbhandler.model_config.find_one({"name": "model_list"})["data"] if data.model not in model_list: raise HTTPException(status_code=404, detail="Model not found") - messages_str = self.tokenizers[data.model].apply_chat_template(data.messages, tokenize=False) - print(f"Chat message str: {messages_str}", flush=True) - generate_data = { - "key": api_key, - "prompt_input": messages_str, - "model_name": data.model, - "pipeline_params": { - "temperature": data.temperature, - "top_p": data.top_p, - "max_tokens": data.max_tokens + elif data.model in SUPPORTED_MULTIMODAL_MODELS: + image_url, prompt = "", "" + last_message = data.messages[-1] + content = last_message["content"] + if isinstance(content, list): + for cnt in content: + if cnt["type"] == "image_url": + image_url = cnt["image_url"]["url"] + elif cnt["type"] == "text": + prompt = cnt["text"] + else: + prompt = content + generate_data = { + "key": api_key, + "prompt": prompt, + "image_url": image_url, + "model_name": data.model, + "pipeline_params": { + "temperature": data.temperature, + "top_p": data.top_p, + "max_tokens": data.max_tokens, + "logprobs": 1 + } } - } - response = await self.generate(TextPrompt(**generate_data)) - return response['prompt_output'] + response = await self.generate(MultimodalPrompt(**generate_data)) + return response['prompt_output'] + else: + messages_str = self.tokenizers[data.model].apply_chat_template(data.messages, tokenize=False) + print(f"Chat message str: {messages_str}", flush=True) + generate_data = { + "key": api_key, + "prompt_input": messages_str, + "model_name": data.model, + "pipeline_params": { + "temperature": data.temperature, + "top_p": data.top_p, + "max_tokens": data.max_tokens + } + } + response = await self.generate(TextPrompt(**generate_data)) + return response['prompt_output'] def base64_to_pil_image(self, base64_image): image = base64.b64decode(base64_image) diff --git a/utils/data_types.py b/utils/data_types.py index ac69764..bce87b4 100644 --- a/utils/data_types.py +++ b/utils/data_types.py @@ -18,6 +18,14 @@ class TextPrompt(BaseModel): pipeline_params: dict = {} seed: int = 0 +class MultimodalPrompt(BaseModel): + key: str = "" + prompt: str = "" + image_url: str = "" + model_name: str = "" + pipeline_params: dict = {} + seed: int = 0 + class TextToImage(BaseModel): prompt: str