diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..d16e9335b --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third-party/cutlass"] + path = third-party/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/docs/CN/source/getting_started/multimodal_model_quickstart.rst b/docs/CN/source/getting_started/multimodal_model_quickstart.rst new file mode 100644 index 000000000..cc3eaf724 --- /dev/null +++ b/docs/CN/source/getting_started/multimodal_model_quickstart.rst @@ -0,0 +1,11 @@ +..multimodal_model_quickstart.rst +------------------------- + +下载多模态模型(如llava系列、internvl系列、qwen_vl系列等)的模型以后,在终端使用下面的代码部署API服务: + +.. code-block:: console + + $ python -m lightllm.server.api_server --model_dir ~/models/llava-7b-chat --use_dynamic_prompt_cache --enable_multimodal + +.. note:: + 上面代码中的 ``--model_dir`` 参数需要修改为你本机实际的模型路径。 diff --git a/lightllm-kernel/CMakeLists.txt b/lightllm-kernel/CMakeLists.txt new file mode 100644 index 000000000..25a9855b6 --- /dev/null +++ b/lightllm-kernel/CMakeLists.txt @@ -0,0 +1,65 @@ +cmake_minimum_required(VERSION 3.22) +project(lightllm_kernel LANGUAGES CXX CUDA) + +# GPU 架构:缺省支持 A100(80)、Ampere(86)、Ada/L40s/4090(89)、Hopper(90), +if(NOT CMAKE_CUDA_ARCHITECTURES) + set(CMAKE_CUDA_ARCHITECTURES 80;86;89;90) +endif() + +# 找 PyTorch & Python +find_package(Torch REQUIRED) +find_package(Python REQUIRED COMPONENTS Development) +find_package(CUDAToolkit REQUIRED) + +# 收集 csrc 下的 .cpp/.cu +file(GLOB_RECURSE SRC_CPP CONFIGURE_DEPENDS "${PROJECT_SOURCE_DIR}/csrc/*.cpp") +file(GLOB_RECURSE SRC_CUDA CONFIGURE_DEPENDS "${PROJECT_SOURCE_DIR}/csrc/*.cu") + +# 编译生成 Python 扩展, _C.so +if (NOT TARGET _C) + add_library(_C SHARED ${SRC_CPP} ${SRC_CUDA}) + + # C++17 更方便调度宏 + target_compile_features(_C PRIVATE cxx_std_17) + target_include_directories(_C PRIVATE + ${TORCH_INCLUDE_DIRS} + ${CUDAToolkit_INCLUDE_DIRS} + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/csrc + ${PROJECT_SOURCE_DIR}/../third-party/cutlass/include + ) + target_link_libraries(_C + PRIVATE + ${TORCH_LIBRARIES} + Python::Python + CUDA::cudart + CUDA::cuda_driver) + + + # 输出文件名 _C.so,无前缀 + set_target_properties(_C PROPERTIES + PREFIX "" + OUTPUT_NAME "_C" + BUILD_RPATH "\$ORIGIN;\$ORIGIN/../torch/lib" + INSTALL_RPATH "\$ORIGIN;\$ORIGIN/../torch/lib" + ) +endif() +# 安装:把 _C.so、Python 包和 csrc 一起拷到 site-packages +include(GNUInstallDirs) + +# 1) 计算 Python site-packages 路径 + +message(STATUS "Installing to ARCH = ${Python_SITEARCH}") +message(STATUS "Installing to PURE = ${Python_SITELIB}") + +# 2) 安装编译好的 _C.so 到 lightllm_kernel 目录 +install(TARGETS _C + LIBRARY DESTINATION ${Python_SITEARCH}/lightllm_kernel) + +# 3) 安装 Python 源码包 +install(DIRECTORY ${PROJECT_SOURCE_DIR}/lightllm_kernel + DESTINATION ${Python_SITELIB}) + +# 4) 安装 csrc 源码以供 JIT fallback +install(DIRECTORY ${PROJECT_SOURCE_DIR}/csrc + DESTINATION ${Python_SITELIB}/lightllm_kernel) diff --git a/lightllm-kernel/LICENSE b/lightllm-kernel/LICENSE new file mode 100644 index 000000000..7a4a3ea24 --- /dev/null +++ b/lightllm-kernel/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/lightllm-kernel/Makefile b/lightllm-kernel/Makefile new file mode 100644 index 000000000..5b7100bb6 --- /dev/null +++ b/lightllm-kernel/Makefile @@ -0,0 +1,14 @@ +.PHONY: build clean submodule + +SUBMODULE_DIR = third-party/cutlass + +submodule: + git submodule update --init --recursive + +build: submodule + # 8.0-> A100, 8.6-> A10, 8.9-> L40s/4090, 9.0+PTX-> Hopper + TORCH_CUDA_ARCH_LIST="8.0;8.6;8.9;9.0+PTX" \ + python -m pip install -v . + +clean: + rm -rf build dist *.egg-info \ No newline at end of file diff --git a/lightllm-kernel/README-CH.md b/lightllm-kernel/README-CH.md new file mode 100644 index 000000000..647a594b8 --- /dev/null +++ b/lightllm-kernel/README-CH.md @@ -0,0 +1,42 @@ +# LightLLM-Kernel + +[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) + +lightllm-kernel 是大模型推理系统 LightLLM 的 CUDA 算子库。它提供了在大型模型推理过程中所需的一系列自定义 GPU 运算算子,以加速关键步骤的计算。 + +## 功能列表 + +| Module | Description | +|--------------|-------------------------------------------------------------------------------------------------| +| **Attention** | Optimized Multi-Head Attention kernels with fused QKV operations and efficient softmax | +| **MoE** | Expert routing and computation kernels for Mixture-of-Experts architectures | +| **Quant** | Low-precision quantization support (INT8/INT4) for weights and activations | +| **Extensions**| Continuous expansion of optimized operations for emerging model architectures | + +## 安装方法 + +lightllm_kernel 提供了静态编译以及JIT(Just-In-Time)动态编译的安装方式。推荐使用静态编译安装以获得最佳性能,同时也支持开发者使用可编辑安装进行开发调试。 + +### System Requirements +- NVIDIA GPU with Compute Capability ≥ 7.0 (Volta+) +- CUDA 11.8 or higher +- Python 3.8+ + +### Installation Methods + +#### Static Compilation (Recommended) +```bash +git clone https://github.com/YourUsername/lightllm_kernel.git +cd lightllm_kernel +make build +# Alternative using pip +pip install . +``` + +## 贡献指南 +欢迎社区开发者为 lightllm_kernel 做出贡献!如果您计划新增自定义算子或改进现有功能,请参考以下指南: +- 新增算子实现:在 csrc/ 目录下添加您的 CUDA/C++ 源码文件,添加时建议参考现有算子的代码风格和结构。 +- 注册Python接口:在 csrc/ops_bindings.cpp中,将新增的算子通过 PyBind11 或 TORCH_LIBRARY 等机制注册到 Python 接口。 +- 导出算子到Python模块:在lightllm_kernel/ops/__init__.py只添加相应的导出代码,使新算子包含在 lightllm_kernel.ops 模块中。 +- 本地测试:开发完成后,请在本地对您的更改进行测试。您可以编译安装新的版本并编写简单的脚本调用新算子,检查其功能和性能是否符合预期。如果项目附带了测试用例,也请运行所有测试确保不引入回归。 +- \ No newline at end of file diff --git a/lightllm-kernel/README.md b/lightllm-kernel/README.md new file mode 100644 index 000000000..9ce4bce41 --- /dev/null +++ b/lightllm-kernel/README.md @@ -0,0 +1,39 @@ +# LightLLM-Kernel + +[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) + +LightLLM-Kernel is a high-performance CUDA kernel library powering the LightLLM inference system. It provides optimized GPU implementations for critical operations in large language model (LLM) inference, delivering significant performance improvements through carefully crafted CUDA kernels. + +## Project Overview + +LightLLM-Kernel serves as the computational backbone for LightLLM framework, offering: +- **Custom CUDA Kernels**: Highly optimized implementations for transformer-based model operations +- **Memory Efficiency**: Reduced memory footprint through advanced quantization techniques +- **Scalability**: Support for large model architectures including MoE (Mixture-of-Experts) models + +## Key Features + +### Core Modules +| Module | Description | +|--------------|-------------------------------------------------------------------------------------------------| +| **Attention** | Optimized Multi-Head Attention kernels with fused QKV operations and efficient softmax | +| **MoE** | Expert routing and computation kernels for Mixture-of-Experts architectures | +| **Quant** | Low-precision quantization support (INT8/INT4) for weights and activations | +| **Extensions**| Continuous expansion of optimized operations for emerging model architectures | + +## Installation + +### System Requirements +- NVIDIA GPU with Compute Capability ≥ 7.0 (Volta+) +- CUDA 11.8 or higher +- Python 3.8+ + +### Installation Methods + +#### Static Compilation (Recommended) +```bash +git clone https://github.com/YourUsername/lightllm_kernel.git +cd lightllm_kernel +make build +# Alternative using pip +pip install . \ No newline at end of file diff --git a/lightllm-kernel/benchmark/bench_quant_per_token_bf16_fp8.py b/lightllm-kernel/benchmark/bench_quant_per_token_bf16_fp8.py new file mode 100644 index 000000000..cd2eb291f --- /dev/null +++ b/lightllm-kernel/benchmark/bench_quant_per_token_bf16_fp8.py @@ -0,0 +1,71 @@ +import time +import torch +import itertools +from typing import Optional, Tuple +from vllm import _custom_ops as ops +from sgl_kernel import sgl_per_token_quant_fp8 + +try: + from lightllm_kernel.ops import per_token_quant_bf16_fp8 +except ImportError: + raise ImportError("lightllm-kernel op per_token_quant_bf16_fp8 not found.") + +fp8_type_ = torch.float8_e4m3fn + + +def vllm_per_token_quant_fp8( + input: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True) + + +def sglang_per_token_quant_fp8( + input: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32) + output = torch.empty_like(input, device=input.device, dtype=fp8_type_) + sgl_per_token_quant_fp8(input, output, scale) + + return output, scale + + +def lightllm_per_token_quant_fp8( + input: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + return per_token_quant_bf16_fp8(input) + + +def dequantize(q: torch.Tensor, scale: torch.Tensor): + return q.to(torch.bfloat16) * scale.view(-1, *((1,) * (q.dim() - 1))) + + +def benchmark(fn, name, inp, iterations=200): + for _ in range(20): + q, s = fn(inp) + torch.cuda.synchronize() + + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + starter.record() + for _ in range(iterations): + q, s = fn(inp) + ender.record() + torch.cuda.synchronize() + avg_ms = starter.elapsed_time(ender) / iterations + + q, s = fn(inp) + recon = dequantize(q, s) + err = recon - inp.to(torch.bfloat16) + mse = err.pow(2).mean().item() + max_err = err.abs().max().item() + + print(f"{name:20s} | latency: {avg_ms:7.3f} ms | MSE: {mse:.3e} | MaxErr: {max_err:.3e}") + + +if __name__ == "__main__": + batch, seq_len = 64, 4096 + device = "cuda" + inp = torch.randn(batch, seq_len, device=device, dtype=torch.bfloat16) + + benchmark(vllm_per_token_quant_fp8, "vllm_ops", inp) + benchmark(sglang_per_token_quant_fp8, "sgl_kernel", inp) + benchmark(lightllm_per_token_quant_fp8, "lightllm_kernel", inp) diff --git a/lightllm-kernel/benchmark/bench_rms_norm.py b/lightllm-kernel/benchmark/bench_rms_norm.py new file mode 100644 index 000000000..c591c53cb --- /dev/null +++ b/lightllm-kernel/benchmark/bench_rms_norm.py @@ -0,0 +1,78 @@ +import time +import torch +from typing import Optional, Tuple, Union + +from vllm import _custom_ops as vllm_ops +from lightllm_kernel.ops import rmsnorm_bf16 as lightllm_rms_norm +from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm as triton_rms_norm + + +def vllm_rmsnorm( + x: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + residual: Optional[torch.Tensor] = None, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + vllm_ops.fused_add_rms_norm(x, residual, weight, eps) + output = (x, residual) + else: + out = torch.empty_like(x) + vllm_ops.rms_norm(out, x, weight, eps) + output = out + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def torch_rmsnorm(x: torch.Tensor, w: torch.Tensor, eps: float): + mean_sq = x.pow(2).mean(dim=-1, keepdim=True) + inv_std = torch.rsqrt(mean_sq + eps) + out = x * inv_std * w + return out + + +def benchmark(fn, name, x, w, eps, iterations=200): + for _ in range(10): + _ = fn(x, w, eps) + torch.cuda.synchronize() + + starter = torch.cuda.Event(enable_timing=True) + ender = torch.cuda.Event(enable_timing=True) + starter.record() + for _ in range(iterations): + _ = fn(x, w, eps) + ender.record() + torch.cuda.synchronize() + latency_ms = starter.elapsed_time(ender) / iterations + + y_ref = torch_rmsnorm(x, w, eps) + y_out = fn(x, w, eps) + err = y_out - y_ref + mse = err.pow(2).mean().item() + max_err = err.abs().max().item() + + print(f"{name:20s} | latency: {latency_ms:7.3f} ms | MSE: {mse:.3e} | MaxErr: {max_err:.3e}") + + +if __name__ == "__main__": + + batch, dim = 64, 1024 + eps = 1e-6 + device = "cuda" + + x = torch.randn(batch, dim, device=device, dtype=torch.bfloat16) + w = torch.randn(dim, device=device, dtype=torch.bfloat16) + + benchmark(torch_rmsnorm, "torch_rmsnorm", x, w, eps) + benchmark(lightllm_rms_norm, "lightllm_rms_norm", x, w, eps) + benchmark(triton_rms_norm, "triton_rms_norm", x, w, eps) + benchmark(vllm_rmsnorm, "vllm_rmsnorm", x, w, eps) diff --git a/lightllm-kernel/benchmark/bench_tp_norm.py b/lightllm-kernel/benchmark/bench_tp_norm.py new file mode 100644 index 000000000..53599ebb3 --- /dev/null +++ b/lightllm-kernel/benchmark/bench_tp_norm.py @@ -0,0 +1,86 @@ +# bench_tp_norm_tp4.py +import os +import torch +import torch.distributed as dist +from types import SimpleNamespace + +from lightllm_kernel.ops import ( + rmsnorm_bf16, + pre_tp_norm_bf16, + post_tp_norm_bf16, +) + + +def init_dist(): + dist.init_process_group("nccl", init_method="env://") + rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(rank) + return rank, dist.get_world_size() + + +def tp_norm_cuda(x, w, cfg): + if cfg.tp_world == 1: + return rmsnorm_bf16(x, w, cfg.eps) + + var_local = pre_tp_norm_bf16(x) + dist.all_reduce(var_local, op=dist.ReduceOp.SUM) + return post_tp_norm_bf16(x, w, var_local, cfg.global_embed, cfg.eps) + + +def tp_norm_ref(x, w, cfg): + x32 = x.to(torch.float32) + var = x32.pow(2).sum(-1, keepdim=True) + if cfg.tp_world > 1: + dist.all_reduce(var, op=dist.ReduceOp.SUM) + x32 = x32 * torch.rsqrt(var / cfg.global_embed + cfg.eps) + return (w.to(torch.float32) * x32).to(x.dtype) + + +def bench(fn, tag, x, w, cfg, iters=200): + for _ in range(20): + fn(x, w, cfg) + torch.cuda.synchronize() + t0 = torch.cuda.Event(True) + t1 = torch.cuda.Event(True) + t0.record() + for _ in range(iters): + fn(x, w, cfg) + t1.record() + torch.cuda.synchronize() + ms = t0.elapsed_time(t1) / iters + + ref = tp_norm_ref(x, w, cfg).to(torch.float32) + out = fn(x, w, cfg).to(torch.float32) + mse = (out - ref).pow(2).mean().item() + err = (out - ref).abs().max().item() + + if dist.get_rank() == 0: + print(f"{tag:18s}| {ms:6.3f} ms | MSE {mse:.3e} | MaxErr {err:.3e}") + + +if __name__ == "__main__": + rank, world = init_dist() + + tp_world = 4 + pad_heads, dim_h = 32, 1024 + local_embed = pad_heads * dim_h + global_embed = local_embed * tp_world + tokens = 2048 + eps = 1e-6 + + x = torch.randn(tokens, local_embed, device=f"cuda:{rank}", dtype=torch.bfloat16) + w = torch.randn(local_embed, device=f"cuda:{rank}", dtype=torch.bfloat16) + + cfg = SimpleNamespace(tp_world=tp_world, global_embed=global_embed, eps=eps) + + if rank == 0: + print( + f"tp={tp_world}, tokens={tokens}, local_embed={local_embed}, " f"global_embed={global_embed}, dtype=bf16\n" + ) + dist.barrier() + + bench(tp_norm_ref, "torch_ref", x, w, cfg) + bench(tp_norm_cuda, "cuda_kernel", x, w, cfg) + + dist.destroy_process_group() +# python -m torch.distributed.run --nproc_per_node=4 bench_tp_norm.py diff --git a/lightllm-kernel/csrc/allgather/all_gather.cu b/lightllm-kernel/csrc/allgather/all_gather.cu new file mode 100644 index 000000000..56e4a863d --- /dev/null +++ b/lightllm-kernel/csrc/allgather/all_gather.cu @@ -0,0 +1,150 @@ +#include +#include +#include +#include + +#include "ops_common.h" +#include "all_gather.cuh" + +namespace lightllm { +namespace ops { +// Fake pointer type, must match fptr_t type in ops.h. +// We use this type alias to indicate when pointers are passed in as int64_t. +using fptr_t = int64_t; +static_assert(sizeof(void*) == sizeof(fptr_t)); + +fptr_t init_custom_gather_ar(const std::vector& fake_ipc_ptrs, + torch::Tensor& rank_data, int64_t rank, + bool full_nvlink) { + int world_size = fake_ipc_ptrs.size(); + if (world_size > 8) + throw std::invalid_argument("world size > 8 is not supported"); + if (world_size % 2 != 0) + throw std::invalid_argument("Odd num gpus is not supported for now"); + if (rank < 0 || rank >= world_size) + throw std::invalid_argument("invalid rank passed in"); + + vllm::Signal* ipc_ptrs[8]; + for (int i = 0; i < world_size; i++) { + ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); + } + return (fptr_t) new vllm::CustomAllgather(ipc_ptrs, rank_data.data_ptr(), + rank_data.numel(), rank, world_size, + full_nvlink); +} + +/** + * Make sure tensor t's data lies completely within ((char)t.data_ptr()) + + * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous() + * because it allows transpose of contiguous slice (i.e. slicing the first + * dimension). Currently, we require this because stride information is not + * passed into the kernels and we treat input tensors as flat. + * + * Examples + * A = torch.zeros(3, 3, 3) + * 1. A: OK + * 2. A[1:]: OK + * 3. A.permute(2, 0, 1): OK + * 4. A[1:].permute(2, 0, 1): OK + * 5. A[None].expand(2, -1, -1, -1): Not OK + * 6. A[:, 1:, 1:]: Not OK + */ +bool _is_weak_contiguous_gather(torch::Tensor& t) { + return t.is_contiguous() || + (t.storage().nbytes() - t.storage_offset() * t.element_size() == + t.numel() * t.element_size()); +} + +/** + * Performs an out-of-place allgather and stores result in out. + * + * If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered. + * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first + * copied into _reg_buffer. + */ +void all_gather(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + + fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) { + auto fa = reinterpret_cast(_fa); + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK(_is_weak_contiguous_gather(out)); + TORCH_CHECK(_is_weak_contiguous_gather(inp)); + auto input_size = inp.numel() * inp.element_size(); + auto reg_buffer = reinterpret_cast(_reg_buffer); + if (reg_buffer) { + TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes); + AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size, + cudaMemcpyDeviceToDevice, stream)); + } else { + reg_buffer = inp.data_ptr(); + } + switch (out.scalar_type()) { + case at::ScalarType::Float: { + fa->allgather(stream, reinterpret_cast(reg_buffer), + reinterpret_cast(out.data_ptr()), + inp.numel()); + break; + } + case at::ScalarType::Half: { + fa->allgather(stream, reinterpret_cast(reg_buffer), + reinterpret_cast(out.data_ptr()), inp.numel()); + break; + } +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) + case at::ScalarType::BFloat16: { + fa->allgather( + stream, reinterpret_cast(reg_buffer), + reinterpret_cast(out.data_ptr()), inp.numel()); + break; + } +#endif + default: + throw std::runtime_error( + "custom allgather only supports float32, float16 and bfloat16"); + } +} + +void allgather_dispose(fptr_t _fa) { + delete reinterpret_cast(_fa); +} + +int64_t meta_size() { return sizeof(vllm::Signal); } + +void allgather_register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs) { + auto fa = reinterpret_cast(_fa); + TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_); + void* ipc_ptrs[8]; + for (int i = 0; i < fake_ipc_ptrs.size(); i++) { + ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); + } + fa->register_buffer(ipc_ptrs); +} + +// Use vector to represent byte data for python binding compatibility. +std::tuple, std::vector> +allgather_get_graph_buffer_ipc_meta(fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + auto [handle, offsets] = fa->get_graph_buffer_ipc_meta(); + std::vector bytes(handle.begin(), handle.end()); + return std::make_tuple(bytes, offsets); +} + +// Use vector to represent byte data for python binding compatibility. +void allgather_register_graph_buffers(fptr_t _fa, + const std::vector>& handles, + const std::vector>& offsets) { + auto fa = reinterpret_cast(_fa); + std::vector bytes; + bytes.reserve(handles.size()); + for (int i = 0; i < handles.size(); i++) { + bytes.emplace_back(handles[i].begin(), handles[i].end()); + } + bytes.reserve(handles.size()); + fa->register_graph_buffers(bytes, offsets); +} + + } // namespace ops +} // namespace lightllm \ No newline at end of file diff --git a/lightllm-kernel/csrc/allgather/all_gather.cuh b/lightllm-kernel/csrc/allgather/all_gather.cuh new file mode 100644 index 000000000..99cb579be --- /dev/null +++ b/lightllm-kernel/csrc/allgather/all_gather.cuh @@ -0,0 +1,287 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include "all_reduce.cuh" + +// #define CUDACHECK(cmd) \ +// do { \ +// cudaError_t e = cmd; \ +// if (e != cudaSuccess) { \ +// printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ +// cudaGetErrorString(e)); \ +// exit(EXIT_FAILURE); \ +// } \ +// } while (0) + +namespace vllm { + +// use packed type to maximize memory efficiency +// goal: generate ld.128 and st.128 instructions +template +struct gather_packed_t { + // the (P)acked type for load/store + using P = array_t; +}; + +template +__global__ void __launch_bounds__(512, 1) + custom_all_gather_kernel(RankData* _dp, RankSignals sg, Signal* self_sg, + T* __restrict__ result, int rank, int size) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = gridDim.x * blockDim.x; + using P = typename gather_packed_t::P; + multi_gpu_barrier(sg, self_sg, rank); + for (int idx = tid; idx < size; idx += stride) { + #pragma unroll + for (int step = 0; step < ngpus; step ++) { + int src_rank = (rank - step + ngpus) % ngpus; // 当前步骤中数据来源的进程 + P* ptr = (P*)_dp->ptrs[src_rank]; + int dst_offset = src_rank * size; // 数据在 recv_buf 中的存储位置 + // 从 src_rank 的 handle 中读取数据,并存储到 recv_buf + int dst_idx = dst_offset + idx; + ((P*)result)[dst_idx] = ptr[idx]; + } + } + multi_gpu_barrier(sg, self_sg, rank); + +} + +using IPC_KEY = std::array; +static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t)); +static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t)); + +class CustomAllgather { + public: + int rank_; + int world_size_; + bool full_nvlink_; + + RankSignals sg_; + // Stores an map from a pointer to its peer pointters from all ranks. + std::unordered_map buffers_; + Signal* self_sg_; + + // Stores rank data from all ranks. This is mainly for cuda graph purposes. + // For cuda graph to work, all kernel arguments must be fixed during graph + // capture time. However, the peer pointers are not known during graph capture + // time. Therefore, during capture, we increment the rank data pointer and use + // that as the argument to the kernel. The kernel arguments are stored in + // graph_unreg_buffers_. The actual peer pointers will be filled in at the + // memory pointed to by the pointers in graph_unreg_buffers_ when + // the IPC handles are exchanged between ranks. + // + // The overall process looks like this: + // 1. Graph capture. + // 2. Each rank obtains the IPC handles for each addresses used during cuda + // graph capture using get_graph_buffer_ipc_meta. + // 3. (In Python) all gather the IPC handles. + // 4. Obtain the peer pointers by opening the IPC handles, and store them in + // the rank data array at corresponding positions. + RankData *d_rank_data_base_, *d_rank_data_end_; + std::vector graph_unreg_buffers_; + // a map from IPC handles to opened IPC pointers + std::map ipc_handles_; + + /** + * Signals are an array of ipc-enabled buffers from all ranks. + * For each of the buffer, the layout is as follows: + * | -- sizeof(Signal) -- | ------ a few MB ----- | + * The first section is for allgather synchronization, and the second section + * is for storing the intermediate results required by some allgather algos. + * + * Note: this class does not own any device memory. Any required buffers + * are passed in from the constructor. + */ + CustomAllgather(Signal** signals, void* rank_data, size_t rank_data_sz, + int rank, int world_size, bool full_nvlink = true) + : rank_(rank), + world_size_(world_size), + full_nvlink_(full_nvlink), + self_sg_(signals[rank]), + d_rank_data_base_(reinterpret_cast(rank_data)), + d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { + for (int i = 0; i < world_size_; i++) { + sg_.signals[i] = signals[i]; + } + } + + char* open_ipc_handle(const void* ipc_handle) { + auto [it, new_handle] = + ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); + if (new_handle) { + char* ipc_ptr; + CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr, + *((const cudaIpcMemHandle_t*)ipc_handle), + cudaIpcMemLazyEnablePeerAccess)); + it->second = ipc_ptr; + } + return it->second; + } + + std::pair> get_graph_buffer_ipc_meta() { + auto num_buffers = graph_unreg_buffers_.size(); + auto handle_sz = sizeof(cudaIpcMemHandle_t); + std::string handles(handle_sz * num_buffers, static_cast(0)); + std::vector offsets(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto ptr = graph_unreg_buffers_[i]; + void* base_ptr; + // note: must share the base address of each allocation, or we get wrong + // address + if (cuPointerGetAttribute(&base_ptr, + CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, + (CUdeviceptr)ptr) != CUDA_SUCCESS) + throw std::runtime_error("failed to get pointer attr"); + CUDACHECK(cudaIpcGetMemHandle( + (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); + offsets[i] = ((char*)ptr) - ((char*)base_ptr); + } + return std::make_pair(handles, offsets); + } + + void check_rank_data_capacity(size_t num = 1) { + if (d_rank_data_base_ + num > d_rank_data_end_) + throw std::runtime_error( + "Rank data buffer is overflowed by " + + std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); + } + + /** + * Register already-shared IPC pointers. + */ + void register_buffer(void** ptrs) { + check_rank_data_capacity(); + RankData data; + for (int i = 0; i < world_size_; i++) { + data.ptrs[i] = ptrs[i]; + } + auto d_data = d_rank_data_base_++; + CUDACHECK( + cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); + buffers_[ptrs[rank_]] = d_data; + } + + // Note: when registering graph buffers, we intentionally choose to not + // deduplicate the addresses. That means if the allocator reuses some + // addresses, they will be registered again. This is to account for the remote + // possibility of different allocation patterns between ranks. For example, + // rank 1 may get the same input address for the second allgather, but rank 2 + // got a different address. IPC handles have internal reference counting + // mechanism so overhead should be small. + void register_graph_buffers( + const std::vector& handles, + const std::vector>& offsets) { + auto num_buffers = graph_unreg_buffers_.size(); + check_rank_data_capacity(num_buffers); + std::vector rank_data(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto self_ptr = graph_unreg_buffers_[i]; + auto& rd = rank_data[i]; + for (int j = 0; j < world_size_; j++) { + if (j != rank_) { + char* handle = + open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); + handle += offsets[j][i]; + rd.ptrs[j] = handle; + } else { + rd.ptrs[j] = self_ptr; + } + } + } + CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(), + sizeof(RankData) * num_buffers, + cudaMemcpyHostToDevice)); + d_rank_data_base_ += num_buffers; + graph_unreg_buffers_.clear(); + } + + /** + * Performs allgather, assuming input has already been registered. + * + * Block and grid default configs are results after careful grid search. Using + * 36 blocks give the best or close to the best runtime on the devices I + * tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only + * take a small amount of SMs. Not quite sure the underlying reason, but my + * guess is that too many SMs will cause contention on NVLink bus. + */ + template + void allgather(cudaStream_t stream, T* input, T* output, int size, + int threads = 512, int block_limit = 36) { + auto d = gather_packed_t::P::size; + if (size % d != 0) + throw std::runtime_error( + "custom allgather currently requires input length to be multiple " + "of " + + std::to_string(d)); + if (block_limit > kMaxBlocks) + throw std::runtime_error("max supported block limit is " + + std::to_string(kMaxBlocks) + ". Got " + + std::to_string(block_limit)); + + RankData* ptrs; + cudaStreamCaptureStatus status; + CUDACHECK(cudaStreamIsCapturing(stream, &status)); + if (status == cudaStreamCaptureStatusActive) { + ptrs = d_rank_data_base_ + graph_unreg_buffers_.size(); + graph_unreg_buffers_.push_back(input); + } else { + auto it = buffers_.find(input); + if (it == buffers_.end()) + throw std::runtime_error( + "buffer address " + + std::to_string(reinterpret_cast(input)) + + " is not registered!"); + ptrs = it->second; + } + size /= d; + // auto bytes = size * sizeof(typename packed_t::P); + int blocks = std::min(block_limit, (size + threads - 1) / threads); +#define KL(ngpus, name) \ + name<<>>(ptrs, sg_, self_sg_, output, \ + rank_, size); + // TODO(hanzhi713): Threshold is different for A100 and H100. + // Add per device threshold. +#define REDUCE_CASE(ngpus) \ + case ngpus: { \ + KL(ngpus, custom_all_gather_kernel); \ + break; \ + } + + switch (world_size_) { + REDUCE_CASE(2) + REDUCE_CASE(4) + REDUCE_CASE(6) + REDUCE_CASE(8) + default: + throw std::runtime_error( + "custom allgather only supports num gpus in (2,4,6,8). Actual num " + "gpus = " + + std::to_string(world_size_)); + } +#undef REDUCE_CASE +#undef KL + } + + ~CustomAllgather() { + for (auto [_, ptr] : ipc_handles_) { + CUDACHECK(cudaIpcCloseMemHandle(ptr)); + } + } +}; +/** + * To inspect PTX/SASS, copy paste this header file to compiler explorer and add + a template instantiation: + * template void vllm::CustomAllgather::allgather(cudaStream_t, half *, + half *, int, int, int); +*/ +} // namespace vllm diff --git a/lightllm-kernel/csrc/allgather/all_reduce.cuh b/lightllm-kernel/csrc/allgather/all_reduce.cuh new file mode 100644 index 000000000..6be4d4f2b --- /dev/null +++ b/lightllm-kernel/csrc/allgather/all_reduce.cuh @@ -0,0 +1,516 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +namespace vllm { + +constexpr int kMaxBlocks = 36; +// Counter may overflow, but it's fine since unsigned int overflow is +// well-defined behavior. +using FlagType = uint32_t; +struct Signal { + alignas(128) FlagType self_counter[kMaxBlocks][8]; + // Two sets of peer counters are needed for two syncs. The reason is that + // it's possible for peer GPU block to arrive at the second sync point while + // the current GPU block haven't passed the first sync point. Thus, peer GPU + // may write counter+1 while current GPU is busy waiting for counter. We use + // alternating counter array to avoid this possibility. + alignas(128) FlagType peer_counter[2][kMaxBlocks][8]; +}; + +struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; + +struct __align__(16) RankSignals { Signal* signals[8]; }; + +// like std::array, but aligned +template +struct __align__(alignof(T) * sz) array_t { + T data[sz]; + using type = T; + static constexpr int size = sz; +}; + +// use packed type to maximize memory efficiency +// goal: generate ld.128 and st.128 instructions +template +struct packed_t { + // the (P)acked type for load/store + using P = array_t; + // the (A)ccumulator type for reduction + using A = array_t; +}; + +#define DINLINE __device__ __forceinline__ + +// scalar cast functions +DINLINE float upcast_s(half val) { return __half2float(val); } + +template +DINLINE T downcast_s(float val); +template <> +DINLINE half downcast_s(float val) { + return __float2half(val); +} + +// scalar add functions +// for some reason when compiling with Pytorch, the + operator for half and +// bfloat is disabled so we call the intrinsics directly +DINLINE half& assign_add(half& a, half b) { + a = __hadd(a, b); + return a; +} +DINLINE float& assign_add(float& a, float b) { return a += b; } + +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } +template <> +DINLINE nv_bfloat16 downcast_s(float val) { + return __float2bfloat16(val); +} +DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) { + a = __hadd(a, b); + return a; +} +#endif + +template +DINLINE array_t& packed_assign_add(array_t& a, array_t b) { +#pragma unroll + for (int i = 0; i < N; i++) { + assign_add(a.data[i], b.data[i]); + } + return a; +} + +template +DINLINE array_t upcast(array_t val) { + if constexpr (std::is_same::value) { + return val; + } else { + array_t out; +#pragma unroll + for (int i = 0; i < N; i++) { + out.data[i] = upcast_s(val.data[i]); + } + return out; + } +} + +template +DINLINE O downcast(array_t val) { + if constexpr (std::is_same::value) { + return val; + } else { + O out; +#pragma unroll + for (int i = 0; i < O::size; i++) { + out.data[i] = downcast_s(val.data[i]); + } + return out; + } +} + +static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), + "l"(flag_addr)); +#else + asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag), + "l"(flag_addr)); +#endif +} + +static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) { + FlagType flag; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile("ld.acquire.sys.global.u32 %0, [%1];" + : "=r"(flag) + : "l"(flag_addr)); +#else + asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;" + : "=r"(flag) + : "l"(flag_addr)); +#endif + return flag; +} + +static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) { + asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +} + +static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) { + FlagType flag; + asm volatile("ld.volatile.global.u32 %0, [%1];" + : "=r"(flag) + : "l"(flag_addr)); + return flag; +} + +// is_start: whether this is the very first synchronization barrier. +// need_fence: whether a memory fence is needed. If true, a release-acquire +// semantic is used to enforce memory access order before and after this +// barrier. +template +DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, + int rank) { + if constexpr (!is_start) __syncthreads(); + static_assert( + !(is_start && need_fence)); // Start barrier shouldn't need fence. + if (threadIdx.x < ngpus) { + // Increment the counter. Technically we only need one counter, but we use + // multiple per block to eliminate the need to share the counter via smem. + auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1; + // Write the expected counter value to peer and wait for correct value from + // peer. + auto peer_counter_ptr = + &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank]; + auto self_counter_ptr = + &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x]; + if constexpr (need_fence) { + st_flag_release(peer_counter_ptr, val); + while (ld_flag_acquire(self_counter_ptr) != val); + } else { + st_flag_volatile(peer_counter_ptr, val); + while (ld_flag_volatile(self_counter_ptr) != val); + } + } + if constexpr (is_start || need_fence) __syncthreads(); +} + +template +DINLINE P packed_reduce(const P* ptrs[], int idx) { + A tmp = upcast(ptrs[0][idx]); +#pragma unroll + for (int i = 1; i < ngpus; i++) { + packed_assign_add(tmp, upcast(ptrs[i][idx])); + } + return downcast

(tmp); +} + +template +__global__ void __launch_bounds__(512, 1) + cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg, + T* __restrict__ result, int rank, int size) { + using P = typename packed_t::P; + using A = typename packed_t::A; + // note: we don't reorder the address so the accumulation order is the same + // for all ranks, ensuring bitwise identical results + auto dp = *_dp; + multi_gpu_barrier(sg, self_sg, rank); + // do the actual reduction + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); + } + multi_gpu_barrier(sg, self_sg, rank); +} + +template +DINLINE P* get_tmp_buf(Signal* sg) { + return (P*)(((Signal*)sg) + 1); +} + +template +__global__ void __launch_bounds__(512, 1) + cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg, + T* __restrict__ result, int rank, int size) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = gridDim.x * blockDim.x; + using P = typename packed_t::P; + using A = typename packed_t::A; + int part = size / ngpus; + int start = rank * part; + int end = rank == ngpus - 1 ? size : start + part; + int largest_part = part + size % ngpus; + const P* ptrs[ngpus]; + P* tmps[ngpus]; +#pragma unroll + for (int i = 0; i < ngpus; i++) { + int target = (rank + i) % ngpus; + ptrs[i] = (const P*)_dp->ptrs[target]; + tmps[i] = get_tmp_buf

(sg.signals[target]); + } + auto tmp_out = tmps[0]; + multi_gpu_barrier(sg, self_sg, rank); + // stage 1: reduce scatter + for (int idx = start + tid; idx < end; idx += stride) { + tmp_out[idx - start] = packed_reduce(ptrs, idx); + } + multi_gpu_barrier(sg, self_sg, rank); + + // stage 2: allgather. Note: it's important to match the tid between + // the two stages, because visibility across devices is only guaranteed + // between threads that have the same tid. If thread i computes the sum of + // start + i in the first stage, then thread i also gathers start + i from all + // ranks. + for (int idx = tid; idx < largest_part; idx += stride) { +#pragma unroll + for (int i = 0; i < ngpus; i++) { + int gather_from_rank = ((rank + i) % ngpus); + if (gather_from_rank == ngpus - 1 || idx < part) { + int dst_idx = gather_from_rank * part + idx; + ((P*)result)[dst_idx] = tmps[i][idx]; + } + } + } +} + +using IPC_KEY = std::array; +static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t)); +static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t)); + +class CustomAllreduce { + public: + int rank_; + int world_size_; + bool full_nvlink_; + + RankSignals sg_; + // Stores an map from a pointer to its peer pointters from all ranks. + std::unordered_map buffers_; + Signal* self_sg_; + + // Stores rank data from all ranks. This is mainly for cuda graph purposes. + // For cuda graph to work, all kernel arguments must be fixed during graph + // capture time. However, the peer pointers are not known during graph capture + // time. Therefore, during capture, we increment the rank data pointer and use + // that as the argument to the kernel. The kernel arguments are stored in + // graph_unreg_buffers_. The actual peer pointers will be filled in at the + // memory pointed to by the pointers in graph_unreg_buffers_ when + // the IPC handles are exchanged between ranks. + // + // The overall process looks like this: + // 1. Graph capture. + // 2. Each rank obtains the IPC handles for each addresses used during cuda + // graph capture using get_graph_buffer_ipc_meta. + // 3. (In Python) all gather the IPC handles. + // 4. Obtain the peer pointers by opening the IPC handles, and store them in + // the rank data array at corresponding positions. + RankData *d_rank_data_base_, *d_rank_data_end_; + std::vector graph_unreg_buffers_; + // a map from IPC handles to opened IPC pointers + std::map ipc_handles_; + + /** + * Signals are an array of ipc-enabled buffers from all ranks. + * For each of the buffer, the layout is as follows: + * | -- sizeof(Signal) -- | ------ a few MB ----- | + * The first section is for allreduce synchronization, and the second section + * is for storing the intermediate results required by some allreduce algos. + * + * Note: this class does not own any device memory. Any required buffers + * are passed in from the constructor. + */ + CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz, + int rank, int world_size, bool full_nvlink = true) + : rank_(rank), + world_size_(world_size), + full_nvlink_(full_nvlink), + self_sg_(signals[rank]), + d_rank_data_base_(reinterpret_cast(rank_data)), + d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { + for (int i = 0; i < world_size_; i++) { + sg_.signals[i] = signals[i]; + } + } + + char* open_ipc_handle(const void* ipc_handle) { + auto [it, new_handle] = + ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); + if (new_handle) { + char* ipc_ptr; + CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr, + *((const cudaIpcMemHandle_t*)ipc_handle), + cudaIpcMemLazyEnablePeerAccess)); + it->second = ipc_ptr; + } + return it->second; + } + + std::pair> get_graph_buffer_ipc_meta() { + auto num_buffers = graph_unreg_buffers_.size(); + auto handle_sz = sizeof(cudaIpcMemHandle_t); + std::string handles(handle_sz * num_buffers, static_cast(0)); + std::vector offsets(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto ptr = graph_unreg_buffers_[i]; + void* base_ptr; + // note: must share the base address of each allocation, or we get wrong + // address + if (cuPointerGetAttribute(&base_ptr, + CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, + (CUdeviceptr)ptr) != CUDA_SUCCESS) + throw std::runtime_error("failed to get pointer attr"); + CUDACHECK(cudaIpcGetMemHandle( + (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); + offsets[i] = ((char*)ptr) - ((char*)base_ptr); + } + return std::make_pair(handles, offsets); + } + + void check_rank_data_capacity(size_t num = 1) { + if (d_rank_data_base_ + num > d_rank_data_end_) + throw std::runtime_error( + "Rank data buffer is overflowed by " + + std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); + } + + /** + * Register already-shared IPC pointers. + */ + void register_buffer(void** ptrs) { + check_rank_data_capacity(); + RankData data; + for (int i = 0; i < world_size_; i++) { + data.ptrs[i] = ptrs[i]; + } + auto d_data = d_rank_data_base_++; + CUDACHECK( + cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); + buffers_[ptrs[rank_]] = d_data; + } + + // Note: when registering graph buffers, we intentionally choose to not + // deduplicate the addresses. That means if the allocator reuses some + // addresses, they will be registered again. This is to account for the remote + // possibility of different allocation patterns between ranks. For example, + // rank 1 may get the same input address for the second allreduce, but rank 2 + // got a different address. IPC handles have internal reference counting + // mechanism so overhead should be small. + void register_graph_buffers( + const std::vector& handles, + const std::vector>& offsets) { + auto num_buffers = graph_unreg_buffers_.size(); + check_rank_data_capacity(num_buffers); + std::vector rank_data(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto self_ptr = graph_unreg_buffers_[i]; + auto& rd = rank_data[i]; + for (int j = 0; j < world_size_; j++) { + if (j != rank_) { + char* handle = + open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); + handle += offsets[j][i]; + rd.ptrs[j] = handle; + } else { + rd.ptrs[j] = self_ptr; + } + } + } + CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(), + sizeof(RankData) * num_buffers, + cudaMemcpyHostToDevice)); + d_rank_data_base_ += num_buffers; + graph_unreg_buffers_.clear(); + } + + /** + * Performs allreduce, assuming input has already been registered. + * + * Block and grid default configs are results after careful grid search. Using + * 36 blocks give the best or close to the best runtime on the devices I + * tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only + * take a small amount of SMs. Not quite sure the underlying reason, but my + * guess is that too many SMs will cause contention on NVLink bus. + */ + template + void allreduce(cudaStream_t stream, T* input, T* output, int size, + int threads = 512, int block_limit = 36) { + auto d = packed_t::P::size; + if (size % d != 0) + throw std::runtime_error( + "custom allreduce currently requires input length to be multiple " + "of " + + std::to_string(d)); + if (block_limit > kMaxBlocks) + throw std::runtime_error("max supported block limit is " + + std::to_string(kMaxBlocks) + ". Got " + + std::to_string(block_limit)); + + RankData* ptrs; + cudaStreamCaptureStatus status; + CUDACHECK(cudaStreamIsCapturing(stream, &status)); + if (status == cudaStreamCaptureStatusActive) { + ptrs = d_rank_data_base_ + graph_unreg_buffers_.size(); + graph_unreg_buffers_.push_back(input); + } else { + auto it = buffers_.find(input); + if (it == buffers_.end()) + throw std::runtime_error( + "buffer address " + + std::to_string(reinterpret_cast(input)) + + " is not registered!"); + ptrs = it->second; + } + + size /= d; + auto bytes = size * sizeof(typename packed_t::P); + int blocks = std::min(block_limit, (size + threads - 1) / threads); +#define KL(ngpus, name) \ + name<<>>(ptrs, sg_, self_sg_, output, \ + rank_, size); + // TODO(hanzhi713): Threshold is different for A100 and H100. + // Add per device threshold. +#define REDUCE_CASE(ngpus) \ + case ngpus: { \ + if (world_size_ == 2) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else if (full_nvlink_) { \ + if ((world_size_ <= 4 && bytes < 512 * 1024) || \ + (world_size_ <= 8 && bytes < 256 * 1024)) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else { \ + KL(ngpus, cross_device_reduce_2stage); \ + } \ + } \ + break; \ + } + + switch (world_size_) { + REDUCE_CASE(2) + REDUCE_CASE(4) + REDUCE_CASE(6) + REDUCE_CASE(8) + default: + throw std::runtime_error( + "custom allreduce only supports num gpus in (2,4,6,8). Actual num " + "gpus = " + + std::to_string(world_size_)); + } +#undef REDUCE_CASE +#undef KL + } + + ~CustomAllreduce() { + for (auto [_, ptr] : ipc_handles_) { + CUDACHECK(cudaIpcCloseMemHandle(ptr)); + } + } +}; +/** + * To inspect PTX/SASS, copy paste this header file to compiler explorer and add + a template instantiation: + * template void vllm::CustomAllreduce::allreduce(cudaStream_t, half *, + half *, int, int, int); +*/ +} // namespace vllm diff --git a/lightllm-kernel/csrc/attention/decode_attention_kernel.cu b/lightllm-kernel/csrc/attention/decode_attention_kernel.cu new file mode 100644 index 000000000..3fd4ce336 --- /dev/null +++ b/lightllm-kernel/csrc/attention/decode_attention_kernel.cu @@ -0,0 +1,569 @@ +#include +#include // need for FLT_MAX +#include +#include +#include +#include "ops_common.h" +#include +#include +#include +#include + +namespace lightllm { +namespace ops { + +# include +#define LIGHT_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define LIGHT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, LIGHT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +template +__device__ inline float tofloat(T value) { + return static_cast(value); +} + +// Specialization for __half +template <> +__device__ inline float tofloat<__half>(__half value) { + return __half2float(value); +} + +// Specialization for __nv_bfloat16 +template <> +__device__ inline float tofloat<__nv_bfloat16>(__nv_bfloat16 value) { + return __bfloat162float(value); +} + +template +struct BytesToType; + +template <> +struct BytesToType<2> +{ + using type = uint16_t; +}; +template <> +struct BytesToType<4> +{ + using type = uint32_t; +}; +template <> +struct BytesToType<8> +{ + using type = uint64_t; +}; +template <> +struct BytesToType<16> +{ + using type = float4; +}; + +template +__device__ inline void copy(const void* local, void* data) +{ + using T = typename BytesToType::type; + + const T* in = static_cast(local); + T* out = static_cast(data); + *out = *in; +} + +template +__device__ inline +float attn_thread_group_dot(T* local_q, T* local_k) +{ + // Helper function for QK Dot. + // [TODO] It should be optimized by type fp32x4. + + float qk = 0.0f; +# pragma unroll + for(int32_t i = 0; i < ELEMENT_NUM; i++) { + qk += tofloat(local_q[i]) * tofloat(local_k[i]); + } +#pragma unroll + for (int32_t mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +template +__device__ inline +float attn_block_reduce_max(float reducing, float* shared_mem) +{ + // Helper function for reduce softmax qkmax. + constexpr int32_t WARP_SIZE = 32; + const int32_t lane_id = threadIdx.x % WARP_SIZE; + const int32_t warp_id = threadIdx.x / WARP_SIZE; + +# pragma unroll + for (int32_t mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + reducing = fmaxf(reducing, __shfl_xor_sync(uint32_t(-1), reducing, mask)); + } + + if (lane_id == 0) { + shared_mem[warp_id] = reducing; + } + __syncthreads(); + + if (lane_id < WPT) reducing = shared_mem[lane_id]; + else reducing = -FLT_MAX; + +# pragma unroll + for (int32_t mask = WPT / 2; mask >= 1; mask /= 2) { + reducing = fmaxf(reducing, __shfl_xor_sync(uint32_t(-1), reducing, mask)); + } + + reducing = __shfl_sync(uint32_t(-1), reducing, 0); + return reducing; +} + +template +__device__ inline +float attn_block_reduce_sum(float reducing, float *shared_mem) +{ + // Helper function for reduce softmax exp sum. + constexpr int32_t WARP_SIZE = 32; + const int32_t lane_id = threadIdx.x % WARP_SIZE; + const int32_t warp_id = threadIdx.x / WARP_SIZE; + +# pragma unroll + for (int32_t mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + reducing += __shfl_xor_sync(uint32_t(-1), reducing, mask); + } + + if (lane_id == 0) shared_mem[warp_id] = reducing; + __syncthreads(); + + if (lane_id < WPT) reducing = shared_mem[lane_id]; + +# pragma unroll + for (int32_t mask = WPT / 2; mask >= 1; mask /= 2) { + reducing += __shfl_xor_sync(uint32_t(-1), reducing, mask); + } + reducing = __shfl_sync(uint32_t(-1), reducing, 0); + return reducing; +} + +template< + int32_t HEAD_SIZE, + int32_t THREAD_GROUP_SIZE, // how many threads inside a group + int32_t TPB, + int32_t QUANT_GROUP, + typename T> +__global__ +void dynamic_batching_decoding_cache_attention_fp16_kernel( + T* __restrict__ output, // [context_lens, num_heads..., head_size] + + const T* __restrict__ query, // [seq_lens, num_heads..., head_size] + const int8_t* k_cache, // [max_token, num_kv_heads, head_size] + const T* k_scale, // [max_token, num_kv_heads, head_size / quant_group(8)] + const int8_t* v_cache, // [max_token, num_kv_heads, head_size] + const T* v_scale, // [max_token, num_kv_heads, head_size / quant_group(8)] + + const float attn_scale, + + const int64_t output_stride_s, + const int64_t output_stride_h, + + const int64_t query_stride_s, + const int64_t query_stride_h, + + const int64_t kcache_stride_s, + const int64_t kcache_stride_h, + + const int64_t vcache_stride_s, + const int64_t vcache_stride_h, + + const int32_t * __restrict__ b_seq_len, + const int32_t * __restrict__ b_req_idx, + const int32_t * __restrict__ req_to_tokens, + const int64_t req_to_tokens_stride, + const int64_t max_len_in_batch, + const int64_t gqa_group_size) { + + /* --- Decoding Attention Kernel Implementation --- */ + constexpr int64_t WARP_SIZE = 32; // warp size + constexpr int64_t WPT = TPB / WARP_SIZE; // warp per thread block, TPB for Thread per block 4, block_size + constexpr int64_t GPW = WARP_SIZE / THREAD_GROUP_SIZE; // thread group per warp 4 + constexpr int64_t GPT = WARP_SIZE / THREAD_GROUP_SIZE * WPT; // thread group per thread block 16 + + // const int64_t num_heads = gridDim.x; + const int64_t head_idx = blockIdx.x; + const int64_t batch_idx = blockIdx.y; + + const int64_t seq_len = b_seq_len[batch_idx]; + const int64_t cur_req_idx = b_req_idx[batch_idx]; + const int32_t * b_start_loc = req_to_tokens + cur_req_idx * req_to_tokens_stride; + + constexpr int64_t VEC_SIZE = 16 / sizeof(T); // 128 bits, 这个是 cuda 能操作的最大的一个单位的数吧,8 + + // ------------------------------------------------ // + // Step 1. Load Q into Thread Reg. + constexpr int64_t VEC_LEN = (HEAD_SIZE / VEC_SIZE) / THREAD_GROUP_SIZE; // 128 / 8 / 8 = 2 + + static_assert((HEAD_SIZE / THREAD_GROUP_SIZE) % VEC_SIZE == 0); + static_assert(HEAD_SIZE % THREAD_GROUP_SIZE == 0); + static_assert(QUANT_GROUP == 8); + + constexpr int64_t QUANT_GROUP_SHIFT = 3; + + // The elements in Q, K, and V will be evenly distributed across each thread group. + T local_q[VEC_SIZE * VEC_LEN]; // 2 * 8 + + const int64_t warp_id = threadIdx.x / WARP_SIZE; + const int64_t warp_lane_id = threadIdx.x % WARP_SIZE; + const int64_t group_id = warp_lane_id / THREAD_GROUP_SIZE; + const int64_t group_lane_id = warp_lane_id % THREAD_GROUP_SIZE; + const int64_t kv_head_idx = head_idx / gqa_group_size; + + #pragma unroll + for (int64_t i = 0; i < VEC_LEN; i++) { + // copy 128(16 * 8) bits from Q to Local Q + + // 这个地方是错开间隔读取的,不知道如果设置成为连续位置读取会不会一样呢? + copy( + &query[ + batch_idx * query_stride_s + + head_idx * query_stride_h + + (group_lane_id + i * THREAD_GROUP_SIZE) * VEC_SIZE + ], + &local_q[i * VEC_SIZE]); + } + // ------------------------------------------------ // + // Step 2. Solve QK Dot + + const int64_t context_len = seq_len; + extern __shared__ float logits[]; + float qk_max = -FLT_MAX; + + for (int64_t base_id = warp_id * GPW; base_id < context_len; base_id += GPT) { + int8_t local_k_quant[VEC_SIZE * VEC_LEN]; + T local_k[VEC_SIZE * VEC_LEN]; + T local_k_scale[VEC_LEN]; + const int64_t context_id = base_id + group_id; + const int64_t mem_context_id = *(b_start_loc + context_id); + + // all thread groups within a warp must be launched together. + if (context_id >= context_len){ + memset(local_k, 0, sizeof(local_k)); + } else { + const int64_t key_offset + = (mem_context_id) * kcache_stride_s + + kv_head_idx * kcache_stride_h + + group_lane_id * VEC_SIZE; + #pragma unroll + for (int64_t i = 0; i < VEC_LEN; i++) { + // copy 128(16 * 8) bits from K to Local K + const int64_t key_idx = key_offset + i * THREAD_GROUP_SIZE * VEC_SIZE; + copy(&k_cache[key_idx], &local_k_quant[i * VEC_SIZE]); + + const int64_t key_scale_idx = key_idx >> QUANT_GROUP_SHIFT; + local_k_scale[i] = k_scale[key_scale_idx]; + } + + #pragma unroll + for (int64_t i = 0; i < VEC_LEN; i++) { + #pragma unroll + for (int64_t j = 0; j < VEC_SIZE; j++) { + local_k[i * VEC_SIZE + j] + = local_k_scale[i] * (T)local_k_quant[i * VEC_SIZE + j]; + } + } + } + + // Ready for QK Dot + const float qk_dot + = attn_scale + * attn_thread_group_dot(local_q, local_k); + + if (group_lane_id == 0 && context_id < context_len) { + logits[context_id] = qk_dot; + qk_max = fmaxf(qk_dot, qk_max); + } + } + + // ------------------------------------------------ // + // Step 3. Softmax + + __shared__ float red_smem[WPT]; + + qk_max = attn_block_reduce_max(qk_max, red_smem); + + float exp_sum = 0.0f; + for (int64_t context_id = threadIdx.x; context_id < context_len; context_id += TPB){ + logits[context_id] -= qk_max; + logits[context_id] = exp(logits[context_id]); + exp_sum += logits[context_id]; + } + + static_assert(WPT == 2 || WPT == 4 || WPT == 8 || WPT == 16 || WPT == 32 || WPT == 64); + exp_sum = attn_block_reduce_sum(exp_sum, red_smem); + + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int64_t context_id = threadIdx.x; context_id < context_len; context_id += TPB) { + logits[context_id] *= inv_sum; + } + __syncthreads(); // Must have this. + + // ------------------------------------------------ // + // Step 4. Solve logits * V + + int8_t local_v_quant[VEC_SIZE * VEC_LEN]; + float local_v[VEC_SIZE * VEC_LEN]; + T local_v_scale[VEC_LEN]; + + #pragma unroll + for(int32_t i = 0; i < VEC_SIZE * VEC_LEN; i++) { + local_v[i] = 0; + } + + for (int64_t base_id = warp_id * GPW; base_id < context_len; base_id += GPT) { + const int64_t context_id = base_id + group_id; + const int64_t mem_context_id = *(b_start_loc + context_id); + // all thread groups within a warp must be launched together. + if (context_id < context_len){ + const int64_t value_offset + = (mem_context_id) * vcache_stride_s + + kv_head_idx * vcache_stride_h + + group_lane_id * VEC_SIZE; + #pragma unroll + for (int64_t i = 0; i < VEC_LEN; i++) { + // copy 128(16 * 8) bits from V to Local V + const int64_t value_idx = value_offset + i * THREAD_GROUP_SIZE * VEC_SIZE; + copy(&v_cache[value_idx], &local_v_quant[i * VEC_SIZE]); + + const int64_t value_scale_idx = value_idx >> QUANT_GROUP_SHIFT; + local_v_scale[i] = v_scale[value_scale_idx]; + } + + #pragma unroll + for (int64_t i = 0; i < VEC_LEN; i++) { + #pragma unroll + for (int64_t j = 0; j < VEC_SIZE; j++) { + local_v[i * VEC_SIZE + j] += (tofloat(local_v_scale[i]) + * (float)local_v_quant[i * VEC_SIZE + j] + * logits[context_id]); + } + } + } + } + + #pragma unroll + for (int32_t i = 0; i < VEC_SIZE * VEC_LEN; i++) { + #pragma unroll + for (int32_t mask = THREAD_GROUP_SIZE; mask <= WARP_SIZE >> 1; mask = mask << 1) { + local_v[i] += __shfl_xor_sync(uint32_t(-1), local_v[i], mask); + } + } + + __syncthreads(); + + // do some reuse + for (int64_t i = threadIdx.x; i < HEAD_SIZE; i += TPB){ + logits[i] = 0; + } + + __syncthreads(); + + if (warp_lane_id < THREAD_GROUP_SIZE) { + #pragma unroll + for (int32_t i = 0; i < VEC_LEN; i++) { + #pragma unroll + for (int32_t j = 0; j < VEC_SIZE; j++) { + atomicAdd( + logits + i * THREAD_GROUP_SIZE * VEC_SIZE + warp_lane_id * VEC_SIZE + j, + local_v[i * VEC_SIZE + j] + ); + } + } + } + + __syncthreads(); + + for (int64_t i = threadIdx.x; i < HEAD_SIZE; i += TPB){ + output[batch_idx * output_stride_s + head_idx * output_stride_h + i] = logits[i]; + } +} + + +template +void run_group_int8kv_decode_attention_kernel( + T* __restrict__ output, + const T* __restrict__ query, + const int8_t* k_cache, + const T* k_scale, + const int8_t* v_cache, + const T* v_scale, + const float attn_scale, + const int64_t output_stride_s, + const int64_t output_stride_h, + const int64_t query_stride_s, + const int64_t query_stride_h, + const int64_t kcache_stride_s, + const int64_t kcache_stride_h, + const int64_t vcache_stride_s, + const int64_t vcache_stride_h, + const int32_t * __restrict__ b_seq_len, + const int32_t * __restrict__ b_req_idx, + const int32_t * __restrict__ req_to_tokens, + const int64_t req_to_tokens_stride, + const int64_t max_len_in_batch, + + const int64_t batch_size, + const int64_t q_head_num, + const int64_t head_dim, + const int64_t gqa_group_size) { + + constexpr int64_t WARP_SIZE = 32; + constexpr int64_t TPB = 256; + constexpr int64_t MAX_SHM_SIZE = 48 * 1024; + + constexpr int64_t reduce_shm_size = TPB / WARP_SIZE * sizeof(float); + const int64_t logits_size = max(max_len_in_batch * sizeof(float), head_dim * sizeof(float)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (reduce_shm_size + logits_size <= MAX_SHM_SIZE) { + const dim3 grid_size = {(unsigned int)q_head_num, (unsigned int)batch_size, 1}; + switch (head_dim){ + case 64: + dynamic_batching_decoding_cache_attention_fp16_kernel<64, 4, 256, 8> + <<>> + ( + output, query, k_cache, k_scale, v_cache, v_scale, + attn_scale, + output_stride_s, output_stride_h, + query_stride_s, query_stride_h, + kcache_stride_s, kcache_stride_h, + vcache_stride_s, vcache_stride_h, + b_seq_len, b_req_idx, req_to_tokens, + req_to_tokens_stride, + max_len_in_batch, + gqa_group_size + ); + break; + case 96: + dynamic_batching_decoding_cache_attention_fp16_kernel<96, 4, 256, 8> + <<>> + ( + output, query, k_cache, k_scale, v_cache, v_scale, + attn_scale, + output_stride_s, output_stride_h, + query_stride_s, query_stride_h, + kcache_stride_s, kcache_stride_h, + vcache_stride_s, vcache_stride_h, + b_seq_len, b_req_idx, req_to_tokens, + req_to_tokens_stride, + max_len_in_batch, + gqa_group_size + ); + break; + case 128: + dynamic_batching_decoding_cache_attention_fp16_kernel<128, 8, 256, 8> + <<>> + ( + output, query, k_cache, k_scale, v_cache, v_scale, + attn_scale, + output_stride_s, output_stride_h, + query_stride_s, query_stride_h, + kcache_stride_s, kcache_stride_h, + vcache_stride_s, vcache_stride_h, + b_seq_len, b_req_idx, req_to_tokens, + req_to_tokens_stride, + max_len_in_batch, + gqa_group_size + ); + break; + case 256: + dynamic_batching_decoding_cache_attention_fp16_kernel<256, 16, 256, 8> + <<>> + ( + output, query, k_cache, k_scale, v_cache, v_scale, + attn_scale, + output_stride_s, output_stride_h, + query_stride_s, query_stride_h, + kcache_stride_s, kcache_stride_h, + vcache_stride_s, vcache_stride_h, + b_seq_len, b_req_idx, req_to_tokens, + req_to_tokens_stride, + max_len_in_batch, + gqa_group_size + ); + break; + default: + assert(false); + } + } else { + assert(false); + } +} + +void group_int8kv_decode_attention(at::Tensor o, at::Tensor q, at::Tensor k, at::Tensor k_s, at::Tensor v, at::Tensor v_s, at::Tensor req_to_tokens, at::Tensor b_req_idx, at::Tensor b_seq_len, int max_len_in_batch) { + int64_t batch_size = b_seq_len.sizes()[0]; + int64_t head_num = q.sizes()[1]; + int64_t head_dim = q.sizes()[2]; // q shape [batchsize, head_num, head_dim] + float att_scale = 1.0 / std::sqrt(head_dim); + int64_t kv_head_num = k.sizes()[1]; + assert(head_num % kv_head_num == 0); + int64_t gqa_group_size = head_num / kv_head_num; + LIGHT_DISPATCH_FLOATING_TYPES(q.scalar_type(), "group_int8kv_decode_attention", ([&]{ + run_group_int8kv_decode_attention_kernel( + o.data_ptr(), q.data_ptr(), + k.data_ptr(), k_s.data_ptr(), + v.data_ptr(), v_s.data_ptr(), + att_scale, + o.stride(0), + o.stride(1), + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + b_seq_len.data_ptr(), + b_req_idx.data_ptr(), + req_to_tokens.data_ptr(), + req_to_tokens.stride(0), + max_len_in_batch, + batch_size, + head_num, + head_dim, + gqa_group_size + ); + } + )); +} + +void group_int8kv_decode_attention( + torch::Tensor o, + torch::Tensor q, + torch::Tensor k, + torch::Tensor k_s, + torch::Tensor v, + torch::Tensor v_s, + torch::Tensor req_to_tokens, + torch::Tensor b_req_idx, + torch::Tensor b_seq_len, + int64_t max_len_in_batch) +{ + group_int8kv_decode_attention( + o, + q, + k, + k_s, + v, + v_s, + req_to_tokens, + b_req_idx, + b_seq_len, + static_cast(max_len_in_batch) + ); +} + + +} +} \ No newline at end of file diff --git a/lightllm-kernel/csrc/attention/decode_attention_kernel_in8kv_flashdecoding.cu b/lightllm-kernel/csrc/attention/decode_attention_kernel_in8kv_flashdecoding.cu new file mode 100644 index 000000000..c55eaaf6f --- /dev/null +++ b/lightllm-kernel/csrc/attention/decode_attention_kernel_in8kv_flashdecoding.cu @@ -0,0 +1,650 @@ + +#include +#include +#include // need for FLT_MAX +#include +#include +#include +#include +#include +#include "ops_common.h" +# include + +#include +#include + +namespace lightllm { +namespace ops { + +template +__device__ inline float tofloat(T value) { + return static_cast(value); +} + +// Specialization for __half +template <> +__device__ inline float tofloat<__half>(__half value) { + return __half2float(value); +} + +// Specialization for __nv_bfloat16 +template <> +__device__ inline float tofloat<__nv_bfloat16>(__nv_bfloat16 value) { + return __bfloat162float(value); +} + +#define LIGHT_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define LIGHT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, LIGHT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +template +struct BytesToType; + +template <> +struct BytesToType<2> +{ + using type = uint16_t; +}; +template <> +struct BytesToType<4> +{ + using type = uint32_t; +}; +template <> +struct BytesToType<8> +{ + using type = uint64_t; +}; +template <> +struct BytesToType<16> +{ + using type = float4; +}; + +template +__device__ inline void copy(const void* local, void* data) +{ + using T = typename BytesToType::type; + + const T* in = static_cast(local); + T* out = static_cast(data); + *out = *in; +} + +template +__device__ inline +float attn_thread_group_dot(T* local_q, T* local_k) +{ + // Helper function for QK Dot. + // [TODO] It should be optimized by type fp32x4. + + float qk = 0.0f; +# pragma unroll + for(int32_t i = 0; i < ELEMENT_NUM; i++) { + qk += tofloat(local_q[i]) * tofloat(local_k[i]); + } +#pragma unroll + for (int32_t mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +template +__device__ inline +float attn_block_reduce_max(float reducing, float* shared_mem) +{ + // Helper function for reduce softmax qkmax. + constexpr int32_t WARP_SIZE = 32; + const int32_t lane_id = threadIdx.x % WARP_SIZE; + const int32_t warp_id = threadIdx.x / WARP_SIZE; + +# pragma unroll + for (int32_t mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + reducing = fmaxf(reducing, __shfl_xor_sync(uint32_t(-1), reducing, mask)); + } + + if (lane_id == 0) { + shared_mem[warp_id] = reducing; + } + __syncthreads(); + + if (lane_id < WPT) reducing = shared_mem[lane_id]; + else reducing = -FLT_MAX; + +# pragma unroll + for (int32_t mask = WPT / 2; mask >= 1; mask /= 2) { + reducing = fmaxf(reducing, __shfl_xor_sync(uint32_t(-1), reducing, mask)); + } + + reducing = __shfl_sync(uint32_t(-1), reducing, 0); + return reducing; +} + +template +__device__ inline +float attn_block_reduce_sum(float reducing, float *shared_mem) +{ + // Helper function for reduce softmax exp sum. + constexpr int32_t WARP_SIZE = 32; + const int32_t lane_id = threadIdx.x % WARP_SIZE; + const int32_t warp_id = threadIdx.x / WARP_SIZE; + +# pragma unroll + for (int32_t mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + reducing += __shfl_xor_sync(uint32_t(-1), reducing, mask); + } + + if (lane_id == 0) shared_mem[warp_id] = reducing; + __syncthreads(); + + if (lane_id < WPT) reducing = shared_mem[lane_id]; + +# pragma unroll + for (int32_t mask = WPT / 2; mask >= 1; mask /= 2) { + reducing += __shfl_xor_sync(uint32_t(-1), reducing, mask); + } + reducing = __shfl_sync(uint32_t(-1), reducing, 0); + return reducing; +} + +template< + int32_t HEAD_SIZE, + int32_t THREAD_GROUP_SIZE, // how many threads inside a group + int32_t TPB, + int32_t QUANT_GROUP, + typename T> +__global__ +void dynamic_batching_flashdecoding_cache_attention_int8kv_kernel( + const int64_t seq_block_size, + + T* __restrict__ output_emb, + T* __restrict__ output_logexpsum, + // T* __restrict__ output, // [context_lens, num_heads..., head_size] + + const T* __restrict__ query, // [seq_lens, num_heads..., head_size] + const int8_t* k_cache, // [max_token, num_kv_heads, head_size] + const T* k_scale, // [max_token, num_kv_heads, head_size / quant_group(8)] + const int8_t* v_cache, // [max_token, num_kv_heads, head_size] + const T* v_scale, // [max_token, num_kv_heads, head_size / quant_group(8)] + + const float attn_scale, + + const int64_t output_emb_stride_b, + const int64_t output_emb_stride_h, + const int64_t output_emb_stride_s, + const int64_t output_emb_stride_d, + + const int64_t output_logexpsum_stride_b, + const int64_t output_logexpsum_stride_h, + const int64_t output_logexpsum_stride_s, + + const int64_t query_stride_s, + const int64_t query_stride_h, + + const int64_t kcache_stride_s, + const int64_t kcache_stride_h, + + const int64_t vcache_stride_s, + const int64_t vcache_stride_h, + + const int32_t * __restrict__ b_seq_len, + const int32_t * __restrict__ b_req_idx, + const int32_t * __restrict__ req_to_tokens, + const int64_t req_to_tokens_stride, + const int64_t max_len_in_batch, + const int64_t gqa_group_size) { + + /* --- Decoding Attention Kernel Implementation --- */ + constexpr int64_t WARP_SIZE = 32; // warp size + constexpr int64_t WPT = TPB / WARP_SIZE; // warp per thread block, TPB for Thread per block 4, block_size + constexpr int64_t GPW = WARP_SIZE / THREAD_GROUP_SIZE; // thread group per warp 4 + constexpr int64_t GPT = WARP_SIZE / THREAD_GROUP_SIZE * WPT; // thread group per thread block 16 + + // const int64_t num_heads = gridDim.x; + const int64_t head_idx = blockIdx.x; + const int64_t batch_idx = blockIdx.y; + const int64_t seq_block_idx = blockIdx.z; + + const int64_t seq_len = b_seq_len[batch_idx]; + const int64_t cur_req_idx = b_req_idx[batch_idx]; + const int32_t * b_start_loc = req_to_tokens + cur_req_idx * req_to_tokens_stride + seq_block_idx * seq_block_size; + + constexpr int64_t VEC_SIZE = 16 / sizeof(T); // 128 bits, 这个是 cuda 能操作的最大的一个单位的数吧,8 + + // ------------------------------------------------ // + // Step 1. Load Q into Thread Reg. + constexpr int64_t VEC_LEN = (HEAD_SIZE / VEC_SIZE) / THREAD_GROUP_SIZE; // 128 / 8 / 8 = 2 + + static_assert((HEAD_SIZE / THREAD_GROUP_SIZE) % VEC_SIZE == 0); + static_assert(HEAD_SIZE % THREAD_GROUP_SIZE == 0); + static_assert(QUANT_GROUP == 8); + + constexpr int64_t QUANT_GROUP_SHIFT = 3; + + // The elements in Q, K, and V will be evenly distributed across each thread group. + T local_q[VEC_SIZE * VEC_LEN]; // 2 * 8 + + const int64_t warp_id = threadIdx.x / WARP_SIZE; + const int64_t warp_lane_id = threadIdx.x % WARP_SIZE; + const int64_t group_id = warp_lane_id / THREAD_GROUP_SIZE; + const int64_t group_lane_id = warp_lane_id % THREAD_GROUP_SIZE; + const int64_t kv_head_idx = head_idx / gqa_group_size; + + if (seq_len <= seq_block_idx * seq_block_size) { + return; + } + const int64_t context_len = min(seq_len - seq_block_idx * seq_block_size, seq_block_size); + + #pragma unroll + for (int64_t i = 0; i < VEC_LEN; i++) { + // copy 128(16 * 8) bits from Q to Local Q + + // 这个地方是错开间隔读取的,不知道如果设置成为连续位置读取会不会一样呢? + copy( + &query[ + batch_idx * query_stride_s + + head_idx * query_stride_h + + (group_lane_id + i * THREAD_GROUP_SIZE) * VEC_SIZE + ], + &local_q[i * VEC_SIZE]); + } + // ------------------------------------------------ // + // Step 2. Solve QK Dot + + extern __shared__ float logits[]; + float qk_max = -FLT_MAX; + + for (int64_t base_id = warp_id * GPW; base_id < context_len; base_id += GPT) { + int8_t local_k_quant[VEC_SIZE * VEC_LEN]; + T local_k[VEC_SIZE * VEC_LEN]; + T local_k_scale[VEC_LEN]; + const int64_t context_id = base_id + group_id; + const int64_t mem_context_id = *(b_start_loc + context_id); + + // all thread groups within a warp must be launched together. + if (context_id >= context_len){ + memset(local_k, 0, sizeof(local_k)); + } else { + const int64_t key_offset + = (mem_context_id) * kcache_stride_s + + kv_head_idx * kcache_stride_h + + group_lane_id * VEC_SIZE; + #pragma unroll + for (int64_t i = 0; i < VEC_LEN; i++) { + // copy 128(16 * 8) bits from K to Local K + const int64_t key_idx = key_offset + i * THREAD_GROUP_SIZE * VEC_SIZE; + copy(&k_cache[key_idx], &local_k_quant[i * VEC_SIZE]); + + const int64_t key_scale_idx = key_idx >> QUANT_GROUP_SHIFT; + local_k_scale[i] = k_scale[key_scale_idx]; + } + + #pragma unroll + for (int64_t i = 0; i < VEC_LEN; i++) { + #pragma unroll + for (int64_t j = 0; j < VEC_SIZE; j++) { + local_k[i * VEC_SIZE + j] + = local_k_scale[i] * (T)local_k_quant[i * VEC_SIZE + j]; + } + } + } + + // Ready for QK Dot + const float qk_dot + = attn_scale + * attn_thread_group_dot(local_q, local_k); + + if (group_lane_id == 0 && context_id < context_len) { + logits[context_id] = qk_dot; + qk_max = fmaxf(qk_dot, qk_max); + } + } + + // ------------------------------------------------ // + // Step 3. Softmax + + __shared__ float red_smem[WPT]; + + qk_max = attn_block_reduce_max(qk_max, red_smem); + + float exp_sum = 0.0f; + for (int64_t context_id = threadIdx.x; context_id < context_len; context_id += TPB){ + logits[context_id] -= qk_max; + logits[context_id] = exp(logits[context_id]); + exp_sum += logits[context_id]; + } + + static_assert(WPT == 2 || WPT == 4 || WPT == 8 || WPT == 16 || WPT == 32 || WPT == 64); + exp_sum = attn_block_reduce_sum(exp_sum, red_smem); + + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int64_t context_id = threadIdx.x; context_id < context_len; context_id += TPB) { + logits[context_id] *= inv_sum; + } + __syncthreads(); // Must have this. + + // ------------------------------------------------ // + // Step 4. Solve logits * V + + int8_t local_v_quant[VEC_SIZE * VEC_LEN]; + float local_v[VEC_SIZE * VEC_LEN]; + T local_v_scale[VEC_LEN]; + + #pragma unroll + for(int32_t i = 0; i < VEC_SIZE * VEC_LEN; i++) { + local_v[i] = 0; + } + + for (int64_t base_id = warp_id * GPW; base_id < context_len; base_id += GPT) { + const int64_t context_id = base_id + group_id; + const int64_t mem_context_id = *(b_start_loc + context_id); + // all thread groups within a warp must be launched together. + if (context_id < context_len){ + const int64_t value_offset + = (mem_context_id) * vcache_stride_s + + kv_head_idx * vcache_stride_h + + group_lane_id * VEC_SIZE; + #pragma unroll + for (int64_t i = 0; i < VEC_LEN; i++) { + // copy 128(16 * 8) bits from V to Local V + const int64_t value_idx = value_offset + i * THREAD_GROUP_SIZE * VEC_SIZE; + copy(&v_cache[value_idx], &local_v_quant[i * VEC_SIZE]); + + const int64_t value_scale_idx = value_idx >> QUANT_GROUP_SHIFT; + local_v_scale[i] = v_scale[value_scale_idx]; + } + + #pragma unroll + for (int64_t i = 0; i < VEC_LEN; i++) { + #pragma unroll + for (int64_t j = 0; j < VEC_SIZE; j++) { + local_v[i * VEC_SIZE + j] += (tofloat(local_v_scale[i]) + * (float)local_v_quant[i * VEC_SIZE + j] + * logits[context_id]); + } + } + } + } + + #pragma unroll + for (int32_t i = 0; i < VEC_SIZE * VEC_LEN; i++) { + #pragma unroll + for (int32_t mask = THREAD_GROUP_SIZE; mask <= WARP_SIZE >> 1; mask = mask << 1) { + local_v[i] += __shfl_xor_sync(uint32_t(-1), local_v[i], mask); + } + } + + __syncthreads(); + + // do some reuse + for (int64_t i = threadIdx.x; i < HEAD_SIZE; i += TPB){ + logits[i] = 0; + } + + __syncthreads(); + + if (warp_lane_id < THREAD_GROUP_SIZE) { + #pragma unroll + for (int32_t i = 0; i < VEC_LEN; i++) { + #pragma unroll + for (int32_t j = 0; j < VEC_SIZE; j++) { + atomicAdd( + logits + i * THREAD_GROUP_SIZE * VEC_SIZE + warp_lane_id * VEC_SIZE + j, + local_v[i * VEC_SIZE + j] + ); + } + } + } + + __syncthreads(); + + for (int64_t i = threadIdx.x; i < HEAD_SIZE; i += TPB) { + output_emb[batch_idx * output_emb_stride_b + head_idx * output_emb_stride_h + seq_block_idx * output_emb_stride_s + i] = logits[i]; + } + + output_logexpsum[batch_idx * output_logexpsum_stride_b + head_idx * output_logexpsum_stride_h + seq_block_idx] = logf(exp_sum) + qk_max; +} + + +template +void run_group_int8kv_decode_flashattention_kernel( + const int64_t seq_block_size, + T* __restrict__ output_emb, + T* __restrict__ output_logexpsum, + const T* __restrict__ query, + const int8_t* k_cache, + const T* k_scale, + const int8_t* v_cache, + const T* v_scale, + const float attn_scale, + + const int64_t output_emb_stride_b, + const int64_t output_emb_stride_h, + const int64_t output_emb_stride_s, + const int64_t output_emb_stride_d, + + const int64_t output_logexpsum_stride_b, + const int64_t output_logexpsum_stride_h, + const int64_t output_logexpsum_stride_s, + + const int64_t query_stride_s, + const int64_t query_stride_h, + const int64_t kcache_stride_s, + const int64_t kcache_stride_h, + const int64_t vcache_stride_s, + const int64_t vcache_stride_h, + const int32_t * __restrict__ b_seq_len, + const int32_t * __restrict__ b_req_idx, + const int32_t * __restrict__ req_to_tokens, + const int64_t req_to_tokens_stride, + const int64_t max_len_in_batch, + + const int64_t batch_size, + const int64_t q_head_num, + const int64_t head_dim, + const int64_t gqa_group_size) { + + constexpr int64_t WARP_SIZE = 32; + constexpr int64_t TPB = 256; + constexpr int64_t MAX_SHM_SIZE = 48 * 1024; + + constexpr int64_t reduce_shm_size = TPB / WARP_SIZE * sizeof(float); + const int64_t logits_size = max(seq_block_size * sizeof(float), head_dim * sizeof(float)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (reduce_shm_size + logits_size <= MAX_SHM_SIZE) { + const dim3 grid_size = {static_cast(q_head_num), static_cast(batch_size), static_cast((max_len_in_batch + seq_block_size - 1) / seq_block_size)}; + switch (head_dim){ + case 64: + dynamic_batching_flashdecoding_cache_attention_int8kv_kernel<64, 4, 256, 8> + <<>> + ( + seq_block_size, + output_emb, + output_logexpsum, + query, k_cache, k_scale, v_cache, v_scale, + attn_scale, + output_emb_stride_b, + output_emb_stride_h, + output_emb_stride_s, + output_emb_stride_d, + output_logexpsum_stride_b, + output_logexpsum_stride_h, + output_logexpsum_stride_s, + query_stride_s, query_stride_h, + kcache_stride_s, kcache_stride_h, + vcache_stride_s, vcache_stride_h, + b_seq_len, b_req_idx, req_to_tokens, + req_to_tokens_stride, + max_len_in_batch, + gqa_group_size + ); + break; + case 96: + dynamic_batching_flashdecoding_cache_attention_int8kv_kernel<96, 4, 256, 8> + <<>> + ( + seq_block_size, + output_emb, + output_logexpsum, + query, k_cache, k_scale, v_cache, v_scale, + attn_scale, + output_emb_stride_b, + output_emb_stride_h, + output_emb_stride_s, + output_emb_stride_d, + output_logexpsum_stride_b, + output_logexpsum_stride_h, + output_logexpsum_stride_s, + query_stride_s, query_stride_h, + kcache_stride_s, kcache_stride_h, + vcache_stride_s, vcache_stride_h, + b_seq_len, b_req_idx, req_to_tokens, + req_to_tokens_stride, + max_len_in_batch, + gqa_group_size + ); + break; + case 128: + dynamic_batching_flashdecoding_cache_attention_int8kv_kernel<128, 8, 256, 8> + <<>> + ( + seq_block_size, + output_emb, + output_logexpsum, + query, k_cache, k_scale, v_cache, v_scale, + attn_scale, + output_emb_stride_b, + output_emb_stride_h, + output_emb_stride_s, + output_emb_stride_d, + output_logexpsum_stride_b, + output_logexpsum_stride_h, + output_logexpsum_stride_s, + query_stride_s, query_stride_h, + kcache_stride_s, kcache_stride_h, + vcache_stride_s, vcache_stride_h, + b_seq_len, b_req_idx, req_to_tokens, + req_to_tokens_stride, + max_len_in_batch, + gqa_group_size + ); + break; + case 256: + dynamic_batching_flashdecoding_cache_attention_int8kv_kernel<256, 16, 256, 8> + <<>> + ( + seq_block_size, + output_emb, + output_logexpsum, + query, k_cache, k_scale, v_cache, v_scale, + attn_scale, + output_emb_stride_b, + output_emb_stride_h, + output_emb_stride_s, + output_emb_stride_d, + output_logexpsum_stride_b, + output_logexpsum_stride_h, + output_logexpsum_stride_s, + query_stride_s, query_stride_h, + kcache_stride_s, kcache_stride_h, + vcache_stride_s, vcache_stride_h, + b_seq_len, b_req_idx, req_to_tokens, + req_to_tokens_stride, + max_len_in_batch, + gqa_group_size + ); + break; + default: + assert(false); + } + } else { + assert(false); + } +} + +void group_int8kv_flashdecoding_attention(const int seq_block_size, at::Tensor mid_o_emb, at::Tensor mid_o_logexpsum, float att_scale, at::Tensor q, at::Tensor k, at::Tensor k_s, at::Tensor v, at::Tensor v_s, at::Tensor req_to_tokens, at::Tensor b_req_idx, at::Tensor b_seq_len, int max_len_in_batch) { + int64_t batch_size = b_seq_len.sizes()[0]; + int64_t head_num = q.sizes()[1]; + int64_t head_dim = q.sizes()[2]; // q shape [batchsize, head_num, head_dim] + int64_t kv_head_num = k.sizes()[1]; + assert(head_num % kv_head_num == 0); + int64_t gqa_group_size = head_num / kv_head_num; + + LIGHT_DISPATCH_FLOATING_TYPES(q.scalar_type(), "group_int8kv_flashdecoding_attention", ([&] { + run_group_int8kv_decode_flashattention_kernel( + seq_block_size, + mid_o_emb.data_ptr(), + mid_o_logexpsum.data_ptr(), + q.data_ptr(), + k.data_ptr(), k_s.data_ptr(), + v.data_ptr(), v_s.data_ptr(), + att_scale, + + mid_o_emb.stride(0), + mid_o_emb.stride(1), + mid_o_emb.stride(2), + mid_o_emb.stride(3), + mid_o_logexpsum.stride(0), + mid_o_logexpsum.stride(1), + mid_o_logexpsum.stride(2), + + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + b_seq_len.data_ptr(), + b_req_idx.data_ptr(), + req_to_tokens.data_ptr(), + req_to_tokens.stride(0), + max_len_in_batch, + batch_size, + head_num, + head_dim, + gqa_group_size + ); + })); + +} + +void group_int8kv_flashdecoding_attention( + const int64_t seq_block_size, + torch::Tensor mid_o_emb, + torch::Tensor mid_o_logexpsum, + fp32_t att_scale, + torch::Tensor q, + torch::Tensor k, + torch::Tensor k_s, + torch::Tensor v, + torch::Tensor v_s, + torch::Tensor req_to_tokens, + torch::Tensor b_req_idx, + torch::Tensor b_seq_len, + int64_t max_len_in_batch) +{ + group_int8kv_flashdecoding_attention( + static_cast(seq_block_size), + mid_o_emb, + mid_o_logexpsum, + att_scale, + q, + k, + k_s, + v, + v_s, + req_to_tokens, + b_req_idx, + b_seq_len, + static_cast(max_len_in_batch) + ); +} + +} +} \ No newline at end of file diff --git a/lightllm-kernel/csrc/cuda_compat.h b/lightllm-kernel/csrc/cuda_compat.h new file mode 100644 index 000000000..82e55613d --- /dev/null +++ b/lightllm-kernel/csrc/cuda_compat.h @@ -0,0 +1,49 @@ +#pragma once + +#ifdef USE_ROCM + #include +#endif + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + +#ifndef USE_ROCM + #define VLLM_LDG(arg) __ldg(arg) +#else + #define VLLM_LDG(arg) *(arg) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor(var, lane_mask, width) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) +#else + #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \ + __shfl_down_sync(uint32_t(-1), var, lane_delta) +#else + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) +#endif + +#ifndef USE_ROCM + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) +#else + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) +#endif diff --git a/lightllm-kernel/csrc/fusion/add_norm_quant.cu b/lightllm-kernel/csrc/fusion/add_norm_quant.cu new file mode 100755 index 000000000..3684dffc8 --- /dev/null +++ b/lightllm-kernel/csrc/fusion/add_norm_quant.cu @@ -0,0 +1,551 @@ +#include "ops_common.h" +#include "reduce/sm70.cuh" + +namespace lightllm { +namespace ops { + +using namespace lightllm; + +template +__global__ void device_add_norm_quant_bf16_general( + bf16_t* __restrict__ input, // Input tensor in BF16 format + const bf16_t* __restrict__ residual, // Residual tensor in BF16 format + const bf16_t* __restrict__ weight, // Weight tensor in BF16 format + fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format + fp32_t* __restrict__ scales, // Output scales for each group + const int64_t M, // Number of rows in the input tensor + const int32_t N, // Number of cols in the input tensor + const fp32_t eps // Epsilon value for numerical stability +) { + const fp32_t r_N = 1 / (fp32_t)N; // Reciprocal of N. + constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format + + const int32_t tid = threadIdx.x; + const int32_t bid = blockIdx.x; + + // Each block processes one row of the input tensor. + bf16_t* _input = input + bid * N; + const bf16_t* _residual = residual + bid * N; + fp8_e4m3_t* _output = output + bid * N; + + fp32_t* _scales; + _scales = scales + bid; + + // Shared memory workspace to store data. + extern __shared__ bf16_t workspace1[]; + + // Local registers to hold data. + bf16_t local_input; + bf16_t local_residual; + bf16_t local_w; + bf16_t local_output; + fp8_e4m3_t local_f8; + + + // Each thread computes a partial sum of squares. + fp32_t local_square_sum = 0.0f; + for (int32_t i = tid; i < N; i += TPB) { + local_input = _input[i]; + local_residual = _residual[i]; + + fp32_t x = cvt_bf16_f32(local_input); + fp32_t r = cvt_bf16_f32(local_residual); + local_input = cvt_f32_bf16(x + r); + fp32_t tmp = cvt_bf16_f32(local_input); + local_square_sum += tmp * tmp; + + _input[i] = local_input; + workspace1[i] = local_input; + } + + const fp32_t reduced_square_sum = lightllm::reduce::sm70::sync_block_reduce_sum_f32(local_square_sum); + + // Compute the mean square and then the inverse RMS normalization factor. + // For RMSNorm, the normalization factor is 1/sqrt(mean(x^2)+eps). + const fp32_t mean_square = reduced_square_sum * r_N; + const fp32_t inv_norm = rsqrtf(mean_square + eps); + + // Normalize each element using the computed normalization factor. + fp32_t local_max = -FLT_MAX; + for (int32_t i = tid; i < N; i += TPB) { + local_input = workspace1[i]; + local_w = weight[i]; + + fp32_t x = cvt_bf16_f32(local_input); + fp32_t w = cvt_bf16_f32(local_w); + + fp32_t ret = x * inv_norm * w; + local_output = cvt_f32_bf16(ret); + fp32_t tmp = cvt_bf16_f32(local_output); + local_max = fmaxf(local_max, fabsf(tmp)); + + workspace1[i] = local_output; + } + + // Reduce the maximum value across the block + const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32(local_max); + + // Compute the scale factor with epsilon to avoid division by zero + constexpr fp32_t epsilon = 1e-7f; + const fp32_t scale = reduced_max / FP8_E4M3_MAX; + const fp32_t inv_scale = 1.0f / (scale + epsilon); + + for (int32_t i = tid; i < N; i += TPB) { + local_output = workspace1[i]; + + fp32_t tmp = cvt_bf16_f32(local_output); + fp32_t ret = tmp * inv_scale; + local_f8 = fp8_e4m3_t(ret); + + _output[i] = local_f8; + } + + if(tid == 0){ + *_scales = scale; + } +} + + + +template +__global__ void device_add_norm_quant_bf16_vpt( + bf16_t* __restrict__ input, // Input tensor in BF16 format + const bf16_t* __restrict__ residual, // Residual tensor in BF16 format + const bf16_t* __restrict__ weight, // Weight tensor in BF16 format + fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format + fp32_t* __restrict__ scales, // Output scales for each group + const int64_t M, // Number of rows in the input tensor + const int32_t N, // Number of cols in the input tensor + const fp32_t eps // Epsilon value for numerical stability +) { + constexpr int32_t VPT = 8; // Number of FP16 values processed per thread. + const fp32_t r_N = 1 / (fp32_t)N; // Reciprocal of N. + constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format + + const int32_t tid = threadIdx.x; + const int32_t bid = blockIdx.x; + + // Each block processes one row of the input tensor. + bf16_t* _input = input + bid * N; + const bf16_t* _residual = residual + bid * N; + fp8_e4m3_t* _output = output + bid * N; + + fp32_t* _scales; + _scales = scales + bid; + + // Shared memory workspace to store vectorized (half2) data. + // Note: since each bf16x2_t holds 2 half values, the workspace size is N/2. + extern __shared__ bf16x2_t workspace2[]; + + // Local registers to hold vectorized data. + bf16x2_t local_input[VPT / 2]; + bf16x2_t local_residual[VPT / 2]; + bf16x2_t local_w[VPT / 2]; + bf16x2_t local_output[VPT / 2]; + fp8x4_e4m3_t local_f8[VPT / 4]; + + + // Each thread computes a partial sum of squares. + fp32_t local_square_sum = 0.0f; + for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { + // Load VPT FP16 elements from global memory (_input) into local vector (local_input). + vec_copy(_input + i, local_input); + // Load VPT FP16 elements from global memory (_residual) into local vector (local_residual). + vec_copy(_residual + i, local_residual); + + # pragma unroll + for (int32_t j = 0; j < VPT / 2; j++) { + // Convert the bf16x2_t to fp32x2_t for computation. + fp32x2_t x = bf16x2_to_fp32x2(local_input[j]); + fp32x2_t r = bf16x2_to_fp32x2(local_residual[j]); + // Add the residual to the input. + local_input[j] = _float22bf162_rn(make_float2(x.x + r.x, x.y + r.y)); + + fp32x2_t tmp = bf16x2_to_fp32x2(local_input[j]); + local_square_sum += (tmp.x * tmp.x + tmp.y * tmp.y); + } + + // Store the loaded data into shared memory. + // Divide index by 2 because 'workspace' is an array of bf16x2_t. + vec_copy(local_input, _input + i); + vec_copy(local_input, workspace2 + (i >> 1)); + } + + const fp32_t reduced_square_sum = lightllm::reduce::sm70::sync_block_reduce_sum_f32(local_square_sum); + + // Compute the mean square and then the inverse RMS normalization factor. + // For RMSNorm, the normalization factor is 1/sqrt(mean(x^2)+eps). + const fp32_t mean_square = reduced_square_sum * r_N; + const fp32_t inv_norm = rsqrtf(mean_square + eps); + + // Normalize each element using the computed normalization factor. + fp32_t local_max = -FLT_MAX; + for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { + // Load the previously stored vectorized data from shared memory. + vec_copy(workspace2 + (i >> 1), local_input); + // Load the corresponding weight values from global memory. + vec_copy(weight + i, local_w); + + #pragma unroll + for (int32_t j = 0; j < VPT / 2; j++) { + fp32x2_t x = bf16x2_to_fp32x2(local_input[j]); + fp32x2_t w = bf16x2_to_fp32x2(local_w[j]); + // Apply normalization: multiply by inv_norm and then scale by the weight. + fp32x2_t ret = make_float2( + x.x * inv_norm * w.x, + x.y * inv_norm * w.y + ); + local_output[j] = _float22bf162_rn(ret); + + + fp32x2_t tmp = bf16x2_to_fp32x2(local_output[j]); + fp32_t max = fmaxf(fabsf(tmp.x), fabsf(tmp.y)); + local_max = fmaxf(local_max, max); + } + + vec_copy(local_output, workspace2 + (i >> 1)); + } + + // Reduce the maximum value across the block + const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32(local_max); + + // Compute the scale factor with epsilon to avoid division by zero + constexpr fp32_t epsilon = 1e-7f; + const fp32_t scale = reduced_max / FP8_E4M3_MAX; + const fp32_t inv_scale = 1.0f / (scale + epsilon); + + for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { + vec_copy(workspace2 + (i >> 1), local_output); + + #pragma unroll + for (int32_t j = 0; j < VPT/4; j++) { + fp32x2_t x = bf16x2_to_fp32x2(local_output[2 * j + 0]); + fp32x2_t y = bf16x2_to_fp32x2(local_output[2 * j + 1]); + fp32x4_t ret = make_float4( + x.x * inv_scale, + x.y * inv_scale, + y.x * inv_scale, + y.y * inv_scale + ); + local_f8[j] = fp8x4_e4m3_t(ret); + } + + vec_copy(local_f8, _output + i); + } + + if(tid == 0){ + *_scales = scale; + } +} + + +template +__global__ void device_add_norm_quant_bf16( + bf16_t* __restrict__ input, // Input tensor in BF16 format + const bf16_t* __restrict__ residual, // Residual tensor in BF16 format + const bf16_t* __restrict__ weight, // Weight tensor in BF16 format + fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format + fp32_t* __restrict__ scales, // Output scales for each group + const int64_t M, // Number of rows in the input tensor + const fp32_t eps // Epsilon value for numerical stability +) { + constexpr int32_t VPT = 8; // Number of FP16 values processed per thread. + constexpr fp32_t r_N = 1 / (fp32_t)N; // Reciprocal of N. + constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format + + static_assert(N % 2 == 0, "N must be even."); + static_assert(N % VPT == 0, "N must be a multiple of VPT."); + + const int32_t tid = threadIdx.x; + const int32_t bid = blockIdx.x; + + // Each block processes one row of the input tensor. + bf16_t* _input = input + bid * N; + const bf16_t* _residual = residual + bid * N; + fp8_e4m3_t* _output = output + bid * N; + + fp32_t* _scales; + _scales = scales + bid; + + // Shared memory workspace to store vectorized (half2) data. + // Note: since each bf16x2_t holds 2 half values, the workspace size is N/2. + __shared__ bf16x2_t workspace[N / 2]; + + // Local registers to hold vectorized data. + bf16x2_t local_input[VPT / 2]; + bf16x2_t local_residual[VPT / 2]; + bf16x2_t local_w[VPT / 2]; + bf16x2_t local_output[VPT / 2]; + fp8x4_e4m3_t local_f8[VPT / 4]; + + + // Each thread computes a partial sum of squares. + fp32_t local_square_sum = 0.0f; + # pragma unroll + for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { + // Load VPT FP16 elements from global memory (_input) into local vector (local_input). + vec_copy(_input + i, local_input); + // Load VPT FP16 elements from global memory (_residual) into local vector (local_residual). + vec_copy(_residual + i, local_residual); + + # pragma unroll + for (int32_t j = 0; j < VPT / 2; j++) { + // Convert the bf16x2_t to fp32x2_t for computation. + fp32x2_t x = bf16x2_to_fp32x2(local_input[j]); + fp32x2_t r = bf16x2_to_fp32x2(local_residual[j]); + // Add the residual to the input. + local_input[j] = _float22bf162_rn(make_float2(x.x + r.x, x.y + r.y)); + + fp32x2_t tmp = bf16x2_to_fp32x2(local_input[j]); + local_square_sum += (tmp.x * tmp.x + tmp.y * tmp.y); + } + + // Store the loaded data into shared memory. + // Divide index by 2 because 'workspace' is an array of bf16x2_t. + vec_copy(local_input, _input + i); + vec_copy(local_input, workspace + (i >> 1)); + } + + const fp32_t reduced_square_sum = lightllm::reduce::sm70::sync_block_reduce_sum_f32(local_square_sum); + + // Compute the mean square and then the inverse RMS normalization factor. + // For RMSNorm, the normalization factor is 1/sqrt(mean(x^2)+eps). + const fp32_t mean_square = reduced_square_sum * r_N; + const fp32_t inv_norm = rsqrtf(mean_square + eps); + + // Normalize each element using the computed normalization factor. + fp32_t local_max = -FLT_MAX; + #pragma unroll + for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { + // Load the previously stored vectorized data from shared memory. + vec_copy(workspace + (i >> 1), local_input); + // Load the corresponding weight values from global memory. + vec_copy(weight + i, local_w); + + #pragma unroll + for (int32_t j = 0; j < VPT / 2; j++) { + fp32x2_t x = bf16x2_to_fp32x2(local_input[j]); + fp32x2_t w = bf16x2_to_fp32x2(local_w[j]); + // Apply normalization: multiply by inv_norm and then scale by the weight. + fp32x2_t ret = make_float2( + x.x * inv_norm * w.x, + x.y * inv_norm * w.y + ); + local_output[j] = _float22bf162_rn(ret); + + + fp32x2_t tmp = bf16x2_to_fp32x2(local_output[j]); + fp32_t max = fmaxf(fabsf(tmp.x), fabsf(tmp.y)); + local_max = fmaxf(local_max, max); + } + + vec_copy(local_output, workspace + (i >> 1)); + } + + // Reduce the maximum value across the block + const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32(local_max); + + // Compute the scale factor with epsilon to avoid division by zero + constexpr fp32_t epsilon = 1e-7f; + const fp32_t scale = reduced_max / FP8_E4M3_MAX; + const fp32_t inv_scale = 1.0f / (scale + epsilon); + + #pragma unroll + for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { + vec_copy(workspace + (i >> 1), local_output); + + #pragma unroll + for (int32_t j = 0; j < VPT/4; j++) { + fp32x2_t x = bf16x2_to_fp32x2(local_output[2 * j + 0]); + fp32x2_t y = bf16x2_to_fp32x2(local_output[2 * j + 1]); + fp32x4_t ret = make_float4( + x.x * inv_scale, + x.y * inv_scale, + y.x * inv_scale, + y.y * inv_scale + ); + local_f8[j] = fp8x4_e4m3_t(ret); + } + + vec_copy(local_f8, _output + i); + } + + if(tid == 0){ + *_scales = scale; + } +} + +/** + * @brief Fused add norm quant + */ +std::tuple add_norm_quant_bf16_fp8( + Tensor& X, const Tensor &R, const Tensor &W, + const fp32_t eps +) { + TORCH_CHECK(X.ndimension() == 2, "Input tensor X must be 2D"); + TORCH_CHECK(R.ndimension() == 2, "Input tensor R must be 2D"); + TORCH_CHECK(W.ndimension() == 1, "Input tensor W must be 1D"); + + TORCH_CHECK(X.is_cuda(), "Input tensor X must be a CUDA tensor."); + TORCH_CHECK(R.is_cuda(), "Input tensor R must be a CUDA tensor."); + TORCH_CHECK(W.is_cuda(), "Input tensor W must be a CUDA tensor."); + + TORCH_CHECK(X.scalar_type() == c10::ScalarType::BFloat16, "Input tensor X must be BF16."); + TORCH_CHECK(R.scalar_type() == c10::ScalarType::BFloat16, "Input tensor R must be BF16."); + TORCH_CHECK(W.scalar_type() == c10::ScalarType::BFloat16, "Input tensor W must be BF16."); + + Tensor contiguous_X = X.is_contiguous() ? X : X.contiguous(); + Tensor contiguous_R = R.is_contiguous() ? R : R.contiguous(); + Tensor contiguous_W = W.is_contiguous() ? W : W.contiguous(); + + const uint32_t M = contiguous_X.size(0); + const uint32_t N = contiguous_X.size(1); + + Tensor output_q = torch::empty( + {M, N}, + torch::TensorOptions() + .dtype(torch::kFloat8_e4m3fn) + .device(contiguous_X.device()) + ); + Tensor scales = torch::empty( + {M, 1}, + torch::TensorOptions() + .dtype(torch::kFloat32) + .device(contiguous_X.device()) + ); + + const int32_t blocks = M; + + switch (N) { + case 16: + device_add_norm_quant_bf16<128, 16> + <<>>( + PTR(contiguous_X), + PTR(contiguous_R), + PTR(contiguous_W), + PTR(output_q), + PTR(scales), + M, + eps + ); + break; + case 32: + device_add_norm_quant_bf16<128, 32> + <<>>( + PTR(contiguous_X), + PTR(contiguous_R), + PTR(contiguous_W), + PTR(output_q), + PTR(scales), + M, + eps + ); + break; + case 64: + device_add_norm_quant_bf16<128, 64> + <<>>( + PTR(contiguous_X), + PTR(contiguous_R), + PTR(contiguous_W), + PTR(output_q), + PTR(scales), + M, + eps + ); + break; + case 512: + device_add_norm_quant_bf16<128, 512> + <<>>( + PTR(contiguous_X), + PTR(contiguous_R), + PTR(contiguous_W), + PTR(output_q), + PTR(scales), + M, + eps + ); + break; + case 1024: + device_add_norm_quant_bf16<128, 1024> + <<>>( + PTR(contiguous_X), + PTR(contiguous_R), + PTR(contiguous_W), + PTR(output_q), + PTR(scales), + M, + eps + ); + break; + case 3200: + device_add_norm_quant_bf16<128, 3200> + <<>>( + PTR(contiguous_X), + PTR(contiguous_R), + PTR(contiguous_W), + PTR(output_q), + PTR(scales), + M, + eps + ); + break; + case 4096: + device_add_norm_quant_bf16<128, 4096> + <<>>( + PTR(contiguous_X), + PTR(contiguous_R), + PTR(contiguous_W), + PTR(output_q), + PTR(scales), + M, + eps + ); + break; + case 12800: + device_add_norm_quant_bf16<256, 12800> + <<>>( + PTR(contiguous_X), + PTR(contiguous_R), + PTR(contiguous_W), + PTR(output_q), + PTR(scales), + M, + eps + ); + break; + default: { + static constexpr int32_t TPB = 128; + const int64_t shared_mem_size = N * sizeof(bf16_t); + if (N % 8 == 0) { + device_add_norm_quant_bf16_vpt + <<>>( + PTR(contiguous_X), + PTR(contiguous_R), + PTR(contiguous_W), + PTR(output_q), + PTR(scales), + M, + N, + eps + ); + } else { + device_add_norm_quant_bf16_general + <<>>( + PTR(contiguous_X), + PTR(contiguous_R), + PTR(contiguous_W), + PTR(output_q), + PTR(scales), + M, + N, + eps + ); + } + } + } + + return {output_q, scales}; +} + +} // namespace ops +} // namespace lightllm \ No newline at end of file diff --git a/lightllm-kernel/csrc/fusion/gelu_per_token_quant.cu b/lightllm-kernel/csrc/fusion/gelu_per_token_quant.cu new file mode 100755 index 000000000..b204e9737 --- /dev/null +++ b/lightllm-kernel/csrc/fusion/gelu_per_token_quant.cu @@ -0,0 +1,367 @@ +#include "ops_common.h" +#include "reduce/sm70.cuh" + + +namespace lightllm { +namespace ops { + +using namespace lightllm; + +template +__global__ void device_gelu_per_token_quant_bf16_to_fp8( + const bf16_t* __restrict__ input, // Input tensor in BF16 format + fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format + fp32_t* __restrict__ scales, // Output scales for each group + const int64_t M // Number of rows in the input tensor +) { + constexpr int32_t VPT = 8; + + static_assert(N % 2 == 0, "N must be even."); + static_assert(N % VPT == 0, "N must be a multiple of VPT."); + + const int32_t bid = blockIdx.x; + const int32_t tid = threadIdx.x; + constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format + const bf16x2_t one = _float22bf162_rn(make_float2(1.0f, 1.0f)); + const bf16x2_t one_2 = _float22bf162_rn(make_float2(0.5f, 0.5f)); + + const bf16_t* _input = input + bid * N; // Input pointer for the group + fp8_e4m3_t* _output = output + bid * N; // Output pointer for the group + + fp32_t* _scales; + _scales = scales + bid; + + // Local arrays for intermediate storage + fp8x4_e4m3_t local_f8[VPT / 4]; + bf16x2_t local_bf16[VPT / 2]; + + __shared__ bf16x2_t workspace[N / 2]; + + fp32_t local_max = -FLT_MAX; + for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { + vec_copy(_input + i, local_bf16); + //gelu + #pragma unroll + for(int32_t j = 0; j< VPT/2; j++){ + fp32x2_t tmp = bf16x2_to_fp32x2(local_bf16[j]); + tmp.x = erf(tmp.x * 0.7071067811f); + tmp.y = erf(tmp.y * 0.7071067811f); + bf16x2_t tan = _float22bf162_rn(tmp); + tan = __hadd2(tan, one); + tan = __hmul2(tan, local_bf16[j]); + tan = __hmul2(tan, one_2); + local_bf16[j] = tan; + } + + vec_copy(local_bf16, workspace + (i >> 1)); + + #pragma unroll + for(int32_t j = 0; j< VPT/2; j++){ + fp32x2_t tmp = bf16x2_to_fp32x2(local_bf16[j]); + fp32_t max = fmaxf(fabsf(tmp.x), fabsf(tmp.y)); + local_max = fmaxf(local_max, max); + } + } + + // Reduce the maximum value across the thread group + const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32(local_max); + + // Compute the scale factor with epsilon to avoid division by zero + constexpr fp32_t epsilon = 1e-7f; + const fp32_t scale = reduced_max / FP8_E4M3_MAX; + const fp32_t inv_scale = 1.0f / (scale + epsilon); + + for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { + vec_copy(workspace + (i >> 1), local_bf16); + + #pragma unroll + for (int32_t j = 0; j < VPT/4; j++) { + fp32x2_t x = bf16x2_to_fp32x2(local_bf16[2 * j + 0]); + fp32x2_t y = bf16x2_to_fp32x2(local_bf16[2 * j + 1]); + fp32x4_t ret = make_float4( + x.x * inv_scale, + x.y * inv_scale, + y.x * inv_scale, + y.y * inv_scale + ); + local_f8[j] = fp8x4_e4m3_t(ret); + } + + vec_copy(local_f8, _output + i); + } + + if(tid == 0){ + *_scales = scale; + } +} + + +template +__global__ void gelu_per_token_quant_bf16_to_fp8_vpt( + const bf16_t* __restrict__ input, // Input tensor in BF16 format + fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format + fp32_t* __restrict__ scales, // Output scales for each group + const int64_t M, // Number of rows in the input tensor + const int32_t N +) { + constexpr int32_t VPT = 8; + + const int32_t bid = blockIdx.x; + const int32_t tid = threadIdx.x; + constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format + constexpr fp32_t sqrt_2_over_pi = 0.7978845608028654f; + constexpr fp32_t coeff = 0.044715f; + + const bf16_t* _input = input + bid * N; // Input pointer for the group + fp8_e4m3_t* _output = output + bid * N; // Output pointer for the group + + fp32_t* _scales; + _scales = scales + bid; + + // Local arrays for intermediate storage + fp8x4_e4m3_t local_f8[VPT / 4]; + bf16x2_t local_bf16[VPT / 2]; + + extern __shared__ bf16x2_t workspace[]; + + fp32_t local_max = -FLT_MAX; + for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { + vec_copy(_input + i, local_bf16); + + #pragma unroll + for(int32_t j = 0; j< VPT/2; j++){ + fp32x2_t tmp = bf16x2_to_fp32x2(local_bf16[j]); + + fp32_t tanh_arg1 = sqrt_2_over_pi * (tmp.x + coeff * tmp.x * tmp.x * tmp.x); + fp32_t tanh_arg2 = sqrt_2_over_pi * (tmp.y + coeff * tmp.y * tmp.y * tmp.y); + tmp.x = 0.5f * tmp.x * (1.0f + tanhf(tanh_arg1)); + tmp.y = 0.5f * tmp.y * (1.0f + tanhf(tanh_arg2)); + + local_bf16[j] = _float22bf162_rn(tmp); + } + + vec_copy(local_bf16, workspace + (i >> 1)); + + // Compute the max for the VPT elements. + #pragma unroll + for(int32_t j = 0; j< VPT/2; j++){ + fp32x2_t tmp = bf16x2_to_fp32x2(local_bf16[j]); + fp32_t max = fmaxf(fabsf(tmp.x), fabsf(tmp.y)); + local_max = fmaxf(local_max, max); + } + } + + // Reduce the maximum value across the thread group + const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32(local_max); + + // Compute the scale factor with epsilon to avoid division by zero + constexpr fp32_t epsilon = 1e-7f; + const fp32_t scale = reduced_max / FP8_E4M3_MAX; + const fp32_t inv_scale = 1.0f / (scale + epsilon); + + for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { + vec_copy(workspace + (i >> 1), local_bf16); + + #pragma unroll + for (int32_t j = 0; j < VPT/4; j++) { + fp32x2_t x = bf16x2_to_fp32x2(local_bf16[2 * j + 0]); + fp32x2_t y = bf16x2_to_fp32x2(local_bf16[2 * j + 1]); + fp32x4_t ret = make_float4( + x.x * inv_scale, + x.y * inv_scale, + y.x * inv_scale, + y.y * inv_scale + ); + local_f8[j] = fp8x4_e4m3_t(ret); + } + + vec_copy(local_f8, _output + i); + } + + if(tid == 0){ + *_scales = scale; + } +} + + +template +__global__ void gelu_per_token_quant_bf16_to_fp8_general( + const bf16_t* __restrict__ input, // Input tensor in BF16 format + fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format + fp32_t* __restrict__ scales, // Output scales for each group + const int64_t M, // Number of rows in the input tensor + const int32_t N +) { + const int32_t bid = blockIdx.x; + const int32_t tid = threadIdx.x; + constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format + constexpr fp32_t sqrt_2_over_pi = 0.7978845608028654f; + constexpr fp32_t coeff = 0.044715f; + + const bf16_t* _input = input + bid * N; // Input pointer for the group + fp8_e4m3_t* _output = output + bid * N; // Output pointer for the group + + fp32_t* _scales; + _scales = scales + bid; + + extern __shared__ bf16_t workspace_[]; + + fp32_t local_max = -FLT_MAX; + + for (int32_t i = tid; i < N; i += TPB) { + fp32_t tmp = cvt_bf16_f32(_input[i]); + fp32_t tanh_arg = sqrt_2_over_pi * (tmp + coeff * tmp * tmp * tmp); + tmp = 0.5f * tmp * (1.0f + tanhf(tanh_arg)); + local_max = fmaxf(local_max, fabsf(tmp)); + workspace_[i] = cvt_f32_bf16(tmp); + } + + // Reduce the maximum value across the thread group + const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32(local_max); + + // Compute the scale factor with epsilon to avoid division by zero + constexpr fp32_t epsilon = 1e-7f; + const fp32_t scale = reduced_max / FP8_E4M3_MAX; + const fp32_t inv_scale = 1.0f / (scale + epsilon); + + for (int32_t i = tid; i < N; i += TPB) { + // Load the previously stored vectorized data from shared memory. + fp32_t x = cvt_bf16_f32(workspace_[i]); + // Apply normalization: multiply by inv_norm and then scale by the weight. + fp32_t ret = x * inv_scale; + _output[i] = fp8_e4m3_t(ret); + } + + if(tid == 0){ + *_scales = scale; + } +} + +void gelu_per_token_quant_bf16_fp8 ( + Tensor& output, + const Tensor& input, + Tensor& scales +) { + TORCH_CHECK(input.is_cuda(), "Input must be a CUDA tensor"); + TORCH_CHECK(input.dim() == 2, "Input must be 2-dimensional"); + TORCH_CHECK(input.scalar_type() == c10::kBFloat16, "Input must be BF16 type"); + + Tensor contiguous_input = input.is_contiguous() ? input : input.contiguous(); + Tensor contiguous_scales = scales.is_contiguous() ? scales : scales.contiguous(); + + const int64_t M = input.size(0); + const int64_t N = input.size(1); + + const int32_t blocks = M; + + switch (N) { + case 16: + device_gelu_per_token_quant_bf16_to_fp8<64, 16> + <<>>( + PTR(contiguous_input), + PTR(output), + PTR(contiguous_scales), + M + ); + break; + case 32: + device_gelu_per_token_quant_bf16_to_fp8<64, 32> + <<>>( + PTR(contiguous_input), + PTR(output), + PTR(contiguous_scales), + M + ); + break; + case 64: + device_gelu_per_token_quant_bf16_to_fp8<64, 64> + <<>>( + PTR(contiguous_input), + PTR(output), + PTR(contiguous_scales), + M + ); + break; + case 512: + device_gelu_per_token_quant_bf16_to_fp8<64, 512> + <<>>( + PTR(contiguous_input), + PTR(output), + PTR(contiguous_scales), + M + ); + break; + + case 1024: + device_gelu_per_token_quant_bf16_to_fp8<128, 1024> + <<>>( + PTR(contiguous_input), + PTR(output), + PTR(contiguous_scales), + M + ); + break; + case 2048: + device_gelu_per_token_quant_bf16_to_fp8<128, 2048> + <<>>( + PTR(contiguous_input), + PTR(output), + PTR(contiguous_scales), + M + ); + break; + case 3200: + device_gelu_per_token_quant_bf16_to_fp8<128, 3200> + <<>>( + PTR(contiguous_input), + PTR(output), + PTR(contiguous_scales), + M + ); + break; + case 4096: + device_gelu_per_token_quant_bf16_to_fp8<256, 4096> + <<>>( + PTR(contiguous_input), + PTR(output), + PTR(contiguous_scales), + M + ); + break; + case 12800: + device_gelu_per_token_quant_bf16_to_fp8<256, 12800> + <<>>( + PTR(contiguous_input), + PTR(output), + PTR(contiguous_scales), + M + ); + break; + default: { + static constexpr int32_t TPB = 128; + int32_t sharedmem = N / 2 * sizeof(bf16x2_t); + if (N % 8 == 0) { + gelu_per_token_quant_bf16_to_fp8_vpt<128> + <<>>( + PTR(contiguous_input), + PTR(output), + PTR(contiguous_scales), + M, N + ); + } + else { + gelu_per_token_quant_bf16_to_fp8_general<128> + <<>>( + PTR(contiguous_input), + PTR(output), + PTR(contiguous_scales), + M, N + ); + } + } + } + return ; +} + +} // namespace ops +} // namespace lightllm \ No newline at end of file diff --git a/lightllm-kernel/csrc/fusion/post_tp_norm.cu b/lightllm-kernel/csrc/fusion/post_tp_norm.cu new file mode 100755 index 000000000..89f711405 --- /dev/null +++ b/lightllm-kernel/csrc/fusion/post_tp_norm.cu @@ -0,0 +1,364 @@ +#include "ops_common.h" +#include "reduce/sm70.cuh" + +namespace lightllm { +namespace ops { + +using namespace lightllm; + +/** + * @brief CUDA kernel to perform RMS normalization on an FP16 tensor. + * + * Each block processes one row of the input tensor. + * + * @tparam TPB Threads per block. + * @tparam N Number of FP16 elements in one row. + * + * @param X Pointer to the input tensor in global memory. [M, N] + * @param W Pointer to the weight tensor in global memory. [N] + * @param V Pointer to the variance tensor in global memory. [M] + * @param Y Pointer to the output tensor in global memory. [M, N] + * @param M Number of rows in the tensor. + * @param eps Epsilon for numerical stability. + */ +template +__global__ +void device_post_tp_norm_bf16_general( + bf16_t __restrict__ *X, // [M, N] Input tensor pointer. + const bf16_t __restrict__ *W, // [N] Weight tensor pointer. + const fp32_t __restrict__ *V, // [M] variance + bf16_t __restrict__ *Y, // [M, N] Output tensor pointer. + const int32_t M, // Number of rows. + const int32_t N, + const int32_t embed_dim, // if multiGPUs, embed_dim differs from N + const fp32_t eps // Epsilon for numerical stability. +) { + const fp32_t r_N = 1 / (fp32_t)embed_dim; // Reciprocal of N. + + const int32_t tid = threadIdx.x; + const int32_t bid = blockIdx.x; + + // Each block processes one row of the input tensor. + bf16_t* _X = X + bid * N; + bf16_t* _Y = Y + bid * N; + + // Local registers to hold data. + bf16_t local_x = cvt_f32_bf16(0.0f); + bf16_t local_w = cvt_f32_bf16(0.0f); + bf16_t local_y = cvt_f32_bf16(0.0f); + + fp32_t reduced_square_sum = V[bid]; + + // Compute the mean square and then the inverse RMS normalization factor. + // For RMSNorm, the normalization factor is 1/sqrt(mean(x^2)+eps). + fp32_t mean_square = reduced_square_sum * r_N; + fp32_t inv_norm = rsqrtf(mean_square + eps); + + for (int32_t i = tid; i < N; i += TPB) { + local_x = _X[i]; + local_w = W[i]; + + fp32_t x = cvt_bf16_f32(local_x); + fp32_t w = cvt_bf16_f32(local_w); + + fp32_t ret = x * inv_norm * w; + local_y = cvt_f32_bf16(ret); + + _Y[i] = local_y; + } +} + + +/** + * @brief CUDA kernel to perform RMS normalization on an FP16 tensor. + * + * Each block processes one row of the input tensor. The kernel loads the + * data in a vectorized manner (using half2), computes the mean square, + * calculates the reciprocal square root (i.e. 1/sqrt(mean_square+eps)), + * and then normalizes the input row element‐wise while scaling with a weight. + * + * @tparam TPB Threads per block. + * @tparam N Number of FP16 elements in one row (must be a multiple of VPT). + * + * @param X Pointer to the input tensor in global memory. [M, N] + * @param W Pointer to the weight tensor in global memory. [N] + * @param V Pointer to the variance tensor in global memory. [M] + * @param Y Pointer to the output tensor in global memory. [M, N] + * @param M Number of rows in the tensor. + * @param eps Epsilon for numerical stability. + */ +template +__global__ +void device_post_tp_norm_bf16_vpt( + bf16_t __restrict__ *X, // [M, N] Input tensor pointer. + const bf16_t __restrict__ *W, // [N] Weight tensor pointer. + const fp32_t __restrict__ *V, // [M] variance + bf16_t __restrict__ *Y, // [M, N] Output tensor pointer. + const int32_t M, // Number of rows. + const int32_t N, + const int32_t embed_dim, // if multiGPUs, embed_dim differs from N + const fp32_t eps // Epsilon for numerical stability. +) { + constexpr int32_t VPT = 8; // Number of bf16 values processed per thread. + const fp32_t r_N = 1 / (fp32_t)embed_dim; // Reciprocal of N. + + const int32_t tid = threadIdx.x; + const int32_t bid = blockIdx.x; + + // Each block processes one row of the input tensor. + bf16_t* _X = X + bid * N; + bf16_t* _Y = Y + bid * N; + + // Local registers to hold vectorized data. + bf16x2_t local_x[VPT / 2]; + bf16x2_t local_w[VPT / 2]; + bf16x2_t local_y[VPT / 2]; + + fp32_t reduced_square_sum = V[bid]; + + // Compute the mean square and then the inverse RMS normalization factor. + // For RMSNorm, the normalization factor is 1/sqrt(mean(x^2)+eps). + fp32_t mean_square = reduced_square_sum * r_N; + fp32_t inv_norm = rsqrtf(mean_square + eps); + + // Normalize each element using the computed normalization factor. + for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { + // Load the previously stored vectorized data from global memory. + vec_copy(_X + i, local_x); + // Load the corresponding weight values from global memory. + vec_copy(W + i, local_w); + + #pragma unroll + for (int32_t j = 0; j < VPT / 2; j++) { + fp32x2_t x = bf16x2_to_fp32x2(local_x[j]); + fp32x2_t w = bf16x2_to_fp32x2(local_w[j]); + // Apply normalization: multiply by inv_norm and then scale by the weight. + fp32x2_t ret = make_float2( + x.x * inv_norm * w.x, + x.y * inv_norm * w.y + ); + local_y[j] = _float22bf162_rn(ret); + } + // Write the normalized vectorized data back to global memory. + vec_copy(local_y, _Y + i); + } +} + +/** + * @brief CUDA kernel to perform RMS normalization on an FP16 tensor. + * + * Each block processes one row of the input tensor. The kernel loads the + * data in a vectorized manner (using half2), computes the mean square, + * calculates the reciprocal square root (i.e. 1/sqrt(mean_square+eps)), + * and then normalizes the input row element‐wise while scaling with a weight. + * + * @tparam TPB Threads per block. + * @tparam N Number of FP16 elements in one row (must be a multiple of VPT). + * + * @param X Pointer to the input tensor in global memory. [M, N] + * @param W Pointer to the weight tensor in global memory. [N] + * @param V Pointer to the variance tensor in global memory. [M] + * @param Y Pointer to the output tensor in global memory. [M, N] + * @param M Number of rows in the tensor. + * @param eps Epsilon for numerical stability. + */ +template +__global__ +void device_post_tp_norm_bf16( + bf16_t __restrict__ *X, // [M, N] Input tensor pointer. + const bf16_t __restrict__ *W, // [N] Weight tensor pointer. + const fp32_t __restrict__ *V, // [M] variance + bf16_t __restrict__ *Y, // [M, N] Output tensor pointer. + const int32_t M, // Number of rows. + const int32_t embed_dim, // if multiGPUs, embed_dim differs from N + const fp32_t eps // Epsilon for numerical stability. +) { + constexpr int32_t VPT = 8; // Number of bf16 values processed per thread. + const fp32_t r_N = 1 / (fp32_t)embed_dim; // Reciprocal of N. + + static_assert(N % 2 == 0, "N must be even."); + static_assert(N % VPT == 0, "N must be a multiple of VPT."); + + const int32_t tid = threadIdx.x; + const int32_t bid = blockIdx.x; + + // Each block processes one row of the input tensor. + bf16_t* _X = X + bid * N; + bf16_t* _Y = Y + bid * N; + + // Local registers to hold vectorized data. + bf16x2_t local_x[VPT / 2]; + bf16x2_t local_w[VPT / 2]; + bf16x2_t local_y[VPT / 2]; + + fp32_t reduced_square_sum = V[bid]; + + // Compute the mean square and then the inverse RMS normalization factor. + // For RMSNorm, the normalization factor is 1/sqrt(mean(x^2)+eps). + fp32_t mean_square = reduced_square_sum * r_N; + fp32_t inv_norm = rsqrtf(mean_square + eps); + + // Normalize each element using the computed normalization factor. + # pragma unroll + for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { + // Load the previously stored vectorized data from global memory. + vec_copy(_X + i, local_x); + // Load the corresponding weight values from global memory. + vec_copy(W + i, local_w); + + #pragma unroll + for (int32_t j = 0; j < VPT / 2; j++) { + fp32x2_t x = bf16x2_to_fp32x2(local_x[j]); + fp32x2_t w = bf16x2_to_fp32x2(local_w[j]); + // Apply normalization: multiply by inv_norm and then scale by the weight. + fp32x2_t ret = make_float2( + x.x * inv_norm * w.x, + x.y * inv_norm * w.y + ); + local_y[j] = _float22bf162_rn(ret); + } + // Write the normalized vectorized data back to global memory. + vec_copy(local_y, _Y + i); + } +} + +/** + * @brief Launch RMSNorm kernel for FP16 tensors with aligned 16-element rows. + * + * This function validates the input tensors, ensures they are contiguous, + * selects the appropriate kernel configuration based on the row width N, + * and launches the CUDA kernel. + * + * @param X Input tensor with shape [M, N] (FP16, CUDA). + * @param W Weight tensor with shape [N] (FP16, CUDA). + * @param eps Epsilon for numerical stability. + * @return Output tensor with the same shape as X. + */ +Tensor post_tp_norm_bf16(Tensor &X, const Tensor &W, const Tensor &V, const int embed_dim, const fp32_t eps) { + TORCH_CHECK(X.ndimension() == 2 || X.ndimension() == 4, "Input tensor must be 2D or 4D"); + TORCH_CHECK(X.is_cuda(), "Input tensor must be a CUDA tensor."); + TORCH_CHECK(X.scalar_type() == c10::ScalarType::BFloat16, "Input tensor must be BF16."); + + Tensor contiguous_X = X.is_contiguous() ? X : X.contiguous(); + Tensor contiguous_W = W.is_contiguous() ? W : W.contiguous(); + Tensor contiguous_V = V.is_contiguous() ? V : V.contiguous(); + + Tensor input_tensor; + uint32_t M, N; + Tensor Y; + + if (X.ndimension() == 2) { + M = contiguous_X.size(0); + N = contiguous_X.size(1); + input_tensor = contiguous_X; + Y = torch::empty_like(input_tensor); + } else { + const uint32_t d0 = contiguous_X.size(0); + const uint32_t d1 = contiguous_X.size(1); + const uint32_t d2 = contiguous_X.size(2); + const uint32_t d3 = contiguous_X.size(3); + + M = d0 * d1; + N = d2 * d3; + input_tensor = contiguous_X.view({M, N}); + Y = torch::empty_like(input_tensor); + } + + // Each CUDA block processes one row. + const int32_t blocks = M; + + // Kernel dispatch based on the value of N. + switch (N) { + case 768: + device_post_tp_norm_bf16<128, 768> + <<>>( + PTR(input_tensor), PTR(contiguous_W), + PTR(contiguous_V), PTR(Y), + M, embed_dim, eps + ); + break; + case 1024: + device_post_tp_norm_bf16<128, 1024> + <<>>( + PTR(input_tensor), PTR(contiguous_W), + PTR(contiguous_V), PTR(Y), + M, embed_dim, eps + ); + break; + case 1664: + device_post_tp_norm_bf16<128, 1664> + <<>>( + PTR(input_tensor), PTR(contiguous_W), + PTR(contiguous_V), PTR(Y), + M, embed_dim, eps + ); + break; + case 2048: + device_post_tp_norm_bf16<128, 2048> + <<>>( + PTR(input_tensor), PTR(contiguous_W), + PTR(contiguous_V), PTR(Y), + M, embed_dim, eps + ); + break; + case 3200: + device_post_tp_norm_bf16<128, 3200> + <<>>( + PTR(input_tensor), PTR(contiguous_W), + PTR(contiguous_V), PTR(Y), + M, embed_dim, eps + ); + break; + case 4096: + device_post_tp_norm_bf16<256, 4096> + <<>>( + PTR(input_tensor), PTR(contiguous_W), + PTR(contiguous_V), PTR(Y), + M, embed_dim, eps + ); + break; + case 8192: + device_post_tp_norm_bf16<512, 8192> + <<>>( + PTR(input_tensor), PTR(contiguous_W), + PTR(contiguous_V), PTR(Y), + M, embed_dim, eps + ); + break; + case 10240: + device_post_tp_norm_bf16<512, 10240> + <<>>( + PTR(input_tensor), PTR(contiguous_W), + PTR(contiguous_V), PTR(Y), + M, embed_dim, eps + ); + break; + default: + static constexpr int32_t TPB = 256; + if (N % 8 == 0) { + device_post_tp_norm_bf16_vpt + <<>>( + PTR(input_tensor), PTR(contiguous_W), + PTR(contiguous_V), PTR(Y), + M, N, embed_dim, eps + ); + } else { + device_post_tp_norm_bf16_general + <<>>( + PTR(input_tensor), PTR(contiguous_W), + PTR(contiguous_V), PTR(Y), + M, N, embed_dim, eps + ); + } + } + + // need to reshape Y back to 4 dimens + if (X.ndimension() == 4) { + Y = Y.reshape(X.sizes()); + } + + return Y; +} + +} // namespace ops +} // namespace lightllm \ No newline at end of file diff --git a/lightllm-kernel/csrc/fusion/pre_tp_norm.cu b/lightllm-kernel/csrc/fusion/pre_tp_norm.cu new file mode 100755 index 000000000..966cf5ce7 --- /dev/null +++ b/lightllm-kernel/csrc/fusion/pre_tp_norm.cu @@ -0,0 +1,257 @@ +#include "ops_common.h" +#include "reduce/sm70.cuh" + +namespace lightllm { +namespace ops { + +using namespace lightllm; + +/** + * @tparam TPB Threads per block. + * @tparam N Number of bf16 elements in one row. + * + * @param X Pointer to the input tensor in global memory. [M, N] + * @param M Number of rows in the tensor. + */ +template +__global__ +void device_pre_tp_norm_bf16_general( + bf16_t __restrict__ *X, // [M, N] Input tensor pointer. + fp32_t __restrict__ *V, // [M] Variance tensor pointer. + const int32_t M, // Number of rows. + const int32_t N +) { + const int32_t tid = threadIdx.x; + const int32_t bid = blockIdx.x; + + // Each block processes one row of the input tensor. + bf16_t* _X = X + bid * N; + + bf16_t local_x = cvt_f32_bf16(0.0f); + fp32_t local_square_sum = 0.0f; + for (int32_t i = tid; i < N; i += TPB) { + local_x = _X[i]; + + fp32_t tmp = cvt_bf16_f32(local_x); + + local_square_sum += tmp * tmp; + } + + fp32_t block_square_sum = lightllm::reduce::sm70::sync_block_reduce_sum_f32(local_square_sum); + + if (tid == 0) { + V[bid] = block_square_sum; + } + +} + + + +/** + * @tparam TPB Threads per block. + * @tparam N Number of bf16 elements in one row (must be a multiple of VPT). + * + * @param X Pointer to the input tensor in global memory. [M, N] + * @param M Number of rows in the tensor. + */ +template +__global__ +void device_pre_tp_norm_bf16_vpt( + bf16_t __restrict__ *X, // [M, N] Input tensor pointer. + fp32_t __restrict__ *V, // [M] Variance tensor pointer. + const int32_t M, // Number of rows. + const int32_t N +) { + constexpr int32_t VPT = 8; // Number of bf16 values processed per thread. + + const int32_t tid = threadIdx.x; + const int32_t bid = blockIdx.x; + + // Each block processes one row of the input tensor. + bf16_t* _X = X + bid * N; + + // Local registers to hold vectorized data. + bf16x2_t local_x[VPT / 2]; + + // Each thread computes a partial sum of squares. + fp32_t local_square_sum = 0.0f; + for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { + // Load VPT bf16 elements from global memory (_X) into local vector (local_x). + vec_copy(_X + i, local_x); + + // Compute the sum of squares for the VPT elements. + #pragma unroll + for (int32_t j = 0; j < VPT / 2; j++) { + fp32x2_t tmp = bf16x2_to_fp32x2(local_x[j]); + local_square_sum += (tmp.x * tmp.x + tmp.y * tmp.y); + } + } + + // Reduce the partial sums across the block, block reduce sum will invoke __syncthread(); + V[bid] = lightllm::reduce::sm70::sync_block_reduce_sum_f32(local_square_sum); + +} + + +/** + * @tparam TPB Threads per block. + * @tparam N Number of bf16 elements in one row (must be a multiple of VPT). + * + * @param X Pointer to the input tensor in global memory. [M, N] + * @param M Number of rows in the tensor. + */ +template +__global__ +void device_pre_tp_norm_bf16( + bf16_t __restrict__ *X, // [M, N] Input tensor pointer. + fp32_t __restrict__ *V, // [M] Variance tensor pointer. + const int32_t M // Number of rows. +) { + constexpr int32_t VPT = 8; // Number of bf16 values processed per thread. + + static_assert(N % 2 == 0, "N must be even."); + static_assert(N % VPT == 0, "N must be a multiple of VPT."); + + const int32_t tid = threadIdx.x; + const int32_t bid = blockIdx.x; + + // Each block processes one row of the input tensor. + bf16_t* _X = X + bid * N; + + // Local registers to hold vectorized data. + bf16x2_t local_x[VPT / 2]; + + // Each thread computes a partial sum of squares. + fp32_t local_square_sum = 0.0f; + # pragma unroll + for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { + // Load VPT bf16 elements from global memory (_X) into local vector (local_x). + vec_copy(_X + i, local_x); + + // Compute the sum of squares for the VPT elements. + #pragma unroll + for (int32_t j = 0; j < VPT / 2; j++) { + fp32x2_t tmp = bf16x2_to_fp32x2(local_x[j]); + local_square_sum += (tmp.x * tmp.x + tmp.y * tmp.y); + } + } + + // Reduce the partial sums across the block, block reduce sum will invoke __syncthread(); + V[bid] = lightllm::reduce::sm70::sync_block_reduce_sum_f32(local_square_sum); + +} + +/** + * @param X Input tensor with shape [M, N] (bf16, CUDA). + */ +Tensor pre_tp_norm_bf16(Tensor &X) { + TORCH_CHECK(X.ndimension() == 2 || X.ndimension() == 4, "Input tensor must be 2D or 4D"); + TORCH_CHECK(X.is_cuda(), "Input tensor must be a CUDA tensor."); + TORCH_CHECK(X.scalar_type() == c10::ScalarType::BFloat16, "Input tensor must be BF16."); + + Tensor contiguous_X = X.is_contiguous() ? X : X.contiguous(); + Tensor input_tensor; + uint32_t M, N; + Tensor V; + + if (X.ndimension() == 2) { + M = contiguous_X.size(0); + N = contiguous_X.size(1); + input_tensor = contiguous_X; + V = torch::empty( + {M}, + torch::TensorOptions() + .dtype(c10::ScalarType::Float) + .device(contiguous_X.device()) + ); + } else { + const uint32_t d0 = contiguous_X.size(0); + const uint32_t d1 = contiguous_X.size(1); + const uint32_t d2 = contiguous_X.size(2); + const uint32_t d3 = contiguous_X.size(3); + + M = d0 * d1; + N = d2 * d3; + input_tensor = contiguous_X.view({M, N}); + V = torch::empty( + {M}, + torch::TensorOptions() + .dtype(c10::ScalarType::Float) + .device(contiguous_X.device()) + ); + } + + + // Each CUDA block processes one row. + const int32_t blocks = M; + + // Kernel dispatch based on the value of N. + switch (N) { + case 768: + device_pre_tp_norm_bf16<128, 768> + <<>>( + PTR(input_tensor), PTR(V), M + ); + break; + case 1024: + device_pre_tp_norm_bf16<128, 1024> + <<>>( + PTR(input_tensor), PTR(V), M + ); + break; + case 1664: + device_pre_tp_norm_bf16<128, 1664> + <<>>( + PTR(input_tensor), PTR(V), M + ); + break; + case 2048: + device_pre_tp_norm_bf16<128, 2048> + <<>>( + PTR(input_tensor), PTR(V), M + ); + break; + case 3200: + device_pre_tp_norm_bf16<128, 3200> + <<>>( + PTR(input_tensor), PTR(V), M + ); + break; + case 4096: + device_pre_tp_norm_bf16<256, 4096> + <<>>( + PTR(input_tensor), PTR(V), M + ); + break; + case 8192: + device_pre_tp_norm_bf16<512, 8192> + <<>>( + PTR(input_tensor), PTR(V), M + ); + break; + case 10240: + device_pre_tp_norm_bf16<512, 10240> + <<>>( + PTR(input_tensor), PTR(V), M + ); + break; + default: { + static constexpr int32_t TPB = 256; + if (N % 8 == 0) { + device_pre_tp_norm_bf16_vpt + <<>>( + PTR(input_tensor), PTR(V), M, N + ); + } else { + device_pre_tp_norm_bf16_general + <<>>( + PTR(input_tensor), PTR(V), M, N + ); + } + } + } + return V; +} + +} // namespace ops +} // namespace lightllm \ No newline at end of file diff --git a/lightllm-kernel/csrc/gemm/Epilogues.md b/lightllm-kernel/csrc/gemm/Epilogues.md new file mode 100755 index 000000000..aae04157b --- /dev/null +++ b/lightllm-kernel/csrc/gemm/Epilogues.md @@ -0,0 +1,147 @@ +# CUTLASS Epilogues + +## Introduction +This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs. + +Currently, we only support symmetric quantization for weights, +and symmetric and asymmetric quantization for activations. +Both can be quantized per-tensor or per-channel (weights) / per-token (activations). + +There are 4 epilogues: +1. ScaledEpilogue: symmetric quantization for activations, no bias. +1. ScaledEpilogueBias: symmetric quantization for activations, supports bias. +1. ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias. +1. ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias. + +We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size. +Instead, if no bias is passed, the epilogue will use 0 as the bias. +That induces a redundant addition operation (and runtime check), but the performance impact is minor. + +## Underlying Linear Algebra + +More details available in the [Activation Quantization RFC](https://github.com/vllm-project/vllm/issues/3975). + +If $` \widehat X `$ is the quantized $` X `$, our matrices become the following + +```math +A = s_a (\widehat A - J_a z_a) +``` +```math +B = s_b \widehat B +``` +```math +D = A B + C +``` +```math +D = s_a s_b \widehat D + C +``` + +Here, D is the output of the GEMM, and C is the bias. +A is the activations and supports asymmetric quantization, +and B is the weights and only supports symmetric quantization. +$ s_a $ and $s_b$ are the scales for activations and weights, respectively. +$ z_a $ is the zero-point for activations, and $ J_a $ is the matrix of all ones with dimensions of A. +Additional epilogues would be required to support asymmetric quantization for weights. + +Expanding further, we can calculate $` \widehat D `$ as follows: + +```math +A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B +``` +```math +A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right) +``` +```math +\widehat D = \widehat A \widehat B - z_a J_a \widehat B +``` + +Note that $` \widehat A \widehat B `$ is the raw output of the GEMM, +and $` J_a \widehat B `$ is known ahead of time. +Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of column sums of $` \widehat B `$. + +## Epilogues + +### ScaledEpilogue +This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$. +The output of the GEMM is: + +```math +\widehat D = \widehat A \widehat B +``` +```math +D = s_a s_b \widehat D +``` +```math +D = s_a s_b \widehat A \widehat B +``` + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). + +### ScaledEpilogueBias +This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$. +The output of the GEMM is: + +```math +\widehat D = \widehat A \widehat B +``` +```math +D = s_a s_b \widehat D + C +``` +```math +D = s_a s_b \widehat A \widehat B + C +``` + + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). +- `bias` is the bias, is always per-channel (row-vector). + +### ScaledEpilogueAzp +This epilogue computes the asymmetric per-tensor quantization for activations with bias. +The output of the GEMM is: + +```math +\widehat D = \widehat A \widehat B - z_a J_a \widehat B +``` +```math +D = s_a s_b \widehat D + C +``` +```math +D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C +``` + +Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$. +That is precomputed and stored in `azp_with_adj` as a row-vector. + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). + - Generally this will be per-tensor as the zero-points are per-tensor. +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). +- `azp_with_adj` is the precomputed zero-point term ($` z_a J_a \widehat B `$), is per-channel (row-vector). +- `bias` is the bias, is always per-channel (row-vector). + +To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel. + +### ScaledEpilogueAzpPerToken +This epilogue computes the asymmetric per-token quantization for activations with bias. + +The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector. +That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$. + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). + - Generally this will be per-token as the zero-points are per-token. +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). +- `azp_adj` is the precomputed zero-point adjustment term ($` \mathbf 1 \widehat B `$), is per-channel (row-vector). +- `azp` is the zero-point (`z_a`), is per-token (column-vector). +- `bias` is the bias, is always per-channel (row-vector). + +To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel. + +The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM): +``` +out = scale_a * scale_b * (Dq - azp_adj * azp) + bias +``` diff --git a/lightllm-kernel/csrc/gemm/scaled_mm_c3x.cu b/lightllm-kernel/csrc/gemm/scaled_mm_c3x.cu new file mode 100755 index 000000000..55d623755 --- /dev/null +++ b/lightllm-kernel/csrc/gemm/scaled_mm_c3x.cu @@ -0,0 +1,73 @@ +#include + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + + #include "scaled_mm_c3x_sm90_fp8_dispatch.cuh" + #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace lightllm { +namespace ops { + +using namespace lightllm; +/* + This file defines quantized GEMM operations using the CUTLASS 3.x API, for + NVIDIA GPUs with sm90a (Hopper) or later. +*/ + +template