diff --git a/benchmarks/openx_by_frame.py b/benchmarks/openx_by_frame.py index 6979af1..f717045 100644 --- a/benchmarks/openx_by_frame.py +++ b/benchmarks/openx_by_frame.py @@ -237,10 +237,17 @@ def __init__( self.file_extension = ".vla" def get_loader(self): - return get_vla_dataloader( - self.dataset_dir, batch_size=self.batch_size, cache_dir=CACHE_DIR, - unit = self.unit, - ) + if self.unit == "frame": + return get_vla_dataloader( + self.dataset_dir, batch_size=1, cache_dir=CACHE_DIR, + unit = self.unit, + slice_size=self.batch_size, + ) + else: + return get_vla_dataloader( + self.dataset_dir, batch_size=self.batch_size, cache_dir=CACHE_DIR, + unit = self.unit, + ) class HDF5Handler(DatasetHandler): @@ -352,13 +359,13 @@ def evaluation(args): logger.debug(f"Evaluating dataset: {dataset_name}") handlers = [ - # VLAHandler( - # args.exp_dir, - # dataset_name, - # args.num_batches, - # args.batch_size, - # args.log_frequency, - # ), + VLAHandler( + args.exp_dir, + dataset_name, + args.num_batches, + args.batch_size, + args.log_frequency, + ), HDF5Handler( args.exp_dir, dataset_name, @@ -366,13 +373,13 @@ def evaluation(args): args.batch_size, args.log_frequency, ), - # LeRobotHandler( - # args.exp_dir, - # dataset_name, - # args.num_batches, - # args.batch_size, - # args.log_frequency, - # ), + LeRobotHandler( + args.exp_dir, + dataset_name, + args.num_batches, + args.batch_size, + args.log_frequency, + ), # RLDSHandler( # args.exp_dir, # dataset_name, diff --git a/fog_x/loader/vla.py b/fog_x/loader/vla.py index 2db5ace..92f8f96 100644 --- a/fog_x/loader/vla.py +++ b/fog_x/loader/vla.py @@ -196,6 +196,125 @@ def _read_vla(self, data_path, return_type = None): def get_batch(self): return [self.__next__() for _ in range(self.batch_size)] +class VLAFrameLoader: + def __init__(self, path: Text, batch_size=1, cache_dir="/tmp/fog_x/cache/", buffer_size=50, num_workers=-1, return_type="numpy", split="all", slice_size=1): + self.files = self._get_files(path, split) + self.split = split + + self.cache_dir = cache_dir + self.batch_size = batch_size + self.return_type = return_type + self.buffer_size = buffer_size + self.buffer = mp.Queue(maxsize=buffer_size) + if num_workers == -1: + num_workers = 2 + self.num_workers = num_workers + self.processes = [] + self.slice_size = slice_size + random.shuffle(self.files) + self._start_workers() + + def _get_files(self, path, split): + ret = [] + if "*" in path: + ret = glob.glob(path) + elif os.path.isdir(path): + ret = glob.glob(os.path.join(path, "*.vla")) + else: + ret = [path] + if split == "train": + ret = ret[:int(len(ret)*0.9)] + elif split == "val": + ret = ret[int(len(ret)*0.9):] + elif split == "all": + pass + else: + raise ValueError(f"Invalid split: {split}") + return ret + + def _read_vla_slice(self, data_path): + traj = fog_x.Trajectory(data_path, cache_dir=self.cache_dir) + total_frames = len(traj) + if self.slice_size > total_frames: + start_idx = 0 + end_idx = total_frames + else: + start_idx = random.randint(0, total_frames - self.slice_size) + end_idx = start_idx + self.slice_size + + slice_data = traj.load_slice(start_idx, end_idx) + return slice_data + + def _worker(self): + max_retries = 3 + while True: + if not self.files: + logger.info("Worker finished") + break + + for attempt in range(max_retries): + try: + file_path = random.choice(self.files) + data = self._read_vla_slice(file_path) + self.buffer.put(data) + break # Exit the retry loop if successful + except Exception as e: + logger.error(f"Error reading {file_path} on attempt {attempt + 1}: {e}") + if attempt + 1 == max_retries: + logger.error(f"Failed to read {file_path} after {max_retries} attempts") + + def _start_workers(self): + for _ in range(self.num_workers): + p = mp.Process(target=self._worker) + p.start() + logger.debug(f"Started worker {p.pid}") + self.processes.append(p) + + def get_batch_by_slice(self): + batch = [] + timeout = 5 # Adjust this value based on your needs + start_time = time.time() + + while len(batch) < self.batch_size: + if time.time() - start_time > timeout: + logger.warning(f"Timeout reached while getting batch. Batch size: {len(batch)}") + break + + try: + item = self.buffer.get(timeout=1) + batch.append(item) + except mp.queues.Empty: + if all(not p.is_alive() for p in self.processes) and self.buffer.empty(): + if len(batch) == 0: + return None # No more data available + else: + break # Return partial batch + + return batch + + def __iter__(self): + return self + + def __next__(self): + batch = self.get_batch_by_slice() + if batch is None: + random.shuffle(self.files) + self._start_workers() + raise StopIteration + return batch + + def __len__(self): + return len(self.files) + + def peek(self): + file = random.choice(self.files) + return self._read_vla_slice(file) + + def __del__(self): + for p in self.processes: + p.terminate() + p.join() + import torch from torch.utils.data import IterableDataset, DataLoader from fog_x.loader.vla import VLALoader @@ -216,6 +335,19 @@ def __next__(self): raise StopIteration return batch[0] # Return a single item, not a batch +class VLAIterableFrameDataset(IterableDataset): + def __init__(self, path: Text, cache_dir: Optional[Text] = None, buffer_size: int = 1000, slice_size: int = 1): + self.vla_loader = VLAFrameLoader(path, batch_size=1, cache_dir=cache_dir, buffer_size=buffer_size, slice_size=slice_size) + + def __iter__(self): + return self + + def __next__(self): + batch = self.vla_loader.get_batch_by_slice() + if batch is None: + raise StopIteration + return batch[0] # Return a single item, not a batch + def vla_collate_fn(batch): # Convert data to PyTorch tensors # You may need to adjust this based on the structure of your VLA data @@ -226,9 +358,17 @@ def get_vla_dataloader( batch_size: int = 1, cache_dir: Optional[Text] = None, buffer_size: int = 1000, - num_workers: int = 0 + num_workers: int = 0, + unit: str = "trajectory", + slice_size: int = 1 ): - dataset = VLAIterableDataset(path, cache_dir, buffer_size) + if unit == "trajectory": + dataset = VLAIterableDataset(path, cache_dir, buffer_size) + elif unit == "frame": + dataset = VLAIterableFrameDataset(path, cache_dir, buffer_size, slice_size) + else: + raise ValueError(f"Invalid unit: {unit}. Choose 'trajectory' or 'frame'.") + return DataLoader( dataset, batch_size=batch_size, diff --git a/fog_x/trajectory.py b/fog_x/trajectory.py index da8f9d7..69c67ad 100644 --- a/fog_x/trajectory.py +++ b/fog_x/trajectory.py @@ -107,7 +107,22 @@ def _get_current_timestamp(self): return current_time def __len__(self): - raise NotImplementedError + def _get_length_of_stream(container, stream): + """ + Get the length of the stream. + """ + length = 0 + for packet in container.demux([stream]): + if packet.dts is not None: + length += 1 + return length + + container_to_get_length = av.open(self.path, mode="r", format="matroska") + streams = container_to_get_length.streams + length = _get_length_of_stream(container_to_get_length, streams[0]) + logger.debug(f"Length of the stream is {length}") + container_to_get_length.close() + return length def __getitem__(self, key): """ @@ -219,7 +234,26 @@ def _convert_h5_cache_to_tensor(h5_cache): else: raise ValueError(f"Invalid return_type {return_type}") - + def load_slice(self, start, end): + + np_cache = None + if not os.path.exists(self.cache_file_name): + logger.debug(f"Loading the container file {self.path}, saving to cache {self.cache_file_name}") + np_cache = self._load_from_container() + try: + self._write_to_cache(np_cache) + except Exception as e: + logger.error(f"Error writing to cache file {self.cache_file_name}: {e}") + return np_cache[start:end] + + # TODO: currently keys are hardcoded to observation and action + np_cache = {} + with h5py.File(self.cache_file_name, "r") as h5_cache: + for key in h5_cache['observation'].keys(): + np_cache[f'observation/{key}'] = h5_cache[f'observation/{key}'][start:end] + for key in h5_cache['action'].keys(): + np_cache[f'action/{key}'] = h5_cache[f'action/{key}'][start:end] + return np_cache def init_feature_streams(self, feature_spec: Dict): """