A framework for implementing and reproducing the paper "Gridded Transformer Neural Processes for Spatio-Temporal Data" (ICML 2025) in Python.
conda create -n tnp python=3.12
conda activate tnp
pip install -r requirements.txt
pip install -e .Note: For this to work you need to have a CDS account (https://cds.climate.copernicus.eu/) and a file ~/.cdsapirc which contains your credentials.
For this version, the credential file uses the format:
url: https://cds.climate.copernicus.eu/api
key: <api-key>
python experiments/era5_download.py --start_year 2009 --end_year 2019 --num_processes 4python experiments/mm_era5_download.py --start_year 2009 --end_year 2019 --num_processes 4python experiments/generate_gp_data.py --gen_name train --num_processes 4 --config experiments/configs/generators/pregenerate-gp.yml
python experiments/generate_gp_data.py --gen_name train-large-lengthscale --num_processes 4 --config experiments/configs/generators/pregenerate-gp.yml
python experiments/generate_gp_data.py --gen_name test --num_processes 4 --config experiments/configs/generators/pregenerate-gp.yml
python experiments/generate_gp_data.py --gen_name test-large-lengthscale --num_processes 4 --config experiments/configs/generators/pregenerate-gp.ymlHere we train the Swin-TNP using the pseudo-token grid encoder (PT-GE).
python experiments/lightning_train.py --config experiments/configs/models/gp/swintnp.yml experiments/configs/models/grid_encoders/pt-ge.yml experiments/configs/models/grid_decoders/mhca-gd.yml experiments/configs/generators/gp.ymlHere we train the Swin-TNP using the pseudo-token grid encoder (PT-GE).
python experiments/slurm_lightning_train.py --config experiments/configs/models/ootg-swintnp.yml experiments/configs/models/grid_encoders/ootg-pt-ge.yml experiments/configs/models/grid_decoders/mhca-gd.yml experiments/configs/generators/ootg-era5.yml experiments/configs/data/era5-nc-files.ymlFor using the kernel-interpolation grid encoder (KI-GE), we use:
python experiments/slurm_lightning_train.py --config experiments/configs/models/ootg-swintnp.yml experiments/configs/models/grid_encoders/ootg-setconv-ge.yml experiments/configs/models/grid_decoders/mhca-gd.yml experiments/configs/generators/ootg-era5.yml experiments/configs/data/era5-nc-files.ymlFor benchmarking against the ConvCNP, we use:
python experiments/slurm_lightning_train.py --config experiments/configs/models/ootg-convcnp.yml experiments/configs/models/cnns/unet.yml experiments/configs/generators/ootg-era5.yml experiments/configs/data/era5-nc-files.ymlFor the translation equivariant version, we use the following:
python experiments/slurm_lightning_train.py --config experiments/configs/models/ootg-swintetnp.yml experiments/configs/models/grid_encoders/ootg-pt-te-ge.yml experiments/configs/models/grid_decoders/temhca-gd.yml experiments/configs/generators/ootg-era5.yml experiments/configs/data/era5-nc-files.ymlFor the approximately translation equivariant version, we use the following:
python experiments/slurm_lightning_train.py --config experiments/configs/models/ootg-swinatetnp.yml experiments/configs/models/grid_encoders/ootg-pt-te-ge.yml experiments/configs/models/grid_decoders/temhca-gd.yml experiments/configs/generators/ootg-era5.yml experiments/configs/data/era5-nc-files.ymlHere we train the Swin-TNP using the multi pseudo-token grid encoder (PT-GE).
python experiments/slurm_lightning_train.py --config experiments/configs/models/mm-swintnp.yml experiments/configs/models/grid_encoders/pt-ge-tt.yml experiments/configs/models/grid_decoders/mhca-gd.yml experiments/configs/generators/mm-era5.yml experiments/configs/data/mm-era5-nc-files.ymlFor using the kernel-interpolation grid encoder (KI-GE), we use the following command:
python experiments/slurm_lightning_train.py --config experiments/configs/models/mm-swintnp.yml experiments/configs/models/grid_encoders/setconv-ge-tt.yml experiments/configs/models/grid_decoders/mhca-gd.yml experiments/configs/generators/mm-era5.yml experiments/configs/data/mm-era5-nc-files.ymlFor benchmarking against ConvCNP, we use:
python experiments/slurm_lightning_train.py --config experiments/configs/models/mm-convcnp.yml experiments/configs/models/cnns/unet.yml experiments/configs/generators/mm-era5.yml experiments/configs/data/mm-era5-nc-files.ymlFor the translation equivariant version, we use the following:
python experiments/slurm_lightning_train.py --config experiments/configs/models/mm-swintetnp.yml experiments/configs/models/grid_encoders/pt-te-ge-tt.yml experiments/configs/models/grid_decoders/temhca-gd.yml experiments/configs/generators/mm-era5.yml experiments/configs/data/mm-era5-nc-files.ymlFor the approximately translation equivariant version, we use the following:
python experiments/slurm_lightning_train.py --config experiments/configs/models/mm-swinatetnp.yml experiments/configs/models/grid_encoders/pt-te-ge-tt.yml experiments/configs/models/grid_decoders/temhca-gd.yml experiments/configs/generators/mm-era5.yml experiments/configs/data/mm-era5-nc-files.ymlTo acknowledge the repository or paper, please cite
@misc{ashman2024griddedtransformerneuralprocesses,
title={Gridded Transformer Neural Processes for Large Unstructured Spatio-Temporal Data},
author={Matthew Ashman and Cristiana Diaconu and Eric Langezaal and Adrian Weller and Richard E. Turner},
year={2024},
eprint={2410.06731},
archivePrefix={arXiv},
primaryClass={stat.ML},
url={https://arxiv.org/abs/2410.06731},
}