Skip to content

Commit

Permalink
Fix non-used device in jax to torch
Browse files Browse the repository at this point in the history
  • Loading branch information
mantasu authored Jul 2, 2024
1 parent b18e6d1 commit c1ec71d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions gymnasium/wrappers/jax_to_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ def _jax_iterable_to_torch(
if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
return type(value)._make(jax_to_torch(v) for v in value)
return type(value)._make(jax_to_torch(v, device) for v in value)
else:
return type(value)(jax_to_torch(v) for v in value)
return type(value)(jax_to_torch(v, device) for v in value)


class JaxToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
Expand Down

0 comments on commit c1ec71d

Please sign in to comment.