diff --git a/Dockerfile.nixl b/Dockerfile.nixl new file mode 100644 index 000000000..427a230e5 --- /dev/null +++ b/Dockerfile.nixl @@ -0,0 +1,89 @@ +FROM nvcr.io/nvidia/tritonserver:25.04-py3-min as base +ARG PYTORCH_VERSION=2.6.0 +ARG PYTHON_VERSION=3.9 +ARG CUDA_VERSION=12.4 +ARG MAMBA_VERSION=23.1.0-1 +ARG TARGETPLATFORM +ARG INSTALL_NIXL=true + +ENV PATH=/opt/conda/bin:$PATH \ + CONDA_PREFIX=/opt/conda + +RUN chmod 777 -R /tmp && apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + ca-certificates \ + libssl-dev \ + curl \ + g++ \ + make \ + git && \ + rm -rf /var/lib/apt/lists/* + +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") MAMBA_ARCH=aarch64 ;; \ + *) MAMBA_ARCH=x86_64 ;; \ + esac && \ + curl -fsSL -o ~/mambaforge.sh -v "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + rm ~/mambaforge.sh + +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") exit 1 ;; \ + *) /opt/conda/bin/conda update -y conda && \ + /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \ + esac && \ + /opt/conda/bin/conda clean -ya + + +WORKDIR /root + +COPY ./requirements.txt /lightllm/requirements.txt +RUN --mount=type=cache,target=/root/.cache/pip pip install -r /lightllm/requirements.txt --ignore-installed --extra-index-url https://download.pytorch.org/whl/cu124 + +RUN pip install --no-cache-dir https://github.com/ModelTC/flash-attn-3-build/releases/download/v2.7.4.post1/flash_attn-3.0.0b1-cp39-cp39-linux_x86_64.whl + +RUN --mount=type=cache,target=/root/.cache/pip pip install nvidia-nccl-cu12==2.25.1 # for allreduce hang issues in multinode H100 + +RUN if [ "$INSTALL_NIXL" == "true" ]; then \ + apt-get update && apt-get install -y cmake automake autotools-dev libtool libz-dev && \ + DEBIAN_FRONTEND=noninteractive apt-get -y install --reinstall libibverbs-dev rdma-core ibverbs-utils libibumad-dev; \ + rm -rf /usr/lib/ucx && \ + rm -rf /opt/hpcx/ucx && \ + cd /usr/local/src && \ + git clone https://github.com/openucx/ucx.git && \ + cd ucx && \ + git checkout v1.19.x && \ + ./autogen.sh && ./configure \ + --enable-shared \ + --disable-static \ + --disable-doxygen-doc \ + --enable-optimizations \ + --enable-cma \ + --enable-devel-headers \ + --with-cuda=/usr/local/cuda \ + --with-verbs=yes \ + --with-dm \ + --with-gdrcopy=/usr/local \ + --with-efa \ + --enable-mt && \ + make -j && \ + make -j install-strip && \ + ldconfig; \ + fi + +RUN if [ "$INSTALL_NIXL" == "true" ]; then \ + apt-get update && apt-get install -y pkg-config tmux net-tools; \ + cd /usr/local/src; \ + pip install --upgrade meson pybind11 patchelf; \ + git clone https://github.com/ai-dynamo/nixl.git -b main && \ + cd nixl && \ + rm -rf build && \ + mkdir build && \ + meson setup build/ --prefix=/usr/local/nixl --buildtype=release && \ + cd build && \ + ninja && \ + ninja install && \ + cd .. && pip install . --no-deps; \ + fi + +COPY . /lightllm +RUN pip install -e /lightllm --no-cache-dir diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index c7760e995..8540b9a74 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -174,6 +174,10 @@ def _init_kv_move_buffer(self): # p d 分离的推理模式下才需要做这一步初始化 if self.run_mode in ["prefill", "decode"]: self.mem_manager.alloc_kv_move_buffer(self.mem_manager.size) + elif self.run_mode in ["nixl_prefill", "nixl_decode"]: + page_num = int(os.getenv("PD_NIXL_MOVE_PAGE_NUM", 32)) + page_size = int(os.getenv("PD_NIXL_MOVE_PAGE_SIZE", 1024)) + self.mem_manager.alloc_paged_kv_move_buffer(page_num, page_size) def _check_mem_size(self): self.max_total_token_num = self.mem_manager.size diff --git a/lightllm/common/deepseek2_mem_manager.py b/lightllm/common/deepseek2_mem_manager.py index 6ddec24e2..c0a0b72b9 100644 --- a/lightllm/common/deepseek2_mem_manager.py +++ b/lightllm/common/deepseek2_mem_manager.py @@ -36,6 +36,12 @@ def alloc_kv_move_buffer(self, max_req_total_len): self.token_dim_size = self.kv_move_buffer.shape[-1] * self.kv_move_buffer.shape[-2] return + def alloc_paged_kv_move_buffer(self, page_num, page_size): + self.kv_move_buffer = torch.empty( + (page_num, page_size, self.layer_num, self.head_num, self.head_dim), dtype=self.dtype, device="cuda" + ) + return + def send_to_decode_node( self, move_tasks: List[KVMoveTask], diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index 4142ce4aa..4f8292bf2 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -96,6 +96,14 @@ def alloc_kv_move_buffer(self, max_req_total_len): self.token_dim_size = self.kv_move_buffer.shape[-2] * self.kv_move_buffer.shape[-1] return + def alloc_paged_kv_move_buffer(self, page_num, page_size): + if isinstance(self, MemoryManager) and type(self) != MemoryManager: + raise NotImplementedError("subclass need reimpl this method") + self.kv_move_buffer = torch.empty( + (page_num, page_size, self.layer_num, 2 * self.head_num, self.head_dim), dtype=self.dtype, device="cuda" + ) + return + def send_to_decode_node( self, move_tasks: List[KVMoveTask], diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 3f3eaf96f..43c8c68c1 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -7,7 +7,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--run_mode", type=str, - choices=["normal", "prefill", "decode", "pd_master", "config_server"], + choices=["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"], default="normal", help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode, config_server is for pd split mode used to register pd_master node, and get pd_master node list, @@ -54,6 +54,20 @@ def make_argument_parser() -> argparse.ArgumentParser: default=None, help="The port number for the config server in config_server mode.", ) + parser.add_argument( + "--pd_nixl_remote_prefill_http_port", + type=int, + default=42001, + help="nixl pd mode, prefill node used for triggering prefill http port.", + ) + + parser.add_argument( + "--pd_nixl_remote_prefill_port", + type=int, + default=42002, + help="nixl pd mode, prefill and decode used for meta exchange.", + ) + parser.add_argument( "--model_name", type=str, diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index de1e690a2..3fbf5cd43 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -67,7 +67,7 @@ def normal_or_p_d_start(args): enable_mps() - if args.run_mode not in ["normal", "prefill", "decode"]: + if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode"]: return assert args.zmq_mode in ["tcp://", "ipc:///tmp/"] diff --git a/lightllm/server/core/objs/__init__.py b/lightllm/server/core/objs/__init__.py index 5594df6a0..ab390a602 100644 --- a/lightllm/server/core/objs/__init__.py +++ b/lightllm/server/core/objs/__init__.py @@ -1,5 +1,5 @@ from .sampling_params import SamplingParams -from .req import Req, FinishStatus +from .req import Req, FinishStatus, PDNIXLChunkedPrefillReq from .shm_req_manager import ShmReqManager from .rpc_shm import RpcShmParams, RpcShmResults, ShmSyncStatusArray from .start_args_type import StartArgs diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index f2ebadad1..2b063c64f 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -106,6 +106,7 @@ def get_str(self): f"shm_cur_kv_len:{self.shm_cur_kv_len}," f"shm_cur_output_len:{self.shm_cur_output_len}," f"finish_status:{self.finish_status.is_finished()}" + f"group_id: {self.group_req_id}" ) def init( @@ -354,3 +355,55 @@ def post_init( # 错误问题。 self.sample_params.max_new_tokens = self.sample_params.max_new_tokens + self.prefix_token_ids.size + 6 return + + +class PdNixlReqState(ctypes.Structure): + _pack_ = 4 + _MAX_TP_SIZE = 32 + _fields_ = [("dp_world_size", ctypes.c_int), ("state", ctypes.c_int * _MAX_TP_SIZE)] + + def __init__(self): + self.dp_world_size = 0 + self.state = (ctypes.c_int * self._MAX_TP_SIZE)(*([0] * self._MAX_TP_SIZE)) + + def set_dp_world_size(self, size: int): + self.dp_world_size = size + + def set_tp_state(self, tp_id: int, state: int): + assert ( + self.dp_world_size > 0 and tp_id >= 0 and tp_id < self.dp_world_size + ), f"tp_id {tp_id} out of range [0, {self.dp_world_size})" + self.state[tp_id] = state + + def set_state(self): + assert self.dp_world_size > 0, "dp_world_size should be set before calling this" + unique_state = np.unique(self.state[: self.dp_world_size]) + self.state[self.dp_world_size] = unique_state[0] + + def get_state(self): + assert self.dp_world_size > 0, "dp_world_size should be set before calling this" + return self.state[self.dp_world_size] + + +class PDNIXLChunkedPrefillReq(ChunkedPrefillReq): + _pack_ = 4 + _fields_ = ChunkedPrefillReq._fields_ + [ + # 用于pd nixl状态同步 + ("pd_nixl_req_state", PdNixlReqState) + ] + + def set_dp_world_size(self, dp_world_size): + self.pd_nixl_req_state.dp_world_size = dp_world_size + + # called by each tp rank, no contention + def set_pd_req_rank_state(self, tp_id: int, state: int): + self.pd_nixl_req_state.set_tp_state(tp_id, state) + + # state: -1 for failed, 0 for in progress, 1 for success + # set by router + def set_pd_req_state(self): + self.pd_nixl_req_state.set_state() + + # read by all rank + def get_pd_req_state(self): + return self.pd_nixl_req_state.get_state() diff --git a/lightllm/server/core/objs/shm_req_manager.py b/lightllm/server/core/objs/shm_req_manager.py index 315eb938e..5dd73ad02 100644 --- a/lightllm/server/core/objs/shm_req_manager.py +++ b/lightllm/server/core/objs/shm_req_manager.py @@ -3,7 +3,7 @@ from lightllm.utils.envs_utils import get_unique_server_name from multiprocessing import shared_memory from lightllm.utils.log_utils import init_logger -from .req import Req, NormalReq, ChunkedPrefillReq, TokenHealingReq +from .req import Req, NormalReq, ChunkedPrefillReq, TokenHealingReq, PDNIXLChunkedPrefillReq from .shm_array import ShmArray from .atomic_array_lock import AtomicShmArrayLock, AtomicLockItem from .atomic_lock import AtomicShmLock @@ -33,6 +33,9 @@ def get_req_class_type(self): if args.token_healing_mode: return TokenHealingReq + if args.run_mode in ["nixl_prefill", "nixl_decode"]: + return PDNIXLChunkedPrefillReq + if args.disable_chunked_prefill: return NormalReq else: diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index ec1eb427e..e50826cde 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -6,7 +6,10 @@ @dataclass class StartArgs: - run_mode: str = field(default="normal", metadata={"choices": ["normal", "prefill", "decode", "pd_master"]}) + run_mode: str = field( + default="normal", + metadata={"choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode"]}, + ) host: str = field(default="127.0.0.1") port: int = field(default=8000) zmq_mode: str = field( diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index fa455c225..6a1a33a78 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -101,7 +101,7 @@ def __init__( self.metric_client = MetricClient(metric_port) self.pd_mode: NodeRole = NodeRole(self.args.run_mode) - assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL] + assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL, NodeRole.NP, NodeRole.ND] self.id_gen = ReqIDGenerator() self.first_time_costs = MovingAverage() self.per_token_costs = MovingAverage() @@ -228,7 +228,7 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False): # health 请求 request_id 为负数,直接返回 if is_health_req: return sampling_params.group_request_id - if self.pd_mode == NodeRole.NORMAL: + if self.pd_mode.is_normal(): if not self.is_multinode_tp: group_request_id = self.id_gen.generate_id() else: @@ -238,7 +238,7 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False): assert sampling_params.group_request_id != -1 group_request_id = sampling_params.group_request_id sampling_params.group_request_id = group_request_id - elif self.pd_mode == NodeRole.P or self.pd_mode == NodeRole.D: + elif self.pd_mode.is_P_or_D(): assert sampling_params.group_request_id is not None, "p d mode, group_request_id must be setting" group_request_id = sampling_params.group_request_id else: @@ -419,7 +419,7 @@ async def transfer_to_next_module_or_node( if self.is_multinode_tp_master: async with self.transfer_lock: for sender in self.multinode_req_manager: - sender.send_pyobj( + await sender.send_pyobj( (prompt, sampling_params, original_multimodal_params), protocol=pickle.HIGHEST_PROTOCOL, ) @@ -448,35 +448,37 @@ async def transfer_to_next_module( group_req_objs: Optional[GroupReqObjs] = None, ): - if self.pd_mode == NodeRole.P: + if self.pd_mode.is_P(): if self.enable_multimodal: - self.send_to_visual.send_pyobj( + await self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) else: - self.send_to_router.send_pyobj( + + # P 模式下,直接将请求发送到路由器 + await self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) return - if self.pd_mode == NodeRole.D: + if self.pd_mode.is_D(): # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了, 传输一个空的即可 - self.send_to_router.send_pyobj( + await self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) return - if self.pd_mode == NodeRole.NORMAL: + if self.pd_mode.is_normal(): if self.enable_multimodal: - self.send_to_visual.send_pyobj( + await self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) else: - self.send_to_router.send_pyobj( + await self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) @@ -521,7 +523,7 @@ async def _wait_to_token_package( # pd master 节点需要这个做统计信息, 所以放在元数据中返回给 pd master 节点 metadata["prompt_tokens"] = prompt_tokens # p 节点返回 prompt_ids 信息,防止 d 节点重新 encode - if self.pd_mode == NodeRole.P and is_first_token: + if self.pd_mode.is_P() and is_first_token: metadata["prompt_ids"] = prompt_ids prompt_cache_len = metadata.pop("prompt_cache_len", 0) @@ -623,7 +625,7 @@ async def recycle_resource_loop(self): pre_time_mark = time.time() for req_status in self.req_id_to_out_inf.values(): logger.info( - f"left req id {req_status.group_req_objs.group_req_id}" + f"left req id {req_status.group_req_objs.group_req_id} " f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} " f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}" ) diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index 10a4a8ec5..94394182d 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -5,6 +5,7 @@ import socket import httpx import base64 +import zmq from typing import Dict, Optional from lightllm.server.pd_io_struct import NodeRole, ObjType from lightllm.server.httpserver.async_queue import AsyncQueue @@ -33,6 +34,8 @@ async def pd_handle_loop(manager: HttpServerManager): manager.host_ip = manager.args.host asyncio.create_task(timer_log(manager)) + if manager.pd_mode.is_NP_or_ND(): + asyncio.create_task(pd_handle_loop_from_d(manager)) id_to_handle_task: Dict[int, asyncio.Task] = {} @@ -92,7 +95,8 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O logger.info(f"Sent registration JSON: {regist_json}") # 转发任务 - forwarding_tokens_task = asyncio.create_task(_up_tokens_to_pd_master(forwarding_queue, websocket)) + if manager.pd_mode != NodeRole.NP: # nixl prefill don't need up token to master + forwarding_tokens_task = asyncio.create_task(_up_tokens_to_pd_master(forwarding_queue, websocket)) # 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。 while True: @@ -182,3 +186,33 @@ async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket): handle_list = await forwarding_queue.wait_to_get_all_data() if handle_list: await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list))) + + +async def pd_handle_loop_from_d(manager: HttpServerManager): + if manager.pd_mode != NodeRole.NP: + return + + context = zmq.asyncio.Context(2) + manager.recv_from_d = context.socket(zmq.PULL) + manager.recv_from_d.bind(f"tcp://*:{manager.args.pd_nixl_remote_prefill_http_port}") + + while True: + try: + ( + prompt, + sampling_params, + multimodal_params, + ) = await manager.recv_from_d.recv_pyobj() + + # 触发推理的task + async def pd_process_generate(manager: "HttpServerManager", prompt, sampling_params, multimodal_params): + try: + async for _, _, _, _ in manager.generate(prompt, sampling_params, multimodal_params, None): + pass + except BaseException as e: + logger.error(str(e)) + + asyncio.create_task(pd_process_generate(manager, prompt, sampling_params, multimodal_params)) + + except Exception as e: + logger.exception(f"pd loop generate error: {str(e)}") diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 05b2d987c..65bec6a1c 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -1,20 +1,15 @@ import sys -import zmq -import zmq.asyncio import asyncio import uvloop -import rpyc import time -import hashlib import datetime -import aiohttp import ujson as json import pickle asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from typing import Union, List, Tuple, Dict from lightllm.server.core.objs import FinishStatus -from ..pd_io_struct import PD_Client_Obj, UpKVStatus, ObjType +from ..pd_io_struct import PD_Client_Obj, UpKVStatus, ObjType, NodeRole from lightllm.server.core.objs import SamplingParams from ..multimodal_params import MultimodalParams from ..tokenizer import get_tokenizer @@ -55,10 +50,11 @@ async def register_pd(self, pd_info_json, websocket): pd_client = PD_Client_Obj(**pd_info_json) pd_client.websocket = websocket self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client - if pd_client.mode == "prefill": + client_pd_mode: NodeRole = NodeRole(pd_client.mode) + if client_pd_mode.is_P(): self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port] self.prefill_nodes.append(pd_client) - elif pd_client.mode == "decode": + elif client_pd_mode.is_D(): self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port] self.decode_nodes.append(pd_client) else: @@ -110,8 +106,8 @@ async def select_p_d_node( ) -> Tuple[PD_Client_Obj, PD_Client_Obj]: import random - p_node = random.choice(self.prefill_nodes) - d_node = random.choice(self.decode_nodes) + p_node = random.choice(self.prefill_nodes) if self.prefill_nodes else None + d_node = random.choice(self.decode_nodes) if self.decode_nodes else None return p_node, d_node async def generate( @@ -133,6 +129,10 @@ async def generate( p_node, d_node = await self.select_p_d_node(prompt, sampling_params, multimodal_params) + if not p_node or not d_node: + logger.error(f"{group_request_id}: No p_node or d_node found") + return + results_generator = self._wait_to_token_package( p_node, d_node, @@ -233,8 +233,8 @@ async def fetch_stream( try: await asyncio.wait_for(up_status_event.wait(), timeout=60) except asyncio.TimeoutError: - logger.warning(f"group_request_id: {group_request_id} kv move time out err, server is busy now.") - raise ServerBusyError() + logger.warning(f"group_request_id: {group_request_id} kv move time out err") + assert False, f"req_id {group_request_id} kv move time out, server is busy" sampling_params.move_kv_to_decode_node.initialize(None) sampling_params.max_new_tokens = old_max_new_tokens - 1 @@ -253,6 +253,43 @@ async def fetch_stream( return + async def fetch_stream_nixl( + self, + p_node: PD_Client_Obj, + d_node: PD_Client_Obj, + prompt: Union[str, List[int]], + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + request: Request, + ): + group_request_id = sampling_params.group_request_id + + req_status = ReqStatus(group_request_id, p_node, d_node) + self.req_id_to_out_inf[group_request_id] = req_status + + p_start_args = p_node.start_args + prefill_node_dict = { + "node_id": p_start_args["pd_node_id"], + "ip": p_start_args["host"], + "rpyc_port": p_start_args["pd_nixl_remote_prefill_port"], + "max_new_tokens": sampling_params.max_new_tokens, + "pd_master_node_id": self.args.pd_node_id, + } + + sampling_params.move_kv_to_decode_node.initialize(prefill_node_dict) + sampling_params.suggested_dp_index = -1 + + await d_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params)))) + + while True: + await req_status.wait_to_ready() + if await request.is_disconnected(): + raise Exception(f"req_id {group_request_id} disconnected") + if await req_status.can_read(self.req_id_to_out_inf): + token_list = await req_status.pop_all_tokens() + for sub_req_id, request_output, metadata, finish_status in token_list: + yield sub_req_id, request_output, metadata, finish_status + async def _wait_to_token_package( self, p_node: PD_Client_Obj, @@ -269,7 +306,11 @@ async def _wait_to_token_package( unfinished_count = sampling_params.best_of is_first_token = True - async for sub_req_id, out_str, metadata, finish_status in self.fetch_stream( + client_mode: NodeRole = NodeRole(d_node.mode) + + fetch_stream = self.fetch_stream_nixl if client_mode.is_NP_or_ND() else self.fetch_stream + + async for sub_req_id, out_str, metadata, finish_status in fetch_stream( p_node, d_node, prompt, sampling_params, multimodal_params, request ): if await request.is_disconnected(): diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index bf320e199..2e2dab5b0 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -154,6 +154,12 @@ def to_dict(self): ret["audios"] = [a.to_dict() for a in self.audios] return ret + @classmethod + def from_dict(cls, data: dict): + if "images" not in data: + return cls() + return cls(images=data["images"]) + def to_origin_dict(self): """ 将内容转换为原始请求的形式,主要用于请求转发 diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 414e3c74a..4267afaee 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -13,23 +13,30 @@ class NodeRole(enum.Enum): P = "prefill" D = "decode" + + NP = "nixl_prefill" + ND = "nixl_decode" + NORMAL = "normal" PD_MASTER = "pd_master" def is_D(self): - return self == NodeRole.D + return self == NodeRole.D or self == NodeRole.ND def is_P(self): - return self == NodeRole.P + return self == NodeRole.P or self == NodeRole.NP def is_normal(self): return self == NodeRole.NORMAL def is_P_or_NORMAL(self): - return (self == NodeRole.P) or (self == NodeRole.NORMAL) + return self.is_P() or self.is_normal() def is_P_or_D(self): - return (self == NodeRole.P) or (self == NodeRole.D) + return self.is_P() or self.is_D() + + def is_NP_or_ND(self): + return self == NodeRole.NP or self == NodeRole.ND class ObjType(enum.Enum): @@ -47,8 +54,8 @@ class PD_Client_Obj: websocket: WebSocket = None # 用于通信的 websocket 连接对象 def __post_init__(self): - if self.mode not in ["prefill", "decode"]: - error_info = f"""mode must in ["prefill", "decode"], but get {self.mode}""" + if self.mode not in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + error_info = f"""mode must in ["prefill", "decode", "nixl_prefill", "nixl_decode"], but get {self.mode}""" logger.error(error_info) raise ValueError(error_info) return @@ -114,6 +121,23 @@ class PDTransJoinInfo: connect_id: str +@dataclass +class RemotePrefillServerInfo: + perfill_server_id: int + prefill_server_ip: str + prefill_server_port: int + + +@dataclass +class DistInfo: + world_size: int + nnodes: int + dp_size: int + dp_world_size: int + dp_size_in_node: int + node_world_size: int + + @dataclass class PDTransLeaveInfo: decode_id: int diff --git a/lightllm/server/router/batch.py b/lightllm/server/router/batch.py index 14a987f49..94cbd0536 100644 --- a/lightllm/server/router/batch.py +++ b/lightllm/server/router/batch.py @@ -1,8 +1,9 @@ import time from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union -from lightllm.server.core.objs import ShmReqManager, Req +from lightllm.server.core.objs import ShmReqManager, Req, PDNIXLChunkedPrefillReq from lightllm.utils.log_utils import init_logger +from lightllm.utils.dist_utils import get_dp_world_size logger = init_logger(__name__) @@ -53,6 +54,8 @@ def filter_out_finished_req(self, shm_req_manager: ShmReqManager): req = None else: unfinished_req_ids.append(req.request_id) + if isinstance(req, PDNIXLChunkedPrefillReq): + req.set_pd_req_state() self.reqs = [self.id_to_reqs[req_id] for req_id in unfinished_req_ids] self.id_to_reqs = {req.request_id: req for req in self.reqs} diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index c10847e3f..3c97c1e31 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -16,7 +16,7 @@ from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue from lightllm.server.core.objs.io_objs import GroupReqIndexes -from lightllm.server.core.objs import ShmReqManager, StartArgs +from lightllm.server.core.objs import ShmReqManager, StartArgs, PDNIXLChunkedPrefillReq from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient from .stats import Stats from .pause_strategy import Fcfs, select_paused_reqs @@ -28,6 +28,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.pd_io_struct import DistInfo logger = init_logger(__name__) @@ -44,6 +45,7 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por self.dp_size = args.dp # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容 self.dp_size_in_node = max(1, args.dp // self.nnodes) + self.dp_world_size = self.world_size // self.dp_size self.is_multinode_tp = args.nnodes > 1 and args.dp == 1 self.is_multinode_and_multidp = args.nnodes > 1 and args.dp > 1 # 判断是否是保守调度,保守调度不会发生暂停 req 的情况,但是有些场景可能影响吞吐 @@ -94,8 +96,8 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por self.stats_tool = Stats(not args.disable_log_stats, args.log_stats_interval) self.metric_client = MetricClient(metric_port) - self.is_pd_run_mode = self.args.run_mode in ["prefill", "decode"] - self.is_pd_decode_mode = self.args.run_mode == "decode" + self.is_pd_run_mode = self.args.run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"] + self.is_pd_decode_mode = self.args.run_mode in ["decode", "nixl_decode"] # p d 分离模式下,需要调度锁来同步调度端和推理端的一些数据操作 # 主要是为了防止调度失误,造成 OOM 等错误 self.router_lock = mp.Lock() @@ -114,12 +116,14 @@ async def wait_to_model_ready(self): self.mem_queues: List[torch.multiprocessing.Queue] = [ torch.multiprocessing.Queue() for _ in range(self.node_world_size) ] + self.result_queues: List[mp.Queue] = [mp.Queue() for _ in range(self.node_world_size)] self.rpc_event = multiprocessing.Event() self.rpc_finished_event = multiprocessing.Event() assert (self.world_size % self.nnodes) == 0 node_world_size = self.world_size // self.nnodes for rank_id in range(self.node_rank * node_world_size, (self.node_rank + 1) * node_world_size): + rpc_model = await start_model_process( args=self.args, rank=rank_id, @@ -128,7 +132,8 @@ async def wait_to_model_ready(self): rpc_event=self.rpc_event, rpc_finished_event=self.rpc_finished_event, info_queue=self.info_queue, - mem_queue=self.mem_queues[(rank_id % node_world_size)], + result_queue=self.result_queues[rank_id % node_world_size], + mem_queue=self.mem_queues[rank_id % node_world_size], router_lock=self.router_lock, ) self.model_rpc_servers.append(rpc_model) @@ -181,7 +186,7 @@ async def wait_to_model_ready(self): get_unique_server_name(), self.max_total_token_num, node_world_size=self.node_world_size, - dp_world_size=self.world_size // self.dp_size, + dp_world_size=self.dp_world_size, ) self.req_queue = build_req_queue(self.args, self, self.dp_size_in_node) logger.info(f"use req queue {self.req_queue.__class__.__name__}") @@ -194,6 +199,30 @@ async def wait_to_model_ready(self): start_prefill_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + if self.args.run_mode == "nixl_prefill": + from lightllm.server.router.model_infer.mode_backend.pd_nixl.pd_remote_prefill import ( + start_pd_remote_prefill_server_process, + ) + + dist_info = DistInfo( + self.world_size, + self.nnodes, + self.dp_size, + self.dp_world_size, + self.dp_size_in_node, + self.node_world_size, + ) + + start_pd_remote_prefill_server_process( + self.args.pd_node_id, + dist_info=dist_info, + http_server_port=self.args.pd_nixl_remote_prefill_http_port, + server_port=self.args.pd_nixl_remote_prefill_port, + from_backend_queue=self.info_queue, + to_backend_queues=self.result_queues, + agent_meta_queues=self.mem_queues, + ) + if self.args.run_mode == "decode": # 启动 decode kv move 管理进程 from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.decode_node_impl import ( @@ -202,6 +231,28 @@ async def wait_to_model_ready(self): start_decode_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + if self.args.run_mode == "nixl_decode": + from lightllm.server.router.model_infer.mode_backend.pd_nixl.pd_remote_prefill import ( + start_pd_remote_prefill_client_process, + ) + + dist_info = DistInfo( + self.world_size, + self.nnodes, + self.dp_size, + self.dp_world_size, + self.dp_size_in_node, + self.node_world_size, + ) + + start_pd_remote_prefill_client_process( + self.args.pd_node_id, + dist_info, + from_backend_queue=self.info_queue, + to_backend_queues=self.result_queues, + agent_meta_queues=self.mem_queues, + ) + return def add_req(self, group_req_indexes: GroupReqIndexes): @@ -210,6 +261,8 @@ def add_req(self, group_req_indexes: GroupReqIndexes): req = self.shm_req_manager.get_req_obj_by_index(req_index) req.multimodal_params = group_req_indexes.multimodal_params req.start_time = group_req_indexes.time_mark + if isinstance(req, PDNIXLChunkedPrefillReq): + req.set_dp_world_size(self.dp_world_size) req_group.append(req) logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s") diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 10b68245c..6c919b363 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -10,7 +10,7 @@ from typing import List, Dict, Tuple, Optional, Union, Any from lightllm.common.req_manager import ReqManager from lightllm.utils.infer_utils import mark_start, mark_end -from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager +from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager, PDNIXLChunkedPrefillReq from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode from lightllm.utils.log_utils import init_logger from lightllm.server.req_id_generator import convert_sub_id_to_group_id @@ -119,7 +119,7 @@ def _save_promptcache_kvbuffer(self): https://arxiv.org/abs/2403.01241 """ prompt_cache_token_id = list(self.radix_cache.root_node.children.values())[0].token_id_key - print(f"prompt_cache_token_id : {prompt_cache_token_id}") + # print(f"prompt_cache_token_id : {prompt_cache_token_id}") index = range(len(prompt_cache_token_id)) prompt_cache_kv_buffer = self.radix_cache.mem_manager.get_index_kv_buffer(index) torch.save(prompt_cache_kv_buffer, f"prompt_cache_rank_{dist.get_rank()}.pt") @@ -266,12 +266,14 @@ def __init__( # 当开启后,mtp_gen_token_ids 保存多生成的多余的token_id,但是在后面的 # 步骤中需要重新进行校验。 self.mtp_gen_token_ids: List[int] = [] + self.in_prefill_or_transfer = False def init_all(self): if self.initialized is False: self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index) self.shm_req.link_prompt_ids_shm_array() self.shm_req.link_logprobs_shm_array() + self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size) g_infer_context.req_manager.req_sampling_params_manager.init_req_sampling_params(self) @@ -307,6 +309,7 @@ def init_all(self): self.initialized = True self.paused = False + self.in_prefill_or_transfer = False return def is_uninitialized(self): diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index 7ad15f00f..9d43c3b7b 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -20,3 +20,7 @@ from .continues_batch.pd_mode.decode_node_impl.decode_impl_mtp import ContinuesBatchBackendForMtpDecodeNode from .continues_batch.pd_mode.prefill_node_impl.prefill_impl_mtp import ChunckedPrefillForMtpPrefillNode from .continues_batch.pd_mode.prefill_node_impl.prefill_impl_mtp_for_dp_chuncked import DPChunkedForMtpPrefillNode +from .pd_nixl.impl_for_pd_prefill import PDNIXLBackendForPrefillNode +from .pd_nixl.impl_for_pd_decode import PDNIXLBackendForDecodeNode +from .pd_nixl.impl_for_pd_decode_dp import PDNIXLDPBackendForDecodeNode +from .pd_nixl.impl_for_pd_prefill_dp import PDNIXLDPBackendForPrefillNode diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index dd1ea45fe..fdfb01bd3 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -261,6 +261,7 @@ def _post_handle( is_chuncked_mode: bool, do_filter_finished_reqs: bool, extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, + call_post_handle_for_chunk: bool = False, ) -> List[int]: """ extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于 @@ -284,6 +285,10 @@ def _post_handle( # 对于没有到达需要输出 token 阶段的请求,直接略过, 说明还 # 处于chuncked prefill kv 填充的阶段。 if req_obj.cur_kv_len < req_obj.get_cur_total_len(): + # chunk transfer + if call_post_handle_for_chunk and extra_post_req_handle_func: + extra_post_req_handle_func(req_obj, next_token_id, next_token_logprob) + continue # 将生成的下一个token的信息写入到管理对象中。 diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py index 184fc7a1c..6aee9af10 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py @@ -46,6 +46,7 @@ def normal_prefill_reqs( ok_finished_reqs: List[InferReq], mask_func: Optional[Callable[[List[InferReq], torch.Tensor], None]] = None, extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, + call_post_handle_for_chunk: bool = False, ): model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal @@ -69,6 +70,7 @@ def normal_prefill_reqs( is_chuncked_mode=not self.disable_chunked_prefill, do_filter_finished_reqs=False, extra_post_req_handle_func=extra_post_req_handle_func, + call_post_handle_for_chunk=call_post_handle_for_chunk, ) return diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index e575cab14..8db9b5eb4 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -1,5 +1,5 @@ import torch -from typing import List, Tuple +from typing import List, Tuple, Callable, Optional from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend from lightllm.common.basemodel.batch_objs import ModelOutput @@ -52,7 +52,15 @@ def decode(self): self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) return - def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs): + def normal_prefill_reqs( + self, + prefill_reqs: List[InferReq], + max_prefill_num: int, + uninit_reqs, + ok_finished_reqs, + extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, + call_post_handle_for_chunk: bool = False, + ): model_input, run_reqs, padded_req_num = padded_prepare_prefill_inputs( prefill_reqs, is_multimodal=self.is_multimodal ) @@ -65,7 +73,13 @@ def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int next_token_ids = next_token_ids.detach().cpu().numpy() next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() self._post_handle( - run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False + run_reqs, + next_token_ids, + next_token_logprobs, + is_chuncked_mode=True, + do_filter_finished_reqs=False, + extra_post_req_handle_func=extra_post_req_handle_func, + call_post_handle_for_chunk=call_post_handle_for_chunk, ) return @@ -117,7 +131,15 @@ def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, unini ) return - def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs): + def overlap_prefill_reqs( + self, + prefill_reqs: List[InferReq], + max_prefill_num: int, + uninit_reqs, + ok_finished_reqs, + extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, + call_post_handle_for_chunk: bool = False, + ): ( micro_input, run_reqs, @@ -142,6 +164,12 @@ def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: in next_token_ids = next_token_ids.detach().cpu().numpy() next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() self._post_handle( - all_run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False + all_run_reqs, + next_token_ids, + next_token_logprobs, + is_chuncked_mode=True, + do_filter_finished_reqs=False, + extra_post_req_handle_func=extra_post_req_handle_func, + call_post_handle_for_chunk=call_post_handle_for_chunk, ) return diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 185570d74..c87883c75 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -81,7 +81,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_req_idx.append(req.req_idx) input_id = req.get_last_gen_token() seq_len = req.get_cur_total_len() - assert req.cur_kv_len == seq_len - 1 + assert req.cur_kv_len == seq_len - 1, f"{req.cur_kv_len} {seq_len}" b_seq_len.append(seq_len) input_ids.append(input_id) total_token_num += seq_len diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py new file mode 100644 index 000000000..5ddb9219b --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py @@ -0,0 +1,558 @@ +import time +from concurrent.futures import ThreadPoolExecutor +import torch.multiprocessing as mp +import torch +from typing import Dict, List +import queue +import numpy as np +import asyncio +import pickle +import threading + + +from lightllm.utils.log_utils import init_logger +from lightllm.server.core.objs.req import PDNIXLChunkedPrefillReq +from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq + +from .nixl_kv_transporter import NixlMetadata, NixlKVTransporter +from .pd_remote_prefill_obj import ( + PrefillRequest, + RemoteRequest, + RemoteRequstType, + ConnectRequest, + KVMoveRequest, + RemotePrefillStatus, + ThreadSafeDict, + TransferState, + SafePageIndexScheduler, + RemoteTransferType, + RemoteTransferStatusType, + PageTransferAck, + NotificationType, + Notification, +) + +logger = init_logger(__name__) + + +class PDNIXLBackendBase(ModeBackend): + _THREAD_WAIT_INTERVAL = 0.001 + + def __init__(self, to_remote_queue: mp.Queue, from_remote_queue: mp.Queue, nixl_meta_queue: mp.Queue): + super().__init__() + self.to_remote_queue = to_remote_queue + self.from_remote_queue = from_remote_queue + self.nixl_meta_queue = nixl_meta_queue + self.prefill_post_handle_queue = queue.Queue() + + # for decode + self.remote_prefilled_reqs: ThreadSafeDict = ThreadSafeDict() + self.request_to_page_ids: ThreadSafeDict = ThreadSafeDict() + self.request_to_first_token: ThreadSafeDict = ThreadSafeDict() + + # for prefill + self.remote_prefill_requests: ThreadSafeDict = ThreadSafeDict() + self.inflght_transfer_requests: ThreadSafeDict = ThreadSafeDict() + + def init_custom(self): + self.nixl_agent = NixlKVTransporter(self.args.pd_node_id, self.rank_in_node) + self.nixl_agent.register_kv_buffer(self.model.mem_manager.kv_buffer) + self.nixl_agent.register_kv_move_buffer(self.model.mem_manager.kv_move_buffer) + self.page_scheduer = SafePageIndexScheduler(self.nixl_agent.num_pages) + + self.nixl_meta_queue.put( + ( + self.nixl_agent.agent_metadata, + self.nixl_agent.num_tokens, + self.nixl_agent.num_pages, + self.nixl_agent.local_mem_desc, + self.nixl_agent.local_page_mem_desc, + ) + ) + + def _start_async_loop(self, async_loop_func): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(async_loop_func()) + + async def _handle_remote_prefill(self, req_status: RemotePrefillStatus): + group_req_id = req_status.group_req_id + status = req_status.status + if status != RemoteTransferStatusType.SUCCESS: + logger.warning(f"remote prefill reqeust: {group_req_id} done with state: {status}") + + ret = None + if run_req := self.remote_prefilled_reqs.get(group_req_id, None): + if ( + req_status.transfer_type == RemoteTransferType.PAGE_TRANSFER + and status == RemoteTransferStatusType.SUCCESS + ): + kv_start, kv_len = req_status.kv_start, req_status.kv_len + token_ids = g_infer_context.req_manager.req_to_token_indexs[run_req.req_idx][ + kv_start : kv_start + kv_len + ] # gpu tensor + self.model.mem_manager.kv_buffer[:, token_ids, :, :] = self.model.mem_manager.kv_move_buffer[ + req_status.page_id + ][:kv_len].transpose(0, 1) + ret = PageTransferAck(group_req_id=group_req_id, page_id=req_status.page_id) + + if req_status.is_last or status != RemoteTransferStatusType.SUCCESS: + shm_req: PDNIXLChunkedPrefillReq = run_req.shm_req + shm_req.set_pd_req_rank_state(self.rank_in_dp, status.value) + self.remote_prefilled_reqs.pop(group_req_id) + self.request_to_first_token[group_req_id] = (req_status.next_token_id, req_status.next_token_logprob) + + if self.is_master_in_dp: + # return page ids + if group_req_id in self.request_to_page_ids: + self.page_scheduer.return_(self.request_to_page_ids[group_req_id]) + del self.request_to_page_ids[group_req_id] + + logger.info( + f"remote prefill reqeust: {group_req_id} done with status: {status} " + f"took: {time.time() - run_req.remote_prefill_start} seconds" + ) + ret = None + + else: + if self.is_master_in_dp: + logger.warning(f"remote prefill reqeust: {group_req_id} not found") + + return ret + + async def _prefill_wait_loop_async(self): + while True: + # from local + try: + req_status = self.from_remote_queue.get_nowait() + await self._handle_remote_prefill(req_status) + except queue.Empty: + pass + + # from remote + notifies = self.nixl_agent.get_new_notifs() + for agent_name, req_statuses in notifies.items(): + acks = [] + for req_statuses_bytes in req_statuses: + noti: Notification = Notification.from_bytes(req_statuses_bytes) + if noti.type == NotificationType.REMOTE_MD: + self.nixl_agent.connect_to_remote(agent_name, noti.data) + elif noti.type == NotificationType.TRANSFER_NOTIFY: + for req_status in noti.data: + prefill_status = RemotePrefillStatus.deserialize(req_status) + ack = await self._handle_remote_prefill(prefill_status) + if ack: + acks.append(ack) + if len(acks) > 0: + # wait for copy done + torch.cuda.current_stream().synchronize() + logger.info(f"send {len(acks)} acks to {agent_name}") + self.nixl_agent.send_transfer_notify(agent_name, acks) + + await asyncio.sleep(PDNIXLBackendBase._THREAD_WAIT_INTERVAL) + + def _handle_chunked_transfer(self, req: InferReq, next_token_id: int = None, next_token_logprob: float = None): + if next_token_id: + next_token_id = int(next_token_id) + next_token_logprob = float(next_token_logprob) + + shm_req: PDNIXLChunkedPrefillReq = req.shm_req + group_req_id = shm_req.group_req_id + if group_req_id not in self.remote_prefill_requests: + logger.info(f"remote prefill request {group_req_id} not found") + return + + remote_request: PrefillRequest = self.remote_prefill_requests[group_req_id] + if remote_request.transfer_state is None: + remote_request.transfer_state = TransferState( + start_time=time.time(), + current_chunk_id=0, + transfered_kv_len=remote_request.data.local_cached_len, + current_kv_len=req.cur_kv_len, + is_finished=req.finish_status.is_finished(), + token_index=self.model.req_manager.req_to_token_indexs[req.req_idx].tolist(), + free_page_ids=remote_request.data.page_ids.copy(), + next_token_id=next_token_id, + next_token_logprob=next_token_logprob, + lock=threading.Lock(), + ) + shm_req.set_pd_req_rank_state(self.rank_in_dp, RemoteTransferStatusType.IN_PROGRESS.value) + req.in_prefill_or_transfer = True + self.inflght_transfer_requests[group_req_id] = req + else: + transfer_state: TransferState = remote_request.transfer_state + with transfer_state.lock: + transfer_state.current_chunk_id += 1 + transfer_state.current_kv_len = req.cur_kv_len + transfer_state.is_finished = req.finish_status.is_finished() + if next_token_id: + transfer_state.next_token_id = next_token_id + transfer_state.next_token_logprob = next_token_logprob + + async def _transfer_kv_to_remote_paged_batch(self, transfer_reqs: List[KVMoveRequest]): + start = time.time() + requests_by_agents = dict() + transfer_pages = self.page_scheduer.borrow(len(transfer_reqs)) + # first copy the kv to transfer pages & build notification + for trans_req, page_index in zip(transfer_reqs, transfer_pages): + trans_req: KVMoveRequest + group_req_id = trans_req.group_req_id + remote_request: PrefillRequest = self.remote_prefill_requests.get(group_req_id) + transfer_state: TransferState = remote_request.transfer_state + decode_id: int = remote_request.decode_id + if decode_id not in requests_by_agents: + requests_by_agents[decode_id] = ([], [], []) + + with transfer_state.lock: + + start_kv_len = transfer_state.transfered_kv_len + trans_kv_len = min(trans_req.cur_kv_len - trans_req.prev_kv_len, self.nixl_agent.page_size) + trans_kv_index = transfer_state.token_index[start_kv_len : start_kv_len + trans_kv_len] + self.model.mem_manager.kv_move_buffer[page_index][:trans_kv_len] = self.model.mem_manager.kv_buffer[ + :, trans_kv_index, :, : + ].transpose(0, 1) + + receive_page = transfer_state.free_page_ids.pop(0) + requests_by_agents[decode_id][0].append(page_index) + requests_by_agents[decode_id][1].append(receive_page) + is_last = transfer_state.is_finished and start_kv_len + trans_kv_len == transfer_state.current_kv_len + + requests_by_agents[decode_id][2].append( + RemotePrefillStatus( + transfer_type=RemoteTransferType.PAGE_TRANSFER, + group_req_id=group_req_id, + status=RemoteTransferStatusType.SUCCESS, + chunk_id=transfer_state.current_chunk_id, + is_last=is_last, + page_id=receive_page, + kv_start=start_kv_len, + kv_len=trans_kv_len, + next_token_id=transfer_state.next_token_id, + next_token_logprob=transfer_state.next_token_logprob, + ) + ) + transfer_state.transfered_kv_len += trans_kv_len + + # wait copy done + torch.cuda.current_stream().synchronize() + for decode_id, (transfer_pages, receive_pages, notifications) in requests_by_agents.items(): + assert len(transfer_reqs) == len(receive_pages), "transfer_reqs and receive_pages should have same length" + # transfer + self.nixl_agent.write_blocks_paged(decode_id, transfer_pages, receive_pages, notifications) + + logger.info(f"transfer kv to remote paged batch: {len(transfer_reqs)} " f"took: {time.time() - start} seconds") + + async def _handle_transfer_loop(self): + while True: + free_transfer_pages = self.page_scheduer.current_size() + if free_transfer_pages < 1: + await asyncio.sleep(PDNIXLBackendBase._THREAD_WAIT_INTERVAL) + continue + + transfer_reqs = [] + for group_req_id, req in self.inflght_transfer_requests.items(): + remote_request: PrefillRequest = self.remote_prefill_requests.get(group_req_id) + transfer_state: TransferState = remote_request.transfer_state + with transfer_state.lock: + if transfer_state.completed() or len(transfer_state.free_page_ids) == 0: + continue + + if transfer_state.transfered_kv_len >= transfer_state.current_kv_len: + continue + + transfer_reqs.append( + KVMoveRequest( + group_req_id=group_req_id, + prev_kv_len=transfer_state.transfered_kv_len, + cur_kv_len=transfer_state.current_kv_len, + ) + ) + if len(transfer_reqs) >= free_transfer_pages: + break + + if len(transfer_reqs) > 0: + await self._transfer_kv_to_remote_paged_batch(transfer_reqs) + + await asyncio.sleep(PDNIXLBackendBase._THREAD_WAIT_INTERVAL) + + async def _wait_page_transfer_loop(self): + while True: + # local pages can be reused as soon as transfer is done + done_pages, done_requests = await self.nixl_agent.get_done_page_transfers() + if len(done_pages): + self.page_scheduer.return_(done_pages) + + # release requests when prefill done + for req_id, status in done_requests: + if req_id not in self.inflght_transfer_requests: + if self.is_master_in_dp: + logger.warning(f"{req_id} not found in inflght_transfer_requests") + continue + + req: InferReq = self.inflght_transfer_requests[req_id] + shm_req: PDNIXLChunkedPrefillReq = req.shm_req + shm_req.set_pd_req_rank_state(self.rank_in_dp, status.value) + transfer_state = self.remote_prefill_requests[req_id].transfer_state + if self.is_master_in_dp: + logger.info( + f"req: {req_id} kv transfer with state: {status} " + f"took: {time.time() - transfer_state.start_time} seconds" + ) + # only delete success transfers, failed / aborted will delete after send abort notification + if status == RemoteTransferStatusType.SUCCESS: + del self.inflght_transfer_requests[req_id] + del self.remote_prefill_requests[req_id] + + # remote pages should be released after nofication received + notifies = self.nixl_agent.get_new_notifs() + for _, trans_acks in notifies.items(): + for trans_ack_bytes in trans_acks: + trans_acks_noti: Notification = Notification.from_bytes(trans_ack_bytes) + assert trans_acks_noti.type == NotificationType.TRANSFER_NOTIFY_ACK + for trans_ack in trans_acks_noti.data: + ack = PageTransferAck.deserialize(trans_ack) + remote_request: PrefillRequest = self.remote_prefill_requests.get(ack.group_req_id) + if remote_request is None: + continue + + transfer_state: TransferState = remote_request.transfer_state + with transfer_state.lock: + transfer_state.free_page_ids.append(ack.page_id) + + await asyncio.sleep(PDNIXLBackendBase._THREAD_WAIT_INTERVAL) + + async def _wait_transfer_loop(self): + while True: + done_req_ids = self.nixl_agent.get_done_tranfers() + for req_id, state in done_req_ids: + if state != 1: + logger.info(f"wait transfer done: {req_id} state: {state}") + + if req_id not in self.inflght_transfer_requests: + if self.is_master_in_dp: + logger.warning(f"{req_id} not found in inflght_transfer_requests") + continue + + req: InferReq = self.inflght_transfer_requests[req_id] + shm_req: PDNIXLChunkedPrefillReq = req.shm_req + shm_req.set_pd_req_rank_state(self.rank_in_dp, state) + transfer_state = self.remote_prefill_requests[req_id].transfer_state + if self.is_master_in_dp: + logger.info( + f"req: {req_id} kv transfer with state: {state} " + f"took: {time.time() - transfer_state.start_time} seconds" + ) + del self.remote_prefill_requests[req_id] + del self.inflght_transfer_requests[req_id] + + time.sleep(PDNIXLBackendBase._THREAD_WAIT_INTERVAL) + + async def _handle_prefill_loop(self): + while True: + request: RemoteRequest = self.from_remote_queue.get() + if request.type == RemoteRequstType.REMOTE_CONNECT: + request: ConnectRequest + logger.info(f"connect request received from: {request.decode_id}") + self.nixl_agent.add_remote_agent( + NixlMetadata( + id=request.decode_id, + num_tokens=request.num_tokens, + num_pages=request.num_pages, + agent_metadatas=request.agent_metadatas, + agent_mem_descs=request.agent_mem_descs, + agent_page_mem_descs=request.agent_page_mem_descs, + ) + ) + self.to_remote_queue.put("OK") + + if request.type == RemoteRequstType.REMOTE_PREFILL: + request: PrefillRequest + group_request_id = request.data.sampling_params.group_request_id + logger.info( + f"prefill request received from decode: {request.decode_id} " + f"and group request id: {group_request_id}" + ) + self.remote_prefill_requests[group_request_id] = request + + def _transfer_kv_to_remote(self, req: InferReq, group_req_id: int, cur_kv_len: int, is_finished: bool): + start = time.time() + remote_request: PrefillRequest = self.remote_prefill_requests[group_req_id] + + transfer_state = remote_request.transfer_state + token_index = self.model.req_manager.req_to_token_indexs[req.req_idx] + + kv_transfer_req = KVMoveRequest( + group_req_id=group_req_id, + token_ids=token_index[:cur_kv_len].tolist(), + prev_kv_len=transfer_state.current_kv_len, + cur_kv_len=cur_kv_len, + ) + if transfer_state.current_chunk_id == 0: + self.inflght_transfer_requests[group_req_id] = req + logger.debug( + f"put {group_req_id} into inflght_transfer_requests and size: {len(self.inflght_transfer_requests)}" + ) + + # kick off kv transfer + self.nixl_agent.write_blocks(kv_transfer_req, remote_request, is_finished) + + transfer_state.current_kv_len = cur_kv_len + transfer_state.current_chunk_id += 1 + logger.info( + f"transfer kv to remote: {group_req_id} " + f"current chunk id: {transfer_state.current_chunk_id} {cur_kv_len} " + f"took: {time.time() - start} seconds" + ) + + def _post_remote_prefill(self, req: InferReq, success: bool = True): + + req.in_prefill_or_transfer = False + req.cur_kv_len = req.get_cur_total_len() + if self.is_master_in_dp: + req.shm_req.shm_cur_kv_len = req.cur_kv_len + + group_req_id = req.shm_req.group_req_id + if not success: + self.request_to_first_token.pop(group_req_id, None) + return + + assert group_req_id in self.request_to_first_token + token_id, token_logprob = self.request_to_first_token.pop(group_req_id) + + req.set_next_gen_token_id(token_id, token_logprob) + req.cur_output_len += 1 + + req.out_token_id_count[token_id] += 1 + req.update_finish_status(self.eos_id) + + if self.is_master_in_dp: + req.shm_req.shm_cur_output_len = req.cur_output_len + + if req.finish_status.is_finished(): + req.shm_req.finish_token_index = req.get_cur_total_len() - 1 + req.shm_req.finish_status = req.finish_status + + req.shm_req.candetoken_out_len = req.cur_output_len + + def _decode_filter_reqs( + self, prefill_reqs: List[InferReq], aborted_reqs: List[InferReq], decode_reqs: List[InferReq] + ): + new_prefill_reqs: List[InferReq] = [] + new_aborted_reqs: List[InferReq] = [] + remote_prefill_reqs: List[InferReq] = [] + + # filter out aborted requests + for req in aborted_reqs: + if req.in_prefill_or_transfer: + shm_req: PDNIXLChunkedPrefillReq = req.shm_req + state = shm_req.get_pd_req_state() + if state != RemoteTransferStatusType.IN_PROGRESS.value: + new_aborted_reqs.append(req) + self._post_remote_prefill(req, False) + else: + remote_prefill_reqs.append(req) + else: + new_aborted_reqs.append(req) + + for req in prefill_reqs: + if req.in_prefill_or_transfer: + shm_req: PDNIXLChunkedPrefillReq = req.shm_req + # state is updated by router + state = shm_req.get_pd_req_state() + if state == RemoteTransferStatusType.SUCCESS.value: # success + self._post_remote_prefill(req) + decode_reqs.append(req) + elif state == RemoteTransferStatusType.FAILED.value: # failure + self._post_remote_prefill(req, False) + new_aborted_reqs.append(req) + elif state == RemoteTransferStatusType.IN_PROGRESS.value: # in progress + remote_prefill_reqs.append(req) + else: + logger.warning(f"remote prefill request {shm_req.group_req_id} unexpected state {state}") + continue + + new_prefill_reqs.append(req) + + return new_prefill_reqs, new_aborted_reqs, decode_reqs, remote_prefill_reqs + + def _prefill_filter_reqs(self, ok_finished_reqs: List[InferReq], aborted_reqs: List[InferReq]): + new_ok_finished_reqs = [] + kv_transfer_reqs = [] + + for req in ok_finished_reqs: + if req.in_prefill_or_transfer: + shm_req: PDNIXLChunkedPrefillReq = req.shm_req + state = shm_req.get_pd_req_state() + if state == RemoteTransferStatusType.SUCCESS.value: # success + new_ok_finished_reqs.append(req) + req.in_prefill_or_transfer = False + elif state == RemoteTransferStatusType.FAILED.value: # failure + aborted_reqs.append(req) + req.in_prefill_or_transfer = False + elif state == RemoteTransferStatusType.IN_PROGRESS.value: + kv_transfer_reqs.append(req) + else: + logger.warning(f"remote prefill request {shm_req.group_req_id} unexpected state {state}") + continue + + new_ok_finished_reqs.append(req) + + return new_ok_finished_reqs, aborted_reqs, kv_transfer_reqs + + def _prepare_remote_prefill_inputs(self, req_objs: List[InferReq]): + run_reqs = [] + start_loc = 0 + input_ids = [] + nopad_b_req_idx = [] + nopad_b_start_loc = [] + nopad_b_seq_len = [] + + for req in req_objs: + run_reqs.append(req) + nopad_b_req_idx.append(req.req_idx) + nopad_b_start_loc.append(start_loc) + + input_token_ids = req.get_input_token_ids() + seq_len = len(input_token_ids) + input_token_len = seq_len - req.cur_kv_len + input_id = input_token_ids[req.cur_kv_len :] + nopad_b_seq_len.append(seq_len) + input_ids.append(input_id) + start_loc += input_token_len + + nopad_b_start_loc.append(start_loc) # last request + + input_ids = np.concatenate(input_ids, dtype=np.int64) + + if g_infer_context.radix_cache is not None: + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]) + + req_to_token_indexs = g_infer_context.req_manager.req_to_token_indexs + for idx, req_idx in enumerate(nopad_b_req_idx): + cur_kv_len = req_objs[idx].cur_kv_len + seq_len = nopad_b_seq_len[idx] + mem_start = nopad_b_start_loc[idx] + mem_end = nopad_b_start_loc[idx + 1] + req_to_token_indexs[req_idx, cur_kv_len : nopad_b_seq_len[idx]] = mem_indexes[mem_start:mem_end] + + kwargs = { + "batch_size": len(run_reqs), + "input_ids": input_ids, + "mem_indexes": mem_indexes.tolist(), + "b_req_idx": nopad_b_req_idx, + "b_start_loc": nopad_b_start_loc, + "b_seq_len": nopad_b_seq_len, + } + + return kwargs, run_reqs + + def _prefill_abort_remote(self, req_objs: List[InferReq]): + for req_obj in req_objs: + group_req_id = req_obj.shm_req.group_req_id + if group_req_id in self.remote_prefill_requests: + self.nixl_agent.send_abort_notify(self.remote_prefill_requests[group_req_id].decode_id, group_req_id) + del self.remote_prefill_requests[group_req_id] + if group_req_id in self.inflght_transfer_requests: + del self.inflght_transfer_requests[group_req_id] diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py new file mode 100644 index 000000000..8a642f2e6 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py @@ -0,0 +1,106 @@ +import time +import torch.multiprocessing as mp +import threading +from concurrent.futures import ThreadPoolExecutor +from lightllm.server.router.model_infer.mode_backend.continues_batch.impl import ContinuesBatchBackend +from typing import List, Tuple, Dict +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.server.core.objs.req import PDNIXLChunkedPrefillReq +from lightllm.utils.log_utils import init_logger +from lightllm.server.multimodal_params import MultimodalParams + +from .pd_remote_prefill_obj import ( + RemotePrefillTask, + RemotePrefillServerInfo, + RemotePrefillRequest, + RemoteTransferStatusType, +) + +from .impl_for_pd_base import PDNIXLBackendBase + +logger = init_logger(__name__) + + +class PDNIXLBackendForDecodeNode(PDNIXLBackendBase): + def __init__(self, prefill_task_queue: mp.Queue, prefill_done_queue: mp.Queue, nix_meta_queue: mp.Queue) -> None: + super().__init__(prefill_task_queue, prefill_done_queue, nix_meta_queue) + + def init_custom(self): + super().init_custom() + self.wait_prefill_thread = threading.Thread( + target=self._start_async_loop, args=(self._prefill_wait_loop_async,), daemon=True + ) + self.wait_move_page_pool = ThreadPoolExecutor(max_workers=4) + self.wait_prefill_thread.start() + return + + def _build_remote_prefill_task(self, index: int, kwargs: Dict, req: InferReq): + prefill_node = req.shm_req.sample_params.move_kv_to_decode_node.to_dict() + prefill_node_info = RemotePrefillServerInfo( + perfill_server_id=prefill_node["node_id"], + prefill_server_ip=prefill_node["ip"], + prefill_server_port=prefill_node["rpyc_port"], + ) + + mem_indexes = kwargs.get("mem_indexes") + b_start_loc = kwargs.get("b_start_loc") + prefill_request = RemotePrefillRequest( + prompt=req.shm_req.get_prompt_ids(), + sampling_params=req.shm_req.sample_params, + multimodal_params=MultimodalParams.from_dict(req.multimodal_params), + local_cached_len=req.cur_kv_len, + token_ids=mem_indexes[b_start_loc[index] : b_start_loc[index + 1]], + page_ids=self.page_scheduer.borrow(), # get page ids for this request, blocking when not enough pages + ) + return RemotePrefillTask(server_info=prefill_node_info, prefill_request=prefill_request) + + def _trigger_remote_prefill(self, req_id: int, index: int, kwargs: Dict, req: InferReq): + remote_prefill_task = self._build_remote_prefill_task(index, kwargs, req) + self.request_to_page_ids[req_id] = remote_prefill_task.prefill_request.page_ids + self.to_remote_queue.put(remote_prefill_task) + + def prefill(self, reqs: List[Tuple]): + self._init_reqs(reqs, init_req_obj=False) + return + + def decode(self): + + uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids, + no_decode=False, + ) + # filter out remote prefilling reqs + prefill_reqs, aborted_reqs, decode_reqs, _ = self._decode_filter_reqs(prefill_reqs, aborted_reqs, decode_reqs) + + self._filter_reqs(aborted_reqs) + + # allocate kv cache, do remote prefill + if prefill_reqs: + # TODO: we could allocate cache later after remote prefill done and get a signal from remote + # but it will have a risk to not have enough cache for this request. + kwargs, run_reqs = self._prepare_remote_prefill_inputs(prefill_reqs) + for idx, run_req in enumerate(run_reqs): + run_req: InferReq = run_req + shm_req: PDNIXLChunkedPrefillReq = run_req.shm_req + # forward each req to remote prefill + # since the token index are the same across TPs, we only need to trigger prefill on master + if self.is_master_in_dp: + run_req.remote_prefill_start = time.time() + # since this function may blocking the calling thread, so we do it in a thread pool + self.wait_move_page_pool.submit( + self._trigger_remote_prefill, shm_req.group_req_id, idx, kwargs, run_req + ) + + shm_req.set_pd_req_rank_state( + self.rank_in_dp, RemoteTransferStatusType.IN_PROGRESS.value + ) # set in progress state + run_req.in_prefill_or_transfer = True + self.remote_prefilled_reqs[shm_req.group_req_id] = run_req + + if decode_reqs: + ContinuesBatchBackend.normal_decode( + self, decode_reqs=decode_reqs, uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs + ) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + return diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py new file mode 100644 index 000000000..85c45869f --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py @@ -0,0 +1,74 @@ +import time +import torch +import torch.multiprocessing as mp +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.server.core.objs.req import PDNIXLChunkedPrefillReq +from lightllm.utils.log_utils import init_logger +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend + +from .impl_for_pd_decode import PDNIXLBackendForDecodeNode, RemoteTransferStatusType + +logger = init_logger(__name__) + + +class PDNIXLDPBackendForDecodeNode(PDNIXLBackendForDecodeNode): + def __init__(self, prefill_task_queue: mp.Queue, prefill_done_queue: mp.Queue, nix_meta_queue: mp.Queue) -> None: + super().__init__(prefill_task_queue, prefill_done_queue, nix_meta_queue) + self.enable_decode_microbatch_overlap = get_env_start_args().enable_decode_microbatch_overlap + + def init_custom(self): + super().init_custom() + + self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) + from lightllm.server.router.model_infer.mode_backend.pre import padded_prepare_prefill_inputs + + kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs([], 1, is_multimodal=self.is_multimodal) + self.model.forward(**kwargs) + assert len(run_reqs) == 0 and padded_req_num == 1 + + return + + def decode(self): + + uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids, + no_decode=False, + ) + # filter out remote prefilling reqs + prefill_reqs, aborted_reqs, decode_reqs, _ = self._decode_filter_reqs(prefill_reqs, aborted_reqs, decode_reqs) + + self._filter_reqs(aborted_reqs) + + # allocate kv cache, do remote prefill + if prefill_reqs: + # TODO: we could allocate cache later after remote prefill done and get a signal from remote + # but it will have a risk to not have enough cache for this request. + kwargs, run_reqs = self._prepare_remote_prefill_inputs(prefill_reqs) + for idx, run_req in enumerate(run_reqs): + run_req: InferReq = run_req + shm_req: PDNIXLChunkedPrefillReq = run_req.shm_req + # forward each req to remote prefill + # since the token index are the same across TPs, we only need to trigger prefill on master + if self.is_master_in_dp: + run_req.remote_prefill_start = time.time() + # since this function may blocking the calling thread, so we do it in a thread pool + self.wait_move_page_pool.submit( + self._trigger_remote_prefill, shm_req.group_req_id, idx, kwargs, run_req + ) + + shm_req.set_pd_req_rank_state( + self.rank_in_dp, RemoteTransferStatusType.IN_PROGRESS.value + ) # set in progress state + run_req.in_prefill_or_transfer = True + self.remote_prefilled_reqs[shm_req.group_req_id] = run_req + + max_decode_num = self._dp_all_reduce_decode_req_num(decode_reqs=decode_reqs) + if max_decode_num != 0: + if not self.enable_decode_microbatch_overlap: + DPChunkedPrefillBackend.normal_decode(self, decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs) + else: + DPChunkedPrefillBackend.overlap_decode(self, decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + return diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py new file mode 100644 index 000000000..b182a8daf --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py @@ -0,0 +1,61 @@ +import threading +import torch.multiprocessing as mp +from typing import List, Tuple +from lightllm.server.router.model_infer.infer_batch import g_infer_context +from lightllm.utils.log_utils import init_logger +from lightllm.server.router.model_infer.mode_backend.continues_batch.impl import ContinuesBatchBackend +from .impl_for_pd_base import PDNIXLBackendBase + +logger = init_logger(__name__) + + +class PDNIXLBackendForPrefillNode(PDNIXLBackendBase): + def __init__(self, transfer_task_queue: mp.Queue, transfer_done_queue: mp.Queue, nixl_meta_queue: mp.Queue) -> None: + super().__init__(transfer_task_queue, transfer_done_queue, nixl_meta_queue) + + def init_custom(self): + super().init_custom() + self.handle_prefill_loop_thread = threading.Thread( + target=self._start_async_loop, args=(self._handle_prefill_loop,), daemon=True + ) + self.wait_transfer_loop_thread = threading.Thread( + target=self._start_async_loop, args=(self._wait_page_transfer_loop,), daemon=True + ) + self.handle_transfer_loop_thread = threading.Thread( + target=self._start_async_loop, args=(self._handle_transfer_loop,), daemon=True + ) + + self.handle_prefill_loop_thread.start() + self.handle_transfer_loop_thread.start() + self.wait_transfer_loop_thread.start() + return + + def prefill(self, reqs: List[Tuple]): + self._init_reqs(reqs) + return + + def decode(self): + uinit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids, + no_decode=True, + ) + + ok_finished_reqs, aborted_reqs, _ = self._prefill_filter_reqs(ok_finished_reqs, aborted_reqs) + + assert len(uinit_reqs) == 0 + assert len(decode_reqs) == 0 + + self._prefill_abort_remote(aborted_reqs) + self._filter_reqs(aborted_reqs) + + if prefill_reqs: + ContinuesBatchBackend.normal_prefill_reqs( + self, + prefill_reqs=prefill_reqs, + uninit_reqs=uinit_reqs, + ok_finished_reqs=ok_finished_reqs, + extra_post_req_handle_func=self._handle_chunked_transfer, + call_post_handle_for_chunk=True, + ) + self._overlap_req_init_and_filter(uninit_reqs=uinit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + return diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py new file mode 100644 index 000000000..ff4561cb2 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py @@ -0,0 +1,62 @@ +import torch +import torch.multiprocessing as mp +from lightllm.server.router.model_infer.infer_batch import g_infer_context +from lightllm.utils.log_utils import init_logger +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend + +from .impl_for_pd_prefill import PDNIXLBackendForPrefillNode + +logger = init_logger(__name__) + + +class PDNIXLDPBackendForPrefillNode(PDNIXLBackendForPrefillNode): + def __init__(self, transfer_task_queue: mp.Queue, transfer_done_queue: mp.Queue, nixl_meta_queue: mp.Queue) -> None: + super().__init__(transfer_task_queue, transfer_done_queue, nixl_meta_queue) + self.enable_prefill_microbatch_overlap = get_env_start_args().enable_prefill_microbatch_overlap + + def init_custom(self): + super().init_custom() + self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) + return + + def decode(self): + uinit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids, + no_decode=True, + ) + + ok_finished_reqs, aborted_reqs, _ = self._prefill_filter_reqs(ok_finished_reqs, aborted_reqs) + + assert len(uinit_reqs) == 0 + assert len(decode_reqs) == 0 + + self._prefill_abort_remote(aborted_reqs) + self._filter_reqs(aborted_reqs) + + # 进行 chuncked prefill + dp_prefill_req_nums, max_prefill_num = self._dp_all_gather_prefill_req_num(prefill_reqs=prefill_reqs) + if self.chunked_prefill_state.dp_need_prefill(prefill_reqs, decode_reqs, dp_prefill_req_nums, max_prefill_num): + if not self.enable_prefill_microbatch_overlap: + DPChunkedPrefillBackend.normal_prefill_reqs( + self, + prefill_reqs, + max_prefill_num, + uinit_reqs, + ok_finished_reqs, + extra_post_req_handle_func=self._handle_chunked_transfer, + call_post_handle_for_chunk=True, + ) + else: + DPChunkedPrefillBackend.overlap_prefill_reqs( + self, + prefill_reqs, + max_prefill_num, + uinit_reqs, + ok_finished_reqs, + extra_post_req_handle_func=self._handle_chunked_transfer, + call_post_handle_for_chunk=True, + ) + + self._overlap_req_init_and_filter(uninit_reqs=uinit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + return diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py new file mode 100644 index 000000000..fffc22858 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py @@ -0,0 +1,374 @@ +from collections import defaultdict +from typing import Dict, List, Any +from torch import Tensor +from dataclasses import dataclass +import queue +import pickle +import time + +from lightllm.utils.log_utils import init_logger + +from .pd_remote_prefill_obj import ( + RemoteAgent, + KVMoveRequest, + PrefillRequest, + RemotePrefillStatus, + ThreadSafeDict, + KVMoveRequestState, + PageTransferAck, + RemoteTransferStatusType, + RemoteTransferType, + NotificationType, + Notification, +) + + +logger = init_logger(__name__) + +try: + from nixl._api import nixl_agent as NixlWrapper + from nixl._api import nixlBind + + logger.info("Nixl is available") +except ImportError: + logger.warning("nixl is not installed, which is required for pd disagreggation!!!") + NixlWrapper = None + + +@dataclass +class NixlMetadata: + id: str + num_tokens: list[int] + num_pages: list[int] + agent_metadatas: list[bytes] + agent_mem_descs: list[bytes] + agent_page_mem_descs: list[bytes] + + +class NixlKVTransporter: + def __init__(self, node_id: int, tp_idx: int): + self.node_id = node_id + self.tp_idx = tp_idx + self.nixl_agent = NixlWrapper(self.agent_name, None) + + self.num_layers = -1 + self.num_tokens = -1 + self.num_heads = -1 + self.head_dims = -1 + self.token_len = -1 + self.num_pages = -1 + self.page_size = -1 + self.page_len = -1 + + self.reg_desc = None + self.local_xfer_handles = None + self.page_reg_desc = None + self.page_local_xfer_handles = None + + self.remote_agents = defaultdict(list) + + self.inflight_transfers: ThreadSafeDict = ThreadSafeDict() + self.inflight_page_transfers: ThreadSafeDict = ThreadSafeDict() + + @property + def agent_name(self) -> str: + return f"{self.node_id}_{self.tp_idx}" + + @property + def agent_metadata(self): + return self.nixl_agent.get_agent_metadata() + + @property + def local_mem_desc(self): + return self.nixl_agent.get_serialized_descs(self.reg_desc) + + @property + def local_page_mem_desc(self): + return self.nixl_agent.get_serialized_descs(self.page_reg_desc) + + def get_new_notifs(self): + return self.nixl_agent.get_new_notifs() + + def _create_xfer_handles(self, reg_desc: nixlBind.nixlRegDList, num_tokens: int, agent_name: str = ""): + base_addr, _, device_id, _ = reg_desc[0] + layer_len = num_tokens * self.token_len + tokens_data = [0] * (self.num_layers * num_tokens) + idx = 0 + for layer_id in range(self.num_layers): + for token_id in range(num_tokens): + tokens_data[idx] = ( + base_addr + layer_id * layer_len + token_id * self.token_len, + self.token_len, + device_id, + ) + idx += 1 + descs = self.nixl_agent.get_xfer_descs(tokens_data, "VRAM", True) + return self.nixl_agent.prep_xfer_dlist(agent_name, descs, is_sorted=True) + + def register_kv_buffer(self, kv_buffer: Tensor): + self.num_layers, self.num_tokens, self.num_heads, self.head_dim = kv_buffer.shape + self.token_len = self.num_heads * self.head_dim * kv_buffer.element_size() + + self.reg_desc = self.nixl_agent.register_memory(kv_buffer) + self.local_xfer_handles = self._create_xfer_handles(self.reg_desc, self.num_tokens) + + def _create_paged_xfer_handles(self, reg_desc: nixlBind.nixlRegDList, page_num: int, agent_name: str = ""): + base_addr, _, device_id, _ = reg_desc[0] + pages_data = [] + for page_id in range(page_num): + pages_data.append((base_addr + page_id * self.page_len, self.page_len, device_id)) + descs = self.nixl_agent.get_xfer_descs(pages_data, "VRAM", True) + return self.nixl_agent.prep_xfer_dlist(agent_name, descs, is_sorted=True) + + def register_kv_move_buffer(self, kv_move_buffer: Tensor): + self.num_pages, self.page_size, _, _, _ = kv_move_buffer.shape + self.page_len = self.page_size * self.num_layers * self.token_len + self.page_reg_desc = self.nixl_agent.register_memory(kv_move_buffer) + self.page_local_xfer_handles = self._create_paged_xfer_handles(self.page_reg_desc, self.num_pages) + + def add_remote_agent(self, remote_agent: NixlMetadata): + for idx, (agent_metadata, num_tokens, num_pages, agent_mem_desc, agent_page_mem_desc) in enumerate( + zip( + remote_agent.agent_metadatas, + remote_agent.num_tokens, + remote_agent.num_pages, + remote_agent.agent_mem_descs, + remote_agent.agent_page_mem_descs, + ) + ): + if self.tp_idx != idx: + self.remote_agents[remote_agent.id].append(None) + continue + + peer_name = self.nixl_agent.add_remote_agent(agent_metadata) + if isinstance(peer_name, bytes): + peer_name = peer_name.decode() + + self.nixl_agent.send_notif( + peer_name, Notification(type=NotificationType.REMOTE_MD, data=self.agent_metadata).to_bytes() + ) + + mem_desc = self.nixl_agent.deserialize_descs(agent_mem_desc) + kv_xfer_handles = self._create_xfer_handles(mem_desc, num_tokens, agent_name=peer_name) + + page_mem_desc = self.nixl_agent.deserialize_descs(agent_page_mem_desc) + kv_page_xfer_handles = self._create_paged_xfer_handles(page_mem_desc, num_pages, agent_name=peer_name) + + logger.info("Added remote agent %s with mem desc %s", peer_name, page_mem_desc) + self.remote_agents[remote_agent.id].append( + RemoteAgent( + name=peer_name, + kv_mem_desc=mem_desc, + num_tokens=num_tokens, + kv_xfer_handles=kv_xfer_handles, + kv_page_mem_desc=page_mem_desc, + num_pages=num_pages, + kv_page_xfer_handles=kv_page_xfer_handles, + ) + ) + + def connect_to_remote(self, name: str, remote_md: bytes): + target = self.nixl_agent.add_remote_agent(remote_md) + if isinstance(target, bytes): + target = target.decode() + assert name == target, "Target name {} does not match remote name {}".format(target, name) + + def _get_token_desc_ids(self, token_ids: List[int], num_tokens: int): + token_ids_len, idx = len(token_ids), 0 + descs_ids = [0] * (self.num_layers * token_ids_len) + for layer_id in range(self.num_layers): + for token_id in token_ids: + descs_ids[idx] = layer_id * num_tokens + token_id + idx += 1 + return descs_ids + + def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest, is_finished: bool): + group_reqeust_id = request.group_req_id + skip_kv_move_len = prefill_request.data.local_cached_len + + # current kv len is less than remote cached kv len, just skip + if request.cur_kv_len <= skip_kv_move_len: + return + + kv_move_start = max(skip_kv_move_len, request.prev_kv_len) + kv_move_end = request.cur_kv_len + + src_token_ids = request.token_ids[kv_move_start:] + dst_token_ids = prefill_request.data.token_ids[ + kv_move_start - skip_kv_move_len : kv_move_end - skip_kv_move_len + ] + + remote_agent: RemoteAgent = self.remote_agents[prefill_request.decode_id][ + self.tp_idx + ] # TODO one-one mapping now + + if len(src_token_ids) > 0: + assert len(src_token_ids) == len(dst_token_ids), ( + f"{len(src_token_ids)} {len(dst_token_ids)} {kv_move_start} " + f"{kv_move_end} {skip_kv_move_len}, {len(prefill_request.data.token_ids)}" + ) + src_token_descs = self._get_token_desc_ids(src_token_ids, self.num_tokens) + dst_token_descs = self._get_token_desc_ids(dst_token_ids, remote_agent.num_tokens) + + src_handle = self.local_xfer_handles + dst_handle = remote_agent.kv_xfer_handles + + notify_status = ( + RemotePrefillStatus( + group_req_id=group_reqeust_id, + status=1, + chunk_id=prefill_request.transfer_state.current_chunk_id, + is_last=is_finished, + ).serialize() + if is_finished + else b"" + ) + + handle = self.nixl_agent.make_prepped_xfer( + "WRITE", src_handle, src_token_descs, dst_handle, dst_token_descs, notify_status + ) + + status = self.nixl_agent.transfer(handle) + assert status != "ERR" + + if group_reqeust_id not in self.inflight_transfers: + self.inflight_transfers[group_reqeust_id] = KVMoveRequestState( + handles=[], done_handles=[], remote_agent=remote_agent, abort=False, is_last_arrived=False + ) + + self.inflight_transfers[group_reqeust_id].handles.append(handle) + + if is_finished: + self.inflight_transfers[group_reqeust_id].is_last_arrived = True + + return handle + + return None + + def write_blocks_paged( + self, + remote_id: int, + transfer_pages: List[int], + receive_pages: List[int], + notifications: List[RemotePrefillStatus], + ): + remote_agent: RemoteAgent = self.remote_agents[remote_id][self.tp_idx] + src_handle = self.page_local_xfer_handles + dst_handle = remote_agent.kv_page_xfer_handles + notify_status = Notification(type=NotificationType.TRANSFER_NOTIFY, data=[n.serialize() for n in notifications]) + handle = self.nixl_agent.make_prepped_xfer( + "WRITE", src_handle, transfer_pages, dst_handle, receive_pages, notify_status.to_bytes() + ) + status = self.nixl_agent.transfer(handle) + assert status != "ERR", f"Transfer failed with status {status} for handle {handle}" + self.inflight_page_transfers[handle] = (transfer_pages, receive_pages, notifications, remote_agent) + + def send_transfer_notify(self, agent_name: str, acks: List[PageTransferAck]): + assert len(acks) > 0, "Acks should not be empty" + acks_noti = Notification(type=NotificationType.TRANSFER_NOTIFY_ACK, data=[ack.serialize() for ack in acks]) + self.nixl_agent.send_notif(agent_name, acks_noti.to_bytes()) + + def send_abort_notify(self, remote_id: int, group_req_id: int): + remote_agent: RemoteAgent = self.remote_agents[remote_id][self.tp_idx] + notify_status = RemotePrefillStatus( + group_req_id=group_req_id, + transfer_type=RemoteTransferType.PAGE_TRANSFER, + status=RemoteTransferStatusType.FAILED, + is_last=True, + ) + self.nixl_agent.send_notif( + remote_agent.name, + Notification(type=NotificationType.TRANSFER_NOTIFY, data=[notify_status.serialize()]).to_bytes(), + ) + + if group_req_id in self.inflight_transfers: + self.inflight_transfers[group_req_id].abort = True + + async def get_done_page_transfers(self): + done_pages = [] + done_requests = [] + for handle, (transfer_pages, _, notifications, _) in self.inflight_page_transfers.items(): + xfer_state = self.nixl_agent.check_xfer_state(handle) + if xfer_state == "DONE": + done_pages.extend(transfer_pages) + done_requests.extend( + [(x.group_req_id, RemoteTransferStatusType.SUCCESS) for x in notifications if x.is_last] + ) + self.nixl_agent.release_xfer_handle(handle) + del self.inflight_page_transfers[handle] + + elif xfer_state == "PROC": + continue + else: + logger.warning(f"Transfer failed with state {xfer_state} for handle {handle}") + done_pages.extend(transfer_pages) + done_requests.extend([(x.group_req_id, RemoteTransferStatusType.FAILED) for x in notifications]) + self.nixl_agent.release_xfer_handle(handle) + del self.inflight_page_transfers[handle] + + return done_pages, done_requests + + def get_done_tranfers(self): + done_req_ids = [] + for req_id, kv_move_state in self.inflight_transfers.items(): + kv_move_state: KVMoveRequestState + if kv_move_state.abort: + logger.warning(f"{req_id} Transfer aborted") + done_req_ids.append((req_id, -1)) + continue + + if not kv_move_state.is_last_arrived: + continue + + remote_agent: RemoteAgent = kv_move_state.remote_agent + + left_handles = [] + failed = False + for handle in kv_move_state.handles: + if failed: + left_handles.append(handle) + continue + + xfer_state = self.nixl_agent.check_xfer_state(handle) + + if xfer_state == "DONE": + kv_move_state.done_handles.append(handle) + elif xfer_state == "PROC": + left_handles.append(handle) + else: + logger.warning(f"{req_id} Transfer failed with state {xfer_state}") + failed = True + kv_move_state.done_handles.append(handle) + notify_failed_status = RemotePrefillStatus( + group_req_id=req_id, status=-1, chunk_id=-1, is_last=True + ) + self.nixl_agent.send_notif(remote_agent.name, notify_failed_status.serialize()) + + kv_move_state.handles = left_handles + + if failed: + done_req_ids.append((req_id, -1)) + elif len(left_handles) == 0: + done_req_ids.append((req_id, 1)) + + for req_id, _ in done_req_ids: + kv_move_state: KVMoveRequestState = self.inflight_transfers[req_id] + for handle in kv_move_state.handles + kv_move_state.done_handles: + # release will abort inflight transfer + self.nixl_agent.release_xfer_handle(handle) + + del self.inflight_transfers[req_id] + return done_req_ids + + def shutdown(self): + self.nixl_agent.deregister_memory(self.reg_desc) + self.nixl_agent.release_dlist_handle(self.local_xfer_handles) + self.nixl_agent.release_dlist_handle(self.page_local_xfer_handles) + for id, agents in self.remote_agents.items(): + for agent in agents: + self.nixl_agent.remove_remote_agent(agent.name) + self.nixl_agent.release_dlist_handle(agent.kv_xfer_handles) + self.nixl_agent.release_dlist_handle(agent.kv_page_xfer_handles) + + for handle in self.inflight_page_transfers: + self.nixl_agent.release_xfer_handle(handle) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py new file mode 100644 index 000000000..d1fa4003a --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py @@ -0,0 +1,314 @@ +from typing import List, Any +import zmq +import inspect +import random +import time + +import torch.multiprocessing as mp + +from lightllm.utils.log_utils import init_logger +from lightllm.utils.net_utils import get_hostname_ip +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.server.pd_io_struct import DistInfo + +from .pd_remote_prefill_obj import ( + ConnectRequest, + RemoteRequest, + RemoteRequstType, + PrefillRequest, + RemotePrefillRequest, + RemotePrefillServerInfo, + RemotePrefillTask, + RemotePrefillStatus, + RemoteTransferStatusType, + RemoteTransferType, + SockWithPoller, +) +from .nixl_kv_transporter import NixlMetadata + +logger = init_logger(__name__) + + +class PDRemotePrefillBase: + def __init__( + self, + id: int, + dist_info: DistInfo, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], # need send kv cache to this process and register with nixl + ): + self.id = id + self.dist_info = dist_info + assert len(agent_meta_queues) == dist_info.node_world_size + self.agent_meta_queues = agent_meta_queues + self.from_backend_queue = from_backend_queue + self.to_backend_queues = to_backend_queues + self.local_agent_meta = None + + def local_init(self): + agent_metas = NixlMetadata( + id=self.id, + agent_metadatas=[], + num_tokens=[], + num_pages=[], + agent_mem_descs=[], + agent_page_mem_descs=[], + ) + for tp in range(self.dist_info.node_world_size): + agent_metadata, num_tokens, num_pages, mem_desc, page_mem_desc = self.agent_meta_queues[tp].get(timeout=60) + logger.info(f"Received agent_metadata from {tp} with mem reg: {mem_desc}") + agent_metas.num_tokens.append(num_tokens) + agent_metas.num_pages.append(num_pages) + agent_metas.agent_metadatas.append(agent_metadata) + agent_metas.agent_mem_descs.append(mem_desc) + agent_metas.agent_page_mem_descs.append(page_mem_desc) + + self.local_agent_meta = agent_metas + logger.info("All local kv cache registered.") + + +class PDRemotePrefillServer(PDRemotePrefillBase): + def __init__( + self, + id: int, + dist_info: DistInfo, + http_server_port: int, + server_port: int, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], + ): + super().__init__(id, dist_info, from_backend_queue, to_backend_queues, agent_meta_queues) + # map from client id to decode server info + self.remote_decode_clients = {} + + # build control path + _ctx = zmq.Context() + self.recv_from_decode = SockWithPoller(_ctx.socket(zmq.ROUTER)) + self.host_ip = get_hostname_ip() + self.recv_from_decode.bind(f"tcp://{self.host_ip}:{server_port}") + + # build trigger remote prefill path + self.send_to_httpserver = SockWithPoller(_ctx.socket(zmq.PUSH)) + self.send_to_httpserver.connect(f"tcp://{self.host_ip}:{http_server_port}") + + def main_loop(self): + self.local_init() + while True: + try: + client_obj, request = self.recv_from_decode.recv_pyobj_multipart() + request: RemoteRequest + logger.info(f"recevied request from decode, type: {request.type}") + + if request.type == RemoteRequstType.REMOTE_CONNECT: + # forward request to all prefill server + for queue in self.to_backend_queues: + queue.put(request) + + success = True + for idx in range(self.dist_info.node_world_size): + ack = self.from_backend_queue.get() + logger.info(f"received ack from backend {idx}: {ack}") + if ack != "OK": + success = False + break + + self.recv_from_decode.send_pyobj_multipart(client_obj, success) + logger.info(f"Sent ack to decode: {success}") + if not success: + logger.warning(f"Remote connect failed: {request}") + + if request.type == RemoteRequstType.REMOTE_PREFILL: + request: PrefillRequest = request + if self.dist_info.dp_size_in_node > 1: + group_req_id = request.data.sampling_params.group_request_id + suggested_dp_index = request.data.sampling_params.suggested_dp_index + if suggested_dp_index < 0: # not likely to happen + suggested_dp_index = random.randint(0, self.dist_info.dp_size_in_node) + request.data.sampling_params.suggested_dp_index = suggested_dp_index + logger.warning( + f"Suggested dp index is negative for {group_req_id}, set to {suggested_dp_index}" + ) + + for local_rank in range( + suggested_dp_index * self.dist_info.dp_world_size, + (suggested_dp_index + 1) * self.dist_info.dp_world_size, + ): + self.to_backend_queues[local_rank].put(request) + else: + for queue in self.to_backend_queues: + queue.put(request) + + self.send_to_httpserver.send_pyobj( + (request.data.prompt, request.data.sampling_params, request.data.multimodal_params) + ) + + except Exception as e: + logger.error(f"Error in remote prefill server loop: {e}", exc_info=e) + + +class PDRemotePrefillClient(PDRemotePrefillBase): + def __init__( + self, + id: int, + dist_info: DistInfo, + from_backend_queue: mp.Queue, # only tp0 will trigger prefill + to_backend_queues: List[mp.Queue], # one to many done queue + agent_meta_queues: List[mp.Queue], + ): + super().__init__(id, dist_info, from_backend_queue, to_backend_queues, agent_meta_queues) + # map from server id to prefill server info + + self.remote_prefill_servers = {} + self.client_socket_cnt = 0 + self._ctx = zmq.Context() + + def _connect_server(self, server_ip: str, port: int): + _socket = self._ctx.socket(zmq.DEALER) + _socket.setsockopt_string(zmq.IDENTITY, f"{self.id}_{self.client_socket_cnt}") + self.client_socket_cnt += 1 + connect_str = f"tcp://{server_ip}:{port}" + _socket.connect(connect_str) + return SockWithPoller(_socket) + + def _send_nixl_agent(self, socket: SockWithPoller): + socket.send_pyobj( + ConnectRequest( + type=RemoteRequstType.REMOTE_CONNECT, + decode_id=self.id, + num_tokens=self.local_agent_meta.num_tokens, + num_pages=self.local_agent_meta.num_pages, + agent_metadatas=self.local_agent_meta.agent_metadatas, + agent_mem_descs=self.local_agent_meta.agent_mem_descs, + agent_page_mem_descs=self.local_agent_meta.agent_page_mem_descs, + ) + ) + + success = socket.recv_pyobj(timeout=60) + logger.info(f"recv remote nixl connect response {success}") + if success is None: + logger.warning("timeout to recv remote nixl connect response") + return False + + return success + + def connect_to_prefill_server(self, server_info: RemotePrefillServerInfo): + + if server_info.perfill_server_id in self.remote_prefill_servers: + return True + + # build control path if not exist + _socket = self._connect_server(server_info.prefill_server_ip, server_info.prefill_server_port) + success = self._send_nixl_agent(_socket) + if success: + self.remote_prefill_servers[server_info.perfill_server_id] = (_socket, server_info) + return True + else: + logger.warning("Remote Prefill Server Connect Failed") + return False + + def main_loop(self): + self.local_init() + while True: + try: + prefill_tasks: RemotePrefillTask = self.from_backend_queue.get() + # connect first + if self.connect_to_prefill_server(prefill_tasks.server_info): + # do prefill + self.remote_prefill(prefill_tasks.server_info.perfill_server_id, prefill_tasks.prefill_request) + else: + # failed to connect a remote + for idx in self.to_backend_queues: + self.to_backend_queues.put( + RemotePrefillStatus( + transfer_type=RemoteTransferType.PAGE_TRANSFER, + group_req_id=prefill_tasks.prefill_request.sampling_params.group_request_id, + status=RemoteTransferStatusType.FAILED, + is_last=True, + ) + ) + except Exception as e: + logger.error(f"Remote prefill client loop error: {e}", exc_info=e) + + # place request to server do remote prefill + def remote_prefill(self, server_id: int, prefill_request: RemotePrefillRequest): + socket, _ = self.remote_prefill_servers[server_id] + prefill_request.sampling_params.max_new_tokens = 1 + socket.send_pyobj( + PrefillRequest( + type=RemoteRequstType.REMOTE_PREFILL, decode_id=self.id, data=prefill_request, transfer_state=None + ) + ) + + +def remote_prefill_server_loop( + id: int, + dist_info: DistInfo, + http_server_port: int, + server_port: int, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], +): + graceful_registry(inspect.currentframe().f_code.co_name) + server = PDRemotePrefillServer( + id, dist_info, http_server_port, server_port, from_backend_queue, to_backend_queues, agent_meta_queues + ) + server.main_loop() + + +def start_pd_remote_prefill_server_process( + id: int, + dist_info: DistInfo, + http_server_port: int, + server_port: int, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], +): + proc = mp.Process( + target=remote_prefill_server_loop, + args=(id, dist_info, http_server_port, server_port, from_backend_queue, to_backend_queues, agent_meta_queues), + ) + proc.start() + assert proc.is_alive() + logger.info(f"remote prefill server with id: {id} started!") + return proc + + +def remote_prefill_client_loop( + id: int, + dist_info: DistInfo, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], +): + graceful_registry(inspect.currentframe().f_code.co_name) + + client = PDRemotePrefillClient( + id, + dist_info, + from_backend_queue, + to_backend_queues, + agent_meta_queues, + ) + client.main_loop() + + +def start_pd_remote_prefill_client_process( + id: int, + dist_info: DistInfo, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], +): + + proc = mp.Process( + target=remote_prefill_client_loop, + args=(id, dist_info, from_backend_queue, to_backend_queues, agent_meta_queues), + ) + proc.start() + assert proc.is_alive() + logger.info(f"remote prefill client with id: {id} started!") + return proc diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py new file mode 100644 index 000000000..99a61cc43 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py @@ -0,0 +1,301 @@ +from dataclasses import dataclass, asdict +from enum import Enum +import json +from typing import List, Union, Optional, Any +from threading import Lock, Condition +import pickle +import zmq +import threading + +from lightllm.utils.log_utils import init_logger +from lightllm.server.core.objs import SamplingParams +from lightllm.server.multimodal_params import MultimodalParams +from lightllm.server.pd_io_struct import RemotePrefillServerInfo + +logger = init_logger(__name__) + +try: + from nixl._api import nixlBind, nixl_prepped_dlist_handle, nixl_xfer_handle + +except ImportError: + logger.error("nixl is not installed, which is required for pd disagreggation!!!") + raise + + +class RemoteRequstType(Enum): + REMOTE_CONNECT = 1 + REMOTE_PREFILL = 2 + + +@dataclass +class RemotePrefillRequest: + prompt: Union[str, List[int]] + sampling_params: SamplingParams + multimodal_params: MultimodalParams + local_cached_len: int # will skip transfer + token_ids: List[int] # mem cache indexes + page_ids: List[int] # transfer page indexes + + +@dataclass +class RemotePrefillTask: + server_info: RemotePrefillServerInfo + prefill_request: RemotePrefillRequest + + +@dataclass +class RemoteRequest: + type: RemoteRequstType + + +@dataclass +class ConnectRequest(RemoteRequest): + decode_id: int + num_tokens: List[int] + num_pages: List[int] + agent_metadatas: List[bytes] + agent_mem_descs: List[bytes] + agent_page_mem_descs: List[bytes] + + +@dataclass +class TransferState: + start_time: float + lock: threading.Lock + free_page_ids: List[int] + + current_kv_len: int = 0 + current_chunk_id: int = 0 + + transfered_kv_len: int = 0 + transfered_chunk_id: int = 0 + + token_index: List[int] = None + is_finished: bool = False + + next_token_id: int = None + next_token_logprob: float = None + + def completed(self): + return self.is_finished and self.transfered_kv_len == self.current_kv_len + + +@dataclass +class PrefillRequest(RemoteRequest): + decode_id: int + data: RemotePrefillRequest + # transfer status + transfer_state: Optional[TransferState] + + +@dataclass +class KVMoveRequest: + group_req_id: int + prev_kv_len: int + cur_kv_len: int + + +@dataclass +class RemoteAgent: + name: str + num_tokens: int + num_pages: int + kv_mem_desc: nixlBind.nixlRegDList + kv_xfer_handles: nixl_prepped_dlist_handle + kv_page_mem_desc: nixlBind.nixlRegDList + kv_page_xfer_handles: nixl_prepped_dlist_handle + + +@dataclass +class KVMoveRequestState: + handles: List[nixl_xfer_handle] + done_handles: List[nixl_xfer_handle] + remote_agent: RemoteAgent + abort: bool + is_last_arrived: bool + + +class SerializableBase: + def to_dict(self): + return asdict(self) + + def serialize(self): + return json.dumps(self.to_dict()).encode() + + @classmethod + def from_dict(cls, dict_obj): + return cls(**dict_obj) + + @classmethod + def deserialize(cls, data: bytes): + return cls.from_dict(json.loads(data.decode())) + + +class RemoteTransferType(Enum): + TOKEN_TRANSFER = 1 + PAGE_TRANSFER = 2 + + +class RemoteTransferStatusType(Enum): + FAILED = -1 + IN_PROGRESS = 0 + SUCCESS = 1 + + +@dataclass +class RemotePrefillStatus(SerializableBase): + transfer_type: RemoteTransferType + group_req_id: int + status: RemoteTransferStatusType + chunk_id: int = -1 + is_last: bool = False + page_id: int = -1 + kv_start: int = 0 + kv_len: int = 0 + next_token_id: int = None + next_token_logprob: float = None + + def to_dict(self): + dict_obj = asdict(self) + dict_obj["transfer_type"] = self.transfer_type.name + dict_obj["status"] = self.status.name + return dict_obj + + @classmethod + def from_dict(cls, dict_obj): + dict_obj["transfer_type"] = RemoteTransferType[dict_obj["transfer_type"]] + dict_obj["status"] = RemoteTransferStatusType[dict_obj["status"]] + return cls(**dict_obj) + + +@dataclass +class PageTransferAck(SerializableBase): + group_req_id: int + page_id: int + + +class NotificationType(Enum): + REMOTE_MD = 1 + TRANSFER_NOTIFY = 2 + TRANSFER_NOTIFY_ACK = 3 + + +@dataclass +class Notification: + type: NotificationType + data: Union[bytes, List[bytes]] + + def to_bytes(self): + return pickle.dumps(self) + + @classmethod + def from_bytes(cls, data): + return pickle.loads(data) + + +class ThreadSafeDict: + def __init__(self): + self._dict = {} + self._lock = Lock() + + def __getitem__(self, key): + with self._lock: + return self._dict[key] + + def __setitem__(self, key, value): + with self._lock: + self._dict[key] = value + + def __delitem__(self, key): + with self._lock: + del self._dict[key] + + def __contains__(self, key): + with self._lock: + return key in self._dict + + def __len__(self) -> int: + with self._lock: + return len(self._dict) + + def get(self, key, default=None): + with self._lock: + return self._dict.get(key, default) + + def items(self): + with self._lock: + return list(self._dict.items()) + + def keys(self): + with self._lock: + return list(self._dict.keys()) + + def pop(self, key: Any, default: Optional[Any] = None) -> Any: + with self._lock: + return self._dict.pop(key, default) + + def values(self): + with self._lock: + return list(self._dict.values()) + + def clear(self) -> None: + with self._lock: + self._dict.clear() + + +class SockWithPoller: + def __init__(self, sock: zmq.Socket): + self.sock = sock + self.poller = zmq.Poller() + self.poller.register(self.sock, zmq.POLLIN) + + def recv_pyobj(self, timeout: int = 5): + socks = dict(self.poller.poll(timeout * 1000)) + if socks: + if socks.get(self.sock) == zmq.POLLIN: + return self.sock.recv_pyobj() + else: + None + + def send_pyobj(self, obj: Any): + return self.sock.send_pyobj(obj) + + def recv_pyobj_multipart(self): + client_id, data = self.sock.recv_multipart() + return client_id, pickle.loads(data) + + def send_pyobj_multipart(self, client_id: bytes, data: Any): + return self.sock.send_multipart([client_id, pickle.dumps(data)]) + + def bind(self, addr: str): + return self.sock.bind(addr) + + def connect(self, addr: str): + return self.sock.connect(addr) + + +class SafePageIndexScheduler: + def __init__(self, num_pages: int): + self.num_pages = num_pages + self.items = list(range(num_pages)) + self.lock = Lock() + self.cond = Condition(self.lock) + + def borrow(self, num_pages: int = 2) -> List[int]: + if num_pages > self.num_pages: + raise ValueError(f"Cannot borrow {num_pages} pages, only {self.num_pages} available.") + + with self.cond: + while len(self.items) < num_pages: + self.cond.wait() + ret, self.items = self.items[:num_pages], self.items[num_pages:] + return ret + + def return_(self, items: List[int]): + with self.cond: + self.items.extend(items) + self.cond.notify_all() + + def current_size(self) -> int: + with self.lock: + return len(self.items) diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 311c2725f..6f0d41657 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -27,6 +27,10 @@ ContinuesBatchBackendForMtpDecodeNode, ChunckedPrefillForMtpPrefillNode, DPChunkedForMtpPrefillNode, + PDNIXLBackendForPrefillNode, + PDNIXLBackendForDecodeNode, + PDNIXLDPBackendForPrefillNode, + PDNIXLDPBackendForDecodeNode, ) from lightllm.server.router.model_infer.mode_backend.redundancy_expert_manager import RedundancyExpertManager from lightllm.server.core.objs import RpcShmParams, RpcShmResults, ShmSyncStatusArray @@ -48,12 +52,14 @@ def __init__( rpc_event: multiprocessing.Event, rpc_finished_event: multiprocessing.Event, info_queue: mp.Queue, + result_queue: mp.Queue, mem_queue: mp.Queue, ): super().__init__() self.args: StartArgs = args self.node_world_size = node_world_size self.info_queue = info_queue + self.result_queue = result_queue self.mem_queue = mem_queue self.rpc_event = rpc_event self.rpc_finished_event = rpc_finished_event @@ -124,6 +130,8 @@ def init_model(self, kvargs): assert not (is_outlines_constraint_mode and is_xgrammar_constraint_mode), "only one constraint mode can be true" is_prefill_node = self.args.run_mode == "prefill" is_decode_node = self.args.run_mode == "decode" + is_nixl_prefill_node = self.args.run_mode == "nixl_prefill" + is_nixl_decode_node = self.args.run_mode == "nixl_decode" enable_mtp = self.args.mtp_mode is not None @@ -138,6 +146,14 @@ def init_model(self, kvargs): self.backend = DPChunkedForPrefillNode(self.info_queue, self.mem_queue) else: self.backend = ChunckedPrefillForPrefillNode(self.info_queue, self.mem_queue) + + elif is_nixl_prefill_node: + assert not enable_mtp, "nixl pd does not support mtp now." + if self.args.dp > 1: + self.backend = PDNIXLDPBackendForPrefillNode(self.info_queue, self.result_queue, self.mem_queue) + else: + self.backend = PDNIXLBackendForPrefillNode(self.info_queue, self.result_queue, self.mem_queue) + elif is_decode_node: if enable_mtp: if self.args.dp > 1: @@ -149,11 +165,20 @@ def init_model(self, kvargs): self.backend = DPForDecodeNode(self.info_queue, self.mem_queue) else: self.backend = ContinuesBatchBackendForDecodeNode(self.info_queue, self.mem_queue) + + elif is_nixl_decode_node: + assert not enable_mtp, "nixl pd does not support mtp now." + if self.args.dp > 1: + self.backend = PDNIXLDPBackendForDecodeNode(self.info_queue, self.result_queue, self.mem_queue) + else: + self.backend = PDNIXLBackendForDecodeNode(self.info_queue, self.result_queue, self.mem_queue) + elif self.args.dp > 1: if enable_mtp: self.backend = DPChunkedPrefillWithMTPBackend() else: - self.backend = DPChunkedPrefillBackend() + self.backend = ContinuesBatchBackendForDecodeNode(self.info_queue, self.mem_queue) + elif use_reward_model: self.backend = RewardModelBackend() elif return_all_prompt_logprobs: @@ -282,6 +307,7 @@ def _init_env( rank_in_node, node_world_size, info_queue, + result_queue, mem_queue, router_lock, rpc_event: mp.Event, @@ -300,7 +326,7 @@ def _init_env( g_router_lock.obj = router_lock model_rpc_server = ModelRpcServer( - args, rank, rank_in_node, node_world_size, rpc_event, rpc_finished_event, info_queue, mem_queue + args, rank, rank_in_node, node_world_size, rpc_event, rpc_finished_event, info_queue, result_queue, mem_queue ) success_event.set() @@ -316,6 +342,7 @@ async def start_model_process( rpc_event, rpc_finished_event, info_queue: mp.Queue, + result_queue: mp.Queue, mem_queue: mp.Queue, router_lock: mp.Queue, ): @@ -330,6 +357,7 @@ async def start_model_process( rank_in_node, node_world_size, info_queue, + result_queue, mem_queue, router_lock, rpc_event, diff --git a/lightllm/server/router/req_queue/__init__.py b/lightllm/server/router/req_queue/__init__.py index 736ab5e59..36ba7879e 100644 --- a/lightllm/server/router/req_queue/__init__.py +++ b/lightllm/server/router/req_queue/__init__.py @@ -15,9 +15,9 @@ def _get_req_queue_class(args, router, dp_size_in_node: int): return ChunkedPrefillQueue if args.first_token_constraint_mode: return ChunkedPrefillQueue - if args.run_mode == "decode": + if args.run_mode in ["decode", "nixl_decode"]: return QueueForPDDecode - if args.run_mode == "prefill": + if args.run_mode in ["prefill", "nixl_prefill"]: return QueueForPDChunkedPrefill if args.disable_chunked_prefill: diff --git a/lightllm/utils/health_check.py b/lightllm/utils/health_check.py index f6305e209..ee0778b65 100644 --- a/lightllm/utils/health_check.py +++ b/lightllm/utils/health_check.py @@ -70,7 +70,7 @@ async def health_check(args, httpserver_manager: HttpServerManager, request: Req health_obj.begin_check() try: request_dict = {"inputs": "你好!", "parameters": {"do_sample": True, "temperature": 0.8, "max_new_tokens": 2}} - if args.run_mode == "prefill": + if args.run_mode in ["prefill", "nixl_prefill"]: request_dict["parameters"]["max_new_tokens"] = 1 prompt = request_dict.pop("inputs") sample_params_dict = request_dict["parameters"]