diff --git a/test/inductor/test_combo_kernels.py b/test/inductor/test_combo_kernels.py index bccdacab2a679..23696b0431515 100644 --- a/test/inductor/test_combo_kernels.py +++ b/test/inductor/test_combo_kernels.py @@ -296,23 +296,6 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8) - @requires_cuda - def test_persistent_reduction_no_x_dim(self): - def fn(x, y): - return x.sum(1), y.sum(1) - - inps = ( - torch.rand(16, 256, device="cuda"), - torch.rand(32, 256, device="cuda"), - ) - torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256) - torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256) - out_eager = fn(*inps) - out_compiled = torch.compile(fn)(*inps) - - self.assertEqual(out_eager, out_compiled) - self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) - @instantiate_parametrized_tests class ComboKernelDynamicShapesTests(TestCase): diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index b0a6c4d4441ed..887a071d80eed 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -651,31 +651,6 @@ def test_2d_reduction_odd_shapes( # Check the code for multiple Rn_BLOCK's self._assert_reduction_ndims(code, 2) - def test_2d_reduction_no_x_dim(self): - """ - Tests a 2D reduction without an "x" dimension. - """ - # We need a size to get no x dim. - view = self._discontiguous_tensor((2, 346), self.device) - - # Expect 1 block pointer for the input. - result, (code,) = run_and_compare( - self, - torch.prod, - view, - expected_num_block_pointers=1, - expected_num_triton_kernels=1, - config_patches=tiled_reduction_config, - ) - - # Check that there's no X dimension in the signature. - (signature_line,) = ( - line for line in code.splitlines() if line.startswith("def triton") - ) - self.assertNotIn("BLOCK", signature_line) - - # Check for 2 reduction dimensions in the body. - self._assert_reduction_ndims(code, 2) @parametrize( "size,expected_num_block_pointers,expected_num_triton_kernels,expect_fallback", diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index ce7e941ee1ff6..3304edf8261da 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -109,11 +109,11 @@ def want_no_x_dim(features: SIMDKernelFeatures) -> bool: Heuristic to decide if we should drop the X dimension from a persistent reduction kernel. So the [XBLOCK, RBLOCK] block becomes a [RBLOCK] block and XBLOCK is forced to be always 1. Strangely this is faster than a [1, RBLOCK] block in some cases. + + ROCm branch change: Remove want_no_x_dim for persistent reduction. + Inductor benchmarks show no perf advantage and simplifies autotune flow. """ - return ( - features.get_reduction_hint() == ReductionHint.INNER - and V.graph.sizevars.statically_known_geq(features.reduction_numel, 256) - ) + return False @staticmethod def reduction_split_factor( diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index dc3b9d218ec28..e6d3cd3fc7646 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1713,13 +1713,10 @@ def should_use_persistent_reduction(self) -> bool: ) def want_no_x_dim(self): - if ( - self.persistent_reduction - and len(self.numels) == self.num_reduction_dims + 1 - ): - if self.fixed_config: - return self.fixed_config["XBLOCK"] == 1 - return V.choices.want_no_x_dim(self.features) + """ + ROCm branch change: Remove want_no_x_dim for persistent reduction. + Inductor benchmarks show no perf advantage and simplifies autotune flow. + """ return False @property diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index e1ae739b97fea..50565690fd96f 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2049,25 +2049,39 @@ def _persistent_reduction_configs( xnumel = size_hints["x"] rnumel = get_total_reduction_numel(size_hints) + max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or ( + inductor_meta.get("max_autotune") + or inductor_meta.get("max_autotune_pointwise") + ) + configs = [ triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) for xblock in (1, 8, 32, 128) - if xblock == 1 or (rnumel * xblock <= 4096 and xblock <= xnumel) + if xblock == 1 or (xblock <= xnumel and (max_autotune_enabled or rnumel * xblock <= 4096)) ] # TODO(jansel): we should be able to improve these heuristics - if reduction_hint == ReductionHint.INNER and rnumel >= 256: - configs = configs[:1] - elif reduction_hint == ReductionHint.OUTER: - configs = configs[-1:] - elif reduction_hint == ReductionHint.OUTER_TINY: - configs = [ + if not max_autotune_enabled: # Don't filter if tuning enabled + if reduction_hint == ReductionHint.INNER and rnumel >= 256: + configs = configs[:1] + elif reduction_hint == ReductionHint.OUTER: + configs = configs[-1:] + + if reduction_hint == ReductionHint.OUTER_TINY: + tiny_configs = [ triton_config_reduction( size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel, ) ] + if max_autotune_enabled: + for tconfig in tiny_configs: + if tconfig not in configs: + configs.append(tconfig) + else: + configs = tiny_configs + for c in configs: # we don't need Rn_BLOCK for persistent reduction for prefix in size_hints: