diff --git a/Dockerfile.nixl b/Dockerfile.nixl new file mode 100644 index 000000000..427a230e5 --- /dev/null +++ b/Dockerfile.nixl @@ -0,0 +1,89 @@ +FROM nvcr.io/nvidia/tritonserver:25.04-py3-min as base +ARG PYTORCH_VERSION=2.6.0 +ARG PYTHON_VERSION=3.9 +ARG CUDA_VERSION=12.4 +ARG MAMBA_VERSION=23.1.0-1 +ARG TARGETPLATFORM +ARG INSTALL_NIXL=true + +ENV PATH=/opt/conda/bin:$PATH \ + CONDA_PREFIX=/opt/conda + +RUN chmod 777 -R /tmp && apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + ca-certificates \ + libssl-dev \ + curl \ + g++ \ + make \ + git && \ + rm -rf /var/lib/apt/lists/* + +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") MAMBA_ARCH=aarch64 ;; \ + *) MAMBA_ARCH=x86_64 ;; \ + esac && \ + curl -fsSL -o ~/mambaforge.sh -v "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + rm ~/mambaforge.sh + +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") exit 1 ;; \ + *) /opt/conda/bin/conda update -y conda && \ + /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \ + esac && \ + /opt/conda/bin/conda clean -ya + + +WORKDIR /root + +COPY ./requirements.txt /lightllm/requirements.txt +RUN --mount=type=cache,target=/root/.cache/pip pip install -r /lightllm/requirements.txt --ignore-installed --extra-index-url https://download.pytorch.org/whl/cu124 + +RUN pip install --no-cache-dir https://github.com/ModelTC/flash-attn-3-build/releases/download/v2.7.4.post1/flash_attn-3.0.0b1-cp39-cp39-linux_x86_64.whl + +RUN --mount=type=cache,target=/root/.cache/pip pip install nvidia-nccl-cu12==2.25.1 # for allreduce hang issues in multinode H100 + +RUN if [ "$INSTALL_NIXL" == "true" ]; then \ + apt-get update && apt-get install -y cmake automake autotools-dev libtool libz-dev && \ + DEBIAN_FRONTEND=noninteractive apt-get -y install --reinstall libibverbs-dev rdma-core ibverbs-utils libibumad-dev; \ + rm -rf /usr/lib/ucx && \ + rm -rf /opt/hpcx/ucx && \ + cd /usr/local/src && \ + git clone https://github.com/openucx/ucx.git && \ + cd ucx && \ + git checkout v1.19.x && \ + ./autogen.sh && ./configure \ + --enable-shared \ + --disable-static \ + --disable-doxygen-doc \ + --enable-optimizations \ + --enable-cma \ + --enable-devel-headers \ + --with-cuda=/usr/local/cuda \ + --with-verbs=yes \ + --with-dm \ + --with-gdrcopy=/usr/local \ + --with-efa \ + --enable-mt && \ + make -j && \ + make -j install-strip && \ + ldconfig; \ + fi + +RUN if [ "$INSTALL_NIXL" == "true" ]; then \ + apt-get update && apt-get install -y pkg-config tmux net-tools; \ + cd /usr/local/src; \ + pip install --upgrade meson pybind11 patchelf; \ + git clone https://github.com/ai-dynamo/nixl.git -b main && \ + cd nixl && \ + rm -rf build && \ + mkdir build && \ + meson setup build/ --prefix=/usr/local/nixl --buildtype=release && \ + cd build && \ + ninja && \ + ninja install && \ + cd .. && pip install . --no-deps; \ + fi + +COPY . /lightllm +RUN pip install -e /lightllm --no-cache-dir diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 1aca6672c..7b5b021f6 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -7,7 +7,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--run_mode", type=str, - choices=["normal", "prefill", "decode", "pd_master", "config_server"], + choices=["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"], default="normal", help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode, config_server is for pd split mode used to register pd_master node, and get pd_master node list, @@ -54,6 +54,20 @@ def make_argument_parser() -> argparse.ArgumentParser: default=None, help="The port number for the config server in config_server mode.", ) + parser.add_argument( + "--pd_nixl_remote_prefill_http_port", + type=int, + default=42001, + help="nixl pd mode, prefill node used for triggering prefill http port.", + ) + + parser.add_argument( + "--pd_nixl_remote_prefill_port", + type=int, + default=42002, + help="nixl pd mode, prefill and decode used for meta exchange.", + ) + parser.add_argument( "--model_name", type=str, diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index fa100f119..4a7775677 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -69,7 +69,7 @@ def normal_or_p_d_start(args): enable_mps() - if args.run_mode not in ["normal", "prefill", "decode"]: + if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode"]: return assert args.zmq_mode in ["tcp://", "ipc:///tmp/"] diff --git a/lightllm/server/core/objs/__init__.py b/lightllm/server/core/objs/__init__.py index d4e69fd36..0096d7baa 100644 --- a/lightllm/server/core/objs/__init__.py +++ b/lightllm/server/core/objs/__init__.py @@ -1,4 +1,4 @@ from .sampling_params import SamplingParams -from .req import Req, FinishStatus +from .req import Req, FinishStatus, PDNIXLChunkedPrefillReq from .shm_req_manager import ShmReqManager from .rpc_shm import RpcShmParams, RpcShmResults, ShmSyncStatusArray diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index c8d8476e5..9ac010e98 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -102,6 +102,7 @@ def get_str(self): f"shm_cur_kv_len:{self.shm_cur_kv_len}," f"shm_cur_output_len:{self.shm_cur_output_len}," f"finish_status:{self.finish_status.is_finished()}" + f"group_id: {self.group_req_id}" ) def init( @@ -333,3 +334,55 @@ def post_init( # 错误问题。 self.sample_params.max_new_tokens = self.sample_params.max_new_tokens + self.prefix_token_ids.size + 6 return + + +class PdNixlReqState(ctypes.Structure): + _pack_ = 4 + _MAX_TP_SIZE = 32 + _fields_ = [("dp_world_size", ctypes.c_int), ("state", ctypes.c_int * _MAX_TP_SIZE)] + + def __init__(self): + self.dp_world_size = 0 + self.state = (ctypes.c_int * self._MAX_TP_SIZE)(*([0] * self._MAX_TP_SIZE)) + + def set_dp_world_size(self, size: int): + self.dp_world_size = size + + def set_tp_state(self, tp_id: int, state: int): + assert ( + self.dp_world_size > 0 and tp_id >= 0 and tp_id < self.dp_world_size + ), f"tp_id {tp_id} out of range [0, {self.dp_world_size})" + self.state[tp_id] = state + + def set_state(self): + assert self.dp_world_size > 0, "dp_world_size should be set before calling this" + unique_state = np.unique(self.state[: self.dp_world_size]) + self.state[self.dp_world_size] = unique_state[0] + + def get_state(self): + assert self.dp_world_size > 0, "dp_world_size should be set before calling this" + return self.state[self.dp_world_size] + + +class PDNIXLChunkedPrefillReq(ChunkedPrefillReq): + _pack_ = 4 + _fields_ = ChunkedPrefillReq._fields_ + [ + # 用于pd nixl状态同步 + ("pd_nixl_req_state", PdNixlReqState) + ] + + def set_dp_world_size(self, dp_world_size): + self.pd_nixl_req_state.dp_world_size = dp_world_size + + # called by each tp rank, no contention + def set_pd_req_rank_state(self, tp_id: int, state: int): + self.pd_nixl_req_state.set_tp_state(tp_id, state) + + # state: -1 for failed, 0 for in progress, 1 for success + # set by router + def set_pd_req_state(self): + self.pd_nixl_req_state.set_state() + + # read by all rank + def get_pd_req_state(self): + return self.pd_nixl_req_state.get_state() diff --git a/lightllm/server/core/objs/shm_req_manager.py b/lightllm/server/core/objs/shm_req_manager.py index 315eb938e..5dd73ad02 100644 --- a/lightllm/server/core/objs/shm_req_manager.py +++ b/lightllm/server/core/objs/shm_req_manager.py @@ -3,7 +3,7 @@ from lightllm.utils.envs_utils import get_unique_server_name from multiprocessing import shared_memory from lightllm.utils.log_utils import init_logger -from .req import Req, NormalReq, ChunkedPrefillReq, TokenHealingReq +from .req import Req, NormalReq, ChunkedPrefillReq, TokenHealingReq, PDNIXLChunkedPrefillReq from .shm_array import ShmArray from .atomic_array_lock import AtomicShmArrayLock, AtomicLockItem from .atomic_lock import AtomicShmLock @@ -33,6 +33,9 @@ def get_req_class_type(self): if args.token_healing_mode: return TokenHealingReq + if args.run_mode in ["nixl_prefill", "nixl_decode"]: + return PDNIXLChunkedPrefillReq + if args.disable_chunked_prefill: return NormalReq else: diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 0adb61a4c..8f354da23 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -6,7 +6,10 @@ @dataclass class StartArgs: - run_mode: str = field(default="normal", metadata={"choices": ["normal", "prefill", "decode", "pd_master"]}) + run_mode: str = field( + default="normal", + metadata={"choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode"]}, + ) host: str = field(default="127.0.0.1") port: int = field(default=8000) zmq_mode: str = field( diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index e02eaaf79..14e971a6a 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -98,7 +98,7 @@ def __init__( self.metric_client = MetricClient(metric_port) self.pd_mode: NodeRole = NodeRole(self.args.run_mode) - assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL] + assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL, NodeRole.NP, NodeRole.ND] self.id_gen = ReqIDGenerator() self.first_time_costs = MovingAverage() self.per_token_costs = MovingAverage() @@ -225,7 +225,7 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False): # health 请求 request_id 为负数,直接返回 if is_health_req: return sampling_params.group_request_id - if self.pd_mode == NodeRole.NORMAL: + if self.pd_mode.is_normal(): if not self.is_multinode_tp: group_request_id = self.id_gen.generate_id() else: @@ -235,7 +235,7 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False): assert sampling_params.group_request_id != -1 group_request_id = sampling_params.group_request_id sampling_params.group_request_id = group_request_id - elif self.pd_mode == NodeRole.P or self.pd_mode == NodeRole.D: + elif self.pd_mode.is_P_or_D(): assert sampling_params.group_request_id is not None, "p d mode, group_request_id must be setting" group_request_id = sampling_params.group_request_id else: @@ -412,7 +412,7 @@ async def transfer_to_next_module_or_node( if self.is_multinode_tp_master: async with self.transfer_lock: for sender in self.multinode_req_manager: - sender.send_pyobj( + await sender.send_pyobj( (prompt, sampling_params, original_multimodal_params), protocol=pickle.HIGHEST_PROTOCOL, ) @@ -441,35 +441,37 @@ async def transfer_to_next_module( group_req_objs: Optional[GroupReqObjs] = None, ): - if self.pd_mode == NodeRole.P: + if self.pd_mode.is_P(): if self.enable_multimodal: - self.send_to_visual.send_pyobj( + await self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) else: - self.send_to_router.send_pyobj( + + # P 模式下,直接将请求发送到路由器 + await self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) return - if self.pd_mode == NodeRole.D: + if self.pd_mode.is_D(): # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了, 传输一个空的即可 - self.send_to_router.send_pyobj( + await self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) return - if self.pd_mode == NodeRole.NORMAL: + if self.pd_mode.is_normal(): if self.enable_multimodal: - self.send_to_visual.send_pyobj( + await self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) else: - self.send_to_router.send_pyobj( + await self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) @@ -514,7 +516,7 @@ async def _wait_to_token_package( # pd master 节点需要这个做统计信息, 所以放在元数据中返回给 pd master 节点 metadata["prompt_tokens"] = prompt_tokens # p 节点返回 prompt_ids 信息,防止 d 节点重新 encode - if self.pd_mode == NodeRole.P and is_first_token: + if self.pd_mode.is_P() and is_first_token: metadata["prompt_ids"] = prompt_ids prompt_cache_len = metadata.pop("prompt_cache_len", 0) @@ -611,7 +613,7 @@ async def recycle_resource_loop(self): pre_time_mark = time.time() for req_status in self.req_id_to_out_inf.values(): logger.info( - f"left req id {req_status.group_req_objs.group_req_id}" + f"left req id {req_status.group_req_objs.group_req_id} " f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} " f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}" ) diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index 10a4a8ec5..94394182d 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -5,6 +5,7 @@ import socket import httpx import base64 +import zmq from typing import Dict, Optional from lightllm.server.pd_io_struct import NodeRole, ObjType from lightllm.server.httpserver.async_queue import AsyncQueue @@ -33,6 +34,8 @@ async def pd_handle_loop(manager: HttpServerManager): manager.host_ip = manager.args.host asyncio.create_task(timer_log(manager)) + if manager.pd_mode.is_NP_or_ND(): + asyncio.create_task(pd_handle_loop_from_d(manager)) id_to_handle_task: Dict[int, asyncio.Task] = {} @@ -92,7 +95,8 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O logger.info(f"Sent registration JSON: {regist_json}") # 转发任务 - forwarding_tokens_task = asyncio.create_task(_up_tokens_to_pd_master(forwarding_queue, websocket)) + if manager.pd_mode != NodeRole.NP: # nixl prefill don't need up token to master + forwarding_tokens_task = asyncio.create_task(_up_tokens_to_pd_master(forwarding_queue, websocket)) # 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。 while True: @@ -182,3 +186,33 @@ async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket): handle_list = await forwarding_queue.wait_to_get_all_data() if handle_list: await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list))) + + +async def pd_handle_loop_from_d(manager: HttpServerManager): + if manager.pd_mode != NodeRole.NP: + return + + context = zmq.asyncio.Context(2) + manager.recv_from_d = context.socket(zmq.PULL) + manager.recv_from_d.bind(f"tcp://*:{manager.args.pd_nixl_remote_prefill_http_port}") + + while True: + try: + ( + prompt, + sampling_params, + multimodal_params, + ) = await manager.recv_from_d.recv_pyobj() + + # 触发推理的task + async def pd_process_generate(manager: "HttpServerManager", prompt, sampling_params, multimodal_params): + try: + async for _, _, _, _ in manager.generate(prompt, sampling_params, multimodal_params, None): + pass + except BaseException as e: + logger.error(str(e)) + + asyncio.create_task(pd_process_generate(manager, prompt, sampling_params, multimodal_params)) + + except Exception as e: + logger.exception(f"pd loop generate error: {str(e)}") diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 05b2d987c..65bec6a1c 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -1,20 +1,15 @@ import sys -import zmq -import zmq.asyncio import asyncio import uvloop -import rpyc import time -import hashlib import datetime -import aiohttp import ujson as json import pickle asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from typing import Union, List, Tuple, Dict from lightllm.server.core.objs import FinishStatus -from ..pd_io_struct import PD_Client_Obj, UpKVStatus, ObjType +from ..pd_io_struct import PD_Client_Obj, UpKVStatus, ObjType, NodeRole from lightllm.server.core.objs import SamplingParams from ..multimodal_params import MultimodalParams from ..tokenizer import get_tokenizer @@ -55,10 +50,11 @@ async def register_pd(self, pd_info_json, websocket): pd_client = PD_Client_Obj(**pd_info_json) pd_client.websocket = websocket self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client - if pd_client.mode == "prefill": + client_pd_mode: NodeRole = NodeRole(pd_client.mode) + if client_pd_mode.is_P(): self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port] self.prefill_nodes.append(pd_client) - elif pd_client.mode == "decode": + elif client_pd_mode.is_D(): self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port] self.decode_nodes.append(pd_client) else: @@ -110,8 +106,8 @@ async def select_p_d_node( ) -> Tuple[PD_Client_Obj, PD_Client_Obj]: import random - p_node = random.choice(self.prefill_nodes) - d_node = random.choice(self.decode_nodes) + p_node = random.choice(self.prefill_nodes) if self.prefill_nodes else None + d_node = random.choice(self.decode_nodes) if self.decode_nodes else None return p_node, d_node async def generate( @@ -133,6 +129,10 @@ async def generate( p_node, d_node = await self.select_p_d_node(prompt, sampling_params, multimodal_params) + if not p_node or not d_node: + logger.error(f"{group_request_id}: No p_node or d_node found") + return + results_generator = self._wait_to_token_package( p_node, d_node, @@ -233,8 +233,8 @@ async def fetch_stream( try: await asyncio.wait_for(up_status_event.wait(), timeout=60) except asyncio.TimeoutError: - logger.warning(f"group_request_id: {group_request_id} kv move time out err, server is busy now.") - raise ServerBusyError() + logger.warning(f"group_request_id: {group_request_id} kv move time out err") + assert False, f"req_id {group_request_id} kv move time out, server is busy" sampling_params.move_kv_to_decode_node.initialize(None) sampling_params.max_new_tokens = old_max_new_tokens - 1 @@ -253,6 +253,43 @@ async def fetch_stream( return + async def fetch_stream_nixl( + self, + p_node: PD_Client_Obj, + d_node: PD_Client_Obj, + prompt: Union[str, List[int]], + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + request: Request, + ): + group_request_id = sampling_params.group_request_id + + req_status = ReqStatus(group_request_id, p_node, d_node) + self.req_id_to_out_inf[group_request_id] = req_status + + p_start_args = p_node.start_args + prefill_node_dict = { + "node_id": p_start_args["pd_node_id"], + "ip": p_start_args["host"], + "rpyc_port": p_start_args["pd_nixl_remote_prefill_port"], + "max_new_tokens": sampling_params.max_new_tokens, + "pd_master_node_id": self.args.pd_node_id, + } + + sampling_params.move_kv_to_decode_node.initialize(prefill_node_dict) + sampling_params.suggested_dp_index = -1 + + await d_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params)))) + + while True: + await req_status.wait_to_ready() + if await request.is_disconnected(): + raise Exception(f"req_id {group_request_id} disconnected") + if await req_status.can_read(self.req_id_to_out_inf): + token_list = await req_status.pop_all_tokens() + for sub_req_id, request_output, metadata, finish_status in token_list: + yield sub_req_id, request_output, metadata, finish_status + async def _wait_to_token_package( self, p_node: PD_Client_Obj, @@ -269,7 +306,11 @@ async def _wait_to_token_package( unfinished_count = sampling_params.best_of is_first_token = True - async for sub_req_id, out_str, metadata, finish_status in self.fetch_stream( + client_mode: NodeRole = NodeRole(d_node.mode) + + fetch_stream = self.fetch_stream_nixl if client_mode.is_NP_or_ND() else self.fetch_stream + + async for sub_req_id, out_str, metadata, finish_status in fetch_stream( p_node, d_node, prompt, sampling_params, multimodal_params, request ): if await request.is_disconnected(): diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index bf320e199..2e2dab5b0 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -154,6 +154,12 @@ def to_dict(self): ret["audios"] = [a.to_dict() for a in self.audios] return ret + @classmethod + def from_dict(cls, data: dict): + if "images" not in data: + return cls() + return cls(images=data["images"]) + def to_origin_dict(self): """ 将内容转换为原始请求的形式,主要用于请求转发 diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 414e3c74a..4267afaee 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -13,23 +13,30 @@ class NodeRole(enum.Enum): P = "prefill" D = "decode" + + NP = "nixl_prefill" + ND = "nixl_decode" + NORMAL = "normal" PD_MASTER = "pd_master" def is_D(self): - return self == NodeRole.D + return self == NodeRole.D or self == NodeRole.ND def is_P(self): - return self == NodeRole.P + return self == NodeRole.P or self == NodeRole.NP def is_normal(self): return self == NodeRole.NORMAL def is_P_or_NORMAL(self): - return (self == NodeRole.P) or (self == NodeRole.NORMAL) + return self.is_P() or self.is_normal() def is_P_or_D(self): - return (self == NodeRole.P) or (self == NodeRole.D) + return self.is_P() or self.is_D() + + def is_NP_or_ND(self): + return self == NodeRole.NP or self == NodeRole.ND class ObjType(enum.Enum): @@ -47,8 +54,8 @@ class PD_Client_Obj: websocket: WebSocket = None # 用于通信的 websocket 连接对象 def __post_init__(self): - if self.mode not in ["prefill", "decode"]: - error_info = f"""mode must in ["prefill", "decode"], but get {self.mode}""" + if self.mode not in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + error_info = f"""mode must in ["prefill", "decode", "nixl_prefill", "nixl_decode"], but get {self.mode}""" logger.error(error_info) raise ValueError(error_info) return @@ -114,6 +121,23 @@ class PDTransJoinInfo: connect_id: str +@dataclass +class RemotePrefillServerInfo: + perfill_server_id: int + prefill_server_ip: str + prefill_server_port: int + + +@dataclass +class DistInfo: + world_size: int + nnodes: int + dp_size: int + dp_world_size: int + dp_size_in_node: int + node_world_size: int + + @dataclass class PDTransLeaveInfo: decode_id: int diff --git a/lightllm/server/router/batch.py b/lightllm/server/router/batch.py index 15d04b208..5b8edeb29 100644 --- a/lightllm/server/router/batch.py +++ b/lightllm/server/router/batch.py @@ -1,8 +1,9 @@ import time from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union -from lightllm.server.core.objs import ShmReqManager, Req +from lightllm.server.core.objs import ShmReqManager, Req, PDNIXLChunkedPrefillReq from lightllm.utils.log_utils import init_logger +from lightllm.utils.dist_utils import get_dp_world_size logger = init_logger(__name__) @@ -53,6 +54,8 @@ def filter_out_finished_req(self, shm_req_manager: ShmReqManager): req = None else: unfinished_req_ids.append(req.request_id) + if isinstance(req, PDNIXLChunkedPrefillReq): + req.set_pd_req_state() self.reqs = [self.id_to_reqs[req_id] for req_id in unfinished_req_ids] self.id_to_reqs = {req.request_id: req for req in self.reqs} diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 2811e4228..ff654050b 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -22,7 +22,7 @@ from .req_queue import build_req_queue from lightllm.utils.infer_utils import calculate_time from lightllm.server.core.objs.io_objs import GroupReqIndexes -from lightllm.server.core.objs import ShmReqManager +from lightllm.server.core.objs import ShmReqManager, PDNIXLChunkedPrefillReq from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient from .stats import Stats from .pause_strategy import Fcfs, select_paused_reqs @@ -34,6 +34,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.pd_io_struct import DistInfo logger = init_logger(__name__) @@ -49,6 +50,7 @@ def __init__(self, args, router_port, detokenization_port, metric_port): self.dp_size = args.dp # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容 self.dp_size_in_node = max(1, args.dp // self.nnodes) + self.dp_world_size = self.world_size // self.dp_size self.is_multinode_tp = args.nnodes > 1 and args.dp == 1 self.is_multinode_and_multidp = args.nnodes > 1 and args.dp > 1 # 判断是否是保守调度,保守调度不会发生暂停 req 的情况,但是有些场景可能影响吞吐 @@ -97,8 +99,8 @@ def __init__(self, args, router_port, detokenization_port, metric_port): self.stats_tool = Stats(not args.disable_log_stats, args.log_stats_interval) self.metric_client = MetricClient(metric_port) - self.is_pd_run_mode = self.args.run_mode in ["prefill", "decode"] - self.is_pd_decode_mode = self.args.run_mode == "decode" + self.is_pd_run_mode = self.args.run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"] + self.is_pd_decode_mode = self.args.run_mode in ["decode", "nixl_decode"] # p d 分离模式下,需要调度锁来同步调度端和推理端的一些数据操作 # 主要是为了防止调度失误,造成 OOM 等错误 self.router_lock = mp.Lock() @@ -118,12 +120,14 @@ async def wait_to_model_ready(self): self.mem_queues: List[torch.multiprocessing.Queue] = [ torch.multiprocessing.Queue() for _ in range(self.node_world_size) ] + self.result_queues: List[mp.Queue] = [mp.Queue() for _ in range(self.node_world_size)] self.rpc_event = multiprocessing.Event() self.rpc_finished_event = multiprocessing.Event() assert (self.world_size % self.nnodes) == 0 node_world_size = self.world_size // self.nnodes for rank_id in range(self.node_rank * node_world_size, (self.node_rank + 1) * node_world_size): + rpc_model = await start_model_process( args=self.args, rank=rank_id, @@ -132,7 +136,8 @@ async def wait_to_model_ready(self): rpc_event=self.rpc_event, rpc_finished_event=self.rpc_finished_event, info_queue=self.info_queue, - mem_queue=self.mem_queues[(rank_id % node_world_size)], + result_queue=self.result_queues[rank_id % node_world_size], + mem_queue=self.mem_queues[rank_id % node_world_size], router_lock=self.router_lock, ) self.model_rpc_servers.append(rpc_model) @@ -187,7 +192,7 @@ async def wait_to_model_ready(self): get_unique_server_name(), self.max_total_token_num, node_world_size=self.node_world_size, - dp_world_size=self.world_size // self.dp_size, + dp_world_size=self.dp_world_size, ) self.req_queue = build_req_queue(self.args, self, self.dp_size_in_node) logger.info(f"use req queue {self.req_queue.__class__.__name__}") @@ -200,6 +205,30 @@ async def wait_to_model_ready(self): start_prefill_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + if self.args.run_mode == "nixl_prefill": + from lightllm.server.router.model_infer.mode_backend.pd_nixl.pd_remote_prefill import ( + start_pd_remote_prefill_server_process, + ) + + dist_info = DistInfo( + self.world_size, + self.nnodes, + self.dp_size, + self.dp_world_size, + self.dp_size_in_node, + self.node_world_size, + ) + + start_pd_remote_prefill_server_process( + self.args.pd_node_id, + dist_info=dist_info, + http_server_port=self.args.pd_nixl_remote_prefill_http_port, + server_port=self.args.pd_nixl_remote_prefill_port, + from_backend_queue=self.info_queue, + to_backend_queues=self.result_queues, + agent_meta_queues=self.mem_queues, + ) + if self.args.run_mode == "decode": # 启动 decode kv move 管理进程 from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.decode_node_impl import ( @@ -208,6 +237,28 @@ async def wait_to_model_ready(self): start_decode_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + if self.args.run_mode == "nixl_decode": + from lightllm.server.router.model_infer.mode_backend.pd_nixl.pd_remote_prefill import ( + start_pd_remote_prefill_client_process, + ) + + dist_info = DistInfo( + self.world_size, + self.nnodes, + self.dp_size, + self.dp_world_size, + self.dp_size_in_node, + self.node_world_size, + ) + + start_pd_remote_prefill_client_process( + self.args.pd_node_id, + dist_info, + from_backend_queue=self.info_queue, + to_backend_queues=self.result_queues, + agent_meta_queues=self.mem_queues, + ) + return def add_req(self, group_req_indexes: GroupReqIndexes): @@ -216,6 +267,8 @@ def add_req(self, group_req_indexes: GroupReqIndexes): req = self.shm_req_manager.get_req_obj_by_index(req_index) req.multimodal_params = group_req_indexes.multimodal_params req.start_time = group_req_indexes.time_mark + if isinstance(req, PDNIXLChunkedPrefillReq): + req.set_dp_world_size(self.dp_world_size) req_group.append(req) logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s") diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 09f649943..1f8888417 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -10,7 +10,7 @@ from typing import List, Dict, Tuple, Optional, Union, Any from lightllm.common.req_manager import ReqManager from lightllm.utils.infer_utils import mark_start, mark_end -from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager +from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager, PDNIXLChunkedPrefillReq from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode from lightllm.utils.log_utils import init_logger from lightllm.server.req_id_generator import convert_sub_id_to_group_id @@ -118,7 +118,7 @@ def _save_promptcache_kvbuffer(self): https://arxiv.org/abs/2403.01241 """ prompt_cache_token_id = list(self.radix_cache.root_node.children.values())[0].token_id_key - print(f"prompt_cache_token_id : {prompt_cache_token_id}") + # print(f"prompt_cache_token_id : {prompt_cache_token_id}") index = range(len(prompt_cache_token_id)) prompt_cache_kv_buffer = self.radix_cache.mem_manager.get_index_kv_buffer(index) torch.save(prompt_cache_kv_buffer, f"prompt_cache_rank_{dist.get_rank()}.pt") @@ -257,12 +257,16 @@ def __init__( self.vocab_size = vocab_size self.initialized = False self.paused = False + self.in_prefill_or_transfer = False def init_all(self): if self.initialized is False: self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index) self.shm_req.link_prompt_ids_shm_array() self.shm_req.link_logprobs_shm_array() + if isinstance(self.shm_req, PDNIXLChunkedPrefillReq): + self.in_prefill_or_transfer = False + self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size) if self.sampling_param.shm_param.input_penalty: self.out_token_id_count = collections.Counter(self.shm_req.get_prompt_ids()) @@ -300,6 +304,7 @@ def init_all(self): self.initialized = True self.paused = False + self.in_prefill_or_transfer = False return def is_uninitialized(self): diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index 4594eec28..fe13a60a7 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -12,3 +12,7 @@ from .chunked_prefill.impl_for_xgrammar_mode import XgrammarBackend from .continues_batch.pd_mode.prefill_node_impl.prefill_impl_for_dp_chuncked import DPChunkedForPrefillNode from .continues_batch.pd_mode.decode_node_impl.decode_impl_for_dp import DPForDecodeNode +from .pd_nixl.impl_for_pd_prefill import PDNIXLBackendForPrefillNode +from .pd_nixl.impl_for_pd_decode import PDNIXLBackendForDecodeNode +from .pd_nixl.impl_for_pd_decode_dp import PDNIXLDPBackendForDecodeNode +from .pd_nixl.impl_for_pd_prefill_dp import PDNIXLDPBackendForPrefillNode diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index e4db0aca7..8f2531869 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -250,6 +250,7 @@ def _post_handle( is_chuncked_mode: bool, do_filter_finished_reqs: bool, extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, + extra_post_req_handle_chunk_func: Optional[Callable[[InferReq], None]] = None, ) -> List[int]: """ extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于 @@ -274,6 +275,10 @@ def _post_handle( finished_req_ids.append(req_obj.shm_req.request_id) continue + if extra_post_req_handle_chunk_func is not None: + # 如果存在额外的处理函数,则调用这个函数进行处理。 + extra_post_req_handle_chunk_func(req_obj) + # 对于没有到达需要输出 token 阶段的请求,直接略过 if req_obj.cur_kv_len < req_obj.get_cur_total_len(): continue diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 9f174e8bc..66d056709 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -79,7 +79,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]): nopad_b_req_idx.append(req.req_idx) input_id = req.get_last_gen_token() seq_len = req.get_cur_total_len() - assert req.cur_kv_len == seq_len - 1 + assert req.cur_kv_len == seq_len - 1, f"{req.cur_kv_len} {seq_len}" nopad_b_seq_len.append(seq_len) input_ids.append(input_id) nopad_total_token_num += seq_len diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py new file mode 100644 index 000000000..2156944d8 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py @@ -0,0 +1,308 @@ +import time +import torch.multiprocessing as mp +from typing import Dict, List +import queue +import numpy as np + + +from lightllm.utils.log_utils import init_logger +from lightllm.server.core.objs.req import PDNIXLChunkedPrefillReq +from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq + +from .nixl_kv_transporter import NixlMetadata, NixlKVTransporter +from .pd_remote_prefill_obj import ( + PrefillRequest, + RemoteRequest, + RemoteRequstType, + ConnectRequest, + KVMoveRequest, + RemotePrefillStatus, + ThreadSafeDict, + TransferState, +) + +logger = init_logger(__name__) + + +class PDNIXLBackendBase(ModeBackend): + _THREAD_WAIT_INTERVAL = 0.001 + + def __init__(self, to_remote_queue: mp.Queue, from_remote_queue: mp.Queue, nixl_meta_queue: mp.Queue): + super().__init__() + self.to_remote_queue = to_remote_queue + self.from_remote_queue = from_remote_queue + self.nixl_meta_queue = nixl_meta_queue + self.prefill_post_handle_queue = queue.Queue() + + # for decode + self.remote_prefilled_reqs: ThreadSafeDict = ThreadSafeDict() + + # for prefill + self.remote_prefill_requests: ThreadSafeDict = ThreadSafeDict() + self.inflght_transfer_requests: ThreadSafeDict = ThreadSafeDict() + + def init_custom(self): + self.nixl_agent = NixlKVTransporter(self.args.pd_node_id, self.rank_in_node) + self.nixl_agent.register_kv_buffer(self.model.mem_manager.kv_buffer) + self.nixl_meta_queue.put( + (self.nixl_agent.agent_metadata, self.nixl_agent.num_tokens, self.nixl_agent.local_mem_desc) + ) + + def _prefill_wait_loop(self): + while True: + + def handle_remote_prefill(req_status: RemotePrefillStatus): + group_req_id = req_status.group_req_id + status = req_status.status + if status != 1: + logger.warning(f"remote prefill reqeust: {group_req_id} done with state: {status}") + + if run_req := self.remote_prefilled_reqs.get(group_req_id, None): + if req_status.is_last or status != 1: + shm_req: PDNIXLChunkedPrefillReq = run_req.shm_req + shm_req.set_pd_req_rank_state(self.rank_in_dp, status) + self.remote_prefilled_reqs.pop(group_req_id) + if self.is_master_in_dp: + logger.info( + f"remote prefill reqeust: {group_req_id} done with status: {status} " + f"took: {time.time() - run_req.remote_prefill_start} seconds" + ) + else: + if self.is_master_in_dp: + logger.warning(f"remote prefill reqeust: {group_req_id} not found") + + # from local + try: + req_status = self.from_remote_queue.get_nowait() + handle_remote_prefill(req_status) + except queue.Empty: + pass + + # from remote + notifies = self.nixl_agent.get_new_notifs() + for agent_name, req_statuses in notifies.items(): + for req_status in req_statuses: + prefill_status = RemotePrefillStatus.deserialize(req_status) + handle_remote_prefill(prefill_status) + + time.sleep(PDNIXLBackendBase._THREAD_WAIT_INTERVAL) + + def _handle_transfer_loop(self): + while True: + try: + req: InferReq = self.prefill_post_handle_queue.get() + self._transfer_kv_to_remote(req) + time.sleep(PDNIXLBackendBase._THREAD_WAIT_INTERVAL) + except queue.Empty: + pass + + + def _wait_transfer_loop(self): + while True: + done_req_ids = self.nixl_agent.get_done_tranfers() + for req_id, state in done_req_ids: + if state != 1: + logger.info(f"wait transfer done: {req_id} state: {state}") + + if req_id not in self.inflght_transfer_requests: + if self.is_master_in_dp: + logger.warning(f"{req_id} not found in inflght_transfer_requests") + continue + + req: InferReq = self.inflght_transfer_requests[req_id] + shm_req: PDNIXLChunkedPrefillReq = req.shm_req + shm_req.set_pd_req_rank_state(self.rank_in_dp, state) + transfer_state = self.remote_prefill_requests[req_id].transfer_state + if self.is_master_in_dp: + logger.info( + f"req: {req_id} kv transfer with state: {state} " + f"took: {time.time() - transfer_state.start_time} seconds" + ) + del self.remote_prefill_requests[req_id] + del self.inflght_transfer_requests[req_id] + + time.sleep(PDNIXLBackendBase._THREAD_WAIT_INTERVAL) + + def _handle_prefill_loop(self): + while True: + request: RemoteRequest = self.from_remote_queue.get() + if request.type == RemoteRequstType.REMOTE_CONNECT: + request: ConnectRequest + logger.info(f"connect request received from: {request.decode_id}") + self.nixl_agent.add_remote_agent( + NixlMetadata( + id=request.decode_id, + num_tokens=request.num_tokens, + agent_metadatas=request.agent_metadatas, + agent_mem_descs=request.agent_mem_descs, + ) + ) + self.to_remote_queue.put("OK") + + if request.type == RemoteRequstType.REMOTE_PREFILL: + request: PrefillRequest + group_request_id = request.data.sampling_params.group_request_id + logger.info( + f"prefill request received from decode: {request.decode_id} " + f"and group request id: {group_request_id}" + ) + self.remote_prefill_requests[group_request_id] = request + + def _transfer_kv_to_remote(self, req: InferReq): + group_req_id = req.shm_req.group_req_id + # set state + if group_req_id not in self.remote_prefill_requests: + logger.info(f"remote prefill request {group_req_id} not found") + return + start = time.time() + remote_request: PrefillRequest = self.remote_prefill_requests[group_req_id] + if remote_request.transfer_state is None: + remote_request.transfer_state = TransferState( + start_time=time.time(), + current_kv_len=0, + current_chunk_id=0, + ) + + transfer_state = remote_request.transfer_state + token_index = self.model.req_manager.req_to_token_indexs[req.req_idx] + is_finished = req.finish_status.is_finished() + + kv_transfer_req = KVMoveRequest( + group_req_id=group_req_id, + token_ids=token_index[: req.cur_kv_len].tolist(), + prev_kv_len=transfer_state.current_kv_len, + cur_kv_len=req.cur_kv_len, + ) + if transfer_state.current_chunk_id == 0: + shm_req: PDNIXLChunkedPrefillReq = req.shm_req + shm_req.set_pd_req_rank_state(self.rank_in_dp, 0) + req.in_prefill_or_transfer = True + self.inflght_transfer_requests[group_req_id] = req + logger.debug( + f"put {group_req_id} into inflght_transfer_requests and size: {len(self.inflght_transfer_requests)}" + ) + + # kick off kv transfer + self.nixl_agent.write_blocks(kv_transfer_req, remote_request, is_finished) + + transfer_state.current_kv_len = req.cur_kv_len + transfer_state.current_chunk_id += 1 + logger.info( + f"transfer kv to remote: {group_req_id} " + f"current chunk id: {transfer_state.current_chunk_id} " + f"took: {time.time() - start} seconds" + ) + + def _decode_filter_reqs( + self, prefill_reqs: List[InferReq], aborted_reqs: List[InferReq], decode_reqs: List[InferReq] + ): + new_prefill_reqs: List[InferReq] = [] + new_aborted_reqs: List[InferReq] = [] + remote_prefill_reqs: List[InferReq] = [] + + # filter out aborted requests + for req in aborted_reqs: + if req.in_prefill_or_transfer: + shm_req: PDNIXLChunkedPrefillReq = req.shm_req + state = shm_req.get_pd_req_state() + if state != 0: + new_aborted_reqs.append(req) + req.in_prefill_or_transfer = False + else: + # TODO trigger remote abort + remote_prefill_reqs.append(req) + else: + new_aborted_reqs.append(req) + + for req in prefill_reqs: + if req.in_prefill_or_transfer: + shm_req: PDNIXLChunkedPrefillReq = req.shm_req + # state is updated by router + state = shm_req.get_pd_req_state() + if state == 1: # success + req.cur_kv_len = req.get_cur_total_len() - 1 + decode_reqs.append(req) + req.in_prefill_or_transfer = False + elif state == -1: # failure + new_aborted_reqs.append(req) + req.in_prefill_or_transfer = False + elif state == 0: # in progress + remote_prefill_reqs.append(req) + else: + logger.warning(f"remote prefill request {shm_req.group_req_id} unexpected state {state}") + continue + + new_prefill_reqs.append(req) + + return new_prefill_reqs, new_aborted_reqs, decode_reqs, remote_prefill_reqs + + def _prefill_filter_reqs(self, ok_finished_reqs: List[InferReq], aborted_reqs: List[InferReq]): + new_ok_finished_reqs = [] + kv_transfer_reqs = [] + + for req in ok_finished_reqs: + if req.in_prefill_or_transfer: + shm_req: PDNIXLChunkedPrefillReq = req.shm_req + state = shm_req.get_pd_req_state() + if state == 1: # success + new_ok_finished_reqs.append(req) + req.in_prefill_or_transfer = False + elif state == -1: # failure + aborted_reqs.append(req) + req.in_prefill_or_transfer = False + elif state == 0: + kv_transfer_reqs.append(req) + else: + logger.warning(f"remote prefill request {shm_req.group_req_id} unexpected state {state}") + continue + + new_ok_finished_reqs.append(req) + + return new_ok_finished_reqs, aborted_reqs, kv_transfer_reqs + + def _prepare_remote_prefill_inputs(self, req_objs: List[InferReq]): + run_reqs = [] + start_loc = 0 + input_ids = [] + nopad_b_req_idx = [] + nopad_b_start_loc = [] + nopad_b_seq_len = [] + + for req in req_objs: + run_reqs.append(req) + nopad_b_req_idx.append(req.req_idx) + nopad_b_start_loc.append(start_loc) + + input_token_ids = req.get_input_token_ids() + seq_len = len(input_token_ids) + input_token_len = seq_len - req.cur_kv_len + input_id = input_token_ids[req.cur_kv_len :] + nopad_b_seq_len.append(seq_len) + input_ids.append(input_id) + start_loc += input_token_len + + nopad_b_start_loc.append(start_loc) # last request + + input_ids = np.concatenate(input_ids, dtype=np.int64) + + if g_infer_context.radix_cache is not None: + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]) + + kwargs = { + "batch_size": len(run_reqs), + "input_ids": input_ids, + "mem_indexes": mem_indexes.tolist(), + "b_req_idx": nopad_b_req_idx, + "b_start_loc": nopad_b_start_loc, + "b_seq_len": nopad_b_seq_len, + } + return kwargs, run_reqs + + def _prefill_abort_remote(self, req_objs: List[InferReq]): + for req_obj in req_objs: + group_req_id = req_obj.shm_req.group_req_id + if group_req_id in self.remote_prefill_requests: + self.nixl_agent.send_abort_notify(self.remote_prefill_requests[group_req_id].decode_id, group_req_id) + del self.remote_prefill_requests[group_req_id] diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py new file mode 100644 index 000000000..6b30759a6 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py @@ -0,0 +1,101 @@ +import time +import torch +import torch.multiprocessing as mp +import threading +from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend +from typing import List, Tuple, Dict +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.server.core.objs.req import PDNIXLChunkedPrefillReq +from lightllm.utils.log_utils import init_logger +from lightllm.server.router.model_infer.mode_backend.generic_pre_process import prepare_decode_inputs +from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample +from lightllm.server.multimodal_params import MultimodalParams + +from .pd_remote_prefill_obj import RemotePrefillTask, RemotePrefillServerInfo, RemotePrefillRequest + +from .impl_for_pd_base import PDNIXLBackendBase + +logger = init_logger(__name__) + + +class PDNIXLBackendForDecodeNode(PDNIXLBackendBase): + def __init__(self, prefill_task_queue: mp.Queue, prefill_done_queue: mp.Queue, nix_meta_queue: mp.Queue) -> None: + super().__init__(prefill_task_queue, prefill_done_queue, nix_meta_queue) + + def init_custom(self): + super().init_custom() + self.wait_prefill_thread = threading.Thread(target=self._prefill_wait_loop, daemon=True) + self.wait_prefill_thread.start() + return + + def _build_remote_prefill_task(self, index: int, kwargs: Dict, req: InferReq): + prefill_node = req.shm_req.sample_params.move_kv_to_decode_node.to_dict() + prefill_node_info = RemotePrefillServerInfo( + perfill_server_id=prefill_node["node_id"], + prefill_server_ip=prefill_node["ip"], + prefill_server_port=prefill_node["rpyc_port"], + ) + + mem_indexes = kwargs.get("mem_indexes") + b_start_loc = kwargs.get("b_start_loc") + prefill_request = RemotePrefillRequest( + prompt=req.shm_req.get_prompt_ids(), + sampling_params=req.shm_req.sample_params, + multimodal_params=MultimodalParams.from_dict(req.multimodal_params), + local_cached_len=req.cur_kv_len, + token_ids=mem_indexes[b_start_loc[index] : b_start_loc[index + 1]], + ) + return RemotePrefillTask(server_info=prefill_node_info, prefill_request=prefill_request) + + def prefill(self, reqs: List[Tuple]): + self._init_reqs(reqs, init_req_obj=False) + return + + def decode(self): + + uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids, + no_decode=False, + ) + # filter out remote prefilling reqs + prefill_reqs, aborted_reqs, decode_reqs, _ = self._decode_filter_reqs(prefill_reqs, aborted_reqs, decode_reqs) + + self._filter_reqs(aborted_reqs) + + # allocate kv cache, do remote prefill + if prefill_reqs: + # TODO: we could allocate cache later after remote prefill done and get a signal from remote + # but it will have a risk to not have enough cache for this request. + kwargs, run_reqs = self._prepare_remote_prefill_inputs(prefill_reqs) + for idx, run_req in enumerate(run_reqs): + run_req: InferReq = run_req + shm_req: PDNIXLChunkedPrefillReq = run_req.shm_req + # forward each req to remote prefill + # since the token index are the same across TPs, we only need to trigger prefill on master + if self.is_master_in_dp: + run_req.remote_prefill_start = time.time() + self.to_remote_queue.put(self._build_remote_prefill_task(idx, kwargs, run_req)) + + shm_req.set_pd_req_rank_state(self.rank_in_dp, 0) # set in progress state + run_req.in_prefill_or_transfer = True + self.remote_prefilled_reqs[shm_req.group_req_id] = run_req + + if decode_reqs: + kwargs, run_reqs = prepare_decode_inputs(decode_reqs) + logits = self.model.forward(**kwargs) + + self._overlap_req_init_and_filter( + uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True + ) + + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + + self._post_handle( + run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False + ) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + + return diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py new file mode 100644 index 000000000..c6790da58 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py @@ -0,0 +1,122 @@ +import time +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +from typing import List +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.server.core.objs.req import PDNIXLChunkedPrefillReq +from lightllm.utils.log_utils import init_logger +from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import padded_prepare_decode_inputs + +from .impl_for_pd_decode import PDNIXLBackendForDecodeNode + +logger = init_logger(__name__) + + +class PDNIXLDPBackendForDecodeNode(PDNIXLBackendForDecodeNode): + def __init__(self, prefill_task_queue: mp.Queue, prefill_done_queue: mp.Queue, nix_meta_queue: mp.Queue) -> None: + super().__init__(prefill_task_queue, prefill_done_queue, nix_meta_queue) + self.enable_decode_microbatch_overlap = get_env_start_args().enable_decode_microbatch_overlap + + def init_custom(self): + super().init_custom() + + self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) + from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import padded_prepare_prefill_inputs + + kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs([], 1, is_multimodal=self.is_multimodal) + self.model.forward(**kwargs) + assert len(run_reqs) == 0 and padded_req_num == 1 + + return + + def decode(self): + + uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids, + no_decode=False, + ) + # filter out remote prefilling reqs + prefill_reqs, aborted_reqs, decode_reqs, _ = self._decode_filter_reqs(prefill_reqs, aborted_reqs, decode_reqs) + + self._filter_reqs(aborted_reqs) + + # allocate kv cache, do remote prefill + if prefill_reqs: + # TODO: we could allocate cache later after remote prefill done and get a signal from remote + # but it will have a risk to not have enough cache for this request. + kwargs, run_reqs = self._prepare_remote_prefill_inputs(prefill_reqs) + for idx, run_req in enumerate(run_reqs): + run_req: InferReq = run_req + shm_req: PDNIXLChunkedPrefillReq = run_req.shm_req + # forward each req to remote prefill + # since the token index are the same across TPs, we only need to trigger prefill on master + if self.is_master_in_dp: + run_req.remote_prefill_start = time.time() + self.to_remote_queue.put(self._build_remote_prefill_task(idx, kwargs, run_req)) + + shm_req.set_pd_req_rank_state(self.rank_in_dp, 0) # set in progress state + run_req.in_prefill_or_transfer = True + self.remote_prefilled_reqs[shm_req.group_req_id] = run_req + + self.reduce_tensor.fill_(len(decode_reqs)) + dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX) + max_decode_num = self.reduce_tensor.item() + if max_decode_num != 0: + if not self.enable_decode_microbatch_overlap: + self.normal_decode(decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs) + else: + self.overlap_decode(decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs) + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + return + + def normal_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs): + + kwargs, run_reqs, padded_req_num = padded_prepare_decode_inputs( + decode_reqs, max_decode_num, is_multimodal=self.is_multimodal + ) + logits = self.model.forward(**kwargs) + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + if len(run_reqs) != 0: + logits = logits[0 : len(run_reqs), :] + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + self._post_handle( + run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False + ) + return + + def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs): + from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import ( + padded_overlap_prepare_decode_inputs, + ) + + ( + micro_batch, + run_reqs, + padded_req_num, + micro_batch1, + run_reqs1, + padded_req_num1, + ) = padded_overlap_prepare_decode_inputs(decode_reqs, max_decode_num, is_multimodal=self.is_multimodal) + + logits, logits1 = self.model.microbatch_overlap_decode(micro_batch, micro_batch1) + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + req_num, req_num1 = len(run_reqs), len(run_reqs1) + all_logits = torch.empty((req_num + req_num1, logits.shape[1]), dtype=logits.dtype, device=logits.device) + + all_logits[0:req_num, :].copy_(logits[0:req_num, :], non_blocking=True) + all_logits[req_num : (req_num + req_num1), :].copy_(logits1[0:req_num1, :], non_blocking=True) + + all_run_reqs = run_reqs + run_reqs1 + if all_run_reqs: + next_token_ids, next_token_probs = sample(all_logits, all_run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + self._post_handle( + all_run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False + ) + return diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py new file mode 100644 index 000000000..736e52835 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py @@ -0,0 +1,66 @@ +import threading +import torch +import torch.multiprocessing as mp +from typing import List, Tuple +from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end +from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context +from lightllm.utils.log_utils import init_logger +from lightllm.server.router.model_infer.mode_backend.generic_pre_process import prepare_prefill_inputs +from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample +from .impl_for_pd_base import PDNIXLBackendBase + +logger = init_logger(__name__) + + +class PDNIXLBackendForPrefillNode(PDNIXLBackendBase): + def __init__(self, transfer_task_queue: mp.Queue, transfer_done_queue: mp.Queue, nixl_meta_queue: mp.Queue) -> None: + super().__init__(transfer_task_queue, transfer_done_queue, nixl_meta_queue) + + def init_custom(self): + super().init_custom() + self.handle_prefill_loop_thread = threading.Thread(target=self._handle_prefill_loop, daemon=True) + self.wait_transfer_loop_thread = threading.Thread(target=self._wait_transfer_loop, daemon=True) + self.handle_transfer_loop_thread = threading.Thread(target=self._handle_transfer_loop, daemon=True) + + self.handle_prefill_loop_thread.start() + self.handle_transfer_loop_thread.start() + self.wait_transfer_loop_thread.start() + return + + def prefill(self, reqs: List[Tuple]): + self._init_reqs(reqs) + return + + def decode(self): + uinit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids, + no_decode=True, + ) + + ok_finished_reqs, aborted_reqs, _ = self._prefill_filter_reqs(ok_finished_reqs, aborted_reqs) + + assert len(uinit_reqs) == 0 + assert len(decode_reqs) == 0 + + self._prefill_abort_remote(aborted_reqs) + self._filter_reqs(aborted_reqs + ok_finished_reqs) + + if prefill_reqs: + kwargs, run_reqs = prepare_prefill_inputs( + prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal + ) + + logits = self.model.forward(**kwargs) + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + + self._post_handle( + run_reqs, + next_token_ids, + next_token_logprobs, + is_chuncked_mode=True, + do_filter_finished_reqs=False, + extra_post_req_handle_chunk_func=lambda req: self.prefill_post_handle_queue.put(req), + ) + return diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py new file mode 100644 index 000000000..3f4354e42 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py @@ -0,0 +1,117 @@ +import threading +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +from typing import List, Tuple +from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end +from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context +from lightllm.utils.log_utils import init_logger +from lightllm.server.router.model_infer.mode_backend.generic_pre_process import prepare_prefill_inputs +from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import padded_prepare_prefill_inputs + +from .impl_for_pd_base import PDNIXLBackendBase +from .impl_for_pd_prefill import PDNIXLBackendForPrefillNode + +logger = init_logger(__name__) + + +class PDNIXLDPBackendForPrefillNode(PDNIXLBackendForPrefillNode): + def __init__(self, transfer_task_queue: mp.Queue, transfer_done_queue: mp.Queue, nixl_meta_queue: mp.Queue) -> None: + super().__init__(transfer_task_queue, transfer_done_queue, nixl_meta_queue) + self.enable_prefill_microbatch_overlap = get_env_start_args().enable_prefill_microbatch_overlap + + def init_custom(self): + super().init_custom() + self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) + return + + def decode(self): + uinit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids, + no_decode=True, + ) + + ok_finished_reqs, aborted_reqs, _ = self._prefill_filter_reqs(ok_finished_reqs, aborted_reqs) + + assert len(uinit_reqs) == 0 + assert len(decode_reqs) == 0 + + self._prefill_abort_remote(aborted_reqs) + self._filter_reqs(aborted_reqs + ok_finished_reqs) + + # if ok_finished_reqs: + # for req in ok_finished_reqs: + # self._transfer_kv_to_remote(req) + # self._filter_reqs(ok_finished_reqs) + # ok_finished_reqs.clear() + + current_dp_prefill_num = len(prefill_reqs) + self.reduce_tensor.fill_(current_dp_prefill_num) + dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) + max_prefill_num = self.reduce_tensor.item() + if max_prefill_num != 0: + if not self.enable_prefill_microbatch_overlap: + self.normal_prefill_reqs(prefill_reqs, max_prefill_num, uinit_reqs, ok_finished_reqs) + else: + self.overlap_prefill_reqs(prefill_reqs, max_prefill_num, uinit_reqs, ok_finished_reqs) + + self._overlap_req_init_and_filter(uninit_reqs=uinit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + return + + def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs): + + kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs( + prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal + ) + logits = self.model.forward(**kwargs) + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + if len(run_reqs) != 0: + logits = logits[0 : len(run_reqs), :] + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + self._post_handle( + run_reqs, + next_token_ids, + next_token_logprobs, + is_chuncked_mode=True, + do_filter_finished_reqs=False, + extra_post_req_handle_chunk_func=lambda req: self.prefill_post_handle_queue.put(req), + ) + + def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs): + from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import ( + padded_overlap_prepare_prefill_inputs, + ) + + ( + micro_batch, + run_reqs, + padded_req_num, + micro_batch1, + run_reqs1, + padded_req_num1, + ) = padded_overlap_prepare_prefill_inputs(prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal) + logits, logits1 = self.model.microbatch_overlap_prefill(micro_batch, micro_batch1) + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + req_num, req_num1 = len(run_reqs), len(run_reqs1) + all_logits = torch.empty((req_num + req_num1, logits.shape[1]), dtype=logits.dtype, device=logits.device) + + all_logits[0:req_num, :].copy_(logits[0:req_num, :], non_blocking=True) + all_logits[req_num : (req_num + req_num1), :].copy_(logits1[0:req_num1, :], non_blocking=True) + + all_run_reqs = run_reqs + run_reqs1 + if all_run_reqs: + next_token_ids, next_token_probs = sample(all_logits, all_run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + self._post_handle( + all_run_reqs, + next_token_ids, + next_token_logprobs, + is_chuncked_mode=True, + do_filter_finished_reqs=False, + extra_post_req_handle_chunk_func=lambda req: self.prefill_post_handle_queue.put(req), + ) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py new file mode 100644 index 000000000..6fb9673e4 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py @@ -0,0 +1,236 @@ +from collections import defaultdict +from typing import Dict, List, Any +from torch import Tensor +from dataclasses import dataclass + +from lightllm.utils.log_utils import init_logger + +from .pd_remote_prefill_obj import ( + RemoteAgent, + KVMoveRequest, + PrefillRequest, + RemotePrefillStatus, + ThreadSafeDict, + KVMoveRequestState, +) + + +logger = init_logger(__name__) + +try: + from nixl._api import nixl_agent as NixlWrapper + from nixl._api import nixlBind + + logger.info("Nixl is available") +except ImportError: + logger.warning("nixl is not installed, which is required for pd disagreggation!!!") + NixlWrapper = None + + +@dataclass +class NixlMetadata: + id: str + num_tokens: list[int] + agent_metadatas: list[bytes] + agent_mem_descs: list[bytes] + + +class NixlKVTransporter: + def __init__(self, node_id: int, tp_idx: int): + self.node_id = node_id + self.tp_idx = tp_idx + self.nixl_agent = NixlWrapper(self.agent_name, None) + + self.num_layers = -1 + self.num_tokens = -1 + self.num_heads = -1 + self.head_dims = -1 + self.token_len = -1 + + self.reg_desc = None + self.local_xfer_handles = None + + self.remote_agents = defaultdict(list) + + self.inflight_transfers: ThreadSafeDict = ThreadSafeDict() + + @property + def agent_name(self) -> str: + return f"{self.node_id}_{self.tp_idx}" + + @property + def agent_metadata(self): + return self.nixl_agent.get_agent_metadata() + + @property + def local_mem_desc(self): + return self.nixl_agent.get_serialized_descs(self.reg_desc) + + def get_new_notifs(self): + return self.nixl_agent.get_new_notifs() + + def _create_xfer_handles(self, reg_desc: nixlBind.nixlRegDList, num_tokens: int, agent_name: str = ""): + base_addr, _, device_id, _ = reg_desc[0] + layer_len = num_tokens * self.token_len + tokens_data = [] + for layer_id in range(self.num_layers): + for token_id in range(num_tokens): + tokens_data.append( + (base_addr + layer_id * layer_len + token_id * self.token_len, self.token_len, device_id) + ) + descs = self.nixl_agent.get_xfer_descs(tokens_data, "VRAM", True) + return self.nixl_agent.prep_xfer_dlist(agent_name, descs, is_sorted=True) + + def register_kv_buffer(self, kv_buffer: Tensor): + self.num_layers, self.num_tokens, self.num_heads, self.head_dim = kv_buffer.shape + self.token_len = self.num_heads * self.head_dim * kv_buffer.element_size() + + self.reg_desc = self.nixl_agent.register_memory(kv_buffer) + self.local_xfer_handles = self._create_xfer_handles(self.reg_desc, self.num_tokens) + + def add_remote_agent(self, remote_agent: NixlMetadata): + for idx, (agent_metadata, num_tokens, agent_mem_desc) in enumerate( + zip(remote_agent.agent_metadatas, remote_agent.num_tokens, remote_agent.agent_mem_descs) + ): + if self.tp_idx != idx: + self.remote_agents[remote_agent.id].append(None) + continue + + peer_name = self.nixl_agent.add_remote_agent(agent_metadata) + mem_desc = self.nixl_agent.deserialize_descs(agent_mem_desc) + logger.info("Added remote agent %s with mem desc %s", peer_name, mem_desc) + kv_xfer_handles = self._create_xfer_handles(mem_desc, num_tokens, agent_name=peer_name) + self.remote_agents[remote_agent.id].append( + RemoteAgent( + name=peer_name, kv_mem_desc=mem_desc, num_tokens=num_tokens, kv_xfer_handles=kv_xfer_handles + ) + ) + + def _get_token_desc_ids(self, token_ids: List[int], num_tokens: int): + descs_ids = [] + for layer_id in range(self.num_layers): + for token_id in token_ids: + descs_ids.append(layer_id * num_tokens + token_id) + return descs_ids + + def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest, is_finished: bool): + group_reqeust_id = request.group_req_id + skip_kv_move_len = prefill_request.data.local_cached_len + + # current kv len is less than remote cached kv len, just skip + if request.cur_kv_len <= skip_kv_move_len: + return + + kv_move_start = max(skip_kv_move_len, request.prev_kv_len) + kv_move_end = request.cur_kv_len + + src_token_ids = request.token_ids[kv_move_start:] + dst_token_ids = prefill_request.data.token_ids[kv_move_start - skip_kv_move_len : kv_move_end - skip_kv_move_len] + + remote_agent: RemoteAgent = self.remote_agents[prefill_request.decode_id][ + self.tp_idx + ] # TODO one-one mapping now + + if len(src_token_ids) > 0: + assert len(src_token_ids) == len(dst_token_ids), f"{len(src_token_ids)} {len(dst_token_ids)} {kv_move_start} {kv_move_end} {skip_kv_move_len}" + src_token_descs = self._get_token_desc_ids(src_token_ids, self.num_tokens) + dst_token_descs = self._get_token_desc_ids(dst_token_ids, remote_agent.num_tokens) + + src_handle = self.local_xfer_handles + dst_handle = remote_agent.kv_xfer_handles + + notify_status = RemotePrefillStatus( + group_req_id=group_reqeust_id, + status=1, + chunk_id=prefill_request.transfer_state.current_chunk_id, + is_last=is_finished, + ).serialize() if is_finished else b"" + + handle = self.nixl_agent.make_prepped_xfer( + "WRITE", src_handle, src_token_descs, dst_handle, dst_token_descs, notify_status + ) + + status = self.nixl_agent.transfer(handle) + assert status != "ERR" + + if group_reqeust_id not in self.inflight_transfers: + self.inflight_transfers[group_reqeust_id] = KVMoveRequestState( + handles=[], done_handles=[], remote_agent=remote_agent, abort=False, is_last_arrived=False + ) + + self.inflight_transfers[group_reqeust_id].handles.append(handle) + + if is_finished: + self.inflight_transfers[group_reqeust_id].is_last_arrived = True + + return handle + + return None + + def send_abort_notify(self, remote_id: int, group_reqeust_id): + remote_agent: RemoteAgent = self.remote_agents[remote_id][self.tp_idx] + notify_status = RemotePrefillStatus(group_req_id=group_reqeust_id, status=-1, chunk_id=-1, is_last=True) + self.nixl_agent.send_notif(remote_agent.name, notify_status.serialize()) + + if group_reqeust_id in self.inflight_transfers: + self.inflight_transfers[group_reqeust_id].abort = True + + def get_done_tranfers(self): + done_req_ids = [] + for req_id, kv_move_state in self.inflight_transfers.items(): + kv_move_state: KVMoveRequestState + if kv_move_state.abort: + logger.warning(f"{req_id} Transfer aborted") + done_req_ids.append((req_id, -1)) + continue + + if not kv_move_state.is_last_arrived: + continue + + remote_agent: RemoteAgent = kv_move_state.remote_agent + + left_handles = [] + failed = False + for handle in kv_move_state.handles: + if failed: + left_handles.append(handle) + continue + + xfer_state = self.nixl_agent.check_xfer_state(handle) + + if xfer_state == "DONE": + kv_move_state.done_handles.append(handle) + elif xfer_state == "PROC": + left_handles.append(handle) + else: + logger.warning(f"{req_id} Transfer failed with state {xfer_state}") + failed = True + kv_move_state.done_handles.append(handle) + notify_failed_status = RemotePrefillStatus( + group_req_id=req_id, status=-1, chunk_id=-1, is_last=True + ) + self.nixl_agent.send_notif(remote_agent.name, notify_failed_status.serialize()) + + kv_move_state.handles = left_handles + + if failed: + done_req_ids.append((req_id, -1)) + elif len(left_handles) == 0: + done_req_ids.append((req_id, 1)) + + for req_id, _ in done_req_ids: + kv_move_state: KVMoveRequestState = self.inflight_transfers[req_id] + for handle in kv_move_state.handles + kv_move_state.done_handles: + # release will abort inflight transfer + self.nixl_agent.release_xfer_handle(handle) + + del self.inflight_transfers[req_id] + return done_req_ids + + def shutdonw(self): + self.nixl_agent.deregister_memory(self.reg_desc) + self.nixl_agent.release_dlist_handle(self.local_xfer_handles) + for id, agents in self.remote_agents.items(): + for agent in agents: + self.nixl_agent.remove_remote_agent(agent.name) + self.nixl_agent.release_xfer_handle(agent.kv_xfer_handles) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py new file mode 100644 index 000000000..9df2c5664 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py @@ -0,0 +1,300 @@ +from typing import List, Any +import zmq +import inspect +import random + +import torch.multiprocessing as mp + +from lightllm.utils.log_utils import init_logger +from lightllm.utils.net_utils import get_hostname_ip +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.server.pd_io_struct import DistInfo + +from .pd_remote_prefill_obj import ( + ConnectRequest, + RemoteRequest, + RemoteRequstType, + PrefillRequest, + RemotePrefillRequest, + RemotePrefillServerInfo, + RemotePrefillTask, + RemotePrefillStatus, + SockWithPoller, +) +from .nixl_kv_transporter import NixlMetadata + +logger = init_logger(__name__) + + +class PDRemotePrefillBase: + def __init__( + self, + id: int, + dist_info: DistInfo, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], # need send kv cache to this process and register with nixl + ): + self.id = id + self.dist_info = dist_info + assert len(agent_meta_queues) == dist_info.node_world_size + self.agent_meta_queues = agent_meta_queues + self.from_backend_queue = from_backend_queue + self.to_backend_queues = to_backend_queues + self.local_agent_meta = None + + def local_init(self): + agent_metas = NixlMetadata(id=self.id, agent_metadatas=[], num_tokens=[], agent_mem_descs=[]) + for tp in range(self.dist_info.node_world_size): + agent_metadata, num_tokens, mem_desc = self.agent_meta_queues[tp].get(timeout=60) + logger.info(f"Received agent_metadata from {tp} with mem reg: {mem_desc}") + agent_metas.num_tokens.append(num_tokens) + agent_metas.agent_metadatas.append(agent_metadata) + agent_metas.agent_mem_descs.append(mem_desc) + + self.local_agent_meta = agent_metas + logger.info("All local kv cache registered.") + + +class PDRemotePrefillServer(PDRemotePrefillBase): + def __init__( + self, + id: int, + dist_info: DistInfo, + http_server_port: int, + server_port: int, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], + ): + super().__init__(id, dist_info, from_backend_queue, to_backend_queues, agent_meta_queues) + # map from client id to decode server info + self.remote_decode_clients = {} + + # build control path + _ctx = zmq.Context() + self.recv_from_decode = SockWithPoller(_ctx.socket(zmq.ROUTER)) + self.host_ip = get_hostname_ip() + self.recv_from_decode.bind(f"tcp://{self.host_ip}:{server_port}") + + # build trigger remote prefill path + self.send_to_httpserver = SockWithPoller(_ctx.socket(zmq.PUSH)) + self.send_to_httpserver.connect(f"tcp://{self.host_ip}:{http_server_port}") + + def main_loop(self): + self.local_init() + while True: + try: + client_obj, request = self.recv_from_decode.recv_pyobj_multipart() + request: RemoteRequest + logger.info(f"recevied request from decode, type: {request.type}") + + if request.type == RemoteRequstType.REMOTE_CONNECT: + # forward request to all prefill server + for queue in self.to_backend_queues: + queue.put(request) + + success = True + for idx in range(self.dist_info.node_world_size): + ack = self.from_backend_queue.get() + logger.info(f"received ack from backend {idx}: {ack}") + if ack != "OK": + success = False + break + + self.recv_from_decode.send_pyobj_multipart(client_obj, success) + logger.info(f"Sent ack to decode: {success}") + if not success: + logger.warning(f"Remote connect failed: {request}") + + if request.type == RemoteRequstType.REMOTE_PREFILL: + request: PrefillRequest = request + if self.dist_info.dp_size_in_node > 1: + group_req_id = request.data.sampling_params.group_request_id + suggested_dp_index = request.data.sampling_params.suggested_dp_index + if suggested_dp_index < 0: # not likely to happen + suggested_dp_index = random.randint(0, self.dist_info.dp_size_in_node) + request.data.sampling_params.suggested_dp_index = suggested_dp_index + logger.warning( + f"Suggested dp index is negative for {group_req_id}, set to {suggested_dp_index}" + ) + + for local_rank in range( + suggested_dp_index * self.dist_info.dp_world_size, + (suggested_dp_index + 1) * self.dist_info.dp_world_size, + ): + self.to_backend_queues[local_rank].put(request) + else: + for queue in self.to_backend_queues: + queue.put(request) + + self.send_to_httpserver.send_pyobj( + (request.data.prompt, request.data.sampling_params, request.data.multimodal_params) + ) + + except Exception as e: + logger.error(f"Error in remote prefill server loop: {e}", exc_info=e) + + +class PDRemotePrefillClient(PDRemotePrefillBase): + def __init__( + self, + id: int, + dist_info: DistInfo, + from_backend_queue: mp.Queue, # only tp0 will trigger prefill + to_backend_queues: List[mp.Queue], # one to many done queue + agent_meta_queues: List[mp.Queue], + ): + super().__init__(id, dist_info, from_backend_queue, to_backend_queues, agent_meta_queues) + # map from server id to prefill server info + + self.remote_prefill_servers = {} + self.client_socket_cnt = 0 + self._ctx = zmq.Context() + + def _connect_server(self, server_ip: str, port: int): + _socket = self._ctx.socket(zmq.DEALER) + _socket.setsockopt_string(zmq.IDENTITY, f"{self.id}_{self.client_socket_cnt}") + self.client_socket_cnt += 1 + connect_str = f"tcp://{server_ip}:{port}" + _socket.connect(connect_str) + return SockWithPoller(_socket) + + def _send_nixl_agent(self, socket: SockWithPoller): + socket.send_pyobj( + ConnectRequest( + type=RemoteRequstType.REMOTE_CONNECT, + decode_id=self.id, + num_tokens=self.local_agent_meta.num_tokens, + agent_metadatas=self.local_agent_meta.agent_metadatas, + agent_mem_descs=self.local_agent_meta.agent_mem_descs, + ) + ) + + success = socket.recv_pyobj(timeout=60) + logger.info(f"recv remote nixl connect response {success}") + if success is None: + logger.warning("timeout to recv remote nixl connect response") + return False + + return success + + def connect_to_prefill_server(self, server_info: RemotePrefillServerInfo): + + if server_info.perfill_server_id in self.remote_prefill_servers: + return True + + # build control path if not exist + _socket = self._connect_server(server_info.prefill_server_ip, server_info.prefill_server_port) + success = self._send_nixl_agent(_socket) + if success: + self.remote_prefill_servers[server_info.perfill_server_id] = (_socket, server_info) + return True + else: + logger.warning("Remote Prefill Server Connect Failed") + return False + + def main_loop(self): + self.local_init() + while True: + try: + prefill_tasks: RemotePrefillTask = self.from_backend_queue.get() + # connect first + if self.connect_to_prefill_server(prefill_tasks.server_info): + # do prefill + self.remote_prefill(prefill_tasks.server_info.perfill_server_id, prefill_tasks.prefill_request) + else: + # failed to connect a remote + for idx in self.to_backend_queues: + self.to_backend_queues.put( + RemotePrefillStatus( + group_req_id=prefill_tasks.prefill_request.sampling_params.group_request_id, + status=-1, + chunk_id=-1, + is_last=True, + ) + ) + except Exception as e: + logger.error(f"Remote prefill client loop error: {e}", exc_info=e) + + # place request to server do remote prefill + def remote_prefill(self, server_id: int, prefill_request: RemotePrefillRequest): + socket, _ = self.remote_prefill_servers[server_id] + prefill_request.sampling_params.max_new_tokens = 1 + socket.send_pyobj( + PrefillRequest( + type=RemoteRequstType.REMOTE_PREFILL, decode_id=self.id, data=prefill_request, transfer_state=None + ) + ) + + +def remote_prefill_server_loop( + id: int, + dist_info: DistInfo, + http_server_port: int, + server_port: int, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], +): + graceful_registry(inspect.currentframe().f_code.co_name) + server = PDRemotePrefillServer( + id, dist_info, http_server_port, server_port, from_backend_queue, to_backend_queues, agent_meta_queues + ) + server.main_loop() + + +def start_pd_remote_prefill_server_process( + id: int, + dist_info: DistInfo, + http_server_port: int, + server_port: int, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], +): + proc = mp.Process( + target=remote_prefill_server_loop, + args=(id, dist_info, http_server_port, server_port, from_backend_queue, to_backend_queues, agent_meta_queues), + ) + proc.start() + assert proc.is_alive() + logger.info(f"remote prefill server with id: {id} started!") + return proc + + +def remote_prefill_client_loop( + id: int, + dist_info: DistInfo, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], +): + graceful_registry(inspect.currentframe().f_code.co_name) + + client = PDRemotePrefillClient( + id, + dist_info, + from_backend_queue, + to_backend_queues, + agent_meta_queues, + ) + client.main_loop() + + +def start_pd_remote_prefill_client_process( + id: int, + dist_info: DistInfo, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], +): + + proc = mp.Process( + target=remote_prefill_client_loop, + args=(id, dist_info, from_backend_queue, to_backend_queues, agent_meta_queues), + ) + proc.start() + assert proc.is_alive() + logger.info(f"remote prefill client with id: {id} started!") + return proc diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py new file mode 100644 index 000000000..02ee6c4ed --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py @@ -0,0 +1,190 @@ +from dataclasses import dataclass, asdict +from enum import Enum +import json +from typing import List, Union, Optional, Any +from threading import Lock +import pickle +import zmq + +from lightllm.utils.log_utils import init_logger +from lightllm.server.core.objs import SamplingParams +from lightllm.server.multimodal_params import MultimodalParams +from lightllm.server.pd_io_struct import RemotePrefillServerInfo + +logger = init_logger(__name__) + +try: + from nixl._api import nixlBind, nixl_prepped_dlist_handle, nixl_xfer_handle + +except ImportError: + logger.error("nixl is not installed, which is required for pd disagreggation!!!") + raise + + +class RemoteRequstType(Enum): + REMOTE_CONNECT = 1 + REMOTE_PREFILL = 2 + + +@dataclass +class RemotePrefillRequest: + prompt: Union[str, List[int]] + sampling_params: SamplingParams + multimodal_params: MultimodalParams + local_cached_len: int # will skip transfer + token_ids: List[int] # mem cache indexes + + +@dataclass +class RemotePrefillTask: + server_info: RemotePrefillServerInfo + prefill_request: RemotePrefillRequest + + +@dataclass +class RemoteRequest: + type: RemoteRequstType + + +@dataclass +class ConnectRequest(RemoteRequest): + decode_id: int + num_tokens: List[int] + agent_metadatas: List[bytes] + agent_mem_descs: List[bytes] + + +@dataclass +class TransferState: + start_time: float + current_kv_len: int + current_chunk_id: int + + +@dataclass +class PrefillRequest(RemoteRequest): + decode_id: int + data: RemotePrefillRequest + # transfer status + transfer_state: Optional[TransferState] + + +@dataclass +class KVMoveRequest: + group_req_id: int + token_ids: List[int] + prev_kv_len: int + cur_kv_len: int + + +@dataclass +class RemoteAgent: + name: str + num_tokens: int + kv_mem_desc: nixlBind.nixlRegDList + kv_xfer_handles: nixl_prepped_dlist_handle + + +@dataclass +class KVMoveRequestState: + handles: List[nixl_xfer_handle] + done_handles: List[nixl_xfer_handle] + remote_agent: RemoteAgent + abort: bool + is_last_arrived: bool + + +@dataclass +class RemotePrefillStatus: + group_req_id: int + status: int + chunk_id: int + is_last: bool + + def serialize(self): + return json.dumps(asdict(self)).encode() + + @classmethod + def deserialize(cls, data: bytes): + return cls(**json.loads(data.decode())) + + +class ThreadSafeDict: + def __init__(self): + self._dict = {} + self._lock = Lock() + + def __getitem__(self, key): + with self._lock: + return self._dict[key] + + def __setitem__(self, key, value): + with self._lock: + self._dict[key] = value + + def __delitem__(self, key): + with self._lock: + del self._dict[key] + + def __contains__(self, key): + with self._lock: + return key in self._dict + + def __len__(self) -> int: + with self._lock: + return len(self._dict) + + def get(self, key, default=None): + with self._lock: + return self._dict.get(key, default) + + def items(self): + with self._lock: + return list(self._dict.items()) + + def keys(self): + with self._lock: + return list(self._dict.keys()) + + def pop(self, key: Any, default: Optional[Any] = None) -> Any: + with self._lock: + return self._dict.pop(key, default) + + def values(self): + with self._lock: + return list(self._dict.values()) + + def clear(self) -> None: + with self._lock: + self._dict.clear() + + +class SockWithPoller: + def __init__(self, sock: zmq.Socket): + self.sock = sock + self.poller = zmq.Poller() + self.poller.register(self.sock, zmq.POLLIN) + + def recv_pyobj(self, timeout: int = 5): + socks = dict(self.poller.poll(timeout * 1000)) + if socks: + if socks.get(self.sock) == zmq.POLLIN: + return self.sock.recv_pyobj() + else: + None + + def send_pyobj(self, obj: Any): + return self.sock.send_pyobj(obj) + + def recv_pyobj_multipart(self): + client_id, data = self.sock.recv_multipart() + return client_id, pickle.loads(data) + + def send_pyobj_multipart(self, client_id: bytes, data: Any): + return self.sock.send_multipart([client_id, pickle.dumps(data)]) + + def bind(self, addr: str): + return self.sock.bind(addr) + + def connect(self, addr: str): + return self.sock.connect(addr) diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index ed1470fba..e20d1bffd 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -21,6 +21,10 @@ DPForDecodeNode, ChunckedPrefillForPrefillNode, DPChunkedForPrefillNode, + PDNIXLBackendForPrefillNode, + PDNIXLBackendForDecodeNode, + PDNIXLDPBackendForPrefillNode, + PDNIXLDPBackendForDecodeNode, ) from lightllm.server.core.objs import RpcShmParams, RpcShmResults, ShmSyncStatusArray from lightllm.utils.log_utils import init_logger @@ -40,12 +44,14 @@ def __init__( rpc_event: multiprocessing.Event, rpc_finished_event: multiprocessing.Event, info_queue: mp.Queue, + result_queue: mp.Queue, mem_queue: mp.Queue, ): super().__init__() self.args = args self.node_world_size = node_world_size self.info_queue = info_queue + self.result_queue = result_queue self.mem_queue = mem_queue self.rpc_event = rpc_event self.rpc_finished_event = rpc_finished_event @@ -120,22 +126,40 @@ def init_model(self, kvargs): ), "only one constraint mode can be true" is_prefill_node = kvargs.get("args", None).run_mode == "prefill" is_decode_node = kvargs.get("args", None).run_mode == "decode" + is_nixl_prefill_node = kvargs.get("args", None).run_mode == "nixl_prefill" + is_nixl_decode_node = kvargs.get("args", None).run_mode == "nixl_decode" else: is_outlines_constraint_mode = False is_xgrammar_constraint_mode = False is_prefill_node = False is_decode_node = False + is_nixl_prefill_node = False + is_nixl_decode_node = False if is_prefill_node: if kvargs.get("args", None).dp > 1: self.backend = DPChunkedForPrefillNode(self.info_queue, self.mem_queue) else: self.backend = ChunckedPrefillForPrefillNode(self.info_queue, self.mem_queue) + + elif is_nixl_prefill_node: + if kvargs.get("args", None).dp > 1: + self.backend = PDNIXLDPBackendForPrefillNode(self.info_queue, self.result_queue, self.mem_queue) + else: + self.backend = PDNIXLBackendForPrefillNode(self.info_queue, self.result_queue, self.mem_queue) + elif is_decode_node: if kvargs.get("args", None).dp > 1: self.backend = DPForDecodeNode(self.info_queue, self.mem_queue) else: self.backend = ContinuesBatchBackendForDecodeNode(self.info_queue, self.mem_queue) + + elif is_nixl_decode_node: + if kvargs.get("args", None).dp > 1: + self.backend = PDNIXLDPBackendForDecodeNode(self.info_queue, self.result_queue, self.mem_queue) + else: + self.backend = PDNIXLBackendForDecodeNode(self.info_queue, self.result_queue, self.mem_queue) + elif kvargs.get("dp_size", 1) > 1: self.backend = DPChunkedPrefillBackend() elif use_reward_model: @@ -274,6 +298,7 @@ def _init_env( rank_in_node, node_world_size, info_queue, + result_queue, mem_queue, router_lock, rpc_event: mp.Event, @@ -292,7 +317,7 @@ def _init_env( g_router_lock.obj = router_lock model_rpc_server = ModelRpcServer( - args, rank, rank_in_node, node_world_size, rpc_event, rpc_finished_event, info_queue, mem_queue + args, rank, rank_in_node, node_world_size, rpc_event, rpc_finished_event, info_queue, result_queue, mem_queue ) success_event.set() @@ -308,6 +333,7 @@ async def start_model_process( rpc_event, rpc_finished_event, info_queue: mp.Queue, + result_queue: mp.Queue, mem_queue: mp.Queue, router_lock: mp.Queue, ): @@ -323,6 +349,7 @@ async def start_model_process( rpc_event, rpc_finished_event, info_queue, + result_queue, mem_queue, ) @@ -335,6 +362,7 @@ async def start_model_process( rank_in_node, node_world_size, info_queue, + result_queue, mem_queue, router_lock, rpc_event, diff --git a/lightllm/server/router/req_queue/__init__.py b/lightllm/server/router/req_queue/__init__.py index 736ab5e59..36ba7879e 100644 --- a/lightllm/server/router/req_queue/__init__.py +++ b/lightllm/server/router/req_queue/__init__.py @@ -15,9 +15,9 @@ def _get_req_queue_class(args, router, dp_size_in_node: int): return ChunkedPrefillQueue if args.first_token_constraint_mode: return ChunkedPrefillQueue - if args.run_mode == "decode": + if args.run_mode in ["decode", "nixl_decode"]: return QueueForPDDecode - if args.run_mode == "prefill": + if args.run_mode in ["prefill", "nixl_prefill"]: return QueueForPDChunkedPrefill if args.disable_chunked_prefill: diff --git a/lightllm/utils/health_check.py b/lightllm/utils/health_check.py index f6305e209..ee0778b65 100644 --- a/lightllm/utils/health_check.py +++ b/lightllm/utils/health_check.py @@ -70,7 +70,7 @@ async def health_check(args, httpserver_manager: HttpServerManager, request: Req health_obj.begin_check() try: request_dict = {"inputs": "你好!", "parameters": {"do_sample": True, "temperature": 0.8, "max_new_tokens": 2}} - if args.run_mode == "prefill": + if args.run_mode in ["prefill", "nixl_prefill"]: request_dict["parameters"]["max_new_tokens"] = 1 prompt = request_dict.pop("inputs") sample_params_dict = request_dict["parameters"]