JAX-PF: an efficient GPU-computing simulation for differentiable phase field (PF) simulaiton, built on top of JAX-FEM leveraging JAX.
We want to emphasize the following four features that differential JAX-PF from other PF software:
- Ease-of-use: Leveraging capability of automatic differentiation (AD) in JAX, JAX-PF automatically generates Jacobians and derivatives of different energy terms with machine precision, enabling the realization of multi-physics and multi-variable PF models.
- Automatic Sensitivity: Implicit time integration with customized adjoint-based AD enables efficient gradient-based optimization and inverse design of strongly nonlinear PF systems.
- High-performance GPU-acceleration: Through the XLA backend and vectorized operations, JAX-PF delivers competitive GPU performance, drastically reducing computational time relative to CPU-based solvers.
- Unified multiscale ecosystem with JAX-CPFEM: Built on the same JAX-FEM foundation, JAX-PF integrates seamlessly with JAX-CPFEM to enable coupled process–structure–property simulations (e.g., dynamic recrystallization), while preserving full differentiability for optimization and design.
🔥 Join us for the development of JAX-PF! This project is under active development!
Four benchmark problems are provided, including Allen–Cahn, Cahn–Hilliard, coupled Allen–Cahn and Cahn–Hilliard, each implemented with both explicit and implicit time integration, and Eshelby inclusion for lattice misfit in solid-state phase transformations.
📣 Comparison between JAX-PF and PRISMS-PF.
Validation of benchmark problems in JAX-PF.
🔥 For each case, both explicit and implicit time stepping schemes are provided
The initial (left) and final (right) grain structure for a 2D grain growth simulation.
The distribution of composition during a simulation of spinodal decomposition from initial fluctuations (left) to final two distinct phases (right).
A 2D simulations of the multi-variants precipitate in an Mg-Nd alloy.
A 3D simulations of the single-variants precipitate in an Mg-Nd alloy.
📣 Multiscale simulations (PF-CPFEM) using JAX-PF and JAX-CPFEM, which are built on top of the same underlying JAX-FEM ecosystem.
📣 A demos: calibration of material parameters.
JAX-PF supports Linux and macOS, which depend on JAX-FEM.
JAX-FEM is a collection of several numerical tools, including the Finite Element Method (FEM). See JAX-FEM installation instructions. Depending on your hardware, you may install the CPU or GPU version of JAX. Both will work, while the GPU version usually gives better performance.
Neper is a free/open-source software package for polycrystal generation and meshing. It can be used to generate polycrystals with a wide variety of morphological properties. A good instruction video is on Youtube.
Place the downloaded phaseField/
file in the applications/
folder of JAX-FEM, and then you can run it.
For example, you can download phaseField/allenCahn/explicit_fem
folder and place it in the applications/
folder of JAX-FEM, run
python -m applications.phaseField.allenCahn.explicit_fem.explicit_AC
from the root directory. Use Paraview for visualization.
📣 Comming soon!
If you found this library useful in academic or industry work, we appreciate your support if you consider 1) starring the project on Github, and 2) citing relevant papers:
- Efficient GPU-computing simulation platform JAX-CPFEM for differentiable crystal plasticity finite element method. DOI: https://doi.org/10.1038/s41524-025-01528-2