Code for the paper:
Will Grathowhl*, Jacob Kelly*, Milad Hashemi, Mohammad Norouzi, Kevin Swersky, David Duvenaud. "No MCMC for me: Amortized sampling for fast and stable training of energy-based models" arXiv preprint (2020). [arxiv] [bibtex]
*Equal Contribution
Code for implementing Variational Entropy Regularized Approximate maximum likelihood (VERA). Contains scripts for training VERA and using VERA for JEM training. Code is also available for training semi-supervised models on tabular data, mode counting experiments, and tractable likelihood models experiments.
For more info on me and my work please checkout my website, twitter, or Google Scholar.
Many thanks to my amazing co-authors: Jacob Kelly, Milad Hashemi, Mohammad Norouzi, Kevin Swersky, David Duvenaud.
pytorch==1.5.1
torchvision==0.6.1
numpy
scikit-learn
matplotlib
seaborn
tqdmA brief explanation of hyperparameters that can be set from flags and their names in the paper.
- --clf_weightClassification weight (- \alpha)
- --pg_controlGradient norm penalty (- \gamma)
- --ent_weightEntropy regularization weight (- \lambda)
- --clf_ent_weightClassification entropy (- \beta)
An explanation of flags for different modes of training. Without any of these flags, an unsupervised VERA model will be trained.
- --clf_onlyFor training a classifier on its own, i.e. without an EBM as in JEM.
- --jemDo JEM training.
- --labels_per_classIf this is greater than zero, use this many labels per class for semi-supervised learning. If zero (default), do full-label training.
To train a CIFAR10/CIFAR100 JEM model as in the paper (pretrained models available here), run:
python train.py --dataset DATASET  # cifar10 or cifar100
                --ent_weight 0.0001  --noise_dim 128  \
                --viz_every 1000 --save_dir /YOUR/SAVE/DIR --data_aug --dropout .3 --thicc_resnet \
                --ckpt_path /PATH/TO/YOUR/MODEL.pt --generator_type vera --n_epochs 200 --print_every 100 \
                --lr .00003 --glr .00006 --post_lr .00003 --batch_size 40 --pg_control .1 \
                --decay_epochs 150 175 --jem  --warmup_iters 2500 --clf_weight 100. --g_feats 256To evaluate the classifier (on CIFAR10):
python eval.py --ckpt_path /PATH/TO/YOUR/MODEL.pt --eval test_clf --dataset cifar_testTo do OOD detection (on CIFAR100)
python eval.py --ckpt_path /PATH/TO/YOUR/MODEL.pt --eval OOD --ood_dataset cifar_100To generate a histogram of OOD scores.
python eval.py --ckpt_path /PATH/TO/YOUR/MODEL.pt --eval logp_hist --datasets cifar10 svhn --save_dir /YOUR/HIST/FOLDERTo generate unconditional samples
python eval.py --ckpt_path /PATH/TO/YOUR/MODEL.pt --eval uncond_samples --save_dir /YOUR/SAVE/DIR --n_sample_steps 100 --n_steps 40To generate conditional samples
python eval.py --ckpt_path /PATH/TO/YOUR/MODEL.pt --eval cond_samples --save_dir /YOUR/SAVE/DIR --n_sample_steps 100 --n_steps 40Models can be trained by passing in the --dataset stackmnist flag.
Code for counting captured modes of a saved model is available in mode_counting/stackmnist_mode.py.
Code for training the MNIST classifier for counting modes is available in mode_counting/mnist_classify.py.
Hyperparameters may be found in the paper. In particular note that results were reported on MNIST rescaled to 64x64, which can be specified with --img_size 64.
Tabular data for semi-supervised classification must be downloaded manually and placed in datasets/.
Download 1000_train.csv.gz and 1000_test.csv.gz from here. Unzip each of these files and place in datasets/HEPMASS/.
Download UCI HAR Dataset.zip from here. Unzip. Rename the resulting folder to HUMAN/ and place this folder in datasets/.
Download data.zip from here. Unzip. Place the resulting file in datasets/CROP/.
If you want to use all three datasets, the datasets/ folder should include these files:
datasets/
|-- HEPMASS
|   |-- 1000_train.csv
|   |-- 1000_test.csv
|-- HUMAN
|   |-- train
|       |-- X_train.txt
|       |-- y_train.txt
|   |-- test
|       |-- X_test.txt
|       |-- y_test.txt
|-- CROP
|   |-- WinnipegDataset.txtSome code from this repository was adapted from the following repositories:
@article{grathwohl2020nomcmc,
  title={No MCMC for me: Amortized sampling for fast and stable training of energy-based models},
  author={Grathowhl, Will and Kelly, Jacob and Hashemi, Milad and Norouzi, Mohammad and Swersky, Kevin and Duvenaud, David},
  journal={arXiv preprint arXiv:2010.04230},
  year={2020}
}