diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index ce7e941ee1ff..00de22393abf 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -1,20 +1,33 @@ from __future__ import annotations import typing -from typing import Any, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING import sympy +import torch + from . import config from .codecache import write_text from .metrics import get_metric_table, is_metric_table_enabled from .runtime.hints import DeviceProperties, ReductionHint from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse +from .template_heuristics import ( + BaseConfigHeuristic, + CPUConfigHeuristic, + CUDAConfigHeuristic, + ROCmConfigHeuristic, + XPUConfigHeuristic, +) from .virtualized import V if TYPE_CHECKING: - import torch + from collections.abc import Generator + from functools import partial + + from triton import Config as TritonConfig + from torch.utils._ordered_set import OrderedSet from .codegen.simd_kernel_features import SIMDKernelFeatures @@ -40,6 +53,99 @@ class MyHeuristics(InductorChoices): torch._inductor.virtualized.V.set_choices_handler(MyHeuristics()) """ + def get_config_heuristics( + self, device_type: Optional[str] = "cuda" + ) -> BaseConfigHeuristic: + if device_type == "cuda": + if torch.version.hip is None: + return CUDAConfigHeuristic() + else: + return ROCmConfigHeuristic() + elif device_type == "xpu": + return XPUConfigHeuristic() + elif device_type == "cpu": + return CPUConfigHeuristic() + else: + return BaseConfigHeuristic() + + # GEMM configs + def get_base_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + if config.max_autotune_gemm_search_space != "EXHAUSTIVE": + return mm_heuristics.get_mm_configs() + else: + return mm_heuristics.get_exhaustive_mm_configs() + + def get_extra_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_extra_mm_configs() + + def get_int8_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_int8_mm_configs() + + def get_mixed_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_mixed_mm_configs() + + def get_persistent_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_persistent_mm_configs() + + def get_scaled_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_scaled_mm_configs() + + def get_scaled_persistent_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_scaled_persistent_mm_configs() + + def get_mm_plus_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_mm_plus_mm_configs() + + # Conv configs + def get_conv_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + conv_heuristics = self.get_config_heuristics(device_type) + return conv_heuristics.get_conv_configs() + + # Flex attention configs + def get_flex_attention_fwd_configs( + self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda" + ) -> list[Any]: + flex_heuristics = self.get_config_heuristics(device_type) + return flex_heuristics.get_flex_attn_fwd_configs(head_dim, dtype) + + def get_flex_attention_bwd_configs( + self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda" + ) -> list[Any]: + flex_heuristics = self.get_config_heuristics(device_type) + return flex_heuristics.get_flex_attn_bwd_configs(head_dim, dtype) + + def get_flex_decode_configs( + self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda" + ) -> list[Any]: + flex_heuristics = self.get_config_heuristics(device_type) + return flex_heuristics.get_flex_decode_configs(head_dim, dtype) + def triton_kernel_kwargs( self, kernel_cls: type[TritonKernel], diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 2b05bb24d747..ed930ad296a9 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -397,12 +397,22 @@ def prologue_fusion_enabled() -> bool: "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE", "DEFAULT" ).upper() # type: ignore[assignment] +# Specify the size of the search space for flex attention autotuning. +# DEFAULT - balance between compile time overhead and performance +# EXHAUSTIVE - maximize performance +max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT" +).upper() # type: ignore[assignment] + # NOTE: This feature is deprecated and will be defauled to False in the future. # Whether we fall back to ATen or hard error when no matches are found during autotuning autotune_fallback_to_aten = ( os.environ.get("TORCHINDUCTOR_AUTOTUNE_FALLBACK_TO_ATEN", "1") == "1" ) +# DEPRECATED. This setting is ignored. +autotune_fallback_to_aten = False + # the value used as a fallback for the unbacked SymInts # that can appear in the input shapes (e.g., in autotuning) unbacked_symint_fallback = 8192 diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py index b0b1e07787a6..c3886111cb02 100644 --- a/torch/_inductor/kernel/bmm.py +++ b/torch/_inductor/kernel/bmm.py @@ -24,7 +24,7 @@ _is_static_problem, addmm_epilogue, mm_args, - mm_configs, + mm_config_kwargs, mm_options, should_fallback_to_aten, ) @@ -46,12 +46,6 @@ def _is_large_block_for_cpu(m, n, k): return m * n > 2**12 -def bmm_configs(m, n, k, *, device_type): - if device_type == "cpu": - return mm_configs(m, n, k, scale=0.5, exclude=_is_large_block_for_cpu) - return mm_configs(m, n, k) - - bmm_template = TritonTemplate( name="bmm", grid=bmm_grid, @@ -184,8 +178,14 @@ def may_require_contiguous(t, meta_t): # options to tune from choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else [] + + device_type = ir.get_device_type(mat1) + bmm_configs = V.choices.get_base_mm_configs(device_type) + if use_triton_template(layout): - for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)): + for config in bmm_configs( + m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) + ): bmm_template.maybe_append_choice( choices, input_nodes=(mat1, mat2), @@ -239,8 +239,14 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): if use_aten_gemm_kernels() else [] ) + + device_type = ir.get_device_type(mat1) + bmm_configs = V.choices.get_base_mm_configs(device_type) + if use_triton_template(layout): - for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)): + for config in bmm_configs( + m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) + ): bmm_template.maybe_append_choice( choices, input_nodes=(inp, mat1, mat2), diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py index 3d74718a01eb..9e6c5e8d42b8 100644 --- a/torch/_inductor/kernel/conv.py +++ b/torch/_inductor/kernel/conv.py @@ -2,7 +2,7 @@ from __future__ import annotations import logging -from typing import cast, Optional, TYPE_CHECKING, TypedDict +from typing import Optional, TYPE_CHECKING, TypedDict import torch from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate @@ -29,7 +29,7 @@ use_triton_template, ) from ..virtualized import V -from .mm_common import build_rocm_gemm_configs, filtered_configs +from .mm_common import mm_config_kwargs if TYPE_CHECKING: @@ -61,31 +61,6 @@ def conv3d_grid(n, c, d, h, w, meta, *, cdiv): ) -# List of dictionaries to store the kernel configs. Configs that evaluate to true -# will be utilised on the target platform -kernel_configs = [ - # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps" - {"config": (64, 256, 16, 2, 4), "cond": True}, - {"config": (256, 64, 16, 2, 4), "cond": True}, - {"config": (1024, 16, 16, 1, 8), "cond": True}, - {"config": (128, 128, 32, 2, 8), "cond": True}, - {"config": (64, 64, 32, 2, 4), "cond": True}, - {"config": (64, 256, 32, 2, 8), "cond": True}, - {"config": (256, 64, 32, 2, 8), "cond": True}, -] - -# Create filtered list of configs based on conv -platform_configs = tuple( - cast(tuple[int, int, int, int, int], config["config"]) - for config in kernel_configs - if config["cond"] -) - -# On ROCm convert num_stages to 1 as pipelining provides no benefit -if torch.version.hip and torch.cuda.is_available(): - platform_configs = build_rocm_gemm_configs(platform_configs) - - def _is_large_block_for_cpu(m, n, k): # Thresholds are experimentally determined to reduce Triton CPU compile times if m > 256 or n > 256 or k > 256: @@ -93,19 +68,6 @@ def _is_large_block_for_cpu(m, n, k): return m * n * k > 2**17 -def conv_configs(m, n, k, *, device_type, **kwargs): - if device_type == "cpu": - return filtered_configs( - m, - n, - k, - configs=platform_configs, - scale=0.5, - exclude=_is_large_block_for_cpu, - ) - return filtered_configs(m, n, k, configs=platform_configs) - - LOOP_BODY_2D = """ idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W @@ -497,6 +459,8 @@ def convolution( "groups": groups, } + device_type = ir.get_device_type(x) + if len(x.get_size()) == len(weight.get_size()) - 1: # add batch dimension to simplify rest of function return L[aten.squeeze]( @@ -511,11 +475,7 @@ def convolution( # Always convert conv1D to 2D for Intel GPU. # Only conv2D can be converted to channel last layout, # which have much better performance. - if ( - len(x.get_size()) == 3 - and len(kernel_shape) == 1 - and ir.get_device_type(x) == "xpu" - ): + if len(x.get_size()) == 3 and len(kernel_shape) == 1 and device_type == "xpu": kwargs.update( { "stride": (1,) + stride, @@ -564,7 +524,7 @@ def channels_last_conv(): ): return convert_1x1_conv_to_mm(x, weight, bias) - if bias is not None and ir.get_device_type(x) != "cpu": + if bias is not None and device_type != "cpu": # peel off the bias, cudnn is slower with it result = convolution(x, weight, None, **kwargs) return L[aten.add]( @@ -639,11 +599,13 @@ def channels_last_conv(): ): choices.append(aten_conv1x1_via_mm.bind(args, layout)) + conv_configs = V.choices.get_conv_configs(device_type) + for cfg in conv_configs( sympy_product([x.get_size()[0], *x.get_size()[2:]]), out_chan, in_chan, - device_type=ir.get_device_type(x), + **mm_config_kwargs(device_type, _is_large_block_for_cpu), ): if ndim == 2: conv2d_template.maybe_append_choice( diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 4a3abe58d9f9..b1bd86590fa0 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -18,7 +18,6 @@ from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.value_ranges import ValueRanges -from .. import config from ..ir import ( Buffer, ComputedBuffer, @@ -782,102 +781,6 @@ class Mode(Enum): bwd = auto() -def _get_rocm_config(query, mode: Mode) -> tuple[int, int, int, int]: - dtype = query.get_dtype() - head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1]) - fwd_config = None - - if mode == Mode.fwd: - if head_dim <= 256: - if dtype == torch.float32: - fwd_config = (64, 64, 4, 1) - else: - fwd_config = (128, 64, 8, 1) - fwd_config = _rocm_default_config.get((dtype, head_dim), fwd_config) - else: # modest hardware or extremely large head_dim - if dtype == torch.float32: - fwd_config = (32, 16, 4, 1) - else: - fwd_config = (64, 32, 4, 1) - return fwd_config - else: # bwd - assert mode == Mode.bwd - if dtype == torch.float32: - return (16, 16, 4, 1) - elif head_dim <= 256: - if head_dim == 64: - return (64, 64, 4, 1) - elif head_dim == 128: - return (64, 128, 8, 1) - else: - return (64, 64, 4, 1) - else: # modest hardware or extremely large head_dim - return (16, 16, 4, 1) - - -def _get_nv_config(query, mode: Mode) -> tuple[int, int, int, int]: - dtype = query.get_dtype() - head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1]) - fwd_config = None - bwd_config = None - capability = torch.cuda.get_device_capability() - - if mode == Mode.fwd: - if head_dim <= 256: - if dtype == torch.float32: - fwd_config = (64, 64, 4, 3) - else: - fwd_config = (128, 64, 4, 3) - if capability >= (9, 0): - fwd_config = _h100_default_config.get((dtype, head_dim), fwd_config) - elif capability >= (8, 0): - fwd_config = _a100_default_config.get((dtype, head_dim), fwd_config) - else: # modest hardware or extremely large head_dim - if dtype == torch.float32: - fwd_config = (32, 16, 4, 3) - else: - fwd_config = (64, 32, 4, 3) - return fwd_config - - else: # bwd - assert mode == Mode.bwd - if dtype == torch.float32: - bwd_config = (16, 16, 4, 1) - elif head_dim <= 256 and capability >= (9, 0): # H100 - if head_dim == 64: - bwd_config = (64, 64, 4, 3) - elif head_dim == 128: - bwd_config = (64, 128, 8, 3) - else: - bwd_config = (64, 64, 4, 2) - elif capability >= (8, 0): - if head_dim >= 64: - bwd_config = (32, 128, 4, 3) - elif head_dim == 128: - # SM86/89 have smaller shared memory sizes - num_stages = 3 if capability[-1] == 0 else 2 - bwd_config = (64, 64, 4, num_stages) - else: - bwd_config = (64, 64, 4, 2) - else: # modest hardware or extremely large head_dim - bwd_config = (16, 16, 4, 1) - return bwd_config - - -def _get_default_config_fwd(query) -> tuple[int, int, int, int]: - if torch.version.hip is None: - return _get_nv_config(query, mode=Mode.fwd) - else: - return _get_rocm_config(query, mode=Mode.fwd) - - -def _get_default_config_bwd(query) -> tuple[int, int, int, int]: - if torch.version.hip is None: - return _get_nv_config(query, mode=Mode.bwd) - else: - return _get_rocm_config(query, mode=Mode.bwd) - - def create_num_blocks_fake_generator(sparse_indices): # The idea here is that we need to create a real tensor with real data # that's representative for benchmarking. @@ -1462,35 +1365,28 @@ def flex_attention( set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) choices: list[Any] = [] - configs: list[tuple[int, int, int, int]] = [] - configs.append(_get_default_config_fwd(query)) - if config.max_autotune: - configs += [ - (128, 64, 4, 3), - (128, 128, 4, 3), - (128, 128, 8, 2), - (64, 128, 4, 3), - (64, 64, 4, 3), - ] - # On ROCm convert num_stages to 1 to avoid shmem issues - if torch.version.hip: - configs = [(c[0], c[1], c[2], 1) for c in configs] + dtype = query.get_dtype() + head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1]) + configs = V.choices.get_flex_attention_fwd_configs(head_dim, dtype) # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) - # ROCm specific considerations - if torch.version.hip: - kernel_options["kpack"] = 2 - # Note, we don't need to pass in the captured buffers explicitly # because they're implicitly added by the score_mod function # We do need to explicitly pass it in for autotuning though. original_kernel_options = kernel_options.copy() - for BLOCK_M, BLOCK_N, num_warps, num_stages in configs: - if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0 or SPARSE_Q_BLOCK_SIZE % BLOCK_M != 0: + + # Default config for warp specialization + num_consumer_groups, num_buffers_warp_spec = 0, 0 + + for conf in configs: + if ( + SPARSE_KV_BLOCK_SIZE % conf.block_n != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_m != 0 + ): if len(configs) == 1: raise ValueError( f"Q and KV block size must be divisible by BLOCK_M and BLOCK_N. We " @@ -1508,14 +1404,30 @@ def flex_attention( cur_kernel_options[k[4:]] = v if k.startswith("bwd_"): cur_kernel_options.pop(k) - cur_kernel_options.setdefault("num_stages", num_stages) - cur_kernel_options.setdefault("num_warps", num_warps) - cur_kernel_options.setdefault("BLOCK_M", BLOCK_M) - cur_kernel_options.setdefault("BLOCK_N", BLOCK_N) + + cur_kernel_options.setdefault("num_stages", conf.num_stages) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + if cur_kernel_options.get("num_consumer_groups", False): + cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) + cur_kernel_options.setdefault( + "num_buffers_warp_spec", num_buffers_warp_spec + ) + + # Disabling TMA by default, only explicit kernel_options supported for now + cur_kernel_options.setdefault("USE_TMA", False) + + cur_kernel_options.setdefault("BLOCK_M", conf.block_m) + cur_kernel_options.setdefault("BLOCK_N", conf.block_n) + # Blocksparse options cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + # ROCm specific kernargs + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + error = flex_attention_template.maybe_append_choice( choices=choices, input_nodes=[ @@ -2578,13 +2490,21 @@ def flex_attention_backward(*args, **kwargs): if BLOCK2 % BLOCK1 == 0 ] ) + + dtype = query.get_dtype() + head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1]) + configs = V.choices.get_flex_attention_bwd_configs(head_dim, dtype) + + # Default config for warp specialization + num_consumer_groups, num_buffers_warp_spec = 0, 0 + original_kernel_options = kernel_options.copy() - for BLOCK1, BLOCK2, num_warps, num_stages in configs: + for conf in configs: if ( - SPARSE_KV_BLOCK_SIZE % BLOCK1 != 0 - or SPARSE_Q_BLOCK_SIZE % BLOCK1 != 0 - or SPARSE_KV_BLOCK_SIZE % BLOCK2 != 0 - or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0 + SPARSE_KV_BLOCK_SIZE % conf.block_m != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_m != 0 + or SPARSE_KV_BLOCK_SIZE % conf.block_n != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_n != 0 ): continue @@ -2598,17 +2518,29 @@ def flex_attention_backward(*args, **kwargs): cur_kernel_options[k[4:]] = v if k.startswith("fwd_"): cur_kernel_options.pop(k) - cur_kernel_options.setdefault("num_warps", num_warps) - cur_kernel_options.setdefault("num_stages", num_stages) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + cur_kernel_options.setdefault("num_stages", conf.num_stages) + + if cur_kernel_options.get("num_consumer_groups", False): + cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) + cur_kernel_options.setdefault( + "num_buffers_warp_spec", num_buffers_warp_spec + ) + + cur_kernel_options.setdefault("BLOCK_M1", conf.block_m) + cur_kernel_options.setdefault("BLOCK_N1", conf.block_n) + cur_kernel_options.setdefault("BLOCK_M2", conf.block_n) + cur_kernel_options.setdefault("BLOCK_N2", conf.block_m) - cur_kernel_options.setdefault("BLOCK_M1", BLOCK1) - cur_kernel_options.setdefault("BLOCK_N1", BLOCK2) - cur_kernel_options.setdefault("BLOCK_M2", BLOCK2) - cur_kernel_options.setdefault("BLOCK_N2", BLOCK1) # Blocksparse options cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + # ROCm specific kernargs + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + flex_attention_backward_template.maybe_append_choice( choices=choices, input_nodes=[ diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py index 54b665595b65..74e511ca6a5e 100644 --- a/torch/_inductor/kernel/flex_decoding.py +++ b/torch/_inductor/kernel/flex_decoding.py @@ -8,7 +8,7 @@ import torch from torch._inductor.virtualized import V -from .. import config, ir +from .. import ir from ..ir import FixedLayout, FlexibleLayout from ..lowering import empty, empty_strided, lowerings from ..runtime.runtime_utils import is_power_of_2, next_power_of_2 @@ -318,21 +318,6 @@ def get_split_k(B: int, H: int, Mk: int) -> int: return split_k -def _get_decoding_default_config(key) -> tuple[int, int, int]: - dtype = key.get_dtype() - head_dim = key.get_size()[-1] - sm_version = torch.cuda.get_device_capability() - default_config = (64, 2, 1) - if sm_version >= (9, 0): - if head_dim > 128 and dtype == torch.float32: - return default_config - if torch.version.hip is None: - return (64, 2, 3) - else: - return (64, 2, 1) - return default_config - - def create_flex_decoding_kernel(*args, **kwargs): from .flex_attention import set_head_dim_values @@ -427,19 +412,9 @@ def create_flex_decoding_kernel(*args, **kwargs): mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) choices: list[Any] = [] - configs: list[tuple[int, int, int]] = [] - configs.append(_get_decoding_default_config(key)) - # Note: max_autotune is not supported yet. Causes error in lowering the dynamic shape in reduction ops. - if config.max_autotune: - configs += [ - (64, 2, 2), - (32, 2, 3), - (128, 2, 3), - ] - - # Use num_stages=1 on ROCm to avoid shmem limitation - if torch.version.hip: - configs = [(c[0], c[1], 1) for c in configs] + dtype = key.get_dtype() + head_dim = V.graph.sizevars.evaluate_static_shape(key.get_size()[-1]) + configs = V.choices.get_flex_decode_configs(head_dim, dtype) # TODO: fix autotuning. @@ -522,8 +497,12 @@ def create_flex_decoding_kernel(*args, **kwargs): # Note, we don't need to pass in the captured buffers explicitly # because they're implicitly added by the score_mod function # We do need to explicitly pass it in for autotuning though. - for BLOCK_N, num_warps, num_stages in configs: - if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0: + + # Default config for warp specialization + num_consumer_groups, num_buffers_warp_spec = 0, 0 + + for conf in configs: + if SPARSE_KV_BLOCK_SIZE % conf.block_n != 0: continue cur_kernel_options = original_kernel_options.copy() @@ -535,10 +514,24 @@ def create_flex_decoding_kernel(*args, **kwargs): if k.startswith("bwd_"): cur_kernel_options.pop(k) # Performance tuning - cur_kernel_options.setdefault("BLOCK_N", BLOCK_N) + cur_kernel_options.setdefault("BLOCK_N", conf.block_n) cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) - cur_kernel_options.setdefault("num_warps", num_warps) - cur_kernel_options.setdefault("num_stages", num_stages) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + cur_kernel_options.setdefault("num_stages", conf.num_stages) + + if cur_kernel_options.get("num_consumer_groups", False): + cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) + cur_kernel_options.setdefault( + "num_buffers_warp_spec", num_buffers_warp_spec + ) + + # Set default to False + cur_kernel_options.setdefault("USE_TMA", False) + + # Add ROCm-specific parameters if they exist in the config + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) flex_decoding_template.maybe_append_choice( choices=choices, diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index f9f91dc342f0..16ae263bc3d4 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -28,7 +28,6 @@ TritonTemplate, ) from ..utils import ( - get_gpu_shared_memory, get_tma_workspace_arg, use_aten_gemm_kernels, use_ck_gemm_template, @@ -41,17 +40,13 @@ from .mm_common import ( _is_static_problem, addmm_epilogue, - extra_mm_configs, - int8_mm_configs, mm_args, - mm_configs, + mm_config_kwargs, mm_grid, mm_options, - persistent_mm_configs, persistent_mm_grid, persistent_mm_options, should_fallback_to_aten, - triton_config, ) @@ -359,15 +354,6 @@ def _is_large_block_for_cpu(m, n, k): return m * n > 2**13 -def mm_config_kwargs(device): - if device == "cpu": - return { - "scale": 0.5, - "exclude": _is_large_block_for_cpu, - } - return {} - - def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): """ Giving torch.addmm a 1D tensor calls a different (faster) cublasLt @@ -385,6 +371,7 @@ def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): @register_lowering(aten.mm, type_promotion_kind=None) def tuned_mm(mat1, mat2, *, layout=None): m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + device_type = ir.get_device_type(mat1) name = "mm" # below is for getting an overview logging info of inductor mms @@ -410,8 +397,15 @@ def tuned_mm(mat1, mat2, *, layout=None): [aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else [] ) static_shape, is_nonzero = _is_static_problem(layout) + + mm_configs = V.choices.get_base_mm_configs(device_type) + persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type) + extra_mm_configs = V.choices.get_extra_mm_configs(device_type) + if is_nonzero and use_triton_template(layout): - for config in mm_configs(m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))): + for config in mm_configs( + m, n, k, *mm_config_kwargs(device_type, _is_large_block_for_cpu) + ): mm_template.maybe_append_choice( choices, input_nodes=(mat1, mat2), @@ -420,7 +414,7 @@ def tuned_mm(mat1, mat2, *, layout=None): ) if use_triton_tma_template(mat1, mat2): for config in persistent_mm_configs( - m, n, k, **mm_config_kwargs(ir.get_device_type(mat1)) + m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) ): persistent_tma_mm_template.maybe_append_choice( choices, @@ -459,7 +453,7 @@ def tuned_mm(mat1, mat2, *, layout=None): always_included.append("extern_mm") num_choices_before_extra_configs = len(choices) for config in extra_mm_configs( - m, n, k, **mm_config_kwargs(ir.get_device_type(mat1)) + m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) ): mm_template.maybe_append_choice( choices, @@ -521,6 +515,8 @@ def tuned_int_mm(mat1, mat2, *, layout=None): layout, ) + device_type = ir.get_device_type(mat1) + static_shape, is_nonzero = _is_static_problem(layout) use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k) @@ -532,9 +528,12 @@ def tuned_int_mm(mat1, mat2, *, layout=None): CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True ) + + int8_mm_configs = V.choices.get_int8_mm_configs(device_type) + if is_nonzero and use_triton_template(layout, enable_int32=True): for config in int8_mm_configs( - m, n, k, **mm_config_kwargs(ir.get_device_type(mat1)) + m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) ): mm_template.maybe_append_choice( choices, @@ -552,6 +551,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None): @register_lowering(aten.addmm, type_promotion_kind=None) def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): ordered_kwargs_for_cpp_kernel = ("beta", "alpha") + device_type = ir.get_device_type(mat1) m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout) static_shape, is_nonzero = _is_static_problem(layout) @@ -617,8 +617,13 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): ), ) + mm_configs = V.choices.get_base_mm_configs(device_type) + persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type) + if is_nonzero and use_triton_template(layout): - for config in mm_configs(m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))): + for config in mm_configs( + m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) + ): mm_template.maybe_append_choice( choices, input_nodes=(inp_expanded, mat1, mat2), @@ -630,7 +635,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): if use_triton_tma_template(mat1, mat2): for config in persistent_mm_configs( - m, n, k, **mm_config_kwargs(ir.get_device_type(mat1)) + m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) ): persistent_tma_mm_template.maybe_append_choice( choices, @@ -769,52 +774,6 @@ def dims_are_int(dims): return all(isinstance(dim, int) for dim in dims) -def try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout): - m, n, k = get_size_hints(mat1, mat2, m, n, k) - if not dims_are_int([m, n, k]): - return None - - if mat1.dtype != torch.float16: - return None - - # only use heuristic if we are running on an A100 - # torch.cuda.get_device_capability() >= (8, 0) returns true for A10G - # which does not have enough shared memory for one of the configs - if ( - not torch.cuda.get_device_capability() >= (8, 0) - ) or get_gpu_shared_memory() != 166912: - return None - - if m == 1 and (n % 16 != 0 or k % 16 != 0): - return None - - if m <= 16 and n >= 4096 and k >= 4096: - return triton_config( - BLOCK_M=16, - BLOCK_N=64, - BLOCK_K=128, - num_stages=5, - num_warps=4, - ) - elif m > 16 and m <= 32 and n >= 4096 and k >= 4096: - return triton_config( - BLOCK_M=32, - BLOCK_N=32, - BLOCK_K=128, - num_stages=5, - num_warps=4, - ) - elif m > 32 and m <= 64 and n >= 4096 and k >= 4096: - return triton_config( - BLOCK_M=64, - BLOCK_N=32, - BLOCK_K=128, - num_stages=5, - num_warps=4, - ) - return None - - def mm_autoheuristic( mat1, mat2, diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index b4c5ea612023..08d946869364 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -1,440 +1,22 @@ # mypy: allow-untyped-defs -import functools -import itertools import logging -from collections.abc import Sequence -from typing import Any, cast +from typing import Any import sympy import torch from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn from torch._inductor.virtualized import V -from torch.utils._ordered_set import OrderedSet from .. import config as inductor_config from ..codegen.wrapper import PythonWrapperCodegen from ..ir import ChoiceCaller, Layout -from ..runtime.runtime_utils import next_power_of_2 -from ..utils import ( - get_backend_num_stages, - get_num_sms, - TMA_DESCRIPTOR_SIZE, - use_aten_gemm_kernels, -) +from ..utils import get_num_sms, TMA_DESCRIPTOR_SIZE, use_aten_gemm_kernels log = logging.getLogger(__name__) -def triton_config(num_stages, num_warps, **kwargs): - from triton import Config # type: ignore[attr-defined] - - return Config(kwargs, num_stages=num_stages, num_warps=num_warps) - - -def build_rocm_gemm_configs(configs): - rocm_num_stages = get_backend_num_stages() - return tuple((c[0], c[1], c[2], rocm_num_stages, c[4]) for c in configs) - - -def filtered_configs( - m: int, - n: int, - k: int, - configs: Sequence[tuple[int, int, int, int, int]], - has_int8_tensor=False, - scale=1, - exclude=lambda m, n, k: False, -): - """ - Heuristic to shrink configs when they are bigger than the input size - - :param scale: scale factor applied to the config values - :param exclude: whether a given config should be excluded - """ - from torch._inductor import config - - max_mm_configs = config.test_configs.max_mm_configs - - min_block_size = 16 - # block_k=16 seems to be causing issues - # see: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424 - min_block_size_k = 32 if has_int8_tensor else 16 - m = max( - next_power_of_2( - V.graph.sizevars.size_hint( - m, - fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type] - ) - ), - min_block_size, - ) - n = max( - next_power_of_2( - V.graph.sizevars.size_hint( - n, - fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type] - ) - ), - min_block_size, - ) - k = max( - next_power_of_2( - V.graph.sizevars.size_hint( - k, - fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type] - ) - ), - min_block_size_k, - ) - used = OrderedSet[tuple[int, ...]]() - for block_m, block_n, block_k, num_stages, num_warps in configs: - # shrink configs for small sizes - block_m = max(min(int(block_m * scale), m), min_block_size) - block_n = max(min(int(block_n * scale), n), min_block_size) - block_k = max(min(int(block_k * scale), k), min_block_size_k) - - if exclude(block_m, block_n, block_k): - continue - - # each warp computes 16x16 tile = 256 - num_warps = min(num_warps, block_m * block_n // 256) - if torch.version.hip: - kpack = 2 - for matrix_instr_nonkdim in [0, 16]: - if matrix_instr_nonkdim != 0 and ( - block_m % matrix_instr_nonkdim != 0 - or block_n % matrix_instr_nonkdim != 0 - ): - # block_m and block_n must be a multiple of matrix_instr_nonkdim - continue - - if ( - block_m, - block_n, - block_k, - num_stages, - num_warps, - matrix_instr_nonkdim, - kpack, - ) not in used and ( - max_mm_configs is None or len(used) < max_mm_configs - ): - used.add( - ( - block_m, - block_n, - block_k, - num_stages, - num_warps, - matrix_instr_nonkdim, - kpack, - ) - ) - yield triton_config( - BLOCK_M=block_m, - BLOCK_N=block_n, - BLOCK_K=block_k, - num_stages=num_stages, - num_warps=num_warps, - matrix_instr_nonkdim=matrix_instr_nonkdim, - kpack=kpack, - ) - else: - if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used and ( - max_mm_configs is None or len(used) < max_mm_configs - ): - used.add((block_m, block_n, block_k, num_stages, num_warps, 0)) - yield triton_config( - BLOCK_M=block_m, - BLOCK_N=block_n, - BLOCK_K=block_k, - num_stages=num_stages, - num_warps=num_warps, - ) - - -# List of dictionaries to store the kernel configs. Configs that evaluate to true -# will be utilised on the target platform. The configs are as follows: -# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) -mm_kernel_configs = ( - [ - {"config": (32, 32, 16, 1, 2), "cond": True}, - {"config": (32, 32, 128, 2, 4), "cond": True}, - {"config": (32, 64, 32, 5, 8), "cond": True}, - {"config": (64, 32, 32, 5, 8), "cond": True}, - {"config": (64, 32, 128, 5, 4), "cond": True}, - {"config": (64, 64, 16, 2, 4), "cond": True}, - {"config": (64, 64, 32, 2, 4), "cond": True}, - {"config": (64, 64, 64, 3, 8), "cond": True}, - {"config": (64, 64, 128, 5, 4), "cond": True}, - {"config": (64, 128, 32, 3, 4), "cond": True}, - {"config": (64, 128, 32, 4, 8), "cond": True}, - {"config": (64, 128, 64, 3, 4), "cond": True}, - {"config": (64, 128, 128, 4, 4), "cond": True}, - {"config": (128, 64, 32, 3, 4), "cond": True}, - {"config": (128, 64, 32, 4, 8), "cond": True}, - {"config": (128, 128, 32, 2, 8), "cond": True}, - {"config": (128, 128, 32, 3, 4), "cond": True}, - {"config": (128, 128, 64, 3, 4), "cond": True}, - {"config": (128, 128, 64, 5, 8), "cond": True}, - ] - if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE" - else [ - {"config": (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps), "cond": True} - for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( - [16, 32, 64, 128, 256], repeat=3 - ) - for num_stages in [1, 2, 3, 4, 5] - for num_warps in [2, 4, 8] - ] -) - -# these are only used in tuned_mm when AutoHeuristic is enabled -# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned -# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10 -# which saves compilation time (since less configs are autotuned) and potentially increase performance -# because the learned heuristic might predict a config that is not part mm_configs -extra_mm_kernel_configs = [ - {"config": (16, 32, 16, 3, 2), "cond": True}, - {"config": (16, 32, 32, 4, 2), "cond": True}, - {"config": (16, 32, 32, 5, 2), "cond": True}, - {"config": (64, 64, 128, 3, 4), "cond": True}, - {"config": (128, 64, 32, 2, 2), "cond": True}, - {"config": (128, 64, 64, 3, 8), "cond": True}, - {"config": (128, 64, 128, 4, 8), "cond": True}, - {"config": (128, 128, 32, 4, 4), "cond": True}, - {"config": (128, 128, 64, 3, 8), "cond": True}, - {"config": (128, 128, 64, 5, 4), "cond": True}, -] - -int8_mm_kernel_configs = [ - {"config": (64, 64, 32, 2, 4), "cond": True}, - {"config": (64, 128, 32, 3, 4), "cond": True}, - {"config": (128, 64, 32, 3, 4), "cond": True}, - {"config": (64, 128, 32, 4, 8), "cond": True}, - {"config": (128, 64, 32, 4, 8), "cond": True}, - {"config": (64, 32, 32, 5, 8), "cond": True}, - {"config": (32, 64, 32, 5, 8), "cond": True}, - {"config": (128, 128, 32, 2, 8), "cond": True}, - {"config": (64, 64, 64, 3, 8), "cond": True}, - # {"config": (32, 32, 128, 2, 4), "cond": True}, - # {"config": (64, 64, 16, 2, 4), "cond": True}, - # {"config": (32, 32, 16, 1, 2), "cond": True}, - {"config": (128, 256, 128, 3, 8), "cond": True}, - {"config": (256, 128, 128, 3, 8), "cond": True}, -] - -# Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192). -mixed_mm_kernel_configs_small_m = [ - {"config": (16, 128, 256, 3, 4), "cond": True}, - {"config": (16, 128, 256, 5, 8), "cond": True}, -] - -mixed_mm_kernel_configs = ( - mm_kernel_configs + mixed_mm_kernel_configs_small_m - if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE" - else mm_kernel_configs -) - -persistent_mm_kernel_configs = [ - {"config": (128, 256, 64, 3, 8), "cond": True}, - {"config": (128, 128, 64, 3, 8), "cond": True}, - {"config": (128, 128, 128, 3, 8), "cond": True}, - {"config": (128, 128, 128, 3, 4), "cond": True}, - {"config": (128, 128, 64, 4, 8), "cond": True}, -] - -scaled_mm_kernel_configs = [ - {"config": (128, 256, 32, 3, 8), "cond": True}, - {"config": (256, 128, 32, 3, 8), "cond": True}, - {"config": (256, 64, 32, 4, 4), "cond": True}, - {"config": (64, 256, 32, 4, 4), "cond": True}, - {"config": (128, 128, 32, 4, 4), "cond": True}, - {"config": (128, 64, 32, 4, 4), "cond": True}, - {"config": (64, 128, 32, 4, 4), "cond": True}, - {"config": (128, 32, 32, 4, 4), "cond": True}, - {"config": (64, 32, 32, 5, 2), "cond": True}, - {"config": (256, 128, 128, 3, 8), "cond": True}, - {"config": (256, 64, 128, 4, 4), "cond": True}, - {"config": (64, 256, 128, 4, 4), "cond": True}, - {"config": (128, 128, 128, 4, 4), "cond": True}, - {"config": (128, 64, 64, 4, 4), "cond": True}, - {"config": (64, 128, 64, 4, 4), "cond": True}, - {"config": (128, 32, 64, 4, 4), "cond": True}, - {"config": (64, 32, 64, 5, 2), "cond": True}, - {"config": (16, 32, 32, 2, 2), "cond": True}, - {"config": (16, 64, 32, 2, 2), "cond": True}, - {"config": (16, 128, 32, 2, 4), "cond": True}, - {"config": (16, 256, 32, 2, 4), "cond": True}, - {"config": (16, 32, 64, 2, 2), "cond": True}, - {"config": (16, 64, 64, 2, 2), "cond": True}, - {"config": (16, 128, 64, 2, 4), "cond": True}, - {"config": (16, 256, 64, 2, 4), "cond": True}, - {"config": (32, 32, 32, 2, 2), "cond": True}, - {"config": (32, 64, 32, 2, 2), "cond": True}, - {"config": (32, 128, 32, 2, 4), "cond": True}, - {"config": (32, 256, 32, 2, 4), "cond": True}, - {"config": (32, 32, 64, 2, 2), "cond": True}, - {"config": (32, 64, 64, 2, 2), "cond": True}, - {"config": (32, 128, 64, 2, 4), "cond": True}, - {"config": (32, 256, 64, 2, 4), "cond": True}, - {"config": (16, 32, 32, 3, 2), "cond": True}, - {"config": (16, 64, 32, 3, 2), "cond": True}, - {"config": (16, 128, 32, 3, 4), "cond": True}, - {"config": (16, 256, 32, 3, 4), "cond": True}, - {"config": (16, 32, 64, 3, 2), "cond": True}, - {"config": (16, 64, 64, 3, 2), "cond": True}, - {"config": (16, 128, 64, 3, 4), "cond": True}, - {"config": (16, 256, 64, 3, 4), "cond": True}, - {"config": (32, 32, 32, 3, 2), "cond": True}, - {"config": (32, 64, 32, 3, 2), "cond": True}, - {"config": (32, 128, 32, 3, 4), "cond": True}, - {"config": (32, 256, 32, 3, 4), "cond": True}, - {"config": (32, 32, 64, 3, 2), "cond": True}, - {"config": (32, 64, 64, 3, 2), "cond": True}, - {"config": (32, 128, 64, 3, 4), "cond": True}, - {"config": (32, 256, 64, 3, 4), "cond": True}, - {"config": (16, 32, 32, 4, 2), "cond": True}, - {"config": (16, 64, 32, 4, 2), "cond": True}, - {"config": (16, 128, 32, 4, 4), "cond": True}, - {"config": (16, 256, 32, 4, 4), "cond": True}, - {"config": (16, 32, 64, 4, 2), "cond": True}, - {"config": (16, 64, 64, 4, 2), "cond": True}, - {"config": (16, 128, 64, 4, 4), "cond": True}, - {"config": (16, 256, 64, 4, 4), "cond": True}, - {"config": (32, 32, 32, 4, 2), "cond": True}, - {"config": (32, 64, 32, 4, 2), "cond": True}, - {"config": (32, 128, 32, 4, 4), "cond": True}, - {"config": (32, 256, 32, 4, 4), "cond": True}, - {"config": (32, 32, 64, 4, 2), "cond": True}, - {"config": (32, 64, 64, 4, 2), "cond": True}, - {"config": (32, 128, 64, 4, 4), "cond": True}, - {"config": (32, 256, 64, 4, 4), "cond": True}, - {"config": (16, 32, 32, 5, 2), "cond": True}, - {"config": (16, 64, 32, 5, 2), "cond": True}, - {"config": (16, 128, 32, 5, 4), "cond": True}, - {"config": (16, 256, 32, 5, 4), "cond": True}, - {"config": (16, 32, 64, 5, 2), "cond": True}, - {"config": (16, 64, 64, 5, 2), "cond": True}, - {"config": (16, 128, 64, 5, 4), "cond": True}, - {"config": (16, 256, 64, 5, 4), "cond": True}, - {"config": (32, 32, 32, 5, 2), "cond": True}, - {"config": (32, 64, 32, 5, 2), "cond": True}, - {"config": (32, 128, 32, 5, 4), "cond": True}, - {"config": (32, 256, 32, 5, 4), "cond": True}, - {"config": (32, 32, 64, 5, 2), "cond": True}, - {"config": (32, 64, 64, 5, 2), "cond": True}, - {"config": (32, 128, 64, 5, 4), "cond": True}, - {"config": (32, 256, 64, 5, 4), "cond": True}, - {"config": (16, 32, 32, 6, 2), "cond": True}, - {"config": (16, 64, 32, 6, 2), "cond": True}, - {"config": (16, 128, 32, 6, 4), "cond": True}, - {"config": (16, 256, 32, 6, 4), "cond": True}, - {"config": (16, 32, 64, 6, 2), "cond": True}, - {"config": (16, 64, 64, 6, 2), "cond": True}, - {"config": (16, 128, 64, 6, 4), "cond": True}, - {"config": (16, 256, 64, 6, 4), "cond": True}, - {"config": (32, 32, 32, 6, 2), "cond": True}, - {"config": (32, 64, 32, 6, 2), "cond": True}, - {"config": (32, 128, 32, 6, 4), "cond": True}, - {"config": (32, 256, 32, 6, 4), "cond": True}, - {"config": (32, 32, 64, 6, 2), "cond": True}, - {"config": (32, 64, 64, 6, 2), "cond": True}, - {"config": (32, 128, 64, 6, 4), "cond": True}, - {"config": (32, 256, 64, 6, 4), "cond": True}, -] - -scaled_persistent_mm_kernel_configs = [ - {"config": (128, 128, 64, 3, 8), "cond": True}, - {"config": (128, 128, 128, 3, 8), "cond": True}, - {"config": (128, 128, 128, 4, 8), "cond": True}, - {"config": (128, 128, 128, 4, 4), "cond": True}, - {"config": (128, 128, 128, 3, 4), "cond": True}, - {"config": (128, 128, 128, 5, 4), "cond": True}, - {"config": (128, 128, 128, 5, 8), "cond": True}, - {"config": (128, 128, 128, 6, 8), "cond": True}, - {"config": (128, 128, 64, 4, 8), "cond": True}, -] - - -# Create filtered list of configs based on cond evaluation -mm_platform_configs = tuple( - cast(tuple[int, int, int, int, int], config["config"]) - for config in mm_kernel_configs - if config["cond"] -) -extra_mm_platform_configs = tuple( - cast(tuple[int, int, int, int, int], config["config"]) - for config in extra_mm_kernel_configs - if config["cond"] -) -int8_platform_configs = tuple( - cast(tuple[int, int, int, int, int], config["config"]) - for config in int8_mm_kernel_configs - if config["cond"] -) -mixed_mm_platform_configs = tuple( - cast(tuple[int, int, int, int, int], config["config"]) - for config in mixed_mm_kernel_configs - if config["cond"] -) -persistent_mm_platform_configs = tuple( - cast(tuple[int, int, int, int, int], config["config"]) - for config in persistent_mm_kernel_configs - if config["cond"] -) -scaled_mm_platform_configs = tuple( - cast(tuple[int, int, int, int, int], config["config"]) - for config in scaled_mm_kernel_configs - if config["cond"] -) -scaled_persistent_mm_platform_configs = tuple( - cast(tuple[int, int, int, int, int], config["config"]) - for config in scaled_persistent_mm_kernel_configs - if config["cond"] -) - -# On ROCm convert num_stages to improve performance -if torch.version.hip and torch.cuda.is_available(): - mm_platform_configs = build_rocm_gemm_configs(mm_platform_configs) - extra_mm_platform_configs = build_rocm_gemm_configs(extra_mm_platform_configs) - int8_platform_configs = build_rocm_gemm_configs(int8_platform_configs) - mixed_mm_platform_configs = build_rocm_gemm_configs(mixed_mm_platform_configs) - scaled_mm_platform_configs = build_rocm_gemm_configs(scaled_mm_platform_configs) - -mm_configs = functools.partial( - filtered_configs, - configs=mm_platform_configs, -) - -extra_mm_configs = functools.partial( - filtered_configs, - configs=extra_mm_platform_configs, -) - -int8_mm_configs = functools.partial( - filtered_configs, - configs=int8_platform_configs, -) - -persistent_mm_configs = functools.partial( - filtered_configs, - configs=persistent_mm_platform_configs, -) - -scaled_mm_configs = functools.partial( - filtered_configs, - configs=scaled_mm_platform_configs, -) - -scaled_persistent_mm_configs = functools.partial( - filtered_configs, - configs=scaled_persistent_mm_platform_configs, -) - - def should_fallback_to_aten(choices: list[ChoiceCaller]) -> bool: if len(choices) == 0 and not use_aten_gemm_kernels(): if inductor_config.autotune_fallback_to_aten: @@ -490,8 +72,7 @@ def mm_options(config, sym_m, sym_n, sym_k, layout): not inductor_config.force_same_precision or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0) ) - return dict( - GROUP_M=8, + options_dict = dict( EVEN_K=even_k_symbolic, ALLOW_TF32=allow_tf32, ACC_TYPE=acc_type(layout.dtype), @@ -500,6 +81,13 @@ def mm_options(config, sym_m, sym_n, sym_k, layout): **config.kwargs, ) + # If GROUP_M not specified then default to 8 + if "GROUP_M" not in config.kwargs: + group_m = config.kwargs.get("GROUP_M", 8) + options_dict["GROUP_M"] = group_m + + return options_dict + def persistent_mm_options(mat1, mat2): return dict( @@ -552,6 +140,15 @@ def mm_args( return [m, n, k, layout, mat1, mat2, *others] +def mm_config_kwargs(device, exclude_condition): + if device == "cpu": + return { + "scale": 0.5, + "exclude": exclude_condition, + } + return {} + + def addmm_epilogue(dtype, alpha, beta): def epilogue(acc, bias): if alpha != 1: diff --git a/torch/_inductor/kernel/mm_plus_mm.py b/torch/_inductor/kernel/mm_plus_mm.py index f432cbb13b6f..ac6bbee6c75a 100644 --- a/torch/_inductor/kernel/mm_plus_mm.py +++ b/torch/_inductor/kernel/mm_plus_mm.py @@ -1,8 +1,8 @@ # mypy: allow-untyped-defs -import functools import torch +from .. import ir from ..lowering import lowerings from ..select_algorithm import ( autotune_select_algorithm, @@ -112,101 +112,14 @@ ) -@functools.lru_cache(None) -def mm_configs(): - import triton - - # List of dictionaries to store the kernel configs. Configs that evaluate to true - # will be utilised on the target platform - mm_triton_configs = [ - { - "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, - "num_stages": 2, - "num_warps": 4, - "cond": True, - }, - { - "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, - "num_stages": 3, - "num_warps": 8, - "cond": True, - }, - { - "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, - "num_stages": 4, - "num_warps": 16, - "cond": True, - }, - { - "config": {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32}, - "num_stages": 4, - "num_warps": 8, - "cond": True, - }, - { - "config": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32}, - "num_stages": 4, - "num_warps": 8, - "cond": True, - }, - { - "config": {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, - "num_stages": 1, - "num_warps": 8, - "cond": True, - }, - { - "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, - "num_stages": 1, - "num_warps": 8, - "cond": True, - }, - { - "config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128}, - "num_stages": 1, - "num_warps": 8, - "cond": torch.version.hip is None, - }, - { - "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16}, - "num_stages": 2, - "num_warps": 4, - "cond": True, - }, - { - "config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16}, - "num_stages": 1, - "num_warps": 2, - "cond": True, - }, - ] - - # Filter out configs in which cond evaluates to true - # On ROCm convert num_stages to 1 as pipelining provides no benefit - if torch.version.hip: - filtered_configs = [ - triton.Config(c["config"], num_stages=1, num_warps=c["num_warps"]) - for c in mm_triton_configs - if c["cond"] - ] - else: - filtered_configs = [ - triton.Config( - c["config"], num_stages=c["num_stages"], num_warps=c["num_warps"] - ) - for c in mm_triton_configs - if c["cond"] - ] - - return filtered_configs - - def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None): """ Computes mm(mat1, mat2) + mm(mat3, mat4) """ m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout) m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout) + device_type = ir.get_device_type(mat1) + # Optimization is optional, because we can always just not do the fusion if ( m1 * n1 == 0 @@ -231,6 +144,9 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None): if use_aten_gemm_kernels() else [] ) + + mm_configs = V.choices.get_mm_plus_mm_configs(device_type) + if use_triton_template(layout1): for config in mm_configs(): # see https://github.com/openai/triton/issues/1298 diff --git a/torch/_inductor/kernel/mm_scaled.py b/torch/_inductor/kernel/mm_scaled.py index 506dc30f1ffd..aa917e120168 100644 --- a/torch/_inductor/kernel/mm_scaled.py +++ b/torch/_inductor/kernel/mm_scaled.py @@ -11,7 +11,7 @@ from torch.utils._triton import has_triton_tma_device from ..config import triton as triton_config -from ..ir import _IntLike, ChoiceCaller, Layout, StorageBox, TensorBox +from ..ir import _IntLike, ChoiceCaller, get_device_type, Layout, StorageBox, TensorBox from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering from ..select_algorithm import ( autotune_select_algorithm, @@ -27,13 +27,12 @@ use_ck_gemm_template, use_triton_template, ) +from ..virtualized import V from .mm_common import ( _is_static_problem, mm_args, mm_grid, persistent_mm_grid, - scaled_mm_configs, - scaled_persistent_mm_configs, should_fallback_to_aten, ) @@ -508,6 +507,7 @@ def tuned_scaled_mm( m, n, k, layout, mat_a, mat_b = mm_args( mat_a, mat_b, layout=layout, out_dtype=out_dtype ) + # below is for getting an overview logging info of inductor mms counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1 log.info( @@ -520,6 +520,8 @@ def tuned_scaled_mm( layout, ) + device_type = get_device_type(mat_a) + check_supported_striding(mat_a, mat_b) scale_a, scale_b = realize_inputs(scale_a, scale_b) @@ -544,6 +546,11 @@ def tuned_scaled_mm( _, is_nonzero = _is_static_problem(layout) + scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type) + scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs( + device_type + ) + if is_nonzero and use_triton_template(layout, enable_float8=True): if use_persistent_tma(k, bias is not None): for config in scaled_persistent_mm_configs(m, n, k): diff --git a/torch/_inductor/runtime/triton_helpers.py b/torch/_inductor/runtime/triton_helpers.py index 6c997285beec..05a04b4030a5 100644 --- a/torch/_inductor/runtime/triton_helpers.py +++ b/torch/_inductor/runtime/triton_helpers.py @@ -44,7 +44,8 @@ def set_driver_to_gpu(): def get_backend_options(): - driver = triton.runtime.driver + from triton.runtime import driver + target = driver.active.get_current_target() backend = triton.compiler.compiler.make_backend(target) options = backend.parse_options(dict()) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index b4e138d6fcfb..3dd1f2db38cf 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -1247,6 +1247,7 @@ def make_kernel_render(out_node): ), "num_stages": num_stages, "num_warps": num_warps, + "GROUP_M": kwargs.get("GROUP_M", -1), "allow_tf32": str(kwargs.get("ALLOW_TF32", None)), "acc_type": str(kwargs.get("ACC_TYPE", None)), }, diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py new file mode 100644 index 000000000000..5567350f84e6 --- /dev/null +++ b/torch/_inductor/template_heuristics.py @@ -0,0 +1,1130 @@ +from __future__ import annotations + +import dataclasses +import itertools +from functools import partial +from threading import Lock +from typing import Any, Callable, TYPE_CHECKING + +import torch +from torch.utils._ordered_set import OrderedSet + +from . import config +from .utils import get_backend_num_stages +from .virtualized import V + + +if TYPE_CHECKING: + from collections.abc import Generator + + from triton import Config as TritonConfig + + +# Gemm Configs +@dataclasses.dataclass +class BaseConfig: + """ + Base Gemm configuration used for most backends (CPU, CUDA) + """ + + block_m: int + block_n: int + block_k: int + num_stages: int + num_warps: int + + +@dataclasses.dataclass +class GemmConfig(BaseConfig): + """ + Gemm configuration used for most backends (CPU, CUDA) + """ + + group_m: int = 8 + + +ConvConfig = BaseConfig + + +# FlexAttention Configs +@dataclasses.dataclass +class FlexConfig: + """ + Base Config class for flex attention + - FlexAttn forward, backward and flex decode will use this + + NOTE: + For flex_attn bwd block_m and block_n are reused for block_m1, block_m2, block_n1, block_n2 + + """ + + block_m: int + block_n: int + num_stages: int + num_warps: int + + +@dataclasses.dataclass +class FlexDecodeConfig: + """ + Config class for flex decoding + """ + + block_n: int + num_stages: int + num_warps: int + + +# ROCm classes +@dataclasses.dataclass +class ROCmGemmConfig(GemmConfig): + """ + ROCm subclass for GEMMs, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 16 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmConvConfig(ConvConfig): + """ + ROCm subclass for Conv, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 16 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmFlexConfig(FlexConfig): + """ + ROCm subclass for FlexAttn, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 0 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmFlexDecodeConfig(FlexDecodeConfig): + """ + ROCm subclass for FlexDecode, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 0 + waves_per_eu: int = 0 + kpack: int = 2 + + +class BaseHeuristicSingleton(type): + """ + Thread-safe implementation of single to be used in the config heuristic subclasses + to ensure heavy __init__ calls are not repeatedly run + """ + + _instances: dict[type[Any], Any] = {} + _lock: Lock = Lock() + + def __call__( + cls: BaseHeuristicSingleton, *args: Any, **kwargs: Any + ) -> BaseConfigHeuristic: + with cls._lock: + if cls not in cls._instances: + instance = super().__call__() + cls._instances[cls] = instance + return cls._instances[cls] + + +class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton): + """ + Base class for mm_configs, device specific triton kernels config inherit from here + """ + + def __init__(self) -> None: + # List of dictionaries to store the kernel configs. Configs that evaluate to true + # will be utilised on the target platform. The configs are as follows: + # (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) + self.mm_configs: list[BaseConfig] = [ + GemmConfig(32, 32, 16, 1, 2), + GemmConfig(32, 32, 128, 2, 4), + GemmConfig(32, 64, 32, 5, 8), + GemmConfig(64, 32, 32, 5, 8), + GemmConfig(64, 32, 128, 5, 4), + GemmConfig(64, 64, 16, 2, 4), + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 64, 64, 3, 8), + GemmConfig(64, 64, 128, 5, 4), + GemmConfig(64, 128, 32, 3, 4), + GemmConfig(64, 128, 32, 4, 8), + GemmConfig(64, 128, 64, 3, 4), + GemmConfig(64, 128, 128, 4, 4), + GemmConfig(128, 64, 32, 3, 4), + GemmConfig(128, 64, 32, 4, 8), + GemmConfig(128, 128, 32, 2, 8), + GemmConfig(128, 128, 32, 3, 4), + GemmConfig(128, 128, 64, 3, 4), + GemmConfig(128, 128, 64, 5, 8), + ] + + # Exhaustive search for mm configs + self.exhaustive_configs: list[BaseConfig] = [ + GemmConfig(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, group_m) + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( + [16, 32, 64, 128, 256], repeat=3 + ) + for num_stages in [1, 2, 3, 4, 5] + for num_warps in [2, 4, 8] + for group_m in [8] + ] + + # these are only used in tuned_mm when AutoHeuristic is enabled + # the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned + # when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10 + # which saves compilation time (since less configs are autotuned) and potentially increase performance + # because the learned heuristic might predict a config that is not part mm_configs + self.extra_mm_configs: list[BaseConfig] = [ + GemmConfig(16, 32, 16, 3, 2), + GemmConfig(16, 32, 32, 4, 2), + GemmConfig(16, 32, 32, 5, 2), + GemmConfig(64, 64, 128, 3, 4), + GemmConfig(128, 64, 32, 2, 2), + GemmConfig(128, 64, 64, 3, 8), + GemmConfig(128, 64, 128, 4, 8), + GemmConfig(128, 128, 32, 4, 4), + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 64, 5, 4), + ] + + self.int8_mm_configs: list[BaseConfig] = [ + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 128, 32, 3, 4), + GemmConfig(128, 64, 32, 3, 4), + GemmConfig(64, 128, 32, 4, 8), + GemmConfig(128, 64, 32, 4, 8), + GemmConfig(64, 32, 32, 5, 8), + GemmConfig(32, 64, 32, 5, 8), + GemmConfig(128, 128, 32, 2, 8), + GemmConfig(64, 64, 64, 3, 8), + GemmConfig(128, 256, 128, 3, 8), + GemmConfig(256, 128, 128, 3, 8), + ] + + self.mixed_mm_configs: list[BaseConfig] = [ + GemmConfig(16, 128, 256, 3, 4), + GemmConfig(16, 128, 256, 5, 8), + ] + + self.persistent_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 256, 64, 3, 8), + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 128, 3, 8), + GemmConfig(128, 128, 128, 3, 4), + GemmConfig(128, 128, 64, 4, 8), + GemmConfig(128, 128, 64, 5, 8), + GemmConfig(256, 128, 64, 4, 8), + GemmConfig(128, 128, 64, 5, 4), + ] + + self.scaled_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 256, 32, 3, 8), + GemmConfig(256, 128, 32, 3, 8), + GemmConfig(256, 64, 32, 4, 4), + GemmConfig(64, 256, 32, 4, 4), + GemmConfig(128, 128, 32, 4, 4), + GemmConfig(128, 64, 32, 4, 4), + GemmConfig(64, 128, 32, 4, 4), + GemmConfig(128, 32, 32, 4, 4), + GemmConfig(64, 32, 32, 5, 2), + GemmConfig(256, 128, 128, 3, 8), + GemmConfig(256, 64, 128, 4, 4), + GemmConfig(64, 256, 128, 4, 4), + GemmConfig(128, 128, 128, 4, 4), + GemmConfig(128, 64, 64, 4, 4), + GemmConfig(64, 128, 64, 4, 4), + GemmConfig(128, 32, 64, 4, 4), + GemmConfig(64, 32, 64, 5, 2), + GemmConfig(16, 32, 32, 2, 2), + GemmConfig(16, 64, 32, 2, 2), + GemmConfig(16, 128, 32, 2, 4), + GemmConfig(16, 256, 32, 2, 4), + GemmConfig(16, 32, 64, 2, 2), + GemmConfig(16, 64, 64, 2, 2), + GemmConfig(16, 128, 64, 2, 4), + GemmConfig(16, 256, 64, 2, 4), + GemmConfig(32, 32, 32, 2, 2), + GemmConfig(32, 64, 32, 2, 2), + GemmConfig(32, 128, 32, 2, 4), + GemmConfig(32, 256, 32, 2, 4), + GemmConfig(32, 32, 64, 2, 2), + GemmConfig(32, 64, 64, 2, 2), + GemmConfig(32, 128, 64, 2, 4), + GemmConfig(32, 256, 64, 2, 4), + GemmConfig(16, 32, 32, 3, 2), + GemmConfig(16, 64, 32, 3, 2), + GemmConfig(16, 128, 32, 3, 4), + GemmConfig(16, 256, 32, 3, 4), + GemmConfig(16, 32, 64, 3, 2), + GemmConfig(16, 64, 64, 3, 2), + GemmConfig(16, 128, 64, 3, 4), + GemmConfig(16, 256, 64, 3, 4), + GemmConfig(32, 32, 32, 3, 2), + GemmConfig(32, 64, 32, 3, 2), + GemmConfig(32, 128, 32, 3, 4), + GemmConfig(32, 256, 32, 3, 4), + GemmConfig(32, 32, 64, 3, 2), + GemmConfig(32, 64, 64, 3, 2), + GemmConfig(32, 128, 64, 3, 4), + GemmConfig(32, 256, 64, 3, 4), + GemmConfig(16, 32, 32, 4, 2), + GemmConfig(16, 64, 32, 4, 2), + GemmConfig(16, 128, 32, 4, 4), + GemmConfig(16, 256, 32, 4, 4), + GemmConfig(16, 32, 64, 4, 2), + GemmConfig(16, 64, 64, 4, 2), + GemmConfig(16, 128, 64, 4, 4), + GemmConfig(16, 256, 64, 4, 4), + GemmConfig(32, 32, 32, 4, 2), + GemmConfig(32, 64, 32, 4, 2), + GemmConfig(32, 128, 32, 4, 4), + GemmConfig(32, 256, 32, 4, 4), + GemmConfig(32, 32, 64, 4, 2), + GemmConfig(32, 64, 64, 4, 2), + GemmConfig(32, 128, 64, 4, 4), + GemmConfig(32, 256, 64, 4, 4), + GemmConfig(16, 32, 32, 5, 2), + GemmConfig(16, 64, 32, 5, 2), + GemmConfig(16, 128, 32, 5, 4), + GemmConfig(16, 256, 32, 5, 4), + GemmConfig(16, 32, 64, 5, 2), + GemmConfig(16, 64, 64, 5, 2), + GemmConfig(16, 128, 64, 5, 4), + GemmConfig(16, 256, 64, 5, 4), + GemmConfig(32, 32, 32, 5, 2), + GemmConfig(32, 64, 32, 5, 2), + GemmConfig(32, 128, 32, 5, 4), + GemmConfig(32, 256, 32, 5, 4), + GemmConfig(32, 32, 64, 5, 2), + GemmConfig(32, 64, 64, 5, 2), + GemmConfig(32, 128, 64, 5, 4), + GemmConfig(32, 256, 64, 5, 4), + GemmConfig(16, 32, 32, 6, 2), + GemmConfig(16, 64, 32, 6, 2), + GemmConfig(16, 128, 32, 6, 4), + GemmConfig(16, 256, 32, 6, 4), + GemmConfig(16, 32, 64, 6, 2), + GemmConfig(16, 64, 64, 6, 2), + GemmConfig(16, 128, 64, 6, 4), + GemmConfig(16, 256, 64, 6, 4), + GemmConfig(32, 32, 32, 6, 2), + GemmConfig(32, 64, 32, 6, 2), + GemmConfig(32, 128, 32, 6, 4), + GemmConfig(32, 256, 32, 6, 4), + GemmConfig(32, 32, 64, 6, 2), + GemmConfig(32, 64, 64, 6, 2), + GemmConfig(32, 128, 64, 6, 4), + GemmConfig(32, 256, 64, 6, 4), + ] + + self.scaled_persistent_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 128, 3, 8), + GemmConfig(128, 128, 128, 4, 8), + GemmConfig(128, 128, 128, 4, 4), + GemmConfig(128, 128, 128, 3, 4), + GemmConfig(128, 128, 128, 5, 4), + GemmConfig(128, 128, 128, 5, 8), + GemmConfig(128, 128, 128, 6, 8), + GemmConfig(128, 128, 64, 4, 8), + ] + + # TODO: Unify with other gemm patterns, mm_plus_mm currently follows + # slightly different pattern than rest + self.mm_plus_mm_configs: list[BaseConfig] = [ + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 64, 32, 3, 8), + GemmConfig(64, 64, 32, 4, 16), + GemmConfig(64, 32, 32, 4, 8), + GemmConfig(32, 64, 32, 4, 8), + GemmConfig(128, 128, 32, 1, 8), + GemmConfig(64, 64, 64, 1, 8), + GemmConfig(32, 32, 128, 1, 8), + GemmConfig(64, 64, 16, 2, 4), + GemmConfig(32, 32, 16, 1, 2), + ] + + self.conv_configs: list[BaseConfig] = [ + ConvConfig(64, 256, 16, 2, 4), + ConvConfig(256, 64, 16, 2, 4), + ConvConfig(1024, 16, 16, 1, 8), + ConvConfig(128, 128, 32, 2, 8), + ConvConfig(64, 64, 32, 2, 4), + ConvConfig(64, 256, 32, 2, 8), + ConvConfig(256, 64, 32, 2, 8), + ] + + self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [ + FlexConfig(128, 64, 3, 4), + FlexConfig(128, 128, 3, 4), + FlexConfig(128, 128, 2, 8), + FlexConfig(64, 128, 3, 4), + FlexConfig(64, 64, 3, 4), + ] + + self.flex_attn_bwd_autotune_configs: list[FlexConfig] = [ + FlexConfig(BLOCK1, BLOCK2, s, w) + for BLOCK1 in [32, 64] + for BLOCK2 in [32, 64, 128] + for s in [1, 3, 4, 5] # num_stages + for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) + if BLOCK2 % BLOCK1 == 0 + ] + + self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [ + FlexDecodeConfig(64, 3, 2), + FlexDecodeConfig(32, 3, 2), + FlexDecodeConfig(128, 3, 2), + ] + + self.exhaustive_flex_attn_fwd_configs: list[FlexConfig] = [ + FlexConfig(BLOCK_M, BLOCK_N, num_stages, num_warps) + for BLOCK_M in [16, 32, 64, 128] + for BLOCK_N in [32, 64, 128] + for num_stages in [1, 3, 4, 5] + for num_warps in [2, 4, 8] + ] + + self.exhaustive_flex_attn_bwd_configs: list[FlexConfig] = [ + FlexConfig(BLOCK1, BLOCK2, num_stages, num_warps) + for BLOCK1 in [16, 32, 64, 128] + for BLOCK2 in [16, 32, 64, 128] + for num_stages in [1, 3, 4, 5] + for num_warps in [2, 4, 8] + if BLOCK2 % BLOCK1 == 0 + ] + + self.exhaustive_flex_decode_configs: list[FlexDecodeConfig] = [ + FlexDecodeConfig(block_n, num_stages, num_warps) + for block_n in [16, 32, 64, 128] + for num_stages in [1, 3, 4, 5] + for num_warps in [2, 4, 8] + ] + + def _finalize_mm_configs( + self, + configs: list[BaseConfig], + ) -> Generator[TritonConfig, None, None]: + """ + Finalizes configs after scaling, applying additional constraints. + """ + used: OrderedSet[tuple[int, ...]] = OrderedSet() + + max_mm_configs = config.test_configs.max_mm_configs + + for conf in configs: + # Each warp computes a 16x16 tile = 256 elements + num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256) + + # Construct key for finding duplicate configs + key: tuple[int, ...] = ( + conf.block_m, + conf.block_n, + conf.block_k, + conf.num_stages, + num_warps, + ) + + # Check if gemm specific arg exists - add to key if does + group_m = getattr(conf, "group_m", None) + if group_m is not None: + key += (group_m,) + + if key not in used and ( + max_mm_configs is None or len(used) < max_mm_configs + ): + used.add(key) + kwargs = { + "BLOCK_M": conf.block_m, + "BLOCK_N": conf.block_n, + "BLOCK_K": conf.block_k, + "num_stages": conf.num_stages, + "num_warps": num_warps, + } + if group_m is not None: + kwargs["GROUP_M"] = group_m + yield self.triton_config(**kwargs) + + def _scale_mm_configs( + self, + m: int, + n: int, + k: int, + configs: list[BaseConfig], + scale: float, + has_int8_tensor: bool, + exclude: Callable[[int, int, int], bool], + ) -> list[BaseConfig]: + """ + Scales and filters matrix multiplication configs based on input size. + """ + from .runtime.runtime_utils import next_power_of_2 + + min_block_size = 16 + min_block_size_k = 32 if has_int8_tensor else 16 + + m = max( + next_power_of_2( + V.graph.sizevars.size_hint( + m, + fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] + ) + ), + min_block_size, + ) + n = max( + next_power_of_2( + V.graph.sizevars.size_hint( + n, + fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] + ) + ), + min_block_size, + ) + k = max( + next_power_of_2( + V.graph.sizevars.size_hint( + k, + fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] + ) + ), + min_block_size_k, + ) + + scaled_configs = [] + for c in configs: + scaled_config = dataclasses.replace( + c, + block_m=max(min(int(c.block_m * scale), m), min_block_size), + block_n=max(min(int(c.block_n * scale), n), min_block_size), + block_k=max(min(int(c.block_k * scale), k), min_block_size_k), + ) + + if not exclude( + scaled_config.block_m, scaled_config.block_n, scaled_config.block_k + ): + scaled_configs.append(scaled_config) + + return scaled_configs + + def preprocess_mm_configs( + self, + m: int, + n: int, + k: int, + configs: list[BaseConfig], + has_int8_tensor: bool = False, + scale: int = 1, + exclude: Callable[[int, int, int], bool] = lambda m, n, k: False, + ) -> Generator[TritonConfig, None, None]: + scaled_configs = self._scale_mm_configs( + m, n, k, configs, scale, has_int8_tensor, exclude + ) + return self._finalize_mm_configs(scaled_configs) + + def triton_config( + self, num_stages: int, num_warps: int, **kwargs: Any + ) -> TritonConfig: + from triton import Config as TritonConfig # type: ignore[attr-defined] + + return TritonConfig(kwargs, num_stages=num_stages, num_warps=num_warps) + + def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.mm_configs) + + def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.exhaustive_configs) + + def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.extra_mm_configs) + + def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.int8_mm_configs) + + def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + mm_configs = ( + self.mm_configs + self.mixed_mm_configs + if config.max_autotune_gemm_search_space == "EXHAUSTIVE" + else self.mm_configs + ) + return partial(self.preprocess_mm_configs, configs=mm_configs) + + def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.persistent_mm_configs) + + def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.scaled_mm_configs) + + def get_scaled_persistent_mm_configs( + self, + ) -> partial[Generator[TritonConfig, None, None]]: + return partial( + self.preprocess_mm_configs, configs=self.scaled_persistent_mm_configs + ) + + def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self._finalize_mm_configs, configs=self.mm_plus_mm_configs) + + def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.conv_configs) + + # Flex attn helpers + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = FlexConfig(64, 64, 3, 4) + else: + default_config = FlexConfig(128, 64, 3, 4) + else: + if dtype == torch.float32: + default_config = FlexConfig(32, 16, 3, 4) + else: + default_config = FlexConfig(64, 32, 3, 4) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_bwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + default_config = FlexConfig(16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + default_config = FlexDecodeConfig(block_n=64, num_stages=1, num_warps=2) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + + +class CPUConfigHeuristic(BaseConfigHeuristic): + pass + + +class CUDAConfigHeuristic(BaseConfigHeuristic): + """ + Child class for CUDA device specific gemm/flex attention/conv/ configs. + """ + + def __init__(self) -> None: + super().__init__() + + self.h100_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 3, 4), + (torch.float32, 128): FlexConfig(32, 64, 3, 4), + (torch.float32, 256): FlexConfig(32, 32, 3, 4), + (torch.bfloat16, 64): FlexConfig(128, 128, 3, 4), + (torch.bfloat16, 128): FlexConfig(128, 64, 3, 8), + (torch.bfloat16, 256): FlexConfig(64, 32, 3, 4), + (torch.float16, 64): FlexConfig(128, 128, 3, 4), + (torch.float16, 128): FlexConfig(128, 128, 3, 8), + (torch.float16, 256): FlexConfig(64, 32, 3, 4), + } + + self.a100_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 3, 4), + (torch.float32, 128): FlexConfig(128, 32, 3, 4), + (torch.float32, 256): FlexConfig(64, 16, 3, 4), + (torch.bfloat16, 64): FlexConfig(128, 64, 3, 4), + (torch.bfloat16, 128): FlexConfig(128, 64, 3, 8), + (torch.bfloat16, 256): FlexConfig(32, 64, 3, 4), + (torch.float16, 64): FlexConfig(128, 64, 3, 4), + (torch.float16, 128): FlexConfig(128, 64, 3, 8), + (torch.float16, 256): FlexConfig(32, 64, 3, 4), + } + + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + capability = torch.cuda.get_device_capability() + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = FlexConfig(64, 64, 3, 4) + else: + default_config = FlexConfig(128, 64, 3, 4) + if capability >= (9, 0): + default_config = self.h100_default_flex_config.get( + (dtype, head_dim), default_config + ) + elif capability >= (8, 0): + default_config = self.a100_default_flex_config.get( + (dtype, head_dim), default_config + ) + else: + if dtype == torch.float32: + default_config = FlexConfig(32, 16, 3, 4) + else: + default_config = FlexConfig(64, 32, 3, 4) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + capability = torch.cuda.get_device_capability() + + flex_attn_bwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + if dtype == torch.float32: + default_config = FlexConfig(16, 16, 1, 4) + elif head_dim <= 256 and capability >= (9, 0): # H100 + if head_dim == 64: + default_config = FlexConfig(64, 64, 3, 4) + elif head_dim == 128: + default_config = FlexConfig(64, 128, 3, 8) + else: + default_config = FlexConfig(64, 64, 2, 4) + elif capability >= (8, 0): # A100 + if head_dim == 64: + default_config = FlexConfig(32, 128, 3, 4) + elif head_dim == 128: + # SM86/89 have smaller shared memory sizes + num_stages = 3 if capability[1] == 0 else 2 + default_config = FlexConfig(64, 64, num_stages, 4) + else: + default_config = FlexConfig(64, 64, 2, 4) + else: # modest hardware or extremely large head_dim + default_config = FlexConfig(16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + capability = torch.cuda.get_device_capability() + + default_config = FlexDecodeConfig(64, 1, 2) + + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + if capability >= (9, 0): # sm_90+ + if head_dim > 128 and dtype == torch.float32: + default_config = FlexDecodeConfig(64, 1, 2) + else: + default_config = FlexDecodeConfig(64, 3, 2) + else: + default_config = FlexDecodeConfig(64, 1, 2) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + + +class ROCmConfigHeuristic(BaseConfigHeuristic): + """ + Child class for ROCm specific gemm/flex attention/conv/ configs. + """ + + def __init__(self) -> None: + super().__init__() + + self.default_num_stages = get_backend_num_stages() + + self.mm_configs: list[BaseConfig] = [ + ROCmGemmConfig( + 16, 16, 256, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(32, 16, 256, self.default_num_stages, 4, group_m=4), + ROCmGemmConfig( + 32, 32, 16, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(32, 32, 128, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(32, 64, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig( + 64, 16, 128, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(64, 32, 32, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 32, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 32, 64, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(64, 32, 128, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 64, 16, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 64, 64, self.default_num_stages, 4, group_m=4), + ROCmGemmConfig(64, 64, 128, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig(64, 64, 256, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 64, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(64, 128, 32, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(64, 128, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(64, 128, 128, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(128, 32, 32, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(128, 32, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig( + 128, 64, 32, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(128, 64, 64, self.default_num_stages, 4, group_m=16), + ROCmGemmConfig(128, 64, 128, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 128, 128, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 128, 32, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig( + 128, 128, 32, self.default_num_stages, 8, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 128, 64, self.default_num_stages, 4, group_m=16), + ROCmGemmConfig(128, 128, 64, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(128, 128, 128, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig( + 128, 256, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 256, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(256, 64, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 256, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(256, 128, 32, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig(256, 128, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(256, 256, 64, self.default_num_stages, 8, group_m=4), + ] + + # Exhaustive search for mm configs + self.exhaustive_configs: list[BaseConfig] = [ + ROCmGemmConfig( + BLOCK_M, + BLOCK_N, + BLOCK_K, + num_stages, + num_warps, + group_m, + matrix_instr_nonkdim, + waves_per_eu, + kpack, + ) + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( + [16, 32, 64, 128, 256], repeat=3 + ) + for num_stages in [1, self.default_num_stages] + for num_warps in [4, 8] + for group_m in [4, 8, 16] + for matrix_instr_nonkdim in [0, 16] + for waves_per_eu in [0, 2] + for kpack in [2] + ] + + self.default_flex_config = { + (torch.float32, 64): ROCmFlexConfig(128, 32, 1, 4), + (torch.float32, 128): ROCmFlexConfig(128, 32, 1, 4), + (torch.float32, 256): ROCmFlexConfig(64, 16, 1, 4), + (torch.bfloat16, 64): ROCmFlexConfig(128, 64, 1, 8), + (torch.bfloat16, 128): ROCmFlexConfig(128, 64, 1, 8), + (torch.bfloat16, 256): ROCmFlexConfig(32, 64, 1, 8), + (torch.float16, 64): ROCmFlexConfig(128, 64, 1, 8), + (torch.float16, 128): ROCmFlexConfig(128, 64, 1, 8), + (torch.float16, 256): ROCmFlexConfig(32, 64, 1, 4), + } + + self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK1, BLOCK2, 1, w) + for BLOCK1 in [16, 64, 128] + for BLOCK2 in [16, 32, 64, 128] + for w in [4, 8] + ] + + self.flex_attn_bwd_autotune_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK1, BLOCK2, 1, w, mfma) + for BLOCK1 in [16, 32, 64] + for BLOCK2 in [32, 64, 128] + for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) + for mfma in [0, 16] + if BLOCK2 % BLOCK1 == 0 + ] + + self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [ + ROCmFlexDecodeConfig(32, 1, 4), + ROCmFlexDecodeConfig(64, 1, 4), + ROCmFlexDecodeConfig(128, 1, 4), + ROCmFlexDecodeConfig(32, 1, 8), + ROCmFlexDecodeConfig(64, 1, 8), + ROCmFlexDecodeConfig(128, 1, 8), + ] + + self.exhaustive_flex_attn_fwd_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK_M, BLOCK_N, num_stages, num_warps, mfma, wpeu) + for BLOCK_M in [16, 32, 64, 128] + for BLOCK_N in [32, 64, 128] + for num_stages in [1, 2] + for num_warps in [2, 4, 8] + for mfma in [0, 16] + for wpeu in [0, int(8 // num_warps)] + ] + + self.exhaustive_flex_attn_bwd_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK1, BLOCK2, num_stages, num_warps, mfma, wpeu) + for BLOCK1 in [16, 32, 64, 128] + for BLOCK2 in [16, 32, 64, 128] + for num_stages in [1, 2] + for num_warps in [2, 4, 8] + for mfma in [0, 16] + for wpeu in [0, int(8 // num_warps)] + if BLOCK2 % BLOCK1 == 0 + ] + + self.exhaustive_flex_decode_configs: list[FlexDecodeConfig] = [ + ROCmFlexDecodeConfig(block_n, num_stages, num_warps, mfma, wpeu, kpack=2) + for block_n in [16, 32, 64, 128] + for num_stages in [1, 2] + for num_warps in [2, 4, 8] + for mfma in [0, 16] + for wpeu in [0, int(8 // num_warps)] + ] + + def _filter_configs( + self, configs: list[BaseConfig], new_num_stages: int + ) -> list[BaseConfig]: + # TODO: _filter_configs can be removed once backend specific configs are added + # for all methods + for c in configs: + c.num_stages = self.default_num_stages + return configs + + def _finalize_mm_configs( + self, + configs: list[BaseConfig], + ) -> Generator[TritonConfig, None, None]: + """ + Finalizes configs after scaling, applying additional constraints. + """ + used: OrderedSet[tuple[int, ...]] = OrderedSet() + + max_mm_configs = config.test_configs.max_mm_configs + + for conf in configs: + # Each warp computes a 16x16 tile = 256 elements + conf.num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256) + + # Defaults for AMD triton backend kern args if not set + matrix_instr_nonkdim = getattr(conf, "matrix_instr_nonkdim", 16) + waves_per_eu = getattr(conf, "waves_per_eu", 0) + kpack = getattr(conf, "kpack", 2) + + if matrix_instr_nonkdim != 0 and ( + conf.block_m % matrix_instr_nonkdim != 0 + or conf.block_n % matrix_instr_nonkdim != 0 + ): + # block_m and block_n must be a multiple of matrix_instr_nonkdim + continue + + # Construct key for finding duplicate configs + key: tuple[int, ...] = ( + conf.block_m, + conf.block_n, + conf.block_k, + conf.num_stages, + conf.num_warps, + waves_per_eu, + matrix_instr_nonkdim, + kpack, + ) + + # Check if gemm specific arg exists - add to key if does + group_m = getattr(conf, "group_m", None) + if group_m is not None: + key += (group_m,) + + if waves_per_eu != 0: + waves_per_eu = int(8 // conf.num_warps) + + if key not in used and ( + max_mm_configs is None or len(used) < max_mm_configs + ): + used.add(key) + kwargs = { + "BLOCK_M": conf.block_m, + "BLOCK_N": conf.block_n, + "BLOCK_K": conf.block_k, + "num_stages": conf.num_stages, + "num_warps": conf.num_warps, + "matrix_instr_nonkdim": matrix_instr_nonkdim, + "waves_per_eu": waves_per_eu, + "kpack": kpack, + } + if group_m is not None: + kwargs["GROUP_M"] = group_m + yield self.triton_config(**kwargs) + + def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs( + self.extra_mm_configs, self.default_num_stages + ) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs( + self.int8_mm_configs, self.default_num_stages + ) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + mm_configs = ( + self.mm_configs + self.mixed_mm_configs + if config.max_autotune_gemm_search_space == "EXHAUSTIVE" + else self.mm_configs + ) + filtered_configs = self._filter_configs(mm_configs, self.default_num_stages) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs( + self.persistent_mm_configs, self.default_num_stages + ) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs( + self.scaled_mm_configs, self.default_num_stages + ) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_scaled_persistent_mm_configs( + self, + ) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs( + self.scaled_persistent_mm_configs, self.default_num_stages + ) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs(self.mm_plus_mm_configs, 1) + return partial(self._finalize_mm_configs, configs=filtered_configs) + + def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs( + self.conv_configs, self.default_num_stages + ) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = ROCmFlexConfig(64, 64, 1, 4) + else: + default_config = ROCmFlexConfig(128, 64, 1, 8) + default_config = self.default_flex_config.get( + (dtype, head_dim), default_config + ) + else: + if dtype == torch.float32: + default_config = ROCmFlexConfig(32, 16, 1, 4) + else: + default_config = ROCmFlexConfig(64, 32, 1, 4) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_bwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + if dtype == torch.float32: + default_config = ROCmFlexConfig(16, 16, 1, 4) + elif head_dim <= 256: + if head_dim == 64: + default_config = ROCmFlexConfig(64, 64, 1, 4) + elif head_dim == 128: + default_config = ROCmFlexConfig(64, 128, 1, 8) + else: + default_config = ROCmFlexConfig(64, 64, 1, 4) + else: + default_config = ROCmFlexConfig(16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + default_config = ROCmFlexDecodeConfig(64, 1, 4) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + + +class XPUConfigHeuristic(BaseConfigHeuristic): + """ + Placeholder child class for XPU specific overrides. + """