diff --git a/test/inductor/test_combo_kernels.py b/test/inductor/test_combo_kernels.py index a054464bf6689..b6f356e256713 100644 --- a/test/inductor/test_combo_kernels.py +++ b/test/inductor/test_combo_kernels.py @@ -296,23 +296,6 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8) - @requires_cuda - def test_persistent_reduction_no_x_dim(self): - def fn(x, y): - return x.sum(1), y.sum(1) - - inps = ( - torch.rand(16, 256, device="cuda"), - torch.rand(32, 256, device="cuda"), - ) - torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256) - torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256) - out_eager = fn(*inps) - out_compiled = torch.compile(fn)(*inps) - - self.assertEqual(out_eager, out_compiled) - self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) - @instantiate_parametrized_tests class ComboKernelDynamicShapesTests(TestCase): diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 19f83a35e96d7..53308ccf7f463 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -706,31 +706,6 @@ def test_2d_reduction_odd_shapes( # Check the code for multiple Rn_BLOCK's self._assert_reduction_ndims(code, 2) - def test_2d_reduction_no_x_dim(self): - """ - Tests a 2D reduction without an "x" dimension. - """ - # We need a size to get no x dim. - view = self._discontiguous_tensor((2, 346), self.device) - - # Expect 1 block pointer for the input. - result, (code,) = run_and_compare( - self, - torch.prod, - view, - expected_num_block_pointers=1, - expected_num_triton_kernels=1, - config_patches=tiled_reduction_config, - ) - - # Check that there's no X dimension in the signature. - (signature_line,) = ( - line for line in code.splitlines() if line.startswith("def triton") - ) - self.assertNotIn("BLOCK", signature_line) - - # Check for 2 reduction dimensions in the body. - self._assert_reduction_ndims(code, 2) @parametrize( "size,expected_num_block_pointers,expected_num_triton_kernels,expect_fallback", diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index 00de22393abf3..cffe585a236a1 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -215,11 +215,11 @@ def want_no_x_dim(features: SIMDKernelFeatures) -> bool: Heuristic to decide if we should drop the X dimension from a persistent reduction kernel. So the [XBLOCK, RBLOCK] block becomes a [RBLOCK] block and XBLOCK is forced to be always 1. Strangely this is faster than a [1, RBLOCK] block in some cases. + + ROCm branch change: Remove want_no_x_dim for persistent reduction. + Inductor benchmarks show no perf advantage and simplifies autotune flow. """ - return ( - features.get_reduction_hint() == ReductionHint.INNER - and V.graph.sizevars.statically_known_geq(features.reduction_numel, 256) - ) + return False @staticmethod def reduction_split_factor( diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index a404abc136f52..dc8722b3cee8f 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1768,13 +1768,10 @@ def should_use_persistent_reduction(self) -> bool: ) def want_no_x_dim(self): - if ( - self.persistent_reduction - and len(self.numels) == self.num_reduction_dims + 1 - ): - if self.fixed_config: - return self.fixed_config["XBLOCK"] == 1 - return V.choices.want_no_x_dim(self.features) + """ + ROCm branch change: Remove want_no_x_dim for persistent reduction. + Inductor benchmarks show no perf advantage and simplifies autotune flow. + """ return False @property diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 54c7e83c0879b..2ea6a2d467a67 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2556,6 +2556,10 @@ def _persistent_reduction_configs( rnumel = get_total_reduction_numel(size_hints) MAX_PERSISTENT_BLOCK_NUMEL = 4096 + max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or ( + inductor_meta.get("max_autotune") + or inductor_meta.get("max_autotune_pointwise") + ) if "y" not in size_hints: configs = [ @@ -2585,18 +2589,27 @@ def _persistent_reduction_configs( if "y" in size_hints: pass # TODO(jansel): we should be able to improve these heuristics - elif reduction_hint == ReductionHint.INNER and rnumel >= 256: - configs = configs[:1] - elif reduction_hint == ReductionHint.OUTER: - configs = configs[-1:] - elif reduction_hint == ReductionHint.OUTER_TINY: - configs = [ + if not max_autotune_enabled: # Don't filter if tuning enabled + if reduction_hint == ReductionHint.INNER and rnumel >= 256: + configs = configs[:1] + elif reduction_hint == ReductionHint.OUTER: + configs = configs[-1:] + + if reduction_hint == ReductionHint.OUTER_TINY: + tiny_configs = [ triton_config_reduction( size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel, ) ] + if max_autotune_enabled: + for tconfig in tiny_configs: + if tconfig not in configs: + configs.append(tconfig) + else: + configs = tiny_configs + for c in configs: # we don't need Rn_BLOCK for persistent reduction for prefix in size_hints: