From 6e0d4b4982200a8cf7874d75bdc58267c256abc2 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Wed, 15 Oct 2025 12:31:18 -0700 Subject: [PATCH 1/7] [router] add l0+l1 tokenization cache --- sgl-router/Cargo.toml | 2 + sgl-router/benches/tokenizer_benchmark.rs | 499 ++++++++++++++++- .../py_src/sglang_router/router_args.py | 29 + sgl-router/src/config/types.rs | 51 ++ sgl-router/src/config/validation.rs | 26 + sgl-router/src/lib.rs | 22 + sgl-router/src/main.rs | 20 +- sgl-router/src/server.rs | 27 +- sgl-router/src/tokenizer/cache/fingerprint.rs | 103 ++++ sgl-router/src/tokenizer/cache/l0.rs | 216 ++++++++ sgl-router/src/tokenizer/cache/l1.rs | 501 ++++++++++++++++++ sgl-router/src/tokenizer/cache/mod.rs | 358 +++++++++++++ sgl-router/src/tokenizer/huggingface.rs | 24 +- sgl-router/src/tokenizer/mock.rs | 8 +- sgl-router/src/tokenizer/mod.rs | 2 + sgl-router/tests/api_endpoints_test.rs | 4 + sgl-router/tests/responses_api_test.rs | 10 + sgl-router/tests/test_pd_routing.rs | 1 + 18 files changed, 1880 insertions(+), 23 deletions(-) create mode 100644 sgl-router/src/tokenizer/cache/fingerprint.rs create mode 100644 sgl-router/src/tokenizer/cache/l0.rs create mode 100644 sgl-router/src/tokenizer/cache/l1.rs create mode 100644 sgl-router/src/tokenizer/cache/mod.rs diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index 0557f06602f..9d3602dcaf4 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -38,6 +38,7 @@ futures-util = "0.3" futures = "0.3" pyo3 = { version = "0.25.1", features = ["extension-module"] } dashmap = "6.1.0" +blake3 = "1.5" http = "1.1.0" tokio = { version = "1.42.0", features = ["full"] } async-trait = "0.1" @@ -53,6 +54,7 @@ metrics-exporter-prometheus = "0.17.0" uuid = { version = "1.10", features = ["v4", "serde"] } ulid = "1.2.1" parking_lot = "0.12.4" +rayon = "1.10" thiserror = "2.0.12" regex = "1.10" url = "2.5.4" diff --git a/sgl-router/benches/tokenizer_benchmark.rs b/sgl-router/benches/tokenizer_benchmark.rs index 6830d679782..1175bfb6a00 100644 --- a/sgl-router/benches/tokenizer_benchmark.rs +++ b/sgl-router/benches/tokenizer_benchmark.rs @@ -14,19 +14,34 @@ use std::{ use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput}; use sglang_router_rs::tokenizer::{ - huggingface::HuggingFaceTokenizer, sequence::Sequence, stop::*, stream::DecodeStream, traits::*, + cache::{CacheConfig, CachedTokenizer}, + huggingface::HuggingFaceTokenizer, + sequence::Sequence, + stop::*, + stream::DecodeStream, + traits::*, }; -// Include the common test utilities -#[path = "../tests/common/mod.rs"] -mod common; -use common::ensure_tokenizer_cached; - // Cache the tokenizer path for the entire benchmark run static TOKENIZER_PATH: OnceLock = OnceLock::new(); fn get_tokenizer_path() -> &'static PathBuf { - TOKENIZER_PATH.get_or_init(ensure_tokenizer_cached) + TOKENIZER_PATH.get_or_init(|| { + // Use Qwen3-4B-Instruct which has ChatML special tokens (<|im_start|>, <|im_end|>) + // with special: true, normalized: false - perfect for demonstrating L1 cache + let rt = tokio::runtime::Runtime::new().expect("Failed to create tokio runtime"); + let tokenizer_dir = rt.block_on(async { + sglang_router_rs::tokenizer::hub::download_tokenizer_from_hf( + "Qwen/Qwen3-4B-Instruct-2507", + ) + .await + .expect("Failed to download Qwen3-4B-Instruct tokenizer from HuggingFace") + }); + + // The download_tokenizer_from_hf returns the directory containing tokenizer.json + // We need to construct the full path to tokenizer.json + tokenizer_dir.join("tokenizer.json") + }) } // Production target: 100k tokens per second @@ -1253,6 +1268,468 @@ fn bench_scaling_characteristics(c: &mut Criterion) { group.finish(); } +fn bench_l1_cache_chat_template(c: &mut Criterion) { + let tokenizer_path = get_tokenizer_path(); + let tokenizer = Arc::new( + HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap()) + .expect("Failed to load tokenizer"), + ); + + let mut group = c.benchmark_group("l1_cache_chat"); + + // ============================================================================ + // SCENARIO 1: High Prefix Reuse (95%+ - Realistic Chat Application) + // ============================================================================ + // Most realistic: Same system prompt across 95%+ of requests, different user queries + // This is typical for chat applications where the same system context is reused + + let system_prompt = generate_system_prompt(8000); + + // Generate 100 different user queries of varying lengths (realistic distribution) + let user_queries: Vec = (0..100) + .map(|i| { + let base_queries = [ + "What is the capital of France?", + "Explain quantum mechanics in simple terms.", + "How do I sort an array in Python?", + "What are the benefits of exercise?", + "Can you help me write a resume?", + ]; + let query = base_queries[i % base_queries.len()]; + // Add variation to make each unique + format!("{} (Query #{})", query, i) + }) + .collect(); + + // Create prompts with ChatML format (same system prefix, different queries) + let realistic_prompts: Vec = user_queries + .iter() + .map(|query| { + format!( + "<|im_start|>system\n{}<|im_end|><|im_start|>user\n{}<|im_end|><|im_start|>assistant\n", + system_prompt, query + ) + }) + .collect(); + + // Baseline: No cache + let printed_baseline = Arc::new(AtomicBool::new(false)); + group.bench_function("realistic_chat_uncached", |b| { + let printed = printed_baseline.clone(); + let tokenizer = tokenizer.clone(); + let test_prompts = realistic_prompts.clone(); + + b.iter_custom(|iters| { + let start = Instant::now(); + for _ in 0..iters { + // Simulate 100 requests with different queries (realistic workload) + for prompt in &test_prompts { + black_box(tokenizer.encode(prompt).unwrap()); + } + } + let duration = start.elapsed(); + + if !printed.load(Ordering::Relaxed) { + let total_ops = iters * test_prompts.len() as u64; + let ops_per_sec = total_ops as f64 / duration.as_secs_f64(); + let avg_time_us = duration.as_micros() as f64 / total_ops as f64; + + let result = format!( + "{:<30} | {:>8} | {:>12.0} | {:>12.1} | {:>20}", + "Uncached (baseline)", + test_prompts[0].len(), + ops_per_sec, + avg_time_us, + "N/A" + ); + add_result("l1_cache", result); + + printed.store(true, Ordering::Relaxed); + } + + duration + }); + }); + + // L0-only: Should have 0% hit rate (all queries are unique) + let l0_only_config = CacheConfig { + enable_l0: true, + l0_max_entries: 10_000, + enable_l1: false, + l1_max_memory: 0, + }; + let cached_l0_only = Arc::new(CachedTokenizer::new(tokenizer.clone(), l0_only_config)); + + let printed_l0 = Arc::new(AtomicBool::new(false)); + group.bench_function("realistic_chat_l0_only", |b| { + let printed = printed_l0.clone(); + let cached = cached_l0_only.clone(); + let test_prompts = realistic_prompts.clone(); + + b.iter_custom(|iters| { + cached.clear_cache(); // Start fresh each iteration + let start = Instant::now(); + + for _ in 0..iters { + for prompt in &test_prompts { + black_box(cached.encode(prompt).unwrap()); + } + } + let duration = start.elapsed(); + + if !printed.load(Ordering::Relaxed) { + let total_ops = iters * test_prompts.len() as u64; + let ops_per_sec = total_ops as f64 / duration.as_secs_f64(); + let avg_time_us = duration.as_micros() as f64 / total_ops as f64; + let stats = cached.cache_stats().unwrap(); + + let result = format!( + "{:<30} | {:>8} | {:>12.0} | {:>12.1} | L0:{:>6.1}% L1:{:>6}", + "L0-only (no benefit)", + test_prompts[0].len(), + ops_per_sec, + avg_time_us, + stats.hit_rate * 100.0, + "N/A" + ); + add_result("l1_cache", result); + + printed.store(true, Ordering::Relaxed); + } + + duration + }); + }); + + // L0+L1: Should show significant speedup from prefix caching + let l0_l1_config = CacheConfig { + enable_l0: true, + l0_max_entries: 10_000, + enable_l1: true, + l1_max_memory: 50 * 1024 * 1024, + }; + let cached_l0_l1 = Arc::new(CachedTokenizer::new( + tokenizer.clone(), + l0_l1_config.clone(), + )); + + let printed_l0_l1 = Arc::new(AtomicBool::new(false)); + group.bench_function("realistic_chat_l0_l1", |b| { + let printed = printed_l0_l1.clone(); + let cached = cached_l0_l1.clone(); + let test_prompts = realistic_prompts.clone(); + + b.iter_custom(|iters| { + cached.clear_cache(); // Start fresh + + // Prime with first request to populate L1 with system prefix + cached.encode(&test_prompts[0]).unwrap(); + + let start = Instant::now(); + for _ in 0..iters { + // All subsequent requests benefit from L1 prefix cache + for prompt in &test_prompts { + black_box(cached.encode(prompt).unwrap()); + } + } + let duration = start.elapsed(); + + if !printed.load(Ordering::Relaxed) { + let total_ops = iters * test_prompts.len() as u64; + let ops_per_sec = total_ops as f64 / duration.as_secs_f64(); + let avg_time_us = duration.as_micros() as f64 / total_ops as f64; + let stats = cached.cache_stats().unwrap(); + let l1_stats = cached.l1_cache_stats().unwrap(); + + let result = format!( + "{:<30} | {:>8} | {:>12.0} | {:>12.1} | L0:{:>6.1}% L1:{:>6.1}%", + "L0+L1 (95%+ prefix reuse)", + test_prompts[0].len(), + ops_per_sec, + avg_time_us, + stats.hit_rate * 100.0, + l1_stats.hit_rate * 100.0 + ); + add_result("l1_cache", result); + + printed.store(true, Ordering::Relaxed); + } + + duration + }); + }); + + // ============================================================================ + // SCENARIO 2: Customer Service Bot (100% prefix reuse) + // ============================================================================ + // Identical greeting/instructions with different customer queries + + let service_system = "You are a helpful customer service assistant for TechCorp. \ + Always be polite, professional, and helpful. Our business hours are 9 AM to 5 PM EST. \ + We offer a 30-day return policy on all products. For technical issues, escalate to technical support. \ + For billing issues, escalate to accounting department.".repeat(20); // ~2KB + + let customer_queries = [ + "I need to return my laptop", + "My order hasn't arrived yet", + "How do I reset my password?", + "What's your return policy?", + "I was charged twice for my order", + "Can I change my shipping address?", + "Is my product under warranty?", + "I need help installing the software", + ]; + + let service_prompts: Vec = customer_queries + .iter() + .map(|query| { + format!( + "<|im_start|>system\n{}<|im_end|><|im_start|>user\n{}<|im_end|><|im_start|>assistant\n", + service_system, query + ) + }) + .collect(); + + // Service bot with L1-only (to compare against L0+L1) + let l1_only_config = CacheConfig { + enable_l0: false, + l0_max_entries: 0, + enable_l1: true, + l1_max_memory: 50 * 1024 * 1024, + }; + let service_cached_l1 = Arc::new(CachedTokenizer::new(tokenizer.clone(), l1_only_config)); + + let printed_service_l1 = Arc::new(AtomicBool::new(false)); + group.bench_function("customer_service_l1_only", |b| { + let printed = printed_service_l1.clone(); + let cached = service_cached_l1.clone(); + let test_prompts = service_prompts.clone(); + + b.iter_custom(|iters| { + cached.clear_cache(); + cached.encode(&test_prompts[0]).unwrap(); // Prime cache + + let start = Instant::now(); + for _ in 0..iters { + for prompt in &test_prompts { + black_box(cached.encode(prompt).unwrap()); + } + } + let duration = start.elapsed(); + + if !printed.load(Ordering::Relaxed) { + let total_ops = iters * test_prompts.len() as u64; + let ops_per_sec = total_ops as f64 / duration.as_secs_f64(); + let avg_time_us = duration.as_micros() as f64 / total_ops as f64; + let l1_stats = cached.l1_cache_stats().unwrap(); + + let result = format!( + "{:<30} | {:>8} | {:>12.0} | {:>12.1} | L0:{:>6} L1:{:>6.1}%", + "Customer Service (L1-only)", + test_prompts[0].len(), + ops_per_sec, + avg_time_us, + "N/A", + l1_stats.hit_rate * 100.0 + ); + add_result("l1_cache", result); + + printed.store(true, Ordering::Relaxed); + } + + duration + }); + }); + + // Service bot with L0+L1 + let service_cached = Arc::new(CachedTokenizer::new( + tokenizer.clone(), + l0_l1_config.clone(), + )); + + let printed_service = Arc::new(AtomicBool::new(false)); + group.bench_function("customer_service_l0_l1", |b| { + let printed = printed_service.clone(); + let cached = service_cached.clone(); + let test_prompts = service_prompts.clone(); + + b.iter_custom(|iters| { + cached.clear_cache(); + cached.encode(&test_prompts[0]).unwrap(); // Prime cache + + let start = Instant::now(); + for _ in 0..iters { + for prompt in &test_prompts { + black_box(cached.encode(prompt).unwrap()); + } + } + let duration = start.elapsed(); + + if !printed.load(Ordering::Relaxed) { + let total_ops = iters * test_prompts.len() as u64; + let ops_per_sec = total_ops as f64 / duration.as_secs_f64(); + let avg_time_us = duration.as_micros() as f64 / total_ops as f64; + let stats = cached.cache_stats().unwrap(); + let l1_stats = cached.l1_cache_stats().unwrap(); + + let result = format!( + "{:<30} | {:>8} | {:>12.0} | {:>12.1} | L0:{:>6.1}% L1:{:>6.1}%", + "Customer Service (100% reuse)", + test_prompts[0].len(), + ops_per_sec, + avg_time_us, + stats.hit_rate * 100.0, + l1_stats.hit_rate * 100.0 + ); + add_result("l1_cache", result); + + printed.store(true, Ordering::Relaxed); + } + + duration + }); + }); + + // ============================================================================ + // SCENARIO 3: Multi-Turn Conversation (Progressive context building) + // ============================================================================ + // Each turn builds on previous context (common in chat applications) + + let conversation_system = + "You are a helpful coding assistant. Help users write better code.".repeat(10); + + // Simulate a 5-turn conversation where context grows + let conversation_turns = vec![ + format!("<|im_start|>system\n{}<|im_end|><|im_start|>user\nHow do I sort an array in Python?<|im_end|><|im_start|>assistant\n", conversation_system), + format!("<|im_start|>system\n{}<|im_end|><|im_start|>user\nHow do I sort an array in Python?<|im_end|><|im_start|>assistant\nYou can use the sorted() function or list.sort() method.<|im_end|><|im_start|>user\nWhat's the difference between them?<|im_end|><|im_start|>assistant\n", conversation_system), + format!("<|im_start|>system\n{}<|im_end|><|im_start|>user\nHow do I sort an array in Python?<|im_end|><|im_start|>assistant\nYou can use the sorted() function or list.sort() method.<|im_end|><|im_start|>user\nWhat's the difference between them?<|im_end|><|im_start|>assistant\nsorted() creates a new list, sort() modifies in place.<|im_end|><|im_start|>user\nCan I sort by a custom key?<|im_end|><|im_start|>assistant\n", conversation_system), + ]; + + let conv_cached = Arc::new(CachedTokenizer::new( + tokenizer.clone(), + l0_l1_config.clone(), + )); + + let printed_conv = Arc::new(AtomicBool::new(false)); + group.bench_function("multi_turn_conversation", |b| { + let printed = printed_conv.clone(); + let cached = conv_cached.clone(); + let test_turns = conversation_turns.clone(); + + b.iter_custom(|iters| { + cached.clear_cache(); + + let start = Instant::now(); + for _ in 0..iters { + // Simulate progressive conversation (each turn shares prefix with previous) + for turn in &test_turns { + black_box(cached.encode(turn).unwrap()); + } + } + let duration = start.elapsed(); + + if !printed.load(Ordering::Relaxed) { + let total_ops = iters * test_turns.len() as u64; + let ops_per_sec = total_ops as f64 / duration.as_secs_f64(); + let avg_time_us = duration.as_micros() as f64 / total_ops as f64; + let stats = cached.cache_stats().unwrap(); + let l1_stats = cached.l1_cache_stats().unwrap(); + + let result = format!( + "{:<30} | {:>8} | {:>12.0} | {:>12.1} | L0:{:>6.1}% L1:{:>6.1}%", + "Multi-turn Conversation", + test_turns[0].len(), + ops_per_sec, + avg_time_us, + stats.hit_rate * 100.0, + l1_stats.hit_rate * 100.0 + ); + add_result("l1_cache", result); + + printed.store(true, Ordering::Relaxed); + } + + duration + }); + }); + + // ============================================================================ + // SCENARIO 4: Code Review Assistant (Same guidelines, different code snippets) + // ============================================================================ + + let review_system = "You are a code review assistant. Check for: \ + 1) Code quality and readability \ + 2) Performance issues \ + 3) Security vulnerabilities \ + 4) Best practices \ + 5) Documentation completeness" + .repeat(15); + + let code_snippets = [ + "function add(a, b) { return a + b; }", + "def factorial(n): return 1 if n <= 1 else n * factorial(n-1)", + "SELECT * FROM users WHERE id = $_GET['id']", // Security issue + "for (var i = 0; i < 10; i++) { setTimeout(() => console.log(i), 100); }", // Closure issue + ]; + + let review_prompts: Vec = code_snippets + .iter() + .map(|code| { + format!( + "<|im_start|>system\n{}<|im_end|><|im_start|>user\nReview this code:\n```\n{}\n```<|im_end|><|im_start|>assistant\n", + review_system, code + ) + }) + .collect(); + + let review_cached = Arc::new(CachedTokenizer::new(tokenizer.clone(), l0_l1_config)); + + let printed_review = Arc::new(AtomicBool::new(false)); + group.bench_function("code_review_assistant", |b| { + let printed = printed_review.clone(); + let cached = review_cached.clone(); + let test_prompts = review_prompts.clone(); + + b.iter_custom(|iters| { + cached.clear_cache(); + cached.encode(&test_prompts[0]).unwrap(); // Prime cache + + let start = Instant::now(); + for _ in 0..iters { + for prompt in &test_prompts { + black_box(cached.encode(prompt).unwrap()); + } + } + let duration = start.elapsed(); + + if !printed.load(Ordering::Relaxed) { + let total_ops = iters * test_prompts.len() as u64; + let ops_per_sec = total_ops as f64 / duration.as_secs_f64(); + let avg_time_us = duration.as_micros() as f64 / total_ops as f64; + let stats = cached.cache_stats().unwrap(); + let l1_stats = cached.l1_cache_stats().unwrap(); + + let result = format!( + "{:<30} | {:>8} | {:>12.0} | {:>12.1} | L0:{:>6.1}% L1:{:>6.1}%", + "Code Review (high reuse)", + test_prompts[0].len(), + ops_per_sec, + avg_time_us, + stats.hit_rate * 100.0, + l1_stats.hit_rate * 100.0 + ); + add_result("l1_cache", result); + + printed.store(true, Ordering::Relaxed); + } + + duration + }); + }); + + group.finish(); +} + // Print final summary table fn print_summary() { println!("\n{}", "=".repeat(120)); @@ -1372,6 +1849,13 @@ fn print_summary() { "Operation", "Calls/sec", "Time/call", "Improvement" ); } + "l1_cache" => { + println!("L1 CACHE (PREFIX MATCHING) - REALISTIC WORKLOADS"); + println!( + "{:<30} | {:>8} | {:>12} | {:>12} | {:>20}", + "Scenario", "Size(B)", "Ops/sec", "Time(µs)", "Hit Rates" + ); + } _ => {} } println!("{}", "-".repeat(120)); @@ -1396,6 +1880,7 @@ fn run_benchmarks(c: &mut Criterion) { bench_latency_distribution(c); bench_scaling_characteristics(c); bench_memory_efficiency(c); + bench_l1_cache_chat_template(c); // Print summary at the end print_summary(); diff --git a/sgl-router/py_src/sglang_router/router_args.py b/sgl-router/py_src/sglang_router/router_args.py index 587e5a02334..96977180257 100644 --- a/sgl-router/py_src/sglang_router/router_args.py +++ b/sgl-router/py_src/sglang_router/router_args.py @@ -87,6 +87,11 @@ class RouterArgs: model_path: Optional[str] = None tokenizer_path: Optional[str] = None chat_template: Optional[str] = None + # Tokenizer cache configuration + tokenizer_cache_enable_l0: bool = False + tokenizer_cache_l0_max_entries: int = 10000 + tokenizer_cache_enable_l1: bool = False + tokenizer_cache_l1_max_memory: int = 50 * 1024 * 1024 # 50MB reasoning_parser: Optional[str] = None tool_call_parser: Optional[str] = None # Backend selection @@ -467,6 +472,30 @@ def add_cli_args( default=None, help="Chat template path (optional)", ) + parser.add_argument( + f"--{prefix}tokenizer-cache-enable-l0", + action="store_true", + default=RouterArgs.tokenizer_cache_enable_l0, + help="Enable L0 (whole-string exact match) tokenizer cache (default: False)", + ) + parser.add_argument( + f"--{prefix}tokenizer-cache-l0-max-entries", + type=int, + default=RouterArgs.tokenizer_cache_l0_max_entries, + help="Maximum number of entries in L0 tokenizer cache (default: 10000)", + ) + parser.add_argument( + f"--{prefix}tokenizer-cache-enable-l1", + action="store_true", + default=RouterArgs.tokenizer_cache_enable_l1, + help="Enable L1 (prefix matching) tokenizer cache (default: False)", + ) + parser.add_argument( + f"--{prefix}tokenizer-cache-l1-max-memory", + type=int, + default=RouterArgs.tokenizer_cache_l1_max_memory, + help="Maximum memory for L1 tokenizer cache in bytes (default: 50MB)", + ) parser.add_argument( f"--{prefix}reasoning-parser", type=str, diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index dc704657b9a..3f318634b36 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -81,6 +81,53 @@ pub struct RouterConfig { pub reasoning_parser: Option, /// Parser for handling tool-call interactions pub tool_call_parser: Option, + /// Tokenizer cache configuration + #[serde(default)] + pub tokenizer_cache: TokenizerCacheConfig, +} + +/// Tokenizer cache configuration +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct TokenizerCacheConfig { + /// Enable L0 cache (whole-string exact match) + #[serde(default = "default_enable_l0")] + pub enable_l0: bool, + /// Maximum number of entries in L0 cache + #[serde(default = "default_l0_max_entries")] + pub l0_max_entries: usize, + /// Enable L1 cache (prefix matching at fixed boundaries) + #[serde(default = "default_enable_l1")] + pub enable_l1: bool, + /// Maximum memory for L1 cache in bytes + #[serde(default = "default_l1_max_memory")] + pub l1_max_memory: usize, +} + +fn default_enable_l0() -> bool { + false +} + +fn default_l0_max_entries() -> usize { + 10_000 +} + +fn default_enable_l1() -> bool { + false +} + +fn default_l1_max_memory() -> usize { + 50 * 1024 * 1024 // 50MB +} + +impl Default for TokenizerCacheConfig { + fn default() -> Self { + Self { + enable_l0: default_enable_l0(), + l0_max_entries: default_l0_max_entries(), + enable_l1: default_enable_l1(), + l1_max_memory: default_l1_max_memory(), + } + } } fn default_history_backend() -> HistoryBackend { @@ -459,6 +506,7 @@ impl Default for RouterConfig { oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: TokenizerCacheConfig::default(), } } } @@ -1004,6 +1052,7 @@ mod tests { oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: TokenizerCacheConfig::default(), }; assert!(config.mode.is_pd_mode()); @@ -1072,6 +1121,7 @@ mod tests { oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: TokenizerCacheConfig::default(), }; assert!(!config.mode.is_pd_mode()); @@ -1136,6 +1186,7 @@ mod tests { oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: TokenizerCacheConfig::default(), }; assert!(config.has_service_discovery()); diff --git a/sgl-router/src/config/validation.rs b/sgl-router/src/config/validation.rs index 1d8457d6194..7f6bc14a7a0 100644 --- a/sgl-router/src/config/validation.rs +++ b/sgl-router/src/config/validation.rs @@ -42,6 +42,9 @@ impl ConfigValidator { } } + // Validate tokenizer cache configuration + Self::validate_tokenizer_cache(&config.tokenizer_cache)?; + Ok(()) } @@ -446,6 +449,29 @@ impl ConfigValidator { Ok(()) } + /// Validate tokenizer cache configuration + fn validate_tokenizer_cache(cache: &TokenizerCacheConfig) -> ConfigResult<()> { + // Validate L0 max entries when L0 is enabled + if cache.enable_l0 && cache.l0_max_entries == 0 { + return Err(ConfigError::InvalidValue { + field: "tokenizer_cache.l0_max_entries".to_string(), + value: cache.l0_max_entries.to_string(), + reason: "Must be > 0 when L0 cache is enabled".to_string(), + }); + } + + // Validate L1 max memory when L1 is enabled + if cache.enable_l1 && cache.l1_max_memory == 0 { + return Err(ConfigError::InvalidValue { + field: "tokenizer_cache.l1_max_memory".to_string(), + value: cache.l1_max_memory.to_string(), + reason: "Must be > 0 when L1 cache is enabled".to_string(), + }); + } + + Ok(()) + } + /// Validate compatibility between different configuration sections fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> { // IGW mode is independent - skip other compatibility checks when enabled diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 6a9a9da1d2b..e08d7c5af0c 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -198,6 +198,10 @@ struct Router { model_path: Option, tokenizer_path: Option, chat_template: Option, + tokenizer_cache_enable_l0: bool, + tokenizer_cache_l0_max_entries: usize, + tokenizer_cache_enable_l1: bool, + tokenizer_cache_l1_max_memory: usize, reasoning_parser: Option, tool_call_parser: Option, backend: BackendType, @@ -350,6 +354,12 @@ impl Router { oracle, reasoning_parser: self.reasoning_parser.clone(), tool_call_parser: self.tool_call_parser.clone(), + tokenizer_cache: config::TokenizerCacheConfig { + enable_l0: self.tokenizer_cache_enable_l0, + l0_max_entries: self.tokenizer_cache_l0_max_entries, + enable_l1: self.tokenizer_cache_enable_l1, + l1_max_memory: self.tokenizer_cache_l1_max_memory, + }, }) } } @@ -415,6 +425,10 @@ impl Router { model_path = None, tokenizer_path = None, chat_template = None, + tokenizer_cache_enable_l0 = false, + tokenizer_cache_l0_max_entries = 10000, + tokenizer_cache_enable_l1 = false, + tokenizer_cache_l1_max_memory = 52428800, reasoning_parser = None, tool_call_parser = None, backend = BackendType::Sglang, @@ -480,6 +494,10 @@ impl Router { model_path: Option, tokenizer_path: Option, chat_template: Option, + tokenizer_cache_enable_l0: bool, + tokenizer_cache_l0_max_entries: usize, + tokenizer_cache_enable_l1: bool, + tokenizer_cache_l1_max_memory: usize, reasoning_parser: Option, tool_call_parser: Option, backend: BackendType, @@ -559,6 +577,10 @@ impl Router { model_path, tokenizer_path, chat_template, + tokenizer_cache_enable_l0, + tokenizer_cache_l0_max_entries, + tokenizer_cache_enable_l1, + tokenizer_cache_l1_max_memory, reasoning_parser, tool_call_parser, backend, diff --git a/sgl-router/src/main.rs b/sgl-router/src/main.rs index f32c8ddf3d1..4722c6a5be2 100644 --- a/sgl-router/src/main.rs +++ b/sgl-router/src/main.rs @@ -5,7 +5,7 @@ use sglang_router_rs::{ config::{ CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig, HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig, - RouterConfig, RoutingMode, + RouterConfig, RoutingMode, TokenizerCacheConfig, }, metrics::PrometheusConfig, server::{self, ServerConfig}, @@ -270,6 +270,18 @@ struct CliArgs { #[arg(long)] chat_template: Option, + #[arg(long, default_value_t = false)] + tokenizer_cache_enable_l0: bool, + + #[arg(long, default_value_t = 10000)] + tokenizer_cache_l0_max_entries: usize, + + #[arg(long, default_value_t = false)] + tokenizer_cache_enable_l1: bool, + + #[arg(long, default_value_t = 52428800)] + tokenizer_cache_l1_max_memory: usize, + #[arg(long, default_value = "memory", value_parser = ["memory", "none", "oracle"])] history_backend: String, @@ -581,6 +593,12 @@ impl CliArgs { oracle, reasoning_parser: self.reasoning_parser.clone(), tool_call_parser: self.tool_call_parser.clone(), + tokenizer_cache: TokenizerCacheConfig { + enable_l0: self.tokenizer_cache_enable_l0, + l0_max_entries: self.tokenizer_cache_l0_max_entries, + enable_l1: self.tokenizer_cache_enable_l1, + l1_max_memory: self.tokenizer_cache_l1_max_memory, + }, }) } diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index b699ff08d55..762c3fc283b 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -48,7 +48,11 @@ use crate::{ reasoning_parser::ParserFactory as ReasoningParserFactory, routers::{router_manager::RouterManager, RouterTrait}, service_discovery::{start_service_discovery, ServiceDiscoveryConfig}, - tokenizer::{factory as tokenizer_factory, traits::Tokenizer}, + tokenizer::{ + cache::{CacheConfig, CachedTokenizer}, + factory as tokenizer_factory, + traits::Tokenizer, + }, tool_parser::ParserFactory as ToolParserFactory, }; @@ -864,7 +868,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box Result<(), Box) + } else { + // Use base tokenizer directly without caching + Some(base_tokenizer) + }; let reasoning_parser_factory = Some(ReasoningParserFactory::new()); let tool_parser_factory = Some(ToolParserFactory::new()); diff --git a/sgl-router/src/tokenizer/cache/fingerprint.rs b/sgl-router/src/tokenizer/cache/fingerprint.rs new file mode 100644 index 00000000000..0adafd85012 --- /dev/null +++ b/sgl-router/src/tokenizer/cache/fingerprint.rs @@ -0,0 +1,103 @@ +//! Tokenizer Fingerprinting for Cache Invalidation +//! +//! Creates a unique fingerprint of a tokenizer's configuration to detect +//! when the tokenizer has changed and the cache needs to be cleared. + +use super::super::traits::Tokenizer; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; + +/// A fingerprint of a tokenizer's configuration +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct TokenizerFingerprint { + /// Size of the vocabulary + pub vocab_size: usize, + /// Hash of a sample of vocabulary tokens (for speed) + pub vocab_hash: u64, + /// Hash of special tokens + pub special_tokens_hash: u64, +} + +impl TokenizerFingerprint { + /// Create a fingerprint from a tokenizer + pub fn from_tokenizer(tokenizer: &dyn Tokenizer) -> Self { + let vocab_size = tokenizer.vocab_size(); + let vocab_hash = Self::compute_vocab_hash(tokenizer); + let special_tokens_hash = Self::compute_special_tokens_hash(tokenizer); + + Self { + vocab_size, + vocab_hash, + special_tokens_hash, + } + } + + /// Compute a hash of the vocabulary by sampling tokens + fn compute_vocab_hash(tokenizer: &dyn Tokenizer) -> u64 { + let mut hasher = DefaultHasher::new(); + let vocab_size = tokenizer.vocab_size(); + + // Sample up to 1000 tokens for speed + let sample_size = vocab_size.min(1000); + let step = if sample_size > 0 { + vocab_size / sample_size + } else { + 1 + }; + + for i in (0..vocab_size).step_by(step.max(1)) { + if let Some(token) = tokenizer.id_to_token(i as u32) { + token.hash(&mut hasher); + } + } + + hasher.finish() + } + + /// Compute a hash of special tokens + fn compute_special_tokens_hash(tokenizer: &dyn Tokenizer) -> u64 { + let mut hasher = DefaultHasher::new(); + let special_tokens = tokenizer.get_special_tokens(); + + special_tokens.bos_token.hash(&mut hasher); + special_tokens.eos_token.hash(&mut hasher); + special_tokens.unk_token.hash(&mut hasher); + special_tokens.sep_token.hash(&mut hasher); + special_tokens.pad_token.hash(&mut hasher); + special_tokens.cls_token.hash(&mut hasher); + special_tokens.mask_token.hash(&mut hasher); + special_tokens.additional_special_tokens.hash(&mut hasher); + + hasher.finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tokenizer::mock::MockTokenizer; + + #[test] + fn test_fingerprint_equality() { + let tokenizer1 = MockTokenizer::new(); + let tokenizer2 = MockTokenizer::new(); + + let fp1 = TokenizerFingerprint::from_tokenizer(&tokenizer1); + let fp2 = TokenizerFingerprint::from_tokenizer(&tokenizer2); + + // Same tokenizer config should produce same fingerprint + assert_eq!(fp1, fp2); + } + + #[test] + fn test_fingerprint_consistency() { + let tokenizer = MockTokenizer::new(); + + let fp1 = TokenizerFingerprint::from_tokenizer(&tokenizer); + let fp2 = TokenizerFingerprint::from_tokenizer(&tokenizer); + + // Fingerprint should be consistent + assert_eq!(fp1, fp2); + assert_eq!(fp1.vocab_size, tokenizer.vocab_size()); + } +} diff --git a/sgl-router/src/tokenizer/cache/l0.rs b/sgl-router/src/tokenizer/cache/l0.rs new file mode 100644 index 00000000000..768ff42a3aa --- /dev/null +++ b/sgl-router/src/tokenizer/cache/l0.rs @@ -0,0 +1,216 @@ +//! L0 Cache: Whole-string exact match cache +//! +//! This is the simplest and most effective cache layer. +//! Key: input string → Value: full encoding result +//! +//! Expected hit rate: 60-90% for workloads with repeated system prompts + +use super::super::traits::Encoding; +use dashmap::DashMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; + +/// L0 cache implementation using DashMap for lock-free reads +pub struct L0Cache { + /// The cache map: input string → encoding + map: Arc>, + /// Maximum number of entries before eviction + max_entries: usize, + /// Cache hit counter + hits: AtomicU64, + /// Cache miss counter + misses: AtomicU64, +} + +impl L0Cache { + /// Create a new L0 cache with the specified capacity + pub fn new(max_entries: usize) -> Self { + Self { + map: Arc::new(DashMap::with_capacity(max_entries.min(1024))), + max_entries, + hits: AtomicU64::new(0), + misses: AtomicU64::new(0), + } + } + + /// Get an encoding from the cache + pub fn get(&self, key: &str) -> Option { + match self.map.get(key) { + Some(entry) => { + self.hits.fetch_add(1, Ordering::Relaxed); + Some(entry.value().clone()) + } + None => { + self.misses.fetch_add(1, Ordering::Relaxed); + None + } + } + } + + /// Insert an encoding into the cache + pub fn insert(&self, key: String, value: Encoding) { + // Simple eviction: if we're at capacity, remove a random entry + // DashMap doesn't support LRU directly, so we use a simple strategy + if self.map.len() >= self.max_entries { + // Get the key to remove in a separate scope to ensure iterator is dropped + let key_to_remove = { self.map.iter().next().map(|entry| entry.key().clone()) }; // Iterator fully dropped here, all locks released + + // Now remove it + if let Some(k) = key_to_remove { + self.map.remove(&k); + } + } + + self.map.insert(key, value); + } + + /// Get the current number of entries in the cache + pub fn len(&self) -> usize { + self.map.len() + } + + /// Check if the cache is empty + pub fn is_empty(&self) -> bool { + self.map.is_empty() + } + + /// Get cache statistics + pub fn stats(&self) -> CacheStats { + let hits = self.hits.load(Ordering::Relaxed); + let misses = self.misses.load(Ordering::Relaxed); + let total_requests = hits + misses; + + CacheStats { + hits, + misses, + entries: self.len(), + hit_rate: if total_requests > 0 { + hits as f64 / total_requests as f64 + } else { + 0.0 + }, + } + } + + /// Clear the cache + pub fn clear(&self) { + self.map.clear(); + self.hits.store(0, Ordering::Relaxed); + self.misses.store(0, Ordering::Relaxed); + } + + /// Estimate memory usage in bytes + pub fn memory_usage(&self) -> usize { + // Rough estimate: + // - Each entry: key (string) + value (encoding ~250 tokens * 4 bytes) + overhead + // - Average: ~2.2KB per entry + self.len() * 2200 + } +} + +#[derive(Debug, Clone)] +pub struct CacheStats { + pub hits: u64, + pub misses: u64, + pub entries: usize, + pub hit_rate: f64, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tokenizer::traits::Encoding; + + fn mock_encoding(tokens: Vec) -> Encoding { + Encoding::Sp(tokens) + } + + #[test] + fn test_basic_get_set() { + let cache = L0Cache::new(10); + + // Miss + assert!(cache.get("hello").is_none()); + + // Insert + cache.insert("hello".to_string(), mock_encoding(vec![1, 2, 3])); + + // Hit + let result = cache.get("hello"); + assert!(result.is_some()); + assert_eq!(result.unwrap().token_ids(), &[1, 2, 3]); + } + + #[test] + fn test_eviction() { + let cache = L0Cache::new(2); + + cache.insert("a".to_string(), mock_encoding(vec![1])); + cache.insert("b".to_string(), mock_encoding(vec![2])); + + // Should evict when adding third + cache.insert("c".to_string(), mock_encoding(vec![3])); + + // Cache should have exactly 2 entries + assert_eq!(cache.len(), 2); + } + + #[test] + fn test_stats() { + let cache = L0Cache::new(10); + + cache.insert("test".to_string(), mock_encoding(vec![1, 2, 3])); + + // 1 miss (initial get that returned None) + let _ = cache.get("missing"); + + // 1 hit + let _ = cache.get("test"); + + let stats = cache.stats(); + assert_eq!(stats.hits, 1); + assert_eq!(stats.misses, 1); + assert_eq!(stats.hit_rate, 0.5); + } + + #[test] + fn test_clear() { + let cache = L0Cache::new(10); + + cache.insert("test".to_string(), mock_encoding(vec![1, 2, 3])); + assert_eq!(cache.len(), 1); + + cache.clear(); + assert_eq!(cache.len(), 0); + assert!(cache.get("test").is_none()); + } + + #[test] + fn test_concurrent_access() { + use std::thread; + + let cache = Arc::new(L0Cache::new(1000)); + let mut handles = vec![]; + + // Spawn 10 threads + for i in 0..10 { + let cache_clone = cache.clone(); + handles.push(thread::spawn(move || { + // Each thread inserts and reads + let key = format!("key_{}", i); + cache_clone.insert(key.clone(), mock_encoding(vec![i as u32])); + + // Read it back + let result = cache_clone.get(&key); + assert!(result.is_some()); + })); + } + + for handle in handles { + handle.join().unwrap(); + } + + // Should have 10 entries + assert_eq!(cache.len(), 10); + } +} diff --git a/sgl-router/src/tokenizer/cache/l1.rs b/sgl-router/src/tokenizer/cache/l1.rs new file mode 100644 index 00000000000..39505cb15ba --- /dev/null +++ b/sgl-router/src/tokenizer/cache/l1.rs @@ -0,0 +1,501 @@ +//! L1 Cache: Special-token boundary prefix cache +//! +//! Caches tokenization results at ALL special token boundaries. +//! Special tokens (like `<|im_start|>`, `<|im_end|>`) are atomic in BPE tokenizers (special: true, normalized: false), +//! making them the ONLY safe split points that guarantee correctness. +//! +//! **Design**: Cache at every special token boundary (not at fixed granularity intervals) +//! - Simple: No granularity parameter, no search windows +//! - Efficient: Fewer cache entries (10 instead of 64 for typical 8KB prompt) +//! - Natural: Aligns with actual chat template structure +//! +//! Example: +//! +//! Template: "<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\n{query}<|im_end|>" +//! +//! Request 1: "<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\nWhat is 2+2?<|im_end|>" +//! Request 2: "<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\nHello!<|im_end|>" +//! +//! Cache points: After each "<|im_end|>" (atomic tokens, guaranteed safe) +//! Result: tokenize(prefix) + tokenize(suffix) == tokenize(prefix + suffix) + +use super::super::traits::TokenIdType; +use blake3; +use dashmap::DashMap; +use std::mem::size_of; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; + +/// Hash type for cache keys +type Blake3Hash = [u8; 32]; + +/// Number of shards for concurrent access +const NUM_SHARDS: usize = 16; + +/// Find ALL special token boundaries in the text +/// +/// **ONLY uses special tokens** - these are atomic (special: true, normalized: false) in BPE, +/// guaranteeing: tokenize(prefix) + tokenize(suffix) == tokenize(prefix + suffix) +/// +/// No fallback to whitespace/punctuation - better to not cache than risk corruption. +/// +/// Common special tokens: +/// - ChatML: `<|im_start|>`, `<|im_end|>` +/// - Llama 3: `<|begin_of_text|>`, `<|end_of_text|>`, `<|eot_id|>` +/// - GPT: `<|endoftext|>` +/// - Custom: `<|reserved_special_token_N|>` +/// +/// Returns positions immediately after each special token (where prefixes can be cached). +fn find_special_token_boundaries(text: &str, special_tokens: &[&str]) -> Vec { + if special_tokens.is_empty() { + return Vec::new(); + } + + let mut boundaries = Vec::new(); + + // Find all special token end positions + for &token in special_tokens { + let mut start = 0; + while let Some(pos) = text[start..].find(token) { + let boundary = start + pos + token.len(); + // Only cache boundaries that leave some suffix to tokenize + if boundary < text.len() { + boundaries.push(boundary); + } + start = boundary; + } + } + + // Sort and deduplicate (in case multiple special tokens end at same position) + boundaries.sort_unstable(); + boundaries.dedup(); + + boundaries +} + +/// A cached prefix entry +#[derive(Debug, Clone)] +struct CachedPrefix { + /// The pre-computed token IDs for this prefix + tokens: Vec, + /// Last access timestamp (for LRU eviction) + last_accessed: Arc, + /// Size in bytes (for memory tracking during eviction) + size_bytes: usize, +} + +/// L1 cache implementation with special-token-boundary prefix matching +pub struct L1Cache { + /// Sharded maps for concurrent access + /// Key: Blake3 hash of bytes[0..boundary] + /// Value: Cached token IDs for that prefix + shards: Vec>>, + /// Maximum memory in bytes + max_memory: usize, + /// Current memory usage estimate + current_memory: AtomicU64, + /// Cache hit counter + hits: AtomicU64, + /// Cache miss counter + misses: AtomicU64, + /// Monotonic counter for LRU timestamps + access_counter: AtomicU64, +} + +impl L1Cache { + /// Create a new L1 cache with the specified memory limit + pub fn new(max_memory: usize) -> Self { + let shards = (0..NUM_SHARDS).map(|_| Arc::new(DashMap::new())).collect(); + + Self { + shards, + max_memory, + current_memory: AtomicU64::new(0), + hits: AtomicU64::new(0), + misses: AtomicU64::new(0), + access_counter: AtomicU64::new(0), + } + } + + /// Try to find the longest prefix match at special token boundaries + /// Returns (cached_tokens, byte_offset) if found + /// + /// Uses pre-computed tokens cached during insertion. + pub fn longest_prefix_match( + &self, + input: &str, + special_tokens: &[&str], + ) -> Option<(Vec, usize)> { + let boundaries = find_special_token_boundaries(input, special_tokens); + + if boundaries.is_empty() { + self.misses.fetch_add(1, Ordering::Relaxed); + return None; + } + + // Search backwards from the longest boundary to find the best match + for &boundary_pos in boundaries.iter().rev() { + let prefix = &input[0..boundary_pos]; + let prefix_bytes = prefix.as_bytes(); + let hash = blake3::hash(prefix_bytes); + let hash_bytes: Blake3Hash = *hash.as_bytes(); + + let shard_idx = hash_bytes[0] as usize % NUM_SHARDS; + + if let Some(entry) = self.shards[shard_idx].get(&hash_bytes) { + // Update last accessed timestamp for LRU + let timestamp = self.access_counter.fetch_add(1, Ordering::Relaxed); + entry.last_accessed.store(timestamp, Ordering::Relaxed); + + self.hits.fetch_add(1, Ordering::Relaxed); + return Some((entry.tokens.clone(), boundary_pos)); + } + } + + self.misses.fetch_add(1, Ordering::Relaxed); + None + } + + /// Insert prefix entries at ALL special token boundaries + /// + /// Re-tokenizes each prefix to ensure correctness (BPE tokenization is not prefix-stable). + /// This is more expensive on cache misses but provides correct tokens for cache hits. + /// + /// Optimized for workloads with high prefix reuse (e.g., chat templates with repeated system prompts). + pub fn insert_at_boundaries( + &self, + input: &str, + tokenizer: &E, + special_tokens: &[&str], + ) -> anyhow::Result<()> { + let boundaries = find_special_token_boundaries(input, special_tokens); + + if boundaries.is_empty() { + return Ok(()); + } + + // Calculate how much memory we need and tokenize each prefix + let mut entries_to_insert = Vec::new(); + for &boundary_pos in &boundaries { + // Extract prefix up to this special token boundary + let prefix = &input[0..boundary_pos]; + let prefix_bytes = prefix.as_bytes(); + let hash = blake3::hash(prefix_bytes); + let hash_bytes: Blake3Hash = *hash.as_bytes(); + + // Re-tokenize the prefix for guaranteed correctness + // This is the only way to know the exact token boundaries + let prefix_encoding = tokenizer.encode(prefix)?; + let prefix_tokens = prefix_encoding.token_ids().to_vec(); + + // Size = text bytes + token storage + let size_bytes = boundary_pos + prefix_tokens.len() * size_of::(); + + entries_to_insert.push((hash_bytes, prefix_tokens, size_bytes)); + } + + if entries_to_insert.is_empty() { + return Ok(()); + } + + let total_size_needed: usize = entries_to_insert.iter().map(|(_, _, size)| size).sum(); + + // Evict if necessary + let current = self.current_memory.load(Ordering::Relaxed) as usize; + if current + total_size_needed > self.max_memory { + self.evict_lru(total_size_needed); + } + + // Insert all entries + for (hash_bytes, prefix_tokens, size_bytes) in entries_to_insert { + let shard_idx = hash_bytes[0] as usize % NUM_SHARDS; + + let cached = CachedPrefix { + tokens: prefix_tokens, + last_accessed: Arc::new(AtomicU64::new( + self.access_counter.load(Ordering::Relaxed), + )), + size_bytes, + }; + + self.shards[shard_idx].insert(hash_bytes, cached); + self.current_memory + .fetch_add(size_bytes as u64, Ordering::Relaxed); + } + + Ok(()) + } + + /// Evict least recently used entries using approximate LRU via random sampling + /// + /// This uses an approximate LRU strategy that's much faster than true LRU: + /// - Samples K random entries from the cache (K=32) + /// - Evicts the oldest entry among the samples + /// - Repeats until enough space is freed + /// + /// This provides O(samples) complexity instead of O(total_entries * log(total_entries)), + /// avoiding latency spikes when eviction is triggered on large caches. + /// + /// The approximation is excellent in practice - sampling 32 entries from a large cache + /// gives high probability of finding very old entries. + fn evict_lru(&self, space_needed: usize) { + const SAMPLE_SIZE: usize = 32; // Number of entries to sample per eviction round + let mut freed = 0usize; + let mut iteration = 0usize; + + // Keep evicting until we have enough space + while freed < space_needed { + // Collect samples from shards + let mut samples: Vec<(usize, Blake3Hash, u64, usize)> = Vec::with_capacity(SAMPLE_SIZE); + + // Sample entries across different shards + for i in 0..SAMPLE_SIZE { + // Distribute samples across shards using iteration and index for variety + let shard_idx = (iteration * SAMPLE_SIZE + i) % NUM_SHARDS; + + // Get first entry from that shard (DashMap iteration order is arbitrary) + if let Some(entry) = self.shards[shard_idx].iter().next() { + let hash = *entry.key(); + let timestamp = entry.value().last_accessed.load(Ordering::Relaxed); + let size = entry.value().size_bytes; + samples.push((shard_idx, hash, timestamp, size)); + } + } + + if samples.is_empty() { + // Cache is empty, nothing to evict + break; + } + + // Find the oldest entry among samples + if let Some((shard_idx, hash, _, _)) = + samples.iter().min_by_key(|(_, _, ts, _)| ts).copied() + { + // Remove it + if let Some((_, removed)) = self.shards[shard_idx].remove(&hash) { + freed += removed.size_bytes; + self.current_memory + .fetch_sub(removed.size_bytes as u64, Ordering::Relaxed); + } + } + + iteration += 1; + } + } + + /// Get the number of entries in the cache + pub fn len(&self) -> usize { + self.shards.iter().map(|s| s.len()).sum() + } + + /// Check if the cache is empty + pub fn is_empty(&self) -> bool { + self.shards.iter().all(|s| s.is_empty()) + } + + /// Get cache statistics + pub fn stats(&self) -> L1CacheStats { + let hits = self.hits.load(Ordering::Relaxed); + let misses = self.misses.load(Ordering::Relaxed); + let total_requests = hits + misses; + + L1CacheStats { + hits, + misses, + entries: self.len(), + memory_bytes: self.current_memory.load(Ordering::Relaxed) as usize, + hit_rate: if total_requests > 0 { + hits as f64 / total_requests as f64 + } else { + 0.0 + }, + } + } + + /// Clear the cache + pub fn clear(&self) { + for shard in &self.shards { + shard.clear(); + } + self.current_memory.store(0, Ordering::Relaxed); + self.hits.store(0, Ordering::Relaxed); + self.misses.store(0, Ordering::Relaxed); + } +} + +#[derive(Debug, Clone)] +pub struct L1CacheStats { + pub hits: u64, + pub misses: u64, + pub entries: usize, + pub memory_bytes: usize, + pub hit_rate: f64, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tokenizer::mock::MockTokenizer; + + #[test] + fn test_basic_prefix_match() { + let cache = L1Cache::new(1024 * 1024); + let special_tokens = &["<|im_start|>", "<|im_end|>"]; + let tokenizer = MockTokenizer::new(); + + // Realistic ChatML template with special tokens + let input1 = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nHello there! How are you doing today?<|im_end|>"; + + // Insert at special token boundaries (re-tokenizes prefixes) + cache + .insert_at_boundaries(input1, &tokenizer, special_tokens) + .unwrap(); + + // Should have cached at special token boundaries + assert!(!cache.is_empty()); + + // Search with same prefix but different user query + let input2 = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nWhat is 2+2?<|im_end|>"; + let result = cache.longest_prefix_match(input2, special_tokens); + + // Should find a match at the special token boundary (after system message) + assert!(result.is_some()); + let (tokens, offset) = result.unwrap(); + assert!(offset > 0); + assert!(!tokens.is_empty()); + } + + #[test] + fn test_short_input_with_boundaries() { + let cache = L1Cache::new(1024 * 1024); + let special_tokens = &["<|im_start|>", "<|im_end|>"]; + let tokenizer = MockTokenizer::new(); + + // Short input with special tokens + let input = "<|im_start|>user\nHi<|im_end|>"; + + cache + .insert_at_boundaries(input, &tokenizer, special_tokens) + .unwrap(); + + // Should cache at <|im_start|> boundary (has suffix left) + assert!(!cache.is_empty()); + + // Should find a match + let result = cache.longest_prefix_match(input, special_tokens); + assert!(result.is_some()); + } + + #[test] + fn test_longest_match() { + let cache = L1Cache::new(1024 * 1024); + let special_tokens = &["<|im_start|>", "<|im_end|>"]; + let tokenizer = MockTokenizer::new(); + + // Create multi-turn conversation with multiple special token boundaries (~400 bytes) + let input = "<|im_start|>system\nYou are a helpful AI assistant that provides detailed and accurate responses.<|im_end|><|im_start|>user\nHello there! How are you today? Can you help me understand how tokenization works in language models?<|im_end|><|im_start|>assistant\nI'm doing well, thank you! I'd be happy to explain tokenization. Tokenization is the process of breaking text into smaller units called tokens.<|im_end|>"; + + cache + .insert_at_boundaries(input, &tokenizer, special_tokens) + .unwrap(); + + // Should have multiple entries at special token boundaries + assert!(cache.len() >= 2); // At least 2 boundaries + + // Search with partial conversation - should match at a special token boundary + let partial_input = "<|im_start|>system\nYou are a helpful AI assistant that provides detailed and accurate responses.<|im_end|><|im_start|>user\nHello there! How are you today? Can you help me understand how tokenization works in language models?<|im_end|>"; + let result = cache.longest_prefix_match(partial_input, special_tokens); + + // Should find a match at a special token boundary + assert!(result.is_some()); + let (_, offset) = result.unwrap(); + assert!(offset > 0); + assert!(offset <= partial_input.len()); + } + + #[test] + fn test_stats() { + let cache = L1Cache::new(1024 * 1024); + let special_tokens = &["<|im_start|>", "<|im_end|>"]; + let tokenizer = MockTokenizer::new(); + + // ChatML input with special tokens + let input = "<|im_start|>system\nYou are a helpful assistant that provides detailed answers.<|im_end|><|im_start|>user\nHello there! How are you today?<|im_end|>"; + + cache + .insert_at_boundaries(input, &tokenizer, special_tokens) + .unwrap(); + + // Try to find match + let _ = cache.longest_prefix_match(input, special_tokens); + + let stats = cache.stats(); + // Should have at least one hit (the longest special token boundary should match) + assert!(stats.hits >= 1); + assert_eq!(stats.hit_rate, 1.0); + } + + #[test] + fn test_clear() { + let cache = L1Cache::new(1024 * 1024); + let special_tokens = &["<|im_start|>", "<|im_end|>"]; + let tokenizer = MockTokenizer::new(); + + // ChatML input with special tokens + let input = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nHello there!<|im_end|>"; + + cache + .insert_at_boundaries(input, &tokenizer, special_tokens) + .unwrap(); + assert!(!cache.is_empty()); + + cache.clear(); + assert!(cache.is_empty()); + + let stats = cache.stats(); + assert_eq!(stats.hits, 0); + assert_eq!(stats.misses, 0); + } + + #[test] + fn test_lru_eviction() { + // Create a small cache (5KB) to trigger eviction + let cache = L1Cache::new(5 * 1024); + let special_tokens = &["<|im_start|>", "<|im_end|>", "<|eot_id|>"]; + let tokenizer = MockTokenizer::new(); + + // Insert first conversation + let input1 = "<|im_start|>system\nYou are a helpful assistant specialized in mathematics.<|im_end|><|im_start|>user\nCan you explain calculus to me?<|im_end|><|im_start|>assistant\nCertainly! Calculus is a branch of mathematics that studies continuous change.<|im_end|><|eot_id|>"; + cache + .insert_at_boundaries(input1, &tokenizer, special_tokens) + .unwrap(); + + // Access the first entry to update its timestamp + let result = cache.longest_prefix_match(input1, special_tokens); + assert!(result.is_some()); + + // Insert second conversation + let input2 = "<|im_start|>system\nYou are a helpful assistant specialized in physics.<|im_end|><|im_start|>user\nWhat is quantum mechanics?<|im_end|><|im_start|>assistant\nQuantum mechanics is the fundamental theory describing nature at atomic and subatomic scales.<|im_end|><|eot_id|>"; + cache + .insert_at_boundaries(input2, &tokenizer, special_tokens) + .unwrap(); + + // Access the second entry to make it more recent + let result = cache.longest_prefix_match(input2, special_tokens); + assert!(result.is_some()); + + // Insert third conversation (should trigger eviction of oldest) + let input3 = "<|im_start|>system\nYou are a helpful assistant specialized in chemistry.<|im_end|><|im_start|>user\nExplain the periodic table to me please.<|im_end|><|im_start|>assistant\nThe periodic table is a tabular arrangement of chemical elements organized by atomic number and electron configuration.<|im_end|><|eot_id|>"; + cache + .insert_at_boundaries(input3, &tokenizer, special_tokens) + .unwrap(); + + // Verify cache didn't exceed max memory + let stats = cache.stats(); + assert!(stats.memory_bytes <= 5 * 1024); + + // The most recently accessed entries should still be present + let result = cache.longest_prefix_match(input3, special_tokens); + assert!(result.is_some()); + } +} diff --git a/sgl-router/src/tokenizer/cache/mod.rs b/sgl-router/src/tokenizer/cache/mod.rs new file mode 100644 index 00000000000..c7eea01d618 --- /dev/null +++ b/sgl-router/src/tokenizer/cache/mod.rs @@ -0,0 +1,358 @@ +//! Tokenizer Caching Layer +//! +//! Provides a caching wrapper around any tokenizer implementation to speed up +//! repeated tokenization of the same strings (e.g., system prompts). +//! +//! # Architecture +//! - **L0 Cache**: Whole-string exact match (90% of wins) +//! - **L1 Cache**: Prefix matching at fixed boundaries (future work) +//! +//! # Usage +//! ```ignore +//! let tokenizer = Arc::new(HuggingFaceTokenizer::from_file("tokenizer.json")?); +//! let cached = Arc::new(CachedTokenizer::new(tokenizer, CacheConfig::default())); +//! let encoding = cached.encode("Hello world")?; +//! ``` + +mod fingerprint; +mod l0; +mod l1; + +pub use fingerprint::TokenizerFingerprint; +pub use l0::{CacheStats, L0Cache}; +pub use l1::{L1Cache, L1CacheStats}; + +use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer}; +use anyhow::Result; +use rayon::prelude::*; +use std::sync::Arc; + +/// Configuration for the tokenizer cache +#[derive(Debug, Clone)] +pub struct CacheConfig { + /// Enable L0 (whole-string) cache + pub enable_l0: bool, + /// Maximum number of entries in L0 cache + pub l0_max_entries: usize, + /// Enable L1 (prefix) cache + pub enable_l1: bool, + /// Maximum memory for L1 cache in bytes + pub l1_max_memory: usize, +} + +impl Default for CacheConfig { + fn default() -> Self { + Self { + enable_l0: true, + l0_max_entries: 10_000, // ~22MB memory for typical prompts + enable_l1: false, // Opt-in for now + l1_max_memory: 50 * 1024 * 1024, // 50MB + } + } +} + +/// A caching wrapper around any tokenizer +pub struct CachedTokenizer { + /// The underlying tokenizer + inner: Arc, + /// L0 cache (whole-string exact match) + l0: Option, + /// L1 cache (prefix matching at fixed boundaries) + l1: Option, + /// Configuration + #[allow(dead_code)] + config: CacheConfig, + /// Fingerprint for cache invalidation + fingerprint: TokenizerFingerprint, + /// Cached special token strings (extracted once at construction) + special_token_strings: Vec, +} + +impl CachedTokenizer { + /// Create a new cached tokenizer + pub fn new(inner: Arc, config: CacheConfig) -> Self { + let fingerprint = TokenizerFingerprint::from_tokenizer(inner.as_ref()); + + let l0 = if config.enable_l0 { + Some(L0Cache::new(config.l0_max_entries)) + } else { + None + }; + + let l1 = if config.enable_l1 { + Some(L1Cache::new(config.l1_max_memory)) + } else { + None + }; + + // Extract special tokens once at construction time + let special_token_strings = Self::extract_special_token_strings(&inner); + + Self { + inner, + l0, + l1, + config, + fingerprint, + special_token_strings, + } + } + + /// Extract all special token strings from the tokenizer (called once at construction) + fn extract_special_token_strings(tokenizer: &Arc) -> Vec { + let special_tokens = tokenizer.get_special_tokens(); + let mut tokens = Vec::new(); + + if let Some(ref token) = special_tokens.bos_token { + tokens.push(token.clone()); + } + if let Some(ref token) = special_tokens.eos_token { + tokens.push(token.clone()); + } + if let Some(ref token) = special_tokens.unk_token { + tokens.push(token.clone()); + } + if let Some(ref token) = special_tokens.sep_token { + tokens.push(token.clone()); + } + if let Some(ref token) = special_tokens.pad_token { + tokens.push(token.clone()); + } + if let Some(ref token) = special_tokens.cls_token { + tokens.push(token.clone()); + } + if let Some(ref token) = special_tokens.mask_token { + tokens.push(token.clone()); + } + + tokens.extend(special_tokens.additional_special_tokens.iter().cloned()); + tokens + } + + /// Get L0 cache statistics + pub fn cache_stats(&self) -> Option { + self.l0.as_ref().map(|cache| cache.stats()) + } + + /// Get L1 cache statistics + pub fn l1_cache_stats(&self) -> Option { + self.l1.as_ref().map(|cache| cache.stats()) + } + + /// Clear the cache + pub fn clear_cache(&self) { + if let Some(l0) = &self.l0 { + l0.clear(); + } + if let Some(l1) = &self.l1 { + l1.clear(); + } + } + + /// Get the fingerprint of the underlying tokenizer + pub fn fingerprint(&self) -> &TokenizerFingerprint { + &self.fingerprint + } +} + +impl Encoder for CachedTokenizer { + fn encode(&self, input: &str) -> Result { + // L0 cache lookup (exact match) + if let Some(l0) = &self.l0 { + if let Some(cached) = l0.get(input) { + return Ok(cached); + } + } + + // L1 cache lookup (prefix match at special token boundaries) + if let Some(l1) = &self.l1 { + let special_tokens: Vec<&str> = self + .special_token_strings + .iter() + .map(|s| s.as_str()) + .collect(); + + if let Some((prefix_tokens, prefix_len)) = + l1.longest_prefix_match(input, &special_tokens) + { + // We have a prefix match - tokenize the suffix + let suffix = &input[prefix_len..]; + if !suffix.is_empty() { + let suffix_encoding = self.inner.encode(suffix)?; + + // Merge prefix tokens + suffix tokens + // Safe because we're splitting at special token boundaries + let mut merged_tokens = prefix_tokens; + merged_tokens.extend_from_slice(suffix_encoding.token_ids()); + + let merged_encoding = Encoding::Sp(merged_tokens); + + // Cache the full result in L0 + if let Some(l0) = &self.l0 { + l0.insert(input.to_string(), merged_encoding.clone()); + } + + return Ok(merged_encoding); + } + } + } + + // Full tokenization (both L0 and L1 miss) + let encoding = self.inner.encode(input)?; + + // Cache in L0 + if let Some(l0) = &self.l0 { + l0.insert(input.to_string(), encoding.clone()); + } + + // Cache in L1 at special token boundaries + // Re-tokenizes prefixes for correctness (optimized for high prefix reuse) + if let Some(l1) = &self.l1 { + let special_tokens: Vec<&str> = self + .special_token_strings + .iter() + .map(|s| s.as_str()) + .collect(); + let _ = l1.insert_at_boundaries(input, self.inner.as_ref(), &special_tokens); + // Ignore errors in cache insertion - cache is best-effort + } + + Ok(encoding) + } + + fn encode_batch(&self, inputs: &[&str]) -> Result> { + // Process each input in parallel, leveraging thread-safe caches + // This maintains the parallelism from the underlying HuggingFaceTokenizer + inputs.par_iter().map(|&input| self.encode(input)).collect() + } +} + +impl Decoder for CachedTokenizer { + fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result { + // Decoding is not cached (it's fast enough and rarely repeated) + self.inner.decode(token_ids, skip_special_tokens) + } +} + +impl Tokenizer for CachedTokenizer { + fn vocab_size(&self) -> usize { + self.inner.vocab_size() + } + + fn get_special_tokens(&self) -> &SpecialTokens { + self.inner.get_special_tokens() + } + + fn token_to_id(&self, token: &str) -> Option { + self.inner.token_to_id(token) + } + + fn id_to_token(&self, id: TokenIdType) -> Option { + self.inner.id_to_token(id) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tokenizer::mock::MockTokenizer; + + #[test] + fn test_cache_hit() { + let tokenizer = Arc::new(MockTokenizer::new()); + let cached = CachedTokenizer::new(tokenizer, CacheConfig::default()); + + let input = "Hello world"; + + // First call - miss + let result1 = cached.encode(input).unwrap(); + + // Second call - hit + let result2 = cached.encode(input).unwrap(); + + // Results should be identical + assert_eq!(result1.token_ids(), result2.token_ids()); + + // Check cache stats + let stats = cached.cache_stats().unwrap(); + assert_eq!(stats.hits, 1); + assert_eq!(stats.misses, 1); + } + + #[test] + fn test_cache_disabled() { + let tokenizer = Arc::new(MockTokenizer::new()); + let config = CacheConfig { + enable_l0: false, + l0_max_entries: 0, + enable_l1: false, + l1_max_memory: 0, + }; + let cached = CachedTokenizer::new(tokenizer, config); + + let input = "Hello world"; + + // Both calls should work even without cache + let result1 = cached.encode(input).unwrap(); + let result2 = cached.encode(input).unwrap(); + + assert_eq!(result1.token_ids(), result2.token_ids()); + + // No cache stats available + assert!(cached.cache_stats().is_none()); + } + + #[test] + fn test_encode_batch() { + let tokenizer = Arc::new(MockTokenizer::new()); + let cached = CachedTokenizer::new(tokenizer, CacheConfig::default()); + + let inputs = vec!["Hello", "world", "Hello"]; // "Hello" repeated + + let results = cached.encode_batch(&inputs).unwrap(); + + assert_eq!(results.len(), 3); + + // With parallel execution, duplicate inputs may be processed simultaneously + // and both see cache misses. Verify results are correct instead. + assert_eq!(results[0].token_ids(), results[2].token_ids()); // Both "Hello" should match + + // After batch processing, cache should be populated + // Subsequent calls should hit the cache + let _ = cached.encode("Hello").unwrap(); + let stats = cached.cache_stats().unwrap(); + + // Should have at least 1 hit from the call above (cache was populated by batch) + assert!( + stats.hits >= 1, + "Expected at least 1 cache hit after batch processing" + ); + } + + #[test] + fn test_decoder_passthrough() { + let tokenizer = Arc::new(MockTokenizer::new()); + let cached = CachedTokenizer::new(tokenizer, CacheConfig::default()); + + let tokens = vec![1, 2, 3]; + let decoded = cached.decode(&tokens, false).unwrap(); + + // Should just pass through to inner tokenizer + assert!(!decoded.is_empty()); + } + + #[test] + fn test_tokenizer_trait_methods() { + let tokenizer = Arc::new(MockTokenizer::new()); + let cached = CachedTokenizer::new(tokenizer.clone(), CacheConfig::default()); + + // Should pass through to inner tokenizer + assert_eq!(cached.vocab_size(), tokenizer.vocab_size()); + assert!(cached.token_to_id("Hello").is_some()); + assert!(cached.id_to_token(1).is_some()); + } +} diff --git a/sgl-router/src/tokenizer/huggingface.rs b/sgl-router/src/tokenizer/huggingface.rs index 727d35715b8..552b7361a19 100644 --- a/sgl-router/src/tokenizer/huggingface.rs +++ b/sgl-router/src/tokenizer/huggingface.rs @@ -44,8 +44,8 @@ impl HuggingFaceTokenizer { // Extract special tokens let special_tokens = Self::extract_special_tokens(&tokenizer); - // Build vocab mappings - let vocab = tokenizer.get_vocab(false); + // Build vocab mappings (include special tokens to get added_tokens like <|im_start|>) + let vocab = tokenizer.get_vocab(true); // true = include special tokens and added_tokens let reverse_vocab: HashMap = vocab .iter() .map(|(token, &id)| (id, token.clone())) @@ -80,7 +80,7 @@ impl HuggingFaceTokenizer { /// Create from an existing HuggingFace tokenizer pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self { let special_tokens = Self::extract_special_tokens(&tokenizer); - let vocab = tokenizer.get_vocab(false); + let vocab = tokenizer.get_vocab(true); // true = include special tokens and added_tokens let reverse_vocab: HashMap = vocab .iter() .map(|(token, &id)| (id, token.clone())) @@ -98,8 +98,7 @@ impl HuggingFaceTokenizer { /// Extract special tokens from the tokenizer fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens { - // Try to get special tokens from the tokenizer - // This is a simplified version - actual implementation would need to handle various formats + // Get vocab with special tokens included (added_tokens like <|im_start|>) let vocab = tokenizer.get_vocab(true); let find_token = |patterns: &[&str]| -> Option { @@ -111,6 +110,19 @@ impl HuggingFaceTokenizer { None }; + // Extract additional special tokens (ChatML tokens like <|im_start|>, <|im_end|>, etc.) + // These are typically in added_tokens with special: true + let mut additional_special_tokens = Vec::new(); + for (token, _id) in vocab.iter() { + // Look for tokens that match common special token patterns + if token.starts_with("<|") && token.ends_with("|>") { + additional_special_tokens.push(token.clone()); + } else if token.starts_with("<|") && token.ends_with(">") { + // Alternative patterns like <|endoftext|> + additional_special_tokens.push(token.clone()); + } + } + SpecialTokens { bos_token: find_token(&["", "<|startoftext|>", "", "[CLS]"]), eos_token: find_token(&["", "<|endoftext|>", "", "[SEP]"]), @@ -119,7 +131,7 @@ impl HuggingFaceTokenizer { pad_token: find_token(&["", "", "[PAD]"]), cls_token: find_token(&["[CLS]", "", ""]), mask_token: find_token(&["[MASK]", "", ""]), - additional_special_tokens: vec![], + additional_special_tokens, } } diff --git a/sgl-router/src/tokenizer/mock.rs b/sgl-router/src/tokenizer/mock.rs index ab918db373c..136f622f8de 100644 --- a/sgl-router/src/tokenizer/mock.rs +++ b/sgl-router/src/tokenizer/mock.rs @@ -62,11 +62,9 @@ impl MockTokenizer { impl Encoder for MockTokenizer { fn encode(&self, input: &str) -> Result { - // Simple word-based tokenization for testing - let tokens: Vec = input - .split_whitespace() - .filter_map(|word| self.vocab.get(word).copied()) - .collect(); + // Simple character-based tokenization for testing + // Returns one token per character to ensure non-empty encodings + let tokens: Vec = input.bytes().map(|b| b as u32).collect(); Ok(Encoding::Sp(tokens)) } diff --git a/sgl-router/src/tokenizer/mod.rs b/sgl-router/src/tokenizer/mod.rs index 651b340de23..78fe3915883 100644 --- a/sgl-router/src/tokenizer/mod.rs +++ b/sgl-router/src/tokenizer/mod.rs @@ -2,6 +2,7 @@ use std::{ops::Deref, sync::Arc}; use anyhow::Result; +pub mod cache; pub mod factory; pub mod hub; pub mod mock; @@ -22,6 +23,7 @@ pub mod tiktoken; mod tests; // Re-exports +pub use cache::{CacheConfig, CacheStats, CachedTokenizer, TokenizerFingerprint}; pub use factory::{ create_tokenizer, create_tokenizer_async, create_tokenizer_async_with_chat_template, create_tokenizer_from_file, create_tokenizer_with_chat_template, diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index a94b416f059..d880bf2c7be 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -69,6 +69,7 @@ impl TestContext { oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(), }; Self::new_with_config(config, worker_configs).await @@ -1406,6 +1407,7 @@ mod error_tests { oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(), }; let ctx = TestContext::new_with_config( @@ -1735,6 +1737,7 @@ mod pd_mode_tests { oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(), }; // Create app context @@ -1898,6 +1901,7 @@ mod request_id_tests { oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(), }; let ctx = TestContext::new_with_config( diff --git a/sgl-router/tests/responses_api_test.rs b/sgl-router/tests/responses_api_test.rs index 86cd154d086..589b0ce5d84 100644 --- a/sgl-router/tests/responses_api_test.rs +++ b/sgl-router/tests/responses_api_test.rs @@ -84,6 +84,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(), }; // Create router and context @@ -284,6 +285,7 @@ async fn test_conversations_crud_basic() { oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(), }; let ctx = common::create_test_context(router_cfg); @@ -619,6 +621,7 @@ async fn test_multi_turn_loop_with_mcp() { oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(), }; let ctx = common::create_test_context(router_cfg); @@ -795,6 +798,7 @@ async fn test_max_tool_calls_limit() { oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(), }; let ctx = common::create_test_context(router_cfg); @@ -937,6 +941,7 @@ async fn setup_streaming_mcp_test() -> ( oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(), }; let ctx = common::create_test_context(router_cfg); @@ -1378,6 +1383,7 @@ async fn test_conversation_items_create_and_get() { oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(), }; let ctx = common::create_test_context(router_cfg); @@ -1479,6 +1485,7 @@ async fn test_conversation_items_delete() { oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(), }; let ctx = common::create_test_context(router_cfg); @@ -1586,6 +1593,7 @@ async fn test_conversation_items_max_limit() { oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(), }; let ctx = common::create_test_context(router_cfg); @@ -1663,6 +1671,7 @@ async fn test_conversation_items_unsupported_type() { oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(), }; let ctx = common::create_test_context(router_cfg); @@ -1739,6 +1748,7 @@ async fn test_conversation_items_multi_conversation_sharing() { oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(), }; let ctx = common::create_test_context(router_cfg); diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index dfe9cdcdce2..6e054126cfd 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -200,6 +200,7 @@ mod test_pd_routing { oracle: None, reasoning_parser: None, tool_call_parser: None, + tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(), }; let app_context = { From 5f6c86355961e6b0e9f857047d382963928e2a51 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Thu, 16 Oct 2025 19:16:24 -0700 Subject: [PATCH 2/7] fixup --- sgl-router/src/tokenizer/cache/fingerprint.rs | 7 ++-- sgl-router/src/tokenizer/cache/l0.rs | 10 ++++-- sgl-router/src/tokenizer/cache/l1.rs | 14 +++++--- sgl-router/src/tokenizer/cache/mod.rs | 33 +++++++++---------- 4 files changed, 38 insertions(+), 26 deletions(-) diff --git a/sgl-router/src/tokenizer/cache/fingerprint.rs b/sgl-router/src/tokenizer/cache/fingerprint.rs index 0adafd85012..be8565e5ecf 100644 --- a/sgl-router/src/tokenizer/cache/fingerprint.rs +++ b/sgl-router/src/tokenizer/cache/fingerprint.rs @@ -3,9 +3,12 @@ //! Creates a unique fingerprint of a tokenizer's configuration to detect //! when the tokenizer has changed and the cache needs to be cleared. +use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, +}; + use super::super::traits::Tokenizer; -use std::collections::hash_map::DefaultHasher; -use std::hash::{Hash, Hasher}; /// A fingerprint of a tokenizer's configuration #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/sgl-router/src/tokenizer/cache/l0.rs b/sgl-router/src/tokenizer/cache/l0.rs index 768ff42a3aa..203ea5284e3 100644 --- a/sgl-router/src/tokenizer/cache/l0.rs +++ b/sgl-router/src/tokenizer/cache/l0.rs @@ -5,10 +5,14 @@ //! //! Expected hit rate: 60-90% for workloads with repeated system prompts -use super::super::traits::Encoding; +use std::sync::{ + atomic::{AtomicU64, Ordering}, + Arc, +}; + use dashmap::DashMap; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::Arc; + +use super::super::traits::Encoding; /// L0 cache implementation using DashMap for lock-free reads pub struct L0Cache { diff --git a/sgl-router/src/tokenizer/cache/l1.rs b/sgl-router/src/tokenizer/cache/l1.rs index 39505cb15ba..4bc2ca74a8e 100644 --- a/sgl-router/src/tokenizer/cache/l1.rs +++ b/sgl-router/src/tokenizer/cache/l1.rs @@ -19,12 +19,18 @@ //! Cache points: After each "<|im_end|>" (atomic tokens, guaranteed safe) //! Result: tokenize(prefix) + tokenize(suffix) == tokenize(prefix + suffix) -use super::super::traits::TokenIdType; +use std::{ + mem::size_of, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, +}; + use blake3; use dashmap::DashMap; -use std::mem::size_of; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::Arc; + +use super::super::traits::TokenIdType; /// Hash type for cache keys type Blake3Hash = [u8; 32]; diff --git a/sgl-router/src/tokenizer/cache/mod.rs b/sgl-router/src/tokenizer/cache/mod.rs index c7eea01d618..a0bf27452eb 100644 --- a/sgl-router/src/tokenizer/cache/mod.rs +++ b/sgl-router/src/tokenizer/cache/mod.rs @@ -18,14 +18,15 @@ mod fingerprint; mod l0; mod l1; +use std::sync::Arc; + +use anyhow::Result; pub use fingerprint::TokenizerFingerprint; pub use l0::{CacheStats, L0Cache}; pub use l1::{L1Cache, L1CacheStats}; +use rayon::prelude::*; use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer}; -use anyhow::Result; -use rayon::prelude::*; -use std::sync::Arc; /// Configuration for the tokenizer cache #[derive(Debug, Clone)] @@ -157,6 +158,14 @@ impl CachedTokenizer { impl Encoder for CachedTokenizer { fn encode(&self, input: &str) -> Result { + // Collect special tokens once if L1 is enabled (avoid redundant allocation) + let special_tokens: Option> = self.l1.as_ref().map(|_| { + self.special_token_strings + .iter() + .map(|s| s.as_str()) + .collect() + }); + // L0 cache lookup (exact match) if let Some(l0) = &self.l0 { if let Some(cached) = l0.get(input) { @@ -166,15 +175,9 @@ impl Encoder for CachedTokenizer { // L1 cache lookup (prefix match at special token boundaries) if let Some(l1) = &self.l1 { - let special_tokens: Vec<&str> = self - .special_token_strings - .iter() - .map(|s| s.as_str()) - .collect(); + let tokens = special_tokens.as_ref().unwrap(); - if let Some((prefix_tokens, prefix_len)) = - l1.longest_prefix_match(input, &special_tokens) - { + if let Some((prefix_tokens, prefix_len)) = l1.longest_prefix_match(input, tokens) { // We have a prefix match - tokenize the suffix let suffix = &input[prefix_len..]; if !suffix.is_empty() { @@ -208,12 +211,8 @@ impl Encoder for CachedTokenizer { // Cache in L1 at special token boundaries // Re-tokenizes prefixes for correctness (optimized for high prefix reuse) if let Some(l1) = &self.l1 { - let special_tokens: Vec<&str> = self - .special_token_strings - .iter() - .map(|s| s.as_str()) - .collect(); - let _ = l1.insert_at_boundaries(input, self.inner.as_ref(), &special_tokens); + let tokens = special_tokens.as_ref().unwrap(); + let _ = l1.insert_at_boundaries(input, self.inner.as_ref(), tokens); // Ignore errors in cache insertion - cache is best-effort } From 229029bbba5c9acd534e6ef26470b8511b59663f Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Thu, 16 Oct 2025 19:36:24 -0700 Subject: [PATCH 3/7] fixup --- .../tests/tokenizer_cache_correctness_test.rs | 467 ++++++++++++++++++ 1 file changed, 467 insertions(+) create mode 100644 sgl-router/tests/tokenizer_cache_correctness_test.rs diff --git a/sgl-router/tests/tokenizer_cache_correctness_test.rs b/sgl-router/tests/tokenizer_cache_correctness_test.rs new file mode 100644 index 00000000000..66a1838fef1 --- /dev/null +++ b/sgl-router/tests/tokenizer_cache_correctness_test.rs @@ -0,0 +1,467 @@ +//! Cache correctness integration test +//! +//! This test validates that the tokenizer cache (L0, L1, and L0+L1 combined) produces +//! exactly the same token IDs as uncached tokenization across multiple chat turns. +//! Uses the real Qwen/Qwen3-4B-Instruct-2507 tokenizer to test with actual special tokens. + +use std::{ + path::PathBuf, + sync::{Arc, OnceLock}, +}; + +use sglang_router_rs::tokenizer::{ + cache::{CacheConfig, CachedTokenizer}, + hub::download_tokenizer_from_hf, + huggingface::HuggingFaceTokenizer, + traits::Encoder, +}; + +/// Global tokenizer path cache - download once, reuse across all tests +static TOKENIZER_PATH: OnceLock> = OnceLock::new(); + +/// Download Qwen3-4B-Instruct-2507 tokenizer once and cache the path +async fn get_tokenizer_path() -> Option { + // Check if already downloaded + if let Some(cached) = TOKENIZER_PATH.get() { + return cached.clone(); + } + + // Download tokenizer + let result = match download_tokenizer_from_hf("Qwen/Qwen3-4B-Instruct-2507").await { + Ok(cache_dir) => { + let tokenizer_path = cache_dir.join("tokenizer.json"); + if tokenizer_path.exists() { + Some(tokenizer_path) + } else { + println!("Tokenizer downloaded but tokenizer.json not found"); + None + } + } + Err(e) => { + println!("Failed to download tokenizer: {}", e); + None + } + }; + + // Cache the result (even if None, so we don't retry on failure) + TOKENIZER_PATH.set(result.clone()).ok(); + result +} + +/// Comprehensive multi-turn chat conversation for testing cache correctness +/// Uses Qwen's special tokens with diverse content to hit edge cases +const CHAT_TURNS: [&str; 29] = [ + // Basic conversation patterns + "<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>", + "<|im_start|>system\nYou are a helpful AI assistant.<|im_end|><|im_start|>user\nWhat is the capital of France?<|im_end|>", + "<|im_start|>system\nYou are a helpful AI assistant.<|im_end|><|im_start|>user\nWhat is the capital of France?<|im_end|><|im_start|>assistant\nThe capital of France is Paris.<|im_end|>", + + // Different system prompts (testing different prefix patterns) + "<|im_start|>system\nYou are a coding tutor specializing in Rust programming.<|im_end|><|im_start|>user\nExplain ownership.<|im_end|>", + "<|im_start|>system\nYou are a math teacher.<|im_end|><|im_start|>user\nSolve: 2x + 5 = 13<|im_end|>", + + // Long conversation with multiple turns (testing longer prefixes) + "<|im_start|>system\nYou are a helpful AI assistant.<|im_end|><|im_start|>user\nTell me about deep learning.<|im_end|><|im_start|>assistant\nDeep learning is a subset of machine learning that uses neural networks with multiple layers.<|im_end|><|im_start|>user\nWhat are the main architectures?<|im_end|>", + + // Code snippets (testing different character patterns) + "<|im_start|>system\nYou are a code reviewer.<|im_end|><|im_start|>user\nReview this code:\nfn main() {\n println!(\"Hello, world!\");\n}\n<|im_end|>", + "<|im_start|>system\nYou are a code reviewer.<|im_end|><|im_start|>user\nExplain this Rust code:\nimpl Drop for Box {\n fn drop(&mut self) { /* ... */ }\n}\n<|im_end|>", + + // Mathematical content + "<|im_start|>system\nYou are a math tutor.<|im_end|><|im_start|>user\nProve that √2 is irrational using proof by contradiction.<|im_end|>", + "<|im_start|>system\nYou are a math tutor.<|im_end|><|im_start|>user\nCalculate: ∫(x² + 3x + 2)dx from 0 to 5<|im_end|>", + + // Multilingual content + "<|im_start|>system\nYou are a multilingual assistant.<|im_end|><|im_start|>user\nTranslate to French: The quick brown fox jumps over the lazy dog.<|im_end|>", + "<|im_start|>system\nYou are a multilingual assistant.<|im_end|><|im_start|>user\n你好,请帮我翻译这句话:I love programming in Rust.<|im_end|>", + "<|im_start|>system\nYou are a multilingual assistant.<|im_end|><|im_start|>user\nこんにちは!Rustについて教えてください。<|im_end|>", + + // Special characters and emojis + "<|im_start|>system\nYou are a friendly chatbot.<|im_end|><|im_start|>user\nWhat do you think about emojis? 😀🎉🚀💻<|im_end|>", + "<|im_start|>system\nYou are a data analyst.<|im_end|><|im_start|>user\nAnalyze this: {\"name\": \"test\", \"value\": 42, \"nested\": {\"key\": \"value\"}}<|im_end|>", + + // Very long message (testing large token counts) + "<|im_start|>system\nYou are a literature expert.<|im_end|><|im_start|>user\nAnalyze the themes in this passage: In the vast expanse of the digital realm, where bits and bytes dance in harmonious symphony, there exists a paradigm that transcends mere computation. This paradigm, known as machine learning, represents humanity's quest to imbue silicon with the spark of cognition. Deep neural networks, inspired by the intricate architecture of biological brains, layer upon layer of artificial neurons, each connection a synapse firing in the dark recesses of mathematical space. Through gradient descent, these networks learn patterns invisible to human perception, extracting meaning from chaos, signal from noise. The transformer architecture revolutionized this field, introducing attention mechanisms that allowed models to focus on relevant information, much like how humans selectively attend to important details in their environment.<|im_end|>", + + // Edge case: Multiple special tokens in sequence + "<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\nHi<|im_end|><|im_start|>assistant\nHello!<|im_end|><|im_start|>user\nHow are you?<|im_end|>", + + // Edge case: Empty-ish messages + "<|im_start|>system\n<|im_end|><|im_start|>user\nTest<|im_end|>", + "<|im_start|>system\nBrief.<|im_end|><|im_start|>user\nOK<|im_end|>", + + // Technical documentation style + "<|im_start|>system\nYou are a technical writer.<|im_end|><|im_start|>user\nDocument the following API:\n\n```rust\npub struct CachedTokenizer {\n inner: Arc,\n l0: Option,\n l1: Option,\n}\n\nimpl Encoder for CachedTokenizer {\n fn encode(&self, input: &str) -> Result;\n}\n```\n<|im_end|>", + + // Conversation with code review + "<|im_start|>system\nYou are a senior Rust developer.<|im_end|><|im_start|>user\nReview for correctness:\n\nlet special_tokens: Option> = self.l1.as_ref().map(|_| {\n self.special_token_strings.iter().map(|s| s.as_str()).collect()\n});<|im_end|>", + + // Markdown formatted content + "<|im_start|>system\nYou are a documentation assistant.<|im_end|><|im_start|>user\nFormat this as markdown:\n\n# Cache Architecture\n\n## L0 Cache\n- Exact match\n- DashMap based\n- 10K entries\n\n## L1 Cache \n- Prefix match\n- Special token boundaries\n- 50MB memory\n<|im_end|>", + + // Complex nested structures + "<|im_start|>system\nYou are a JSON expert.<|im_end|><|im_start|>user\nValidate this JSON:\n{\n \"tokenizer_cache\": {\n \"enable_l0\": true,\n \"l0_max_entries\": 10000,\n \"enable_l1\": true,\n \"l1_max_memory\": 52428800,\n \"stats\": {\n \"hits\": [1, 2, 3],\n \"misses\": {\"count\": 5}\n }\n }\n}\n<|im_end|>", + + // SQL queries + "<|im_start|>system\nYou are a database expert.<|im_end|><|im_start|>user\nOptimize this query:\nSELECT u.name, COUNT(p.id) as post_count\nFROM users u\nLEFT JOIN posts p ON u.id = p.user_id\nWHERE u.created_at > '2024-01-01'\nGROUP BY u.id, u.name\nHAVING COUNT(p.id) > 5\nORDER BY post_count DESC;<|im_end|>", + + // Regex patterns + "<|im_start|>system\nYou are a regex expert.<|im_end|><|im_start|>user\nExplain this regex: ^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\\.)+[a-zA-Z]{2,}$<|im_end|>", + + // Command line examples + "<|im_start|>system\nYou are a DevOps engineer.<|im_end|><|im_start|>user\nExplain this command:\ncargo bench --bench tokenizer_benchmark -- --color=never | tee results.txt<|im_end|>", + + // Unicode edge cases + "<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\nTest: café, naïve, Zürich, 北京, 東京, मुंबई, Москва<|im_end|>", + + // Mixed content complexity + "<|im_start|>system\nYou are a software architect.<|im_end|><|im_start|>user\nDesign a caching system that:\n1. Handles 10K+ QPS\n2. Maintains 99.9% uptime \n3. Supports L0 (exact) and L1 (prefix) caching\n4. Uses Blake3 for hashing (10GB/s throughput)\n5. Implements LRU eviction\n6. Thread-safe with lock-free reads\n\nKey requirements:\n- Memory: 50MB L1 budget\n- Latency: <100µs p99\n- Correctness: 100% (no false tokens)\n<|im_end|>", + + // Very long technical discussion + "<|im_start|>system\nYou are a compiler expert.<|im_end|><|im_start|>user\nExplain why BPE tokenizers are not prefix-stable:\n\nThe core issue is that BPE applies merges based on local context. When you tokenize 'prefix' alone, it might apply merge rules differently than when tokenizing 'prefix + suffix' as a whole. For example:\n\ntokenize('hello world') might produce [hello, _world]\ntokenize('hello') + tokenize(' world') might produce [hel, lo, _wo, rld]\n\nThis is because the merge rules see different contexts. The space before 'world' in the first case is part of the token boundary, but in the second case, ' world' is tokenized in isolation.\n\nSpecial tokens solve this because they are:\n1. Atomic (never split or merged)\n2. Protected from normalization\n3. Marked with special: true flag\n4. Have normalized: false property\n\nThis guarantees: tokenize(prefix + special + suffix) = tokenize(prefix + special) + tokenize(suffix)\n\nOur L1 cache exploits this by:\n1. Finding all special token boundaries\n2. Re-tokenizing prefixes at those boundaries\n3. Caching the exact token IDs\n4. On cache hit, appending suffix tokens\n\nThis achieves both correctness (100%) and performance (22.7x speedup on high prefix reuse workloads).<|im_end|>", +]; + +#[tokio::test] +async fn test_cache_produces_identical_tokens() { + // Get tokenizer path (download once, cached across tests) + let tokenizer_path = match get_tokenizer_path().await { + Some(path) => path, + None => { + println!("Skipping test - tokenizer not available"); + return; + } + }; + + // Create base tokenizer (no cache) + let base_tokenizer = Arc::new( + HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap()) + .expect("Failed to load base tokenizer"), + ); + + // Create cached tokenizers with different configurations + let l0_only_config = CacheConfig { + enable_l0: true, + l0_max_entries: 10_000, + enable_l1: false, + l1_max_memory: 0, + }; + + let l1_only_config = CacheConfig { + enable_l0: false, + l0_max_entries: 0, + enable_l1: true, + l1_max_memory: 50 * 1024 * 1024, + }; + + let l0_l1_config = CacheConfig { + enable_l0: true, + l0_max_entries: 10_000, + enable_l1: true, + l1_max_memory: 50 * 1024 * 1024, + }; + + let l0_tokenizer = Arc::new(CachedTokenizer::new(base_tokenizer.clone(), l0_only_config)); + let l1_tokenizer = Arc::new(CachedTokenizer::new(base_tokenizer.clone(), l1_only_config)); + let l0_l1_tokenizer = Arc::new(CachedTokenizer::new(base_tokenizer.clone(), l0_l1_config)); + + println!( + "\n=== Testing Cache Correctness Across {} Chat Turns ===\n", + CHAT_TURNS.len() + ); + + for (turn_idx, turn) in CHAT_TURNS.iter().enumerate() { + println!("Turn {}: Testing {} chars", turn_idx + 1, turn.len()); + + // Tokenize with base (no cache) + let base_encoding = base_tokenizer + .encode(turn) + .expect("Base tokenization failed"); + let base_tokens = base_encoding.token_ids(); + + // Tokenize with L0-only + let l0_encoding = l0_tokenizer.encode(turn).expect("L0 tokenization failed"); + let l0_tokens = l0_encoding.token_ids(); + + // Tokenize with L1-only + let l1_encoding = l1_tokenizer.encode(turn).expect("L1 tokenization failed"); + let l1_tokens = l1_encoding.token_ids(); + + // Tokenize with L0+L1 + let l0_l1_encoding = l0_l1_tokenizer + .encode(turn) + .expect("L0+L1 tokenization failed"); + let l0_l1_tokens = l0_l1_encoding.token_ids(); + + // Verify all configurations produce identical token IDs + assert_eq!( + base_tokens.len(), + l0_tokens.len(), + "Turn {}: L0 token count mismatch (base: {}, L0: {})", + turn_idx + 1, + base_tokens.len(), + l0_tokens.len() + ); + + assert_eq!( + base_tokens.len(), + l1_tokens.len(), + "Turn {}: L1 token count mismatch (base: {}, L1: {})", + turn_idx + 1, + base_tokens.len(), + l1_tokens.len() + ); + + assert_eq!( + base_tokens.len(), + l0_l1_tokens.len(), + "Turn {}: L0+L1 token count mismatch (base: {}, L0+L1: {})", + turn_idx + 1, + base_tokens.len(), + l0_l1_tokens.len() + ); + + // Compare token by token + for (token_idx, (((base_token, l0_token), l1_token), l0_l1_token)) in base_tokens + .iter() + .zip(l0_tokens.iter()) + .zip(l1_tokens.iter()) + .zip(l0_l1_tokens.iter()) + .enumerate() + { + assert_eq!( + base_token, + l0_token, + "Turn {}, token {}: L0 mismatch (base: {}, L0: {})", + turn_idx + 1, + token_idx, + base_token, + l0_token + ); + + assert_eq!( + base_token, + l1_token, + "Turn {}, token {}: L1 mismatch (base: {}, L1: {})", + turn_idx + 1, + token_idx, + base_token, + l1_token + ); + + assert_eq!( + base_token, + l0_l1_token, + "Turn {}, token {}: L0+L1 mismatch (base: {}, L0+L1: {})", + turn_idx + 1, + token_idx, + base_token, + l0_l1_token + ); + } + + println!( + " ✓ All configurations produced identical {} tokens", + base_tokens.len() + ); + } + + // Print cache statistics + if let Some(l0_stats) = l0_tokenizer.cache_stats() { + println!("\n=== L0 Cache Statistics ==="); + println!(" Hits: {}", l0_stats.hits); + println!(" Misses: {}", l0_stats.misses); + println!( + " Hit rate: {:.2}%", + if l0_stats.hits + l0_stats.misses > 0 { + l0_stats.hits as f64 / (l0_stats.hits + l0_stats.misses) as f64 * 100.0 + } else { + 0.0 + } + ); + println!(" Entries: {}", l0_stats.entries); + } + + if let Some(l1_stats) = l1_tokenizer.l1_cache_stats() { + println!("\n=== L1 Cache Statistics ==="); + println!(" Hits: {}", l1_stats.hits); + println!(" Misses: {}", l1_stats.misses); + println!( + " Hit rate: {:.2}%", + if l1_stats.hits + l1_stats.misses > 0 { + l1_stats.hits as f64 / (l1_stats.hits + l1_stats.misses) as f64 * 100.0 + } else { + 0.0 + } + ); + println!(" Entries: {}", l1_stats.entries); + println!(" Memory used: {} bytes", l1_stats.memory_bytes); + } + + if let Some(l0_stats) = l0_l1_tokenizer.cache_stats() { + if let Some(l1_stats) = l0_l1_tokenizer.l1_cache_stats() { + println!("\n=== L0+L1 Combined Cache Statistics ==="); + println!(" L0 Hits: {}", l0_stats.hits); + println!(" L1 Hits: {}", l1_stats.hits); + println!( + " Total Hit rate: {:.2}%", + if l0_stats.hits + l1_stats.hits + l0_stats.misses + l1_stats.misses > 0 { + (l0_stats.hits + l1_stats.hits) as f64 + / (l0_stats.hits + l1_stats.hits + l0_stats.misses + l1_stats.misses) as f64 + * 100.0 + } else { + 0.0 + } + ); + } + } + + println!("\n✓ All cache configurations produce identical tokenization results!"); +} + +#[tokio::test] +async fn test_cache_correctness_with_edge_cases() { + // Get tokenizer path (download once, cached across tests) + let tokenizer_path = match get_tokenizer_path().await { + Some(path) => path, + None => { + println!("Skipping test - tokenizer not available"); + return; + } + }; + + // Create base and cached tokenizers + let base_tokenizer = Arc::new( + HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap()) + .expect("Failed to load base tokenizer"), + ); + + let cached_config = CacheConfig { + enable_l0: true, + l0_max_entries: 10_000, + enable_l1: true, + l1_max_memory: 50 * 1024 * 1024, + }; + + let cached_tokenizer = Arc::new(CachedTokenizer::new(base_tokenizer.clone(), cached_config)); + + println!("\n=== Testing Edge Cases and Complex Patterns ===\n"); + + // Edge cases that stress-test the cache + let edge_cases = [ + // Minimal messages + ("<|im_start|>system\n<|im_end|>", "Empty system message"), + ("<|im_start|>user\na<|im_end|>", "Single character"), + + // Special token boundaries + ("<|im_start|>system\nA<|im_end|><|im_start|>user\nB<|im_end|><|im_start|>assistant\nC<|im_end|>", "Minimal multi-turn"), + + // Repeated exact queries (L0 hit test) + ("<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\nHello!<|im_end|>", "Repeated query 1"), + ("<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\nHello!<|im_end|>", "Repeated query 2"), + + // Same prefix, different suffix (L1 hit test) + ("<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\nWhat is 1+1?<|im_end|>", "Same prefix, diff suffix 1"), + ("<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\nWhat is 2+2?<|im_end|>", "Same prefix, diff suffix 2"), + ("<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\nWhat is 3+3?<|im_end|>", "Same prefix, diff suffix 3"), + + // Unicode stress tests + ("<|im_start|>system\n你好<|im_end|><|im_start|>user\n世界<|im_end|>", "Chinese characters"), + ("<|im_start|>system\nこんにちは<|im_end|><|im_start|>user\n世界<|im_end|>", "Japanese + Chinese"), + ("<|im_start|>system\n🚀💻🎉<|im_end|><|im_start|>user\n😀😃😄<|im_end|>", "Emoji only"), + + // Whitespace edge cases + ("<|im_start|>system\n \n<|im_end|>", "Whitespace only"), + ("<|im_start|>system\n\n\n\n<|im_end|>", "Multiple newlines"), + ("<|im_start|>system\n\t\t\t<|im_end|>", "Tabs"), + + // Long token sequences + ("<|im_start|>system\nThe quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog.<|im_end|>", "Repeated phrase"), + + // Special characters + ("<|im_start|>system\n!@#$%^&*()_+-=[]{}|;':\",./<>?<|im_end|>", "ASCII special chars"), + ("<|im_start|>system\n`~\\<|im_end|>", "Backtick and tilde"), + + // Code with special formatting + ("<|im_start|>system\nCode: fn() -> Result<(), Box><|im_end|>", "Rust generics"), + ("<|im_start|>system\nRegex: ^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$<|im_end|>", "Email regex"), + + // Very long single token sequences (testing buffer handling) + ("<|im_start|>system\naaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa<|im_end|>", "Repeated 'a'"), + ("<|im_start|>system\n0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789<|im_end|>", "Repeated numbers"), + ]; + + let mut test_count = 0; + let mut mismatch_count = 0; + + for (query, description) in edge_cases.iter() { + test_count += 1; + + let base_tokens = base_tokenizer + .encode(query) + .expect("Base encoding failed") + .token_ids() + .to_vec(); + + let cached_tokens = cached_tokenizer + .encode(query) + .expect("Cached encoding failed") + .token_ids() + .to_vec(); + + if base_tokens != cached_tokens { + mismatch_count += 1; + println!(" ✗ {}: Token mismatch!", description); + println!( + " Base length: {}, Cached length: {}", + base_tokens.len(), + cached_tokens.len() + ); + + // Show first few mismatching tokens for debugging + for (i, (base, cached)) in base_tokens.iter().zip(cached_tokens.iter()).enumerate() { + if base != cached { + println!(" Token {}: base={}, cached={}", i, base, cached); + if i >= 5 { + break; + } + } + } + } else { + println!(" ✓ {}: {} tokens", description, base_tokens.len()); + } + } + + assert_eq!( + mismatch_count, 0, + "{} out of {} edge cases failed!", + mismatch_count, test_count + ); + + // Print cache statistics + if let Some(l0_stats) = cached_tokenizer.cache_stats() { + println!("\n=== Cache Statistics ==="); + println!( + " L0 Hits: {} ({:.1}% hit rate)", + l0_stats.hits, + if l0_stats.hits + l0_stats.misses > 0 { + l0_stats.hits as f64 / (l0_stats.hits + l0_stats.misses) as f64 * 100.0 + } else { + 0.0 + } + ); + } + + if let Some(l1_stats) = cached_tokenizer.l1_cache_stats() { + println!( + " L1 Hits: {} ({:.1}% hit rate)", + l1_stats.hits, + if l1_stats.hits + l1_stats.misses > 0 { + l1_stats.hits as f64 / (l1_stats.hits + l1_stats.misses) as f64 * 100.0 + } else { + 0.0 + } + ); + } + + println!("\n✓ All {} edge cases passed!", test_count); +} From 451f39fb2b721fd1c94a238332a1cbe20eb49aa7 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Thu, 16 Oct 2025 19:41:04 -0700 Subject: [PATCH 4/7] fixup --- sgl-router/src/tokenizer/huggingface.rs | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/sgl-router/src/tokenizer/huggingface.rs b/sgl-router/src/tokenizer/huggingface.rs index 552b7361a19..5f4c1b375e1 100644 --- a/sgl-router/src/tokenizer/huggingface.rs +++ b/sgl-router/src/tokenizer/huggingface.rs @@ -110,18 +110,13 @@ impl HuggingFaceTokenizer { None }; - // Extract additional special tokens (ChatML tokens like <|im_start|>, <|im_end|>, etc.) - // These are typically in added_tokens with special: true - let mut additional_special_tokens = Vec::new(); - for (token, _id) in vocab.iter() { - // Look for tokens that match common special token patterns - if token.starts_with("<|") && token.ends_with("|>") { - additional_special_tokens.push(token.clone()); - } else if token.starts_with("<|") && token.ends_with(">") { - // Alternative patterns like <|endoftext|> - additional_special_tokens.push(token.clone()); - } - } + // Extract additional special tokens using the tokenizers library API + let additional_special_tokens: Vec = tokenizer + .get_added_tokens_decoder() + .iter() + .filter(|(_id, token)| token.special) // Only tokens marked as special: true + .map(|(_id, token)| token.content.clone()) + .collect(); SpecialTokens { bos_token: find_token(&["", "<|startoftext|>", "", "[CLS]"]), From 56485610a6ffac1eec06540d0a71a339dd5a25f1 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Thu, 16 Oct 2025 20:04:28 -0700 Subject: [PATCH 5/7] fixup --- sgl-router/src/tokenizer/factory.rs | 2 +- sgl-router/src/tokenizer/mock.rs | 15 ++++++++++++--- sgl-router/src/tokenizer/tests.rs | 2 +- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/sgl-router/src/tokenizer/factory.rs b/sgl-router/src/tokenizer/factory.rs index 8d8cde5f7a6..46cfae3de35 100644 --- a/sgl-router/src/tokenizer/factory.rs +++ b/sgl-router/src/tokenizer/factory.rs @@ -407,7 +407,7 @@ mod tests { #[test] fn test_mock_tokenizer_creation() { let tokenizer = create_tokenizer_from_file("mock").unwrap(); - assert_eq!(tokenizer.vocab_size(), 8); // Mock tokenizer has 8 tokens + assert_eq!(tokenizer.vocab_size(), 14); // Mock tokenizer has 14 tokens } #[test] diff --git a/sgl-router/src/tokenizer/mock.rs b/sgl-router/src/tokenizer/mock.rs index 136f622f8de..8e6abdb86e4 100644 --- a/sgl-router/src/tokenizer/mock.rs +++ b/sgl-router/src/tokenizer/mock.rs @@ -34,6 +34,12 @@ impl MockTokenizer { (".", 6), ("", 999), ("", 1000), + ("<|im_start|>", 1001), + ("<|im_end|>", 1002), + ("<|eot_id|>", 1003), + ("system", 7), + ("user", 8), + ("assistant", 9), ]; for (token, id) in tokens { @@ -62,9 +68,12 @@ impl MockTokenizer { impl Encoder for MockTokenizer { fn encode(&self, input: &str) -> Result { - // Simple character-based tokenization for testing - // Returns one token per character to ensure non-empty encodings - let tokens: Vec = input.bytes().map(|b| b as u32).collect(); + // Simple word-based tokenization using the vocab + // Split by whitespace and look up each word (decoder adds spaces back) + let tokens: Vec = input + .split_whitespace() + .filter_map(|word| self.vocab.get(word).copied()) + .collect(); Ok(Encoding::Sp(tokens)) } diff --git a/sgl-router/src/tokenizer/tests.rs b/sgl-router/src/tokenizer/tests.rs index acd1587669a..9ca8f60c30b 100644 --- a/sgl-router/src/tokenizer/tests.rs +++ b/sgl-router/src/tokenizer/tests.rs @@ -43,7 +43,7 @@ fn test_tokenizer_wrapper() { let text = tokenizer.decode(&[1, 2], false).unwrap(); assert_eq!(text, "Hello world"); - assert_eq!(tokenizer.vocab_size(), 8); + assert_eq!(tokenizer.vocab_size(), 14); assert_eq!(tokenizer.token_to_id("Hello"), Some(1)); assert_eq!(tokenizer.token_to_id("unknown"), None); From 4cc7d49cc14f77fdfd6fb77be8624d58c30aece8 Mon Sep 17 00:00:00 2001 From: Chang Su Date: Sat, 18 Oct 2025 11:57:09 -0700 Subject: [PATCH 6/7] Fix `process_chat_messages` to work with cached tokenizer --- sgl-router/src/routers/grpc/utils.rs | 17 ++++++++++++++--- sgl-router/src/tokenizer/cache/mod.rs | 5 +++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/sgl-router/src/routers/grpc/utils.rs b/sgl-router/src/routers/grpc/utils.rs index e39c9ff9e71..d540be443ec 100644 --- a/sgl-router/src/routers/grpc/utils.rs +++ b/sgl-router/src/routers/grpc/utils.rs @@ -26,6 +26,7 @@ use crate::{ generate::GenerateFinishReason, }, tokenizer::{ + cache::CachedTokenizer, chat_template::{ChatTemplateContentFormat, ChatTemplateParams}, traits::Tokenizer, HuggingFaceTokenizer, @@ -317,9 +318,19 @@ pub fn process_chat_messages( tokenizer: &dyn Tokenizer, ) -> Result { // Use the tokenizer's chat template - we require HuggingFace tokenizer for gRPC - let formatted_text = if let Some(hf_tokenizer) = - tokenizer.as_any().downcast_ref::() - { + // First try direct downcast, then try via CachedTokenizer wrapper + let hf_tokenizer = tokenizer + .as_any() + .downcast_ref::() + .or_else(|| { + // If direct downcast fails, try to get inner tokenizer from CachedTokenizer + tokenizer + .as_any() + .downcast_ref::() + .and_then(|cached| cached.inner().as_any().downcast_ref::()) + }); + + let formatted_text = if let Some(hf_tokenizer) = hf_tokenizer { // Get content format and transform messages accordingly let content_format = hf_tokenizer.chat_template_content_format(); let mut transformed_messages = process_content_format(&request.messages, content_format)?; diff --git a/sgl-router/src/tokenizer/cache/mod.rs b/sgl-router/src/tokenizer/cache/mod.rs index a0bf27452eb..c3d86ec06a2 100644 --- a/sgl-router/src/tokenizer/cache/mod.rs +++ b/sgl-router/src/tokenizer/cache/mod.rs @@ -154,6 +154,11 @@ impl CachedTokenizer { pub fn fingerprint(&self) -> &TokenizerFingerprint { &self.fingerprint } + + /// Get a reference to the inner (wrapped) tokenizer + pub fn inner(&self) -> &Arc { + &self.inner + } } impl Encoder for CachedTokenizer { From b9665650c794da7c2b89def0fb4c1d422dd7eb03 Mon Sep 17 00:00:00 2001 From: Chang Su Date: Sat, 18 Oct 2025 12:00:54 -0700 Subject: [PATCH 7/7] fmt files --- sgl-router/src/routers/grpc/utils.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sgl-router/src/routers/grpc/utils.rs b/sgl-router/src/routers/grpc/utils.rs index d540be443ec..c82d764aa30 100644 --- a/sgl-router/src/routers/grpc/utils.rs +++ b/sgl-router/src/routers/grpc/utils.rs @@ -327,7 +327,12 @@ pub fn process_chat_messages( tokenizer .as_any() .downcast_ref::() - .and_then(|cached| cached.inner().as_any().downcast_ref::()) + .and_then(|cached| { + cached + .inner() + .as_any() + .downcast_ref::() + }) }); let formatted_text = if let Some(hf_tokenizer) = hf_tokenizer {