Diffusion Transformer Policy: Scaling Diffusion Transformer for Generalist Vision-Language-Action Learning
To run the code, you should install the requiresments. The code is run on python3.10 and pytorch 2.2.0, tensorflow==2.15.0, CUDA 12.1.
pip install -r requirements.txt
Then, clone the code as follow,
git clone https://github.com/zhihou7/dit_policy
We provide the corresponding models, that can be utilized for finetuing.
Model | Description | Checkpoint Path |
---|---|---|
DiT Policy | Diffusion Transformer Policy | Google Drive |
DiT Policy | Diffusion Transformer Policy (w/o image augmentation) | Google Drive |
Diffusion MLP Head | Transformer with Diffusion Head Policy (w/o image augmentation) | Google Drive |
Before you run the code, you should update the s3 key "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "S3_ENDPOINT". We train the network with 32 GPUs.
python scripts/train_diffusion_oxe.py task_name=openx_full_train_o2_p32 dataset.traj_length=32 num_pred_action=31 scheduler_type=1 shuffle_buffer_size=128000 dataname=oxe_magic_soup_plus task_name=oxe_full_train_o2_p32_wotimestep_oxe_noclamp_filter batch_size=256
We observe that image augmentation is beneficial for SimplerEnv in our experiments. If you want to use image augmentation, please add ``+image_aug=1''
Here, we provide an example for finetuning with lora, i.e., the 10-shot finetuning code on Real-Franka Arm.
python3 scripts/finetune_realdata.py +pretrained_path=dit_policy_checkpoint.pth dataset.traj_per_episode=16 dataset.traj_length=1 task_name=new_test_nodiffhead_few10_250124 num_pred_action=1 dataname=lab_907_1 batch_size=32 dataset.train_data_list=you pkl dataname file to include the collected pkl files name use_lora=True scheduler_type=0 dataset.num_given_observation=1 max_iters=10000
scheduler_type=0 indicates we use 100 DDPM training steps.
At first, you should follow the instruction-calvin to install CALVIN environment.
we train the network with 4GPUs.
python scripts/train_diffusion_sim.py --config-name config_diffusion_calvin batch_size=32 dataset.traj_length=11 num_pred_action=10 task_name=calvin_exp dataset.num_given_observation=2 dataset=fix_camera use_close_loop_eval=True close_loop_eval.test_episodes_num=32 dataset.use_baseframe_action=True taskname=task_ABC_D dataname=calvin_mc close_loop_eval.eval_iters=10000 close_loop_eval.test_episodes_num=250 scheduler_type=0 wrap_grmg_data=2 +pretrained_path=dit_policy_checkpoint.pth +use_adjust_scheduler=true lr=0.0001 epoch=15 +min_lr_scale=0.01 scheduler.warmup_epochs=1
Method | LIBERO-Spatial | LIBERO-Object | LIBERO-Goal | LIBERO-Long | Average |
---|---|---|---|---|---|
Diffusion Policy from scratch | 78.3 | 92.5% | 68.3 % | 50.5 % | 72.4 % |
Octo fine-tuned | 78.9 % | 85.7 % | 84.6% | 51.1 % | 75.1 % |
OpenVLA fine-tuned | 84.7 % | 88.4 % | 79.2 % | 53.7 % | 76.5 % |
ours fine-tuned | 84.2% | 96.3% | 85.4% | 63.8% | 82.4% |
Method | Input | 1 | 2 | 3 | 4 | 5 | Avg.Len. |
---|---|---|---|---|---|---|---|
RoboFlamingo | S-RGB, G-RGB | 82.4% | 61.9% | 46.6% | 33.1% | 23.5% | 2.47 |
SuSIE | S-RGB | 87.0% | 69.0% | 49.0% | 38.0% | 26.0% | 2.69 |
GR-1 | S-RGB, G-RGB, P | 85.4% | 71.2% | 59.6% | 49.7% | 40.1% | 3.06 |
3D Diffuser | S-RGBD, G-RGBD, Proprio, Cam | 92.2% | 78.7% | 63.9% | 51.2% | 41.2% | 3.27 |
ours w/o pretraining | Static-RGB | 89.5% | 63.3% | 39.8% | 27.3% | 18.5% | 2.38 |
ours | Static-RGB | 94.5% | 82.5% | 72.8% | 61.3% | 50.0% | 3.61 |
Simulation Benchmark Evaluations
This evaluation is based on SimplerEnv
0 | 3 | 4 | 5 | 6 | 8 | |
---|---|---|---|---|---|---|
coke_can/matching_avg | 0.7266666666666669 | 0.567 | 0.787 | 0.17 | nan | 0.163 |
coke_can/variant_avg | 0.6 | 0.49 | 0.823 | 0.006 | nan | 0.545 |
coke_can/matching/horizontal | 0.8500000000000001 | 0.82 | 0.74 | 0.21 | nan | 0.27 |
coke_can/matching/vertical | 0.7400000000000001 | 0.33 | 0.74 | 0.21 | nan | 0.03 |
coke_can/matching/standing | 0.5900000000000001 | 0.55 | 0.88 | 0.09 | nan | 0.19 |
coke_can/variant/horizontal | 0.6799999999999999 | 0.569 | 0.822 | 0.005 | nan | 0.711 |
coke_can/variant/vertical | 0.5066666666666667 | 0.204 | 0.754 | 0.0 | nan | 0.271 |
coke_can/variant/standing | 0.6133333333333334 | 0.698 | 0.893 | 0.013 | nan | 0.653 |
move_near/variant | 0.5213089271066149 | 0.323 | 0.792 | 0.031 | nan | 0.477 |
move_near/matching | 0.49126940133037694 | 0.317 | 0.779 | 0.042 | nan | 0.462 |
drawer/matching_avg | 0.4629629629629629 | 0.597 | 0.25 | 0.227 | nan | 0.356 |
drawer/variant_avg | 0.3752343844338537 | 0.294 | 0.353 | 0.011 | nan | 0.177 |
drawer/matching/open | 0.2314814814814815 | 0.296 | 0.157 | 0.009 | nan | 0.194 |
drawer/matching/close | 0.6944444444444443 | 0.891 | 0.343 | 0.444 | nan | 0.518 |
drawer/variant/open | 0.2155516441230727 | 0.069 | 0.333 | 0.0 | nan | 0.158 |
drawer/variant/close | 0.5349171247446347 | 0.519 | 0.372 | 0.021 | nan | 0.195 |
put_spoon_on_tablecloth/matching_partial | 0.25 | 0.167 | nan | 0.347 | 0.778 | 0.041 |
put_spoon_on_tablecloth/matching_entire | 0.16666666666666666 | 0.0 | nan | 0.125 | 0.472 | 0.0 |
put_carrot_on_plate/matching_partial | 0.20833333333333334 | 0.208 | nan | 0.528 | 0.278 | 0.333 |
put_carrot_on_plate/matching_entire | 0.16666666666666666 | 0.042 | nan | 0.083 | 0.097 | 0.0 |
stack_green_block_on_yellow_block/matching_partial | 0.08333333333333333 | 0.083 | nan | 0.319 | 0.403 | 0.125 |
stack_green_block_on_yellow_block/matching_entire | 0.0 | 0.0 | nan | 0.0 | 0.042 | 0.0 |
put_eggplant_in_basket/matching_partial | 0.08333333333333333 | 0.0 | nan | 0.667 | 0.875 | 0.083 |
put_eggplant_in_basket/matching_entire | 0.0 | 0.0 | nan | 0.431 | 0.569 | 0.041 |
apple_in_drawer/matching_avg | 0.04203703703703703 | 0.213 | 0.037 | 0.0 | 0.0 | nan |
apple_in_drawer/variant_avg | 0.035355068856811014 | 0.101 | 0.206 | 0.0 | 0.0 | nan |
models | ours | RT-1-X | RT-2-X | Octo-Base | Octo-Small | OpenVLA |
In our experiments, we use the Bridge_orig from tfds in google cloud, in which the image has been resized (480*512->224*224) and caused image distortion. We think this part might significantly affect the evaluation on bridige. Please notice that RT-2-X is a huge model with web-scale data.
Please refer to the project page.
The dataloader code of OXE is based on OpenVLA, The dataloader code of CALVIN is based on GR-MG, The architecture is based on transformers.
If you find our code or models useful in your work, please cite our paper:
@article{hou2024diffusion,
title={Diffusion Transformer Policy: Scaling Diffusion Transformer for Generalist Vision-Language-Action Learning},
author={Hou, Zhi and Zhang, Tianyi and Xiong, Yuwen and Pu, Hengjun and Zhao, Chengyang and Tong, Ronglei and Qiao, Yu and Dai, Jifeng and Chen, Yuntao},
journal={arXiv preprint arXiv:2410.15959},
year={2024}
}