Skip to content

Conversation

josephdviviano
Copy link
Collaborator

  • I've read the .github/CONTRIBUTING.md file
  • My code follows the typing guidelines
  • I've added appropriate tests
  • I've run pre-commit hooks locally

Description

changes:

  • added documentation
  • small change to preprocessors, to allow them to cast outputs to defined dtypes.
  • added a RecurrentDiscretePolicyEstimator, which accepts a carry in addition to States during forward.
  • The sampler and log probability calculations now use an EstimatorAdapter for all communication with estimators. These adapters handle all estimator-specific logic - for example, the management of the carry for RecurrentDiscretePolicyEstimator (but there might be many more applications of this). The DefaultEstimatorAdapter replicates exactly the behaviour of the library before this addition, and is used unless the user specifies some other adapter.
  • EstimatorAdapters can either be vectorized, or non-vectorized. Most, but not all, EstimatorAdapters can leverage vectorized operations during the computation of the log probabilities stage. The code supports both paths. During Sampling, operations are not vectorized due to the sequential nature of the operation.
  • EstimatorAdapters have the following three methods:
    • init_context - returns the appropriate Context class, currently, only RolloutContext is availbale.
    • compute - given some states, a context, and a step mask, interfaces with the estimator.
    • record - records per-step artifacts into trajectory level buffers. These objects are what are eventually emitted by the context at the end of a rollout.
  • Adapters use RolloutContext to maintain the state of a rollout during sampling / log_prob calculations. They have a finalize() method which returns a dict of artifacts to be stored.
  • For most use cases the default feature set of DefaultEstimatorAdapters and RolloutContext will be sufficient, but it allows the user to define new contexts in cases where new variables need to be tracked during sampling.
  • Added a suite of basic recurrent modules.
  • All probability calculations optionally accept unique adapters for pf and pb. This is in the extremely unusal case that a user needs a different estimator logic for the backward and forward trajectory, and therefore a different adapter type.
  • Added tests ensuring:
    • The new modules work as expected (smoke tests).
    • The new probability-calculation logic works as expected. These tests duplicate the original probability calculation logic, to ensure the adaptor-supported logic is numerically identical.
    • Integration tests for estimators, samplers, and adapters.

@josephdviviano josephdviviano self-assigned this Oct 9, 2025
@josephdviviano josephdviviano requested a review from saleml October 9, 2025 16:36
Copy link
Collaborator

@younik younik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Adapter adds a layer of complexity that I don't see beneficial. If we need to support RNN, for example, shall we just change the interface of the Estimator to also return extra? In the case of non-recurrent NN, this is empty, in the case of recurrent this can be the hidden.

Then the user can write their preferred estimator.

Comment on lines 77 to 92
pf_estimator = RecurrentDiscretePolicyEstimator(
module=model,
n_actions=env.n_actions,
is_backward=False,
).to(device)

