-
Notifications
You must be signed in to change notification settings - Fork 71
[Flex Attention Perf] Backwards cherry-pick for Inductor Autotune refactor #2392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flex Attention Perf] Backwards cherry-pick for Inductor Autotune refactor #2392
Conversation
…pytorch#147452) This change was reverted in pytorch#147388 for regressing an internal workload. I have removed the additional ir.device_type calls in mm_scaled and unpack_mixed_mm.py which could be contributing to the additional compile time. Pull Request resolved: pytorch#147452 Approved by: https://github.com/jansel (cherry picked from commit 32299e5)
Jenkins build for 820aa47fd4b2cf232f11c42ee9f8ae84b9ef63af commit finished as FAILURE |
Some model results: After tuning cherry pick | TORCHINDUCTOR_FLEX_SEARCH_SPACE="EXHAUSTIVE" | score_mod bench
|
Jenkins build for 77b19c13455ae93e0ffd70bff7d33b261df438ac commit finished as FAILURE |
Looks like some failures to address : AttributeError: 'GemmConfig' object has no attribute '_replace' |
Replaces pytorch#143286 Adds ROCm specific MM configs for max-autotune incorporating ROCm specific triton tuning kernargs such as waves_per_eu, kpack, matrix_instr_nonkdim. This PR also introduces behavior to allow tuning for GROUP_M in triton gemm case. Dynamo huggingface inference benchmarks: `TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS="TRITON" python huggingface.py --performance --inference --bfloat16 --backend=inductor` GEOMEAN speedup (before): | 1.35x GEOMEAN speedup (after): | 1.42x name | Eager - abs latency | old - abs_latency | old - speedup | new - abs_latency | new - speedup -- | -- | -- | -- | -- | -- AlbertForMaskedLM | 26.22 | 26.52 | 98.86% | 24.58 | 106.67% AlbertForQuestionAnswering | 25.96 | 26.40 | 98.33% | 24.10 | 107.73% AllenaiLongformerBase | 21.03 | 10.65 | 197.50% | 10.49 | 200.58% BartForCausalLM | 7.77 | 9.76 | 79.63% | 8.79 | 88.46% BartForConditionalGeneration | 14.44 | 12.86 | 112.26% | 11.96 | 120.70% BertForMaskedLM | 8.10 | 8.82 | 91.89% | 8.57 | 94.53% BertForQuestionAnswering | 6.82 | 7.32 | 93.20% | 7.10 | 96.18% BlenderbotForCausalLM | 10.97 | 11.39 | 96.34% | 10.10 | 108.65% BlenderbotSmallForCausalLM | 5.91 | 5.44 | 108.72% | 4.82 | 122.67% BlenderbotSmallForConditionalGeneration | 12.64 | 9.65 | 130.94% | 9.11 | 138.83% CamemBert | 8.35 | 9.15 | 91.24% | 8.86 | 94.27% DebertaForMaskedLM | 10.92 | 6.09 | 179.44% | 5.90 | 185.05% DebertaForQuestionAnswering | 14.29 | 7.70 | 185.59% | 7.26 | 196.75% DebertaV2ForMaskedLM | 15.47 | 10.22 | 151.32% | 9.34 | 165.55% DebertaV2ForQuestionAnswering | 14.98 | 6.11 | 245.28% | 6.28 | 238.40% DistilBertForMaskedLM | 8.37 | 8.70 | 96.30% | 8.22 | 101.92% DistilBertForQuestionAnswering | 10.21 | 10.54 | 96.88% | 10.39 | 98.36% DistillGPT2 | 8.77 | 6.78 | 129.40% | 6.31 | 138.88% ElectraForCausalLM | 10.32 | 4.70 | 219.45% | 4.60 | 224.29% ElectraForQuestionAnswering | 11.48 | 5.62 | 204.20% | 5.44 | 210.95% GPT2ForSequenceClassification | 6.21 | 5.72 | 108.50% | 5.58 | 111.26% GoogleFnet | 26.51 | 20.81 | 127.37% | 19.91 | 133.11% LayoutLMForMaskedLM | 12.09 | 7.99 | 151.28% | 7.66 | 157.80% LayoutLMForSequenceClassification | 10.62 | 6.49 | 163.67% | 6.25 | 169.95% M2M100ForConditionalGeneration | 14.98 | 10.20 | 146.79% | 9.89 | 151.42% MBartForCausalLM | 7.67 | 9.78 | 78.44% | 8.87 | 86.55% MBartForConditionalGeneration | 13.45 | 12.69 | 105.99% | 12.03 | 111.82% MT5ForConditionalGeneration | 19.96 | 5.32 | 375.37% | 5.08 | 393.01% MegatronBertForCausalLM | 13.22 | 7.86 | 168.07% | 7.18 | 184.01% MegatronBertForQuestionAnswering | 15.62 | 11.81 | 132.21% | 11.02 | 141.68% MobileBertForMaskedLM | 26.63 | 10.82 | 245.99% | 11.95 | 222.73% MobileBertForQuestionAnswering | 23.53 | 7.55 | 311.51% | 9.53 | 247.03% OPTForCausalLM | 7.33 | 7.64 | 95.93% | 7.56 | 96.90% PLBartForCausalLM | 8.73 | 7.63 | 114.40% | 7.37 | 118.58% PLBartForConditionalGeneration | 10.46 | 8.50 | 122.98% | 8.16 | 128.13% PegasusForCausalLM | 7.18 | 7.37 | 97.42% | 6.64 | 108.22% PegasusForConditionalGeneration | 16.47 | 16.66 | 98.87% | 14.18 | 116.13% RobertaForCausalLM | 10.30 | 9.95 | 103.52% | 9.52 | 108.25% RobertaForQuestionAnswering | 6.37 | 7.13 | 89.28% | 6.79 | 93.87% T5ForConditionalGeneration | 12.40 | 6.72 | 184.51% | 6.48 | 191.16% T5Small | 12.02 | 6.66 | 180.55% | 6.32 | 190.33% TrOCRForCausalLM | 14.12 | 13.31 | 106.11% | 12.45 | 113.41% XGLMForCausalLM | 16.48 | 6.23 | 264.52% | 6.35 | 259.51% XLNetLMHeadModel | 74.87 | 62.23 | 120.32% | 57.95 | 129.19% YituTechConvBert | 20.21 | 10.50 | 192.48% | 9.97 | 202.72% We are also seeing improvement ~9% on internal addmm benchmark This PR will also slightly reduce the compilation time on AMD max-autotune as before this change we assess every config with matrix_instr_nonkdim [0, 16] but we remove this and use 16 for all configs with this update. No CI to test the max-autotune perf currently but this will be enabled via pytorch#148672 after which we can investigate more tuning updates and config pruning Pull Request resolved: pytorch#147315 Approved by: https://github.com/jansel, https://github.com/eellison (cherry picked from commit 2299087)
…n mm and addmm (pytorch#150587) Summary: This PR introduces additional autotuning configurations for the persistent+TMA version of Triton `mm` and `addmm` operations. The new configurations are as follows: * `(128, 128, 64, 5, 8)` * `(256, 128, 64, 4, 8)` * `(128, 128, 64, 5, 4)` These configurations were selected based on exhaustive autotuning performed on commonly used shapes from an internal foundational model. While these new configs are generally more performant across the board, we see notable gains a few specific cases: * In scenarios where `n >> m, k`, the configurations `(128, 128, 64, 5, 8)` and `(256, 128, 64, 4, 8)` tend to produce an additional 5-10% speedup over the aten baseline compared to the original configurations. * Similarly, the configuration `(128, 128, 64, 5, 4)` yields approximately an 8% improvement in scenarios where k >> m, n. These enhancements are expected to provide performance benefits across diverse use cases, particularly when compared to the original set of configurations. Test Plan: contbuild & OSS CI Reviewers: paulzhan Pull Request resolved: pytorch#150587 Approved by: https://github.com/PaulZhang12, https://github.com/drisspg, https://github.com/eellison (cherry picked from commit 5acc3e2)
This PR primarily unifies the flex attention config logic with the GEMM/Conv config approach pytorch#147452 this will make it much easier to handle optimisation pathways for particular triton backends. This PR also introduces: 1. Introduces an exhaustive tuning mode for flex attention via TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE="EXHAUSTIVE" to allow for wide scale benchmarking for perf investigation use cases. 3. Updates configs for ROCm flex autotune path providing perf optimisations AMD perf numbers on score mod benchmark (default inputs) flex_attn | mode | Speedup (Avg) | Speedup (Max) -- | -- | -- | -- fwd | autotune before PR | 2.608 | 20.56 fwd | autotune after PR | 2.862 | 22 fwd | exhaustive_autotune | 2.943 | 22.471 bwd | autotune before PR | 2.196 | 9.831 bwd | autotune after PR | 2.423 | 11.331 bwd | exhaustive_autotune | 2.566 | 13.87 Pull Request resolved: pytorch#156307 Approved by: https://github.com/drisspg, https://github.com/jansel (cherry picked from commit 03023f1)
77b19c1
to
b4a6b64
Compare
Jenkins build for b4a6b648d054a1d3a90c87c7b355884b0bec40a9 commit finished as FAILURE |
@pruthvistony @jithunnair-amd failures are not related please help merge |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tested. already in 2.8
Required for flex attention and gemm improvements