Skip to content

[Flex Attention Perf] Backwards cherry-pick for Inductor Autotune refactor #2392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 108 additions & 2 deletions torch/_inductor/choices.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,33 @@
from __future__ import annotations

import typing
from typing import Any, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING

import sympy

import torch

from . import config
from .codecache import write_text
from .metrics import get_metric_table, is_metric_table_enabled
from .runtime.hints import DeviceProperties, ReductionHint
from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse
from .template_heuristics import (
BaseConfigHeuristic,
CPUConfigHeuristic,
CUDAConfigHeuristic,
ROCmConfigHeuristic,
XPUConfigHeuristic,
)
from .virtualized import V


if TYPE_CHECKING:
import torch
from collections.abc import Generator
from functools import partial

from triton import Config as TritonConfig

from torch.utils._ordered_set import OrderedSet

from .codegen.simd_kernel_features import SIMDKernelFeatures
Expand All @@ -40,6 +53,99 @@ class MyHeuristics(InductorChoices):
torch._inductor.virtualized.V.set_choices_handler(MyHeuristics())
"""

def get_config_heuristics(
self, device_type: Optional[str] = "cuda"
) -> BaseConfigHeuristic:
if device_type == "cuda":
if torch.version.hip is None:
return CUDAConfigHeuristic()
else:
return ROCmConfigHeuristic()
elif device_type == "xpu":
return XPUConfigHeuristic()
elif device_type == "cpu":
return CPUConfigHeuristic()
else:
return BaseConfigHeuristic()

# GEMM configs
def get_base_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
if config.max_autotune_gemm_search_space != "EXHAUSTIVE":
return mm_heuristics.get_mm_configs()
else:
return mm_heuristics.get_exhaustive_mm_configs()

def get_extra_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_extra_mm_configs()

def get_int8_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_int8_mm_configs()

def get_mixed_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_mixed_mm_configs()

def get_persistent_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_persistent_mm_configs()

def get_scaled_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_scaled_mm_configs()

def get_scaled_persistent_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_scaled_persistent_mm_configs()

def get_mm_plus_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_mm_plus_mm_configs()

# Conv configs
def get_conv_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
conv_heuristics = self.get_config_heuristics(device_type)
return conv_heuristics.get_conv_configs()

# Flex attention configs
def get_flex_attention_fwd_configs(
self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda"
) -> list[Any]:
flex_heuristics = self.get_config_heuristics(device_type)
return flex_heuristics.get_flex_attn_fwd_configs(head_dim, dtype)

def get_flex_attention_bwd_configs(
self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda"
) -> list[Any]:
flex_heuristics = self.get_config_heuristics(device_type)
return flex_heuristics.get_flex_attn_bwd_configs(head_dim, dtype)

def get_flex_decode_configs(
self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda"
) -> list[Any]:
flex_heuristics = self.get_config_heuristics(device_type)
return flex_heuristics.get_flex_decode_configs(head_dim, dtype)

def triton_kernel_kwargs(
self,
kernel_cls: type[TritonKernel],
Expand Down
10 changes: 10 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,12 +397,22 @@ def prologue_fusion_enabled() -> bool:
"TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE", "DEFAULT"
).upper() # type: ignore[assignment]

# Specify the size of the search space for flex attention autotuning.
# DEFAULT - balance between compile time overhead and performance
# EXHAUSTIVE - maximize performance
max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.get(
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
).upper() # type: ignore[assignment]

# NOTE: This feature is deprecated and will be defauled to False in the future.
# Whether we fall back to ATen or hard error when no matches are found during autotuning
autotune_fallback_to_aten = (
os.environ.get("TORCHINDUCTOR_AUTOTUNE_FALLBACK_TO_ATEN", "1") == "1"
)

# DEPRECATED. This setting is ignored.
autotune_fallback_to_aten = False

# the value used as a fallback for the unbacked SymInts
# that can appear in the input shapes (e.g., in autotuning)
unbacked_symint_fallback = 8192
Expand Down
24 changes: 15 additions & 9 deletions torch/_inductor/kernel/bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
_is_static_problem,
addmm_epilogue,
mm_args,
mm_configs,
mm_config_kwargs,
mm_options,
should_fallback_to_aten,
)
Expand All @@ -46,12 +46,6 @@ def _is_large_block_for_cpu(m, n, k):
return m * n > 2**12


def bmm_configs(m, n, k, *, device_type):
if device_type == "cpu":
return mm_configs(m, n, k, scale=0.5, exclude=_is_large_block_for_cpu)
return mm_configs(m, n, k)


bmm_template = TritonTemplate(
name="bmm",
grid=bmm_grid,
Expand Down Expand Up @@ -184,8 +178,14 @@ def may_require_contiguous(t, meta_t):

# options to tune from
choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []

device_type = ir.get_device_type(mat1)
bmm_configs = V.choices.get_base_mm_configs(device_type)

if use_triton_template(layout):
for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)):
for config in bmm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
bmm_template.maybe_append_choice(
choices,
input_nodes=(mat1, mat2),
Expand Down Expand Up @@ -239,8 +239,14 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
if use_aten_gemm_kernels()
else []
)

device_type = ir.get_device_type(mat1)
bmm_configs = V.choices.get_base_mm_configs(device_type)

if use_triton_template(layout):
for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)):
for config in bmm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
bmm_template.maybe_append_choice(
choices,
input_nodes=(inp, mat1, mat2),
Expand Down
56 changes: 9 additions & 47 deletions torch/_inductor/kernel/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

import logging
from typing import cast, Optional, TYPE_CHECKING, TypedDict
from typing import Optional, TYPE_CHECKING, TypedDict

import torch
from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate
Expand All @@ -29,7 +29,7 @@
use_triton_template,
)
from ..virtualized import V
from .mm_common import build_rocm_gemm_configs, filtered_configs
from .mm_common import mm_config_kwargs


if TYPE_CHECKING:
Expand Down Expand Up @@ -61,51 +61,13 @@ def conv3d_grid(n, c, d, h, w, meta, *, cdiv):
)


# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform
kernel_configs = [
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
{"config": (64, 256, 16, 2, 4), "cond": True},
{"config": (256, 64, 16, 2, 4), "cond": True},
{"config": (1024, 16, 16, 1, 8), "cond": True},
{"config": (128, 128, 32, 2, 8), "cond": True},
{"config": (64, 64, 32, 2, 4), "cond": True},
{"config": (64, 256, 32, 2, 8), "cond": True},
{"config": (256, 64, 32, 2, 8), "cond": True},
]

# Create filtered list of configs based on conv
platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in kernel_configs
if config["cond"]
)

# On ROCm convert num_stages to 1 as pipelining provides no benefit
if torch.version.hip and torch.cuda.is_available():
platform_configs = build_rocm_gemm_configs(platform_configs)


def _is_large_block_for_cpu(m, n, k):
# Thresholds are experimentally determined to reduce Triton CPU compile times
if m > 256 or n > 256 or k > 256:
return True
return m * n * k > 2**17


def conv_configs(m, n, k, *, device_type, **kwargs):
if device_type == "cpu":
return filtered_configs(
m,
n,
k,
configs=platform_configs,
scale=0.5,
exclude=_is_large_block_for_cpu,
)
return filtered_configs(m, n, k, configs=platform_configs)


LOOP_BODY_2D = """
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
Expand Down Expand Up @@ -497,6 +459,8 @@ def convolution(
"groups": groups,
}

