From e573046d3e0abf7443ce1bc4c31476538f5e3331 Mon Sep 17 00:00:00 2001 From: Kaiyuan Eric Chen Date: Mon, 23 Sep 2024 14:26:10 -0700 Subject: [PATCH] fix lerobot --- benchmarks/openx_by_frame.py | 4 ++-- evaluation.sh | 6 +++--- fog_x/loader/lerobot.py | 22 ++++++++-------------- 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/benchmarks/openx_by_frame.py b/benchmarks/openx_by_frame.py index f717045..94d3715 100644 --- a/benchmarks/openx_by_frame.py +++ b/benchmarks/openx_by_frame.py @@ -9,7 +9,7 @@ import fog_x import csv import stat -from fog_x.loader.lerobot import LeRobotLoader +from fog_x.loader.lerobot import LeRobotLoader_ByFrame from fog_x.loader.vla import get_vla_dataloader from fog_x.loader.hdf5 import get_hdf5_dataloader @@ -310,7 +310,7 @@ def __init__( def get_loader(self): path = os.path.join(self.exp_dir, "hf") - return LeRobotLoader(path, self.dataset_name, batch_size=self.batch_size) + return LeRobotLoader_ByFrame(path, self.dataset_name, batch_size=1, slice_length=self.batch_size) def _recursively_load_data(self, data): import torch diff --git a/evaluation.sh b/evaluation.sh index 66303e2..28ee235 100755 --- a/evaluation.sh +++ b/evaluation.sh @@ -3,7 +3,7 @@ sudo echo "Use sudo access for clearning cache" # Define a list of batch sizes to iterate through -batch_sizes=(4) +batch_sizes=(64) num_batches=200 # batch_sizes=(1 2) @@ -16,7 +16,7 @@ do echo "Running benchmarks with batch size: $batch_size" # python3 benchmarks/openx.py --dataset_names nyu_door_opening_surprising_effectiveness --num_batches $num_batches --batch_size $batch_size - python3 benchmarks/openx_by_frame.py --dataset_names berkeley_cable_routing --num_batches $num_batches --batch_size $batch_size - # python3 benchmarks/openx.py --dataset_names bridge --num_batches $num_batches --batch_size $batch_size + # python3 benchmarks/openx_by_frame.py --dataset_names berkeley_cable_routing --num_batches $num_batches --batch_size $batch_size + python3 benchmarks/openx_by_frame.py --dataset_names bridge --num_batches $num_batches --batch_size $batch_size # python3 benchmarks/openx.py --dataset_names berkeley_autolab_ur5 --num_batches $num_batches --batch_size $batch_size done \ No newline at end of file diff --git a/fog_x/loader/lerobot.py b/fog_x/loader/lerobot.py index bb4efae..fd69611 100644 --- a/fog_x/loader/lerobot.py +++ b/fog_x/loader/lerobot.py @@ -55,11 +55,11 @@ def get_batch(self): class LeRobotLoader_ByFrame(BaseLoader): - def __init__(self, path, dataset_name, batch_size=1, delta_timestamps=None): - super(LeRobotLoader, self).__init__(path) + def __init__(self, path, dataset_name, batch_size=1, delta_timestamps=None, slice_length=16): + super(LeRobotLoader_ByFrame, self).__init__(path) self.batch_size = batch_size self.dataset = LeRobotDataset(root="/mnt/data/fog_x/hf/", repo_id=dataset_name, delta_timestamps=delta_timestamps) - self.episode_index = 0 + self.slice_length = slice_length def __len__(self): return len(self.dataset.episode_data_index["from"]) @@ -78,33 +78,27 @@ def _frame_to_numpy(frame): for attempt in range(max_retries): try: # repeat - if self.episode_index >= len(self.dataset): - self.episode_index = 0 + self.episode_index = np.random.randint(0, len(self.dataset)) try: from_idx = self.dataset.episode_data_index["from"][self.episode_index].item() to_idx = self.dataset.episode_data_index["to"][self.episode_index].item() except Exception as e: - self.episode_index = 0 continue # Randomly select random_frames from episode - random_frames = 16 episode_length = to_idx - from_idx - if episode_length <= random_frames: + if episode_length <= self.slice_length: random_from = from_idx random_to = to_idx else: - random_from = np.random.randint(from_idx, to_idx - 15) - random_to = random_from + 16 - frames = [_frame_to_numpy(self.dataset[idx]) for idx in range(random_from, random_to)] + random_from = np.random.randint(from_idx, to_idx - self.slice_length) + random_to = random_from + self.slice_length + frames = [self.dataset[idx] for idx in range(random_from, random_to)] episode.extend(frames) - self.episode_index += 1 break except Exception as e: if attempt == max_retries - 1: raise e - self.episode_index += 1 - batch_of_episodes.append((episode))