From 1c19d56238f1cd181327bc9ba8b23fc785c6003c Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 18 Jun 2025 18:05:32 +0200 Subject: [PATCH 01/12] feat: support forward-mode Mooncake [experimental] --- .../docs/src/explanation/backends.md | 4 +- .../DifferentiationInterfaceMooncakeExt.jl | 20 ++-- .../forward_onearg.jl | 92 ++++++++++++++ .../forward_twoarg.jl | 112 ++++++++++++++++++ .../src/DifferentiationInterface.jl | 2 + .../test/Back/Mooncake/test.jl | 6 +- 6 files changed, 227 insertions(+), 9 deletions(-) create mode 100644 DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl create mode 100644 DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index 013d9e7c8..b04f35f6b 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -12,7 +12,7 @@ We support the following dense backend choices from [ADTypes.jl](https://github. - [`AutoFiniteDifferences`](@extref ADTypes.AutoFiniteDifferences) - [`AutoForwardDiff`](@extref ADTypes.AutoForwardDiff) - [`AutoGTPSA`](@extref ADTypes.AutoGTPSA) -- [`AutoMooncake`](@extref ADTypes.AutoMooncake) +- [`AutoMooncake`](@extref ADTypes.AutoMooncake) and [`AutoMooncakeForward`](@extref ADTypes.AutoMooncake) - [`AutoPolyesterForwardDiff`](@extref ADTypes.AutoPolyesterForwardDiff) - [`AutoReverseDiff`](@extref ADTypes.AutoReverseDiff) - [`AutoSymbolics`](@extref ADTypes.AutoSymbolics) @@ -48,6 +48,7 @@ In practice, many AD backends have custom implementations for high-level operato | `AutoForwardDiff` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | `AutoGTPSA` | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | | `AutoMooncake` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | + | `AutoMooncakeForward` | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | `AutoPolyesterForwardDiff` | 🔀 | ❌ | 🔀 | ✅ | ✅ | 🔀 | 🔀 | 🔀 | | `AutoReverseDiff` | ❌ | 🔀 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | `AutoSymbolics` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | @@ -68,6 +69,7 @@ Moreover, each context type is supported by a specific subset of backends: | `AutoForwardDiff` | ✅ | ✅ | | `AutoGTPSA` | ✅ | ❌ | | `AutoMooncake` | ✅ | ✅ | +| `AutoMooncakeForward` | ✅ | ✅ | | `AutoPolyesterForwardDiff` | ✅ | ✅ | | `AutoReverseDiff` | ✅ | ❌ | | `AutoSymbolics` | ✅ | ✅ | diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 6253ea229..0d40a2590 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -1,29 +1,35 @@ module DifferentiationInterfaceMooncakeExt -using ADTypes: ADTypes, AutoMooncake +using ADTypes: ADTypes, AutoMooncake, AutoMooncakeForward import DifferentiationInterface as DI using Mooncake: CoDual, Config, + Dual, + prepare_derivative_cache, prepare_gradient_cache, prepare_pullback_cache, + primal, + tangent, tangent_type, + value_and_derivative!!, value_and_gradient!!, value_and_pullback!!, + zero_dual, zero_tangent, _copy_output, _copy_to_output!! -DI.check_available(::AutoMooncake) = true +const AnyAutoMooncake{C} = Union{AutoMooncake{C},AutoMooncakeForward{C}} -get_config(::AutoMooncake{Nothing}) = Config() -get_config(backend::AutoMooncake{<:Config}) = backend.config +DI.check_available(::AnyAutoMooncake) = true -# tangents need to be copied before returning, otherwise they are still aliased in the cache -mycopy(x::Union{Number,AbstractArray{<:Number}}) = copy(x) -mycopy(x) = deepcopy(x) +get_config(::AnyAutoMooncake{Nothing}) = Config() +get_config(backend::AnyAutoMooncake{<:Config}) = backend.config include("onearg.jl") include("twoarg.jl") +include("forward_onearg.jl") +include("forward_twoarg.jl") end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl new file mode 100644 index 000000000..c9228d972 --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl @@ -0,0 +1,92 @@ +## Pushforward + +struct MooncakeOneArgPushforwardPrep{SIG,Tcache,DX} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} + cache::Tcache + dx_righttype::DX +end + +function DI.prepare_pushforward_nokwarg( + strict::Val, + f::F, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + _sig = DI.signature(f, backend, x, tx, contexts...; strict) + config = get_config(backend) + # TODO: silence_debug_messages + cache = prepare_derivative_cache(f, x, map(DI.unwrap, contexts)...; config.debug_mode) + dx_righttype = zero_tangent(x) + prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype) + return prep +end + +function DI.value_and_pushforward( + f::F, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x::X, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C,X} + DI.check_prep(f, prep, backend, x, tx, contexts...) + ys_and_ty = map(tx) do dx + dx_righttype = + dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx) + y_dual = value_and_derivative!!( + prep.cache, + zero_dual(f), + Dual(x, dx_righttype), + map(zero_dual ∘ DI.unwrap, contexts)..., + ) + y = primal(y_dual) + dy = _copy_output(tangent(y_dual)) + return y, dy + end + y = first(ys_and_ty[1]) + ty = last.(ys_and_ty) + return y, ty +end + +function DI.pushforward( + f::F, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) + return DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)[2] +end + +function DI.value_and_pushforward!( + f::F, + ty::NTuple, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) + y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...) + foreach(copyto!, ty, new_ty) + return y, ty +end + +function DI.pushforward!( + f::F, + ty::NTuple, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) + DI.value_and_pushforward!(f, ty, prep, backend, x, tx, contexts...) + return ty +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl new file mode 100644 index 000000000..f90524643 --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -0,0 +1,112 @@ +## Pushforward + +struct MooncakeTwoArgPushforwardPrep{SIG,Tcache,DX,DY} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} + cache::Tcache + dx_righttype::DX + dy_righttype::DY +end + +function DI.prepare_pushforward_nokwarg( + strict::Val, + f!::F, + y, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) + config = get_config(backend) + # TODO: silence_debug_messages + cache = prepare_derivative_cache( + f!, y, x, map(DI.unwrap, contexts)...; config.debug_mode + ) + dx_righttype = zero_tangent(x) + dy_righttype = zero_tangent(y) + prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dx_righttype, dy_righttype) + return prep +end + +function DI.value_and_pushforward( + f!::F, + y, + prep::MooncakeTwoArgPushforwardPrep, + backend::AutoMooncakeForward, + x::X, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C,X} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) + ty = map(tx) do dx + dx_righttype = + dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx) + y_dual = zero_dual(y) + value_and_derivative!!( + prep.cache, + zero_dual(f!), + y_dual, + Dual(x, dx_righttype), + map(zero_dual ∘ DI.unwrap, contexts)..., + ) + dy = _copy_output(tangent(y_dual)) + return dy + end + return y, ty +end + +function DI.pushforward( + f!::F, + y, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) + return DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)[2] +end + +function DI.value_and_pushforward!( + f!::F, + y::Y, + ty::NTuple, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x::X, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C,X,Y} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) + foreach(tx, ty) do dx, dy + dx_righttype = + dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx) + dy_righttype = + dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) + value_and_derivative!!( + prep.cache, + zero_dual(f!), + Dual(y, dy_righttype), + Dual(x, dx_righttype), + map(zero_dual ∘ DI.unwrap, contexts)..., + ) + dy === dy_righttype || copyto!(dy, dy_righttype) + end + return y, ty +end + +function DI.pushforward!( + f!::F, + y, + ty::NTuple, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + DI.check_prep(f!, y, ty, prep, backend, x, tx, contexts...) + DI.pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) + return ty +end diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 32e699572..85f25fa28 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -28,6 +28,7 @@ using ADTypes: AutoForwardDiff, AutoGTPSA, AutoMooncake, + AutoMooncakeForward, AutoPolyesterForwardDiff, AutoReverseDiff, AutoSymbolics, @@ -115,6 +116,7 @@ export AutoFiniteDifferences export AutoForwardDiff export AutoGTPSA export AutoMooncake +export AutoMooncakeForward export AutoPolyesterForwardDiff export AutoReverseDiff export AutoSymbolics diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index 8c9ab839a..7f48c3899 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -10,7 +10,11 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -backends = [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())] +backends = [ + AutoMooncake(; config=nothing), + AutoMooncake(; config=Mooncake.Config()), + AutoMooncakeForward(; config=nothing); +] for backend in backends @test check_available(backend) From 2f9b365ef1007dad935274c96cacd965ba2ea97f Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 1 Jul 2025 10:33:59 +0200 Subject: [PATCH 02/12] Fix comma --- DifferentiationInterface/test/Back/Mooncake/test.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index 7f48c3899..b695179f1 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -13,7 +13,7 @@ LOGGING = get(ENV, "CI", "false") == "false" backends = [ AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config()), - AutoMooncakeForward(; config=nothing); + AutoMooncakeForward(; config=nothing), ] for backend in backends From 9c838f35f12c5a8f8c25dfc27d7c319405eb3656 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 13 Aug 2025 13:58:13 +0200 Subject: [PATCH 03/12] Bump versions --- DifferentiationInterface/Project.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 3efc6c248..cef7316e7 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.7.4" +version = "0.7.5" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -56,7 +56,7 @@ DifferentiationInterfaceTrackerExt = "Tracker" DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"] [compat] -ADTypes = "1.13.0" +ADTypes = "1.17.0" Aqua = "0.8.12" ChainRulesCore = "1.23.0" ComponentArrays = "0.15.27" @@ -77,7 +77,7 @@ JET = "0.9" JLArrays = "0.2.0" JuliaFormatter = "1,2" LinearAlgebra = "1" -Mooncake = "0.4.122" +Mooncake = "0.4.147" Pkg = "1" PolyesterForwardDiff = "0.1.2" Random = "1" From 3e1bd9c347444d8664502d24b61d7bb6ca1e6839 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 13 Aug 2025 14:32:56 +0200 Subject: [PATCH 04/12] Test rule --- .../test/Back/DifferentiateWith/test.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index d2bf57f88..975e6ec22 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -62,13 +62,17 @@ end; @testset for scen in filter(differentiatewith_scenarios()) do scen DIT.operator(scen) == :pullback end - Mooncake.TestUtils.test_rule(StableRNG(0), scen.f, scen.x; is_primitive=true) + Mooncake.TestUtils.test_rule( + StableRNG(0), scen.f, scen.x; is_primitive=true, mode=Mooncake.ReverseMode() + ) end end; @testset "Mooncake errors" begin MooncakeDifferentiateWithError = - Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceMooncakeExt).MooncakeDifferentiateWithError + Base.get_extension( + DifferentiationInterface, :DifferentiationInterfaceMooncakeExt + ).MooncakeDifferentiateWithError e = MooncakeDifferentiateWithError(identity, 1.0, 2.0) @test sprint(showerror, e) == From 80d5f73a30b42d49863214e75e04ec4a66ddb7da Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 13 Aug 2025 14:39:00 +0200 Subject: [PATCH 05/12] Format --- DifferentiationInterface/test/Back/DifferentiateWith/test.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index 975e6ec22..c5403c9a1 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -70,9 +70,7 @@ end; @testset "Mooncake errors" begin MooncakeDifferentiateWithError = - Base.get_extension( - DifferentiationInterface, :DifferentiationInterfaceMooncakeExt - ).MooncakeDifferentiateWithError + Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceMooncakeExt).MooncakeDifferentiateWithError e = MooncakeDifferentiateWithError(identity, 1.0, 2.0) @test sprint(showerror, e) == From 0c3a08402f9acdbac8379f571eaa13b047d4c6cf Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 13 Aug 2025 14:54:30 +0200 Subject: [PATCH 06/12] Fix --- .../test/Back/DifferentiateWith/test.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index c5403c9a1..684270cc6 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -63,14 +63,16 @@ end; DIT.operator(scen) == :pullback end Mooncake.TestUtils.test_rule( - StableRNG(0), scen.f, scen.x; is_primitive=true, mode=Mooncake.ReverseMode() + StableRNG(0), scen.f, scen.x; is_primitive=true, mode=Mooncake.ReverseMode ) end end; @testset "Mooncake errors" begin MooncakeDifferentiateWithError = - Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceMooncakeExt).MooncakeDifferentiateWithError + Base.get_extension( + DifferentiationInterface, :DifferentiationInterfaceMooncakeExt + ).MooncakeDifferentiateWithError e = MooncakeDifferentiateWithError(identity, 1.0, 2.0) @test sprint(showerror, e) == From 01b90bf202fd210db4e7fba85957a98aecbdd364 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 13 Aug 2025 15:17:25 +0200 Subject: [PATCH 07/12] Format --- DifferentiationInterface/test/Back/DifferentiateWith/test.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index 684270cc6..9c655e001 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -70,9 +70,7 @@ end; @testset "Mooncake errors" begin MooncakeDifferentiateWithError = - Base.get_extension( - DifferentiationInterface, :DifferentiationInterfaceMooncakeExt - ).MooncakeDifferentiateWithError + Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceMooncakeExt).MooncakeDifferentiateWithError e = MooncakeDifferentiateWithError(identity, 1.0, 2.0) @test sprint(showerror, e) == From 39a22dcc36e0c633148e93c5419f5967862d0f34 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 13 Aug 2025 17:18:42 +0200 Subject: [PATCH 08/12] Fix coverage --- .../DifferentiationInterfaceMooncakeExt.jl | 2 +- .../DifferentiationInterfaceMooncakeExt/forward_twoarg.jl | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 15d22c26e..d037498c9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -33,7 +33,7 @@ using Mooncake: const AnyAutoMooncake{C} = Union{AutoMooncake{C},AutoMooncakeForward{C}} -DI.check_available(::AnyAutoMooncake) = true +DI.check_available(::AnyAutoMooncake{C}) where {C} = true get_config(::AnyAutoMooncake{Nothing}) = Config() get_config(backend::AnyAutoMooncake{<:Config}) = backend.config diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index f90524643..56dc3411c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -58,7 +58,7 @@ end function DI.pushforward( f!::F, y, - prep::MooncakeOneArgPushforwardPrep, + prep::MooncakeTwoArgPushforwardPrep, backend::AutoMooncakeForward, x, tx::NTuple, @@ -72,7 +72,7 @@ function DI.value_and_pushforward!( f!::F, y::Y, ty::NTuple, - prep::MooncakeOneArgPushforwardPrep, + prep::MooncakeTwoArgPushforwardPrep, backend::AutoMooncakeForward, x::X, tx::NTuple, @@ -100,7 +100,7 @@ function DI.pushforward!( f!::F, y, ty::NTuple, - prep::MooncakeOneArgPushforwardPrep, + prep::MooncakeTwoArgPushforwardPrep, backend::AutoMooncakeForward, x, tx::NTuple, From 076166a0d1504ef347e9e263a90131f7e7b54f20 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 13 Aug 2025 17:20:20 +0200 Subject: [PATCH 09/12] Docs --- DifferentiationInterface/docs/src/explanation/backends.md | 4 ++-- DifferentiationInterface/src/misc/differentiate_with.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index 9cee8be93..3be877d0d 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -12,7 +12,7 @@ We support the following dense backend choices from [ADTypes.jl](https://github. - [`AutoFiniteDifferences`](@extref ADTypes.AutoFiniteDifferences) - [`AutoForwardDiff`](@extref ADTypes.AutoForwardDiff) - [`AutoGTPSA`](@extref ADTypes.AutoGTPSA) -- [`AutoMooncake`](@extref ADTypes.AutoMooncake) and [`AutoMooncakeForward`](@extref ADTypes.AutoMooncake) +- [`AutoMooncake`](@extref ADTypes.AutoMooncake) and [`AutoMooncakeForward`](@extref ADTypes.AutoMooncake) (the latter is experimental) - [`AutoPolyesterForwardDiff`](@extref ADTypes.AutoPolyesterForwardDiff) - [`AutoReverseDiff`](@extref ADTypes.AutoReverseDiff) - [`AutoSymbolics`](@extref ADTypes.AutoSymbolics) @@ -97,7 +97,7 @@ In general, using a forward outer backend over a reverse inner backend will yiel The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends. It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use. In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself. -At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)). +At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), reverse-mode [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)). ## Implementations diff --git a/DifferentiationInterface/src/misc/differentiate_with.jl b/DifferentiationInterface/src/misc/differentiate_with.jl index 256d46f75..98cb63802 100644 --- a/DifferentiationInterface/src/misc/differentiate_with.jl +++ b/DifferentiationInterface/src/misc/differentiate_with.jl @@ -13,7 +13,7 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be !!! warning `DifferentiateWith` only supports out-of-place functions `y = f(x)` without additional context arguments. - It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake](https://github.com/chalk-lab/Mooncake.jl) or automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules. + It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), reverse-mode [Mooncake](https://github.com/chalk-lab/Mooncake.jl), or if it automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules. For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper). !!! warning From da539b37598bfa93fbdf5c4adbf5f68e21429cb9 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 13 Aug 2025 17:51:57 +0200 Subject: [PATCH 10/12] Fix check prep --- .../ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index 56dc3411c..d2c66f290 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -106,7 +106,7 @@ function DI.pushforward!( tx::NTuple, contexts::Vararg{DI.Context,C}; ) where {F,C} - DI.check_prep(f!, y, ty, prep, backend, x, tx, contexts...) + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) DI.pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) return ty end From e5edb0bb5dc9158bdb670eb909ac5c55f1bf3d59 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 13 Aug 2025 18:42:14 +0200 Subject: [PATCH 11/12] Fix --- .../ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index d2c66f290..ca554bf27 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -107,6 +107,6 @@ function DI.pushforward!( contexts::Vararg{DI.Context,C}; ) where {F,C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) - DI.pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) + DI.value_and_pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) return ty end From bc486822c15c789b1323df633c547384961ed6a9 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 13 Aug 2025 20:52:37 +0200 Subject: [PATCH 12/12] Add config --- .../DifferentiationInterfaceMooncakeExt/forward_onearg.jl | 5 +++-- .../DifferentiationInterfaceMooncakeExt/forward_twoarg.jl | 8 ++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl index c9228d972..ebf8601d5 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl @@ -16,8 +16,9 @@ function DI.prepare_pushforward_nokwarg( ) where {F,C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) config = get_config(backend) - # TODO: silence_debug_messages - cache = prepare_derivative_cache(f, x, map(DI.unwrap, contexts)...; config.debug_mode) + cache = prepare_derivative_cache( + f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages + ) dx_righttype = zero_tangent(x) prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype) return prep diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index ca554bf27..56b655b2e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -18,9 +18,13 @@ function DI.prepare_pushforward_nokwarg( ) where {F,C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) config = get_config(backend) - # TODO: silence_debug_messages cache = prepare_derivative_cache( - f!, y, x, map(DI.unwrap, contexts)...; config.debug_mode + f!, + y, + x, + map(DI.unwrap, contexts)...; + config.debug_mode, + config.silence_debug_messages, ) dx_righttype = zero_tangent(x) dy_righttype = zero_tangent(y)