-
Notifications
You must be signed in to change notification settings - Fork 48
Generalize samplers #410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generalize samplers #410
Conversation
…or logic, recurrent estimators, recurrent modules
…com:GFNOrg/torchgfn into generalize_samplers
… generalize_samplers
…tion paths. simplified the API of adapters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Adapter adds a layer of complexity that I don't see beneficial. If we need to support RNN, for example, shall we just change the interface of the Estimator to also return extra? In the case of non-recurrent NN, this is empty, in the case of recurrent this can be the hidden.
Then the user can write their preferred estimator.
pf_estimator = RecurrentDiscretePolicyEstimator( | ||
module=model, | ||
n_actions=env.n_actions, | ||
is_backward=False, | ||
).to(device) | ||
|
||
# GFlowNet (Trajectory Balance), tree DAG -> pb=None, constant_pb=True, | ||
# Use a recurrent adapter for the PF. | ||
gflownet = TBGFlowNet( | ||
pf=pf_estimator, | ||
pb=None, | ||
init_logZ=0.0, | ||
constant_pb=True, | ||
pf_adapter=RecurrentEstimatorAdapter(pf_estimator), | ||
) | ||
gflownet = gflownet.to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the user needs to specify an adapter when there is a 1-1 correspondence?
We already know we need the RecurrentEstimatorAdapter
for RecurrentDiscretePolicyEstimator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't be generally true, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can it be? I see you can do multiple adapters for a single estimator, but you can't really have an adapter for multiple estimators, as they are too entangled. Indeed, the purpose of the adapter i "handling arbitrary estimator interfaces", so the adapter needs to know the specific estimator interface to be able to adapt it to standard API.
This means the adapter must already know what the estimator type is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I admit we're getting into some fairly esoteric territory but I am just worried about boxing ourselves in too much. I can think of examples in both directions:
- the default estimator works for many kinds of estimators (currently).
- for a single estimator (transformer sequence model), I might want to do non-vectorized or vectorized log-prob calculations. So there are at minimum multiple configurations of the same adaptor for such an estimator.
I like your idea though, it seems cleaner. What do you think of this? Also interested in your POV @saleml / @hyeok9855 :
- Each estimator automatically inherits the DefaultAdaptor.
- New estimators written can define a different default adaptor.
- The sampler uses whatever adaptor is provided by the estimator.
- When instantiating some estimator, one can provide adaptors to override the default one.
- The gflownet no longer needs to know anything about adaptors.
This would allow for every estimator to be explicitly associated with an adaptor but allow flexibility to mix/match them due to some specific need the rare user will come up with. And again, for 99% of cases, people will just use the defaultadaptor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a default adaptor for every estimator.
@younik you're looking at this PR as thinking all of this complexity is only to support recurrent estimators, when in fact it's just a way of us supporting arbitrary estimators, and I'm using recurrent estimators as an example. By this I mean, in the sampling process of the gflownet, the estimator accepts (state, thing_a, thing_b, ...). We created a special case The goal of this library is to be a good platform for researchers to build on. So we can't make it to difficult for people to innovate on a core element of the algorithm without needing to re-write the most bug prone elements of the algorithm (the sampling / log prob loops are the easiest places to introduce bugs). For recurrent estimators, I'm OK with putting the carry in the estimator IIF there's a JIT-compatible way of doing so and removing the recurrent adaptor, but that doesn't solve the broader problem of not supporting arbitrary estimator interfaces. Besides - the recurrent adaptor logic is very simple. If you think this is a non-issue (meaning we have some other way of handling arbitrary estimator interfaces), can you state your case / propose a design? I don't feel confident in saying that our current sampler logic covers all possible estimator interfaces across the space of possible algorithms. It's not clear to me that we can shove all of this complexity into the environment's State object. And, if you don't use adapters, it's completely transparent to the user. So I'm not convinced this added complexity is actually bad for users unless they have to re-write a sampler and all the logprob calculations -- this is for sure hard -- i had to do it for this PR and it was really tricky to get right. In other words, if you're in a position where you needed to re-write the samplers / log prob calculation loops, you're in a much better position if you only need to write a new adapter. I have put an enormous amount of work into this PR, please try to discuss it at a level that respects the amount of work I've done here. |
tweak of how default preprocessor is defined
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I meant is to simply have the Adapter class merged into Estimator, because I think the adapter logic is too dependent on Estimator, so it doesn't add up anything to the expected API (you are saying Estimator can do anything, but the adapter needs to transform anything to a fixed API).
Concretely:
- let the Estimator have init context, which returns a tensordict (see inline comment)
- The forward function will also accept the ctx which is a tensordict, and output a new one. Since it is the estimator who created it, we can know what field to expect.
The user can easily define custom logics by subclassing Estimator.
src/gfn/samplers.py
Outdated
"conditioning", | ||
"carry", | ||
"trajectory_log_probs", | ||
"trajectory_estimator_outputs", | ||
"current_estimator_output", | ||
"extras", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't find any place where trajectory_log_probs
, trajectory_estimator_outputs
, current_estimator_output
are used. This is like the default Context, so it should have the minimal things.
As you mentioned, we can't know what the user can come up, so we should let the user define this for particulate cases.
Also, carry and conditional are special cases, can they just go in extra?
And finally, this is a dictionary of tensors, with device and batch size. Can this be just a tensordict?
See general comment for why I believe we don't need this class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I'm ok with removing the context class entirely, and instead using a tensordict. I'll have to move the relevant logic back into the adapter, but that's ok with me. Thanks for pointing that out!
- I don't think supporting recurrent policies or conditional generation are boutique enough to consider them "special cases". I'd prefer to leave these as explicit fields (insofar as there will be a spec for the "expected tensordict").
- I don't really understand your comment that you don't see where "trajectory_log_probs, trajectory_estimator_outputs, current_estimator_output" are used , but I'll try to simplify this stuff when I make the move to a tensordict :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, an update.
I spent a few hours trying to remove the context class entirely, but it actually makes the code much more complicated because you need to pad everything everywhere to ensure the size of the batch does not change at all during the trajectory. This is because not all trajectories terminate at the same time, and we slice along the batch dimension to remove these elements from computation, but tensordict enforces that the batch size of every tensor in the dict does not change.
I'm going to propose that we keep the current structure. We can explore a simplification down the road, but I actually think this slots class is the simplest option.
What I will do instead is do a pass of the API / docstrings, streamlining them a bit.
pf_estimator = RecurrentDiscretePolicyEstimator( | ||
module=model, | ||
n_actions=env.n_actions, | ||
is_backward=False, | ||
).to(device) | ||
|
||
# GFlowNet (Trajectory Balance), tree DAG -> pb=None, constant_pb=True, | ||
# Use a recurrent adapter for the PF. | ||
gflownet = TBGFlowNet( | ||
pf=pf_estimator, | ||
pb=None, | ||
init_logZ=0.0, | ||
constant_pb=True, | ||
pf_adapter=RecurrentEstimatorAdapter(pf_estimator), | ||
) | ||
gflownet = gflownet.to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can it be? I see you can do multiple adapters for a single estimator, but you can't really have an adapter for multiple estimators, as they are too entangled. Indeed, the purpose of the adapter i "handling arbitrary estimator interfaces", so the adapter needs to know the specific estimator interface to be able to adapt it to standard API.
This means the adapter must already know what the estimator type is.
…tion, added some useful safeguard assertions, bugfix related to saving estimator_outputs path
OK, I've streamlined the Context Adapter elements. I'm now going to move the Adapters into the Estimators. |
I've unified the adaptor API, tests are matching and passing. I think this PR does a good job of addressing your previous concerns @younik |
masked_conditioning = None | ||
step_mask = ~dones | ||
|
||
valid_actions, actions_log_probs, estimator_outputs = self.sample_actions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I be calling sample_actions
which calls compute_dist
? Or something like that.
I'm still reviewing this huge PR. Let me briefly summarise my understanding of the high-level purpose of this PR:
First, I believe bullet 1 is independent of the adapter-related things, so it can be separated from this PR just to make it smaller. To be brutally honest, I still think the |
Based on the discussion today, I've implemented changes to this PR here: which should be merged, before we merge this one. |
Make adapters logic part of estimators via `PolicyMixin`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @josephdviviano for your hard work!
This seems much better than before, but I still think that it could be improved. But I believe the structure overall is good to go, and refactorings can be done in the follow-up PRs (maybe I can do that, before diffusion sampling stuffs).
Please check my comments.
# Build the distribution. | ||
dist = self.to_probability_distribution( | ||
states_active, estimator_outputs, **policy_kwargs | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems better to move .to_probability_distribution
to under PolicyMixin.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused. This is in PolicyMixin.compute_dist
.
src/gfn/estimators.py
Outdated
""" | ||
precopmputed_estimator_outputs = getattr(ctx, "current_estimator_output", None) | ||
|
||
if step_mask is None and precopmputed_estimator_outputs is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove step_mask
here and redefine compute_dist
under RecurrentPolicyMixin
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I understand, there are only two cases,
- RecurrentPolicyMixin <=> non-vectorized prob calculation <=> step_mask is not None
- Non-recurrent PolicyMixin <=> vectorized prob calculation <=> step_mask is None
So, I think it would be cleaner to get rid of step_mask from general PolicyMixin and make it specific to the Recurrent PolicyMixin.
(Of course, I can imagine Non-recurrent policy + non-vectorized prob calculation (e.g., for memory efficiency), but I guess in this case we can just use RecurrentPolicyMixin.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be a radical idea, but perhaps it would be better to make get_trajectory_pfs
and get_trajectory_pbs
methods of PolicyMixin
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd be comfortable putting get_trajectory/transition_*
back into the estimator via a PolicyMixin. I agree, followup PR. The logic more comfortably lives in the Policy
rather than anywhere else.
It would also help me possibly generalize your localsearch sampler work, which I would like to do in a follow up. I need to study better how that code works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OTOH, I don't want to remove the step mask. It means we have to duplicate more code for the recurrent estimator, and means the non-vectorized path would have to be removed for the PolicyMixin
, which someone might find useful.
There are actually 3 cases. You can run vectorized / non vectorized for the PolicyMixin. Only the recurrent mixin disables vectorized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general this is designed to be extensible. We don't want to remove everything that isn't currently required -- because we expect in the future the need to design other PolicyMixin
s.
… generalize_samplers
Description
changes:
RecurrentDiscretePolicyEstimator
, which accepts acarry
in addition to States during forward.EstimatorAdapter
for all communication with estimators. These adapters handle all estimator-specific logic - for example, the management of thecarry
forRecurrentDiscretePolicyEstimator
(but there might be many more applications of this). TheDefaultEstimatorAdapter
replicates exactly the behaviour of the library before this addition, and is used unless the user specifies some other adapter.EstimatorAdapter
s can either be vectorized, or non-vectorized. Most, but not all, EstimatorAdapters can leverage vectorized operations during the computation of the log probabilities stage. The code supports both paths. During Sampling, operations are not vectorized due to the sequential nature of the operation.EstimatorAdapters
have the following three methods:init_context
- returns the appropriate Context class, currently, onlyRolloutContext
is availbale.compute
- given some states, a context, and a step mask, interfaces with the estimator.record
- records per-step artifacts into trajectory level buffers. These objects are what are eventually emitted by the context at the end of a rollout.RolloutContext
to maintain the state of a rollout during sampling / log_prob calculations. They have afinalize()
method which returns a dict of artifacts to be stored.DefaultEstimatorAdapters
andRolloutContext
will be sufficient, but it allows the user to define new contexts in cases where new variables need to be tracked during sampling.