-
Notifications
You must be signed in to change notification settings - Fork 33
Open
Labels
questionUser queriesUser queries
Description
With JAX 0.7.2, the following code errors with TypeError: cannot create weak reference to 'Flatten' object:
import jax.numpy as jnp
import optimistix as omx
def test(w):
return (w - jnp.array([5.0, 42.0, -2.0]))**2
solver = omx.LevenbergMarquardt(rtol=1e-4, atol=1e-4)
res = omx.least_squares(lambda w, _: test(w), solver, jnp.zeros(3))
print(res.value)Metadata
Metadata
Assignees
Labels
questionUser queriesUser queries