Skip to content

Commit

Permalink
vla done!
Browse files Browse the repository at this point in the history
  • Loading branch information
KeplerC committed Sep 23, 2024
1 parent ca0c0b6 commit 85f5266
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 22 deletions.
43 changes: 25 additions & 18 deletions benchmarks/openx_by_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,17 @@ def __init__(
self.file_extension = ".vla"

def get_loader(self):
return get_vla_dataloader(
self.dataset_dir, batch_size=self.batch_size, cache_dir=CACHE_DIR,
unit = self.unit,
)
if self.unit == "frame":
return get_vla_dataloader(
self.dataset_dir, batch_size=1, cache_dir=CACHE_DIR,
unit = self.unit,
slice_size=self.batch_size,
)
else:
return get_vla_dataloader(
self.dataset_dir, batch_size=self.batch_size, cache_dir=CACHE_DIR,
unit = self.unit,
)


class HDF5Handler(DatasetHandler):
Expand Down Expand Up @@ -352,27 +359,27 @@ def evaluation(args):
logger.debug(f"Evaluating dataset: {dataset_name}")

handlers = [
# VLAHandler(
# args.exp_dir,
# dataset_name,
# args.num_batches,
# args.batch_size,
# args.log_frequency,
# ),
VLAHandler(
args.exp_dir,
dataset_name,
args.num_batches,
args.batch_size,
args.log_frequency,
),
HDF5Handler(
args.exp_dir,
dataset_name,
args.num_batches,
args.batch_size,
args.log_frequency,
),
# LeRobotHandler(
# args.exp_dir,
# dataset_name,
# args.num_batches,
# args.batch_size,
# args.log_frequency,
# ),
LeRobotHandler(
args.exp_dir,
dataset_name,
args.num_batches,
args.batch_size,
args.log_frequency,
),
# RLDSHandler(
# args.exp_dir,
# dataset_name,
Expand Down
144 changes: 142 additions & 2 deletions fog_x/loader/vla.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,125 @@ def _read_vla(self, data_path, return_type = None):
def get_batch(self):
return [self.__next__() for _ in range(self.batch_size)]

class VLAFrameLoader:
def __init__(self, path: Text, batch_size=1, cache_dir="/tmp/fog_x/cache/", buffer_size=50, num_workers=-1, return_type="numpy", split="all", slice_size=1):
self.files = self._get_files(path, split)
self.split = split

self.cache_dir = cache_dir
self.batch_size = batch_size
self.return_type = return_type
self.buffer_size = buffer_size
self.buffer = mp.Queue(maxsize=buffer_size)
if num_workers == -1:
num_workers = 2
self.num_workers = num_workers
self.processes = []
self.slice_size = slice_size
random.shuffle(self.files)
self._start_workers()

def _get_files(self, path, split):
ret = []
if "*" in path:
ret = glob.glob(path)
elif os.path.isdir(path):
ret = glob.glob(os.path.join(path, "*.vla"))
else:
ret = [path]
if split == "train":
ret = ret[:int(len(ret)*0.9)]
elif split == "val":
ret = ret[int(len(ret)*0.9):]
elif split == "all":
pass
else:
raise ValueError(f"Invalid split: {split}")
return ret

def _read_vla_slice(self, data_path):
traj = fog_x.Trajectory(data_path, cache_dir=self.cache_dir)
total_frames = len(traj)
if self.slice_size > total_frames:
start_idx = 0
end_idx = total_frames
else:
start_idx = random.randint(0, total_frames - self.slice_size)
end_idx = start_idx + self.slice_size

slice_data = traj.load_slice(start_idx, end_idx)
return slice_data

def _worker(self):
max_retries = 3
while True:
if not self.files:
logger.info("Worker finished")
break

for attempt in range(max_retries):
try:
file_path = random.choice(self.files)
data = self._read_vla_slice(file_path)
self.buffer.put(data)
break # Exit the retry loop if successful
except Exception as e:
logger.error(f"Error reading {file_path} on attempt {attempt + 1}: {e}")
if attempt + 1 == max_retries:
logger.error(f"Failed to read {file_path} after {max_retries} attempts")

def _start_workers(self):
for _ in range(self.num_workers):
p = mp.Process(target=self._worker)
p.start()
logger.debug(f"Started worker {p.pid}")
self.processes.append(p)

def get_batch_by_slice(self):
batch = []
timeout = 5 # Adjust this value based on your needs
start_time = time.time()

while len(batch) < self.batch_size:
if time.time() - start_time > timeout:
logger.warning(f"Timeout reached while getting batch. Batch size: {len(batch)}")
break

try:
item = self.buffer.get(timeout=1)
batch.append(item)
except mp.queues.Empty:
if all(not p.is_alive() for p in self.processes) and self.buffer.empty():
if len(batch) == 0:
return None # No more data available
else:
break # Return partial batch

return batch

def __iter__(self):
return self

def __next__(self):
batch = self.get_batch_by_slice()
if batch is None:
random.shuffle(self.files)
self._start_workers()
raise StopIteration
return batch

def __len__(self):
return len(self.files)

def peek(self):
file = random.choice(self.files)
return self._read_vla_slice(file)

def __del__(self):
for p in self.processes:
p.terminate()
p.join()

import torch
from torch.utils.data import IterableDataset, DataLoader
from fog_x.loader.vla import VLALoader
Expand All @@ -216,6 +335,19 @@ def __next__(self):
raise StopIteration
return batch[0] # Return a single item, not a batch

class VLAIterableFrameDataset(IterableDataset):
def __init__(self, path: Text, cache_dir: Optional[Text] = None, buffer_size: int = 1000, slice_size: int = 1):
self.vla_loader = VLAFrameLoader(path, batch_size=1, cache_dir=cache_dir, buffer_size=buffer_size, slice_size=slice_size)

def __iter__(self):
return self

def __next__(self):
batch = self.vla_loader.get_batch_by_slice()
if batch is None:
raise StopIteration
return batch[0] # Return a single item, not a batch

def vla_collate_fn(batch):
# Convert data to PyTorch tensors
# You may need to adjust this based on the structure of your VLA data
Expand All @@ -226,9 +358,17 @@ def get_vla_dataloader(
batch_size: int = 1,
cache_dir: Optional[Text] = None,
buffer_size: int = 1000,
num_workers: int = 0
num_workers: int = 0,
unit: str = "trajectory",
slice_size: int = 1
):
dataset = VLAIterableDataset(path, cache_dir, buffer_size)
if unit == "trajectory":
dataset = VLAIterableDataset(path, cache_dir, buffer_size)
elif unit == "frame":
dataset = VLAIterableFrameDataset(path, cache_dir, buffer_size, slice_size)
else:
raise ValueError(f"Invalid unit: {unit}. Choose 'trajectory' or 'frame'.")

return DataLoader(
dataset,
batch_size=batch_size,
Expand Down
38 changes: 36 additions & 2 deletions fog_x/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,22 @@ def _get_current_timestamp(self):
return current_time

def __len__(self):
raise NotImplementedError
def _get_length_of_stream(container, stream):
"""
Get the length of the stream.
"""
length = 0
for packet in container.demux([stream]):
if packet.dts is not None:
length += 1
return length

container_to_get_length = av.open(self.path, mode="r", format="matroska")
streams = container_to_get_length.streams
length = _get_length_of_stream(container_to_get_length, streams[0])
logger.debug(f"Length of the stream is {length}")
container_to_get_length.close()
return length

def __getitem__(self, key):
"""
Expand Down Expand Up @@ -219,7 +234,26 @@ def _convert_h5_cache_to_tensor(h5_cache):
else:
raise ValueError(f"Invalid return_type {return_type}")


def load_slice(self, start, end):

np_cache = None
if not os.path.exists(self.cache_file_name):
logger.debug(f"Loading the container file {self.path}, saving to cache {self.cache_file_name}")
np_cache = self._load_from_container()
try:
self._write_to_cache(np_cache)
except Exception as e:
logger.error(f"Error writing to cache file {self.cache_file_name}: {e}")
return np_cache[start:end]

# TODO: currently keys are hardcoded to observation and action
np_cache = {}
with h5py.File(self.cache_file_name, "r") as h5_cache:
for key in h5_cache['observation'].keys():
np_cache[f'observation/{key}'] = h5_cache[f'observation/{key}'][start:end]
for key in h5_cache['action'].keys():
np_cache[f'action/{key}'] = h5_cache[f'action/{key}'][start:end]
return np_cache

def init_feature_streams(self, feature_spec: Dict):
"""
Expand Down

0 comments on commit 85f5266

Please sign in to comment.