diff --git a/app.py b/app.py index b1cb3ec..3258da7 100644 --- a/app.py +++ b/app.py @@ -21,6 +21,7 @@ from slowapi.util import get_remote_address from prometheus_fastapi_instrumentator import Instrumentator from PIL import Image +from transformers import AutoTokenizer def get_api_key(request: Request): return request.headers.get("API_KEY", get_remote_address(request)) @@ -83,6 +84,14 @@ class ValidatorInfo(BaseModel): all_uid_info: dict = {} sha: str = "" +class ChatCompletion(BaseModel): + model: str + messages: list[dict] + temperature: float = 1 + top_p: float = 1 + max_tokens: int = 128 + + class ImageGenerationService: def __init__(self): self.subtensor = bt.subtensor("finney") @@ -126,6 +135,12 @@ def __init__(self): Thread(target=self.sync_metagraph_periodically, daemon=True).start() Thread(target=self.recheck_validators, daemon=True).start() Thread(target=self.update_model_config, daemon=True).start() + self.tokenizer_config = self.model_config.find_one({"name": "tokenizer"}) + print(self.tokenizer_config, flush=True) + self.tokenizers = { + k: AutoTokenizer.from_pretrained(v) for k, v in self.tokenizer_config["data"].items() + } + print(self.tokenizers, flush=True) def update_model_config(self): while True: @@ -568,6 +583,27 @@ async def controlnet_api(self, request: Request, data: ImageToImage): generate_data["pipeline_params"][key] = value return await self.generate(Prompt(**generate_data)) + + async def chat_completions(self, request: Request, data: ChatCompletion): + # Get API_KEY from header + api_key = request.headers.get("API_KEY") + self.check_auth(api_key) + if data.model not in self.model_list: + raise HTTPException(status_code=404, detail="Model not found") + messages_str = self.tokenizers[data.model].apply_chat_template(data.messages, tokenize=False) + print(f"Chat message str: {messages_str}", flush=True) + generate_data = { + "key": api_key, + "prompt_input": messages_str, + "model_name": data.model, + "pipeline_params": { + "temperature": data.temperature, + "top_p": data.top_p, + "max_tokens": data.max_tokens + } + } + response = await self.generate(TextPrompt(**generate_data)) + return response['prompt_output'] def base64_to_pil_image(self, base64_image): image = base64.b64decode(base64_image) @@ -632,3 +668,8 @@ async def instantid_api(request: Request, data: ImageToImage): @limiter.limit(API_RATE_LIMIT) # Update the rate limit async def controlnet_api(request: Request, data: ImageToImage): return await app.controlnet_api(request, data) + +@app.app.post("/api/v1/chat/completions") +@limiter.limit(API_RATE_LIMIT) +async def chat_completions_api(request: Request, data: ChatCompletion): + return await app.chat_completions(request, data) diff --git a/requirements.txt b/requirements.txt index 259b400..94eaeff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,6 @@ tqdm==4.66.1 httpx==0.27.0 prometheus_fastapi_instrumentator==6.0.0 pymongo==4.7.3 -slowapi==0.1.9 \ No newline at end of file +slowapi==0.1.9 +transformers +jinja2==3.1.0 \ No newline at end of file