Skip to content

[AUTOGENERATED] [release/2.8] [SWDEV-539215] - Autotune support for persistent reduction and no_x_dim removal #2454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions test/inductor/test_combo_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,23 +296,6 @@ def fn(a0, a1, a2, b0, b1, b2):

self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8)

@requires_cuda
def test_persistent_reduction_no_x_dim(self):
def fn(x, y):
return x.sum(1), y.sum(1)

inps = (
torch.rand(16, 256, device="cuda"),
torch.rand(32, 256, device="cuda"),
)
torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256)
torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256)
out_eager = fn(*inps)
out_compiled = torch.compile(fn)(*inps)

self.assertEqual(out_eager, out_compiled)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4)


@instantiate_parametrized_tests
class ComboKernelDynamicShapesTests(TestCase):
Expand Down
25 changes: 0 additions & 25 deletions test/inductor/test_torchinductor_strided_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,31 +706,6 @@ def test_2d_reduction_odd_shapes(
# Check the code for multiple Rn_BLOCK's
self._assert_reduction_ndims(code, 2)

def test_2d_reduction_no_x_dim(self):
"""
Tests a 2D reduction without an "x" dimension.
"""
# We need a size to get no x dim.
view = self._discontiguous_tensor((2, 346), self.device)

# Expect 1 block pointer for the input.
result, (code,) = run_and_compare(
self,
torch.prod,
view,
expected_num_block_pointers=1,
expected_num_triton_kernels=1,
config_patches=tiled_reduction_config,
)

# Check that there's no X dimension in the signature.
(signature_line,) = (
line for line in code.splitlines() if line.startswith("def triton")
)
self.assertNotIn("BLOCK", signature_line)

# Check for 2 reduction dimensions in the body.
self._assert_reduction_ndims(code, 2)

@parametrize(
"size,expected_num_block_pointers,expected_num_triton_kernels,expect_fallback",
Expand Down
8 changes: 4 additions & 4 deletions torch/_inductor/choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,11 @@ def want_no_x_dim(features: SIMDKernelFeatures) -> bool:
Heuristic to decide if we should drop the X dimension from a persistent reduction kernel.
So the [XBLOCK, RBLOCK] block becomes a [RBLOCK] block and XBLOCK is forced to be always 1.
Strangely this is faster than a [1, RBLOCK] block in some cases.

ROCm branch change: Remove want_no_x_dim for persistent reduction.
Inductor benchmarks show no perf advantage and simplifies autotune flow.
"""
return (
features.get_reduction_hint() == ReductionHint.INNER
and V.graph.sizevars.statically_known_geq(features.reduction_numel, 256)
)
return False

@staticmethod
def reduction_split_factor(
Expand Down
11 changes: 4 additions & 7 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1768,13 +1768,10 @@ def should_use_persistent_reduction(self) -> bool:
)

def want_no_x_dim(self):
if (
self.persistent_reduction
and len(self.numels) == self.num_reduction_dims + 1
):
if self.fixed_config:
return self.fixed_config["XBLOCK"] == 1
return V.choices.want_no_x_dim(self.features)
"""
ROCm branch change: Remove want_no_x_dim for persistent reduction.
Inductor benchmarks show no perf advantage and simplifies autotune flow.
"""
return False

@property
Expand Down
25 changes: 19 additions & 6 deletions torch/_inductor/runtime/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2556,6 +2556,10 @@ def _persistent_reduction_configs(
rnumel = get_total_reduction_numel(size_hints)

MAX_PERSISTENT_BLOCK_NUMEL = 4096
max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or (
inductor_meta.get("max_autotune")
or inductor_meta.get("max_autotune_pointwise")
)

if "y" not in size_hints:
configs = [
Expand Down Expand Up @@ -2585,18 +2589,27 @@ def _persistent_reduction_configs(
if "y" in size_hints:
pass
# TODO(jansel): we should be able to improve these heuristics
elif reduction_hint == ReductionHint.INNER and rnumel >= 256:
configs = configs[:1]
elif reduction_hint == ReductionHint.OUTER:
configs = configs[-1:]
elif reduction_hint == ReductionHint.OUTER_TINY:
configs = [
if not max_autotune_enabled: # Don't filter if tuning enabled
if reduction_hint == ReductionHint.INNER and rnumel >= 256:
configs = configs[:1]
elif reduction_hint == ReductionHint.OUTER:
configs = configs[-1:]

if reduction_hint == ReductionHint.OUTER_TINY:
tiny_configs = [
triton_config_reduction(
size_hints,
2 * (256 // rnumel) if rnumel <= 256 else 1,
rnumel,
)
]
if max_autotune_enabled:
for tconfig in tiny_configs:
if tconfig not in configs:
configs.append(tconfig)
else:
configs = tiny_configs

for c in configs:
# we don't need Rn_BLOCK for persistent reduction
for prefix in size_hints:
Expand Down