diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 601b2a48a..3e37142b1 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -205,6 +205,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument("--disable_dynamic_prompt_cache", action="store_true", help="disable dynamic prompt cache") parser.add_argument("--chunked_prefill_size", type=int, default=4096, help="chunked prefill size") + parser.add_argument("--use_hi_dynamic_prompt_cache", action="store_true", help="enable hierachy prompt cache") parser.add_argument("--disable_chunked_prefill", action="store_true", help="whether to disable chunked prefill") parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode") parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode") @@ -311,7 +312,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--visual_infer_batch_size", type=int, default=1, help="number of images to process in each inference batch" ) parser.add_argument( - "--visual_gpu_ids", nargs="+", type=int, default=None, help="List of GPU IDs to use, e.g., 0 1 2" + "--visual_gpu_ids", nargs="+", type=int, default=[0, 1, 2, 3, 4, 5, 6, 7], help="List of GPU IDs to use, e.g., 0 1 2" ) parser.add_argument("--visual_tp", type=int, default=1, help="number of tensort parallel instances for ViT") parser.add_argument("--visual_dp", type=int, default=1, help="number of data parallel instances for ViT") diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index de1e690a2..3a4093482 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -158,6 +158,10 @@ def normal_or_p_d_start(args): args.batch_max_tokens >= args.chunked_prefill_size ), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size" + # if use_hi_dynamic_prompt_cache, then use_dynamic_prompt_cache must be True + if args.use_hi_dynamic_prompt_cache: + assert not args.disable_dynamic_prompt_cache, "use_hi_dynamic_prompt_cache must be used with use_dynamic_prompt_cache" + # help to manage data stored on Ceph if "s3://" in args.model_dir: from lightllm.utils.petrel_helper import s3_model_prepare diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 8a43d983d..ac76d27ed 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -45,6 +45,7 @@ class StartArgs: router_max_wait_tokens: int = field(default=6) disable_aggressive_schedule: bool = field(default=False) disable_dynamic_prompt_cache: bool = field(default=False) + use_hi_dynamic_prompt_cache: bool = field(default=False) chunked_prefill_size: int = field(default=8192) disable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py new file mode 100644 index 000000000..31a306f67 --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -0,0 +1,128 @@ +import torch +import time +import tempfile +import numpy as np +import torch.distributed as dist +from os.path import join +from .radix_cache import RadixCache, TreeNode, match +from typing import Tuple, Dict, Set, List +from lightllm.common.mem_manager import MemoryManager +from lightllm.utils.log_utils import init_logger +from threading import Lock +from enum import Enum +from .shared_arr import SharedArray +from kvcache.python.jit import PyLocalCacheService + +logger = init_logger(__name__) + +def wait_until_ready(task, timeout=10.0, check_interval=0.01): + start_time = time.time() + while not task.ready(): + time.sleep(check_interval) + if time.time() - start_time > timeout: + logger.error("Current kv cache task not ready in time") + return False + return True + +class LocalCacheManager: + + def __init__(self, unique_name: str, rank_in_node: int, mem_manager): + tmp_dir = tempfile.mkdtemp(prefix=f"cache_{unique_name}_{rank_in_node}") + self.cache_file = join(tmp_dir, "cache_file") + all_buffers = mem_manager.kv_buffer + all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) + + self.py_cache_service = PyLocalCacheService( + file=self.cache_file, + storage_size=128 * (1024 ** 3), # 128GB + num_shard=32, + kvcache_tensor=all_buffers, + num_worker=8 + ) + + def insert(self, tokens, kv_page_indexer, start_pos=0): + t = self.py_cache_service.create( + tokens=tokens, + kv_page_indexer=kv_page_indexer, + mode="w", + start_pos=start_pos) + res = wait_until_ready(t) + if not res: + self.py_cache_service.az5(t) + + def read(self, tokens, kv_page_indexer, start_pos=0): + t = self.py_cache_service.create( + tokens=tokens, + kv_page_indexer=kv_page_indexer, + mode="r", + start_pos=start_pos) + res = wait_until_ready(t) + return res + + def query(self, tokens): + query_result = self.py_cache_service.query(tokens) + max_len = 0 + for result in query_result: + if result: + max_len += 1 + else: + break + return max_len * self.block_size + + @property + def block_size(self,): + return self.py_cache_service.tokens_per_block + +class HiRadixCache(RadixCache): + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager): + super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) + self.rank_in_node = rank_in_node + self.local_cache_manager = LocalCacheManager( + unique_name, + rank_in_node, + mem_manager, + ) + self.is_hi_radix_cache = True + self.disk_cache_match_count = SharedArray(f"{unique_name}_disk_cache_match_count_{rank_in_node}", (1,), dtype=np.int64) + self.disk_cache_match_count.arr[0] = 0 + self.total_match_count = SharedArray(f"{unique_name}_total_match_count_{rank_in_node}", (1,), dtype=np.int64) + self.total_match_count.arr[0] = 0 + self.disk_cache_match_ratio = SharedArray(f"{unique_name}_disk_cache_match_ratio_{rank_in_node}", (1,), dtype=np.float32) + self.disk_cache_match_ratio.arr[0] = 0.0 + logger.info(f"Initializing HiRadixCache {rank_in_node}") + + def insert(self, key, value=None): + share_len = super().insert(key, value) + if share_len == 0: + return 0 + self.local_cache_manager.insert(key, value) + return share_len + + def match_prefix(self, key, update_refs=False): + assert len(key) != 0 + self.total_match_count.arr[0] += 1 + ans_value_list = [] + ans_value = None + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False) + if tree_node.node_prefix_total_len != 0: + ans_value = torch.concat(ans_value_list) + max_len = 0 + if tree_node.node_prefix_total_len < len(key): + max_len = self.local_cache_manager.query(key) + if max_len > tree_node.node_prefix_total_len: + pull_len = max_len - tree_node.node_prefix_total_len + self.disk_cache_match_count.arr[0] += 1 + self.disk_cache_match_ratio.arr[0] = self.disk_cache_match_count.arr[0] / self.total_match_count.arr[0] + self.free_radix_cache_to_get_enough_token(pull_len) + buffers = self.mem_manager.alloc(pull_len) + start_pos = 0 + if ans_value is not None: + buffers = torch.concat([ans_value, buffers]) + start_pos = (tree_node.node_prefix_total_len - 1) // self.local_cache_manager.block_size * self.local_cache_manager.block_size + logger.debug(f"HiCache current match ratio {self.disk_cache_match_ratio.arr[0]}, pulled cache len {pull_len} from disk") + res = self.local_cache_manager.read(tokens=key[:max_len], kv_page_indexer=buffers, start_pos=start_pos) + if res: + super().insert(key[:max_len], buffers) + else: + self.mem_manager.free(buffers[tree_node.node_prefix_total_len:]) + return super().match_prefix(key, update_refs=update_refs) diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index aeffd3a67..45dd50099 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -116,6 +116,8 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager: Memo ) self.tree_total_tokens_num.arr[0] = 0 + self.is_hi_radix_cache = False + def insert(self, key, value=None): if value is None: value = key diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 45e82ff3d..d9387a1d9 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -167,6 +167,7 @@ async def wait_to_model_ready(self): "return_all_prompt_logprobs": self.args.return_all_prompt_logprobs, "use_reward_model": self.args.use_reward_model, "disable_dynamic_prompt_cache": self.args.disable_dynamic_prompt_cache, + "use_hi_dynamic_prompt_cache": self.args.use_hi_dynamic_prompt_cache, "data_type": self.args.data_type, "eos_id": self.eos_id, "diverse_mode": self.args.diverse_mode, diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 10b68245c..0774244b6 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -110,6 +110,7 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finis self.radix_cache.dec_node_ref_counter(req.shared_kv_node) req.shared_kv_node = None + def _save_promptcache_kvbuffer(self): """ save prompt cache kv buffer diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 329dc9f3b..50b4aa547 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -12,6 +12,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.models import get_model from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache +from lightllm.server.router.dynamic_prompt.hiradix_cache import HiRadixCache from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock @@ -53,6 +54,7 @@ def init_model(self, kvargs): self.chunked_prefill_size = self.args.chunked_prefill_size self.return_all_prompt_logprobs = self.args.return_all_prompt_logprobs self.use_dynamic_prompt_cache = not self.args.disable_dynamic_prompt_cache + self.use_hi_dynamic_prompt_cache = self.args.use_hi_dynamic_prompt_cache self.eos_id: List[int] = kvargs.get("eos_id", [2]) self.disable_cudagraph = self.args.disable_cudagraph @@ -118,7 +120,14 @@ def init_model(self, kvargs): self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) self.radix_cache = ( - RadixCache( + HiRadixCache( + get_unique_server_name(), + self.model.mem_manager.size, + self.rank_in_node, + mem_manager=self.model.mem_manager + ) + if self.use_dynamic_prompt_cache and self.use_hi_dynamic_prompt_cache + else RadixCache( get_unique_server_name(), self.model.mem_manager.size, self.rank_in_node, diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 00528fec7..b83b68e97 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -59,6 +59,7 @@ def decode(self): prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal ) model_output = self.model.forward(model_input) + self.store_hicache_after_prefill(run_reqs) self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py index b0eb2b58f..38f785501 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py @@ -35,6 +35,7 @@ def decode(self): ) model_output = self.model.forward(model_input) logits = model_output.logits + self.store_hicache_after_prefill(run_reqs) self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True diff --git a/test/server/test_hicache.py b/test/server/test_hicache.py new file mode 100644 index 000000000..bb82457c4 --- /dev/null +++ b/test/server/test_hicache.py @@ -0,0 +1,155 @@ +# test_hicache.py +import torch +import time +import random +from threading import Thread, Event +from queue import Queue +from lightllm.server.router.dynamic_prompt.cache_controller import ( + HiCacheController, + CacheNode, + BLOCK_SIZE, + HiHostService, + HiHostTask, +) + + +class MockMemoryManager: + """模拟内存管理器,仅返回连续的索引值""" + + def __init__(self): + self.current_idx = 0 + self.kvcache_store = {} + + def alloc(self, size): + indices = list(range(self.current_idx, self.current_idx + size)) + self.current_idx += size + self.store(indices, torch.tensor([[random.randint(0, 0xFFFF) for __ in range(512)] for _ in range(size)])) + return indices + + def load_index_kv_buffer(self, index, load_tensor_dict): + self.kvcache_store[index] = load_tensor_dict["kv_buffer"] + + def get_index_kv_buffer(self, index): + return {"kv_buffer": self.kvcache_store[index]} + + def to_kvcache(self, indices): + assert all( + [idx in self.kvcache_store for idx in indices] + ), f"Not all of {indices} are not found in kvcache_store" + return torch.tensor([self.kvcache_store[idx].tolist() for idx in indices]) + + def store(self, indices, value): + print(f"[TEST:MemManager] Storing {value.shape} at {indices}") + for idx, value_dim in zip(indices, range(value.shape[0])): + self.kvcache_store[idx] = value[value_dim] + print(f"[TEST:MemManager] Stored {value[value_dim].shape} at {idx}") + return indices + + def free(self, indices): + print(f"[TEST:MemManager] Freeing {indices}") + for idx in indices: + del self.kvcache_store[idx] + + +def setup(): + mem_manager = MockMemoryManager() + service = HiHostService() + hicache = HiCacheController(mem_manager) + hicache.service = service # 注入模拟服务 + + indices = mem_manager.alloc(5) + print(mem_manager.to_kvcache(indices)) + + # 预先计算单token大小 + dummy_indices = mem_manager.alloc(1) + kvcache = mem_manager.to_kvcache(dummy_indices[:1]) + token_size = kvcache.nelement() * kvcache.element_size() + print(f"[TEST] Single token KV cache size: {token_size} bytes, Block size: {BLOCK_SIZE}") + + return mem_manager, service, hicache, token_size + + +def test_basic_write_read(mem_manager, hicache, token_size): + # 计算每个块可容纳的token数量 + tokens_per_block = BLOCK_SIZE // token_size + print(f"[TEST] Each block can hold {tokens_per_block} tokens") + + # 生成测试数据:刚好占满一个块 + token_ids = list(range(tokens_per_block)) + indices = mem_manager.alloc(len(token_ids)) + kvcache = mem_manager.to_kvcache(indices) + print(f"[TEST] Generated KV cache with shape: {kvcache.shape}, type: {kvcache.dtype}") + + # 写入缓存 + hicache.write(torch.tensor(token_ids), torch.tensor(indices)) + time.sleep(2) + + # 等待任务完成 + hicache.service.wait_till_all_finished() + + mem_manager.free(indices) + + # 读取验证 + result = hicache.read(torch.tensor(token_ids)) + result = mem_manager.to_kvcache(result.tolist()) + assert result.eq(kvcache).all(), f"Retrieved kvcache: {result}, Expected kvcache: {kvcache}" + print("[TEST] Basic test passed. Retrieved kvcache\n\n") + + +def test_node_splitting(mem_manager, hicache, token_size): + tokens_per_block = BLOCK_SIZE // token_size + # 生成超过一个块的数据 + token_ids = list(range(12, 12 + tokens_per_block * 3 + 1)) + indices = mem_manager.alloc(len(token_ids)) + kvcache = mem_manager.to_kvcache(indices) + + hicache.write(torch.tensor(token_ids), torch.tensor(indices)) + time.sleep(2) + hicache.service.wait_till_all_finished() + + # 验证根节点应该有子节点 + root = hicache.root + assert len(root.children) > 0 + print(f"\nRoot node has {len(root.children)} children") + + # 读取完整序列 + result = hicache.read(torch.tensor(token_ids)) + result = mem_manager.to_kvcache(result.tolist()) + assert result.eq(kvcache).all(), f"Retrieved kvcache: {result}, Expected kvcache: {kvcache}" + print(f"[TEST] Node splitting test passed. Retrieved kvcache: {result.shape}\n\n") + + +def test_partial_read(mem_manager, hicache): + token_ids = [97, 98, 99, 100, 101, 102] + indices = mem_manager.alloc(len(token_ids)) + kvcache = mem_manager.to_kvcache(indices) + hicache.write(torch.tensor(token_ids), torch.tensor(indices)) + time.sleep(2) + hicache.service.wait_till_all_finished() + + # 查询存在的部分前缀 + result = hicache.read(torch.tensor([97, 98, 99])) + result = mem_manager.to_kvcache(result.tolist()) + assert result.eq(kvcache[:3]).all() + print("[TEST] Partial read passed") + + # 查询不存在的前缀 + result = hicache.read(torch.tensor([97, 98, 100])) + assert len(result) == 2 + result = mem_manager.to_kvcache(result.tolist()) + assert result.eq(kvcache[:2]).all() + print(f"[TEST] Non-existent prefix returned: {result.tolist()}") + + +def main(): + mem_manager, service, hicache, token_size = setup() + try: + test_basic_write_read(mem_manager, hicache, token_size) + test_node_splitting(mem_manager, hicache, token_size) + test_partial_read(mem_manager, hicache) + finally: + service.shutdown() + + +if __name__ == "__main__": + main()