Skip to content

JAX 0.7.2 breaks least_squares #170

@jagourq

Description

@jagourq

With JAX 0.7.2, the following code errors with TypeError: cannot create weak reference to 'Flatten' object:

import jax.numpy as jnp
import optimistix as omx

def test(w):
    return (w - jnp.array([5.0, 42.0, -2.0]))**2

solver = omx.LevenbergMarquardt(rtol=1e-4, atol=1e-4)
res = omx.least_squares(lambda w, _: test(w), solver, jnp.zeros(3))
print(res.value)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions