-
Notifications
You must be signed in to change notification settings - Fork 2k
[Tutorial] Fix 06-fused-attention.py
of FP8 provider
#7043
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Tutorial] Fix 06-fused-attention.py
of FP8 provider
#7043
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes the FP8 dtype handling in the fused-attention kernel by separating key and value offset calculations and updating related configuration details. Key changes include:
- Renaming and separating offset variables for key and value computations.
- Adjusting offset calculation for FP8 dtype and updating the tensor descriptor creation.
- Expanding configuration options for BLOCK_N and refining device-specific configuration conditions.
Comments suppressed due to low confidence (4)
python/tutorials/06-fused-attention.py:55
- [nitpick] The variable 'offsetk_y' now clearly denotes the key tensor offset. Consider updating any adjacent comments or documentation to clarify the separation between key and value offsets to enhance readability.
offsetk_y = offset_y + lo
python/tutorials/06-fused-attention.py:56
- [nitpick] It would be beneficial to add a comment explaining why offsetv_y is computed as 'offset_y * HEAD_DIM + lo' for the FP8 dtype, especially in light of the new stride requirements.
if dtype == tl.float8e5:
python/tutorials/06-fused-attention.py:171
- [nitpick] Consider adding an inline comment to explain the purpose and expected behavior of FP8_OUTPUT in controlling the descriptor configuration, to aid future maintenance.
if FP8_OUTPUT:
python/tutorials/06-fused-attention.py:130
- [nitpick] Consider including a brief comment explaining the significance of checking for device capability 9 and the BLOCK_M * BLOCK_N threshold to improve clarity for future maintainers.
and torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128
06-fused-attention.py
of FP8 dtype06-fused-attention.py
of FP8 provider
06-fused-attention.py
of FP8 provider06-fused-attention.py
of FP8 provider
1ccc3f4
to
5450875
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, can you add an fp8 case to test_op?
It appears that |
It tests both forward and backward: triton/python/tutorials/06-fused-attention.py Lines 613 to 620 in 5450875
|
Signed-off-by: Whitney Tsang <[email protected]>
Signed-off-by: Whitney Tsang <[email protected]>
5ec5364
to
00c9235
Compare
Signed-off-by: Whitney Tsang <[email protected]>
00c9235
to
c324519
Compare
@peterbell10 |
Signed-off-by: Whitney Tsang <[email protected]>
When the provider is
fp8
,v
is permuted like below, and the new stride is(H*N_CTX*HEAD_DIM, N_CTX*HEAD_DIM, 1, N_CTX)
.This PR fixes the FP8 dtype handling in the fused-attention kernel by separating
k
andv
offset calculations and updating related configuration details. Key changes include:k
andv
computations.