diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py index f7a24ae0f..cc925525c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py @@ -84,16 +84,18 @@ def __init__( self.e_score_correction_bias = None self.w2_list = [None] * ep_load_expert_num self.w2_scale_list = [None] * ep_load_expert_num - self.scoring_func = network_config["scoring_func"] + self.scoring_func = network_config.get("scoring_func", "softmax") self.w1 = [None, None] # weight, weight_scale self.w2 = [None, None] # weight, weight_scale self.use_fp8_w8a8 = self.quant_method is not None - + network_config["n_group"] = network_config.get("n_group", 0) self.num_experts_per_tok = network_config["num_experts_per_tok"] self.use_grouped_topk = network_config["n_group"] > 0 self.norm_topk_prob = network_config["norm_topk_prob"] self.n_group = network_config["n_group"] + network_config["topk_group"] = network_config.get("topk_group", 0) self.topk_group = network_config["topk_group"] + network_config["routed_scaling_factor"] = network_config.get("routed_scaling_factor", 0) self.routed_scaling_factor = network_config["routed_scaling_factor"] self.lock = threading.Lock() diff --git a/lightllm/common/quantization/deepgemm_quant.py b/lightllm/common/quantization/deepgemm_quant.py index 622a9711c..8d14805ad 100644 --- a/lightllm/common/quantization/deepgemm_quant.py +++ b/lightllm/common/quantization/deepgemm_quant.py @@ -41,11 +41,16 @@ def __init__(self): self.act_scale_suffix = None # no support for static input tensor scale for ds model. def quantize(self, weight: torch.Tensor): + from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant - raise Exception("Not implemented") + return weight_quant(weight, self.block_size) def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): - qweight, weight_scale, input_scale = weights + if len(weights) == 3: + qweight, weight_scale, input_scale = weights + else: + qweight, weight_scale = weights + input_scale = None m, k = input_tensor.shape n = weights[0].shape[1] if input_scale is None: diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py new file mode 100644 index 000000000..11c1897d7 --- /dev/null +++ b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py @@ -0,0 +1,58 @@ +import torch +import triton +import triton.language as tl +from lightllm.utils.dist_utils import get_current_device_id + + +@triton.jit +def weight_quant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n_blocks = tl.cdiv(N, BLOCK_SIZE) + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + + x = tl.load(x_ptr + offs, mask=mask, other=0.0).to(tl.float32) + + amax = tl.max(tl.abs(x)) + + max_fp8e4m3_val = 448.0 + scale = amax / max_fp8e4m3_val + y = (x / (scale + 1e-6)).to(y_ptr.dtype.element_ty) + + tl.store(y_ptr + offs, y, mask=mask) + tl.store(s_ptr + pid_m * n_blocks + pid_n, scale) + + +def mm_weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + assert x.is_contiguous(), "Input tensor must be contiguous" + M, N = x.size() + + y_quant = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=x.device) + + num_blocks_m = triton.cdiv(M, block_size) + num_blocks_n = triton.cdiv(N, block_size) + s_scales = torch.empty((num_blocks_m, num_blocks_n), dtype=torch.float32, device=x.device) + + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"])) + weight_quant_kernel[grid](x, s_scales, y_quant, M, N, BLOCK_SIZE=block_size) + return y_quant, s_scales + + +def weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + assert x.is_contiguous(), "Input tensor must be contiguous" + x = x.cuda(get_current_device_id()) + if x.dim() == 3: + y_quant = torch.empty((x.shape[0], x.shape[1], x.shape[2]), dtype=torch.float8_e4m3fn, device=x.device) + num_blocks_m = triton.cdiv(x.shape[1], block_size) + num_blocks_n = triton.cdiv(x.shape[2], block_size) + s_scales = torch.empty((x.shape[0], num_blocks_m, num_blocks_n), dtype=torch.float32, device=x.device) + for i in range(x.shape[0]): + y_quant[i], s_scales[i] = mm_weight_quant(x[i], block_size) + return y_quant, s_scales + else: + y_quant, s_scales = mm_weight_quant(x, block_size) + return y_quant.t(), s_scales.t() diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 57d10bdcd..2e01bc6e4 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -105,8 +105,6 @@ def _moe_ffn_edp( hidden_states = input token_num, hidden_dim = hidden_states.shape - if self.n_shared_experts is not None: - shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight) router_logits = layer_weight.moe_gate.mm(hidden_states) ep_output = layer_weight.experts.experts( @@ -114,13 +112,11 @@ def _moe_ffn_edp( router_logits=router_logits, top_k=self.num_experts_per_tok, renormalize=self.norm_topk_prob, - use_grouped_topk=self.n_group, - topk_group=self.topk_group, - num_expert_group=self.n_group, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, is_prefill=infer_state.is_prefill, ) - if self.n_shared_experts is not None: - ep_output.add_(shared_output) ep_output = ep_output.view(token_num, hidden_dim) return ep_output diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index b3421a325..10a505127 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -5,6 +5,7 @@ from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.qwen3.model import Qwen3TpPartModel from lightllm.utils.log_utils import init_logger +from lightllm.distributed.communication_op import dist_group_manager logger = init_logger(__name__) @@ -21,3 +22,7 @@ class Qwen3MOEModel(Qwen3TpPartModel): def __init__(self, kvargs): super().__init__(kvargs) return + + def _init_custom(self): + super()._init_custom() + dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"])