device_type = ir.get_device_type(x)

if len(x.get_size()) == len(weight.get_size()) - 1:
# add batch dimension to simplify rest of function
return L[aten.squeeze](
Expand All @@ -511,11 +475,7 @@ def convolution(
# Always convert conv1D to 2D for Intel GPU.
# Only conv2D can be converted to channel last layout,
# which have much better performance.
if (
len(x.get_size()) == 3
and len(kernel_shape) == 1
and ir.get_device_type(x) == "xpu"
):
if len(x.get_size()) == 3 and len(kernel_shape) == 1 and device_type == "xpu":
kwargs.update(
{
"stride": (1,) + stride,
Expand Down Expand Up @@ -564,7 +524,7 @@ def channels_last_conv():
):
return convert_1x1_conv_to_mm(x, weight, bias)

if bias is not None and ir.get_device_type(x) != "cpu":
if bias is not None and device_type != "cpu":
# peel off the bias, cudnn is slower with it
result = convolution(x, weight, None, **kwargs)
return L[aten.add](
Expand Down Expand Up @@ -639,11 +599,13 @@ def channels_last_conv():
):
choices.append(aten_conv1x1_via_mm.bind(args, layout))

conv_configs = V.choices.get_conv_configs(device_type)

for cfg in conv_configs(
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
out_chan,
in_chan,
device_type=ir.get_device_type(x),
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
):
if ndim == 2:
conv2d_template.maybe_append_choice(
Expand Down
Loading