Skip to content

Commit

Permalink
Merge pull request #4 from KeplerC/eddie-changes
Browse files Browse the repository at this point in the history
incorporated most of Peter's comments
  • Loading branch information
KeplerC authored Apr 16, 2024
2 parents c5000d5 + df2013b commit daac973
Showing 1 changed file with 178 additions and 164 deletions.
342 changes: 178 additions & 164 deletions fog_x/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,142 +138,157 @@ def export(
format (str): Supported formats are `rtx`, `open-x`, and `rlds`.
"""
if format == "rtx" or format == "open-x" or format == "rlds":
if export_path == None:
export_path = self.path + "/export"
if not os.path.exists(export_path):
os.makedirs(export_path)

import dm_env
import tensorflow as tf
import tensorflow_datasets as tfds
from envlogger import step_data
from tensorflow_datasets.core.features import Tensor

from fog_x.rlds.writer import CloudBackendWriter

self.obs_keys += obs_keys
self.act_keys += act_keys
self.step_keys += step_keys

(
observation_tf_dict,
action_tf_dict,
step_tf_dict,
) = self._get_tf_feature_dicts(
self.obs_keys,
self.act_keys,
self.step_keys,
)
self.export_rtx(export_path, max_episodes_per_file, version, obs_keys, act_keys, step_keys)
else:
raise ValueError("Unsupported export format")

logger.info("Exporting dataset as RT-X format")
logger.info(f"Observation keys: {observation_tf_dict}")
logger.info(f"Action keys: {action_tf_dict}")
logger.info(f"Step keys: {step_tf_dict}")

# generate tensorflow configuration file
ds_config = tfds.rlds.rlds_base.DatasetConfig(
name=self.name,
description="",
homepage="",
citation="",
version=tfds.core.Version("0.0.1"),
release_notes={
"0.0.1": "Initial release.",
},
observation_info=observation_tf_dict,
action_info=action_tf_dict,
reward_info=(
step_tf_dict["reward"]
if "reward" in step_tf_dict
else Tensor(shape=(), dtype=tf.float32)
),
discount_info=(
step_tf_dict["discount"]
if "discount" in step_tf_dict
else Tensor(shape=(), dtype=tf.float32)
),
)
def export_rtx(
self,
export_path: Optional[str] = None,
max_episodes_per_file: int = 1,
version: str = "0.0.1",
obs_keys=[],
act_keys=[],
step_keys=[]
):
if export_path == None:
export_path = self.path + "/export"
if not os.path.exists(export_path):
os.makedirs(export_path)

ds_identity = tfds.core.dataset_info.DatasetIdentity(
name=ds_config.name,
version=tfds.core.Version(version),
data_dir=export_path,
module_name="",
)
writer = CloudBackendWriter(
data_directory=export_path,
ds_config=ds_config,
ds_identity=ds_identity,
max_episodes_per_file=max_episodes_per_file,
)
import dm_env
import tensorflow as tf
import tensorflow_datasets as tfds
from envlogger import step_data
from tensorflow_datasets.core.features import Tensor

from fog_x.rlds.writer import CloudBackendWriter

self.obs_keys += obs_keys
self.act_keys += act_keys
self.step_keys += step_keys

(
observation_tf_dict,
action_tf_dict,
step_tf_dict,
) = self._get_tf_feature_dicts(
self.obs_keys,
self.act_keys,
self.step_keys,
)

logger.info("Exporting dataset as RT-X format")
logger.info(f"Observation keys: {observation_tf_dict}")
logger.info(f"Action keys: {action_tf_dict}")
logger.info(f"Step keys: {step_tf_dict}")

# generate tensorflow configuration file
ds_config = tfds.rlds.rlds_base.DatasetConfig(
name=self.name,
description="",
homepage="",
citation="",
version=tfds.core.Version("0.0.1"),
release_notes={
"0.0.1": "Initial release.",
},
observation_info=observation_tf_dict,
action_info=action_tf_dict,
reward_info=(
step_tf_dict["reward"]
if "reward" in step_tf_dict
else Tensor(shape=(), dtype=tf.float32)
),
discount_info=(
step_tf_dict["discount"]
if "discount" in step_tf_dict
else Tensor(shape=(), dtype=tf.float32)
),
)

ds_identity = tfds.core.dataset_info.DatasetIdentity(
name=ds_config.name,
version=tfds.core.Version(version),
data_dir=export_path,
module_name="",
)
writer = CloudBackendWriter(
data_directory=export_path,
ds_config=ds_config,
ds_identity=ds_identity,
max_episodes_per_file=max_episodes_per_file,
)

# export the dataset
episodes = self.get_episodes_from_metadata()
for episode in episodes:
steps = episode.collect().rows(named=True)
for i in range(len(steps)):
step = steps[i]
observationd = {}
actiond = {}
stepd = {}
for k, v in step.items():
# logger.info(f"key: {k}")
if k not in self.features:
if k != "episode_id" and k != "Timestamp":
logger.info(
f"Feature {k} not found in the dataset features."
)
continue
feature_spec = self.features[k].to_tf_feature_type()
if (
isinstance(feature_spec, tfds.core.features.Tensor)
and feature_spec.shape != ()
):
# reverse the process
value = np.load(io.BytesIO(v)).astype(feature_spec.np_dtype)
elif (
isinstance(feature_spec, tfds.core.features.Tensor)
and feature_spec.shape == ()
):
value = np.array(v, dtype=feature_spec.np_dtype)
elif isinstance(feature_spec, tfds.core.features.Image):
value = np.load(io.BytesIO(v)).astype(feature_spec.np_dtype)
else:
value = v

if k in self.obs_keys:
observationd[k] = value
elif k in self.act_keys:
actiond[k] = value
else:
stepd[k] = value

# logger.info(
# f"Step: {stepd}"
# f"Observation: {observationd}"
# f"Action: {actiond}"
# )
timestep = dm_env.TimeStep(
step_type=dm_env.StepType.FIRST,
reward=np.float32(
0.0
), # stepd["reward"] if "reward" in step else np.float32(0.0),
discount=np.float32(
0.0
), # stepd["discount"] if "discount" in step else np.float32(0.0),
observation=observationd,
)
stepdata = step_data.StepData(
timestep=timestep, action=actiond, custom_data=None
)
if i < len(steps) - 1:
writer._record_step(stepdata, is_new_episode=False)
# export the dataset
episodes = self.get_episodes_from_metadata()
for episode in episodes:
steps = episode.collect().rows(named=True)
for i in range(len(steps)):
step = steps[i]
observationd = {}
actiond = {}
stepd = {}
for k, v in step.items():
# logger.info(f"key: {k}")
if k not in self.features:
if k != "episode_id" and k != "Timestamp":
logger.info(
f"Feature {k} not found in the dataset features."
)
continue
feature_spec = self.features[k].to_tf_feature_type()
if (
isinstance(feature_spec, tfds.core.features.Tensor)
and feature_spec.shape != ()
):
# reverse the process
value = np.load(io.BytesIO(v)).astype(
feature_spec.np_dtype
)
elif (
isinstance(feature_spec, tfds.core.features.Tensor)
and feature_spec.shape == ()
):
value = np.array(v, dtype=feature_spec.np_dtype)
elif isinstance(
feature_spec, tfds.core.features.Image
):
value = np.load(io.BytesIO(v)).astype(
feature_spec.np_dtype
)
else:
writer._record_step(stepdata, is_new_episode=True)
value = v

pass
else:
raise ValueError("Unsupported export format")
if k in self.obs_keys:
observationd[k] = value
elif k in self.act_keys:
actiond[k] = value
else:
stepd[k] = value

# logger.info(
# f"Step: {stepd}"
# f"Observation: {observationd}"
# f"Action: {actiond}"
# )
timestep = dm_env.TimeStep(
step_type=dm_env.StepType.FIRST,
reward=np.float32(
0.0
), # stepd["reward"] if "reward" in step else np.float32(0.0),
discount=np.float32(
0.0
), # stepd["discount"] if "discount" in step else np.float32(0.0),
observation=observationd,
)
stepdata = step_data.StepData(
timestep=timestep, action=actiond, custom_data=None
)
if i < len(steps) - 1:
writer._record_step(stepdata, is_new_episode=False)
else:
writer._record_step(stepdata, is_new_episode=True)

def load_rtx_episodes(
self,
Expand Down Expand Up @@ -315,7 +330,7 @@ def load_rtx_episodes(

for tf_episode in ds:
logger.info(tf_episode)
fog_epsiode = self.new_episode(
fog_episode = self.new_episode(
metadata=additional_metadata,
)
for step in tf_episode["steps"]:
Expand All @@ -338,7 +353,7 @@ def load_rtx_episodes(
else:
value = v2.numpy()

fog_epsiode.add(
fog_episode.add(
feature=str(k2),
value=value,
feature_type=FeatureType(
Expand All @@ -350,13 +365,13 @@ def load_rtx_episodes(
elif k == "action":
self.act_keys.append(k2)
else:
fog_epsiode.add(
fog_episode.add(
feature=str(k),
value=v.numpy(),
feature_type=FeatureType(tf_feature_spec=data_type[k]),
)
self.step_keys.append(k)
fog_epsiode.close()
fog_episode.close()

def get_episode_info(self) -> pandas.DataFrame:
"""
Expand Down Expand Up @@ -429,35 +444,6 @@ def pytorch_dataset_builder(self, metadata=None, **kwargs):

import torch
from torch.utils.data import Dataset

class PyTorchDataset(Dataset):
def __init__(self, episodes, features):
"""
Initialize the dataset with the episodes and features.
:param episodes: A list of episodes loaded from the database.
:param features: A dictionary of features to be included in the dataset.
"""
self.episodes = episodes
self.features = features

def __len__(self):
"""
Return the total number of episodes in the dataset.
"""
return len(self.episodes)

def __getitem__(self, idx):
"""
Retrieve the idx-th episode from the dataset.
Depending on the structure, you may need to process the episode
and its features here.
"""
print("Retrieving episode at index", idx)
episode = self.episodes[idx].collect().to_pandas()
# Process the episode and its features here
# For simplicity, let's assume we're just returning the episode
return episode

episodes = self.get_episodes_from_metadata(metadata)

# Initialize the PyTorch dataset with the episodes and features
Expand Down Expand Up @@ -485,3 +471,31 @@ def get_as_huggingface_dataset(self):

hf_dataset = datasets.load_dataset("parquet", data_files=parquet_files)
return hf_dataset

class PyTorchDataset(Dataset):
def __init__(self, episodes, features):
"""
Initialize the dataset with the episodes and features.
:param episodes: A list of episodes loaded from the database.
:param features: A dictionary of features to be included in the dataset.
"""
self.episodes = episodes
self.features = features

def __len__(self):
"""
Return the total number of episodes in the dataset.
"""
return len(self.episodes)

def __getitem__(self, idx):
"""
Retrieve the idx-th episode from the dataset.
Depending on the structure, you may need to process the episode
and its features here.
"""
print("Retrieving episode at index", idx)
episode = self.episodes[idx].collect().to_pandas()
# Process the episode and its features here
# For simplicity, let's assume we're just returning the episode
return episode

0 comments on commit daac973

Please sign in to comment.