Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 144 additions & 1 deletion src/gfn/env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import warnings
from abc import ABC, abstractmethod
from typing import Optional, Tuple, cast
from collections import Counter
from typing import TYPE_CHECKING, Dict, Optional, Tuple, cast

if TYPE_CHECKING:
from gfn.gflownet import GFlowNet

import torch
from torch_geometric.data import Data as GeometricData
Expand Down Expand Up @@ -679,6 +684,144 @@ def _backward_step(self, states: DiscreteStates, actions: Actions) -> DiscreteSt
self.update_masks(new_states)
return new_states

def get_terminating_state_dist(self, states: DiscreteStates) -> torch.Tensor:
"""Computes the empirical distribution over terminating states.

This requires the environment to support enumeration APIs:
`get_terminating_states_indices` and `n_terminating_states`.

Args:
states: A batch of terminating `DiscreteStates`.

Returns:
A 1D tensor of shape `(n_terminating_states,)` with empirical frequencies.
"""
try:
states_indices = (
self.get_terminating_states_indices(states).cpu().numpy().tolist()
)
except NotImplementedError as e:
warnings.warn(
"Environment does not implement state enumeration required for\n"
"empirical distribution `get_terminating_states_indices`. Skipping.",
UserWarning,
)
raise e

counter = Counter(str(idx) for idx in states_indices)
try:
counter_list = [
counter[str(state_idx)] if str(state_idx) in counter else 0
for state_idx in range(self.n_terminating_states)
]
except NotImplementedError as e:
warnings.warn(
"Environment does not implement state enumeration required for\n"
"empirical distribution `n_terminating_states`. Skipping.",
UserWarning,
)
raise e

denom = len(states_indices)
if denom == 0:
warnings.warn(
"No terminating states provided to compute empirical distribution.",
UserWarning,
)
return torch.zeros(
(self.n_terminating_states,), dtype=torch.get_default_dtype()
)

return torch.tensor(counter_list, dtype=torch.get_default_dtype()) / denom

def validate(
self,
gflownet: "GFlowNet",
n_validation_samples: int = 1000,
visited_terminating_states: Optional[DiscreteStates] = None,
) -> Tuple[Dict[str, float], DiscreteStates | None]:
"""Evaluate a GFlowNet against this environment's true distribution.

Designed for environments with known target reward distributions. Computes
the L1 distance between the empirical distribution of sampled terminating
states and the environment's `true_dist`. If the GFlowNet has a learned
`logZ` and the environment implements `log_partition`, also reports the
absolute difference between learned and target `logZ`.

Returns an empty result with a user-facing warning when validation is not
applicable (e.g., missing enumeration APIs or `true_dist`).
"""
# Check availability of true distribution.
try:
true_dist = self.true_dist
if isinstance(true_dist, torch.Tensor):
true_dist = true_dist.cpu()
else:
warnings.warn(
"Environment `true_dist` is not a tensor; cannot validate.",
UserWarning,
)
return {}, visited_terminating_states
except NotImplementedError:
warnings.warn(
"Environment does not implement `true_dist`; validation is skipped.",
UserWarning,
)
return {}, visited_terminating_states

# Attempt to retrieve true logZ if available.
true_logZ: float | None = None
try:
true_logZ = self.log_partition
except NotImplementedError:
true_logZ = None

# Determine which terminating states to use.
if visited_terminating_states is None:
if n_validation_samples <= 0:
warnings.warn(
"n_validation_samples <= 0 and no visited states provided; nothing to validate.",
UserWarning,
)
return {}, None
sampled_terminating_states = gflownet.sample_terminating_states(
self, n_validation_samples
)
assert isinstance(sampled_terminating_states, DiscreteStates)
else:
sampled_terminating_states = visited_terminating_states[
-n_validation_samples:
]

# Compute empirical distribution; may require enumeration support.
try:
final_states_dist = self.get_terminating_state_dist(
sampled_terminating_states
)
except NotImplementedError:
# Already warned in helper; return gracefully.
return {}, sampled_terminating_states

if final_states_dist.numel() == 0:
warnings.warn(
"Empirical distribution is empty (no terminating samples).", UserWarning
)
return {}, sampled_terminating_states

l1_dist = (final_states_dist - true_dist).abs().mean().item()
validation_info: Dict[str, float] = {"l1_dist": l1_dist}

# Report logZ difference if both sides are available.
learned_logZ: float | None = None
if hasattr(gflownet, "logZ") and isinstance(
getattr(gflownet, "logZ"), torch.Tensor
):
learned_logZ = float(getattr(gflownet, "logZ").item())
if learned_logZ is not None and true_logZ is not None:
validation_info["logZ_diff"] = abs(learned_logZ - true_logZ)

return validation_info, sampled_terminating_states

def get_states_indices(self, states: DiscreteStates) -> torch.Tensor:
"""Optional method to return the indices of the states in the environment.

Expand Down
114 changes: 0 additions & 114 deletions src/gfn/gym/helpers/hypergrid/rewards.py

This file was deleted.

Loading