diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..bd2a5c76e --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 2}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..ea69378f4 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 2, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 2}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..b9a717cee --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json index 5c0dab42b..37ba845fa 100644 --- a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -1 +1 @@ -{"1": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}} \ No newline at end of file +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 2}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..286de4928 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 2}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=4096,N=192,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=4096,N=192,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..cd4b2b79e --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=4096,N=192,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 5}, "8192": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 5}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=512,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=512,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..fe56e1c44 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=512,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..25333e743 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..ed56a6fc7 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..5e2f44cb0 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..bc763e8bc --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..457d72dc8 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json index 394ce3193..a4f26860b 100644 --- a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -1 +1 @@ -{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 5}, "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4}} \ No newline at end of file +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..f1a0658ba --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=96,N=4096,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=96,N=4096,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..217515264 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=96,N=4096,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 5}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4, "num_warps": 2, "num_stages": 4}, "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "num_warps": 4, "num_stages": 4}, "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "num_warps": 4, "num_stages": 4}, "8192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "num_warps": 4, "num_stages": 4}} \ No newline at end of file diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py index f7a24ae0f..750c94291 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py @@ -84,17 +84,17 @@ def __init__( self.e_score_correction_bias = None self.w2_list = [None] * ep_load_expert_num self.w2_scale_list = [None] * ep_load_expert_num - self.scoring_func = network_config["scoring_func"] + self.scoring_func = "softmax" # network_config["scoring_func"] self.w1 = [None, None] # weight, weight_scale self.w2 = [None, None] # weight, weight_scale self.use_fp8_w8a8 = self.quant_method is not None - + network_config["n_group"] = 0 self.num_experts_per_tok = network_config["num_experts_per_tok"] self.use_grouped_topk = network_config["n_group"] > 0 self.norm_topk_prob = network_config["norm_topk_prob"] self.n_group = network_config["n_group"] - self.topk_group = network_config["topk_group"] - self.routed_scaling_factor = network_config["routed_scaling_factor"] + self.topk_group = 0 # network_config["topk_group"] + self.routed_scaling_factor = 0 # network_config["routed_scaling_factor"] self.lock = threading.Lock() # init buffer diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py index 131e65f54..5a4e84f82 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py @@ -16,6 +16,7 @@ def __init__( e_score_correction_bias_name: str, weight_prefix: str, n_routed_experts: int, + num_fused_shared_experts: int, split_inter_size: int, data_type: torch.dtype, network_config: Dict[str, Any], @@ -34,7 +35,10 @@ def __init__( self.e_score_correction_bias_name = e_score_correction_bias_name self.weight_prefix = weight_prefix - self.n_routed_experts = n_routed_experts + assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." + self.n_routed_experts = n_routed_experts + num_fused_shared_experts + self.num_fused_shared_experts = num_fused_shared_experts + self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) self.split_inter_size = split_inter_size self.data_type_ = data_type self.tp_rank_ = get_current_rank_in_dp() @@ -63,7 +67,11 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t topk_group=topk_group, num_expert_group=num_expert_group, scoring_func=self.scoring_func, + num_fused_shared_experts=self.num_fused_shared_experts, ) + if self.num_fused_shared_experts > 0: + topk_ids[:, -1] = self.n_routed_experts - 1 + topk_weights[:, -1] = 1.0 / self.routed_scaling_factor w1, w1_scale = self.w1 w2, w2_scale = self.w2 use_fp8_w8a8 = self.quant_method is not None @@ -93,16 +101,18 @@ def _fuse(self): and None not in self.experts_gate_projs and None not in self.w2_list ): - w1_list = [] + gate_out_dim, gate_in_dim = self.experts_gate_projs[0].shape + up_out_dim, up_in_dim = self.experts_up_projs[0].shape + assert gate_in_dim == up_in_dim + dtype = self.experts_gate_projs[0].dtype + total_expert_num = self.n_routed_experts + + w1 = torch.empty((total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu") + for i_experts in range(self.n_routed_experts): - expert_gate_up_proj = torch.cat( - [self.experts_gate_projs[i_experts], self.experts_up_projs[i_experts]], dim=0 - ) - expert_gate_up_proj = expert_gate_up_proj - w1_list.append(expert_gate_up_proj) - - inter_shape, hidden_size = w1_list[0].shape[0], w1_list[0].shape[1] - w1 = torch._utils._flatten_dense_tensors(w1_list).view(len(w1_list), inter_shape, hidden_size) + w1[i_experts, 0:gate_out_dim:, :] = self.experts_gate_projs[i_experts] + w1[i_experts, gate_out_dim:, :] = self.experts_up_projs[i_experts] + inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) if not self.quantized_weight and self.quant_method is not None: @@ -123,17 +133,19 @@ def _fuse_weight_scale(self): and None not in self.experts_gate_proj_scales and None not in self.w2_scale_list ): - w1_scale_list = [] - for i_experts in range(self.n_routed_experts): - expert_gate_up_proj_scale = torch.cat( - [self.experts_gate_proj_scales[i_experts], self.experts_up_proj_scales[i_experts]], dim=0 - ) - w1_scale_list.append(expert_gate_up_proj_scale) - - inter_shape, hidden_size = w1_scale_list[0].shape[0], w1_scale_list[0].shape[1] - w1_scale = torch._utils._flatten_dense_tensors(w1_scale_list).view( - len(w1_scale_list), inter_shape, hidden_size + gate_out_dim, gate_in_dim = self.experts_gate_proj_scales[0].shape + up_out_dim, up_in_dim = self.experts_up_proj_scales[0].shape + assert gate_in_dim == up_in_dim + dtype = self.experts_gate_proj_scales[0].dtype + total_expert_num = self.n_routed_experts + + w1_scale = torch.empty( + (total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu" ) + + for i_experts in range(self.n_routed_experts): + w1_scale[i_experts, 0:gate_out_dim:, :] = self.experts_gate_proj_scales[i_experts] + w1_scale[i_experts, gate_out_dim:, :] = self.experts_up_proj_scales[i_experts] inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1] w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view( len(self.w2_scale_list), inter_shape, hidden_size diff --git a/lightllm/common/fused_moe/grouped_fused_moe.py b/lightllm/common/fused_moe/grouped_fused_moe.py index c1e239bef..b6c5f123b 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/fused_moe/grouped_fused_moe.py @@ -34,6 +34,7 @@ from .moe_silu_and_mul import silu_and_mul_fwd from .moe_sum_reduce import moe_sum_reduce from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8 +from lightllm.utils.dist_utils import get_current_rank_in_dp FFN_MOE_CHUNK_SIZE = 8 * 1024 @@ -220,8 +221,13 @@ def moe_align1( @triton.jit def moe_align2_kernel( experts_token_num_ptr, # [expert_num,] - mblocks_to_expert_id, # [max_num_m_blocks,] - mblocks_to_m_index, # [max_num_m_blocks,] + expert_to_token_index_ptr, # [expert_num, token_num * topk_num] + expert_to_token_index_stride_0, + expert_to_weights_ptr, + expert_to_weights_stride_0, + mblocks_to_expert_id_ptr, # [max_num_m_blocks,] + padded_expert_to_token_index_ptr, + padded_expert_to_weights_ptr, expert_num, max_num_m_blocks, BLOCK_M: tl.constexpr, @@ -241,27 +247,49 @@ def moe_align2_kernel( block_off = tl.arange(0, 128) for start_loc in range(0, cur_block_num, 128): tl.store( - mblocks_to_expert_id + block_start + start_loc + block_off, + mblocks_to_expert_id_ptr + block_start + start_loc + block_off, expert_id, mask=start_loc + block_off < cur_block_num, ) + + cur_expert_to_token_index_ptr = expert_to_token_index_ptr + expert_id * expert_to_token_index_stride_0 + for start_loc in range(0, cur_block_num): + offset = start_loc * BLOCK_M + tl.arange(0, BLOCK_M) + m_index = tl.load(cur_expert_to_token_index_ptr + offset, mask=offset < cur_expert_token_num, other=0) tl.store( - mblocks_to_m_index + block_start + start_loc + block_off, - start_loc + block_off, - mask=start_loc + block_off < cur_block_num, + padded_expert_to_token_index_ptr + block_start * BLOCK_M + offset, + m_index, + mask=offset < cur_expert_token_num, + ) + + m_weight = tl.load( + expert_to_weights_ptr + expert_id * expert_to_weights_stride_0 + offset, + mask=offset < cur_expert_token_num, + other=0.0, + ) + tl.store( + padded_expert_to_weights_ptr + block_start * BLOCK_M + offset, + m_weight, + mask=offset < cur_expert_token_num, ) if expert_id == expert_num - 1: for extra_fill_start in range(block_start + cur_block_num, max_num_m_blocks, 128): tl.store( - mblocks_to_expert_id + extra_fill_start + block_off, + mblocks_to_expert_id_ptr + extra_fill_start + block_off, -1, mask=extra_fill_start + block_off < max_num_m_blocks, ) return -def moe_align2(token_num_mul_topk_num: int, exports_token_num: torch.Tensor, block_m: int): +def moe_align2( + token_num_mul_topk_num: int, + exports_token_num: torch.Tensor, + block_m: int, + expert_to_token_index: torch.Tensor, + expert_to_weights: torch.Tensor, +): """ exports_token_num is tensor shape [expert_num] , will get expert need handle token num. out tensor is a tensor that contain block schduel infos tensor. @@ -269,14 +297,20 @@ def moe_align2(token_num_mul_topk_num: int, exports_token_num: torch.Tensor, blo max_num_tokens_padded = token_num_mul_topk_num + exports_token_num.shape[0] * (block_m - 1) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_m) mblocks_to_expert_id = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda") - mblocks_to_m_index = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda") + padded_expert_to_token_index = torch.empty(max_num_tokens_padded, dtype=torch.int32, device="cuda").fill_(-1) + padded_expert_to_weights = torch.empty(max_num_tokens_padded, dtype=torch.float32, device="cuda") expert_num = exports_token_num.shape[0] grid = (expert_num,) moe_align2_kernel[grid]( exports_token_num, + expert_to_token_index, + expert_to_token_index.stride(0), + expert_to_weights, + expert_to_weights.stride(0), mblocks_to_expert_id, - mblocks_to_m_index, + padded_expert_to_token_index, + padded_expert_to_weights, expert_num, max_num_m_blocks, BLOCK_M=block_m, @@ -285,13 +319,14 @@ def moe_align2(token_num_mul_topk_num: int, exports_token_num: torch.Tensor, blo num_stages=1, ) - return mblocks_to_expert_id, mblocks_to_m_index + return mblocks_to_expert_id, padded_expert_to_token_index, padded_expert_to_weights @triton.jit def grouped_matmul_kernel( mblocks_to_expert_id, # [max_m_block_size] - mblocks_to_m_index, # [max_m_block_size] + padded_expert_to_token_index, # [max_m_block_size] + padded_expert_to_weights, # [max_m_block_size] k, # int n, # int topk_num, # int @@ -307,12 +342,7 @@ def grouped_matmul_kernel( weight_stride_0, weight_stride_1, weight_stride_2, - expert_to_weights_ptr, # [expert_num, token_num * topk] - expert_to_weights_stride0, - expert_to_weights_stride1, expert_to_token_num, # [expert_num] - expert_to_token_index, # [expert_num, token_num * topk_num] - expert_to_token_index_stride_0, out_ptr, # [token_num * topk_num, n] out_stride_0, out_stride_1, @@ -350,28 +380,14 @@ def grouped_matmul_kernel( if expert_id == -1: return - - tile_m_idx = tl.load(mblocks_to_m_index + pid_m) tile_n_idx = pid_n - - # get the gemm size of the current problem - cur_m = tl.load(expert_to_token_num + expert_id, eviction_policy="evict_last") - # do regular gemm here - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - token_mask = offs_am < cur_m + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + # token_mask = offs_am < cur_m a_m_index = tl.load( - expert_to_token_index + expert_id * expert_to_token_index_stride_0 + offs_am, - mask=token_mask, - other=0, + padded_expert_to_token_index + offs_am, ) - if MUL_ROUTED_WEIGHT: - a_m_scale = tl.load( - expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am, - mask=token_mask, - other=0.0, - ) - + token_mask = a_m_index != -1 offs_bn = (tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n offs_k = tl.arange(0, BLOCK_SIZE_K) @@ -437,6 +453,11 @@ def grouped_matmul_kernel( accumulator *= ab_scale if MUL_ROUTED_WEIGHT: + a_m_scale = tl.load( + padded_expert_to_weights + offs_am, + mask=token_mask, + other=0.0, + ) accumulator *= a_m_scale[:, None] c = accumulator.to(compute_type) @@ -530,16 +551,22 @@ def grouped_matmul( token_inputs, token_input_scale = qinput_tensor, input_scale if reused_mblock_infos is None: - mblocks_to_expert_id, mblocks_to_m_index = moe_align2(token_num_mul_topk_num, expert_to_token_num, BLOCK_SIZE_M) + mblocks_to_expert_id, padded_expert_to_token_index, padded_expert_to_weights = moe_align2( + token_num_mul_topk_num, expert_to_token_num, BLOCK_SIZE_M, expert_to_token_index, expert_to_weights + ) else: # when up group gemm and down group gemm use same BLOCK_SIZE_M, # can reuse (mblocks_to_expert_id, mblocks_to_m_index) created by moe_align2 kernel. - mblocks_to_expert_id, mblocks_to_m_index, reused_block_size_m = reused_mblock_infos + ( + mblocks_to_expert_id, + padded_expert_to_token_index, + padded_expert_to_weights, + reused_block_size_m, + ) = reused_mblock_infos if reused_block_size_m != BLOCK_SIZE_M: - mblocks_to_expert_id, mblocks_to_m_index = moe_align2( - token_num_mul_topk_num, expert_to_token_num, BLOCK_SIZE_M + mblocks_to_expert_id, padded_expert_to_token_index, padded_expert_to_weights = moe_align2( + token_num_mul_topk_num, expert_to_token_num, BLOCK_SIZE_M, expert_to_token_index, expert_to_weights ) - block_num = triton.cdiv(n, BLOCK_SIZE_N) * mblocks_to_expert_id.shape[0] grid = (block_num,) @@ -548,7 +575,8 @@ def grouped_matmul( grouped_matmul_kernel[grid]( mblocks_to_expert_id, - mblocks_to_m_index, + padded_expert_to_token_index, + padded_expert_to_weights, k, n, topk_num, @@ -570,12 +598,7 @@ def grouped_matmul( expert_weights.stride(0), expert_weights.stride(1), expert_weights.stride(2), - expert_to_weights, - expert_to_weights.stride(0), - expert_to_weights.stride(1), expert_to_token_num, - expert_to_token_index, - expert_to_token_index.stride(0), out, out.stride(0), out.stride(1), @@ -594,7 +617,7 @@ def grouped_matmul( num_warps=num_warps, num_stages=num_stages, ) - return (mblocks_to_expert_id, mblocks_to_m_index, BLOCK_SIZE_M) + return (mblocks_to_expert_id, padded_expert_to_token_index, padded_expert_to_weights, BLOCK_SIZE_M) def fused_experts_impl( @@ -625,7 +648,6 @@ def fused_experts_impl( CHUNK_SIZE = FFN_MOE_CHUNK_SIZE topk_num = topk_ids.shape[1] M = min(num_tokens, CHUNK_SIZE) - intermediate_cache1 = alloc_tensor_func((M, topk_num, N), device=hidden_states.device, dtype=hidden_states.dtype) intermediate_cache2 = alloc_tensor_func( (M, topk_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype diff --git a/lightllm/common/fused_moe/grouped_topk.py b/lightllm/common/fused_moe/grouped_topk.py index e8eae1b15..b0e7f51a5 100644 --- a/lightllm/common/fused_moe/grouped_topk.py +++ b/lightllm/common/fused_moe/grouped_topk.py @@ -208,6 +208,7 @@ def triton_grouped_topk( topk_group: int = 0, scoring_func: str = "softmax", group_score_used_topk_num=2, + num_fused_shared_experts: int = 0, ): if correction_bias is not None: @@ -222,8 +223,8 @@ def triton_grouped_topk( dtype = torch.float32 scores_buffer = torch.empty((token_num, total_expert_num), dtype=dtype, device="cuda") - out_topk_weights = torch.empty((token_num, topk), dtype=torch.float32, device="cuda") - out_topk_ids = torch.empty((token_num, topk), dtype=torch.long, device="cuda") + out_topk_weights = torch.empty((token_num, topk + num_fused_shared_experts), dtype=torch.float32, device="cuda") + out_topk_ids = torch.empty((token_num, topk + num_fused_shared_experts), dtype=torch.long, device="cuda") assert total_expert_num % num_expert_group == 0 diff --git a/lightllm/common/fused_moe/moe_kernel_configs.py b/lightllm/common/fused_moe/moe_kernel_configs.py index 0b107ede3..3b47b14c3 100644 --- a/lightllm/common/fused_moe/moe_kernel_configs.py +++ b/lightllm/common/fused_moe/moe_kernel_configs.py @@ -42,12 +42,12 @@ def try_to_get_best_config( else: if M <= expert_num: config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 1, + "num_stages": 3, } else: config = { diff --git a/lightllm/common/fused_moe/moe_silu_and_mul.py b/lightllm/common/fused_moe/moe_silu_and_mul.py index 3f6bdb44f..5c62dbc90 100644 --- a/lightllm/common/fused_moe/moe_silu_and_mul.py +++ b/lightllm/common/fused_moe/moe_silu_and_mul.py @@ -54,6 +54,54 @@ def _silu_and_mul_kernel( ) +@triton.jit +def _silu_and_mul_kernel_fast( + input_ptr, + output_ptr, + stride_input_m, + stride_input_n, + stride_output_m, + stride_output_n, + size_n, + BLOCK_N: tl.constexpr, + NEED_MASK: tl.constexpr, +): + stride_input_m = tl.cast(stride_input_m, dtype=tl.int64) + stride_output_m = tl.cast(stride_output_m, dtype=tl.int64) + + cur_batch = tl.program_id(0) + pid = tl.program_id(1) + n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N) + + up_offsets = cur_batch * stride_input_m + (n_offsets[None, :] + size_n) + gate_offsets = cur_batch * stride_input_m + n_offsets[None, :] + res_offsets = cur_batch * stride_output_m + n_offsets[None, :] + if NEED_MASK: + mask = n_offsets[None, :] < size_n + else: + mask = True + + up = tl.load( + input_ptr + up_offsets, + mask=mask, + other=0.0, + ) + gate = tl.load( + input_ptr + gate_offsets, + mask=mask, + other=0.0, + ).to(tl.float32) + + gate = gate / (1 + tl.exp(-gate)) + gate = gate.to(input_ptr.dtype.element_ty) + + tl.store( + output_ptr + res_offsets, + up * gate, + mask=mask, + ) + + def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config): assert input.is_contiguous() assert output.is_contiguous() @@ -68,6 +116,26 @@ def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config): if not run_config: run_config = MoeSiluAndMulKernelConfig.try_to_get_best_config(M=size_m, N=size_n, out_dtype=str(output.dtype)) + if size_m <= 4096: + BLOCK_N = run_config["BLOCK_N"] + grid = ( + size_m, + triton.cdiv(size_n, BLOCK_N), + ) + NEED_MASK = size_n % BLOCK_N != 0 + _silu_and_mul_kernel_fast[grid]( + input, + output, + stride_input_m, + stride_input_n, + stride_output_m, + stride_output_n, + size_n, + BLOCK_N=BLOCK_N, + NEED_MASK=NEED_MASK, + ) + return + BLOCK_M = run_config["BLOCK_M"] BLOCK_N = run_config["BLOCK_N"] num_warps = run_config["num_warps"] diff --git a/lightllm/common/fused_moe/topk_select.py b/lightllm/common/fused_moe/topk_select.py index ca8d22f48..92303b0c5 100644 --- a/lightllm/common/fused_moe/topk_select.py +++ b/lightllm/common/fused_moe/topk_select.py @@ -181,6 +181,7 @@ def select_experts( num_expert_group: Optional[int] = None, scoring_func: str = "softmax", custom_routing_function: Optional[Callable] = None, + num_fused_shared_experts: int = 0, ): from lightllm.common.fused_moe.topk_select import fused_topk from lightllm.common.fused_moe.grouped_topk import triton_grouped_topk @@ -216,6 +217,7 @@ def select_experts( topk_group=topk_group, scoring_func=scoring_func, group_score_used_topk_num=group_score_topk_num, + num_fused_shared_experts=num_fused_shared_experts, ) elif custom_routing_function is None: diff --git a/lightllm/common/quantization/deepgemm_quant.py b/lightllm/common/quantization/deepgemm_quant.py index 622a9711c..3bfbadc33 100644 --- a/lightllm/common/quantization/deepgemm_quant.py +++ b/lightllm/common/quantization/deepgemm_quant.py @@ -49,12 +49,12 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ m, k = input_tensor.shape n = weights[0].shape[1] if input_scale is None: - input_scale = torch.empty((m, k // self.block_size), dtype=torch.float32, device=input_tensor.device) qinput_tensor = self.cache_manager.alloc_tensor( (m, k), qweight.dtype, device=qweight.device, is_graph_out=False ) - per_token_group_quant_fp8(input_tensor, self.block_size, qinput_tensor, input_scale) - input_scale = tma_align_input_scale(input_scale) + _, input_scale = per_token_group_quant_fp8( + input_tensor, self.block_size, qinput_tensor, column_major_scales=True, scale_tma_aligned=True + ) if out is None: if use_custom_tensor_mananger: diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py b/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py index aa3b5f61d..760cda137 100644 --- a/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py +++ b/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py @@ -6,7 +6,7 @@ from lightllm.utils.sgl_utils import HAS_SGL_KERNEL, sgl_ops from frozendict import frozendict from functools import lru_cache -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple try: from deep_gemm import ceil_div @@ -109,17 +109,46 @@ def per_token_group_quant_fp8( x: torch.Tensor, group_size: int, x_q: torch.Tensor, - x_s: torch.Tensor, + x_s: torch.Tensor = None, eps: float = 1e-10, dtype: torch.dtype = torch.float8_e4m3fn, + column_major_scales: bool = False, + scale_tma_aligned: bool = False, + alloc_func: Callable = torch.empty, ): + # Adapted from + # https://github.com/sgl-project/sglang/blob/7e257cd666c0d639626487987ea8e590da1e9395/python/sglang/srt/layers/quantization/fp8_kernel.py#L290 if HAS_SGL_KERNEL: finfo = torch.finfo(dtype) fp8_max, fp8_min = finfo.max, finfo.min + if column_major_scales: + if scale_tma_aligned: + # aligned to 4 * sizeof(float) + aligned_size = (x.shape[-2] + 3) // 4 * 4 + x_s = alloc_func( + x.shape[:-2] + (x.shape[-1] // group_size, aligned_size), + device=x.device, + dtype=torch.float32, + ).permute(-1, -2)[: x.shape[-2], :] + else: + x_s = alloc_func( + (x.shape[-1] // group_size,) + x.shape[:-1], + device=x.device, + dtype=torch.float32, + ).permute(-1, -2) + else: + if x_s is None: + x_s = alloc_func( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) sgl_ops.sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, 1e-10, fp8_min, fp8_max) else: lightllm_per_token_group_quant_fp8(x, group_size, x_q, x_s, eps=1e-10, dtype=torch.float8_e4m3fn) + return x_q, x_s + # copy from # https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58 @@ -229,8 +258,8 @@ def test_per_token_group_quant_fp8(): x_q = torch.randn((1024, 8192)).cuda().to(torch.float8_e4m3fn) # x_s = torch.randn((1024, 8192 // group_size), dtype=torch.float32).cuda() - x_s = torch.randn((8192 // group_size, 1024 + 10), dtype=torch.float32).cuda().t() - per_token_group_quant_fp8(x, group_size, x_q, x_s) + # x_s = torch.randn((8192 // group_size, 1024 + 10), dtype=torch.float32).cuda().t() + _, x_s = per_token_group_quant_fp8(x, group_size, x_q, None, column_major_scales=True) x_s = x_s[:1024] th_x_q, th_x_s = torch_quant(x, group_size) print("th_x_s - x_s", torch.abs(th_x_s - x_s.reshape(-1)).max()) @@ -238,4 +267,5 @@ def test_per_token_group_quant_fp8(): if __name__ == "__main__": - test_tma_align() + test_per_token_group_quant_fp8() + # test_tma_align() diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index ba752a4e8..5d5cbc55f 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -19,7 +19,6 @@ from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding_fp8 import gqa_token_decode_attention_flash_decoding_fp8 from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo @@ -666,7 +665,8 @@ def _moe_ffn( hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape - if self.n_shared_experts is not None: + # if fused_shared_experts is not enabled, compute shared_output + if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight) router_logits = layer_weight.moe_gate.mm(hidden_states) @@ -682,7 +682,7 @@ def _moe_ffn( hidden_states.mul_(self.routed_scaling_factor) - if self.n_shared_experts is not None: + if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: hidden_states.add_(shared_output) return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 7a9f3c150..dc2b1e285 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -3,7 +3,7 @@ import math import numpy as np from lightllm.common.basemodel import TransformerLayerWeight -from lightllm.utils.envs_utils import enable_env_vars +from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args from lightllm.common.basemodel.layer_weights.meta_weights import ( ROWMMWeight, MultiROWMMWeight, @@ -39,6 +39,9 @@ def _parse_config(self): self.v_head_dim = self.network_config_["v_head_dim"] self.num_attention_heads = self.network_config_["num_attention_heads"] self.kv_lora_rank = self.network_config_["kv_lora_rank"] + self.num_fused_shared_experts = 0 + if get_env_start_args().enable_fused_shared_experts and self.is_moe: + self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0) def _init_weight_names(self): if self.q_lora_rank is None: @@ -96,8 +99,25 @@ def _load_vb_scale(self, kv_b_proj_scale_, block_size): )[:, :, self.qk_nope_head_dim // block_size :].transpose(0, 1) return v_b_proj_scale_.contiguous().to(kv_b_proj_scale_.dtype) + def _rename_shared_experts(self, weights, weight_scale_suffix): + old_prefix = f"model.layers.{self.layer_num_}.mlp.shared_experts" + new_prefix = f"model.layers.{self.layer_num_}.mlp.experts" + proj_names = ["gate_proj", "down_proj", "up_proj"] + for i in range(self.num_fused_shared_experts): + expert_id = self.n_routed_experts + i + for proj in proj_names: + weight_tensor = weights.get(f"{old_prefix}.{proj}.weight") + if weight_tensor is not None: + weights[f"{new_prefix}.{expert_id}.{proj}.weight"] = weight_tensor + if self.quant_cfg.quantized_weight: + scale_tensor = weights.get(f"{old_prefix}.{proj}." + weight_scale_suffix) + if scale_tensor is not None: + weights[f"{new_prefix}.{expert_id}.{proj}." + weight_scale_suffix] = scale_tensor + def load_hf_weights(self, weights): kv_b_quant_method = self.quant_cfg.get_quant_method(self.layer_num_, "kv_b_proj") + if self.quant_cfg.quantized_weight: + weight_scale_suffix = kv_b_quant_method.weight_scale_suffix if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights: kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"] @@ -105,29 +125,27 @@ def load_hf_weights(self, weights): if self.quant_cfg.quantized_weight: kv_b_proj_ = weight_dequant( kv_b_proj_.cuda(), - weights[ - f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + kv_b_quant_method.weight_scale_suffix - ].cuda(), + weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix].cuda(), ).cpu() weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = self._load_kb(kv_b_proj_) weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = self._load_vb(kv_b_proj_) if ( self.quant_cfg.quantized_weight - and f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + kv_b_quant_method.weight_scale_suffix - in weights + and f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix in weights ): - kv_b_proj_scale_ = weights[ - f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + kv_b_quant_method.weight_scale_suffix - ] + kv_b_proj_scale_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix] block_size = 128 - weights[ - f"model.layers.{self.layer_num_}.self_attn.k_b_proj." + kv_b_quant_method.weight_scale_suffix - ] = self._load_kb_scale(kv_b_proj_scale_, block_size) - weights[ - f"model.layers.{self.layer_num_}.self_attn.v_b_proj." + kv_b_quant_method.weight_scale_suffix - ] = self._load_vb_scale(kv_b_proj_scale_, block_size) + weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj." + weight_scale_suffix] = self._load_kb_scale( + kv_b_proj_scale_, block_size + ) + weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj." + weight_scale_suffix] = self._load_vb_scale( + kv_b_proj_scale_, block_size + ) + # rename the shared experts weight + if self.num_fused_shared_experts > 0: + self._rename_shared_experts(weights, weight_scale_suffix) return super().load_hf_weights(weights) def _init_qkvo(self): @@ -198,6 +216,8 @@ def _init_qkvo(self): ) def _load_mlp(self, mlp_prefix): + if self.num_fused_shared_experts > 0: + return self.gate_up_proj = MultiROWMMWeight( weight_names=[f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], data_type=self.data_type_, @@ -235,6 +255,7 @@ def _init_moe(self): e_score_correction_bias_name=self.e_score_correction_bias_name, weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", n_routed_experts=self.n_routed_experts, + num_fused_shared_experts=self.num_fused_shared_experts, split_inter_size=moe_intermediate_size // self.tp_world_size_, data_type=self.data_type_, network_config=self.network_config_, diff --git a/lightllm/models/deepseek2/triton_kernel/rotary_emb.py b/lightllm/models/deepseek2/triton_kernel/rotary_emb.py index 93ff323f3..6f2e333db 100644 --- a/lightllm/models/deepseek2/triton_kernel/rotary_emb.py +++ b/lightllm/models/deepseek2/triton_kernel/rotary_emb.py @@ -5,59 +5,52 @@ @triton.jit -def _rotary_kernel( +def _rotary_kernel_q( Q, - K, Cos, Sin, stride_qbs, stride_qh, stride_qd, - stride_kbs, - stride_kh, - stride_kd, stride_cosbs, stride_cosd, stride_sinbs, stride_sind, max_total_len, HEAD_Q, - HEAD_K, # N_CTX 代表要计算的上下文长度 BLOCK_HEAD: tl.constexpr, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, ): cur_head_index = tl.program_id(0) + if cur_head_index >= HEAD_Q: + return cur_seq_index = tl.program_id(1) - cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 dim_range1 = dim_range0 + 1 off_q0 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range0[None, None, :] * stride_qd + cur_seq_range[:, None, None] * stride_qbs + cur_head_index * stride_qh + dim_range0[None, None, :] * stride_qd ) off_q1 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range1[None, None, :] * stride_qd + cur_seq_range[:, None, None] * stride_qbs + cur_head_index * stride_qh + dim_range1[None, None, :] * stride_qd ) + mask = cur_seq_range[:, None, None] < max_total_len cos_range = tl.arange(0, BLOCK_DMODEL // 2) off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd q0 = tl.load( Q + off_q0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), + mask=mask, other=0.0, ) q1 = tl.load( Q + off_q1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), + mask=mask, other=0.0, ) @@ -67,34 +60,51 @@ def _rotary_kernel( out0 = q0 * cos - q1 * sin out1 = q0 * sin + q1 * cos - tl.store( - Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) - ) - tl.store( - Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) - ) + tl.store(Q + off_q0, out0, mask=mask) + tl.store(Q + off_q1, out1, mask=mask) + return - off_k0 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range0[None, None, :] * stride_kd - ) - off_k1 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range1[None, None, :] * stride_kd - ) + +@triton.jit +def _rotary_kernel_k( + K, + Cos, + Sin, + stride_kbs, + stride_kh, + stride_kd, + stride_cosbs, + stride_cosd, + stride_sinbs, + stride_sind, + max_total_len, + HEAD_K, # HEAD_K is 1. + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_seq_index = tl.program_id(0) + + cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + + dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 + dim_range1 = dim_range0 + 1 + + cos_range = tl.arange(0, BLOCK_DMODEL // 2) + off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd + + off_k0 = cur_seq_range[:, None, None] * stride_kbs + dim_range0[None, None, :] * stride_kd + off_k1 = cur_seq_range[:, None, None] * stride_kbs + dim_range1[None, None, :] * stride_kd off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd k0 = tl.load( K + off_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + mask=(cur_seq_range[:, None, None] < max_total_len), other=0.0, ) k1 = tl.load( K + off_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + mask=(cur_seq_range[:, None, None] < max_total_len), other=0.0, ) @@ -107,12 +117,12 @@ def _rotary_kernel( tl.store( K + off_k0, out_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + mask=(cur_seq_range[:, None, None] < max_total_len), ) tl.store( K + off_k1, out_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + mask=(cur_seq_range[:, None, None] < max_total_len), ) return @@ -126,21 +136,36 @@ def rotary_emb_fwd(q, k, cos, sin): assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" BLOCK_SEQ = 16 - BLOCK_HEAD = 4 + BLOCK_HEAD = 2 if head_dim >= 128: num_warps = 8 else: num_warps = 4 - - grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - _rotary_kernel[grid]( + grid = (triton.next_power_of_2(head_num_q), triton.cdiv(total_len, BLOCK_SEQ)) + _rotary_kernel_q[grid]( q, - k, cos, sin, q.stride(0), q.stride(1), q.stride(2), + cos.stride(0), + cos.stride(1), + sin.stride(0), + sin.stride(1), + total_len, + head_num_q, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=head_dim, + num_warps=num_warps, + num_stages=1, + ) + grid = (triton.cdiv(total_len, BLOCK_SEQ),) + _rotary_kernel_k[grid]( + k, + cos, + sin, k.stride(0), k.stride(1), k.stride(2), @@ -149,9 +174,7 @@ def rotary_emb_fwd(q, k, cos, sin): sin.stride(0), sin.stride(1), total_len, - head_num_q, head_num_k, - BLOCK_HEAD=BLOCK_HEAD, BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=head_dim, num_warps=num_warps, diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 125134659..6e277a528 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -16,7 +16,7 @@ from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2, token_att_fwd2_int8v from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd +from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index 45c250378..5a376d1b9 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -44,7 +44,7 @@ ChannelDimension, ImageInput, PILImageResampling, - VideoInput, + # VideoInput, get_image_size, infer_channel_dimension_format, is_scaled_image, @@ -54,6 +54,8 @@ valid_images, validate_preprocess_arguments, ) + +VideoInput = None from transformers.utils import TensorType, is_vision_available, logging logger = logging.get_logger(__name__) diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 57d10bdcd..03506fd9e 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -105,22 +105,17 @@ def _moe_ffn_edp( hidden_states = input token_num, hidden_dim = hidden_states.shape - if self.n_shared_experts is not None: - shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight) router_logits = layer_weight.moe_gate.mm(hidden_states) ep_output = layer_weight.experts.experts( hidden_states, router_logits=router_logits, - top_k=self.num_experts_per_tok, + top_k=8, renormalize=self.norm_topk_prob, - use_grouped_topk=self.n_group, - topk_group=self.topk_group, - num_expert_group=self.n_group, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, is_prefill=infer_state.is_prefill, ) - if self.n_shared_experts is not None: - ep_output.add_(shared_output) - ep_output = ep_output.view(token_num, hidden_dim) return ep_output diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index b3421a325..8eff289b1 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -5,6 +5,7 @@ from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.qwen3.model import Qwen3TpPartModel from lightllm.utils.log_utils import init_logger +from lightllm.distributed.communication_op import dist_group_manager logger = init_logger(__name__) @@ -21,3 +22,7 @@ class Qwen3MOEModel(Qwen3TpPartModel): def __init__(self, kvargs): super().__init__(kvargs) return + + def _init_custom(self): + super()._init_custom() + dist_group_manager.new_deepep_group(256, self.config["hidden_size"]) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 601b2a48a..e84966b08 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -411,6 +411,11 @@ def make_argument_parser() -> argparse.ArgumentParser: action="store_true", help="""Whether to update the redundant expert for deepseekv3 model by online expert used counter.""", ) + parser.add_argument( + "--enable_fused_shared_experts", + action="store_true", + help="""Whether to enable fused shared experts for deepseekv3 model.""", + ) parser.add_argument( "--mtp_mode", choices=["deepseekv3", None], diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index d223931ed..a5967dad8 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -26,7 +26,9 @@ def get_unique_server_name(): def set_cuda_arch(args): if not torch.cuda.is_available(): return - if args.enable_flashinfer_prefill or args.enable_flashinfer_decode: + from lightllm.utils.sgl_utils import HAS_FLASHINFER + + if HAS_FLASHINFER: capability = torch.cuda.get_device_capability() arch = f"{capability[0]}.{capability[1]}" os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}" @@ -77,15 +79,16 @@ def get_lightllm_websocket_max_message_size(): return int(os.getenv("LIGHTLLM_WEBSOCKET_MAX_SIZE", 16 * 1024 * 1024)) -# get_redundancy_expert_ids and get_redundancy_expert_num are primarily used to obtain the IDs and number of redundant experts during inference. -# They depend on a configuration file specified by ep_redundancy_expert_config_path, which is a JSON formatted text file. -# The content format is as follows: -# { -# "redundancy_expert_num": 1, # Number of redundant experts per rank -# "0": [0], # Key: layer_index (string), Value: list of original expert IDs that are redundant for this layer -# "1": [0], -# "default": [0] # Default list of redundant expert IDs if layer-specific entry is not found -# } +# get_redundancy_expert_ids and get_redundancy_expert_num are primarily used to obtain the IDs +# and number of redundant experts during inference. They depend on a configuration file specified +# by ep_redundancy_expert_config_path, which is a JSON formatted text file. +# The content format is as follows: +# { +# "redundancy_expert_num": 1, # Number of redundant experts per rank +# "0": [0], # Key: layer_index (string), Value: list of redundant expert IDs of this layer +# "1": [0], +# "default": [0] # Default list of redundant expert IDs if layer-specific entry is not found +# } @lru_cache(maxsize=None) diff --git a/lightllm/utils/sgl_utils.py b/lightllm/utils/sgl_utils.py index b48a62506..3a183c47e 100644 --- a/lightllm/utils/sgl_utils.py +++ b/lightllm/utils/sgl_utils.py @@ -30,3 +30,15 @@ "sgl_kernel is not installed, or the installed version did not support fa3. \ Try to upgrade it." ) + +try: + import flashinfer + from flashinfer.norm import fused_add_rmsnorm, rmsnorm + + HAS_FLASHINFER = True +except: + HAS_FLASHINFER = False + logger.warning( + "flashinfer is not installed, you can't use the api of it. \ + You can solve it by running `pip install flashinfer`." + ) diff --git a/test/kernel/fuse_moe_tuning.py b/test/kernel/fuse_moe_tuning.py index 6e971573a..c15129d97 100644 --- a/test/kernel/fuse_moe_tuning.py +++ b/test/kernel/fuse_moe_tuning.py @@ -7,6 +7,7 @@ from typing import List from lightllm.utils.log_utils import init_logger from transformers import AutoConfig +import torch.nn.functional as F logger = init_logger(__name__) @@ -61,6 +62,7 @@ def test_kernel( use_fp8_w8a8: bool, is_up: bool, block_shape, + num_fused_experts: int, **config, ): set_seed() @@ -68,6 +70,8 @@ def test_kernel( a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1_scale = w2_scale = None + if num_fused_experts > 0: + expert_num += num_fused_experts if use_fp8_w8a8: init_dtype = dtype @@ -91,19 +95,21 @@ def test_kernel( w1 = torch.randn(expert_num, 2 * n, k, dtype=dtype).cuda() w2 = torch.randn(expert_num, k, 2 * n // 2, dtype=dtype).cuda() - rnd_logics = torch.randn(m, expert_num, device="cuda") + rnd_logics = torch.randn(m, expert_num - num_fused_experts, device="cuda") topk_values, topk_ids = torch.topk(rnd_logics, topk, dim=1) - topk_weights = torch.randn((m, topk), device="cuda", dtype=dtype) / 10 + topk_weights = torch.randn((m, topk + num_fused_experts), device="cuda", dtype=dtype) / 10 + if num_fused_experts > 0: + topk_ids = F.pad(topk_ids, (0, 1), mode="constant", value=expert_num) - expert_to_tokens = torch.empty((expert_num, topk * m), dtype=torch.int32, device="cuda") - expert_to_weights = torch.empty((expert_num, topk * m), dtype=torch.float32, device="cuda") + expert_to_tokens = torch.empty((expert_num, (topk + num_fused_experts) * m), dtype=torch.int32, device="cuda") + expert_to_weights = torch.empty((expert_num, (topk + num_fused_experts) * m), dtype=torch.float32, device="cuda") moe_align(topk_ids=topk_ids, out=expert_to_tokens) expert_to_token_num = torch.empty((expert_num,), dtype=torch.int32, device="cuda") - moe_align1(expert_to_tokens, topk_weights, expert_to_weights, expert_to_token_num, topk=topk) + moe_align1(expert_to_tokens, topk_weights, expert_to_weights, expert_to_token_num, topk=topk + num_fused_experts) - out1 = torch.zeros((m * topk, 2 * n), dtype=torch.bfloat16, device="cuda") - down_in = torch.zeros((m * topk, n), dtype=torch.bfloat16, device="cuda") - out2 = torch.zeros((m * topk, k), dtype=torch.bfloat16, device="cuda") + out1 = torch.zeros((m * (topk + 1), 2 * n), dtype=torch.bfloat16, device="cuda") + down_in = torch.zeros((m * (topk + 1), n), dtype=torch.bfloat16, device="cuda") + out2 = torch.zeros((m * (topk + 1), k), dtype=torch.bfloat16, device="cuda") for _ in range(test_count): input_tuples.append( @@ -219,6 +225,7 @@ def worker( use_fp8_w8a8: bool, is_up: bool, block_shape, + num_fused_experts: int, test_configs, queue, ): @@ -235,6 +242,7 @@ def worker( use_fp8_w8a8=use_fp8_w8a8, is_up=is_up, block_shape=block_shape, + num_fused_experts=num_fused_experts, **test_configs[index], ) queue.put(cost_time) # Put result in queue @@ -302,6 +310,7 @@ def tuning_configs( use_fp8_w8a8: bool, is_up: bool, block_shape, + num_fused_experts: int, ): os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) best_config, best_cost_time = None, 10000000 @@ -325,6 +334,7 @@ def tuning_configs( use_fp8_w8a8, is_up, block_shape, + num_fused_experts, test_configs, queue, ), @@ -359,6 +369,7 @@ def tuning_configs( use_fp8_w8a8, is_up, block_shape, + num_fused_experts, test_configs, queue, ), @@ -393,15 +404,16 @@ def main(args): if config.architectures[0] == "Qwen3MoeForCausalLM": expert_num = config.num_experts topk_num = config.num_experts_per_tok - n = 2 * config.moe_intermediate_size // args.tp + n = config.moe_intermediate_size // args.tp elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: expert_num = config.n_routed_experts topk_num = config.num_experts_per_tok - n = 2 * config.moe_intermediate_size // args.tp + n = config.moe_intermediate_size // args.tp else: pass hidden_dim = getattr(config, "hidden_size", None) or config.text_config.hidden_size + print(n, hidden_dim) use_fp8_w8a8 = args.use_fp8_w8a8 block_shape = None if hasattr(config, "quantization_config") and "weight_block_size" in config.quantization_config: @@ -424,6 +436,7 @@ def main(args): "use_fp8_w8a8": use_fp8_w8a8, "is_up": True, "block_shape": block_shape, + "num_fused_experts": args.num_fused_experts, }, ) up_dict[m] = ans @@ -431,7 +444,7 @@ def main(args): N=n * 2, K=hidden_dim, topk_num=topk_num, - expert_num=expert_num, + expert_num=expert_num + 1, mul_routed_weight=False, use_fp8_w8a8=use_fp8_w8a8, out_dtype=str(torch.bfloat16), @@ -453,6 +466,7 @@ def main(args): "use_fp8_w8a8": use_fp8_w8a8, "is_up": False, "block_shape": block_shape, + "num_fused_experts": args.num_fused_experts, }, ) down_dict[m] = ans @@ -461,7 +475,7 @@ def main(args): N=hidden_dim, K=n, topk_num=1, - expert_num=expert_num, + expert_num=expert_num + 1, mul_routed_weight=True, use_fp8_w8a8=use_fp8_w8a8, out_dtype=str(torch.bfloat16), @@ -474,5 +488,6 @@ def main(args): parser.add_argument("--model_dir", type=str, default="deepseek-ai/DeepSeek-R1") parser.add_argument("--tp", type=int, default=8) parser.add_argument("--use_fp8_w8a8", action="store_true") + parser.add_argument("--num_fused_experts", type=int, default=0) args = parser.parse_args() main(args)