Skip to content

Commit 0b090c9

Browse files
author
Weichao Luo
committed
fix lint
1 parent 1c28ca9 commit 0b090c9

File tree

8 files changed

+75
-29
lines changed

8 files changed

+75
-29
lines changed

lightllm/server/pd_io_struct.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ class RemotePrefillServerInfo:
127127
prefill_server_ip: str
128128
prefill_server_port: int
129129

130+
130131
@dataclass
131132
class DistInfo:
132133
world_size: int
@@ -136,6 +137,7 @@ class DistInfo:
136137
dp_size_in_node: int
137138
node_world_size: int
138139

140+
139141
@dataclass
140142
class PDTransLeaveInfo:
141143
decode_id: int

lightllm/server/router/manager.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,18 @@ async def wait_to_model_ready(self):
210210
start_pd_remote_prefill_server_process,
211211
)
212212

213-
dist_info = DistInfo(self.world_size, self.nnodes, self.dp_size,
214-
self.dp_world_size, self.dp_size_in_node, self.node_world_size)
213+
dist_info = DistInfo(
214+
self.world_size,
215+
self.nnodes,
216+
self.dp_size,
217+
self.dp_world_size,
218+
self.dp_size_in_node,
219+
self.node_world_size,
220+
)
215221

216222
start_pd_remote_prefill_server_process(
217223
self.args.pd_node_id,
218-
dist_info = dist_info,
224+
dist_info=dist_info,
219225
http_server_port=self.args.pd_remote_prefill_http_port,
220226
server_port=self.args.pd_remote_prefill_port,
221227
from_backend_queue=self.info_queue,
@@ -235,8 +241,15 @@ async def wait_to_model_ready(self):
235241
from lightllm.server.router.model_infer.mode_backend.pd_nixl.pd_remote_prefill import (
236242
start_pd_remote_prefill_client_process,
237243
)
238-
dist_info = DistInfo(self.world_size, self.nnodes, self.dp_size,
239-
self.dp_world_size, self.dp_size_in_node, self.node_world_size)
244+
245+
dist_info = DistInfo(
246+
self.world_size,
247+
self.nnodes,
248+
self.dp_size,
249+
self.dp_world_size,
250+
self.dp_size_in_node,
251+
self.node_world_size,
252+
)
240253

241254
start_pd_remote_prefill_client_process(
242255
self.args.pd_node_id,

lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def _transfer_kv_to_remote(self, req: InferReq):
160160

161161
kv_transfer_req = KVMoveRequest(
162162
group_req_id=group_req_id,
163-
token_ids=token_index[ : req.cur_kv_len].tolist(),
163+
token_ids=token_index[: req.cur_kv_len].tolist(),
164164
prev_kv_len=transfer_state.current_kv_len,
165165
cur_kv_len=req.cur_kv_len,
166166
)
@@ -176,7 +176,6 @@ def _transfer_kv_to_remote(self, req: InferReq):
176176
transfer_state.current_kv_len = req.cur_kv_len
177177
transfer_state.current_chunk_id += 1
178178

179-
180179
def _decode_filter_reqs(
181180
self, prefill_reqs: List[InferReq], aborted_reqs: List[InferReq], decode_reqs: List[InferReq]
182181
):

lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def init_custom(self):
2525

2626
self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False)
2727
from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import padded_prepare_prefill_inputs
28+
2829
kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs([], 1, is_multimodal=self.is_multimodal)
2930
self.model.forward(**kwargs)
3031
assert len(run_reqs) == 0 and padded_req_num == 1

lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@ def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int
7373
next_token_ids = next_token_ids.detach().cpu().numpy()
7474
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
7575
self._post_handle(
76-
run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False,
76+
run_reqs,
77+
next_token_ids,
78+
next_token_logprobs,
79+
is_chuncked_mode=True,
80+
do_filter_finished_reqs=False,
7781
extra_post_req_handle_func=lambda req, _1, _2: self._transfer_kv_to_remote(req),
7882
)
7983

@@ -104,6 +108,10 @@ def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: in
104108
next_token_ids = next_token_ids.detach().cpu().numpy()
105109
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
106110
self._post_handle(
107-
all_run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False,
111+
all_run_reqs,
112+
next_token_ids,
113+
next_token_logprobs,
114+
is_chuncked_mode=True,
115+
do_filter_finished_reqs=False,
108116
extra_post_req_handle_func=lambda req, _1, _2: self._transfer_kv_to_remote(req),
109117
)

lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77
from lightllm.utils.log_utils import init_logger
88

99
from .pd_remote_prefill_obj import (
10-
RemoteAgent, KVMoveRequest, PrefillRequest,
11-
RemotePrefillStatus, ThreadSafeDict, KVMoveRequestState
12-
)
10+
RemoteAgent,
11+
KVMoveRequest,
12+
PrefillRequest,
13+
RemotePrefillStatus,
14+
ThreadSafeDict,
15+
KVMoveRequestState,
16+
)
1317

