This repo is the official code release for the ICML 2024 conference paper:
|
Rethinking Transformers in Solving POMDPs
|
In this work, we challenge the suitability of Transformers as sequence models in Partially Observable RL by leveraging regular language and circuit complexity theories. We advocate Linear RNNs as a promising alternative.
In the paper, we compare representative models including GPT, LSTM, and LRU on three different tasks to validate our theory through experiments. This codebase is used to reproduce the experimental results from the paper.
Run the following commands.
cd pomdp-discrete
conda create -n tfporl-discrete python=3.8
pip install -r requirements.txt
Run the following commands.
cd pomdp-discrete
conda create -n tfporl-continuous python=3.8
pip install -r requirements.txt
If you meet any problems, please refer to the guidance in JAX.
We can only guarantee the reproducibility with the environment configuration as below.
First, you need to download the file from this link and tar -xvf the_file_name in the ~/.mujoco folder. Then, run the following commands.
cd defog
conda create -n tfporl-defog python=3.8.17After that, add the following lines to your ~/.bashrc file:
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/YOUR_PATH_TO_THIS/.mujoco/mujoco210/bin
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidiaRemember to source ~/.bashrc to make the changes take effect.
Install D4RL by following the guidance in D4RL.
Degrade the dm-control and mujoco package:
pip install mujoco==2.3.7
pip install dm-control==1.0.14pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
pip install -r requirements.txtTo download original D4RL data,
python download_d4rl_datasets.pyAfter installing packages, you can run the following script to reproduce results:
cd pomdp-discrete
# for regular language tasks
python main.py \
--config_env configs/envs/regular_parity.py \
--config_env.env_name 25 \
--config_rl configs/rl/dqn_default.py \
--train_episodes 40000 \
--config_seq configs/seq_models/gpt_default.py \
--config_seq.model.seq_model_config.n_layer {n_layer} \
--config_seq.sampled_seq_len -1 \
--config_seq.model.action_embedder.hidden_size=0 \
--config_rl.config_critic.hidden_dims="()"
# for Passive T-maze
python main.py \
--config_env configs/envs/tmaze_passive.py \
--config_env.env_name 50 \
--config_rl configs/rl/dqn_default.py \
--train_episodes 20000 \
--config_seq configs/seq_models/lstm_default.py \
--config_seq.sampled_seq_len -1 \
# for Passive Visual Match
python main.py \
--config_env configs/envs/visual_match.py \
--config_env.env_name 60 \
--config_rl configs/rl/sacd_default.py \
--shared_encoder --freeze_critic \
--train_episodes 40000 \
--config_seq configs/seq_models/gpt_cnn.py \
--config_seq.sampled_seq_len -1 \In the scripts, env_name is the max training length of regular langauge task. You can try other regular language tasks in pomdp-discretes/configs/envs/. and other sequence model in pomdp-discretes/configs/seq_models/.
Feel free to add other regular language in pomdp-discretes/envs/regular.py by input its DFA.
After installing packages, you can run the following script to reproduce results:
python main.py \
--config_env configs/envs/pomdps/pybullet_p.py \
--config_env.env_name cheetah \
--config_rl configs/rl/td3_default.py \
--config_seq configs/seq_models/lstm_default.py \
--config_seq.sampled_seq_len 64 \
--train_episodes 1500 \
--shared_encoder --freeze_all \In the scripts, env_name is the control task type, including ant, walker, cheetah, and hopper. You can change the pomdp by replacing pybullet_p with pybullet_v. and other sequence model in pomdp-continuous/configs/seq_models/.
After installing the packages and data, you can run the following script to reproduce results:
cd defog
python main.py env=hopper model=dtYou can replace hopper with halfcheetah, walker2d. You can also replace dt with dlstm or dlru to test more sequence model.
The code is largely based on prior works:
This work is licensed under the MIT license. See the LICENSE file for details.
If you find our work useful, please consider citing:
@article{Lu2024Rethink,
title={Rethinking Transformers in Solving POMDPs},
author={Chenhao Lu and Ruizhe Shi and Yuyao Liu and Kaizhe Hu and Simon S. Du and Huazhe Xu},
journal={International Conference on Machine Learning},
year={2024}
}