From d8e2bb0ffd5320297252dcd43739553cef157c20 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Fri, 11 Apr 2025 16:51:30 +0800 Subject: [PATCH 01/19] WIP. --- lightllm/server/httpserver/pd_loop.py | 39 ++- .../httpserver_for_pd_master/manager.py | 61 +--- lightllm/server/pd_io_struct.py | 5 + .../server/router/model_infer/infer_batch.py | 4 + .../mode_backend/generic_pre_process.py | 39 +++ .../pd_disaggregation/impl_for_pd_decode.py | 134 ++++++++ .../pd_disaggregation/impl_for_pd_prefill.py | 141 +++++++++ .../pd_disaggregation/pd_remote_prefill.py | 295 ++++++++++++++++++ .../pd_remote_prefill_obj.py | 67 ++++ 9 files changed, 727 insertions(+), 58 deletions(-) create mode 100644 lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_decode.py create mode 100644 lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_prefill.py create mode 100644 lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill.py create mode 100644 lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill_obj.py diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index 10a4a8ec5..45312b108 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -5,6 +5,7 @@ import socket import httpx import base64 +import zmq from typing import Dict, Optional from lightllm.server.pd_io_struct import NodeRole, ObjType from lightllm.server.httpserver.async_queue import AsyncQueue @@ -33,6 +34,7 @@ async def pd_handle_loop(manager: HttpServerManager): manager.host_ip = manager.args.host asyncio.create_task(timer_log(manager)) + asyncio.create_task(pd_handle_loop_from_d(manager)) id_to_handle_task: Dict[int, asyncio.Task] = {} @@ -92,7 +94,8 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O logger.info(f"Sent registration JSON: {regist_json}") # 转发任务 - forwarding_tokens_task = asyncio.create_task(_up_tokens_to_pd_master(forwarding_queue, websocket)) + if manager.pd_mode == NodeRole.D: + forwarding_tokens_task = asyncio.create_task(_up_tokens_to_pd_master(forwarding_queue, websocket)) # 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。 while True: @@ -182,3 +185,37 @@ async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket): handle_list = await forwarding_queue.wait_to_get_all_data() if handle_list: await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list))) + + +async def pd_handle_loop_from_d(manager: HttpServerManager): + if manager.pd_mode != NodeRole.P: + return + + context = zmq.asyncio.Context(2) + manager.recv_from_d = context.socket(zmq.PULL) + manager.recv_from_d.bind(f"tcp://*:{manager.args.pd_remote_prefill_port}") + + while True: + try: + ( + prompt, + sampling_params, + multimodal_params, + ) = await manager.recv_from_d.recv_pyobj() + + # 触发推理的task + async def pd_process_generate( + manager: "HttpServerManager", prompt, sampling_params, multimodal_params + ): + try: + async for _, _, _, _ in manager.generate( + prompt, sampling_params, multimodal_params, None + ): + pass + except BaseException as e: + logger.error(str(e)) + + asyncio.create_task(pd_process_generate(manager, prompt, sampling_params, multimodal_params)) + + except Exception as e: + logger.exception(f"pd loop generate error: {str(e)}") \ No newline at end of file diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 05b2d987c..3f963c520 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -1,13 +1,8 @@ import sys -import zmq -import zmq.asyncio import asyncio import uvloop -import rpyc import time -import hashlib import datetime -import aiohttp import ujson as json import pickle @@ -78,15 +73,6 @@ async def remove_pd(self, pd_info_json): logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} removed") return - async def update_req_status(self, upkv_status: UpKVStatus): - try: - group_request_id = convert_sub_id_to_group_id(upkv_status.group_request_id) - up_status_event = self.req_id_to_out_inf[group_request_id].up_status_event - up_status_event.upkv_status = upkv_status - up_status_event.set() - except: - pass - return def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwargs=None): kwargs = {} if kwargs is None else kwargs @@ -192,55 +178,18 @@ async def fetch_stream( up_status_event = req_status.up_status_event d_start_args = d_node.start_args - decode_node_dict = { + prefill_node_dict = { "node_id": d_start_args["pd_node_id"], "ip": d_start_args["host"], - "rpyc_port": d_start_args["pd_decode_rpyc_port"], + "rpyc_port": d_start_args["pd_prefill_rpyc_port"], "max_new_tokens": sampling_params.max_new_tokens - 1, "pd_master_node_id": self.args.pd_node_id, } - old_max_new_tokens = sampling_params.max_new_tokens - sampling_params.max_new_tokens = 1 - sampling_params.move_kv_to_decode_node.initialize(decode_node_dict if old_max_new_tokens != 1 else None) + sampling_params.move_kv_to_decode_node.initialize(prefill_node_dict) sampling_params.suggested_dp_index = -1 - await p_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params)))) - - while True: - await req_status.wait_to_ready() - if await request.is_disconnected(): - raise Exception(f"req_id {group_request_id} disconnected") - - if await req_status.can_read(self.req_id_to_out_inf): - token_list = await req_status.pop_all_tokens() - for sub_req_id, request_output, metadata, finish_status in token_list: - if old_max_new_tokens != 1: - finish_status = FinishStatus(FinishStatus.NO_FINISH) - else: - finish_status = FinishStatus(FinishStatus.FINISHED_LENGTH) - # 得到 p 节点返回的 prompt_ids 信息 - if metadata.get("prompt_ids", None) is not None: - prompt_ids = metadata.get("prompt_ids") - prompt_ids.append(metadata.get("id")) - yield sub_req_id, request_output, metadata, finish_status - break - - # 如果只需要一个输出 token,prefill 完就直接结束掉吧 - if old_max_new_tokens == 1: - return - - try: - await asyncio.wait_for(up_status_event.wait(), timeout=60) - except asyncio.TimeoutError: - logger.warning(f"group_request_id: {group_request_id} kv move time out err, server is busy now.") - raise ServerBusyError() - - sampling_params.move_kv_to_decode_node.initialize(None) - sampling_params.max_new_tokens = old_max_new_tokens - 1 - sampling_params.suggested_dp_index = up_status_event.upkv_status.dp_index - - await d_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt_ids, sampling_params, multimodal_params)))) + await d_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params)))) while True: await req_status.wait_to_ready() @@ -251,8 +200,6 @@ async def fetch_stream( for sub_req_id, request_output, metadata, finish_status in token_list: yield sub_req_id, request_output, metadata, finish_status - return - async def _wait_to_token_package( self, p_node: PD_Client_Obj, diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 414e3c74a..cb0bb8a1e 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -113,6 +113,11 @@ class PDTransJoinInfo: # 一次连接,使用一个 uuid 为其标识 connect_id: str +@dataclass +class RemotePrefillServerInfo: + perfill_server_id: int + prefill_server_ip: str + prefill_server_port: int @dataclass class PDTransLeaveInfo: diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 09f649943..d94ea5445 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -257,6 +257,8 @@ def __init__( self.vocab_size = vocab_size self.initialized = False self.paused = False + self.remote_prefilling = False + self.kv_transfering = False def init_all(self): if self.initialized is False: @@ -300,6 +302,8 @@ def init_all(self): self.initialized = True self.paused = False + self.remote_prefilling = False + self.kv_transfering = False return def is_uninitialized(self): diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 9f174e8bc..84926d484 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -4,6 +4,45 @@ from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context from lightllm.common.basemodel.infer_lock import g_infer_state_lock +def prepare_remote_prefill_inputs(req_objs: List[InferReq]): + run_reqs = [] + start_loc = 0 + input_ids = [] + nopad_b_req_idx = [] + nopad_b_start_loc = [] + nopad_b_seq_len = [] + + for req in req_objs: + run_reqs.append(req) + nopad_b_req_idx.append(req.req_idx) + nopad_b_start_loc.append(start_loc) + + input_token_ids = req.get_input_token_ids() + seq_len = len(input_token_ids) + input_token_len = seq_len - req.cur_kv_len + input_id = input_token_ids[req.cur_kv_len :] + nopad_b_seq_len.append(seq_len) + input_ids.append(input_id) + start_loc += input_token_len + + nopad_b_start_loc.append(start_loc) # last request + + input_ids = np.concatenate(input_ids, dtype=np.int64) + g_infer_state_lock.acquire() # I don't think it's needed + if g_infer_context.radix_cache is not None: + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]).cuda() + g_infer_state_lock.release() + kwargs = { + "batch_size": len(run_reqs), + "input_ids": input_ids, + "mem_indexes": mem_indexes, + "b_req_idx": nopad_b_req_idx, + "b_start_loc": nopad_b_start_loc, + "b_seq_len": nopad_b_seq_len, + } + return kwargs, run_reqs + def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool, is_multimodal=False): run_reqs = [] diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_decode.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_decode.py new file mode 100644 index 000000000..856b30138 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_decode.py @@ -0,0 +1,134 @@ +import torch +import torch.multiprocessing as mp +import threading +from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend +from typing import List, Tuple, Dict +from lightllm.utils.infer_utils import set_random_seed +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.server.core.objs import FinishStatus +from lightllm.utils.log_utils import init_logger +from lightllm.server.router.model_infer.mode_backend.generic_pre_process import prepare_decode_inputs, prepare_remote_prefill_inputs +from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample + +from .pd_remote_prefill_obj import ( + RemotePrefillTask, + RemotePrefillServerInfo, + RemotePrefillRequest) + +logger = init_logger(__name__) + + +class PDBackendForDecodeNode(ModeBackend): + def __init__(self, + prefill_task_queue: mp.Queue, + prefill_done_queue: mp.Queue, + mem_queue: mp.Queue) -> None: + super().__init__() + self.prefill_task_queue = prefill_task_queue + self.prefill_done_queue = prefill_done_queue + self.mem_queue = mem_queue + self.remote_prefilled_reqs: Dict[str, InferReq] = {} + + def wait_prefill_done_loop(self): + while True: + prefill_done_id = self.prefill_done_queue.get() + if prefill_done_id is None: # None means exit + logger.info("wait prefill done loop exits") + break + if run_req := self.remote_prefilled_reqs.get(prefill_done_id, None): + # remote prefill and transfer done, we need set kv cache to prompt len + + run_req.remote_prefilling = False + self.remote_prefilled_reqs.pop(prefill_done_id) + else: + logger.warning(f"wait prefill done loop: cannot find run_req with id {prefill_done_id}") + + + def init_custom(self): + + self.mem_queue.put((self.rank_in_dp, self.model.mem_manager.kv_buffer)) + + threading.Thread(target=self.wait_prefill_done_loop, daemon=True).start() + + return + + def prefill(self, reqs: List[Tuple]): + self._init_reqs(reqs, init_req_obj=False) + return + + def _build_remote_prefill_task(self, index: int, kwargs: Dict, req: InferReq): + prefill_node = req.shm_req.sample_params.move_kv_to_decode_node.to_dict() + prefill_node_info = RemotePrefillServerInfo( + perfill_server_id=prefill_node['node_id'], + prefill_server_ip=prefill_node['ip'], + prefill_server_port=prefill_node['rpyc_port'], + ) + + mem_indexes = kwargs.get('mem_indexes') + b_start_loc = kwargs.get('b_start_loc') + + prefill_request = RemotePrefillRequest( + group_request_id=req.shm_req.group_req_id, + prompt = req.shm_req.get_prompt_ids(), + sampling_params=req.shm_req.sample_params, + multimodal_params=req.multimodal_params, + local_cached_len=req.cur_kv_len, + token_ids=mem_indexes[b_start_loc[index]: b_start_loc[index+1]], + ) + return RemotePrefillTask(server_info=prefill_node_info, prefill_request=prefill_request) + + def _get_classed_reqs(self, req_ids: List[int], no_decode: bool = False): + uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = super()._get_classed_reqs( + req_ids, + no_decode, + ) + new_prefill_reqs = [] + # filter remote prefill requests + for r in prefill_reqs: + if r.remote_prefilling: + continue + new_prefill_reqs.append(r) + return uninit_reqs, aborted_reqs, ok_finished_reqs, new_prefill_reqs, decode_reqs + + def decode(self): + uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids, + no_decode=False, + ) + + self._filter_reqs(aborted_reqs) + + # allocate kv cache, do remote prefill + if prefill_reqs: + # TODO: we could allocate cache later after remote prefill done and get a signal from remote + # but it will have a risk to not have enough cache for this request. + kwargs, run_reqs = prepare_remote_prefill_inputs(prefill_reqs) + for idx, run_req in enumerate(run_reqs): + run_req: InferReq = run_req + # forward each req to remote prefill + # since the token index are the same across TPs, we only need to trigger prefill on master + if self.is_master_in_dp: + self.prefill_task_queue.put(self._build_remote_prefill_task(idx, kwargs, run_req)) + + run_req.remote_prefilling = True + self.remote_prefilled_reqs[run_req.req_id] = run_req + + if decode_reqs: + kwargs, run_reqs = prepare_decode_inputs(decode_reqs) + logits = self.model.forward(**kwargs) + + self._overlap_req_init_and_filter( + uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True + ) + + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + + self._post_handle( + run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False + ) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + return + diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_prefill.py new file mode 100644 index 000000000..579190132 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_prefill.py @@ -0,0 +1,141 @@ +import time +import threading +import torch +import torch.multiprocessing as mp +from typing import List, Tuple, Dict +from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend +from lightllm.utils.infer_utils import set_random_seed +from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end +from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams, g_infer_context +from lightllm.server.pd_io_struct import KVMoveTask, DecodeNodeInfo +from lightllm.utils.log_utils import init_logger +from lightllm.server.router.model_infer.mode_backend.generic_pre_process import prepare_prefill_inputs +from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample + +logger = init_logger(__name__) + + +class ChunckedPrefillForPrefillNode(ModeBackend): + def __init__(self, + transfer_task_queue: mp.Queue, + transfer_done_queue: mp.Queue, + mem_queue: mp.Queue) -> None: + super().__init__() + self.transfer_task_queue = transfer_task_queue + self.transfer_done_queue = transfer_done_queue + self.mem_queue = mem_queue + self.inflight_transfer_request: Dict[str, InferReq] = {} + + def wait_transfer_done(self): + while True: + transfer_done = self.transfer_done_queue.get() + logger.debug(f"Transfer done: {transfer_done}") + self.inflight_transfer_request[transfer_done].wait() + del self.inflight_transfer_request[transfer_done] + + def init_custom(self): + threading.Thread(target=self.wait_transfer_done, daemon=True).start() + return + + def prefill(self, reqs: List[Tuple]): + self._init_reqs(reqs) + return + + def _get_classed_reqs(self, req_ids: List[int], no_decode: bool = False): + uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = super()._get_classed_reqs( + req_ids, + no_decode + ) + new_ok_finished = [] + # filter remote prefill requests + for r in ok_finished_reqs: + if r.kv_transfering: + continue + new_ok_finished.append(r) + return uninit_reqs, aborted_reqs, new_ok_finished, prefill_reqs, decode_reqs + + def decode(self): + uinit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids, + no_decode=True, + ) + assert len(uinit_reqs) == 0 + assert len(decode_reqs) == 0 + + self._filter_reqs(aborted_reqs) + + if ok_finished_reqs: + self.prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(ok_finished_reqs) + self._filter_reqs(ok_finished_reqs) + + if prefill_reqs: + kwargs, run_reqs = prepare_prefill_inputs( + prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal + ) + + logits = self.model.forward(**kwargs) + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + + self._post_handle( + run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False + ) + return + + def prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(self, run_reqs: List[InferReq]): + # 提前在radix cache中回收相关的信息,并添加引用信息 + if self.is_master_in_dp: + logger.info("prefill_req_handle_and_frozen_tokens") + try: + for req in run_reqs: + req: InferReq = req + key = req.get_input_token_ids()[0 : req.cur_kv_len] + key = torch.tensor(key, dtype=torch.int64, device="cpu") + value = self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() + prefix_len = self.radix_cache.insert(key, value) + old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len + self.model.mem_manager.free( + self.model.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len] + ) + if req.shared_kv_node is not None: + self.radix_cache.dec_node_ref_counter(req.shared_kv_node) + req.shared_kv_node = None + + req.cur_kv_len = 0 + req.shm_req.shm_cur_kv_len = 0 + + if req.shm_req.sample_params.move_kv_to_decode_node.exists: + # 注意兼容纯tp 和 tp dp 混合模式的逻辑 + if self.is_master_in_dp: + g_router_lock.acquire() + self.shared_token_load.add_frozened_token_count(len(key), self.dp_rank_in_node) + g_router_lock.release() + + share_node, kv_len, value = self.radix_cache.match_prefix(key, update_refs=True) + assert len(key) == len(value) + # 将下面的请求放入到任务队列中, 注意要使用raidx cache 返回的value + decode_node_info = DecodeNodeInfo(**req.shm_req.sample_params.move_kv_to_decode_node.to_dict()) + task = KVMoveTask( + group_request_id=req.shm_req.group_req_id, + input_tokens=key.tolist(), + prefill_token_indexes=value.tolist(), + decode_token_indexes=None, + prefill_node_id=self.args.pd_node_id, + decode_node=decode_node_info, + move_kv_len=None, + prefill_dp_index=self.dp_rank_in_node, + decode_dp_index=None, + mark_start_time=time.time(), + ) + g_kv_move_task_cache[task.group_request_id] = (task, share_node) + + # 注意兼容纯 tp 和 tp dp 混合模式的逻辑 + if self.is_master_in_dp: + self.info_queue.put(task) + except BaseException as e: + logger.exception(str(e)) + + if self.is_master_in_dp: + logger.info("prefill_req_handle_and_frozen_tokens end") + return diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill.py new file mode 100644 index 000000000..9b6eb3342 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill.py @@ -0,0 +1,295 @@ +import asyncio +from typing import List, Union +from enum import Enum +import zmq +from collections import defaultdict + +from torch import Tensor +import torch.multiprocessing as mp + +from lightllm.utils.log_utils import init_logger +from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.net_utils import get_hostname_ip + +from .pd_remote_prefill_obj import ( + ConnectRequest, + RemoteRequest, + RemoteAgent, + RemoteRequstType, + PrefillRequest, + RemotePrefillRequest, + RemotePrefillServerInfo, + KVMoveRequest, + RemotePrefillTask, +) + + +logger = init_logger(__name__) + +try: + from nixl._api import nixl_agent + +except ImportError: + logger.error("nixl is not installed, which is required for pd disagreggation!!!") + raise + + +class PDRemotePrefillBase: + def __init__(self, + device_index: int, + kv_cache_queue: mp.Queue, # need send kv cache to this process and register with nixl + tp_size: int,): + # create local nixl agent + self.nixl_agent_name = f'nixl_agent_{get_unique_server_name}_{device_index}' + self.nixl_agent = nixl_agent(self.nixl_agent_name, None) + + # metadata need to send to remote server to make connection + self.nixl_agent_metadata = self.nixl_agent.get_agent_metadata() + + self.reg_descs = [None] * tp_size + self.local_xfer_handles = [None] * tp_size + + self.device_index = device_index + self.kv_cache_queue = kv_cache_queue + self.tp_size = tp_size + + self.num_layers = -1 + self.num_tokens = -1 + self.num_heads = -1 + self.head_dims= -1 + self.token_len = -1 + self.layer_len = -1 + + + def _create_xfer_handles(self, idx, reg_descs): + base_addr, _, device_id = reg_descs[0] + tokens_data = [] + for layer_id in range(self.num_layers): + for token_id in range(self.num_tokens): + tokens_data.append(base_addr + layer_id * self.layer_len + token_id * self.token_len, self.token_len, device_id) + + descs = self.nixl_agent.get_xfer_descs(tokens_data, "VRAM", True) + self.local_xfer_handles[idx] = self.nixl_agent.prep_xfer_dlist("", descs, is_sorted=True) + + def _get_token_desc_ids(self, token_ids: List[int]): + descs_ids = [] + for layer_id in range(self.num_layers): + for token_id in token_ids: + descs_ids.append(layer_id * self.num_tokens + token_id) + return descs_ids + + def local_init(self): + for _ in range(self.tp_size): + idx, tensor = self.kv_cache_queue.get(timeout=60) + if self.num_layers == -1: + self.num_layers, self.num_tokens, self.num_heads, self.head_dim = tensor.shape + self.token_len = self.num_heads * self.head_dim * tensor.element_size() + self.layer_len = self.num_tokens * self.token_len + + self.reg_descs[idx] = self.nixl_agent.register_memory(tensor) + self._create_xfer_handles(idx, self.reg_descs[idx]) + + logger.info("All local kv cache registered.") + + +class PDRemotePrefillServer(PDRemotePrefillBase): + def __init__(self, + http_server_port: int, + server_port: int, + device_index: int, + kvmove_request_queue: mp.Queue, + kvmove_done_queue: mp.Queue, + kv_cache_queue: mp.Queue, + tp_size: int): + super().__init__(device_index, kv_cache_queue, tp_size) + # map from client id to decode server info + self.remote_decode_clients = {} + + # build control path + _ctx = zmq.Context() + self.recv_from_decode = _ctx.socket(zmq.PAIR) + self.host_ip = get_hostname_ip() + self.recv_from_decode.bind(f"tcp://{self.host_ip}:{server_port}") + + # build trigger remote prefill path + self.send_to_httpserver = _ctx.socket(zmq.PUSH) + self.send_to_httpserver.connect(f"tcp://{self.host_ip}:{http_server_port}") + + self.prefill_requests = {} + + self.kvmove_request_queue = kvmove_request_queue + self.kvmove_done_queue = kvmove_done_queue + + self.inflight_transfer = defaultdict(list) + + def add_remote_agent(self, request: ConnectRequest): + peer_name = self.nixl_agent.add_remote_agent(request.agent_metadata) + mem_desc = self.nixl_agent.deserialize_descs(request.agent_mem_desc) + kv_xfer_handles = [] + for idx, desc in enumerate(mem_desc): + kv_xfer_handles.append(self._create_xfer_handles(idx, desc)) + + self.remote_decode_clients[request.decode_id] = RemoteAgent( + name=peer_name, + kv_mem_desc=mem_desc, + kv_xfer_handles=kv_xfer_handles) + + + def main_loop(self): + self.local_init() + self.transfer_task = asyncio.create_task(self.transfer_loop()) + self.wait_transfer_task = asyncio.create_task(self.wait_transfer_task()) + while True: + request: RemoteRequest = self.recv_from_decode.recv_pyobj() + if request.type == RemoteRequstType.REMOTE_CONNECT: + request: ConnectRequest = request + self.add_remote_agent(request) + elif request.type == RemoteRequstType.REMOTE_PREFILL: + request: PrefillRequest = request + self.trigger_prefill(request) + + + def trigger_prefill(self, request: PrefillRequest): + self.send_to_httpserver.send_pyobj((request.data.prompt, request.data.sampling_params, request.data.multimodal_params)) + self.prefill_requests[request.data.sampling_params.group_request_id] = request + + + async def transfer_loop(self): + while True: + request: KVMoveRequest = self.kv_cache_queue.get() + await self.trigger_kvcache_write(request) + + + async def trigger_kvcache_write(self, request: KVMoveRequest): + group_reqeust_id = request.group_req_id + prefill_request: PrefillRequest = self.prefill_requests[group_reqeust_id] + skip_kv_move_len = prefill_request.data.local_cached_len + src_token_ids = request.token_ids[skip_kv_move_len:] + dst_token_ids = prefill_request.data.token_ids[skip_kv_move_len:] + remote_agent = self.remote_decode_clients[prefill_request.decode_id] + if len(src_token_ids) > 0: + assert len(src_token_ids) == len(dst_token_ids) + src_token_descs = self._get_token_desc_ids(src_token_ids) + dst_token_descs = self._get_token_desc_ids(dst_token_ids) + + for i in range(self.tp_size): #TODO make this a single transfer + src_handle = self.local_xfer_handles[i] + dst_handle = remote_agent.remote_xfer_handles[i] + handle = self.nixl_agent.make_prepped_xfer("WRITE", + src_handle, src_token_descs, + dst_handle, dst_token_descs, group_reqeust_id) + self.inflight_transfer[group_reqeust_id].append(handle) + status = self.nixl_agent.transfer(handle) + + + await self.kv_cache_queue.put({"src": src_token_descs, "dst": dst_token_descs}) + + + def get_done_tranfers(self) -> List[str]: + done_req_ids = [] + failed_req_ids = [] + for req_id, handles in self.inflight_transfer.items(): + running_reqs = [] + failed_reqs = [] + for handle in handles: + xfer_state = self.nixl_agent.check_xfer_state(handle) + if xfer_state == "DONE": + self.nixl_wrapper.release_xfer_handle(handle) # TODO ptarasiewicz: why abort is throwing errors? + continue + if xfer_state == "PROC": + running_reqs.append(handle) + else: + logger.warning(f"Transfer failed with state {xfer_state}") + failed_reqs.append(handle) + break + + if failed_reqs: + failed_req_ids.append(req_id) + continue + + if len(running_reqs) == 0: + done_req_ids.append(req_id) + else: + self.inflight_transfer[req_id] = running_reqs + + return done_req_ids, failed_req_ids + + + async def wait_transfer_loop(self): + while True: + done_ids, failed_ids = self.get_done_transfers() + # handle successfully completed transfers + pass + + # handle failed transfers + pass + + # remote ids from inflight transfers and cancle inflight transfers if failed + + + + + + +class PDRemotePrefillClient(PDRemotePrefillBase): + + def __init__(self, + prefill_request_queue: mp.Queue, # only tp0 will trigger prefill + prefill_done_queue: List[mp.Queue], # one to many done queue + device_index: int, + kv_cache_queue: mp.Queue, # need send kv cache to this process and register with nixl + tp_size: int, + my_id: int, + ): + super().__init__(device_index, kv_cache_queue, tp_size) + # map from server id to prefill server info + self.remote_prefill_servers = {} + self.prefill_request_queue = prefill_request_queue + self.prefill_done_queue = prefill_done_queue + self.remote_prefill_requests = {} + self.my_id = my_id + + def connect_to_prefill_server(self, server_info: RemotePrefillServerInfo): + # build control path if not exist + if server_info.perfill_server_id not in self.remote_prefill_servers: + _ctx = zmq.Context() + _socket = _ctx.socket(zmq.PUSH) + connect_str = f"tcp://{server_info.prefill_server_ip}:{server_info.prefill_server_port}" + _socket.connect(connect_str) + _socket.send_pyobj(ConnectRequest( + type=RemoteRequstType.REMOTE_CONNECT, + decode_id=self.my_id, + agent_metadata=self.nixl_agent_metadata, + agent_mem_desc=self.nixl_agent.get_serialized_descs(self.reg_descs))) + self.remote_prefill_servers[server_info.perfill_server_id] = (_socket, server_info) + + def main_loop(self): + self.local_init() + asyncio.create_task(self.prefill_wait_loop()) + while True: + prefill_tasks: List[RemotePrefillTask] = self.prefill_request_queue.get() + for task in prefill_tasks: + # connect first + self.connect_to_prefill_server(task.server_info) + # do prefill + self.remote_prefill(task.server_info.perfill_server_id, task.prefill_request) + + + # place request to server do remote prefill + def remote_prefill(self, server_id: int, prefill_request: RemotePrefillRequest): + socket, _ = self.remote_prefill_servers[server_id] + group_req_id = str(prefill_request.sampling_params.group_request_id) + socket.send_pyobj(RemoteRequest(type=RemoteRequstType.REMOTE_PREFILL, decode_id=self.my_id, data=prefill_request)) + self.remote_prefill_requests[group_req_id] = prefill_request + + + async def prefill_wait_loop(self): + while True: + notifies = self.nixl_agent.get_new_notifs() + for agent_name, msgs in notifies.items(): + for msg in msgs: + # we got a finished prefill msg + for pdq in self.prefill_done_queue: + pdq.put(msg) + del self.remote_prefill_requests[msg] + logger.info(f"prefill reqeust: {msg} done") diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill_obj.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill_obj.py new file mode 100644 index 000000000..77b327b03 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill_obj.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass +from enum import Enum +from typing import List, Union + +from lightllm.utils.log_utils import init_logger +from lightllm.server.core.objs import SamplingParams +from lightllm.server.multimodal_params import MultimodalParams +from lightllm.server.pd_io_struct import RemotePrefillServerInfo + +logger = init_logger(__name__) + +try: + from nixl._api import nixl_agent, nixlBind, nixl_prepped_dlist_handle + +except ImportError: + logger.error("nixl is not installed, which is required for pd disagreggation!!!") + raise + + +class RemoteRequstType(Enum): + REMOTE_CONNECT = 1 + REMOTE_PREFILL = 2 + +@dataclass +class RemotePrefillRequest: + prompt: Union[str, List[int]] + sampling_params: SamplingParams + multimodal_params: MultimodalParams + local_cached_len: int # will skip transfer + token_ids: List[int] # mem cache indexes + + +@dataclass +class RemotePrefillTask: + server_info: RemotePrefillServerInfo + prefill_request: RemotePrefillRequest + + +@dataclass +class RemoteRequest: + type: RemoteRequstType + + +@dataclass +class ConnectRequest(RemoteRequest): + decode_id: int + agent_metadata: str + agent_mem_desc: str + + +@dataclass +class PrefillRequest(RemoteRequest): + decode_id: int + data: RemotePrefillRequest + + +@dataclass +class KVMoveRequest: + group_req_id: int + token_ids: List[int] + + +@dataclass +class RemoteAgent: + name: str + kv_mem_desc: List[nixlBind.nixlRegDList] + kv_xfer_handles: List[nixl_prepped_dlist_handle] \ No newline at end of file From e2d24627a0a417e55658d7d1ca774f80d6629b02 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Mon, 21 Apr 2025 17:10:49 +0800 Subject: [PATCH 02/19] working. --- lightllm/server/api_cli.py | 14 + lightllm/server/core/objs/__init__.py | 2 +- lightllm/server/core/objs/req.py | 42 +++ lightllm/server/core/objs/shm_req_manager.py | 6 +- lightllm/server/httpserver/pd_loop.py | 2 +- .../httpserver_for_pd_master/manager.py | 13 +- lightllm/server/multimodal_params.py | 6 + lightllm/server/router/batch.py | 6 +- lightllm/server/router/manager.py | 45 ++- .../server/router/model_infer/infer_batch.py | 15 +- .../model_infer/mode_backend/__init__.py | 2 + .../mode_backend/generic_pre_process.py | 41 +-- .../pd_disaggregation/impl_for_pd_base.py | 278 +++++++++++++++ .../pd_disaggregation/impl_for_pd_decode.py | 83 ++--- .../pd_disaggregation/impl_for_pd_prefill.py | 116 ++----- .../pd_disaggregation/nixl_kv_transporter.py | 154 +++++++++ .../pd_disaggregation/pd_remote_prefill.py | 327 +++++++----------- .../pd_remote_prefill_obj.py | 13 +- .../server/router/model_infer/model_rpc.py | 21 +- 19 files changed, 756 insertions(+), 430 deletions(-) create mode 100644 lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py create mode 100644 lightllm/server/router/model_infer/mode_backend/pd_disaggregation/nixl_kv_transporter.py diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 1aca6672c..70d892bd3 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -54,6 +54,20 @@ def make_argument_parser() -> argparse.ArgumentParser: default=None, help="The port number for the config server in config_server mode.", ) + parser.add_argument( + "--pd_remote_prefill_http_port", + type=int, + default=42001, + help="p d mode, remote prefill node used for kv move manager rpyc server port", + ) + + parser.add_argument( + "--pd_remote_prefill_port", + type=int, + default=42002, + help="p d mode, remote prefill node used for kv move manager rpyc server port", + ) + parser.add_argument( "--model_name", type=str, diff --git a/lightllm/server/core/objs/__init__.py b/lightllm/server/core/objs/__init__.py index d4e69fd36..2bdd9a51c 100644 --- a/lightllm/server/core/objs/__init__.py +++ b/lightllm/server/core/objs/__init__.py @@ -1,4 +1,4 @@ from .sampling_params import SamplingParams -from .req import Req, FinishStatus +from .req import Req, FinishStatus, PDChunkedPrefillReq from .shm_req_manager import ShmReqManager from .rpc_shm import RpcShmParams, RpcShmResults, ShmSyncStatusArray diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index c8d8476e5..5f81cf0b4 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -8,6 +8,7 @@ from lightllm.server.req_id_generator import convert_sub_id_to_group_id from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.dist_utils import get_dp_world_size from typing import List, Any, Union @@ -102,6 +103,7 @@ def get_str(self): f"shm_cur_kv_len:{self.shm_cur_kv_len}," f"shm_cur_output_len:{self.shm_cur_output_len}," f"finish_status:{self.finish_status.is_finished()}" + f"group_id: {self.group_req_id}" ) def init( @@ -333,3 +335,43 @@ def post_init( # 错误问题。 self.sample_params.max_new_tokens = self.sample_params.max_new_tokens + self.prefix_token_ids.size + 6 return + + +class PDChunkedPrefillReq(ChunkedPrefillReq): + _pack_ = 4 + _MAX_TP_SIZE = 128 + + def post_init(self): + super().post_init() + self.create_pd_req_state_shm_array() + + def create_pd_req_state_shm_array(self): + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_pd_req_state_{self.index_in_shm_mem}" + # self.dp_world_size = PDChunkedPrefillReq._MAX_TP_SIZE # get_dp_world_size() + self.pd_req_state_shm = ShmArray(name, (PDChunkedPrefillReq._MAX_TP_SIZE + 1,), dtype=np.int8) + self.pd_req_state_shm.create_shm() + self.pd_req_state_shm.arr.fill(0) + return + + def link_pd_req_state_shm_array(self): + service_uni_name = get_unique_server_name() + # self.dp_world_size = PDChunkedPrefillReq._MAX_TP_SIZE #get_dp_world_size() + name = f"{service_uni_name}_shm_pd_req_state_{self.index_in_shm_mem}" + self.pd_req_state_shm = ShmArray(name, (PDChunkedPrefillReq._MAX_TP_SIZE + 1,), dtype=np.int8) + self.pd_req_state_shm.link_shm() + return + + # called by each tp rank, no contention + def set_pd_req_rank_state(self, tp_id: int, state: int): + self.pd_req_state_shm.arr[tp_id] = state + + # state: -1 for failed, 0 for in progress, 1 for success + # set by router + def set_pd_req_state(self, dp_world_size: int): + unique_state = np.unique(self.pd_req_state_shm.arr[:dp_world_size]) + self.pd_req_state_shm.arr[dp_world_size] = unique_state[0] + + # read by all rank + def get_pd_req_state(self, dp_world_size: int): + return self.pd_req_state_shm.arr[dp_world_size] diff --git a/lightllm/server/core/objs/shm_req_manager.py b/lightllm/server/core/objs/shm_req_manager.py index 315eb938e..ca7d9bebe 100644 --- a/lightllm/server/core/objs/shm_req_manager.py +++ b/lightllm/server/core/objs/shm_req_manager.py @@ -3,7 +3,7 @@ from lightllm.utils.envs_utils import get_unique_server_name from multiprocessing import shared_memory from lightllm.utils.log_utils import init_logger -from .req import Req, NormalReq, ChunkedPrefillReq, TokenHealingReq +from .req import Req, NormalReq, ChunkedPrefillReq, TokenHealingReq, PDChunkedPrefillReq from .shm_array import ShmArray from .atomic_array_lock import AtomicShmArrayLock, AtomicLockItem from .atomic_lock import AtomicShmLock @@ -33,11 +33,15 @@ def get_req_class_type(self): if args.token_healing_mode: return TokenHealingReq + if args.run_mode == 'prefill' or args.run_mode == 'decode': + return PDChunkedPrefillReq + if args.disable_chunked_prefill: return NormalReq else: return ChunkedPrefillReq + def get_max_req_num(self): args: StartArgs = get_env_start_args() return args.running_max_req_size diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index 45312b108..18579fc5f 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -193,7 +193,7 @@ async def pd_handle_loop_from_d(manager: HttpServerManager): context = zmq.asyncio.Context(2) manager.recv_from_d = context.socket(zmq.PULL) - manager.recv_from_d.bind(f"tcp://*:{manager.args.pd_remote_prefill_port}") + manager.recv_from_d.bind(f"tcp://*:{manager.args.pd_remote_prefill_http_port}") while True: try: diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 3f963c520..597d3d3b1 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -177,11 +177,11 @@ async def fetch_stream( up_status_event = req_status.up_status_event - d_start_args = d_node.start_args + p_start_args = p_node.start_args prefill_node_dict = { - "node_id": d_start_args["pd_node_id"], - "ip": d_start_args["host"], - "rpyc_port": d_start_args["pd_prefill_rpyc_port"], + "node_id": p_start_args["pd_node_id"], + "ip": p_start_args["host"], + "rpyc_port": p_start_args["pd_remote_prefill_port"], "max_new_tokens": sampling_params.max_new_tokens - 1, "pd_master_node_id": self.args.pd_node_id, } @@ -271,11 +271,6 @@ async def abort(self, group_request_id): except: pass - try: - await req_status.p_node.websocket.send_bytes(pickle.dumps((ObjType.ABORT, group_request_id))) - except: - pass - try: await req_status.d_node.websocket.send_bytes(pickle.dumps((ObjType.ABORT, group_request_id))) except: diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index bf320e199..c57edfe4e 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -154,6 +154,12 @@ def to_dict(self): ret["audios"] = [a.to_dict() for a in self.audios] return ret + @classmethod + def from_dict(cls, data: dict): + if 'images' not in data: + return cls() + return cls(images=data["images"]) + def to_origin_dict(self): """ 将内容转换为原始请求的形式,主要用于请求转发 diff --git a/lightllm/server/router/batch.py b/lightllm/server/router/batch.py index 15d04b208..641d3ab57 100644 --- a/lightllm/server/router/batch.py +++ b/lightllm/server/router/batch.py @@ -1,8 +1,9 @@ import time from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union -from lightllm.server.core.objs import ShmReqManager, Req +from lightllm.server.core.objs import ShmReqManager, Req, PDChunkedPrefillReq from lightllm.utils.log_utils import init_logger +from lightllm.utils.dist_utils import get_dp_world_size logger = init_logger(__name__) @@ -53,6 +54,9 @@ def filter_out_finished_req(self, shm_req_manager: ShmReqManager): req = None else: unfinished_req_ids.append(req.request_id) + if isinstance(req, PDChunkedPrefillReq): + req.link_pd_req_state_shm_array() + req.set_pd_req_state(get_dp_world_size()) self.reqs = [self.id_to_reqs[req_id] for req_id in unfinished_req_ids] self.id_to_reqs = {req.request_id: req for req in self.reqs} diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 2811e4228..b8b4ee583 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -118,12 +118,16 @@ async def wait_to_model_ready(self): self.mem_queues: List[torch.multiprocessing.Queue] = [ torch.multiprocessing.Queue() for _ in range(self.node_world_size) ] + self.result_queues: List[mp.Queue] = [ + mp.Queue() for _ in range(self.world_size) + ] self.rpc_event = multiprocessing.Event() self.rpc_finished_event = multiprocessing.Event() assert (self.world_size % self.nnodes) == 0 node_world_size = self.world_size // self.nnodes for rank_id in range(self.node_rank * node_world_size, (self.node_rank + 1) * node_world_size): + rpc_model = await start_model_process( args=self.args, rank=rank_id, @@ -132,7 +136,8 @@ async def wait_to_model_ready(self): rpc_event=self.rpc_event, rpc_finished_event=self.rpc_finished_event, info_queue=self.info_queue, - mem_queue=self.mem_queues[(rank_id % node_world_size)], + result_queue=self.result_queues[rank_id], + mem_queue=self.mem_queues[rank_id], router_lock=self.router_lock, ) self.model_rpc_servers.append(rpc_model) @@ -194,19 +199,41 @@ async def wait_to_model_ready(self): if self.args.run_mode == "prefill": # 启动 prefill kv move 管理进程 - from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.prefill_node_impl import ( - start_prefill_kv_move_manager_process, - ) + # from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.prefill_node_impl import ( + # start_prefill_kv_move_manager_process, + # ) - start_prefill_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + # start_prefill_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + + from lightllm.server.router.model_infer.mode_backend.pd_disaggregation.pd_remote_prefill import ( + start_pd_remote_prefill_server_process + ) + start_pd_remote_prefill_server_process( + self.args.pd_node_id, + http_server_port=self.args.pd_remote_prefill_http_port, + server_port=self.args.pd_remote_prefill_port, + from_backend_queue=self.info_queue, + to_backend_queues=self.result_queues, + agent_meta_queues=self.mem_queues + ) if self.args.run_mode == "decode": # 启动 decode kv move 管理进程 - from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.decode_node_impl import ( - start_decode_kv_move_manager_process, - ) + # from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.decode_node_impl import ( + # start_decode_kv_move_manager_process, + # ) + + # start_decode_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) - start_decode_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + from lightllm.server.router.model_infer.mode_backend.pd_disaggregation.pd_remote_prefill import ( + start_pd_remote_prefill_client_process + ) + start_pd_remote_prefill_client_process( + self.args.pd_node_id, + from_backend_queue=self.info_queue, + to_backend_queues=self.result_queues, + agent_meta_queues=self.mem_queues, + ) return diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index d94ea5445..2c32d61e7 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -10,7 +10,7 @@ from typing import List, Dict, Tuple, Optional, Union, Any from lightllm.common.req_manager import ReqManager from lightllm.utils.infer_utils import mark_start, mark_end -from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager +from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager, PDChunkedPrefillReq from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode from lightllm.utils.log_utils import init_logger from lightllm.server.req_id_generator import convert_sub_id_to_group_id @@ -118,7 +118,7 @@ def _save_promptcache_kvbuffer(self): https://arxiv.org/abs/2403.01241 """ prompt_cache_token_id = list(self.radix_cache.root_node.children.values())[0].token_id_key - print(f"prompt_cache_token_id : {prompt_cache_token_id}") + # print(f"prompt_cache_token_id : {prompt_cache_token_id}") index = range(len(prompt_cache_token_id)) prompt_cache_kv_buffer = self.radix_cache.mem_manager.get_index_kv_buffer(index) torch.save(prompt_cache_kv_buffer, f"prompt_cache_rank_{dist.get_rank()}.pt") @@ -257,14 +257,17 @@ def __init__( self.vocab_size = vocab_size self.initialized = False self.paused = False - self.remote_prefilling = False - self.kv_transfering = False + self.in_prefill_or_transfer = False def init_all(self): if self.initialized is False: self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index) self.shm_req.link_prompt_ids_shm_array() self.shm_req.link_logprobs_shm_array() + if isinstance(self.shm_req, PDChunkedPrefillReq): + self.shm_req.link_pd_req_state_shm_array() + self.in_prefill_or_transfer = False + self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size) if self.sampling_param.shm_param.input_penalty: self.out_token_id_count = collections.Counter(self.shm_req.get_prompt_ids()) @@ -283,6 +286,10 @@ def init_all(self): self.cur_output_len = 0 self.finish_status = FinishStatus() + + # print(f"[INFO] init_all: {self.shm_req.group_req_id} {self.shm_req.get_pd_req_state()} {self.remote_prefilling}", + # f"{self.cur_kv_len} {self.get_cur_total_len()}") + if self.paused or not self.initialized: # 如果是具有 prompt_cache 的使用特性则需要进行提前的填充和恢复操作。 if g_infer_context.radix_cache is not None and self.get_cur_total_len() > 1: diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index 4594eec28..f2ef008ad 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -12,3 +12,5 @@ from .chunked_prefill.impl_for_xgrammar_mode import XgrammarBackend from .continues_batch.pd_mode.prefill_node_impl.prefill_impl_for_dp_chuncked import DPChunkedForPrefillNode from .continues_batch.pd_mode.decode_node_impl.decode_impl_for_dp import DPForDecodeNode +from .pd_disaggregation.impl_for_pd_prefill import PDNIXLBackendForPrefillNode +from .pd_disaggregation.impl_for_pd_decode import PDNIXLBackendForDecodeNode \ No newline at end of file diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 84926d484..66d056709 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -4,45 +4,6 @@ from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context from lightllm.common.basemodel.infer_lock import g_infer_state_lock -def prepare_remote_prefill_inputs(req_objs: List[InferReq]): - run_reqs = [] - start_loc = 0 - input_ids = [] - nopad_b_req_idx = [] - nopad_b_start_loc = [] - nopad_b_seq_len = [] - - for req in req_objs: - run_reqs.append(req) - nopad_b_req_idx.append(req.req_idx) - nopad_b_start_loc.append(start_loc) - - input_token_ids = req.get_input_token_ids() - seq_len = len(input_token_ids) - input_token_len = seq_len - req.cur_kv_len - input_id = input_token_ids[req.cur_kv_len :] - nopad_b_seq_len.append(seq_len) - input_ids.append(input_id) - start_loc += input_token_len - - nopad_b_start_loc.append(start_loc) # last request - - input_ids = np.concatenate(input_ids, dtype=np.int64) - g_infer_state_lock.acquire() # I don't think it's needed - if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]).cuda() - g_infer_state_lock.release() - kwargs = { - "batch_size": len(run_reqs), - "input_ids": input_ids, - "mem_indexes": mem_indexes, - "b_req_idx": nopad_b_req_idx, - "b_start_loc": nopad_b_start_loc, - "b_seq_len": nopad_b_seq_len, - } - return kwargs, run_reqs - def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool, is_multimodal=False): run_reqs = [] @@ -118,7 +79,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]): nopad_b_req_idx.append(req.req_idx) input_id = req.get_last_gen_token() seq_len = req.get_cur_total_len() - assert req.cur_kv_len == seq_len - 1 + assert req.cur_kv_len == seq_len - 1, f"{req.cur_kv_len} {seq_len}" nopad_b_seq_len.append(seq_len) input_ids.append(input_id) nopad_total_token_num += seq_len diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py new file mode 100644 index 000000000..49099d275 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py @@ -0,0 +1,278 @@ + +import time +import torch.multiprocessing as mp +from typing import Dict, List +import numpy as np + + +from lightllm.utils.log_utils import init_logger +from lightllm.server.core.objs.req import PDChunkedPrefillReq +from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq + +from .nixl_kv_transporter import NixlMetadata, NixlKVTransporter +from .pd_remote_prefill_obj import (PrefillRequest, + RemoteRequest, + RemoteRequstType, + ConnectRequest, + KVMoveRequest) + +logger = init_logger(__name__) + + + +class PDNIXLBackendBase(ModeBackend): + def __init__(self, + to_remote_queue: mp.Queue, + from_remote_queue: mp.Queue, + nixl_meta_queue: mp.Queue): + super().__init__() + self.to_remote_queue = to_remote_queue + self.from_remote_queue = from_remote_queue + self.nixl_meta_queue = nixl_meta_queue + + # for decode + self.remote_prefilled_reqs: Dict[int, InferReq] = {} + + # for prefill + self.remote_prefill_requests: Dict[str, PrefillRequest] = {} + self.inflght_transfer_requests: Dict[str, InferReq] = {} + + + def init_custom(self): + self.nixl_agent = NixlKVTransporter(self.args.pd_node_id, self.tp_rank) + self.nixl_agent.register_kv_buffer(self.model.mem_manager.kv_buffer) + self.nixl_meta_queue.put((self.nixl_agent.agent_metadata, + self.nixl_agent.num_tokens, + self.nixl_agent.local_mem_desc)) + + + def _prefill_wait_loop(self): + while True: + notifies = self.nixl_agent.get_new_notifs() + for agent_name, req_idxs in notifies.items(): + for req_id in req_idxs: + group_req_id = int(req_id.decode()) + if run_req := self.remote_prefilled_reqs.get(group_req_id, None): + shm_req: PDChunkedPrefillReq = run_req.shm_req + shm_req.set_pd_req_rank_state(self.rank_in_dp, 1) + self.remote_prefilled_reqs.pop(group_req_id) + logger.info(f"remote prefill reqeust: {group_req_id} done") + else: + logger.warning(f"remote prefill reqeust: {group_req_id} not found") + time.sleep(0) + + + def _wait_transfer_loop(self): + while True: + done_req_ids = self.nixl_agent.get_done_tranfers() + + for req_id, state in done_req_ids: + logger.info(f"wait transfer done: {req_id} state: {state}") + if req_id not in self.inflght_transfer_requests: + logger.warning(f"{req_id} not found in inflght_transfer_requests") + continue + + req: InferReq = self.inflght_transfer_requests[req_id] + shm_req: PDChunkedPrefillReq = req.shm_req + shm_req.set_pd_req_rank_state(self.rank_in_dp, state) + del self.inflght_transfer_requests[req_id] + time.sleep(0) + + + def _handle_prefill_loop(self): + while True: + request: RemoteRequest = self.from_remote_queue.get() + if request.type == RemoteRequstType.REMOTE_CONNECT: + request: ConnectRequest + logger.info(f"connect request received from: {request.decode_id}") + self.nixl_agent.add_remote_agent(NixlMetadata( + id = request.decode_id, + num_tokens = request.num_tokens, + agent_metadatas = request.agent_metadatas, + agent_mem_descs = request.agent_mem_descs + )) + + if request.type == RemoteRequstType.REMOTE_PREFILL: + request: PrefillRequest + group_request_id = request.data.sampling_params.group_request_id + logger.info(f"prefill request received from decode: {request.decode_id} " + f"and group request id: {group_request_id}") + self.remote_prefill_requests[group_request_id] = request + + def _transfer_kv_to_remote(self, req: InferReq): + group_req_id = req.shm_req.group_req_id + # set state + if group_req_id not in self.remote_prefill_requests: + logger.info(f"remote prefill request {group_req_id} not found") + return + + # kick off kv transfer + if req.finish_status.is_finished(): + kv_transfer_req = KVMoveRequest( + group_req_id=group_req_id, + token_ids=self.model.req_manager.req_to_token_indexs[req.req_idx][:req.cur_kv_len].tolist() + ) + remote_request = self.remote_prefill_requests[group_req_id] + self.nixl_agent.write_blocks(kv_transfer_req, remote_request) + shm_req: PDChunkedPrefillReq = req.shm_req + shm_req.set_pd_req_rank_state(self.rank_in_dp, 0) + req.kv_transfering = True + self.inflght_transfer_requests[group_req_id] = req + + def _decode_filter_reqs(self, prefill_reqs: List[InferReq], + aborted_reqs: List[InferReq], decode_reqs: List[InferReq]): + new_prefill_reqs: List[InferReq] = [] + remote_prefill_reqs: List[InferReq] = [] + + for req in prefill_reqs: + if req.in_prefill_or_transfer: + shm_req: PDChunkedPrefillReq = req.shm_req + # state is updated by router + state = shm_req.get_pd_req_state(self.dp_world_size) + if state == 1: # success + req.cur_kv_len = req.get_cur_total_len() - 1 + decode_reqs.append(req) + req.in_prefill_or_transfer = False + elif state == -1: # failure + aborted_reqs.append(req) + req.in_prefill_or_transfer = False + elif state == 0: # in progress + remote_prefill_reqs.append(req) + else: + logger.warning(f"remote prefill request {shm_req.group_req_id} unexpected state {state}") + continue + + new_prefill_reqs.append(req) + + return new_prefill_reqs, aborted_reqs, decode_reqs, remote_prefill_reqs + + def _prefill_filter_reqs(self, ok_finished_reqs: List[InferReq], aborted_reqs: List[InferReq]): + new_ok_finished_reqs = [] + kv_transfer_reqs = [] + + for req in ok_finished_reqs: + if req.in_prefill_or_transfer: + shm_req: PDChunkedPrefillReq = req.shm_req + state = shm_req.get_pd_req_state(self.dp_world_size) + if state == 1: # success + new_ok_finished_reqs.append(req) + req.in_prefill_or_transfer = False + elif state == -1: # failure + aborted_reqs.append(req) + req.in_prefill_or_transfer = False + elif state == 0: + kv_transfer_reqs.append(req) + else: + logger.warning(f"remote prefill request {shm_req.group_req_id} unexpected state {state}") + continue + + new_ok_finished_reqs.append(req) + + return new_ok_finished_reqs, aborted_reqs, kv_transfer_reqs + + + def _prepare_remote_prefill_inputs(self, req_objs: List[InferReq]): + run_reqs = [] + start_loc = 0 + input_ids = [] + nopad_b_req_idx = [] + nopad_b_start_loc = [] + nopad_b_seq_len = [] + + for req in req_objs: + run_reqs.append(req) + nopad_b_req_idx.append(req.req_idx) + nopad_b_start_loc.append(start_loc) + + input_token_ids = req.get_input_token_ids() + seq_len = len(input_token_ids) + input_token_len = seq_len - req.cur_kv_len + input_id = input_token_ids[req.cur_kv_len :] + nopad_b_seq_len.append(seq_len) + input_ids.append(input_id) + start_loc += input_token_len + + nopad_b_start_loc.append(start_loc) # last request + + input_ids = np.concatenate(input_ids, dtype=np.int64) + # g_infer_state_lock.acquire() # I don't think it's needed + if g_infer_context.radix_cache is not None: + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]) + # g_infer_state_lock.release() + kwargs = { + "batch_size": len(run_reqs), + "input_ids": input_ids, + "mem_indexes": mem_indexes.tolist(), + "b_req_idx": nopad_b_req_idx, + "b_start_loc": nopad_b_start_loc, + "b_seq_len": nopad_b_seq_len, + } + return kwargs, run_reqs + + # def _get_classed_reqs(self, req_ids: List[int], no_decode: bool = False, prefill=True): + # uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = super()._get_classed_reqs( + # req_ids, + # no_decode + # ) + + # progressing_reqs = [] + # new_ok_or_prefill_reqs = [] + + # ok_or_prefill_reqs = ok_finished_reqs if prefill else prefill_reqs + # success_reqs = progressing_reqs if prefill else decode_reqs + + # # filter remote prefill requests + # for r in ok_or_prefill_reqs: + # r: InferReq + # if r.in_prefill_or_transfer: + # shm_req: PDChunkedPrefillReq = r.shm_req + # # state is updated by router + # state = shm_req.get_pd_req_state() + # if state == 1: + # success_reqs.append(r) + # r.in_prefill_or_transfer = False + # elif state == -1: + # aborted_reqs.append(r) + # r.in_prefill_or_transfer = False + # elif state == 0: # in progress + # progressing_reqs.append(r) + # else: + # logger.warning(f"remote prefill request {r.req_id} unexpected state {state}") + # continue + + # new_ok_or_prefill_reqs.append(r) + + # if prefill: + # return uninit_reqs, aborted_reqs, new_ok_or_prefill_reqs, prefill_reqs, decode_reqs, progressing_reqs + # else: + # return uninit_reqs, aborted_reqs, ok_finished_reqs, new_ok_or_prefill_reqs, decode_reqs, progressing_reqs + + # def _get_classed_reqs(self, req_ids: List[int], no_decode: bool = False): + # uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = super()._get_classed_reqs( + # req_ids, + # no_decode + # ) + # new_ok_finished = [] + # transfer_reqs = [] + # # filter remote prefill requests + # for r in ok_finished_reqs: + # r: InferReq + # if r.kv_transfering: + # shm_req: PDChunkedPrefillReq = r.shm_req + # state = shm_req.get_pd_req_state() # state is updated by last post_handle, change is reflected here + # if state == 1: + # new_ok_finished.append(r) + # r.kv_transfering = False + # elif state == -1: + # aborted_reqs.append(r) + # r.kv_transfering = False + # elif state == 0: # in progress + # transfer_reqs.append(r) + # else: + # logger.warning(f"remote prefill request {r.req_id} unexpected state {state}") + # continue + # new_ok_finished.append(r) + + # return uninit_reqs, aborted_reqs, new_ok_finished, prefill_reqs, decode_reqs, transfer_reqs \ No newline at end of file diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_decode.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_decode.py index 856b30138..d9002dc78 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_decode.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_decode.py @@ -3,58 +3,37 @@ import threading from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend from typing import List, Tuple, Dict -from lightllm.utils.infer_utils import set_random_seed from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq -from lightllm.server.core.objs import FinishStatus +from lightllm.server.core.objs.req import PDChunkedPrefillReq from lightllm.utils.log_utils import init_logger -from lightllm.server.router.model_infer.mode_backend.generic_pre_process import prepare_decode_inputs, prepare_remote_prefill_inputs +from lightllm.server.router.model_infer.mode_backend.generic_pre_process import prepare_decode_inputs from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample +from lightllm.server.multimodal_params import MultimodalParams from .pd_remote_prefill_obj import ( RemotePrefillTask, RemotePrefillServerInfo, RemotePrefillRequest) +from .impl_for_pd_base import PDNIXLBackendBase + logger = init_logger(__name__) -class PDBackendForDecodeNode(ModeBackend): +class PDNIXLBackendForDecodeNode(PDNIXLBackendBase): def __init__(self, prefill_task_queue: mp.Queue, prefill_done_queue: mp.Queue, - mem_queue: mp.Queue) -> None: - super().__init__() - self.prefill_task_queue = prefill_task_queue - self.prefill_done_queue = prefill_done_queue - self.mem_queue = mem_queue - self.remote_prefilled_reqs: Dict[str, InferReq] = {} - - def wait_prefill_done_loop(self): - while True: - prefill_done_id = self.prefill_done_queue.get() - if prefill_done_id is None: # None means exit - logger.info("wait prefill done loop exits") - break - if run_req := self.remote_prefilled_reqs.get(prefill_done_id, None): - # remote prefill and transfer done, we need set kv cache to prompt len - - run_req.remote_prefilling = False - self.remote_prefilled_reqs.pop(prefill_done_id) - else: - logger.warning(f"wait prefill done loop: cannot find run_req with id {prefill_done_id}") + nix_meta_queue: mp.Queue) -> None: + super().__init__(prefill_task_queue, prefill_done_queue, nix_meta_queue) def init_custom(self): - - self.mem_queue.put((self.rank_in_dp, self.model.mem_manager.kv_buffer)) - - threading.Thread(target=self.wait_prefill_done_loop, daemon=True).start() - + super().init_custom() + self.wait_prefill_thread = threading.Thread(target=self._prefill_wait_loop, daemon=True) + self.wait_prefill_thread.start() return - def prefill(self, reqs: List[Tuple]): - self._init_reqs(reqs, init_req_obj=False) - return def _build_remote_prefill_task(self, index: int, kwargs: Dict, req: InferReq): prefill_node = req.shm_req.sample_params.move_kv_to_decode_node.to_dict() @@ -66,35 +45,29 @@ def _build_remote_prefill_task(self, index: int, kwargs: Dict, req: InferReq): mem_indexes = kwargs.get('mem_indexes') b_start_loc = kwargs.get('b_start_loc') - + logger.info(req.shm_req.get_str()) prefill_request = RemotePrefillRequest( - group_request_id=req.shm_req.group_req_id, prompt = req.shm_req.get_prompt_ids(), sampling_params=req.shm_req.sample_params, - multimodal_params=req.multimodal_params, + multimodal_params=MultimodalParams.from_dict(req.multimodal_params), local_cached_len=req.cur_kv_len, token_ids=mem_indexes[b_start_loc[index]: b_start_loc[index+1]], ) return RemotePrefillTask(server_info=prefill_node_info, prefill_request=prefill_request) - def _get_classed_reqs(self, req_ids: List[int], no_decode: bool = False): - uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = super()._get_classed_reqs( - req_ids, - no_decode, - ) - new_prefill_reqs = [] - # filter remote prefill requests - for r in prefill_reqs: - if r.remote_prefilling: - continue - new_prefill_reqs.append(r) - return uninit_reqs, aborted_reqs, ok_finished_reqs, new_prefill_reqs, decode_reqs + + def prefill(self, reqs: List[Tuple]): + self._init_reqs(reqs, init_req_obj=False) + return def decode(self): + uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( g_infer_context.infer_req_ids, no_decode=False, ) + # filter out remote prefilling reqs + prefill_reqs, aborted_reqs, decode_reqs, _ = self._decode_filter_reqs(prefill_reqs, aborted_reqs, decode_reqs) self._filter_reqs(aborted_reqs) @@ -102,18 +75,21 @@ def decode(self): if prefill_reqs: # TODO: we could allocate cache later after remote prefill done and get a signal from remote # but it will have a risk to not have enough cache for this request. - kwargs, run_reqs = prepare_remote_prefill_inputs(prefill_reqs) + kwargs, run_reqs = self._prepare_remote_prefill_inputs(prefill_reqs) for idx, run_req in enumerate(run_reqs): run_req: InferReq = run_req + shm_req: PDChunkedPrefillReq = run_req.shm_req # forward each req to remote prefill # since the token index are the same across TPs, we only need to trigger prefill on master if self.is_master_in_dp: - self.prefill_task_queue.put(self._build_remote_prefill_task(idx, kwargs, run_req)) + self.to_remote_queue.put(self._build_remote_prefill_task(idx, kwargs, run_req)) - run_req.remote_prefilling = True - self.remote_prefilled_reqs[run_req.req_id] = run_req + shm_req.set_pd_req_rank_state(self.rank_in_dp, 0) # set in progress state + run_req.in_prefill_or_transfer = True + self.remote_prefilled_reqs[shm_req.group_req_id] = run_req if decode_reqs: + # print(f"decode req: {self.rank_in_dp}: {len(decode_reqs)}") kwargs, run_reqs = prepare_decode_inputs(decode_reqs) logits = self.model.forward(**kwargs) @@ -126,9 +102,8 @@ def decode(self): next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() self._post_handle( - run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False - ) + run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False) self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) - return + return \ No newline at end of file diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_prefill.py index 579190132..6710854e7 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_prefill.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_prefill.py @@ -1,71 +1,56 @@ -import time import threading import torch import torch.multiprocessing as mp -from typing import List, Tuple, Dict -from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend -from lightllm.utils.infer_utils import set_random_seed +from typing import List, Tuple from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end -from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams, g_infer_context -from lightllm.server.pd_io_struct import KVMoveTask, DecodeNodeInfo +from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context from lightllm.utils.log_utils import init_logger from lightllm.server.router.model_infer.mode_backend.generic_pre_process import prepare_prefill_inputs from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample +from .impl_for_pd_base import PDNIXLBackendBase logger = init_logger(__name__) -class ChunckedPrefillForPrefillNode(ModeBackend): +class PDNIXLBackendForPrefillNode(PDNIXLBackendBase): def __init__(self, transfer_task_queue: mp.Queue, transfer_done_queue: mp.Queue, - mem_queue: mp.Queue) -> None: - super().__init__() - self.transfer_task_queue = transfer_task_queue - self.transfer_done_queue = transfer_done_queue - self.mem_queue = mem_queue - self.inflight_transfer_request: Dict[str, InferReq] = {} - - def wait_transfer_done(self): - while True: - transfer_done = self.transfer_done_queue.get() - logger.debug(f"Transfer done: {transfer_done}") - self.inflight_transfer_request[transfer_done].wait() - del self.inflight_transfer_request[transfer_done] + nixl_meta_queue: mp.Queue) -> None: + super().__init__(transfer_task_queue, transfer_done_queue, nixl_meta_queue) + def init_custom(self): - threading.Thread(target=self.wait_transfer_done, daemon=True).start() + super().init_custom() + self.handle_prefill_loop_thread = threading.Thread(target=self._handle_prefill_loop, daemon=True) + self.wait_transfer_loop_thread = threading.Thread(target=self._wait_transfer_loop, daemon=True) + self.handle_prefill_loop_thread.start() + self.wait_transfer_loop_thread.start() return + def prefill(self, reqs: List[Tuple]): self._init_reqs(reqs) return - def _get_classed_reqs(self, req_ids: List[int], no_decode: bool = False): - uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = super()._get_classed_reqs( - req_ids, - no_decode - ) - new_ok_finished = [] - # filter remote prefill requests - for r in ok_finished_reqs: - if r.kv_transfering: - continue - new_ok_finished.append(r) - return uninit_reqs, aborted_reqs, new_ok_finished, prefill_reqs, decode_reqs - def decode(self): uinit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( g_infer_context.infer_req_ids, no_decode=True, ) + + ok_finished_reqs, aborted_reqs, _ = self._prefill_filter_reqs(ok_finished_reqs, aborted_reqs) + # print(f"{self.rank_in_dp}: {len(uinit_reqs)} uninit, {len(aborted_reqs)} aborted, {len(ok_finished_reqs)} ok finished, " + # f"{len(prefill_reqs)} new prefill, {len(decode_reqs)} decode, {len(transfer_reqs)} transfer_reqs") + assert len(uinit_reqs) == 0 assert len(decode_reqs) == 0 self._filter_reqs(aborted_reqs) if ok_finished_reqs: - self.prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(ok_finished_reqs) + for req in ok_finished_reqs: + self._transfer_kv_to_remote(req) self._filter_reqs(ok_finished_reqs) if prefill_reqs: @@ -79,63 +64,10 @@ def decode(self): next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() self._post_handle( - run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False + run_reqs, next_token_ids, next_token_logprobs, + is_chuncked_mode=True, + do_filter_finished_reqs=False, + extra_post_req_handle_func=lambda req, _1, _2: self._transfer_kv_to_remote(req) ) return - def prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(self, run_reqs: List[InferReq]): - # 提前在radix cache中回收相关的信息,并添加引用信息 - if self.is_master_in_dp: - logger.info("prefill_req_handle_and_frozen_tokens") - try: - for req in run_reqs: - req: InferReq = req - key = req.get_input_token_ids()[0 : req.cur_kv_len] - key = torch.tensor(key, dtype=torch.int64, device="cpu") - value = self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() - prefix_len = self.radix_cache.insert(key, value) - old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len - self.model.mem_manager.free( - self.model.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len] - ) - if req.shared_kv_node is not None: - self.radix_cache.dec_node_ref_counter(req.shared_kv_node) - req.shared_kv_node = None - - req.cur_kv_len = 0 - req.shm_req.shm_cur_kv_len = 0 - - if req.shm_req.sample_params.move_kv_to_decode_node.exists: - # 注意兼容纯tp 和 tp dp 混合模式的逻辑 - if self.is_master_in_dp: - g_router_lock.acquire() - self.shared_token_load.add_frozened_token_count(len(key), self.dp_rank_in_node) - g_router_lock.release() - - share_node, kv_len, value = self.radix_cache.match_prefix(key, update_refs=True) - assert len(key) == len(value) - # 将下面的请求放入到任务队列中, 注意要使用raidx cache 返回的value - decode_node_info = DecodeNodeInfo(**req.shm_req.sample_params.move_kv_to_decode_node.to_dict()) - task = KVMoveTask( - group_request_id=req.shm_req.group_req_id, - input_tokens=key.tolist(), - prefill_token_indexes=value.tolist(), - decode_token_indexes=None, - prefill_node_id=self.args.pd_node_id, - decode_node=decode_node_info, - move_kv_len=None, - prefill_dp_index=self.dp_rank_in_node, - decode_dp_index=None, - mark_start_time=time.time(), - ) - g_kv_move_task_cache[task.group_request_id] = (task, share_node) - - # 注意兼容纯 tp 和 tp dp 混合模式的逻辑 - if self.is_master_in_dp: - self.info_queue.put(task) - except BaseException as e: - logger.exception(str(e)) - - if self.is_master_in_dp: - logger.info("prefill_req_handle_and_frozen_tokens end") - return diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/nixl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/nixl_kv_transporter.py new file mode 100644 index 000000000..56bfd0e38 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/nixl_kv_transporter.py @@ -0,0 +1,154 @@ + +from collections import defaultdict +from typing import Dict, List, Any +from torch import Tensor +from dataclasses import dataclass + +from lightllm.utils.log_utils import init_logger + +from .pd_remote_prefill_obj import RemoteAgent, KVMoveRequest, PrefillRequest + + +logger = init_logger(__name__) + +try: + from nixl._api import nixl_agent as NixlWrapper + from nixl._api import nixlBind + logger.info("Nixl is available") +except ImportError: + logger.warning("nixl is not installed, which is required for pd disagreggation!!!") + NixlWrapper = None + +@dataclass +class NixlMetadata: + id: str + num_tokens: list[int] + agent_metadatas: list[bytes] + agent_mem_descs: list[bytes] + +class NixlKVTransporter: + def __init__(self, node_id: int, tp_idx: int): + self.node_id = node_id + self.tp_idx = tp_idx + self.nixl_agent = NixlWrapper(self.agent_name, None) + + self.num_layers = -1 + self.num_tokens = -1 + self.num_heads = -1 + self.head_dims= -1 + self.token_len = -1 + self.layer_len = -1 + + self.reg_desc = None + self.local_xfer_handles = None + + self.remote_agents = defaultdict(list) + self.inflight_transfers: Dict[str, Any] = {} + + + @property + def agent_name(self) -> str: + return f"{self.node_id}_{self.tp_idx}" + + @property + def agent_metadata(self): + return self.nixl_agent.get_agent_metadata() + + @property + def local_mem_desc(self): + return self.nixl_agent.get_serialized_descs(self.reg_desc) + + def get_new_notifs(self): + return self.nixl_agent.get_new_notifs() + + + def _create_xfer_handles(self, reg_desc: nixlBind.nixlRegDList, num_tokens: int, agent_name: str = ""): + base_addr, _, device_id, _ = reg_desc[0] + tokens_data = [] + for layer_id in range(self.num_layers): + for token_id in range(num_tokens): + tokens_data.append((base_addr + layer_id * self.layer_len + token_id * self.token_len, self.token_len, device_id)) + descs = self.nixl_agent.get_xfer_descs(tokens_data, "VRAM", True) + return self.nixl_agent.prep_xfer_dlist(agent_name, descs, is_sorted=True) + + def register_kv_buffer(self, kv_buffer: Tensor): + self.num_layers, self.num_tokens, self.num_heads, self.head_dim = kv_buffer.shape + self.token_len = self.num_heads * self.head_dim * kv_buffer.element_size() + self.layer_len = self.num_tokens * self.token_len + + self.reg_desc = self.nixl_agent.register_memory(kv_buffer) + self.local_xfer_handles = self._create_xfer_handles(self.reg_desc, self.num_tokens) + + + def add_remote_agent(self, remote_agent: NixlMetadata): + for agent_metadata, num_tokens, agent_mem_desc in zip(remote_agent.agent_metadatas, + remote_agent.num_tokens, + remote_agent.agent_mem_descs): + peer_name = self.nixl_agent.add_remote_agent(agent_metadata) + mem_desc = self.nixl_agent.deserialize_descs(agent_mem_desc) + logger.info("Added remote agent %s with mem desc %s", peer_name, mem_desc) + kv_xfer_handles = self._create_xfer_handles(mem_desc, num_tokens, agent_name=peer_name) + self.remote_agents[remote_agent.id].append( + RemoteAgent(name=peer_name, kv_mem_desc=mem_desc, num_tokens=num_tokens, kv_xfer_handles=kv_xfer_handles)) + + def _get_token_desc_ids(self, token_ids: List[int]): + descs_ids = [] + for layer_id in range(self.num_layers): + for token_id in token_ids: + descs_ids.append(layer_id * self.num_tokens + token_id) + return descs_ids + + def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest): + group_reqeust_id = request.group_req_id + skip_kv_move_len = prefill_request.data.local_cached_len + src_token_ids = request.token_ids[skip_kv_move_len:] + dst_token_ids = prefill_request.data.token_ids + remote_agent: RemoteAgent = self.remote_agents[prefill_request.decode_id][self.tp_idx] #TODO one-one mapping now + + if len(src_token_ids) > 0: + assert len(src_token_ids) == len(dst_token_ids), f"{len(src_token_ids)} {len(dst_token_ids)}" + src_token_descs = self._get_token_desc_ids(src_token_ids) + dst_token_descs = self._get_token_desc_ids(dst_token_ids) + + src_handle = self.local_xfer_handles + dst_handle = remote_agent.kv_xfer_handles + handle = self.nixl_agent.make_prepped_xfer("WRITE", + src_handle, src_token_descs, + dst_handle, dst_token_descs, + str(group_reqeust_id).encode()) + + status = self.nixl_agent.transfer(handle) + assert status != 'ERR' + + self.inflight_transfers[group_reqeust_id] = handle + + return handle + + return None + + def get_done_tranfers(self): + done_req_ids = [] + for req_id, handle in self.inflight_transfers.items(): + xfer_state = self.nixl_agent.check_xfer_state(handle) + if xfer_state == "DONE": + done_req_ids.append((req_id, 1)) + elif xfer_state == "PROC": + continue + else: + logger.warning(f"{req_id} Transfer failed with state {xfer_state}") + done_req_ids.append((req_id, -1)) + + for req_id, _ in done_req_ids: + self.nixl_agent.release_xfer_handle(self.inflight_transfers[req_id]) + del self.inflight_transfers[req_id] + + return done_req_ids + + def shutdonw(self): + self.nixl_agent.deregister_memory(self.reg_desc) + self.nixl_agent.release_dlist_handle(self.local_xfer_handles) + for id, agents in self.remote_agents.items(): + for agent in agents: + self.nixl_agent.remove_remote_agent(agent.name) + self.nixl_agent.release_xfer_handle(agent.kv_xfer_handles) + diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill.py index 9b6eb3342..309213773 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill.py @@ -1,113 +1,67 @@ -import asyncio -from typing import List, Union -from enum import Enum +from typing import List import zmq -from collections import defaultdict -from torch import Tensor import torch.multiprocessing as mp from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.net_utils import get_hostname_ip from .pd_remote_prefill_obj import ( ConnectRequest, RemoteRequest, - RemoteAgent, RemoteRequstType, PrefillRequest, RemotePrefillRequest, RemotePrefillServerInfo, - KVMoveRequest, RemotePrefillTask, ) +from .nixl_kv_transporter import NixlMetadata logger = init_logger(__name__) -try: - from nixl._api import nixl_agent - -except ImportError: - logger.error("nixl is not installed, which is required for pd disagreggation!!!") - raise - - class PDRemotePrefillBase: def __init__(self, - device_index: int, - kv_cache_queue: mp.Queue, # need send kv cache to this process and register with nixl - tp_size: int,): - # create local nixl agent - self.nixl_agent_name = f'nixl_agent_{get_unique_server_name}_{device_index}' - self.nixl_agent = nixl_agent(self.nixl_agent_name, None) - - # metadata need to send to remote server to make connection - self.nixl_agent_metadata = self.nixl_agent.get_agent_metadata() - - self.reg_descs = [None] * tp_size - self.local_xfer_handles = [None] * tp_size - - self.device_index = device_index - self.kv_cache_queue = kv_cache_queue - self.tp_size = tp_size - - self.num_layers = -1 - self.num_tokens = -1 - self.num_heads = -1 - self.head_dims= -1 - self.token_len = -1 - self.layer_len = -1 - - - def _create_xfer_handles(self, idx, reg_descs): - base_addr, _, device_id = reg_descs[0] - tokens_data = [] - for layer_id in range(self.num_layers): - for token_id in range(self.num_tokens): - tokens_data.append(base_addr + layer_id * self.layer_len + token_id * self.token_len, self.token_len, device_id) - - descs = self.nixl_agent.get_xfer_descs(tokens_data, "VRAM", True) - self.local_xfer_handles[idx] = self.nixl_agent.prep_xfer_dlist("", descs, is_sorted=True) - - def _get_token_desc_ids(self, token_ids: List[int]): - descs_ids = [] - for layer_id in range(self.num_layers): - for token_id in token_ids: - descs_ids.append(layer_id * self.num_tokens + token_id) - return descs_ids + id: int, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], # need send kv cache to this process and register with nixl + ): + self.id = id + self.tp_size = len(agent_meta_queues) + self.agent_meta_queues = agent_meta_queues + self.from_backend_queue = from_backend_queue + self.to_backend_queues = to_backend_queues + self.local_agent_meta = None def local_init(self): - for _ in range(self.tp_size): - idx, tensor = self.kv_cache_queue.get(timeout=60) - if self.num_layers == -1: - self.num_layers, self.num_tokens, self.num_heads, self.head_dim = tensor.shape - self.token_len = self.num_heads * self.head_dim * tensor.element_size() - self.layer_len = self.num_tokens * self.token_len - - self.reg_descs[idx] = self.nixl_agent.register_memory(tensor) - self._create_xfer_handles(idx, self.reg_descs[idx]) - + agent_metas = NixlMetadata(id=self.id, agent_metadatas=[], num_tokens=[], agent_mem_descs=[]) + for tp in range(self.tp_size): + agent_metadata, num_tokens, mem_desc = self.agent_meta_queues[tp].get(timeout=60) + logger.info(f"Received agent_metadata from {tp} with mem reg: {mem_desc}") + agent_metas.num_tokens.append(num_tokens) + agent_metas.agent_metadatas.append(agent_metadata) + agent_metas.agent_mem_descs.append(mem_desc) + + self.local_agent_meta = agent_metas logger.info("All local kv cache registered.") class PDRemotePrefillServer(PDRemotePrefillBase): def __init__(self, + id: int, http_server_port: int, server_port: int, - device_index: int, - kvmove_request_queue: mp.Queue, - kvmove_done_queue: mp.Queue, - kv_cache_queue: mp.Queue, - tp_size: int): - super().__init__(device_index, kv_cache_queue, tp_size) + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue]): + super().__init__(id, from_backend_queue, to_backend_queues, agent_meta_queues) # map from client id to decode server info self.remote_decode_clients = {} # build control path _ctx = zmq.Context() - self.recv_from_decode = _ctx.socket(zmq.PAIR) + self.recv_from_decode = _ctx.socket(zmq.PULL) self.host_ip = get_hostname_ip() self.recv_from_decode.bind(f"tcp://{self.host_ip}:{server_port}") @@ -117,137 +71,41 @@ def __init__(self, self.prefill_requests = {} - self.kvmove_request_queue = kvmove_request_queue - self.kvmove_done_queue = kvmove_done_queue - - self.inflight_transfer = defaultdict(list) - - def add_remote_agent(self, request: ConnectRequest): - peer_name = self.nixl_agent.add_remote_agent(request.agent_metadata) - mem_desc = self.nixl_agent.deserialize_descs(request.agent_mem_desc) - kv_xfer_handles = [] - for idx, desc in enumerate(mem_desc): - kv_xfer_handles.append(self._create_xfer_handles(idx, desc)) - - self.remote_decode_clients[request.decode_id] = RemoteAgent( - name=peer_name, - kv_mem_desc=mem_desc, - kv_xfer_handles=kv_xfer_handles) def main_loop(self): self.local_init() - self.transfer_task = asyncio.create_task(self.transfer_loop()) - self.wait_transfer_task = asyncio.create_task(self.wait_transfer_task()) while True: request: RemoteRequest = self.recv_from_decode.recv_pyobj() - if request.type == RemoteRequstType.REMOTE_CONNECT: - request: ConnectRequest = request - self.add_remote_agent(request) - elif request.type == RemoteRequstType.REMOTE_PREFILL: + logger.info(f"recevied request from decode, type: {request.type}") + + if request.type == RemoteRequstType.REMOTE_PREFILL: request: PrefillRequest = request self.trigger_prefill(request) + # forward request to prefill server + for queue in self.to_backend_queues: + queue.put(request) + def trigger_prefill(self, request: PrefillRequest): self.send_to_httpserver.send_pyobj((request.data.prompt, request.data.sampling_params, request.data.multimodal_params)) self.prefill_requests[request.data.sampling_params.group_request_id] = request - async def transfer_loop(self): - while True: - request: KVMoveRequest = self.kv_cache_queue.get() - await self.trigger_kvcache_write(request) - - - async def trigger_kvcache_write(self, request: KVMoveRequest): - group_reqeust_id = request.group_req_id - prefill_request: PrefillRequest = self.prefill_requests[group_reqeust_id] - skip_kv_move_len = prefill_request.data.local_cached_len - src_token_ids = request.token_ids[skip_kv_move_len:] - dst_token_ids = prefill_request.data.token_ids[skip_kv_move_len:] - remote_agent = self.remote_decode_clients[prefill_request.decode_id] - if len(src_token_ids) > 0: - assert len(src_token_ids) == len(dst_token_ids) - src_token_descs = self._get_token_desc_ids(src_token_ids) - dst_token_descs = self._get_token_desc_ids(dst_token_ids) - - for i in range(self.tp_size): #TODO make this a single transfer - src_handle = self.local_xfer_handles[i] - dst_handle = remote_agent.remote_xfer_handles[i] - handle = self.nixl_agent.make_prepped_xfer("WRITE", - src_handle, src_token_descs, - dst_handle, dst_token_descs, group_reqeust_id) - self.inflight_transfer[group_reqeust_id].append(handle) - status = self.nixl_agent.transfer(handle) - - - await self.kv_cache_queue.put({"src": src_token_descs, "dst": dst_token_descs}) - - - def get_done_tranfers(self) -> List[str]: - done_req_ids = [] - failed_req_ids = [] - for req_id, handles in self.inflight_transfer.items(): - running_reqs = [] - failed_reqs = [] - for handle in handles: - xfer_state = self.nixl_agent.check_xfer_state(handle) - if xfer_state == "DONE": - self.nixl_wrapper.release_xfer_handle(handle) # TODO ptarasiewicz: why abort is throwing errors? - continue - if xfer_state == "PROC": - running_reqs.append(handle) - else: - logger.warning(f"Transfer failed with state {xfer_state}") - failed_reqs.append(handle) - break - - if failed_reqs: - failed_req_ids.append(req_id) - continue - - if len(running_reqs) == 0: - done_req_ids.append(req_id) - else: - self.inflight_transfer[req_id] = running_reqs - - return done_req_ids, failed_req_ids - - - async def wait_transfer_loop(self): - while True: - done_ids, failed_ids = self.get_done_transfers() - # handle successfully completed transfers - pass - - # handle failed transfers - pass - - # remote ids from inflight transfers and cancle inflight transfers if failed - - - - - class PDRemotePrefillClient(PDRemotePrefillBase): def __init__(self, - prefill_request_queue: mp.Queue, # only tp0 will trigger prefill - prefill_done_queue: List[mp.Queue], # one to many done queue - device_index: int, - kv_cache_queue: mp.Queue, # need send kv cache to this process and register with nixl - tp_size: int, - my_id: int, + id: int, + from_backend_queue: mp.Queue, # only tp0 will trigger prefill + to_backend_queues: List[mp.Queue], # one to many done queue + agent_meta_queues: List[mp.Queue] ): - super().__init__(device_index, kv_cache_queue, tp_size) + super().__init__(id, from_backend_queue, to_backend_queues, agent_meta_queues) # map from server id to prefill server info + self.remote_prefill_servers = {} - self.prefill_request_queue = prefill_request_queue - self.prefill_done_queue = prefill_done_queue - self.remote_prefill_requests = {} - self.my_id = my_id def connect_to_prefill_server(self, server_info: RemotePrefillServerInfo): # build control path if not exist @@ -258,38 +116,91 @@ def connect_to_prefill_server(self, server_info: RemotePrefillServerInfo): _socket.connect(connect_str) _socket.send_pyobj(ConnectRequest( type=RemoteRequstType.REMOTE_CONNECT, - decode_id=self.my_id, - agent_metadata=self.nixl_agent_metadata, - agent_mem_desc=self.nixl_agent.get_serialized_descs(self.reg_descs))) + decode_id=self.id, + num_tokens=self.local_agent_meta.num_tokens, + agent_metadatas=self.local_agent_meta.agent_metadatas, + agent_mem_descs=self.local_agent_meta.agent_mem_descs)) + self.remote_prefill_servers[server_info.perfill_server_id] = (_socket, server_info) def main_loop(self): self.local_init() - asyncio.create_task(self.prefill_wait_loop()) while True: - prefill_tasks: List[RemotePrefillTask] = self.prefill_request_queue.get() - for task in prefill_tasks: - # connect first - self.connect_to_prefill_server(task.server_info) - # do prefill - self.remote_prefill(task.server_info.perfill_server_id, task.prefill_request) + prefill_tasks: RemotePrefillTask = self.from_backend_queue.get() + + # connect first + self.connect_to_prefill_server(prefill_tasks.server_info) + # do prefill + self.remote_prefill(prefill_tasks.server_info.perfill_server_id, prefill_tasks.prefill_request) # place request to server do remote prefill def remote_prefill(self, server_id: int, prefill_request: RemotePrefillRequest): socket, _ = self.remote_prefill_servers[server_id] - group_req_id = str(prefill_request.sampling_params.group_request_id) - socket.send_pyobj(RemoteRequest(type=RemoteRequstType.REMOTE_PREFILL, decode_id=self.my_id, data=prefill_request)) - self.remote_prefill_requests[group_req_id] = prefill_request - - - async def prefill_wait_loop(self): - while True: - notifies = self.nixl_agent.get_new_notifs() - for agent_name, msgs in notifies.items(): - for msg in msgs: - # we got a finished prefill msg - for pdq in self.prefill_done_queue: - pdq.put(msg) - del self.remote_prefill_requests[msg] - logger.info(f"prefill reqeust: {msg} done") + prefill_request.sampling_params.max_new_tokens = 1 + socket.send_pyobj(PrefillRequest(type=RemoteRequstType.REMOTE_PREFILL, decode_id=self.id, data=prefill_request)) + + + +def remote_prefill_server_loop( + id: int, + http_server_port: int, + server_port: int, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], +): + server = PDRemotePrefillServer(id, http_server_port, server_port, + from_backend_queue, to_backend_queues, agent_meta_queues) + server.main_loop() + + +def start_pd_remote_prefill_server_process( + id: int, + http_server_port: int, + server_port: int, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], +): + proc = mp.Process( + target=remote_prefill_server_loop, + args=( + id, http_server_port, server_port, + from_backend_queue, to_backend_queues, agent_meta_queues) + ) + proc.start() + assert proc.is_alive() + logger.info(f"remote prefill server with id: {id} started!") + return proc + + +def remote_prefill_client_loop( + id: int, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue]): + + client = PDRemotePrefillClient( + id, + from_backend_queue, + to_backend_queues, + agent_meta_queues, + ) + client.main_loop() + +def start_pd_remote_prefill_client_process( + id: int, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue] +): + + proc = mp.Process( + target=remote_prefill_client_loop, + args=(id, from_backend_queue, to_backend_queues, agent_meta_queues) + ) + proc.start() + assert proc.is_alive() + logger.info(f"remote prefill client with id: {id} started!") + return proc diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill_obj.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill_obj.py index 77b327b03..152a336cd 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill_obj.py @@ -10,13 +10,12 @@ logger = init_logger(__name__) try: - from nixl._api import nixl_agent, nixlBind, nixl_prepped_dlist_handle + from nixl._api import nixlBind, nixl_prepped_dlist_handle except ImportError: logger.error("nixl is not installed, which is required for pd disagreggation!!!") raise - class RemoteRequstType(Enum): REMOTE_CONNECT = 1 REMOTE_PREFILL = 2 @@ -44,8 +43,9 @@ class RemoteRequest: @dataclass class ConnectRequest(RemoteRequest): decode_id: int - agent_metadata: str - agent_mem_desc: str + num_tokens: List[int] + agent_metadatas: List[bytes] + agent_mem_descs: List[bytes] @dataclass @@ -63,5 +63,6 @@ class KVMoveRequest: @dataclass class RemoteAgent: name: str - kv_mem_desc: List[nixlBind.nixlRegDList] - kv_xfer_handles: List[nixl_prepped_dlist_handle] \ No newline at end of file + num_tokens: int + kv_mem_desc: nixlBind.nixlRegDList + kv_xfer_handles: nixl_prepped_dlist_handle \ No newline at end of file diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index ed1470fba..7f427c84f 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -21,6 +21,8 @@ DPForDecodeNode, ChunckedPrefillForPrefillNode, DPChunkedForPrefillNode, + PDNIXLBackendForPrefillNode, + PDNIXLBackendForDecodeNode ) from lightllm.server.core.objs import RpcShmParams, RpcShmResults, ShmSyncStatusArray from lightllm.utils.log_utils import init_logger @@ -40,12 +42,14 @@ def __init__( rpc_event: multiprocessing.Event, rpc_finished_event: multiprocessing.Event, info_queue: mp.Queue, + result_queue: mp.Queue, mem_queue: mp.Queue, ): super().__init__() self.args = args self.node_world_size = node_world_size self.info_queue = info_queue + self.result_queue = result_queue self.mem_queue = mem_queue self.rpc_event = rpc_event self.rpc_finished_event = rpc_finished_event @@ -130,12 +134,18 @@ def init_model(self, kvargs): if kvargs.get("args", None).dp > 1: self.backend = DPChunkedForPrefillNode(self.info_queue, self.mem_queue) else: - self.backend = ChunckedPrefillForPrefillNode(self.info_queue, self.mem_queue) + # self.backend = ChunckedPrefillForPrefillNode(self.info_queue, self.mem_queue) + self.backend = PDNIXLBackendForPrefillNode(self.info_queue, + self.result_queue, + self.mem_queue) elif is_decode_node: if kvargs.get("args", None).dp > 1: self.backend = DPForDecodeNode(self.info_queue, self.mem_queue) else: - self.backend = ContinuesBatchBackendForDecodeNode(self.info_queue, self.mem_queue) + # self.backend = ContinuesBatchBackendForDecodeNode(self.info_queue, self.mem_queue) + self.backend = PDNIXLBackendForDecodeNode(self.info_queue, + self.result_queue, + self.mem_queue) elif kvargs.get("dp_size", 1) > 1: self.backend = DPChunkedPrefillBackend() elif use_reward_model: @@ -274,6 +284,7 @@ def _init_env( rank_in_node, node_world_size, info_queue, + result_queue, mem_queue, router_lock, rpc_event: mp.Event, @@ -292,7 +303,7 @@ def _init_env( g_router_lock.obj = router_lock model_rpc_server = ModelRpcServer( - args, rank, rank_in_node, node_world_size, rpc_event, rpc_finished_event, info_queue, mem_queue + args, rank, rank_in_node, node_world_size, rpc_event, rpc_finished_event, info_queue, result_queue, mem_queue ) success_event.set() @@ -308,11 +319,11 @@ async def start_model_process( rpc_event, rpc_finished_event, info_queue: mp.Queue, + result_queue: mp.Queue, mem_queue: mp.Queue, router_lock: mp.Queue, ): import lightllm.utils.rpyc_fix_utils as _ - # 单卡单机时不使用 rpc if node_world_size == 1 and args.nnodes == 1: return ModelRpcServer( @@ -323,6 +334,7 @@ async def start_model_process( rpc_event, rpc_finished_event, info_queue, + result_queue, mem_queue, ) @@ -335,6 +347,7 @@ async def start_model_process( rank_in_node, node_world_size, info_queue, + result_queue, mem_queue, router_lock, rpc_event, From 7eb270ff6e989e91d21f6c144c5354d74d69b86a Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Tue, 22 Apr 2025 16:22:30 +0800 Subject: [PATCH 03/19] prefill abort -> notify decode. --- .../pd_disaggregation/impl_for_pd_base.py | 48 ++++++++++--- .../pd_disaggregation/impl_for_pd_prefill.py | 1 + .../pd_disaggregation/nixl_kv_transporter.py | 34 +++++++-- .../pd_remote_prefill_obj.py | 71 ++++++++++++++++++- 4 files changed, 134 insertions(+), 20 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py index 49099d275..c760551ca 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py @@ -15,7 +15,9 @@ RemoteRequest, RemoteRequstType, ConnectRequest, - KVMoveRequest) + KVMoveRequest, + RemotePrefillStatus, + ThreadSafeDict) logger = init_logger(__name__) @@ -32,11 +34,11 @@ def __init__(self, self.nixl_meta_queue = nixl_meta_queue # for decode - self.remote_prefilled_reqs: Dict[int, InferReq] = {} + self.remote_prefilled_reqs: ThreadSafeDict = ThreadSafeDict() # for prefill - self.remote_prefill_requests: Dict[str, PrefillRequest] = {} - self.inflght_transfer_requests: Dict[str, InferReq] = {} + self.remote_prefill_requests: ThreadSafeDict = ThreadSafeDict() + self.inflght_transfer_requests: ThreadSafeDict = ThreadSafeDict() def init_custom(self): @@ -50,14 +52,16 @@ def init_custom(self): def _prefill_wait_loop(self): while True: notifies = self.nixl_agent.get_new_notifs() - for agent_name, req_idxs in notifies.items(): - for req_id in req_idxs: - group_req_id = int(req_id.decode()) + for agent_name, req_statuses in notifies.items(): + for req_status in req_statuses: + prefill_status = RemotePrefillStatus.deserialize(req_status) + group_req_id = prefill_status.group_req_id + status = prefill_status.status if run_req := self.remote_prefilled_reqs.get(group_req_id, None): shm_req: PDChunkedPrefillReq = run_req.shm_req - shm_req.set_pd_req_rank_state(self.rank_in_dp, 1) + shm_req.set_pd_req_rank_state(self.rank_in_dp, status) self.remote_prefilled_reqs.pop(group_req_id) - logger.info(f"remote prefill reqeust: {group_req_id} done") + logger.info(f"remote prefill reqeust: {group_req_id} done with status: {status}") else: logger.warning(f"remote prefill reqeust: {group_req_id} not found") time.sleep(0) @@ -123,8 +127,21 @@ def _transfer_kv_to_remote(self, req: InferReq): def _decode_filter_reqs(self, prefill_reqs: List[InferReq], aborted_reqs: List[InferReq], decode_reqs: List[InferReq]): new_prefill_reqs: List[InferReq] = [] + new_aborted_reqs: List[InferReq] = [] remote_prefill_reqs: List[InferReq] = [] + # filter out aborted requests + for req in aborted_reqs: + if req.in_prefill_or_transfer: + shm_req: PDChunkedPrefillReq = req.shm_req + state = shm_req.get_pd_req_state(self.dp_world_size) + if state != 0: + new_aborted_reqs.append(req) + req.in_prefill_or_transfer = False + else: + #TODO trigger remote abort + remote_prefill_reqs.append(req) + for req in prefill_reqs: if req.in_prefill_or_transfer: shm_req: PDChunkedPrefillReq = req.shm_req @@ -145,7 +162,7 @@ def _decode_filter_reqs(self, prefill_reqs: List[InferReq], new_prefill_reqs.append(req) - return new_prefill_reqs, aborted_reqs, decode_reqs, remote_prefill_reqs + return new_prefill_reqs, new_aborted_reqs, decode_reqs, remote_prefill_reqs def _prefill_filter_reqs(self, ok_finished_reqs: List[InferReq], aborted_reqs: List[InferReq]): new_ok_finished_reqs = [] @@ -171,7 +188,6 @@ def _prefill_filter_reqs(self, ok_finished_reqs: List[InferReq], aborted_reqs: L return new_ok_finished_reqs, aborted_reqs, kv_transfer_reqs - def _prepare_remote_prefill_inputs(self, req_objs: List[InferReq]): run_reqs = [] start_loc = 0 @@ -211,6 +227,16 @@ def _prepare_remote_prefill_inputs(self, req_objs: List[InferReq]): } return kwargs, run_reqs + def _prefill_abort_remote(self, req_objs: List[InferReq]): + for req_obj in req_objs: + group_req_id = req_obj.shm_req.group_req_id + if group_req_id in self.remote_prefill_requests: + self.nixl_agent.send_abort_notify( + self.remote_prefill_requests[group_req_id].decode_id, + group_req_id) + del self.remote_prefill_requests[group_req_id] + + # def _get_classed_reqs(self, req_ids: List[int], no_decode: bool = False, prefill=True): # uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = super()._get_classed_reqs( # req_ids, diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_prefill.py index 6710854e7..fc3437131 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_prefill.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_prefill.py @@ -46,6 +46,7 @@ def decode(self): assert len(uinit_reqs) == 0 assert len(decode_reqs) == 0 + self._prefill_abort_remote(aborted_reqs) self._filter_reqs(aborted_reqs) if ok_finished_reqs: diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/nixl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/nixl_kv_transporter.py index 56bfd0e38..6b737a37f 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/nixl_kv_transporter.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/nixl_kv_transporter.py @@ -3,10 +3,11 @@ from typing import Dict, List, Any from torch import Tensor from dataclasses import dataclass +import threading from lightllm.utils.log_utils import init_logger -from .pd_remote_prefill_obj import RemoteAgent, KVMoveRequest, PrefillRequest +from .pd_remote_prefill_obj import RemoteAgent, KVMoveRequest, PrefillRequest, RemotePrefillStatus, ThreadSafeDict logger = init_logger(__name__) @@ -43,8 +44,8 @@ def __init__(self, node_id: int, tp_idx: int): self.local_xfer_handles = None self.remote_agents = defaultdict(list) - self.inflight_transfers: Dict[str, Any] = {} + self.inflight_transfers: ThreadSafeDict = ThreadSafeDict() @property def agent_name(self) -> str: @@ -112,23 +113,41 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest): src_handle = self.local_xfer_handles dst_handle = remote_agent.kv_xfer_handles + notify_status = RemotePrefillStatus(group_req_id=group_reqeust_id, status=1) handle = self.nixl_agent.make_prepped_xfer("WRITE", src_handle, src_token_descs, dst_handle, dst_token_descs, - str(group_reqeust_id).encode()) + notify_status.serialize()) status = self.nixl_agent.transfer(handle) assert status != 'ERR' - self.inflight_transfers[group_reqeust_id] = handle + self.inflight_transfers[group_reqeust_id] = (handle, remote_agent, False) return handle return None + + def send_abort_notify(self, remote_id: int, group_reqeust_id): + remote_agent: RemoteAgent = self.remote_agents[remote_id][self.tp_idx] + notify_status = RemotePrefillStatus(group_req_id=group_reqeust_id, status=-1) + self.nixl_agent.send_notif(remote_agent.name, notify_status.serialize()) + + if group_reqeust_id in self.inflight_transfers: + self.inflight_transfers[group_reqeust_id][2] = True + + def get_done_tranfers(self): done_req_ids = [] - for req_id, handle in self.inflight_transfers.items(): + + for req_id, (handle, remote_agent, is_abort) in self.inflight_transfers.items(): + if is_abort: + logger.warning(f"{req_id} Transfer aborted") + done_req_ids.append((req_id, -1)) + continue + + remote_agent: RemoteAgent xfer_state = self.nixl_agent.check_xfer_state(handle) if xfer_state == "DONE": done_req_ids.append((req_id, 1)) @@ -137,9 +156,12 @@ def get_done_tranfers(self): else: logger.warning(f"{req_id} Transfer failed with state {xfer_state}") done_req_ids.append((req_id, -1)) + notify_failed_status = RemotePrefillStatus(group_req_id=req_id, status=-1) + self.nixl_agent.send_notif(remote_agent.name, notify_failed_status.serialize()) for req_id, _ in done_req_ids: - self.nixl_agent.release_xfer_handle(self.inflight_transfers[req_id]) + # release will abort inflight transfer + self.nixl_agent.release_xfer_handle(self.inflight_transfers[req_id][0]) del self.inflight_transfers[req_id] return done_req_ids diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill_obj.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill_obj.py index 152a336cd..871780d90 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill_obj.py @@ -1,6 +1,8 @@ -from dataclasses import dataclass +from dataclasses import dataclass, asdict from enum import Enum -from typing import List, Union +import json +from typing import List, Union, Optional, Any +from threading import Lock from lightllm.utils.log_utils import init_logger from lightllm.server.core.objs import SamplingParams @@ -65,4 +67,67 @@ class RemoteAgent: name: str num_tokens: int kv_mem_desc: nixlBind.nixlRegDList - kv_xfer_handles: nixl_prepped_dlist_handle \ No newline at end of file + kv_xfer_handles: nixl_prepped_dlist_handle + + +@dataclass +class RemotePrefillStatus: + group_req_id: int + status: int + + def serialize(self): + return json.dumps(asdict(self)).encode() + + @classmethod + def deserialize(cls, data: bytes): + return cls(**json.loads(data.decode())) + + +class ThreadSafeDict: + def __init__(self): + self._dict = {} + self._lock = Lock() + + def __getitem__(self, key): + with self._lock: + return self._dict[key] + + def __setitem__(self, key, value): + with self._lock: + self._dict[key] = value + + def __delitem__(self, key): + with self._lock: + del self._dict[key] + + def __contains__(self, key): + with self._lock: + return key in self._dict + + def __len__(self) -> int: + with self._lock: + return len(self._data) + + def get(self, key, default=None): + with self._lock: + return self._dict.get(key, default) + + def items(self): + with self._lock: + return list(self._dict.items()) + + def keys(self): + with self._lock: + return list(self._dict.keys()) + + def pop(self, key: Any, default: Optional[Any] = None) -> Any: + with self._lock: + return self._dict.pop(key, default) + + def values(self): + with self._lock: + return list(self._dict.values()) + + def clear(self) -> None: + with self._lock: + self._dict.clear() \ No newline at end of file From bbdd86aa259965f6436d617c96ef273b132b028c Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Tue, 22 Apr 2025 18:10:54 +0800 Subject: [PATCH 04/19] fix long connect time. --- .../pd_disaggregation/impl_for_pd_base.py | 36 +++++++++----- .../pd_disaggregation/nixl_kv_transporter.py | 10 ++-- .../pd_disaggregation/pd_remote_prefill.py | 49 +++++++++++++++---- 3 files changed, 71 insertions(+), 24 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py index c760551ca..77c7afc64 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py @@ -2,6 +2,7 @@ import time import torch.multiprocessing as mp from typing import Dict, List +import queue import numpy as np @@ -51,20 +52,32 @@ def init_custom(self): def _prefill_wait_loop(self): while True: + def handle_remote_prefill(req_status: RemotePrefillStatus): + group_req_id = req_status.group_req_id + status = req_status.status + if run_req := self.remote_prefilled_reqs.get(group_req_id, None): + shm_req: PDChunkedPrefillReq = run_req.shm_req + shm_req.set_pd_req_rank_state(self.rank_in_dp, status) + self.remote_prefilled_reqs.pop(group_req_id) + logger.info(f"remote prefill reqeust: {group_req_id} done with status: {status}") + else: + logger.warning(f"remote prefill reqeust: {group_req_id} not found") + + # from local + try: + req_status = self.from_remote_queue.get_nowait() + handle_remote_prefill(req_status) + except queue.Empty: + pass + + # from remote notifies = self.nixl_agent.get_new_notifs() for agent_name, req_statuses in notifies.items(): for req_status in req_statuses: prefill_status = RemotePrefillStatus.deserialize(req_status) - group_req_id = prefill_status.group_req_id - status = prefill_status.status - if run_req := self.remote_prefilled_reqs.get(group_req_id, None): - shm_req: PDChunkedPrefillReq = run_req.shm_req - shm_req.set_pd_req_rank_state(self.rank_in_dp, status) - self.remote_prefilled_reqs.pop(group_req_id) - logger.info(f"remote prefill reqeust: {group_req_id} done with status: {status}") - else: - logger.warning(f"remote prefill reqeust: {group_req_id} not found") - time.sleep(0) + handle_remote_prefill(prefill_status) + + time.sleep(0.001) def _wait_transfer_loop(self): @@ -81,7 +94,7 @@ def _wait_transfer_loop(self): shm_req: PDChunkedPrefillReq = req.shm_req shm_req.set_pd_req_rank_state(self.rank_in_dp, state) del self.inflght_transfer_requests[req_id] - time.sleep(0) + time.sleep(0.001) def _handle_prefill_loop(self): @@ -96,6 +109,7 @@ def _handle_prefill_loop(self): agent_metadatas = request.agent_metadatas, agent_mem_descs = request.agent_mem_descs )) + self.to_remote_queue.put("OK") if request.type == RemoteRequstType.REMOTE_PREFILL: request: PrefillRequest diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/nixl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/nixl_kv_transporter.py index 6b737a37f..4eab304a0 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/nixl_kv_transporter.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/nixl_kv_transporter.py @@ -82,9 +82,13 @@ def register_kv_buffer(self, kv_buffer: Tensor): def add_remote_agent(self, remote_agent: NixlMetadata): - for agent_metadata, num_tokens, agent_mem_desc in zip(remote_agent.agent_metadatas, - remote_agent.num_tokens, - remote_agent.agent_mem_descs): + for idx, (agent_metadata, num_tokens, agent_mem_desc) in enumerate(zip(remote_agent.agent_metadatas, + remote_agent.num_tokens, + remote_agent.agent_mem_descs)): + if self.tp_idx != idx: + self.remote_agents[remote_agent.id].append(None) + continue + peer_name = self.nixl_agent.add_remote_agent(agent_metadata) mem_desc = self.nixl_agent.deserialize_descs(agent_mem_desc) logger.info("Added remote agent %s with mem desc %s", peer_name, mem_desc) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill.py index 309213773..267a7ff69 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill.py @@ -14,6 +14,7 @@ RemotePrefillRequest, RemotePrefillServerInfo, RemotePrefillTask, + RemotePrefillStatus ) from .nixl_kv_transporter import NixlMetadata @@ -61,7 +62,7 @@ def __init__(self, # build control path _ctx = zmq.Context() - self.recv_from_decode = _ctx.socket(zmq.PULL) + self.recv_from_decode = _ctx.socket(zmq.PAIR) self.host_ip = get_hostname_ip() self.recv_from_decode.bind(f"tcp://{self.host_ip}:{server_port}") @@ -79,13 +80,27 @@ def main_loop(self): request: RemoteRequest = self.recv_from_decode.recv_pyobj() logger.info(f"recevied request from decode, type: {request.type}") + # forward request to prefill server + for queue in self.to_backend_queues: + queue.put(request) + + if request.type == RemoteRequstType.REMOTE_CONNECT: + success = True + for idx in range(self.tp_size): + ack = self.from_backend_queue.get() + if ack != "OK": + success = False + break + + self.recv_from_decode.send_pyobj(success) + if not success: + logger.warning(f"Remote connect failed: {request}") + + if request.type == RemoteRequstType.REMOTE_PREFILL: request: PrefillRequest = request self.trigger_prefill(request) - # forward request to prefill server - for queue in self.to_backend_queues: - queue.put(request) def trigger_prefill(self, request: PrefillRequest): @@ -111,7 +126,7 @@ def connect_to_prefill_server(self, server_info: RemotePrefillServerInfo): # build control path if not exist if server_info.perfill_server_id not in self.remote_prefill_servers: _ctx = zmq.Context() - _socket = _ctx.socket(zmq.PUSH) + _socket = _ctx.socket(zmq.PAIR) connect_str = f"tcp://{server_info.prefill_server_ip}:{server_info.prefill_server_port}" _socket.connect(connect_str) _socket.send_pyobj(ConnectRequest( @@ -121,7 +136,15 @@ def connect_to_prefill_server(self, server_info: RemotePrefillServerInfo): agent_metadatas=self.local_agent_meta.agent_metadatas, agent_mem_descs=self.local_agent_meta.agent_mem_descs)) - self.remote_prefill_servers[server_info.perfill_server_id] = (_socket, server_info) + success = _socket.recv_pyobj() + if success: + self.remote_prefill_servers[server_info.perfill_server_id] = (_socket, server_info) + return True + else: + logger.warning("Remote Prefill Server Connect Failed") + return False + + return True def main_loop(self): self.local_init() @@ -129,10 +152,16 @@ def main_loop(self): prefill_tasks: RemotePrefillTask = self.from_backend_queue.get() # connect first - self.connect_to_prefill_server(prefill_tasks.server_info) - # do prefill - self.remote_prefill(prefill_tasks.server_info.perfill_server_id, prefill_tasks.prefill_request) - + if(self.connect_to_prefill_server(prefill_tasks.server_info)): + # do prefill + self.remote_prefill(prefill_tasks.server_info.perfill_server_id, prefill_tasks.prefill_request) + else: + # failed to connect a remote + for idx in self.to_backend_queues: + self.to_backend_queues.put(RemotePrefillStatus( + group_req_id=prefill_tasks.prefill_request.sampling_params.group_request_id, + status=-1, + )) # place request to server do remote prefill def remote_prefill(self, server_id: int, prefill_request: RemotePrefillRequest): From 64bd8c555b2da0bfe53b981bcef3c75d214c8348 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Wed, 23 Apr 2025 13:06:04 +0800 Subject: [PATCH 05/19] nixl pd. --- lightllm/server/api_cli.py | 6 +- lightllm/server/api_start.py | 2 +- lightllm/server/core/objs/shm_req_manager.py | 2 +- lightllm/server/core/objs/start_args_type.py | 3 +- lightllm/server/httpserver/manager.py | 14 +-- lightllm/server/httpserver/pd_loop.py | 7 +- .../httpserver_for_pd_master/manager.py | 103 +++++++++++++++++- lightllm/server/pd_io_struct.py | 19 +++- lightllm/server/router/manager.py | 22 ++-- .../pd_disaggregation/impl_for_pd_base.py | 71 +----------- .../server/router/model_infer/model_rpc.py | 26 +++-- lightllm/server/router/req_queue/__init__.py | 4 +- lightllm/utils/health_check.py | 2 +- 13 files changed, 166 insertions(+), 115 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 70d892bd3..640d19f7a 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -7,7 +7,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--run_mode", type=str, - choices=["normal", "prefill", "decode", "pd_master", "config_server"], + choices=["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"], default="normal", help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode, config_server is for pd split mode used to register pd_master node, and get pd_master node list, @@ -58,14 +58,14 @@ def make_argument_parser() -> argparse.ArgumentParser: "--pd_remote_prefill_http_port", type=int, default=42001, - help="p d mode, remote prefill node used for kv move manager rpyc server port", + help="nixl pd mode, prefill node used for triggering prefill http port.", ) parser.add_argument( "--pd_remote_prefill_port", type=int, default=42002, - help="p d mode, remote prefill node used for kv move manager rpyc server port", + help="nixl pd mode, prefill and decode used for meta exchange.", ) parser.add_argument( diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index fa100f119..4a7775677 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -69,7 +69,7 @@ def normal_or_p_d_start(args): enable_mps() - if args.run_mode not in ["normal", "prefill", "decode"]: + if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode"]: return assert args.zmq_mode in ["tcp://", "ipc:///tmp/"] diff --git a/lightllm/server/core/objs/shm_req_manager.py b/lightllm/server/core/objs/shm_req_manager.py index ca7d9bebe..f3e9fe711 100644 --- a/lightllm/server/core/objs/shm_req_manager.py +++ b/lightllm/server/core/objs/shm_req_manager.py @@ -33,7 +33,7 @@ def get_req_class_type(self): if args.token_healing_mode: return TokenHealingReq - if args.run_mode == 'prefill' or args.run_mode == 'decode': + if args.run_mode in ["nixl_prefill", "nixl_decode"]: return PDChunkedPrefillReq if args.disable_chunked_prefill: diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 0adb61a4c..3b159d664 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -6,7 +6,8 @@ @dataclass class StartArgs: - run_mode: str = field(default="normal", metadata={"choices": ["normal", "prefill", "decode", "pd_master"]}) + run_mode: str = field(default="normal", metadata={"choices": ["normal", "prefill", "decode", "pd_master", + "nixl_prefill", "nixl_decode"]}) host: str = field(default="127.0.0.1") port: int = field(default=8000) zmq_mode: str = field( diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index e02eaaf79..91cae273f 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -98,7 +98,7 @@ def __init__( self.metric_client = MetricClient(metric_port) self.pd_mode: NodeRole = NodeRole(self.args.run_mode) - assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL] + assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL, NodeRole.NP, NodeRole.ND] self.id_gen = ReqIDGenerator() self.first_time_costs = MovingAverage() self.per_token_costs = MovingAverage() @@ -225,7 +225,7 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False): # health 请求 request_id 为负数,直接返回 if is_health_req: return sampling_params.group_request_id - if self.pd_mode == NodeRole.NORMAL: + if self.pd_mode.is_normal(): if not self.is_multinode_tp: group_request_id = self.id_gen.generate_id() else: @@ -235,7 +235,7 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False): assert sampling_params.group_request_id != -1 group_request_id = sampling_params.group_request_id sampling_params.group_request_id = group_request_id - elif self.pd_mode == NodeRole.P or self.pd_mode == NodeRole.D: + elif self.pd_mode.is_P_or_D(): assert sampling_params.group_request_id is not None, "p d mode, group_request_id must be setting" group_request_id = sampling_params.group_request_id else: @@ -441,7 +441,7 @@ async def transfer_to_next_module( group_req_objs: Optional[GroupReqObjs] = None, ): - if self.pd_mode == NodeRole.P: + if self.pd_mode.is_P(): if self.enable_multimodal: self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), @@ -454,7 +454,7 @@ async def transfer_to_next_module( ) return - if self.pd_mode == NodeRole.D: + if self.pd_mode.is_D(): # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了, 传输一个空的即可 self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), @@ -462,7 +462,7 @@ async def transfer_to_next_module( ) return - if self.pd_mode == NodeRole.NORMAL: + if self.pd_mode.is_normal(): if self.enable_multimodal: self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), @@ -514,7 +514,7 @@ async def _wait_to_token_package( # pd master 节点需要这个做统计信息, 所以放在元数据中返回给 pd master 节点 metadata["prompt_tokens"] = prompt_tokens # p 节点返回 prompt_ids 信息,防止 d 节点重新 encode - if self.pd_mode == NodeRole.P and is_first_token: + if self.pd_mode.is_P() and is_first_token: metadata["prompt_ids"] = prompt_ids prompt_cache_len = metadata.pop("prompt_cache_len", 0) diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index 18579fc5f..6fb87f3dd 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -34,7 +34,8 @@ async def pd_handle_loop(manager: HttpServerManager): manager.host_ip = manager.args.host asyncio.create_task(timer_log(manager)) - asyncio.create_task(pd_handle_loop_from_d(manager)) + if manager.pd_mode.is_NP_or_ND(): + asyncio.create_task(pd_handle_loop_from_d(manager)) id_to_handle_task: Dict[int, asyncio.Task] = {} @@ -94,7 +95,7 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O logger.info(f"Sent registration JSON: {regist_json}") # 转发任务 - if manager.pd_mode == NodeRole.D: + if manager.pd_mode.is_D(): forwarding_tokens_task = asyncio.create_task(_up_tokens_to_pd_master(forwarding_queue, websocket)) # 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。 @@ -188,7 +189,7 @@ async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket): async def pd_handle_loop_from_d(manager: HttpServerManager): - if manager.pd_mode != NodeRole.P: + if manager.pd_mode != NodeRole.NP: return context = zmq.asyncio.Context(2) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 597d3d3b1..a3920fb50 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -9,7 +9,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from typing import Union, List, Tuple, Dict from lightllm.server.core.objs import FinishStatus -from ..pd_io_struct import PD_Client_Obj, UpKVStatus, ObjType +from ..pd_io_struct import PD_Client_Obj, UpKVStatus, ObjType, NodeRole from lightllm.server.core.objs import SamplingParams from ..multimodal_params import MultimodalParams from ..tokenizer import get_tokenizer @@ -50,10 +50,11 @@ async def register_pd(self, pd_info_json, websocket): pd_client = PD_Client_Obj(**pd_info_json) pd_client.websocket = websocket self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client - if pd_client.mode == "prefill": + client_pd_mode: NodeRole = NodeRole(pd_client.mode) + if client_pd_mode.is_P(): self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port] self.prefill_nodes.append(pd_client) - elif pd_client.mode == "decode": + elif client_pd_mode.is_D(): self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port] self.decode_nodes.append(pd_client) else: @@ -73,6 +74,15 @@ async def remove_pd(self, pd_info_json): logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} removed") return + async def update_req_status(self, upkv_status: UpKVStatus): + try: + group_request_id = convert_sub_id_to_group_id(upkv_status.group_request_id) + up_status_event = self.req_id_to_out_inf[group_request_id].up_status_event + up_status_event.upkv_status = upkv_status + up_status_event.set() + except: + pass + return def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwargs=None): kwargs = {} if kwargs is None else kwargs @@ -177,6 +187,82 @@ async def fetch_stream( up_status_event = req_status.up_status_event + d_start_args = d_node.start_args + decode_node_dict = { + "node_id": d_start_args["pd_node_id"], + "ip": d_start_args["host"], + "rpyc_port": d_start_args["pd_decode_rpyc_port"], + "max_new_tokens": sampling_params.max_new_tokens - 1, + "pd_master_node_id": self.args.pd_node_id, + } + + old_max_new_tokens = sampling_params.max_new_tokens + sampling_params.max_new_tokens = 1 + sampling_params.move_kv_to_decode_node.initialize(decode_node_dict if old_max_new_tokens != 1 else None) + sampling_params.suggested_dp_index = -1 + + await p_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params)))) + + while True: + await req_status.wait_to_ready() + if await request.is_disconnected(): + raise Exception(f"req_id {group_request_id} disconnected") + + if await req_status.can_read(self.req_id_to_out_inf): + token_list = await req_status.pop_all_tokens() + for sub_req_id, request_output, metadata, finish_status in token_list: + if old_max_new_tokens != 1: + finish_status = FinishStatus(FinishStatus.NO_FINISH) + else: + finish_status = FinishStatus(FinishStatus.FINISHED_LENGTH) + # 得到 p 节点返回的 prompt_ids 信息 + if metadata.get("prompt_ids", None) is not None: + prompt_ids = metadata.get("prompt_ids") + prompt_ids.append(metadata.get("id")) + yield sub_req_id, request_output, metadata, finish_status + break + + # 如果只需要一个输出 token,prefill 完就直接结束掉吧 + if old_max_new_tokens == 1: + return + + try: + await asyncio.wait_for(up_status_event.wait(), timeout=60) + except asyncio.TimeoutError: + logger.warning(f"group_request_id: {group_request_id} kv move time out err") + assert False, f"req_id {group_request_id} kv move time out, server is busy" + + sampling_params.move_kv_to_decode_node.initialize(None) + sampling_params.max_new_tokens = old_max_new_tokens - 1 + sampling_params.suggested_dp_index = up_status_event.upkv_status.dp_index + + await d_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt_ids, sampling_params, multimodal_params)))) + + while True: + await req_status.wait_to_ready() + if await request.is_disconnected(): + raise Exception(f"req_id {group_request_id} disconnected") + if await req_status.can_read(self.req_id_to_out_inf): + token_list = await req_status.pop_all_tokens() + for sub_req_id, request_output, metadata, finish_status in token_list: + yield sub_req_id, request_output, metadata, finish_status + + return + + async def fetch_stream_nixl( + self, + p_node: PD_Client_Obj, + d_node: PD_Client_Obj, + prompt: Union[str, List[int]], + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + request: Request, + ): + group_request_id = sampling_params.group_request_id + + req_status = ReqStatus(group_request_id, p_node, d_node) + self.req_id_to_out_inf[group_request_id] = req_status + p_start_args = p_node.start_args prefill_node_dict = { "node_id": p_start_args["pd_node_id"], @@ -216,7 +302,11 @@ async def _wait_to_token_package( unfinished_count = sampling_params.best_of is_first_token = True - async for sub_req_id, out_str, metadata, finish_status in self.fetch_stream( + client_mode: NodeRole = NodeRole(d_node.mode) + + fetch_stream = self.fetch_stream_nixl if client_mode.is_NP_or_ND() else self.fetch_stream + + async for sub_req_id, out_str, metadata, finish_status in fetch_stream( p_node, d_node, prompt, sampling_params, multimodal_params, request ): if await request.is_disconnected(): @@ -271,6 +361,11 @@ async def abort(self, group_request_id): except: pass + try: + await req_status.p_node.websocket.send_bytes(pickle.dumps((ObjType.ABORT, group_request_id))) + except: + pass + try: await req_status.d_node.websocket.send_bytes(pickle.dumps((ObjType.ABORT, group_request_id))) except: diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index cb0bb8a1e..69d1e6f08 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -13,23 +13,30 @@ class NodeRole(enum.Enum): P = "prefill" D = "decode" + + NP = "nixl_prefill" + ND = "nixl_decode" + NORMAL = "normal" PD_MASTER = "pd_master" def is_D(self): - return self == NodeRole.D + return self == NodeRole.D or self == NodeRole.ND def is_P(self): - return self == NodeRole.P + return self == NodeRole.P or self == NodeRole.NP def is_normal(self): return self == NodeRole.NORMAL def is_P_or_NORMAL(self): - return (self == NodeRole.P) or (self == NodeRole.NORMAL) + return self.is_P() or self.is_normal() def is_P_or_D(self): - return (self == NodeRole.P) or (self == NodeRole.D) + return self.is_P() or self.is_D() + + def is_NP_or_ND(self): + return self == NodeRole.NP or self == NodeRole.ND class ObjType(enum.Enum): @@ -47,8 +54,8 @@ class PD_Client_Obj: websocket: WebSocket = None # 用于通信的 websocket 连接对象 def __post_init__(self): - if self.mode not in ["prefill", "decode"]: - error_info = f"""mode must in ["prefill", "decode"], but get {self.mode}""" + if self.mode not in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + error_info = f"""mode must in ["prefill", "decode", "nixl_prefill", "nixl_decode"], but get {self.mode}""" logger.error(error_info) raise ValueError(error_info) return diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index b8b4ee583..a1c2f8509 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -97,8 +97,8 @@ def __init__(self, args, router_port, detokenization_port, metric_port): self.stats_tool = Stats(not args.disable_log_stats, args.log_stats_interval) self.metric_client = MetricClient(metric_port) - self.is_pd_run_mode = self.args.run_mode in ["prefill", "decode"] - self.is_pd_decode_mode = self.args.run_mode == "decode" + self.is_pd_run_mode = self.args.run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"] + self.is_pd_decode_mode = self.args.run_mode in ["decode", "nixl_decode"] # p d 分离模式下,需要调度锁来同步调度端和推理端的一些数据操作 # 主要是为了防止调度失误,造成 OOM 等错误 self.router_lock = mp.Lock() @@ -199,12 +199,13 @@ async def wait_to_model_ready(self): if self.args.run_mode == "prefill": # 启动 prefill kv move 管理进程 - # from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.prefill_node_impl import ( - # start_prefill_kv_move_manager_process, - # ) + from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.prefill_node_impl import ( + start_prefill_kv_move_manager_process, + ) - # start_prefill_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + start_prefill_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + if self.args.run_mode == "nixl_prefill": from lightllm.server.router.model_infer.mode_backend.pd_disaggregation.pd_remote_prefill import ( start_pd_remote_prefill_server_process ) @@ -219,12 +220,13 @@ async def wait_to_model_ready(self): if self.args.run_mode == "decode": # 启动 decode kv move 管理进程 - # from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.decode_node_impl import ( - # start_decode_kv_move_manager_process, - # ) + from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.decode_node_impl import ( + start_decode_kv_move_manager_process, + ) - # start_decode_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + start_decode_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + if self.args.run_mode == "nixl_decode": from lightllm.server.router.model_infer.mode_backend.pd_disaggregation.pd_remote_prefill import ( start_pd_remote_prefill_client_process ) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py index 77c7afc64..b9655ee78 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py @@ -25,6 +25,7 @@ class PDNIXLBackendBase(ModeBackend): + _THEAD_WAIT_INTERVAL = 0.001 def __init__(self, to_remote_queue: mp.Queue, from_remote_queue: mp.Queue, @@ -77,7 +78,7 @@ def handle_remote_prefill(req_status: RemotePrefillStatus): prefill_status = RemotePrefillStatus.deserialize(req_status) handle_remote_prefill(prefill_status) - time.sleep(0.001) + time.sleep(PDNIXLBackendBase._THEAD_WAIT_INTERVAL) def _wait_transfer_loop(self): @@ -94,7 +95,7 @@ def _wait_transfer_loop(self): shm_req: PDChunkedPrefillReq = req.shm_req shm_req.set_pd_req_rank_state(self.rank_in_dp, state) del self.inflght_transfer_requests[req_id] - time.sleep(0.001) + time.sleep(PDNIXLBackendBase._THEAD_WAIT_INTERVAL) def _handle_prefill_loop(self): @@ -250,69 +251,3 @@ def _prefill_abort_remote(self, req_objs: List[InferReq]): group_req_id) del self.remote_prefill_requests[group_req_id] - - # def _get_classed_reqs(self, req_ids: List[int], no_decode: bool = False, prefill=True): - # uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = super()._get_classed_reqs( - # req_ids, - # no_decode - # ) - - # progressing_reqs = [] - # new_ok_or_prefill_reqs = [] - - # ok_or_prefill_reqs = ok_finished_reqs if prefill else prefill_reqs - # success_reqs = progressing_reqs if prefill else decode_reqs - - # # filter remote prefill requests - # for r in ok_or_prefill_reqs: - # r: InferReq - # if r.in_prefill_or_transfer: - # shm_req: PDChunkedPrefillReq = r.shm_req - # # state is updated by router - # state = shm_req.get_pd_req_state() - # if state == 1: - # success_reqs.append(r) - # r.in_prefill_or_transfer = False - # elif state == -1: - # aborted_reqs.append(r) - # r.in_prefill_or_transfer = False - # elif state == 0: # in progress - # progressing_reqs.append(r) - # else: - # logger.warning(f"remote prefill request {r.req_id} unexpected state {state}") - # continue - - # new_ok_or_prefill_reqs.append(r) - - # if prefill: - # return uninit_reqs, aborted_reqs, new_ok_or_prefill_reqs, prefill_reqs, decode_reqs, progressing_reqs - # else: - # return uninit_reqs, aborted_reqs, ok_finished_reqs, new_ok_or_prefill_reqs, decode_reqs, progressing_reqs - - # def _get_classed_reqs(self, req_ids: List[int], no_decode: bool = False): - # uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = super()._get_classed_reqs( - # req_ids, - # no_decode - # ) - # new_ok_finished = [] - # transfer_reqs = [] - # # filter remote prefill requests - # for r in ok_finished_reqs: - # r: InferReq - # if r.kv_transfering: - # shm_req: PDChunkedPrefillReq = r.shm_req - # state = shm_req.get_pd_req_state() # state is updated by last post_handle, change is reflected here - # if state == 1: - # new_ok_finished.append(r) - # r.kv_transfering = False - # elif state == -1: - # aborted_reqs.append(r) - # r.kv_transfering = False - # elif state == 0: # in progress - # transfer_reqs.append(r) - # else: - # logger.warning(f"remote prefill request {r.req_id} unexpected state {state}") - # continue - # new_ok_finished.append(r) - - # return uninit_reqs, aborted_reqs, new_ok_finished, prefill_reqs, decode_reqs, transfer_reqs \ No newline at end of file diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 7f427c84f..b88414257 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -124,28 +124,38 @@ def init_model(self, kvargs): ), "only one constraint mode can be true" is_prefill_node = kvargs.get("args", None).run_mode == "prefill" is_decode_node = kvargs.get("args", None).run_mode == "decode" + is_nixl_prefill_node = kvargs.get("args", None).run_mode == "nixl_prefill" + is_nixl_decode_node = kvargs.get("args", None).run_mode == "nixl_decode" else: is_outlines_constraint_mode = False is_xgrammar_constraint_mode = False is_prefill_node = False is_decode_node = False + is_nixl_prefill_node = False + is_nixl_decode_node = False if is_prefill_node: if kvargs.get("args", None).dp > 1: self.backend = DPChunkedForPrefillNode(self.info_queue, self.mem_queue) else: - # self.backend = ChunckedPrefillForPrefillNode(self.info_queue, self.mem_queue) - self.backend = PDNIXLBackendForPrefillNode(self.info_queue, - self.result_queue, - self.mem_queue) + self.backend = ChunckedPrefillForPrefillNode(self.info_queue, self.mem_queue) + + elif is_nixl_prefill_node: + assert kvargs.get("args", None).dp == 1 + self.backend = PDNIXLBackendForPrefillNode(self.info_queue, + self.result_queue, + self.mem_queue) elif is_decode_node: if kvargs.get("args", None).dp > 1: self.backend = DPForDecodeNode(self.info_queue, self.mem_queue) else: - # self.backend = ContinuesBatchBackendForDecodeNode(self.info_queue, self.mem_queue) - self.backend = PDNIXLBackendForDecodeNode(self.info_queue, - self.result_queue, - self.mem_queue) + self.backend = ContinuesBatchBackendForDecodeNode(self.info_queue, self.mem_queue) + + elif is_nixl_decode_node: + assert kvargs.get("args", None).dp == 1 + self.backend = PDNIXLBackendForDecodeNode(self.info_queue, + self.result_queue, + self.mem_queue) elif kvargs.get("dp_size", 1) > 1: self.backend = DPChunkedPrefillBackend() elif use_reward_model: diff --git a/lightllm/server/router/req_queue/__init__.py b/lightllm/server/router/req_queue/__init__.py index 736ab5e59..36ba7879e 100644 --- a/lightllm/server/router/req_queue/__init__.py +++ b/lightllm/server/router/req_queue/__init__.py @@ -15,9 +15,9 @@ def _get_req_queue_class(args, router, dp_size_in_node: int): return ChunkedPrefillQueue if args.first_token_constraint_mode: return ChunkedPrefillQueue - if args.run_mode == "decode": + if args.run_mode in ["decode", "nixl_decode"]: return QueueForPDDecode - if args.run_mode == "prefill": + if args.run_mode in ["prefill", "nixl_prefill"]: return QueueForPDChunkedPrefill if args.disable_chunked_prefill: diff --git a/lightllm/utils/health_check.py b/lightllm/utils/health_check.py index f6305e209..ee0778b65 100644 --- a/lightllm/utils/health_check.py +++ b/lightllm/utils/health_check.py @@ -70,7 +70,7 @@ async def health_check(args, httpserver_manager: HttpServerManager, request: Req health_obj.begin_check() try: request_dict = {"inputs": "你好!", "parameters": {"do_sample": True, "temperature": 0.8, "max_new_tokens": 2}} - if args.run_mode == "prefill": + if args.run_mode in ["prefill", "nixl_prefill"]: request_dict["parameters"]["max_new_tokens"] = 1 prompt = request_dict.pop("inputs") sample_params_dict = request_dict["parameters"] From daec2759caa6245f979f3993deb8f6475ae2cd53 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Wed, 23 Apr 2025 13:19:36 +0800 Subject: [PATCH 06/19] fixup. --- lightllm/server/httpserver/pd_loop.py | 2 +- lightllm/server/router/model_infer/infer_batch.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index 6fb87f3dd..9d9aa1860 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -95,7 +95,7 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O logger.info(f"Sent registration JSON: {regist_json}") # 转发任务 - if manager.pd_mode.is_D(): + if manager.pd_mode != NodeRole.NP: # nixl prefill don't need up token to master forwarding_tokens_task = asyncio.create_task(_up_tokens_to_pd_master(forwarding_queue, websocket)) # 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。 diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 2c32d61e7..f5d26af8a 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -309,8 +309,7 @@ def init_all(self): self.initialized = True self.paused = False - self.remote_prefilling = False - self.kv_transfering = False + self.in_prefill_or_transfer = False return def is_uninitialized(self): From 8916d1d2506b351b9183bdf4af51bc6602681103 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Wed, 23 Apr 2025 18:59:47 +0800 Subject: [PATCH 07/19] fixup. --- lightllm/server/core/objs/req.py | 15 +- .../httpserver_for_pd_master/manager.py | 8 +- lightllm/server/router/batch.py | 2 +- lightllm/server/router/manager.py | 4 +- .../pd_disaggregation/impl_for_pd_base.py | 27 ++- .../pd_disaggregation/impl_for_pd_decode.py | 3 +- .../pd_disaggregation/pd_remote_prefill.py | 156 +++++++++++------- 7 files changed, 140 insertions(+), 75 deletions(-) diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 5f81cf0b4..7fd0a90d6 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -8,7 +8,6 @@ from lightllm.server.req_id_generator import convert_sub_id_to_group_id from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.envs_utils import get_env_start_args -from lightllm.utils.dist_utils import get_dp_world_size from typing import List, Any, Union @@ -95,6 +94,7 @@ class Req(ctypes.Structure): ("reward_score", ctypes.c_float), # 请求回复累计概率和 ("cumlogprob", ctypes.c_float), + ("dp_world_size", ctypes.c_int), ] def get_str(self): @@ -344,6 +344,7 @@ class PDChunkedPrefillReq(ChunkedPrefillReq): def post_init(self): super().post_init() self.create_pd_req_state_shm_array() + self.dp_world_size = 0 def create_pd_req_state_shm_array(self): service_uni_name = get_unique_server_name() @@ -368,10 +369,12 @@ def set_pd_req_rank_state(self, tp_id: int, state: int): # state: -1 for failed, 0 for in progress, 1 for success # set by router - def set_pd_req_state(self, dp_world_size: int): - unique_state = np.unique(self.pd_req_state_shm.arr[:dp_world_size]) - self.pd_req_state_shm.arr[dp_world_size] = unique_state[0] + def set_pd_req_state(self): + assert self.dp_world_size > 0, "dp_world_size should be set before calling this" + unique_state = np.unique(self.pd_req_state_shm.arr[:self.dp_world_size]) + self.pd_req_state_shm.arr[self.dp_world_size] = unique_state[0] # read by all rank - def get_pd_req_state(self, dp_world_size: int): - return self.pd_req_state_shm.arr[dp_world_size] + def get_pd_req_state(self): + assert self.dp_world_size > 0, "dp_world_size should be set before calling this" + return self.pd_req_state_shm.arr[self.dp_world_size] diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index a3920fb50..090b142e7 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -106,8 +106,8 @@ async def select_p_d_node( ) -> Tuple[PD_Client_Obj, PD_Client_Obj]: import random - p_node = random.choice(self.prefill_nodes) - d_node = random.choice(self.decode_nodes) + p_node = random.choice(self.prefill_nodes) if self.prefill_nodes else None + d_node = random.choice(self.decode_nodes) if self.decode_nodes else None return p_node, d_node async def generate( @@ -129,6 +129,10 @@ async def generate( p_node, d_node = await self.select_p_d_node(prompt, sampling_params, multimodal_params) + if not p_node or not d_node: + logger.error(f"{group_request_id}: No p_node or d_node found") + return + results_generator = self._wait_to_token_package( p_node, d_node, diff --git a/lightllm/server/router/batch.py b/lightllm/server/router/batch.py index 641d3ab57..0b98d40af 100644 --- a/lightllm/server/router/batch.py +++ b/lightllm/server/router/batch.py @@ -56,7 +56,7 @@ def filter_out_finished_req(self, shm_req_manager: ShmReqManager): unfinished_req_ids.append(req.request_id) if isinstance(req, PDChunkedPrefillReq): req.link_pd_req_state_shm_array() - req.set_pd_req_state(get_dp_world_size()) + req.set_pd_req_state() self.reqs = [self.id_to_reqs[req_id] for req_id in unfinished_req_ids] self.id_to_reqs = {req.request_id: req for req in self.reqs} diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index a1c2f8509..d8fa71b90 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -22,7 +22,7 @@ from .req_queue import build_req_queue from lightllm.utils.infer_utils import calculate_time from lightllm.server.core.objs.io_objs import GroupReqIndexes -from lightllm.server.core.objs import ShmReqManager +from lightllm.server.core.objs import ShmReqManager, PDChunkedPrefillReq from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient from .stats import Stats from .pause_strategy import Fcfs, select_paused_reqs @@ -245,6 +245,8 @@ def add_req(self, group_req_indexes: GroupReqIndexes): req = self.shm_req_manager.get_req_obj_by_index(req_index) req.multimodal_params = group_req_indexes.multimodal_params req.start_time = group_req_indexes.time_mark + if isinstance(req, PDChunkedPrefillReq): + req.dp_world_size = self.world_size req_group.append(req) logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s") diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py index b9655ee78..fb4ebe0d4 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py @@ -23,7 +23,6 @@ logger = init_logger(__name__) - class PDNIXLBackendBase(ModeBackend): _THEAD_WAIT_INTERVAL = 0.001 def __init__(self, @@ -56,13 +55,18 @@ def _prefill_wait_loop(self): def handle_remote_prefill(req_status: RemotePrefillStatus): group_req_id = req_status.group_req_id status = req_status.status + if status != 1: + logger.warning(f"remote prefill reqeust: {group_req_id} done with state: {status}") if run_req := self.remote_prefilled_reqs.get(group_req_id, None): shm_req: PDChunkedPrefillReq = run_req.shm_req shm_req.set_pd_req_rank_state(self.rank_in_dp, status) self.remote_prefilled_reqs.pop(group_req_id) - logger.info(f"remote prefill reqeust: {group_req_id} done with status: {status}") + if self.is_master_in_dp: + logger.info(f"remote prefill reqeust: {group_req_id} done with status: {status} " + f"took: {time.time() - run_req.remote_prefill_start} seconds") else: - logger.warning(f"remote prefill reqeust: {group_req_id} not found") + if self.is_master_in_dp: + logger.warning(f"remote prefill reqeust: {group_req_id} not found") # from local try: @@ -86,15 +90,21 @@ def _wait_transfer_loop(self): done_req_ids = self.nixl_agent.get_done_tranfers() for req_id, state in done_req_ids: - logger.info(f"wait transfer done: {req_id} state: {state}") + if state != 1: + logger.info(f"wait transfer done: {req_id} state: {state}") + if req_id not in self.inflght_transfer_requests: - logger.warning(f"{req_id} not found in inflght_transfer_requests") + if self.is_master_in_dp: + logger.warning(f"{req_id} not found in inflght_transfer_requests") continue req: InferReq = self.inflght_transfer_requests[req_id] shm_req: PDChunkedPrefillReq = req.shm_req shm_req.set_pd_req_rank_state(self.rank_in_dp, state) del self.inflght_transfer_requests[req_id] + if self.is_master_in_dp: + logger.info(f"req: {req_id} kv transfer with state: {state} " + f"took: {time.time() - req.kv_transfer_start} seconds") time.sleep(PDNIXLBackendBase._THEAD_WAIT_INTERVAL) @@ -128,6 +138,7 @@ def _transfer_kv_to_remote(self, req: InferReq): # kick off kv transfer if req.finish_status.is_finished(): + req.kv_transfer_start = time.time() kv_transfer_req = KVMoveRequest( group_req_id=group_req_id, token_ids=self.model.req_manager.req_to_token_indexs[req.req_idx][:req.cur_kv_len].tolist() @@ -149,7 +160,7 @@ def _decode_filter_reqs(self, prefill_reqs: List[InferReq], for req in aborted_reqs: if req.in_prefill_or_transfer: shm_req: PDChunkedPrefillReq = req.shm_req - state = shm_req.get_pd_req_state(self.dp_world_size) + state = shm_req.get_pd_req_state() if state != 0: new_aborted_reqs.append(req) req.in_prefill_or_transfer = False @@ -161,7 +172,7 @@ def _decode_filter_reqs(self, prefill_reqs: List[InferReq], if req.in_prefill_or_transfer: shm_req: PDChunkedPrefillReq = req.shm_req # state is updated by router - state = shm_req.get_pd_req_state(self.dp_world_size) + state = shm_req.get_pd_req_state() if state == 1: # success req.cur_kv_len = req.get_cur_total_len() - 1 decode_reqs.append(req) @@ -186,7 +197,7 @@ def _prefill_filter_reqs(self, ok_finished_reqs: List[InferReq], aborted_reqs: L for req in ok_finished_reqs: if req.in_prefill_or_transfer: shm_req: PDChunkedPrefillReq = req.shm_req - state = shm_req.get_pd_req_state(self.dp_world_size) + state = shm_req.get_pd_req_state() if state == 1: # success new_ok_finished_reqs.append(req) req.in_prefill_or_transfer = False diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_decode.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_decode.py index d9002dc78..213632b72 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_decode.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_decode.py @@ -1,3 +1,4 @@ +import time import torch import torch.multiprocessing as mp import threading @@ -45,7 +46,6 @@ def _build_remote_prefill_task(self, index: int, kwargs: Dict, req: InferReq): mem_indexes = kwargs.get('mem_indexes') b_start_loc = kwargs.get('b_start_loc') - logger.info(req.shm_req.get_str()) prefill_request = RemotePrefillRequest( prompt = req.shm_req.get_prompt_ids(), sampling_params=req.shm_req.sample_params, @@ -82,6 +82,7 @@ def decode(self): # forward each req to remote prefill # since the token index are the same across TPs, we only need to trigger prefill on master if self.is_master_in_dp: + run_req.remote_prefill_start = time.time() self.to_remote_queue.put(self._build_remote_prefill_task(idx, kwargs, run_req)) shm_req.set_pd_req_rank_state(self.rank_in_dp, 0) # set in progress state diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill.py index 267a7ff69..66605ba35 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill.py @@ -1,10 +1,13 @@ -from typing import List +from typing import List, Any import zmq +import inspect +import pickle import torch.multiprocessing as mp from lightllm.utils.log_utils import init_logger from lightllm.utils.net_utils import get_hostname_ip +from lightllm.utils.graceful_utils import graceful_registry from .pd_remote_prefill_obj import ( ConnectRequest, @@ -21,6 +24,23 @@ logger = init_logger(__name__) +class SockWithPoller: + def __init__(self, sock: zmq.Socket): + self.sock = sock + self.poller = zmq.Poller() + self.poller.register(self.sock, zmq.POLLIN) + + def recv_pyobj(self, timeout: int = 5): + socks = dict(self.poller.poll(timeout * 1000)) + if socks: + if socks.get(self.sock) == zmq.POLLIN: + return self.sock.recv_pyobj() + else: + None + + def send_pyobj(self, obj: Any): + return self.sock.send_pyobj(obj) + class PDRemotePrefillBase: def __init__(self, id: int, @@ -62,7 +82,7 @@ def __init__(self, # build control path _ctx = zmq.Context() - self.recv_from_decode = _ctx.socket(zmq.PAIR) + self.recv_from_decode = _ctx.socket(zmq.ROUTER) self.host_ip = get_hostname_ip() self.recv_from_decode.bind(f"tcp://{self.host_ip}:{server_port}") @@ -73,33 +93,37 @@ def __init__(self, self.prefill_requests = {} - def main_loop(self): self.local_init() while True: - request: RemoteRequest = self.recv_from_decode.recv_pyobj() - logger.info(f"recevied request from decode, type: {request.type}") + try: + client_obj, msg = self.recv_from_decode.recv_multipart() + request: RemoteRequest = pickle.loads(msg) + logger.info(f"recevied request from decode, type: {request.type}") + + # forward request to prefill server + for queue in self.to_backend_queues: + queue.put(request) - # forward request to prefill server - for queue in self.to_backend_queues: - queue.put(request) + if request.type == RemoteRequstType.REMOTE_CONNECT: + success = True + for idx in range(self.tp_size): + ack = self.from_backend_queue.get() + if ack != "OK": + success = False + break - if request.type == RemoteRequstType.REMOTE_CONNECT: - success = True - for idx in range(self.tp_size): - ack = self.from_backend_queue.get() - if ack != "OK": - success = False - break + self.recv_from_decode.send_multipart([client_obj, pickle.dumps(success)]) + if not success: + logger.warning(f"Remote connect failed: {request}") - self.recv_from_decode.send_pyobj(success) - if not success: - logger.warning(f"Remote connect failed: {request}") + if request.type == RemoteRequstType.REMOTE_PREFILL: + request: PrefillRequest = request + self.trigger_prefill(request) - if request.type == RemoteRequstType.REMOTE_PREFILL: - request: PrefillRequest = request - self.trigger_prefill(request) + except Exception as e: + logger.error(f"Error in remote prefill server loop: {e}", exc_info=e) @@ -121,47 +145,66 @@ def __init__(self, # map from server id to prefill server info self.remote_prefill_servers = {} + self.client_socket_cnt = 0 + self._ctx = zmq.Context() + + def _connect_server(self, server_ip: str, port: int): + _socket = self._ctx.socket(zmq.DEALER) + _socket.setsockopt_string(zmq.IDENTITY, f"{self.id}_{self.client_socket_cnt}") + self.client_socket_cnt += 1 + connect_str = f"tcp://{server_ip}:{port}" + _socket.connect(connect_str) + return SockWithPoller(_socket) + + def _send_nixl_agent(self, socket: SockWithPoller): + socket.send_pyobj(ConnectRequest( + type=RemoteRequstType.REMOTE_CONNECT, + decode_id=self.id, + num_tokens=self.local_agent_meta.num_tokens, + agent_metadatas=self.local_agent_meta.agent_metadatas, + agent_mem_descs=self.local_agent_meta.agent_mem_descs)) + + success = socket.recv_pyobj(timeout=60) + if success is None: + logger.warning("timeout to recv remote nixl connect response") + return False + + return success def connect_to_prefill_server(self, server_info: RemotePrefillServerInfo): + + if server_info.perfill_server_id in self.remote_prefill_servers: + return True + # build control path if not exist - if server_info.perfill_server_id not in self.remote_prefill_servers: - _ctx = zmq.Context() - _socket = _ctx.socket(zmq.PAIR) - connect_str = f"tcp://{server_info.prefill_server_ip}:{server_info.prefill_server_port}" - _socket.connect(connect_str) - _socket.send_pyobj(ConnectRequest( - type=RemoteRequstType.REMOTE_CONNECT, - decode_id=self.id, - num_tokens=self.local_agent_meta.num_tokens, - agent_metadatas=self.local_agent_meta.agent_metadatas, - agent_mem_descs=self.local_agent_meta.agent_mem_descs)) - - success = _socket.recv_pyobj() - if success: - self.remote_prefill_servers[server_info.perfill_server_id] = (_socket, server_info) - return True - else: - logger.warning("Remote Prefill Server Connect Failed") - return False - - return True + _socket = self._connect_server(server_info.prefill_server_ip, server_info.prefill_server_port) + success = self._send_nixl_agent(_socket) + if success: + self.remote_prefill_servers[server_info.perfill_server_id] = (_socket, server_info) + return True + else: + logger.warning("Remote Prefill Server Connect Failed") + return False + def main_loop(self): self.local_init() while True: - prefill_tasks: RemotePrefillTask = self.from_backend_queue.get() - - # connect first - if(self.connect_to_prefill_server(prefill_tasks.server_info)): - # do prefill - self.remote_prefill(prefill_tasks.server_info.perfill_server_id, prefill_tasks.prefill_request) - else: - # failed to connect a remote - for idx in self.to_backend_queues: - self.to_backend_queues.put(RemotePrefillStatus( - group_req_id=prefill_tasks.prefill_request.sampling_params.group_request_id, - status=-1, - )) + try: + prefill_tasks: RemotePrefillTask = self.from_backend_queue.get() + # connect first + if(self.connect_to_prefill_server(prefill_tasks.server_info)): + # do prefill + self.remote_prefill(prefill_tasks.server_info.perfill_server_id, prefill_tasks.prefill_request) + else: + # failed to connect a remote + for idx in self.to_backend_queues: + self.to_backend_queues.put(RemotePrefillStatus( + group_req_id=prefill_tasks.prefill_request.sampling_params.group_request_id, + status=-1, + )) + except Exception as e: + logger.error(f"Remote prefill client loop error: {e}", exc_info=e) # place request to server do remote prefill def remote_prefill(self, server_id: int, prefill_request: RemotePrefillRequest): @@ -170,7 +213,6 @@ def remote_prefill(self, server_id: int, prefill_request: RemotePrefillRequest): socket.send_pyobj(PrefillRequest(type=RemoteRequstType.REMOTE_PREFILL, decode_id=self.id, data=prefill_request)) - def remote_prefill_server_loop( id: int, http_server_port: int, @@ -179,6 +221,7 @@ def remote_prefill_server_loop( to_backend_queues: List[mp.Queue], agent_meta_queues: List[mp.Queue], ): + graceful_registry(inspect.currentframe().f_code.co_name) server = PDRemotePrefillServer(id, http_server_port, server_port, from_backend_queue, to_backend_queues, agent_meta_queues) server.main_loop() @@ -209,6 +252,7 @@ def remote_prefill_client_loop( from_backend_queue: mp.Queue, to_backend_queues: List[mp.Queue], agent_meta_queues: List[mp.Queue]): + graceful_registry(inspect.currentframe().f_code.co_name) client = PDRemotePrefillClient( id, From 0d40968e3607fa6aa581666e025a8ab3fd6069da Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Wed, 23 Apr 2025 19:06:22 +0800 Subject: [PATCH 08/19] rename. --- lightllm/server/router/manager.py | 4 ++-- lightllm/server/router/model_infer/mode_backend/__init__.py | 4 ++-- .../{pd_disaggregation => pd_nixl}/impl_for_pd_base.py | 0 .../{pd_disaggregation => pd_nixl}/impl_for_pd_decode.py | 0 .../{pd_disaggregation => pd_nixl}/impl_for_pd_prefill.py | 0 .../{pd_disaggregation => pd_nixl}/nixl_kv_transporter.py | 0 .../{pd_disaggregation => pd_nixl}/pd_remote_prefill.py | 0 .../{pd_disaggregation => pd_nixl}/pd_remote_prefill_obj.py | 0 8 files changed, 4 insertions(+), 4 deletions(-) rename lightllm/server/router/model_infer/mode_backend/{pd_disaggregation => pd_nixl}/impl_for_pd_base.py (100%) rename lightllm/server/router/model_infer/mode_backend/{pd_disaggregation => pd_nixl}/impl_for_pd_decode.py (100%) rename lightllm/server/router/model_infer/mode_backend/{pd_disaggregation => pd_nixl}/impl_for_pd_prefill.py (100%) rename lightllm/server/router/model_infer/mode_backend/{pd_disaggregation => pd_nixl}/nixl_kv_transporter.py (100%) rename lightllm/server/router/model_infer/mode_backend/{pd_disaggregation => pd_nixl}/pd_remote_prefill.py (100%) rename lightllm/server/router/model_infer/mode_backend/{pd_disaggregation => pd_nixl}/pd_remote_prefill_obj.py (100%) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index d8fa71b90..57cce6d82 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -206,7 +206,7 @@ async def wait_to_model_ready(self): start_prefill_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) if self.args.run_mode == "nixl_prefill": - from lightllm.server.router.model_infer.mode_backend.pd_disaggregation.pd_remote_prefill import ( + from lightllm.server.router.model_infer.mode_backend.pd_nixl.pd_remote_prefill import ( start_pd_remote_prefill_server_process ) start_pd_remote_prefill_server_process( @@ -227,7 +227,7 @@ async def wait_to_model_ready(self): start_decode_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) if self.args.run_mode == "nixl_decode": - from lightllm.server.router.model_infer.mode_backend.pd_disaggregation.pd_remote_prefill import ( + from lightllm.server.router.model_infer.mode_backend.pd_nixl.pd_remote_prefill import ( start_pd_remote_prefill_client_process ) start_pd_remote_prefill_client_process( diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index f2ef008ad..767da8f70 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -12,5 +12,5 @@ from .chunked_prefill.impl_for_xgrammar_mode import XgrammarBackend from .continues_batch.pd_mode.prefill_node_impl.prefill_impl_for_dp_chuncked import DPChunkedForPrefillNode from .continues_batch.pd_mode.decode_node_impl.decode_impl_for_dp import DPForDecodeNode -from .pd_disaggregation.impl_for_pd_prefill import PDNIXLBackendForPrefillNode -from .pd_disaggregation.impl_for_pd_decode import PDNIXLBackendForDecodeNode \ No newline at end of file +from .pd_nixl.impl_for_pd_prefill import PDNIXLBackendForPrefillNode +from .pd_nixl.impl_for_pd_decode import PDNIXLBackendForDecodeNode \ No newline at end of file diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py similarity index 100% rename from lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_base.py rename to lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_decode.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py similarity index 100% rename from lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_decode.py rename to lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py similarity index 100% rename from lightllm/server/router/model_infer/mode_backend/pd_disaggregation/impl_for_pd_prefill.py rename to lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/nixl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py similarity index 100% rename from lightllm/server/router/model_infer/mode_backend/pd_disaggregation/nixl_kv_transporter.py rename to lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py similarity index 100% rename from lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill.py rename to lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py diff --git a/lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill_obj.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py similarity index 100% rename from lightllm/server/router/model_infer/mode_backend/pd_disaggregation/pd_remote_prefill_obj.py rename to lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py From b24cea6d5465b8e825b4f2bb98afc58b0129c72e Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Wed, 23 Apr 2025 19:54:04 +0800 Subject: [PATCH 09/19] fix lint. --- lightllm/server/core/objs/req.py | 2 +- lightllm/server/core/objs/shm_req_manager.py | 1 - lightllm/server/core/objs/start_args_type.py | 6 +- lightllm/server/httpserver/pd_loop.py | 12 +- lightllm/server/multimodal_params.py | 2 +- lightllm/server/pd_io_struct.py | 2 + lightllm/server/router/manager.py | 12 +- .../server/router/model_infer/infer_batch.py | 4 - .../model_infer/mode_backend/__init__.py | 2 +- .../mode_backend/pd_nixl/impl_for_pd_base.py | 94 +++++++------- .../pd_nixl/impl_for_pd_decode.py | 34 ++--- .../pd_nixl/impl_for_pd_prefill.py | 16 +-- .../pd_nixl/nixl_kv_transporter.py | 39 +++--- .../mode_backend/pd_nixl/pd_remote_prefill.py | 120 +++++++++--------- .../pd_nixl/pd_remote_prefill_obj.py | 8 +- .../server/router/model_infer/model_rpc.py | 11 +- 16 files changed, 175 insertions(+), 190 deletions(-) diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 7fd0a90d6..24538e714 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -371,7 +371,7 @@ def set_pd_req_rank_state(self, tp_id: int, state: int): # set by router def set_pd_req_state(self): assert self.dp_world_size > 0, "dp_world_size should be set before calling this" - unique_state = np.unique(self.pd_req_state_shm.arr[:self.dp_world_size]) + unique_state = np.unique(self.pd_req_state_shm.arr[: self.dp_world_size]) self.pd_req_state_shm.arr[self.dp_world_size] = unique_state[0] # read by all rank diff --git a/lightllm/server/core/objs/shm_req_manager.py b/lightllm/server/core/objs/shm_req_manager.py index f3e9fe711..fd7901158 100644 --- a/lightllm/server/core/objs/shm_req_manager.py +++ b/lightllm/server/core/objs/shm_req_manager.py @@ -41,7 +41,6 @@ def get_req_class_type(self): else: return ChunkedPrefillReq - def get_max_req_num(self): args: StartArgs = get_env_start_args() return args.running_max_req_size diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 3b159d664..8f354da23 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -6,8 +6,10 @@ @dataclass class StartArgs: - run_mode: str = field(default="normal", metadata={"choices": ["normal", "prefill", "decode", "pd_master", - "nixl_prefill", "nixl_decode"]}) + run_mode: str = field( + default="normal", + metadata={"choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode"]}, + ) host: str = field(default="127.0.0.1") port: int = field(default=8000) zmq_mode: str = field( diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index 9d9aa1860..7eb95145a 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -95,7 +95,7 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O logger.info(f"Sent registration JSON: {regist_json}") # 转发任务 - if manager.pd_mode != NodeRole.NP: # nixl prefill don't need up token to master + if manager.pd_mode != NodeRole.NP: # nixl prefill don't need up token to master forwarding_tokens_task = asyncio.create_task(_up_tokens_to_pd_master(forwarding_queue, websocket)) # 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。 @@ -205,13 +205,9 @@ async def pd_handle_loop_from_d(manager: HttpServerManager): ) = await manager.recv_from_d.recv_pyobj() # 触发推理的task - async def pd_process_generate( - manager: "HttpServerManager", prompt, sampling_params, multimodal_params - ): + async def pd_process_generate(manager: "HttpServerManager", prompt, sampling_params, multimodal_params): try: - async for _, _, _, _ in manager.generate( - prompt, sampling_params, multimodal_params, None - ): + async for _, _, _, _ in manager.generate(prompt, sampling_params, multimodal_params, None): pass except BaseException as e: logger.error(str(e)) @@ -219,4 +215,4 @@ async def pd_process_generate( asyncio.create_task(pd_process_generate(manager, prompt, sampling_params, multimodal_params)) except Exception as e: - logger.exception(f"pd loop generate error: {str(e)}") \ No newline at end of file + logger.exception(f"pd loop generate error: {str(e)}") diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index c57edfe4e..2e2dab5b0 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -156,7 +156,7 @@ def to_dict(self): @classmethod def from_dict(cls, data: dict): - if 'images' not in data: + if "images" not in data: return cls() return cls(images=data["images"]) diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 69d1e6f08..a26fc9da5 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -120,12 +120,14 @@ class PDTransJoinInfo: # 一次连接,使用一个 uuid 为其标识 connect_id: str + @dataclass class RemotePrefillServerInfo: perfill_server_id: int prefill_server_ip: str prefill_server_port: int + @dataclass class PDTransLeaveInfo: decode_id: int diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 57cce6d82..03ad8d127 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -118,9 +118,7 @@ async def wait_to_model_ready(self): self.mem_queues: List[torch.multiprocessing.Queue] = [ torch.multiprocessing.Queue() for _ in range(self.node_world_size) ] - self.result_queues: List[mp.Queue] = [ - mp.Queue() for _ in range(self.world_size) - ] + self.result_queues: List[mp.Queue] = [mp.Queue() for _ in range(self.world_size)] self.rpc_event = multiprocessing.Event() self.rpc_finished_event = multiprocessing.Event() @@ -207,15 +205,16 @@ async def wait_to_model_ready(self): if self.args.run_mode == "nixl_prefill": from lightllm.server.router.model_infer.mode_backend.pd_nixl.pd_remote_prefill import ( - start_pd_remote_prefill_server_process + start_pd_remote_prefill_server_process, ) + start_pd_remote_prefill_server_process( self.args.pd_node_id, http_server_port=self.args.pd_remote_prefill_http_port, server_port=self.args.pd_remote_prefill_port, from_backend_queue=self.info_queue, to_backend_queues=self.result_queues, - agent_meta_queues=self.mem_queues + agent_meta_queues=self.mem_queues, ) if self.args.run_mode == "decode": @@ -228,8 +227,9 @@ async def wait_to_model_ready(self): if self.args.run_mode == "nixl_decode": from lightllm.server.router.model_infer.mode_backend.pd_nixl.pd_remote_prefill import ( - start_pd_remote_prefill_client_process + start_pd_remote_prefill_client_process, ) + start_pd_remote_prefill_client_process( self.args.pd_node_id, from_backend_queue=self.info_queue, diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index f5d26af8a..3f165ca08 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -286,10 +286,6 @@ def init_all(self): self.cur_output_len = 0 self.finish_status = FinishStatus() - - # print(f"[INFO] init_all: {self.shm_req.group_req_id} {self.shm_req.get_pd_req_state()} {self.remote_prefilling}", - # f"{self.cur_kv_len} {self.get_cur_total_len()}") - if self.paused or not self.initialized: # 如果是具有 prompt_cache 的使用特性则需要进行提前的填充和恢复操作。 if g_infer_context.radix_cache is not None and self.get_cur_total_len() > 1: diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index 767da8f70..0a4654fd4 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -13,4 +13,4 @@ from .continues_batch.pd_mode.prefill_node_impl.prefill_impl_for_dp_chuncked import DPChunkedForPrefillNode from .continues_batch.pd_mode.decode_node_impl.decode_impl_for_dp import DPForDecodeNode from .pd_nixl.impl_for_pd_prefill import PDNIXLBackendForPrefillNode -from .pd_nixl.impl_for_pd_decode import PDNIXLBackendForDecodeNode \ No newline at end of file +from .pd_nixl.impl_for_pd_decode import PDNIXLBackendForDecodeNode diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py index fb4ebe0d4..7f9765c9a 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py @@ -1,4 +1,3 @@ - import time import torch.multiprocessing as mp from typing import Dict, List @@ -12,23 +11,23 @@ from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq from .nixl_kv_transporter import NixlMetadata, NixlKVTransporter -from .pd_remote_prefill_obj import (PrefillRequest, - RemoteRequest, - RemoteRequstType, - ConnectRequest, - KVMoveRequest, - RemotePrefillStatus, - ThreadSafeDict) +from .pd_remote_prefill_obj import ( + PrefillRequest, + RemoteRequest, + RemoteRequstType, + ConnectRequest, + KVMoveRequest, + RemotePrefillStatus, + ThreadSafeDict, +) logger = init_logger(__name__) class PDNIXLBackendBase(ModeBackend): _THEAD_WAIT_INTERVAL = 0.001 - def __init__(self, - to_remote_queue: mp.Queue, - from_remote_queue: mp.Queue, - nixl_meta_queue: mp.Queue): + + def __init__(self, to_remote_queue: mp.Queue, from_remote_queue: mp.Queue, nixl_meta_queue: mp.Queue): super().__init__() self.to_remote_queue = to_remote_queue self.from_remote_queue = from_remote_queue @@ -41,17 +40,16 @@ def __init__(self, self.remote_prefill_requests: ThreadSafeDict = ThreadSafeDict() self.inflght_transfer_requests: ThreadSafeDict = ThreadSafeDict() - def init_custom(self): self.nixl_agent = NixlKVTransporter(self.args.pd_node_id, self.tp_rank) self.nixl_agent.register_kv_buffer(self.model.mem_manager.kv_buffer) - self.nixl_meta_queue.put((self.nixl_agent.agent_metadata, - self.nixl_agent.num_tokens, - self.nixl_agent.local_mem_desc)) - + self.nixl_meta_queue.put( + (self.nixl_agent.agent_metadata, self.nixl_agent.num_tokens, self.nixl_agent.local_mem_desc) + ) def _prefill_wait_loop(self): while True: + def handle_remote_prefill(req_status: RemotePrefillStatus): group_req_id = req_status.group_req_id status = req_status.status @@ -62,8 +60,10 @@ def handle_remote_prefill(req_status: RemotePrefillStatus): shm_req.set_pd_req_rank_state(self.rank_in_dp, status) self.remote_prefilled_reqs.pop(group_req_id) if self.is_master_in_dp: - logger.info(f"remote prefill reqeust: {group_req_id} done with status: {status} " - f"took: {time.time() - run_req.remote_prefill_start} seconds") + logger.info( + f"remote prefill reqeust: {group_req_id} done with status: {status} " + f"took: {time.time() - run_req.remote_prefill_start} seconds" + ) else: if self.is_master_in_dp: logger.warning(f"remote prefill reqeust: {group_req_id} not found") @@ -84,7 +84,6 @@ def handle_remote_prefill(req_status: RemotePrefillStatus): time.sleep(PDNIXLBackendBase._THEAD_WAIT_INTERVAL) - def _wait_transfer_loop(self): while True: done_req_ids = self.nixl_agent.get_done_tranfers() @@ -103,30 +102,35 @@ def _wait_transfer_loop(self): shm_req.set_pd_req_rank_state(self.rank_in_dp, state) del self.inflght_transfer_requests[req_id] if self.is_master_in_dp: - logger.info(f"req: {req_id} kv transfer with state: {state} " - f"took: {time.time() - req.kv_transfer_start} seconds") + logger.info( + f"req: {req_id} kv transfer with state: {state} " + f"took: {time.time() - req.kv_transfer_start} seconds" + ) time.sleep(PDNIXLBackendBase._THEAD_WAIT_INTERVAL) - def _handle_prefill_loop(self): while True: request: RemoteRequest = self.from_remote_queue.get() if request.type == RemoteRequstType.REMOTE_CONNECT: request: ConnectRequest logger.info(f"connect request received from: {request.decode_id}") - self.nixl_agent.add_remote_agent(NixlMetadata( - id = request.decode_id, - num_tokens = request.num_tokens, - agent_metadatas = request.agent_metadatas, - agent_mem_descs = request.agent_mem_descs - )) + self.nixl_agent.add_remote_agent( + NixlMetadata( + id=request.decode_id, + num_tokens=request.num_tokens, + agent_metadatas=request.agent_metadatas, + agent_mem_descs=request.agent_mem_descs, + ) + ) self.to_remote_queue.put("OK") if request.type == RemoteRequstType.REMOTE_PREFILL: request: PrefillRequest group_request_id = request.data.sampling_params.group_request_id - logger.info(f"prefill request received from decode: {request.decode_id} " - f"and group request id: {group_request_id}") + logger.info( + f"prefill request received from decode: {request.decode_id} " + f"and group request id: {group_request_id}" + ) self.remote_prefill_requests[group_request_id] = request def _transfer_kv_to_remote(self, req: InferReq): @@ -141,7 +145,7 @@ def _transfer_kv_to_remote(self, req: InferReq): req.kv_transfer_start = time.time() kv_transfer_req = KVMoveRequest( group_req_id=group_req_id, - token_ids=self.model.req_manager.req_to_token_indexs[req.req_idx][:req.cur_kv_len].tolist() + token_ids=self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].tolist(), ) remote_request = self.remote_prefill_requests[group_req_id] self.nixl_agent.write_blocks(kv_transfer_req, remote_request) @@ -150,13 +154,14 @@ def _transfer_kv_to_remote(self, req: InferReq): req.kv_transfering = True self.inflght_transfer_requests[group_req_id] = req - def _decode_filter_reqs(self, prefill_reqs: List[InferReq], - aborted_reqs: List[InferReq], decode_reqs: List[InferReq]): + def _decode_filter_reqs( + self, prefill_reqs: List[InferReq], aborted_reqs: List[InferReq], decode_reqs: List[InferReq] + ): new_prefill_reqs: List[InferReq] = [] new_aborted_reqs: List[InferReq] = [] remote_prefill_reqs: List[InferReq] = [] - # filter out aborted requests + # filter out aborted requests for req in aborted_reqs: if req.in_prefill_or_transfer: shm_req: PDChunkedPrefillReq = req.shm_req @@ -165,7 +170,7 @@ def _decode_filter_reqs(self, prefill_reqs: List[InferReq], new_aborted_reqs.append(req) req.in_prefill_or_transfer = False else: - #TODO trigger remote abort + # TODO trigger remote abort remote_prefill_reqs.append(req) for req in prefill_reqs: @@ -173,14 +178,14 @@ def _decode_filter_reqs(self, prefill_reqs: List[InferReq], shm_req: PDChunkedPrefillReq = req.shm_req # state is updated by router state = shm_req.get_pd_req_state() - if state == 1: # success + if state == 1: # success req.cur_kv_len = req.get_cur_total_len() - 1 decode_reqs.append(req) req.in_prefill_or_transfer = False - elif state == -1: # failure + elif state == -1: # failure aborted_reqs.append(req) req.in_prefill_or_transfer = False - elif state == 0: # in progress + elif state == 0: # in progress remote_prefill_reqs.append(req) else: logger.warning(f"remote prefill request {shm_req.group_req_id} unexpected state {state}") @@ -198,10 +203,10 @@ def _prefill_filter_reqs(self, ok_finished_reqs: List[InferReq], aborted_reqs: L if req.in_prefill_or_transfer: shm_req: PDChunkedPrefillReq = req.shm_req state = shm_req.get_pd_req_state() - if state == 1: # success + if state == 1: # success new_ok_finished_reqs.append(req) req.in_prefill_or_transfer = False - elif state == -1: # failure + elif state == -1: # failure aborted_reqs.append(req) req.in_prefill_or_transfer = False elif state == 0: @@ -235,7 +240,7 @@ def _prepare_remote_prefill_inputs(self, req_objs: List[InferReq]): input_ids.append(input_id) start_loc += input_token_len - nopad_b_start_loc.append(start_loc) # last request + nopad_b_start_loc.append(start_loc) # last request input_ids = np.concatenate(input_ids, dtype=np.int64) # g_infer_state_lock.acquire() # I don't think it's needed @@ -257,8 +262,5 @@ def _prefill_abort_remote(self, req_objs: List[InferReq]): for req_obj in req_objs: group_req_id = req_obj.shm_req.group_req_id if group_req_id in self.remote_prefill_requests: - self.nixl_agent.send_abort_notify( - self.remote_prefill_requests[group_req_id].decode_id, - group_req_id) + self.nixl_agent.send_abort_notify(self.remote_prefill_requests[group_req_id].decode_id, group_req_id) del self.remote_prefill_requests[group_req_id] - diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py index 213632b72..4ba256a20 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py @@ -11,10 +11,7 @@ from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.server.multimodal_params import MultimodalParams -from .pd_remote_prefill_obj import ( - RemotePrefillTask, - RemotePrefillServerInfo, - RemotePrefillRequest) +from .pd_remote_prefill_obj import RemotePrefillTask, RemotePrefillServerInfo, RemotePrefillRequest from .impl_for_pd_base import PDNIXLBackendBase @@ -22,40 +19,34 @@ class PDNIXLBackendForDecodeNode(PDNIXLBackendBase): - def __init__(self, - prefill_task_queue: mp.Queue, - prefill_done_queue: mp.Queue, - nix_meta_queue: mp.Queue) -> None: + def __init__(self, prefill_task_queue: mp.Queue, prefill_done_queue: mp.Queue, nix_meta_queue: mp.Queue) -> None: super().__init__(prefill_task_queue, prefill_done_queue, nix_meta_queue) - def init_custom(self): super().init_custom() self.wait_prefill_thread = threading.Thread(target=self._prefill_wait_loop, daemon=True) self.wait_prefill_thread.start() return - def _build_remote_prefill_task(self, index: int, kwargs: Dict, req: InferReq): prefill_node = req.shm_req.sample_params.move_kv_to_decode_node.to_dict() prefill_node_info = RemotePrefillServerInfo( - perfill_server_id=prefill_node['node_id'], - prefill_server_ip=prefill_node['ip'], - prefill_server_port=prefill_node['rpyc_port'], + perfill_server_id=prefill_node["node_id"], + prefill_server_ip=prefill_node["ip"], + prefill_server_port=prefill_node["rpyc_port"], ) - mem_indexes = kwargs.get('mem_indexes') - b_start_loc = kwargs.get('b_start_loc') + mem_indexes = kwargs.get("mem_indexes") + b_start_loc = kwargs.get("b_start_loc") prefill_request = RemotePrefillRequest( - prompt = req.shm_req.get_prompt_ids(), + prompt=req.shm_req.get_prompt_ids(), sampling_params=req.shm_req.sample_params, multimodal_params=MultimodalParams.from_dict(req.multimodal_params), local_cached_len=req.cur_kv_len, - token_ids=mem_indexes[b_start_loc[index]: b_start_loc[index+1]], + token_ids=mem_indexes[b_start_loc[index] : b_start_loc[index + 1]], ) return RemotePrefillTask(server_info=prefill_node_info, prefill_request=prefill_request) - def prefill(self, reqs: List[Tuple]): self._init_reqs(reqs, init_req_obj=False) return @@ -85,7 +76,7 @@ def decode(self): run_req.remote_prefill_start = time.time() self.to_remote_queue.put(self._build_remote_prefill_task(idx, kwargs, run_req)) - shm_req.set_pd_req_rank_state(self.rank_in_dp, 0) # set in progress state + shm_req.set_pd_req_rank_state(self.rank_in_dp, 0) # set in progress state run_req.in_prefill_or_transfer = True self.remote_prefilled_reqs[shm_req.group_req_id] = run_req @@ -103,8 +94,9 @@ def decode(self): next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() self._post_handle( - run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False) + run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False + ) self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) - return \ No newline at end of file + return diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py index fc3437131..ed77c9263 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py @@ -13,13 +13,9 @@ class PDNIXLBackendForPrefillNode(PDNIXLBackendBase): - def __init__(self, - transfer_task_queue: mp.Queue, - transfer_done_queue: mp.Queue, - nixl_meta_queue: mp.Queue) -> None: + def __init__(self, transfer_task_queue: mp.Queue, transfer_done_queue: mp.Queue, nixl_meta_queue: mp.Queue) -> None: super().__init__(transfer_task_queue, transfer_done_queue, nixl_meta_queue) - def init_custom(self): super().init_custom() self.handle_prefill_loop_thread = threading.Thread(target=self._handle_prefill_loop, daemon=True) @@ -28,7 +24,6 @@ def init_custom(self): self.wait_transfer_loop_thread.start() return - def prefill(self, reqs: List[Tuple]): self._init_reqs(reqs) return @@ -40,8 +35,6 @@ def decode(self): ) ok_finished_reqs, aborted_reqs, _ = self._prefill_filter_reqs(ok_finished_reqs, aborted_reqs) - # print(f"{self.rank_in_dp}: {len(uinit_reqs)} uninit, {len(aborted_reqs)} aborted, {len(ok_finished_reqs)} ok finished, " - # f"{len(prefill_reqs)} new prefill, {len(decode_reqs)} decode, {len(transfer_reqs)} transfer_reqs") assert len(uinit_reqs) == 0 assert len(decode_reqs) == 0 @@ -65,10 +58,11 @@ def decode(self): next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() self._post_handle( - run_reqs, next_token_ids, next_token_logprobs, + run_reqs, + next_token_ids, + next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False, - extra_post_req_handle_func=lambda req, _1, _2: self._transfer_kv_to_remote(req) + extra_post_req_handle_func=lambda req, _1, _2: self._transfer_kv_to_remote(req), ) return - diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py index 4eab304a0..7ae9a4bf6 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py @@ -1,4 +1,3 @@ - from collections import defaultdict from typing import Dict, List, Any from torch import Tensor @@ -15,11 +14,13 @@ try: from nixl._api import nixl_agent as NixlWrapper from nixl._api import nixlBind + logger.info("Nixl is available") except ImportError: logger.warning("nixl is not installed, which is required for pd disagreggation!!!") NixlWrapper = None + @dataclass class NixlMetadata: id: str @@ -27,6 +28,7 @@ class NixlMetadata: agent_metadatas: list[bytes] agent_mem_descs: list[bytes] + class NixlKVTransporter: def __init__(self, node_id: int, tp_idx: int): self.node_id = node_id @@ -36,7 +38,7 @@ def __init__(self, node_id: int, tp_idx: int): self.num_layers = -1 self.num_tokens = -1 self.num_heads = -1 - self.head_dims= -1 + self.head_dims = -1 self.token_len = -1 self.layer_len = -1 @@ -62,13 +64,14 @@ def local_mem_desc(self): def get_new_notifs(self): return self.nixl_agent.get_new_notifs() - def _create_xfer_handles(self, reg_desc: nixlBind.nixlRegDList, num_tokens: int, agent_name: str = ""): base_addr, _, device_id, _ = reg_desc[0] tokens_data = [] for layer_id in range(self.num_layers): for token_id in range(num_tokens): - tokens_data.append((base_addr + layer_id * self.layer_len + token_id * self.token_len, self.token_len, device_id)) + tokens_data.append( + (base_addr + layer_id * self.layer_len + token_id * self.token_len, self.token_len, device_id) + ) descs = self.nixl_agent.get_xfer_descs(tokens_data, "VRAM", True) return self.nixl_agent.prep_xfer_dlist(agent_name, descs, is_sorted=True) @@ -80,11 +83,10 @@ def register_kv_buffer(self, kv_buffer: Tensor): self.reg_desc = self.nixl_agent.register_memory(kv_buffer) self.local_xfer_handles = self._create_xfer_handles(self.reg_desc, self.num_tokens) - def add_remote_agent(self, remote_agent: NixlMetadata): - for idx, (agent_metadata, num_tokens, agent_mem_desc) in enumerate(zip(remote_agent.agent_metadatas, - remote_agent.num_tokens, - remote_agent.agent_mem_descs)): + for idx, (agent_metadata, num_tokens, agent_mem_desc) in enumerate( + zip(remote_agent.agent_metadatas, remote_agent.num_tokens, remote_agent.agent_mem_descs) + ): if self.tp_idx != idx: self.remote_agents[remote_agent.id].append(None) continue @@ -94,7 +96,10 @@ def add_remote_agent(self, remote_agent: NixlMetadata): logger.info("Added remote agent %s with mem desc %s", peer_name, mem_desc) kv_xfer_handles = self._create_xfer_handles(mem_desc, num_tokens, agent_name=peer_name) self.remote_agents[remote_agent.id].append( - RemoteAgent(name=peer_name, kv_mem_desc=mem_desc, num_tokens=num_tokens, kv_xfer_handles=kv_xfer_handles)) + RemoteAgent( + name=peer_name, kv_mem_desc=mem_desc, num_tokens=num_tokens, kv_xfer_handles=kv_xfer_handles + ) + ) def _get_token_desc_ids(self, token_ids: List[int]): descs_ids = [] @@ -108,7 +113,9 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest): skip_kv_move_len = prefill_request.data.local_cached_len src_token_ids = request.token_ids[skip_kv_move_len:] dst_token_ids = prefill_request.data.token_ids - remote_agent: RemoteAgent = self.remote_agents[prefill_request.decode_id][self.tp_idx] #TODO one-one mapping now + remote_agent: RemoteAgent = self.remote_agents[prefill_request.decode_id][ + self.tp_idx + ] # TODO one-one mapping now if len(src_token_ids) > 0: assert len(src_token_ids) == len(dst_token_ids), f"{len(src_token_ids)} {len(dst_token_ids)}" @@ -118,13 +125,12 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest): src_handle = self.local_xfer_handles dst_handle = remote_agent.kv_xfer_handles notify_status = RemotePrefillStatus(group_req_id=group_reqeust_id, status=1) - handle = self.nixl_agent.make_prepped_xfer("WRITE", - src_handle, src_token_descs, - dst_handle, dst_token_descs, - notify_status.serialize()) + handle = self.nixl_agent.make_prepped_xfer( + "WRITE", src_handle, src_token_descs, dst_handle, dst_token_descs, notify_status.serialize() + ) status = self.nixl_agent.transfer(handle) - assert status != 'ERR' + assert status != "ERR" self.inflight_transfers[group_reqeust_id] = (handle, remote_agent, False) @@ -132,7 +138,6 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest): return None - def send_abort_notify(self, remote_id: int, group_reqeust_id): remote_agent: RemoteAgent = self.remote_agents[remote_id][self.tp_idx] notify_status = RemotePrefillStatus(group_req_id=group_reqeust_id, status=-1) @@ -141,7 +146,6 @@ def send_abort_notify(self, remote_id: int, group_reqeust_id): if group_reqeust_id in self.inflight_transfers: self.inflight_transfers[group_reqeust_id][2] = True - def get_done_tranfers(self): done_req_ids = [] @@ -177,4 +181,3 @@ def shutdonw(self): for agent in agents: self.nixl_agent.remove_remote_agent(agent.name) self.nixl_agent.release_xfer_handle(agent.kv_xfer_handles) - diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py index 66605ba35..2e5fa7604 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py @@ -17,13 +17,14 @@ RemotePrefillRequest, RemotePrefillServerInfo, RemotePrefillTask, - RemotePrefillStatus + RemotePrefillStatus, ) from .nixl_kv_transporter import NixlMetadata logger = init_logger(__name__) + class SockWithPoller: def __init__(self, sock: zmq.Socket): self.sock = sock @@ -41,19 +42,21 @@ def recv_pyobj(self, timeout: int = 5): def send_pyobj(self, obj: Any): return self.sock.send_pyobj(obj) + class PDRemotePrefillBase: - def __init__(self, - id: int, - from_backend_queue: mp.Queue, - to_backend_queues: List[mp.Queue], - agent_meta_queues: List[mp.Queue], # need send kv cache to this process and register with nixl - ): + def __init__( + self, + id: int, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], # need send kv cache to this process and register with nixl + ): self.id = id self.tp_size = len(agent_meta_queues) self.agent_meta_queues = agent_meta_queues self.from_backend_queue = from_backend_queue self.to_backend_queues = to_backend_queues - self.local_agent_meta = None + self.local_agent_meta = None def local_init(self): agent_metas = NixlMetadata(id=self.id, agent_metadatas=[], num_tokens=[], agent_mem_descs=[]) @@ -69,13 +72,15 @@ def local_init(self): class PDRemotePrefillServer(PDRemotePrefillBase): - def __init__(self, - id: int, - http_server_port: int, - server_port: int, - from_backend_queue: mp.Queue, - to_backend_queues: List[mp.Queue], - agent_meta_queues: List[mp.Queue]): + def __init__( + self, + id: int, + http_server_port: int, + server_port: int, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], + ): super().__init__(id, from_backend_queue, to_backend_queues, agent_meta_queues) # map from client id to decode server info self.remote_decode_clients = {} @@ -92,7 +97,6 @@ def __init__(self, self.prefill_requests = {} - def main_loop(self): self.local_init() while True: @@ -117,7 +121,6 @@ def main_loop(self): if not success: logger.warning(f"Remote connect failed: {request}") - if request.type == RemoteRequstType.REMOTE_PREFILL: request: PrefillRequest = request self.trigger_prefill(request) @@ -125,22 +128,21 @@ def main_loop(self): except Exception as e: logger.error(f"Error in remote prefill server loop: {e}", exc_info=e) - - def trigger_prefill(self, request: PrefillRequest): - self.send_to_httpserver.send_pyobj((request.data.prompt, request.data.sampling_params, request.data.multimodal_params)) + self.send_to_httpserver.send_pyobj( + (request.data.prompt, request.data.sampling_params, request.data.multimodal_params) + ) self.prefill_requests[request.data.sampling_params.group_request_id] = request - class PDRemotePrefillClient(PDRemotePrefillBase): - - def __init__(self, - id: int, - from_backend_queue: mp.Queue, # only tp0 will trigger prefill - to_backend_queues: List[mp.Queue], # one to many done queue - agent_meta_queues: List[mp.Queue] - ): + def __init__( + self, + id: int, + from_backend_queue: mp.Queue, # only tp0 will trigger prefill + to_backend_queues: List[mp.Queue], # one to many done queue + agent_meta_queues: List[mp.Queue], + ): super().__init__(id, from_backend_queue, to_backend_queues, agent_meta_queues) # map from server id to prefill server info @@ -157,12 +159,15 @@ def _connect_server(self, server_ip: str, port: int): return SockWithPoller(_socket) def _send_nixl_agent(self, socket: SockWithPoller): - socket.send_pyobj(ConnectRequest( - type=RemoteRequstType.REMOTE_CONNECT, - decode_id=self.id, - num_tokens=self.local_agent_meta.num_tokens, - agent_metadatas=self.local_agent_meta.agent_metadatas, - agent_mem_descs=self.local_agent_meta.agent_mem_descs)) + socket.send_pyobj( + ConnectRequest( + type=RemoteRequstType.REMOTE_CONNECT, + decode_id=self.id, + num_tokens=self.local_agent_meta.num_tokens, + agent_metadatas=self.local_agent_meta.agent_metadatas, + agent_mem_descs=self.local_agent_meta.agent_mem_descs, + ) + ) success = socket.recv_pyobj(timeout=60) if success is None: @@ -186,23 +191,24 @@ def connect_to_prefill_server(self, server_info: RemotePrefillServerInfo): logger.warning("Remote Prefill Server Connect Failed") return False - def main_loop(self): self.local_init() while True: try: prefill_tasks: RemotePrefillTask = self.from_backend_queue.get() # connect first - if(self.connect_to_prefill_server(prefill_tasks.server_info)): + if self.connect_to_prefill_server(prefill_tasks.server_info): # do prefill self.remote_prefill(prefill_tasks.server_info.perfill_server_id, prefill_tasks.prefill_request) else: # failed to connect a remote for idx in self.to_backend_queues: - self.to_backend_queues.put(RemotePrefillStatus( - group_req_id=prefill_tasks.prefill_request.sampling_params.group_request_id, - status=-1, - )) + self.to_backend_queues.put( + RemotePrefillStatus( + group_req_id=prefill_tasks.prefill_request.sampling_params.group_request_id, + status=-1, + ) + ) except Exception as e: logger.error(f"Remote prefill client loop error: {e}", exc_info=e) @@ -222,8 +228,9 @@ def remote_prefill_server_loop( agent_meta_queues: List[mp.Queue], ): graceful_registry(inspect.currentframe().f_code.co_name) - server = PDRemotePrefillServer(id, http_server_port, server_port, - from_backend_queue, to_backend_queues, agent_meta_queues) + server = PDRemotePrefillServer( + id, http_server_port, server_port, from_backend_queue, to_backend_queues, agent_meta_queues + ) server.main_loop() @@ -237,9 +244,7 @@ def start_pd_remote_prefill_server_process( ): proc = mp.Process( target=remote_prefill_server_loop, - args=( - id, http_server_port, server_port, - from_backend_queue, to_backend_queues, agent_meta_queues) + args=(id, http_server_port, server_port, from_backend_queue, to_backend_queues, agent_meta_queues), ) proc.start() assert proc.is_alive() @@ -248,30 +253,25 @@ def start_pd_remote_prefill_server_process( def remote_prefill_client_loop( - id: int, - from_backend_queue: mp.Queue, - to_backend_queues: List[mp.Queue], - agent_meta_queues: List[mp.Queue]): + id: int, from_backend_queue: mp.Queue, to_backend_queues: List[mp.Queue], agent_meta_queues: List[mp.Queue] +): graceful_registry(inspect.currentframe().f_code.co_name) client = PDRemotePrefillClient( - id, - from_backend_queue, - to_backend_queues, - agent_meta_queues, - ) + id, + from_backend_queue, + to_backend_queues, + agent_meta_queues, + ) client.main_loop() + def start_pd_remote_prefill_client_process( - id: int, - from_backend_queue: mp.Queue, - to_backend_queues: List[mp.Queue], - agent_meta_queues: List[mp.Queue] + id: int, from_backend_queue: mp.Queue, to_backend_queues: List[mp.Queue], agent_meta_queues: List[mp.Queue] ): proc = mp.Process( - target=remote_prefill_client_loop, - args=(id, from_backend_queue, to_backend_queues, agent_meta_queues) + target=remote_prefill_client_loop, args=(id, from_backend_queue, to_backend_queues, agent_meta_queues) ) proc.start() assert proc.is_alive() diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py index 871780d90..308fc8b19 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py @@ -18,17 +18,19 @@ logger.error("nixl is not installed, which is required for pd disagreggation!!!") raise + class RemoteRequstType(Enum): REMOTE_CONNECT = 1 REMOTE_PREFILL = 2 + @dataclass class RemotePrefillRequest: prompt: Union[str, List[int]] sampling_params: SamplingParams multimodal_params: MultimodalParams - local_cached_len: int # will skip transfer - token_ids: List[int] # mem cache indexes + local_cached_len: int # will skip transfer + token_ids: List[int] # mem cache indexes @dataclass @@ -130,4 +132,4 @@ def values(self): def clear(self) -> None: with self._lock: - self._dict.clear() \ No newline at end of file + self._dict.clear() diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index b88414257..564aea3a2 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -22,7 +22,7 @@ ChunckedPrefillForPrefillNode, DPChunkedForPrefillNode, PDNIXLBackendForPrefillNode, - PDNIXLBackendForDecodeNode + PDNIXLBackendForDecodeNode, ) from lightllm.server.core.objs import RpcShmParams, RpcShmResults, ShmSyncStatusArray from lightllm.utils.log_utils import init_logger @@ -142,9 +142,7 @@ def init_model(self, kvargs): elif is_nixl_prefill_node: assert kvargs.get("args", None).dp == 1 - self.backend = PDNIXLBackendForPrefillNode(self.info_queue, - self.result_queue, - self.mem_queue) + self.backend = PDNIXLBackendForPrefillNode(self.info_queue, self.result_queue, self.mem_queue) elif is_decode_node: if kvargs.get("args", None).dp > 1: self.backend = DPForDecodeNode(self.info_queue, self.mem_queue) @@ -153,9 +151,7 @@ def init_model(self, kvargs): elif is_nixl_decode_node: assert kvargs.get("args", None).dp == 1 - self.backend = PDNIXLBackendForDecodeNode(self.info_queue, - self.result_queue, - self.mem_queue) + self.backend = PDNIXLBackendForDecodeNode(self.info_queue, self.result_queue, self.mem_queue) elif kvargs.get("dp_size", 1) > 1: self.backend = DPChunkedPrefillBackend() elif use_reward_model: @@ -334,6 +330,7 @@ async def start_model_process( router_lock: mp.Queue, ): import lightllm.utils.rpyc_fix_utils as _ + # 单卡单机时不使用 rpc if node_world_size == 1 and args.nnodes == 1: return ModelRpcServer( From 24fef1f1d3ff24d1b8af640313c73224a5fec590 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Thu, 24 Apr 2025 16:46:48 +0800 Subject: [PATCH 10/19] support dp for pd_nixl --- lightllm/server/pd_io_struct.py | 8 ++ lightllm/server/router/manager.py | 19 ++- .../model_infer/mode_backend/__init__.py | 2 + .../mode_backend/pd_nixl/impl_for_pd_base.py | 6 +- .../pd_nixl/impl_for_pd_decode.py | 1 - .../pd_nixl/impl_for_pd_decode_dp.py | 121 ++++++++++++++++++ .../pd_nixl/impl_for_pd_prefill_dp.py | 109 ++++++++++++++++ .../mode_backend/pd_nixl/pd_remote_prefill.py | 93 +++++++------- .../pd_nixl/pd_remote_prefill_obj.py | 36 ++++++ .../server/router/model_infer/model_rpc.py | 16 ++- 10 files changed, 351 insertions(+), 60 deletions(-) create mode 100644 lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py create mode 100644 lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index a26fc9da5..b4831072a 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -127,6 +127,14 @@ class RemotePrefillServerInfo: prefill_server_ip: str prefill_server_port: int +@dataclass +class DistInfo: + world_size: int + nnodes: int + dp_size: int + dp_world_size: int + dp_size_in_node: int + node_world_size: int @dataclass class PDTransLeaveInfo: diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 03ad8d127..0cdff2d44 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -34,6 +34,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.pd_io_struct import DistInfo logger = init_logger(__name__) @@ -49,6 +50,7 @@ def __init__(self, args, router_port, detokenization_port, metric_port): self.dp_size = args.dp # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容 self.dp_size_in_node = max(1, args.dp // self.nnodes) + self.dp_world_size = self.world_size // self.dp_size self.is_multinode_tp = args.nnodes > 1 and args.dp == 1 self.is_multinode_and_multidp = args.nnodes > 1 and args.dp > 1 # 判断是否是保守调度,保守调度不会发生暂停 req 的情况,但是有些场景可能影响吞吐 @@ -118,7 +120,7 @@ async def wait_to_model_ready(self): self.mem_queues: List[torch.multiprocessing.Queue] = [ torch.multiprocessing.Queue() for _ in range(self.node_world_size) ] - self.result_queues: List[mp.Queue] = [mp.Queue() for _ in range(self.world_size)] + self.result_queues: List[mp.Queue] = [mp.Queue() for _ in range(self.node_world_size)] self.rpc_event = multiprocessing.Event() self.rpc_finished_event = multiprocessing.Event() @@ -134,8 +136,8 @@ async def wait_to_model_ready(self): rpc_event=self.rpc_event, rpc_finished_event=self.rpc_finished_event, info_queue=self.info_queue, - result_queue=self.result_queues[rank_id], - mem_queue=self.mem_queues[rank_id], + result_queue=self.result_queues[rank_id % node_world_size], + mem_queue=self.mem_queues[rank_id % node_world_size], router_lock=self.router_lock, ) self.model_rpc_servers.append(rpc_model) @@ -190,7 +192,7 @@ async def wait_to_model_ready(self): get_unique_server_name(), self.max_total_token_num, node_world_size=self.node_world_size, - dp_world_size=self.world_size // self.dp_size, + dp_world_size=self.dp_world_size, ) self.req_queue = build_req_queue(self.args, self, self.dp_size_in_node) logger.info(f"use req queue {self.req_queue.__class__.__name__}") @@ -208,8 +210,12 @@ async def wait_to_model_ready(self): start_pd_remote_prefill_server_process, ) + dist_info = DistInfo(self.world_size, self.nnodes, self.dp_size, + self.dp_world_size, self.dp_size_in_node, self.node_world_size) + start_pd_remote_prefill_server_process( self.args.pd_node_id, + dist_info = dist_info, http_server_port=self.args.pd_remote_prefill_http_port, server_port=self.args.pd_remote_prefill_port, from_backend_queue=self.info_queue, @@ -229,9 +235,12 @@ async def wait_to_model_ready(self): from lightllm.server.router.model_infer.mode_backend.pd_nixl.pd_remote_prefill import ( start_pd_remote_prefill_client_process, ) + dist_info = DistInfo(self.world_size, self.nnodes, self.dp_size, + self.dp_world_size, self.dp_size_in_node, self.node_world_size) start_pd_remote_prefill_client_process( self.args.pd_node_id, + dist_info, from_backend_queue=self.info_queue, to_backend_queues=self.result_queues, agent_meta_queues=self.mem_queues, @@ -246,7 +255,7 @@ def add_req(self, group_req_indexes: GroupReqIndexes): req.multimodal_params = group_req_indexes.multimodal_params req.start_time = group_req_indexes.time_mark if isinstance(req, PDChunkedPrefillReq): - req.dp_world_size = self.world_size + req.dp_world_size = self.dp_world_size req_group.append(req) logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s") diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index 0a4654fd4..fe13a60a7 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -14,3 +14,5 @@ from .continues_batch.pd_mode.decode_node_impl.decode_impl_for_dp import DPForDecodeNode from .pd_nixl.impl_for_pd_prefill import PDNIXLBackendForPrefillNode from .pd_nixl.impl_for_pd_decode import PDNIXLBackendForDecodeNode +from .pd_nixl.impl_for_pd_decode_dp import PDNIXLDPBackendForDecodeNode +from .pd_nixl.impl_for_pd_prefill_dp import PDNIXLDPBackendForPrefillNode diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py index 7f9765c9a..6bd04dcca 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py @@ -41,7 +41,7 @@ def __init__(self, to_remote_queue: mp.Queue, from_remote_queue: mp.Queue, nixl_ self.inflght_transfer_requests: ThreadSafeDict = ThreadSafeDict() def init_custom(self): - self.nixl_agent = NixlKVTransporter(self.args.pd_node_id, self.tp_rank) + self.nixl_agent = NixlKVTransporter(self.args.pd_node_id, self.rank_in_node) self.nixl_agent.register_kv_buffer(self.model.mem_manager.kv_buffer) self.nixl_meta_queue.put( (self.nixl_agent.agent_metadata, self.nixl_agent.num_tokens, self.nixl_agent.local_mem_desc) @@ -243,11 +243,11 @@ def _prepare_remote_prefill_inputs(self, req_objs: List[InferReq]): nopad_b_start_loc.append(start_loc) # last request input_ids = np.concatenate(input_ids, dtype=np.int64) - # g_infer_state_lock.acquire() # I don't think it's needed + if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]) - # g_infer_state_lock.release() + kwargs = { "batch_size": len(run_reqs), "input_ids": input_ids, diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py index 4ba256a20..80ac3f1c4 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py @@ -81,7 +81,6 @@ def decode(self): self.remote_prefilled_reqs[shm_req.group_req_id] = run_req if decode_reqs: - # print(f"decode req: {self.rank_in_dp}: {len(decode_reqs)}") kwargs, run_reqs = prepare_decode_inputs(decode_reqs) logits = self.model.forward(**kwargs) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py new file mode 100644 index 000000000..6be2873e9 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py @@ -0,0 +1,121 @@ +import time +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +from typing import List +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.server.core.objs.req import PDChunkedPrefillReq +from lightllm.utils.log_utils import init_logger +from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import padded_prepare_decode_inputs + +from .impl_for_pd_decode import PDNIXLBackendForDecodeNode + +logger = init_logger(__name__) + + +class PDNIXLDPBackendForDecodeNode(PDNIXLBackendForDecodeNode): + def __init__(self, prefill_task_queue: mp.Queue, prefill_done_queue: mp.Queue, nix_meta_queue: mp.Queue) -> None: + super().__init__(prefill_task_queue, prefill_done_queue, nix_meta_queue) + self.enable_decode_microbatch_overlap = get_env_start_args().enable_decode_microbatch_overlap + + def init_custom(self): + super().init_custom() + + self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) + from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import padded_prepare_prefill_inputs + kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs([], 1, is_multimodal=self.is_multimodal) + self.model.forward(**kwargs) + assert len(run_reqs) == 0 and padded_req_num == 1 + + return + + def decode(self): + + uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids, + no_decode=False, + ) + # filter out remote prefilling reqs + prefill_reqs, aborted_reqs, decode_reqs, _ = self._decode_filter_reqs(prefill_reqs, aborted_reqs, decode_reqs) + + self._filter_reqs(aborted_reqs) + + # allocate kv cache, do remote prefill + if prefill_reqs: + # TODO: we could allocate cache later after remote prefill done and get a signal from remote + # but it will have a risk to not have enough cache for this request. + kwargs, run_reqs = self._prepare_remote_prefill_inputs(prefill_reqs) + for idx, run_req in enumerate(run_reqs): + run_req: InferReq = run_req + shm_req: PDChunkedPrefillReq = run_req.shm_req + # forward each req to remote prefill + # since the token index are the same across TPs, we only need to trigger prefill on master + if self.is_master_in_dp: + run_req.remote_prefill_start = time.time() + self.to_remote_queue.put(self._build_remote_prefill_task(idx, kwargs, run_req)) + + shm_req.set_pd_req_rank_state(self.rank_in_dp, 0) # set in progress state + run_req.in_prefill_or_transfer = True + self.remote_prefilled_reqs[shm_req.group_req_id] = run_req + + self.reduce_tensor.fill_(len(decode_reqs)) + dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX) + max_decode_num = self.reduce_tensor.item() + if max_decode_num != 0: + if not self.enable_decode_microbatch_overlap: + self.normal_decode(decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs) + else: + self.overlap_decode(decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs) + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + return + + def normal_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs): + + kwargs, run_reqs, padded_req_num = padded_prepare_decode_inputs( + decode_reqs, max_decode_num, is_multimodal=self.is_multimodal + ) + logits = self.model.forward(**kwargs) + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + if len(run_reqs) != 0: + logits = logits[0 : len(run_reqs), :] + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + self._post_handle( + run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False + ) + return + + def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs): + from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import ( + padded_overlap_prepare_decode_inputs, + ) + + ( + micro_batch, + run_reqs, + padded_req_num, + micro_batch1, + run_reqs1, + padded_req_num1, + ) = padded_overlap_prepare_decode_inputs(decode_reqs, max_decode_num, is_multimodal=self.is_multimodal) + + logits, logits1 = self.model.microbatch_overlap_decode(micro_batch, micro_batch1) + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + req_num, req_num1 = len(run_reqs), len(run_reqs1) + all_logits = torch.empty((req_num + req_num1, logits.shape[1]), dtype=logits.dtype, device=logits.device) + + all_logits[0:req_num, :].copy_(logits[0:req_num, :], non_blocking=True) + all_logits[req_num : (req_num + req_num1), :].copy_(logits1[0:req_num1, :], non_blocking=True) + + all_run_reqs = run_reqs + run_reqs1 + if all_run_reqs: + next_token_ids, next_token_probs = sample(all_logits, all_run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + self._post_handle( + all_run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False + ) + return diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py new file mode 100644 index 000000000..3e72f73ce --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py @@ -0,0 +1,109 @@ +import threading +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +from typing import List, Tuple +from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end +from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context +from lightllm.utils.log_utils import init_logger +from lightllm.server.router.model_infer.mode_backend.generic_pre_process import prepare_prefill_inputs +from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import padded_prepare_prefill_inputs + +from .impl_for_pd_base import PDNIXLBackendBase +from .impl_for_pd_prefill import PDNIXLBackendForPrefillNode + +logger = init_logger(__name__) + + +class PDNIXLDPBackendForPrefillNode(PDNIXLBackendForPrefillNode): + def __init__(self, transfer_task_queue: mp.Queue, transfer_done_queue: mp.Queue, nixl_meta_queue: mp.Queue) -> None: + super().__init__(transfer_task_queue, transfer_done_queue, nixl_meta_queue) + self.enable_prefill_microbatch_overlap = get_env_start_args().enable_prefill_microbatch_overlap + + def init_custom(self): + super().init_custom() + self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) + return + + def decode(self): + uinit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids, + no_decode=True, + ) + + ok_finished_reqs, aborted_reqs, _ = self._prefill_filter_reqs(ok_finished_reqs, aborted_reqs) + + assert len(uinit_reqs) == 0 + assert len(decode_reqs) == 0 + + self._prefill_abort_remote(aborted_reqs) + self._filter_reqs(aborted_reqs) + + if ok_finished_reqs: + for req in ok_finished_reqs: + self._transfer_kv_to_remote(req) + self._filter_reqs(ok_finished_reqs) + ok_finished_reqs.clear() + + current_dp_prefill_num = len(prefill_reqs) + self.reduce_tensor.fill_(current_dp_prefill_num) + dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) + max_prefill_num = self.reduce_tensor.item() + if max_prefill_num != 0: + if not self.enable_prefill_microbatch_overlap: + self.normal_prefill_reqs(prefill_reqs, max_prefill_num, uinit_reqs, ok_finished_reqs) + else: + self.overlap_prefill_reqs(prefill_reqs, max_prefill_num, uinit_reqs, ok_finished_reqs) + + self._overlap_req_init_and_filter(uninit_reqs=uinit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + return + + def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs): + + kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs( + prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal + ) + logits = self.model.forward(**kwargs) + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + if len(run_reqs) != 0: + logits = logits[0 : len(run_reqs), :] + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + self._post_handle( + run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False, + extra_post_req_handle_func=lambda req, _1, _2: self._transfer_kv_to_remote(req), + ) + + def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs): + from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import ( + padded_overlap_prepare_prefill_inputs, + ) + + ( + micro_batch, + run_reqs, + padded_req_num, + micro_batch1, + run_reqs1, + padded_req_num1, + ) = padded_overlap_prepare_prefill_inputs(prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal) + logits, logits1 = self.model.microbatch_overlap_prefill(micro_batch, micro_batch1) + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + req_num, req_num1 = len(run_reqs), len(run_reqs1) + all_logits = torch.empty((req_num + req_num1, logits.shape[1]), dtype=logits.dtype, device=logits.device) + + all_logits[0:req_num, :].copy_(logits[0:req_num, :], non_blocking=True) + all_logits[req_num : (req_num + req_num1), :].copy_(logits1[0:req_num1, :], non_blocking=True) + + all_run_reqs = run_reqs + run_reqs1 + if all_run_reqs: + next_token_ids, next_token_probs = sample(all_logits, all_run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + self._post_handle( + all_run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False, + extra_post_req_handle_func=lambda req, _1, _2: self._transfer_kv_to_remote(req), + ) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py index 2e5fa7604..e07a46e50 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py @@ -1,13 +1,14 @@ from typing import List, Any import zmq import inspect -import pickle +import random import torch.multiprocessing as mp from lightllm.utils.log_utils import init_logger from lightllm.utils.net_utils import get_hostname_ip from lightllm.utils.graceful_utils import graceful_registry +from lightllm.server.pd_io_struct import DistInfo from .pd_remote_prefill_obj import ( ConnectRequest, @@ -18,41 +19,25 @@ RemotePrefillServerInfo, RemotePrefillTask, RemotePrefillStatus, + SockWithPoller, ) from .nixl_kv_transporter import NixlMetadata - logger = init_logger(__name__) -class SockWithPoller: - def __init__(self, sock: zmq.Socket): - self.sock = sock - self.poller = zmq.Poller() - self.poller.register(self.sock, zmq.POLLIN) - - def recv_pyobj(self, timeout: int = 5): - socks = dict(self.poller.poll(timeout * 1000)) - if socks: - if socks.get(self.sock) == zmq.POLLIN: - return self.sock.recv_pyobj() - else: - None - - def send_pyobj(self, obj: Any): - return self.sock.send_pyobj(obj) - - class PDRemotePrefillBase: def __init__( self, id: int, + dist_info: DistInfo, from_backend_queue: mp.Queue, to_backend_queues: List[mp.Queue], agent_meta_queues: List[mp.Queue], # need send kv cache to this process and register with nixl ): self.id = id - self.tp_size = len(agent_meta_queues) + self.dist_info = dist_info + assert len(agent_meta_queues) == dist_info.node_world_size self.agent_meta_queues = agent_meta_queues self.from_backend_queue = from_backend_queue self.to_backend_queues = to_backend_queues @@ -60,7 +45,7 @@ def __init__( def local_init(self): agent_metas = NixlMetadata(id=self.id, agent_metadatas=[], num_tokens=[], agent_mem_descs=[]) - for tp in range(self.tp_size): + for tp in range(self.dist_info.node_world_size): agent_metadata, num_tokens, mem_desc = self.agent_meta_queues[tp].get(timeout=60) logger.info(f"Received agent_metadata from {tp} with mem reg: {mem_desc}") agent_metas.num_tokens.append(num_tokens) @@ -75,41 +60,40 @@ class PDRemotePrefillServer(PDRemotePrefillBase): def __init__( self, id: int, + dist_info: DistInfo, http_server_port: int, server_port: int, from_backend_queue: mp.Queue, to_backend_queues: List[mp.Queue], agent_meta_queues: List[mp.Queue], ): - super().__init__(id, from_backend_queue, to_backend_queues, agent_meta_queues) + super().__init__(id, dist_info, from_backend_queue, to_backend_queues, agent_meta_queues) # map from client id to decode server info self.remote_decode_clients = {} # build control path _ctx = zmq.Context() - self.recv_from_decode = _ctx.socket(zmq.ROUTER) + self.recv_from_decode = SockWithPoller(_ctx.socket(zmq.ROUTER)) self.host_ip = get_hostname_ip() self.recv_from_decode.bind(f"tcp://{self.host_ip}:{server_port}") # build trigger remote prefill path - self.send_to_httpserver = _ctx.socket(zmq.PUSH) + self.send_to_httpserver = SockWithPoller(_ctx.socket(zmq.PUSH)) self.send_to_httpserver.connect(f"tcp://{self.host_ip}:{http_server_port}") - self.prefill_requests = {} - def main_loop(self): self.local_init() while True: try: - client_obj, msg = self.recv_from_decode.recv_multipart() - request: RemoteRequest = pickle.loads(msg) + client_obj, request = self.recv_from_decode.recv_pyobj_multipart() + request: RemoteRequest logger.info(f"recevied request from decode, type: {request.type}") - # forward request to prefill server - for queue in self.to_backend_queues: - queue.put(request) - if request.type == RemoteRequstType.REMOTE_CONNECT: + # forward request to all prefill server + for queue in self.to_backend_queues: + queue.put(request) + success = True for idx in range(self.tp_size): ack = self.from_backend_queue.get() @@ -117,33 +101,45 @@ def main_loop(self): success = False break - self.recv_from_decode.send_multipart([client_obj, pickle.dumps(success)]) + self.recv_from_decode.send_pyobj_multipart(client_obj, success) if not success: logger.warning(f"Remote connect failed: {request}") if request.type == RemoteRequstType.REMOTE_PREFILL: request: PrefillRequest = request - self.trigger_prefill(request) + if self.dist_info.dp_size_in_node > 1: + group_req_id = request.data.sampling_params.group_request_id + suggested_dp_index = request.data.sampling_params.suggested_dp_index + if suggested_dp_index < 0: # not likely to happen + suggested_dp_index = random.randint(0, self.dist_info.dp_size_in_node) + request.data.sampling_params.suggested_dp_index = suggested_dp_index + logger.warning(f"Suggested dp index is negative for {group_req_id}, set to {suggested_dp_index}") + + for local_rank in range(suggested_dp_index * self.dist_info.dp_world_size, + (suggested_dp_index + 1) * self.dist_info.dp_world_size): + self.to_backend_queues[local_rank].put(request) + else: + for queue in self.to_backend_queues: + queue.put(request) + + self.send_to_httpserver.send_pyobj( + (request.data.prompt, request.data.sampling_params, request.data.multimodal_params) + ) except Exception as e: logger.error(f"Error in remote prefill server loop: {e}", exc_info=e) - def trigger_prefill(self, request: PrefillRequest): - self.send_to_httpserver.send_pyobj( - (request.data.prompt, request.data.sampling_params, request.data.multimodal_params) - ) - self.prefill_requests[request.data.sampling_params.group_request_id] = request - class PDRemotePrefillClient(PDRemotePrefillBase): def __init__( self, id: int, + dist_info: DistInfo, from_backend_queue: mp.Queue, # only tp0 will trigger prefill to_backend_queues: List[mp.Queue], # one to many done queue agent_meta_queues: List[mp.Queue], ): - super().__init__(id, from_backend_queue, to_backend_queues, agent_meta_queues) + super().__init__(id, dist_info, from_backend_queue, to_backend_queues, agent_meta_queues) # map from server id to prefill server info self.remote_prefill_servers = {} @@ -221,6 +217,7 @@ def remote_prefill(self, server_id: int, prefill_request: RemotePrefillRequest): def remote_prefill_server_loop( id: int, + dist_info: DistInfo, http_server_port: int, server_port: int, from_backend_queue: mp.Queue, @@ -229,13 +226,14 @@ def remote_prefill_server_loop( ): graceful_registry(inspect.currentframe().f_code.co_name) server = PDRemotePrefillServer( - id, http_server_port, server_port, from_backend_queue, to_backend_queues, agent_meta_queues + id, dist_info, http_server_port, server_port, from_backend_queue, to_backend_queues, agent_meta_queues ) server.main_loop() def start_pd_remote_prefill_server_process( id: int, + dist_info: DistInfo, http_server_port: int, server_port: int, from_backend_queue: mp.Queue, @@ -244,7 +242,7 @@ def start_pd_remote_prefill_server_process( ): proc = mp.Process( target=remote_prefill_server_loop, - args=(id, http_server_port, server_port, from_backend_queue, to_backend_queues, agent_meta_queues), + args=(id, dist_info, http_server_port, server_port, from_backend_queue, to_backend_queues, agent_meta_queues), ) proc.start() assert proc.is_alive() @@ -253,12 +251,13 @@ def start_pd_remote_prefill_server_process( def remote_prefill_client_loop( - id: int, from_backend_queue: mp.Queue, to_backend_queues: List[mp.Queue], agent_meta_queues: List[mp.Queue] + id: int, dist_info: DistInfo, from_backend_queue: mp.Queue, to_backend_queues: List[mp.Queue], agent_meta_queues: List[mp.Queue] ): graceful_registry(inspect.currentframe().f_code.co_name) client = PDRemotePrefillClient( id, + dist_info, from_backend_queue, to_backend_queues, agent_meta_queues, @@ -267,11 +266,11 @@ def remote_prefill_client_loop( def start_pd_remote_prefill_client_process( - id: int, from_backend_queue: mp.Queue, to_backend_queues: List[mp.Queue], agent_meta_queues: List[mp.Queue] + id: int, dist_info: DistInfo, from_backend_queue: mp.Queue, to_backend_queues: List[mp.Queue], agent_meta_queues: List[mp.Queue] ): proc = mp.Process( - target=remote_prefill_client_loop, args=(id, from_backend_queue, to_backend_queues, agent_meta_queues) + target=remote_prefill_client_loop, args=(id, dist_info, from_backend_queue, to_backend_queues, agent_meta_queues) ) proc.start() assert proc.is_alive() diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py index 308fc8b19..17cb64e10 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py @@ -3,6 +3,8 @@ import json from typing import List, Union, Optional, Any from threading import Lock +import pickle +import zmq from lightllm.utils.log_utils import init_logger from lightllm.server.core.objs import SamplingParams @@ -76,6 +78,8 @@ class RemoteAgent: class RemotePrefillStatus: group_req_id: int status: int + chunk_id: int + is_last: bool def serialize(self): return json.dumps(asdict(self)).encode() @@ -133,3 +137,35 @@ def values(self): def clear(self) -> None: with self._lock: self._dict.clear() + + +class SockWithPoller: + def __init__(self, sock: zmq.Socket): + self.sock = sock + self.poller = zmq.Poller() + self.poller.register(self.sock, zmq.POLLIN) + + def recv_pyobj(self, timeout: int = 5): + socks = dict(self.poller.poll(timeout * 1000)) + if socks: + if socks.get(self.sock) == zmq.POLLIN: + return self.sock.recv_pyobj() + else: + None + + def send_pyobj(self, obj: Any): + return self.sock.send_pyobj(obj) + + def recv_pyobj_multipart(self): + client_id, data = self.sock.recv_multipart() + return client_id, pickle.loads(data) + + def send_pyobj_multipart(self, client_id: str, data: Any): + return self.sock.send_multipart([client_id, pickle.dumps(data)]) + + def bind(self, addr: str): + return self.sock.bind(addr) + + def connect(self, addr: str): + return self.sock.connect(addr) + diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 564aea3a2..e20d1bffd 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -23,6 +23,8 @@ DPChunkedForPrefillNode, PDNIXLBackendForPrefillNode, PDNIXLBackendForDecodeNode, + PDNIXLDPBackendForPrefillNode, + PDNIXLDPBackendForDecodeNode, ) from lightllm.server.core.objs import RpcShmParams, RpcShmResults, ShmSyncStatusArray from lightllm.utils.log_utils import init_logger @@ -141,8 +143,11 @@ def init_model(self, kvargs): self.backend = ChunckedPrefillForPrefillNode(self.info_queue, self.mem_queue) elif is_nixl_prefill_node: - assert kvargs.get("args", None).dp == 1 - self.backend = PDNIXLBackendForPrefillNode(self.info_queue, self.result_queue, self.mem_queue) + if kvargs.get("args", None).dp > 1: + self.backend = PDNIXLDPBackendForPrefillNode(self.info_queue, self.result_queue, self.mem_queue) + else: + self.backend = PDNIXLBackendForPrefillNode(self.info_queue, self.result_queue, self.mem_queue) + elif is_decode_node: if kvargs.get("args", None).dp > 1: self.backend = DPForDecodeNode(self.info_queue, self.mem_queue) @@ -150,8 +155,11 @@ def init_model(self, kvargs): self.backend = ContinuesBatchBackendForDecodeNode(self.info_queue, self.mem_queue) elif is_nixl_decode_node: - assert kvargs.get("args", None).dp == 1 - self.backend = PDNIXLBackendForDecodeNode(self.info_queue, self.result_queue, self.mem_queue) + if kvargs.get("args", None).dp > 1: + self.backend = PDNIXLDPBackendForDecodeNode(self.info_queue, self.result_queue, self.mem_queue) + else: + self.backend = PDNIXLBackendForDecodeNode(self.info_queue, self.result_queue, self.mem_queue) + elif kvargs.get("dp_size", 1) > 1: self.backend = DPChunkedPrefillBackend() elif use_reward_model: From 7daf76886ab44610866e125814bf5c25ad33cd89 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Thu, 24 Apr 2025 19:51:33 +0800 Subject: [PATCH 11/19] chunk transfer. --- .../mode_backend/pd_nixl/impl_for_pd_base.py | 61 +++++++++---- .../pd_nixl/nixl_kv_transporter.py | 89 ++++++++++++++----- .../mode_backend/pd_nixl/pd_remote_prefill.py | 9 +- .../pd_nixl/pd_remote_prefill_obj.py | 20 ++++- 4 files changed, 134 insertions(+), 45 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py index 6bd04dcca..a3d27ed9f 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py @@ -19,6 +19,7 @@ KVMoveRequest, RemotePrefillStatus, ThreadSafeDict, + TransferState, ) logger = init_logger(__name__) @@ -55,15 +56,17 @@ def handle_remote_prefill(req_status: RemotePrefillStatus): status = req_status.status if status != 1: logger.warning(f"remote prefill reqeust: {group_req_id} done with state: {status}") + if run_req := self.remote_prefilled_reqs.get(group_req_id, None): - shm_req: PDChunkedPrefillReq = run_req.shm_req - shm_req.set_pd_req_rank_state(self.rank_in_dp, status) - self.remote_prefilled_reqs.pop(group_req_id) - if self.is_master_in_dp: - logger.info( - f"remote prefill reqeust: {group_req_id} done with status: {status} " - f"took: {time.time() - run_req.remote_prefill_start} seconds" - ) + if req_status.is_last or status != 1: + shm_req: PDChunkedPrefillReq = run_req.shm_req + shm_req.set_pd_req_rank_state(self.rank_in_dp, status) + self.remote_prefilled_reqs.pop(group_req_id) + if self.is_master_in_dp: + logger.info( + f"remote prefill reqeust: {group_req_id} done with status: {status} " + f"took: {time.time() - run_req.remote_prefill_start} seconds" + ) else: if self.is_master_in_dp: logger.warning(f"remote prefill reqeust: {group_req_id} not found") @@ -100,12 +103,15 @@ def _wait_transfer_loop(self): req: InferReq = self.inflght_transfer_requests[req_id] shm_req: PDChunkedPrefillReq = req.shm_req shm_req.set_pd_req_rank_state(self.rank_in_dp, state) - del self.inflght_transfer_requests[req_id] + transfer_state = self.remote_prefill_requests[req_id].transfer_state if self.is_master_in_dp: logger.info( f"req: {req_id} kv transfer with state: {state} " - f"took: {time.time() - req.kv_transfer_start} seconds" + f"took: {time.time() - transfer_state.start_time} seconds" ) + del self.remote_prefill_requests[req_id] + del self.inflght_transfer_requests[req_id] + time.sleep(PDNIXLBackendBase._THEAD_WAIT_INTERVAL) def _handle_prefill_loop(self): @@ -140,20 +146,37 @@ def _transfer_kv_to_remote(self, req: InferReq): logger.info(f"remote prefill request {group_req_id} not found") return - # kick off kv transfer - if req.finish_status.is_finished(): - req.kv_transfer_start = time.time() - kv_transfer_req = KVMoveRequest( - group_req_id=group_req_id, - token_ids=self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].tolist(), + remote_request: PrefillRequest = self.remote_prefill_requests[group_req_id] + if remote_request.transfer_state is None: + remote_request.transfer_state = TransferState( + start_time=time.time(), + current_kv_len=0, + current_chunk_id=0, ) - remote_request = self.remote_prefill_requests[group_req_id] - self.nixl_agent.write_blocks(kv_transfer_req, remote_request) + + transfer_state = remote_request.transfer_state + token_index = self.model.req_manager.req_to_token_indexs[req.req_idx] + is_finished = req.finish_status.is_finished() + + kv_transfer_req = KVMoveRequest( + group_req_id=group_req_id, + token_ids=token_index[ : req.cur_kv_len].tolist(), + prev_kv_len=transfer_state.current_kv_len, + cur_kv_len=req.cur_kv_len, + ) + # kick off kv transfer + self.nixl_agent.write_blocks(kv_transfer_req, remote_request, is_finished) + + if transfer_state.current_chunk_id == 0: shm_req: PDChunkedPrefillReq = req.shm_req shm_req.set_pd_req_rank_state(self.rank_in_dp, 0) - req.kv_transfering = True + req.in_prefill_or_transfer = True self.inflght_transfer_requests[group_req_id] = req + transfer_state.current_kv_len = req.cur_kv_len + transfer_state.current_chunk_id += 1 + + def _decode_filter_reqs( self, prefill_reqs: List[InferReq], aborted_reqs: List[InferReq], decode_reqs: List[InferReq] ): diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py index 7ae9a4bf6..bbfe5df3f 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py @@ -6,7 +6,10 @@ from lightllm.utils.log_utils import init_logger -from .pd_remote_prefill_obj import RemoteAgent, KVMoveRequest, PrefillRequest, RemotePrefillStatus, ThreadSafeDict +from .pd_remote_prefill_obj import ( + RemoteAgent, KVMoveRequest, PrefillRequest, + RemotePrefillStatus, ThreadSafeDict, KVMoveRequestState + ) logger = init_logger(__name__) @@ -108,11 +111,20 @@ def _get_token_desc_ids(self, token_ids: List[int]): descs_ids.append(layer_id * self.num_tokens + token_id) return descs_ids - def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest): + def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest, is_finished: bool): group_reqeust_id = request.group_req_id skip_kv_move_len = prefill_request.data.local_cached_len - src_token_ids = request.token_ids[skip_kv_move_len:] - dst_token_ids = prefill_request.data.token_ids + + # current kv len is less than remote cached kv len, just skip + if request.cur_kv_len <= skip_kv_move_len: + return + + kv_move_start = max(skip_kv_move_len, request.prev_kv_len) + kv_move_end = request.cur_kv_len + + src_token_ids = request.token_ids[kv_move_start:] + dst_token_ids = prefill_request.data.token_ids[kv_move_start - skip_kv_move_len: kv_move_end] + remote_agent: RemoteAgent = self.remote_agents[prefill_request.decode_id][ self.tp_idx ] # TODO one-one mapping now @@ -124,7 +136,12 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest): src_handle = self.local_xfer_handles dst_handle = remote_agent.kv_xfer_handles - notify_status = RemotePrefillStatus(group_req_id=group_reqeust_id, status=1) + notify_status = RemotePrefillStatus( + group_req_id=group_reqeust_id, + status=1, + chunk_id=prefill_request.transfer_state.current_chunk_id, + is_last=is_finished) + handle = self.nixl_agent.make_prepped_xfer( "WRITE", src_handle, src_token_descs, dst_handle, dst_token_descs, notify_status.serialize() ) @@ -132,7 +149,14 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest): status = self.nixl_agent.transfer(handle) assert status != "ERR" - self.inflight_transfers[group_reqeust_id] = (handle, remote_agent, False) + if group_reqeust_id not in self.inflight_transfers: + self.inflight_transfers[group_reqeust_id] = KVMoveRequestState( + handles=[], + done_handles=[], + remote_agent=remote_agent, + abort=False + ) + self.inflight_transfers[group_reqeust_id].handles.append(handle) return handle @@ -140,36 +164,57 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest): def send_abort_notify(self, remote_id: int, group_reqeust_id): remote_agent: RemoteAgent = self.remote_agents[remote_id][self.tp_idx] - notify_status = RemotePrefillStatus(group_req_id=group_reqeust_id, status=-1) + notify_status = RemotePrefillStatus(group_req_id=group_reqeust_id, status=-1, chunk_id=-1, is_last=True) self.nixl_agent.send_notif(remote_agent.name, notify_status.serialize()) if group_reqeust_id in self.inflight_transfers: - self.inflight_transfers[group_reqeust_id][2] = True + self.inflight_transfers[group_reqeust_id].abort = True def get_done_tranfers(self): done_req_ids = [] - for req_id, (handle, remote_agent, is_abort) in self.inflight_transfers.items(): - if is_abort: + for req_id, kv_move_state in self.inflight_transfers.items(): + kv_move_state: KVMoveRequestState + if kv_move_state.abort: logger.warning(f"{req_id} Transfer aborted") done_req_ids.append((req_id, -1)) continue - remote_agent: RemoteAgent - xfer_state = self.nixl_agent.check_xfer_state(handle) - if xfer_state == "DONE": - done_req_ids.append((req_id, 1)) - elif xfer_state == "PROC": - continue - else: - logger.warning(f"{req_id} Transfer failed with state {xfer_state}") + remote_agent: RemoteAgent = kv_move_state.remote_agent + + left_handles = [] + failed = False + for handle in kv_move_state.handles: + if failed: + left_handles.append(handle) + continue + + xfer_state = self.nixl_agent.check_xfer_state(handle) + + if xfer_state == "DONE": + kv_move_state.done_handles.append(handle) + elif xfer_state == "PROC": + left_handles.append(handle) + else: + logger.warning(f"{req_id} Transfer failed with state {xfer_state}") + failed = True + kv_move_state.done_handles.append(handle) + notify_failed_status = RemotePrefillStatus(group_req_id=req_id, status=-1, chunk_id=-1, is_last=True) + self.nixl_agent.send_notif(remote_agent.name, notify_failed_status.serialize()) + + kv_move_state.handles = left_handles + + if failed: done_req_ids.append((req_id, -1)) - notify_failed_status = RemotePrefillStatus(group_req_id=req_id, status=-1) - self.nixl_agent.send_notif(remote_agent.name, notify_failed_status.serialize()) + elif len(left_handles) == 0: + done_req_ids.append((req_id, 1)) for req_id, _ in done_req_ids: - # release will abort inflight transfer - self.nixl_agent.release_xfer_handle(self.inflight_transfers[req_id][0]) + kv_move_state: KVMoveRequestState = self.inflight_transfers[req_id] + for handle in kv_move_state.handles + kv_move_state.done_handles: + # release will abort inflight transfer + self.nixl_agent.release_xfer_handle(handle) + del self.inflight_transfers[req_id] return done_req_ids diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py index e07a46e50..79490a703 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py @@ -95,13 +95,15 @@ def main_loop(self): queue.put(request) success = True - for idx in range(self.tp_size): + for idx in range(self.dist_info.node_world_size): ack = self.from_backend_queue.get() + logger.info(f"received ack from backend {idx}: {ack}") if ack != "OK": success = False break self.recv_from_decode.send_pyobj_multipart(client_obj, success) + logger.info(f"Sent ack to decode: {success}") if not success: logger.warning(f"Remote connect failed: {request}") @@ -166,6 +168,7 @@ def _send_nixl_agent(self, socket: SockWithPoller): ) success = socket.recv_pyobj(timeout=60) + logger.info(f"recv remote nixl connect response {success}") if success is None: logger.warning("timeout to recv remote nixl connect response") return False @@ -203,6 +206,8 @@ def main_loop(self): RemotePrefillStatus( group_req_id=prefill_tasks.prefill_request.sampling_params.group_request_id, status=-1, + chunk_id=-1, + is_last=True, ) ) except Exception as e: @@ -212,7 +217,7 @@ def main_loop(self): def remote_prefill(self, server_id: int, prefill_request: RemotePrefillRequest): socket, _ = self.remote_prefill_servers[server_id] prefill_request.sampling_params.max_new_tokens = 1 - socket.send_pyobj(PrefillRequest(type=RemoteRequstType.REMOTE_PREFILL, decode_id=self.id, data=prefill_request)) + socket.send_pyobj(PrefillRequest(type=RemoteRequstType.REMOTE_PREFILL, decode_id=self.id, data=prefill_request, transfer_state=None)) def remote_prefill_server_loop( diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py index 17cb64e10..32d9aa1ad 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py @@ -14,7 +14,7 @@ logger = init_logger(__name__) try: - from nixl._api import nixlBind, nixl_prepped_dlist_handle + from nixl._api import nixlBind, nixl_prepped_dlist_handle, nixl_xfer_handle except ImportError: logger.error("nixl is not installed, which is required for pd disagreggation!!!") @@ -53,17 +53,26 @@ class ConnectRequest(RemoteRequest): agent_metadatas: List[bytes] agent_mem_descs: List[bytes] +@dataclass +class TransferState: + start_time: float + current_kv_len: int + current_chunk_id: int @dataclass class PrefillRequest(RemoteRequest): decode_id: int data: RemotePrefillRequest + # transfer status + transfer_state: Optional[TransferState] @dataclass class KVMoveRequest: group_req_id: int token_ids: List[int] + prev_kv_len: int + cur_kv_len: int @dataclass @@ -73,6 +82,13 @@ class RemoteAgent: kv_mem_desc: nixlBind.nixlRegDList kv_xfer_handles: nixl_prepped_dlist_handle +@dataclass +class KVMoveRequestState: + handles: List[nixl_xfer_handle] + done_handles: List[nixl_xfer_handle] + remote_agent: RemoteAgent + abort: bool + @dataclass class RemotePrefillStatus: @@ -160,7 +176,7 @@ def recv_pyobj_multipart(self): client_id, data = self.sock.recv_multipart() return client_id, pickle.loads(data) - def send_pyobj_multipart(self, client_id: str, data: Any): + def send_pyobj_multipart(self, client_id: bytes, data: Any): return self.sock.send_multipart([client_id, pickle.dumps(data)]) def bind(self, addr: str): From 046db94208b8462a02733e1bf059931372cf9580 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Tue, 20 May 2025 14:56:31 +0800 Subject: [PATCH 12/19] fix max_new_token. --- lightllm/server/httpserver_for_pd_master/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 090b142e7..d97770afb 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -272,7 +272,7 @@ async def fetch_stream_nixl( "node_id": p_start_args["pd_node_id"], "ip": p_start_args["host"], "rpyc_port": p_start_args["pd_remote_prefill_port"], - "max_new_tokens": sampling_params.max_new_tokens - 1, + "max_new_tokens": sampling_params.max_new_tokens, "pd_master_node_id": self.args.pd_node_id, } From c6871e14ebd90f58ceb319bf562e1159f569c90e Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Wed, 28 May 2025 15:40:59 +0800 Subject: [PATCH 13/19] rename args and update request state with list. --- lightllm/server/api_cli.py | 4 +- lightllm/server/core/objs/__init__.py | 2 +- lightllm/server/core/objs/req.py | 63 ++++++++++--------- lightllm/server/core/objs/shm_req_manager.py | 4 +- lightllm/server/httpserver/pd_loop.py | 2 +- .../httpserver_for_pd_master/manager.py | 2 +- lightllm/server/router/batch.py | 5 +- lightllm/server/router/manager.py | 10 +-- .../server/router/model_infer/infer_batch.py | 5 +- .../mode_backend/pd_nixl/impl_for_pd_base.py | 14 ++--- .../pd_nixl/impl_for_pd_decode.py | 4 +- .../pd_nixl/impl_for_pd_decode_dp.py | 4 +- 12 files changed, 61 insertions(+), 58 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 640d19f7a..7b5b021f6 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -55,14 +55,14 @@ def make_argument_parser() -> argparse.ArgumentParser: help="The port number for the config server in config_server mode.", ) parser.add_argument( - "--pd_remote_prefill_http_port", + "--pd_nixl_remote_prefill_http_port", type=int, default=42001, help="nixl pd mode, prefill node used for triggering prefill http port.", ) parser.add_argument( - "--pd_remote_prefill_port", + "--pd_nixl_remote_prefill_port", type=int, default=42002, help="nixl pd mode, prefill and decode used for meta exchange.", diff --git a/lightllm/server/core/objs/__init__.py b/lightllm/server/core/objs/__init__.py index 2bdd9a51c..0096d7baa 100644 --- a/lightllm/server/core/objs/__init__.py +++ b/lightllm/server/core/objs/__init__.py @@ -1,4 +1,4 @@ from .sampling_params import SamplingParams -from .req import Req, FinishStatus, PDChunkedPrefillReq +from .req import Req, FinishStatus, PDNIXLChunkedPrefillReq from .shm_req_manager import ShmReqManager from .rpc_shm import RpcShmParams, RpcShmResults, ShmSyncStatusArray diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 24538e714..5d9789fc8 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -94,7 +94,6 @@ class Req(ctypes.Structure): ("reward_score", ctypes.c_float), # 请求回复累计概率和 ("cumlogprob", ctypes.c_float), - ("dp_world_size", ctypes.c_int), ] def get_str(self): @@ -337,44 +336,50 @@ def post_init( return -class PDChunkedPrefillReq(ChunkedPrefillReq): +class PdNixlReqState(ctypes.Structure): _pack_ = 4 - _MAX_TP_SIZE = 128 - - def post_init(self): - super().post_init() - self.create_pd_req_state_shm_array() + _MAX_TP_SIZE = 32 + _fields_ = [("dp_world_size", ctypes.c_int), ("state", ctypes.c_int * _MAX_TP_SIZE)] + def __init__(self): self.dp_world_size = 0 + self.state = (ctypes.c_int * self._MAX_TP_SIZE)(*([0] * self._MAX_TP_SIZE)) - def create_pd_req_state_shm_array(self): - service_uni_name = get_unique_server_name() - name = f"{service_uni_name}_shm_pd_req_state_{self.index_in_shm_mem}" - # self.dp_world_size = PDChunkedPrefillReq._MAX_TP_SIZE # get_dp_world_size() - self.pd_req_state_shm = ShmArray(name, (PDChunkedPrefillReq._MAX_TP_SIZE + 1,), dtype=np.int8) - self.pd_req_state_shm.create_shm() - self.pd_req_state_shm.arr.fill(0) - return + def set_dp_world_size(self, size: int): + self.dp_world_size = size - def link_pd_req_state_shm_array(self): - service_uni_name = get_unique_server_name() - # self.dp_world_size = PDChunkedPrefillReq._MAX_TP_SIZE #get_dp_world_size() - name = f"{service_uni_name}_shm_pd_req_state_{self.index_in_shm_mem}" - self.pd_req_state_shm = ShmArray(name, (PDChunkedPrefillReq._MAX_TP_SIZE + 1,), dtype=np.int8) - self.pd_req_state_shm.link_shm() - return + def set_tp_state(self, tp_id: int, state: int): + assert self.dp_world_size > 0 and tp_id >= 0 and tp_id < self.dp_world_size, \ + f"tp_id {tp_id} out of range [0, {self.dp_world_size})" + self.state[tp_id] = state + + def set_state(self): + assert self.dp_world_size > 0, "dp_world_size should be set before calling this" + unique_state = np.unique(self.state[: self.dp_world_size]) + self.state[self.dp_world_size] = unique_state[0] - # called by each tp rank, no contention + def get_state(self): + assert self.dp_world_size > 0, "dp_world_size should be set before calling this" + return self.state[self.dp_world_size] + + +class PDNIXLChunkedPrefillReq(ChunkedPrefillReq): + _pack_ = 4 + _fields_ = ChunkedPrefillReq._fields_ + [ + # 用于pd nixl状态同步 + ("pd_nixl_req_state", PdNixlReqState)] + + def set_dp_world_size(self, dp_world_size): + self.pd_nixl_req_state.dp_world_size = dp_world_size + + # called by each tp rank, no contention def set_pd_req_rank_state(self, tp_id: int, state: int): - self.pd_req_state_shm.arr[tp_id] = state + self.pd_nixl_req_state.set_tp_state(tp_id, state) # state: -1 for failed, 0 for in progress, 1 for success # set by router def set_pd_req_state(self): - assert self.dp_world_size > 0, "dp_world_size should be set before calling this" - unique_state = np.unique(self.pd_req_state_shm.arr[: self.dp_world_size]) - self.pd_req_state_shm.arr[self.dp_world_size] = unique_state[0] + self.pd_nixl_req_state.set_state() # read by all rank def get_pd_req_state(self): - assert self.dp_world_size > 0, "dp_world_size should be set before calling this" - return self.pd_req_state_shm.arr[self.dp_world_size] + return self.pd_nixl_req_state.get_state() \ No newline at end of file diff --git a/lightllm/server/core/objs/shm_req_manager.py b/lightllm/server/core/objs/shm_req_manager.py index fd7901158..5dd73ad02 100644 --- a/lightllm/server/core/objs/shm_req_manager.py +++ b/lightllm/server/core/objs/shm_req_manager.py @@ -3,7 +3,7 @@ from lightllm.utils.envs_utils import get_unique_server_name from multiprocessing import shared_memory from lightllm.utils.log_utils import init_logger -from .req import Req, NormalReq, ChunkedPrefillReq, TokenHealingReq, PDChunkedPrefillReq +from .req import Req, NormalReq, ChunkedPrefillReq, TokenHealingReq, PDNIXLChunkedPrefillReq from .shm_array import ShmArray from .atomic_array_lock import AtomicShmArrayLock, AtomicLockItem from .atomic_lock import AtomicShmLock @@ -34,7 +34,7 @@ def get_req_class_type(self): return TokenHealingReq if args.run_mode in ["nixl_prefill", "nixl_decode"]: - return PDChunkedPrefillReq + return PDNIXLChunkedPrefillReq if args.disable_chunked_prefill: return NormalReq diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index 7eb95145a..94394182d 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -194,7 +194,7 @@ async def pd_handle_loop_from_d(manager: HttpServerManager): context = zmq.asyncio.Context(2) manager.recv_from_d = context.socket(zmq.PULL) - manager.recv_from_d.bind(f"tcp://*:{manager.args.pd_remote_prefill_http_port}") + manager.recv_from_d.bind(f"tcp://*:{manager.args.pd_nixl_remote_prefill_http_port}") while True: try: diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index d97770afb..65bec6a1c 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -271,7 +271,7 @@ async def fetch_stream_nixl( prefill_node_dict = { "node_id": p_start_args["pd_node_id"], "ip": p_start_args["host"], - "rpyc_port": p_start_args["pd_remote_prefill_port"], + "rpyc_port": p_start_args["pd_nixl_remote_prefill_port"], "max_new_tokens": sampling_params.max_new_tokens, "pd_master_node_id": self.args.pd_node_id, } diff --git a/lightllm/server/router/batch.py b/lightllm/server/router/batch.py index 0b98d40af..5b8edeb29 100644 --- a/lightllm/server/router/batch.py +++ b/lightllm/server/router/batch.py @@ -1,7 +1,7 @@ import time from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union -from lightllm.server.core.objs import ShmReqManager, Req, PDChunkedPrefillReq +from lightllm.server.core.objs import ShmReqManager, Req, PDNIXLChunkedPrefillReq from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_dp_world_size @@ -54,8 +54,7 @@ def filter_out_finished_req(self, shm_req_manager: ShmReqManager): req = None else: unfinished_req_ids.append(req.request_id) - if isinstance(req, PDChunkedPrefillReq): - req.link_pd_req_state_shm_array() + if isinstance(req, PDNIXLChunkedPrefillReq): req.set_pd_req_state() self.reqs = [self.id_to_reqs[req_id] for req_id in unfinished_req_ids] diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 0cdff2d44..7f02c1e72 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -22,7 +22,7 @@ from .req_queue import build_req_queue from lightllm.utils.infer_utils import calculate_time from lightllm.server.core.objs.io_objs import GroupReqIndexes -from lightllm.server.core.objs import ShmReqManager, PDChunkedPrefillReq +from lightllm.server.core.objs import ShmReqManager, PDNIXLChunkedPrefillReq from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient from .stats import Stats from .pause_strategy import Fcfs, select_paused_reqs @@ -216,8 +216,8 @@ async def wait_to_model_ready(self): start_pd_remote_prefill_server_process( self.args.pd_node_id, dist_info = dist_info, - http_server_port=self.args.pd_remote_prefill_http_port, - server_port=self.args.pd_remote_prefill_port, + http_server_port=self.args.pd_nixl_remote_prefill_http_port, + server_port=self.args.pd_nixl_remote_prefill_port, from_backend_queue=self.info_queue, to_backend_queues=self.result_queues, agent_meta_queues=self.mem_queues, @@ -254,8 +254,8 @@ def add_req(self, group_req_indexes: GroupReqIndexes): req = self.shm_req_manager.get_req_obj_by_index(req_index) req.multimodal_params = group_req_indexes.multimodal_params req.start_time = group_req_indexes.time_mark - if isinstance(req, PDChunkedPrefillReq): - req.dp_world_size = self.dp_world_size + if isinstance(req, PDNIXLChunkedPrefillReq): + req.set_dp_world_size(self.dp_world_size) req_group.append(req) logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s") diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 3f165ca08..1f8888417 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -10,7 +10,7 @@ from typing import List, Dict, Tuple, Optional, Union, Any from lightllm.common.req_manager import ReqManager from lightllm.utils.infer_utils import mark_start, mark_end -from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager, PDChunkedPrefillReq +from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager, PDNIXLChunkedPrefillReq from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode from lightllm.utils.log_utils import init_logger from lightllm.server.req_id_generator import convert_sub_id_to_group_id @@ -264,8 +264,7 @@ def init_all(self): self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index) self.shm_req.link_prompt_ids_shm_array() self.shm_req.link_logprobs_shm_array() - if isinstance(self.shm_req, PDChunkedPrefillReq): - self.shm_req.link_pd_req_state_shm_array() + if isinstance(self.shm_req, PDNIXLChunkedPrefillReq): self.in_prefill_or_transfer = False self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py index a3d27ed9f..c2b767f54 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py @@ -6,7 +6,7 @@ from lightllm.utils.log_utils import init_logger -from lightllm.server.core.objs.req import PDChunkedPrefillReq +from lightllm.server.core.objs.req import PDNIXLChunkedPrefillReq from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq @@ -59,7 +59,7 @@ def handle_remote_prefill(req_status: RemotePrefillStatus): if run_req := self.remote_prefilled_reqs.get(group_req_id, None): if req_status.is_last or status != 1: - shm_req: PDChunkedPrefillReq = run_req.shm_req + shm_req: PDNIXLChunkedPrefillReq = run_req.shm_req shm_req.set_pd_req_rank_state(self.rank_in_dp, status) self.remote_prefilled_reqs.pop(group_req_id) if self.is_master_in_dp: @@ -101,7 +101,7 @@ def _wait_transfer_loop(self): continue req: InferReq = self.inflght_transfer_requests[req_id] - shm_req: PDChunkedPrefillReq = req.shm_req + shm_req: PDNIXLChunkedPrefillReq = req.shm_req shm_req.set_pd_req_rank_state(self.rank_in_dp, state) transfer_state = self.remote_prefill_requests[req_id].transfer_state if self.is_master_in_dp: @@ -168,7 +168,7 @@ def _transfer_kv_to_remote(self, req: InferReq): self.nixl_agent.write_blocks(kv_transfer_req, remote_request, is_finished) if transfer_state.current_chunk_id == 0: - shm_req: PDChunkedPrefillReq = req.shm_req + shm_req: PDNIXLChunkedPrefillReq = req.shm_req shm_req.set_pd_req_rank_state(self.rank_in_dp, 0) req.in_prefill_or_transfer = True self.inflght_transfer_requests[group_req_id] = req @@ -187,7 +187,7 @@ def _decode_filter_reqs( # filter out aborted requests for req in aborted_reqs: if req.in_prefill_or_transfer: - shm_req: PDChunkedPrefillReq = req.shm_req + shm_req: PDNIXLChunkedPrefillReq = req.shm_req state = shm_req.get_pd_req_state() if state != 0: new_aborted_reqs.append(req) @@ -198,7 +198,7 @@ def _decode_filter_reqs( for req in prefill_reqs: if req.in_prefill_or_transfer: - shm_req: PDChunkedPrefillReq = req.shm_req + shm_req: PDNIXLChunkedPrefillReq = req.shm_req # state is updated by router state = shm_req.get_pd_req_state() if state == 1: # success @@ -224,7 +224,7 @@ def _prefill_filter_reqs(self, ok_finished_reqs: List[InferReq], aborted_reqs: L for req in ok_finished_reqs: if req.in_prefill_or_transfer: - shm_req: PDChunkedPrefillReq = req.shm_req + shm_req: PDNIXLChunkedPrefillReq = req.shm_req state = shm_req.get_pd_req_state() if state == 1: # success new_ok_finished_reqs.append(req) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py index 80ac3f1c4..6b30759a6 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py @@ -5,7 +5,7 @@ from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend from typing import List, Tuple, Dict from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq -from lightllm.server.core.objs.req import PDChunkedPrefillReq +from lightllm.server.core.objs.req import PDNIXLChunkedPrefillReq from lightllm.utils.log_utils import init_logger from lightllm.server.router.model_infer.mode_backend.generic_pre_process import prepare_decode_inputs from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample @@ -69,7 +69,7 @@ def decode(self): kwargs, run_reqs = self._prepare_remote_prefill_inputs(prefill_reqs) for idx, run_req in enumerate(run_reqs): run_req: InferReq = run_req - shm_req: PDChunkedPrefillReq = run_req.shm_req + shm_req: PDNIXLChunkedPrefillReq = run_req.shm_req # forward each req to remote prefill # since the token index are the same across TPs, we only need to trigger prefill on master if self.is_master_in_dp: diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py index 6be2873e9..72db6fceb 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py @@ -4,7 +4,7 @@ import torch.distributed as dist from typing import List from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq -from lightllm.server.core.objs.req import PDChunkedPrefillReq +from lightllm.server.core.objs.req import PDNIXLChunkedPrefillReq from lightllm.utils.log_utils import init_logger from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.utils.envs_utils import get_env_start_args @@ -49,7 +49,7 @@ def decode(self): kwargs, run_reqs = self._prepare_remote_prefill_inputs(prefill_reqs) for idx, run_req in enumerate(run_reqs): run_req: InferReq = run_req - shm_req: PDChunkedPrefillReq = run_req.shm_req + shm_req: PDNIXLChunkedPrefillReq = run_req.shm_req # forward each req to remote prefill # since the token index are the same across TPs, we only need to trigger prefill on master if self.is_master_in_dp: From a52dbf147fbdeb80e14225d7cf38a333c752b5ae Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Wed, 28 May 2025 16:36:34 +0800 Subject: [PATCH 14/19] update & add docker --- Dockerfile.nixl | 89 +++++++++++++++++++ .../pd_nixl/impl_for_pd_prefill.py | 7 +- .../pd_nixl/impl_for_pd_prefill_dp.py | 12 +-- 3 files changed, 96 insertions(+), 12 deletions(-) create mode 100644 Dockerfile.nixl diff --git a/Dockerfile.nixl b/Dockerfile.nixl new file mode 100644 index 000000000..8ac409505 --- /dev/null +++ b/Dockerfile.nixl @@ -0,0 +1,89 @@ +FROM nvcr.io/nvidia/tritonserver:25.04-py3-min as base +ARG PYTORCH_VERSION=2.6.0 +ARG PYTHON_VERSION=3.9 +ARG CUDA_VERSION=12.4 +ARG MAMBA_VERSION=23.1.0-1 +ARG TARGETPLATFORM +ARG INSTALL_NIXL=false + +ENV PATH=/opt/conda/bin:$PATH \ + CONDA_PREFIX=/opt/conda + +RUN chmod 777 -R /tmp && apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + ca-certificates \ + libssl-dev \ + curl \ + g++ \ + make \ + git && \ + rm -rf /var/lib/apt/lists/* + +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") MAMBA_ARCH=aarch64 ;; \ + *) MAMBA_ARCH=x86_64 ;; \ + esac && \ + curl -fsSL -o ~/mambaforge.sh -v "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + rm ~/mambaforge.sh + +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") exit 1 ;; \ + *) /opt/conda/bin/conda update -y conda && \ + /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \ + esac && \ + /opt/conda/bin/conda clean -ya + + +WORKDIR /root + +COPY ./requirements.txt /lightllm/requirements.txt +RUN --mount=type=cache,target=/root/.cache/pip pip install -r /lightllm/requirements.txt --ignore-installed --extra-index-url https://download.pytorch.org/whl/cu124 + +RUN pip install --no-cache-dir https://github.com/ModelTC/flash-attn-3-build/releases/download/v2.7.4.post1/flash_attn-3.0.0b1-cp39-cp39-linux_x86_64.whl + +RUN --mount=type=cache,target=/root/.cache/pip pip install nvidia-nccl-cu12==2.25.1 # for allreduce hang issues in multinode H100 + +RUN if [ "$INSTALL_NIXL" == "true" ]; then \ + apt-get update && apt-get install -y cmake automake autotools-dev libtool libz-dev && \ + DEBIAN_FRONTEND=noninteractive apt-get -y install --reinstall libibverbs-dev rdma-core ibverbs-utils libibumad-dev; \ + rm -rf /usr/lib/ucx && \ + rm -rf /opt/hpcx/ucx && \ + cd /usr/local/src && \ + git clone https://github.com/openucx/ucx.git && \ + cd ucx && \ + git checkout v1.19.x && \ + ./autogen.sh && ./configure \ + --enable-shared \ + --disable-static \ + --disable-doxygen-doc \ + --enable-optimizations \ + --enable-cma \ + --enable-devel-headers \ + --with-cuda=/usr/local/cuda \ + --with-verbs=yes \ + --with-dm \ + --with-gdrcopy=/usr/local \ + --with-efa \ + --enable-mt && \ + make -j && \ + make -j install-strip && \ + ldconfig; \ + fi + +RUN if [ "$INSTALL_NIXL" == "true" ]; then \ + apt-get update && apt-get install -y pkg-config tmux net-tools; \ + cd /usr/local/src; \ + pip install --upgrade meson pybind11 patchelf; \ + git clone https://github.com/ai-dynamo/nixl.git -b main && \ + cd nixl && \ + rm -rf build && \ + mkdir build && \ + meson setup build/ --prefix=/usr/local/nixl && \ + cd build && \ + ninja && \ + ninja install && \ + cd .. && pip install . --no-deps; \ + fi + +COPY . /lightllm +RUN pip install -e /lightllm --no-cache-dir diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py index ed77c9263..78af8cc52 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py @@ -40,12 +40,7 @@ def decode(self): assert len(decode_reqs) == 0 self._prefill_abort_remote(aborted_reqs) - self._filter_reqs(aborted_reqs) - - if ok_finished_reqs: - for req in ok_finished_reqs: - self._transfer_kv_to_remote(req) - self._filter_reqs(ok_finished_reqs) + self._filter_reqs(aborted_reqs + ok_finished_reqs) if prefill_reqs: kwargs, run_reqs = prepare_prefill_inputs( diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py index 3e72f73ce..63605508c 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py @@ -39,13 +39,13 @@ def decode(self): assert len(decode_reqs) == 0 self._prefill_abort_remote(aborted_reqs) - self._filter_reqs(aborted_reqs) + self._filter_reqs(aborted_reqs + ok_finished_reqs) - if ok_finished_reqs: - for req in ok_finished_reqs: - self._transfer_kv_to_remote(req) - self._filter_reqs(ok_finished_reqs) - ok_finished_reqs.clear() + # if ok_finished_reqs: + # for req in ok_finished_reqs: + # self._transfer_kv_to_remote(req) + # self._filter_reqs(ok_finished_reqs) + # ok_finished_reqs.clear() current_dp_prefill_num = len(prefill_reqs) self.reduce_tensor.fill_(current_dp_prefill_num) From 20d129891c43f74c4cea26ee3557b3b87655cf20 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Thu, 29 May 2025 17:58:14 +0800 Subject: [PATCH 15/19] fix. --- lightllm/server/httpserver/manager.py | 2 +- .../model_infer/mode_backend/pd_nixl/impl_for_pd_base.py | 9 +++++---- .../mode_backend/pd_nixl/pd_remote_prefill_obj.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 91cae273f..0e9af6a83 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -611,7 +611,7 @@ async def recycle_resource_loop(self): pre_time_mark = time.time() for req_status in self.req_id_to_out_inf.values(): logger.info( - f"left req id {req_status.group_req_objs.group_req_id}" + f"left req id {req_status.group_req_objs.group_req_id} " f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} " f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}" ) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py index c2b767f54..f076ae73e 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py @@ -164,14 +164,15 @@ def _transfer_kv_to_remote(self, req: InferReq): prev_kv_len=transfer_state.current_kv_len, cur_kv_len=req.cur_kv_len, ) - # kick off kv transfer - self.nixl_agent.write_blocks(kv_transfer_req, remote_request, is_finished) - if transfer_state.current_chunk_id == 0: shm_req: PDNIXLChunkedPrefillReq = req.shm_req shm_req.set_pd_req_rank_state(self.rank_in_dp, 0) req.in_prefill_or_transfer = True self.inflght_transfer_requests[group_req_id] = req + logger.debug(f"put {group_req_id} into inflght_transfer_requests and size: {len(self.inflght_transfer_requests)}") + + # kick off kv transfer + self.nixl_agent.write_blocks(kv_transfer_req, remote_request, is_finished) transfer_state.current_kv_len = req.cur_kv_len transfer_state.current_chunk_id += 1 @@ -206,7 +207,7 @@ def _decode_filter_reqs( decode_reqs.append(req) req.in_prefill_or_transfer = False elif state == -1: # failure - aborted_reqs.append(req) + new_aborted_reqs.append(req) req.in_prefill_or_transfer = False elif state == 0: # in progress remote_prefill_reqs.append(req) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py index 32d9aa1ad..1e3b6f1dc 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py @@ -128,7 +128,7 @@ def __contains__(self, key): def __len__(self) -> int: with self._lock: - return len(self._data) + return len(self._dict) def get(self, key, default=None): with self._lock: From 597e516873ceba030932b148b73972d6fafe81a9 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Fri, 30 May 2025 10:18:23 +0800 Subject: [PATCH 16/19] fix lint. --- lightllm/server/core/objs/req.py | 13 +++++--- lightllm/server/pd_io_struct.py | 2 ++ lightllm/server/router/manager.py | 23 ++++++++++--- .../mode_backend/pd_nixl/impl_for_pd_base.py | 7 ++-- .../pd_nixl/impl_for_pd_decode_dp.py | 1 + .../pd_nixl/impl_for_pd_prefill_dp.py | 12 +++++-- .../pd_nixl/nixl_kv_transporter.py | 26 ++++++++------- .../mode_backend/pd_nixl/pd_remote_prefill.py | 33 ++++++++++++++----- .../pd_nixl/pd_remote_prefill_obj.py | 4 ++- 9 files changed, 86 insertions(+), 35 deletions(-) diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 5d9789fc8..9ac010e98 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -340,6 +340,7 @@ class PdNixlReqState(ctypes.Structure): _pack_ = 4 _MAX_TP_SIZE = 32 _fields_ = [("dp_world_size", ctypes.c_int), ("state", ctypes.c_int * _MAX_TP_SIZE)] + def __init__(self): self.dp_world_size = 0 self.state = (ctypes.c_int * self._MAX_TP_SIZE)(*([0] * self._MAX_TP_SIZE)) @@ -348,8 +349,9 @@ def set_dp_world_size(self, size: int): self.dp_world_size = size def set_tp_state(self, tp_id: int, state: int): - assert self.dp_world_size > 0 and tp_id >= 0 and tp_id < self.dp_world_size, \ - f"tp_id {tp_id} out of range [0, {self.dp_world_size})" + assert ( + self.dp_world_size > 0 and tp_id >= 0 and tp_id < self.dp_world_size + ), f"tp_id {tp_id} out of range [0, {self.dp_world_size})" self.state[tp_id] = state def set_state(self): @@ -366,12 +368,13 @@ class PDNIXLChunkedPrefillReq(ChunkedPrefillReq): _pack_ = 4 _fields_ = ChunkedPrefillReq._fields_ + [ # 用于pd nixl状态同步 - ("pd_nixl_req_state", PdNixlReqState)] + ("pd_nixl_req_state", PdNixlReqState) + ] def set_dp_world_size(self, dp_world_size): self.pd_nixl_req_state.dp_world_size = dp_world_size - # called by each tp rank, no contention + # called by each tp rank, no contention def set_pd_req_rank_state(self, tp_id: int, state: int): self.pd_nixl_req_state.set_tp_state(tp_id, state) @@ -382,4 +385,4 @@ def set_pd_req_state(self): # read by all rank def get_pd_req_state(self): - return self.pd_nixl_req_state.get_state() \ No newline at end of file + return self.pd_nixl_req_state.get_state() diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index b4831072a..4267afaee 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -127,6 +127,7 @@ class RemotePrefillServerInfo: prefill_server_ip: str prefill_server_port: int + @dataclass class DistInfo: world_size: int @@ -136,6 +137,7 @@ class DistInfo: dp_size_in_node: int node_world_size: int + @dataclass class PDTransLeaveInfo: decode_id: int diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 7f02c1e72..ff654050b 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -210,12 +210,18 @@ async def wait_to_model_ready(self): start_pd_remote_prefill_server_process, ) - dist_info = DistInfo(self.world_size, self.nnodes, self.dp_size, - self.dp_world_size, self.dp_size_in_node, self.node_world_size) + dist_info = DistInfo( + self.world_size, + self.nnodes, + self.dp_size, + self.dp_world_size, + self.dp_size_in_node, + self.node_world_size, + ) start_pd_remote_prefill_server_process( self.args.pd_node_id, - dist_info = dist_info, + dist_info=dist_info, http_server_port=self.args.pd_nixl_remote_prefill_http_port, server_port=self.args.pd_nixl_remote_prefill_port, from_backend_queue=self.info_queue, @@ -235,8 +241,15 @@ async def wait_to_model_ready(self): from lightllm.server.router.model_infer.mode_backend.pd_nixl.pd_remote_prefill import ( start_pd_remote_prefill_client_process, ) - dist_info = DistInfo(self.world_size, self.nnodes, self.dp_size, - self.dp_world_size, self.dp_size_in_node, self.node_world_size) + + dist_info = DistInfo( + self.world_size, + self.nnodes, + self.dp_size, + self.dp_world_size, + self.dp_size_in_node, + self.node_world_size, + ) start_pd_remote_prefill_client_process( self.args.pd_node_id, diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py index f076ae73e..0edf8441f 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py @@ -160,7 +160,7 @@ def _transfer_kv_to_remote(self, req: InferReq): kv_transfer_req = KVMoveRequest( group_req_id=group_req_id, - token_ids=token_index[ : req.cur_kv_len].tolist(), + token_ids=token_index[: req.cur_kv_len].tolist(), prev_kv_len=transfer_state.current_kv_len, cur_kv_len=req.cur_kv_len, ) @@ -169,7 +169,9 @@ def _transfer_kv_to_remote(self, req: InferReq): shm_req.set_pd_req_rank_state(self.rank_in_dp, 0) req.in_prefill_or_transfer = True self.inflght_transfer_requests[group_req_id] = req - logger.debug(f"put {group_req_id} into inflght_transfer_requests and size: {len(self.inflght_transfer_requests)}") + logger.debug( + f"put {group_req_id} into inflght_transfer_requests and size: {len(self.inflght_transfer_requests)}" + ) # kick off kv transfer self.nixl_agent.write_blocks(kv_transfer_req, remote_request, is_finished) @@ -177,7 +179,6 @@ def _transfer_kv_to_remote(self, req: InferReq): transfer_state.current_kv_len = req.cur_kv_len transfer_state.current_chunk_id += 1 - def _decode_filter_reqs( self, prefill_reqs: List[InferReq], aborted_reqs: List[InferReq], decode_reqs: List[InferReq] ): diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py index 72db6fceb..c6790da58 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py @@ -25,6 +25,7 @@ def init_custom(self): self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import padded_prepare_prefill_inputs + kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs([], 1, is_multimodal=self.is_multimodal) self.model.forward(**kwargs) assert len(run_reqs) == 0 and padded_req_num == 1 diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py index 63605508c..5d3555e27 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py @@ -73,7 +73,11 @@ def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int next_token_ids = next_token_ids.detach().cpu().numpy() next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() self._post_handle( - run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False, + run_reqs, + next_token_ids, + next_token_logprobs, + is_chuncked_mode=True, + do_filter_finished_reqs=False, extra_post_req_handle_func=lambda req, _1, _2: self._transfer_kv_to_remote(req), ) @@ -104,6 +108,10 @@ def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: in next_token_ids = next_token_ids.detach().cpu().numpy() next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() self._post_handle( - all_run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False, + all_run_reqs, + next_token_ids, + next_token_logprobs, + is_chuncked_mode=True, + do_filter_finished_reqs=False, extra_post_req_handle_func=lambda req, _1, _2: self._transfer_kv_to_remote(req), ) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py index bbfe5df3f..9e35e6508 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py @@ -7,9 +7,13 @@ from lightllm.utils.log_utils import init_logger from .pd_remote_prefill_obj import ( - RemoteAgent, KVMoveRequest, PrefillRequest, - RemotePrefillStatus, ThreadSafeDict, KVMoveRequestState - ) + RemoteAgent, + KVMoveRequest, + PrefillRequest, + RemotePrefillStatus, + ThreadSafeDict, + KVMoveRequestState, +) logger = init_logger(__name__) @@ -120,10 +124,10 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest, return kv_move_start = max(skip_kv_move_len, request.prev_kv_len) - kv_move_end = request.cur_kv_len + kv_move_end = request.cur_kv_len src_token_ids = request.token_ids[kv_move_start:] - dst_token_ids = prefill_request.data.token_ids[kv_move_start - skip_kv_move_len: kv_move_end] + dst_token_ids = prefill_request.data.token_ids[kv_move_start - skip_kv_move_len : kv_move_end] remote_agent: RemoteAgent = self.remote_agents[prefill_request.decode_id][ self.tp_idx @@ -140,7 +144,8 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest, group_req_id=group_reqeust_id, status=1, chunk_id=prefill_request.transfer_state.current_chunk_id, - is_last=is_finished) + is_last=is_finished, + ) handle = self.nixl_agent.make_prepped_xfer( "WRITE", src_handle, src_token_descs, dst_handle, dst_token_descs, notify_status.serialize() @@ -151,10 +156,7 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest, if group_reqeust_id not in self.inflight_transfers: self.inflight_transfers[group_reqeust_id] = KVMoveRequestState( - handles=[], - done_handles=[], - remote_agent=remote_agent, - abort=False + handles=[], done_handles=[], remote_agent=remote_agent, abort=False ) self.inflight_transfers[group_reqeust_id].handles.append(handle) @@ -199,7 +201,9 @@ def get_done_tranfers(self): logger.warning(f"{req_id} Transfer failed with state {xfer_state}") failed = True kv_move_state.done_handles.append(handle) - notify_failed_status = RemotePrefillStatus(group_req_id=req_id, status=-1, chunk_id=-1, is_last=True) + notify_failed_status = RemotePrefillStatus( + group_req_id=req_id, status=-1, chunk_id=-1, is_last=True + ) self.nixl_agent.send_notif(remote_agent.name, notify_failed_status.serialize()) kv_move_state.handles = left_handles diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py index 79490a703..9df2c5664 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py @@ -112,13 +112,17 @@ def main_loop(self): if self.dist_info.dp_size_in_node > 1: group_req_id = request.data.sampling_params.group_request_id suggested_dp_index = request.data.sampling_params.suggested_dp_index - if suggested_dp_index < 0: # not likely to happen + if suggested_dp_index < 0: # not likely to happen suggested_dp_index = random.randint(0, self.dist_info.dp_size_in_node) request.data.sampling_params.suggested_dp_index = suggested_dp_index - logger.warning(f"Suggested dp index is negative for {group_req_id}, set to {suggested_dp_index}") + logger.warning( + f"Suggested dp index is negative for {group_req_id}, set to {suggested_dp_index}" + ) - for local_rank in range(suggested_dp_index * self.dist_info.dp_world_size, - (suggested_dp_index + 1) * self.dist_info.dp_world_size): + for local_rank in range( + suggested_dp_index * self.dist_info.dp_world_size, + (suggested_dp_index + 1) * self.dist_info.dp_world_size, + ): self.to_backend_queues[local_rank].put(request) else: for queue in self.to_backend_queues: @@ -217,7 +221,11 @@ def main_loop(self): def remote_prefill(self, server_id: int, prefill_request: RemotePrefillRequest): socket, _ = self.remote_prefill_servers[server_id] prefill_request.sampling_params.max_new_tokens = 1 - socket.send_pyobj(PrefillRequest(type=RemoteRequstType.REMOTE_PREFILL, decode_id=self.id, data=prefill_request, transfer_state=None)) + socket.send_pyobj( + PrefillRequest( + type=RemoteRequstType.REMOTE_PREFILL, decode_id=self.id, data=prefill_request, transfer_state=None + ) + ) def remote_prefill_server_loop( @@ -256,7 +264,11 @@ def start_pd_remote_prefill_server_process( def remote_prefill_client_loop( - id: int, dist_info: DistInfo, from_backend_queue: mp.Queue, to_backend_queues: List[mp.Queue], agent_meta_queues: List[mp.Queue] + id: int, + dist_info: DistInfo, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], ): graceful_registry(inspect.currentframe().f_code.co_name) @@ -271,11 +283,16 @@ def remote_prefill_client_loop( def start_pd_remote_prefill_client_process( - id: int, dist_info: DistInfo, from_backend_queue: mp.Queue, to_backend_queues: List[mp.Queue], agent_meta_queues: List[mp.Queue] + id: int, + dist_info: DistInfo, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], ): proc = mp.Process( - target=remote_prefill_client_loop, args=(id, dist_info, from_backend_queue, to_backend_queues, agent_meta_queues) + target=remote_prefill_client_loop, + args=(id, dist_info, from_backend_queue, to_backend_queues, agent_meta_queues), ) proc.start() assert proc.is_alive() diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py index 1e3b6f1dc..dc54f6a33 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py @@ -53,12 +53,14 @@ class ConnectRequest(RemoteRequest): agent_metadatas: List[bytes] agent_mem_descs: List[bytes] + @dataclass class TransferState: start_time: float current_kv_len: int current_chunk_id: int + @dataclass class PrefillRequest(RemoteRequest): decode_id: int @@ -82,6 +84,7 @@ class RemoteAgent: kv_mem_desc: nixlBind.nixlRegDList kv_xfer_handles: nixl_prepped_dlist_handle + @dataclass class KVMoveRequestState: handles: List[nixl_xfer_handle] @@ -184,4 +187,3 @@ def bind(self, addr: str): def connect(self, addr: str): return self.sock.connect(addr) - From dbad0f270950d9259b224307b1905ba19266f72d Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Wed, 4 Jun 2025 18:09:49 +0800 Subject: [PATCH 17/19] fix abort. --- Dockerfile.nixl | 2 +- .../router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Dockerfile.nixl b/Dockerfile.nixl index 8ac409505..332356af3 100644 --- a/Dockerfile.nixl +++ b/Dockerfile.nixl @@ -4,7 +4,7 @@ ARG PYTHON_VERSION=3.9 ARG CUDA_VERSION=12.4 ARG MAMBA_VERSION=23.1.0-1 ARG TARGETPLATFORM -ARG INSTALL_NIXL=false +ARG INSTALL_NIXL=true ENV PATH=/opt/conda/bin:$PATH \ CONDA_PREFIX=/opt/conda diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py index 0edf8441f..85b28f8f7 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py @@ -197,6 +197,8 @@ def _decode_filter_reqs( else: # TODO trigger remote abort remote_prefill_reqs.append(req) + else: + new_aborted_reqs.append(req) for req in prefill_reqs: if req.in_prefill_or_transfer: From 0829dc6cfd64662c5f6fb7c6eb9f5439ea3d64b8 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 6 Jun 2025 14:14:39 +0800 Subject: [PATCH 18/19] fix prefill num_tokens > decode num_tokens --- .../mode_backend/pd_nixl/nixl_kv_transporter.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py index 9e35e6508..352d7b247 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py @@ -47,7 +47,6 @@ def __init__(self, node_id: int, tp_idx: int): self.num_heads = -1 self.head_dims = -1 self.token_len = -1 - self.layer_len = -1 self.reg_desc = None self.local_xfer_handles = None @@ -73,11 +72,12 @@ def get_new_notifs(self): def _create_xfer_handles(self, reg_desc: nixlBind.nixlRegDList, num_tokens: int, agent_name: str = ""): base_addr, _, device_id, _ = reg_desc[0] + layer_len = num_tokens * self.token_len tokens_data = [] for layer_id in range(self.num_layers): for token_id in range(num_tokens): tokens_data.append( - (base_addr + layer_id * self.layer_len + token_id * self.token_len, self.token_len, device_id) + (base_addr + layer_id * layer_len + token_id * self.token_len, self.token_len, device_id) ) descs = self.nixl_agent.get_xfer_descs(tokens_data, "VRAM", True) return self.nixl_agent.prep_xfer_dlist(agent_name, descs, is_sorted=True) @@ -85,7 +85,6 @@ def _create_xfer_handles(self, reg_desc: nixlBind.nixlRegDList, num_tokens: int, def register_kv_buffer(self, kv_buffer: Tensor): self.num_layers, self.num_tokens, self.num_heads, self.head_dim = kv_buffer.shape self.token_len = self.num_heads * self.head_dim * kv_buffer.element_size() - self.layer_len = self.num_tokens * self.token_len self.reg_desc = self.nixl_agent.register_memory(kv_buffer) self.local_xfer_handles = self._create_xfer_handles(self.reg_desc, self.num_tokens) @@ -108,11 +107,11 @@ def add_remote_agent(self, remote_agent: NixlMetadata): ) ) - def _get_token_desc_ids(self, token_ids: List[int]): + def _get_token_desc_ids(self, token_ids: List[int], num_tokens: int): descs_ids = [] for layer_id in range(self.num_layers): for token_id in token_ids: - descs_ids.append(layer_id * self.num_tokens + token_id) + descs_ids.append(layer_id * num_tokens + token_id) return descs_ids def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest, is_finished: bool): @@ -135,8 +134,8 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest, if len(src_token_ids) > 0: assert len(src_token_ids) == len(dst_token_ids), f"{len(src_token_ids)} {len(dst_token_ids)}" - src_token_descs = self._get_token_desc_ids(src_token_ids) - dst_token_descs = self._get_token_desc_ids(dst_token_ids) + src_token_descs = self._get_token_desc_ids(src_token_ids, self.num_tokens) + dst_token_descs = self._get_token_desc_ids(dst_token_ids, remote_agent.num_tokens) src_handle = self.local_xfer_handles dst_handle = remote_agent.kv_xfer_handles From 68480558523ae128d4eb0dea7502a12241ec75f6 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Tue, 24 Jun 2025 17:19:14 +0800 Subject: [PATCH 19/19] fix chunked transfer & bugs. --- Dockerfile.nixl | 4 +-- lightllm/server/httpserver/manager.py | 14 ++++++----- .../model_infer/mode_backend/base_backend.py | 5 ++++ .../mode_backend/pd_nixl/impl_for_pd_base.py | 25 +++++++++++++++---- .../pd_nixl/impl_for_pd_prefill.py | 5 +++- .../pd_nixl/impl_for_pd_prefill_dp.py | 4 +-- .../pd_nixl/nixl_kv_transporter.py | 21 ++++++++++------ .../pd_nixl/pd_remote_prefill_obj.py | 1 + 8 files changed, 55 insertions(+), 24 deletions(-) diff --git a/Dockerfile.nixl b/Dockerfile.nixl index 332356af3..427a230e5 100644 --- a/Dockerfile.nixl +++ b/Dockerfile.nixl @@ -71,14 +71,14 @@ RUN if [ "$INSTALL_NIXL" == "true" ]; then \ fi RUN if [ "$INSTALL_NIXL" == "true" ]; then \ - apt-get update && apt-get install -y pkg-config tmux net-tools; \ + apt-get update && apt-get install -y pkg-config tmux net-tools; \ cd /usr/local/src; \ pip install --upgrade meson pybind11 patchelf; \ git clone https://github.com/ai-dynamo/nixl.git -b main && \ cd nixl && \ rm -rf build && \ mkdir build && \ - meson setup build/ --prefix=/usr/local/nixl && \ + meson setup build/ --prefix=/usr/local/nixl --buildtype=release && \ cd build && \ ninja && \ ninja install && \ diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 0e9af6a83..14e971a6a 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -412,7 +412,7 @@ async def transfer_to_next_module_or_node( if self.is_multinode_tp_master: async with self.transfer_lock: for sender in self.multinode_req_manager: - sender.send_pyobj( + await sender.send_pyobj( (prompt, sampling_params, original_multimodal_params), protocol=pickle.HIGHEST_PROTOCOL, ) @@ -443,12 +443,14 @@ async def transfer_to_next_module( if self.pd_mode.is_P(): if self.enable_multimodal: - self.send_to_visual.send_pyobj( + await self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) else: - self.send_to_router.send_pyobj( + + # P 模式下,直接将请求发送到路由器 + await self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) @@ -456,7 +458,7 @@ async def transfer_to_next_module( if self.pd_mode.is_D(): # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了, 传输一个空的即可 - self.send_to_router.send_pyobj( + await self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) @@ -464,12 +466,12 @@ async def transfer_to_next_module( if self.pd_mode.is_normal(): if self.enable_multimodal: - self.send_to_visual.send_pyobj( + await self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) else: - self.send_to_router.send_pyobj( + await self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index e4db0aca7..8f2531869 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -250,6 +250,7 @@ def _post_handle( is_chuncked_mode: bool, do_filter_finished_reqs: bool, extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, + extra_post_req_handle_chunk_func: Optional[Callable[[InferReq], None]] = None, ) -> List[int]: """ extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于 @@ -274,6 +275,10 @@ def _post_handle( finished_req_ids.append(req_obj.shm_req.request_id) continue + if extra_post_req_handle_chunk_func is not None: + # 如果存在额外的处理函数,则调用这个函数进行处理。 + extra_post_req_handle_chunk_func(req_obj) + # 对于没有到达需要输出 token 阶段的请求,直接略过 if req_obj.cur_kv_len < req_obj.get_cur_total_len(): continue diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py index 85b28f8f7..2156944d8 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py @@ -26,13 +26,14 @@ class PDNIXLBackendBase(ModeBackend): - _THEAD_WAIT_INTERVAL = 0.001 + _THREAD_WAIT_INTERVAL = 0.001 def __init__(self, to_remote_queue: mp.Queue, from_remote_queue: mp.Queue, nixl_meta_queue: mp.Queue): super().__init__() self.to_remote_queue = to_remote_queue self.from_remote_queue = from_remote_queue self.nixl_meta_queue = nixl_meta_queue + self.prefill_post_handle_queue = queue.Queue() # for decode self.remote_prefilled_reqs: ThreadSafeDict = ThreadSafeDict() @@ -85,12 +86,21 @@ def handle_remote_prefill(req_status: RemotePrefillStatus): prefill_status = RemotePrefillStatus.deserialize(req_status) handle_remote_prefill(prefill_status) - time.sleep(PDNIXLBackendBase._THEAD_WAIT_INTERVAL) + time.sleep(PDNIXLBackendBase._THREAD_WAIT_INTERVAL) + + def _handle_transfer_loop(self): + while True: + try: + req: InferReq = self.prefill_post_handle_queue.get() + self._transfer_kv_to_remote(req) + time.sleep(PDNIXLBackendBase._THREAD_WAIT_INTERVAL) + except queue.Empty: + pass + def _wait_transfer_loop(self): while True: done_req_ids = self.nixl_agent.get_done_tranfers() - for req_id, state in done_req_ids: if state != 1: logger.info(f"wait transfer done: {req_id} state: {state}") @@ -112,7 +122,7 @@ def _wait_transfer_loop(self): del self.remote_prefill_requests[req_id] del self.inflght_transfer_requests[req_id] - time.sleep(PDNIXLBackendBase._THEAD_WAIT_INTERVAL) + time.sleep(PDNIXLBackendBase._THREAD_WAIT_INTERVAL) def _handle_prefill_loop(self): while True: @@ -145,7 +155,7 @@ def _transfer_kv_to_remote(self, req: InferReq): if group_req_id not in self.remote_prefill_requests: logger.info(f"remote prefill request {group_req_id} not found") return - + start = time.time() remote_request: PrefillRequest = self.remote_prefill_requests[group_req_id] if remote_request.transfer_state is None: remote_request.transfer_state = TransferState( @@ -178,6 +188,11 @@ def _transfer_kv_to_remote(self, req: InferReq): transfer_state.current_kv_len = req.cur_kv_len transfer_state.current_chunk_id += 1 + logger.info( + f"transfer kv to remote: {group_req_id} " + f"current chunk id: {transfer_state.current_chunk_id} " + f"took: {time.time() - start} seconds" + ) def _decode_filter_reqs( self, prefill_reqs: List[InferReq], aborted_reqs: List[InferReq], decode_reqs: List[InferReq] diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py index 78af8cc52..736e52835 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py @@ -20,7 +20,10 @@ def init_custom(self): super().init_custom() self.handle_prefill_loop_thread = threading.Thread(target=self._handle_prefill_loop, daemon=True) self.wait_transfer_loop_thread = threading.Thread(target=self._wait_transfer_loop, daemon=True) + self.handle_transfer_loop_thread = threading.Thread(target=self._handle_transfer_loop, daemon=True) + self.handle_prefill_loop_thread.start() + self.handle_transfer_loop_thread.start() self.wait_transfer_loop_thread.start() return @@ -58,6 +61,6 @@ def decode(self): next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False, - extra_post_req_handle_func=lambda req, _1, _2: self._transfer_kv_to_remote(req), + extra_post_req_handle_chunk_func=lambda req: self.prefill_post_handle_queue.put(req), ) return diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py index 5d3555e27..3f4354e42 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py @@ -78,7 +78,7 @@ def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False, - extra_post_req_handle_func=lambda req, _1, _2: self._transfer_kv_to_remote(req), + extra_post_req_handle_chunk_func=lambda req: self.prefill_post_handle_queue.put(req), ) def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs): @@ -113,5 +113,5 @@ def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: in next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False, - extra_post_req_handle_func=lambda req, _1, _2: self._transfer_kv_to_remote(req), + extra_post_req_handle_chunk_func=lambda req: self.prefill_post_handle_queue.put(req), ) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py index 352d7b247..6fb9673e4 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py @@ -2,7 +2,6 @@ from typing import Dict, List, Any from torch import Tensor from dataclasses import dataclass -import threading from lightllm.utils.log_utils import init_logger @@ -126,28 +125,29 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest, kv_move_end = request.cur_kv_len src_token_ids = request.token_ids[kv_move_start:] - dst_token_ids = prefill_request.data.token_ids[kv_move_start - skip_kv_move_len : kv_move_end] + dst_token_ids = prefill_request.data.token_ids[kv_move_start - skip_kv_move_len : kv_move_end - skip_kv_move_len] remote_agent: RemoteAgent = self.remote_agents[prefill_request.decode_id][ self.tp_idx ] # TODO one-one mapping now if len(src_token_ids) > 0: - assert len(src_token_ids) == len(dst_token_ids), f"{len(src_token_ids)} {len(dst_token_ids)}" + assert len(src_token_ids) == len(dst_token_ids), f"{len(src_token_ids)} {len(dst_token_ids)} {kv_move_start} {kv_move_end} {skip_kv_move_len}" src_token_descs = self._get_token_desc_ids(src_token_ids, self.num_tokens) dst_token_descs = self._get_token_desc_ids(dst_token_ids, remote_agent.num_tokens) src_handle = self.local_xfer_handles dst_handle = remote_agent.kv_xfer_handles + notify_status = RemotePrefillStatus( group_req_id=group_reqeust_id, status=1, chunk_id=prefill_request.transfer_state.current_chunk_id, is_last=is_finished, - ) + ).serialize() if is_finished else b"" handle = self.nixl_agent.make_prepped_xfer( - "WRITE", src_handle, src_token_descs, dst_handle, dst_token_descs, notify_status.serialize() + "WRITE", src_handle, src_token_descs, dst_handle, dst_token_descs, notify_status ) status = self.nixl_agent.transfer(handle) @@ -155,10 +155,14 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest, if group_reqeust_id not in self.inflight_transfers: self.inflight_transfers[group_reqeust_id] = KVMoveRequestState( - handles=[], done_handles=[], remote_agent=remote_agent, abort=False + handles=[], done_handles=[], remote_agent=remote_agent, abort=False, is_last_arrived=False ) + self.inflight_transfers[group_reqeust_id].handles.append(handle) + if is_finished: + self.inflight_transfers[group_reqeust_id].is_last_arrived = True + return handle return None @@ -173,7 +177,6 @@ def send_abort_notify(self, remote_id: int, group_reqeust_id): def get_done_tranfers(self): done_req_ids = [] - for req_id, kv_move_state in self.inflight_transfers.items(): kv_move_state: KVMoveRequestState if kv_move_state.abort: @@ -181,6 +184,9 @@ def get_done_tranfers(self): done_req_ids.append((req_id, -1)) continue + if not kv_move_state.is_last_arrived: + continue + remote_agent: RemoteAgent = kv_move_state.remote_agent left_handles = [] @@ -219,7 +225,6 @@ def get_done_tranfers(self): self.nixl_agent.release_xfer_handle(handle) del self.inflight_transfers[req_id] - return done_req_ids def shutdonw(self): diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py index dc54f6a33..02ee6c4ed 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py @@ -91,6 +91,7 @@ class KVMoveRequestState: done_handles: List[nixl_xfer_handle] remote_agent: RemoteAgent abort: bool + is_last_arrived: bool @dataclass