-
Notifications
You must be signed in to change notification settings - Fork 252
Allow gradient transform parameters to be dynamic #516
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@hawkinsp Pinging you since you recently repaired some type annotation errors. The optimizer classes accepting only float breaks type annotations for the Tjax shim classes (https://github.com/NeilGirdhar/tjax/blob/main/tjax/_src/gradient/transforms.py). Tjax provides a parallel set of optimizers, identical in functionality, except they support dynamic optimizer parameters. They do this by storing dynamic fields in a dataclass rather than closing over parameters. However, the optimizer functionality is delegated to Optax, which means calling Optax update methods with Jax arrays. Is there any reason Optax methods can't accept such arrays? Would it be possible to widen these parameter types to |
@mtthss Would you mind taking a look at this? |
Hello. I was on paternity leave for most of the past year. Are you still having this issue? Happy to look into it if that's the case |
@mtthss Hello, yes I'm still getting the type errors. (Congrats on becoming a father!) |
which arguments are causing errors to you? |
All of the ones I changed. I maintain a shim library so that I can use optax with dynamic, inspectable parameters. What I ended up doing for the time being is to mark every use of optax with Thanks for taking a look at this. |
(Of course, my dream would be that you adopt the dynamic design so that I don't have to maintain my shim library 😄.) |
d18a5cb
to
59b2964
Compare
@mtthss Do you mind taking a look at this? |
59b2964
to
bcf5cbf
Compare
I'm a new maintainer. At some point, optax moved to https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.inject_hyperparams to make parameters available in the state. I think the point of these type annotations is to explicitly indicate that (like you're pointing out) the parameters are closed over and not really modifiable without |
That's a separate issue from this PR. (I believe, in fact, it's the first issue #1.)
Even if the parameters are closed over, that doesn't make them floats. A closed-over Jax array is still a Jax array. The annotations are simply wrong. It's just that most users probably don't hit this because most users are probably just using fixed float values. |
Let me talk to @vroulet |
bcf5cbf
to
e200b35
Compare
That's a completely fair point. We'll make an attempt to change the types (this will affect all internal code so many bugs possible) but a priori it should not pose problems. We'll keep you updated Neil. (It'll be next month a priori) |
No description provided.