Skip to content

Commit

Permalink
chore: Add support for unit selection in HDF5 loader
Browse files Browse the repository at this point in the history
  • Loading branch information
KeplerC committed Sep 23, 2024
1 parent 3b841fb commit 68da7d5
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
5 changes: 4 additions & 1 deletion benchmarks/openx_by_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
self.log_frequency = log_frequency
self.results = []
self.log_level = "debug"
self.unit = "frame"

def measure_average_trajectory_size(self):
"""Calculates the average size of trajectory files in the dataset directory."""
Expand Down Expand Up @@ -237,7 +238,8 @@ def __init__(

def get_loader(self):
return get_vla_dataloader(
self.dataset_dir, batch_size=self.batch_size, cache_dir=CACHE_DIR
self.dataset_dir, batch_size=self.batch_size, cache_dir=CACHE_DIR,
unit = self.unit,
)


Expand Down Expand Up @@ -265,6 +267,7 @@ def get_loader(self):
path=os.path.join(self.dataset_dir, "*.h5"),
batch_size=self.batch_size,
num_workers=0, # You can adjust this if needed
unit = self.unit,
)


Expand Down
2 changes: 1 addition & 1 deletion evaluation.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ do
echo "Running benchmarks with batch size: $batch_size"

# python3 benchmarks/openx.py --dataset_names nyu_door_opening_surprising_effectiveness --num_batches $num_batches --batch_size $batch_size
python3 benchmarks/openx.py --dataset_names berkeley_cable_routing --num_batches $num_batches --batch_size $batch_size
python3 benchmarks/openx_by_frame.py --dataset_names berkeley_cable_routing --num_batches $num_batches --batch_size $batch_size
# python3 benchmarks/openx.py --dataset_names bridge --num_batches $num_batches --batch_size $batch_size
# python3 benchmarks/openx.py --dataset_names berkeley_autolab_ur5 --num_batches $num_batches --batch_size $batch_size
done
23 changes: 20 additions & 3 deletions fog_x/loader/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __del__(self):
p.join()


class HDF5IterableDataset(IterableDataset):
class HDF5IterableEpisodeDataset(IterableDataset):
def __init__(self, path, batch_size=1):
# Note: batch size = 1 is to bypass the dataloader without pytorch dataloader
self.hdf5_loader = HDF5Loader(path, 1)
Expand All @@ -120,9 +120,26 @@ def hdf5_collate_fn(batch):
# Convert data to PyTorch tensors
return batch

class HDF5IterableFrameDataset(IterableDataset):
def __init__(self, path, batch_size=1):
# Note: batch size = 1 is to bypass the dataloader without pytorch dataloader
self.hdf5_loader = HDF5Loader(path, 1)

def __iter__(self):
return self

def __next__(self):
try:
batch = next(self.hdf5_loader)
return batch[0] # Return a single item, not a batch
except StopIteration:
raise StopIteration

def get_hdf5_dataloader(path: str, batch_size: int = 1, num_workers: int = 0):
dataset = HDF5IterableDataset(path, batch_size)
def get_hdf5_dataloader(path: str, batch_size: int = 1, num_workers: int = 0, unit: str = "trajectory"):
if unit == "trajectory":
dataset = HDF5IterableEpisodeDataset(path, batch_size)
elif unit == "frame":
dataset = HDF5IterableFrameDataset(path, batch_size)
return DataLoader(
dataset,
batch_size=batch_size,
Expand Down

0 comments on commit 68da7d5

Please sign in to comment.