Skip to content

Commit f4e52da

Browse files
author
Weichao Luo
committed
paged transfer.
1 parent 551d94b commit f4e52da

13 files changed

+583
-115
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ def _init_kv_move_buffer(self):
166166
# p d 分离的推理模式下才需要做这一步初始化
167167
if self.run_mode in ["prefill", "decode"]:
168168
self.mem_manager.alloc_kv_move_buffer(self.mem_manager.size)
169+
elif self.run_mode in ["nixl_prefill", "nixl_decode"]:
170+
page_num = int(os.getenv("PD_NIXL_MOVE_PAGE_NUM", 32))
171+
page_size = int(os.getenv("PD_NIXL_MOVE_PAGE_SIZE", 1024))
172+
self.mem_manager.alloc_paged_kv_move_buffer(page_num, page_size)
169173

170174
def _check_mem_size(self):
171175
self.max_total_token_num = self.mem_manager.size

lightllm/common/deepseek2_mem_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ def alloc_kv_move_buffer(self, max_req_total_len):
3636
self.token_dim_size = self.kv_move_buffer.shape[-1] * self.kv_move_buffer.shape[-2]
3737
return
3838

39+
def alloc_paged_kv_move_buffer(self, page_num, page_size):
40+
self.kv_move_buffer = torch.empty(
41+
(page_num, page_size, self.layer_num, self.head_num, self.head_dim), dtype=self.dtype, device="cuda"
42+
)
43+
return
44+
3945
def send_to_decode_node(
4046
self,
4147
move_tasks: List[KVMoveTask],

lightllm/common/mem_manager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,14 @@ def alloc_kv_move_buffer(self, max_req_total_len):
9292
self.token_dim_size = self.kv_move_buffer.shape[-2] * self.kv_move_buffer.shape[-1]
9393
return
9494

95+
def alloc_paged_kv_move_buffer(self, page_num, page_size):
96+
if isinstance(self, MemoryManager) and type(self) != MemoryManager:
97+
raise NotImplementedError("subclass need reimpl this method")
98+
self.kv_move_buffer = torch.empty(
99+
(page_num, page_size, self.layer_num, 2 * self.head_num, self.head_dim), dtype=self.dtype, device="cuda"
100+
)
101+
return
102+
95103
def send_to_decode_node(
96104
self,
97105
move_tasks: List[KVMoveTask],

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,6 @@ def init_all(self):
264264
self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index)
265265
self.shm_req.link_prompt_ids_shm_array()
266266
self.shm_req.link_logprobs_shm_array()
267-
if isinstance(self.shm_req, PDNIXLChunkedPrefillReq):
268-
self.in_prefill_or_transfer = False
269267

270268
self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size)
271269
if self.sampling_param.shm_param.input_penalty:

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def _post_handle(
291291
req_obj.update_finish_status(self.eos_id)
292292

293293
if extra_post_req_handle_chunk_func is not None:
294-
extra_post_req_handle_chunk_func(req_obj)
294+
extra_post_req_handle_chunk_func(req_obj, next_token_id, next_token_logprob)
295295

296296
if extra_post_req_handle_func is not None:
297297
extra_post_req_handle_func(req_obj, next_token_id, next_token_logprob)

lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py

Lines changed: 289 additions & 59 deletions
Large diffs are not rendered by default.

lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
import torch.multiprocessing as mp
44
import threading
5+
from concurrent.futures import ThreadPoolExecutor
56
from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend
67
from typing import List, Tuple, Dict
78
from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq
@@ -11,7 +12,7 @@
1112
from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample
1213
from lightllm.server.multimodal_params import MultimodalParams
1314

14-
from .pd_remote_prefill_obj import RemotePrefillTask, RemotePrefillServerInfo, RemotePrefillRequest
15+
from .pd_remote_prefill_obj import RemotePrefillTask, RemotePrefillServerInfo, RemotePrefillRequest, RemoteTransferStatusType
1516

1617
from .impl_for_pd_base import PDNIXLBackendBase
1718

@@ -24,7 +25,10 @@ def __init__(self, prefill_task_queue: mp.Queue, prefill_done_queue: mp.Queue, n
2425

2526
def init_custom(self):
2627
super().init_custom()
27-
self.wait_prefill_thread = threading.Thread(target=self._prefill_wait_loop, daemon=True)
28+
self.wait_prefill_thread = threading.Thread(target=self._start_async_loop,
29+
args=(self._prefill_wait_loop_async,),
30+
daemon=True)
31+
self.wait_move_page_pool = ThreadPoolExecutor(max_workers=4)
2832
self.wait_prefill_thread.start()
2933
return
3034

@@ -44,9 +48,15 @@ def _build_remote_prefill_task(self, index: int, kwargs: Dict, req: InferReq):
4448
multimodal_params=MultimodalParams.from_dict(req.multimodal_params),
4549
local_cached_len=req.cur_kv_len,
4650
token_ids=mem_indexes[b_start_loc[index] : b_start_loc[index + 1]],
51+
page_ids=self.page_scheduer.borrow() # get page ids for this request, blocking when not enough pages
4752
)
4853
return RemotePrefillTask(server_info=prefill_node_info, prefill_request=prefill_request)
4954

55+
def _trigger_remote_prefill(self, req_id: int, index: int, kwargs: Dict, req: InferReq):
56+
remote_prefill_task = self._build_remote_prefill_task(index, kwargs, req)
57+
self.request_to_page_ids[req_id] = remote_prefill_task.prefill_request.page_ids
58+
self.to_remote_queue.put(remote_prefill_task)
59+
5060
def prefill(self, reqs: List[Tuple]):
5161
self._init_reqs(reqs, init_req_obj=False)
5262
return
@@ -74,9 +84,11 @@ def decode(self):
7484
# since the token index are the same across TPs, we only need to trigger prefill on master
7585
if self.is_master_in_dp:
7686
run_req.remote_prefill_start = time.time()
77-
self.to_remote_queue.put(self._build_remote_prefill_task(idx, kwargs, run_req))
87+
# since this function may blocking the calling thread, so we do it in a thread pool
88+
self.wait_move_page_pool.submit(self._trigger_remote_prefill,
89+
shm_req.group_req_id, idx, kwargs, run_req)
7890

79-
shm_req.set_pd_req_rank_state(self.rank_in_dp, 0) # set in progress state
91+
shm_req.set_pd_req_rank_state(self.rank_in_dp, RemoteTransferStatusType.IN_PROGRESS.value) # set in progress state
8092
run_req.in_prefill_or_transfer = True
8193
self.remote_prefilled_reqs[shm_req.group_req_id] = run_req
8294

lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from lightllm.utils.envs_utils import get_env_start_args
1111
from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import padded_prepare_decode_inputs
1212

13-
from .impl_for_pd_decode import PDNIXLBackendForDecodeNode
13+
from .impl_for_pd_decode import PDNIXLBackendForDecodeNode, RemoteTransferStatusType
1414

1515
logger = init_logger(__name__)
1616

@@ -55,9 +55,11 @@ def decode(self):
5555
# since the token index are the same across TPs, we only need to trigger prefill on master
5656
if self.is_master_in_dp:
5757
run_req.remote_prefill_start = time.time()
58-
self.to_remote_queue.put(self._build_remote_prefill_task(idx, kwargs, run_req))
58+
# since this function may blocking the calling thread, so we do it in a thread pool
59+
self.wait_move_page_pool.submit(self._trigger_remote_prefill,
60+
shm_req.group_req_id, idx, kwargs, run_req)
5961

60-
shm_req.set_pd_req_rank_state(self.rank_in_dp, 0) # set in progress state
62+
shm_req.set_pd_req_rank_state(self.rank_in_dp, RemoteTransferStatusType.IN_PROGRESS.value) # set in progress state
6163
run_req.in_prefill_or_transfer = True
6264
self.remote_prefilled_reqs[shm_req.group_req_id] = run_req
6365

lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,15 @@ def __init__(self, transfer_task_queue: mp.Queue, transfer_done_queue: mp.Queue,
1818

1919
def init_custom(self):
2020
super().init_custom()
21-
self.handle_prefill_loop_thread = threading.Thread(target=self._handle_prefill_loop, daemon=True)
22-
self.wait_transfer_loop_thread = threading.Thread(target=self._wait_transfer_loop, daemon=True)
23-
self.handle_transfer_loop_thread = threading.Thread(target=self._handle_transfer_loop, daemon=True)
21+
self.handle_prefill_loop_thread = threading.Thread(target=self._start_async_loop,
22+
args=(self._handle_prefill_loop,),
23+
daemon=True)
24+
self.wait_transfer_loop_thread = threading.Thread(target=self._start_async_loop,
25+
args=(self._wait_page_transfer_loop,),
26+
daemon=True)
27+
self.handle_transfer_loop_thread = threading.Thread(target=self._start_async_loop,
28+
args=(self._handle_transfer_loop,),
29+
daemon=True)
2430

2531
self.handle_prefill_loop_thread.start()
2632
self.handle_transfer_loop_thread.start()

lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,6 @@ def decode(self):
4141
self._prefill_abort_remote(aborted_reqs)
4242
self._filter_reqs(aborted_reqs + ok_finished_reqs)
4343

44-
# if ok_finished_reqs:
45-
# for req in ok_finished_reqs:
46-
# self._transfer_kv_to_remote(req)
47-
# self._filter_reqs(ok_finished_reqs)
48-
# ok_finished_reqs.clear()
49-
5044
current_dp_prefill_num = len(prefill_reqs)
5145
self.reduce_tensor.fill_(current_dp_prefill_num)
5246
dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False)

0 commit comments

Comments
 (0)