Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A bug/unexpected loss of precision occurs when executing train_rsl_rl.py and train_jax_ppo.py #18

Open
Zi-ang-Cao opened this issue Jan 17, 2025 · 3 comments
Assignees

Comments

@Zi-ang-Cao
Copy link

Zi-ang-Cao commented Jan 17, 2025

Hi MuJoCo Playground team,

Thanks for bring such exciting GPU-accelerated MuJoCo environment to us! I hope the learned policy in mujoco_playground will have a better sim2real performance. However, when I tried out the examples in learning folder, both train_rsl_rl.py and train_jax_ppo.py gives me bug/unexpected loss warnings.

I am using Ubuntu 22.04 + NVIDIA-SMI 550 (CUDA12.4) + CUDA-toolkit 12.1. And this is how I setup the virtual environment via Conda:

conda create -n mjx python=3.10 -y
conda activate mjx
conda install nvidia/label/cuda-12.4.0::cuda

conda install "jaxlib=*=*cuda*" jax -c conda-forge
pip install torch==2.3.1 torchvision==0.18.1

pip install mujoco
pip install mujoco_mjx
pip install brax
cd mujoco_playground
pip install -e .

After running python learning/train_rsl_rl.py --env_name=G1JoystickRoughTerrain --use_wandb, the main concerning message is pasted below and the xla_dump_to folder is attached to this issue as well.
report_bug_rsrl_rl.zip

E0116 21:53:35.782328  160376 buffer_comparator.cc:157] Difference at 2: 0.857876, expected 1.0906
E0116 21:53:35.782335  160376 buffer_comparator.cc:157] Difference at 3: 1.05171, expected 0.555048
E0116 21:53:35.782336  160376 buffer_comparator.cc:157] Difference at 4: 1.00491, expected 0.702288
E0116 21:53:35.782338  160376 buffer_comparator.cc:157] Difference at 5: 0.839983, expected 1.15234
E0116 21:53:35.782339  160376 buffer_comparator.cc:157] Difference at 8: 0.697264, expected 0.893167
E0116 21:53:35.782340  160376 buffer_comparator.cc:157] Difference at 9: 0.751796, expected 0.312863
E0116 21:53:35.782341  160376 buffer_comparator.cc:157] Difference at 10: 0.914679, expected 0.354761
E0116 21:53:35.782343  160376 buffer_comparator.cc:157] Difference at 13: 0.676378, expected 0.405283
E0116 21:53:35.782344  160376 buffer_comparator.cc:157] Difference at 15: 0.317772, expected 1.22127
E0116 21:53:35.782345  160376 buffer_comparator.cc:157] Difference at 16: 0.297798, expected 1.3303
2025-01-16 21:53:35.783104: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1180] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0116 21:53:35.786688  160376 buffer_comparator.cc:157] Difference at 16: -nan, expected 0
E0116 21:53:35.786699  160376 buffer_comparator.cc:157] Difference at 17: -nan, expected 0
E0116 21:53:35.786701  160376 buffer_comparator.cc:157] Difference at 18: -nan, expected 0
E0116 21:53:35.786702  160376 buffer_comparator.cc:157] Difference at 19: -nan, expected 0
E0116 21:53:35.786703  160376 buffer_comparator.cc:157] Difference at 20: -nan, expected 0
E0116 21:53:35.786704  160376 buffer_comparator.cc:157] Difference at 21: -nan, expected 0
E0116 21:53:35.786705  160376 buffer_comparator.cc:157] Difference at 22: -nan, expected 0
E0116 21:53:35.786706  160376 buffer_comparator.cc:157] Difference at 23: -nan, expected 0
E0116 21:53:35.786708  160376 buffer_comparator.cc:157] Difference at 24: -nan, expected 0
E0116 21:53:35.786709  160376 buffer_comparator.cc:157] Difference at 25: -nan, expected 0
2025-01-16 21:53:35.786711: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1180] Results do not match the reference. This is likely a bug/unexpected loss of precision.

Similarly, the main error message and related xla_dump_to folder after executing python learning/train_jax_ppo --env_name=G1JoystickRoughTerrain --use_wandb are provided below.
report_bug_jax_ppo.zip

