Official PyTorch Implementation of T2W Weight Generation Framework
T2W leverages Diffusion Transformer (DiT) to directly synthesize neural network weights from natural language descriptions.
π Paper: Text2Weight: Bridging Natural Language and Neural Network Weight Spaces (ACM MM 2025)
T2W presents a novel paradigm for neural network weight generation using text-conditioned diffusion models. Given a textual description of the target task, our framework generates functional network weights that can be directly deployed for downstream classification tasks.
- Text-Conditioned Weight Synthesis: Generate task-specific neural network weights from natural language descriptions
- Diffusion Transformer Architecture: Employ DiT as the generative backbone for high-quality weight generation
- LoRA-Style Classifier Head: Generate weights for a lightweight two-layer MLP adapter based on LoRA principles
- Multi-Dataset Support: Support for CIFAR-100, Caltech-256, and ImageNet benchmarks
Text Description β CLIP Encoder β Condition Features β DiT β Weight Prediction β LoRA-MLP Classifier β Downstream Evaluation
T2W/
βββ T2W/ # Core framework module
β βββ diffusion/ # Diffusion process implementation
β β βββ gaussian_diffusion.py # Gaussian diffusion utilities
β β βββ respace.py # Timestep respacing
β β βββ timestep_sampler.py # Timestep sampling strategies
β βββ models/
β β βββ transformer.py # DiT model architecture
β βββ distributed.py # Multi-GPU training utilities
β βββ meters.py # Training/testing meters
β βββ utils.py # General utilities
βββ data_gen/ # Dataset generation scripts
β βββ cifar100/ # CIFAR-100 data generation
β β βββ generate_config.py # Step 1: Generate task configurations
β β βββ generate_dataset.py # Step 2: Train adapters and generate weights
β β βββ dataset_config.py # Step 3: Build dataset with CLIP features
β β βββ normalization.py # Weight normalization utilities
β β βββ p_dataset.py # PyTorch dataset for training
β βββ caltech256/ # Caltech-256 data generation
β βββ imagenet/ # ImageNet data generation
βββ configs/ # Hydra configuration files
β βββ train/ # Training configurations
β βββ test/ # Testing configurations
βββ main.py # Main training script (CIFAR-100)
βββ eval.py # Evaluation/inference script (CIFAR-100)
βββ vis.py # Visualization utilities
βββ LICENSE
# Create conda environment
conda create -n t2w python=3.10
conda activate t2w
# Install PyTorch (adjust CUDA version as needed)
pip install torch>=2.0.0 torchvision>=0.15.0 --index-url https://download.pytorch.org/whl/cu118
# Install dependencies
pip install clip
pip install torchmetrics
pip install tqdm
pip install numpy- GPU: NVIDIA GPU with at least 24GB VRAM (recommended: A100, RTX 4090)
- RAM: 32GB+ system memory
- Storage: 50GB+ for datasets and checkpoints
The data generation pipeline creates (text, weight) pairs for training T2W. Each sample consists of:
- Text Features: CLIP-encoded features of task descriptions (512-dim)
- Weight: Pre-trained LoRA-style MLP classifier weights for the corresponding task
The data generation process involves three sequential steps:
| Step | Script | Description |
|---|---|---|
| 1 | generate_config.py |
Generate random task configurations (class combinations) |
| 2 | generate_dataset.py |
Train adapter weights for each task |
| 3 | dataset_config.py |
Build final dataset with CLIP text features |
This script generates random combinations of classes to create diverse training tasks.
cd data_gen/cifar100
# Generate task configurations
python generate_config.pyOutput: data_point.json - Contains task configurations with randomly selected class subsets.
Configuration Options (modify in script):
data_point: Number of tasks to generate (default: 12000)class_num_range: Range of classes per task (default: 8-32)
This script trains LoRA-style MLP adapters for each task configuration.
# First, train the base model on all classes
python generate_dataset.py 0 1 0 # args: begin, end, gpu_id
# Modify the script to call train_base() first (uncomment in __main__)
# This generates: checkpoints/base.ptAfter obtaining base.pt, train task-specific adapters:
# Train task adapters (can be parallelized across GPUs)
python generate_dataset.py 0 3000 0 # GPU 0: tasks 0-3000
python generate_dataset.py 3000 6000 1 # GPU 1: tasks 3000-6000
python generate_dataset.py 6000 9000 2 # GPU 2: tasks 6000-9000
python generate_dataset.py 9000 12000 3 # GPU 3: tasks 9000-12000Arguments:
begin: Starting task indexend: Ending task indexgpu_id: CUDA device ID
Output: checkpoints/{task_id}/task.pt - Adapter weights for each task
This script processes the trained weights and generates CLIP text features for conditioning.
python dataset_config.pyOutput:
train.json: Training dataset (80% of samples)val.json: Validation dataset (20% of samples)
Each JSON entry has the following structure:
{
"index": 0,
"selected_classes": ["dog", "cat", "bird", ...],
"text_features": [0.023, -0.156, ...],
"path": "checkpoints/0/task.pt"
}CIFAR-100
cd data_gen/cifar100
# Step 1: Generate configurations
python generate_config.py
# Step 2: Train base model
# Edit generate_dataset.py: uncomment train_base() in __main__
python generate_dataset.py 0 1 0
# Step 2b: Train task adapters
# Edit generate_dataset.py: uncomment train_task() in __main__
python generate_dataset.py 0 12000 0
# Step 3: Build dataset
python dataset_config.pyCaltech-256
cd data_gen/caltech256
# Download Caltech-256 dataset first
# Extract to: ./data/caltech256/256_ObjectCategories/
# Then follow the same steps as CIFAR-100
python generate_config.py
python generate_dataset.py 0 1 0 # Train base
python generate_dataset.py 0 12000 0 # Train tasks
python dataset_config.pyImageNet
cd data_gen/imagenet
# Ensure ImageNet is downloaded and extracted
# Update the data path in generate_dataset.py
# Follow the same steps as CIFAR-100
python generate_config.py
python generate_dataset.py 0 1 0 # Train base
python generate_dataset.py 0 12000 0 # Train tasks
python dataset_config.pyNote: The current
main.pyis configured for the CIFAR-100 dataset. To train on other datasets (Caltech-256, ImageNet), you need to modify the dataset inmain.pyaccordingly.
python main.py --config-name cifar100_classifierModify configs/train/cifar100_classifier.yaml to adjust training parameters:
# Basic settings
amp: false # Automatic Mixed Precision
num_gpus: 1 # Number of GPUs
exp_name: cifar100_classifier # Experiment name
resume_path: null # Path to resume from checkpoint
# Dataset settings
dataset:
normalizer: openai # Weight normalization strategy
openai_coeff: 1.646 # Normalization coefficient
num_workers: 8 # Data loading workers
# Transformer architecture
transformer:
ema: true # Use Exponential Moving Average
predict_xstart: true # Predict x_start directly
chunk_size: 576 # Token chunk size
n_embd: 2048 # Embedding dimension
n_layer: 12 # Number of transformer layers
n_head: 16 # Number of attention heads
# Training hyperparameters
train:
base_lr: 4e-4 # Base learning rate
num_ep: 15000 # Number of epochs
mb_size: 550 # Batch size
ema_decay: 0.9999 # EMA decay rate
grad_clip: 0.1 # Gradient clippingCheckpoints are saved to results/{exp_name}/checkpoints/:
best.pt: Best performing model{epoch:04d}.pt: Periodic checkpoints
Note: The current
eval.pyis configured for the CIFAR-100 dataset. To evaluate on other datasets, you need to modify the dataset paths and configurations ineval.pyaccordingly.
Use eval.py to generate weights and evaluate them on downstream classification tasks:
python eval.py --config-name cifar100_classifierBefore running evaluation, modify the configuration variables at the top of eval.py:
# Configuration: modify these paths according to your setup
DATA_ROOT = "./data/cifar100"
CHECKPOINT_PATH = "./results/cifar100_classifier/checkpoints/best.pt"The evaluation script performs the following steps:
- Load T2W Model: Load the trained T2W checkpoint
- Sample Weights: Generate neural network weights from text conditions using the diffusion process
- Build Classifiers: Convert generated weights to functional classifier networks
- Evaluate: Test the generated classifiers on downstream CIFAR-100 classification tasks
The framework reports:
- Classification Accuracy: Top-1 accuracy on the test set
- MSE Loss: Mean squared error between generated and ground-truth weights
Using device: cuda:0
Now Sample: torch.Size([32, 17424])
ADAPTED Accuracy: 78.45%
Evaluation results are saved to:
cifar100_eval_results.json: Contains accuracy and loss for each batch
The core generative model uses a Transformer architecture adapted for diffusion:
Input Parameters (Noised) β Token Encoding β Positional Embedding
β
Multi-Head Self-Attention (ΓN layers)
β
Token Decoding β Output Parameters
Key components:
- Parameter Encoder: Projects weight chunks into transformer token embeddings
- Frequency Embedder: Encodes diffusion timesteps using sinusoidal features
- CLIP Condition: Integrates text features as additional conditioning tokens
Generated weights correspond to a two-layer MLP adapter:
linear1.weight: (hidden_dim Γ 512) - First layer weightslinear1.bias: (hidden_dim,) - First layer biaslinear2.weight: (512 Γ hidden_dim) - Second layer weightslinear2.bias: (512,) - Second layer bias
Total parameters: ~17K (with hidden_dim=16)
If you find this work useful, please consider citing:
@inproceedings{tian2025text2weight,
title={Text2Weight: Bridging Natural Language and Neural Network Weight Spaces},
author={Tian, Bowen and Chen, Wenshuo and Li, Zexi and Lai, Songning and Wu, Jiemin and Yue, Yutao},
booktitle={Proceedings of the 33rd ACM International Conference on Multimedia},
pages={10152--10160},
year={2025}
}This project is licensed under the MIT License - see the LICENSE file for details.
This codebase builds upon several excellent open-source projects:
For questions and issues, please open an issue on this repository or contact the authors directly.

