diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 5a3cb9b8..53543328 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -1,6 +1,6 @@ import math from abc import ABC, abstractmethod -from typing import Any, Generic, Tuple, TypeVar +from typing import Any, Generic, Optional, Tuple, TypeVar import torch import torch.nn as nn @@ -204,12 +204,15 @@ def get_pfs_and_pbs( def get_trajectories_scores( self, trajectories: Trajectories, + log_rewards: Optional[torch.Tensor] = None, recalculate_all_logprobs: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Given a batch of trajectories, calculate forward & backward policy scores. Args: trajectories: Trajectories to evaluate. + log_rewards: Log rewards to use. If None, use the log_rewards from the + trajectories. recalculate_all_logprobs: Whether to re-evaluate all logprobs. Returns: A tuple of float tensors of shape (n_trajectories,) @@ -224,7 +227,10 @@ def get_trajectories_scores( total_log_pf_trajectories = log_pf_trajectories.sum(dim=0) total_log_pb_trajectories = log_pb_trajectories.sum(dim=0) - log_rewards = trajectories.log_rewards + if log_rewards is None: + log_rewards = trajectories.log_rewards + assert log_rewards is not None + assert log_rewards.shape == (trajectories.n_trajectories,) if math.isfinite(self.log_reward_clip_min) and log_rewards is not None: log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index 8ac682b0..37591f07 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -1,5 +1,5 @@ import math -from typing import Tuple +from typing import Optional, Tuple import torch @@ -94,12 +94,18 @@ def get_pfs_and_pbs( ) def get_scores( - self, env: Env, transitions: Transitions, recalculate_all_logprobs: bool = True + self, + env: Env, + transitions: Transitions, + log_rewards: Optional[torch.Tensor] = None, + recalculate_all_logprobs: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Given a batch of transitions, calculate the scores. Args: transitions: a batch of transitions. + log_rewards: log rewards of the transitions. If None, use the log rewards + from the transitions. Unless recalculate_all_logprobs=True, in which case we re-evaluate the logprobs of the transitions with the current self.pf. The following applies: @@ -120,6 +126,11 @@ def get_scores( states = transitions.states actions = transitions.actions + if log_rewards is None: + log_rewards = transitions.log_rewards + assert log_rewards is not None + assert log_rewards.shape == (transitions.n_transitions,) + if len(states) == 0: return ( torch.tensor(self.log_prob_min, device=transitions.device), @@ -144,10 +155,10 @@ def get_scores( log_F_s = self.logF(states).squeeze(-1) if self.forward_looking: - log_rewards = env.log_reward(states) + fl_log_rewards = env.log_reward(states) if math.isfinite(self.log_reward_clip_min): - log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) - log_F_s = log_F_s + log_rewards + fl_log_rewards = fl_log_rewards.clamp_min(self.log_reward_clip_min) + log_F_s = log_F_s + fl_log_rewards preds = log_pf_actions + log_F_s @@ -180,10 +191,7 @@ def get_scores( log_F_s_next = torch.zeros_like(log_pb_actions) log_F_s_next[~valid_transitions_is_terminating] = valid_log_F_s_next - assert transitions.log_rewards is not None - valid_transitions_log_rewards = transitions.log_rewards[ - ~transitions.states.is_sink_state - ] + valid_transitions_log_rewards = log_rewards[~transitions.states.is_sink_state] log_F_s_next[valid_transitions_is_terminating] = valid_transitions_log_rewards[ valid_transitions_is_terminating ] @@ -198,6 +206,7 @@ def loss( self, env: Env, transitions: Transitions, + log_rewards: Optional[torch.Tensor] = None, recalculate_all_logprobs: bool = True, reduction: str = "mean", ) -> torch.Tensor: @@ -207,7 +216,9 @@ def loss( 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266). """ warn_about_recalculating_logprobs(transitions, recalculate_all_logprobs) - _, _, scores = self.get_scores(env, transitions, recalculate_all_logprobs) + _, _, scores = self.get_scores( + env, transitions, log_rewards, recalculate_all_logprobs + ) scores = scores**2 loss = loss_reduce(scores, reduction) diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 10f866c0..852abae5 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -1,5 +1,5 @@ import warnings -from typing import Any +from typing import Any, Optional import torch @@ -189,6 +189,7 @@ def loss( self, env: DiscreteEnv, states_container: StatesContainer[DiscreteStates], + log_rewards: Optional[torch.Tensor] = None, recalculate_all_logprobs: bool = True, reduction: str = "mean", ) -> torch.Tensor: @@ -209,11 +210,16 @@ def loss( states_container.intermediary_states, states_container.intermediary_conditioning, ) + + if log_rewards is None: + log_rewards = states_container.log_rewards + assert log_rewards is not None + assert log_rewards.shape == (len(states_container.states),) rm_loss = self.reward_matching_loss( env, states_container.terminating_states, states_container.terminating_conditioning, - states_container.terminating_log_rewards, + log_rewards[states_container.is_terminating], ) return fm_loss + self.alpha * rm_loss diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index e6771eaa..4c67d171 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -1,6 +1,6 @@ import math import warnings -from typing import List, Literal, Tuple +from typing import List, Literal, Optional, Tuple import torch @@ -169,6 +169,7 @@ def calculate_targets( sink_states_mask: MaskTensor, full_mask: MaskTensor, i: int, + log_rewards: Optional[torch.Tensor] = None, ) -> TargetsTensor: """ Calculate the targets tensor for the current sub-trajectory length. @@ -182,12 +183,16 @@ def calculate_targets( sink_states_mask: A mask tensor of shape (max_length, n_trajectories) representing sink states. full_mask: A mask tensor of shape (max_length, n_trajectories) representing full states. i: The sub-trajectory length. + log_rewards: Optional tensor of shape (n_trajectories) containing the log rewards. If None, use the log rewards from the trajectories. Returns: The targets tensor of shape (max_length + 1 - i, n_trajectories). """ targets = torch.full_like(preds, fill_value=-float("inf")) - assert trajectories.log_rewards is not None - log_rewards = trajectories.log_rewards[trajectories.terminating_idx >= i] + if log_rewards is None: + log_rewards = trajectories.log_rewards + assert log_rewards is not None + assert log_rewards.shape == (trajectories.n_trajectories,) + log_rewards = log_rewards[trajectories.terminating_idx >= i] if math.isfinite(self.log_reward_clip_min): log_rewards.clamp_min(self.log_reward_clip_min) @@ -246,8 +251,8 @@ def calculate_log_state_flows( log_F = self.logF(valid_states).squeeze(-1) if self.forward_looking: - log_rewards = env.log_reward(states).unsqueeze(-1) - log_F = log_F + log_rewards + fl_log_rewards = env.log_reward(states).unsqueeze(-1) + log_F = log_F + fl_log_rewards log_state_flows[mask[:-1]] = log_F.squeeze() return log_state_flows @@ -276,6 +281,7 @@ def get_scores( self, env: Env, trajectories: Trajectories, + log_rewards: Optional[torch.Tensor] = None, recalculate_all_logprobs: bool = True, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """Scores all submitted trajectories. @@ -325,6 +331,7 @@ def get_scores( sink_states_mask, full_mask, i, + log_rewards, ) flattening_mask = trajectories.terminating_idx.lt( diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 7558a2b5..e41067a6 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -3,7 +3,7 @@ and the [Log Partition Variance loss](https://arxiv.org/abs/2302.05446). """ -from typing import cast +from typing import Optional, cast import torch import torch.nn as nn @@ -55,6 +55,7 @@ def loss( self, env: Env, trajectories: Trajectories, + log_rewards: Optional[torch.Tensor] = None, recalculate_all_logprobs: bool = True, reduction: str = "mean", ) -> torch.Tensor: @@ -69,7 +70,7 @@ def loss( del env # unused warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs) _, _, scores = self.get_trajectories_scores( - trajectories, recalculate_all_logprobs=recalculate_all_logprobs + trajectories, log_rewards, recalculate_all_logprobs=recalculate_all_logprobs ) # If the conditioning values exist, we pass them to self.logZ @@ -113,6 +114,7 @@ def loss( self, env: Env, trajectories: Trajectories, + log_rewards: Optional[torch.Tensor] = None, recalculate_all_logprobs: bool = True, reduction: str = "mean", ) -> torch.Tensor: @@ -124,7 +126,7 @@ def loss( del env # unused warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs) _, _, scores = self.get_trajectories_scores( - trajectories, recalculate_all_logprobs=recalculate_all_logprobs + trajectories, log_rewards, recalculate_all_logprobs=recalculate_all_logprobs ) scores = (scores - scores.mean()).pow(2) loss = loss_reduce(scores, reduction)