Skip to content

[quant] deepgemm-fp8w8a8-b128 quantize #952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,18 @@ def __init__(
self.e_score_correction_bias = None
self.w2_list = [None] * ep_load_expert_num
self.w2_scale_list = [None] * ep_load_expert_num
self.scoring_func = network_config["scoring_func"]
self.scoring_func = network_config.get("scoring_func", "softmax")
self.w1 = [None, None] # weight, weight_scale
self.w2 = [None, None] # weight, weight_scale
self.use_fp8_w8a8 = self.quant_method is not None

network_config["n_group"] = network_config.get("n_group", 0)
self.num_experts_per_tok = network_config["num_experts_per_tok"]
self.use_grouped_topk = network_config["n_group"] > 0
self.norm_topk_prob = network_config["norm_topk_prob"]
self.n_group = network_config["n_group"]
network_config["topk_group"] = network_config.get("topk_group", 0)
self.topk_group = network_config["topk_group"]
network_config["routed_scaling_factor"] = network_config.get("routed_scaling_factor", 0)
self.routed_scaling_factor = network_config["routed_scaling_factor"]

self.lock = threading.Lock()
Expand Down
9 changes: 7 additions & 2 deletions lightllm/common/quantization/deepgemm_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,16 @@ def __init__(self):
self.act_scale_suffix = None # no support for static input tensor scale for ds model.

def quantize(self, weight: torch.Tensor):
from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider moving this import to the top of the file for better code organization and to avoid potential performance implications of local imports.


raise Exception("Not implemented")
return weight_quant(weight, self.block_size)

def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True):
qweight, weight_scale, input_scale = weights
if len(weights) == 3:
qweight, weight_scale, input_scale = weights
else:
qweight, weight_scale = weights
input_scale = None
m, k = input_tensor.shape
n = weights[0].shape[1]
if input_scale is None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
import triton
import triton.language as tl
from lightllm.utils.dist_utils import get_current_device_id


@triton.jit
def weight_quant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
n_blocks = tl.cdiv(N, BLOCK_SIZE)

offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)

x = tl.load(x_ptr + offs, mask=mask, other=0.0).to(tl.float32)

amax = tl.max(tl.abs(x))

max_fp8e4m3_val = 448.0
scale = amax / max_fp8e4m3_val
y = (x / (scale + 1e-6)).to(y_ptr.dtype.element_ty)

tl.store(y_ptr + offs, y, mask=mask)
tl.store(s_ptr + pid_m * n_blocks + pid_n, scale)


def mm_weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.is_contiguous(), "Input tensor must be contiguous"
M, N = x.size()

y_quant = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=x.device)

num_blocks_m = triton.cdiv(M, block_size)
num_blocks_n = triton.cdiv(N, block_size)
s_scales = torch.empty((num_blocks_m, num_blocks_n), dtype=torch.float32, device=x.device)

grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"]))
weight_quant_kernel[grid](x, s_scales, y_quant, M, N, BLOCK_SIZE=block_size)
return y_quant, s_scales


def weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.is_contiguous(), "Input tensor must be contiguous"
x = x.cuda(get_current_device_id())
if x.dim() == 3:
y_quant = torch.empty((x.shape[0], x.shape[1], x.shape[2]), dtype=torch.float8_e4m3fn, device=x.device)
num_blocks_m = triton.cdiv(x.shape[1], block_size)
num_blocks_n = triton.cdiv(x.shape[2], block_size)
s_scales = torch.empty((x.shape[0], num_blocks_m, num_blocks_n), dtype=torch.float32, device=x.device)
for i in range(x.shape[0]):
y_quant[i], s_scales[i] = mm_weight_quant(x[i], block_size)
return y_quant, s_scales
else:
y_quant, s_scales = mm_weight_quant(x, block_size)
return y_quant.t(), s_scales.t()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function returns transposed tensors for 2D inputs but not for 3D inputs. Add a comment explaining why the transpose is necessary for 2D tensors but not for 3D tensors.

10 changes: 3 additions & 7 deletions lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,22 +105,18 @@ def _moe_ffn_edp(

hidden_states = input
token_num, hidden_dim = hidden_states.shape
if self.n_shared_experts is not None:
shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight)

router_logits = layer_weight.moe_gate.mm(hidden_states)
ep_output = layer_weight.experts.experts(
hidden_states,
router_logits=router_logits,
top_k=self.num_experts_per_tok,
renormalize=self.norm_topk_prob,
use_grouped_topk=self.n_group,
topk_group=self.topk_group,
num_expert_group=self.n_group,
use_grouped_topk=False,
topk_group=None,
num_expert_group=None,
is_prefill=infer_state.is_prefill,
)
if self.n_shared_experts is not None:
ep_output.add_(shared_output)

ep_output = ep_output.view(token_num, hidden_dim)
return ep_output
5 changes: 5 additions & 0 deletions lightllm/models/qwen3_moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight
from lightllm.models.qwen3.model import Qwen3TpPartModel
from lightllm.utils.log_utils import init_logger
from lightllm.distributed.communication_op import dist_group_manager


logger = init_logger(__name__)
Expand All @@ -21,3 +22,7 @@ class Qwen3MOEModel(Qwen3TpPartModel):
def __init__(self, kvargs):
super().__init__(kvargs)
return

def _init_custom(self):
super()._init_custom()
dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"])