diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp index 648cdb83dbd..10c469f1c66 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp @@ -626,7 +626,6 @@ class ReverseRetOpt final : public OpRewritePattern { // skip primal return if (val == Activity::enzyme_constnoneed || - val == Activity::enzyme_activenoneed || val == Activity::enzyme_dupnoneed) { newRetActivityArgs.push_back(iattr); continue; @@ -636,15 +635,35 @@ class ReverseRetOpt final : public OpRewritePattern { switch (val) { case Activity::enzyme_active: - if (!res.use_empty()) { - outs_args.push_back(res); - out_ty.push_back(res.getType()); - newRetActivityArgs.push_back(iattr); - } else { + if (res.use_empty()) { changed = true; auto new_activenn = ActivityAttr::get(rewriter.getContext(), Activity::enzyme_activenoneed); newRetActivityArgs.push_back(new_activenn); + } else { + int in_idx = 0; + for (auto act : inpActivity) { + auto v = cast(act).getValue(); + in_idx += + (v == Activity::enzyme_dup || v == Activity::enzyme_dupnoneed) + ? 2 + : 1; + } + in_idx += out_idx; + auto dres = uop.getInputs()[in_idx]; + + if (matchPattern(dres, m_Zero()) || + matchPattern(dres, m_AnyZeroFloat())) { + changed = true; + auto new_const = ActivityAttr::get(rewriter.getContext(), + Activity::enzyme_const); + newRetActivityArgs.push_back(new_const); + } else { + newRetActivityArgs.push_back(iattr); + } + + outs_args.push_back(res); + out_ty.push_back(res.getType()); } break; @@ -668,7 +687,31 @@ class ReverseRetOpt final : public OpRewritePattern { newRetActivityArgs.push_back(iattr); break; - case Activity::enzyme_activenoneed: + case Activity::enzyme_activenoneed: { + int in_idx = 0; + for (auto act : inpActivity) { + auto v = cast(act).getValue(); + in_idx += + (v == Activity::enzyme_dup || v == Activity::enzyme_dupnoneed) + ? 2 + : 1; + } + in_idx += out_idx; + + auto dres = uop.getInputs()[in_idx]; + + if (matchPattern(dres, m_Zero()) || + matchPattern(dres, m_AnyZeroFloat())) { + changed = true; + auto new_constnn = ActivityAttr::get(rewriter.getContext(), + Activity::enzyme_constnoneed); + newRetActivityArgs.push_back(new_constnn); + } else { + newRetActivityArgs.push_back(iattr); + } + + continue; + } case Activity::enzyme_constnoneed: case Activity::enzyme_dupnoneed: break; @@ -763,6 +806,10 @@ class ReverseRetOpt final : public OpRewritePattern { } else if (new_val == Activity::enzyme_constnoneed && old_val == Activity::enzyme_const) { ++oldIdx; // skip const primal + } else if (new_val == Activity::enzyme_const && + old_val == Activity::enzyme_active) { + uop.getOutputs()[oldIdx++].replaceAllUsesWith( + newOp.getOutputs()[newIdx++]); } } } diff --git a/enzyme/test/MLIR/ReverseMode/canonicalize.mlir b/enzyme/test/MLIR/ReverseMode/canonicalize.mlir index 48816d72b84..33676ec3d7a 100644 --- a/enzyme/test/MLIR/ReverseMode/canonicalize.mlir +++ b/enzyme/test/MLIR/ReverseMode/canonicalize.mlir @@ -41,4 +41,20 @@ module { // CHECK: enzyme.autodiff @square2(%arg0, %arg1, %arg2, %arg3){{.*}}activity = [#enzyme, #enzyme]{{.*}}ret_activity = [#enzyme, #enzyme]{{.*}} return %cst : f32 } + + // Test 5: active -> const for ret_activity (iff derivative is 0) + func.func @test5(%x: f32, %y: f32, %dr0: f32) -> (f32,f32,f32,f32) { + %cst = arith.constant 0.0000e+00 : f32 + %r:4 = enzyme.autodiff @square2(%x,%y,%dr0,%cst) { activity=[#enzyme, #enzyme], ret_activity=[#enzyme, #enzyme] } : (f32,f32,f32,f32) -> (f32,f32,f32,f32) + // CHECK: %{{.*}} = enzyme.autodiff @square2(%arg0, %arg1, %arg2, %cst){{.*}}activity = [#enzyme, #enzyme]{{.*}}ret_activity = [#enzyme, #enzyme]{{.*}} + return %r#0,%r#1,%r#2,%r#3 : f32,f32,f32,f32 + } + + // Test 6: active -> activenoneed/const -> constnoneed for ret_activity + func.func @test6(%x: f32, %y: f32, %dr0: f32) -> (f32,f32,f32) { + %cst = arith.constant 0.0000e+00 : f32 + %r:4 = enzyme.autodiff @square2(%x,%y,%dr0,%cst) { activity=[#enzyme, #enzyme], ret_activity=[#enzyme, #enzyme] } : (f32,f32,f32,f32) -> (f32,f32,f32,f32) + // CHECK: %{{.*}} = enzyme.autodiff @square2(%arg0, %arg1, %arg2, %cst){{.*}}activity = [#enzyme, #enzyme]{{.*}}ret_activity = [#enzyme, #enzyme]{{.*}} + return %r#0,%r#2,%r#3 : f32,f32,f32 + } }