Tianze Yang†, Yucheng Shi†, Mengnan Du, Xuansheng Wu, Qiaoyu Tan, Jin Sun, Ninghao Liu
† Equal contribution
This is the official repository for our paper "Concept-Centric Token Interpretation for Vector-Quantized Generative Models", accepted at the International Conference on Machine Learning (ICML) 2025.
We propose a novel framework, CORTEX, for interpreting tokens in vector-quantized generative models through a concept-centric lens.
Figure: Our pipeline for token-level concept interpretation.
| Requirement | Value |
|---|---|
| Python | 3.12.3 |
| Conda env | CORTEX |
# Option A (preferred): use the YAML file
conda env create -f environment.yml # creates env named “CORTEX”
# Option B: use the requirements file
conda create -n CORTEX python=3.12.3
conda activate CORTEX
pip install -r requirements.txtconda activate CORTEXCORTEX
├── VQGAN_explanation/ # Experiments & analyses based on VQGAN
├── Dalle_explanation/ # Experiments & analyses based on DALLE
├── environment.yml # Conda environment specification (preferred)
├── requirements.txt # Pip fallback dependency list
└── README.md # Repository overview (you are here)
cd VQGAN_explanationThis subfolder contains the implementation of CORTEX to explain the VQGAN model.
CORTEX/VQGAN_explanation/
├── checkpoints/ # Model checkpoints (download required)
├── datasets/ # Datasets (download required)
├── eval/ # Evaluation scripts
│ ├── codebook_level_explanation.py
│ ├── sample_concept_level_explanation.py
│ ├── sample_image_level_explanation.py
├── logs/ # Training logs
├── results/ # Results directory
├── model.py # IEM architecture
├── new_vqgan.py # Prepare for the VQGAN repository
├── dataset.py # Dataset loader
├── train.py # Training script for IEM
├── test.py # Evaluation script
├── TIS_computation.py # Token Importance Score computation
├── TIS_analysis.py # TIS analysis for concept-level explanations
├── generate_freq_based_tokens.py # Generate frequency-based baseline
- Clone the repository of VQGAN
- Place the
new_vqgan.pyfile into the VQGAN repository under thetaming-transformers/taming/modelsdirectory (If you want to runeval/codebook_level_explanation.py) - Download the datasets or generate your own dataset and replace the
datasetsdirectory
(The dataset was generated using the VQGAN model.) - Download pre-trained checkpoints or train your own IEMs and place them in the
checkpointsdirectory
📥 Data and Checkpoints Download:
You can download our generated dataset from Download Datasets and Our pre-trained checkpoints from Download Checkpoints.
⚠️ Note: The dataset is quite large. For efficiency, we recommend generating only the required subset for your task instead of downloading the entire dataset.
You can train your own Interpretable Explanation Model (IEM) on different Vector-Quantized Generative Models (VQGMs).
The model input is a token-based embedding with shape (256, 16, 16). To train IEMs on other VQGMs, you need to first generate the required dataset:
- For each image, save its token-based embedding (of shape
256 × 16 × 16) - During generation, record the corresponding token indices and label
- Save this metadata in a
.csvfile following the format of this train_embeddings.csv
python train.py --model {model_name}Where model_name ∈ {1, 2, 3, 4}.
python test.py --model {model_name}python TIS_computation.py --model {model_name} --data_type {data_type} --batch_size {batch_size} --gpu {gpu_number}model_name: 1, 2, 3, or 4data_type:train: foreval/sample_concept_level_explanation.pytest: foreval/sample_image_level_explanation.py
batch_size: Integer valuegpu_number: GPU device index
Example:
python TIS_computation.py --model 1 --data_type train --batch_size 25 --gpu 1
⚠️ This process may take considerable time depending on dataset and GPU.
python generate_freq_based_tokens.pypython TIS_analysis.py --model {model_name}cd evalpython sample_image_level_explanation.py --model {model_name}python sample_concept_level_explanation.py --model {model_name} --top_n {top_n_value} --token_num {token_num}top_n: Select top-n tokens per imagetoken_num: Number of tokens to use
Replace the line inside codebook_level_explanation.py:
VQGAN_directory = {Your VQGAN directory}with your actual VQGAN repo path.
Run:
python codebook_level_explanation.py --model {model_name} --steps {optimization_steps} --lr {learning_rate} --optimization_type {token_selection or embedding}Example:
python codebook_level_explanation.py --model 1 --optimization_type tokencd Dalle_explanationThis subfolder contains the implementation of CORTEX to explain the DALL·E-mini model.
CORTEX/Dalle_explanation/
├── checkpoints/ # Model checkpoints (download required)
├── datasets/ # Datasets (download required)
├── bias_detection.py # Bias detection using TIS
├── dataset.py # Dataset loader
├── model.py # IEM architecture
├── test.py # Evaluation script
├── train.py # Training script for IEM
├── TIS_computation.py # Token Importance Score computation
├── TIS_analysis.py # TIS analysis
-
Download the datasets generated by DALL·E-mini
and replace thedatasetsdirectory.📥 Dataset Download: Datasets
-
Download the pre-trained checkpoints and place them in the
checkpointsdirectory.📥 Checkpoints Download: Checkpoints
⚠️ Note: In this experiment, we only pretrained the CNN-based model; you can train the IEM with other structures
You can train an Interpretable Explanation Model (IEM) on DALL·E-mini embeddings using:
python train.py --model 1python test.py --model 1 --bias_type doctor_color # or doctor_genderpython TIS_computation.py --model 1 --bias_type doctor_color # or doctor_genderpython TIS_analysis.py --model 1 --bias_type doctor_color # or doctor_genderpython bias_detection.py --model 1 --bias_type doctor_color --top_n {top_n_value} --token_num {token_num_value}
# or use doctor_genderThis project is licensed under the Apache License 2.0.
You may use, modify, and distribute this code under the terms of the license.
For full license details, please refer to the LICENSE file included in the repository.
