Skip to content

Commit aa6acb3

Browse files
author
niushengxiao
committed
feat: kv fp8 quant calibration for fa3
1 parent 364618c commit aa6acb3

22 files changed

+198
-1636
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .offline_fp8_quant_mem_manager import OfflineFP8QuantMemManager
2+
3+
4+
class CalibrationFP8KVMemoryManager(OfflineFP8QuantMemManager):
5+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
6+
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_export_mode=False)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .offline_fp8_quant_mem_manager import OfflineFP8QuantMemManager
2+
3+
4+
class ExportCalibrationMemoryManager(OfflineFP8QuantMemManager):
5+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
6+
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_export_mode=True)

lightllm/common/fp8kv_mem_manager.py

Lines changed: 0 additions & 9 deletions
This file was deleted.

lightllm/common/mem_manager.py

Lines changed: 0 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import re
22
import os
3-
import json
43
import torch
54
import torch.distributed as dist
65
from typing import List, Union
@@ -13,155 +12,10 @@
1312
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
1413
from lightllm.distributed.pynccl import PyNcclCommunicator
1514
from lightllm.utils.dist_utils import get_current_device_id
16-
from lightllm.utils.envs_utils import get_kv_quant_calibration_inference_count
17-
from lightllm.utils.envs_utils import get_kv_quant_calibration_warmup_count
18-
from lightllm.utils.dist_utils import get_global_rank
19-
from lightllm.utils.config_utils import get_model_architectures
2015

2116
logger = init_logger(__name__)
2217

2318

