Skip to content

[Flex Attention Perf] Backwards cherry-pick for Inductor Autotune refactor #2392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 11, 2025

Conversation

jataylo
Copy link

@jataylo jataylo commented Jul 22, 2025

Required for flex attention and gemm improvements

…pytorch#147452)

This change was reverted in pytorch#147388 for regressing an internal workload.

I have removed the additional ir.device_type calls in mm_scaled and unpack_mixed_mm.py which could be contributing to the additional compile time.

Pull Request resolved: pytorch#147452
Approved by: https://github.com/jansel

(cherry picked from commit 32299e5)
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Jul 22, 2025

Jenkins build for 820aa47fd4b2cf232f11c42ee9f8ae84b9ef63af commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@jataylo jataylo changed the title Cherry-pick for Autotune refactor Cherry-pick for Inductor Autotune refactor Jul 25, 2025
@jataylo
Copy link
Author

jataylo commented Jul 25, 2025

Some model results:
Old tuning |
Model | MI300 - flex call ms
model1 | 8.61
model2 | 18.168
model3 | 11.257
model4 | 10.301
model5 | 19.0303

After tuning cherry pick |
Model | MI300 - flex call ms
model1 | 8.4089
model2 | 19.076
model3 | 11.5422
model4 | 10.1494
model5 | 18.9276

TORCHINDUCTOR_FLEX_SEARCH_SPACE="EXHAUSTIVE" |
Model | MI300 - flex call ms
model1 | 7.151
model2 | 17.1463
model3 | 10.1236
model4 | 9.119
model5 | 16.8396

score_mod bench

flex_attn mode Speedup (Avg) Speedup (Max)
fwd autotune before PR 2.608 20.56
fwd autotune after PR 2.862 22
fwd exhaustive_autotune 2.943 22.471
bwd autotune before PR 2.196 9.831
bwd autotune after PR 2.423 11.331
bwd exhaustive_autotune 2.566 13.87

@jataylo jataylo marked this pull request as ready for review July 25, 2025 09:15
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Jul 25, 2025

Jenkins build for 77b19c13455ae93e0ffd70bff7d33b261df438ac commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@jataylo
Copy link
Author

jataylo commented Jul 30, 2025

Looks like some failures to address : AttributeError: 'GemmConfig' object has no attribute '_replace'

@jataylo jataylo marked this pull request as draft July 30, 2025 14:10
jataylo and others added 3 commits August 10, 2025 22:38
Replaces pytorch#143286

Adds ROCm specific MM configs for max-autotune incorporating ROCm specific triton tuning kernargs such as waves_per_eu, kpack, matrix_instr_nonkdim. This PR also introduces behavior to allow tuning for GROUP_M in triton gemm case.

Dynamo huggingface inference benchmarks:
`TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS="TRITON" python huggingface.py --performance --inference --bfloat16 --backend=inductor`

GEOMEAN speedup (before): | 1.35x
GEOMEAN speedup (after): | 1.42x

name | Eager - abs latency | old - abs_latency | old - speedup | new - abs_latency | new - speedup
-- | -- | -- | -- | -- | --
AlbertForMaskedLM | 26.22 | 26.52 | 98.86% | 24.58 | 106.67%
AlbertForQuestionAnswering | 25.96 | 26.40 | 98.33% | 24.10 | 107.73%
AllenaiLongformerBase | 21.03 | 10.65 | 197.50% | 10.49 | 200.58%
BartForCausalLM | 7.77 | 9.76 | 79.63% | 8.79 | 88.46%
BartForConditionalGeneration | 14.44 | 12.86 | 112.26% | 11.96 | 120.70%
BertForMaskedLM | 8.10 | 8.82 | 91.89% | 8.57 | 94.53%
BertForQuestionAnswering | 6.82 | 7.32 | 93.20% | 7.10 | 96.18%
BlenderbotForCausalLM | 10.97 | 11.39 | 96.34% | 10.10 | 108.65%
BlenderbotSmallForCausalLM | 5.91 | 5.44 | 108.72% | 4.82 | 122.67%
BlenderbotSmallForConditionalGeneration | 12.64 | 9.65 | 130.94% | 9.11 | 138.83%
CamemBert | 8.35 | 9.15 | 91.24% | 8.86 | 94.27%
DebertaForMaskedLM | 10.92 | 6.09 | 179.44% | 5.90 | 185.05%
DebertaForQuestionAnswering | 14.29 | 7.70 | 185.59% | 7.26 | 196.75%
DebertaV2ForMaskedLM | 15.47 | 10.22 | 151.32% | 9.34 | 165.55%
DebertaV2ForQuestionAnswering | 14.98 | 6.11 | 245.28% | 6.28 | 238.40%
DistilBertForMaskedLM | 8.37 | 8.70 | 96.30% | 8.22 | 101.92%
DistilBertForQuestionAnswering | 10.21 | 10.54 | 96.88% | 10.39 | 98.36%
DistillGPT2 | 8.77 | 6.78 | 129.40% | 6.31 | 138.88%
ElectraForCausalLM | 10.32 | 4.70 | 219.45% | 4.60 | 224.29%
ElectraForQuestionAnswering | 11.48 | 5.62 | 204.20% | 5.44 | 210.95%
GPT2ForSequenceClassification | 6.21 | 5.72 | 108.50% | 5.58 | 111.26%
GoogleFnet | 26.51 | 20.81 | 127.37% | 19.91 | 133.11%
LayoutLMForMaskedLM | 12.09 | 7.99 | 151.28% | 7.66 | 157.80%
LayoutLMForSequenceClassification | 10.62 | 6.49 | 163.67% | 6.25 | 169.95%
M2M100ForConditionalGeneration | 14.98 | 10.20 | 146.79% | 9.89 | 151.42%
MBartForCausalLM | 7.67 | 9.78 | 78.44% | 8.87 | 86.55%
MBartForConditionalGeneration | 13.45 | 12.69 | 105.99% | 12.03 | 111.82%
MT5ForConditionalGeneration | 19.96 | 5.32 | 375.37% | 5.08 | 393.01%
MegatronBertForCausalLM | 13.22 | 7.86 | 168.07% | 7.18 | 184.01%
MegatronBertForQuestionAnswering | 15.62 | 11.81 | 132.21% | 11.02 | 141.68%
MobileBertForMaskedLM | 26.63 | 10.82 | 245.99% | 11.95 | 222.73%
MobileBertForQuestionAnswering | 23.53 | 7.55 | 311.51% | 9.53 | 247.03%
OPTForCausalLM | 7.33 | 7.64 | 95.93% | 7.56 | 96.90%
PLBartForCausalLM | 8.73 | 7.63 | 114.40% | 7.37 | 118.58%
PLBartForConditionalGeneration | 10.46 | 8.50 | 122.98% | 8.16 | 128.13%
PegasusForCausalLM | 7.18 | 7.37 | 97.42% | 6.64 | 108.22%
PegasusForConditionalGeneration | 16.47 | 16.66 | 98.87% | 14.18 | 116.13%
RobertaForCausalLM | 10.30 | 9.95 | 103.52% | 9.52 | 108.25%
RobertaForQuestionAnswering | 6.37 | 7.13 | 89.28% | 6.79 | 93.87%
T5ForConditionalGeneration | 12.40 | 6.72 | 184.51% | 6.48 | 191.16%
T5Small | 12.02 | 6.66 | 180.55% | 6.32 | 190.33%
TrOCRForCausalLM | 14.12 | 13.31 | 106.11% | 12.45 | 113.41%
XGLMForCausalLM | 16.48 | 6.23 | 264.52% | 6.35 | 259.51%
XLNetLMHeadModel | 74.87 | 62.23 | 120.32% | 57.95 | 129.19%
YituTechConvBert | 20.21 | 10.50 | 192.48% | 9.97 | 202.72%

