Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/deepseek-v2-lite-eagle3.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"mscale": 0.707,
"mscale_all_dim": 0.707,
"original_max_position_embeddings": 4096,
"type": "yarn"
"rope_type": "yarn"
},
"rope_theta": 10000,
"sliding_window": null,
Expand Down
21 changes: 21 additions & 0 deletions examples/run_deepseek_v2_lite_eagle3_online.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)

# train eagle3 for deepseek-v2-lite
NUM_GPUS=${1:-8}

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3_online.py \
--target-model-path DeepSeek-V2-Lite \
--draft-model-config $ROOT_DIR/configs/deepseek-v2-lite-eagle3.json \
--train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \
--output-dir $ROOT_DIR/outputs/deepseek-v2-lite-eagle3 \
--num-epochs 10 \
--batch-size 1 \
--tp-size 1 \
--learning-rate 1e-4 \
--max-length 2048 \
--chat-template deepseek \
--cache-dir $ROOT_DIR/cache \
123 changes: 123 additions & 0 deletions specforge/modeling/draft/llama3_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,116 @@ def forward(self, x, position_ids):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


# Inverse dim formula to find dim based on number of rotations
def yarn_find_correction_dim(
num_rotations, dim, base=10000, max_position_embeddings=2048
):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)


# Find dim range bounds based on rotations
def yarn_find_correction_range(
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
):
low = math.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1) # Clamp values just in case


def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0


def yarn_linear_ramp_mask(min_val, max_val, dim):
if min_val == max_val:
max_val += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (
max_val - min_val
)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func


class LlamaYarnRotaryEmbedding(LlamaRotaryEmbedding):

def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
original_max_position_embeddings=4096,
beta_fast=32,
beta_slow=1,
mscale=1,
mscale_all_dim=0,
):
self.scaling_factor = scaling_factor
self.original_max_position_embeddings = original_max_position_embeddings
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.mscale = mscale
self.mscale_all_dim = mscale_all_dim
super().__init__(dim, max_position_embeddings, base, device)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
dim = self.dim

freq_extra = 1.0 / (
self.base
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
)
freq_inter = 1.0 / (
self.scaling_factor
* self.base
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
)

low, high = yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
dim,
self.base,
self.original_max_position_embeddings,
)
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
device=device, dtype=torch.float32
)
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
self.register_buffer("inv_freq", inv_freq, persistent=False)

t = torch.arange(seq_len, device=device, dtype=torch.float32)

freqs = torch.outer(t, inv_freq)

_mscale = float(
yarn_get_mscale(self.scaling_factor, self.mscale)
/ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
)

emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer(
"cos_cached",
(emb.cos() * _mscale)[None, None, :, :].to(dtype),
persistent=False,
)
self.register_buffer(
"sin_cached",
(emb.sin() * _mscale)[None, None, :, :].to(dtype),
persistent=False,
)


class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand Down Expand Up @@ -397,6 +507,19 @@ def _init_rope(self):
self.rotary_emb = LlamaMutiRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings
)
elif scaling_type == "yarn":
self.rotary_emb = LlamaYarnRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
original_max_position_embeddings=self.config.rope_scaling[
"original_max_position_embeddings"
],
scaling_factor=self.config.rope_scaling["factor"],
beta_fast=self.config.rope_scaling["beta_fast"],
beta_slow=self.config.rope_scaling["beta_slow"],
mscale=self.config.rope_scaling["mscale"],
mscale_all_dim=self.config.rope_scaling["mscale_all_dim"],
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

Expand Down
Loading