24-
class OfflineFP8QuantManager:
25-
def __init__(self, layer_num, head_num):
26-
self.qmin = torch.finfo(torch.float8_e4m3fn).min
27-
self.qmax = torch.finfo(torch.float8_e4m3fn).max
28-
self.model_arch = get_model_architectures(get_env_start_args().model_dir)
29-
self.layer_num = layer_num
30-
self.head_num = head_num
31-
self.total_head_num = head_num * dist.get_world_size() if dist.is_initialized() else head_num
32-
self.scales_shape = [layer_num, 2 * head_num] if get_env_start_args().enable_fa3 else [layer_num, 2]
33-
self.scales = None
34-
self.scales_list = []
35-
self.abs_max = None
36-
self.warmup_counts = get_kv_quant_calibration_warmup_count()
37-
self.inference_counts = get_kv_quant_calibration_inference_count()
38-
self.count = 0
39-
self.enable_calib = False
40-
if get_env_start_args().export_kv_quant_calibration:
41-
self.abs_max = torch.zeros(self.scales_shape, dtype=torch.float32, device="cuda")
42-
elif get_env_start_args().kv_quant_calibration_config_path is not None:
43-
logger.info(
44-
f"kv_quant_calibration_config_path {get_env_start_args().kv_quant_calibration_config_path} is set, "
45-
"will load kv quant calibration config"
46-
)
47-
if os.path.exists(get_env_start_args().kv_quant_calibration_config_path):
48-
with open(get_env_start_args().kv_quant_calibration_config_path, "r") as f:
49-
cfg = json.load(f)
50-
51-
if cfg["architectures"] != self.model_arch:
52-
raise ValueError(
53-
f"architectures {cfg['architectures']} in config "
54-
f"not match current model_arch {self.model_arch}"
55-
)
56-
if cfg["num_layers"] != layer_num:
57-
raise ValueError(
58-
f"num_layers {cfg['num_layers']} in config " f"not match current layer_num {layer_num}"
59-
)
60-
if cfg["num_head"] != self.total_head_num:
61-
raise ValueError(
62-
f"num_head {cfg['num_head']} in config "
63-
f"not match current model head num {self.total_head_num}"
64-
)
65-
if get_env_start_args().enable_fa3:
66-
if cfg["quant_type"] != "per_head":
67-
raise ValueError(f"quant type {cfg['num_head']} in config not match fa3 backend")
68-
else:
69-
if cfg["quant_type"] != "per_tensor":
70-
raise ValueError(f"quant type {cfg['quant_type']} in config not match flashinfer backend")
71-
72-
self.qmin = cfg["qmin"]
73-
self.qmax = cfg["qmax"]
74-
self.scales_shape = cfg["scales_shape"]
75-
76-
full_scales_list = cfg["scales"]
77-
self.scales_list = full_scales_list
78-
self.scales = torch.tensor(self.scales_list, dtype=torch.float32, device="cuda").view(self.scales_shape)
79-
if not get_env_start_args().enable_fa3:
80-
self.scales = torch.repeat_interleave(self.scales, self.head_num, dim=-1)
81-
if get_env_start_args().enable_fa3 and dist.is_initialized() and dist.get_world_size() > 1:
82-
half_head = self.total_head_num // 2
83-
start_head = dist.get_rank() * head_num
84-
end_head = start_head + head_num
85-
k_scales = self.scales[:, start_head:end_head].contiguous()
86-
v_scales = self.scales[:, start_head + half_head : end_head + half_head].contiguous()
87-
current_scales = torch.cat((k_scales, v_scales), dim=-1)
88-
89-
self.scales_list = current_scales.tolist()
90-
self.scales = current_scales
91-
else:
92-
raise FileNotFoundError(
93-
f"kv_quant_calibration_config {get_env_start_args().kv_quant_calibration_config_path} not found"
94-
)
95-
elif "calibration_fp8kv" in get_env_start_args().mode:
96-
logger.warning("scales is None, no kv_quant_calibration_config_path be set")
97-
98-
def enable_calibration(self):
99-
assert get_env_start_args().disable_cudagraph, "Calibration is not supported in cudagraph mode"
100-
logger.info("Enable kv cache calibration, will collect kv cache data for quantization calibration")
101-
self.enable_calib = True
102-
103-
def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int):
104-
if not self.enable_calib or self.count >= self.warmup_counts + self.inference_counts:
105-
return
106-
107-
if self.abs_max is not None and self.count >= self.warmup_counts:
108-
if get_env_start_args().enable_fa3:
109-
kv_max = kv_buffer.abs().amax(dim=(0, 2)).to(torch.float32)
110-
else:
111-
k_max = kv_buffer[:, : self.head_num, :].abs().amax(dim=()).to(torch.float32)
112-
v_max = kv_buffer[:, self.head_num :, :].abs().amax(dim=()).to(torch.float32)
113-
kv_max = torch.tensor([k_max, v_max], device="cuda", dtype=torch.float32)
114-
self.abs_max[layer_index] = torch.maximum(self.abs_max[layer_index], kv_max)
115-
if self.count == self.warmup_counts + self.inference_counts - 1 and layer_index == self.layer_num - 1:
116-
final_abs_max = self.abs_max
117-
if dist.is_initialized() and dist.get_world_size() > 1:
118-
if get_env_start_args().enable_fa3:
119-
k_max, v_max = torch.chunk(self.abs_max, 2, dim=-1)
120-
k_max = k_max.contiguous()
121-
v_max = v_max.contiguous()
122-
gathered_k_max = [torch.zeros_like(k_max) for _ in range(dist.get_world_size())]
123-
gathered_v_max = [torch.zeros_like(v_max) for _ in range(dist.get_world_size())]
124-
dist.all_gather(gathered_k_max, k_max, group=None, async_op=False)
125-
dist.all_gather(gathered_v_max, v_max, group=None, async_op=False)
126-
k_max = torch.cat(gathered_k_max, dim=-1)
127-
v_max = torch.cat(gathered_v_max, dim=-1)
128-
final_abs_max = torch.cat((k_max, v_max), dim=-1)
129-
else:
130-
dist.all_reduce(self.abs_max, op=dist.ReduceOp.MAX, group=None, async_op=False)
131-
132-
self.scales = final_abs_max / self.qmax
133-
self.scales = torch.where(self.scales > 0, self.scales, torch.ones_like(self.scales))
134-
135-
if get_global_rank() == 0:
136-
self.abs_max = final_abs_max
137-
self._export_calibration_data()
138-
139-
if layer_index == self.layer_num - 1:
140-
self.count += 1
141-
142-
def _export_calibration_data(self):
143-
cfg = {
144-
"version": "1.0",
145-
"architectures": self.model_arch,
146-
"quant_type": "per_head" if get_env_start_args().enable_fa3 else "per_tensor",
147-
"qmin": self.qmin,
148-
"qmax": self.qmax,
149-
"num_layers": self.layer_num,
150-
"num_head": self.total_head_num,
151-
"scales_shape": list(self.abs_max.shape),
152-
"scales": self.scales.cpu().numpy().tolist(),
153-
}
154-
with open("./kv_cache_calib.json", "w") as f:
155-
json.dump(cfg, f, indent=4)
156-
logger.info(
157-
f"Export kv cache calibration data to kv_cache_calib.json, "
158-
f"architectures: {self.model_arch}, "
159-
f"qmin: {self.qmin}, qmax: {self.qmax}, "
160-
f"total heads: {self.total_head_num}, "
161-
f"scales_shape: {list(self.abs_max.shape)}, "
162-
)
163-
164-
16519
class MemoryManager:
16620
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
16721
self.size = size
@@ -198,7 +52,6 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
19852
layer_num,
19953
)
20054
self.HOLD_TOKEN_MEMINDEX = self.size
201-
self.offline_fp8_quant_manager = OfflineFP8QuantManager(layer_num, head_num)
20255

20356
def get_cell_size(self):
20457
return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype)