We are also seeing improvement ~9% on internal addmm benchmark

This PR will also slightly reduce the compilation time on AMD max-autotune as before this change we assess every config with matrix_instr_nonkdim [0, 16] but we remove this and use 16 for all configs with this update.

No CI to test the max-autotune perf currently but this will be enabled via pytorch#148672 after which we can investigate more tuning updates and config pruning

Pull Request resolved: pytorch#147315
Approved by: https://github.com/jansel, https://github.com/eellison

(cherry picked from commit 2299087)
…n mm and addmm (pytorch#150587)

Summary:
This PR introduces additional autotuning configurations for the persistent+TMA version of Triton `mm` and `addmm` operations. The new configurations are as follows:
* `(128, 128, 64, 5, 8)`
* `(256, 128, 64, 4, 8)`
* `(128, 128, 64, 5, 4)`

These configurations were selected based on exhaustive autotuning performed on commonly used shapes from an internal foundational model.

While these new configs are generally more performant across the board, we see notable gains a few specific cases:
* In scenarios where `n >> m, k`, the configurations `(128, 128, 64, 5, 8)` and `(256, 128, 64, 4, 8)` tend to produce an additional 5-10% speedup over the aten baseline compared to the original configurations.
* Similarly, the configuration `(128, 128, 64, 5, 4)` yields approximately an 8% improvement in scenarios where k >> m, n.

These enhancements are expected to provide performance benefits across diverse use cases, particularly when compared to the original set of configurations.

Test Plan:
contbuild & OSS CI

Reviewers: paulzhan

Pull Request resolved: pytorch#150587
Approved by: https://github.com/PaulZhang12, https://github.com/drisspg, https://github.com/eellison

(cherry picked from commit 5acc3e2)
This PR primarily unifies the flex attention config logic with the GEMM/Conv config approach pytorch#147452 this will make it much easier to handle optimisation pathways for particular triton backends.

This PR also introduces:
1. Introduces an exhaustive tuning mode for flex attention via TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE="EXHAUSTIVE" to allow for wide scale benchmarking for perf investigation use cases.
3. Updates configs for ROCm flex autotune path providing perf optimisations

AMD perf numbers on score mod benchmark (default inputs)
flex_attn | mode | Speedup (Avg) | Speedup (Max)
-- | -- | -- | --
fwd | autotune before PR | 2.608 | 20.56
fwd | autotune after PR | 2.862 | 22
fwd | exhaustive_autotune | 2.943 | 22.471
bwd | autotune before PR | 2.196 | 9.831
bwd | autotune after PR | 2.423 | 11.331
bwd | exhaustive_autotune | 2.566 | 13.87

Pull Request resolved: pytorch#156307
Approved by: https://github.com/drisspg, https://github.com/jansel

(cherry picked from commit 03023f1)
@jataylo jataylo force-pushed the autotune-refactor-27-cp branch from 77b19c1 to b4a6b64 Compare August 10, 2025 22:49
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Aug 10, 2025

Jenkins build for b4a6b648d054a1d3a90c87c7b355884b0bec40a9 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@jataylo
Copy link
Author

jataylo commented Aug 11, 2025

@pruthvistony @jithunnair-amd failures are not related please help merge

@jataylo jataylo marked this pull request as ready for review August 11, 2025 14:34
@jataylo jataylo changed the title Cherry-pick for Inductor Autotune refactor [Flex Attention Perf] Backwards cherry-pick for Inductor Autotune refactor Aug 11, 2025
Copy link
Collaborator

@pruthvistony pruthvistony left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tested. already in 2.8

@pruthvistony pruthvistony merged commit 2975a90 into ROCm:release/2.7 Aug 11, 2025
1 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants