Skip to content

[Tutorial] Fix 06-fused-attention.py of FP8 provider #7043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

whitneywhtsang
Copy link
Collaborator

@whitneywhtsang whitneywhtsang commented Jun 4, 2025

When the provider is fp8, v is permuted like below, and the new stride is (H*N_CTX*HEAD_DIM, N_CTX*HEAD_DIM, 1, N_CTX).

        if mode == "fwd" and "fp8" in provider:
            v = v.permute(0, 1, 3, 2).contiguous()
            v = v.permute(0, 1, 3, 2)

This PR fixes the FP8 dtype handling in the fused-attention kernel by separating k and v offset calculations and updating related configuration details. Key changes include:

  • Renaming and separating offset variables for k and v computations.
  • Adjusting offset calculation for FP8 dtype and updating the tensor descriptor creation.
  • Expanding configuration options for BLOCK_N and refining device-specific configuration conditions.

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes the FP8 dtype handling in the fused-attention kernel by separating key and value offset calculations and updating related configuration details. Key changes include:

  • Renaming and separating offset variables for key and value computations.
  • Adjusting offset calculation for FP8 dtype and updating the tensor descriptor creation.
  • Expanding configuration options for BLOCK_N and refining device-specific configuration conditions.
Comments suppressed due to low confidence (4)

python/tutorials/06-fused-attention.py:55

  • [nitpick] The variable 'offsetk_y' now clearly denotes the key tensor offset. Consider updating any adjacent comments or documentation to clarify the separation between key and value offsets to enhance readability.
offsetk_y = offset_y + lo

python/tutorials/06-fused-attention.py:56

  • [nitpick] It would be beneficial to add a comment explaining why offsetv_y is computed as 'offset_y * HEAD_DIM + lo' for the FP8 dtype, especially in light of the new stride requirements.
if dtype == tl.float8e5:

python/tutorials/06-fused-attention.py:171

  • [nitpick] Consider adding an inline comment to explain the purpose and expected behavior of FP8_OUTPUT in controlling the descriptor configuration, to aid future maintenance.
if FP8_OUTPUT:

python/tutorials/06-fused-attention.py:130

  • [nitpick] Consider including a brief comment explaining the significance of checking for device capability 9 and the BLOCK_M * BLOCK_N threshold to improve clarity for future maintainers.
and torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128

@whitneywhtsang whitneywhtsang changed the title Fix 06-fused-attention.py of FP8 dtype Fix 06-fused-attention.py of FP8 provider Jun 4, 2025
@whitneywhtsang whitneywhtsang marked this pull request as ready for review June 4, 2025 02:02
@whitneywhtsang whitneywhtsang requested a review from ptillet as a code owner June 4, 2025 02:02
@whitneywhtsang whitneywhtsang changed the title Fix 06-fused-attention.py of FP8 provider [Tutorial] Fix 06-fused-attention.py of FP8 provider Jun 4, 2025
@whitneywhtsang whitneywhtsang force-pushed the 06-fused-attention branch 2 times, most recently from 1ccc3f4 to 5450875 Compare June 4, 2025 19:20
Copy link
Contributor

@peterbell10 peterbell10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, can you add an fp8 case to test_op?

@whitneywhtsang
Copy link
Collaborator Author

Thanks, can you add an fp8 case to test_op?

It appears that test_op is written for bwd, and from looking at bench_flash_attention, there are no differences between provider triton-fp16 and triton-fp8 for bwd.

@peterbell10
Copy link
Contributor

It tests both forward and backward:

# triton implementation
tri_out = attention(q, k, v, causal, sm_scale, warp_specialize).half()
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0)

@whitneywhtsang whitneywhtsang force-pushed the 06-fused-attention branch 2 times, most recently from 5ec5364 to 00c9235 Compare June 5, 2025 16:24
@whitneywhtsang
Copy link
Collaborator Author

@peterbell10 RuntimeError: "baddbmm_cuda" not implemented for 'Float8_e5m2' any suggestions on how to get reference output for fp8?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants