From 287431d7bddf1c5c21d35c3f9816db61aa3e25f8 Mon Sep 17 00:00:00 2001 From: Carlo Date: Tue, 17 Jun 2025 18:53:25 +0100 Subject: [PATCH 1/5] first attempt at adding nested regions --- .../runtime/stablehlo_composite_helper.cpp | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp b/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp index 101b36908555..93b4beeea4a6 100644 --- a/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp +++ b/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp @@ -7,6 +7,7 @@ #include "absl/log/log.h" #include "absl/strings/str_cat.h" +#include "llvm/ADT/STLExtras.h" #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/IR/IRMapping.h" #include "mlir/Support/LogicalResult.h" @@ -120,8 +121,29 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { std::unordered_map> boundary_output_ops_map = BuildBoundaryOutputOpsMap(func_op); - for (const auto& [unused, ops] : boundary_output_ops_map) { - if (mlir::failed(BuildStableHLOComposite(ops, op_order_map))) { + struct BoundaryGroup { + std::string key; + llvm::SmallVector ops; + size_t first_order; // lexical min + }; + + llvm::SmallVector groups; + groups.reserve(boundary_output_ops_map.size()); + + for (auto& kv : boundary_output_ops_map) { + size_t min_ord = std::numeric_limits::max(); + for (mlir::Operation* op : kv.second) { + if (op != nullptr) min_ord = std::min(min_ord, op_order_map.at(op)); + } + groups.push_back({kv.first, kv.second, min_ord}); + } + + llvm::sort(groups, [](const BoundaryGroup& a, const BoundaryGroup& b) { + return a.first_order < b.first_order; // inner → outer + }); + + for (auto& grp : groups) { + if (mlir::failed(BuildStableHLOComposite(grp.ops, op_order_map))) { func_op.emitError() << "failed to build composite."; return signalPassFailure(); } From 8ee9d3b556c4c29a73c4f021f51c5dd9deb4e38a Mon Sep 17 00:00:00 2001 From: Carlo Date: Tue, 17 Jun 2025 19:22:25 +0100 Subject: [PATCH 2/5] changed comparison order + rehceck for op ids --- torch_xla/csrc/runtime/stablehlo_composite_helper.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp b/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp index 93b4beeea4a6..e237eb6049a2 100644 --- a/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp +++ b/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp @@ -139,10 +139,11 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { } llvm::sort(groups, [](const BoundaryGroup& a, const BoundaryGroup& b) { - return a.first_order < b.first_order; // inner → outer + return a.first_order > b.first_order; // inner → outer }); for (auto& grp : groups) { + op_order_map = BuildOpOrderMap(func_op); if (mlir::failed(BuildStableHLOComposite(grp.ops, op_order_map))) { func_op.emitError() << "failed to build composite."; return signalPassFailure(); From 8cf8637e4d071745812fb5098ac6c567a3b3eecc Mon Sep 17 00:00:00 2001 From: Carlo Date: Wed, 18 Jun 2025 14:38:00 +0100 Subject: [PATCH 3/5] removed composites from outer fns --- .../runtime/stablehlo_composite_helper.cpp | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp b/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp index e237eb6049a2..572cc715cde9 100644 --- a/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp +++ b/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp @@ -1,5 +1,6 @@ #include "torch_xla/csrc/runtime/stablehlo_composite_helper.h" +#include #include #include #include @@ -124,22 +125,22 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { struct BoundaryGroup { std::string key; llvm::SmallVector ops; - size_t first_order; // lexical min + size_t last_order; }; llvm::SmallVector groups; groups.reserve(boundary_output_ops_map.size()); for (auto& kv : boundary_output_ops_map) { - size_t min_ord = std::numeric_limits::max(); + size_t last_ord = 0; for (mlir::Operation* op : kv.second) { - if (op != nullptr) min_ord = std::min(min_ord, op_order_map.at(op)); + if (op != nullptr) last_ord = std::max(last_ord, op_order_map.at(op)); } - groups.push_back({kv.first, kv.second, min_ord}); + groups.push_back({kv.first, kv.second, last_ord}); } llvm::sort(groups, [](const BoundaryGroup& a, const BoundaryGroup& b) { - return a.first_order > b.first_order; // inner → outer + return a.last_order < b.last_order; }); for (auto& grp : groups) { @@ -344,6 +345,22 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { } } + llvm::DenseSet wrapper_set(output_ops.begin(), + output_ops.end()); + + for (mlir::Operation* mark : output_ops) + if (mark->use_empty()) mark->erase(); + + for (mlir::Operation* op : llvm::reverse(impl_ops)) { + if (wrapper_set.contains(op) || !op->use_empty()) continue; + + bool pure_or_composite = mlir::wouldOpBeTriviallyDead(op) || + llvm::isa(op) || + llvm::isa(op); + + if (pure_or_composite) op->erase(); + } + if (!mlir::sortTopologically(composite_op->getBlock())) { composite_op->emitError() << "The graph is not acyclic after BuildStableHLOCompositePass pass."; From 1dc03f867843b05a27c3ba6d561fef536e2a0bc4 Mon Sep 17 00:00:00 2001 From: Carlo Date: Wed, 18 Jun 2025 16:22:22 +0100 Subject: [PATCH 4/5] Fixed stableHLO nested regions --- torch_xla/csrc/runtime/stablehlo_composite_helper.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp b/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp index 572cc715cde9..c743f8527285 100644 --- a/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp +++ b/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/runtime/stablehlo_composite_helper.h" -#include #include #include #include @@ -8,7 +7,6 @@ #include "absl/log/log.h" #include "absl/strings/str_cat.h" -#include "llvm/ADT/STLExtras.h" #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/IR/IRMapping.h" #include "mlir/Support/LogicalResult.h" From 1945b59483733ca9b0661e8a67b6248a51972f7d Mon Sep 17 00:00:00 2001 From: Carlomus <48855305+Carlomus@users.noreply.github.com> Date: Wed, 25 Jun 2025 11:07:11 +0100 Subject: [PATCH 5/5] Update test_composite.py Renamed _impl in test, so that first impl is 'impl', first is 'impl_0' and so on --- test/stablehlo/test_composite.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/stablehlo/test_composite.py b/test/stablehlo/test_composite.py index 8fe211475ba1..b2c188aba944 100644 --- a/test/stablehlo/test_composite.py +++ b/test/stablehlo/test_composite.py @@ -147,10 +147,10 @@ def forward(self, x, y): stablehlo = self.run_func_get_stablehlo(M(), input_args) self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2) self.assertTrue( - '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}' + '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}' in stablehlo) self.assertTrue( - '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}' + '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0}' in stablehlo) def test_composite_builder_sdpa_pattern(self): @@ -175,10 +175,10 @@ def forward(self, x, y): stablehlo = self.run_func_get_stablehlo(M(), input_args) self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2) self.assertTrue( - '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}' + '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}' in stablehlo) self.assertTrue( - '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}' + '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0}' in stablehlo) def test_composite_builder_export_sdpa_pattern(self): @@ -208,10 +208,10 @@ def forward(self, x, y): stablehlo = stablehlo_gm.get_stablehlo_text() self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2) self.assertTrue( - '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}' + '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}' in stablehlo) self.assertTrue( - '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}' + '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0}' in stablehlo) if has_tf_package(): self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb'))) @@ -240,10 +240,10 @@ def forward(self, x, y): stablehlo = stablehlo_gm.get_stablehlo_text() self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2) self.assertTrue( - '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}' + '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}' in stablehlo) self.assertTrue( - '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}' + '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0}' in stablehlo) if has_tf_package(): self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb')))