diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index c7760e995..ff3290233 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -23,6 +23,7 @@ from lightllm.distributed.communication_op import CustomProcessGroup, dist_group_manager from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch +from lightllm.utils.envs_utils import set_model_init_status logger = init_logger(__name__) @@ -103,6 +104,7 @@ def __init__(self, kvargs): self._init_cudagraph() self._check_max_len_infer() torch.cuda.empty_cache() + set_model_init_status(True) return def _init_config(self): diff --git a/lightllm/common/basemodel/triton_kernel/destindex_copy_kv_fp8.py b/lightllm/common/basemodel/triton_kernel/destindex_copy_kv_fp8.py new file mode 100644 index 000000000..f542d2c90 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/destindex_copy_kv_fp8.py @@ -0,0 +1,98 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_destindex_copy_kv_per_head_fp8( + K, + Dest_loc, + Out, + scale, + stride_k_bs, + stride_k_h, + stride_k_d, + stride_o_bs, + stride_o_h, + stride_o_d, + head_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_HEAD: tl.constexpr, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr, +): + cur_index = tl.program_id(0) + offs_h = tl.arange(0, BLOCK_HEAD) + offs_d = tl.arange(0, BLOCK_DMODEL) + + dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) + + k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :] + o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] + + # to fp8 + scale_ptrs = scale + offs_h + scales = tl.load(scale_ptrs, mask=offs_h < head_num, other=1.0) + k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0) + k_scale = k / scales[:, None] + k_fp8 = tl.clamp(k_scale, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv) + + tl.store(o_ptrs, k_fp8, mask=offs_h[:, None] < head_num) + return + + +@torch.no_grad() +def destindex_copy_kv_fp8(K, DestLoc, scales, Out): + if scales is None: + Out[DestLoc] = K.to(torch.float8_e4m3fn) + return + + seq_len = DestLoc.shape[0] + head_num = K.shape[1] + head_dim = K.shape[2] + assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2] + BLOCK_HEAD = triton.next_power_of_2(head_num) + grid = (seq_len,) + num_warps = 1 + + _fwd_kernel_destindex_copy_kv_per_head_fp8[grid]( + K, + DestLoc, + Out, + scales, + K.stride(0), + K.stride(1), + K.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + head_num, + BLOCK_DMODEL=head_dim, + BLOCK_HEAD=BLOCK_HEAD, + FP8_MIN=torch.finfo(torch.float8_e4m3fn).min, + FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, + num_warps=num_warps, + num_stages=1, + ) + + +if __name__ == "__main__": + import torch.nn.functional as F + from lightllm.utils.vllm_utils import vllm_ops + + B, N_CTX, H, HEAD_DIM = 32, 1024, 16, 128 + dtype = torch.bfloat16 + NUM = B + dest_loc = torch.arange(NUM).cuda() * 2 + kv = torch.randn((len(dest_loc), H, HEAD_DIM), dtype=dtype).cuda() + out = torch.zeros((B * N_CTX, H, HEAD_DIM), dtype=torch.uint8).cuda() + scale = kv.abs().amax(dim=(0, 2)).to(torch.float32) / 448 + destindex_copy_kv_fp8(kv, dest_loc, scale, out.view(torch.float8_e4m3fn)) + + assert torch.allclose( + out[:, :, :HEAD_DIM][dest_loc].view(torch.float8_e4m3fn).float() * scale.view(H, 1).expand(NUM, H, 1), + kv.float(), + atol=1e-5, + rtol=1e-1, + ) diff --git a/lightllm/common/basemodel/triton_kernel/q_per_head_fp8_quant.py b/lightllm/common/basemodel/triton_kernel/q_per_head_fp8_quant.py new file mode 100644 index 000000000..34a12bc91 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/q_per_head_fp8_quant.py @@ -0,0 +1,151 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _per_head_max_reduce_kernel( + Q, + Scales, + StartLoc, + stride_q_t, + stride_q_h, + stride_scales_b, + FP8_MAX: tl.constexpr, + BLOCK_T: tl.constexpr, + BLOCK_D: tl.constexpr, +): + b_id = tl.program_id(0) + h_id = tl.program_id(1) + + max_val = 0.0 + + start_loc = tl.load(StartLoc + b_id) + end_loc = tl.load(StartLoc + b_id + 1) + for t_offset in range(start_loc, end_loc, BLOCK_T): + t_idx = t_offset + tl.arange(0, BLOCK_T) + q_range = tl.arange(0, BLOCK_D) + q_ptrs = Q + t_idx[:, None] * stride_q_t + h_id * stride_q_h + q_range[None, :] + mask = (t_idx[:, None] < end_loc) & (q_range[None, :] < stride_q_h) + q_vals = tl.load(q_ptrs, mask=mask, other=0.0) + max_val = tl.maximum(tl.max(q_vals.abs()), max_val) + + scale = tl.where(max_val > 0, max_val / FP8_MAX, 1.0) + scale_ptr = Scales + b_id * stride_scales_b + h_id + tl.store(scale_ptr, scale) + + +@triton.jit +def _apply_quantization_kernel( + Q, + Q_out, + BatchIds, + Scales, + stride_q_t, + stride_q_h, + stride_qout_t, + stride_qout_h, + stride_scales_b, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr, + BLOCK_D: tl.constexpr, +): + t_id = tl.program_id(0) + h_id = tl.program_id(1) + + batch_id = tl.load(BatchIds + t_id) + scale_ptr = Scales + batch_id * stride_scales_b + h_id + scale = tl.load(scale_ptr) + + q_range = tl.arange(0, BLOCK_D) + q_ptrs = Q + t_id * stride_q_t + h_id * stride_q_h + q_range + qout_ptrs = Q_out + t_id * stride_qout_t + h_id * stride_qout_h + q_range + mask = q_range < stride_q_h + q_vals = tl.load(q_ptrs, mask=mask, other=0.0) + q_scaled = q_vals / scale + q_clamped = tl.clamp(q_scaled, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv) + tl.store(qout_ptrs, q_clamped, mask=q_range < stride_qout_h) + + +@torch.no_grad() +def q_per_head_fp8_quant(q, seq_lens, b1_start_loc, scale_out=None, token_batch_ids=None): + T, H, D = q.shape + B = seq_lens.shape[0] + + BLOCK_D = triton.next_power_of_2(D) + BLOCK_T = 256 + num_warps = 4 + num_stages = 2 + + q_out = torch.empty_like(q, dtype=torch.float8_e4m3fn) + if scale_out is None: + scale_out = torch.empty((B, H), dtype=torch.float32, device=q.device) + if token_batch_ids is None: + token_batch_ids = torch.repeat_interleave(torch.arange(B, device=q.device), seq_lens) + + _per_head_max_reduce_kernel[(B, H)]( + q, + scale_out, + b1_start_loc, + q.stride(0), + q.stride(1), + scale_out.stride(0), + FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, + BLOCK_T=BLOCK_T, + BLOCK_D=BLOCK_D, + num_warps=num_warps, + num_stages=num_stages, + ) + + _apply_quantization_kernel[(T, H)]( + q, + q_out, + token_batch_ids, + scale_out, + q.stride(0), + q.stride(1), + q_out.stride(0), + q_out.stride(1), + scale_out.stride(0), + FP8_MIN=torch.finfo(torch.float8_e4m3fn).min, + FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, + BLOCK_D=BLOCK_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return q_out, scale_out + + +def ref_q_per_head_fp8_quant(q, seq_lens): + min_fp8 = torch.finfo(torch.float8_e4m3fn).min + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + B = seq_lens.size(0) + device = q.device + token_batch_ids = torch.repeat_interleave(torch.arange(B, device=device), seq_lens) + max_per_time_head = q.abs().amax(dim=2) + max_per_bh = torch.zeros((B, max_per_time_head.size(1)), device=device, dtype=max_per_time_head.dtype) + max_per_bh.scatter_reduce_( + 0, + token_batch_ids.unsqueeze(-1).expand(-1, max_per_time_head.size(1)), + max_per_time_head, + reduce="amax", + include_self=False, + ) + scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)).to(torch.float32) + scale_expanded = scales[token_batch_ids].view(-1, scales.size(1), 1) + q_q = (q / scale_expanded).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn) + return q_q, scales + + +if __name__ == "__main__": + B, T, H, D = 200, 1000, 4, 7 * 128 + seq_lens = torch.ones((B,), dtype=torch.int32).cuda() * T // B + start_locs = torch.zeros(B + 1, dtype=torch.int32).cuda() + start_locs[1:] = seq_lens.cumsum(dim=0) + q = torch.randn((T, H, D), dtype=torch.float32).cuda() + + q_out, scales = q_per_head_fp8_quant(q, seq_lens, start_locs) + q_out1, scales1 = ref_q_per_head_fp8_quant(q, seq_lens) + assert torch.allclose(scales, scales1, atol=1e-10, rtol=0) + assert torch.allclose(q_out.int(), q_out1.int(), atol=1e-10, rtol=0) diff --git a/lightllm/common/calibration_fp8kv_mem_manager.py b/lightllm/common/calibration_fp8kv_mem_manager.py new file mode 100755 index 000000000..2c896d495 --- /dev/null +++ b/lightllm/common/calibration_fp8kv_mem_manager.py @@ -0,0 +1,6 @@ +from .offline_fp8_quant_mem_manager import OfflineFP8QuantMemManager + + +class CalibrationFP8KVMemoryManager(OfflineFP8QuantMemManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_export_mode=False) diff --git a/lightllm/common/export_calibration_mem_manager.py b/lightllm/common/export_calibration_mem_manager.py new file mode 100755 index 000000000..b2749176e --- /dev/null +++ b/lightllm/common/export_calibration_mem_manager.py @@ -0,0 +1,6 @@ +from .offline_fp8_quant_mem_manager import OfflineFP8QuantMemManager + + +class ExportCalibrationMemoryManager(OfflineFP8QuantMemManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_export_mode=True) diff --git a/lightllm/common/mem_utils.py b/lightllm/common/mem_utils.py index 359ceada5..dfb8e849d 100644 --- a/lightllm/common/mem_utils.py +++ b/lightllm/common/mem_utils.py @@ -1,5 +1,7 @@ from lightllm.common.mem_manager import MemoryManager from lightllm.common.int8kv_mem_manager import INT8KVMemoryManager +from lightllm.common.calibration_fp8kv_mem_manager import CalibrationFP8KVMemoryManager +from lightllm.common.export_calibration_mem_manager import ExportCalibrationMemoryManager from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager from lightllm.common.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager from lightllm.utils.log_utils import init_logger @@ -20,6 +22,12 @@ def select_mem_manager_class(mode): logger.info("Model kv cache using mode triton int8kv") elif "triton_fp8kv" in mode: raise Exception("currently only for deepseek") + elif "offline_calibration_fp8kv" in mode: + memory_manager_class = CalibrationFP8KVMemoryManager + logger.info("Model kv cache using mode offline calibration fp8kv") + elif "export_fp8kv_calibration" in mode: + memory_manager_class = ExportCalibrationMemoryManager + logger.info("Using mode export fp8kv calibration") else: memory_manager_class = MemoryManager logger.info("Model kv cache using mode normal") diff --git a/lightllm/common/offline_fp8_quant_mem_manager.py b/lightllm/common/offline_fp8_quant_mem_manager.py new file mode 100755 index 000000000..686371231 --- /dev/null +++ b/lightllm/common/offline_fp8_quant_mem_manager.py @@ -0,0 +1,162 @@ +import os +import json +import torch +import torch.distributed as dist +from lightllm.utils.envs_utils import get_kv_quant_calibration_inference_count +from lightllm.utils.envs_utils import get_kv_quant_calibration_warmup_count +from lightllm.utils.dist_utils import get_global_rank +from lightllm.utils.config_utils import get_model_architectures +from lightllm.utils.log_utils import init_logger +from lightllm.utils.envs_utils import get_env_start_args, get_model_init_status + +logger = init_logger(__name__) + +from .mem_manager import MemoryManager + + +class OfflineFP8QuantMemManager(MemoryManager): + def __init__( + self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_export_mode=False + ): + # 这里用uint8存储量化后的kv,方便兼容各种torch算子。fp8量化目前采用离线方案,kv_buffer不存储scale + super().__init__( + size, dtype if is_export_mode else torch.uint8, head_num, head_dim, layer_num, always_copy, mem_fraction + ) + + self.qmax = torch.finfo(torch.float8_e4m3fn).max + self.qmin = torch.finfo(torch.float8_e4m3fn).min + self.layer_num = layer_num + self.total_head_num = head_num * dist.get_world_size() if dist.is_initialized() else head_num + self.count = 0 + self.scales = None + self.scales_list = None + self.abs_max = None + + if is_export_mode: + scales_shape = [layer_num, 2 * head_num] if get_env_start_args().enable_fa3 else [layer_num, 2] + self.abs_max = torch.zeros(scales_shape, dtype=torch.float32, device="cuda") + elif get_env_start_args().kv_quant_calibration_config_path is not None: + logger.info( + f"kv_quant_calibration_config_path {get_env_start_args().kv_quant_calibration_config_path} is set, " + "will load kv quant calibration config" + ) + cfg = self._load_and_check_config() + + self.scales_list = cfg["scales"] + self.scales = torch.tensor(self.scales_list, dtype=torch.float32, device="cuda").view(cfg["scales_shape"]) + if not get_env_start_args().enable_fa3: + self.scales = torch.repeat_interleave(self.scales, self.head_num, dim=-1) + if get_env_start_args().enable_fa3 and dist.is_initialized() and dist.get_world_size() > 1: + half_head = self.total_head_num // 2 + start_head = dist.get_rank() * head_num + end_head = start_head + head_num + k_scales = self.scales[:, start_head:end_head].contiguous() + v_scales = self.scales[:, start_head + half_head : end_head + half_head].contiguous() + current_scales = torch.cat((k_scales, v_scales), dim=-1) + + self.scales_list = current_scales.tolist() + self.scales = current_scales + else: + logger.warning("scales is None, no kv_quant_calibration_config_path be set, will use 1.0 as scales") + + def _load_and_check_config(self): + if os.path.exists(get_env_start_args().kv_quant_calibration_config_path): + with open(get_env_start_args().kv_quant_calibration_config_path, "r") as f: + cfg = json.load(f) + + if cfg["qmin"] != self.qmin: + raise ValueError(f"qmin {cfg['qmin']} in config not match torch.float8_e4m3fn.min {self.qmin}") + if cfg["qmax"] != self.qmax: + raise ValueError(f"qmax {cfg['qmax']} in config not match torch.float8_e4m3fn.max {self.qmax}") + model_arch = get_model_architectures(get_env_start_args().model_dir) + if cfg["architectures"] != model_arch: + raise ValueError( + f"architectures {cfg['architectures']} in config " f"not match current model_arch {model_arch}" + ) + if cfg["num_layers"] != self.layer_num: + raise ValueError( + f"num_layers {cfg['num_layers']} in config " f"not match current layer_num {self.layer_num}" + ) + if cfg["num_head"] != self.total_head_num: + raise ValueError( + f"num_head {cfg['num_head']} in config " f"not match current model head num {self.total_head_num}" + ) + if get_env_start_args().enable_fa3: + if cfg["quant_type"] != "per_head": + raise ValueError(f"quant type {cfg['num_head']} in config not match fa3 backend") + else: + if cfg["quant_type"] != "per_tensor": + raise ValueError(f"quant type {cfg['quant_type']} in config not match flashinfer backend") + + return cfg + else: + raise FileNotFoundError( + f"kv_quant_calibration_config {get_env_start_args().kv_quant_calibration_config_path} not found" + ) + + def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int): + inference_counts = get_kv_quant_calibration_inference_count() + warmup_counts = get_kv_quant_calibration_warmup_count() + if not get_model_init_status() or self.count >= warmup_counts + inference_counts: + return + + if self.count == 0 and layer_index == 0: + logger.info("kv cache calibration mode will collect kv cache data for quantization calibration") + + if self.abs_max is not None and self.count >= warmup_counts: + if get_env_start_args().enable_fa3: + kv_max = kv_buffer.abs().amax(dim=(0, 2)).to(torch.float32) + else: + k_max = kv_buffer[:, : self.head_num, :].abs().amax(dim=()).to(torch.float32) + v_max = kv_buffer[:, self.head_num :, :].abs().amax(dim=()).to(torch.float32) + kv_max = torch.tensor([k_max, v_max], device="cuda", dtype=torch.float32) + self.abs_max[layer_index] = torch.maximum(self.abs_max[layer_index], kv_max) + if self.count == warmup_counts + inference_counts - 1 and layer_index == self.layer_num - 1: + final_abs_max = self.abs_max + if dist.is_initialized() and dist.get_world_size() > 1: + if get_env_start_args().enable_fa3: + k_max, v_max = torch.chunk(self.abs_max, 2, dim=-1) + k_max = k_max.contiguous() + v_max = v_max.contiguous() + gathered_k_max = [torch.zeros_like(k_max) for _ in range(dist.get_world_size())] + gathered_v_max = [torch.zeros_like(v_max) for _ in range(dist.get_world_size())] + dist.all_gather(gathered_k_max, k_max, group=None, async_op=False) + dist.all_gather(gathered_v_max, v_max, group=None, async_op=False) + k_max = torch.cat(gathered_k_max, dim=-1) + v_max = torch.cat(gathered_v_max, dim=-1) + final_abs_max = torch.cat((k_max, v_max), dim=-1) + else: + dist.all_reduce(self.abs_max, op=dist.ReduceOp.MAX, group=None, async_op=False) + + self.scales = final_abs_max / self.qmax + self.scales = torch.where(self.scales > 0, self.scales, torch.ones_like(self.scales)) + + if get_global_rank() == 0: + self.abs_max = final_abs_max + self._export_calibration_data() + + if layer_index == self.layer_num - 1: + self.count += 1 + + def _export_calibration_data(self): + model_arch = get_model_architectures(get_env_start_args().model_dir) + cfg = { + "version": "1.0", + "architectures": model_arch, + "quant_type": "per_head" if get_env_start_args().enable_fa3 else "per_tensor", + "qmin": self.qmin, + "qmax": self.qmax, + "num_layers": self.layer_num, + "num_head": self.total_head_num, + "scales_shape": list(self.abs_max.shape), + "scales": self.scales.cpu().numpy().tolist(), + } + with open("./kv_cache_calib.json", "w") as f: + json.dump(cfg, f, indent=4) + logger.info( + f"Export kv cache calibration data to kv_cache_calib.json, " + f"architectures: {model_arch}, " + f"qmin: {self.qmin}, qmax: {self.qmax}, " + f"total heads: {self.total_head_num}, " + f"scales_shape: {list(self.abs_max.shape)}, " + ) diff --git a/lightllm/models/llama/flashattention_infer_struct.py b/lightllm/models/llama/flashattention_infer_struct.py index 0ca940ae8..3427ff9ee 100644 --- a/lightllm/models/llama/flashattention_infer_struct.py +++ b/lightllm/models/llama/flashattention_infer_struct.py @@ -28,7 +28,9 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if self.is_prefill: self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() - self.page_table = torch.empty((self.batch_size, self.max_seq_len), dtype=torch.int32).to(input_ids.device) + self.page_table = torch.empty( + (self.batch_size, self.max_seq_len), dtype=torch.int32, device=input_ids.device + ) self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len]) else: # Meta information of flashattention for decoding @@ -43,12 +45,49 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): : self.batch_size * model.graph_max_len_in_batch ].reshape(self.batch_size, model.graph_max_len_in_batch) else: - self.page_table = torch.empty((self.batch_size, self.max_len_in_batch), dtype=torch.int32).to( - input_ids.device + self.page_table = torch.empty( + (self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device ) self.page_table[:, :max_seq_len_k].copy_( model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k] ) self.page_table[:, max_seq_len_k:].fill_(0) + + if "offline_calibration_fp8kv" in model.mode: + if self.is_prefill: + device = input_ids.device + # q_scale和token_batch_ids在对q做per head量化使用,为了节省资源在推理外部初始化 + self.q_scale = torch.empty( + (self.batch_size, self.mem_manager.head_num), dtype=torch.float32, device=device + ) + self.token_batch_ids = torch.repeat_interleave( + torch.arange(self.batch_size, device=device), self.b_q_seq_len + ) + + offline_scales = self.mem_manager.scales + head_num = self.mem_manager.head_num + # 为了减少推理计算量,在推理外部初始化k_descale和v_descale + self.k_descale = ( + offline_scales[:, :head_num] + .view(-1, 1, head_num) + .expand(offline_scales.shape[0], self.batch_size, head_num) + if offline_scales is not None + else torch.ones( + (self.mem_manager.layer_num, self.batch_size, head_num), + dtype=torch.float32, + device=input_ids.device, + ) + ) + self.v_descale = ( + offline_scales[:, head_num:] + .view(-1, 1, head_num) + .expand(offline_scales.shape[0], self.batch_size, head_num) + if offline_scales is not None + else torch.ones( + (self.mem_manager.layer_num, self.batch_size, head_num), + dtype=torch.float32, + device=input_ids.device, + ) + ) return diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 125134659..4b06a75c3 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -19,14 +19,23 @@ from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv, destindex_copy_quantize_kv +from lightllm.common.basemodel.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 from lightllm.common.basemodel import TransformerLayerInferTpl from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_dequantize_kv from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops +from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant +from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops + +if HAS_VLLM: + scaled_fp8_quant = vllm_ops.scaled_fp8_quant +else: + scaled_fp8_quant = None logger = init_logger(__name__) @@ -60,11 +69,34 @@ def _bind_norm(self): def _bind_attention(self): if get_env_start_args().enable_fa3: - self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_flashattention, self) - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashattention, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) + if "offline_calibration_fp8kv" in self.mode: + self._context_attention_kernel = partial( + LlamaTransformerLayerInfer._context_attention_flashattention_fp8, self + ) + self._token_attention_kernel = partial( + LlamaTransformerLayerInfer._token_decode_attention_flashattention_fp8, self + ) + self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_fp8kv, self) + elif "export_fp8kv_calibration" in self.mode: + self._context_attention_kernel = partial( + LlamaTransformerLayerInfer._context_attention_flashattention, self + ) + self._token_attention_kernel = partial( + LlamaTransformerLayerInfer._token_decode_attention_flashattention, self + ) + self._copy_kv_to_mem_cache = partial( + LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self + ) + elif not self.mode: + self._context_attention_kernel = partial( + LlamaTransformerLayerInfer._context_attention_flashattention, self + ) + self._token_attention_kernel = partial( + LlamaTransformerLayerInfer._token_decode_attention_flashattention, self + ) + self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) + else: + raise Exception(f"Unsupported mode for fa3 backend: {self.mode}") return elif get_env_start_args().enable_flashinfer_prefill: self._context_attention_kernel = partial( @@ -102,6 +134,15 @@ def _bind_attention(self): elif "triton_int8kv" in self.mode: self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_int8kv, self) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_int8kv, self) + elif "offline_calibration_fp8kv" in self.mode: + assert get_env_start_args().enable_flashinfer_prefill and get_env_start_args().enable_flashinfer_decode + self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_fp8kv, self) + self._context_attention_kernel = partial( + LlamaTransformerLayerInfer._context_attention_flashinfer_kernel_fp8, self + ) + self._token_attention_kernel = partial( + LlamaTransformerLayerInfer._token_decode_attention_flashinfer_fp8, self + ) elif "triton_flashdecoding" in self.mode: self._token_attention_kernel = partial( LlamaTransformerLayerInfer._token_decode_attention_flashdecoding, self @@ -120,7 +161,12 @@ def _bind_attention(self): LlamaTransformerLayerInfer._token_decode_attention_gqa_flashdecoding_vsm, self ) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - else: + elif "export_fp8kv_calibration" in self.mode: + self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_flashinfer, self) + self._copy_kv_to_mem_cache = partial( + LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self + ) + elif not self.mode: if get_env_start_args().enable_flashinfer_decode: self._token_attention_kernel = partial( LlamaTransformerLayerInfer._token_decode_attention_flashinfer, self @@ -128,6 +174,8 @@ def _bind_attention(self): else: self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_normal, self) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) + else: + raise Exception(f"Unsupported mode: {self.mode}") return @@ -185,6 +233,26 @@ def _tpsp_get_qkv( ) return q, cache_kv + def _context_attention_flashinfer_kernel_fp8( + self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + kv = kv.unsqueeze(1) + k = kv[:, :, : self.tp_k_head_num_, :].view(torch.float8_e4m3fn) + v = kv[:, :, self.tp_k_head_num_ :, :].view(torch.float8_e4m3fn) + offline_scales = infer_state.mem_manager.scales_list + k_descale = offline_scales[self.layer_num_][0] if offline_scales is not None else None + v_descale = offline_scales[self.layer_num_][1] if offline_scales is not None else None + infer_state.prefill_wrapper.run( + q.view(q.shape[0], -1, self.head_dim_), + (k, v), + k_scale=k_descale, + v_scale=v_descale, + out=o_tensor.view(q.shape[0], -1, self.head_dim_), + ) + return o_tensor + def _context_attention_flashinfer_kernel( self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None ) -> torch.Tensor: @@ -249,7 +317,7 @@ def _context_attention_kernel_ppl_int8kv( ) return o_tensor - def _context_attention_flashattention(self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None): + def _context_attention_flashattention(self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None): cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( -1, 1, self.tp_k_head_num_, self.head_dim_ ) @@ -279,6 +347,49 @@ def _context_attention_flashattention(self, q, kv, infer_state: LlamaInferStateI ) return o + def _context_attention_flashattention_fp8( + self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None + ): + q, q_scale = q_per_head_fp8_quant( + q.view(q.shape[0], self.tp_k_head_num_, -1), + infer_state.b_seq_len, + infer_state.cu_seqlens_q, + infer_state.q_scale, + infer_state.token_batch_ids, + ) + cache_k = ( + (infer_state.mem_manager.kv_buffer[self.layer_num_][:, : self.tp_k_head_num_, :]) + .reshape(-1, 1, self.tp_k_head_num_, self.head_dim_) + .view(torch.float8_e4m3fn) + ) + cache_v = ( + ( + infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ] + ) + .reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) + .view(torch.float8_e4m3fn) + ) + o = flash_attn_with_kvcache( + q=q.view(-1, self.tp_q_head_num_, self.head_dim_), + k_cache=cache_k, + v_cache=cache_v, + page_table=infer_state.page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.cu_seqlens_k, + max_seqlen_q=infer_state.q_max_seq_len, + causal=True, + window_size=(-1, -1), + softcap=0.0, + q_descale=q_scale, + k_descale=infer_state.k_descale[self.layer_num_], + v_descale=infer_state.v_descale[self.layer_num_], + return_softmax_lse=False, + ) + return o + def _get_o( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: @@ -365,12 +476,27 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) return + def _copy_kv_to_mem_cache_with_calibration(self, buffer, mem_index, mem_manager): + destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) + mem_manager.update_calibration_data(buffer, self.layer_num_) + return + def _copy_kv_to_mem_cache_int8kv(self, buffer, mem_index, mem_manager): destindex_copy_quantize_kv( buffer, mem_index, mem_manager.kv_buffer[self.layer_num_], mem_manager.scale_buffer[self.layer_num_] ) return + def _copy_kv_to_mem_cache_fp8kv(self, buffer, mem_index, mem_manager): + scales = mem_manager.scales + destindex_copy_kv_fp8( + buffer, + mem_index, + scales[self.layer_num_] if scales is not None else None, + mem_manager.kv_buffer[self.layer_num_].view(torch.float8_e4m3fn), + ) + return + def _copy_kv_to_mem_cache_ppl_int8kv(self, buffer, mem_index, mem_manager): from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_quantize_kv @@ -387,6 +513,26 @@ def _copy_kv_to_mem_cache_ppl_int4kv(self, buffer, mem_index, mem_manager): ) return + def _token_decode_attention_flashinfer_fp8(self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None): + batch_size = infer_state.batch_size + calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) + + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + kv = infer_state.mem_manager.kv_buffer[self.layer_num_].unsqueeze(1) + k = kv[:, :, : self.tp_k_head_num_, :].view(torch.float8_e4m3fn) + v = kv[:, :, self.tp_k_head_num_ :, :].view(torch.float8_e4m3fn) + offline_scales = infer_state.mem_manager.scales_list + k_descale = offline_scales[self.layer_num_][0] if offline_scales is not None else None + v_descale = offline_scales[self.layer_num_][1] if offline_scales is not None else None + infer_state.decode_wrapper.run( + q.view(calcu_shape1), + (k, v), + k_scale=k_descale, + v_scale=v_descale, + out=o_tensor.view(calcu_shape1), + ) + return o_tensor + def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None): batch_size = infer_state.batch_size calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) @@ -678,7 +824,7 @@ def _token_decode_attention_gqa_flashdecoding_vsm( alloc_tensor_func=self.alloc_tensor, ) - def _token_decode_attention_flashattention(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): + def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None): cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( -1, 1, self.tp_k_head_num_, self.head_dim_ ) @@ -708,6 +854,43 @@ def _token_decode_attention_flashattention(self, q, infer_state: LlamaInferState ) return o + def _token_decode_attention_flashattention_fp8( + self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None + ): + cache_k = ( + (infer_state.mem_manager.kv_buffer[self.layer_num_][:, : self.tp_k_head_num_, :]) + .reshape(-1, 1, self.tp_k_head_num_, self.head_dim_) + .view(torch.float8_e4m3fn) + ) + cache_v = ( + ( + infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ] + ) + .reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) + .view(torch.float8_e4m3fn) + ) + q, q_scale = scaled_fp8_quant(q.view(q.shape[0] * self.tp_k_head_num_, -1), use_per_token_if_dynamic=True) + o = flash_attn_with_kvcache( + q=q.view(-1, self.tp_q_head_num_, self.head_dim_), + k_cache=cache_k, + v_cache=cache_v, + page_table=infer_state.page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.cu_seqlens_k, + max_seqlen_q=1, + causal=False, + window_size=(-1, -1), + softcap=0.0, + q_descale=q_scale.view(infer_state.batch_size, self.tp_k_head_num_), + k_descale=infer_state.k_descale[self.layer_num_], + v_descale=infer_state.v_descale[self.layer_num_], + return_softmax_lse=False, + ) + return o + def overlap_tpsp_token_forward( self, input_embdings: torch.Tensor, diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index b55c17afd..abc258e8b 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -41,7 +41,7 @@ def __init__(self, model): ), ] self.q_data_type = model.data_type - self.kv_data_type = model.data_type + self.kv_data_type = torch.float8_e4m3fn if "offline_calibration_fp8kv" in model.mode else model.data_type @ModelRegistry("llama") diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 3f3eaf96f..d904c727f 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -164,11 +164,18 @@ def make_argument_parser() -> argparse.ArgumentParser: default=[], nargs="+", help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding - | triton_gqa_attention | triton_gqa_flashdecoding | triton_fp8kv, + | triton_gqa_attention | triton_gqa_flashdecoding | triton_fp8kv | offline_calibration_fp8kv + | export_fp8kv_calibration triton_flashdecoding mode is for long context, current support llama llama2 qwen; triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA; triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel; triton_fp8kv mode use float8 to store kv cache, currently only for deepseek2; + offline_calibration_fp8kv mode use float8 to store kv cache, need fa3 or flashinfer backend, + currently only for llama and qwen model; + export_fp8kv_calibration record and export kv cache quant calibration results to a json file. + It can be used for llama and qwen model. + Calibration need to disable cudagraph and use fa3 or flashinfer backend. + Tp size must no more than head num when calibration. ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel; ppl_fp16 mode use ppl fast fp16 decode attention kernel; you need to read source code to make sure the supported detail mode for all models""", @@ -442,4 +449,10 @@ def make_argument_parser() -> argparse.ArgumentParser: but ensure that the model is compatible with the specified step count. currently, deepseekv3 model only support 1 step""", ) + parser.add_argument( + "--kv_quant_calibration_config_path", + type=str, + default=None, + help="""Path of the kv quant calibration config. It can be used for llama and qwen model.""", + ) return parser diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index de1e690a2..6e6c27b5e 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -110,6 +110,21 @@ def normal_or_p_d_start(args): if args.return_all_prompt_logprobs: assert args.disable_dynamic_prompt_cache is True, "need add --disable_dynamic_prompt_cache" assert args.disable_chunked_prefill is True, "need add --disable_chunked_prefill" + if "offline_calibration_fp8kv" in args.mode: + assert args.enable_fa3 is True or ( + args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True + ), ( + "offline_calibration_fp8kv mode need enable fa3 or flashinfer, add --enable_fa3 or " + "--enable_flashinfer_prefill and --enable_flashinfer_decode" + ) + if "export_fp8kv_calibration" in args.mode: + assert args.enable_fa3 is True or ( + args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True + ), ( + "export_fp8kv_calibration mode need enable fa3 or flashinfer, add --enable_fa3 or " + "--enable_flashinfer_prefill and --enable_flashinfer_decode" + ) + assert args.disable_cudagraph is True, "export_fp8kv_calibration mode need disable cudagraph" # 部分模式还不能支持与高级动态调度算法协同,to do. if args.diverse_mode: diff --git a/lightllm/server/core/objs/out_token_circlequeue.py b/lightllm/server/core/objs/out_token_circlequeue.py index b3cb65ef5..eaad8cde3 100644 --- a/lightllm/server/core/objs/out_token_circlequeue.py +++ b/lightllm/server/core/objs/out_token_circlequeue.py @@ -2,7 +2,7 @@ import ctypes from typing import Tuple -LIGHTLLM_TOKEN_MAX_BYTES = int(os.getenv("LIGHTLLM_TOKEN_MAX_BYTES", 128)) +LIGHTLLM_TOKEN_MAX_BYTES = int(os.getenv("LIGHTLLM_TOKEN_MAX_BYTES", 696)) LIGHTLLM_OUT_TOKEN_QUEUE_SIZE = int(os.getenv("LIGHTLLM_OUT_TOKEN_QUEUE_SIZE", 8)) @@ -24,7 +24,9 @@ def __init__(self): def set(self, token_str: str, src_index: int, special: bool, count_output_tokens: int): str_bytes = token_str.encode("utf-8") - assert len(str_bytes) <= LIGHTLLM_TOKEN_MAX_BYTES + assert ( + len(str_bytes) <= LIGHTLLM_TOKEN_MAX_BYTES + ), f"Token string {len(str_bytes)} exceeds maximum length of {LIGHTLLM_TOKEN_MAX_BYTES} bytes." ctypes.memmove(self.data, str_bytes, len(str_bytes)) self.data_len = len(str_bytes) self.src_index = src_index diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index ec1eb427e..f76fbc8c8 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -93,3 +93,4 @@ class StartArgs: mtp_mode: Optional[str] = field(default=None) mtp_draft_model_dir: Optional[str] = field(default=None) mtp_step: int = field(default=0) + kv_quant_calibration_config_path: Optional[str] = field(default=None) diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index a3ceb99e1..7f0601c92 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -31,6 +31,16 @@ def get_eos_token_ids(model_path: str): assert False, "error eos_token_id format in config.json" +def get_model_architectures(model_path: str): + try: + config_json = get_config_json(model_path) + arch = config_json["architectures"][0] + return arch + except: + logger.error("can not get architectures from config.json, return unknown_architecture") + return "unknown_architecture" + + def get_vocab_size(model_path: str): try: config_json = get_config_json(model_path) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index d223931ed..b78784d82 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -77,15 +77,18 @@ def get_lightllm_websocket_max_message_size(): return int(os.getenv("LIGHTLLM_WEBSOCKET_MAX_SIZE", 16 * 1024 * 1024)) -# get_redundancy_expert_ids and get_redundancy_expert_num are primarily used to obtain the IDs and number of redundant experts during inference. -# They depend on a configuration file specified by ep_redundancy_expert_config_path, which is a JSON formatted text file. -# The content format is as follows: -# { -# "redundancy_expert_num": 1, # Number of redundant experts per rank -# "0": [0], # Key: layer_index (string), Value: list of original expert IDs that are redundant for this layer -# "1": [0], -# "default": [0] # Default list of redundant expert IDs if layer-specific entry is not found -# } +# get_redundancy_expert_ids and get_redundancy_expert_num are primarily +# used to obtain the IDs and number of redundant experts during inference. +# They depend on a configuration file specified by ep_redundancy_expert_config_path, +# which is a JSON formatted text file. +# The content format is as follows: +# { +# "redundancy_expert_num": 1, # Number of redundant experts per rank +# "0": [0], # Key: layer_index (string), +# # Value: list of original expert IDs that are redundant for this layer +# "1": [0], +# "default": [0] # Default list of redundant expert IDs if layer-specific entry is not found +# } @lru_cache(maxsize=None) @@ -132,3 +135,29 @@ def get_redundancy_expert_update_interval(): @lru_cache(maxsize=None) def get_redundancy_expert_update_max_load_count(): return int(os.getenv("LIGHTLLM_REDUNDANCY_EXPERT_UPDATE_MAX_LOAD_COUNT", 1)) + + +@lru_cache(maxsize=None) +def get_kv_quant_calibration_warmup_count(): + # 服务启动后前warmup次推理不计入量化校准统计,该参数可以控制在一个更大的校准数据集的不同位置处开始校准。 + return int(os.getenv("LIGHTLLM_KV_QUANT_CALIBRARTION_WARMUP_COUNT", 0)) + + +@lru_cache(maxsize=None) +def get_kv_quant_calibration_inference_count(): + # warmup后开始进行量化校准统计,推理次数达到inference_count后输出统计校准结果,通过该参数可以控制对量化校准数据的采集量。 + return int(os.getenv("LIGHTLLM_KV_QUANT_CALIBRARTION_INFERENCE_COUNT", 4000)) + + +g_model_init_done = False + + +def get_model_init_status(): + global g_model_init_done + return g_model_init_done + + +def set_model_init_status(status: bool): + global g_model_init_done + g_model_init_done = status + return g_model_init_done diff --git a/test/advanced_config/fp8_calibration_per_head/test_kv_cache_calib_per_head_qwen2.5_14b.json b/test/advanced_config/fp8_calibration_per_head/test_kv_cache_calib_per_head_qwen2.5_14b.json new file mode 100644 index 000000000..4dec53ae4 --- /dev/null +++ b/test/advanced_config/fp8_calibration_per_head/test_kv_cache_calib_per_head_qwen2.5_14b.json @@ -0,0 +1,879 @@ +{ + "version": "1.0", + "architectures": "Qwen2ForCausalLM", + "quant_type": "per_head", + "qmin": -448.0, + "qmax": 448.0, + "num_layers": 48, + "num_head": 8, + "scales_shape": [ + 48, + 16 + ], + "scales": [ + [ + 0.0555245578289032, + 0.0518973246216774, + 0.0357142873108387, + 0.0449218787252903, + 0.0571986623108387, + 0.0426897332072258, + 0.0376674123108387, + 0.0474330373108387, + 0.0096261166036129, + 0.008754185400903225, + 0.0052315848879516125, + 0.00798688642680645, + 0.01297433115541935, + 0.007603236939758062, + 0.006417410913854837, + 0.00456891767680645 + ], + [ + 0.0156947560608387, + 0.0210658498108387, + 0.0138811394572258, + 0.0182756707072258, + 0.01785714365541935, + 0.02441406436264515, + 0.012834821827709675, + 0.01311383955180645, + 0.0013427735539153218, + 0.0022495815064758062, + 0.0016392299439758062, + 0.0012468610657379031, + 0.0019618445076048374, + 0.003191266907379031, + 0.00115966796875, + 0.0011247907532379031 + ], + [ + 0.02748326025903225, + 0.01409040205180645, + 0.0563616082072258, + 0.0262276791036129, + 0.02734375186264515, + 0.0279017873108387, + 0.02385602705180645, + 0.02832031436264515, + 0.007777622900903225, + 0.0030343192629516125, + 0.006801060400903225, + 0.004847935400903225, + 0.00505719892680645, + 0.0029820033814758062, + 0.0044991630129516125, + 0.005615234840661287 + ], + [ + 0.0322265625, + 0.02539062686264515, + 0.0686383992433548, + 0.0319475457072258, + 0.0260881707072258, + 0.02469308115541935, + 0.0301339291036129, + 0.02287946455180645, + 0.00359235517680645, + 0.0074288509786129, + 0.01492745615541935, + 0.00725446455180645, + 0.006487165577709675, + 0.0062081473879516125, + 0.005754743702709675, + 0.008021763525903225 + ], + [ + 0.02859933115541935, + 0.0313895121216774, + 0.02343750186264515, + 0.0272042416036129, + 0.0327845998108387, + 0.02664620615541935, + 0.0298549123108387, + 0.0290178582072258, + 0.0066964291036129, + 0.005824497900903225, + 0.005684989038854837, + 0.008858817629516125, + 0.007533482275903225, + 0.007463728077709675, + 0.005092076025903225, + 0.00857979990541935 + ], + [ + 0.0269252248108387, + 0.0393415205180645, + 0.02650669775903225, + 0.0230189748108387, + 0.0359933041036129, + 0.0359933041036129, + 0.0330636166036129, + 0.033203125, + 0.010393415577709675, + 0.007603236939758062, + 0.008754185400903225, + 0.01004464365541935, + 0.0084054134786129, + 0.0091378353536129, + 0.00823102705180645, + 0.01346261240541935 + ], + [ + 0.0341796875, + 0.0382254496216774, + 0.0379464291036129, + 0.0287388414144516, + 0.0329241082072258, + 0.0340401791036129, + 0.0379464291036129, + 0.0359933041036129, + 0.00833565928041935, + 0.0148577019572258, + 0.009835380129516125, + 0.007882255129516125, + 0.008265904150903225, + 0.011718750931322575, + 0.007777622900903225, + 0.010951451025903225 + ], + [ + 0.0336216539144516, + 0.0341796875, + 0.0376674123108387, + 0.041015625, + 0.0343191996216774, + 0.0339006707072258, + 0.0368303582072258, + 0.0379464291036129, + 0.00969587080180645, + 0.0101143978536129, + 0.0087890625, + 0.00809151865541935, + 0.01457868330180645, + 0.01346261240541935, + 0.011718750931322575, + 0.01067243330180645 + ], + [ + 0.02469308115541935, + 0.0357142873108387, + 0.02943638525903225, + 0.0387834832072258, + 0.0412946455180645, + 0.037109375, + 0.02483258955180645, + 0.0323660746216774, + 0.0135323666036129, + 0.01025390625, + 0.012276786379516125, + 0.00833565928041935, + 0.0135323666036129, + 0.0120675228536129, + 0.009835380129516125, + 0.009765625 + ], + [ + 0.0404575914144516, + 0.0357142873108387, + 0.0379464291036129, + 0.0333426371216774, + 0.0404575914144516, + 0.0362723246216774, + 0.0350167416036129, + 0.0376674123108387, + 0.008370536379516125, + 0.0076729916036129, + 0.008928571827709675, + 0.01262555830180645, + 0.0088936947286129, + 0.01018415205180645, + 0.01409040205180645, + 0.010811942629516125 + ], + [ + 0.0337611623108387, + 0.0373883955180645, + 0.0359933041036129, + 0.0357142873108387, + 0.0418526791036129, + 0.02845982275903225, + 0.0337611623108387, + 0.0379464291036129, + 0.01736886240541935, + 0.01067243330180645, + 0.013253348879516125, + 0.01262555830180645, + 0.01053292490541935, + 0.0133928582072258, + 0.01297433115541935, + 0.011439732275903225 + ], + [ + 0.0426897332072258, + 0.0279017873108387, + 0.0426897332072258, + 0.0340401791036129, + 0.0415736623108387, + 0.0491071455180645, + 0.0368303582072258, + 0.0345982164144516, + 0.011788505129516125, + 0.010951451025903225, + 0.01611328125, + 0.010323661379516125, + 0.01067243330180645, + 0.010323661379516125, + 0.009765625, + 0.007219587452709675 + ], + [ + 0.0259486623108387, + 0.03125, + 0.03515625, + 0.02832031436264515, + 0.0415736623108387, + 0.0468750037252903, + 0.0348772332072258, + 0.0418526791036129, + 0.014160157181322575, + 0.007219587452709675, + 0.011439732275903225, + 0.012904576025903225, + 0.009835380129516125, + 0.011230469681322575, + 0.007638114038854837, + 0.0140206478536129 + ], + [ + 0.0421316996216774, + 0.0343191996216774, + 0.0404575914144516, + 0.0407366082072258, + 0.0362723246216774, + 0.0323660746216774, + 0.0429687537252903, + 0.0350167416036129, + 0.01053292490541935, + 0.01004464365541935, + 0.010323661379516125, + 0.0096261166036129, + 0.01025390625, + 0.01492745615541935, + 0.00920758955180645, + 0.014229911379516125 + ], + [ + 0.0359933041036129, + 0.02943638525903225, + 0.0407366082072258, + 0.0355747789144516, + 0.0323660746216774, + 0.02385602705180645, + 0.02957589365541935, + 0.0339006707072258, + 0.010463169775903225, + 0.01150948740541935, + 0.00920758955180645, + 0.011718750931322575, + 0.01164899580180645, + 0.01506696455180645, + 0.012276786379516125, + 0.013741630129516125 + ], + [ + 0.0424107164144516, + 0.0432477705180645, + 0.03055245615541935, + 0.0457589291036129, + 0.0432477705180645, + 0.0454799123108387, + 0.0387834832072258, + 0.0334821455180645, + 0.0161830373108387, + 0.008056640625, + 0.009835380129516125, + 0.008614677004516125, + 0.011369978077709675, + 0.0133928582072258, + 0.008754185400903225, + 0.0130440853536129 + ], + [ + 0.0359933041036129, + 0.0287388414144516, + 0.0488281287252903, + 0.0404575914144516, + 0.0290178582072258, + 0.0279017873108387, + 0.037109375, + 0.0393415205180645, + 0.010742188431322575, + 0.011928013525903225, + 0.0096261166036129, + 0.014229911379516125, + 0.012207032181322575, + 0.0154854916036129, + 0.0081612728536129, + 0.012695313431322575 + ], + [ + 0.0404575914144516, + 0.0326450914144516, + 0.0341796875, + 0.0385044664144516, + 0.0341796875, + 0.0396205373108387, + 0.0345982164144516, + 0.03055245615541935, + 0.008928571827709675, + 0.012276786379516125, + 0.011369978077709675, + 0.01736886240541935, + 0.013323103077709675, + 0.0158342644572258, + 0.011369978077709675, + 0.01102120615541935 + ], + [ + 0.03027343936264515, + 0.02929687686264515, + 0.0482700914144516, + 0.0333426371216774, + 0.0396205373108387, + 0.0407366082072258, + 0.0474330373108387, + 0.0426897332072258, + 0.014648438431322575, + 0.00906808115541935, + 0.01150948740541935, + 0.013671875931322575, + 0.01639229990541935, + 0.011439732275903225, + 0.009835380129516125, + 0.0163225457072258 + ], + [ + 0.03515625, + 0.0373883955180645, + 0.0354352705180645, + 0.0398995541036129, + 0.0471540205180645, + 0.0329241082072258, + 0.0429687537252903, + 0.0421316996216774, + 0.010811942629516125, + 0.01604352705180645, + 0.01199776865541935, + 0.01625279150903225, + 0.0159737728536129, + 0.01262555830180645, + 0.01116071455180645, + 0.01018415205180645 + ], + [ + 0.0340401791036129, + 0.041015625, + 0.0269252248108387, + 0.0488281287252903, + 0.0376674123108387, + 0.0407366082072258, + 0.0499441996216774, + 0.0340401791036129, + 0.0143694207072258, + 0.0158342644572258, + 0.01506696455180645, + 0.017578125, + 0.01492745615541935, + 0.0101143978536129, + 0.0149972103536129, + 0.010951451025903225 + ], + [ + 0.0315290205180645, + 0.0382254496216774, + 0.041015625, + 0.0426897332072258, + 0.0385044664144516, + 0.0404575914144516, + 0.0452008955180645, + 0.0418526791036129, + 0.01708984375, + 0.01150948740541935, + 0.00882394053041935, + 0.00969587080180645, + 0.01297433115541935, + 0.01360212080180645, + 0.0192522332072258, + 0.0166713185608387 + ], + [ + 0.0502232164144516, + 0.0329241082072258, + 0.0355747789144516, + 0.0449218787252903, + 0.0308314748108387, + 0.0647321492433548, + 0.0387834832072258, + 0.0474330373108387, + 0.008858817629516125, + 0.02762276865541935, + 0.011858259327709675, + 0.009835380129516125, + 0.0168108269572258, + 0.0163225457072258, + 0.015136719681322575, + 0.013253348879516125 + ], + [ + 0.0320870541036129, + 0.0465959832072258, + 0.037109375, + 0.03125, + 0.0661272332072258, + 0.0313895121216774, + 0.0393415205180645, + 0.0311104916036129, + 0.01311383955180645, + 0.009905134327709675, + 0.01199776865541935, + 0.01639229990541935, + 0.012834821827709675, + 0.0185546875, + 0.01067243330180645, + 0.01541573740541935 + ], + [ + 0.033203125, + 0.0435267873108387, + 0.0354352705180645, + 0.0288783498108387, + 0.0325055830180645, + 0.03515625, + 0.0460379496216774, + 0.02957589365541935, + 0.011928013525903225, + 0.00969587080180645, + 0.00906808115541935, + 0.0159737728536129, + 0.0138811394572258, + 0.012904576025903225, + 0.010463169775903225, + 0.0181361623108387 + ], + [ + 0.0329241082072258, + 0.0352957621216774, + 0.0382254496216774, + 0.0385044664144516, + 0.0354352705180645, + 0.0390625, + 0.0424107164144516, + 0.0376674123108387, + 0.015276228077709675, + 0.01688058115541935, + 0.011718750931322575, + 0.011928013525903225, + 0.01346261240541935, + 0.01297433115541935, + 0.012834821827709675, + 0.01262555830180645 + ], + [ + 0.0354352705180645, + 0.0407366082072258, + 0.0507812537252903, + 0.0343191996216774, + 0.0359933041036129, + 0.0449218787252903, + 0.0421316996216774, + 0.0288783498108387, + 0.0184151791036129, + 0.01409040205180645, + 0.01653180830180645, + 0.01199776865541935, + 0.01722935400903225, + 0.0140206478536129, + 0.01492745615541935, + 0.0213448666036129 + ], + [ + 0.0479910746216774, + 0.03125, + 0.0429687537252903, + 0.0299944207072258, + 0.0359933041036129, + 0.0320870541036129, + 0.0396205373108387, + 0.0415736623108387, + 0.014787946827709675, + 0.011928013525903225, + 0.015136719681322575, + 0.014787946827709675, + 0.0143694207072258, + 0.013671875931322575, + 0.01604352705180645, + 0.0135323666036129 + ], + [ + 0.0347377248108387, + 0.0297154039144516, + 0.0426897332072258, + 0.0421316996216774, + 0.03125, + 0.0426897332072258, + 0.037109375, + 0.0368303582072258, + 0.012346540577709675, + 0.0171595998108387, + 0.013253348879516125, + 0.0101143978536129, + 0.0407366082072258, + 0.0143694207072258, + 0.0200892873108387, + 0.01981026865541935 + ], + [ + 0.0376674123108387, + 0.0330636166036129, + 0.0398995541036129, + 0.0398995541036129, + 0.0479910746216774, + 0.02762276865541935, + 0.0700334832072258, + 0.0326450914144516, + 0.01199776865541935, + 0.01702008955180645, + 0.01722935400903225, + 0.0120675228536129, + 0.0154854916036129, + 0.0145089291036129, + 0.01967076025903225, + 0.01213727705180645 + ], + [ + 0.0471540205180645, + 0.0387834832072258, + 0.0449218787252903, + 0.0474330373108387, + 0.0429687537252903, + 0.0365513414144516, + 0.0809151828289032, + 0.0362723246216774, + 0.015625, + 0.01199776865541935, + 0.01053292490541935, + 0.0212053582072258, + 0.01443917490541935, + 0.0163225457072258, + 0.0158342644572258, + 0.01506696455180645 + ], + [ + 0.041015625, + 0.0449218787252903, + 0.0357142873108387, + 0.0382254496216774, + 0.0398995541036129, + 0.0482700914144516, + 0.0429687537252903, + 0.0608258955180645, + 0.01067243330180645, + 0.012346540577709675, + 0.014229911379516125, + 0.013323103077709675, + 0.01248604990541935, + 0.015625, + 0.01102120615541935, + 0.014160157181322575 + ], + [ + 0.0336216539144516, + 0.0352957621216774, + 0.0387834832072258, + 0.0329241082072258, + 0.041015625, + 0.0373883955180645, + 0.0390625, + 0.0435267873108387, + 0.0185546875, + 0.01360212080180645, + 0.014787946827709675, + 0.0203683041036129, + 0.013671875931322575, + 0.01611328125, + 0.012207032181322575, + 0.013183594681322575 + ], + [ + 0.0340401791036129, + 0.0365513414144516, + 0.0398995541036129, + 0.064453125, + 0.037109375, + 0.0347377248108387, + 0.0468750037252903, + 0.0424107164144516, + 0.01688058115541935, + 0.0184151791036129, + 0.0143694207072258, + 0.014299665577709675, + 0.012416294775903225, + 0.013323103077709675, + 0.0184151791036129, + 0.0205078125 + ], + [ + 0.02553013525903225, + 0.0327845998108387, + 0.0309709832072258, + 0.0440848246216774, + 0.0287388414144516, + 0.0477120541036129, + 0.0401785746216774, + 0.0396205373108387, + 0.014160157181322575, + 0.01702008955180645, + 0.017578125, + 0.012346540577709675, + 0.0203683041036129, + 0.014787946827709675, + 0.015276228077709675, + 0.01395089365541935 + ], + [ + 0.0382254496216774, + 0.03515625, + 0.0382254496216774, + 0.0440848246216774, + 0.0493861623108387, + 0.0505022332072258, + 0.02664620615541935, + 0.0407366082072258, + 0.0200892873108387, + 0.02078683115541935, + 0.01967076025903225, + 0.0193917416036129, + 0.01555524580180645, + 0.014648438431322575, + 0.01576451025903225, + 0.0158342644572258 + ], + [ + 0.0532924123108387, + 0.041015625, + 0.0396205373108387, + 0.041015625, + 0.0319475457072258, + 0.0460379496216774, + 0.0404575914144516, + 0.0340401791036129, + 0.0185546875, + 0.0182756707072258, + 0.0269252248108387, + 0.01457868330180645, + 0.02357701025903225, + 0.01555524580180645, + 0.0252511166036129, + 0.02566964365541935 + ], + [ + 0.0379464291036129, + 0.0432477705180645, + 0.0382254496216774, + 0.0387834832072258, + 0.0347377248108387, + 0.0426897332072258, + 0.0322265625, + 0.0424107164144516, + 0.0279017873108387, + 0.01883370615541935, + 0.02455357275903225, + 0.0259486623108387, + 0.0166015625, + 0.02273995615541935, + 0.01869419775903225, + 0.03055245615541935 + ], + [ + 0.0322265625, + 0.0452008955180645, + 0.0393415205180645, + 0.0491071455180645, + 0.0379464291036129, + 0.0362723246216774, + 0.0359933041036129, + 0.0468750037252903, + 0.02176339365541935, + 0.0192522332072258, + 0.01653180830180645, + 0.0181361623108387, + 0.01981026865541935, + 0.0181361623108387, + 0.0191127248108387, + 0.01869419775903225 + ], + [ + 0.0382254496216774, + 0.041015625, + 0.0319475457072258, + 0.0407366082072258, + 0.0362723246216774, + 0.0432477705180645, + 0.0323660746216774, + 0.0449218787252903, + 0.0239955373108387, + 0.0185546875, + 0.02246093936264515, + 0.01994977705180645, + 0.02636718936264515, + 0.0182756707072258, + 0.0223214291036129, + 0.015206473879516125 + ], + [ + 0.0376674123108387, + 0.0404575914144516, + 0.0401785746216774, + 0.0357142873108387, + 0.0435267873108387, + 0.0336216539144516, + 0.0306919664144516, + 0.0396205373108387, + 0.02845982275903225, + 0.01967076025903225, + 0.0231584832072258, + 0.0185546875, + 0.0269252248108387, + 0.0290178582072258, + 0.02385602705180645, + 0.02636718936264515 + ], + [ + 0.0320870541036129, + 0.0291573666036129, + 0.037109375, + 0.0373883955180645, + 0.0415736623108387, + 0.0465959832072258, + 0.03027343936264515, + 0.0401785746216774, + 0.0213448666036129, + 0.02371651865541935, + 0.0239955373108387, + 0.0309709832072258, + 0.0223214291036129, + 0.0171595998108387, + 0.0213448666036129, + 0.02483258955180645 + ], + [ + 0.0385044664144516, + 0.0457589291036129, + 0.03027343936264515, + 0.0354352705180645, + 0.0426897332072258, + 0.03041294775903225, + 0.0322265625, + 0.0368303582072258, + 0.0297154039144516, + 0.02929687686264515, + 0.015625, + 0.02483258955180645, + 0.0220424123108387, + 0.0212053582072258, + 0.01883370615541935, + 0.02273995615541935 + ], + [ + 0.0571986623108387, + 0.0415736623108387, + 0.0421316996216774, + 0.0354352705180645, + 0.0393415205180645, + 0.0393415205180645, + 0.0357142873108387, + 0.0463169664144516, + 0.0344587080180645, + 0.02273995615541935, + 0.02539062686264515, + 0.0323660746216774, + 0.03125, + 0.03027343936264515, + 0.0398995541036129, + 0.0401785746216774 + ], + [ + 0.0323660746216774, + 0.0345982164144516, + 0.037109375, + 0.0362723246216774, + 0.0440848246216774, + 0.0505022332072258, + 0.0429687537252903, + 0.0404575914144516, + 0.0259486623108387, + 0.0357142873108387, + 0.0336216539144516, + 0.02469308115541935, + 0.0376674123108387, + 0.0390625, + 0.0345982164144516, + 0.0344587080180645 + ], + [ + 0.0622209832072258, + 0.0379464291036129, + 0.0452008955180645, + 0.0352957621216774, + 0.0418526791036129, + 0.0505022332072258, + 0.0546875037252903, + 0.0415736623108387, + 0.0382254496216774, + 0.02190290205180645, + 0.0398995541036129, + 0.0376674123108387, + 0.0566406287252903, + 0.0474330373108387, + 0.0471540205180645, + 0.0521763414144516 + ], + [ + 0.0421316996216774, + 0.0505022332072258, + 0.0333426371216774, + 0.0859375074505806, + 0.0454799123108387, + 0.0390625, + 0.0499441996216774, + 0.0387834832072258, + 0.0496651791036129, + 0.0563616082072258, + 0.0412946455180645, + 0.0329241082072258, + 0.065011166036129, + 0.0675223246216774, + 0.0493861623108387, + 0.02469308115541935 + ], + [ + 0.0359933041036129, + 0.0347377248108387, + 0.0530133955180645, + 0.037109375, + 0.0396205373108387, + 0.0306919664144516, + 0.0580357164144516, + 0.0563616082072258, + 0.0588727705180645, + 0.0457589291036129, + 0.0407366082072258, + 0.0382254496216774, + 0.0535714328289032, + 0.0334821455180645, + 0.0538504496216774, + 0.0736607164144516 + ] + ] +} \ No newline at end of file diff --git a/test/advanced_config/fp8_calibration_per_head/test_kv_cache_calib_per_head_qwen2.5_32b.json b/test/advanced_config/fp8_calibration_per_head/test_kv_cache_calib_per_head_qwen2.5_32b.json new file mode 100644 index 000000000..a5c4a56ed --- /dev/null +++ b/test/advanced_config/fp8_calibration_per_head/test_kv_cache_calib_per_head_qwen2.5_32b.json @@ -0,0 +1,1167 @@ +{ + "version": "1.0", + "architectures": "Qwen2ForCausalLM", + "quant_type": "per_head", + "qmin": -448.0, + "qmax": 448.0, + "num_layers": 64, + "num_head": 8, + "scales_shape": [ + 64, + 16 + ], + "scales": [ + [ + 0.0552455373108387, + 0.0518973246216774, + 0.0362723246216774, + 0.0440848246216774, + 0.0571986623108387, + 0.0452008955180645, + 0.0333426371216774, + 0.0468750037252903, + 0.0106026791036129, + 0.011928013525903225, + 0.006835937965661287, + 0.012765067629516125, + 0.0210658498108387, + 0.010463169775903225, + 0.007707868702709675, + 0.006801060400903225 + ], + [ + 0.01492745615541935, + 0.02148437686264515, + 0.013183594681322575, + 0.0181361623108387, + 0.01771763525903225, + 0.02539062686264515, + 0.013741630129516125, + 0.012834821827709675, + 0.0015694755129516125, + 0.002284458838403225, + 0.002406529150903225, + 0.0019095285097137094, + 0.0013078962219879031, + 0.0035400392953306437, + 0.0011509486939758062, + 0.0018484933534637094 + ], + [ + 0.02859933115541935, + 0.01457868330180645, + 0.0558035746216774, + 0.0262276791036129, + 0.02483258955180645, + 0.02162388525903225, + 0.0169503353536129, + 0.0291573666036129, + 0.007882255129516125, + 0.0035226005129516125, + 0.00833565928041935, + 0.005440848413854837, + 0.004045759327709675, + 0.00334821455180645, + 0.004708426538854837, + 0.006731306202709675 + ], + [ + 0.0315290205180645, + 0.0223214291036129, + 0.0513392873108387, + 0.0320870541036129, + 0.0231584832072258, + 0.02092633955180645, + 0.02483258955180645, + 0.02273995615541935, + 0.0044991630129516125, + 0.012904576025903225, + 0.0182756707072258, + 0.0086495541036129, + 0.00906808115541935, + 0.00701032392680645, + 0.007533482275903225, + 0.0096261166036129 + ], + [ + 0.0288783498108387, + 0.02748326025903225, + 0.0205078125, + 0.02566964365541935, + 0.0267857164144516, + 0.02845982275903225, + 0.0313895121216774, + 0.02762276865541935, + 0.008754185400903225, + 0.0059640067629516125, + 0.006801060400903225, + 0.0069405697286129, + 0.01150948740541935, + 0.00969587080180645, + 0.005092076025903225, + 0.014648438431322575 + ], + [ + 0.02636718936264515, + 0.0412946455180645, + 0.0315290205180645, + 0.0251116082072258, + 0.0415736623108387, + 0.0387834832072258, + 0.0404575914144516, + 0.0373883955180645, + 0.0133928582072258, + 0.008998326025903225, + 0.013183594681322575, + 0.01018415205180645, + 0.008998326025903225, + 0.00927734375, + 0.010463169775903225, + 0.008370536379516125 + ], + [ + 0.0350167416036129, + 0.0385044664144516, + 0.0418526791036129, + 0.02845982275903225, + 0.0336216539144516, + 0.0365513414144516, + 0.0350167416036129, + 0.037109375, + 0.009905134327709675, + 0.0138811394572258, + 0.01116071455180645, + 0.009486607275903225, + 0.010951451025903225, + 0.012904576025903225, + 0.007777622900903225, + 0.01248604990541935 + ], + [ + 0.0337611623108387, + 0.0327845998108387, + 0.0368303582072258, + 0.0418526791036129, + 0.0348772332072258, + 0.0313895121216774, + 0.0333426371216774, + 0.0306919664144516, + 0.0130440853536129, + 0.009835380129516125, + 0.01213727705180645, + 0.009347098879516125, + 0.0156947560608387, + 0.01604352705180645, + 0.011928013525903225, + 0.012346540577709675 + ], + [ + 0.0242745541036129, + 0.0323660746216774, + 0.0269252248108387, + 0.0373883955180645, + 0.0393415205180645, + 0.037109375, + 0.0252511166036129, + 0.0333426371216774, + 0.01199776865541935, + 0.010811942629516125, + 0.01004464365541935, + 0.00927734375, + 0.011300223879516125, + 0.010951451025903225, + 0.013671875931322575, + 0.0101143978536129 + ], + [ + 0.0376674123108387, + 0.0344587080180645, + 0.0347377248108387, + 0.0337611623108387, + 0.0415736623108387, + 0.0337611623108387, + 0.0345982164144516, + 0.0385044664144516, + 0.009765625, + 0.01018415205180645, + 0.012416294775903225, + 0.00927734375, + 0.00927734375, + 0.01102120615541935, + 0.01443917490541935, + 0.01262555830180645 + ], + [ + 0.0340401791036129, + 0.0350167416036129, + 0.0333426371216774, + 0.0357142873108387, + 0.0396205373108387, + 0.0269252248108387, + 0.0326450914144516, + 0.0368303582072258, + 0.01590401865541935, + 0.01248604990541935, + 0.01590401865541935, + 0.014299665577709675, + 0.01297433115541935, + 0.0154854916036129, + 0.01625279150903225, + 0.010742188431322575 + ], + [ + 0.0421316996216774, + 0.0269252248108387, + 0.0432477705180645, + 0.0362723246216774, + 0.0407366082072258, + 0.0471540205180645, + 0.0382254496216774, + 0.0379464291036129, + 0.015206473879516125, + 0.010393415577709675, + 0.0176478810608387, + 0.01443917490541935, + 0.012904576025903225, + 0.010742188431322575, + 0.00969587080180645, + 0.0084054134786129 + ], + [ + 0.0259486623108387, + 0.0288783498108387, + 0.0327845998108387, + 0.02748326025903225, + 0.0387834832072258, + 0.0457589291036129, + 0.0345982164144516, + 0.0387834832072258, + 0.01981026865541935, + 0.008126395754516125, + 0.015625, + 0.011369978077709675, + 0.0096261166036129, + 0.0130440853536129, + 0.00920758955180645, + 0.0193917416036129 + ], + [ + 0.0387834832072258, + 0.0318080373108387, + 0.0390625, + 0.0352957621216774, + 0.03515625, + 0.0319475457072258, + 0.0401785746216774, + 0.0329241082072258, + 0.01443917490541935, + 0.00969587080180645, + 0.011300223879516125, + 0.0101143978536129, + 0.011230469681322575, + 0.01674107275903225, + 0.0106026791036129, + 0.01688058115541935 + ], + [ + 0.0350167416036129, + 0.02748326025903225, + 0.0404575914144516, + 0.0334821455180645, + 0.0287388414144516, + 0.0249720998108387, + 0.02832031436264515, + 0.0309709832072258, + 0.012834821827709675, + 0.01248604990541935, + 0.01213727705180645, + 0.01248604990541935, + 0.009765625, + 0.0158342644572258, + 0.015206473879516125, + 0.012276786379516125 + ], + [ + 0.0368303582072258, + 0.0350167416036129, + 0.02553013525903225, + 0.0426897332072258, + 0.0407366082072258, + 0.0390625, + 0.0316685289144516, + 0.0260881707072258, + 0.0149972103536129, + 0.00830078125, + 0.009765625, + 0.00927734375, + 0.010811942629516125, + 0.00906808115541935, + 0.01025390625, + 0.011300223879516125 + ], + [ + 0.0344587080180645, + 0.02859933115541935, + 0.041015625, + 0.0347377248108387, + 0.02845982275903225, + 0.02650669775903225, + 0.0325055830180645, + 0.0440848246216774, + 0.012834821827709675, + 0.01213727705180645, + 0.009416853077709675, + 0.0149972103536129, + 0.0145089291036129, + 0.0130440853536129, + 0.009416853077709675, + 0.01262555830180645 + ], + [ + 0.03515625, + 0.0288783498108387, + 0.02762276865541935, + 0.0359933041036129, + 0.0337611623108387, + 0.0347377248108387, + 0.0298549123108387, + 0.03055245615541935, + 0.00969587080180645, + 0.01150948740541935, + 0.012276786379516125, + 0.011858259327709675, + 0.009765625, + 0.014299665577709675, + 0.01067243330180645, + 0.01555524580180645 + ], + [ + 0.02636718936264515, + 0.02441406436264515, + 0.0291573666036129, + 0.0313895121216774, + 0.0352957621216774, + 0.0359933041036129, + 0.0438058041036129, + 0.037109375, + 0.012695313431322575, + 0.008544921875, + 0.01771763525903225, + 0.012904576025903225, + 0.014160157181322575, + 0.009835380129516125, + 0.008928571827709675, + 0.01150948740541935 + ], + [ + 0.0287388414144516, + 0.0329241082072258, + 0.0297154039144516, + 0.0365513414144516, + 0.033203125, + 0.02371651865541935, + 0.0393415205180645, + 0.0373883955180645, + 0.014160157181322575, + 0.0168108269572258, + 0.01213727705180645, + 0.01248604990541935, + 0.0125558041036129, + 0.01150948740541935, + 0.01297433115541935, + 0.008998326025903225 + ], + [ + 0.0288783498108387, + 0.0355747789144516, + 0.0231584832072258, + 0.037109375, + 0.03125, + 0.0362723246216774, + 0.0387834832072258, + 0.02539062686264515, + 0.01311383955180645, + 0.012695313431322575, + 0.009905134327709675, + 0.01702008955180645, + 0.0133928582072258, + 0.011928013525903225, + 0.0148577019572258, + 0.011718750931322575 + ], + [ + 0.0288783498108387, + 0.02943638525903225, + 0.03125, + 0.02845982275903225, + 0.02929687686264515, + 0.0319475457072258, + 0.0376674123108387, + 0.0365513414144516, + 0.011788505129516125, + 0.011230469681322575, + 0.0091378353536129, + 0.0110909603536129, + 0.01409040205180645, + 0.014718192629516125, + 0.012765067629516125, + 0.012346540577709675 + ], + [ + 0.0452008955180645, + 0.02832031436264515, + 0.0272042416036129, + 0.0362723246216774, + 0.02343750186264515, + 0.0339006707072258, + 0.0306919664144516, + 0.0376674123108387, + 0.01004464365541935, + 0.012904576025903225, + 0.01164899580180645, + 0.010463169775903225, + 0.010323661379516125, + 0.00955636240541935, + 0.01297433115541935, + 0.011439732275903225 + ], + [ + 0.0232979916036129, + 0.0404575914144516, + 0.03027343936264515, + 0.0315290205180645, + 0.0549665205180645, + 0.0277622789144516, + 0.0373883955180645, + 0.02832031436264515, + 0.0106026791036129, + 0.01248604990541935, + 0.0096261166036129, + 0.01722935400903225, + 0.013811384327709675, + 0.0159737728536129, + 0.00955636240541935, + 0.012346540577709675 + ], + [ + 0.02469308115541935, + 0.0379464291036129, + 0.02832031436264515, + 0.02553013525903225, + 0.02734375186264515, + 0.02469308115541935, + 0.0373883955180645, + 0.0241350457072258, + 0.010742188431322575, + 0.008754185400903225, + 0.010811942629516125, + 0.013253348879516125, + 0.01555524580180645, + 0.010463169775903225, + 0.0091378353536129, + 0.0106026791036129 + ], + [ + 0.02832031436264515, + 0.0288783498108387, + 0.0297154039144516, + 0.0309709832072258, + 0.0277622789144516, + 0.0345982164144516, + 0.0315290205180645, + 0.0348772332072258, + 0.0185546875, + 0.013183594681322575, + 0.00920758955180645, + 0.012695313431322575, + 0.01199776865541935, + 0.014718192629516125, + 0.013183594681322575, + 0.011718750931322575 + ], + [ + 0.0316685289144516, + 0.02734375186264515, + 0.0368303582072258, + 0.02190290205180645, + 0.0318080373108387, + 0.0385044664144516, + 0.0281808041036129, + 0.0320870541036129, + 0.0135323666036129, + 0.01360212080180645, + 0.0171595998108387, + 0.01248604990541935, + 0.0115792416036129, + 0.0115792416036129, + 0.01360212080180645, + 0.01443917490541935 + ], + [ + 0.0432477705180645, + 0.02762276865541935, + 0.037109375, + 0.0279017873108387, + 0.02483258955180645, + 0.0260881707072258, + 0.02859933115541935, + 0.0347377248108387, + 0.01457868330180645, + 0.01443917490541935, + 0.012207032181322575, + 0.0125558041036129, + 0.012904576025903225, + 0.013253348879516125, + 0.014229911379516125, + 0.013741630129516125 + ], + [ + 0.0299944207072258, + 0.0259486623108387, + 0.0415736623108387, + 0.0421316996216774, + 0.0311104916036129, + 0.0379464291036129, + 0.02748326025903225, + 0.0299944207072258, + 0.0185546875, + 0.0135323666036129, + 0.013183594681322575, + 0.01004464365541935, + 0.0161830373108387, + 0.01262555830180645, + 0.013323103077709675, + 0.011439732275903225 + ], + [ + 0.037109375, + 0.0337611623108387, + 0.02929687686264515, + 0.0318080373108387, + 0.0385044664144516, + 0.0241350457072258, + 0.0616629496216774, + 0.02845982275903225, + 0.013671875931322575, + 0.01409040205180645, + 0.011928013525903225, + 0.009835380129516125, + 0.01702008955180645, + 0.0133928582072258, + 0.0145089291036129, + 0.01199776865541935 + ], + [ + 0.0368303582072258, + 0.0306919664144516, + 0.0368303582072258, + 0.0452008955180645, + 0.0415736623108387, + 0.02957589365541935, + 0.0694754496216774, + 0.0325055830180645, + 0.0135323666036129, + 0.010742188431322575, + 0.0101143978536129, + 0.0319475457072258, + 0.0143694207072258, + 0.01116071455180645, + 0.01395089365541935, + 0.0135323666036129 + ], + [ + 0.037109375, + 0.03515625, + 0.0337611623108387, + 0.0357142873108387, + 0.0379464291036129, + 0.0418526791036129, + 0.0334821455180645, + 0.0432477705180645, + 0.011718750931322575, + 0.01262555830180645, + 0.012765067629516125, + 0.01004464365541935, + 0.00927734375, + 0.01248604990541935, + 0.0125558041036129, + 0.0213448666036129 + ], + [ + 0.0337611623108387, + 0.0291573666036129, + 0.0398995541036129, + 0.0401785746216774, + 0.0313895121216774, + 0.02246093936264515, + 0.0365513414144516, + 0.0365513414144516, + 0.013183594681322575, + 0.014787946827709675, + 0.012346540577709675, + 0.02162388525903225, + 0.02176339365541935, + 0.01611328125, + 0.01004464365541935, + 0.0168108269572258 + ], + [ + 0.0393415205180645, + 0.0323660746216774, + 0.0315290205180645, + 0.0393415205180645, + 0.0359933041036129, + 0.0359933041036129, + 0.0355747789144516, + 0.03125, + 0.013671875931322575, + 0.012416294775903225, + 0.017578125, + 0.01639229990541935, + 0.01409040205180645, + 0.02273995615541935, + 0.0161830373108387, + 0.0177873894572258 + ], + [ + 0.0277622789144516, + 0.02762276865541935, + 0.0376674123108387, + 0.0297154039144516, + 0.0345982164144516, + 0.0385044664144516, + 0.0457589291036129, + 0.0426897332072258, + 0.0202287957072258, + 0.013253348879516125, + 0.01736886240541935, + 0.01722935400903225, + 0.0140206478536129, + 0.01708984375, + 0.012416294775903225, + 0.0221819207072258 + ], + [ + 0.0329241082072258, + 0.0376674123108387, + 0.0343191996216774, + 0.0373883955180645, + 0.0454799123108387, + 0.0306919664144516, + 0.0429687537252903, + 0.0382254496216774, + 0.01555524580180645, + 0.0223214291036129, + 0.01639229990541935, + 0.01953125, + 0.0203683041036129, + 0.01799665205180645, + 0.0156947560608387, + 0.0140206478536129 + ], + [ + 0.03027343936264515, + 0.0435267873108387, + 0.0272042416036129, + 0.0415736623108387, + 0.0365513414144516, + 0.0385044664144516, + 0.0443638414144516, + 0.0327845998108387, + 0.0158342644572258, + 0.02148437686264515, + 0.01492745615541935, + 0.0184151791036129, + 0.01981026865541935, + 0.01150948740541935, + 0.01953125, + 0.0156947560608387 + ], + [ + 0.0308314748108387, + 0.0350167416036129, + 0.0362723246216774, + 0.0362723246216774, + 0.0382254496216774, + 0.0352957621216774, + 0.0421316996216774, + 0.0421316996216774, + 0.01409040205180645, + 0.014718192629516125, + 0.015136719681322575, + 0.01311383955180645, + 0.01674107275903225, + 0.0174386166036129, + 0.0252511166036129, + 0.0177873894572258 + ], + [ + 0.0482700914144516, + 0.02762276865541935, + 0.033203125, + 0.0438058041036129, + 0.0341796875, + 0.0719866082072258, + 0.0337611623108387, + 0.0477120541036129, + 0.01067243330180645, + 0.0368303582072258, + 0.01346261240541935, + 0.01555524580180645, + 0.0210658498108387, + 0.01981026865541935, + 0.01771763525903225, + 0.01869419775903225 + ], + [ + 0.0306919664144516, + 0.0432477705180645, + 0.0379464291036129, + 0.02664620615541935, + 0.0655691996216774, + 0.0341796875, + 0.0390625, + 0.0318080373108387, + 0.01555524580180645, + 0.013253348879516125, + 0.01506696455180645, + 0.02190290205180645, + 0.0203683041036129, + 0.02650669775903225, + 0.011858259327709675, + 0.01883370615541935 + ], + [ + 0.0306919664144516, + 0.0387834832072258, + 0.0337611623108387, + 0.0290178582072258, + 0.0352957621216774, + 0.0306919664144516, + 0.0404575914144516, + 0.02832031436264515, + 0.0172991082072258, + 0.01311383955180645, + 0.013253348879516125, + 0.02064732275903225, + 0.01674107275903225, + 0.01702008955180645, + 0.011858259327709675, + 0.02092633955180645 + ], + [ + 0.0311104916036129, + 0.0359933041036129, + 0.0376674123108387, + 0.0382254496216774, + 0.0365513414144516, + 0.0407366082072258, + 0.0426897332072258, + 0.0401785746216774, + 0.01953125, + 0.0230189748108387, + 0.01443917490541935, + 0.01688058115541935, + 0.01736886240541935, + 0.015625, + 0.01541573740541935, + 0.01576451025903225 + ], + [ + 0.0337611623108387, + 0.0352957621216774, + 0.0440848246216774, + 0.0352957621216774, + 0.0362723246216774, + 0.0446428582072258, + 0.0387834832072258, + 0.0288783498108387, + 0.0205078125, + 0.01611328125, + 0.02064732275903225, + 0.01457868330180645, + 0.02078683115541935, + 0.014718192629516125, + 0.0210658498108387, + 0.02357701025903225 + ], + [ + 0.0446428582072258, + 0.0280412957072258, + 0.0460379496216774, + 0.02859933115541935, + 0.0318080373108387, + 0.0326450914144516, + 0.0418526791036129, + 0.0401785746216774, + 0.01785714365541935, + 0.012765067629516125, + 0.0191127248108387, + 0.0164620541036129, + 0.01799665205180645, + 0.015625, + 0.0203683041036129, + 0.0164620541036129 + ], + [ + 0.03125, + 0.02859933115541935, + 0.0421316996216774, + 0.0396205373108387, + 0.0393415205180645, + 0.0471540205180645, + 0.037109375, + 0.0412946455180645, + 0.0138811394572258, + 0.02162388525903225, + 0.015625, + 0.01409040205180645, + 0.065011166036129, + 0.0191127248108387, + 0.0221819207072258, + 0.0272042416036129 + ], + [ + 0.0345982164144516, + 0.0297154039144516, + 0.0345982164144516, + 0.0398995541036129, + 0.0482700914144516, + 0.02553013525903225, + 0.07421875, + 0.0279017873108387, + 0.0192522332072258, + 0.02078683115541935, + 0.02148437686264515, + 0.0164620541036129, + 0.0193917416036129, + 0.01750837080180645, + 0.02148437686264515, + 0.0148577019572258 + ], + [ + 0.0429687537252903, + 0.037109375, + 0.0415736623108387, + 0.0471540205180645, + 0.0432477705180645, + 0.0359933041036129, + 0.0842633992433548, + 0.0322265625, + 0.01785714365541935, + 0.01506696455180645, + 0.0125558041036129, + 0.02343750186264515, + 0.0169503353536129, + 0.0203683041036129, + 0.01883370615541935, + 0.0174386166036129 + ], + [ + 0.0373883955180645, + 0.0468750037252903, + 0.0348772332072258, + 0.0379464291036129, + 0.0379464291036129, + 0.0429687537252903, + 0.0407366082072258, + 0.0622209832072258, + 0.012416294775903225, + 0.013741630129516125, + 0.01653180830180645, + 0.01639229990541935, + 0.012765067629516125, + 0.01639229990541935, + 0.013323103077709675, + 0.014160157181322575 + ], + [ + 0.0315290205180645, + 0.03055245615541935, + 0.0382254496216774, + 0.03041294775903225, + 0.0401785746216774, + 0.0330636166036129, + 0.0322265625, + 0.0463169664144516, + 0.02176339365541935, + 0.01994977705180645, + 0.02162388525903225, + 0.0203683041036129, + 0.01953125, + 0.01708984375, + 0.015276228077709675, + 0.01708984375 + ], + [ + 0.0365513414144516, + 0.0323660746216774, + 0.0373883955180645, + 0.063058041036129, + 0.0313895121216774, + 0.0311104916036129, + 0.0387834832072258, + 0.0359933041036129, + 0.0205078125, + 0.02273995615541935, + 0.01702008955180645, + 0.01722935400903225, + 0.0153459832072258, + 0.01702008955180645, + 0.02078683115541935, + 0.01994977705180645 + ], + [ + 0.02539062686264515, + 0.03041294775903225, + 0.0262276791036129, + 0.0407366082072258, + 0.02762276865541935, + 0.0454799123108387, + 0.037109375, + 0.037109375, + 0.017578125, + 0.01897321455180645, + 0.02190290205180645, + 0.014648438431322575, + 0.02553013525903225, + 0.01785714365541935, + 0.01967076025903225, + 0.0172991082072258 + ], + [ + 0.0341796875, + 0.0337611623108387, + 0.0379464291036129, + 0.0415736623108387, + 0.0471540205180645, + 0.0443638414144516, + 0.0269252248108387, + 0.0387834832072258, + 0.02176339365541935, + 0.01722935400903225, + 0.0241350457072258, + 0.0288783498108387, + 0.0191127248108387, + 0.0200892873108387, + 0.01897321455180645, + 0.01883370615541935 + ], + [ + 0.0560825914144516, + 0.0398995541036129, + 0.0404575914144516, + 0.0354352705180645, + 0.0362723246216774, + 0.0407366082072258, + 0.0418526791036129, + 0.0299944207072258, + 0.01736886240541935, + 0.0322265625, + 0.0339006707072258, + 0.01625279150903225, + 0.0242745541036129, + 0.0158342644572258, + 0.02832031436264515, + 0.0341796875 + ], + [ + 0.0362723246216774, + 0.0404575914144516, + 0.0355747789144516, + 0.0365513414144516, + 0.0308314748108387, + 0.0396205373108387, + 0.0287388414144516, + 0.0390625, + 0.0341796875, + 0.01897321455180645, + 0.0251116082072258, + 0.0334821455180645, + 0.01736886240541935, + 0.02748326025903225, + 0.02553013525903225, + 0.0379464291036129 + ], + [ + 0.0323660746216774, + 0.0435267873108387, + 0.0345982164144516, + 0.0463169664144516, + 0.0357142873108387, + 0.0343191996216774, + 0.0330636166036129, + 0.0415736623108387, + 0.0287388414144516, + 0.02343750186264515, + 0.01785714365541935, + 0.0223214291036129, + 0.0221819207072258, + 0.02064732275903225, + 0.02539062686264515, + 0.0251116082072258 + ], + [ + 0.0309709832072258, + 0.0376674123108387, + 0.0316685289144516, + 0.037109375, + 0.0339006707072258, + 0.0407366082072258, + 0.0329241082072258, + 0.041015625, + 0.03041294775903225, + 0.0202287957072258, + 0.02832031436264515, + 0.0230189748108387, + 0.02957589365541935, + 0.0192522332072258, + 0.02664620615541935, + 0.0174386166036129 + ], + [ + 0.0316685289144516, + 0.0376674123108387, + 0.0337611623108387, + 0.0306919664144516, + 0.0432477705180645, + 0.0315290205180645, + 0.0280412957072258, + 0.0376674123108387, + 0.0354352705180645, + 0.01883370615541935, + 0.0313895121216774, + 0.0213448666036129, + 0.0287388414144516, + 0.0368303582072258, + 0.02832031436264515, + 0.0322265625 + ], + [ + 0.0340401791036129, + 0.0242745541036129, + 0.0355747789144516, + 0.0347377248108387, + 0.037109375, + 0.0446428582072258, + 0.0281808041036129, + 0.0412946455180645, + 0.02762276865541935, + 0.0323660746216774, + 0.0298549123108387, + 0.0382254496216774, + 0.02832031436264515, + 0.0205078125, + 0.0212053582072258, + 0.02734375186264515 + ], + [ + 0.0288783498108387, + 0.0396205373108387, + 0.0320870541036129, + 0.0355747789144516, + 0.0426897332072258, + 0.0298549123108387, + 0.02929687686264515, + 0.03515625, + 0.0373883955180645, + 0.0379464291036129, + 0.0172991082072258, + 0.0325055830180645, + 0.0298549123108387, + 0.0290178582072258, + 0.01869419775903225, + 0.0359933041036129 + ], + [ + 0.0569196455180645, + 0.0390625, + 0.0365513414144516, + 0.0341796875, + 0.0357142873108387, + 0.0365513414144516, + 0.0390625, + 0.0468750037252903, + 0.037109375, + 0.0259486623108387, + 0.0327845998108387, + 0.0401785746216774, + 0.0355747789144516, + 0.0306919664144516, + 0.0452008955180645, + 0.0454799123108387 + ], + [ + 0.0330636166036129, + 0.0368303582072258, + 0.0343191996216774, + 0.0344587080180645, + 0.0387834832072258, + 0.0638950914144516, + 0.0435267873108387, + 0.0421316996216774, + 0.02845982275903225, + 0.02176339365541935, + 0.0438058041036129, + 0.033203125, + 0.0446428582072258, + 0.0449218787252903, + 0.0407366082072258, + 0.0474330373108387 + ], + [ + 0.0591517873108387, + 0.0452008955180645, + 0.0493861623108387, + 0.0435267873108387, + 0.0471540205180645, + 0.0558035746216774, + 0.0616629496216774, + 0.0496651791036129, + 0.0440848246216774, + 0.02650669775903225, + 0.0421316996216774, + 0.0477120541036129, + 0.0605468787252903, + 0.0555245578289032, + 0.0597098246216774, + 0.0594308078289032 + ], + [ + 0.0485491082072258, + 0.0605468787252903, + 0.0452008955180645, + 0.0792410746216774, + 0.0513392873108387, + 0.0438058041036129, + 0.0538504496216774, + 0.0432477705180645, + 0.0566406287252903, + 0.0599888414144516, + 0.0482700914144516, + 0.0279017873108387, + 0.0672433078289032, + 0.0613839328289032, + 0.0538504496216774, + 0.02385602705180645 + ], + [ + 0.0390625, + 0.0327845998108387, + 0.0479910746216774, + 0.0404575914144516, + 0.0390625, + 0.0318080373108387, + 0.0633370578289032, + 0.0560825914144516, + 0.0736607164144516, + 0.0591517873108387, + 0.0613839328289032, + 0.0521763414144516, + 0.0619419664144516, + 0.0385044664144516, + 0.063058041036129, + 0.078683041036129 + ] + ] +} \ No newline at end of file diff --git a/test/advanced_config/fp8_calibration_per_head/test_kv_cache_calib_per_head_qwen2.5_72b.json b/test/advanced_config/fp8_calibration_per_head/test_kv_cache_calib_per_head_qwen2.5_72b.json new file mode 100644 index 000000000..680d51502 --- /dev/null +++ b/test/advanced_config/fp8_calibration_per_head/test_kv_cache_calib_per_head_qwen2.5_72b.json @@ -0,0 +1,1455 @@ +{ + "version": "1.0", + "architectures": "Qwen2ForCausalLM", + "quant_type": "per_head", + "qmin": -448.0, + "qmax": 448.0, + "num_layers": 80, + "num_head": 8, + "scales_shape": [ + 80, + 16 + ], + "scales": [ + [ + 0.0415736623108387, + 0.0277622789144516, + 0.0320870541036129, + 0.0355747789144516, + 0.02064732275903225, + 0.02957589365541935, + 0.0239955373108387, + 0.02092633955180645, + 0.0006931850221008062, + 0.0003117152664344758, + 0.0003204345703125, + 0.0004752023087348789, + 0.00041634697117842734, + 0.00018419539264868945, + 0.00019727434846572578, + 0.0005558559205383062 + ], + [ + 0.013671875931322575, + 0.01297433115541935, + 0.01771763525903225, + 0.0213448666036129, + 0.0200892873108387, + 0.0184151791036129, + 0.0164620541036129, + 0.0362723246216774, + 0.0002528599579818547, + 0.00032479423680342734, + 0.00028773717349395156, + 0.0003618513001129031, + 0.0025634765625, + 0.0005384173127822578, + 0.0006365095032379031, + 0.0003901890595443547 + ], + [ + 0.0259486623108387, + 0.03027343936264515, + 0.0239955373108387, + 0.0280412957072258, + 0.0267857164144516, + 0.0329241082072258, + 0.02343750186264515, + 0.0306919664144516, + 0.0028599330689758062, + 0.003923689015209675, + 0.005371094215661287, + 0.004150390625, + 0.0054757255129516125, + 0.003505161963403225, + 0.005405971314758062, + 0.003069196594879031 + ], + [ + 0.0272042416036129, + 0.02441406436264515, + 0.02580915205180645, + 0.02287946455180645, + 0.02064732275903225, + 0.03027343936264515, + 0.02176339365541935, + 0.0401785746216774, + 0.0020839148201048374, + 0.002894810400903225, + 0.002650669775903225, + 0.0016566686099395156, + 0.001796177588403225, + 0.003313337219879031, + 0.0016217913944274187, + 0.003383091650903225 + ], + [ + 0.03041294775903225, + 0.02566964365541935, + 0.0270647332072258, + 0.02762276865541935, + 0.0343191996216774, + 0.0379464291036129, + 0.01897321455180645, + 0.0325055830180645, + 0.004952567163854837, + 0.003383091650903225, + 0.0086495541036129, + 0.005266462452709675, + 0.0035226005129516125, + 0.0029820033814758062, + 0.005650111939758062, + 0.004778181202709675 + ], + [ + 0.01869419775903225, + 0.0323660746216774, + 0.0202287957072258, + 0.03041294775903225, + 0.0239955373108387, + 0.02385602705180645, + 0.0373883955180645, + 0.02580915205180645, + 0.003993443213403225, + 0.004115513525903225, + 0.006173270288854837, + 0.003976004663854837, + 0.004638671875, + 0.005405971314758062, + 0.0054757255129516125, + 0.0036969867069274187 + ], + [ + 0.0171595998108387, + 0.03125, + 0.02176339365541935, + 0.01967076025903225, + 0.0262276791036129, + 0.0325055830180645, + 0.0279017873108387, + 0.02371651865541935, + 0.003191266907379031, + 0.005022321827709675, + 0.0027553015388548374, + 0.0048828125, + 0.005126953125, + 0.004307338502258062, + 0.006487165577709675, + 0.005196707788854837 + ], + [ + 0.01541573740541935, + 0.01967076025903225, + 0.02664620615541935, + 0.02287946455180645, + 0.0239955373108387, + 0.0354352705180645, + 0.0270647332072258, + 0.0291573666036129, + 0.0044991630129516125, + 0.004673549439758062, + 0.006173270288854837, + 0.00809151865541935, + 0.005196707788854837, + 0.005580357275903225, + 0.003016880713403225, + 0.003976004663854837 + ], + [ + 0.0379464291036129, + 0.0365513414144516, + 0.0267857164144516, + 0.0345982164144516, + 0.0339006707072258, + 0.0325055830180645, + 0.0297154039144516, + 0.0308314748108387, + 0.003138951025903225, + 0.0031215124763548374, + 0.008021763525903225, + 0.0037318640388548374, + 0.004237583838403225, + 0.00432477705180645, + 0.002947126282379031, + 0.002406529150903225 + ], + [ + 0.0440848246216774, + 0.0404575914144516, + 0.0327845998108387, + 0.0493861623108387, + 0.0368303582072258, + 0.0418526791036129, + 0.0393415205180645, + 0.0446428582072258, + 0.00347028486430645, + 0.00830078125, + 0.006801060400903225, + 0.003941127564758062, + 0.004708426538854837, + 0.005092076025903225, + 0.005754743702709675, + 0.005684989038854837 + ], + [ + 0.0449218787252903, + 0.0277622789144516, + 0.0407366082072258, + 0.0449218787252903, + 0.0454799123108387, + 0.0443638414144516, + 0.0435267873108387, + 0.0449218787252903, + 0.006382533814758062, + 0.005615234840661287, + 0.0035400392953306437, + 0.006487165577709675, + 0.00578962080180645, + 0.00383649580180645, + 0.0047433036379516125, + 0.004185268189758062 + ], + [ + 0.02832031436264515, + 0.0421316996216774, + 0.0424107164144516, + 0.0468750037252903, + 0.0404575914144516, + 0.0435267873108387, + 0.0184151791036129, + 0.0385044664144516, + 0.006870815064758062, + 0.0047433036379516125, + 0.0081612728536129, + 0.00676618330180645, + 0.005092076025903225, + 0.00627790205180645, + 0.011928013525903225, + 0.0052315848879516125 + ], + [ + 0.0288783498108387, + 0.0287388414144516, + 0.0333426371216774, + 0.0398995541036129, + 0.02929687686264515, + 0.0313895121216774, + 0.0429687537252903, + 0.0390625, + 0.00578962080180645, + 0.00456891767680645, + 0.008056640625, + 0.008265904150903225, + 0.006347656715661287, + 0.005580357275903225, + 0.0069405697286129, + 0.005754743702709675 + ], + [ + 0.0336216539144516, + 0.0424107164144516, + 0.0443638414144516, + 0.0446428582072258, + 0.0379464291036129, + 0.0435267873108387, + 0.037109375, + 0.0396205373108387, + 0.0057198661379516125, + 0.006626674439758062, + 0.0064522880129516125, + 0.00676618330180645, + 0.006382533814758062, + 0.0038888114504516125, + 0.0062081473879516125, + 0.00481305830180645 + ], + [ + 0.0359933041036129, + 0.0488281287252903, + 0.0418526791036129, + 0.0407366082072258, + 0.0429687537252903, + 0.02287946455180645, + 0.0507812537252903, + 0.0435267873108387, + 0.0043770927004516125, + 0.005022321827709675, + 0.00456891767680645, + 0.00420270673930645, + 0.006138393189758062, + 0.008370536379516125, + 0.004150390625, + 0.00481305830180645 + ], + [ + 0.0313895121216774, + 0.0401785746216774, + 0.0315290205180645, + 0.0220424123108387, + 0.0426897332072258, + 0.0426897332072258, + 0.02469308115541935, + 0.0362723246216774, + 0.009416853077709675, + 0.009416853077709675, + 0.006103516090661287, + 0.0074288509786129, + 0.006975446827709675, + 0.00676618330180645, + 0.0054757255129516125, + 0.007324219215661287 + ], + [ + 0.0357142873108387, + 0.0373883955180645, + 0.0454799123108387, + 0.0315290205180645, + 0.0337611623108387, + 0.0365513414144516, + 0.0398995541036129, + 0.0379464291036129, + 0.007603236939758062, + 0.0049874442629516125, + 0.007777622900903225, + 0.007533482275903225, + 0.005092076025903225, + 0.005894252564758062, + 0.00652204267680645, + 0.0064522880129516125 + ], + [ + 0.0457589291036129, + 0.0398995541036129, + 0.0457589291036129, + 0.0280412957072258, + 0.0387834832072258, + 0.02664620615541935, + 0.0488281287252903, + 0.03125, + 0.005650111939758062, + 0.00652204267680645, + 0.00554548017680645, + 0.007568359840661287, + 0.005405971314758062, + 0.007289341650903225, + 0.005196707788854837, + 0.005894252564758062 + ], + [ + 0.0465959832072258, + 0.0355747789144516, + 0.0339006707072258, + 0.0359933041036129, + 0.03027343936264515, + 0.0232979916036129, + 0.0390625, + 0.0426897332072258, + 0.007393973413854837, + 0.01053292490541935, + 0.008370536379516125, + 0.0052315848879516125, + 0.006173270288854837, + 0.01025390625, + 0.006487165577709675, + 0.00578962080180645 + ], + [ + 0.041015625, + 0.0308314748108387, + 0.0435267873108387, + 0.0329241082072258, + 0.0362723246216774, + 0.0308314748108387, + 0.0454799123108387, + 0.0355747789144516, + 0.006068638525903225, + 0.009974888525903225, + 0.0120675228536129, + 0.007045201025903225, + 0.006312779150903225, + 0.00774274580180645, + 0.007045201025903225, + 0.007952009327709675 + ], + [ + 0.0318080373108387, + 0.0297154039144516, + 0.0379464291036129, + 0.0390625, + 0.0465959832072258, + 0.0424107164144516, + 0.0379464291036129, + 0.0401785746216774, + 0.00830078125, + 0.0052315848879516125, + 0.005371094215661287, + 0.005615234840661287, + 0.007219587452709675, + 0.005371094215661287, + 0.007463728077709675, + 0.005266462452709675 + ], + [ + 0.0327845998108387, + 0.041015625, + 0.02162388525903225, + 0.0449218787252903, + 0.0252511166036129, + 0.0418526791036129, + 0.0401785746216774, + 0.0373883955180645, + 0.004411970265209675, + 0.006243024952709675, + 0.00857979990541935, + 0.00439453125, + 0.004778181202709675, + 0.004673549439758062, + 0.004220145288854837, + 0.004638671875 + ], + [ + 0.0340401791036129, + 0.02859933115541935, + 0.0382254496216774, + 0.041015625, + 0.0415736623108387, + 0.0373883955180645, + 0.0337611623108387, + 0.0379464291036129, + 0.0076729916036129, + 0.012765067629516125, + 0.00676618330180645, + 0.008440290577709675, + 0.01116071455180645, + 0.01883370615541935, + 0.00871930830180645, + 0.01688058115541935 + ], + [ + 0.0421316996216774, + 0.0443638414144516, + 0.02650669775903225, + 0.0336216539144516, + 0.0407366082072258, + 0.0262276791036129, + 0.0421316996216774, + 0.0316685289144516, + 0.006661551538854837, + 0.01541573740541935, + 0.00603376142680645, + 0.008126395754516125, + 0.012346540577709675, + 0.006382533814758062, + 0.0079171322286129, + 0.00920758955180645 + ], + [ + 0.0507812537252903, + 0.0373883955180645, + 0.0350167416036129, + 0.0415736623108387, + 0.0269252248108387, + 0.0382254496216774, + 0.0396205373108387, + 0.02260044775903225, + 0.00920758955180645, + 0.0081612728536129, + 0.005859375465661287, + 0.006487165577709675, + 0.00701032392680645, + 0.00809151865541935, + 0.010393415577709675, + 0.007114955689758062 + ], + [ + 0.0368303582072258, + 0.0407366082072258, + 0.0398995541036129, + 0.0421316996216774, + 0.0421316996216774, + 0.0401785746216774, + 0.0336216539144516, + 0.0348772332072258, + 0.01004464365541935, + 0.01067243330180645, + 0.01116071455180645, + 0.007080078590661287, + 0.009416853077709675, + 0.011369978077709675, + 0.009974888525903225, + 0.01053292490541935 + ], + [ + 0.0267857164144516, + 0.0435267873108387, + 0.02636718936264515, + 0.0457589291036129, + 0.0418526791036129, + 0.0429687537252903, + 0.0318080373108387, + 0.0393415205180645, + 0.008370536379516125, + 0.006835937965661287, + 0.01443917490541935, + 0.009486607275903225, + 0.010323661379516125, + 0.006801060400903225, + 0.01248604990541935, + 0.00882394053041935 + ], + [ + 0.0438058041036129, + 0.0488281287252903, + 0.0387834832072258, + 0.0382254496216774, + 0.0465959832072258, + 0.0424107164144516, + 0.0460379496216774, + 0.0485491082072258, + 0.0106026791036129, + 0.006661551538854837, + 0.00955636240541935, + 0.013323103077709675, + 0.0078125, + 0.007568359840661287, + 0.007533482275903225, + 0.00823102705180645 + ], + [ + 0.0387834832072258, + 0.0365513414144516, + 0.0488281287252903, + 0.0407366082072258, + 0.0387834832072258, + 0.0315290205180645, + 0.037109375, + 0.0376674123108387, + 0.0084054134786129, + 0.010811942629516125, + 0.007080078590661287, + 0.008056640625, + 0.013253348879516125, + 0.007882255129516125, + 0.007114955689758062, + 0.0192522332072258 + ], + [ + 0.0510602705180645, + 0.0471540205180645, + 0.0446428582072258, + 0.0407366082072258, + 0.0479910746216774, + 0.0471540205180645, + 0.0407366082072258, + 0.0319475457072258, + 0.0079171322286129, + 0.0096261166036129, + 0.008510044775903225, + 0.0086495541036129, + 0.008928571827709675, + 0.008928571827709675, + 0.01506696455180645, + 0.011230469681322575 + ], + [ + 0.0435267873108387, + 0.0404575914144516, + 0.0460379496216774, + 0.0424107164144516, + 0.0446428582072258, + 0.0502232164144516, + 0.0449218787252903, + 0.0337611623108387, + 0.01018415205180645, + 0.00784737803041935, + 0.00823102705180645, + 0.0110909603536129, + 0.0101143978536129, + 0.008684431202709675, + 0.007114955689758062, + 0.012276786379516125 + ], + [ + 0.0496651791036129, + 0.0432477705180645, + 0.0390625, + 0.0387834832072258, + 0.0449218787252903, + 0.0396205373108387, + 0.0418526791036129, + 0.0426897332072258, + 0.01213727705180645, + 0.01688058115541935, + 0.013183594681322575, + 0.00920758955180645, + 0.009765625, + 0.01004464365541935, + 0.009347098879516125, + 0.009765625 + ], + [ + 0.0454799123108387, + 0.0421316996216774, + 0.0398995541036129, + 0.0333426371216774, + 0.0438058041036129, + 0.0457589291036129, + 0.0424107164144516, + 0.0465959832072258, + 0.013253348879516125, + 0.014648438431322575, + 0.009835380129516125, + 0.014718192629516125, + 0.007882255129516125, + 0.00830078125, + 0.0110909603536129, + 0.01395089365541935 + ], + [ + 0.0329241082072258, + 0.041015625, + 0.037109375, + 0.0471540205180645, + 0.0463169664144516, + 0.0412946455180645, + 0.0424107164144516, + 0.0452008955180645, + 0.01953125, + 0.0221819207072258, + 0.0086495541036129, + 0.008928571827709675, + 0.009486607275903225, + 0.00955636240541935, + 0.00906808115541935, + 0.010951451025903225 + ], + [ + 0.0482700914144516, + 0.0373883955180645, + 0.0418526791036129, + 0.0407366082072258, + 0.0320870541036129, + 0.0488281287252903, + 0.0368303582072258, + 0.0316685289144516, + 0.01025390625, + 0.01164899580180645, + 0.00906808115541935, + 0.0115792416036129, + 0.009416853077709675, + 0.00969587080180645, + 0.008126395754516125, + 0.01736886240541935 + ], + [ + 0.03125, + 0.0365513414144516, + 0.0376674123108387, + 0.0404575914144516, + 0.0446428582072258, + 0.0513392873108387, + 0.0412946455180645, + 0.0350167416036129, + 0.0110909603536129, + 0.011230469681322575, + 0.012695313431322575, + 0.010393415577709675, + 0.0148577019572258, + 0.00927734375, + 0.0087890625, + 0.010811942629516125 + ], + [ + 0.0471540205180645, + 0.0429687537252903, + 0.0415736623108387, + 0.0429687537252903, + 0.0429687537252903, + 0.0362723246216774, + 0.0337611623108387, + 0.0426897332072258, + 0.01053292490541935, + 0.01248604990541935, + 0.009835380129516125, + 0.0140206478536129, + 0.008196149952709675, + 0.01067243330180645, + 0.008928571827709675, + 0.0115792416036129 + ], + [ + 0.0438058041036129, + 0.0488281287252903, + 0.0398995541036129, + 0.0463169664144516, + 0.0347377248108387, + 0.0412946455180645, + 0.0541294664144516, + 0.0468750037252903, + 0.0084054134786129, + 0.00906808115541935, + 0.015206473879516125, + 0.0086495541036129, + 0.00830078125, + 0.00955636240541935, + 0.0106026791036129, + 0.0130440853536129 + ], + [ + 0.0505022332072258, + 0.0429687537252903, + 0.0474330373108387, + 0.0471540205180645, + 0.037109375, + 0.0474330373108387, + 0.0468750037252903, + 0.0513392873108387, + 0.0091378353536129, + 0.008858817629516125, + 0.009486607275903225, + 0.0120675228536129, + 0.009765625, + 0.009486607275903225, + 0.008056640625, + 0.008056640625 + ], + [ + 0.0474330373108387, + 0.0465959832072258, + 0.0463169664144516, + 0.0463169664144516, + 0.0382254496216774, + 0.0465959832072258, + 0.0496651791036129, + 0.037109375, + 0.009347098879516125, + 0.00927734375, + 0.007952009327709675, + 0.0096261166036129, + 0.00920758955180645, + 0.0125558041036129, + 0.00798688642680645, + 0.009905134327709675 + ], + [ + 0.033203125, + 0.0491071455180645, + 0.0438058041036129, + 0.0502232164144516, + 0.0516183041036129, + 0.0319475457072258, + 0.0482700914144516, + 0.0373883955180645, + 0.011369978077709675, + 0.007359096314758062, + 0.006731306202709675, + 0.01018415205180645, + 0.007393973413854837, + 0.007777622900903225, + 0.008928571827709675, + 0.012695313431322575 + ], + [ + 0.0454799123108387, + 0.0415736623108387, + 0.0449218787252903, + 0.0435267873108387, + 0.0350167416036129, + 0.0474330373108387, + 0.0337611623108387, + 0.0532924123108387, + 0.012695313431322575, + 0.007777622900903225, + 0.012346540577709675, + 0.010811942629516125, + 0.0074288509786129, + 0.01360212080180645, + 0.008928571827709675, + 0.0079171322286129 + ], + [ + 0.0457589291036129, + 0.0393415205180645, + 0.0393415205180645, + 0.0443638414144516, + 0.0452008955180645, + 0.0488281287252903, + 0.0516183041036129, + 0.0566406287252903, + 0.01116071455180645, + 0.01018415205180645, + 0.0149972103536129, + 0.01067243330180645, + 0.011439732275903225, + 0.0079171322286129, + 0.0088936947286129, + 0.011718750931322575 + ], + [ + 0.0376674123108387, + 0.0412946455180645, + 0.0463169664144516, + 0.0485491082072258, + 0.0463169664144516, + 0.0336216539144516, + 0.0521763414144516, + 0.0376674123108387, + 0.01639229990541935, + 0.009416853077709675, + 0.0101143978536129, + 0.013741630129516125, + 0.013323103077709675, + 0.011788505129516125, + 0.010742188431322575, + 0.01625279150903225 + ], + [ + 0.0435267873108387, + 0.0376674123108387, + 0.0404575914144516, + 0.0471540205180645, + 0.0429687537252903, + 0.0468750037252903, + 0.0336216539144516, + 0.0485491082072258, + 0.0115792416036129, + 0.011928013525903225, + 0.013741630129516125, + 0.0125558041036129, + 0.008998326025903225, + 0.011439732275903225, + 0.012416294775903225, + 0.00927734375 + ], + [ + 0.0382254496216774, + 0.0440848246216774, + 0.0471540205180645, + 0.0362723246216774, + 0.0368303582072258, + 0.0443638414144516, + 0.0418526791036129, + 0.0505022332072258, + 0.0153459832072258, + 0.01004464365541935, + 0.009974888525903225, + 0.0182756707072258, + 0.0143694207072258, + 0.0159737728536129, + 0.011788505129516125, + 0.0125558041036129 + ], + [ + 0.0477120541036129, + 0.0393415205180645, + 0.0421316996216774, + 0.0415736623108387, + 0.0396205373108387, + 0.0401785746216774, + 0.0460379496216774, + 0.0398995541036129, + 0.01150948740541935, + 0.009835380129516125, + 0.008021763525903225, + 0.02148437686264515, + 0.01492745615541935, + 0.02762276865541935, + 0.012904576025903225, + 0.007463728077709675 + ], + [ + 0.0320870541036129, + 0.0457589291036129, + 0.0449218787252903, + 0.0432477705180645, + 0.0463169664144516, + 0.0440848246216774, + 0.0393415205180645, + 0.0429687537252903, + 0.012346540577709675, + 0.0164620541036129, + 0.02190290205180645, + 0.0184151791036129, + 0.0306919664144516, + 0.01102120615541935, + 0.02734375186264515, + 0.0191127248108387 + ], + [ + 0.0477120541036129, + 0.0485491082072258, + 0.0359933041036129, + 0.0404575914144516, + 0.0449218787252903, + 0.0365513414144516, + 0.0569196455180645, + 0.0376674123108387, + 0.01611328125, + 0.0298549123108387, + 0.01639229990541935, + 0.0231584832072258, + 0.01164899580180645, + 0.0168108269572258, + 0.01409040205180645, + 0.0260881707072258 + ], + [ + 0.0454799123108387, + 0.0418526791036129, + 0.0465959832072258, + 0.0429687537252903, + 0.0491071455180645, + 0.0306919664144516, + 0.0432477705180645, + 0.0418526791036129, + 0.0288783498108387, + 0.0418526791036129, + 0.02287946455180645, + 0.02748326025903225, + 0.02260044775903225, + 0.0269252248108387, + 0.02748326025903225, + 0.0185546875 + ], + [ + 0.0521763414144516, + 0.0432477705180645, + 0.0424107164144516, + 0.0438058041036129, + 0.0457589291036129, + 0.0418526791036129, + 0.0337611623108387, + 0.0477120541036129, + 0.01708984375, + 0.0176478810608387, + 0.02148437686264515, + 0.0313895121216774, + 0.014229911379516125, + 0.0184151791036129, + 0.010323661379516125, + 0.02176339365541935 + ], + [ + 0.0359933041036129, + 0.0594308078289032, + 0.0563616082072258, + 0.0541294664144516, + 0.0468750037252903, + 0.0463169664144516, + 0.0479910746216774, + 0.0527343787252903, + 0.012207032181322575, + 0.0163225457072258, + 0.0269252248108387, + 0.0288783498108387, + 0.0330636166036129, + 0.0330636166036129, + 0.0418526791036129, + 0.0318080373108387 + ], + [ + 0.0471540205180645, + 0.0491071455180645, + 0.0546875037252903, + 0.0443638414144516, + 0.0446428582072258, + 0.0538504496216774, + 0.0527343787252903, + 0.0571986623108387, + 0.013811384327709675, + 0.014160157181322575, + 0.01262555830180645, + 0.01708984375, + 0.0313895121216774, + 0.0323660746216774, + 0.0193917416036129, + 0.02162388525903225 + ], + [ + 0.0454799123108387, + 0.0468750037252903, + 0.0424107164144516, + 0.0541294664144516, + 0.0510602705180645, + 0.0365513414144516, + 0.0443638414144516, + 0.0272042416036129, + 0.011858259327709675, + 0.0185546875, + 0.014787946827709675, + 0.0203683041036129, + 0.01541573740541935, + 0.012695313431322575, + 0.010811942629516125, + 0.0299944207072258 + ], + [ + 0.0426897332072258, + 0.0390625, + 0.0527343787252903, + 0.0443638414144516, + 0.0558035746216774, + 0.0510602705180645, + 0.0407366082072258, + 0.0482700914144516, + 0.013741630129516125, + 0.02483258955180645, + 0.0301339291036129, + 0.01346261240541935, + 0.01492745615541935, + 0.0220424123108387, + 0.014718192629516125, + 0.0213448666036129 + ], + [ + 0.0390625, + 0.0510602705180645, + 0.0452008955180645, + 0.0524553582072258, + 0.0479910746216774, + 0.0552455373108387, + 0.0527343787252903, + 0.0446428582072258, + 0.0297154039144516, + 0.0232979916036129, + 0.0148577019572258, + 0.0182756707072258, + 0.0164620541036129, + 0.012416294775903225, + 0.0169503353536129, + 0.014787946827709675 + ], + [ + 0.0521763414144516, + 0.0594308078289032, + 0.0558035746216774, + 0.0499441996216774, + 0.0446428582072258, + 0.0421316996216774, + 0.0477120541036129, + 0.0488281287252903, + 0.0130440853536129, + 0.017578125, + 0.01311383955180645, + 0.0174386166036129, + 0.01457868330180645, + 0.02580915205180645, + 0.02357701025903225, + 0.0385044664144516 + ], + [ + 0.0502232164144516, + 0.0680803582072258, + 0.0460379496216774, + 0.0327845998108387, + 0.0396205373108387, + 0.0341796875, + 0.0485491082072258, + 0.0605468787252903, + 0.015136719681322575, + 0.0242745541036129, + 0.014299665577709675, + 0.0241350457072258, + 0.0182756707072258, + 0.02148437686264515, + 0.0156947560608387, + 0.0110909603536129 + ], + [ + 0.0496651791036129, + 0.0491071455180645, + 0.0496651791036129, + 0.0518973246216774, + 0.0376674123108387, + 0.0407366082072258, + 0.0524553582072258, + 0.0491071455180645, + 0.0164620541036129, + 0.01576451025903225, + 0.0185546875, + 0.014160157181322575, + 0.01443917490541935, + 0.01967076025903225, + 0.01653180830180645, + 0.012346540577709675 + ], + [ + 0.0457589291036129, + 0.0491071455180645, + 0.0446428582072258, + 0.0350167416036129, + 0.0479910746216774, + 0.041015625, + 0.0440848246216774, + 0.0546875037252903, + 0.0182756707072258, + 0.0153459832072258, + 0.01506696455180645, + 0.0176478810608387, + 0.0149972103536129, + 0.01799665205180645, + 0.01492745615541935, + 0.0182756707072258 + ], + [ + 0.0527343787252903, + 0.041015625, + 0.0577566996216774, + 0.0493861623108387, + 0.0569196455180645, + 0.0454799123108387, + 0.0491071455180645, + 0.0446428582072258, + 0.01722935400903225, + 0.01688058115541935, + 0.02260044775903225, + 0.0191127248108387, + 0.014648438431322575, + 0.01506696455180645, + 0.01722935400903225, + 0.0212053582072258 + ], + [ + 0.0429687537252903, + 0.0443638414144516, + 0.0510602705180645, + 0.0479910746216774, + 0.0513392873108387, + 0.0443638414144516, + 0.0505022332072258, + 0.0530133955180645, + 0.0149972103536129, + 0.01360212080180645, + 0.01869419775903225, + 0.01702008955180645, + 0.0185546875, + 0.02929687686264515, + 0.0290178582072258, + 0.0153459832072258 + ], + [ + 0.0485491082072258, + 0.0535714328289032, + 0.0454799123108387, + 0.0599888414144516, + 0.0418526791036129, + 0.0516183041036129, + 0.0552455373108387, + 0.0611049123108387, + 0.02287946455180645, + 0.0191127248108387, + 0.0193917416036129, + 0.0231584832072258, + 0.01541573740541935, + 0.02162388525903225, + 0.0203683041036129, + 0.0252511166036129 + ], + [ + 0.0507812537252903, + 0.0482700914144516, + 0.0499441996216774, + 0.0457589291036129, + 0.0474330373108387, + 0.0532924123108387, + 0.0362723246216774, + 0.0611049123108387, + 0.0161830373108387, + 0.014787946827709675, + 0.01785714365541935, + 0.01674107275903225, + 0.02469308115541935, + 0.0262276791036129, + 0.0185546875, + 0.0301339291036129 + ], + [ + 0.0577566996216774, + 0.0588727705180645, + 0.0404575914144516, + 0.0558035746216774, + 0.0485491082072258, + 0.0438058041036129, + 0.0499441996216774, + 0.0418526791036129, + 0.0154854916036129, + 0.02260044775903225, + 0.01953125, + 0.01625279150903225, + 0.02455357275903225, + 0.01994977705180645, + 0.014648438431322575, + 0.017578125 + ], + [ + 0.0482700914144516, + 0.0474330373108387, + 0.0516183041036129, + 0.0555245578289032, + 0.0521763414144516, + 0.0527343787252903, + 0.0421316996216774, + 0.0552455373108387, + 0.0249720998108387, + 0.0315290205180645, + 0.01785714365541935, + 0.01799665205180645, + 0.0223214291036129, + 0.0460379496216774, + 0.01576451025903225, + 0.0287388414144516 + ], + [ + 0.0558035746216774, + 0.0538504496216774, + 0.0552455373108387, + 0.0491071455180645, + 0.0502232164144516, + 0.0468750037252903, + 0.0546875037252903, + 0.0429687537252903, + 0.02566964365541935, + 0.02636718936264515, + 0.02832031436264515, + 0.02273995615541935, + 0.0205078125, + 0.02580915205180645, + 0.0192522332072258, + 0.0176478810608387 + ], + [ + 0.0493861623108387, + 0.0446428582072258, + 0.0524553582072258, + 0.0491071455180645, + 0.0524553582072258, + 0.0432477705180645, + 0.0521763414144516, + 0.0488281287252903, + 0.02580915205180645, + 0.0291573666036129, + 0.0281808041036129, + 0.01869419775903225, + 0.01869419775903225, + 0.0176478810608387, + 0.02190290205180645, + 0.0184151791036129 + ], + [ + 0.0510602705180645, + 0.0583147332072258, + 0.0521763414144516, + 0.0544084832072258, + 0.0549665205180645, + 0.0412946455180645, + 0.0535714328289032, + 0.0571986623108387, + 0.0182756707072258, + 0.0288783498108387, + 0.0262276791036129, + 0.02357701025903225, + 0.0309709832072258, + 0.01771763525903225, + 0.0336216539144516, + 0.0359933041036129 + ], + [ + 0.0499441996216774, + 0.0496651791036129, + 0.0538504496216774, + 0.0385044664144516, + 0.0613839328289032, + 0.0424107164144516, + 0.0421316996216774, + 0.0446428582072258, + 0.01736886240541935, + 0.0272042416036129, + 0.0203683041036129, + 0.03125, + 0.0185546875, + 0.0202287957072258, + 0.01722935400903225, + 0.0193917416036129 + ], + [ + 0.0482700914144516, + 0.0387834832072258, + 0.0552455373108387, + 0.0602678582072258, + 0.0552455373108387, + 0.0510602705180645, + 0.0530133955180645, + 0.0502232164144516, + 0.02287946455180645, + 0.014718192629516125, + 0.0260881707072258, + 0.02162388525903225, + 0.0368303582072258, + 0.0385044664144516, + 0.0359933041036129, + 0.01785714365541935 + ], + [ + 0.0468750037252903, + 0.0488281287252903, + 0.0438058041036129, + 0.0552455373108387, + 0.0546875037252903, + 0.0521763414144516, + 0.0415736623108387, + 0.0588727705180645, + 0.0368303582072258, + 0.0421316996216774, + 0.02553013525903225, + 0.0373883955180645, + 0.0401785746216774, + 0.0203683041036129, + 0.02162388525903225, + 0.01625279150903225 + ], + [ + 0.0611049123108387, + 0.0541294664144516, + 0.0555245578289032, + 0.0505022332072258, + 0.0544084832072258, + 0.0435267873108387, + 0.0518973246216774, + 0.0449218787252903, + 0.01708984375, + 0.0373883955180645, + 0.0365513414144516, + 0.0327845998108387, + 0.0241350457072258, + 0.0443638414144516, + 0.0232979916036129, + 0.0212053582072258 + ], + [ + 0.0627790242433548, + 0.0530133955180645, + 0.0555245578289032, + 0.0541294664144516, + 0.0412946455180645, + 0.0538504496216774, + 0.0513392873108387, + 0.0560825914144516, + 0.0347377248108387, + 0.0418526791036129, + 0.0443638414144516, + 0.02650669775903225, + 0.0231584832072258, + 0.0231584832072258, + 0.0491071455180645, + 0.0341796875 + ], + [ + 0.0471540205180645, + 0.0597098246216774, + 0.0613839328289032, + 0.0588727705180645, + 0.0510602705180645, + 0.0452008955180645, + 0.0505022332072258, + 0.0535714328289032, + 0.02859933115541935, + 0.0297154039144516, + 0.0350167416036129, + 0.0315290205180645, + 0.0379464291036129, + 0.0239955373108387, + 0.02385602705180645, + 0.0429687537252903 + ], + [ + 0.0560825914144516, + 0.0541294664144516, + 0.0680803582072258, + 0.0412946455180645, + 0.0516183041036129, + 0.0594308078289032, + 0.0505022332072258, + 0.0491071455180645, + 0.0336216539144516, + 0.0549665205180645, + 0.0479910746216774, + 0.0299944207072258, + 0.0260881707072258, + 0.0297154039144516, + 0.0221819207072258, + 0.0319475457072258 + ], + [ + 0.0591517873108387, + 0.0479910746216774, + 0.0510602705180645, + 0.0460379496216774, + 0.0471540205180645, + 0.0429687537252903, + 0.0532924123108387, + 0.0577566996216774, + 0.0513392873108387, + 0.0585937537252903, + 0.0379464291036129, + 0.0616629496216774, + 0.02832031436264515, + 0.02832031436264515, + 0.0619419664144516, + 0.0299944207072258 + ], + [ + 0.0613839328289032, + 0.0496651791036129, + 0.0588727705180645, + 0.0379464291036129, + 0.0496651791036129, + 0.0602678582072258, + 0.0530133955180645, + 0.0566406287252903, + 0.0390625, + 0.0426897332072258, + 0.0457589291036129, + 0.0330636166036129, + 0.0429687537252903, + 0.0355747789144516, + 0.0477120541036129, + 0.0339006707072258 + ], + [ + 0.0412946455180645, + 0.0530133955180645, + 0.0608258955180645, + 0.0599888414144516, + 0.02469308115541935, + 0.0424107164144516, + 0.0460379496216774, + 0.0521763414144516, + 0.0485491082072258, + 0.0521763414144516, + 0.03041294775903225, + 0.0337611623108387, + 0.0301339291036129, + 0.0385044664144516, + 0.0385044664144516, + 0.0412946455180645 + ], + [ + 0.0454799123108387, + 0.0385044664144516, + 0.0376674123108387, + 0.0463169664144516, + 0.0474330373108387, + 0.0382254496216774, + 0.0326450914144516, + 0.0315290205180645, + 0.02441406436264515, + 0.0357142873108387, + 0.0322265625, + 0.0262276791036129, + 0.02566964365541935, + 0.03515625, + 0.0220424123108387, + 0.0185546875 + ] + ] +} \ No newline at end of file diff --git a/test/advanced_config/fp8_calibration_per_head/test_kv_cache_calib_per_head_qwen3_235b.json b/test/advanced_config/fp8_calibration_per_head/test_kv_cache_calib_per_head_qwen3_235b.json new file mode 100644 index 000000000..870ee695a --- /dev/null +++ b/test/advanced_config/fp8_calibration_per_head/test_kv_cache_calib_per_head_qwen3_235b.json @@ -0,0 +1,955 @@ +{ + "version": "1.0", + "architectures": "Qwen3MoeForCausalLM", + "quant_type": "per_head", + "qmin": -448.0, + "qmax": 448.0, + "num_layers": 94, + "num_head": 4, + "scales_shape": [ + 94, + 8 + ], + "scales": [ + [ + 0.0334821455180645, + 0.0493861623108387, + 0.0385044664144516, + 0.0552455373108387, + 9.264265099773183e-05, + 0.00022888185048941523, + 0.00015912737580947578, + 0.00015476772387046367 + ], + [ + 0.037109375, + 0.0333426371216774, + 0.0333426371216774, + 0.0398995541036129, + 0.00028773717349395156, + 0.00010844640200957656, + 0.00010517665941733867, + 0.00013024467625655234 + ], + [ + 0.0719866082072258, + 0.0809151828289032, + 0.0731026828289032, + 0.068359375, + 0.0002833775361068547, + 0.00018964495393447578, + 0.00017547608877066523, + 0.00014168876805342734 + ], + [ + 0.0613839328289032, + 0.0703125, + 0.0591517873108387, + 0.0521763414144516, + 0.00028773717349395156, + 0.00042506627505645156, + 0.0002615792618598789, + 0.0004381452454254031 + ], + [ + 0.0725446492433548, + 0.0697544664144516, + 0.0694754496216774, + 0.0482700914144516, + 0.0003291539032943547, + 0.00026702880859375, + 0.00040108818211592734, + 0.0003117152664344758 + ], + [ + 0.1484375, + 0.1311383992433548, + 0.1037946492433548, + 0.1233258992433548, + 0.00038582939305342734, + 0.00033569338847883046, + 0.00031825475161895156, + 0.0006583078065887094 + ], + [ + 0.037109375, + 0.0491071455180645, + 0.0421316996216774, + 0.0390625, + 0.0005558559205383062, + 0.0006495884736068547, + 0.0003444126923568547, + 0.0003422328445594758 + ], + [ + 0.0345982164144516, + 0.0538504496216774, + 0.0463169664144516, + 0.0569196455180645, + 0.00038582939305342734, + 0.00030953544774092734, + 0.0007193429628387094, + 0.00047738212742842734 + ], + [ + 0.1082589328289032, + 0.1010044664144516, + 0.0831473246216774, + 0.1093750074505806, + 0.0008326939423568547, + 0.0004686628235504031, + 0.0005711147096008062, + 0.00046212333836592734 + ], + [ + 0.0555245578289032, + 0.0412946455180645, + 0.0449218787252903, + 0.02929687686264515, + 0.0007672991487197578, + 0.0009765625, + 0.000518798828125, + 0.0004359654267318547 + ], + [ + 0.0853794664144516, + 0.0870535746216774, + 0.1037946492433548, + 0.086495541036129, + 0.00033351354068145156, + 0.0004141671524848789, + 0.0004207066376693547, + 0.0003160749329254031 + ], + [ + 0.0482700914144516, + 0.0396205373108387, + 0.0429687537252903, + 0.0407366082072258, + 0.0003465925110504031, + 0.0004381452454254031, + 0.0004207066376693547, + 0.0004512242157943547 + ], + [ + 0.03027343936264515, + 0.0357142873108387, + 0.0368303582072258, + 0.0538504496216774, + 0.0005427769501693547, + 0.0003531319962348789, + 0.00030517578125, + 0.00033569338847883046 + ], + [ + 0.0583147332072258, + 0.0585937537252903, + 0.0493861623108387, + 0.0454799123108387, + 0.00038146975566633046, + 0.0004359654267318547, + 0.00047956197522580624, + 0.00040108818211592734 + ], + [ + 0.0675223246216774, + 0.0753348246216774, + 0.0652901828289032, + 0.0594308078289032, + 0.0002833775361068547, + 0.00031825475161895156, + 0.0003422328445594758, + 0.0004991804016754031 + ], + [ + 0.0368303582072258, + 0.0440848246216774, + 0.0471540205180645, + 0.0505022332072258, + 0.0006059919251129031, + 0.00033569338847883046, + 0.00043160576024092734, + 0.0004446847306098789 + ], + [ + 0.0594308078289032, + 0.0482700914144516, + 0.0426897332072258, + 0.0449218787252903, + 0.0003989083634223789, + 0.0006495884736068547, + 0.0006277902284637094, + 0.00030299596255645156 + ], + [ + 0.0524553582072258, + 0.0714285746216774, + 0.0758928582072258, + 0.0725446492433548, + 0.00039454869693145156, + 0.0004752023087348789, + 0.0004141671524848789, + 0.0003465925110504031 + ], + [ + 0.0691964328289032, + 0.0758928582072258, + 0.0516183041036129, + 0.0638950914144516, + 0.0005275181611068547, + 0.00045776370097883046, + 0.00034877232974395156, + 0.0003574916336219758 + ], + [ + 0.0446428582072258, + 0.0330636166036129, + 0.0404575914144516, + 0.0412946455180645, + 0.0003683907852973789, + 0.00044250491191633046, + 0.0003749302704818547, + 0.0005580357392318547 + ], + [ + 0.0499441996216774, + 0.0387834832072258, + 0.0426897332072258, + 0.0488281287252903, + 0.00044686454930342734, + 0.0005253383424133062, + 0.00038146975566633046, + 0.0003749302704818547 + ], + [ + 0.0446428582072258, + 0.0390625, + 0.0449218787252903, + 0.0396205373108387, + 0.00038146975566633046, + 0.0003923688782379031, + 0.00040980748599395156, + 0.0004076276673004031 + ], + [ + 0.0315290205180645, + 0.0415736623108387, + 0.0479910746216774, + 0.0446428582072258, + 0.0003727504226844758, + 0.0004076276673004031, + 0.00046430318616330624, + 0.00032479423680342734 + ], + [ + 0.0404575914144516, + 0.0516183041036129, + 0.0318080373108387, + 0.0499441996216774, + 0.0004512242157943547, + 0.00033569338847883046, + 0.0005449567688629031, + 0.0003749302704818547 + ], + [ + 0.0463169664144516, + 0.0339006707072258, + 0.0446428582072258, + 0.0424107164144516, + 0.0006147112580947578, + 0.0003836495743598789, + 0.0004207066376693547, + 0.0005296979798004031 + ], + [ + 0.0797991082072258, + 0.0987723246216774, + 0.0616629496216774, + 0.065011166036129, + 0.00046212333836592734, + 0.0004599435196723789, + 0.00039672854472883046, + 0.0006277902284637094 + ], + [ + 0.0354352705180645, + 0.0418526791036129, + 0.0463169664144516, + 0.03125, + 0.0006103515625, + 0.0005711147096008062, + 0.0004817417939193547, + 0.0013950893189758062 + ], + [ + 0.0326450914144516, + 0.0315290205180645, + 0.0485491082072258, + 0.0563616082072258, + 0.0006713867769576609, + 0.0005231585237197578, + 0.0006888253847137094, + 0.0004664830048568547 + ], + [ + 0.02943638525903225, + 0.0429687537252903, + 0.02929687686264515, + 0.03125, + 0.0005275181611068547, + 0.0005078997346572578, + 0.00044686454930342734, + 0.00043160576024092734 + ], + [ + 0.0549665205180645, + 0.0457589291036129, + 0.0438058041036129, + 0.0290178582072258, + 0.0013078962219879031, + 0.0013253348879516125, + 0.0007542201783508062, + 0.0005841936799697578 + ], + [ + 0.0959821492433548, + 0.08203125, + 0.1116071492433548, + 0.1037946492433548, + 0.0006277902284637094, + 0.0005296979798004031, + 0.0006365095032379031, + 0.0005057199159637094 + ], + [ + 0.0491071455180645, + 0.0440848246216774, + 0.0443638414144516, + 0.0432477705180645, + 0.00044032506411895156, + 0.0003880092117469758, + 0.0005057199159637094, + 0.0005427769501693547 + ], + [ + 0.0329241082072258, + 0.0393415205180645, + 0.0373883955180645, + 0.0563616082072258, + 0.0005885533173568547, + 0.0004076276673004031, + 0.0004752023087348789, + 0.0008980887942016125 + ], + [ + 0.0577566996216774, + 0.0588727705180645, + 0.0524553582072258, + 0.0460379496216774, + 0.0004970005829818547, + 0.0006844656891189516, + 0.0005711147096008062, + 0.0011160714784637094 + ], + [ + 0.0619419664144516, + 0.0809151828289032, + 0.066964291036129, + 0.0638950914144516, + 0.000579833984375, + 0.0004970005829818547, + 0.0004686628235504031, + 0.0004926409455947578 + ], + [ + 0.0365513414144516, + 0.0452008955180645, + 0.0474330373108387, + 0.0513392873108387, + 0.0010027204407379031, + 0.0004207066376693547, + 0.0006583078065887094, + 0.0005296979798004031 + ], + [ + 0.0527343787252903, + 0.0527343787252903, + 0.0449218787252903, + 0.0438058041036129, + 0.0005623953766189516, + 0.00047084264224395156, + 0.0013863700442016125, + 0.0005929129547439516 + ], + [ + 0.0658482164144516, + 0.0703125, + 0.0703125, + 0.0731026828289032, + 0.0007106236298568547, + 0.0012730190064758062, + 0.0006888253847137094, + 0.0009678432252258062 + ], + [ + 0.0705915242433548, + 0.0672433078289032, + 0.0482700914144516, + 0.0583147332072258, + 0.0008021763642318547, + 0.0009329660097137094, + 0.0006844656891189516, + 0.0010811942629516125 + ], + [ + 0.0468750037252903, + 0.037109375, + 0.0362723246216774, + 0.0429687537252903, + 0.0006147112580947578, + 0.0006495884736068547, + 0.0006234305328689516, + 0.0010986328125 + ], + [ + 0.0521763414144516, + 0.0376674123108387, + 0.0404575914144516, + 0.0499441996216774, + 0.0007890973938629031, + 0.0008021763642318547, + 0.0008806501282379031, + 0.0005296979798004031 + ], + [ + 0.0465959832072258, + 0.0421316996216774, + 0.0390625, + 0.0393415205180645, + 0.0010332380188629031, + 0.0005841936799697578, + 0.0009068080689758062, + 0.000640869140625 + ], + [ + 0.0340401791036129, + 0.0398995541036129, + 0.0471540205180645, + 0.0443638414144516, + 0.0005405971314758062, + 0.0006713867769576609, + 0.0006539481109939516, + 0.0004926409455947578 + ], + [ + 0.0457589291036129, + 0.0571986623108387, + 0.0385044664144516, + 0.0502232164144516, + 0.0013253348879516125, + 0.0006059919251129031, + 0.0007149832672439516, + 0.0006452288362197578 + ], + [ + 0.0446428582072258, + 0.0345982164144516, + 0.0435267873108387, + 0.0435267873108387, + 0.0008021763642318547, + 0.0006147112580947578, + 0.0008239746675826609, + 0.0007280622376129031 + ], + [ + 0.0853794664144516, + 0.102120541036129, + 0.0613839328289032, + 0.0655691996216774, + 0.0009155274019576609, + 0.0012642997317016125, + 0.0008588518830947578, + 0.0010986328125 + ], + [ + 0.0336216539144516, + 0.0387834832072258, + 0.0460379496216774, + 0.02832031436264515, + 0.0008544922457076609, + 0.0008370536379516125, + 0.0007411412079818547, + 0.0019880023319274187 + ], + [ + 0.0555245578289032, + 0.0597098246216774, + 0.0541294664144516, + 0.0477120541036129, + 0.0010506766848266125, + 0.0008937290986068547, + 0.0006452288362197578, + 0.0008457729127258062 + ], + [ + 0.0616629496216774, + 0.0814732164144516, + 0.0758928582072258, + 0.068917416036129, + 0.0007280622376129031, + 0.0007193429628387094, + 0.0007498605409637094, + 0.0008414132753387094 + ], + [ + 0.0365513414144516, + 0.0457589291036129, + 0.0485491082072258, + 0.0513392873108387, + 0.0018310548039153218, + 0.0006103515625, + 0.0014822824159637094, + 0.0007542201783508062 + ], + [ + 0.0613839328289032, + 0.0516183041036129, + 0.0518973246216774, + 0.0488281287252903, + 0.0020054408814758062, + 0.0006626674439758062, + 0.003749302588403225, + 0.0009024484315887094 + ], + [ + 0.0697544664144516, + 0.0725446492433548, + 0.0691964328289032, + 0.0770089328289032, + 0.0015345982974395156, + 0.0028076174203306437, + 0.0014212472597137094, + 0.0021100726444274187 + ], + [ + 0.0725446492433548, + 0.0666852742433548, + 0.0471540205180645, + 0.0616629496216774, + 0.0010593959596008062, + 0.0021798270754516125, + 0.002284458838403225, + 0.0024937221314758062 + ], + [ + 0.0457589291036129, + 0.0390625, + 0.0387834832072258, + 0.0446428582072258, + 0.0009111677063629031, + 0.0010768346255645156, + 0.0011509486939758062, + 0.0018310548039153218 + ], + [ + 0.0510602705180645, + 0.0407366082072258, + 0.0421316996216774, + 0.0516183041036129, + 0.0014648438664153218, + 0.0017089844914153218, + 0.002406529150903225, + 0.0012294225161895156 + ], + [ + 0.0465959832072258, + 0.0396205373108387, + 0.0390625, + 0.0396205373108387, + 0.0030517580453306437, + 0.0014822824159637094, + 0.0021885463502258062, + 0.002162388525903225 + ], + [ + 0.0326450914144516, + 0.0510602705180645, + 0.0479910746216774, + 0.0438058041036129, + 0.0012642997317016125, + 0.0010332380188629031, + 0.0012904576724395156, + 0.0023890906013548374 + ], + [ + 0.0407366082072258, + 0.0513392873108387, + 0.0333426371216774, + 0.0493861623108387, + 0.0038190570194274187, + 0.0024239677004516125, + 0.001953125, + 0.0013078962219879031 + ], + [ + 0.0460379496216774, + 0.037109375, + 0.0452008955180645, + 0.0474330373108387, + 0.0011509486939758062, + 0.0010506766848266125, + 0.0020839148201048374, + 0.0026855471078306437 + ], + [ + 0.0725446492433548, + 0.094308041036129, + 0.0591517873108387, + 0.0585937537252903, + 0.0024762835819274187, + 0.004063197877258062, + 0.001220703125, + 0.001918247900903225 + ], + [ + 0.0341796875, + 0.0393415205180645, + 0.0465959832072258, + 0.0301339291036129, + 0.002458845032379031, + 0.0020664760377258062, + 0.0016479493351653218, + 0.004167829640209675 + ], + [ + 0.0319475457072258, + 0.0343191996216774, + 0.0482700914144516, + 0.0563616082072258, + 0.0016305106692016125, + 0.001953125, + 0.0018223354127258062, + 0.0013602121034637094 + ], + [ + 0.033203125, + 0.0415736623108387, + 0.0368303582072258, + 0.0315290205180645, + 0.0016828265506774187, + 0.0016479493351653218, + 0.0017525809817016125, + 0.002528599463403225 + ], + [ + 0.0546875037252903, + 0.0452008955180645, + 0.0418526791036129, + 0.0313895121216774, + 0.0018659320194274187, + 0.0024937221314758062, + 0.002197265625, + 0.0015171596314758062 + ], + [ + 0.0898437574505806, + 0.08203125, + 0.0976562574505806, + 0.0837053582072258, + 0.0018310548039153218, + 0.0014648438664153218, + 0.002458845032379031, + 0.0013950893189758062 + ], + [ + 0.0505022332072258, + 0.0424107164144516, + 0.0426897332072258, + 0.0418526791036129, + 0.0024239677004516125, + 0.0017177037661895156, + 0.0014299665344879031, + 0.0014474052004516125 + ], + [ + 0.0330636166036129, + 0.0415736623108387, + 0.0357142873108387, + 0.0549665205180645, + 0.0014910016907379031, + 0.0014910016907379031, + 0.0016915458254516125, + 0.002197265625 + ], + [ + 0.0571986623108387, + 0.0544084832072258, + 0.0560825914144516, + 0.0463169664144516, + 0.0013340541627258062, + 0.0021100726444274187, + 0.0025634765625, + 0.003627232275903225 + ], + [ + 0.0591517873108387, + 0.0622209832072258, + 0.066964291036129, + 0.0555245578289032, + 0.0022670202888548374, + 0.002528599463403225, + 0.0020228796638548374, + 0.0023542132694274187 + ], + [ + 0.03515625, + 0.0465959832072258, + 0.0485491082072258, + 0.0530133955180645, + 0.00359235517680645, + 0.0016566686099395156, + 0.002153669251129031, + 0.0020490374881774187 + ], + [ + 0.0566406287252903, + 0.0479910746216774, + 0.0454799123108387, + 0.0438058041036129, + 0.006417410913854837, + 0.0024937221314758062, + 0.01067243330180645, + 0.002580915344879031 + ], + [ + 0.0700334832072258, + 0.0633370578289032, + 0.0703125, + 0.0666852742433548, + 0.003679548157379031, + 0.00906808115541935, + 0.004603794775903225, + 0.006068638525903225 + ], + [ + 0.0694754496216774, + 0.064453125, + 0.0485491082072258, + 0.0577566996216774, + 0.002580915344879031, + 0.006975446827709675, + 0.00701032392680645, + 0.007882255129516125 + ], + [ + 0.0454799123108387, + 0.0457589291036129, + 0.0339006707072258, + 0.0385044664144516, + 0.0020839148201048374, + 0.0021275111939758062, + 0.0042550223879516125, + 0.007777622900903225 + ], + [ + 0.0541294664144516, + 0.0373883955180645, + 0.0426897332072258, + 0.0485491082072258, + 0.0040283203125, + 0.00371442548930645, + 0.005754743702709675, + 0.00359235517680645 + ], + [ + 0.0477120541036129, + 0.0407366082072258, + 0.0357142873108387, + 0.041015625, + 0.009905134327709675, + 0.0028599330689758062, + 0.006870815064758062, + 0.005092076025903225 + ], + [ + 0.0344587080180645, + 0.0449218787252903, + 0.0488281287252903, + 0.0376674123108387, + 0.00347028486430645, + 0.0030866351444274187, + 0.0030343192629516125, + 0.0087890625 + ], + [ + 0.0493861623108387, + 0.0443638414144516, + 0.0327845998108387, + 0.0424107164144516, + 0.009765625, + 0.009347098879516125, + 0.005440848413854837, + 0.0038190570194274187 + ], + [ + 0.0468750037252903, + 0.0396205373108387, + 0.0449218787252903, + 0.0432477705180645, + 0.004063197877258062, + 0.004098074976354837, + 0.006591797340661287, + 0.006835937965661287 + ], + [ + 0.078125, + 0.094308041036129, + 0.0694754496216774, + 0.0571986623108387, + 0.007882255129516125, + 0.01213727705180645, + 0.0048828125, + 0.0048828125 + ], + [ + 0.0359933041036129, + 0.0390625, + 0.0491071455180645, + 0.0297154039144516, + 0.006382533814758062, + 0.0059640067629516125, + 0.005929129663854837, + 0.00749860517680645 + ], + [ + 0.0319475457072258, + 0.0333426371216774, + 0.0482700914144516, + 0.0544084832072258, + 0.007149832788854837, + 0.01981026865541935, + 0.00871930830180645, + 0.012207032181322575 + ], + [ + 0.0327845998108387, + 0.0449218787252903, + 0.0401785746216774, + 0.0345982164144516, + 0.008544921875, + 0.0135323666036129, + 0.0069405697286129, + 0.012207032181322575 + ], + [ + 0.0541294664144516, + 0.0435267873108387, + 0.0415736623108387, + 0.0325055830180645, + 0.00798688642680645, + 0.007638114038854837, + 0.006801060400903225, + 0.004342215601354837 + ], + [ + 0.0385044664144516, + 0.0385044664144516, + 0.0387834832072258, + 0.0343191996216774, + 0.004778181202709675, + 0.006556919775903225, + 0.012834821827709675, + 0.005894252564758062 + ], + [ + 0.0385044664144516, + 0.0412946455180645, + 0.0339006707072258, + 0.0460379496216774, + 0.007568359840661287, + 0.01311383955180645, + 0.01897321455180645, + 0.014718192629516125 + ], + [ + 0.0376674123108387, + 0.0393415205180645, + 0.0297154039144516, + 0.0259486623108387, + 0.010742188431322575, + 0.00920758955180645, + 0.01213727705180645, + 0.005196707788854837 + ], + [ + 0.0368303582072258, + 0.0299944207072258, + 0.0318080373108387, + 0.03041294775903225, + 0.012207032181322575, + 0.012276786379516125, + 0.010742188431322575, + 0.011928013525903225 + ], + [ + 0.0299944207072258, + 0.0354352705180645, + 0.0316685289144516, + 0.02748326025903225, + 0.009486607275903225, + 0.0135323666036129, + 0.009416853077709675, + 0.02176339365541935 + ], + [ + 0.0326450914144516, + 0.0279017873108387, + 0.02762276865541935, + 0.0269252248108387, + 0.010323661379516125, + 0.008858817629516125, + 0.011439732275903225, + 0.0279017873108387 + ], + [ + 0.0262276791036129, + 0.0249720998108387, + 0.03125, + 0.0269252248108387, + 0.0140206478536129, + 0.02845982275903225, + 0.0252511166036129, + 0.0133928582072258 + ], + [ + 0.0524553582072258, + 0.0641741082072258, + 0.0546875037252903, + 0.0496651791036129, + 0.0429687537252903, + 0.0185546875, + 0.0398995541036129, + 0.0115792416036129 + ], + [ + 0.0407366082072258, + 0.0322265625, + 0.0438058041036129, + 0.0421316996216774, + 0.010463169775903225, + 0.017578125, + 0.0148577019572258, + 0.0318080373108387 + ], + [ + 0.0426897332072258, + 0.0443638414144516, + 0.0319475457072258, + 0.0457589291036129, + 0.0185546875, + 0.013253348879516125, + 0.0220424123108387, + 0.0251116082072258 + ] + ] +} \ No newline at end of file diff --git a/test/advanced_config/fp8_calibration_per_head/test_kv_cache_calib_per_head_qwen3_30b.json b/test/advanced_config/fp8_calibration_per_head/test_kv_cache_calib_per_head_qwen3_30b.json new file mode 100644 index 000000000..f9404de0f --- /dev/null +++ b/test/advanced_config/fp8_calibration_per_head/test_kv_cache_calib_per_head_qwen3_30b.json @@ -0,0 +1,495 @@ +{ + "version": "1.0", + "architectures": "Qwen3MoeForCausalLM", + "quant_type": "per_head", + "qmin": -448.0, + "qmax": 448.0, + "num_layers": 48, + "num_head": 4, + "scales_shape": [ + 48, + 8 + ], + "scales": [ + [ + 0.1199776828289032, + 0.133928582072258, + 0.2232142984867096, + 0.172991082072258, + 0.00019182478717993945, + 0.00018201555940322578, + 0.0002626691712066531, + 0.00018964495393447578 + ], + [ + 0.1155133992433548, + 0.1071428656578064, + 0.1032366082072258, + 0.1104910746216774, + 0.0005275181611068547, + 0.00035531181492842734, + 0.0003138951142318547, + 0.00038146975566633046 + ], + [ + 0.0452008955180645, + 0.0412946455180645, + 0.0627790242433548, + 0.0343191996216774, + 0.0003574916336219758, + 0.0003574916336219758, + 0.0007498605409637094, + 0.0003204345703125 + ], + [ + 0.3482142984867096, + 0.3571428656578064, + 0.28125, + 0.345982164144516, + 0.0006975446594879031, + 0.0007934570894576609, + 0.0009591239504516125, + 0.0008370536379516125 + ], + [ + 0.06640625, + 0.0552455373108387, + 0.090401791036129, + 0.0571986623108387, + 0.0009809222538024187, + 0.0005754743469879031, + 0.0005449567688629031, + 0.0009155274019576609 + ], + [ + 0.1032366082072258, + 0.0814732164144516, + 0.1116071492433548, + 0.1183035746216774, + 0.0006931850221008062, + 0.0005078997346572578, + 0.0007455008453689516, + 0.0004686628235504031 + ], + [ + 0.0465959832072258, + 0.0350167416036129, + 0.0530133955180645, + 0.0619419664144516, + 0.0007760184234939516, + 0.0006190708954818547, + 0.0006583078065887094, + 0.0007237026002258062 + ], + [ + 0.0931919664144516, + 0.0853794664144516, + 0.0965401828289032, + 0.0870535746216774, + 0.0005100795533508062, + 0.0005296979798004031, + 0.0006626674439758062, + 0.0006190708954818547 + ], + [ + 0.1411830484867096, + 0.1233258992433548, + 0.0998883992433548, + 0.1210937574505806, + 0.0008762904908508062, + 0.0006495884736068547, + 0.0012032645754516125, + 0.0005972726503387094 + ], + [ + 0.106026791036129, + 0.1043526828289032, + 0.0965401828289032, + 0.0848214328289032, + 0.0008283343049697578, + 0.0007934570894576609, + 0.0007890973938629031, + 0.0006801060517318547 + ], + [ + 0.0599888414144516, + 0.0460379496216774, + 0.041015625, + 0.0223214291036129, + 0.0007498605409637094, + 0.0007019043550826609, + 0.0012294225161895156, + 0.0015171596314758062 + ], + [ + 0.0415736623108387, + 0.0390625, + 0.0571986623108387, + 0.0652901828289032, + 0.0007672991487197578, + 0.0008719308534637094, + 0.0006452288362197578, + 0.0015520368469879031 + ], + [ + 0.1132812574505806, + 0.126116082072258, + 0.1439732164144516, + 0.1010044664144516, + 0.0010245187440887094, + 0.0007367815705947578, + 0.0007455008453689516, + 0.0007324219332076609 + ], + [ + 0.0510602705180645, + 0.0499441996216774, + 0.0538504496216774, + 0.0505022332072258, + 0.0012032645754516125, + 0.0013166156131774187, + 0.0008588518830947578, + 0.0008850098238326609 + ], + [ + 0.0468750037252903, + 0.0426897332072258, + 0.0527343787252903, + 0.0563616082072258, + 0.0010637555969879031, + 0.002458845032379031, + 0.0009504046174697578, + 0.0010811942629516125 + ], + [ + 0.0502232164144516, + 0.0440848246216774, + 0.0535714328289032, + 0.0376674123108387, + 0.0008283343049697578, + 0.0008283343049697578, + 0.0007237026002258062, + 0.0007455008453689516 + ], + [ + 0.0493861623108387, + 0.0347377248108387, + 0.0446428582072258, + 0.0379464291036129, + 0.0011117118410766125, + 0.0010855539003387094, + 0.0009155274019576609, + 0.0008893694612197578 + ], + [ + 0.0647321492433548, + 0.0591517873108387, + 0.0859375074505806, + 0.0655691996216774, + 0.0010593959596008062, + 0.0011073522036895156, + 0.0010768346255645156, + 0.0013340541627258062 + ], + [ + 0.0404575914144516, + 0.0424107164144516, + 0.0513392873108387, + 0.0655691996216774, + 0.0007978167268447578, + 0.0011858259094879031, + 0.0009373256471008062, + 0.00115966796875 + ], + [ + 0.0831473246216774, + 0.0803571492433548, + 0.1127232164144516, + 0.0753348246216774, + 0.0019967216067016125, + 0.0011771066347137094, + 0.0009678432252258062, + 0.0010942731751129031 + ], + [ + 0.1356026828289032, + 0.0987723246216774, + 0.125, + 0.106026791036129, + 0.0011422294192016125, + 0.0018310548039153218, + 0.0011509486939758062, + 0.0009111677063629031 + ], + [ + 0.1015625074505806, + 0.0853794664144516, + 0.0725446492433548, + 0.0842633992433548, + 0.007359096314758062, + 0.0012381417909637094, + 0.0014038087101653218, + 0.0013427735539153218 + ], + [ + 0.0591517873108387, + 0.0479910746216774, + 0.0440848246216774, + 0.0365513414144516, + 0.004342215601354837, + 0.00390625, + 0.01102120615541935, + 0.0057198661379516125 + ], + [ + 0.03125, + 0.0426897332072258, + 0.0532924123108387, + 0.0530133955180645, + 0.0016392299439758062, + 0.00833565928041935, + 0.00676618330180645, + 0.009905134327709675 + ], + [ + 0.0853794664144516, + 0.0892857164144516, + 0.1127232164144516, + 0.0574776828289032, + 0.01067243330180645, + 0.011439732275903225, + 0.004708426538854837, + 0.007045201025903225 + ], + [ + 0.0440848246216774, + 0.0479910746216774, + 0.0549665205180645, + 0.0502232164144516, + 0.010951451025903225, + 0.007777622900903225, + 0.003679548157379031, + 0.02148437686264515 + ], + [ + 0.0485491082072258, + 0.0376674123108387, + 0.0493861623108387, + 0.0502232164144516, + 0.0017002651002258062, + 0.0028076174203306437, + 0.0018397740786895156, + 0.0020228796638548374 + ], + [ + 0.0524553582072258, + 0.0477120541036129, + 0.0541294664144516, + 0.0412946455180645, + 0.00201416015625, + 0.002092634094879031, + 0.0016392299439758062, + 0.0017002651002258062 + ], + [ + 0.0421316996216774, + 0.0393415205180645, + 0.0424107164144516, + 0.0438058041036129, + 0.0019269671756774187, + 0.0020839148201048374, + 0.0017002651002258062, + 0.001735142432153225 + ], + [ + 0.0585937537252903, + 0.0605468787252903, + 0.0594308078289032, + 0.0831473246216774, + 0.0018920899601653218, + 0.0017089844914153218, + 0.0019095285097137094, + 0.0029122489504516125 + ], + [ + 0.0477120541036129, + 0.0454799123108387, + 0.0569196455180645, + 0.0633370578289032, + 0.0014038087101653218, + 0.0018048968631774187, + 0.0013166156131774187, + 0.0019880023319274187 + ], + [ + 0.082589291036129, + 0.0959821492433548, + 0.1015625074505806, + 0.0848214328289032, + 0.0029645648319274187, + 0.0014735631411895156, + 0.0023542132694274187, + 0.001796177588403225 + ], + [ + 0.1395089328289032, + 0.1049107164144516, + 0.1127232164144516, + 0.1132812574505806, + 0.0079171322286129, + 0.0096261166036129, + 0.002458845032379031, + 0.0016392299439758062 + ], + [ + 0.0993303582072258, + 0.0954241082072258, + 0.074776791036129, + 0.0666852742433548, + 0.01653180830180645, + 0.0022321429569274187, + 0.005336216650903225, + 0.001674107275903225 + ], + [ + 0.0560825914144516, + 0.0465959832072258, + 0.0429687537252903, + 0.0337611623108387, + 0.013183594681322575, + 0.008858817629516125, + 0.01897321455180645, + 0.011788505129516125 + ], + [ + 0.0368303582072258, + 0.0421316996216774, + 0.0474330373108387, + 0.0558035746216774, + 0.002772740088403225, + 0.01346261240541935, + 0.01018415205180645, + 0.017578125 + ], + [ + 0.0837053582072258, + 0.0926339328289032, + 0.1121651828289032, + 0.0585937537252903, + 0.0185546875, + 0.02092633955180645, + 0.00906808115541935, + 0.008684431202709675 + ], + [ + 0.0438058041036129, + 0.0580357164144516, + 0.0580357164144516, + 0.0591517873108387, + 0.0182756707072258, + 0.01346261240541935, + 0.005650111939758062, + 0.0343191996216774 + ], + [ + 0.0549665205180645, + 0.0401785746216774, + 0.0569196455180645, + 0.0510602705180645, + 0.0036621096078306437, + 0.006243024952709675, + 0.00334821455180645, + 0.004150390625 + ], + [ + 0.0499441996216774, + 0.0471540205180645, + 0.0546875037252903, + 0.0396205373108387, + 0.004429408814758062, + 0.0049874442629516125, + 0.0034877234138548374, + 0.0035749163944274187 + ], + [ + 0.0440848246216774, + 0.0415736623108387, + 0.0541294664144516, + 0.0488281287252903, + 0.006138393189758062, + 0.005440848413854837, + 0.005894252564758062, + 0.004847935400903225 + ], + [ + 0.0625, + 0.0703125, + 0.0546875037252903, + 0.0892857164144516, + 0.00578962080180645, + 0.006347656715661287, + 0.005894252564758062, + 0.005859375465661287 + ], + [ + 0.0460379496216774, + 0.0412946455180645, + 0.06640625, + 0.03125, + 0.007777622900903225, + 0.01213727705180645, + 0.009835380129516125, + 0.005894252564758062 + ], + [ + 0.0471540205180645, + 0.0457589291036129, + 0.0452008955180645, + 0.0393415205180645, + 0.013323103077709675, + 0.013323103077709675, + 0.0168108269572258, + 0.010881696827709675 + ], + [ + 0.02943638525903225, + 0.0313895121216774, + 0.0376674123108387, + 0.0348772332072258, + 0.0120675228536129, + 0.0221819207072258, + 0.00955636240541935, + 0.008614677004516125 + ], + [ + 0.03041294775903225, + 0.0319475457072258, + 0.0311104916036129, + 0.033203125, + 0.02176339365541935, + 0.0352957621216774, + 0.02455357275903225, + 0.02734375186264515 + ], + [ + 0.0580357164144516, + 0.0435267873108387, + 0.0546875037252903, + 0.0421316996216774, + 0.01869419775903225, + 0.02385602705180645, + 0.02064732275903225, + 0.01981026865541935 + ], + [ + 0.0566406287252903, + 0.0532924123108387, + 0.0474330373108387, + 0.0546875037252903, + 0.02343750186264515, + 0.010393415577709675, + 0.02092633955180645, + 0.0163225457072258 + ] + ] +} \ No newline at end of file diff --git a/test/advanced_config/fp8_calibration_per_tensor/test_kv_cache_calib_per_tensor_qwen2.5_14b.json b/test/advanced_config/fp8_calibration_per_tensor/test_kv_cache_calib_per_tensor_qwen2.5_14b.json new file mode 100644 index 000000000..98d827cb3 --- /dev/null +++ b/test/advanced_config/fp8_calibration_per_tensor/test_kv_cache_calib_per_tensor_qwen2.5_14b.json @@ -0,0 +1,207 @@ +{ + "version": "1.0", + "architectures": "Qwen2ForCausalLM", + "quant_type": "per_tensor", + "qmin": -448.0, + "qmax": 448.0, + "num_layers": 48, + "num_head": 8, + "scales_shape": [ + 48, + 2 + ], + "scales": [ + [ + 0.0574776828289032, + 0.01297433115541935 + ], + [ + 0.02441406436264515, + 0.0031040736939758062 + ], + [ + 0.0558035746216774, + 0.0076729916036129 + ], + [ + 0.066964291036129, + 0.01492745615541935 + ], + [ + 0.0323660746216774, + 0.008684431202709675 + ], + [ + 0.0396205373108387, + 0.010811942629516125 + ], + [ + 0.0387834832072258, + 0.012416294775903225 + ], + [ + 0.0412946455180645, + 0.0143694207072258 + ], + [ + 0.0429687537252903, + 0.013183594681322575 + ], + [ + 0.0424107164144516, + 0.013811384327709675 + ], + [ + 0.0412946455180645, + 0.0156947560608387 + ], + [ + 0.0491071455180645, + 0.0143694207072258 + ], + [ + 0.0477120541036129, + 0.014787946827709675 + ], + [ + 0.0424107164144516, + 0.015206473879516125 + ], + [ + 0.0407366082072258, + 0.014648438431322575 + ], + [ + 0.0449218787252903, + 0.0164620541036129 + ], + [ + 0.0482700914144516, + 0.0154854916036129 + ], + [ + 0.0412946455180645, + 0.01883370615541935 + ], + [ + 0.0477120541036129, + 0.0166015625 + ], + [ + 0.0443638414144516, + 0.0164620541036129 + ], + [ + 0.0491071455180645, + 0.01702008955180645 + ], + [ + 0.0454799123108387, + 0.01981026865541935 + ], + [ + 0.0627790242433548, + 0.0262276791036129 + ], + [ + 0.0661272332072258, + 0.0191127248108387 + ], + [ + 0.0471540205180645, + 0.01702008955180645 + ], + [ + 0.0424107164144516, + 0.0181361623108387 + ], + [ + 0.0493861623108387, + 0.0203683041036129 + ], + [ + 0.0493861623108387, + 0.01590401865541935 + ], + [ + 0.0435267873108387, + 0.037109375 + ], + [ + 0.0700334832072258, + 0.0191127248108387 + ], + [ + 0.0809151828289032, + 0.0193917416036129 + ], + [ + 0.0608258955180645, + 0.0154854916036129 + ], + [ + 0.0432477705180645, + 0.0184151791036129 + ], + [ + 0.064453125, + 0.0205078125 + ], + [ + 0.0463169664144516, + 0.0203683041036129 + ], + [ + 0.0499441996216774, + 0.0202287957072258 + ], + [ + 0.0532924123108387, + 0.02845982275903225 + ], + [ + 0.0429687537252903, + 0.03055245615541935 + ], + [ + 0.0502232164144516, + 0.0252511166036129 + ], + [ + 0.0463169664144516, + 0.0281808041036129 + ], + [ + 0.0452008955180645, + 0.02943638525903225 + ], + [ + 0.0474330373108387, + 0.0325055830180645 + ], + [ + 0.0463169664144516, + 0.0306919664144516 + ], + [ + 0.0574776828289032, + 0.0401785746216774 + ], + [ + 0.0496651791036129, + 0.0415736623108387 + ], + [ + 0.063058041036129, + 0.0569196455180645 + ], + [ + 0.0898437574505806, + 0.06640625 + ], + [ + 0.0585937537252903, + 0.0731026828289032 + ] + ] +} \ No newline at end of file diff --git a/test/advanced_config/fp8_calibration_per_tensor/test_kv_cache_calib_per_tensor_qwen2.5_32b.json b/test/advanced_config/fp8_calibration_per_tensor/test_kv_cache_calib_per_tensor_qwen2.5_32b.json new file mode 100644 index 000000000..7b16ef354 --- /dev/null +++ b/test/advanced_config/fp8_calibration_per_tensor/test_kv_cache_calib_per_tensor_qwen2.5_32b.json @@ -0,0 +1,271 @@ +{ + "version": "1.0", + "architectures": "Qwen2ForCausalLM", + "quant_type": "per_tensor", + "qmin": -448.0, + "qmax": 448.0, + "num_layers": 64, + "num_head": 8, + "scales_shape": [ + 64, + 2 + ], + "scales": [ + [ + 0.0571986623108387, + 0.0210658498108387 + ], + [ + 0.0252511166036129, + 0.003557477844879031 + ], + [ + 0.0555245578289032, + 0.00847516767680645 + ], + [ + 0.0513392873108387, + 0.0182756707072258 + ], + [ + 0.0311104916036129, + 0.014718192629516125 + ], + [ + 0.041015625, + 0.012834821827709675 + ], + [ + 0.0382254496216774, + 0.013671875931322575 + ], + [ + 0.0404575914144516, + 0.01590401865541935 + ], + [ + 0.0385044664144516, + 0.013183594681322575 + ], + [ + 0.0415736623108387, + 0.01590401865541935 + ], + [ + 0.0401785746216774, + 0.01639229990541935 + ], + [ + 0.0482700914144516, + 0.0172991082072258 + ], + [ + 0.0454799123108387, + 0.0200892873108387 + ], + [ + 0.0404575914144516, + 0.0193917416036129 + ], + [ + 0.0398995541036129, + 0.01611328125 + ], + [ + 0.041015625, + 0.01360212080180645 + ], + [ + 0.0429687537252903, + 0.01492745615541935 + ], + [ + 0.0357142873108387, + 0.014718192629516125 + ], + [ + 0.0429687537252903, + 0.0171595998108387 + ], + [ + 0.0382254496216774, + 0.0181361623108387 + ], + [ + 0.0387834832072258, + 0.01625279150903225 + ], + [ + 0.0379464291036129, + 0.0135323666036129 + ], + [ + 0.0429687537252903, + 0.012834821827709675 + ], + [ + 0.0549665205180645, + 0.01688058115541935 + ], + [ + 0.0373883955180645, + 0.0163225457072258 + ], + [ + 0.0348772332072258, + 0.01967076025903225 + ], + [ + 0.0385044664144516, + 0.0169503353536129 + ], + [ + 0.041015625, + 0.0138811394572258 + ], + [ + 0.0418526791036129, + 0.0241350457072258 + ], + [ + 0.0616629496216774, + 0.01799665205180645 + ], + [ + 0.0694754496216774, + 0.02483258955180645 + ], + [ + 0.0432477705180645, + 0.0213448666036129 + ], + [ + 0.0401785746216774, + 0.01967076025903225 + ], + [ + 0.0387834832072258, + 0.0231584832072258 + ], + [ + 0.0460379496216774, + 0.02078683115541935 + ], + [ + 0.0460379496216774, + 0.0210658498108387 + ], + [ + 0.0449218787252903, + 0.0210658498108387 + ], + [ + 0.0421316996216774, + 0.0231584832072258 + ], + [ + 0.0725446492433548, + 0.0362723246216774 + ], + [ + 0.0655691996216774, + 0.0252511166036129 + ], + [ + 0.0412946455180645, + 0.02246093936264515 + ], + [ + 0.0426897332072258, + 0.0221819207072258 + ], + [ + 0.0463169664144516, + 0.0262276791036129 + ], + [ + 0.0446428582072258, + 0.02064732275903225 + ], + [ + 0.0471540205180645, + 0.0792410746216774 + ], + [ + 0.07421875, + 0.02092633955180645 + ], + [ + 0.0842633992433548, + 0.02483258955180645 + ], + [ + 0.0622209832072258, + 0.0171595998108387 + ], + [ + 0.0443638414144516, + 0.0221819207072258 + ], + [ + 0.063058041036129, + 0.02287946455180645 + ], + [ + 0.0449218787252903, + 0.0267857164144516 + ], + [ + 0.0482700914144516, + 0.02734375186264515 + ], + [ + 0.0560825914144516, + 0.0362723246216774 + ], + [ + 0.0446428582072258, + 0.0373883955180645 + ], + [ + 0.0457589291036129, + 0.0287388414144516 + ], + [ + 0.0426897332072258, + 0.03055245615541935 + ], + [ + 0.0426897332072258, + 0.0345982164144516 + ], + [ + 0.0443638414144516, + 0.0376674123108387 + ], + [ + 0.0398995541036129, + 0.0390625 + ], + [ + 0.063058041036129, + 0.0463169664144516 + ], + [ + 0.0454799123108387, + 0.0440848246216774 + ], + [ + 0.0594308078289032, + 0.063058041036129 + ], + [ + 0.0797991082072258, + 0.0675223246216774 + ], + [ + 0.065011166036129, + 0.078125 + ] + ] +} \ No newline at end of file diff --git a/test/advanced_config/fp8_calibration_per_tensor/test_kv_cache_calib_per_tensor_qwen2.5_72b.json b/test/advanced_config/fp8_calibration_per_tensor/test_kv_cache_calib_per_tensor_qwen2.5_72b.json new file mode 100644 index 000000000..3a0d0c8cc --- /dev/null +++ b/test/advanced_config/fp8_calibration_per_tensor/test_kv_cache_calib_per_tensor_qwen2.5_72b.json @@ -0,0 +1,335 @@ +{ + "version": "1.0", + "architectures": "Qwen2ForCausalLM", + "quant_type": "per_tensor", + "qmin": -448.0, + "qmax": 448.0, + "num_layers": 80, + "num_head": 8, + "scales_shape": [ + 80, + 2 + ], + "scales": [ + [ + 0.0415736623108387, + 0.0006931850221008062 + ], + [ + 0.0362723246216774, + 0.0025111609138548374 + ], + [ + 0.0327845998108387, + 0.005580357275903225 + ], + [ + 0.0401785746216774, + 0.0036446708254516125 + ], + [ + 0.0376674123108387, + 0.00830078125 + ], + [ + 0.037109375, + 0.006103516090661287 + ], + [ + 0.0316685289144516, + 0.006068638525903225 + ], + [ + 0.0354352705180645, + 0.008196149952709675 + ], + [ + 0.0379464291036129, + 0.0076729916036129 + ], + [ + 0.0493861623108387, + 0.0078125 + ], + [ + 0.0454799123108387, + 0.00578962080180645 + ], + [ + 0.0471540205180645, + 0.011928013525903225 + ], + [ + 0.0435267873108387, + 0.00833565928041935 + ], + [ + 0.0460379496216774, + 0.007463728077709675 + ], + [ + 0.0510602705180645, + 0.0086495541036129 + ], + [ + 0.0432477705180645, + 0.009347098879516125 + ], + [ + 0.0452008955180645, + 0.008510044775903225 + ], + [ + 0.0485491082072258, + 0.006870815064758062 + ], + [ + 0.0463169664144516, + 0.010323661379516125 + ], + [ + 0.0454799123108387, + 0.012346540577709675 + ], + [ + 0.0465959832072258, + 0.008056640625 + ], + [ + 0.0446428582072258, + 0.0084054134786129 + ], + [ + 0.0418526791036129, + 0.01625279150903225 + ], + [ + 0.0446428582072258, + 0.015206473879516125 + ], + [ + 0.0505022332072258, + 0.00906808115541935 + ], + [ + 0.0432477705180645, + 0.011439732275903225 + ], + [ + 0.0460379496216774, + 0.01248604990541935 + ], + [ + 0.0485491082072258, + 0.01360212080180645 + ], + [ + 0.0482700914144516, + 0.01883370615541935 + ], + [ + 0.0513392873108387, + 0.014718192629516125 + ], + [ + 0.0496651791036129, + 0.01297433115541935 + ], + [ + 0.0482700914144516, + 0.01555524580180645 + ], + [ + 0.0465959832072258, + 0.01506696455180645 + ], + [ + 0.0468750037252903, + 0.0220424123108387 + ], + [ + 0.0482700914144516, + 0.0174386166036129 + ], + [ + 0.0521763414144516, + 0.0143694207072258 + ], + [ + 0.0463169664144516, + 0.0145089291036129 + ], + [ + 0.0527343787252903, + 0.0154854916036129 + ], + [ + 0.0502232164144516, + 0.011230469681322575 + ], + [ + 0.0482700914144516, + 0.012834821827709675 + ], + [ + 0.0516183041036129, + 0.01297433115541935 + ], + [ + 0.0527343787252903, + 0.0140206478536129 + ], + [ + 0.0560825914144516, + 0.01506696455180645 + ], + [ + 0.0516183041036129, + 0.0149972103536129 + ], + [ + 0.0485491082072258, + 0.01248604990541935 + ], + [ + 0.0513392873108387, + 0.01967076025903225 + ], + [ + 0.0471540205180645, + 0.02748326025903225 + ], + [ + 0.0465959832072258, + 0.03027343936264515 + ], + [ + 0.0560825914144516, + 0.0290178582072258 + ], + [ + 0.0507812537252903, + 0.0415736623108387 + ], + [ + 0.0521763414144516, + 0.03125 + ], + [ + 0.0580357164144516, + 0.0418526791036129 + ], + [ + 0.0580357164144516, + 0.0323660746216774 + ], + [ + 0.0583147332072258, + 0.03027343936264515 + ], + [ + 0.0555245578289032, + 0.02483258955180645 + ], + [ + 0.0560825914144516, + 0.0298549123108387 + ], + [ + 0.0594308078289032, + 0.0385044664144516 + ], + [ + 0.0691964328289032, + 0.0221819207072258 + ], + [ + 0.0524553582072258, + 0.01967076025903225 + ], + [ + 0.0538504496216774, + 0.01981026865541935 + ], + [ + 0.0583147332072258, + 0.0213448666036129 + ], + [ + 0.0535714328289032, + 0.0290178582072258 + ], + [ + 0.0580357164144516, + 0.0260881707072258 + ], + [ + 0.0616629496216774, + 0.0309709832072258 + ], + [ + 0.0594308078289032, + 0.02580915205180645 + ], + [ + 0.0577566996216774, + 0.0477120541036129 + ], + [ + 0.0583147332072258, + 0.033203125 + ], + [ + 0.0530133955180645, + 0.0288783498108387 + ], + [ + 0.0588727705180645, + 0.0387834832072258 + ], + [ + 0.0627790242433548, + 0.0306919664144516 + ], + [ + 0.0625, + 0.0404575914144516 + ], + [ + 0.0608258955180645, + 0.0421316996216774 + ], + [ + 0.0641741082072258, + 0.0407366082072258 + ], + [ + 0.063058041036129, + 0.0518973246216774 + ], + [ + 0.0633370578289032, + 0.0443638414144516 + ], + [ + 0.064453125, + 0.0521763414144516 + ], + [ + 0.0585937537252903, + 0.0608258955180645 + ], + [ + 0.0627790242433548, + 0.0485491082072258 + ], + [ + 0.0605468787252903, + 0.0530133955180645 + ], + [ + 0.0488281287252903, + 0.0365513414144516 + ] + ] +} \ No newline at end of file diff --git a/test/advanced_config/fp8_calibration_per_tensor/test_kv_cache_calib_per_tensor_qwen3_235b.json b/test/advanced_config/fp8_calibration_per_tensor/test_kv_cache_calib_per_tensor_qwen3_235b.json new file mode 100644 index 000000000..0c0a47e86 --- /dev/null +++ b/test/advanced_config/fp8_calibration_per_tensor/test_kv_cache_calib_per_tensor_qwen3_235b.json @@ -0,0 +1,391 @@ +{ + "version": "1.0", + "architectures": "Qwen3MoeForCausalLM", + "quant_type": "per_tensor", + "qmin": -448.0, + "qmax": 448.0, + "num_layers": 94, + "num_head": 4, + "scales_shape": [ + 94, + 2 + ], + "scales": [ + [ + 0.0552455373108387, + 0.00022888185048941523 + ], + [ + 0.0404575914144516, + 0.0002899169921875 + ], + [ + 0.0792410746216774, + 0.0002855573548004031 + ], + [ + 0.0705915242433548, + 0.0004512242157943547 + ], + [ + 0.0719866082072258, + 0.0003923688782379031 + ], + [ + 0.149553582072258, + 0.0006801060517318547 + ], + [ + 0.0491071455180645, + 0.0007106236298568547 + ], + [ + 0.0566406287252903, + 0.0007367815705947578 + ], + [ + 0.1104910746216774, + 0.0007629395113326609 + ], + [ + 0.0555245578289032, + 0.0009722028626129031 + ], + [ + 0.106026791036129, + 0.0005929129547439516 + ], + [ + 0.0482700914144516, + 0.0005384173127822578 + ], + [ + 0.0538504496216774, + 0.0005296979798004031 + ], + [ + 0.0583147332072258, + 0.0005318777984939516 + ], + [ + 0.0736607164144516, + 0.0005580357392318547 + ], + [ + 0.0505022332072258, + 0.00045558385318145156 + ], + [ + 0.0602678582072258, + 0.0007716587861068547 + ], + [ + 0.078683041036129, + 0.0004991804016754031 + ], + [ + 0.074776791036129, + 0.0004817417939193547 + ], + [ + 0.0446428582072258, + 0.0005296979798004031 + ], + [ + 0.0499441996216774, + 0.0006583078065887094 + ], + [ + 0.0446428582072258, + 0.00040980748599395156 + ], + [ + 0.0479910746216774, + 0.00044904439710080624 + ], + [ + 0.0499441996216774, + 0.000518798828125 + ], + [ + 0.0463169664144516, + 0.0006321498658508062 + ], + [ + 0.0998883992433548, + 0.0006190708954818547 + ], + [ + 0.0460379496216774, + 0.0015520368469879031 + ], + [ + 0.0560825914144516, + 0.0007498605409637094 + ], + [ + 0.0429687537252903, + 0.0005623953766189516 + ], + [ + 0.0549665205180645, + 0.0013078962219879031 + ], + [ + 0.109933041036129, + 0.0006365095032379031 + ], + [ + 0.0491071455180645, + 0.000518798828125 + ], + [ + 0.0563616082072258, + 0.0008196149719879031 + ], + [ + 0.0571986623108387, + 0.0011422294192016125 + ], + [ + 0.078125, + 0.0006975446594879031 + ], + [ + 0.0513392873108387, + 0.0008501325501129031 + ], + [ + 0.0544084832072258, + 0.00128173828125 + ], + [ + 0.0764508992433548, + 0.00115966796875 + ], + [ + 0.0711495578289032, + 0.001068115234375 + ], + [ + 0.0460379496216774, + 0.0009024484315887094 + ], + [ + 0.0521763414144516, + 0.0008065360016189516 + ], + [ + 0.0465959832072258, + 0.0009329660097137094 + ], + [ + 0.0471540205180645, + 0.0006801060517318547 + ], + [ + 0.0544084832072258, + 0.001220703125 + ], + [ + 0.0446428582072258, + 0.0008283343049697578 + ], + [ + 0.1037946492433548, + 0.0011945453006774187 + ], + [ + 0.0460379496216774, + 0.001970563782379031 + ], + [ + 0.0591517873108387, + 0.0010942731751129031 + ], + [ + 0.0809151828289032, + 0.0008414132753387094 + ], + [ + 0.0513392873108387, + 0.001735142432153225 + ], + [ + 0.0611049123108387, + 0.0036621096078306437 + ], + [ + 0.078125, + 0.0028076174203306437 + ], + [ + 0.0725446492433548, + 0.002406529150903225 + ], + [ + 0.0452008955180645, + 0.0019880023319274187 + ], + [ + 0.0513392873108387, + 0.0023890906013548374 + ], + [ + 0.0465959832072258, + 0.0030866351444274187 + ], + [ + 0.0513392873108387, + 0.0023890906013548374 + ], + [ + 0.0502232164144516, + 0.0038888114504516125 + ], + [ + 0.0465959832072258, + 0.0028424945194274187 + ], + [ + 0.0970982164144516, + 0.004185268189758062 + ], + [ + 0.0465959832072258, + 0.0036097937263548374 + ], + [ + 0.0563616082072258, + 0.0019967216067016125 + ], + [ + 0.0415736623108387, + 0.0025634765625 + ], + [ + 0.0546875037252903, + 0.002580915344879031 + ], + [ + 0.0976562574505806, + 0.0023716518189758062 + ], + [ + 0.0505022332072258, + 0.0018833705689758062 + ], + [ + 0.0549665205180645, + 0.0022495815064758062 + ], + [ + 0.0580357164144516, + 0.0034528460819274187 + ], + [ + 0.0652901828289032, + 0.002406529150903225 + ], + [ + 0.0530133955180645, + 0.003679548157379031 + ], + [ + 0.0569196455180645, + 0.010742188431322575 + ], + [ + 0.0700334832072258, + 0.0088936947286129 + ], + [ + 0.0691964328289032, + 0.0079171322286129 + ], + [ + 0.0460379496216774, + 0.007638114038854837 + ], + [ + 0.0541294664144516, + 0.0057198661379516125 + ], + [ + 0.0477120541036129, + 0.01018415205180645 + ], + [ + 0.0488281287252903, + 0.008858817629516125 + ], + [ + 0.0485491082072258, + 0.009905134327709675 + ], + [ + 0.0468750037252903, + 0.006835937965661287 + ], + [ + 0.094308041036129, + 0.012276786379516125 + ], + [ + 0.0491071455180645, + 0.010811942629516125 + ], + [ + 0.0544084832072258, + 0.0193917416036129 + ], + [ + 0.0449218787252903, + 0.0148577019572258 + ], + [ + 0.0541294664144516, + 0.009835380129516125 + ], + [ + 0.0398995541036129, + 0.013253348879516125 + ], + [ + 0.0460379496216774, + 0.01981026865541935 + ], + [ + 0.0393415205180645, + 0.01311383955180645 + ], + [ + 0.0368303582072258, + 0.01395089365541935 + ], + [ + 0.0362723246216774, + 0.02064732275903225 + ], + [ + 0.0326450914144516, + 0.02580915205180645 + ], + [ + 0.0315290205180645, + 0.0290178582072258 + ], + [ + 0.064453125, + 0.0429687537252903 + ], + [ + 0.0426897332072258, + 0.0319475457072258 + ], + [ + 0.0463169664144516, + 0.02762276865541935 + ] + ] +} \ No newline at end of file diff --git a/test/advanced_config/fp8_calibration_per_tensor/test_kv_cache_calib_per_tensor_qwen3_30b.json b/test/advanced_config/fp8_calibration_per_tensor/test_kv_cache_calib_per_tensor_qwen3_30b.json new file mode 100644 index 000000000..b3f6b714e --- /dev/null +++ b/test/advanced_config/fp8_calibration_per_tensor/test_kv_cache_calib_per_tensor_qwen3_30b.json @@ -0,0 +1,207 @@ +{ + "version": "1.0", + "architectures": "Qwen3MoeForCausalLM", + "quant_type": "per_tensor", + "qmin": -448.0, + "qmax": 448.0, + "num_layers": 48, + "num_head": 4, + "scales_shape": [ + 48, + 2 + ], + "scales": [ + [ + 0.2232142984867096, + 0.0002506801101844758 + ], + [ + 0.1143973246216774, + 0.00047956197522580624 + ], + [ + 0.0611049123108387, + 0.0008283343049697578 + ], + [ + 0.3482142984867096, + 0.0009111677063629031 + ], + [ + 0.0876116082072258, + 0.0009634835878387094 + ], + [ + 0.1183035746216774, + 0.0007542201783508062 + ], + [ + 0.0619419664144516, + 0.0009329660097137094 + ], + [ + 0.0993303582072258, + 0.0006670270813629031 + ], + [ + 0.1439732164144516, + 0.0012032645754516125 + ], + [ + 0.1065848246216774, + 0.0008414132753387094 + ], + [ + 0.0599888414144516, + 0.0015171596314758062 + ], + [ + 0.0641741082072258, + 0.0015258790226653218 + ], + [ + 0.1428571492433548, + 0.0008588518830947578 + ], + [ + 0.0538504496216774, + 0.0012642997317016125 + ], + [ + 0.0566406287252903, + 0.0019967216067016125 + ], + [ + 0.0535714328289032, + 0.0009634835878387094 + ], + [ + 0.0496651791036129, + 0.0011422294192016125 + ], + [ + 0.0859375074505806, + 0.0014212472597137094 + ], + [ + 0.0655691996216774, + 0.0011509486939758062 + ], + [ + 0.1132812574505806, + 0.0018223354127258062 + ], + [ + 0.1350446492433548, + 0.0017700196476653218 + ], + [ + 0.1010044664144516, + 0.00701032392680645 + ], + [ + 0.0591517873108387, + 0.01102120615541935 + ], + [ + 0.0558035746216774, + 0.009905134327709675 + ], + [ + 0.1127232164144516, + 0.011439732275903225 + ], + [ + 0.0549665205180645, + 0.0213448666036129 + ], + [ + 0.0513392873108387, + 0.002650669775903225 + ], + [ + 0.0541294664144516, + 0.001970563782379031 + ], + [ + 0.0438058041036129, + 0.0020839148201048374 + ], + [ + 0.0837053582072258, + 0.003138951025903225 + ], + [ + 0.0633370578289032, + 0.002040318213403225 + ], + [ + 0.1032366082072258, + 0.0032087054569274187 + ], + [ + 0.141741082072258, + 0.00927734375 + ], + [ + 0.1010044664144516, + 0.0163225457072258 + ], + [ + 0.0560825914144516, + 0.01897321455180645 + ], + [ + 0.0560825914144516, + 0.017578125 + ], + [ + 0.1104910746216774, + 0.0210658498108387 + ], + [ + 0.0580357164144516, + 0.0343191996216774 + ], + [ + 0.0563616082072258, + 0.00627790205180645 + ], + [ + 0.0546875037252903, + 0.0049874442629516125 + ], + [ + 0.0544084832072258, + 0.005929129663854837 + ], + [ + 0.0892857164144516, + 0.006243024952709675 + ], + [ + 0.06640625, + 0.011788505129516125 + ], + [ + 0.0471540205180645, + 0.01457868330180645 + ], + [ + 0.0376674123108387, + 0.01576451025903225 + ], + [ + 0.033203125, + 0.03027343936264515 + ], + [ + 0.0583147332072258, + 0.0213448666036129 + ], + [ + 0.0563616082072258, + 0.0212053582072258 + ] + ] +} \ No newline at end of file diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index 6cd97cfdf..3db0ac004 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -41,6 +41,7 @@ def test_model_inference(args): "run_mode": "normal", "max_seq_length": args.max_req_total_len, "disable_cudagraph": args.disable_cudagraph, + "mode": args.mode, } proc = multiprocessing.Process( target=tppart_model_infer, @@ -213,10 +214,12 @@ def torch_profile(fn, log_dir=None): ) as prof: fn() if get_current_rank_in_dp() == 0: - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) -def run_forward_once(model_kvargs, input_len, output_len, batch_size, model_part, enable_overlap, torch_profile=False): +def run_forward_once( + model_kvargs, input_len, output_len, batch_size, model_part, enable_overlap, enable_torch_profile=False +): test_data = np.vstack([np.random.randint(0, 50256, input_len) for _ in range(batch_size)]) test_data = test_data.reshape(-1) test_data = torch.from_numpy(test_data).cuda() @@ -274,7 +277,7 @@ def run_forward_once(model_kvargs, input_len, output_len, batch_size, model_part f"prefill throughput: {dp_size * batch_size * input_len / (time.time() - prefill_start_time)} tokens/s" ) - if torch_profile: + if enable_torch_profile: print("Profile Prefill") try: torch_profile( @@ -312,7 +315,7 @@ def run_forward_once(model_kvargs, input_len, output_len, batch_size, model_part b_seq_len, total_token_num, ) - if torch_profile: + if enable_torch_profile and i == output_len - 1: try: torch_profile( lambda: decode_fn( @@ -391,7 +394,7 @@ def tppart_model_infer(args, model_kvargs, batch_size, input_len, output_len, an batch_size=b, model_part=model_part, enable_overlap=enable_overlap, - torch_profile=False, + enable_torch_profile=False, ) # test @@ -402,7 +405,7 @@ def tppart_model_infer(args, model_kvargs, batch_size, input_len, output_len, an batch_size=b, model_part=model_part, enable_overlap=enable_overlap, - torch_profile=False, + enable_torch_profile=args.torch_profile, ) if rank_id == 0: print("=" * 50) diff --git a/unit_tests/models/llama/test_context_flashattention_nopad_fa3_fp8.py b/unit_tests/models/llama/test_context_flashattention_nopad_fa3_fp8.py new file mode 100644 index 000000000..737bb655b --- /dev/null +++ b/unit_tests/models/llama/test_context_flashattention_nopad_fa3_fp8.py @@ -0,0 +1,154 @@ +import torch +import time +import pytest +import triton as tl +import numpy as np +import torch.nn.functional as F +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd, +) +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def kv_quantize_per_head_fp8(kv_buffer: torch.Tensor, seq_lens): + device = kv_buffer.device + B = seq_lens.size(0) + min_fp8 = torch.finfo(torch.float8_e4m3fn).min + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + _, S_max, H, D = kv_buffer.shape + seq_range = torch.arange(S_max, device=device)[None, :] + valid_mask = (seq_range < seq_lens[:, None]).view(B, S_max, 1, 1) + masked = kv_buffer * valid_mask + max_per_bh = masked.abs().amax(dim=(1, 3)) # [B, H] + scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)).to(torch.float32) + scales_exp = scales.view(B, 1, H, 1) + q = (kv_buffer / scales_exp).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn) + return q, scales + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_context_attention_fwd_fa3_fp8(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + # for i in range(Z * N_CTX): + # kv[i] = torch.randn((2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") * (i % 10 + 1) + + max_input_len = Z * N_CTX + req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) + rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") + b_seq_len += rand_num + b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") + if N_CTX > 1: + b_ready_cache_len = torch.randint_like(b_seq_len, high=(N_CTX - 1) // 2, dtype=torch.int32, device="cuda") + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + q_lens = b_seq_len - b_ready_cache_len + q_start_loc = q_lens.cumsum(0) - q_lens + + q = torch.randn((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_ready_cache_len = b_ready_cache_len + infer_state.b_start_loc = q_start_loc + + context_attention_fwd( + q, + kv[:, :KV_HEADS, :], + kv[:, KV_HEADS:, :], + o, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + req_to_token_indexs, + ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + page_table = torch.empty((batch_size, N_CTX), dtype=torch.int32, device="cuda") + page_table.copy_(req_to_token_indexs[b_req_idx, :N_CTX]) + + q_starts = torch.zeros((Z + 1,)).int().cuda() + q_starts[1:] = torch.cumsum(b_seq_len - b_ready_cache_len, dim=0) + kv_starts = torch.zeros_like(q_starts) + kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) + + k_cache = kv[:, :KV_HEADS, :] + v_cache = kv[:, KV_HEADS:, :] + # o1 = flash_attn_with_kvcache( + # q=q, + # k_cache=k_cache.reshape(-1, 1, kv_heads, head_dim), + # v_cache=v_cache.reshape(-1, 1, kv_heads, head_dim), + # page_table=page_table, + # cache_seqlens=infer_state.b_seq_len, + # cu_seqlens_q=q_starts, + # cu_seqlens_k_new=kv_starts, + # max_seqlen_q=N_CTX, + # causal=True, + # window_size=(-1, -1), + # softcap=0.0, + # return_softmax_lse=False, + # ) + + q, q_scale = q_per_head_fp8_quant(q.view(q.shape[0], kv_heads, -1), q_lens, q_starts) + k, k_scale = kv_quantize_per_head_fp8(k_cache[page_table], b_seq_len) + v, v_scale = kv_quantize_per_head_fp8(v_cache[page_table], b_seq_len) + o1 = flash_attn_with_kvcache( + q=q.view(-1, q_heads, head_dim), + k_cache=k.view(-1, N_CTX, kv_heads, head_dim).to(torch.float8_e4m3fn), + v_cache=v.view(-1, N_CTX, kv_heads, head_dim).to(torch.float8_e4m3fn), + # page_table=page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=q_starts, + cu_seqlens_k_new=kv_starts, + max_seqlen_q=N_CTX, + causal=True, + window_size=(-1, -1), + softcap=0.0, + q_descale=q_scale.view(batch_size, kv_heads), + k_descale=k_scale.view(batch_size, kv_heads), + v_descale=v_scale.view(batch_size, kv_heads), + return_softmax_lse=False, + ) + + # assert torch.allclose(o, o1, atol=1e-1, rtol=1e-1) + cos_sim1 = F.cosine_similarity(o, o1).mean() + print(cos_sim1) + assert cos_sim1.item() == 1 + + +if __name__ == "__main__": + test_context_attention_fwd_fa3_fp8(32, 16384, 32, 4, 128) diff --git a/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_fp8.py b/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_fp8.py new file mode 100644 index 000000000..5ee2306ad --- /dev/null +++ b/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_fp8.py @@ -0,0 +1,145 @@ +import torch +import time +import pytest +import numpy as np +import torch.nn.functional as F +import flashinfer +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd, +) +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops + +if HAS_VLLM: + scaled_fp8_quant = vllm_ops.scaled_fp8_quant +else: + scaled_fp8_quant = None + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_context_attention_fwd_flashinfer_fp8(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + # for i in range(Z * N_CTX): + # kv[i] = torch.randn((2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") * (i % 64 + 1) + + max_input_len = Z * N_CTX + req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) + rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") + b_seq_len += rand_num + b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") + if N_CTX > 1: + b_ready_cache_len = torch.randint_like(b_seq_len, high=(N_CTX - 1) // 2, dtype=torch.int32, device="cuda") + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + q_lens = b_seq_len - b_ready_cache_len + q_start_loc = q_lens.cumsum(0) - q_lens + kv_start_loc = b_seq_len.cumsum(0) - b_seq_len + + q = torch.randn((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_ready_cache_len = b_ready_cache_len + infer_state.b_start_loc = q_start_loc + + context_attention_fwd( + q, + kv[:, :KV_HEADS, :], + kv[:, KV_HEADS:, :], + o, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + req_to_token_indexs, + ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + page_size = 1 + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) + q_starts = torch.zeros((Z + 1,)).int().cuda() + q_starts[1:] = torch.cumsum(b_seq_len - b_ready_cache_len, dim=0) + kv_starts = torch.zeros_like(q_starts) + kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) + q_indptr = q_starts.int() + kv_indptr = kv_starts.int() + kv_indices = torch.arange(Z * N_CTX).cuda().int() + for b, sl, start in zip(b_req_idx, b_seq_len, kv_start_loc): + kv_indices[start : start + sl] = req_to_token_indexs[b][:sl] + kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) + wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, + qo_indptr_buf=q_indptr, + paged_kv_indptr_buf=kv_indptr, + paged_kv_indices_buf=kv_indices, + paged_kv_last_page_len_buf=kv_last_page_len_buffer, + ) + kv_last_page_len = torch.full((batch_size,), page_size, dtype=torch.int32) + k_cache = kv[:, :KV_HEADS, :].contiguous() + v_cache = kv[:, KV_HEADS:, :].contiguous() + k, k_scale = scaled_fp8_quant(k_cache.view(1, -1)) + v, v_scale = scaled_fp8_quant(v_cache.view(1, -1)) + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + q_heads, + kv_heads, + head_dim, + page_size, + causal=True, + pos_encoding_mode="NONE", + logits_soft_cap=0.0, + q_data_type=q.dtype, + kv_data_type=torch.float8_e4m3fn, + ) + wrapper.run( + q, + (k.view(-1, 1, kv_heads, head_dim), v.view(-1, 1, kv_heads, head_dim)), + k_scale=k_scale, + v_scale=v_scale, + out=o1, + return_lse=False, + ) + + # assert torch.allclose(o, o1, atol=1e-2, rtol=2e-1) + cos_sim1 = F.cosine_similarity(o, o1).mean() + print(cos_sim1) + assert cos_sim1 == 1 + + +if __name__ == "__main__": + test_context_attention_fwd_flashinfer_fp8(16, 1024, 28, 4, 128) diff --git a/unit_tests/models/llama/test_token_attention_nopad.py b/unit_tests/models/llama/test_token_attention_nopad.py index d44b9812b..1bbb29166 100644 --- a/unit_tests/models/llama/test_token_attention_nopad.py +++ b/unit_tests/models/llama/test_token_attention_nopad.py @@ -5,10 +5,6 @@ import torch.nn.functional as F import flashinfer from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd, - context_attention_fwd_no_prompt_cache, -) from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.common.req_manager import ReqManager from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd @@ -70,7 +66,7 @@ def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state): for e in [128] ], ) -def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): +def test_token_attention_nopad(batch, seqlen, q_heads, kv_heads, head_dim): Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim dtype = torch.bfloat16 q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") diff --git a/unit_tests/models/llama/test_token_attention_nopad_fa3_fp8.py b/unit_tests/models/llama/test_token_attention_nopad_fa3_fp8.py new file mode 100644 index 000000000..a7f48ab89 --- /dev/null +++ b/unit_tests/models/llama/test_token_attention_nopad_fa3_fp8.py @@ -0,0 +1,187 @@ +import torch +import time +import pytest +import numpy as np +import torch.nn.functional as F +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def kv_quantize_per_head_fp8(kv_buffer: torch.Tensor, seq_lens): + device = kv_buffer.device + B = seq_lens.size(0) + min_fp8 = torch.finfo(torch.float8_e4m3fn).min + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + _, S_max, H, D = kv_buffer.shape + seq_range = torch.arange(S_max, device=device)[None, :] + valid_mask = (seq_range < seq_lens[:, None]).view(B, S_max, 1, 1) + masked = kv_buffer * valid_mask + max_per_bh = masked.float().abs().amax(dim=(1, 3)) # [B, H] + scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)) + scales_exp = scales.view(B, 1, H, 1) + q = (kv_buffer / scales_exp).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn) + return q, scales + + +def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state, req_to_token_indexs): + from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd + + total_token_num = infer_state.total_token_num + batch_size = infer_state.batch_size + calcu_shape1 = (batch_size, q_h, h_dim) + + att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() + + token_att_fwd( + q.view(calcu_shape1), + k, + att_m_tensor, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( + token_softmax_reducev_fwd, + ) + + token_softmax_reducev_fwd( + att_m_tensor, + v, + o, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + ) + return o + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_token_attention_nopad_fa3_fp8(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + # for i in range(Z * N_CTX): + # kv[i] = torch.randn((2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") * (i % 10 + 1) + + max_input_len = Z * N_CTX + req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) + rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") + b_seq_len += rand_num + b_start_loc = b_seq_len.cumsum(0) - b_seq_len + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + + o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_start_loc = b_start_loc + + ref_token_attention_nopad( + q, + kv[:, :KV_HEADS, :], + kv[:, KV_HEADS:, :], + o, + Q_HEADS, + HEAD_DIM, + infer_state, + req_to_token_indexs, + ) + # gqa_decode_attention_fwd( + # q, + # kv[:,:KV_HEADS,:], + # kv[:,KV_HEADS:,:], + # o, + # req_to_token_indexs, + # infer_state.b_req_idx, + # infer_state.b_seq_len, + # ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + kv_starts = torch.zeros((Z + 1,)).int().cuda() + kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) + q_starts = torch.arange(0, Z + 1).int().cuda() + page_table = torch.empty((batch_size, N_CTX), dtype=torch.int32).to(0) + page_table.copy_(req_to_token_indexs[b_req_idx, :N_CTX]) + + k_cache = kv[:, :KV_HEADS, :].contiguous() + v_cache = kv[:, KV_HEADS:, :].contiguous() + # o1 = flash_attn_with_kvcache( + # q=q, + # k_cache=k_cache[page_table].view(-1, N_CTX, kv_heads, head_dim), + # v_cache=v_cache[page_table].view(-1, N_CTX, kv_heads, head_dim), + # # page_table=page_table, + # cache_seqlens=infer_state.b_seq_len, + # cu_seqlens_q=q_starts, + # cu_seqlens_k_new=kv_starts, + # max_seqlen_q=1, + # causal=False, + # window_size=(-1, -1), + # softcap=0.0, + # return_softmax_lse=False, + # ) + + q, q_scale = scaled_fp8_quant(q.view(batch_size * kv_heads, -1), use_per_token_if_dynamic=True) + k, k_scale = kv_quantize_per_head_fp8(k_cache[page_table], b_seq_len) + v, v_scale = kv_quantize_per_head_fp8(v_cache[page_table], b_seq_len) + o1 = flash_attn_with_kvcache( + q=q.view(-1, q_heads, head_dim), + k_cache=k.view(-1, N_CTX, kv_heads, head_dim), + v_cache=v.view(-1, N_CTX, kv_heads, head_dim), + # page_table=page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=q_starts, + cu_seqlens_k_new=kv_starts, + max_seqlen_q=1, + causal=False, + window_size=(-1, -1), + softcap=0.0, + q_descale=q_scale.view(batch_size, kv_heads), + k_descale=k_scale.view(batch_size, kv_heads), + v_descale=v_scale.view(batch_size, kv_heads), + return_softmax_lse=False, + ) + + # assert torch.allclose(o, o1, atol=1e-1, rtol=1e-1) + cos_sim1 = F.cosine_similarity(o, o1).mean() + print(cos_sim1) + assert cos_sim1 == 1 + + +if __name__ == "__main__": + test_token_attention_nopad_fa3_fp8(16, 16384, 28, 4, 128) diff --git a/unit_tests/models/llama/test_token_attention_nopad_flashinfer_fp8.py b/unit_tests/models/llama/test_token_attention_nopad_flashinfer_fp8.py new file mode 100644 index 000000000..5c0e595b9 --- /dev/null +++ b/unit_tests/models/llama/test_token_attention_nopad_flashinfer_fp8.py @@ -0,0 +1,170 @@ +import torch +import time +import pytest +import numpy as np +import torch.nn.functional as F +import flashinfer +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd +from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state, req_to_token_indexs): + from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd + + total_token_num = infer_state.total_token_num + batch_size = infer_state.batch_size + calcu_shape1 = (batch_size, q_h, h_dim) + + att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() + + token_att_fwd( + q.view(calcu_shape1), + k, + att_m_tensor, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( + token_softmax_reducev_fwd, + ) + + token_softmax_reducev_fwd( + att_m_tensor, + v, + o, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + ) + return o + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_token_attention_nopad_flashinfer_fp8(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + # for i in range(Z * N_CTX): + # kv[i] = torch.randn((2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") * (i % 10 + 1) + + max_input_len = Z * N_CTX + req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) + rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") + b_seq_len += rand_num + b_start_loc = b_seq_len.cumsum(0) - b_seq_len + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + + o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_start_loc = b_start_loc + + ref_token_attention_nopad( + q, + kv[:, :KV_HEADS, :], + kv[:, KV_HEADS:, :], + o, + Q_HEADS, + HEAD_DIM, + infer_state, + req_to_token_indexs, + ) + # gqa_decode_attention_fwd( + # q, + # kv[:,:KV_HEADS,:], + # kv[:,KV_HEADS:,:], + # o, + # req_to_token_indexs, + # infer_state.b_req_idx, + # infer_state.b_seq_len, + # ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + page_size = 1 + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) + kv_starts = torch.zeros((Z + 1,)).int().cuda() + kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) + kv_indptr = kv_starts + kv_indices = torch.arange(Z * N_CTX).cuda().int() + for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): + kv_indices[start : start + sl] = req_to_token_indexs[b][:sl] + kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) + wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=True, + paged_kv_indptr_buffer=kv_indptr, + paged_kv_indices_buffer=kv_indices, + paged_kv_last_page_len_buffer=kv_last_page_len_buffer, + ) + kv_last_page_len_buffer = torch.full((batch_size,), page_size, dtype=torch.int32) + k_cache = kv[:, :KV_HEADS, :].contiguous() + v_cache = kv[:, KV_HEADS:, :].contiguous() + k, k_scale = scaled_fp8_quant(k_cache.view(1, -1)) + v, v_scale = scaled_fp8_quant(v_cache.view(1, -1)) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_len_buffer, + q_heads, + kv_heads, + head_dim, + page_size, + q_data_type=dtype, + kv_data_type=torch.float8_e4m3fn, + non_blocking=True, + ) + wrapper.run( + q, + (k.view(-1, 1, kv_heads, head_dim), v.view(-1, 1, kv_heads, head_dim)), + k_scale=k_scale, + v_scale=v_scale, + out=o1, + return_lse=False, + ) + + cos_sim1 = F.cosine_similarity(o, o1).mean() + print(cos_sim1) + assert cos_sim1 == 1.0 + + +if __name__ == "__main__": + test_token_attention_nopad_flashinfer_fp8(16, 16384, 28, 4, 128)