From f606a256fc17081edae7ae8f2528d95a22c3f51a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 14 Jul 2025 16:11:43 +0100 Subject: [PATCH 1/4] Implement Turing.Inference.getlogp_external --- Project.toml | 4 ++-- ext/SliceSamplingTuringExt.jl | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 0c40d12..28fec35 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "SliceSampling" uuid = "43f4d3e8-9711-4a8c-bd1b-03ac73a255cf" -version = "0.7.6" +version = "0.7.7" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -21,7 +21,7 @@ Distributions = "0.25" LinearAlgebra = "1" LogDensityProblems = "2" Random = "1" -Turing = "0.37, 0.38, 0.39" +Turing = "0.39.5" julia = "1.10" [extras] diff --git a/ext/SliceSamplingTuringExt.jl b/ext/SliceSamplingTuringExt.jl index b24486d..1b95882 100644 --- a/ext/SliceSamplingTuringExt.jl +++ b/ext/SliceSamplingTuringExt.jl @@ -39,13 +39,19 @@ function Turing.Inference.getparams( ) return state.transition.params end + +function Turing.Inference.getlogp_external( + ::Turing.DynamicPPL.Model, t::SliceSampling.Transition, state +) + return t.lp +end # end function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction) model = ℓ.model vi = Turing.DynamicPPL.VarInfo(rng, model, Turing.SampleFromUniform()) vi_spl = last(Turing.DynamicPPL.evaluate!!(model, rng, vi, Turing.SampleFromUniform())) - θ = vi_spl[:] + θ = vi_spl[:] init_attempt_count = 1 while !all(isfinite.(θ)) From fd5805058dde9be1d25b41ac91e12860d765276c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 14 Jul 2025 16:14:24 +0100 Subject: [PATCH 2/4] Add tests --- test/turing.jl | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/test/turing.jl b/test/turing.jl index f6c2e89..e09149f 100644 --- a/test/turing.jl +++ b/test/turing.jl @@ -8,8 +8,10 @@ return nothing end + @model logp_check() = x ~ Normal() + n_samples = 1000 - model = demo() + model = demo() @testset for sampler in [ RandPermGibbs(Slice(1)), @@ -30,6 +32,11 @@ ) chain = sample(model, externalsampler(sampler), n_samples; progress=false) + + chain_logp_check = sample( + logp_check(), externalsampler(sampler), 100; progress=false + ) + @test isapprox(logpdf.(Normal(), chain_logp_check[:x]), chain_logp_check[:logp]) end @testset "gibbs($sampler)" for sampler in [ @@ -46,5 +53,10 @@ n_samples; progress=false, ) + + chain_logp_check = sample( + logp_check(), Turing.Gibbs(:x => externalsampler(sampler)), 100; progress=false + ) + @test isapprox(logpdf.(Normal(), chain_logp_check[:x]), chain_logp_check[:logp]) end end From 2391a06d15ebe9385e87348e3764694ee83de5d2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 17 Jul 2025 11:31:26 +0100 Subject: [PATCH 3/4] Fix tests --- ext/SliceSamplingTuringExt.jl | 23 ++++++++--------------- test/turing.jl | 22 ++++++++++++++++++---- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/ext/SliceSamplingTuringExt.jl b/ext/SliceSamplingTuringExt.jl index 1b95882..44f9f6a 100644 --- a/ext/SliceSamplingTuringExt.jl +++ b/ext/SliceSamplingTuringExt.jl @@ -22,24 +22,17 @@ Turing.Inference.isgibbscomponent(::SliceSampling.Slice) = true Turing.Inference.isgibbscomponent(::SliceSampling.SliceSteppingOut) = true Turing.Inference.isgibbscomponent(::SliceSampling.SliceDoublingOut) = true -function Turing.Inference.getparams( - ::Turing.DynamicPPL.Model, sample::SliceSampling.UnivariateSliceState -) +const SliceSamplingStates = Union{ + SliceSampling.UnivariateSliceState, + SliceSampling.GibbsState, + SliceSampling.HitAndRunState, + SliceSampling.LatentSliceState, + SliceSampling.GibbsPolarSliceState, +} +function Turing.Inference.getparams(::Turing.DynamicPPL.Model, sample::SliceSamplingStates) return sample.transition.params end -function Turing.Inference.getparams( - ::Turing.DynamicPPL.Model, state::SliceSampling.GibbsState -) - return state.transition.params -end - -function Turing.Inference.getparams( - ::Turing.DynamicPPL.Model, state::SliceSampling.HitAndRunState -) - return state.transition.params -end - function Turing.Inference.getlogp_external( ::Turing.DynamicPPL.Model, t::SliceSampling.Transition, state ) diff --git a/test/turing.jl b/test/turing.jl index e09149f..f5c6104 100644 --- a/test/turing.jl +++ b/test/turing.jl @@ -8,7 +8,10 @@ return nothing end - @model logp_check() = x ~ Normal() + @model function logp_check() + a ~ Normal() + return b ~ Normal() + end n_samples = 1000 model = demo() @@ -36,7 +39,11 @@ chain_logp_check = sample( logp_check(), externalsampler(sampler), 100; progress=false ) - @test isapprox(logpdf.(Normal(), chain_logp_check[:x]), chain_logp_check[:logp]) + @test isapprox( + logpdf.(Normal(), chain_logp_check[:a]) .+ + logpdf.(Normal(), chain_logp_check[:b]), + chain_logp_check[:lp], + ) end @testset "gibbs($sampler)" for sampler in [ @@ -55,8 +62,15 @@ ) chain_logp_check = sample( - logp_check(), Turing.Gibbs(:x => externalsampler(sampler)), 100; progress=false + logp_check(), + Turing.Gibbs(:a => externalsampler(sampler), :b => externalsampler(sampler)), + 100; + progress=false, + ) + @test isapprox( + logpdf.(Normal(), chain_logp_check[:a]) .+ + logpdf.(Normal(), chain_logp_check[:b]), + chain_logp_check[:lp], ) - @test isapprox(logpdf.(Normal(), chain_logp_check[:x]), chain_logp_check[:logp]) end end From 1993b40c187377d0f0dc1d7ddda6de7c2c475aae Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 17 Jul 2025 14:25:09 +0100 Subject: [PATCH 4/4] Don't set arch=x64 for macos-latest --- .github/workflows/CI.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f5c5a80..de618e8 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -25,14 +25,11 @@ jobs: - ubuntu-latest - macOS-latest - windows-latest - arch: - - x64 steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - arch: ${{ matrix.arch }} - uses: julia-actions/cache@v1 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1