- 
                Notifications
    You must be signed in to change notification settings 
- Fork 229
Gibbs sampler #2647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Gibbs sampler #2647
Changes from 5 commits
c0158ea
              a972b5a
              bdb7f73
              c3cc773
              714c1e8
              97c571d
              94b723d
              891ac14
              2058ae5
              b0812a3
              d910312
              4b1dc2f
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -54,17 +54,48 @@ chain = sample(model, Gibbs( | |
| ), 1000) | ||
| ``` | ||
| """ | ||
| struct GibbsConditional{S,C} <: InferenceAlgorithm | ||
| struct GibbsConditional{C} <: InferenceAlgorithm | ||
| conditional::C | ||
|  | ||
| function GibbsConditional(sym::Symbol, conditional::C) where {C} | ||
| return new{sym,C}(conditional) | ||
| return new{C}(conditional) | ||
| end | ||
| end | ||
|  | ||
| # Mark GibbsConditional as a valid Gibbs component | ||
| isgibbscomponent(::GibbsConditional) = true | ||
|  | ||
| # Required methods for Gibbs constructor | ||
| Base.length(::GibbsConditional) = 1 # Each GibbsConditional handles one variable | ||
|  | ||
| """ | ||
| find_global_varinfo(context, fallback_vi) | ||
|  | ||
| Traverse the context stack to find global variable information from | ||
| GibbsContext, ConditionContext, FixedContext, etc. | ||
| """ | ||
| function find_global_varinfo(context, fallback_vi) | ||
| # Start with the given context and traverse down | ||
| current_context = context | ||
|  | ||
| while current_context !== nothing | ||
| if current_context isa GibbsContext | ||
| # Found GibbsContext, return its global varinfo | ||
| return get_global_varinfo(current_context) | ||
| elseif hasproperty(current_context, :childcontext) && | ||
| isdefined(DynamicPPL, :childcontext) | ||
|          | ||
| # Move to child context if it exists | ||
| current_context = DynamicPPL.childcontext(current_context) | ||
| else | ||
| # No more child contexts | ||
| break | ||
| end | ||
| end | ||
|  | ||
| # If no GibbsContext found, use the fallback | ||
| return fallback_vi | ||
| end | ||
|  | ||
| """ | ||
| DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi) | ||
|  | ||
|  | @@ -97,12 +128,10 @@ function AbstractMCMC.step( | |
| alg = sampler.alg | ||
|  | ||
| # For GibbsConditional within Gibbs, we need to get all variable values | ||
| # Check if we're in a Gibbs context | ||
| global_vi = if hasproperty(model, :context) && model.context isa GibbsContext | ||
| # We're in a Gibbs context, get the global varinfo | ||
| get_global_varinfo(model.context) | ||
| # Traverse the context stack to find all conditioned/fixed/Gibbs variables | ||
| global_vi = if hasproperty(model, :context) | ||
|          | ||
| find_global_varinfo(model.context, state) | ||
| else | ||
| # We're not in a Gibbs context, use the current state | ||
| state | ||
| end | ||
|  | ||
|  | @@ -119,34 +148,10 @@ function AbstractMCMC.step( | |
| updated = rand(rng, conddist) | ||
|  | ||
| # Update the variable in state | ||
| # We need to get the actual VarName for this variable | ||
| # The symbol S tells us which variable to update | ||
| vn = VarName{S}() | ||
|  | ||
| # Check if the variable needs to be a vector | ||
| new_vi = if haskey(state, vn) | ||
| # Update the existing variable | ||
| DynamicPPL.setindex!!(state, updated, vn) | ||
| else | ||
| # Try to find the variable with indices | ||
| # This handles cases where the variable might have indices | ||
| local updated_vi = state | ||
| found = false | ||
| for key in keys(state) | ||
| if DynamicPPL.getsym(key) == S | ||
| updated_vi = DynamicPPL.setindex!!(state, updated, key) | ||
| found = true | ||
| break | ||
| end | ||
| end | ||
| if !found | ||
| error("Could not find variable $S in VarInfo") | ||
| end | ||
| updated_vi | ||
| end | ||
|  | ||
| # Update log joint probability | ||
| new_vi = last(DynamicPPL.evaluate!!(model, new_vi, DynamicPPL.DefaultContext())) | ||
| # The Gibbs sampler ensures that state only contains one variable | ||
| # Get the variable name from the keys | ||
| varname = first(keys(state)) | ||
| new_vi = DynamicPPL.setindex!!(state, updated, varname) | ||
|  | ||
| return nothing, new_vi | ||
| end | ||
|  | @@ -166,80 +171,3 @@ function setparams_varinfo!!( | |
| # the state is nothing and we don't need to update anything | ||
| return params | ||
| end | ||
|  | ||
| """ | ||
| gibbs_initialstep_recursive( | ||
| rng, model, sampler::GibbsConditional, target_varnames, global_vi, prev_state | ||
| ) | ||
|  | ||
| Initialize the GibbsConditional sampler. | ||
| """ | ||
| function gibbs_initialstep_recursive( | ||
| rng::Random.AbstractRNG, | ||
| model::DynamicPPL.Model, | ||
| sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional}, | ||
| target_varnames::AbstractVector{<:VarName}, | ||
| global_vi::DynamicPPL.AbstractVarInfo, | ||
| prev_state, | ||
| ) | ||
| # GibbsConditional doesn't need any special initialization | ||
| # Just perform one sampling step | ||
| return gibbs_step_recursive( | ||
| rng, model, sampler_wrapped, target_varnames, global_vi, nothing | ||
| ) | ||
| end | ||
|  | ||
| """ | ||
| gibbs_step_recursive( | ||
| rng, model, sampler::GibbsConditional, target_varnames, global_vi, state | ||
| ) | ||
|  | ||
| Perform a single step of GibbsConditional sampling. | ||
| """ | ||
| function gibbs_step_recursive( | ||
| rng::Random.AbstractRNG, | ||
| model::DynamicPPL.Model, | ||
| sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional{S}}, | ||
| target_varnames::AbstractVector{<:VarName}, | ||
| global_vi::DynamicPPL.AbstractVarInfo, | ||
| state, | ||
| ) where {S} | ||
| sampler = sampler_wrapped.alg | ||
|  | ||
| # Extract conditioned values as a NamedTuple | ||
| # Include both random variables and observed data | ||
| condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple) | ||
| condvals_obs = NamedTuple{keys(model.args)}(model.args) | ||
| condvals = merge(condvals_vars, condvals_obs) | ||
|  | ||
| # Get the conditional distribution | ||
| conddist = sampler.conditional(condvals) | ||
|  | ||
| # Sample from the conditional distribution | ||
| updated = rand(rng, conddist) | ||
|  | ||
| # Update the variable in global_vi | ||
| # We need to get the actual VarName for this variable | ||
| # The symbol S tells us which variable to update | ||
| vn = VarName{S}() | ||
|  | ||
| # Check if the variable needs to be a vector | ||
| if haskey(global_vi, vn) | ||
| # Update the existing variable | ||
| global_vi = DynamicPPL.setindex!!(global_vi, updated, vn) | ||
| else | ||
| # Try to find the variable with indices | ||
| # This handles cases where the variable might have indices | ||
| for key in keys(global_vi) | ||
| if DynamicPPL.getsym(key) == S | ||
| global_vi = DynamicPPL.setindex!!(global_vi, updated, key) | ||
| break | ||
| end | ||
| end | ||
| end | ||
|  | ||
| # Update log joint probability | ||
| global_vi = last(DynamicPPL.evaluate!!(model, global_vi, DynamicPPL.DefaultContext())) | ||
|  | ||
| return nothing, global_vi | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this being called somewhere? Might be, but I don't remember having a need for
lengthof samplers.