Official implementation of "Scaling World-Model Reinforcement Learning Through Diffusion Policy Optimization" by
Xiaoyuan Cheng* (ucesxc4@ucl.ac.uk), Wenxuan Yuan* (YUAN0186@e.ntu.edu.sg), Zhancun Mu, Yuanzhao Zhang, Yiming Yang, Hai Wang, Zhuo Sun†, Che Liu†
MBDPO is a model-based reinforcement learning framework that unifies search and policy optimization through a diffusion policy representation inside a learned latent world model. Instead of building an explicit planner (e.g. MPPI) on top of the world model, MBDPO reformulates policy optimization as a diffusion process over imagined trajectories, where the score field is corrected by model-based returns and anchored to the behavior distribution via an implicit energy function. This eliminates the structural misalignment between search and value learning that limits prior world-model approaches, and yields monotonic scaling of performance with model capacity.
The repository contains code for training and evaluating MBDPO across 121 continuous control tasks in three settings: online from scratch, multi-task offline pretraining, and offline-to-online (O2O) fine-tuning.
MBDPO learns structured latent trajectories across locomotion and manipulation tasks. Cyclic behaviors form closed-loop patterns, while goal-directed tasks produce smooth trajectories toward successful completion.
DMControl
Cheetah Run Front![]() Reward: 740.7 |
Cup Spin![]() Reward: 840.4 |
Reacher Hard![]() Reward: 985.0 |
Walker Run![]() Reward: 769.2 |
MetaWorld
Bin Picking![]() Reward: 1585.0 Success Rate: 1.00 |
Disassemble![]() Reward: 1556.2 Success Rate: 1.00 |
Door Close![]() Reward: 1549.1 Success Rate: 1.00 |
Lever Pull![]() Reward: 1664.9 Success Rate: 0.90 |
We provide ready-to-use Conda environment files for different experiment suites.
# Example: create environment for MT80 experiments
conda env create -f conda_envs/mbdpo-mt80.yml
conda activate mbdpo-mt80
# Optional: other provided environments
# conda env create -f conda_envs/mbdpo-ms2.yml
# conda env create -f conda_envs/mbdpo-myo.ymlSee notes for each environment in this link
For multi-task offline pretraining, we use the replay buffer results from open-sourced TD-MPC2 dataset (mt80 & mt30).
To download (remember to adjust the dataset path accordingly in configuration yaml files):
- mt30:
mkdir -p ./offline_dataset/mt30
seq 0 3 | xargs -I {} -P 4 wget -c \
-O ./offline_dataset/mt30/chunk_{}.pt \
"https://huggingface.co/datasets/nicklashansen/tdmpc2/resolve/main/mt80/chunk_{}.pt?download=true"- mt80:
mkdir -p ./offline_dataset/mt80
seq 0 19 | xargs -I {} -P 4 wget -c \
-O ./offline_dataset/mt80/chunk_{}.pt \
"https://huggingface.co/datasets/nicklashansen/tdmpc2/resolve/main/mt80/chunk_{}.pt?download=true"This codebase provides support for all 121 continuous control tasks from DMControl (39 tasks), MetaWorld (50 tasks), ManiSkill2 (5 tasks), MyoSuite (10 tasks), Locomotion (7 tasks), and Visual RL (10 tasks) used in our technical report. In the DMControl domain, we use the 11 custom tasks followed the setting from TD-MPC2.
See this link for more detailed tasks and notes in each domain.
python scripts/train.py task=dog-run seed=1 steps=4000000or in the parallel launcher
python scripts/online_parallel_train.py --config cfgs/online_parallel_config.yamlThe default training entry remains the Torch implementation. The JAX
implementation is namespaced under MBDPO/jax_impl/ and uses a separate entry point:
# JAX single-task online
python scripts/train_jax.py task=cheetah-run steps=100000 model_size=1
# JAX offline from NumPy .npz chunks
python scripts/train_jax.py mode=offline task=mt30 data_dir=/path/to/mt30_npz_chunks
# JAX data-parallel updates
python scripts/train_jax.py mode=offline task=mt30 data_dir=/path/to/mt30_npz_chunks \
batch_size=4096 jax_data_parallel_devices=4The JAX backend shares environment construction and common config/logger/task
metadata with the Torch path, but keeps JAX-specific model, math, replay buffer,
diffusion planner, and data-parallel update code under MBDPO/jax_impl/. It has not
been validated on every paper task.
python scripts/train.py task=mt80 multitask=true
# or
python scripts/train.py task=mt30 multitask=truepython scripts/offline_to_online.py \
checkpoint=/path/to/checkpoint.pt \
save_path=/path/to/output_dir \
off2on_task="walker-run" \
steps=40000python scripts/evaluate.py \
task=mt80 \
checkpoint=/path/to/checkpoint.pt \
eval_episodes=10About parameter usage, please refer to this description
@misc{cheng2026scalingworldmodelreinforcementlearning,
title={Scaling World-Model Reinforcement Learning Through Diffusion Policy Optimization},
author={Xiaoyuan Cheng and Wenxuan Yuan and Zhancun Mu and Yuanzhao Zhang and Yiming Yang and Hai Wang and Zhuo Sun and Che Liu},
year={2026},
eprint={2605.26282},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2605.26282},
}
Contributions are welcome — bug reports, questions, feature requests, and pull requests all help. To get started, please open an issue or submit a pull request.
For details on reporting bugs, the pull request process, and code style, see CONTRIBUTING.md. For questions about the paper itself, feel free to contact Xiaoyuan Cheng: ucesxc4@ucl.ac.uk and Wenxuan Yuan: YUAN0186@e.ntu.edu.sg.
This project is released under the MIT License.
Note that this repository depends on third-party code and simulators (DMControl, Meta-World, ManiSkill2, MyoSuite, etc.), which are subject to their own respective licenses.