# GFlowNet (Trajectory Balance), tree DAG -> pb=None, constant_pb=True,
# Use a recurrent adapter for the PF.
gflownet = TBGFlowNet(
pf=pf_estimator,
pb=None,
init_logZ=0.0,
constant_pb=True,
pf_adapter=RecurrentEstimatorAdapter(pf_estimator),
)
gflownet = gflownet.to(device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the user needs to specify an adapter when there is a 1-1 correspondence?
We already know we need the RecurrentEstimatorAdapter for RecurrentDiscretePolicyEstimator

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't be generally true, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can it be? I see you can do multiple adapters for a single estimator, but you can't really have an adapter for multiple estimators, as they are too entangled. Indeed, the purpose of the adapter i "handling arbitrary estimator interfaces", so the adapter needs to know the specific estimator interface to be able to adapt it to standard API.
This means the adapter must already know what the estimator type is.

Copy link
Collaborator Author

@josephdviviano josephdviviano Oct 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I admit we're getting into some fairly esoteric territory but I am just worried about boxing ourselves in too much. I can think of examples in both directions:

  1. the default estimator works for many kinds of estimators (currently).
  2. for a single estimator (transformer sequence model), I might want to do non-vectorized or vectorized log-prob calculations. So there are at minimum multiple configurations of the same adaptor for such an estimator.

I like your idea though, it seems cleaner. What do you think of this? Also interested in your POV @saleml / @hyeok9855 :

  • Each estimator automatically inherits the DefaultAdaptor.
  • New estimators written can define a different default adaptor.
  • The sampler uses whatever adaptor is provided by the estimator.
  • When instantiating some estimator, one can provide adaptors to override the default one.
  • The gflownet no longer needs to know anything about adaptors.

This would allow for every estimator to be explicitly associated with an adaptor but allow flexibility to mix/match them due to some specific need the rare user will come up with. And again, for 99% of cases, people will just use the defaultadaptor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a default adaptor for every estimator.

@josephdviviano
Copy link
Collaborator Author

josephdviviano commented Oct 12, 2025

@younik you're looking at this PR as thinking all of this complexity is only to support recurrent estimators, when in fact it's just a way of us supporting arbitrary estimators, and I'm using recurrent estimators as an example.

By this I mean, in the sampling process of the gflownet, the estimator accepts (state, thing_a, thing_b, ...). We created a special case conditioning - but people keep coming up with new ideas .. they need a way to do this sort of thing without re-writing the sampler / log prob calculation loops. This was a way of allowing for this to happen relatively transparently. One can invent some new adapter, but they don't need to worry about getting the sampler loop right.

The goal of this library is to be a good platform for researchers to build on. So we can't make it to difficult for people to innovate on a core element of the algorithm without needing to re-write the most bug prone elements of the algorithm (the sampling / log prob loops are the easiest places to introduce bugs).

For recurrent estimators, I'm OK with putting the carry in the estimator IIF there's a JIT-compatible way of doing so and removing the recurrent adaptor, but that doesn't solve the broader problem of not supporting arbitrary estimator interfaces. Besides - the recurrent adaptor logic is very simple.

If you think this is a non-issue (meaning we have some other way of handling arbitrary estimator interfaces), can you state your case / propose a design? I don't feel confident in saying that our current sampler logic covers all possible estimator interfaces across the space of possible algorithms. It's not clear to me that we can shove all of this complexity into the environment's State object.

And, if you don't use adapters, it's completely transparent to the user. So I'm not convinced this added complexity is actually bad for users unless they have to re-write a sampler and all the logprob calculations -- this is for sure hard -- i had to do it for this PR and it was really tricky to get right. In other words, if you're in a position where you needed to re-write the samplers / log prob calculation loops, you're in a much better position if you only need to write a new adapter.

I have put an enormous amount of work into this PR, please try to discuss it at a level that respects the amount of work I've done here.

tweak of how default preprocessor is defined
Copy link
Collaborator

@younik younik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant is to simply have the Adapter class merged into Estimator, because I think the adapter logic is too dependent on Estimator, so it doesn't add up anything to the expected API (you are saying Estimator can do anything, but the adapter needs to transform anything to a fixed API).

Concretely:

  1. let the Estimator have init context, which returns a tensordict (see inline comment)
  2. The forward function will also accept the ctx which is a tensordict, and output a new one. Since it is the estimator who created it, we can know what field to expect.

The user can easily define custom logics by subclassing Estimator.

Comment on lines 109 to 114
"conditioning",
"carry",
"trajectory_log_probs",
"trajectory_estimator_outputs",
"current_estimator_output",
"extras",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't find any place where trajectory_log_probs, trajectory_estimator_outputs, current_estimator_output are used. This is like the default Context, so it should have the minimal things.
As you mentioned, we can't know what the user can come up, so we should let the user define this for particulate cases.

Also, carry and conditional are special cases, can they just go in extra?
And finally, this is a dictionary of tensors, with device and batch size. Can this be just a tensordict?

See general comment for why I believe we don't need this class

Copy link
Collaborator Author

@josephdviviano josephdviviano Oct 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I'm ok with removing the context class entirely, and instead using a tensordict. I'll have to move the relevant logic back into the adapter, but that's ok with me. Thanks for pointing that out!
  2. I don't think supporting recurrent policies or conditional generation are boutique enough to consider them "special cases". I'd prefer to leave these as explicit fields (insofar as there will be a spec for the "expected tensordict").
  3. I don't really understand your comment that you don't see where "trajectory_log_probs, trajectory_estimator_outputs, current_estimator_output" are used , but I'll try to simplify this stuff when I make the move to a tensordict :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, an update.

I spent a few hours trying to remove the context class entirely, but it actually makes the code much more complicated because you need to pad everything everywhere to ensure the size of the batch does not change at all during the trajectory. This is because not all trajectories terminate at the same time, and we slice along the batch dimension to remove these elements from computation, but tensordict enforces that the batch size of every tensor in the dict does not change.

I'm going to propose that we keep the current structure. We can explore a simplification down the road, but I actually think this slots class is the simplest option.

What I will do instead is do a pass of the API / docstrings, streamlining them a bit.

Comment on lines 77 to 92
pf_estimator = RecurrentDiscretePolicyEstimator(
module=model,
n_actions=env.n_actions,
is_backward=False,
).to(device)

# GFlowNet (Trajectory Balance), tree DAG -> pb=None, constant_pb=True,
# Use a recurrent adapter for the PF.
gflownet = TBGFlowNet(
pf=pf_estimator,
pb=None,
init_logZ=0.0,
constant_pb=True,
pf_adapter=RecurrentEstimatorAdapter(pf_estimator),
)
gflownet = gflownet.to(device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can it be? I see you can do multiple adapters for a single estimator, but you can't really have an adapter for multiple estimators, as they are too entangled. Indeed, the purpose of the adapter i "handling arbitrary estimator interfaces", so the adapter needs to know the specific estimator interface to be able to adapt it to standard API.
This means the adapter must already know what the estimator type is.

…tion, added some useful safeguard assertions, bugfix related to saving estimator_outputs path
@josephdviviano
Copy link
Collaborator Author

OK, I've streamlined the Context Adapter elements. I'm now going to move the Adapters into the Estimators.

@josephdviviano
Copy link
Collaborator Author

I've unified the adaptor API, tests are matching and passing. I think this PR does a good job of addressing your previous concerns @younik

masked_conditioning = None
step_mask = ~dones

valid_actions, actions_log_probs, estimator_outputs = self.sample_actions(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I be calling sample_actions which calls compute_dist? Or something like that.

@hyeok9855
Copy link
Collaborator

I'm still reviewing this huge PR. Let me briefly summarise my understanding of the high-level purpose of this PR:

  1. Non-vectorised probability calculation, which is basically a for loop over T (trajectory length).
  2. A context for an estimator that generalises (a) trajectory-wise contexts like conditioning, and (b) per-step contexts like recurrent carry.
  3. The `Adapter' handles the contexts and does what has been done by estimators (e.g., estimator_output -> probs)

First, I believe bullet 1 is independent of the adapter-related things, so it can be separated from this PR just to make it smaller.

To be brutally honest, I still think the Adapter makes things unnecessarily complicated, and there must be a simpler way to support those functionalities, e.g., as @younik raised before, we might be able to get rid of the adapter and incorporate all under the estimator.

@josephdviviano
Copy link
Collaborator Author

Based on the discussion today, I've implemented changes to this PR here:

#413

which should be merged, before we merge this one.

Make adapters logic part of estimators via `PolicyMixin`
Copy link
Collaborator

@hyeok9855 hyeok9855 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @josephdviviano for your hard work!

This seems much better than before, but I still think that it could be improved. But I believe the structure overall is good to go, and refactorings can be done in the follow-up PRs (maybe I can do that, before diffusion sampling stuffs).

Please check my comments.

Comment on lines +181 to +184
# Build the distribution.
dist = self.to_probability_distribution(
states_active, estimator_outputs, **policy_kwargs
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems better to move .to_probability_distribution to under PolicyMixin.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused. This is in PolicyMixin.compute_dist.

"""
precopmputed_estimator_outputs = getattr(ctx, "current_estimator_output", None)

if step_mask is None and precopmputed_estimator_outputs is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove step_mask here and redefine compute_dist under RecurrentPolicyMixin?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand, there are only two cases,

  1. RecurrentPolicyMixin <=> non-vectorized prob calculation <=> step_mask is not None
  2. Non-recurrent PolicyMixin <=> vectorized prob calculation <=> step_mask is None

So, I think it would be cleaner to get rid of step_mask from general PolicyMixin and make it specific to the Recurrent PolicyMixin.

(Of course, I can imagine Non-recurrent policy + non-vectorized prob calculation (e.g., for memory efficiency), but I guess in this case we can just use RecurrentPolicyMixin.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a radical idea, but perhaps it would be better to make get_trajectory_pfs and get_trajectory_pbs methods of PolicyMixin.

Copy link
Collaborator Author

@josephdviviano josephdviviano Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be comfortable putting get_trajectory/transition_* back into the estimator via a PolicyMixin. I agree, followup PR. The logic more comfortably lives in the Policy rather than anywhere else.

It would also help me possibly generalize your localsearch sampler work, which I would like to do in a follow up. I need to study better how that code works.

Copy link
Collaborator Author

@josephdviviano josephdviviano Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OTOH, I don't want to remove the step mask. It means we have to duplicate more code for the recurrent estimator, and means the non-vectorized path would have to be removed for the PolicyMixin, which someone might find useful.

There are actually 3 cases. You can run vectorized / non vectorized for the PolicyMixin. Only the recurrent mixin disables vectorized.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general this is designed to be extensible. We don't want to remove everything that isn't currently required -- because we expect in the future the need to design other PolicyMixins.

@josephdviviano josephdviviano merged commit 29c7ba4 into master Oct 14, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants