Treeffuser is an easy-to-use package for probabilistic prediction on tabular data with tree-based diffusion models.
It estimates distributions of the form p(y|x)
where x
is a feature vector and y
is a target vector.
Treeffuser can model conditional distributions p(y|x)
that are arbitrarily complex (e.g., multimodal, heteroscedastic, non-Gaussian, heavy-tailed, etc.).
It is designed to adhere closely to the scikit-learn API and require minimal user tuning.
Install Treeffuser from PyPI:
pip install treeffuser
Install the development version:
pip install git+https://github.com/blei-lab/treeffuser.git@main
The GitHub repository is located at: https://github.com/blei-lab/treeffuser
Here's a simple example demonstrating how to use Treeffuser.
We generate a heteroscedastic response with two sinusoidal components and heavy tails.
import matplotlib.pyplot as plt
import numpy as np
from treeffuser import Treeffuser, Samples
# Generate data
seed = 0
rng = np.random.default_rng(seed=seed)
n = 5000
x = rng.uniform(0, 2 * np.pi, size=n)
z = rng.integers(0, 2, size=n)
y = z * np.sin(x - np.pi / 2) + (1 - z) * np.cos(x) + rng.laplace(scale=x / 30, size=n)
We fit Treeffuser and generate samples. We then plot the samples against the raw data.
# Fit the model
model = Treeffuser(seed=seed)
model.fit(x, y)
# Generate and plot samples
y_samples = model.sample(x, n_samples=1, seed=seed, verbose=True)
plt.scatter(x, y, s=1, label="observed data")
plt.scatter(x, y_samples[0, :], s=1, alpha=0.7, label="Treeffuser samples")
Treeffuser accurately learns the target conditional densities and can generate samples from them.
These samples can be used to compute any downstream estimates of interest:
y_samples = model.sample(x, n_samples=100, verbose=True) # y_samples.shape[0] is 100
# Estimate downstream quantities of interest
y_mean = y_samples.mean(axis=0) # conditional mean
y_std = y_samples.std(axis=0) # conditional std
You can also use the Samples
helper class:
y_samples = Samples(y_samples)
y_mean = y_samples.sample_mean()
y_std = y_samples.sample_std()
y_quantiles = y_samples.sample_quantile(q=[0.05, 0.95])
See the documentation for more information on available methods and parameters.
If you use Treeffuser in your work, please cite:
@article{beltranvelez2024treeffuser,
title={Treeffuser: Probabilistic Predictions via Conditional Diffusions with Gradient-Boosted Trees},
author={Nicolas Beltran-Velez and Alessandro Antonio Grande and Achille Nazaret and Alp Kucukelbir and David Blei},
year={2024},
eprint={2406.07658},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2406.07658},
}