1418

1519
logger = init_logger(__name__)
@@ -120,10 +124,10 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest,
120124
return
121125

122126
kv_move_start = max(skip_kv_move_len, request.prev_kv_len)
123-
kv_move_end = request.cur_kv_len
127+
kv_move_end = request.cur_kv_len
124128

125129
src_token_ids = request.token_ids[kv_move_start:]
126-
dst_token_ids = prefill_request.data.token_ids[kv_move_start - skip_kv_move_len: kv_move_end]
130+
dst_token_ids = prefill_request.data.token_ids[kv_move_start - skip_kv_move_len : kv_move_end]
127131

128132
remote_agent: RemoteAgent = self.remote_agents[prefill_request.decode_id][
129133
self.tp_idx
@@ -140,7 +144,8 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest,
140144
group_req_id=group_reqeust_id,
141145
status=1,
142146
chunk_id=prefill_request.transfer_state.current_chunk_id,
143-
is_last=is_finished)
147+
is_last=is_finished,
148+
)
144149

145150
handle = self.nixl_agent.make_prepped_xfer(
146151
"WRITE", src_handle, src_token_descs, dst_handle, dst_token_descs, notify_status.serialize()
@@ -151,10 +156,7 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest,
151156

152157
if group_reqeust_id not in self.inflight_transfers:
153158
self.inflight_transfers[group_reqeust_id] = KVMoveRequestState(
154-
handles=[],
155-
done_handles=[],
156-
remote_agent=remote_agent,
157-
abort=False
159+
handles=[], done_handles=[], remote_agent=remote_agent, abort=False
158160
)
159161
self.inflight_transfers[group_reqeust_id].handles.append(handle)
160162

@@ -199,7 +201,9 @@ def get_done_tranfers(self):
199201
logger.warning(f"{req_id} Transfer failed with state {xfer_state}")
200202
failed = True
201203
kv_move_state.done_handles.append(handle)
202-
notify_failed_status = RemotePrefillStatus(group_req_id=req_id, status=-1, chunk_id=-1, is_last=True)
204+
notify_failed_status = RemotePrefillStatus(
205+
group_req_id=req_id, status=-1, chunk_id=-1, is_last=True
206+
)
203207
self.nixl_agent.send_notif(remote_agent.name, notify_failed_status.serialize())
204208

205209
kv_move_state.handles = left_handles

lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,17 @@ def main_loop(self):
112112
if self.dist_info.dp_size_in_node > 1:
113113
group_req_id = request.data.sampling_params.group_request_id
114114
suggested_dp_index = request.data.sampling_params.suggested_dp_index
115-
if suggested_dp_index < 0: # not likely to happen
115+
if suggested_dp_index < 0: # not likely to happen
116116
suggested_dp_index = random.randint(0, self.dist_info.dp_size_in_node)
117117
request.data.sampling_params.suggested_dp_index = suggested_dp_index
118-
logger.warning(f"Suggested dp index is negative for {group_req_id}, set to {suggested_dp_index}")
118+
logger.warning(
119+
f"Suggested dp index is negative for {group_req_id}, set to {suggested_dp_index}"
120+
)
119121

120-
for local_rank in range(suggested_dp_index * self.dist_info.dp_world_size,
121-
(suggested_dp_index + 1) * self.dist_info.dp_world_size):
122+
for local_rank in range(
123+
suggested_dp_index * self.dist_info.dp_world_size,
124+
(suggested_dp_index + 1) * self.dist_info.dp_world_size,
125+
):
122126
self.to_backend_queues[local_rank].put(request)
123127
else:
124128
for queue in self.to_backend_queues:
@@ -217,7 +221,11 @@ def main_loop(self):
217221
def remote_prefill(self, server_id: int, prefill_request: RemotePrefillRequest):
218222
socket, _ = self.remote_prefill_servers[server_id]
219223
prefill_request.sampling_params.max_new_tokens = 1
220-
socket.send_pyobj(PrefillRequest(type=RemoteRequstType.REMOTE_PREFILL, decode_id=self.id, data=prefill_request, transfer_state=None))
224+
socket.send_pyobj(
225+
PrefillRequest(
226+
type=RemoteRequstType.REMOTE_PREFILL, decode_id=self.id, data=prefill_request, transfer_state=None
227+
)
228+
)
221229

222230

223231
def remote_prefill_server_loop(
@@ -256,7 +264,11 @@ def start_pd_remote_prefill_server_process(
256264

257265

258266
def remote_prefill_client_loop(
259-
id: int, dist_info: DistInfo, from_backend_queue: mp.Queue, to_backend_queues: List[mp.Queue], agent_meta_queues: List[mp.Queue]
267+
id: int,
268+
dist_info: DistInfo,
269+
from_backend_queue: mp.Queue,
270+
to_backend_queues: List[mp.Queue],
271+
agent_meta_queues: List[mp.Queue],
260272
):
261273
graceful_registry(inspect.currentframe().f_code.co_name)
262274

@@ -271,11 +283,16 @@ def remote_prefill_client_loop(
271283

272284

273285
def start_pd_remote_prefill_client_process(
274-
id: int, dist_info: DistInfo, from_backend_queue: mp.Queue, to_backend_queues: List[mp.Queue], agent_meta_queues: List[mp.Queue]
286+
id: int,
287+
dist_info: DistInfo,
288+
from_backend_queue: mp.Queue,
289+
to_backend_queues: List[mp.Queue],
290+
agent_meta_queues: List[mp.Queue],
275291
):
276292

277293
proc = mp.Process(
278-
target=remote_prefill_client_loop, args=(id, dist_info, from_backend_queue, to_backend_queues, agent_meta_queues)
294+
target=remote_prefill_client_loop,
295+
args=(id, dist_info, from_backend_queue, to_backend_queues, agent_meta_queues),
279296
)
280297
proc.start()
281298
assert proc.is_alive()

lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,14 @@ class ConnectRequest(RemoteRequest):
5353
agent_metadatas: List[bytes]
5454
agent_mem_descs: List[bytes]
5555

56+
5657
@dataclass
5758
class TransferState:
5859
start_time: float
5960
current_kv_len: int
6061
current_chunk_id: int
6162

63+
6264
@dataclass
6365
class PrefillRequest(RemoteRequest):
6466
decode_id: int
@@ -82,6 +84,7 @@ class RemoteAgent:
8284
kv_mem_desc: nixlBind.nixlRegDList
8385
kv_xfer_handles: nixl_prepped_dlist_handle
8486

87+
8588
@dataclass
8689
class KVMoveRequestState:
8790
handles: List[nixl_xfer_handle]
@@ -184,4 +187,3 @@ def bind(self, addr: str):
184187

185188
def connect(self, addr: str):
186189
return self.sock.connect(addr)
187-

0 commit comments

Comments
 (0)