Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions specforge/benchmarks/benchmark_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,6 @@ def run_attention(
loss = output[0].sum()
loss_list.append(loss)

if attention_backend == "sdpa" and not is_last:
# Step 5.7: we need to update the loss mask
ind = torch.arange(seq_len, device=decoder_attention_mask.device)
ind0 = ind[idx:]
ind1 = ind[: seq_len - idx]
decoder_attention_mask[:, :, ind0, ind1] = torch.finfo(
decoder_attention_mask.dtype
).min

# Compute mean loss and backward pass
if loss_list:
mean_loss = sum(loss_list) / len(loss_list)
Expand Down
16 changes: 0 additions & 16 deletions specforge/core/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,6 @@ def forward(
input_ids = padding(input_ids, left=False)
position_mask = padding(position_mask, left=False)
loss_mask = padding(loss_mask, left=False)
if self.attention_backend == "sdpa":
ind = torch.arange(seq_length, device=attention_mask.device)
ind0 = ind[idx:]
ind1 = ind[: seq_length - idx]
attention_mask[:, :, ind0, ind1] = torch.finfo(
attention_mask.dtype
).min
# Flex attention mask shirnking is handled inside attention module
return plosses, vlosses, acces


Expand Down Expand Up @@ -658,14 +650,6 @@ def forward(
input_ids = padding(input_ids, left=False)
position_mask = padding(position_mask, left=False)
loss_mask = padding(loss_mask, left=False)
if self.attention_backend == "sdpa":
ind = torch.arange(seq_length, device=attention_mask.device)
ind0 = ind[idx:]
ind1 = ind[: seq_length - idx]
attention_mask[:, :, ind0, ind1] = torch.finfo(
attention_mask.dtype
).min
# Flex attention mask shirnking is handled inside attention module
return plosses, vlosses, acces


Expand Down
8 changes: 4 additions & 4 deletions specforge/modeling/draft/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,14 @@ def compile_friendly_create_block_mask(


def generate_eagle3_mask(
seq_lengths: torch.Tensor, Q_LEN: int, KV_LEN: int, shift_left: int = 0
seq_lengths: torch.Tensor, Q_LEN: int, KV_LEN: int, lck: int = 0
):

def causal_mask(b, h, q_idx, kv_idx):
# Causal will keep shrinking by 1 diagnol due to appended suffix
# Shirnk the causal by diagnol
causal_mask = q_idx - shift_left >= kv_idx
padding_mask = kv_idx < seq_lengths[b]
causal_mask = q_idx >= kv_idx
padding_mask = (kv_idx < seq_lengths[b]) & (q_idx < seq_lengths[b])
return causal_mask & padding_mask

def suffix_mask(b, h, q_idx, kv_idx):
Expand All @@ -126,5 +126,5 @@ def suffix_mask(b, h, q_idx, kv_idx):
return suffix_mask & padding_mask & diagnol_mask

mask_mod = or_masks(causal_mask, suffix_mask)
mask_mod.__name__ = f"eagle3_mask_Q_{Q_LEN}_KV_{KV_LEN}_shift_left_{shift_left}"
mask_mod.__name__ = f"eagle3_mask_Q_{Q_LEN}_KV_{KV_LEN}_lck_{lck}"
return mask_mod
2 changes: 1 addition & 1 deletion specforge/modeling/draft/llama3_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def forward(
seq_lengths=seq_lengths,
Q_LEN=q_len,
KV_LEN=key_cache.shape[-2],
shift_left=lck,
lck=lck,
),
B=bsz,
H=1, # Rely on broadcast
Expand Down
5 changes: 4 additions & 1 deletion specforge/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def load_config_from_file(config_path: str):


def print_with_rank(message):
logger.info(f"rank {dist.get_rank()}: {message}")
if dist.is_available() and dist.is_initialized():
logger.info(f"rank {dist.get_rank()}: {message}")
else:
logger.info(f"non-distributed: {message}")


def print_on_rank0(message):
Expand Down
35 changes: 12 additions & 23 deletions tests/test_utils/test_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,7 @@ def _test_forward_pass_comparison_for_seq_len(self, seq_len):
atol=1e-2,
rtol=1e-2,
)
if not is_last:
# Step 5.7: we need to update the loss mask
ind = torch.arange(seq_len, device=decoder_attention_mask.device)
ind0 = ind[idx:]
ind1 = ind[: seq_len - idx]
decoder_attention_mask[:, :, ind0, ind1] = torch.finfo(
decoder_attention_mask.dtype
).min

# Check output shape
expected_output_shape = (batch_size, seq_len, self.config.hidden_size)
self.assertEqual(output_flex.shape, expected_output_shape)
Expand Down Expand Up @@ -238,12 +231,6 @@ def _test_backward_pass_gradient_comparison_for_seq_len(self, seq_len):

if not is_last:
# Step 5.7: we need to update the loss mask
ind = torch.arange(seq_len, device=decoder_attention_mask.device)
ind0 = ind[idx:]
ind1 = ind[: seq_len - idx]
decoder_attention_mask[:, :, ind0, ind1] = torch.finfo(
decoder_attention_mask.dtype
).min
loss_mask = padding(loss_mask, left=False)
mean_loss = sum(loss_list) / len(loss_list)
mean_loss_flex = sum(loss_flex_list) / len(loss_flex_list)
Expand All @@ -268,14 +255,16 @@ def test_eagle3_flex_mask(self):
D = 128
Q_LEN = S
KV_LEN = S * 3
lck = 128 * 2
data_type = torch.bfloat16
query = norm_tensor((B, H, S, D), device="cuda", dtype=data_type)
key_cache = norm_tensor((B, H, KV_LEN, D), device="cuda", dtype=data_type)
value_cache = norm_tensor((B, H, KV_LEN, D), device="cuda", dtype=data_type)
seq_lengths = torch.tensor([S], device="cuda", dtype=torch.int32)
seq_lengths -= lck
block_mask = compile_friendly_create_block_mask(
mask_mod=generate_eagle3_mask(
seq_lengths=seq_lengths, Q_LEN=Q_LEN, KV_LEN=KV_LEN, shift_left=128 * 2
seq_lengths=seq_lengths, Q_LEN=Q_LEN, KV_LEN=KV_LEN, lck=lck
),
B=1,
H=1,
Expand All @@ -285,14 +274,14 @@ def test_eagle3_flex_mask(self):
)
# fmt: off
expected_mask = torch.tensor([[[
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
[1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
]]], dtype=torch.int32).to(query.device)
# fmt: on
dense_mask = block_mask.to_dense()
Expand Down