diff --git a/configs/deepseek-v2-lite-eagle3.json b/configs/deepseek-v2-lite-eagle3.json index 9ddad46d..da12c0fb 100644 --- a/configs/deepseek-v2-lite-eagle3.json +++ b/configs/deepseek-v2-lite-eagle3.json @@ -25,7 +25,7 @@ "mscale": 0.707, "mscale_all_dim": 0.707, "original_max_position_embeddings": 4096, - "type": "yarn" + "rope_type": "yarn" }, "rope_theta": 10000, "sliding_window": null, diff --git a/examples/run_deepseek_v2_lite_eagle3_online.sh b/examples/run_deepseek_v2_lite_eagle3_online.sh new file mode 100644 index 00000000..449dd10f --- /dev/null +++ b/examples/run_deepseek_v2_lite_eagle3_online.sh @@ -0,0 +1,21 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for deepseek-v2-lite +NUM_GPUS=${1:-8} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3_online.py \ + --target-model-path DeepSeek-V2-Lite \ + --draft-model-config $ROOT_DIR/configs/deepseek-v2-lite-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \ + --output-dir $ROOT_DIR/outputs/deepseek-v2-lite-eagle3 \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size 1 \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template deepseek \ + --cache-dir $ROOT_DIR/cache \ diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 22e36a94..3b900a09 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -337,6 +337,116 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min_val, max_val, dim): + if min_val == max_val: + max_val += 0.001 # Prevent singularity + linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / ( + max_val - min_val + ) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class LlamaYarnRotaryEmbedding(LlamaRotaryEmbedding): + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", + (emb.cos() * _mscale)[None, None, :, :].to(dtype), + persistent=False, + ) + self.register_buffer( + "sin_cached", + (emb.sin() * _mscale)[None, None, :, :].to(dtype), + persistent=False, + ) + + class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -397,6 +507,19 @@ def _init_rope(self): self.rotary_emb = LlamaMutiRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings ) + elif scaling_type == "yarn": + self.rotary_emb = LlamaYarnRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + original_max_position_embeddings=self.config.rope_scaling[ + "original_max_position_embeddings" + ], + scaling_factor=self.config.rope_scaling["factor"], + beta_fast=self.config.rope_scaling["beta_fast"], + beta_slow=self.config.rope_scaling["beta_slow"], + mscale=self.config.rope_scaling["mscale"], + mscale_all_dim=self.config.rope_scaling["mscale_all_dim"], + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}")