Skip to content

Add Text Endpoint #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from slowapi.util import get_remote_address
from prometheus_fastapi_instrumentator import Instrumentator
from PIL import Image
from transformers import AutoTokenizer

def get_api_key(request: Request):
return request.headers.get("API_KEY", get_remote_address(request))
Expand Down Expand Up @@ -83,6 +84,14 @@ class ValidatorInfo(BaseModel):
all_uid_info: dict = {}
sha: str = ""

class ChatCompletion(BaseModel):
model: str
messages: list[dict]
temperature: float = 1
top_p: float = 1
max_tokens: int = 128


class ImageGenerationService:
def __init__(self):
self.subtensor = bt.subtensor("finney")
Expand Down Expand Up @@ -126,6 +135,12 @@ def __init__(self):
Thread(target=self.sync_metagraph_periodically, daemon=True).start()
Thread(target=self.recheck_validators, daemon=True).start()
Thread(target=self.update_model_config, daemon=True).start()
self.tokenizer_config = self.model_config.find_one({"name": "tokenizer"})
print(self.tokenizer_config, flush=True)
self.tokenizers = {
k: AutoTokenizer.from_pretrained(v) for k, v in self.tokenizer_config["data"].items()
}
print(self.tokenizers, flush=True)

def update_model_config(self):
while True:
Expand Down Expand Up @@ -568,6 +583,27 @@ async def controlnet_api(self, request: Request, data: ImageToImage):
generate_data["pipeline_params"][key] = value

return await self.generate(Prompt(**generate_data))

async def chat_completions(self, request: Request, data: ChatCompletion):
# Get API_KEY from header
api_key = request.headers.get("API_KEY")
self.check_auth(api_key)
if data.model not in self.model_list:
raise HTTPException(status_code=404, detail="Model not found")
messages_str = self.tokenizers[data.model].apply_chat_template(data.messages, tokenize=False)
print(f"Chat message str: {messages_str}", flush=True)
generate_data = {
"key": api_key,
"prompt_input": messages_str,
"model_name": data.model,
"pipeline_params": {
"temperature": data.temperature,
"top_p": data.top_p,
"max_tokens": data.max_tokens
}
}
response = await self.generate(TextPrompt(**generate_data))
return response['prompt_output']

def base64_to_pil_image(self, base64_image):
image = base64.b64decode(base64_image)
Expand Down Expand Up @@ -632,3 +668,8 @@ async def instantid_api(request: Request, data: ImageToImage):
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
async def controlnet_api(request: Request, data: ImageToImage):
return await app.controlnet_api(request, data)

@app.app.post("/api/v1/chat/completions")
@limiter.limit(API_RATE_LIMIT)
async def chat_completions_api(request: Request, data: ChatCompletion):
return await app.chat_completions(request, data)
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ tqdm==4.66.1
httpx==0.27.0
prometheus_fastapi_instrumentator==6.0.0
pymongo==4.7.3
slowapi==0.1.9
slowapi==0.1.9
transformers
jinja2==3.1.0