Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support heterogeneous dicts in infos #250

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions minari/data_collector/data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def step(
self._buffer = EpisodeBuffer(
id=self._episode_id,
observations=step_data["observation"],
infos=step_data["info"],
infos=[step_data["info"]],
)

return obs, rew, terminated, truncated, info
Expand Down Expand Up @@ -206,7 +206,7 @@ def reset(
seed=seed,
options=options,
observations=step_data["observation"],
infos=step_data["info"] if self._record_infos else None,
infos=[step_data["info"]] if self._record_infos else None,
)
return obs, info

Expand Down
11 changes: 4 additions & 7 deletions minari/data_collector/episode_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class EpisodeBuffer:
rewards: list = field(default_factory=list)
terminations: list = field(default_factory=list)
truncations: list = field(default_factory=list)
infos: Optional[dict] = None
infos: Optional[list] = field(default_factory=list)

def add_step_data(self, step_data: StepData) -> EpisodeBuffer:
"""Add step data dictionary to episode buffer.
Expand Down Expand Up @@ -54,11 +54,8 @@ def _append(data, buffer):
else:
actions = jtu.tree_map(_append, step_data["action"], self.actions)

if self.infos is None:
infos = jtu.tree_map(lambda x: [x], step_data["info"])
else:
infos = jtu.tree_map(_append, step_data["info"], self.infos)

if self.infos is not None:
self.infos.append(step_data["info"])
jamartinh marked this conversation as resolved.
Show resolved Hide resolved
self.rewards.append(step_data["reward"])
self.terminations.append(step_data["termination"])
self.truncations.append(step_data["truncation"])
Expand All @@ -72,7 +69,7 @@ def _append(data, buffer):
rewards=self.rewards,
terminations=self.terminations,
truncations=self.truncations,
infos=infos,
infos=self.infos,
)

def __len__(self) -> int:
Expand Down
36 changes: 23 additions & 13 deletions minari/dataset/_storages/arrow_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import json
import pathlib
from itertools import zip_longest
from typing import Any, Dict, Iterable, Optional, Sequence
from typing import Any, Dict, Iterable, Optional, Sequence, List

import gymnasium as gym
import numpy as np


try:
import pyarrow as pa
import pyarrow.dataset as ds
Expand All @@ -19,6 +18,7 @@

from minari.data_collector.episode_buffer import EpisodeBuffer
from minari.dataset.minari_storage import MinariStorage
from minari.dataset._storages.serde import NumpyEncoder, serialize_dict, deserialize_dict


class ArrowStorage(MinariStorage):
Expand Down Expand Up @@ -84,6 +84,17 @@ def get_episodes(self, episode_indices: Iterable[int]) -> Iterable[dict]:
)

def _to_dict(id, episode):
if "infos" in episode.column_names:
try:
infos = decode_info_list(episode["infos"])
except Exception as e: # for backwards compatibility
try:
infos = _decode_info(episode["infos"])
except Exception as e:
raise ValueError(f"Failed to decode infos: {e}")
else:
infos = []

return {
"id": id,
"observations": _decode_space(
Expand All @@ -93,11 +104,7 @@ def _to_dict(id, episode):
"rewards": np.asarray(episode["rewards"])[:-1],
"terminations": np.asarray(episode["terminations"])[:-1],
"truncations": np.asarray(episode["truncations"])[:-1],
"infos": (
_decode_info(episode["infos"])
if "infos" in episode.column_names
else {}
),
"infos": infos,
}

return map(_to_dict, episode_indices, dataset.to_batches())
Expand Down Expand Up @@ -128,7 +135,7 @@ def update_episodes(self, episodes: Iterable[EpisodeBuffer]):
"truncations": np.pad(truncations, ((0, pad))),
}
if episode_data.infos:
episode_batch["infos"] = _encode_info(episode_data.infos)
episode_batch["infos"] = encode_info_list(episode_data.infos)
episode_batch = pa.RecordBatch.from_pydict(episode_batch)

total_steps += len(rewards)
Expand Down Expand Up @@ -257,8 +264,11 @@ def _decode_info(values: pa.Array):
return nested_dict


class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return super().default(obj)
def encode_info_list(info_list: List[Dict[str, Any]]) -> pa.Array:
serialized_list = [serialize_dict(d) for d in info_list]
return pa.array(serialized_list, type=pa.string())
jamartinh marked this conversation as resolved.
Show resolved Hide resolved


def decode_info_list(values: pa.Array) -> List[Dict[str, Any]]:
return [deserialize_dict(item.as_py()) for item in values]

75 changes: 71 additions & 4 deletions minari/dataset/_storages/hdf5_storage.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from __future__ import annotations

import json
import pathlib
from collections import OrderedDict
from itertools import zip_longest
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Dict, Iterable, List, Optional, Tuple, Union, Any

import gymnasium as gym
import numpy as np

from minari.data_collector import EpisodeBuffer
from minari.dataset.minari_storage import MinariStorage

from minari.dataset._storages.serde import serialize_dict, deserialize_dict

try:
import h5py
Expand Down Expand Up @@ -117,8 +118,10 @@ def get_episodes(self, episode_indices: Iterable[int]) -> Iterable[dict]:
infos = None
if "infos" in ep_group:
info_group = ep_group["infos"]
assert isinstance(info_group, h5py.Group)
infos = _decode_info(info_group)
if isinstance(info_group, h5py.Group): # for backward compatibility
infos = _decode_info(info_group)
else:
infos = read_dict_dataset_from_group(ep_group, "infos")

ep_dict = {
"id": ep_idx,
Expand Down Expand Up @@ -213,6 +216,12 @@ def _add_episode_to_group(episode_buffer: Dict, episode_group: h5py.Group):
elif not isinstance(data, Iterable):
if data is not None:
episode_group.create_dataset(key, data=data)

elif key == "infos":
print("infos", type(data))
print(data)

create_dict_dataset_in_group(episode_group, "infos", data)
else:
dtype = None
if all(map(lambda elem: isinstance(elem, str), data)):
Expand Down Expand Up @@ -262,3 +271,61 @@ def unflatten_dict(d: Dict) -> Dict:
current = current[key]
current[keys[-1]] = v
return result


def infer_dtype(value):
if isinstance(value, str):
return h5py.special_dtype(vlen=str)
elif isinstance(value, (int, np.integer)):
return np.int64
elif isinstance(value, (float, np.floating)):
return np.float64
elif isinstance(value, bool):
return np.bool_
elif isinstance(value, list):
if all(isinstance(item, str) for item in value):
return h5py.special_dtype(vlen=str)
elif all(isinstance(item, (int, float, np.integer, np.floating)) for item in value):
return np.float64
else:
return h5py.special_dtype(vlen=str) # Store as JSON string
elif isinstance(value, np.ndarray):
if value.dtype.kind in ['U', 'S']:
return h5py.special_dtype(vlen=str)
else:
return value.dtype
elif isinstance(value, dict):
return h5py.special_dtype(vlen=str) # Store as JSON string
else:
return h5py.special_dtype(vlen=str) # Default to string for unknown types


def serialize_value(value):
if isinstance(value, (str, int, float, bool, np.integer, np.floating)):
return value
elif isinstance(value, np.ndarray):
if value.dtype.kind in ['U', 'S']:
return value.astype(str).tolist()
else:
return value.tolist()
elif isinstance(value, list):
if all(isinstance(item, (str, int, float, bool, np.integer, np.floating)) for item in value):
return value
else:
return json.dumps(value)
elif isinstance(value, dict):
return json.dumps(value)
else:
return str(value)


def create_dict_dataset_in_group(group, dataset_name, dict_list: List[Dict[str, Any]]):
serialized_list = [serialize_dict(d) for d in dict_list]
dt = h5py.special_dtype(vlen=str)
dataset = group.create_dataset(dataset_name, (len(serialized_list),), dtype=dt)
dataset[:] = serialized_list
return dataset

def read_dict_dataset_from_group(group, dataset_name):
dataset = group[dataset_name]
return [deserialize_dict(item) for item in dataset]
23 changes: 23 additions & 0 deletions minari/dataset/_storages/serde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

import json
from typing import Dict, Any

import numpy as np


class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, (np.integer, np.floating)):
return obj.item()
return super().default(obj)


def serialize_dict(data: Dict[str, Any]) -> str:
return json.dumps(data, cls=NumpyEncoder)


def deserialize_dict(serialized_data: str) -> Dict[str, Any]:
return json.loads(serialized_data)
13 changes: 11 additions & 2 deletions minari/dataset/episode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,21 @@ class EpisodeData:
rewards: np.ndarray
terminations: np.ndarray
truncations: np.ndarray
infos: dict
infos: dict | list # dict is for backwards compatibility

def __len__(self) -> int:
return len(self.rewards)

def __repr__(self) -> str:
if isinstance(self.infos, dict):
infos_repr = f"infos=dict with the following keys: {list(self.infos.keys())}"
elif isinstance(self.infos, list):
infos_repr = (f"infos=list of dicts with the following keys: "
f"{set(key for d in self.infos for key in d.keys())}")
elif self.infos is None:
infos_repr = "infos=None"
else:
raise ValueError(f"Unexpected type for infos: {type(self.infos)}")
return (
"EpisodeData("
f"id={self.id}, "
Expand All @@ -29,7 +38,7 @@ def __repr__(self) -> str:
f"rewards=ndarray of {len(self.rewards)} floats, "
f"terminations=ndarray of {len(self.terminations)} bools, "
f"truncations=ndarray of {len(self.truncations)} bools, "
f"infos=dict with the following keys: {list(self.infos.keys())}"
f"{infos_repr}"
")"
)

Expand Down
35 changes: 22 additions & 13 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,18 +493,27 @@ def check_data_integrity(dataset: MinariDataset, episode_indices: List[int]):
assert total_steps == dataset.total_steps


def get_info_at_step_index(infos: Dict, step_index: int) -> Dict:
result = {}
for key in infos.keys():
if isinstance(infos[key], dict):
result[key] = get_info_at_step_index(infos[key], step_index)
elif isinstance(infos[key], np.ndarray):
result[key] = infos[key][step_index]
else:
raise ValueError(
"Infos are in an unsupported format; see Minari documentation for supported formats."
)
return result
def get_info_at_step_index(infos: Dict|List[Dict], step_index: int) -> Dict:

if isinstance(infos, dict): # for backwards compatibility
result = {}
for key in infos.keys():
if isinstance(infos[key], dict):
result[key] = get_info_at_step_index(infos[key], step_index)
elif isinstance(infos[key], np.ndarray):
result[key] = infos[key][step_index]
else:
raise ValueError(
"Infos are in an unsupported format; see Minari documentation for supported formats."
)
return result
elif isinstance(infos, list):
return infos[step_index]

else:
raise ValueError(
"Infos are in an unsupported format; see Minari documentation for supported formats."
)


def _reconstuct_obs_or_action_at_index_recursive(
Expand Down Expand Up @@ -599,7 +608,7 @@ def check_episode_data_integrity(
episode_data_list: Union[List[EpisodeData], MinariDataset],
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
info_sample: Optional[dict] = None,
info_sample: Optional[Union[dict, List[Dict]]] = None,
):
"""Checks to see if a list of EpisodeData instances has consistent data and that the observations and actions are in the appropriate spaces.

Expand Down
3 changes: 2 additions & 1 deletion tests/data_collector/callbacks/test_step_data_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test_data_collector_step_data_callback_info_correction(
record_infos=True,
)
# here we are checking to make sure that if we have an environment changing its info
# structure across steps, it is results in a error
# structure across steps, IT IS OK!
with pytest.raises(ValueError):
num_episodes = 10
env.reset(seed=42)
Expand All @@ -176,4 +176,5 @@ def test_data_collector_step_data_callback_info_correction(
_, _, terminated, truncated, _ = env.step(action)

env.reset()
raise ValueError("This should BE reached")
env.close()
7 changes: 6 additions & 1 deletion tests/data_collector/test_data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,12 @@ def test_truncation_without_reset(dataset_id, env_id, data_format, register_dumm
else:
assert np.array_equal(first_step.observations, last_step.observations)

check_infos_equal(last_step.infos, first_step.infos)
if isinstance(last_step.infos, dict) and isinstance(first_step.infos, dict):
check_infos_equal(last_step.infos, first_step.infos)
else:
assert last_step.infos == first_step.infos
check_infos_equal(last_step.infos[0], first_step.infos[0])

last_step = get_single_step_from_episode(episode, -1)
assert bool(last_step.truncations) is True

Expand Down
2 changes: 1 addition & 1 deletion tests/dataset/test_minari_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_episode_data(space: gym.Space):
rewards=rewards,
terminations=terminations,
truncations=truncations,
infos={"info": True},
infos=[{"info": True}],
)

pattern = r"EpisodeData\("
Expand Down
Loading