Skip to content

Commit 1cb82dc

Browse files
author
Weichao Luo
committed
rebase main.
1 parent 9e73079 commit 1cb82dc

File tree

8 files changed

+56
-214
lines changed

8 files changed

+56
-214
lines changed

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def _post_handle(
261261
is_chuncked_mode: bool,
262262
do_filter_finished_reqs: bool,
263263
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None,
264-
extra_post_req_handle_chunk_func: Optional[Callable[[InferReq], None]] = None,
264+
call_post_handle_for_chunk: bool = False ,
265265
) -> List[int]:
266266
"""
267267
extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于
@@ -282,16 +282,12 @@ def _post_handle(
282282
if self.is_master_in_dp:
283283
shm_req.shm_cur_kv_len = req_obj.cur_kv_len
284284

285-
if extra_post_req_handle_chunk_func is not None:
286-
# 如果存在额外的处理函数,则调用这个函数进行处理。
287-
extra_post_req_handle_chunk_func(req_obj)
288-
289285
# 对于没有到达需要输出 token 阶段的请求,直接略过, 说明还
290286
# 处于chuncked prefill kv 填充的阶段。
291287
if req_obj.cur_kv_len < req_obj.get_cur_total_len():
292288
# chunk transfer
293-
if extra_post_req_handle_chunk_func is not None:
294-
extra_post_req_handle_chunk_func(req_obj)
289+
if call_post_handle_for_chunk and extra_post_req_handle_func:
290+
extra_post_req_handle_func(req_obj, next_token_id, next_token_logprob)
295291

296292
continue
297293

@@ -314,9 +310,6 @@ def _post_handle(
314310
# 更新判断请求的 finished 状态
315311
req_obj.update_finish_status(self.eos_id)
316312

317-
if extra_post_req_handle_chunk_func is not None:
318-
extra_post_req_handle_chunk_func(req_obj, next_token_id, next_token_logprob)
319-
320313
if extra_post_req_handle_func is not None:
321314
extra_post_req_handle_func(req_obj, next_token_id, next_token_logprob)
322315

lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def normal_prefill_reqs(
4646
ok_finished_reqs: List[InferReq],
4747
mask_func: Optional[Callable[[List[InferReq], torch.Tensor], None]] = None,
4848
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None,
49+
call_post_handle_for_chunk: bool = False
4950
):
5051
model_input, run_reqs = prepare_prefill_inputs(
5152
prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal
@@ -69,6 +70,7 @@ def normal_prefill_reqs(
6970
is_chuncked_mode=not self.disable_chunked_prefill,
7071
do_filter_finished_reqs=False,
7172
extra_post_req_handle_func=extra_post_req_handle_func,
73+
call_post_handle_for_chunk=call_post_handle_for_chunk,
7274
)
7375
return
7476

lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from typing import List, Tuple
2+
from typing import List, Tuple, Callable, Optional
33
from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend
44
from lightllm.common.basemodel.batch_objs import ModelOutput
55

@@ -52,7 +52,9 @@ def decode(self):
5252
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
5353
return
5454

55-
def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs):
55+
def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs,
56+
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None,
57+
call_post_handle_for_chunk: bool = False):
5658
model_input, run_reqs, padded_req_num = padded_prepare_prefill_inputs(
5759
prefill_reqs, is_multimodal=self.is_multimodal
5860
)
@@ -65,7 +67,9 @@ def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int
6567
next_token_ids = next_token_ids.detach().cpu().numpy()
6668
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
6769
self._post_handle(
68-
run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False
70+
run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False,
71+
extra_post_req_handle_func=extra_post_req_handle_func,
72+
call_post_handle_for_chunk=call_post_handle_for_chunk
6973
)
7074
return
7175

@@ -117,7 +121,9 @@ def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, unini
117121
)
118122
return
119123

120-
def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs):
124+
def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs,
125+
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None,
126+
call_post_handle_for_chunk: bool = False):
121127
(
122128
micro_input,
123129
run_reqs,
@@ -142,6 +148,8 @@ def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: in
142148
next_token_ids = next_token_ids.detach().cpu().numpy()
143149
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
144150
self._post_handle(
145-
all_run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False
151+
all_run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False,
152+
extra_post_req_handle_func=extra_post_req_handle_func,
153+
call_post_handle_for_chunk=call_post_handle_for_chunk
146154
)
147155
return

lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
import time
2-
import torch
32
import torch.multiprocessing as mp
43
import threading
54
from concurrent.futures import ThreadPoolExecutor
6-
from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend
5+
from lightllm.server.router.model_infer.mode_backend.continues_batch.impl import ContinuesBatchBackend
76
from typing import List, Tuple, Dict
87
from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq
98
from lightllm.server.core.objs.req import PDNIXLChunkedPrefillReq
109
from lightllm.utils.log_utils import init_logger
11-
from lightllm.server.router.model_infer.mode_backend.generic_pre_process import prepare_decode_inputs
12-
from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample
1310
from lightllm.server.multimodal_params import MultimodalParams
1411

1512
from .pd_remote_prefill_obj import RemotePrefillTask, RemotePrefillServerInfo, RemotePrefillRequest, RemoteTransferStatusType
@@ -93,21 +90,8 @@ def decode(self):
9390
self.remote_prefilled_reqs[shm_req.group_req_id] = run_req
9491

9592
if decode_reqs:
96-
kwargs, run_reqs = prepare_decode_inputs(decode_reqs)
97-
logits = self.model.forward(**kwargs)
98-
99-
self._overlap_req_init_and_filter(
100-
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
101-
)
102-
103-
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
104-
next_token_ids = next_token_ids.detach().cpu().numpy()
105-
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
106-
107-
self._post_handle(
108-
run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False
109-
)
93+
ContinuesBatchBackend.normal_decode(
94+
self, decode_reqs=decode_reqs, uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs)
11095

11196
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
112-
11397
return
Lines changed: 5 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
import time
22
import torch
33
import torch.multiprocessing as mp
4-
import torch.distributed as dist
5-
from typing import List
64
from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq
75
from lightllm.server.core.objs.req import PDNIXLChunkedPrefillReq
86
from lightllm.utils.log_utils import init_logger
9-
from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample
107
from lightllm.utils.envs_utils import get_env_start_args
11-
from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import padded_prepare_decode_inputs
8+
from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend
129

1310
from .impl_for_pd_decode import PDNIXLBackendForDecodeNode, RemoteTransferStatusType
1411

@@ -24,7 +21,7 @@ def init_custom(self):
2421
super().init_custom()
2522

2623
self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False)
27-
from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import padded_prepare_prefill_inputs
24+
from lightllm.server.router.model_infer.mode_backend.pre import padded_prepare_prefill_inputs
2825

2926
kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs([], 1, is_multimodal=self.is_multimodal)
3027
self.model.forward(**kwargs)
@@ -63,62 +60,13 @@ def decode(self):
6360
run_req.in_prefill_or_transfer = True
6461
self.remote_prefilled_reqs[shm_req.group_req_id] = run_req
6562

66-
self.reduce_tensor.fill_(len(decode_reqs))
67-
dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX)
68-
max_decode_num = self.reduce_tensor.item()
63+
max_decode_num = self._dp_all_reduce_decode_req_num(decode_reqs=decode_reqs)
6964
if max_decode_num != 0:
7065
if not self.enable_decode_microbatch_overlap:
71-
self.normal_decode(decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs)
66+
DPChunkedPrefillBackend.normal_decode(self, decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs)
7267
else:
73-
self.overlap_decode(decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs)
74-
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
75-
return
76-
77-
def normal_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs):
68+
DPChunkedPrefillBackend.overlap_decode(self, decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs)
7869

79-
kwargs, run_reqs, padded_req_num = padded_prepare_decode_inputs(
80-
decode_reqs, max_decode_num, is_multimodal=self.is_multimodal
81-
)
82-
logits = self.model.forward(**kwargs)
8370
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
84-
if len(run_reqs) != 0:
85-
logits = logits[0 : len(run_reqs), :]
86-
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
87-
next_token_ids = next_token_ids.detach().cpu().numpy()
88-
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
89-
self._post_handle(
90-
run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False
91-
)
9271
return
9372

94-
def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs):
95-
from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import (
96-
padded_overlap_prepare_decode_inputs,
97-
)
98-
99-
(
100-
micro_batch,
101-
run_reqs,
102-
padded_req_num,
103-
micro_batch1,
104-
run_reqs1,
105-
padded_req_num1,
106-
) = padded_overlap_prepare_decode_inputs(decode_reqs, max_decode_num, is_multimodal=self.is_multimodal)
107-
108-
logits, logits1 = self.model.microbatch_overlap_decode(micro_batch, micro_batch1)
109-
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
110-
req_num, req_num1 = len(run_reqs), len(run_reqs1)
111-
all_logits = torch.empty((req_num + req_num1, logits.shape[1]), dtype=logits.dtype, device=logits.device)
112-
113-
all_logits[0:req_num, :].copy_(logits[0:req_num, :], non_blocking=True)
114-
all_logits[req_num : (req_num + req_num1), :].copy_(logits1[0:req_num1, :], non_blocking=True)
115-
116-
all_run_reqs = run_reqs + run_reqs1
117-
if all_run_reqs:
118-
next_token_ids, next_token_probs = sample(all_logits, all_run_reqs, self.eos_id)
119-
next_token_ids = next_token_ids.detach().cpu().numpy()
120-
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
121-
self._post_handle(
122-
all_run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False
123-
)
124-
return

lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
import threading
2-
import torch
32
import torch.multiprocessing as mp
43
from typing import List, Tuple
5-
from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end
6-
from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context
4+
from lightllm.server.router.model_infer.infer_batch import g_infer_context
75
from lightllm.utils.log_utils import init_logger
8-
from lightllm.server.router.model_infer.mode_backend.generic_pre_process import prepare_prefill_inputs
9-
from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample
6+
from lightllm.server.router.model_infer.mode_backend.continues_batch.impl import ContinuesBatchBackend
107
from .impl_for_pd_base import PDNIXLBackendBase
118

129
logger = init_logger(__name__)
@@ -49,25 +46,10 @@ def decode(self):
4946
assert len(decode_reqs) == 0
5047

5148
self._prefill_abort_remote(aborted_reqs)
52-
self._filter_reqs(aborted_reqs + ok_finished_reqs)
49+
self._filter_reqs(aborted_reqs)
5350

5451
if prefill_reqs:
55-
kwargs, run_reqs = prepare_prefill_inputs(
56-
prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal
57-
)
58-
59-
logits = self.model.forward(**kwargs)
60-
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
61-
next_token_ids = next_token_ids.detach().cpu().numpy()
62-
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
63-
64-
65-
self._post_handle(
66-
run_reqs,
67-
next_token_ids,
68-
next_token_logprobs,
69-
is_chuncked_mode=True,
70-
do_filter_finished_reqs=False,
71-
extra_post_req_handle_chunk_func=self._handle_chunked_transfer,
72-
)
52+
ContinuesBatchBackend.normal_prefill_reqs(
53+
self, prefill_reqs=prefill_reqs, uninit_reqs=uinit_reqs, ok_finished_reqs=ok_finished_reqs,
54+
extra_post_req_handle_func=self._handle_chunked_transfer, call_post_handle_for_chunk=True)
7355
return

0 commit comments

Comments
 (0)