lightllm/common/mem_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from lightllm.common.mem_manager import MemoryManager
22
from lightllm.common.int8kv_mem_manager import INT8KVMemoryManager
3-
from lightllm.common.fp8kv_mem_manager import FP8KVMemoryManager
3+
from lightllm.common.calibration_fp8kv_mem_manager import CalibrationFP8KVMemoryManager
4+
from lightllm.common.export_calibration_mem_manager import ExportCalibrationMemoryManager
45
from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
56
from lightllm.common.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
67
from lightllm.utils.log_utils import init_logger
@@ -22,7 +23,10 @@ def select_mem_manager_class(mode):
2223
elif "triton_fp8kv" in mode:
2324
raise Exception("currently only for deepseek")
2425
elif "calibration_fp8kv" in mode:
25-
memory_manager_class = FP8KVMemoryManager
26+
memory_manager_class = CalibrationFP8KVMemoryManager
27+
logger.info("Model kv cache using mode calibration fp8kv")
28+
elif "export_fp8kv_calibration" in mode:
29+
memory_manager_class = ExportCalibrationMemoryManager
2630
logger.info("Model kv cache using mode calibration fp8kv")
2731
else:
2832
memory_manager_class = MemoryManager
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import os
2+
import json
3+
import torch
4+
import torch.distributed as dist
5+
from lightllm.utils.envs_utils import get_kv_quant_calibration_inference_count
6+
from lightllm.utils.envs_utils import get_kv_quant_calibration_warmup_count
7+
from lightllm.utils.dist_utils import get_global_rank
8+
from lightllm.utils.config_utils import get_model_architectures
9+
from lightllm.utils.log_utils import init_logger
10+
from lightllm.utils.envs_utils import get_env_start_args
11+
12+
logger = init_logger(__name__)
13+
14+
from .mem_manager import MemoryManager
15+
16+
17+
class OfflineFP8QuantMemManager(MemoryManager):
18+
def __init__(
19+
self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_export_mode=False
20+
):
21+
# 这里用uint8存储量化后的kv,方便兼容各种torch算子。fp8量化目前采用离线方案,kv_buffer不存储scale
22+
super().__init__(
23+
size, dtype if is_export_mode else torch.uint8, head_num, head_dim, layer_num, always_copy, mem_fraction
24+
)
25+
26+
self.qmin = torch.finfo(torch.float8_e4m3fn).min
27+
self.qmax = torch.finfo(torch.float8_e4m3fn).max
28+
self.model_arch = get_model_architectures(get_env_start_args().model_dir)
29+
self.layer_num = layer_num
30+
self.head_num = head_num
31+
self.total_head_num = head_num * dist.get_world_size() if dist.is_initialized() else head_num
32+
self.scales_shape = [layer_num, 2 * head_num]
33+
self.scales = None
34+
self.scales_list = []
35+
self.abs_max = None
36+
self.warmup_counts = get_kv_quant_calibration_warmup_count()
37+
self.inference_counts = get_kv_quant_calibration_inference_count()
38+
self.count = 0
39+
self.enable_calib = False
40+
self.is_export_mode = is_export_mode
41+
if is_export_mode:
42+
self.abs_max = torch.zeros(self.scales_shape, dtype=torch.float32, device="cuda")
43+
elif get_env_start_args().kv_quant_calibration_config_path is not None:
44+
logger.info(
45+
f"kv_quant_calibration_config_path {get_env_start_args().kv_quant_calibration_config_path} is set, "
46+
"will load kv quant calibration config"
47+
)
48+
if os.path.exists(get_env_start_args().kv_quant_calibration_config_path):
49+
with open(get_env_start_args().kv_quant_calibration_config_path, "r") as f:
50+
cfg = json.load(f)
51+
52+
if cfg["architectures"] != self.model_arch:
53+
raise ValueError(
54+
f"architectures {cfg['architectures']} in config "
55+
f"not match current model_arch {self.model_arch}"
56+
)
57+
if cfg["num_layers"] != layer_num:
58+
raise ValueError(
59+
f"num_layers {cfg['num_layers']} in config " f"not match current layer_num {layer_num}"
60+
)
61+
if cfg["num_head"] != self.total_head_num:
62+
raise ValueError(
63+
f"num_head {cfg['num_head']} in config "
64+
f"not match current model head num {self.total_head_num}"
65+
)
66+
if get_env_start_args().enable_fa3:
67+
if cfg["quant_type"] != "per_head":
68+
raise ValueError(f"quant type {cfg['num_head']} in config not match fa3 backend")
69+
else:
70+
raise ValueError("only support per_head quant type for fa3 backend, use --enable_fa3 in start args")
71+
72+
self.qmin = cfg["qmin"]
73+
self.qmax = cfg["qmax"]
74+
self.scales_shape = cfg["scales_shape"]
75+
76+
full_scales_list = cfg["scales"]
77+
self.scales_list = full_scales_list
78+
self.scales = torch.tensor(self.scales_list, dtype=torch.float32, device="cuda").view(self.scales_shape)
79+
if dist.is_initialized() and dist.get_world_size() > 1:
80+
half_head = self.total_head_num // 2
81+
start_head = dist.get_rank() * head_num
82+
end_head = start_head + head_num
83+
k_scales = self.scales[:, start_head:end_head].contiguous()
84+
v_scales = self.scales[:, start_head + half_head : end_head + half_head].contiguous()
85+
current_scales = torch.cat((k_scales, v_scales), dim=-1)
86+
87+
self.scales_list = current_scales.tolist()
88+
self.scales = current_scales
89+
else:
90+
raise FileNotFoundError(
91+
f"kv_quant_calibration_config {get_env_start_args().kv_quant_calibration_config_path} not found"
92+
)
93+
else:
94+
logger.warning("scales is None, no kv_quant_calibration_config_path be set")
95+
96+
def enable_calibration(self):
97+
assert (
98+
get_env_start_args().enable_fa3
99+
), "Calibration is only supported in fa3 backend, use --enable_fa3 in start args"
100+
assert self.is_export_mode, "Calibration is only supported in export mode"
101+
assert get_env_start_args().disable_cudagraph, "Calibration is not supported in cudagraph mode"
102+
logger.info("Enable kv cache calibration, will collect kv cache data for quantization calibration")
103+
self.enable_calib = True
104+
105+
def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int):
106+
if not self.enable_calib or self.count >= self.warmup_counts + self.inference_counts:
107+
return
108+
109+
if self.abs_max is not None and self.count >= self.warmup_counts:
110+
kv_max = kv_buffer.abs().amax(dim=(0, 2)).to(torch.float32)
111+
self.abs_max[layer_index] = torch.maximum(self.abs_max[layer_index], kv_max)
112+
if self.count == self.warmup_counts + self.inference_counts - 1 and layer_index == self.layer_num - 1:
113+
final_abs_max = self.abs_max
114+
if dist.is_initialized() and dist.get_world_size() > 1:
115+
k_max, v_max = torch.chunk(self.abs_max, 2, dim=-1)
116+
k_max = k_max.contiguous()
117+
v_max = v_max.contiguous()
118+
gathered_k_max = [torch.zeros_like(k_max) for _ in range(dist.get_world_size())]
119+
gathered_v_max = [torch.zeros_like(v_max) for _ in range(dist.get_world_size())]
120+
dist.all_gather(gathered_k_max, k_max, group=None, async_op=False)
121+
dist.all_gather(gathered_v_max, v_max, group=None, async_op=False)
122+
k_max = torch.cat(gathered_k_max, dim=-1)
123+
v_max = torch.cat(gathered_v_max, dim=-1)
124+
final_abs_max = torch.cat((k_max, v_max), dim=-1)
125+
126+
self.scales = final_abs_max / self.qmax
127+
self.scales = torch.where(self.scales > 0, self.scales, torch.ones_like(self.scales))
128+
129+
if get_global_rank() == 0:
130+
self.abs_max = final_abs_max
131+
self._export_calibration_data()
132+
133+
if layer_index == self.layer_num - 1:
134+
self.count += 1
135+
136+
def _export_calibration_data(self):
137+
cfg = {
138+
"version": "1.0",
139+
"architectures": self.model_arch,
140+
"quant_type": "per_head",
141+
"qmin": self.qmin,
142+
"qmax": self.qmax,
143+
"num_layers": self.layer_num,
144+
"num_head": self.total_head_num,
145+
"scales_shape": list(self.abs_max.shape),
146+
"scales": self.scales.cpu().numpy().tolist(),
147+
}
148+
with open("./kv_cache_calib.json", "w") as f:
149+
json.dump(cfg, f, indent=4)
150+
logger.info(
151+
f"Export kv cache calibration data to kv_cache_calib.json, "
152+
f"architectures: {self.model_arch}, "
153+
f"qmin: {self.qmin}, qmax: {self.qmax}, "
154+
f"total heads: {self.total_head_num}, "
155+
f"scales_shape: {list(self.abs_max.shape)}, "
156+
)

lightllm/models/llama/flashattention_infer_struct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
6161
self.page_table[:, max_seq_len_k:].fill_(0)
6262

6363
if "calibration_fp8kv" in model.mode:
64-
offline_scales = self.mem_manager.offline_fp8_quant_manager.scales
64+
offline_scales = self.mem_manager.scales
6565
head_num = self.mem_manager.head_num
6666
self.k_descale = (
6767
offline_scales[:, :head_num]

0 commit comments

Comments
 (0)