A flow matching framework for de novo generation and structure prediction of heterogeneous catalysts. CatFlow jointly generates slab structures and adsorbate coordinates within a unified objective, directly capturing surface-adsorbate interactions.
- Co-generation of slab-adsorbate systems: First framework to jointly generate slab structures and adsorbate coordinates via flow matching.
- Factorized representation: Decomposes slab-adsorbate systems into primitive cells, transformation matrices, vacuum scaling factors, and adsorbates, reducing learnable variables by 9.2× on average.
- Two task modes: Supports both de novo generation (with discrete flow matching for atomic species) and structure prediction (fixed composition).
- Adsorbate conditioning: Generates catalyst structures conditioned on target adsorbate species.
conda env create -f environment.yml
conda activate catflowThe main dependencies include:
- PyTorch >= 2.8
- PyTorch Lightning
- PyMatGen
- ASE
- fairchem-core
CatFlow uses the Open Catalyst 2020 (OC20) IS2RES dataset. The data processing pipeline transforms raw OC20 structures into a factorized representation.
The processed datasets with generated samples and computed energy values are available at Zenodo.
To perform the data processing pipeline manually from raw OC20 structures, follow the instructions below.
# Download and extract OC20 IS2RES set
wget https://dl.fbaipublicfiles.com/opencatalystproject/data/is2res_train_val_test_lmdbs.tar.gz
tar -zxvf is2res_train_val_test_lmdbs.tar.gz
# Download OC20 data mapping file
wget https://dl.fbaipublicfiles.com/opencatalystproject/data/oc20_data_mapping.pkl
# Extract metadata and process into factorized representation
python scripts/data_processing/extract_metadata.py
python scripts/processing.pyThe factorized representation consists of four components:
-
Primitive cell
$(S_{\text{prim}})$ : The repeating unit containing atomic species, coordinates, and lattice matrix -
Transformation matrix
$(M \in \mathbb{Z}^{3 \times 3})$ : Specifies how to construct the slab from the primitive cell -
Vacuum scaling factor
$(k_{\text{vac}})$ : Determines the vacuum region height - Adsorbate: Atomic species (condition) and Cartesian coordinates (learnable)
Train the model with discrete flow matching for atomic species:
bash bash_scripts/train_gen.shKey configurations:
model.flow_model_args.dng=true: Enable discrete flow matching for de novo generationmodel.training_args.flow_loss_type=x1_loss: Use x1 prediction lossmodel.training_args.loss_type=l1: L1 loss for geometric variables
Train the model with fixed atomic species:
bash bash_scripts/train_pred.shKey configurations:
model.flow_model_args.dng=false: Disable discrete flow matching (species are given)
| Argument | Description | Default |
|---|---|---|
train.pl_trainer.devices |
Number of GPUs | 8 |
model.training_args.lr |
Learning rate | 1e-4 |
model.training_args.warmup_steps |
Warmup steps | 5000 |
data.batch_size.train |
Batch size per device | 128 |
bash bash_scripts/sample_all_subsets.shThis script runs parallel generation across multiple GPUs. Key parameters:
CHECKPOINT_PATH="path/to/checkpoint.ckpt"
DATA_ROOT="path/to/dataset"
BASE_OUTPUT_DIR="path/to/outputs"For single dataset sampling:
python scripts/sampling/save_samples.py \
--checkpoint $CHECKPOINT_PATH \
--val_lmdb_path $LMDB_PATH \
--output_dir $OUTPUT_DIR \
--num_samples 1 \
--sampling_steps 50 \
--batch_size 128Evaluate the adsorption energy of generated structures:
bash bash_scripts/eval_E_ads.shThis script:
- Relaxes generated structures using pretrained GNN potentials (UMA)
- Computes adsorption energies:
$\Delta E_{\text{ads}} = E_{\text{sys}} - E_{\text{slab}} - E_{\text{ads}}$ - Aggregates results across all adsorbate subsets
| Metric | Description |
|---|---|
| Validity | Interatomic distances > 0.5 Å, cell volume > 0.1 ų |
| Uniqueness | Non-duplicate slab structures (via StructureMatcher) |
| Compositional diversity | Mean pairwise distance of compositional fingerprints |
| Match rate | Structural match with ground truth (structure prediction) |
| RMSD | Root mean square deviation from ground truth |
|
|
|
CatFlow/
├── configs/ # Hydra configuration files
├── bash_scripts/ # Training and evaluation scripts
├── scripts/
│ ├── sampling/ # Sampling utilities
│ └── relax_energy/ # Energy evaluation scripts
├── src/
│ ├── data/ # Data loading and processing
│ ├── models/
│ │ ├── flow.py # Flow matching model
│ │ ├── layers.py # Encoder, decoder, transformer layers
│ │ └── transformers.py # DiT blocks and embedders
│ └── module/
│ └── effcat_module.py # PyTorch Lightning module
├── environment.yml
└── primitive_atom_distribution.json

