Skip to content

Commit

Permalink
Refactor loading method in fog_x/dataset.py
Browse files Browse the repository at this point in the history
  • Loading branch information
KeplerC committed Apr 16, 2024
1 parent 85186b6 commit 913b0c6
Showing 1 changed file with 40 additions and 15 deletions.
55 changes: 40 additions & 15 deletions fog_x/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,12 @@ def load_rtx_episodes(
metadata=additional_metadata,
)
for step in tf_episode["steps"]:
self._load_rtx_step_data_from_tf_step(
step, fog_episode, additional_metadata, data_type,
ret = self._load_rtx_step_data_from_tf_step(
step, additional_metadata, data_type,
)
for r in ret:
fog_episode.add(**r)

fog_episode.close()

def _prepare_rtx_metadata(
Expand Down Expand Up @@ -362,16 +365,17 @@ def _prepare_rtx_metadata(
metadata=additional_metadata,
)
for step in tf_episode["steps"]:
self._load_rtx_step_data_from_tf_step(
step, fog_episode, additional_metadata, data_type,
ret = self._load_rtx_step_data_from_tf_step(
step, additional_metadata, data_type,
)
for r in ret:
fog_episode.add(**r)
fog_episode.close(save_data = False)
counter += 1

def _load_rtx_step_data_from_tf_step(
self,
step: Dict[str, Any],
fog_episode : Episode,
additional_metadata: Optional[Dict[str, Any]] = None,
data_type: Dict[str, Any] = {},
):
Expand All @@ -382,6 +386,7 @@ def _load_rtx_step_data_from_tf_step(
Tensor,
Text,
)
ret = []

for k, v in step.items():
if k == "observation" or k == "action":
Expand All @@ -401,24 +406,44 @@ def _load_rtx_step_data_from_tf_step(
value = memfile.getvalue()
else:
value = v2.numpy()
fog_episode.add(
feature=str(k2),
value=value,
feature_type=FeatureType(
tf_feature_spec=data_type[k][k2]
),

ret.append(
{
"feature": str(k2),
"value": value,
"feature_type": FeatureType(
tf_feature_spec=data_type[k][k2]
),
}
)
# fog_episode.add(
# feature=str(k2),
# value=value,
# feature_type=FeatureType(
# tf_feature_spec=data_type[k][k2]
# ),
# )
if k == "observation":
self.obs_keys.append(k2)
elif k == "action":
self.act_keys.append(k2)
else:
fog_episode.add(
feature=str(k),
value=v.numpy(),
feature_type=FeatureType(tf_feature_spec=data_type[k]),
# fog_episode.add(
# feature=str(k),
# value=v.numpy(),
# feature_type=FeatureType(tf_feature_spec=data_type[k]),
# )
ret.append(
{
"feature": str(k),
"value": v.numpy(),
"feature_type": FeatureType(
tf_feature_spec=data_type[k]
),
}
)
self.step_keys.append(k)
return ret


def get_episode_info(self) -> pandas.DataFrame:
Expand Down

0 comments on commit 913b0c6

Please sign in to comment.