From 31c0729d58f24219d0cd2ef0842a8bab3c9e7a99 Mon Sep 17 00:00:00 2001 From: baishihao Date: Fri, 13 Jun 2025 20:22:53 +0800 Subject: [PATCH 01/14] speedup load --- .../meta_weights/fused_moe_weight_tp.py | 42 ++++++++++--------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py index 131e65f54..abb0985b9 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py @@ -93,16 +93,18 @@ def _fuse(self): and None not in self.experts_gate_projs and None not in self.w2_list ): - w1_list = [] + gate_out_dim, gate_in_dim = self.experts_gate_projs[0].shape + up_out_dim, up_in_dim = self.experts_up_projs[0].shape + assert gate_in_dim == up_in_dim + dtype = self.experts_gate_projs[0].dtype + total_expert_num = self.n_routed_experts + + w1 = torch.empty((total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu") + for i_experts in range(self.n_routed_experts): - expert_gate_up_proj = torch.cat( - [self.experts_gate_projs[i_experts], self.experts_up_projs[i_experts]], dim=0 - ) - expert_gate_up_proj = expert_gate_up_proj - w1_list.append(expert_gate_up_proj) - - inter_shape, hidden_size = w1_list[0].shape[0], w1_list[0].shape[1] - w1 = torch._utils._flatten_dense_tensors(w1_list).view(len(w1_list), inter_shape, hidden_size) + w1[i_experts, 0:gate_out_dim:, :] = self.experts_gate_projs[i_experts] + w1[i_experts, gate_out_dim:, :] = self.experts_up_projs[i_experts] + inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) if not self.quantized_weight and self.quant_method is not None: @@ -123,17 +125,19 @@ def _fuse_weight_scale(self): and None not in self.experts_gate_proj_scales and None not in self.w2_scale_list ): - w1_scale_list = [] - for i_experts in range(self.n_routed_experts): - expert_gate_up_proj_scale = torch.cat( - [self.experts_gate_proj_scales[i_experts], self.experts_up_proj_scales[i_experts]], dim=0 - ) - w1_scale_list.append(expert_gate_up_proj_scale) - - inter_shape, hidden_size = w1_scale_list[0].shape[0], w1_scale_list[0].shape[1] - w1_scale = torch._utils._flatten_dense_tensors(w1_scale_list).view( - len(w1_scale_list), inter_shape, hidden_size + gate_out_dim, gate_in_dim = self.experts_gate_proj_scales[0].shape + up_out_dim, up_in_dim = self.experts_up_proj_scales[0].shape + assert gate_in_dim == up_in_dim + dtype = self.experts_gate_proj_scales[0].dtype + total_expert_num = self.n_routed_experts + + w1_scale = torch.empty( + (total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu" ) + + for i_experts in range(self.n_routed_experts): + w1_scale[i_experts, 0:gate_out_dim:, :] = self.experts_gate_proj_scales[i_experts] + w1_scale[i_experts, gate_out_dim:, :] = self.experts_up_proj_scales[i_experts] inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1] w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view( len(self.w2_scale_list), inter_shape, hidden_size From a440f2d2d318a4c219b619f4487787b799c416f6 Mon Sep 17 00:00:00 2001 From: baishihao Date: Fri, 13 Jun 2025 20:26:54 +0800 Subject: [PATCH 02/14] speedup kernel --- .../common/fused_moe/grouped_fused_moe.py | 117 +++++++++++------- 1 file changed, 70 insertions(+), 47 deletions(-) diff --git a/lightllm/common/fused_moe/grouped_fused_moe.py b/lightllm/common/fused_moe/grouped_fused_moe.py index c1e239bef..e31e50922 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/fused_moe/grouped_fused_moe.py @@ -34,6 +34,7 @@ from .moe_silu_and_mul import silu_and_mul_fwd from .moe_sum_reduce import moe_sum_reduce from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8 +from lightllm.utils.dist_utils import get_current_rank_in_dp FFN_MOE_CHUNK_SIZE = 8 * 1024 @@ -220,8 +221,13 @@ def moe_align1( @triton.jit def moe_align2_kernel( experts_token_num_ptr, # [expert_num,] - mblocks_to_expert_id, # [max_num_m_blocks,] - mblocks_to_m_index, # [max_num_m_blocks,] + expert_to_token_index_ptr, # [expert_num, token_num * topk_num] + expert_to_token_index_stride_0, + expert_to_weights_ptr, + expert_to_weights_stride_0, + mblocks_to_expert_id_ptr, # [max_num_m_blocks,] + padded_expert_to_token_index_ptr, + padded_expert_to_weights_ptr, expert_num, max_num_m_blocks, BLOCK_M: tl.constexpr, @@ -241,27 +247,49 @@ def moe_align2_kernel( block_off = tl.arange(0, 128) for start_loc in range(0, cur_block_num, 128): tl.store( - mblocks_to_expert_id + block_start + start_loc + block_off, + mblocks_to_expert_id_ptr + block_start + start_loc + block_off, expert_id, mask=start_loc + block_off < cur_block_num, ) + + cur_expert_to_token_index_ptr = expert_to_token_index_ptr + expert_id * expert_to_token_index_stride_0 + for start_loc in range(0, cur_block_num): + offset = start_loc * BLOCK_M + tl.arange(0, BLOCK_M) + m_index = tl.load(cur_expert_to_token_index_ptr + offset, mask=offset < cur_expert_token_num, other=0) tl.store( - mblocks_to_m_index + block_start + start_loc + block_off, - start_loc + block_off, - mask=start_loc + block_off < cur_block_num, + padded_expert_to_token_index_ptr + block_start * BLOCK_M + offset, + m_index, + mask=offset < cur_expert_token_num, + ) + + m_weight = tl.load( + expert_to_weights_ptr + expert_id * expert_to_weights_stride_0 + offset, + mask=offset < cur_expert_token_num, + other=0.0, + ) + tl.store( + padded_expert_to_weights_ptr + block_start * BLOCK_M + offset, + m_weight, + mask=offset < cur_expert_token_num, ) if expert_id == expert_num - 1: for extra_fill_start in range(block_start + cur_block_num, max_num_m_blocks, 128): tl.store( - mblocks_to_expert_id + extra_fill_start + block_off, + mblocks_to_expert_id_ptr + extra_fill_start + block_off, -1, mask=extra_fill_start + block_off < max_num_m_blocks, ) return -def moe_align2(token_num_mul_topk_num: int, exports_token_num: torch.Tensor, block_m: int): +def moe_align2( + token_num_mul_topk_num: int, + exports_token_num: torch.Tensor, + block_m: int, + expert_to_token_index: torch.Tensor, + expert_to_weights: torch.Tensor, +): """ exports_token_num is tensor shape [expert_num] , will get expert need handle token num. out tensor is a tensor that contain block schduel infos tensor. @@ -269,14 +297,20 @@ def moe_align2(token_num_mul_topk_num: int, exports_token_num: torch.Tensor, blo max_num_tokens_padded = token_num_mul_topk_num + exports_token_num.shape[0] * (block_m - 1) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_m) mblocks_to_expert_id = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda") - mblocks_to_m_index = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda") + padded_expert_to_token_index = torch.empty(max_num_tokens_padded, dtype=torch.int32, device="cuda").fill_(-1) + padded_expert_to_weights = torch.empty(max_num_tokens_padded, dtype=torch.float32, device="cuda") expert_num = exports_token_num.shape[0] grid = (expert_num,) moe_align2_kernel[grid]( exports_token_num, + expert_to_token_index, + expert_to_token_index.stride(0), + expert_to_weights, + expert_to_weights.stride(0), mblocks_to_expert_id, - mblocks_to_m_index, + padded_expert_to_token_index, + padded_expert_to_weights, expert_num, max_num_m_blocks, BLOCK_M=block_m, @@ -285,13 +319,14 @@ def moe_align2(token_num_mul_topk_num: int, exports_token_num: torch.Tensor, blo num_stages=1, ) - return mblocks_to_expert_id, mblocks_to_m_index + return mblocks_to_expert_id, padded_expert_to_token_index, padded_expert_to_weights @triton.jit def grouped_matmul_kernel( mblocks_to_expert_id, # [max_m_block_size] - mblocks_to_m_index, # [max_m_block_size] + padded_expert_to_token_index, # [max_m_block_size] + padded_expert_to_weights, # [max_m_block_size] k, # int n, # int topk_num, # int @@ -307,12 +342,7 @@ def grouped_matmul_kernel( weight_stride_0, weight_stride_1, weight_stride_2, - expert_to_weights_ptr, # [expert_num, token_num * topk] - expert_to_weights_stride0, - expert_to_weights_stride1, expert_to_token_num, # [expert_num] - expert_to_token_index, # [expert_num, token_num * topk_num] - expert_to_token_index_stride_0, out_ptr, # [token_num * topk_num, n] out_stride_0, out_stride_1, @@ -350,28 +380,14 @@ def grouped_matmul_kernel( if expert_id == -1: return - - tile_m_idx = tl.load(mblocks_to_m_index + pid_m) tile_n_idx = pid_n - - # get the gemm size of the current problem - cur_m = tl.load(expert_to_token_num + expert_id, eviction_policy="evict_last") - # do regular gemm here - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - token_mask = offs_am < cur_m + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + # token_mask = offs_am < cur_m a_m_index = tl.load( - expert_to_token_index + expert_id * expert_to_token_index_stride_0 + offs_am, - mask=token_mask, - other=0, + padded_expert_to_token_index + offs_am, ) - if MUL_ROUTED_WEIGHT: - a_m_scale = tl.load( - expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am, - mask=token_mask, - other=0.0, - ) - + token_mask = a_m_index != -1 offs_bn = (tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n offs_k = tl.arange(0, BLOCK_SIZE_K) @@ -437,6 +453,11 @@ def grouped_matmul_kernel( accumulator *= ab_scale if MUL_ROUTED_WEIGHT: + a_m_scale = tl.load( + padded_expert_to_weights + offs_am, + mask=token_mask, + other=0.0, + ) accumulator *= a_m_scale[:, None] c = accumulator.to(compute_type) @@ -530,16 +551,22 @@ def grouped_matmul( token_inputs, token_input_scale = qinput_tensor, input_scale if reused_mblock_infos is None: - mblocks_to_expert_id, mblocks_to_m_index = moe_align2(token_num_mul_topk_num, expert_to_token_num, BLOCK_SIZE_M) + mblocks_to_expert_id, padded_expert_to_token_index, padded_expert_to_weights = moe_align2( + token_num_mul_topk_num, expert_to_token_num, BLOCK_SIZE_M, expert_to_token_index, expert_to_weights + ) else: # when up group gemm and down group gemm use same BLOCK_SIZE_M, # can reuse (mblocks_to_expert_id, mblocks_to_m_index) created by moe_align2 kernel. - mblocks_to_expert_id, mblocks_to_m_index, reused_block_size_m = reused_mblock_infos + ( + mblocks_to_expert_id, + padded_expert_to_token_index, + padded_expert_to_weights, + reused_block_size_m, + ) = reused_mblock_infos if reused_block_size_m != BLOCK_SIZE_M: - mblocks_to_expert_id, mblocks_to_m_index = moe_align2( - token_num_mul_topk_num, expert_to_token_num, BLOCK_SIZE_M + mblocks_to_expert_id, padded_expert_to_token_index, padded_expert_to_weights = moe_align2( + token_num_mul_topk_num, expert_to_token_num, BLOCK_SIZE_M, expert_to_token_index, expert_to_weights ) - block_num = triton.cdiv(n, BLOCK_SIZE_N) * mblocks_to_expert_id.shape[0] grid = (block_num,) @@ -548,7 +575,8 @@ def grouped_matmul( grouped_matmul_kernel[grid]( mblocks_to_expert_id, - mblocks_to_m_index, + padded_expert_to_token_index, + padded_expert_to_weights, k, n, topk_num, @@ -570,12 +598,7 @@ def grouped_matmul( expert_weights.stride(0), expert_weights.stride(1), expert_weights.stride(2), - expert_to_weights, - expert_to_weights.stride(0), - expert_to_weights.stride(1), expert_to_token_num, - expert_to_token_index, - expert_to_token_index.stride(0), out, out.stride(0), out.stride(1), @@ -594,7 +617,7 @@ def grouped_matmul( num_warps=num_warps, num_stages=num_stages, ) - return (mblocks_to_expert_id, mblocks_to_m_index, BLOCK_SIZE_M) + return (mblocks_to_expert_id, padded_expert_to_token_index, padded_expert_to_weights, BLOCK_SIZE_M) def fused_experts_impl( From bf33cdbf1a8d8ae6a00b5f2d835042dd1a131260 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sun, 15 Jun 2025 21:53:57 +0800 Subject: [PATCH 03/14] update tuning --- test/kernel/fuse_moe_tuning_fp8.py | 51 +++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 12 deletions(-) diff --git a/test/kernel/fuse_moe_tuning_fp8.py b/test/kernel/fuse_moe_tuning_fp8.py index a30de8d03..2e9fb88d8 100644 --- a/test/kernel/fuse_moe_tuning_fp8.py +++ b/test/kernel/fuse_moe_tuning_fp8.py @@ -58,14 +58,37 @@ def test_kernel( test_count: int, use_fp8_w8a8: bool, is_up: bool, + block_shape, **config, ): set_seed() input_tuples = [] a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((expert_num, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((expert_num, k, n), device="cuda", dtype=dtype) / 10 + w1_scale = w2_scale = None + + if use_fp8_w8a8: + init_dtype = dtype + w1 = torch.randn(expert_num, 2 * n, k, dtype=init_dtype).cuda() + w2 = torch.randn(expert_num, k, 2 * n // 2, dtype=init_dtype).cuda() + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + + if block_shape is None: + w1_scale = torch.randn(expert_num, dtype=torch.float32).cuda() + w2_scale = torch.randn(expert_num, dtype=torch.float32).cuda() + else: + block_n, block_k = block_shape[0], block_shape[1] + n_tiles_w1 = (2 * n + block_n - 1) // block_n + n_tiles_w2 = (k + block_n - 1) // block_n + k_tiles_w1 = (k + block_k - 1) // block_k + k_tiles_w2 = (2 * n // 2 + block_k - 1) // block_k + w1_scale = torch.rand((expert_num, n_tiles_w1, k_tiles_w1), dtype=torch.float32).cuda() + w2_scale = torch.rand((expert_num, n_tiles_w2, k_tiles_w2), dtype=torch.float32).cuda() + else: + w1 = torch.randn(expert_num, 2 * n, k, dtype=dtype).cuda() + w2 = torch.randn(expert_num, k, 2 * n // 2, dtype=dtype).cuda() + rnd_logics = torch.randn(m, expert_num, device="cuda") topk_values, topk_ids = torch.topk(rnd_logics, topk, dim=1) topk_weights = torch.randn((m, topk), device="cuda", dtype=dtype) / 10 @@ -75,12 +98,6 @@ def test_kernel( moe_align(topk_ids=topk_ids, out=expert_to_tokens) expert_to_token_num = torch.empty((expert_num,), dtype=torch.int32, device="cuda") moe_align1(expert_to_tokens, topk_weights, expert_to_weights, expert_to_token_num, topk=topk) - if use_fp8_w8a8: - w1, w1_scale = quantize_moe(w1) - w2, w2_scale = quantize_moe(w2) - else: - w1_scale = torch.empty((0,)) - w2_scale = torch.empty((0,)) out1 = torch.zeros((m * topk, 2 * n), dtype=torch.bfloat16, device="cuda") down_in = torch.zeros((m * topk, n), dtype=torch.bfloat16, device="cuda") @@ -142,6 +159,7 @@ def test_kernel( a, w1, w2, w1_scale, w2_scale, topk_ids, topk_weights, out1, out2, down_in = input_tuples[index] if is_up: grouped_matmul( + topk_ids.numel(), a, None, expert_to_token_num, @@ -158,6 +176,7 @@ def test_kernel( ) else: grouped_matmul( + topk_ids.numel(), down_in, None, expert_to_token_num, @@ -197,6 +216,7 @@ def worker( test_count: int, use_fp8_w8a8: bool, is_up: bool, + block_shape, test_configs, queue, ): @@ -212,6 +232,7 @@ def worker( test_count=test_count, use_fp8_w8a8=use_fp8_w8a8, is_up=is_up, + block_shape=block_shape, **test_configs[index], ) queue.put(cost_time) # Put result in queue @@ -278,6 +299,7 @@ def tuning_configs( test_count: int, use_fp8_w8a8: bool, is_up: bool, + block_shape, ): os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) best_config, best_cost_time = None, 10000000 @@ -300,6 +322,7 @@ def tuning_configs( test_count, use_fp8_w8a8, is_up, + block_shape, test_configs, queue, ), @@ -333,6 +356,7 @@ def tuning_configs( test_count, use_fp8_w8a8, is_up, + block_shape, test_configs, queue, ), @@ -364,10 +388,11 @@ def tuning_configs( from lightllm.common.fused_moe.moe_kernel_configs import MoeGroupedGemmKernelConfig # tuning to get deepseekv2 large configs and store in H800, tp 8 - expert_num = 160 - n = 192 # up is n * 2 - hidden_dim = 5120 - topk_num = 6 + expert_num = 256 + n = 256 # up is n * 2 + hidden_dim = 7168 + topk_num = 8 + block_shape = [128, 128] up_dict = {} for m in [1, 8, 64, 128, 256, 512, 1024, 4096, 8192]: @@ -383,6 +408,7 @@ def tuning_configs( "test_count": 20, "use_fp8_w8a8": True, "is_up": True, + "block_shape": block_shape, }, ) up_dict[m] = ans @@ -411,6 +437,7 @@ def tuning_configs( "test_count": 20, "use_fp8_w8a8": True, "is_up": False, + "block_shape": block_shape, }, ) down_dict[m] = ans From eec0be4310935fbccc38c9e74e297174d28a3fec Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sun, 15 Jun 2025 22:53:39 +0800 Subject: [PATCH 04/14] fix --- test/kernel/fuse_moe_tuning_bf16.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/kernel/fuse_moe_tuning_bf16.py b/test/kernel/fuse_moe_tuning_bf16.py index 712f2ab29..601f65685 100644 --- a/test/kernel/fuse_moe_tuning_bf16.py +++ b/test/kernel/fuse_moe_tuning_bf16.py @@ -142,6 +142,7 @@ def test_kernel( a, w1, w2, w1_scale, w2_scale, topk_ids, topk_weights, out1, out2, down_in = input_tuples[index] if is_up: grouped_matmul( + topk_ids.numel(), a, None, expert_to_token_num, @@ -158,6 +159,7 @@ def test_kernel( ) else: grouped_matmul( + topk_ids.numel(), down_in, None, expert_to_token_num, From 7c04f4871713a954740ce64869d4e8a08f891149 Mon Sep 17 00:00:00 2001 From: baishihao Date: Mon, 16 Jun 2025 14:01:46 +0800 Subject: [PATCH 05/14] add tuning config --- ..._num=1,use_fp8_w8a8=true}_NVIDIA_H200.json | 1 + ..._num=1,use_fp8_w8a8=true}_NVIDIA_H200.json | 2 +- ..._num=1,use_fp8_w8a8=true}_NVIDIA_H200.json | 1 + ..._num=8,use_fp8_w8a8=true}_NVIDIA_H200.json | 1 + ..._num=8,use_fp8_w8a8=true}_NVIDIA_H200.json | 1 + ..._num=8,use_fp8_w8a8=true}_NVIDIA_H200.json | 2 +- .../kernel/benchmark_fused_moe_triton.py | 330 ++++++++++++++++++ .../service}/benchmark_client.py | 0 test/{ => benchmark/service}/benchmark_mcq.py | 84 ++--- test/{ => benchmark/service}/benchmark_qps.py | 0 .../service/benchmark_sharegpt.py} | 105 +++--- 11 files changed, 416 insertions(+), 111 deletions(-) create mode 100644 lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=512,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json create mode 100644 test/benchmark/kernel/benchmark_fused_moe_triton.py rename test/{ => benchmark/service}/benchmark_client.py (100%) rename test/{ => benchmark/service}/benchmark_mcq.py (68%) rename test/{ => benchmark/service}/benchmark_qps.py (100%) rename test/{benchmark_serving.py => benchmark/service/benchmark_sharegpt.py} (71%) diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..ea69378f4 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 2, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 2}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json index 5c0dab42b..37ba845fa 100644 --- a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -1 +1 @@ -{"1": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}} \ No newline at end of file +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 2}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=512,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=512,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..fe56e1c44 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=512,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..25333e743 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..bc763e8bc --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json index 394ce3193..a4f26860b 100644 --- a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -1 +1 @@ -{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 5}, "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4}} \ No newline at end of file +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/test/benchmark/kernel/benchmark_fused_moe_triton.py b/test/benchmark/kernel/benchmark_fused_moe_triton.py new file mode 100644 index 000000000..6f7a5ee39 --- /dev/null +++ b/test/benchmark/kernel/benchmark_fused_moe_triton.py @@ -0,0 +1,330 @@ +# Adapted from +# https://github.com/sgl-project/sglang/blob/v0.4.6.post5/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py +import argparse + +import torch +import triton +import vllm +from transformers import AutoConfig +from lightllm.common.fused_moe.topk_select import select_experts +from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl +from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + fused_moe as fused_moe_sglang, +) + + +def get_model_config(model_name: str, tp_size: int): + """Get model configuration parameters""" + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + + if config.architectures[0] == "DbrxForCausalLM": + E = config.ffn_config.moe_num_experts + topk = config.ffn_config.moe_top_k + intermediate_size = config.ffn_config.ffn_hidden_size + shard_intermediate_size = 2 * intermediate_size // tp_size + elif config.architectures[0] == "JambaForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size + elif config.architectures[0] == "Qwen2MoeForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size + elif config.architectures[0] == "Qwen3MoeForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size + elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: + E = config.n_routed_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size + elif config.architectures[0] in [ + "Grok1ForCausalLM", + "Grok1ImgGen", + "Grok1AForCausalLM", + ]: + E = config.num_local_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size + else: + # Default: Mixtral + E = config.num_local_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size + + vllm_version_num = vllm.__version_tuple__[0] * 100 + vllm.__version_tuple__[1] * 10 + vllm.__version_tuple__[2] + block_shape = None + if hasattr(config, "quantization_config") and "weight_block_size" in config.quantization_config: + block_shape = config.quantization_config["weight_block_size"] + assert len(block_shape) == 2 + assert vllm_version_num >= 66, "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1" + + shape_configs = { + "num_experts": E, + "topk": topk, + "hidden_size": config.hidden_size, + "shard_intermediate_size": shard_intermediate_size, + "dtype": config.torch_dtype, + "block_shape": block_shape, + } + print(f"{shape_configs=}") + return shape_configs + + +def fused_moe_lightllm_api( + x, + w1, + w2, + input_gating, + topk, + use_fp8_w8a8=False, + w1_scale=None, + w2_scale=None, + a1_scale=None, + a2_scale=None, + block_shape=None, +): + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=input_gating, + correction_bias=None, + use_grouped_topk=False, + top_k=topk, + renormalize=True, + topk_group=None, + num_expert_group=None, + scoring_func="softmax", + ) + use_fp8_w8a8 = use_fp8_w8a8 + + return fused_experts_impl( + hidden_states=x, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) + + +def fused_moe_vllm_api( + x, + w1, + w2, + input_gating, + topk, + use_fp8_w8a8=False, + w1_scale=None, + w2_scale=None, + a1_scale=None, + a2_scale=None, + block_shape=None, +): + if block_shape is not None: + return fused_moe_vllm( + x, + w1, + w2, + input_gating, + topk, + renormalize=True, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + ) + else: + return fused_moe_vllm( + x, + w1, + w2, + input_gating, + topk, + renormalize=True, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + + +def fused_moe_sglang_api( + x, + w1, + w2, + input_gating, + topk, + use_fp8_w8a8=False, + w1_scale=None, + w2_scale=None, + a1_scale=None, + a2_scale=None, + block_shape=None, +): + return fused_moe_sglang( + x, + w1, + w2, + input_gating, + topk, + renormalize=True, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + ) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 8, 16, 32, 64, 128], + line_arg="provider", + line_vals=[ + "vllm_fused_moe_triton", + "sglang_fused_moe_triton", + "lightllm_fused_moe_triton", + ], + line_names=[ + "vllm_fused_moe_triton", + "sglang_fused_moe_triton", + "lightllm_fused_moe_triton", + ], + styles=[ + ("blue", "-"), + ("green", "-"), + ("red", "-"), + ], + ylabel="Time (ms)", + plot_name="fused-moe-performance", + args={}, + ) +) +def benchmark(batch_size, provider, model_config, use_fp8=False): + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + + num_tokens = batch_size + num_experts = model_config["num_experts"] + hidden_size = model_config["hidden_size"] + shard_intermediate_size = model_config["shard_intermediate_size"] + topk = model_config["topk"] + dtype = model_config["dtype"] + block_shape = getattr(model_config, "block_shape", None) + block_shape = [128, 128] + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + w1_scale = w2_scale = a1_scale = a2_scale = None + + if use_fp8: + init_dtype = dtype + w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype) + w2 = torch.randn(num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype) + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + + if block_shape is None: + w1_scale = torch.randn(num_experts, dtype=torch.float32) + w2_scale = torch.randn(num_experts, dtype=torch.float32) + a1_scale = torch.randn(1, dtype=torch.float32) + a2_scale = torch.randn(1, dtype=torch.float32) + else: + block_n, block_k = block_shape[0], block_shape[1] + n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n + n_tiles_w2 = (hidden_size + block_n - 1) // block_n + k_tiles_w1 = (hidden_size + block_k - 1) // block_k + k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k + w1_scale = torch.rand((num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32) + w2_scale = torch.rand((num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + else: + w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype) + w2 = torch.randn(num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype) + + input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) + + # Warmup + api_func = ( + fused_moe_vllm_api + if provider == "vllm_fused_moe_triton" + else fused_moe_sglang_api + if provider == "lightllm_fused_moe_triton" + else fused_moe_lightllm_api + ) + for _ in range(10): + api_func( + x, + w1, + w2, + input_gating, + topk, + use_fp8_w8a8=use_fp8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + ) + torch.cuda.synchronize() + + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: api_func( + x, + w1, + w2, + input_gating, + topk, + use_fp8_w8a8=use_fp8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + )[0], + quantiles=quantiles, + ) + return ms, min_ms, max_ms + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1") + parser.add_argument("--tp-size", type=int, default=8) + parser.add_argument("--use-fp8", action="store_true") + parser.add_argument( + "--save-path", + type=str, + default="./configs/benchmark_ops/vllm_sglang_fused_moe/", + ) + args = parser.parse_args() + + model_config = get_model_config(args.model, args.tp_size) + benchmark.run( + show_plots=True, + print_data=True, + save_path=args.save_path, + model_config=model_config, + use_fp8=args.use_fp8, + ) + + +if __name__ == "__main__": + main() diff --git a/test/benchmark_client.py b/test/benchmark/service/benchmark_client.py similarity index 100% rename from test/benchmark_client.py rename to test/benchmark/service/benchmark_client.py diff --git a/test/benchmark_mcq.py b/test/benchmark/service/benchmark_mcq.py similarity index 68% rename from test/benchmark_mcq.py rename to test/benchmark/service/benchmark_mcq.py index 51cdee830..828a970cc 100644 --- a/test/benchmark_mcq.py +++ b/test/benchmark/service/benchmark_mcq.py @@ -26,13 +26,13 @@ import aiohttp import numpy as np -from transformers import PreTrainedTokenizerBase from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast QUESTION = {} + + def get_tokenizer( tokenizer_name: str, tokenizer_mode: str = "slow", @@ -42,25 +42,21 @@ def get_tokenizer( """Gets a tokenizer for the given model name via Huggingface.""" if tokenizer_mode == "slow": if kwargs.get("use_fast", False): - raise ValueError( - "Cannot use the fast tokenizer in slow tokenizer mode.") + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = True if "llama" in tokenizer_name.lower() and kwargs.get("use_fast", True): pass try: - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, *args, - **kwargs) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, *args, **kwargs) except TypeError as e: - err_msg = ( - "Failed to load the tokenizer. If you are using a LLaMA-based " - f"model, use '{_FAST_LLAMA_TOKENIZER}' instead of the original " - "tokenizer.") + err_msg = "Failed to load the tokenizer. {e}" raise RuntimeError(err_msg) from e if not isinstance(tokenizer, PreTrainedTokenizerFast): pass return tokenizer + # (prompt len, output len, latency) REQUEST_LATENCY: List[Tuple[int, int, float]] = [] @@ -73,11 +69,10 @@ def sample_requests( data = [] with open(dataset_path, "r") as f: questions = f.readlines() - gts = {} for question in questions: question = json.loads(question.strip()) file_name = question["file_name"].split(".")[0] - data.append((file_name, question['question_id'], question['instruction'], question['answer'])) + data.append((file_name, question["question_id"], question["instruction"], question["answer"])) if file_name not in QUESTION: QUESTION[file_name] = {} QUESTION[file_name][question["question_id"]] = [question["answer"]] @@ -107,25 +102,22 @@ async def send_request( output_len: int, port: int, ) -> None: - request_start_time = time.time() - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} headers = {"User-Agent": "Benchmark Client"} - file_name, question_id, inputs, answer = request - prompt = f"<系统> <对话历史> <知识> <最新问题> 用户:给出以下问题的答案:\n{inputs} SenseChat:" - print(prompt) - # prompt= "[Round {}]\n\n问:{}\n\n答:".format(1, inputs) - url = f'http://localhost:{port}/generate' + file_name, question_id, inputs, answer = request + prompt = "[Round {}]\n\n问:{}\n\n答:".format(1, inputs) + url = f"http://localhost:{port}/generate" data = { - 'inputs': prompt, - 'parameters': { - 'do_sample': False, - 'ignore_eos': True, - 'max_new_tokens': output_len, - # 'do_sample':True, + "inputs": prompt, + "parameters": { + "do_sample": False, + "ignore_eos": True, + "max_new_tokens": output_len, + # 'do_sample':True, # 'top_p':0.8, # 'temperature':0.8 - # 'temperature': 0.1, - } + # 'temperature': 0.1, + }, } timeout = aiohttp.ClientTimeout(total=3 * 3600) async with aiohttp.ClientSession(timeout=timeout) as session: @@ -140,6 +132,7 @@ async def send_request( if "error" not in output: break + async def benchmark( input_requests: List[Tuple[str, int, int]], request_rate: float, @@ -153,18 +146,18 @@ async def benchmark( def IsOpen(ip, port): - s = socket.socket(socket.AF_INET,socket.SOCK_STREAM) - index=1 + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: - s.connect((ip,int(port))) + s.connect((ip, int(port))) s.shutdown(2) - print('successfully launch model') + print("successfully launch model") return True except: time.sleep(10) return False + def main(args: argparse.Namespace): print(args) random.seed(args.seed) @@ -172,7 +165,6 @@ def main(args: argparse.Namespace): tokenizer = get_tokenizer(args.tokenizer, "slow") input_requests = sample_requests(args.dataset, tokenizer) - benchmark_start_time = time.time() asyncio.run(benchmark(input_requests, args.request_rate, args.port)) rights, alls = 0, 0 for file_name in QUESTION: @@ -186,19 +178,19 @@ def main(args: argparse.Namespace): if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Benchmark the online serving throughput.") - parser.add_argument("--dataset", type=str, required=True, - help="Path to the dataset.") - parser.add_argument("--tokenizer", type=str, required=True, - help="Name or path of the tokenizer.") - parser.add_argument("--request-rate", type=float, default=float("inf"), - help="Number of requests per second. If this is inf, " - "then all the requests are sent at time 0. " - "Otherwise, we use Poisson process to synthesize " - "the request arrival times.") - parser.add_argument("--port", type=int, default=8000, - help="port number") + parser = argparse.ArgumentParser(description="Benchmark the online serving throughput.") + parser.add_argument("--dataset", type=str, required=True, help="Path to the dataset.") + parser.add_argument("--tokenizer", type=str, required=True, help="Name or path of the tokenizer.") + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, " + "then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize " + "the request arrival times.", + ) + parser.add_argument("--port", type=int, default=8000, help="port number") parser.add_argument("--seed", type=int, default=0) args = parser.parse_args() main(args) diff --git a/test/benchmark_qps.py b/test/benchmark/service/benchmark_qps.py similarity index 100% rename from test/benchmark_qps.py rename to test/benchmark/service/benchmark_qps.py diff --git a/test/benchmark_serving.py b/test/benchmark/service/benchmark_sharegpt.py similarity index 71% rename from test/benchmark_serving.py rename to test/benchmark/service/benchmark_sharegpt.py index 9cde7fd8d..c9f92f098 100644 --- a/test/benchmark_serving.py +++ b/test/benchmark/service/benchmark_sharegpt.py @@ -25,11 +25,10 @@ import aiohttp import numpy as np -from transformers import PreTrainedTokenizerBase from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast + def get_tokenizer( tokenizer_name: str, @@ -40,26 +39,22 @@ def get_tokenizer( """Gets a tokenizer for the given model name via Huggingface.""" if tokenizer_mode == "slow": if kwargs.get("use_fast", False): - raise ValueError( - "Cannot use the fast tokenizer in slow tokenizer mode.") + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False if "llama" in tokenizer_name.lower() and kwargs.get("use_fast", True): pass try: - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, *args, - **kwargs) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, *args, **kwargs) except TypeError as e: - err_msg = ( - "Failed to load the tokenizer. If you are using a LLaMA-based " - f"model, use '{_FAST_LLAMA_TOKENIZER}' instead of the original " - "tokenizer.") + err_msg = "Failed to load the tokenizer. {e}" raise RuntimeError(err_msg) from e if not isinstance(tokenizer, PreTrainedTokenizerFast): pass return tokenizer + # (prompt len, output len, latency) REQUEST_LATENCY: List[Tuple[int, int, float]] = [] @@ -73,23 +68,18 @@ def sample_requests( with open(dataset_path) as f: dataset = json.load(f) # Filter out the conversations with less than 2 turns. - dataset = [ - data for data in dataset - if len(data["conversations"]) >= 2 - ] + dataset = [data for data in dataset if len(data["conversations"]) >= 2] # Only keep the first two turns of each conversation. - dataset = [ - (data["conversations"][0]["value"], data["conversations"][1]["value"]) - for data in dataset - ] - + dataset = [(data["conversations"][0]["value"], data["conversations"][1]["value"]) for data in dataset] + print("read data set finish") # Tokenize the prompts and completions. import random + dataset = random.sample(dataset, num_requests * 3) prompts = [prompt for prompt, _ in dataset] completions = [completion for _, completion in dataset] - + prompt_token_ids = tokenizer(prompts).input_ids completion_token_ids = tokenizer(completions).input_ids tokenized_dataset = [] @@ -135,26 +125,21 @@ async def get_request( await asyncio.sleep(interval) -async def send_request( - prompt: str, - prompt_len: int, - output_len: int -) -> None: +async def send_request(prompt: str, prompt_len: int, output_len: int) -> None: request_start_time = time.time() - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} headers = {"User-Agent": "Benchmark Client"} - url = 'http://localhost:8000/generate' - + url = "http://localhost:8000/generate" + data = { - 'inputs': prompt, - 'parameters': { - 'do_sample': False, - 'ignore_eos': True, - 'max_new_tokens': output_len, - # 'temperature': 0.1, - } + "inputs": prompt, + "parameters": { + "do_sample": False, + "ignore_eos": True, + "max_new_tokens": output_len, + # 'temperature': 0.1, + }, } - timeout = aiohttp.ClientTimeout(total=3 * 3600) async with aiohttp.ClientSession(timeout=timeout) as session: @@ -165,7 +150,7 @@ async def send_request( chunks.append(chunk) output = b"".join(chunks).decode("utf-8") output = json.loads(output) - + if "error" not in output: break @@ -181,8 +166,7 @@ async def benchmark( tasks: List[asyncio.Task] = [] async for request in get_request(input_requests, request_rate): prompt, prompt_len, output_len = request - task = asyncio.create_task(send_request(prompt, - prompt_len, output_len)) + task = asyncio.create_task(send_request(prompt, prompt_len, output_len)) tasks.append(task) await asyncio.gather(*tasks) @@ -204,33 +188,28 @@ def main(args: argparse.Namespace): # Compute the latency statistics. avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY]) print(f"Average latency: {avg_latency:.2f} s") - avg_per_token_latency = np.mean([ - latency / (prompt_len + output_len) - for prompt_len, output_len, latency in REQUEST_LATENCY - ]) + avg_per_token_latency = np.mean( + [latency / (prompt_len + output_len) for prompt_len, output_len, latency in REQUEST_LATENCY] + ) print(f"Average latency per token: {avg_per_token_latency:.2f} s") - avg_per_output_token_latency = np.mean([ - latency / output_len - for _, output_len, latency in REQUEST_LATENCY - ]) - print("Average latency per output token: " - f"{avg_per_output_token_latency:.2f} s") + avg_per_output_token_latency = np.mean([latency / output_len for _, output_len, latency in REQUEST_LATENCY]) + print("Average latency per output token: " f"{avg_per_output_token_latency:.2f} s") if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Benchmark the online serving throughput.") - parser.add_argument("--dataset", type=str, required=True, - help="Path to the dataset.") - parser.add_argument("--tokenizer", type=str, required=True, - help="Name or path of the tokenizer.") - parser.add_argument("--request-rate", type=float, default=float("inf"), - help="Number of requests per second. If this is inf, " - "then all the requests are sent at time 0. " - "Otherwise, we use Poisson process to synthesize " - "the request arrival times.") - parser.add_argument("--num-prompts", type=int, default=1000, - help="Number of prompts to process.") + parser = argparse.ArgumentParser(description="Benchmark the online serving throughput.") + parser.add_argument("--dataset", type=str, required=True, help="Path to the dataset.") + parser.add_argument("--tokenizer", type=str, required=True, help="Name or path of the tokenizer.") + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, " + "then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize " + "the request arrival times.", + ) + parser.add_argument("--num-prompts", type=int, default=1000, help="Number of prompts to process.") parser.add_argument("--seed", type=int, default=0) args = parser.parse_args() main(args) From 9665cccdf09d5acad61898acbcaf96430043a2ca Mon Sep 17 00:00:00 2001 From: baishihao Date: Mon, 16 Jun 2025 19:22:45 +0800 Subject: [PATCH 06/14] improve the rotary_embed --- .../deepseek2/triton_kernel/rotary_emb.py | 107 +++++++++++------- 1 file changed, 65 insertions(+), 42 deletions(-) diff --git a/lightllm/models/deepseek2/triton_kernel/rotary_emb.py b/lightllm/models/deepseek2/triton_kernel/rotary_emb.py index 93ff323f3..6f2e333db 100644 --- a/lightllm/models/deepseek2/triton_kernel/rotary_emb.py +++ b/lightllm/models/deepseek2/triton_kernel/rotary_emb.py @@ -5,59 +5,52 @@ @triton.jit -def _rotary_kernel( +def _rotary_kernel_q( Q, - K, Cos, Sin, stride_qbs, stride_qh, stride_qd, - stride_kbs, - stride_kh, - stride_kd, stride_cosbs, stride_cosd, stride_sinbs, stride_sind, max_total_len, HEAD_Q, - HEAD_K, # N_CTX 代表要计算的上下文长度 BLOCK_HEAD: tl.constexpr, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, ): cur_head_index = tl.program_id(0) + if cur_head_index >= HEAD_Q: + return cur_seq_index = tl.program_id(1) - cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 dim_range1 = dim_range0 + 1 off_q0 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range0[None, None, :] * stride_qd + cur_seq_range[:, None, None] * stride_qbs + cur_head_index * stride_qh + dim_range0[None, None, :] * stride_qd ) off_q1 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range1[None, None, :] * stride_qd + cur_seq_range[:, None, None] * stride_qbs + cur_head_index * stride_qh + dim_range1[None, None, :] * stride_qd ) + mask = cur_seq_range[:, None, None] < max_total_len cos_range = tl.arange(0, BLOCK_DMODEL // 2) off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd q0 = tl.load( Q + off_q0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), + mask=mask, other=0.0, ) q1 = tl.load( Q + off_q1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), + mask=mask, other=0.0, ) @@ -67,34 +60,51 @@ def _rotary_kernel( out0 = q0 * cos - q1 * sin out1 = q0 * sin + q1 * cos - tl.store( - Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) - ) - tl.store( - Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) - ) + tl.store(Q + off_q0, out0, mask=mask) + tl.store(Q + off_q1, out1, mask=mask) + return - off_k0 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range0[None, None, :] * stride_kd - ) - off_k1 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range1[None, None, :] * stride_kd - ) + +@triton.jit +def _rotary_kernel_k( + K, + Cos, + Sin, + stride_kbs, + stride_kh, + stride_kd, + stride_cosbs, + stride_cosd, + stride_sinbs, + stride_sind, + max_total_len, + HEAD_K, # HEAD_K is 1. + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_seq_index = tl.program_id(0) + + cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + + dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 + dim_range1 = dim_range0 + 1 + + cos_range = tl.arange(0, BLOCK_DMODEL // 2) + off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd + + off_k0 = cur_seq_range[:, None, None] * stride_kbs + dim_range0[None, None, :] * stride_kd + off_k1 = cur_seq_range[:, None, None] * stride_kbs + dim_range1[None, None, :] * stride_kd off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd k0 = tl.load( K + off_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + mask=(cur_seq_range[:, None, None] < max_total_len), other=0.0, ) k1 = tl.load( K + off_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + mask=(cur_seq_range[:, None, None] < max_total_len), other=0.0, ) @@ -107,12 +117,12 @@ def _rotary_kernel( tl.store( K + off_k0, out_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + mask=(cur_seq_range[:, None, None] < max_total_len), ) tl.store( K + off_k1, out_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + mask=(cur_seq_range[:, None, None] < max_total_len), ) return @@ -126,21 +136,36 @@ def rotary_emb_fwd(q, k, cos, sin): assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" BLOCK_SEQ = 16 - BLOCK_HEAD = 4 + BLOCK_HEAD = 2 if head_dim >= 128: num_warps = 8 else: num_warps = 4 - - grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - _rotary_kernel[grid]( + grid = (triton.next_power_of_2(head_num_q), triton.cdiv(total_len, BLOCK_SEQ)) + _rotary_kernel_q[grid]( q, - k, cos, sin, q.stride(0), q.stride(1), q.stride(2), + cos.stride(0), + cos.stride(1), + sin.stride(0), + sin.stride(1), + total_len, + head_num_q, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=head_dim, + num_warps=num_warps, + num_stages=1, + ) + grid = (triton.cdiv(total_len, BLOCK_SEQ),) + _rotary_kernel_k[grid]( + k, + cos, + sin, k.stride(0), k.stride(1), k.stride(2), @@ -149,9 +174,7 @@ def rotary_emb_fwd(q, k, cos, sin): sin.stride(0), sin.stride(1), total_len, - head_num_q, head_num_k, - BLOCK_HEAD=BLOCK_HEAD, BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=head_dim, num_warps=num_warps, From c3701e99e4f7e9aaea5c466935cd5fe4e5037905 Mon Sep 17 00:00:00 2001 From: baishihao Date: Mon, 16 Jun 2025 19:23:57 +0800 Subject: [PATCH 07/14] improve per_token_quant --- .../common/quantization/deepgemm_quant.py | 6 +-- .../triton_quant/fp8/fp8act_quant_kernel.py | 40 ++++++++++++++++--- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/lightllm/common/quantization/deepgemm_quant.py b/lightllm/common/quantization/deepgemm_quant.py index 622a9711c..3bfbadc33 100644 --- a/lightllm/common/quantization/deepgemm_quant.py +++ b/lightllm/common/quantization/deepgemm_quant.py @@ -49,12 +49,12 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ m, k = input_tensor.shape n = weights[0].shape[1] if input_scale is None: - input_scale = torch.empty((m, k // self.block_size), dtype=torch.float32, device=input_tensor.device) qinput_tensor = self.cache_manager.alloc_tensor( (m, k), qweight.dtype, device=qweight.device, is_graph_out=False ) - per_token_group_quant_fp8(input_tensor, self.block_size, qinput_tensor, input_scale) - input_scale = tma_align_input_scale(input_scale) + _, input_scale = per_token_group_quant_fp8( + input_tensor, self.block_size, qinput_tensor, column_major_scales=True, scale_tma_aligned=True + ) if out is None: if use_custom_tensor_mananger: diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py b/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py index aa3b5f61d..760cda137 100644 --- a/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py +++ b/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py @@ -6,7 +6,7 @@ from lightllm.utils.sgl_utils import HAS_SGL_KERNEL, sgl_ops from frozendict import frozendict from functools import lru_cache -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple try: from deep_gemm import ceil_div @@ -109,17 +109,46 @@ def per_token_group_quant_fp8( x: torch.Tensor, group_size: int, x_q: torch.Tensor, - x_s: torch.Tensor, + x_s: torch.Tensor = None, eps: float = 1e-10, dtype: torch.dtype = torch.float8_e4m3fn, + column_major_scales: bool = False, + scale_tma_aligned: bool = False, + alloc_func: Callable = torch.empty, ): + # Adapted from + # https://github.com/sgl-project/sglang/blob/7e257cd666c0d639626487987ea8e590da1e9395/python/sglang/srt/layers/quantization/fp8_kernel.py#L290 if HAS_SGL_KERNEL: finfo = torch.finfo(dtype) fp8_max, fp8_min = finfo.max, finfo.min + if column_major_scales: + if scale_tma_aligned: + # aligned to 4 * sizeof(float) + aligned_size = (x.shape[-2] + 3) // 4 * 4 + x_s = alloc_func( + x.shape[:-2] + (x.shape[-1] // group_size, aligned_size), + device=x.device, + dtype=torch.float32, + ).permute(-1, -2)[: x.shape[-2], :] + else: + x_s = alloc_func( + (x.shape[-1] // group_size,) + x.shape[:-1], + device=x.device, + dtype=torch.float32, + ).permute(-1, -2) + else: + if x_s is None: + x_s = alloc_func( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) sgl_ops.sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, 1e-10, fp8_min, fp8_max) else: lightllm_per_token_group_quant_fp8(x, group_size, x_q, x_s, eps=1e-10, dtype=torch.float8_e4m3fn) + return x_q, x_s + # copy from # https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58 @@ -229,8 +258,8 @@ def test_per_token_group_quant_fp8(): x_q = torch.randn((1024, 8192)).cuda().to(torch.float8_e4m3fn) # x_s = torch.randn((1024, 8192 // group_size), dtype=torch.float32).cuda() - x_s = torch.randn((8192 // group_size, 1024 + 10), dtype=torch.float32).cuda().t() - per_token_group_quant_fp8(x, group_size, x_q, x_s) + # x_s = torch.randn((8192 // group_size, 1024 + 10), dtype=torch.float32).cuda().t() + _, x_s = per_token_group_quant_fp8(x, group_size, x_q, None, column_major_scales=True) x_s = x_s[:1024] th_x_q, th_x_s = torch_quant(x, group_size) print("th_x_s - x_s", torch.abs(th_x_s - x_s.reshape(-1)).max()) @@ -238,4 +267,5 @@ def test_per_token_group_quant_fp8(): if __name__ == "__main__": - test_tma_align() + test_per_token_group_quant_fp8() + # test_tma_align() From 14eb826e6b61d228db213b2435315285f1c4616a Mon Sep 17 00:00:00 2001 From: baishihao Date: Mon, 16 Jun 2025 20:00:56 +0800 Subject: [PATCH 08/14] silu improve --- lightllm/common/fused_moe/moe_silu_and_mul.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/lightllm/common/fused_moe/moe_silu_and_mul.py b/lightllm/common/fused_moe/moe_silu_and_mul.py index 3f6bdb44f..7080c22a7 100644 --- a/lightllm/common/fused_moe/moe_silu_and_mul.py +++ b/lightllm/common/fused_moe/moe_silu_and_mul.py @@ -54,6 +54,54 @@ def _silu_and_mul_kernel( ) +@triton.jit +def _silu_and_mul_kernel_fast( + input_ptr, + output_ptr, + stride_input_m, + stride_input_n, + stride_output_m, + stride_output_n, + size_n, + BLOCK_N: tl.constexpr, + NEED_MASK: tl.constexpr, +): + stride_input_m = tl.cast(stride_input_m, dtype=tl.int64) + stride_output_m = tl.cast(stride_output_m, dtype=tl.int64) + + cur_batch = tl.program_id(0) + pid = tl.program_id(1) + n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N) + + up_offsets = cur_batch * stride_input_m + (n_offsets[None, :] + size_n) + gate_offsets = cur_batch * stride_input_m + n_offsets[None, :] + res_offsets = cur_batch * stride_output_m + n_offsets[None, :] + if NEED_MASK: + mask = n_offsets[None, :] < size_n + else: + mask = True + + up = tl.load( + input_ptr + up_offsets, + mask=mask, + other=0.0, + ) + gate = tl.load( + input_ptr + gate_offsets, + mask=mask, + other=0.0, + ).to(tl.float32) + + gate = gate / (1 + tl.exp(-gate)) + gate = gate.to(input_ptr.dtype.element_ty) + + tl.store( + output_ptr + res_offsets, + up * gate, + mask=mask, + ) + + def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config): assert input.is_contiguous() assert output.is_contiguous() @@ -68,6 +116,26 @@ def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config): if not run_config: run_config = MoeSiluAndMulKernelConfig.try_to_get_best_config(M=size_m, N=size_n, out_dtype=str(output.dtype)) + if size_m <= 1024: + BLOCK_N = run_config["BLOCK_N"] + grid = ( + size_m, + triton.cdiv(size_n, BLOCK_N), + ) + NEED_MASK = size_n % BLOCK_N != 0 + _silu_and_mul_kernel_fast[grid]( + input, + output, + stride_input_m, + stride_input_n, + stride_output_m, + stride_output_n, + size_n, + BLOCK_N=BLOCK_N, + NEED_MASK=NEED_MASK, + ) + return + BLOCK_M = run_config["BLOCK_M"] BLOCK_N = run_config["BLOCK_N"] num_warps = run_config["num_warps"] From c1a02acdd9b5e3161f977fe928bc4c8884d65e4c Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 19 Jun 2025 22:06:22 +0800 Subject: [PATCH 09/14] update --- lightllm/common/basemodel/cuda_graph.py | 2 +- lightllm/common/fused_moe/moe_silu_and_mul.py | 2 +- .../layer_infer/transformer_layer_infer.py | 1 - .../layer_infer/transformer_layer_infer.py | 2 +- lightllm/utils/envs_utils.py | 23 +++++++++++-------- lightllm/utils/sgl_utils.py | 12 ++++++++++ 6 files changed, 28 insertions(+), 14 deletions(-) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 41cd59039..81b7555c5 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -128,7 +128,7 @@ def replay(self, input_ids, infer_state, input_ids1=None, infer_state1=None): @torch.no_grad() def warmup(self, model): logger.info("Begin capture cudagraph, use the --disable_cudagraph to disable it.") - for batch_size in range(self.max_batch_size, 0, -1): + for batch_size in range(self.max_batch_size, self.max_batch_size - 1, -1): # dummy prefill prefill_input_len = 1 dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda") diff --git a/lightllm/common/fused_moe/moe_silu_and_mul.py b/lightllm/common/fused_moe/moe_silu_and_mul.py index 7080c22a7..5c62dbc90 100644 --- a/lightllm/common/fused_moe/moe_silu_and_mul.py +++ b/lightllm/common/fused_moe/moe_silu_and_mul.py @@ -116,7 +116,7 @@ def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config): if not run_config: run_config = MoeSiluAndMulKernelConfig.try_to_get_best_config(M=size_m, N=size_n, out_dtype=str(output.dtype)) - if size_m <= 1024: + if size_m <= 4096: BLOCK_N = run_config["BLOCK_N"] grid = ( size_m, diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index ba752a4e8..d34518e02 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -19,7 +19,6 @@ from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding_fp8 import gqa_token_decode_attention_flash_decoding_fp8 from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 125134659..6e277a528 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -16,7 +16,7 @@ from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2, token_att_fwd2_int8v from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd +from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index d223931ed..a5967dad8 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -26,7 +26,9 @@ def get_unique_server_name(): def set_cuda_arch(args): if not torch.cuda.is_available(): return - if args.enable_flashinfer_prefill or args.enable_flashinfer_decode: + from lightllm.utils.sgl_utils import HAS_FLASHINFER + + if HAS_FLASHINFER: capability = torch.cuda.get_device_capability() arch = f"{capability[0]}.{capability[1]}" os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}" @@ -77,15 +79,16 @@ def get_lightllm_websocket_max_message_size(): return int(os.getenv("LIGHTLLM_WEBSOCKET_MAX_SIZE", 16 * 1024 * 1024)) -# get_redundancy_expert_ids and get_redundancy_expert_num are primarily used to obtain the IDs and number of redundant experts during inference. -# They depend on a configuration file specified by ep_redundancy_expert_config_path, which is a JSON formatted text file. -# The content format is as follows: -# { -# "redundancy_expert_num": 1, # Number of redundant experts per rank -# "0": [0], # Key: layer_index (string), Value: list of original expert IDs that are redundant for this layer -# "1": [0], -# "default": [0] # Default list of redundant expert IDs if layer-specific entry is not found -# } +# get_redundancy_expert_ids and get_redundancy_expert_num are primarily used to obtain the IDs +# and number of redundant experts during inference. They depend on a configuration file specified +# by ep_redundancy_expert_config_path, which is a JSON formatted text file. +# The content format is as follows: +# { +# "redundancy_expert_num": 1, # Number of redundant experts per rank +# "0": [0], # Key: layer_index (string), Value: list of redundant expert IDs of this layer +# "1": [0], +# "default": [0] # Default list of redundant expert IDs if layer-specific entry is not found +# } @lru_cache(maxsize=None) diff --git a/lightllm/utils/sgl_utils.py b/lightllm/utils/sgl_utils.py index b48a62506..3a183c47e 100644 --- a/lightllm/utils/sgl_utils.py +++ b/lightllm/utils/sgl_utils.py @@ -30,3 +30,15 @@ "sgl_kernel is not installed, or the installed version did not support fa3. \ Try to upgrade it." ) + +try: + import flashinfer + from flashinfer.norm import fused_add_rmsnorm, rmsnorm + + HAS_FLASHINFER = True +except: + HAS_FLASHINFER = False + logger.warning( + "flashinfer is not installed, you can't use the api of it. \ + You can solve it by running `pip install flashinfer`." + ) From 432c5211da07ebac5f0e66734b14d77cedd15de9 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 19 Jun 2025 14:16:04 +0000 Subject: [PATCH 10/14] add config --- ...at16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json | 1 + ...at16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json | 1 + ...t16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 1 + ...at16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json | 1 + ...at16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json | 1 + ...t16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 1 + 6 files changed, 6 insertions(+) create mode 100644 lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=4096,N=192,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=96,N=4096,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..bd2a5c76e --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 2}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..b9a717cee --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=4096,N=192,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=4096,N=192,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..cd4b2b79e --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=4096,N=192,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 5}, "8192": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 5}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..5e2f44cb0 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..457d72dc8 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=96,N=4096,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=96,N=4096,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..217515264 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=96,N=4096,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 5}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4, "num_warps": 2, "num_stages": 4}, "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "num_warps": 4, "num_stages": 4}, "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "num_warps": 4, "num_stages": 4}, "8192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "num_warps": 4, "num_stages": 4}} \ No newline at end of file From c1fe8122cd1a9b600e95592f67f652c4c0ae67c3 Mon Sep 17 00:00:00 2001 From: baishihao Date: Tue, 24 Jun 2025 19:59:35 +0800 Subject: [PATCH 11/14] remove unused scripts --- test/start_scripts/deepseek.sh | 87 --------------------------- test/start_scripts/test.sh | 107 --------------------------------- 2 files changed, 194 deletions(-) delete mode 100644 test/start_scripts/deepseek.sh delete mode 100644 test/start_scripts/test.sh diff --git a/test/start_scripts/deepseek.sh b/test/start_scripts/deepseek.sh deleted file mode 100644 index 78e40a116..000000000 --- a/test/start_scripts/deepseek.sh +++ /dev/null @@ -1,87 +0,0 @@ -# 单机 deepseek V3 ep 运行模式启动示例, 启动参数中的tp含义发生了变化,代表使用的所有卡数量,并不是tp推理。 -# max_total_token_num 可以按照实际场景调节。 -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 --model_dir /dev/shm/DeepSeek-R1 \ ---tp 8 \ ---dp 8 \ ---max_total_token_num 200000 \ ---graph_max_batch_size 64 \ ---batch_max_tokens 8192 \ ---enable_flashinfer_prefill \ ---enable_flashinfer_decode \ ---enable_prefill_microbatch_overlap \ ---disable_aggressive_schedule - -# H800 双机 deepseek V3 ep 运行模式启动实列 -# 启动命令中的 nccl_host 和 nccl_port 两个节点的必须一致,一般nccl_host设置为 node 0的ip。 -# max_total_token_num 最佳设置需要按照使用场景和显存情况配置。 -# 启动后两个节点的8088端口都可以接收访问的请求 -# node 0 -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 --model_dir /dev/shm/DeepSeek-R1 \ ---tp 16 \ ---dp 16 \ ---max_total_token_num 200000 \ ---graph_max_batch_size 64 \ ---batch_max_tokens 8192 \ ---enable_flashinfer_prefill \ ---enable_flashinfer_decode \ ---enable_prefill_microbatch_overlap \ ---nnodes 2 \ ---node_rank 0 \ ---nccl_host \ ---nccl_port 2732 -# node 1 -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 --model_dir /dev/shm/DeepSeek-R1 \ ---tp 16 \ ---dp 16 \ ---max_total_token_num 200000 \ ---graph_max_batch_size 64 \ ---batch_max_tokens 8192 \ ---enable_flashinfer_prefill \ ---enable_flashinfer_decode \ ---enable_prefill_microbatch_overlap \ ---nnodes 2 \ ---node_rank 1 \ ---nccl_host \ ---nccl_port 2732 - -# pd 分离启动示列, 单机 做 P 和 D, 也支持多机组成的D和单机的P混合。 -# 目前 P D 分离的 PD master可能存在并发处理问题,还需提升。 - -# pd master 启动 -python -m lightllm.server.api_server --model_dir /dev/shm/DeepSeek-R1 --run_mode "pd_master" --host `hostname -i` --port 60011 - -# p 启动 -nvidia-cuda-mps-control -d -MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server --model_dir /dev/shm/DeepSeek-R1 \ ---run_mode "prefill" \ ---tp 8 \ ---dp 8 \ ---host `hostname -i` \ ---port 8019 \ ---nccl_port 2732 \ ---max_total_token_num 200000 \ ---batch_max_tokens 8192 \ ---enable_flashinfer_prefill \ ---enable_flashinfer_decode \ ---enable_prefill_microbatch_overlap \ ---disable_cudagraph \ ---pd_master_ip \ ---pd_master_port 60011 - -# d 启动 -nvidia-cuda-mps-control -d -MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server --model_dir /dev/shm/DeepSeek-R1 \ ---run_mode "decode" \ ---tp 8 \ ---dp 8 \ ---host `hostname -i` \ ---port 8121 \ ---nccl_port 12322 \ ---max_total_token_num 200000 \ ---graph_max_batch_size 64 \ ---enable_flashinfer_prefill \ ---enable_flashinfer_decode \ ---enable_prefill_microbatch_overlap \ ---pd_master_ip \ ---pd_master_port 60011 - diff --git a/test/start_scripts/test.sh b/test/start_scripts/test.sh deleted file mode 100644 index 8f3882386..000000000 --- a/test/start_scripts/test.sh +++ /dev/null @@ -1,107 +0,0 @@ -# pd start -python -m lightllm.server.api_server --model_dir /dev/shm/llama2-7b-chat --run_mode "pd_master" --host `hostname -i` --port 60011 - -nvidia-cuda-mps-control -d -CUDA_VISIBLE_DEVICES=0,1,2,3 KV_TRANS_USE_P2P=1 LOADWORKER=1 python -m lightllm.server.api_server --model_dir /dev/shm/llama2-7b-chat \ ---run_mode "prefill" \ ---host `hostname -i` \ ---port 8019 \ ---tp 4 \ ---nccl_port 2732 \ ---max_total_token_num 400000 \ ---tokenizer_mode fast \ ---pd_master_ip `hostname -i` \ ---pd_master_port 60011 \ ---max_req_total_len 16000 \ ---running_max_req_size 128 \ ---disable_cudagraph - -nvidia-cuda-mps-control -d -CUDA_VISIBLE_DEVICES=4,5,6,7 KV_TRANS_USE_P2P=1 LOADWORKER=10 python -m lightllm.server.api_server --model_dir /dev/shm/llama2-7b-chat \ ---run_mode "decode" \ ---host `hostname -i` \ ---port 8121 \ ---nccl_port 12322 \ ---tp 4 \ ---max_total_token_num 400000 \ ---graph_max_len_in_batch 2048 \ ---graph_max_batch_size 16 \ ---tokenizer_mode fast \ ---pd_master_ip `hostname -i` \ ---pd_master_port 60011 - -# pd start1 -python -m lightllm.server.api_server --model_dir /dev/shm/llama2-7b-chat --run_mode "pd_master" --host `hostname -i` --port 60011 - -nvidia-cuda-mps-control -d -CUDA_VISIBLE_DEVICES=0 KV_TRANS_USE_P2P=1 LOADWORKER=1 python -m lightllm.server.api_server --model_dir /dev/shm/llama2-7b-chat \ ---run_mode "prefill" \ ---host `hostname -i` \ ---port 8019 \ ---tp 1 \ ---nccl_port 2732 \ ---max_total_token_num 40000 \ ---tokenizer_mode fast \ ---pd_master_ip `hostname -i` \ ---pd_master_port 60011 \ ---max_req_total_len 16000 \ ---running_max_req_size 128 \ ---disable_cudagraph - -nvidia-cuda-mps-control -d -CUDA_VISIBLE_DEVICES=1 KV_TRANS_USE_P2P=1 LOADWORKER=10 python -m lightllm.server.api_server --model_dir /dev/shm/llama2-7b-chat \ ---run_mode "decode" \ ---host `hostname -i` \ ---port 8121 \ ---nccl_port 12322 \ ---tp 1 \ ---max_total_token_num 40000 \ ---graph_max_len_in_batch 2048 \ ---graph_max_batch_size 16 \ ---tokenizer_mode fast \ ---pd_master_ip `hostname -i` \ ---pd_master_port 60011 - - -# normal start -LOADWORKER=8 python -m lightllm.server.api_server --port 8018 --model_dir /dev/shm/llama2-7b-chat --tp 2 --graph_max_batch_size 16 - - -# 多 pd_master 节点部署实列 -python -m lightllm.server.api_server --run_mode "config_server" --config_server_host 10.120.114.74 --config_server_port 60088 - -python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V2-Lite-Chat --run_mode "pd_master" --host 10.120.114.74 --port 60011 --config_server_host 10.120.114.74 --config_server_port 60088 - -python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V2-Lite-Chat --run_mode "pd_master" --host 10.120.114.74 --port 60012 --config_server_host 10.120.114.74 --config_server_port 60088 - - -nvidia-cuda-mps-control -d -CUDA_VISIBLE_DEVICES=0 KV_TRANS_USE_P2P=1 LOADWORKER=1 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V2-Lite-Chat \ ---run_mode "prefill" \ ---host 10.120.178.74 \ ---port 8019 \ ---tp 1 \ ---nccl_port 2732 \ ---max_total_token_num 40000 \ ---tokenizer_mode fast \ ---max_req_total_len 16000 \ ---running_max_req_size 128 \ ---disable_cudagraph \ ---config_server_host 10.120.114.74 \ ---config_server_port 60088 - -CUDA_VISIBLE_DEVICES=1 KV_TRANS_USE_P2P=1 LOADWORKER=10 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V2-Lite-Chat \ ---run_mode "decode" \ ---host 10.120.178.74 \ ---port 8121 \ ---nccl_port 12322 \ ---tp 1 \ ---max_total_token_num 40000 \ ---graph_max_len_in_batch 2048 \ ---graph_max_batch_size 16 \ ---tokenizer_mode fast \ ---config_server_host 10.120.114.74 \ ---config_server_port 60088 - - - From 038e089d198d1dac589ec0717ebdf49a9c5cb86b Mon Sep 17 00:00:00 2001 From: baishihao Date: Tue, 1 Jul 2025 11:25:01 +0800 Subject: [PATCH 12/14] fix tuning --- test/kernel/fuse_moe_tuning.py | 39 +++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/test/kernel/fuse_moe_tuning.py b/test/kernel/fuse_moe_tuning.py index 6e971573a..dae85a7b5 100644 --- a/test/kernel/fuse_moe_tuning.py +++ b/test/kernel/fuse_moe_tuning.py @@ -7,6 +7,7 @@ from typing import List from lightllm.utils.log_utils import init_logger from transformers import AutoConfig +import torch.nn.functional as F logger = init_logger(__name__) @@ -61,6 +62,7 @@ def test_kernel( use_fp8_w8a8: bool, is_up: bool, block_shape, + num_fused_experts: int, **config, ): set_seed() @@ -68,6 +70,8 @@ def test_kernel( a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1_scale = w2_scale = None + if num_fused_experts > 0: + expert_num += num_fused_experts if use_fp8_w8a8: init_dtype = dtype @@ -91,19 +95,21 @@ def test_kernel( w1 = torch.randn(expert_num, 2 * n, k, dtype=dtype).cuda() w2 = torch.randn(expert_num, k, 2 * n // 2, dtype=dtype).cuda() - rnd_logics = torch.randn(m, expert_num, device="cuda") + rnd_logics = torch.randn(m, expert_num - num_fused_experts, device="cuda") topk_values, topk_ids = torch.topk(rnd_logics, topk, dim=1) - topk_weights = torch.randn((m, topk), device="cuda", dtype=dtype) / 10 + topk_weights = torch.randn((m, topk + num_fused_experts), device="cuda", dtype=dtype) / 10 + if num_fused_experts > 0: + topk_ids = F.pad(topk_ids, (0, 1), mode="constant", value=expert_num) - expert_to_tokens = torch.empty((expert_num, topk * m), dtype=torch.int32, device="cuda") - expert_to_weights = torch.empty((expert_num, topk * m), dtype=torch.float32, device="cuda") + expert_to_tokens = torch.empty((expert_num, (topk + num_fused_experts) * m), dtype=torch.int32, device="cuda") + expert_to_weights = torch.empty((expert_num, (topk + num_fused_experts) * m), dtype=torch.float32, device="cuda") moe_align(topk_ids=topk_ids, out=expert_to_tokens) expert_to_token_num = torch.empty((expert_num,), dtype=torch.int32, device="cuda") - moe_align1(expert_to_tokens, topk_weights, expert_to_weights, expert_to_token_num, topk=topk) + moe_align1(expert_to_tokens, topk_weights, expert_to_weights, expert_to_token_num, topk=topk + 1) - out1 = torch.zeros((m * topk, 2 * n), dtype=torch.bfloat16, device="cuda") - down_in = torch.zeros((m * topk, n), dtype=torch.bfloat16, device="cuda") - out2 = torch.zeros((m * topk, k), dtype=torch.bfloat16, device="cuda") + out1 = torch.zeros((m * (topk + 1), 2 * n), dtype=torch.bfloat16, device="cuda") + down_in = torch.zeros((m * (topk + 1), n), dtype=torch.bfloat16, device="cuda") + out2 = torch.zeros((m * (topk + 1), k), dtype=torch.bfloat16, device="cuda") for _ in range(test_count): input_tuples.append( @@ -219,6 +225,7 @@ def worker( use_fp8_w8a8: bool, is_up: bool, block_shape, + num_fused_experts: int, test_configs, queue, ): @@ -235,6 +242,7 @@ def worker( use_fp8_w8a8=use_fp8_w8a8, is_up=is_up, block_shape=block_shape, + num_fused_experts=num_fused_experts, **test_configs[index], ) queue.put(cost_time) # Put result in queue @@ -302,6 +310,7 @@ def tuning_configs( use_fp8_w8a8: bool, is_up: bool, block_shape, + num_fused_experts: int, ): os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) best_config, best_cost_time = None, 10000000 @@ -325,6 +334,7 @@ def tuning_configs( use_fp8_w8a8, is_up, block_shape, + num_fused_experts, test_configs, queue, ), @@ -359,6 +369,7 @@ def tuning_configs( use_fp8_w8a8, is_up, block_shape, + num_fused_experts, test_configs, queue, ), @@ -393,15 +404,16 @@ def main(args): if config.architectures[0] == "Qwen3MoeForCausalLM": expert_num = config.num_experts topk_num = config.num_experts_per_tok - n = 2 * config.moe_intermediate_size // args.tp + n = config.moe_intermediate_size // args.tp elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: expert_num = config.n_routed_experts topk_num = config.num_experts_per_tok - n = 2 * config.moe_intermediate_size // args.tp + n = config.moe_intermediate_size // args.tp else: pass hidden_dim = getattr(config, "hidden_size", None) or config.text_config.hidden_size + print(n, hidden_dim) use_fp8_w8a8 = args.use_fp8_w8a8 block_shape = None if hasattr(config, "quantization_config") and "weight_block_size" in config.quantization_config: @@ -424,6 +436,7 @@ def main(args): "use_fp8_w8a8": use_fp8_w8a8, "is_up": True, "block_shape": block_shape, + "num_fused_experts": args.num_fused_experts, }, ) up_dict[m] = ans @@ -431,7 +444,7 @@ def main(args): N=n * 2, K=hidden_dim, topk_num=topk_num, - expert_num=expert_num, + expert_num=expert_num + 1, mul_routed_weight=False, use_fp8_w8a8=use_fp8_w8a8, out_dtype=str(torch.bfloat16), @@ -453,6 +466,7 @@ def main(args): "use_fp8_w8a8": use_fp8_w8a8, "is_up": False, "block_shape": block_shape, + "num_fused_experts": args.num_fused_experts, }, ) down_dict[m] = ans @@ -461,7 +475,7 @@ def main(args): N=hidden_dim, K=n, topk_num=1, - expert_num=expert_num, + expert_num=expert_num + 1, mul_routed_weight=True, use_fp8_w8a8=use_fp8_w8a8, out_dtype=str(torch.bfloat16), @@ -474,5 +488,6 @@ def main(args): parser.add_argument("--model_dir", type=str, default="deepseek-ai/DeepSeek-R1") parser.add_argument("--tp", type=int, default=8) parser.add_argument("--use_fp8_w8a8", action="store_true") + parser.add_argument("--num_fused_experts", type=int, default=0) args = parser.parse_args() main(args) From aeb80a5a7f7d610eeb6f24850ac88aeba62db55b Mon Sep 17 00:00:00 2001 From: baishihao Date: Tue, 1 Jul 2025 11:25:27 +0800 Subject: [PATCH 13/14] add fused expert --- ..._num=1,use_fp8_w8a8=true}_NVIDIA_H200.json | 1 + ..._num=8,use_fp8_w8a8=true}_NVIDIA_H200.json | 1 + ..._num=8,use_fp8_w8a8=true}_NVIDIA_H200.json | 1 + .../meta_weights/fused_moe_weight_ep.py | 8 +-- .../meta_weights/fused_moe_weight_tp.py | 10 +++- .../common/fused_moe/grouped_fused_moe.py | 1 - lightllm/common/fused_moe/grouped_topk.py | 5 +- .../common/fused_moe/moe_kernel_configs.py | 10 ++-- lightllm/common/fused_moe/topk_select.py | 2 + .../layer_infer/transformer_layer_infer.py | 5 +- .../layer_weights/transformer_layer_weight.py | 51 +++++++++++++------ lightllm/models/qwen2_vl/vision_process.py | 4 +- .../layer_infer/transformer_layer_infer.py | 13 ++--- lightllm/models/qwen3_moe/model.py | 5 ++ lightllm/server/api_cli.py | 5 ++ 15 files changed, 82 insertions(+), 40 deletions(-) create mode 100644 lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..286de4928 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 2}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..ed56a6fc7 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..f1a0658ba --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py index f7a24ae0f..750c94291 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py @@ -84,17 +84,17 @@ def __init__( self.e_score_correction_bias = None self.w2_list = [None] * ep_load_expert_num self.w2_scale_list = [None] * ep_load_expert_num - self.scoring_func = network_config["scoring_func"] + self.scoring_func = "softmax" # network_config["scoring_func"] self.w1 = [None, None] # weight, weight_scale self.w2 = [None, None] # weight, weight_scale self.use_fp8_w8a8 = self.quant_method is not None - + network_config["n_group"] = 0 self.num_experts_per_tok = network_config["num_experts_per_tok"] self.use_grouped_topk = network_config["n_group"] > 0 self.norm_topk_prob = network_config["norm_topk_prob"] self.n_group = network_config["n_group"] - self.topk_group = network_config["topk_group"] - self.routed_scaling_factor = network_config["routed_scaling_factor"] + self.topk_group = 0 # network_config["topk_group"] + self.routed_scaling_factor = 0 # network_config["routed_scaling_factor"] self.lock = threading.Lock() # init buffer diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py index abb0985b9..5a4e84f82 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py @@ -16,6 +16,7 @@ def __init__( e_score_correction_bias_name: str, weight_prefix: str, n_routed_experts: int, + num_fused_shared_experts: int, split_inter_size: int, data_type: torch.dtype, network_config: Dict[str, Any], @@ -34,7 +35,10 @@ def __init__( self.e_score_correction_bias_name = e_score_correction_bias_name self.weight_prefix = weight_prefix - self.n_routed_experts = n_routed_experts + assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." + self.n_routed_experts = n_routed_experts + num_fused_shared_experts + self.num_fused_shared_experts = num_fused_shared_experts + self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) self.split_inter_size = split_inter_size self.data_type_ = data_type self.tp_rank_ = get_current_rank_in_dp() @@ -63,7 +67,11 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t topk_group=topk_group, num_expert_group=num_expert_group, scoring_func=self.scoring_func, + num_fused_shared_experts=self.num_fused_shared_experts, ) + if self.num_fused_shared_experts > 0: + topk_ids[:, -1] = self.n_routed_experts - 1 + topk_weights[:, -1] = 1.0 / self.routed_scaling_factor w1, w1_scale = self.w1 w2, w2_scale = self.w2 use_fp8_w8a8 = self.quant_method is not None diff --git a/lightllm/common/fused_moe/grouped_fused_moe.py b/lightllm/common/fused_moe/grouped_fused_moe.py index e31e50922..b6c5f123b 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/fused_moe/grouped_fused_moe.py @@ -648,7 +648,6 @@ def fused_experts_impl( CHUNK_SIZE = FFN_MOE_CHUNK_SIZE topk_num = topk_ids.shape[1] M = min(num_tokens, CHUNK_SIZE) - intermediate_cache1 = alloc_tensor_func((M, topk_num, N), device=hidden_states.device, dtype=hidden_states.dtype) intermediate_cache2 = alloc_tensor_func( (M, topk_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype diff --git a/lightllm/common/fused_moe/grouped_topk.py b/lightllm/common/fused_moe/grouped_topk.py index e8eae1b15..b0e7f51a5 100644 --- a/lightllm/common/fused_moe/grouped_topk.py +++ b/lightllm/common/fused_moe/grouped_topk.py @@ -208,6 +208,7 @@ def triton_grouped_topk( topk_group: int = 0, scoring_func: str = "softmax", group_score_used_topk_num=2, + num_fused_shared_experts: int = 0, ): if correction_bias is not None: @@ -222,8 +223,8 @@ def triton_grouped_topk( dtype = torch.float32 scores_buffer = torch.empty((token_num, total_expert_num), dtype=dtype, device="cuda") - out_topk_weights = torch.empty((token_num, topk), dtype=torch.float32, device="cuda") - out_topk_ids = torch.empty((token_num, topk), dtype=torch.long, device="cuda") + out_topk_weights = torch.empty((token_num, topk + num_fused_shared_experts), dtype=torch.float32, device="cuda") + out_topk_ids = torch.empty((token_num, topk + num_fused_shared_experts), dtype=torch.long, device="cuda") assert total_expert_num % num_expert_group == 0 diff --git a/lightllm/common/fused_moe/moe_kernel_configs.py b/lightllm/common/fused_moe/moe_kernel_configs.py index 0b107ede3..3b47b14c3 100644 --- a/lightllm/common/fused_moe/moe_kernel_configs.py +++ b/lightllm/common/fused_moe/moe_kernel_configs.py @@ -42,12 +42,12 @@ def try_to_get_best_config( else: if M <= expert_num: config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 1, + "num_stages": 3, } else: config = { diff --git a/lightllm/common/fused_moe/topk_select.py b/lightllm/common/fused_moe/topk_select.py index ca8d22f48..92303b0c5 100644 --- a/lightllm/common/fused_moe/topk_select.py +++ b/lightllm/common/fused_moe/topk_select.py @@ -181,6 +181,7 @@ def select_experts( num_expert_group: Optional[int] = None, scoring_func: str = "softmax", custom_routing_function: Optional[Callable] = None, + num_fused_shared_experts: int = 0, ): from lightllm.common.fused_moe.topk_select import fused_topk from lightllm.common.fused_moe.grouped_topk import triton_grouped_topk @@ -216,6 +217,7 @@ def select_experts( topk_group=topk_group, scoring_func=scoring_func, group_score_used_topk_num=group_score_topk_num, + num_fused_shared_experts=num_fused_shared_experts, ) elif custom_routing_function is None: diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index d34518e02..5d5cbc55f 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -665,7 +665,8 @@ def _moe_ffn( hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape - if self.n_shared_experts is not None: + # if fused_shared_experts is not enabled, compute shared_output + if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight) router_logits = layer_weight.moe_gate.mm(hidden_states) @@ -681,7 +682,7 @@ def _moe_ffn( hidden_states.mul_(self.routed_scaling_factor) - if self.n_shared_experts is not None: + if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: hidden_states.add_(shared_output) return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 7a9f3c150..dc2b1e285 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -3,7 +3,7 @@ import math import numpy as np from lightllm.common.basemodel import TransformerLayerWeight -from lightllm.utils.envs_utils import enable_env_vars +from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args from lightllm.common.basemodel.layer_weights.meta_weights import ( ROWMMWeight, MultiROWMMWeight, @@ -39,6 +39,9 @@ def _parse_config(self): self.v_head_dim = self.network_config_["v_head_dim"] self.num_attention_heads = self.network_config_["num_attention_heads"] self.kv_lora_rank = self.network_config_["kv_lora_rank"] + self.num_fused_shared_experts = 0 + if get_env_start_args().enable_fused_shared_experts and self.is_moe: + self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0) def _init_weight_names(self): if self.q_lora_rank is None: @@ -96,8 +99,25 @@ def _load_vb_scale(self, kv_b_proj_scale_, block_size): )[:, :, self.qk_nope_head_dim // block_size :].transpose(0, 1) return v_b_proj_scale_.contiguous().to(kv_b_proj_scale_.dtype) + def _rename_shared_experts(self, weights, weight_scale_suffix): + old_prefix = f"model.layers.{self.layer_num_}.mlp.shared_experts" + new_prefix = f"model.layers.{self.layer_num_}.mlp.experts" + proj_names = ["gate_proj", "down_proj", "up_proj"] + for i in range(self.num_fused_shared_experts): + expert_id = self.n_routed_experts + i + for proj in proj_names: + weight_tensor = weights.get(f"{old_prefix}.{proj}.weight") + if weight_tensor is not None: + weights[f"{new_prefix}.{expert_id}.{proj}.weight"] = weight_tensor + if self.quant_cfg.quantized_weight: + scale_tensor = weights.get(f"{old_prefix}.{proj}." + weight_scale_suffix) + if scale_tensor is not None: + weights[f"{new_prefix}.{expert_id}.{proj}." + weight_scale_suffix] = scale_tensor + def load_hf_weights(self, weights): kv_b_quant_method = self.quant_cfg.get_quant_method(self.layer_num_, "kv_b_proj") + if self.quant_cfg.quantized_weight: + weight_scale_suffix = kv_b_quant_method.weight_scale_suffix if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights: kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"] @@ -105,29 +125,27 @@ def load_hf_weights(self, weights): if self.quant_cfg.quantized_weight: kv_b_proj_ = weight_dequant( kv_b_proj_.cuda(), - weights[ - f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + kv_b_quant_method.weight_scale_suffix - ].cuda(), + weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix].cuda(), ).cpu() weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = self._load_kb(kv_b_proj_) weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = self._load_vb(kv_b_proj_) if ( self.quant_cfg.quantized_weight - and f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + kv_b_quant_method.weight_scale_suffix - in weights + and f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix in weights ): - kv_b_proj_scale_ = weights[ - f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + kv_b_quant_method.weight_scale_suffix - ] + kv_b_proj_scale_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix] block_size = 128 - weights[ - f"model.layers.{self.layer_num_}.self_attn.k_b_proj." + kv_b_quant_method.weight_scale_suffix - ] = self._load_kb_scale(kv_b_proj_scale_, block_size) - weights[ - f"model.layers.{self.layer_num_}.self_attn.v_b_proj." + kv_b_quant_method.weight_scale_suffix - ] = self._load_vb_scale(kv_b_proj_scale_, block_size) + weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj." + weight_scale_suffix] = self._load_kb_scale( + kv_b_proj_scale_, block_size + ) + weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj." + weight_scale_suffix] = self._load_vb_scale( + kv_b_proj_scale_, block_size + ) + # rename the shared experts weight + if self.num_fused_shared_experts > 0: + self._rename_shared_experts(weights, weight_scale_suffix) return super().load_hf_weights(weights) def _init_qkvo(self): @@ -198,6 +216,8 @@ def _init_qkvo(self): ) def _load_mlp(self, mlp_prefix): + if self.num_fused_shared_experts > 0: + return self.gate_up_proj = MultiROWMMWeight( weight_names=[f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], data_type=self.data_type_, @@ -235,6 +255,7 @@ def _init_moe(self): e_score_correction_bias_name=self.e_score_correction_bias_name, weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", n_routed_experts=self.n_routed_experts, + num_fused_shared_experts=self.num_fused_shared_experts, split_inter_size=moe_intermediate_size // self.tp_world_size_, data_type=self.data_type_, network_config=self.network_config_, diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index 45c250378..5a376d1b9 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -44,7 +44,7 @@ ChannelDimension, ImageInput, PILImageResampling, - VideoInput, + # VideoInput, get_image_size, infer_channel_dimension_format, is_scaled_image, @@ -54,6 +54,8 @@ valid_images, validate_preprocess_arguments, ) + +VideoInput = None from transformers.utils import TensorType, is_vision_available, logging logger = logging.get_logger(__name__) diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 57d10bdcd..03506fd9e 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -105,22 +105,17 @@ def _moe_ffn_edp( hidden_states = input token_num, hidden_dim = hidden_states.shape - if self.n_shared_experts is not None: - shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight) router_logits = layer_weight.moe_gate.mm(hidden_states) ep_output = layer_weight.experts.experts( hidden_states, router_logits=router_logits, - top_k=self.num_experts_per_tok, + top_k=8, renormalize=self.norm_topk_prob, - use_grouped_topk=self.n_group, - topk_group=self.topk_group, - num_expert_group=self.n_group, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, is_prefill=infer_state.is_prefill, ) - if self.n_shared_experts is not None: - ep_output.add_(shared_output) - ep_output = ep_output.view(token_num, hidden_dim) return ep_output diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index b3421a325..8eff289b1 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -5,6 +5,7 @@ from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.qwen3.model import Qwen3TpPartModel from lightllm.utils.log_utils import init_logger +from lightllm.distributed.communication_op import dist_group_manager logger = init_logger(__name__) @@ -21,3 +22,7 @@ class Qwen3MOEModel(Qwen3TpPartModel): def __init__(self, kvargs): super().__init__(kvargs) return + + def _init_custom(self): + super()._init_custom() + dist_group_manager.new_deepep_group(256, self.config["hidden_size"]) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 601b2a48a..e84966b08 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -411,6 +411,11 @@ def make_argument_parser() -> argparse.ArgumentParser: action="store_true", help="""Whether to update the redundant expert for deepseekv3 model by online expert used counter.""", ) + parser.add_argument( + "--enable_fused_shared_experts", + action="store_true", + help="""Whether to enable fused shared experts for deepseekv3 model.""", + ) parser.add_argument( "--mtp_mode", choices=["deepseekv3", None], From 7e299d6b501820e8dfeb3cb2cefd6cb3b31a6d16 Mon Sep 17 00:00:00 2001 From: baishihao Date: Tue, 8 Jul 2025 17:00:50 +0800 Subject: [PATCH 14/14] update --- test/kernel/fuse_moe_tuning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/kernel/fuse_moe_tuning.py b/test/kernel/fuse_moe_tuning.py index dae85a7b5..c15129d97 100644 --- a/test/kernel/fuse_moe_tuning.py +++ b/test/kernel/fuse_moe_tuning.py @@ -105,7 +105,7 @@ def test_kernel( expert_to_weights = torch.empty((expert_num, (topk + num_fused_experts) * m), dtype=torch.float32, device="cuda") moe_align(topk_ids=topk_ids, out=expert_to_tokens) expert_to_token_num = torch.empty((expert_num,), dtype=torch.int32, device="cuda") - moe_align1(expert_to_tokens, topk_weights, expert_to_weights, expert_to_token_num, topk=topk + 1) + moe_align1(expert_to_tokens, topk_weights, expert_to_weights, expert_to_token_num, topk=topk + num_fused_experts) out1 = torch.zeros((m * (topk + 1), 2 * n), dtype=torch.bfloat16, device="cuda") down_in = torch.zeros((m * (topk + 1), n), dtype=torch.bfloat16, device="cuda")