diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 3efc6c248..cef7316e7 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.7.4" +version = "0.7.5" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -56,7 +56,7 @@ DifferentiationInterfaceTrackerExt = "Tracker" DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"] [compat] -ADTypes = "1.13.0" +ADTypes = "1.17.0" Aqua = "0.8.12" ChainRulesCore = "1.23.0" ComponentArrays = "0.15.27" @@ -77,7 +77,7 @@ JET = "0.9" JLArrays = "0.2.0" JuliaFormatter = "1,2" LinearAlgebra = "1" -Mooncake = "0.4.122" +Mooncake = "0.4.147" Pkg = "1" PolyesterForwardDiff = "0.1.2" Random = "1" diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index cb43633ee..3be877d0d 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -12,7 +12,7 @@ We support the following dense backend choices from [ADTypes.jl](https://github. - [`AutoFiniteDifferences`](@extref ADTypes.AutoFiniteDifferences) - [`AutoForwardDiff`](@extref ADTypes.AutoForwardDiff) - [`AutoGTPSA`](@extref ADTypes.AutoGTPSA) -- [`AutoMooncake`](@extref ADTypes.AutoMooncake) +- [`AutoMooncake`](@extref ADTypes.AutoMooncake) and [`AutoMooncakeForward`](@extref ADTypes.AutoMooncake) (the latter is experimental) - [`AutoPolyesterForwardDiff`](@extref ADTypes.AutoPolyesterForwardDiff) - [`AutoReverseDiff`](@extref ADTypes.AutoReverseDiff) - [`AutoSymbolics`](@extref ADTypes.AutoSymbolics) @@ -48,6 +48,7 @@ In practice, many AD backends have custom implementations for high-level operato | `AutoForwardDiff` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | `AutoGTPSA` | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | | `AutoMooncake` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | + | `AutoMooncakeForward` | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | `AutoPolyesterForwardDiff` | 🔀 | ❌ | 🔀 | ✅ | ✅ | 🔀 | 🔀 | 🔀 | | `AutoReverseDiff` | ❌ | 🔀 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | `AutoSymbolics` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | @@ -68,6 +69,7 @@ Moreover, each context type is supported by a specific subset of backends: | `AutoForwardDiff` | ✅ | ✅ | | `AutoGTPSA` | ✅ | ❌ | | `AutoMooncake` | ✅ | ✅ | +| `AutoMooncakeForward` | ✅ | ✅ | | `AutoPolyesterForwardDiff` | ✅ | ✅ | | `AutoReverseDiff` | ✅ | ❌ | | `AutoSymbolics` | ✅ | ✅ | @@ -95,7 +97,7 @@ In general, using a forward outer backend over a reverse inner backend will yiel The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends. It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use. In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself. -At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)). +At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), reverse-mode [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)). ## Implementations diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 321378e23..d037498c9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -1,16 +1,22 @@ module DifferentiationInterfaceMooncakeExt -using ADTypes: ADTypes, AutoMooncake +using ADTypes: ADTypes, AutoMooncake, AutoMooncakeForward import DifferentiationInterface as DI using Mooncake: Mooncake, CoDual, Config, + Dual, + prepare_derivative_cache, prepare_gradient_cache, prepare_pullback_cache, + primal, + tangent, tangent_type, + value_and_derivative!!, value_and_gradient!!, value_and_pullback!!, + zero_dual, zero_tangent, rdata_type, fdata, @@ -25,17 +31,17 @@ using Mooncake: _copy_output, _copy_to_output!! -DI.check_available(::AutoMooncake) = true +const AnyAutoMooncake{C} = Union{AutoMooncake{C},AutoMooncakeForward{C}} -get_config(::AutoMooncake{Nothing}) = Config() -get_config(backend::AutoMooncake{<:Config}) = backend.config +DI.check_available(::AnyAutoMooncake{C}) where {C} = true -# tangents need to be copied before returning, otherwise they are still aliased in the cache -mycopy(x::Union{Number,AbstractArray{<:Number}}) = copy(x) -mycopy(x) = deepcopy(x) +get_config(::AnyAutoMooncake{Nothing}) = Config() +get_config(backend::AnyAutoMooncake{<:Config}) = backend.config include("onearg.jl") include("twoarg.jl") +include("forward_onearg.jl") +include("forward_twoarg.jl") include("differentiate_with.jl") end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl new file mode 100644 index 000000000..ebf8601d5 --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl @@ -0,0 +1,93 @@ +## Pushforward + +struct MooncakeOneArgPushforwardPrep{SIG,Tcache,DX} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} + cache::Tcache + dx_righttype::DX +end + +function DI.prepare_pushforward_nokwarg( + strict::Val, + f::F, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + _sig = DI.signature(f, backend, x, tx, contexts...; strict) + config = get_config(backend) + cache = prepare_derivative_cache( + f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages + ) + dx_righttype = zero_tangent(x) + prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype) + return prep +end + +function DI.value_and_pushforward( + f::F, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x::X, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C,X} + DI.check_prep(f, prep, backend, x, tx, contexts...) + ys_and_ty = map(tx) do dx + dx_righttype = + dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx) + y_dual = value_and_derivative!!( + prep.cache, + zero_dual(f), + Dual(x, dx_righttype), + map(zero_dual ∘ DI.unwrap, contexts)..., + ) + y = primal(y_dual) + dy = _copy_output(tangent(y_dual)) + return y, dy + end + y = first(ys_and_ty[1]) + ty = last.(ys_and_ty) + return y, ty +end + +function DI.pushforward( + f::F, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) + return DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)[2] +end + +function DI.value_and_pushforward!( + f::F, + ty::NTuple, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) + y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...) + foreach(copyto!, ty, new_ty) + return y, ty +end + +function DI.pushforward!( + f::F, + ty::NTuple, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) + DI.value_and_pushforward!(f, ty, prep, backend, x, tx, contexts...) + return ty +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl new file mode 100644 index 000000000..56b655b2e --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -0,0 +1,116 @@ +## Pushforward + +struct MooncakeTwoArgPushforwardPrep{SIG,Tcache,DX,DY} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} + cache::Tcache + dx_righttype::DX + dy_righttype::DY +end + +function DI.prepare_pushforward_nokwarg( + strict::Val, + f!::F, + y, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) + config = get_config(backend) + cache = prepare_derivative_cache( + f!, + y, + x, + map(DI.unwrap, contexts)...; + config.debug_mode, + config.silence_debug_messages, + ) + dx_righttype = zero_tangent(x) + dy_righttype = zero_tangent(y) + prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dx_righttype, dy_righttype) + return prep +end + +function DI.value_and_pushforward( + f!::F, + y, + prep::MooncakeTwoArgPushforwardPrep, + backend::AutoMooncakeForward, + x::X, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C,X} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) + ty = map(tx) do dx + dx_righttype = + dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx) + y_dual = zero_dual(y) + value_and_derivative!!( + prep.cache, + zero_dual(f!), + y_dual, + Dual(x, dx_righttype), + map(zero_dual ∘ DI.unwrap, contexts)..., + ) + dy = _copy_output(tangent(y_dual)) + return dy + end + return y, ty +end + +function DI.pushforward( + f!::F, + y, + prep::MooncakeTwoArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) + return DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)[2] +end + +function DI.value_and_pushforward!( + f!::F, + y::Y, + ty::NTuple, + prep::MooncakeTwoArgPushforwardPrep, + backend::AutoMooncakeForward, + x::X, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C,X,Y} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) + foreach(tx, ty) do dx, dy + dx_righttype = + dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx) + dy_righttype = + dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) + value_and_derivative!!( + prep.cache, + zero_dual(f!), + Dual(y, dy_righttype), + Dual(x, dx_righttype), + map(zero_dual ∘ DI.unwrap, contexts)..., + ) + dy === dy_righttype || copyto!(dy, dy_righttype) + end + return y, ty +end + +function DI.pushforward!( + f!::F, + y, + ty::NTuple, + prep::MooncakeTwoArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) + DI.value_and_pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) + return ty +end diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 929abff94..32bfc4703 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -28,6 +28,7 @@ using ADTypes: AutoForwardDiff, AutoGTPSA, AutoMooncake, + AutoMooncakeForward, AutoPolyesterForwardDiff, AutoReverseDiff, AutoSymbolics, @@ -115,6 +116,7 @@ export AutoFiniteDifferences export AutoForwardDiff export AutoGTPSA export AutoMooncake +export AutoMooncakeForward export AutoPolyesterForwardDiff export AutoReverseDiff export AutoSymbolics diff --git a/DifferentiationInterface/src/misc/differentiate_with.jl b/DifferentiationInterface/src/misc/differentiate_with.jl index 256d46f75..98cb63802 100644 --- a/DifferentiationInterface/src/misc/differentiate_with.jl +++ b/DifferentiationInterface/src/misc/differentiate_with.jl @@ -13,7 +13,7 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be !!! warning `DifferentiateWith` only supports out-of-place functions `y = f(x)` without additional context arguments. - It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake](https://github.com/chalk-lab/Mooncake.jl) or automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules. + It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), reverse-mode [Mooncake](https://github.com/chalk-lab/Mooncake.jl), or if it automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules. For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper). !!! warning diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index d2bf57f88..9c655e001 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -62,7 +62,9 @@ end; @testset for scen in filter(differentiatewith_scenarios()) do scen DIT.operator(scen) == :pullback end - Mooncake.TestUtils.test_rule(StableRNG(0), scen.f, scen.x; is_primitive=true) + Mooncake.TestUtils.test_rule( + StableRNG(0), scen.f, scen.x; is_primitive=true, mode=Mooncake.ReverseMode + ) end end; diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index 8c9ab839a..b695179f1 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -10,7 +10,11 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -backends = [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())] +backends = [ + AutoMooncake(; config=nothing), + AutoMooncake(; config=Mooncake.Config()), + AutoMooncakeForward(; config=nothing), +] for backend in backends @test check_available(backend)