Skip to content

Add fake balance for EP mode #962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions lightllm/common/fused_moe/grouped_fused_moe_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton.language as tl
from typing import Any, Callable, Dict, Optional, Tuple
import torch.distributed as dist
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.log_utils import init_logger
from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd
from lightllm.common.fused_moe.moe_silu_and_mul_mix_quant_ep import silu_and_mul_masked_post_quant_fwd
Expand Down Expand Up @@ -142,6 +143,15 @@ def fused_experts_impl(

# scatter
all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums.

if get_env_start_args().enable_ep_fake_balance:
rank = dist.get_rank()
if rank == 0:
logger.info(
f"prefill, [{rank}], all_tokens = {all_tokens}, "
f"num_recv_tokens_per_expert_list: {num_recv_tokens_per_expert_list}"
)

# gather_out shape [recive_num_tokens, hidden]
gather_out = torch.empty_like(recv_x[0], device=hidden_states.device, dtype=hidden_states.dtype)
if all_tokens > 0:
Expand Down Expand Up @@ -219,6 +229,18 @@ def fused_experts_impl(
async_finish=False,
return_recv_hook=False,
)

# NOTE: when decoding graph is open, we can not call logger. Thus it can only be used when --disable_cudagraph
args = get_env_start_args()
if args.enable_ep_fake_balance and args.disable_cudagraph:
rank = dist.get_rank()
all_tokens = sum(masked_m)
if rank == 0:
logger.info(
f"decode, [{rank}], all_tokens = {all_tokens}, "
f"expected_m = {expected_m}, num_recv_tokens_per_expert: {masked_m}"
)

# deepgemm
gemm_out_b = masked_group_gemm(recv_x, masked_m, hidden_states.dtype, w1, w1_scale, w2, w2_scale, expected_m)
# low latency combine
Expand Down
9 changes: 9 additions & 0 deletions lightllm/common/fused_moe/topk_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import torch
from lightllm.utils.sgl_utils import sgl_ops
from lightllm.utils.light_utils import light_ops
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.balance_utils import BalancedTensor
from typing import Callable, List, Optional, Tuple
from lightllm.common.fused_moe.softmax_topk import softmax_topk

Expand Down Expand Up @@ -227,4 +229,11 @@ def select_experts(
hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize
)

# Enable EP fake balance
if get_env_start_args().enable_ep_fake_balance:
num_tokens, num_experts = router_logits.shape
balanced_tensor_collection = BalancedTensor(num_experts=num_experts, num_selected=top_k)
balance_topk_ids = balanced_tensor_collection.get_balance_topk_ids(num_tokens)
topk_ids.copy_(balance_topk_ids)

return topk_weights, topk_ids
3 changes: 3 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway"
)

parser.add_argument("--enable_ep_fake_balance", action="store_true", help="Enable the fake balance of the EP mode")

parser.add_argument("--disable_cudagraph", action="store_true", help="Disable the cudagraph of the decoding stage")

parser.add_argument(
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class StartArgs:
visual_dp: int = field(default=1)
visual_nccl_ports: List[int] = field(default_factory=lambda: [29500])
enable_monitor_auth: bool = field(default=False)
enable_ep_fake_balance: bool = field(default=False)
disable_cudagraph: bool = field(default=False)
graph_max_batch_size: int = field(default=256)
graph_split_batch_size: int = field(default=32)
Expand Down
67 changes: 67 additions & 0 deletions lightllm/utils/balance_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
import os

import threading

from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


def singleton_threadsafe(cls):
instances = {}
lock = threading.Lock()

def get_instance(*args, **kwargs):
# A key that includes the arguments is needed for parameter-dependent singletons.
# Using a tuple of args and a frozenset of kwargs items makes it hashable.
key = (cls, args, frozenset(kwargs.items()))
with lock:
if key not in instances:
instances[key] = cls(*args, **kwargs)
return instances[key]

return get_instance


@singleton_threadsafe
class BalancedTensor:
def __init__(self, num_experts=256, num_selected=8):
self.balanced_tensors = {}
self.num_experts = num_experts
self.num_selected = num_selected

def generate_balanced_tensor(self, num_tokens):
# Evenly distribute num_tokens to num_selected experts out of num_experts.
# Note that the num_selected experts activated by a token cannot be repeated.
tensor = torch.empty((num_tokens, self.num_selected), dtype=torch.int, device="cuda")
expert_load = torch.zeros(self.num_experts, dtype=torch.int, device="cuda")

for i in range(num_tokens):
selected_mask = torch.zeros(self.num_experts, dtype=torch.bool, device="cuda")
for j in range(self.num_selected):
# Use a large value for already selected experts to exclude them
load_view = torch.where(selected_mask, torch.iinfo(expert_load.dtype).max, expert_load)

min_load_indices = torch.where(load_view == load_view.min())[0]

if len(min_load_indices) > 1:
# If there are multiple least-loaded experts, select one randomly
rand_idx = torch.randint(0, len(min_load_indices), (1,), device="cuda").item()
chosen_expert = min_load_indices[rand_idx]
else:
chosen_expert = min_load_indices[0]

tensor[i, j] = chosen_expert
expert_load[chosen_expert] += 1
selected_mask[chosen_expert] = True

return tensor

def get_balance_topk_ids(self, num_tokens):
if num_tokens in self.balanced_tensors:
return self.balanced_tensors[num_tokens]

tensor = self.generate_balanced_tensor(num_tokens)
self.balanced_tensors[num_tokens] = tensor
return tensor