Providing camera RGB values within the jitted _get_obs() method of an environment #551
eleninisioti
started this conversation in
General
Replies: 1 comment
-
Hi @eleninisioti , that's correct, the rendering code is not jittable. The closest workaround is to use madrona_mjx, which is still very much under development so YMMV. Hope that helps! |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I would like to get RGB observations from a camera within the step function of a brax environment. To do so I am adding a "my_camera" tag in my xml and then I try to get the observation following the logic in the file io/image.py The issue is that the code there is making use of mujoco instead of mujoco.mjx so it cannot be jitted. If I try:
image = render_array(self.sys, trajectory=pipeline_state, camera="my_camera",height= height, width=width)
then it works when _get_obs() is not jitted but, when it is, it throws the error:
....brax/io/image.py", line 40, in get_image d.qpos, d.qvel = state.q, state.qd jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[5]
I then thought of getting some inspiration from mjx/pipeline.py and did
` from mujoco import mjx
import mujoco
`
Unsurprisngly this complains that the data I am providing to update_scene does not have the right format.
My understanding is that the code written for visualizing from a camera has not been written to enable jitting. This would require writing a different version of the renderer. But perhaps there is a workaround or another way of accessing camera information?
Beta Was this translation helpful? Give feedback.
All reactions