NaNs at Inference #525
Replies: 8 comments
-
Just so I understand - you're training in MJX and evaluating the policy in C MuJoCo (presumably via the python bindings) and seeing unstable physics? Is it possible that there's something else different between the training and eval environments, possibly the initial state? Are you hitting some terminating condition that you're ignoring during the eval? What does the video look like leading up to the instability? Feel free to post a colab. |
Beta Was this translation helpful? Give feedback.
-
I am training in MJX then evaluating the policy in python the same as the this part of the Barkour colab. Similar to the Barkour colab the training and eval environments are identical, including the initial state. I monitor the termination condition when visualizing the policy and it is not terminating, but simply producing the NaN control value. Up to instability (generally one frame), the initial state based on the keyframe I'm using is set and looks correct then it goes to NaN control values. Just to reiterate, I am using a pipeline that very closely mimics the Barkour colab. I have used this pipeline for many problems and am reasonably certain that it works successfully. In the past when I had NaNs it would occur during training due to an unstable simulation or could be resolved with I'll see if I can put together a colab to reproduce this issue, but it does involve a reasonable train time and I may not able to open source this just yet (hopefully soon though). Are there any additional checks I could perform or logs I can provide? |
Beta Was this translation helpful? Give feedback.
-
OK, if you really think it's happening somewhere in the inference function, that's a bit suprising to me, but the good news is that's a pretty small surface area to search - really only a few hundred lines of code or so. You can try removing the @jit so you can trace through, or binary search for the nan with |
Beta Was this translation helpful? Give feedback.
-
Thanks for the suggestion! I'll spend some time tracking down the error and share what I find. |
Beta Was this translation helpful? Give feedback.
-
any luck? |
Beta Was this translation helpful? Give feedback.
-
I still need to investigate this some more, but I can share what I have figured out so far. First, make sure that your simulation is stable. Simulations with features like many contacts, unrealistically high control actions and highly constrained systems (ex. the equality constraints creating a loop as mentioned above) can become unstable easily. This was not the case for me. What produced NaNs:
What did not produce NaNs:
I will work on tracking this down more in the coming weeks, but hopefully this helps! |
Beta Was this translation helpful? Give feedback.
-
I have also seen scenarios where a 4090 produces unstable physics where an A100 does not, given the exact same MJX environment and python version. I have yet to track down why, but it probably has something to do with matmul precision defaults. |
Beta Was this translation helpful? Give feedback.
-
Indeed we find that setting one of
helps on RTX devices. |
Beta Was this translation helpful? Give feedback.
-
Hello,
I have been encountering an issue where my training runs error free and learns well, but then control values of NaN are generated at inference when collecting a trajectory to make a video of the task.
I am currently using the following lines to improve the precision and debug NaNs:
The error generated at inference from MuJoCo is:
The error from the inference is the following:
I'm not sure how the training could work well and then at inference generate NaNs as a NaN value in training would have thrown an error. My model does include a decent number of contacts and two equality constraints that create a loop constraint, but the model appears stable in MuJoCo and during the training.
I do have a work around to fix the issue, which is increasing to 64 bit precision:
My main concern here is that the training time increases drastically along with the GPU memory required. Training for 1 million steps went from 1min 42s to 3 min 42s (on an RTX 4090) and the GPU memory to allocate went from ~20 GB to ~46 GB. Excluding some contacts allowed me to reduce this to 2 min 56s and back under the 24 GB of memory to continue using this GPU.
My pipeline mirrors the Barkour training and inference pipeline very closely.
Some model details that may help (also very similar to Barkour model):
training dt = 0.02
model.opt.timestep = 0.005
integrator = Euler (though I did try the RK4 and it didn't help)
eulerdamp = disable
iterations = 1
ls_iterations = 5
I am using MuJoCo/MJX = 3.1.6 and Brax = 0.9.4 (though I also tried 0.10.5 and same the same issues).
Is there a reason that I am encountering this behaviour when performing the inference?
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions