Skip to content

pd with nixl backend #856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
89 changes: 89 additions & 0 deletions Dockerfile.nixl
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
FROM nvcr.io/nvidia/tritonserver:25.04-py3-min as base
ARG PYTORCH_VERSION=2.6.0
ARG PYTHON_VERSION=3.9
ARG CUDA_VERSION=12.4
ARG MAMBA_VERSION=23.1.0-1
ARG TARGETPLATFORM
ARG INSTALL_NIXL=true

ENV PATH=/opt/conda/bin:$PATH \
CONDA_PREFIX=/opt/conda

RUN chmod 777 -R /tmp && apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
ca-certificates \
libssl-dev \
curl \
g++ \
make \
git && \
rm -rf /var/lib/apt/lists/*

RUN case ${TARGETPLATFORM} in \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -o ~/mambaforge.sh -v "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" && \
bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh

RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya


WORKDIR /root

COPY ./requirements.txt /lightllm/requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip pip install -r /lightllm/requirements.txt --ignore-installed --extra-index-url https://download.pytorch.org/whl/cu124

RUN pip install --no-cache-dir https://github.com/ModelTC/flash-attn-3-build/releases/download/v2.7.4.post1/flash_attn-3.0.0b1-cp39-cp39-linux_x86_64.whl

RUN --mount=type=cache,target=/root/.cache/pip pip install nvidia-nccl-cu12==2.25.1 # for allreduce hang issues in multinode H100

RUN if [ "$INSTALL_NIXL" == "true" ]; then \
apt-get update && apt-get install -y cmake automake autotools-dev libtool libz-dev && \
DEBIAN_FRONTEND=noninteractive apt-get -y install --reinstall libibverbs-dev rdma-core ibverbs-utils libibumad-dev; \
rm -rf /usr/lib/ucx && \
rm -rf /opt/hpcx/ucx && \
cd /usr/local/src && \
git clone https://github.com/openucx/ucx.git && \
cd ucx && \
git checkout v1.19.x && \
./autogen.sh && ./configure \
--enable-shared \
--disable-static \
--disable-doxygen-doc \
--enable-optimizations \
--enable-cma \
--enable-devel-headers \
--with-cuda=/usr/local/cuda \
--with-verbs=yes \
--with-dm \
--with-gdrcopy=/usr/local \
--with-efa \
--enable-mt && \
make -j && \
make -j install-strip && \
ldconfig; \
fi

RUN if [ "$INSTALL_NIXL" == "true" ]; then \
apt-get update && apt-get install -y pkg-config tmux net-tools; \
cd /usr/local/src; \
pip install --upgrade meson pybind11 patchelf; \
git clone https://github.com/ai-dynamo/nixl.git -b main && \
cd nixl && \
rm -rf build && \
mkdir build && \
meson setup build/ --prefix=/usr/local/nixl --buildtype=release && \
cd build && \
ninja && \
ninja install && \
cd .. && pip install . --no-deps; \
fi

COPY . /lightllm
RUN pip install -e /lightllm --no-cache-dir
16 changes: 15 additions & 1 deletion lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--run_mode",
type=str,
choices=["normal", "prefill", "decode", "pd_master", "config_server"],
choices=["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"],
default="normal",
help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode,
config_server is for pd split mode used to register pd_master node, and get pd_master node list,
Expand Down Expand Up @@ -54,6 +54,20 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=None,
help="The port number for the config server in config_server mode.",
)
parser.add_argument(
"--pd_nixl_remote_prefill_http_port",
type=int,
default=42001,
help="nixl pd mode, prefill node used for triggering prefill http port.",
)

parser.add_argument(
"--pd_nixl_remote_prefill_port",
type=int,
default=42002,
help="nixl pd mode, prefill and decode used for meta exchange.",
)

parser.add_argument(
"--model_name",
type=str,
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def normal_or_p_d_start(args):

enable_mps()

if args.run_mode not in ["normal", "prefill", "decode"]:
if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode"]:
return

assert args.zmq_mode in ["tcp://", "ipc:///tmp/"]
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/core/objs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .sampling_params import SamplingParams
from .req import Req, FinishStatus
from .req import Req, FinishStatus, PDNIXLChunkedPrefillReq
from .shm_req_manager import ShmReqManager
from .rpc_shm import RpcShmParams, RpcShmResults, ShmSyncStatusArray
53 changes: 53 additions & 0 deletions lightllm/server/core/objs/req.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def get_str(self):
f"shm_cur_kv_len:{self.shm_cur_kv_len},"
f"shm_cur_output_len:{self.shm_cur_output_len},"
f"finish_status:{self.finish_status.is_finished()}"
f"group_id: {self.group_req_id}"
)

def init(
Expand Down Expand Up @@ -333,3 +334,55 @@ def post_init(
# 错误问题。
self.sample_params.max_new_tokens = self.sample_params.max_new_tokens + self.prefix_token_ids.size + 6
return


class PdNixlReqState(ctypes.Structure):
_pack_ = 4
_MAX_TP_SIZE = 32
_fields_ = [("dp_world_size", ctypes.c_int), ("state", ctypes.c_int * _MAX_TP_SIZE)]

def __init__(self):
self.dp_world_size = 0
self.state = (ctypes.c_int * self._MAX_TP_SIZE)(*([0] * self._MAX_TP_SIZE))

def set_dp_world_size(self, size: int):
self.dp_world_size = size

def set_tp_state(self, tp_id: int, state: int):
assert (
self.dp_world_size > 0 and tp_id >= 0 and tp_id < self.dp_world_size
), f"tp_id {tp_id} out of range [0, {self.dp_world_size})"
self.state[tp_id] = state

def set_state(self):
assert self.dp_world_size > 0, "dp_world_size should be set before calling this"
unique_state = np.unique(self.state[: self.dp_world_size])
self.state[self.dp_world_size] = unique_state[0]

def get_state(self):
assert self.dp_world_size > 0, "dp_world_size should be set before calling this"
return self.state[self.dp_world_size]


class PDNIXLChunkedPrefillReq(ChunkedPrefillReq):
_pack_ = 4
_fields_ = ChunkedPrefillReq._fields_ + [
# 用于pd nixl状态同步
("pd_nixl_req_state", PdNixlReqState)
]

def set_dp_world_size(self, dp_world_size):
self.pd_nixl_req_state.dp_world_size = dp_world_size

# called by each tp rank, no contention
def set_pd_req_rank_state(self, tp_id: int, state: int):
self.pd_nixl_req_state.set_tp_state(tp_id, state)

# state: -1 for failed, 0 for in progress, 1 for success
# set by router
def set_pd_req_state(self):
self.pd_nixl_req_state.set_state()

# read by all rank
def get_pd_req_state(self):
return self.pd_nixl_req_state.get_state()
5 changes: 4 additions & 1 deletion lightllm/server/core/objs/shm_req_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from lightllm.utils.envs_utils import get_unique_server_name
from multiprocessing import shared_memory
from lightllm.utils.log_utils import init_logger
from .req import Req, NormalReq, ChunkedPrefillReq, TokenHealingReq
from .req import Req, NormalReq, ChunkedPrefillReq, TokenHealingReq, PDNIXLChunkedPrefillReq
from .shm_array import ShmArray
from .atomic_array_lock import AtomicShmArrayLock, AtomicLockItem
from .atomic_lock import AtomicShmLock
Expand Down Expand Up @@ -33,6 +33,9 @@ def get_req_class_type(self):
if args.token_healing_mode:
return TokenHealingReq

if args.run_mode in ["nixl_prefill", "nixl_decode"]:
return PDNIXLChunkedPrefillReq

if args.disable_chunked_prefill:
return NormalReq
else:
Expand Down
5 changes: 4 additions & 1 deletion lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

@dataclass
class StartArgs:
run_mode: str = field(default="normal", metadata={"choices": ["normal", "prefill", "decode", "pd_master"]})
run_mode: str = field(
default="normal",
metadata={"choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode"]},
)
host: str = field(default="127.0.0.1")
port: int = field(default=8000)
zmq_mode: str = field(
Expand Down
30 changes: 16 additions & 14 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
self.metric_client = MetricClient(metric_port)

self.pd_mode: NodeRole = NodeRole(self.args.run_mode)
assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL]
assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL, NodeRole.NP, NodeRole.ND]
self.id_gen = ReqIDGenerator()
self.first_time_costs = MovingAverage()
self.per_token_costs = MovingAverage()
Expand Down Expand Up @@ -225,7 +225,7 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False):
# health 请求 request_id 为负数,直接返回
if is_health_req:
return sampling_params.group_request_id
if self.pd_mode == NodeRole.NORMAL:
if self.pd_mode.is_normal():
if not self.is_multinode_tp:
group_request_id = self.id_gen.generate_id()
else:
Expand All @@ -235,7 +235,7 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False):
assert sampling_params.group_request_id != -1
group_request_id = sampling_params.group_request_id
sampling_params.group_request_id = group_request_id
elif self.pd_mode == NodeRole.P or self.pd_mode == NodeRole.D:
elif self.pd_mode.is_P_or_D():
assert sampling_params.group_request_id is not None, "p d mode, group_request_id must be setting"
group_request_id = sampling_params.group_request_id
else:
Expand Down Expand Up @@ -412,7 +412,7 @@ async def transfer_to_next_module_or_node(
if self.is_multinode_tp_master:
async with self.transfer_lock:
for sender in self.multinode_req_manager:
sender.send_pyobj(
await sender.send_pyobj(
(prompt, sampling_params, original_multimodal_params),
protocol=pickle.HIGHEST_PROTOCOL,
)
Expand Down Expand Up @@ -441,35 +441,37 @@ async def transfer_to_next_module(
group_req_objs: Optional[GroupReqObjs] = None,
):

if self.pd_mode == NodeRole.P:
if self.pd_mode.is_P():
if self.enable_multimodal:
self.send_to_visual.send_pyobj(
await self.send_to_visual.send_pyobj(
group_req_objs.to_group_req_index(),
protocol=pickle.HIGHEST_PROTOCOL,
)
else:
self.send_to_router.send_pyobj(

# P 模式下,直接将请求发送到路由器
await self.send_to_router.send_pyobj(
group_req_objs.to_group_req_index(),
protocol=pickle.HIGHEST_PROTOCOL,
)
return

if self.pd_mode == NodeRole.D:
if self.pd_mode.is_D():
# 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了, 传输一个空的即可
self.send_to_router.send_pyobj(
await self.send_to_router.send_pyobj(
group_req_objs.to_group_req_index(),
protocol=pickle.HIGHEST_PROTOCOL,
)
return

if self.pd_mode == NodeRole.NORMAL:
if self.pd_mode.is_normal():
if self.enable_multimodal:
self.send_to_visual.send_pyobj(
await self.send_to_visual.send_pyobj(
group_req_objs.to_group_req_index(),
protocol=pickle.HIGHEST_PROTOCOL,
)
else:
self.send_to_router.send_pyobj(
await self.send_to_router.send_pyobj(
group_req_objs.to_group_req_index(),
protocol=pickle.HIGHEST_PROTOCOL,
)
Expand Down Expand Up @@ -514,7 +516,7 @@ async def _wait_to_token_package(
# pd master 节点需要这个做统计信息, 所以放在元数据中返回给 pd master 节点
metadata["prompt_tokens"] = prompt_tokens
# p 节点返回 prompt_ids 信息,防止 d 节点重新 encode
if self.pd_mode == NodeRole.P and is_first_token:
if self.pd_mode.is_P() and is_first_token:
metadata["prompt_ids"] = prompt_ids

prompt_cache_len = metadata.pop("prompt_cache_len", 0)
Expand Down Expand Up @@ -611,7 +613,7 @@ async def recycle_resource_loop(self):
pre_time_mark = time.time()
for req_status in self.req_id_to_out_inf.values():
logger.info(
f"left req id {req_status.group_req_objs.group_req_id}"
f"left req id {req_status.group_req_objs.group_req_id} "
f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} "
f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}"
)
Expand Down
36 changes: 35 additions & 1 deletion lightllm/server/httpserver/pd_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import socket
import httpx
import base64
import zmq
from typing import Dict, Optional
from lightllm.server.pd_io_struct import NodeRole, ObjType
from lightllm.server.httpserver.async_queue import AsyncQueue
Expand Down Expand Up @@ -33,6 +34,8 @@ async def pd_handle_loop(manager: HttpServerManager):
manager.host_ip = manager.args.host

asyncio.create_task(timer_log(manager))
if manager.pd_mode.is_NP_or_ND():
asyncio.create_task(pd_handle_loop_from_d(manager))

id_to_handle_task: Dict[int, asyncio.Task] = {}

Expand Down Expand Up @@ -92,7 +95,8 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O
logger.info(f"Sent registration JSON: {regist_json}")

# 转发任务
forwarding_tokens_task = asyncio.create_task(_up_tokens_to_pd_master(forwarding_queue, websocket))
if manager.pd_mode != NodeRole.NP: # nixl prefill don't need up token to master
forwarding_tokens_task = asyncio.create_task(_up_tokens_to_pd_master(forwarding_queue, websocket))

# 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。
while True:
Expand Down Expand Up @@ -182,3 +186,33 @@ async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket):
handle_list = await forwarding_queue.wait_to_get_all_data()
if handle_list:
await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list)))


async def pd_handle_loop_from_d(manager: HttpServerManager):
if manager.pd_mode != NodeRole.NP:
return

context = zmq.asyncio.Context(2)
manager.recv_from_d = context.socket(zmq.PULL)
manager.recv_from_d.bind(f"tcp://*:{manager.args.pd_nixl_remote_prefill_http_port}")

while True:
try:
(
prompt,
sampling_params,
multimodal_params,
) = await manager.recv_from_d.recv_pyobj()

# 触发推理的task
async def pd_process_generate(manager: "HttpServerManager", prompt, sampling_params, multimodal_params):
try:
async for _, _, _, _ in manager.generate(prompt, sampling_params, multimodal_params, None):
pass
except BaseException as e:
logger.error(str(e))

asyncio.create_task(pd_process_generate(manager, prompt, sampling_params, multimodal_params))

except Exception as e:
logger.exception(f"pd loop generate error: {str(e)}")
Loading