Skip to content

Commit

Permalink
Merge pull request #22 from BerkeleyAutomation/slicing
Browse files Browse the repository at this point in the history
LeRobot frame slicing added
  • Loading branch information
KeplerC authored Sep 23, 2024
2 parents 68da7d5 + 081dd04 commit ae494f9
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions fog_x/loader/lerobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,64 @@ def _frame_to_numpy(frame):

def get_batch(self):
return next(self)


class LeRobotLoader_ByFrame(BaseLoader):
def __init__(self, path, dataset_name, batch_size=1, delta_timestamps=None):
super(LeRobotLoader, self).__init__(path)
self.batch_size = batch_size
self.dataset = LeRobotDataset(root="/mnt/data/fog_x/hf/", repo_id=dataset_name, delta_timestamps=delta_timestamps)
self.episode_index = 0

def __len__(self):
return len(self.dataset.episode_data_index["from"])

def __iter__(self):
return self

def __next__(self):
max_retries = 3
batch_of_episodes = []

def _frame_to_numpy(frame):
return {k: np.array(v) for k, v in frame.items()}
for _ in range(self.batch_size):
episode = []
for attempt in range(max_retries):
try:
# repeat
if self.episode_index >= len(self.dataset):
self.episode_index = 0
try:
from_idx = self.dataset.episode_data_index["from"][self.episode_index].item()
to_idx = self.dataset.episode_data_index["to"][self.episode_index].item()
except Exception as e:
self.episode_index = 0
continue

# Randomly select random_frames from episode
random_frames = 16
episode_length = to_idx - from_idx
if episode_length <= random_frames:
random_from = from_idx
random_to = to_idx
else:
random_from = np.random.randint(from_idx, to_idx - 15)
random_to = random_from + 16
frames = [_frame_to_numpy(self.dataset[idx]) for idx in range(random_from, random_to)]
episode.extend(frames)
self.episode_index += 1
break
except Exception as e:
if attempt == max_retries - 1:
raise e
self.episode_index += 1


batch_of_episodes.append((episode))


return batch_of_episodes

def get_batch(self):
return next(self)

0 comments on commit ae494f9

Please sign in to comment.