Official code for our paper "Selective Concept Bottlenecks Without Predefined Concepts" (TMLR 2025).
If you find this work useful, please consider citing our work:
@article{schrodi2025selective,
title={Selective Concept Bottlenecks Without Predefined Concepts},
author={Schrodi, Simon and Schur, Julian and Argus, Max and Brox, Thomas},
journal={Transactions on Machine Learning Research},
year={2025}
}
Here is an overview of our method, UCBM:
-
Install the conda environment, via
conda env create --name ucbm --file env.yml
-
call
conda develop .
-
set the base path in
constants.py
(this is where everything will be saved) and the other paths if necessary. -
Download the models following the instructions from Trustworthy-ML-Lab/Label-free-CBM and adjust the paths in
constants.py
if necessary. -
Download the CUB dataset with
bash download_cub.sh
ImageNet:
python discover_concepts.py -d imagenet -b resnet_v2 --con_am 3
CUB:
python discover_concepts.py -d cub -b cub_rn18 --con_am 1
Places-365:
python discover_concepts.py -d places365 -b places365_rn18 --con_am 5
ImageNet:
python train_cbm.py -d imagenet -b resnet_v2 --concept_data "concepts_3000_64" --epochs 20 --lam_gate 0 --lam_w 1e-4 --dropout_p 0.1 --lr 0.001 --cls_save_name "topk_seed_0" --scale_choose 'no' --bias_choose 'learn' --normalize_concepts --relu 'no' --k 42 --seed 0
CUB:
python train_cbm.py -d cub -b cub_rn18 --concept_data "concepts_200_64" --epochs 20 --lam_gate 0 --lam_w 8e-4 --dropout_p 0.2 --lr 0.001 --cls_save_name "topk_seed_0" --scale_choose 'no' --bias_choose 'learn' --normalize_concepts --relu 'no' --k 66 --seed 0
Places-365:
python train_cbm.py -d places365 -b places365_rn18 --concept_data "concepts_1825_64" --epochs 20 --lam_gate 0 --lam_w 4e-4 --dropout_p 0.2 --lr 0.01 --cls_save_name "topk_seed_0" --scale_choose 'no' --bias_choose 'learn' --normalize_concepts --relu 'no' --k 162 --seed 0
cbm_name
is the defined cls_save_name from above with a timestamp.
Plot explanations of individual decisions:
python explanation/prediction.py -d imagenet -b resnet50_v2 -c concepts_3000_64 --cbm_name $cbm_name
Plot explanations of class:
python explanation/class.py -d imagenet -b resnet50_v2 -c concepts_3000_64 --cbm_name $cbm_name
We thank the following GitHub users for their contributions which are used in this repository:
- CRAFT from deel-ai/Craft
- Base cpu NMF implemenation scikit-learn/scikit-learn