Skip to content

feat: Support decode chunk PD serving mode #944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ dist
*.egg-info
.idea
.vscode
tmp/
tmp/
5 changes: 5 additions & 0 deletions docs/CN/source/tutorial/api_server_args_zh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ PD 分离模式参数

配置服务器模式下的端口号


.. option:: --chunked_max_new_token

分块解码最大 token 数量,默认为 ``0`` ,代表不使用分块解码

模型配置参数
-----------

Expand Down
4 changes: 4 additions & 0 deletions docs/EN/source/tutorial/api_server_args_zh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ PD disaggregation Mode Parameters

Port number in configuration server mode

.. option:: --chunked_max_new_token

Maximum token number for chunked decoding, default is ``0``, representing no chunked decoding

Model Configuration Parameters
-----------------------------

Expand Down
6 changes: 6 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
default="default_model_name",
help="just help to distinguish internal model name, use 'host:port/get_model_name' to get",
)
parser.add_argument(
"--chunked_max_new_token",
type=int,
default=0,
help="""Specifies the chunk size for pd mode.""",
)

parser.add_argument(
"--model_dir",
Expand Down
105 changes: 79 additions & 26 deletions lightllm/server/httpserver_for_pd_master/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from lightllm.server.metrics.manager import MetricClient
from lightllm.utils.statics_utils import MovingAverage
from lightllm.server.httpserver.manager import AsyncQueue
from lightllm.utils.error_utils import ServerBusyError
from lightllm.utils.error_utils import ServerBusyError, KVMoveTimeoutError

logger = init_logger(__name__)

Expand Down Expand Up @@ -123,6 +123,9 @@ async def generate(
):
start_time = time.time()
group_request_id = self.id_gen.generate_id()
max_retries = 3
retry_count = 0

try:
sampling_params.group_request_id = group_request_id
# 记录请求到达的相关信息
Expand All @@ -131,22 +134,41 @@ async def generate(
self.metric_client.counter_inc("lightllm_request_count")
self.metric_client.histogram_observe("lightllm_request_max_new_tokens", sampling_params.max_new_tokens)

p_node, d_node = await self.select_p_d_node(prompt, sampling_params, multimodal_params)

results_generator = self._wait_to_token_package(
p_node,
d_node,
start_time,
prompt,
sampling_params,
multimodal_params,
request,
)
async for sub_req_id, request_output, metadata, finish_status in results_generator:
yield sub_req_id, request_output, metadata, finish_status
while retry_count <= max_retries:
try:
p_node, d_node = await self.select_p_d_node(prompt, sampling_params, multimodal_params)

results_generator = self._wait_to_token_package(
p_node,
d_node,
start_time,
prompt,
sampling_params,
multimodal_params,
request,
)
async for sub_req_id, request_output, metadata, finish_status in results_generator:
yield sub_req_id, request_output, metadata, finish_status

break

except KVMoveTimeoutError as e:
retry_count += 1
if retry_count <= max_retries:
logger.warning(f"KV move timeout for group_request_id {group_request_id}, attempt {retry_count}/{max_retries + 1}. Retrying with new nodes...")
# 清理当前请求状态,准备重试
await self.abort(group_request_id)
# 重新生成group_request_id避免冲突
group_request_id = self.id_gen.generate_id()
sampling_params.group_request_id = group_request_id
continue
else:
logger.error(f"KV move timeout after {max_retries + 1} attempts for group_request_id {group_request_id}. Giving up.")
raise ServerBusyError(f"KV move timeout after {max_retries + 1} attempts, server is busy now.")

except BaseException as e:
logger.error(f"has exception {str(e)}")
if not isinstance(e, KVMoveTimeoutError):
logger.error(f"has exception {str(e)}")
await self.abort(group_request_id)
raise e

Expand Down Expand Up @@ -234,22 +256,53 @@ async def fetch_stream(
await asyncio.wait_for(up_status_event.wait(), timeout=60)
except asyncio.TimeoutError:
logger.warning(f"group_request_id: {group_request_id} kv move time out err, server is busy now.")
raise ServerBusyError()
raise KVMoveTimeoutError(f"KV move timeout for group_request_id {group_request_id}")

sampling_params.move_kv_to_decode_node.initialize(None)
sampling_params.max_new_tokens = old_max_new_tokens - 1
sampling_params.suggested_dp_index = up_status_event.upkv_status.dp_index

await d_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt_ids, sampling_params, multimodal_params))))
remaining_tokens = old_max_new_tokens - 1
chunked_max_new_token = self.args.chunked_max_new_token
current_prompt_ids = list(prompt_ids)

while True:
await req_status.wait_to_ready()
if await request.is_disconnected():
raise Exception(f"req_id {group_request_id} disconnected")
if await req_status.can_read(self.req_id_to_out_inf):
token_list = await req_status.pop_all_tokens()
for sub_req_id, request_output, metadata, finish_status in token_list:
yield sub_req_id, request_output, metadata, finish_status
while remaining_tokens > 0:
chunk_size = min(remaining_tokens, chunked_max_new_token) if chunked_max_new_token > 0 else remaining_tokens
sampling_params.max_new_tokens = chunk_size
await d_node.websocket.send_bytes(
pickle.dumps((ObjType.REQ, (current_prompt_ids, sampling_params, multimodal_params)))
)

chunk_finished = False
while not chunk_finished:
await req_status.wait_to_ready()
if await request.is_disconnected():
raise Exception(f"req_id {group_request_id} disconnected")

if await req_status.can_read(self.req_id_to_out_inf):
token_list = await req_status.pop_all_tokens()
for sub_req_id, request_output, metadata, finish_status in token_list:
current_prompt_ids.append(metadata.get("id"))
Comment on lines +283 to +284

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a check to ensure metadata.get("id") is not None before appending it to current_prompt_ids. This will prevent potential None values from being added to the list, which could lead to unexpected behavior later on. A medium severity is assigned because while the code might work in most cases, the absence of this check introduces a risk of runtime errors.

                        if metadata.get("id") is not None:
                            current_prompt_ids.append(metadata.get("id"))

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the code itself can ensure metadata.get("id") is not None. This fix is not necessary.

remaining_tokens -= 1

final_finish_status = finish_status

# reach max new tokens, really finished
if remaining_tokens == 0:
final_finish_status = FinishStatus(FinishStatus.FINISHED_LENGTH)
chunk_finished = True
# reach stop token, really finished
elif finish_status == FinishStatus.FINISHED_STOP:
final_finish_status = FinishStatus(FinishStatus.FINISHED_STOP)
chunk_finished = True
# reach chunk size, not really finished
elif finish_status == FinishStatus.FINISHED_LENGTH:
final_finish_status = FinishStatus(FinishStatus.NO_FINISH)
chunk_finished = True

yield sub_req_id, request_output, metadata, final_finish_status

if final_finish_status.is_finished():
break

return

Expand Down
5 changes: 5 additions & 0 deletions lightllm/utils/error_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@ def __init__(self, message="Server is busy, please try again later", status_code
def __str__(self):
"""String representation of the error"""
return f"{self.message} (Status code: {self.status_code})"

class KVMoveTimeoutError(ServerBusyError):
"""KV移动超时错误,用于触发重试机制"""
def __init__(self, message="KV move timeout, please try again later", status_code=503):
super().__init__(message, status_code)