Skip to content

Commit

Permalink
Refactor dataset loading method and update dataset list in dataset_or…
Browse files Browse the repository at this point in the history
…ganizer.py
  • Loading branch information
KeplerC committed Apr 16, 2024
1 parent e065c9e commit a29712b
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 11 deletions.
78 changes: 70 additions & 8 deletions examples/analytics/dataset_organizer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,74 @@
import fog_x

dataset = fog_x.dataset.Dataset(
name="demo_ds",
path="~/test_dataset",
)

dataset._prepare_rtx_metadata(
name="berkeley_autolab_ur5",
)



DATASETS = [
"fractal20220817_data",
"kuka",
"bridge",
"taco_play",
"jaco_play",
"berkeley_cable_routing",
"roboturk",
"nyu_door_opening_surprising_effectiveness",
"viola",
"berkeley_autolab_ur5",
"toto",
"language_table",
"columbia_cairlab_pusht_real",
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds",
"nyu_rot_dataset_converted_externally_to_rlds",
"stanford_hydra_dataset_converted_externally_to_rlds",
"austin_buds_dataset_converted_externally_to_rlds",
"nyu_franka_play_dataset_converted_externally_to_rlds",
"maniskill_dataset_converted_externally_to_rlds",
"cmu_franka_exploration_dataset_converted_externally_to_rlds",
"ucsd_kitchen_dataset_converted_externally_to_rlds",
"ucsd_pick_and_place_dataset_converted_externally_to_rlds",
"austin_sailor_dataset_converted_externally_to_rlds",
"austin_sirius_dataset_converted_externally_to_rlds",
"bc_z",
"usc_cloth_sim_converted_externally_to_rlds",
"utokyo_pr2_opening_fridge_converted_externally_to_rlds",
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds",
"utokyo_saytap_converted_externally_to_rlds",
"utokyo_xarm_pick_and_place_converted_externally_to_rlds",
"utokyo_xarm_bimanual_converted_externally_to_rlds",
"robo_net",
"berkeley_mvp_converted_externally_to_rlds",
"berkeley_rpt_converted_externally_to_rlds",
"kaist_nonprehensile_converted_externally_to_rlds",
"stanford_mask_vit_converted_externally_to_rlds",
"tokyo_u_lsmo_converted_externally_to_rlds",
"dlr_sara_pour_converted_externally_to_rlds",
"dlr_sara_grid_clamp_converted_externally_to_rlds",
"dlr_edan_shared_control_converted_externally_to_rlds",
"asu_table_top_converted_externally_to_rlds",
"stanford_robocook_converted_externally_to_rlds",
"eth_agent_affordances",
"imperialcollege_sawyer_wrist_cam",
"iamlab_cmu_pickup_insert_converted_externally_to_rlds",
"uiuc_d3field",
"utaustin_mutex",
"berkeley_fanuc_manipulation",
"cmu_play_fusion",
"cmu_stretch",
"berkeley_gnm_recon",
"berkeley_gnm_cory_hall",
"berkeley_gnm_sac_son",
]

for dataset_name in DATASETS:
dataset = fog_x.dataset.Dataset(
name=dataset_name,
path="~/test_dataset",
)

dataset._prepare_rtx_metadata(
name=dataset_name,
sample_size = 10,
shuffle=True,
)


15 changes: 12 additions & 3 deletions fog_x/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ def _prepare_rtx_metadata(
self,
name: str,
export_path: Optional[str] = None,
sample_size = 20,
shuffle = False,
seed = 42,
):

# this is only required if rtx format is used
Expand All @@ -350,11 +353,13 @@ def _prepare_rtx_metadata(

b = tfds.builder_from_directory(builder_dir=dataset2path(name))
ds = b.as_dataset(split="all")
if shuffle:
ds = ds.shuffle(sample_size, seed=seed)
data_type = b.info.features["steps"]
counter = 0

if export_path == None:
export_path = self.path + "/viz"
export_path = self.path + "/" + self.name + "_viz"
if not os.path.exists(export_path):
os.makedirs(export_path)

Expand All @@ -364,7 +369,7 @@ def _prepare_rtx_metadata(

additional_metadata = {
"load_from": name,
"load_index": f"all, {counter}",
"load_index": f"all, {shuffle}, {seed}, {counter}",
}

logger.info(tf_episode)
Expand All @@ -384,14 +389,16 @@ def _prepare_rtx_metadata(
output_filename = f"{self.name}_{counter}_{feature_name}"
output_path = f"{export_path}/{output_filename}"

frame_size = (image.shape[1], image.shape[0])

# save the initial image
cv2.imwrite(f"{output_path}.jpg", image)
# save the video
video_writers[feature_name] = cv2.VideoWriter(
f"{output_path}.mp4",
cv2.VideoWriter_fourcc(*"mp4v"),
10,
(640, 480),
frame_size
)
additional_metadata[f"video_path_{feature_name}"] = output_filename

Expand All @@ -409,6 +416,8 @@ def _prepare_rtx_metadata(
video_writers = {}
fog_episode.close(save_data = False, additional_metadata = additional_metadata)
counter += 1
if counter > sample_size:
break

def _load_rtx_step_data_from_tf_step(
self,
Expand Down

0 comments on commit a29712b

Please sign in to comment.