This repository contains the official implementation of Model Tensor Planning (MTP), a sampling-based model predictive control (MPC) framework that performs high-entropy control generation using structured tensor sampling. See project website.
MTP is implemented entirely in JAX and supports real-time control in high-dimensional systems with GPU acceleration via JIT and MuJoCo XLA.
- Tensor Sampling: Generates globally diverse trajectory candidates via sampling over randomized multipartite graphs.
- Spline Grid Interpolation: Smoothes sampled controls using B-spline and Akima splines for dynamically feasible execution.
- β-Mixing Strategy: Blends global (exploratory) and local (exploitative) samples at each planning iteration.
NOTE: MTP-Bspline
and MTP-Akima
depends on hydrax fork that separates original hydrax since spline support PR. To match the newest commit, I implemented MTP-Cubic
(untuned), a version that samples both global and local splines using interpax
, matching the new API design of the original hydrax
. To play around with MTP-Cubic
, please checkout the branch experimental
of both mtp
and hydrax fork.
Clone this repository and its submodules:
git clone --recurse-submodules [email protected]:anindex/mtp.git
cd mtp
Install the environment and libraries:
# Conda environment setup
conda update -n base conda -y
conda create -n mtp python=3.12
conda activate mtp
conda env config vars set CUDA_HOME=""
conda activate mtp
conda install -c nvidia/label/cuda-12.9.0 cuda-toolkit=12.9.0 -y
conda install -c conda-forge cudnn=9.10.1.4 -y
conda install pip -y
conda activate mtp
# Install hydrax dependency
cd hydrax
pip install -e .
# Install MTP
cd ..
pip install -e .
All examples are configured to run with MTP by default. To switch between planners (e.g., cem
, mppi
, ps
, oes
, de
), replace the last argument.
python examples/navigation.py mtp
python examples/double_cart_pole.py mtp
python examples/pusht.py mtp
python examples/crane.py mtp
python examples/walker.py mtp
python examples/cube.py mtp
python examples/pendulum.py mtp
python examples/g1_standup.py mtp
python examples/g1_mocap.py mtp
To achieve the best performance across tasks, here are recommended tuning guidelines:
Symbol | Description | Typical Range |
---|---|---|
M |
Number of control waypoints (graph depth) | 2–3 (depending on horizon T) |
N |
Number of control candidates per waypoint (graph width) | 30–100 |
β |
Mixing rate (exploration vs. exploitation) | 0.01–0.6 (lower = more stable less exploration) |
E |
Number of elites | 5–100 (depends on task complexity) |
σ_min |
Minimum noise std for CEM sampling | 0.05–0.2 |
σ_max |
Maximum noise std for CEM sampling | 0.3–0.5 |
interpolation |
Interpolation Types | 'linear', 'bspline', 'akima' |
α |
CEM smoothing weight (optional) | 0.0–0.5 |
- B-Spline (default for stable tasks): Good for underactuated systems or where smoothness is critical.
- Akima Spline (use for aggressive control): Works well in contact-rich environments (e.g., dexterous manipulation).
Run scripts/plot_splines.py
to see spline tensors.
This codebase builds upon HydraX, MuJoCo XLA. Special thanks to Vince Kurtz and other contributors.
If you found this repository useful, please consider citing these references:
@misc{le2025mtp,
title={Model Tensor Planning},
author={An T. Le and Khai Nguyen and Minh Nhat Vu and João Carvalho and Jan Peters},
year={2025},
eprint={2505.01059},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2505.01059},
}
@misc{kurtz2024hydrax,
title={Hydrax: Sampling-based model predictive control on GPU with JAX and MuJoCo MJX},
author={Kurtz, Vince},
year={2024},
note={https://github.com/vincekurtz/hydrax}
}