2025-01-16 21:59:15.164089: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1180] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0116 21:59:15.166167  165887 buffer_comparator.cc:157] Difference at 16: 0.806099, expected 8.41456
E0116 21:59:15.166175  165887 buffer_comparator.cc:157] Difference at 17: 0.669318, expected 10.7305
E0116 21:59:15.166177  165887 buffer_comparator.cc:157] Difference at 18: 0.618111, expected 8.46751
E0116 21:59:15.166179  165887 buffer_comparator.cc:157] Difference at 19: 1.59979, expected 11.3751
E0116 21:59:15.166181  165887 buffer_comparator.cc:157] Difference at 20: 1.41715, expected 9.13166
E0116 21:59:15.166182  165887 buffer_comparator.cc:157] Difference at 21: 1.41268, expected 9.03136
E0116 21:59:15.166184  165887 buffer_comparator.cc:157] Difference at 22: 0.627852, expected 9.2793
E0116 21:59:15.166186  165887 buffer_comparator.cc:157] Difference at 23: 0.627606, expected 9.35567
E0116 21:59:15.166189  165887 buffer_comparator.cc:157] Difference at 24: 0.660828, expected 9.69482
E0116 21:59:15.166191  165887 buffer_comparator.cc:157] Difference at 25: 1.15741, expected 10.8501
2025-01-16 21:59:15.166194: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1180] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0116 21:59:15.168032  165887 buffer_comparator.cc:157] Difference at 16: 0.806099, expected 8.41456
E0116 21:59:15.168036  165887 buffer_comparator.cc:157] Difference at 17: 0.669318, expected 10.7305
E0116 21:59:15.168039  165887 buffer_comparator.cc:157] Difference at 18: 0.618111, expected 8.46751
E0116 21:59:15.168040  165887 buffer_comparator.cc:157] Difference at 19: 1.59979, expected 11.3751
E0116 21:59:15.168042  165887 buffer_comparator.cc:157] Difference at 20: 1.41715, expected 9.13166
E0116 21:59:15.168044  165887 buffer_comparator.cc:157] Difference at 21: 1.41268, expected 9.03136
E0116 21:59:15.168045  165887 buffer_comparator.cc:157] Difference at 22: 0.627852, expected 9.2793
E0116 21:59:15.168047  165887 buffer_comparator.cc:157] Difference at 23: 0.627606, expected 9.35567
E0116 21:59:15.168048  165887 buffer_comparator.cc:157] Difference at 24: 0.660828, expected 9.69482
E0116 21:59:15.168051  165887 buffer_comparator.cc:157] Difference at 25: 1.15741, expected 10.8501
2025-01-16 21:59:15.168053: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1180] Results do not match the reference. This is likely a bug/unexpected loss of precision.
@btaba
Copy link
Collaborator

btaba commented Jan 19, 2025

Hi @Zi-ang-Cao , I'm not really sure what the issue is as I can't reproduce on my end. I'm curious if running the MJX tutorial notebooks with brax training on your setup also produce those issues, or is it just the training scripts in Playground?

@Zi-ang-Cao
Copy link
Author

Zi-ang-Cao commented Jan 19, 2025

Hi @Zi-ang-Cao , I'm not really sure what the issue is as I can't reproduce on my end. I'm curious if running the MJX tutorial notebooks with brax training on your setup also produce those issues, or is it just the training scripts in Playground?

Hi Btaba, I downloaded the tutorial notebook and run in my local machine setup. The following screenshot shows that the bug still presents.

Image

Image

When you tried to reproduce the error, did you create the conda in the same way?

# Ubuntu 22.04 + NVIDIA-SMI 550 (CUDA12.4) + CUDA-toolkit 12.1
conda create -n mjx python=3.10 -y
conda activate mjx

# If I do not use the following line, the `pip install -U "jax[cuda12]"` will download the toolkit 12.6 and gives me an error in execution
# That's why I need to install `jax` via the following lines.
conda install nvidia/label/cuda-12.4.0::cuda
conda install "jaxlib=*=*cuda*" jax -c conda-forge

pip install torch==2.3.1 torchvision==0.18.1

pip install mujoco
pip install mujoco_mjx
pip install brax
cd mujoco_playground
pip install -e .

conda list | grep nvidia gives me:

cuda                      12.4.0                        0    nvidia/label/cuda-12.4.0
cuda-cccl                 12.4.99                       0    nvidia/label/cuda-12.4.0
cuda-demo-suite           12.4.99                       0    nvidia/label/cuda-12.4.0
nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
nvidia-cudnn-cu12         8.9.2.26                 pypi_0    pypi
nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi

conda list | grep nvidia gives me:

jax                       0.4.35             pyhd8ed1ab_1    conda-forge
jax-cuda12-pjrt           0.4.35                   pypi_0    pypi
jax-cuda12-plugin         0.4.35                   pypi_0    pypi
jaxlib                    0.4.35          cuda126py310h5e1a0f3_200    conda-forge
jaxopt                    0.8.3                    pypi_0    pypi

pip list | grep torch gives me:

torch                    2.3.1
torchvision              0.18.1

pip list | grep brax gives me:

brax                     0.12.1

pip list | grep mujoco gives me:

mujoco                   3.2.7
mujoco-mjx               3.2.7

@btaba
Copy link
Collaborator

btaba commented Jan 19, 2025

Ok that's good to know. Since the issue pops up for both RSL-RL, brax, and without any playground deps (MJX tutorial notebook), it seems like the issue has something to do with your setup/device + JAX.

The issue seems to pop up for other JAX users as well. This might work. Otherwise I'd open a JAX bug.

Does the RL training converge despite these errors?

@btaba btaba self-assigned this Jan 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants