1
1
import time
2
2
import torch
3
3
import torch .multiprocessing as mp
4
- import torch .distributed as dist
5
- from typing import List
6
4
from lightllm .server .router .model_infer .infer_batch import g_infer_context , InferReq
7
5
from lightllm .server .core .objs .req import PDNIXLChunkedPrefillReq
8
6
from lightllm .utils .log_utils import init_logger
9
- from lightllm .server .router .model_infer .mode_backend .generic_post_process import sample
10
7
from lightllm .utils .envs_utils import get_env_start_args
11
- from lightllm .server .router .model_infer .mode_backend .dp_backend .pre_process import padded_prepare_decode_inputs
8
+ from lightllm .server .router .model_infer .mode_backend .dp_backend .impl import DPChunkedPrefillBackend
12
9
13
10
from .impl_for_pd_decode import PDNIXLBackendForDecodeNode , RemoteTransferStatusType
14
11
@@ -24,7 +21,7 @@ def init_custom(self):
24
21
super ().init_custom ()
25
22
26
23
self .reduce_tensor = torch .tensor ([0 ], dtype = torch .int32 , device = "cuda" , requires_grad = False )
27
- from lightllm .server .router .model_infer .mode_backend .dp_backend . pre_process import padded_prepare_prefill_inputs
24
+ from lightllm .server .router .model_infer .mode_backend .pre import padded_prepare_prefill_inputs
28
25
29
26
kwargs , run_reqs , padded_req_num = padded_prepare_prefill_inputs ([], 1 , is_multimodal = self .is_multimodal )
30
27
self .model .forward (** kwargs )
@@ -63,62 +60,13 @@ def decode(self):
63
60
run_req .in_prefill_or_transfer = True
64
61
self .remote_prefilled_reqs [shm_req .group_req_id ] = run_req
65
62
66
- self .reduce_tensor .fill_ (len (decode_reqs ))
67
- dist .all_reduce (self .reduce_tensor , op = dist .ReduceOp .MAX )
68
- max_decode_num = self .reduce_tensor .item ()
63
+ max_decode_num = self ._dp_all_reduce_decode_req_num (decode_reqs = decode_reqs )
69
64
if max_decode_num != 0 :
70
65
if not self .enable_decode_microbatch_overlap :
71
- self .normal_decode (decode_reqs , max_decode_num , uninit_reqs , ok_finished_reqs )
66
+ DPChunkedPrefillBackend .normal_decode (self , decode_reqs , max_decode_num , uninit_reqs , ok_finished_reqs )
72
67
else :
73
- self .overlap_decode (decode_reqs , max_decode_num , uninit_reqs , ok_finished_reqs )
74
- self ._overlap_req_init_and_filter (uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = True )
75
- return
76
-
77
- def normal_decode (self , decode_reqs : List [InferReq ], max_decode_num : int , uninit_reqs , ok_finished_reqs ):
68
+ DPChunkedPrefillBackend .overlap_decode (self , decode_reqs , max_decode_num , uninit_reqs , ok_finished_reqs )
78
69
79
- kwargs , run_reqs , padded_req_num = padded_prepare_decode_inputs (
80
- decode_reqs , max_decode_num , is_multimodal = self .is_multimodal
81
- )
82
- logits = self .model .forward (** kwargs )
83
70
self ._overlap_req_init_and_filter (uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = True )
84
- if len (run_reqs ) != 0 :
85
- logits = logits [0 : len (run_reqs ), :]
86
- next_token_ids , next_token_probs = sample (logits , run_reqs , self .eos_id )
87
- next_token_ids = next_token_ids .detach ().cpu ().numpy ()
88
- next_token_logprobs = torch .log (next_token_probs ).detach ().cpu ().numpy ()
89
- self ._post_handle (
90
- run_reqs , next_token_ids , next_token_logprobs , is_chuncked_mode = False , do_filter_finished_reqs = False
91
- )
92
71
return
93
72
94
- def overlap_decode (self , decode_reqs : List [InferReq ], max_decode_num : int , uninit_reqs , ok_finished_reqs ):
95
- from lightllm .server .router .model_infer .mode_backend .dp_backend .pre_process import (
96
- padded_overlap_prepare_decode_inputs ,
97
- )
98
-
99
- (
100
- micro_batch ,
101
- run_reqs ,
102
- padded_req_num ,
103
- micro_batch1 ,
104
- run_reqs1 ,
105
- padded_req_num1 ,
106
- ) = padded_overlap_prepare_decode_inputs (decode_reqs , max_decode_num , is_multimodal = self .is_multimodal )
107
-
108
- logits , logits1 = self .model .microbatch_overlap_decode (micro_batch , micro_batch1 )
109
- self ._overlap_req_init_and_filter (uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = True )
110
- req_num , req_num1 = len (run_reqs ), len (run_reqs1 )
111
- all_logits = torch .empty ((req_num + req_num1 , logits .shape [1 ]), dtype = logits .dtype , device = logits .device )
112
-
113
- all_logits [0 :req_num , :].copy_ (logits [0 :req_num , :], non_blocking = True )
114
- all_logits [req_num : (req_num + req_num1 ), :].copy_ (logits1 [0 :req_num1 , :], non_blocking = True )
115
-
116
- all_run_reqs = run_reqs + run_reqs1
117
- if all_run_reqs :
118
- next_token_ids , next_token_probs = sample (all_logits , all_run_reqs , self .eos_id )
119
- next_token_ids = next_token_ids .detach ().cpu ().numpy ()
120
- next_token_logprobs = torch .log (next_token_probs ).detach ().cpu ().numpy ()
121
- self ._post_handle (
122
- all_run_reqs , next_token_ids , next_token_logprobs , is_chuncked_mode = False , do_filter_finished_reqs = False
123
- )
124
- return
0 commit comments