diff --git a/test/stablehlo/test_composite.py b/test/stablehlo/test_composite.py index 8fe211475ba..b2c188aba94 100644 --- a/test/stablehlo/test_composite.py +++ b/test/stablehlo/test_composite.py @@ -147,10 +147,10 @@ def forward(self, x, y): stablehlo = self.run_func_get_stablehlo(M(), input_args) self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2) self.assertTrue( - '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}' + '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}' in stablehlo) self.assertTrue( - '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}' + '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0}' in stablehlo) def test_composite_builder_sdpa_pattern(self): @@ -175,10 +175,10 @@ def forward(self, x, y): stablehlo = self.run_func_get_stablehlo(M(), input_args) self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2) self.assertTrue( - '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}' + '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}' in stablehlo) self.assertTrue( - '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}' + '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0}' in stablehlo) def test_composite_builder_export_sdpa_pattern(self): @@ -208,10 +208,10 @@ def forward(self, x, y): stablehlo = stablehlo_gm.get_stablehlo_text() self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2) self.assertTrue( - '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}' + '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}' in stablehlo) self.assertTrue( - '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}' + '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0}' in stablehlo) if has_tf_package(): self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb'))) @@ -240,10 +240,10 @@ def forward(self, x, y): stablehlo = stablehlo_gm.get_stablehlo_text() self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2) self.assertTrue( - '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}' + '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}' in stablehlo) self.assertTrue( - '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}' + '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0}' in stablehlo) if has_tf_package(): self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb'))) diff --git a/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp b/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp index 101b3690855..c743f852728 100644 --- a/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp +++ b/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp @@ -120,8 +120,30 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { std::unordered_map> boundary_output_ops_map = BuildBoundaryOutputOpsMap(func_op); - for (const auto& [unused, ops] : boundary_output_ops_map) { - if (mlir::failed(BuildStableHLOComposite(ops, op_order_map))) { + struct BoundaryGroup { + std::string key; + llvm::SmallVector ops; + size_t last_order; + }; + + llvm::SmallVector groups; + groups.reserve(boundary_output_ops_map.size()); + + for (auto& kv : boundary_output_ops_map) { + size_t last_ord = 0; + for (mlir::Operation* op : kv.second) { + if (op != nullptr) last_ord = std::max(last_ord, op_order_map.at(op)); + } + groups.push_back({kv.first, kv.second, last_ord}); + } + + llvm::sort(groups, [](const BoundaryGroup& a, const BoundaryGroup& b) { + return a.last_order < b.last_order; + }); + + for (auto& grp : groups) { + op_order_map = BuildOpOrderMap(func_op); + if (mlir::failed(BuildStableHLOComposite(grp.ops, op_order_map))) { func_op.emitError() << "failed to build composite."; return signalPassFailure(); } @@ -321,6 +343,22 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { } } + llvm::DenseSet wrapper_set(output_ops.begin(), + output_ops.end()); + + for (mlir::Operation* mark : output_ops) + if (mark->use_empty()) mark->erase(); + + for (mlir::Operation* op : llvm::reverse(impl_ops)) { + if (wrapper_set.contains(op) || !op->use_empty()) continue; + + bool pure_or_composite = mlir::wouldOpBeTriviallyDead(op) || + llvm::isa(op) || + llvm::isa(op); + + if (pure_or_composite) op->erase(); + } + if (!mlir::sortTopologically(composite_op->getBlock())) { composite_op->emitError() << "The graph is not acyclic after BuildStableHLOCompositePass pass.";