From 755a9edfe7852fe504f75a9000aa9f9870efd400 Mon Sep 17 00:00:00 2001 From: lyh Date: Sat, 4 Oct 2025 00:51:46 +0800 Subject: [PATCH 1/2] fix: remove attention mask shift & add pe shift --- specforge/core/eagle3.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index 183c2f42..075fbb37 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -276,13 +276,7 @@ def forward( input_ids = padding(input_ids, left=False) position_mask = padding(position_mask, left=False) loss_mask = padding(loss_mask, left=False) - if self.attention_backend == "sdpa": - ind = torch.arange(seq_length, device=attention_mask.device) - ind0 = ind[idx:] - ind1 = ind[: seq_length - idx] - attention_mask[:, :, ind0, ind1] = torch.finfo( - attention_mask.dtype - ).min + position_ids += 1 # Flex attention mask shirnking is handled inside attention module return plosses, vlosses, acces @@ -658,13 +652,7 @@ def forward( input_ids = padding(input_ids, left=False) position_mask = padding(position_mask, left=False) loss_mask = padding(loss_mask, left=False) - if self.attention_backend == "sdpa": - ind = torch.arange(seq_length, device=attention_mask.device) - ind0 = ind[idx:] - ind1 = ind[: seq_length - idx] - attention_mask[:, :, ind0, ind1] = torch.finfo( - attention_mask.dtype - ).min + position_ids += 1 # Flex attention mask shirnking is handled inside attention module return plosses, vlosses, acces From 49f73cf18ec4f31ffd098085359d523e8d531b42 Mon Sep 17 00:00:00 2001 From: lyh Date: Sat, 4 Oct 2025 10:33:19 +0800 Subject: [PATCH 2/2] fix: Remove pe shift --- specforge/core/eagle3.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index 075fbb37..a7ef992f 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -276,7 +276,6 @@ def forward( input_ids = padding(input_ids, left=False) position_mask = padding(position_mask, left=False) loss_mask = padding(loss_mask, left=False) - position_ids += 1 # Flex attention mask shirnking is handled inside attention module return plosses, vlosses, acces @@ -652,7 +651,6 @@ def forward( input_ids = padding(input_ids, left=False) position_mask = padding(position_mask, left=False) loss_mask = padding(loss_mask, left=False) - position_ids += 1 # Flex attention mask shirnking is handled inside attention module return plosses, vlosses, acces