-
Notifications
You must be signed in to change notification settings - Fork 48
New harder task #405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: hypergrid_refactor
Are you sure you want to change the base?
New harder task #405
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self._n_modes_via_ids_estimate = float(torch.unique(ids).numel()) | ||
self._mode_stats_kind = "approx" | ||
except Exception: | ||
warnings.warn("+ Warning: Failed to compute mode_stats (skipping).") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to use logger.exception here, to print the exception as well
Also it would be better to avoid catching Exception in general. Why this can fail?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would catch the ValueError in "exact" branch as well. Is this what we want? Should we catch at all?
# Cheap exact threshold (up to ~200k states) | ||
if self.n_states <= 200_000: | ||
axes = [ | ||
torch.arange(self.height, dtype=torch.long) for _ in range(self.ndim) | ||
] | ||
grid = torch.cartesian_prod(*axes) | ||
rewards = self.reward_fn(grid) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how did you come up with this number? Doing the cartesian product seems memory intensive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this number might need to be lowered. It was arbitrary.
except Exception: | ||
# Fall back to heuristic paths below | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add a logger
I don't think in general it is a good idea to mask a lot of stuff to the user. Sometimes we compute the exact mode existence, sometimes we use heuristic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, agreed
for col in range(m): | ||
# Find pivot | ||
piv = None | ||
for r in range(row, k): | ||
if A[r, col]: | ||
piv = r | ||
break | ||
if piv is None: | ||
continue | ||
# Swap | ||
if piv != row: | ||
A[[row, piv]] = A[[piv, row]] | ||
c[[row, piv]] = c[[piv, row]] | ||
# Eliminate below | ||
for r in range(row + 1, k): | ||
if A[r, col]: | ||
A[r, :] ^= A[row, :] | ||
c[r] ^= c[row] | ||
row += 1 | ||
if row == k: | ||
break | ||
# Check for inconsistency: 0 = 1 rows | ||
for r in range(k): | ||
if not A[r, :].any() and c[r]: | ||
return False | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't check the details tbh, but it seems quite inefficient and not easily readable. Can we rely to scipy for these stuffs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll look into it
""" | ||
with torch.no_grad(): | ||
device = torch.device("cpu") | ||
B = min(2048, max(128, 8 * self.ndim)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what are these numbers? Maybe use constant to improve clarity
try: | ||
all_states = self.all_states | ||
if all_states is not None: | ||
mask = self.mode_mask(all_states) | ||
ids = self.mode_ids(all_states) | ||
ids = ids[mask] | ||
ids = ids[ids >= 0] | ||
return int(torch.unique(ids).numel()) | ||
except Exception: | ||
pass | ||
if self._mode_stats_kind == "exact" and self._n_modes_via_ids_exact is not None: | ||
return int(self._n_modes_via_ids_exact) | ||
if ( | ||
self._mode_stats_kind == "approx" | ||
and self._n_modes_via_ids_estimate is not None | ||
): | ||
return int(self._n_modes_via_ids_estimate) | ||
|
||
return 2**self.ndim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to recompute this every time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no you're right it should be stored.
except Exception: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar to other comment, this is not nice for debuggability
Hi @younik - I hear you, this is a big PR. The "splits" would have to be along tasks, though, so the resulting PRs would still be large. I appreciate your comments on the code. I think it would make sense to also look at the tasks (the stuff that's plotted in the notebook) to see if they make sense. I'm not convinced by all of the tasks. I would be open to removing a task or two. I think the one that works best for it's intended purpose is the coprime reward. |
In the above commit, I fixed the comments of Deceptive Reward and also fixed a pyright error. |
I do think |
Description