diff --git a/.gitignore b/.gitignore index 6049c2cdb..d07ab8183 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,4 @@ dist *.egg-info .idea .vscode -tmp/ +tmp/ \ No newline at end of file diff --git a/docs/CN/source/tutorial/api_server_args_zh.rst b/docs/CN/source/tutorial/api_server_args_zh.rst index d7c055ef4..7e87b9c62 100755 --- a/docs/CN/source/tutorial/api_server_args_zh.rst +++ b/docs/CN/source/tutorial/api_server_args_zh.rst @@ -64,6 +64,11 @@ PD 分离模式参数 配置服务器模式下的端口号 + +.. option:: --chunked_max_new_token + + 分块解码最大 token 数量,默认为 ``0`` ,代表不使用分块解码 + 模型配置参数 ----------- diff --git a/docs/EN/source/tutorial/api_server_args_zh.rst b/docs/EN/source/tutorial/api_server_args_zh.rst index 3b25ae85c..d7350485b 100755 --- a/docs/EN/source/tutorial/api_server_args_zh.rst +++ b/docs/EN/source/tutorial/api_server_args_zh.rst @@ -64,6 +64,10 @@ PD disaggregation Mode Parameters Port number in configuration server mode +.. option:: --chunked_max_new_token + + Maximum token number for chunked decoding, default is ``0``, representing no chunked decoding + Model Configuration Parameters ----------------------------- diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 601b2a48a..a2d92cb15 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -60,6 +60,12 @@ def make_argument_parser() -> argparse.ArgumentParser: default="default_model_name", help="just help to distinguish internal model name, use 'host:port/get_model_name' to get", ) + parser.add_argument( + "--chunked_max_new_token", + type=int, + default=0, + help="""Specifies the chunk size for pd mode.""", + ) parser.add_argument( "--model_dir", diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 05b2d987c..862c139a2 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -24,7 +24,7 @@ from lightllm.server.metrics.manager import MetricClient from lightllm.utils.statics_utils import MovingAverage from lightllm.server.httpserver.manager import AsyncQueue -from lightllm.utils.error_utils import ServerBusyError +from lightllm.utils.error_utils import ServerBusyError, KVMoveTimeoutError logger = init_logger(__name__) @@ -123,6 +123,9 @@ async def generate( ): start_time = time.time() group_request_id = self.id_gen.generate_id() + max_retries = 3 + retry_count = 0 + try: sampling_params.group_request_id = group_request_id # 记录请求到达的相关信息 @@ -131,22 +134,41 @@ async def generate( self.metric_client.counter_inc("lightllm_request_count") self.metric_client.histogram_observe("lightllm_request_max_new_tokens", sampling_params.max_new_tokens) - p_node, d_node = await self.select_p_d_node(prompt, sampling_params, multimodal_params) - - results_generator = self._wait_to_token_package( - p_node, - d_node, - start_time, - prompt, - sampling_params, - multimodal_params, - request, - ) - async for sub_req_id, request_output, metadata, finish_status in results_generator: - yield sub_req_id, request_output, metadata, finish_status + while retry_count <= max_retries: + try: + p_node, d_node = await self.select_p_d_node(prompt, sampling_params, multimodal_params) + + results_generator = self._wait_to_token_package( + p_node, + d_node, + start_time, + prompt, + sampling_params, + multimodal_params, + request, + ) + async for sub_req_id, request_output, metadata, finish_status in results_generator: + yield sub_req_id, request_output, metadata, finish_status + + break + + except KVMoveTimeoutError as e: + retry_count += 1 + if retry_count <= max_retries: + logger.warning(f"KV move timeout for group_request_id {group_request_id}, attempt {retry_count}/{max_retries + 1}. Retrying with new nodes...") + # 清理当前请求状态,准备重试 + await self.abort(group_request_id) + # 重新生成group_request_id避免冲突 + group_request_id = self.id_gen.generate_id() + sampling_params.group_request_id = group_request_id + continue + else: + logger.error(f"KV move timeout after {max_retries + 1} attempts for group_request_id {group_request_id}. Giving up.") + raise ServerBusyError(f"KV move timeout after {max_retries + 1} attempts, server is busy now.") except BaseException as e: - logger.error(f"has exception {str(e)}") + if not isinstance(e, KVMoveTimeoutError): + logger.error(f"has exception {str(e)}") await self.abort(group_request_id) raise e @@ -234,22 +256,53 @@ async def fetch_stream( await asyncio.wait_for(up_status_event.wait(), timeout=60) except asyncio.TimeoutError: logger.warning(f"group_request_id: {group_request_id} kv move time out err, server is busy now.") - raise ServerBusyError() + raise KVMoveTimeoutError(f"KV move timeout for group_request_id {group_request_id}") sampling_params.move_kv_to_decode_node.initialize(None) - sampling_params.max_new_tokens = old_max_new_tokens - 1 sampling_params.suggested_dp_index = up_status_event.upkv_status.dp_index - await d_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt_ids, sampling_params, multimodal_params)))) + remaining_tokens = old_max_new_tokens - 1 + chunked_max_new_token = self.args.chunked_max_new_token + current_prompt_ids = list(prompt_ids) - while True: - await req_status.wait_to_ready() - if await request.is_disconnected(): - raise Exception(f"req_id {group_request_id} disconnected") - if await req_status.can_read(self.req_id_to_out_inf): - token_list = await req_status.pop_all_tokens() - for sub_req_id, request_output, metadata, finish_status in token_list: - yield sub_req_id, request_output, metadata, finish_status + while remaining_tokens > 0: + chunk_size = min(remaining_tokens, chunked_max_new_token) if chunked_max_new_token > 0 else remaining_tokens + sampling_params.max_new_tokens = chunk_size + await d_node.websocket.send_bytes( + pickle.dumps((ObjType.REQ, (current_prompt_ids, sampling_params, multimodal_params))) + ) + + chunk_finished = False + while not chunk_finished: + await req_status.wait_to_ready() + if await request.is_disconnected(): + raise Exception(f"req_id {group_request_id} disconnected") + + if await req_status.can_read(self.req_id_to_out_inf): + token_list = await req_status.pop_all_tokens() + for sub_req_id, request_output, metadata, finish_status in token_list: + current_prompt_ids.append(metadata.get("id")) + remaining_tokens -= 1 + + final_finish_status = finish_status + + # reach max new tokens, really finished + if remaining_tokens == 0: + final_finish_status = FinishStatus(FinishStatus.FINISHED_LENGTH) + chunk_finished = True + # reach stop token, really finished + elif finish_status == FinishStatus.FINISHED_STOP: + final_finish_status = FinishStatus(FinishStatus.FINISHED_STOP) + chunk_finished = True + # reach chunk size, not really finished + elif finish_status == FinishStatus.FINISHED_LENGTH: + final_finish_status = FinishStatus(FinishStatus.NO_FINISH) + chunk_finished = True + + yield sub_req_id, request_output, metadata, final_finish_status + + if final_finish_status.is_finished(): + break return diff --git a/lightllm/utils/error_utils.py b/lightllm/utils/error_utils.py index 4424fc17d..d1ffe2b7c 100644 --- a/lightllm/utils/error_utils.py +++ b/lightllm/utils/error_utils.py @@ -16,3 +16,8 @@ def __init__(self, message="Server is busy, please try again later", status_code def __str__(self): """String representation of the error""" return f"{self.message} (Status code: {self.status_code})" + +class KVMoveTimeoutError(ServerBusyError): + """KV移动超时错误,用于触发重试机制""" + def __init__(self, message="KV move timeout, please try again later", status_code=503): + super().__init__(message, status_code) \ No newline at end of file