From 081dd040a341f6baa517fa50f3ee85d7eddcb256 Mon Sep 17 00:00:00 2001 From: LennoxFu Date: Mon, 23 Sep 2024 12:25:40 -0700 Subject: [PATCH] LeRobot frame slicing added --- fog_x/loader/lerobot.py | 61 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/fog_x/loader/lerobot.py b/fog_x/loader/lerobot.py index 8953fb5..bb4efae 100644 --- a/fog_x/loader/lerobot.py +++ b/fog_x/loader/lerobot.py @@ -52,3 +52,64 @@ def _frame_to_numpy(frame): def get_batch(self): return next(self) + + +class LeRobotLoader_ByFrame(BaseLoader): + def __init__(self, path, dataset_name, batch_size=1, delta_timestamps=None): + super(LeRobotLoader, self).__init__(path) + self.batch_size = batch_size + self.dataset = LeRobotDataset(root="/mnt/data/fog_x/hf/", repo_id=dataset_name, delta_timestamps=delta_timestamps) + self.episode_index = 0 + + def __len__(self): + return len(self.dataset.episode_data_index["from"]) + + def __iter__(self): + return self + + def __next__(self): + max_retries = 3 + batch_of_episodes = [] + + def _frame_to_numpy(frame): + return {k: np.array(v) for k, v in frame.items()} + for _ in range(self.batch_size): + episode = [] + for attempt in range(max_retries): + try: + # repeat + if self.episode_index >= len(self.dataset): + self.episode_index = 0 + try: + from_idx = self.dataset.episode_data_index["from"][self.episode_index].item() + to_idx = self.dataset.episode_data_index["to"][self.episode_index].item() + except Exception as e: + self.episode_index = 0 + continue + + # Randomly select random_frames from episode + random_frames = 16 + episode_length = to_idx - from_idx + if episode_length <= random_frames: + random_from = from_idx + random_to = to_idx + else: + random_from = np.random.randint(from_idx, to_idx - 15) + random_to = random_from + 16 + frames = [_frame_to_numpy(self.dataset[idx]) for idx in range(random_from, random_to)] + episode.extend(frames) + self.episode_index += 1 + break + except Exception as e: + if attempt == max_retries - 1: + raise e + self.episode_index += 1 + + + batch_of_episodes.append((episode)) + + + return batch_of_episodes + + def get_batch(self): + return next(self)