Skip to content

Conversation

josephdviviano
Copy link
Collaborator

  • I've read the .github/CONTRIBUTING.md file
  • My code follows the typing guidelines
  • I've added appropriate tests
  • I've run pre-commit hooks locally

Description

  • Added 3 new hypergrid tasks which should be more challenging. Note that the specifics are very much up for debate. I tried to identify environments which were easy to divide + conquer vs those which require compositional knowledge (and therefore some amount of knowledge sharing among agents in a multi-agent setting).
  • Added mode verification logic (to ensure that your particular configuration actually contains modes to find).
  • Added lots of tests around these new rewards.
  • Added visualizations of the reward landscape for these various rewards.

@josephdviviano josephdviviano self-assigned this Oct 3, 2025
@josephdviviano josephdviviano added the enhancement New feature or request label Oct 3, 2025
Copy link
Collaborator

@younik younik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not able to review 1,000+ math-dense LOC for hypergrid.py :(
If you want a careful review, consider splitting this.

151sj28e8yab1

self._n_modes_via_ids_estimate = float(torch.unique(ids).numel())
self._mode_stats_kind = "approx"
except Exception:
warnings.warn("+ Warning: Failed to compute mode_stats (skipping).")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to use logger.exception here, to print the exception as well

Also it would be better to avoid catching Exception in general. Why this can fail?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would catch the ValueError in "exact" branch as well. Is this what we want? Should we catch at all?

Comment on lines +564 to +570
# Cheap exact threshold (up to ~200k states)
if self.n_states <= 200_000:
axes = [
torch.arange(self.height, dtype=torch.long) for _ in range(self.ndim)
]
grid = torch.cartesian_prod(*axes)
rewards = self.reward_fn(grid)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how did you come up with this number? Doing the cartesian product seems memory intensive

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this number might need to be lowered. It was arbitrary.

Comment on lines +572 to +574
except Exception:
# Fall back to heuristic paths below
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a logger
I don't think in general it is a good idea to mask a lot of stuff to the user. Sometimes we compute the exact mode existence, sometimes we use heuristic

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, agreed

Comment on lines +849 to +874
for col in range(m):
# Find pivot
piv = None
for r in range(row, k):
if A[r, col]:
piv = r
break
if piv is None:
continue
# Swap
if piv != row:
A[[row, piv]] = A[[piv, row]]
c[[row, piv]] = c[[piv, row]]
# Eliminate below
for r in range(row + 1, k):
if A[r, col]:
A[r, :] ^= A[row, :]
c[r] ^= c[row]
row += 1
if row == k:
break
# Check for inconsistency: 0 = 1 rows
for r in range(k):
if not A[r, :].any() and c[r]:
return False
return True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't check the details tbh, but it seems quite inefficient and not easily readable. Can we rely to scipy for these stuffs?

https://stackoverflow.com/questions/15638650/is-there-a-standard-solution-for-gauss-elimination-in-python

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into it

"""
with torch.no_grad():
device = torch.device("cpu")
B = min(2048, max(128, 8 * self.ndim))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are these numbers? Maybe use constant to improve clarity

Comment on lines +470 to 488
try:
all_states = self.all_states
if all_states is not None:
mask = self.mode_mask(all_states)
ids = self.mode_ids(all_states)
ids = ids[mask]
ids = ids[ids >= 0]
return int(torch.unique(ids).numel())
except Exception:
pass
if self._mode_stats_kind == "exact" and self._n_modes_via_ids_exact is not None:
return int(self._n_modes_via_ids_exact)
if (
self._mode_stats_kind == "approx"
and self._n_modes_via_ids_estimate is not None
):
return int(self._n_modes_via_ids_estimate)

return 2**self.ndim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to recompute this every time?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no you're right it should be stored.

Comment on lines +478 to +479
except Exception:
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to other comment, this is not nice for debuggability

@josephdviviano
Copy link
Collaborator Author

Hi @younik - I hear you, this is a big PR. The "splits" would have to be along tasks, though, so the resulting PRs would still be large.

I appreciate your comments on the code. I think it would make sense to also look at the tasks (the stuff that's plotted in the notebook) to see if they make sense. I'm not convinced by all of the tasks.

I would be open to removing a task or two. I think the one that works best for it's intended purpose is the coprime reward.

@hyeok9855
Copy link
Collaborator

In the above commit, I fixed the comments of Deceptive Reward and also fixed a pyright error.

@hyeok9855
Copy link
Collaborator

I would be open to removing a task or two. I think the one that works best for it's intended purpose is the coprime reward.

I do think Template Minkowski and Bitwise/XOR rewards are not very interesting to benchmark, especially if you care about the mode coverage. Multiplicative/Coprime seems challenging, but you may want to increase the reward for further modes from the origin.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants