diff --git a/intel_extension_for_transformers/neural_chat/chatbot.py b/intel_extension_for_transformers/neural_chat/chatbot.py index 06d9965a700..1ebf59feff9 100644 --- a/intel_extension_for_transformers/neural_chat/chatbot.py +++ b/intel_extension_for_transformers/neural_chat/chatbot.py @@ -277,6 +277,8 @@ def build_chatbot(config: PipelineConfig=None): parameters["optimization_config"] = config.optimization_config parameters["hf_access_token"] = config.hf_access_token parameters["assistant_model"] = config.assistant_model + parameters["assistant_host"] = config.assistant_host + parameters["assistant_port"] = config.assistant_port if config.serving_config and config.serving_config.framework == "vllm": parameters["use_vllm"] = True parameters["vllm_engine_params"] = config.serving_config.framework_config diff --git a/intel_extension_for_transformers/neural_chat/config.py b/intel_extension_for_transformers/neural_chat/config.py index 5b745b1c999..2ada32fe862 100644 --- a/intel_extension_for_transformers/neural_chat/config.py +++ b/intel_extension_for_transformers/neural_chat/config.py @@ -457,6 +457,8 @@ def __init__(self, loading_config=None, optimization_config=None, assistant_model=None, + assistant_host=None, + assistant_port=None, serving_config=None): self.model_name_or_path = model_name_or_path self.tokenizer_name_or_path = tokenizer_name_or_path @@ -482,4 +484,6 @@ def __init__(self, f"Expect optimization_config be an object of MixedPrecisionConfig, WeightOnlyQuantConfig" + \ " or BitsAndBytesConfig,got {type(self.optimization_config)}." self.assistant_model = assistant_model + self.assistant_host = assistant_host + self.assistant_port = assistant_port self.serving_config = serving_config diff --git a/intel_extension_for_transformers/neural_chat/examples/deployment/assisted_generation/assisted_gen.yaml b/intel_extension_for_transformers/neural_chat/examples/deployment/assisted_generation/assisted_gen.yaml index 56de861ca05..43b759c16d9 100644 --- a/intel_extension_for_transformers/neural_chat/examples/deployment/assisted_generation/assisted_gen.yaml +++ b/intel_extension_for_transformers/neural_chat/examples/deployment/assisted_generation/assisted_gen.yaml @@ -27,6 +27,10 @@ model_name_or_path: "facebook/opt-13b" device: "cpu" assistant_model: "facebook/opt-350m" +# multi-node +assistant_host: "0.0.0.0" +assistant_port: 80 + # task choices = ['textchat', 'voicechat', 'retrieval', 'text2image', 'finetune', 'codegen'] tasks_list: ['textchat'] diff --git a/intel_extension_for_transformers/neural_chat/models/base_model.py b/intel_extension_for_transformers/neural_chat/models/base_model.py index 4d473d21aaf..9ae8c437579 100644 --- a/intel_extension_for_transformers/neural_chat/models/base_model.py +++ b/intel_extension_for_transformers/neural_chat/models/base_model.py @@ -123,6 +123,8 @@ def load_model(self, kwargs: dict): self.use_cache = kwargs["use_cache"] self.ipex_int8 = kwargs["ipex_int8"] self.assistant_model = kwargs["assistant_model"] + self.assistant_host = kwargs["assistant_host"] + self.assistant_port = kwargs["assistant_port"] load_model(model_name=kwargs["model_name"], tokenizer_name=kwargs["tokenizer_name"], device=kwargs["device"], diff --git a/intel_extension_for_transformers/neural_chat/models/model_utils.py b/intel_extension_for_transformers/neural_chat/models/model_utils.py index b46b106538c..f75041b2983 100644 --- a/intel_extension_for_transformers/neural_chat/models/model_utils.py +++ b/intel_extension_for_transformers/neural_chat/models/model_utils.py @@ -487,13 +487,13 @@ def load_model( from transformers import AutoModelForCausalLM assistant_model_class = AutoModelForCausalLM print(f"Loading assistant model via {assistant_model_class}") - assis_model = assistant_model_class.from_pretrained( + assist_model = assistant_model_class.from_pretrained( assistant_model, low_cpu_mem_usage=True, torch_dtype=torch_dtype) - assis_model = assis_model.eval().to(device) - assis_model = assis_model.to(memory_format=torch.channels_last) - MODELS[model_name]["assistant_model"] = assis_model + assist_model = assist_model.eval().to(device) + assist_model = assist_model.to(memory_format=torch.channels_last) + MODELS[model_name]["assistant_model"] = assist_model else: MODELS[model_name]["assistant_model"] = None diff --git a/intel_extension_for_transformers/neural_chat/server/neuralchat_server.py b/intel_extension_for_transformers/neural_chat/server/neuralchat_server.py index 4ac69b21e8e..b6d9e9164dc 100644 --- a/intel_extension_for_transformers/neural_chat/server/neuralchat_server.py +++ b/intel_extension_for_transformers/neural_chat/server/neuralchat_server.py @@ -111,6 +111,8 @@ def init(self, config): peft_model_path = config.get("peft_model_path", "") plugin_as_service = config.get("plugin_as_service", False) assistant_model = config.get("assistant_model", None) + assistant_host = config.get("assistant_host", "0.0.0.0") + assistant_port = config.get("assistant_port", 80) serving = config.get("serving", None) serving_config = None @@ -270,6 +272,8 @@ def init(self, config): "loading_config": loading_config, "optimization_config": optimization_config, "assistant_model": assistant_model, + "assistant_host": assistant_host, + "assistant_port": assistant_port, "serving_config": serving_config, "task": "chat" } diff --git a/intel_extension_for_transformers/neural_chat/server/restful/api.py b/intel_extension_for_transformers/neural_chat/server/restful/api.py index 37844148f4e..6ae8dee9476 100644 --- a/intel_extension_for_transformers/neural_chat/server/restful/api.py +++ b/intel_extension_for_transformers/neural_chat/server/restful/api.py @@ -32,6 +32,7 @@ from .plugin_image2image_api import router as plugin_image2image_router from .codegen_api import router as codegen_router from .tgi_api import router as tgi_router +from .assisted_gen_api import router as assist_router _router = APIRouter() @@ -47,7 +48,8 @@ 'plugin_audio': plugin_audio_router, "image2image": plugin_image2image_router, 'codegen': codegen_router, - 'tgi': tgi_router + 'tgi': tgi_router, + 'assist_generation': assist_router } def setup_router(api_list, chatbot=None, enable_llm=True, use_deepspeed=False, world_size=1, host="0.0.0.0", port=80): diff --git a/intel_extension_for_transformers/neural_chat/server/restful/assisted_gen_api.py b/intel_extension_for_transformers/neural_chat/server/restful/assisted_gen_api.py new file mode 100644 index 00000000000..1427819bcfe --- /dev/null +++ b/intel_extension_for_transformers/neural_chat/server/restful/assisted_gen_api.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import httpx +from fastapi.routing import APIRouter +from fastapi import APIRouter +from ...cli.log import logger +from .openai_protocol import ( + ChatCompletionRequest, + CompletionRequest, +) + + +class AssistedGenerationAPIRouter(APIRouter): + + def __init__(self) -> None: + super().__init__() + + def set_chatbot(self, chatbot, use_deepspeed=False, world_size=1, host="0.0.0.0", port=80) -> None: + self.chatbot = chatbot + self.use_deepspeed = use_deepspeed + self.world_size = world_size + self.host = host + self.port = port + assistant_host = chatbot.assistant_host + assistant_port = chatbot.assistant_port + self.assistant_prefix = 'http://'+assistant_host+":"+assistant_port + + def get_chatbot(self): + if self.chatbot is None: + logger.error("Chatbot instance is not found.") + raise RuntimeError("Chatbot instance has not been set.") + return self.chatbot + + async def handle_assist_chat(self, request: ChatCompletionRequest): + async with httpx.AsyncClient() as client: + response = await client.get(self.assistant_prefix+"/v1/assist/decode", params=request) + return response.json() + + async def handle_assist_decode(self, request: ChatCompletionRequest): + chatbot = self.get_chatbot() + # TODO: complete model inferencing process for assisted model + pass + + async def handle_assist_data_transfer(self, request: ChatCompletionRequest): + async with httpx.AsyncClient() as client: + response = await client.get(self.assistant_prefix+"/v1/assist/data_transfer", params=request) + return response.json() + + +router = AssistedGenerationAPIRouter() + + +# router for small model to do inferencing +@router.post("/v1/assist/chat") +async def assist_chat(request: ChatCompletionRequest): + return await router.handle_assist_chat(request) + + +# router for assisted model to do inferencing +@router.post("/v1/assist/decode") +async def assist_decode(request: CompletionRequest): + return await router.handle_assist_decode(request) + + +# router for assisted model to do data transferring +@router.post("/v1/assist/data_transfer") +async def assist_data_transfer(request: CompletionRequest): + return await router.handle_assist_data_transfer(request)