Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from abc import ABC, abstractmethod
from typing import Any, Generic, Tuple, TypeVar
from typing import Any, Generic, Optional, Tuple, TypeVar

import torch
import torch.nn as nn
Expand Down Expand Up @@ -204,12 +204,15 @@ def get_pfs_and_pbs(
def get_trajectories_scores(
self,
trajectories: Trajectories,
log_rewards: Optional[torch.Tensor] = None,
recalculate_all_logprobs: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Given a batch of trajectories, calculate forward & backward policy scores.

Args:
trajectories: Trajectories to evaluate.
log_rewards: Log rewards to use. If None, use the log_rewards from the
trajectories.
recalculate_all_logprobs: Whether to re-evaluate all logprobs.

Returns: A tuple of float tensors of shape (n_trajectories,)
Expand All @@ -224,7 +227,10 @@ def get_trajectories_scores(
total_log_pf_trajectories = log_pf_trajectories.sum(dim=0)
total_log_pb_trajectories = log_pb_trajectories.sum(dim=0)

log_rewards = trajectories.log_rewards
if log_rewards is None:
log_rewards = trajectories.log_rewards
assert log_rewards is not None
assert log_rewards.shape == (trajectories.n_trajectories,)

if math.isfinite(self.log_reward_clip_min) and log_rewards is not None:
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)
Expand Down
31 changes: 21 additions & 10 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Tuple
from typing import Optional, Tuple

import torch

Expand Down Expand Up @@ -94,12 +94,18 @@ def get_pfs_and_pbs(
)

def get_scores(
self, env: Env, transitions: Transitions, recalculate_all_logprobs: bool = True
self,
env: Env,
transitions: Transitions,
log_rewards: Optional[torch.Tensor] = None,
recalculate_all_logprobs: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Given a batch of transitions, calculate the scores.

Args:
transitions: a batch of transitions.
log_rewards: log rewards of the transitions. If None, use the log rewards
from the transitions.

Unless recalculate_all_logprobs=True, in which case we re-evaluate the logprobs of the transitions with
the current self.pf. The following applies:
Expand All @@ -120,6 +126,11 @@ def get_scores(
states = transitions.states
actions = transitions.actions

if log_rewards is None:
log_rewards = transitions.log_rewards
assert log_rewards is not None
assert log_rewards.shape == (transitions.n_transitions,)

if len(states) == 0:
return (
torch.tensor(self.log_prob_min, device=transitions.device),
Expand All @@ -144,10 +155,10 @@ def get_scores(
log_F_s = self.logF(states).squeeze(-1)

if self.forward_looking:
log_rewards = env.log_reward(states)
fl_log_rewards = env.log_reward(states)
if math.isfinite(self.log_reward_clip_min):
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)
log_F_s = log_F_s + log_rewards
fl_log_rewards = fl_log_rewards.clamp_min(self.log_reward_clip_min)
log_F_s = log_F_s + fl_log_rewards

preds = log_pf_actions + log_F_s

Expand Down Expand Up @@ -180,10 +191,7 @@ def get_scores(

log_F_s_next = torch.zeros_like(log_pb_actions)
log_F_s_next[~valid_transitions_is_terminating] = valid_log_F_s_next
assert transitions.log_rewards is not None
valid_transitions_log_rewards = transitions.log_rewards[
~transitions.states.is_sink_state
]
valid_transitions_log_rewards = log_rewards[~transitions.states.is_sink_state]
log_F_s_next[valid_transitions_is_terminating] = valid_transitions_log_rewards[
valid_transitions_is_terminating
]
Expand All @@ -198,6 +206,7 @@ def loss(
self,
env: Env,
transitions: Transitions,
log_rewards: Optional[torch.Tensor] = None,
recalculate_all_logprobs: bool = True,
reduction: str = "mean",
) -> torch.Tensor:
Expand All @@ -207,7 +216,9 @@ def loss(
3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).
"""
warn_about_recalculating_logprobs(transitions, recalculate_all_logprobs)
_, _, scores = self.get_scores(env, transitions, recalculate_all_logprobs)
_, _, scores = self.get_scores(
env, transitions, log_rewards, recalculate_all_logprobs
)
scores = scores**2
loss = loss_reduce(scores, reduction)

Expand Down
10 changes: 8 additions & 2 deletions src/gfn/gflownet/flow_matching.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any
from typing import Any, Optional

import torch

Expand Down Expand Up @@ -189,6 +189,7 @@ def loss(
self,
env: DiscreteEnv,
states_container: StatesContainer[DiscreteStates],
log_rewards: Optional[torch.Tensor] = None,
recalculate_all_logprobs: bool = True,
reduction: str = "mean",
) -> torch.Tensor:
Expand All @@ -209,11 +210,16 @@ def loss(
states_container.intermediary_states,
states_container.intermediary_conditioning,
)

if log_rewards is None:
log_rewards = states_container.log_rewards
assert log_rewards is not None
assert log_rewards.shape == (len(states_container.states),)
rm_loss = self.reward_matching_loss(
env,
states_container.terminating_states,
states_container.terminating_conditioning,
states_container.terminating_log_rewards,
log_rewards[states_container.is_terminating],
)
return fm_loss + self.alpha * rm_loss

Expand Down
17 changes: 12 additions & 5 deletions src/gfn/gflownet/sub_trajectory_balance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import warnings
from typing import List, Literal, Tuple
from typing import List, Literal, Optional, Tuple

import torch

Expand Down Expand Up @@ -169,6 +169,7 @@ def calculate_targets(
sink_states_mask: MaskTensor,
full_mask: MaskTensor,
i: int,
log_rewards: Optional[torch.Tensor] = None,
) -> TargetsTensor:
"""
Calculate the targets tensor for the current sub-trajectory length.
Expand All @@ -182,12 +183,16 @@ def calculate_targets(
sink_states_mask: A mask tensor of shape (max_length, n_trajectories) representing sink states.
full_mask: A mask tensor of shape (max_length, n_trajectories) representing full states.
i: The sub-trajectory length.
log_rewards: Optional tensor of shape (n_trajectories) containing the log rewards. If None, use the log rewards from the trajectories.

Returns: The targets tensor of shape (max_length + 1 - i, n_trajectories).
"""
targets = torch.full_like(preds, fill_value=-float("inf"))
assert trajectories.log_rewards is not None
log_rewards = trajectories.log_rewards[trajectories.terminating_idx >= i]
if log_rewards is None:
log_rewards = trajectories.log_rewards
assert log_rewards is not None
assert log_rewards.shape == (trajectories.n_trajectories,)
log_rewards = log_rewards[trajectories.terminating_idx >= i]

if math.isfinite(self.log_reward_clip_min):
log_rewards.clamp_min(self.log_reward_clip_min)
Expand Down Expand Up @@ -246,8 +251,8 @@ def calculate_log_state_flows(
log_F = self.logF(valid_states).squeeze(-1)

if self.forward_looking:
log_rewards = env.log_reward(states).unsqueeze(-1)
log_F = log_F + log_rewards
fl_log_rewards = env.log_reward(states).unsqueeze(-1)
log_F = log_F + fl_log_rewards

log_state_flows[mask[:-1]] = log_F.squeeze()
return log_state_flows
Expand Down Expand Up @@ -276,6 +281,7 @@ def get_scores(
self,
env: Env,
trajectories: Trajectories,
log_rewards: Optional[torch.Tensor] = None,
recalculate_all_logprobs: bool = True,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Scores all submitted trajectories.
Expand Down Expand Up @@ -325,6 +331,7 @@ def get_scores(
sink_states_mask,
full_mask,
i,
log_rewards,
)

flattening_mask = trajectories.terminating_idx.lt(
Expand Down
8 changes: 5 additions & 3 deletions src/gfn/gflownet/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
and the [Log Partition Variance loss](https://arxiv.org/abs/2302.05446).
"""

from typing import cast
from typing import Optional, cast

import torch
import torch.nn as nn
Expand Down Expand Up @@ -55,6 +55,7 @@ def loss(
self,
env: Env,
trajectories: Trajectories,
log_rewards: Optional[torch.Tensor] = None,
recalculate_all_logprobs: bool = True,
reduction: str = "mean",
) -> torch.Tensor:
Expand All @@ -69,7 +70,7 @@ def loss(
del env # unused
warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs)
_, _, scores = self.get_trajectories_scores(
trajectories, recalculate_all_logprobs=recalculate_all_logprobs
trajectories, log_rewards, recalculate_all_logprobs=recalculate_all_logprobs
)

# If the conditioning values exist, we pass them to self.logZ
Expand Down Expand Up @@ -113,6 +114,7 @@ def loss(
self,
env: Env,
trajectories: Trajectories,
log_rewards: Optional[torch.Tensor] = None,
recalculate_all_logprobs: bool = True,
reduction: str = "mean",
) -> torch.Tensor:
Expand All @@ -124,7 +126,7 @@ def loss(
del env # unused
warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs)
_, _, scores = self.get_trajectories_scores(
trajectories, recalculate_all_logprobs=recalculate_all_logprobs
trajectories, log_rewards, recalculate_all_logprobs=recalculate_all_logprobs
)
scores = (scores - scores.mean()).pow(2)
loss = loss_reduce(scores, reduction)
Expand Down