-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A bug/unexpected loss of precision occurs when executing train_rsl_rl.py and train_jax_ppo.py #18
Comments
Hi @Zi-ang-Cao , I'm not really sure what the issue is as I can't reproduce on my end. I'm curious if running the MJX tutorial notebooks with brax training on your setup also produce those issues, or is it just the training scripts in Playground? |
Hi Btaba, I downloaded the tutorial notebook and run in my local machine setup. The following screenshot shows that the bug still presents. When you tried to reproduce the error, did you create the conda in the same way? # Ubuntu 22.04 + NVIDIA-SMI 550 (CUDA12.4) + CUDA-toolkit 12.1
conda create -n mjx python=3.10 -y
conda activate mjx
# If I do not use the following line, the `pip install -U "jax[cuda12]"` will download the toolkit 12.6 and gives me an error in execution
# That's why I need to install `jax` via the following lines.
conda install nvidia/label/cuda-12.4.0::cuda
conda install "jaxlib=*=*cuda*" jax -c conda-forge
pip install torch==2.3.1 torchvision==0.18.1
pip install mujoco
pip install mujoco_mjx
pip install brax
cd mujoco_playground
pip install -e .
|
Ok that's good to know. Since the issue pops up for both RSL-RL, brax, and without any playground deps (MJX tutorial notebook), it seems like the issue has something to do with your setup/device + JAX. The issue seems to pop up for other JAX users as well. This might work. Otherwise I'd open a JAX bug. Does the RL training converge despite these errors? |
Hi MuJoCo Playground team,
Thanks for bring such exciting GPU-accelerated MuJoCo environment to us! I hope the learned policy in mujoco_playground will have a better sim2real performance. However, when I tried out the examples in
learning
folder, bothtrain_rsl_rl.py
andtrain_jax_ppo.py
gives me bug/unexpected loss warnings.I am using Ubuntu 22.04 + NVIDIA-SMI 550 (CUDA12.4) + CUDA-toolkit 12.1. And this is how I setup the virtual environment via Conda:
After running
python learning/train_rsl_rl.py --env_name=G1JoystickRoughTerrain --use_wandb
, the main concerning message is pasted below and thexla_dump_to
folder is attached to this issue as well.report_bug_rsrl_rl.zip
Similarly, the main error message and related
xla_dump_to
folder after executingpython learning/train_jax_ppo --env_name=G1JoystickRoughTerrain --use_wandb
are provided below.report_bug_jax_ppo.zip
The text was updated successfully, but these errors were encountered: