From c65e64a28d57512d7289f940150e7324f0a01249 Mon Sep 17 00:00:00 2001 From: Munir Jojo-Verge Date: Fri, 19 Apr 2024 16:44:28 -0400 Subject: [PATCH 1/8] fix: getting trainning ready --- app/vjepa/train.py | 12 ++++++++---- configs/pretrain/vith16_384.yaml | 10 +++++----- logs_and_checkpoints/.gitkeep | 0 3 files changed, 13 insertions(+), 9 deletions(-) create mode 100644 logs_and_checkpoints/.gitkeep diff --git a/app/vjepa/train.py b/app/vjepa/train.py index 2b55616..8ebf33b 100644 --- a/app/vjepa/train.py +++ b/app/vjepa/train.py @@ -381,10 +381,14 @@ def save_checkpoint(epoch, path): try: udata, masks_enc, masks_pred = next(loader) - except Exception: - logger.info('Exhausted data loaders. Refreshing...') - loader = iter(unsupervised_loader) - udata, masks_enc, masks_pred = next(loader) + + except StopIteration: + logger.info('Exhausted data loaders before completing all planned iterations. Ending epoch early...') + break # Exit the current epoch loop if there are no more data points to process + # except Exception: + # logger.info('Exhausted data loaders. Refreshing...') + # loader = iter(unsupervised_loader) + # udata, masks_enc, masks_pred = next(loader) assert len(masks_enc) == len(masks_pred), \ 'Currently require num encoder masks = num predictor masks' diff --git a/configs/pretrain/vith16_384.yaml b/configs/pretrain/vith16_384.yaml index 9c4055a..af4bb5f 100644 --- a/configs/pretrain/vith16_384.yaml +++ b/configs/pretrain/vith16_384.yaml @@ -4,9 +4,9 @@ tasks_per_node: 8 data: dataset_type: VideoDataset datasets: - - /your_path_to_kinetics710_csv_file_index.csv - - /your_path_to_ssv2_csv_file_index.csv - - /your_path_to_howto100m_csv_file_index.csv + - /home/ncdev/Documents/darwin/data/raw/v-jepa-pretrain.csv + # - /your_path_to_ssv2_csv_file_index.csv + # - /your_path_to_howto100m_csv_file_index.csv decode_one_clip: true batch_size: 10 num_clips: 1 @@ -30,7 +30,7 @@ data_aug: - 1.0 reprob: 0.0 logging: - folder: /your_absolute_file_path_for_saving_logs_and_checkpoints/ + folder: /home/ncdev/Documents/darwin/jepa/logs_and_checkpoints write_tag: jepa loss: loss_exp: 1.0 @@ -87,4 +87,4 @@ optimization: final_lr: 1.0e-06 ema: - 0.998 - - 1.0 + - 1.0 \ No newline at end of file diff --git a/logs_and_checkpoints/.gitkeep b/logs_and_checkpoints/.gitkeep new file mode 100644 index 0000000..e69de29 From e5680765fa717461381e7c9f2627fe5da91c7974 Mon Sep 17 00:00:00 2001 From: Munir Jojo-Verge Date: Tue, 30 Apr 2024 12:27:38 -0400 Subject: [PATCH 2/8] fix: CodeStyle --- app/main.py | 38 +- app/main_distributed.py | 75 ++-- app/scaffold.py | 8 +- app/vjepa/train.py | 444 ++++++++++++--------- app/vjepa/transforms.py | 29 +- app/vjepa/utils.py | 87 ++-- evals/image_classification_frozen/eval.py | 284 +++++++------ evals/main.py | 34 +- evals/main_distributed.py | 88 ++-- evals/scaffold.py | 14 +- evals/video_classification_frozen/eval.py | 288 +++++++------ evals/video_classification_frozen/utils.py | 114 +++--- setup.py | 1 + src/datasets/data_manager.py | 20 +- src/datasets/image_dataset.py | 28 +- src/datasets/utils/video/functional.py | 36 +- src/datasets/utils/video/randaugment.py | 48 +-- src/datasets/utils/video/randerase.py | 18 +- src/datasets/utils/video/transforms.py | 270 ++++++------- src/datasets/utils/weighted_sampler.py | 18 +- src/datasets/video_dataset.py | 66 +-- src/masks/multiblock3d.py | 39 +- src/masks/random_tube.py | 25 +- src/models/attentive_pooler.py | 44 +- src/models/predictor.py | 83 ++-- src/models/utils/modules.py | 70 ++-- src/models/utils/patch_embed.py | 12 +- src/models/utils/pos_embs.py | 24 +- src/models/vision_transformer.py | 156 +++++--- src/utils/distributed.py | 18 +- src/utils/logging.py | 43 +- src/utils/monitoring.py | 16 +- src/utils/schedulers.py | 33 +- src/utils/tensors.py | 17 +- 34 files changed, 1400 insertions(+), 1188 deletions(-) diff --git a/app/main.py b/app/main.py index 52e1596..77d63e0 100644 --- a/app/main.py +++ b/app/main.py @@ -17,55 +17,59 @@ parser = argparse.ArgumentParser() parser.add_argument( - '--fname', type=str, - help='name of config file to load', - default='configs.yaml') + "--fname", type=str, help="name of config file to load", default="configs.yaml" +) parser.add_argument( - '--devices', type=str, nargs='+', default=['cuda:0'], - help='which devices to use on local machine') + "--devices", + type=str, + nargs="+", + default=["cuda:0"], + help="which devices to use on local machine", +) def process_main(rank, fname, world_size, devices): import os - os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1]) + + os.environ["CUDA_VISIBLE_DEVICES"] = str(devices[rank].split(":")[-1]) import logging from src.utils.logging import get_logger + logger = get_logger(force=True) if rank == 0: logger.setLevel(logging.INFO) else: logger.setLevel(logging.ERROR) - logger.info(f'called-params {fname}') + logger.info(f"called-params {fname}") # Load config params = None - with open(fname, 'r') as y_file: + with open(fname, "r") as y_file: params = yaml.load(y_file, Loader=yaml.FullLoader) - logger.info('loaded params...') + logger.info("loaded params...") # Log config if rank == 0: pprint.PrettyPrinter(indent=4).pprint(params) - dump = os.path.join(params['logging']['folder'], 'params-pretrain.yaml') - with open(dump, 'w') as f: + dump = os.path.join(params["logging"]["folder"], "params-pretrain.yaml") + with open(dump, "w") as f: yaml.dump(params, f) # Init distributed (access to comm between GPUS on same machine) world_size, rank = init_distributed(rank_and_world_size=(rank, world_size)) - logger.info(f'Running... (rank: {rank}/{world_size})') + logger.info(f"Running... (rank: {rank}/{world_size})") # Launch the app with loaded config - app_main(params['app'], args=params) + app_main(params["app"], args=params) -if __name__ == '__main__': +if __name__ == "__main__": args = parser.parse_args() num_gpus = len(args.devices) - mp.set_start_method('spawn') + mp.set_start_method("spawn") for rank in range(num_gpus): mp.Process( - target=process_main, - args=(rank, args.fname, num_gpus, args.devices) + target=process_main, args=(rank, args.fname, num_gpus, args.devices) ).start() diff --git a/app/main_distributed.py b/app/main_distributed.py index 11ac3a2..b36a646 100644 --- a/app/main_distributed.py +++ b/app/main_distributed.py @@ -20,32 +20,33 @@ parser = argparse.ArgumentParser() parser.add_argument( - '--folder', type=str, - help='location to save submitit logs', - default='/fsx-jepa/massran/submitit/') + "--folder", + type=str, + help="location to save submitit logs", + default="/fsx-jepa/massran/submitit/", +) parser.add_argument( - '--exclude', type=str, - help='nodes to exclude from training', - default=None) + "--exclude", type=str, help="nodes to exclude from training", default=None +) parser.add_argument( - '--batch-launch', action='store_true', - help='whether fname points to a file to batch-lauch several config files') + "--batch-launch", + action="store_true", + help="whether fname points to a file to batch-lauch several config files", +) parser.add_argument( - '--fname', type=str, - help='yaml file containing config file names to launch', - default='configs.yaml') -parser.add_argument( - '--partition', type=str, - help='cluster partition to submit jobs on') -parser.add_argument( - '--time', type=int, default=4300, - help='time in minutes to run job') + "--fname", + type=str, + help="yaml file containing config file names to launch", + default="configs.yaml", +) +parser.add_argument("--partition", type=str, help="cluster partition to submit jobs on") +parser.add_argument("--time", type=int, default=4300, help="time in minutes to run job") class Trainer: def __init__(self, args_pretrain, load_model=None): - self.app = args_pretrain['app'] + self.app = args_pretrain["app"] self.args_pretrain = args_pretrain self.load_model = load_model @@ -54,7 +55,7 @@ def __call__(self): params = self.args_pretrain load_model = self.load_model - logger.info('loaded pretrain params...') + logger.info("loaded pretrain params...") pp = pprint.PrettyPrinter(indent=4) pp.pprint(params) @@ -64,7 +65,9 @@ def __call__(self): def checkpoint(self): fb_trainer = Trainer(self.args_pretrain, True) - return submitit.helpers.DelayedSubmission(fb_trainer,) + return submitit.helpers.DelayedSubmission( + fb_trainer, + ) def launch_app_with_parsed_args( @@ -74,19 +77,20 @@ def launch_app_with_parsed_args( timeout=4300, nodes=1, tasks_per_node=1, - exclude_nodes=None + exclude_nodes=None, ): executor = submitit.AutoExecutor( - folder=os.path.join(submitit_folder, 'job_%j'), - slurm_max_num_timeout=20) + folder=os.path.join(submitit_folder, "job_%j"), slurm_max_num_timeout=20 + ) executor.update_parameters( slurm_partition=partition, - slurm_mem_per_gpu='55G', + slurm_mem_per_gpu="55G", timeout_min=timeout, nodes=nodes, tasks_per_node=tasks_per_node, cpus_per_task=12, - gpus_per_node=tasks_per_node) + gpus_per_node=tasks_per_node, + ) if args.exclude is not None: executor.update_parameters(slurm_exclude=args.exclude) @@ -95,7 +99,9 @@ def launch_app_with_parsed_args( with executor.batch(): for ap in args_for_pretrain: fb_trainer = Trainer(ap) - job = executor.submit(fb_trainer,) + job = executor.submit( + fb_trainer, + ) trainers.append(fb_trainer) jobs.append(job) @@ -114,7 +120,7 @@ def launch(): # -- config, but actually specifies a list of other config files # -- to run in a slurm job array if args.batch_launch: - with open(args.fname, 'r') as y_file: + with open(args.fname, "r") as y_file: config_fnames = yaml.load(y_file, Loader=yaml.FullLoader) # ---------------------------------------------------------------------- # @@ -124,13 +130,13 @@ def launch(): nodes, tasks_per_node = None, None configs = [] for f in config_fnames: - with open(f, 'r') as y_file: + with open(f, "r") as y_file: _params = yaml.load(y_file, Loader=yaml.FullLoader) - nodes = int(_params.get('nodes')) - tasks_per_node = int(_params.get('tasks_per_node')) + nodes = int(_params.get("nodes")) + tasks_per_node = int(_params.get("tasks_per_node")) configs += [_params] - logger.info(f'Loaded {len(configs)} config files') - logger.info(f'Running all jobs with {nodes=} / {tasks_per_node=}') + logger.info(f"Loaded {len(configs)} config files") + logger.info(f"Running all jobs with {nodes=} / {tasks_per_node=}") # ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- # @@ -143,10 +149,11 @@ def launch(): timeout=args.time, nodes=nodes, tasks_per_node=tasks_per_node, - exclude_nodes=args.exclude) + exclude_nodes=args.exclude, + ) # ---------------------------------------------------------------------- # -if __name__ == '__main__': +if __name__ == "__main__": args = parser.parse_args() launch() diff --git a/app/scaffold.py b/app/scaffold.py index 1b49a8b..7946924 100644 --- a/app/scaffold.py +++ b/app/scaffold.py @@ -15,7 +15,7 @@ def main(app, args, resume_preempt=False): - logger.info(f'Running pre-training of app: {app}') - return importlib.import_module(f'app.{app}.train').main( - args=args, - resume_preempt=resume_preempt) + logger.info(f"Running pre-training of app: {app}") + return importlib.import_module(f"app.{app}.train").main( + args=args, resume_preempt=resume_preempt + ) diff --git a/app/vjepa/train.py b/app/vjepa/train.py index 8ebf33b..0390974 100644 --- a/app/vjepa/train.py +++ b/app/vjepa/train.py @@ -13,7 +13,7 @@ # -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE # -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE # -- TO EACH PROCESS - os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['SLURM_LOCALID'] + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["SLURM_LOCALID"] except Exception: pass @@ -37,7 +37,8 @@ get_logger, grad_logger, adamw_logger, - AverageMeter) + AverageMeter, +) from src.utils.tensors import repeat_interleave_batch from app.vjepa.utils import ( @@ -69,19 +70,19 @@ def main(args, resume_preempt=False): # ----------------------------------------------------------------------- # # -- META - cfgs_meta = args.get('meta') - load_model = cfgs_meta.get('load_checkpoint') or resume_preempt - r_file = cfgs_meta.get('read_checkpoint', None) - seed = cfgs_meta.get('seed', _GLOBAL_SEED) - save_every_freq = cfgs_meta.get('save_every_freq', -1) - skip_batches = cfgs_meta.get('skip_batches', -1) - use_sdpa = cfgs_meta.get('use_sdpa', False) - which_dtype = cfgs_meta.get('dtype') - logger.info(f'{which_dtype=}') - if which_dtype.lower() == 'bfloat16': + cfgs_meta = args.get("meta") + load_model = cfgs_meta.get("load_checkpoint") or resume_preempt + r_file = cfgs_meta.get("read_checkpoint", None) + seed = cfgs_meta.get("seed", _GLOBAL_SEED) + save_every_freq = cfgs_meta.get("save_every_freq", -1) + skip_batches = cfgs_meta.get("skip_batches", -1) + use_sdpa = cfgs_meta.get("use_sdpa", False) + which_dtype = cfgs_meta.get("dtype") + logger.info(f"{which_dtype=}") + if which_dtype.lower() == "bfloat16": dtype = torch.bfloat16 mixed_precision = True - elif which_dtype.lower() == 'float16': + elif which_dtype.lower() == "float16": dtype = torch.float16 mixed_precision = True else: @@ -89,72 +90,74 @@ def main(args, resume_preempt=False): mixed_precision = False # -- MASK - cfgs_mask = args.get('mask') + cfgs_mask = args.get("mask") # -- MODEL - cfgs_model = args.get('model') - model_name = cfgs_model.get('model_name') - pred_depth = cfgs_model.get('pred_depth') - pred_embed_dim = cfgs_model.get('pred_embed_dim') - uniform_power = cfgs_model.get('uniform_power', True) - use_mask_tokens = cfgs_model.get('use_mask_tokens', True) - zero_init_mask_tokens = cfgs_model.get('zero_init_mask_tokens', True) + cfgs_model = args.get("model") + model_name = cfgs_model.get("model_name") + pred_depth = cfgs_model.get("pred_depth") + pred_embed_dim = cfgs_model.get("pred_embed_dim") + uniform_power = cfgs_model.get("uniform_power", True) + use_mask_tokens = cfgs_model.get("use_mask_tokens", True) + zero_init_mask_tokens = cfgs_model.get("zero_init_mask_tokens", True) # -- DATA - cfgs_data = args.get('data') - dataset_type = cfgs_data.get('dataset_type', 'videodataset') - mask_type = cfgs_data.get('mask_type', 'multiblock3d') - dataset_paths = cfgs_data.get('datasets', []) - datasets_weights = cfgs_data.get('datasets_weights', None) + cfgs_data = args.get("data") + dataset_type = cfgs_data.get("dataset_type", "videodataset") + mask_type = cfgs_data.get("mask_type", "multiblock3d") + dataset_paths = cfgs_data.get("datasets", []) + datasets_weights = cfgs_data.get("datasets_weights", None) if datasets_weights is not None: - assert len(datasets_weights) == len(dataset_paths), 'Must have one sampling weight specified for each dataset' - batch_size = cfgs_data.get('batch_size') - num_clips = cfgs_data.get('num_clips') - num_frames = cfgs_data.get('num_frames') - tubelet_size = cfgs_data.get('tubelet_size') - sampling_rate = cfgs_data.get('sampling_rate') - duration = cfgs_data.get('clip_duration', None) - crop_size = cfgs_data.get('crop_size', 224) - patch_size = cfgs_data.get('patch_size') - pin_mem = cfgs_data.get('pin_mem', False) - num_workers = cfgs_data.get('num_workers', 1) - filter_short_videos = cfgs_data.get('filter_short_videos', False) - decode_one_clip = cfgs_data.get('decode_one_clip', True) - log_resource_util_data = cfgs_data.get('log_resource_utilization', False) + assert len(datasets_weights) == len( + dataset_paths + ), "Must have one sampling weight specified for each dataset" + batch_size = cfgs_data.get("batch_size") + num_clips = cfgs_data.get("num_clips") + num_frames = cfgs_data.get("num_frames") + tubelet_size = cfgs_data.get("tubelet_size") + sampling_rate = cfgs_data.get("sampling_rate") + duration = cfgs_data.get("clip_duration", None) + crop_size = cfgs_data.get("crop_size", 224) + patch_size = cfgs_data.get("patch_size") + pin_mem = cfgs_data.get("pin_mem", False) + num_workers = cfgs_data.get("num_workers", 1) + filter_short_videos = cfgs_data.get("filter_short_videos", False) + decode_one_clip = cfgs_data.get("decode_one_clip", True) + log_resource_util_data = cfgs_data.get("log_resource_utilization", False) # -- DATA AUGS - cfgs_data_aug = args.get('data_aug') - ar_range = cfgs_data_aug.get('random_resize_aspect_ratio', [3/4, 4/3]) - rr_scale = cfgs_data_aug.get('random_resize_scale', [0.3, 1.0]) - motion_shift = cfgs_data_aug.get('motion_shift', False) - reprob = cfgs_data_aug.get('reprob', 0.) - use_aa = cfgs_data_aug.get('auto_augment', False) + cfgs_data_aug = args.get("data_aug") + ar_range = cfgs_data_aug.get("random_resize_aspect_ratio", [3 / 4, 4 / 3]) + rr_scale = cfgs_data_aug.get("random_resize_scale", [0.3, 1.0]) + motion_shift = cfgs_data_aug.get("motion_shift", False) + reprob = cfgs_data_aug.get("reprob", 0.0) + use_aa = cfgs_data_aug.get("auto_augment", False) # -- LOSS - cfgs_loss = args.get('loss') - loss_exp = cfgs_loss.get('loss_exp') - reg_coeff = cfgs_loss.get('reg_coeff') + cfgs_loss = args.get("loss") + loss_exp = cfgs_loss.get("loss_exp") + reg_coeff = cfgs_loss.get("reg_coeff") # -- OPTIMIZATION - cfgs_opt = args.get('optimization') - ipe = cfgs_opt.get('ipe', None) - ipe_scale = cfgs_opt.get('ipe_scale', 1.0) - clip_grad = cfgs_opt.get('clip_grad', None) - wd = float(cfgs_opt.get('weight_decay')) - final_wd = float(cfgs_opt.get('final_weight_decay')) - num_epochs = cfgs_opt.get('epochs') - warmup = cfgs_opt.get('warmup') - start_lr = cfgs_opt.get('start_lr') - lr = cfgs_opt.get('lr') - final_lr = cfgs_opt.get('final_lr') - ema = cfgs_opt.get('ema') - betas = cfgs_opt.get('betas', (0.9, 0.999)) - eps = cfgs_opt.get('eps', 1.e-8) + cfgs_opt = args.get("optimization") + ipe = cfgs_opt.get("ipe", None) + ipe_scale = cfgs_opt.get("ipe_scale", 1.0) + clip_grad = cfgs_opt.get("clip_grad", None) + wd = float(cfgs_opt.get("weight_decay")) + final_wd = float(cfgs_opt.get("final_weight_decay")) + num_epochs = cfgs_opt.get("epochs") + warmup = cfgs_opt.get("warmup") + start_lr = cfgs_opt.get("start_lr") + lr = cfgs_opt.get("lr") + final_lr = cfgs_opt.get("final_lr") + ema = cfgs_opt.get("ema") + betas = cfgs_opt.get("betas", (0.9, 0.999)) + eps = cfgs_opt.get("eps", 1.0e-8) # -- LOGGING - cfgs_logging = args.get('logging') - folder = cfgs_logging.get('folder') - tag = cfgs_logging.get('write_tag') + cfgs_logging = args.get("logging") + folder = cfgs_logging.get("folder") + tag = cfgs_logging.get("write_tag") # ----------------------------------------------------------------------- # # ----------------------------------------------------------------------- # @@ -163,24 +166,24 @@ def main(args, resume_preempt=False): torch.manual_seed(seed) torch.backends.cudnn.benchmark = True try: - mp.set_start_method('spawn') + mp.set_start_method("spawn") except Exception: pass # -- init torch distributed backend world_size, rank = init_distributed() - logger.info(f'Initialized (rank/world-size) {rank}/{world_size}') + logger.info(f"Initialized (rank/world-size) {rank}/{world_size}") # -- set device if not torch.cuda.is_available(): - device = torch.device('cpu') + device = torch.device("cpu") else: - device = torch.device('cuda:0') + device = torch.device("cuda:0") torch.cuda.set_device(device) # -- log/checkpointing paths - log_file = os.path.join(folder, f'{tag}_r{rank}.csv') - latest_file = f'{tag}-latest.pth.tar' + log_file = os.path.join(folder, f"{tag}_r{rank}.csv") + latest_file = f"{tag}-latest.pth.tar" latest_path = os.path.join(folder, latest_file) load_path = None if load_model: @@ -192,15 +195,15 @@ def main(args, resume_preempt=False): # -- make csv_logger csv_logger = CSVLogger( log_file, - ('%d', 'epoch'), - ('%d', 'itr'), - ('%.5f', 'loss'), - ('%.5f', 'loss-jepa'), - ('%.5f', 'reg-loss'), - ('%.5f', 'enc-grad-norm'), - ('%.5f', 'pred-grad-norm'), - ('%d', 'gpu-time(ms)'), - ('%d', 'wall-time(ms)'), + ("%d", "epoch"), + ("%d", "itr"), + ("%.5f", "loss"), + ("%.5f", "loss-jepa"), + ("%.5f", "reg-loss"), + ("%.5f", "enc-grad-norm"), + ("%.5f", "pred-grad-norm"), + ("%d", "gpu-time(ms)"), + ("%d", "wall-time(ms)"), ) # -- init model @@ -222,22 +225,24 @@ def main(args, resume_preempt=False): target_encoder = copy.deepcopy(encoder) # -- make data transforms - if mask_type == 'multiblock3d': - logger.info('Initializing basic multi-block mask') + if mask_type == "multiblock3d": + logger.info("Initializing basic multi-block mask") mask_collator = MB3DMaskCollator( crop_size=crop_size, num_frames=num_frames, patch_size=patch_size, tubelet_size=tubelet_size, - cfgs_mask=cfgs_mask) + cfgs_mask=cfgs_mask, + ) else: - logger.info('Initializing random tube mask') + logger.info("Initializing random tube mask") mask_collator = TubeMaskCollator( crop_size=crop_size, num_frames=num_frames, patch_size=patch_size, tubelet_size=tubelet_size, - cfgs_mask=cfgs_mask) + cfgs_mask=cfgs_mask, + ) transform = make_transforms( random_horizontal_flip=True, random_resize_aspect_ratio=ar_range, @@ -245,36 +250,37 @@ def main(args, resume_preempt=False): reprob=reprob, auto_augment=use_aa, motion_shift=motion_shift, - crop_size=crop_size) + crop_size=crop_size, + ) # -- init data-loaders/samplers - (unsupervised_loader, - unsupervised_sampler) = init_data( - data=dataset_type, - root_path=dataset_paths, - batch_size=batch_size, - training=True, - clip_len=num_frames, - frame_sample_rate=sampling_rate, - filter_short_videos=filter_short_videos, - decode_one_clip=decode_one_clip, - duration=duration, - num_clips=num_clips, - transform=transform, - datasets_weights=datasets_weights, - collator=mask_collator, - num_workers=num_workers, - world_size=world_size, - pin_mem=pin_mem, - rank=rank, - log_dir=folder if log_resource_util_data else None) + (unsupervised_loader, unsupervised_sampler) = init_data( + data=dataset_type, + root_path=dataset_paths, + batch_size=batch_size, + training=True, + clip_len=num_frames, + frame_sample_rate=sampling_rate, + filter_short_videos=filter_short_videos, + decode_one_clip=decode_one_clip, + duration=duration, + num_clips=num_clips, + transform=transform, + datasets_weights=datasets_weights, + collator=mask_collator, + num_workers=num_workers, + world_size=world_size, + pin_mem=pin_mem, + rank=rank, + log_dir=folder if log_resource_util_data else None, + ) try: _dlen = len(unsupervised_loader) except Exception: # Different interface for webdataset _dlen = unsupervised_loader.num_batches if ipe is None: ipe = _dlen - logger.info(f'iterations per epoch/dataest length: {ipe}/{_dlen}') + logger.info(f"iterations per epoch/dataest length: {ipe}/{_dlen}") # -- init optimizer and scheduler optimizer, scaler, scheduler, wd_scheduler = init_opt( @@ -291,7 +297,8 @@ def main(args, resume_preempt=False): ipe_scale=ipe_scale, mixed_precision=mixed_precision, betas=betas, - eps=eps) + eps=eps, + ) encoder = DistributedDataParallel(encoder, static_graph=True) predictor = DistributedDataParallel(predictor, static_graph=True) target_encoder = DistributedDataParallel(target_encoder) @@ -299,8 +306,10 @@ def main(args, resume_preempt=False): p.requires_grad = False # -- momentum schedule - momentum_scheduler = (ema[0] + i*(ema[1]-ema[0])/(ipe*num_epochs*ipe_scale) - for i in range(int(ipe*num_epochs*ipe_scale)+1)) + momentum_scheduler = ( + ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale) + for i in range(int(ipe * num_epochs * ipe_scale) + 1) + ) start_epoch = 0 # -- load training checkpoint @@ -318,7 +327,8 @@ def main(args, resume_preempt=False): predictor=predictor, target_encoder=target_encoder, opt=optimizer, - scaler=scaler) + scaler=scaler, + ) for _ in range(start_epoch * ipe): scheduler.step() wd_scheduler.step() @@ -329,31 +339,31 @@ def save_checkpoint(epoch, path): if rank != 0: return save_dict = { - 'encoder': encoder.state_dict(), - 'predictor': predictor.state_dict(), - 'opt': optimizer.state_dict(), - 'scaler': None if scaler is None else scaler.state_dict(), - 'target_encoder': target_encoder.state_dict(), - 'epoch': epoch, - 'loss': loss_meter.avg, - 'batch_size': batch_size, - 'world_size': world_size, - 'lr': lr, + "encoder": encoder.state_dict(), + "predictor": predictor.state_dict(), + "opt": optimizer.state_dict(), + "scaler": None if scaler is None else scaler.state_dict(), + "target_encoder": target_encoder.state_dict(), + "epoch": epoch, + "loss": loss_meter.avg, + "batch_size": batch_size, + "world_size": world_size, + "lr": lr, } try: torch.save(save_dict, path) except Exception as e: - logger.info(f'Encountered exception when saving checkpoint: {e}') + logger.info(f"Encountered exception when saving checkpoint: {e}") - logger.info('Initializing loader...') + logger.info("Initializing loader...") loader = iter(unsupervised_loader) if skip_batches > 0: - logger.info(f'Skip {skip_batches} batches') + logger.info(f"Skip {skip_batches} batches") unsupervised_sampler.set_epoch(start_epoch) for itr in range(skip_batches): if itr % 10 == 0: - logger.info(f'Skip {itr}/{skip_batches} batches') + logger.info(f"Skip {itr}/{skip_batches} batches") try: udata = next(loader) except Exception: @@ -362,7 +372,7 @@ def save_checkpoint(epoch, path): # -- TRAINING LOOP for epoch in range(start_epoch, num_epochs): - logger.info('Epoch %d' % (epoch + 1)) + logger.info("Epoch %d" % (epoch + 1)) # -- update distributed-data-loader epoch unsupervised_sampler.set_epoch(epoch) @@ -381,22 +391,27 @@ def save_checkpoint(epoch, path): try: udata, masks_enc, masks_pred = next(loader) - + except StopIteration: - logger.info('Exhausted data loaders before completing all planned iterations. Ending epoch early...') + logger.info( + "Exhausted data loaders before completing all planned iterations. Ending epoch early..." + ) break # Exit the current epoch loop if there are no more data points to process # except Exception: # logger.info('Exhausted data loaders. Refreshing...') # loader = iter(unsupervised_loader) # udata, masks_enc, masks_pred = next(loader) - assert len(masks_enc) == len(masks_pred), \ - 'Currently require num encoder masks = num predictor masks' + assert len(masks_enc) == len( + masks_pred + ), "Currently require num encoder masks = num predictor masks" def load_clips(): # -- unsupervised video clips # Put each clip on the GPU and concatenate along batch # dimension - clips = torch.cat([u.to(device, non_blocking=True) for u in udata[0]], dim=0) + clips = torch.cat( + [u.to(device, non_blocking=True) for u in udata[0]], dim=0 + ) # Put each mask-enc/mask-pred pair on the GPU and reuse the # same mask pair for each clip @@ -410,6 +425,7 @@ def load_clips(): _masks_pred.append(_mp) return (clips, _masks_enc, _masks_pred) + clips, masks_enc, masks_pred = load_clips() for _i, m in enumerate(mask_meters): @@ -427,7 +443,9 @@ def forward_target(c): """ with torch.no_grad(): h = target_encoder(c) - h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dim [B, N, D] + h = F.layer_norm( + h, (h.size(-1),) + ) # normalize over feature-dim [B, N, D] # -- create targets (masked regions of h) h = apply_masks(h, masks_pred, concat=False) return h @@ -442,36 +460,42 @@ def forward_context(c, h): return z def loss_fn(z, h): - loss = 0. + loss = 0.0 # Compute loss and accumulate for each mask-enc/mask-pred pair for zi, hi in zip(z, h): - loss += torch.mean(torch.abs(zi - hi)**loss_exp) / loss_exp + loss += torch.mean(torch.abs(zi - hi) ** loss_exp) / loss_exp loss /= len(masks_pred) return loss def reg_fn(z): - return sum([torch.sqrt(zi.var(dim=1) + 0.0001) for zi in z]) / len(z) + return sum([torch.sqrt(zi.var(dim=1) + 0.0001) for zi in z]) / len( + z + ) # Step 1. Forward - loss_jepa, loss_reg = 0., 0. + loss_jepa, loss_reg = 0.0, 0.0 with torch.cuda.amp.autocast(dtype=dtype, enabled=mixed_precision): h = forward_target(clips) z = forward_context(clips, h) loss_jepa = loss_fn(z, h) # jepa prediction loss pstd_z = reg_fn(z) # predictor variance across patches - loss_reg += torch.mean(F.relu(1.-pstd_z)) + loss_reg += torch.mean(F.relu(1.0 - pstd_z)) loss = loss_jepa + reg_coeff * loss_reg # Step 2. Backward & step - _enc_norm, _pred_norm = 0., 0. + _enc_norm, _pred_norm = 0.0, 0.0 if mixed_precision: scaler.scale(loss).backward() scaler.unscale_(optimizer) else: loss.backward() if (epoch > warmup) and (clip_grad is not None): - _enc_norm = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip_grad) - _pred_norm = torch.nn.utils.clip_grad_norm_(predictor.parameters(), clip_grad) + _enc_norm = torch.nn.utils.clip_grad_norm_( + encoder.parameters(), clip_grad + ) + _pred_norm = torch.nn.utils.clip_grad_norm_( + predictor.parameters(), clip_grad + ) if mixed_precision: scaler.step(optimizer) scaler.update() @@ -487,8 +511,10 @@ def reg_fn(z): # Step 3. momentum update of target encoder m = next(momentum_scheduler) with torch.no_grad(): - for param_q, param_k in zip(encoder.parameters(), target_encoder.parameters()): - param_k.data.mul_(m).add_((1.-m) * param_q.detach().data) + for param_q, param_k in zip( + encoder.parameters(), target_encoder.parameters() + ): + param_k.data.mul_(m).add_((1.0 - m) * param_q.detach().data) return ( float(loss), @@ -500,11 +526,25 @@ def reg_fn(z): grad_stats_pred, optim_stats, ) - (loss, loss_jepa, loss_reg, _new_lr, _new_wd, grad_stats, grad_stats_pred, optim_stats,), gpu_etime_ms = gpu_timer(train_step) - iter_elapsed_time_ms = (time.time() - itr_start_time) * 1000. + + ( + loss, + loss_jepa, + loss_reg, + _new_lr, + _new_wd, + grad_stats, + grad_stats_pred, + optim_stats, + ), gpu_etime_ms = gpu_timer(train_step) + iter_elapsed_time_ms = (time.time() - itr_start_time) * 1000.0 loss_meter.update(loss) - input_var = float(AllReduce.apply(clips.view(clips.shape[0], -1).var(dim=1).mean(dim=0))) - input_var_min = float(AllReduce.apply(torch.min(clips.view(clips.shape[0], -1).var(dim=1)))) + input_var = float( + AllReduce.apply(clips.view(clips.shape[0], -1).var(dim=1).mean(dim=0)) + ) + input_var_min = float( + AllReduce.apply(torch.min(clips.view(clips.shape[0], -1).var(dim=1))) + ) input_var_meter.update(input_var) input_var_min_meter.update(input_var_min) jepa_loss_meter.update(loss_jepa) @@ -523,68 +563,88 @@ def log_stats(): grad_stats.global_norm, grad_stats_pred.global_norm, gpu_etime_ms, - iter_elapsed_time_ms) + iter_elapsed_time_ms, + ) if (itr % log_freq == 0) or np.isnan(loss) or np.isinf(loss): logger.info( - '[%d, %5d] loss: %.3f | p%.3f r%.3f | ' - 'input_var: %.3f %.3f | ' - 'masks: %s ' - '[wd: %.2e] [lr: %.2e] ' - '[mem: %.2e] ' - '[gpu: %.1f ms]' - '[wall: %.1f ms]' - % (epoch + 1, itr, - loss_meter.avg, - jepa_loss_meter.avg, - reg_loss_meter.avg, - input_var_meter.avg, - input_var_min_meter.avg, - '[' + ', '.join(['%.1f' % m.avg for m in mask_meters]) + ']', - _new_wd, - _new_lr, - torch.cuda.max_memory_allocated() / 1024.0**2, - gpu_time_meter.avg, - wall_time_meter.avg)) + "[%d, %5d] loss: %.3f | p%.3f r%.3f | " + "input_var: %.3f %.3f | " + "masks: %s " + "[wd: %.2e] [lr: %.2e] " + "[mem: %.2e] " + "[gpu: %.1f ms]" + "[wall: %.1f ms]" + % ( + epoch + 1, + itr, + loss_meter.avg, + jepa_loss_meter.avg, + reg_loss_meter.avg, + input_var_meter.avg, + input_var_min_meter.avg, + "[" + + ", ".join(["%.1f" % m.avg for m in mask_meters]) + + "]", + _new_wd, + _new_lr, + torch.cuda.max_memory_allocated() / 1024.0**2, + gpu_time_meter.avg, + wall_time_meter.avg, + ) + ) if optim_stats is not None: logger.info( - '[%d, %5d] first moment: %.2e [%.2e %.2e] second moment: %.2e [%.2e %.2e]' - % (epoch + 1, itr, - optim_stats.get('exp_avg').avg, - optim_stats.get('exp_avg').min, - optim_stats.get('exp_avg').max, - optim_stats.get('exp_avg_sq').avg, - optim_stats.get('exp_avg_sq').min, - optim_stats.get('exp_avg_sq').max)) + "[%d, %5d] first moment: %.2e [%.2e %.2e] second moment: %.2e [%.2e %.2e]" + % ( + epoch + 1, + itr, + optim_stats.get("exp_avg").avg, + optim_stats.get("exp_avg").min, + optim_stats.get("exp_avg").max, + optim_stats.get("exp_avg_sq").avg, + optim_stats.get("exp_avg_sq").min, + optim_stats.get("exp_avg_sq").max, + ) + ) if grad_stats is not None: logger.info( - '[%d, %5d] enc_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e' - % (epoch + 1, itr, - grad_stats.first_layer, - grad_stats.last_layer, - grad_stats.min, - grad_stats.max, - grad_stats.global_norm)) + "[%d, %5d] enc_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e" + % ( + epoch + 1, + itr, + grad_stats.first_layer, + grad_stats.last_layer, + grad_stats.min, + grad_stats.max, + grad_stats.global_norm, + ) + ) if grad_stats_pred is not None: logger.info( - '[%d, %5d] pred_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e' - % (epoch + 1, itr, - grad_stats_pred.first_layer, - grad_stats_pred.last_layer, - grad_stats_pred.min, - grad_stats_pred.max, - grad_stats_pred.global_norm)) + "[%d, %5d] pred_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e" + % ( + epoch + 1, + itr, + grad_stats_pred.first_layer, + grad_stats_pred.last_layer, + grad_stats_pred.min, + grad_stats_pred.max, + grad_stats_pred.global_norm, + ) + ) + log_stats() - assert not np.isnan(loss), 'loss is nan' + assert not np.isnan(loss), "loss is nan" # -- Save Checkpoint - logger.info('avg. loss %.3f' % loss_meter.avg) + logger.info("avg. loss %.3f" % loss_meter.avg) # -- Save Last if epoch % checkpoint_freq == 0 or epoch == (num_epochs - 1): save_checkpoint(epoch + 1, latest_path) if save_every_freq > 0 and epoch % save_every_freq == 0: - save_every_file = f'{tag}-e{epoch}.pth.tar' + save_every_file = f"{tag}-e{epoch}.pth.tar" save_every_path = os.path.join(folder, save_every_file) save_checkpoint(epoch + 1, save_every_path) diff --git a/app/vjepa/transforms.py b/app/vjepa/transforms.py index 0854dd9..cc90645 100644 --- a/app/vjepa/transforms.py +++ b/app/vjepa/transforms.py @@ -14,14 +14,13 @@ def make_transforms( random_horizontal_flip=True, - random_resize_aspect_ratio=(3/4, 4/3), + random_resize_aspect_ratio=(3 / 4, 4 / 3), random_resize_scale=(0.3, 1.0), reprob=0.0, auto_augment=False, motion_shift=False, crop_size=224, - normalize=((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) + normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ): _frames_augmentation = VideoTransform( @@ -42,14 +41,13 @@ class VideoTransform(object): def __init__( self, random_horizontal_flip=True, - random_resize_aspect_ratio=(3/4, 4/3), + random_resize_aspect_ratio=(3 / 4, 4 / 3), random_resize_scale=(0.3, 1.0), reprob=0.0, auto_augment=False, motion_shift=False, crop_size=224, - normalize=((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) + normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ): self.random_horizontal_flip = random_horizontal_flip @@ -62,25 +60,28 @@ def __init__( self.std = torch.tensor(normalize[1], dtype=torch.float32) if not self.auto_augment: # Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255. - self.mean *= 255. - self.std *= 255. + self.mean *= 255.0 + self.std *= 255.0 self.autoaug_transform = video_transforms.create_random_augment( input_size=(crop_size, crop_size), - auto_augment='rand-m7-n4-mstd0.5-inc1', - interpolation='bicubic', + auto_augment="rand-m7-n4-mstd0.5-inc1", + interpolation="bicubic", ) - self.spatial_transform = video_transforms.random_resized_crop_with_shift \ - if motion_shift else video_transforms.random_resized_crop + self.spatial_transform = ( + video_transforms.random_resized_crop_with_shift + if motion_shift + else video_transforms.random_resized_crop + ) self.reprob = reprob self.erase_transform = RandomErasing( reprob, - mode='pixel', + mode="pixel", max_count=1, num_splits=1, - device='cpu', + device="cpu", ) def __call__(self, buffer): diff --git a/app/vjepa/utils.py b/app/vjepa/utils.py index dc8668d..7bdecd5 100644 --- a/app/vjepa/utils.py +++ b/app/vjepa/utils.py @@ -16,9 +16,7 @@ import src.models.vision_transformer as video_vit import src.models.predictor as vit_pred from src.models.utils.multimask import MultiMaskWrapper, PredictorMultiMaskWrapper -from src.utils.schedulers import ( - WarmupCosineSchedule, - CosineWDSchedule) +from src.utils.schedulers import WarmupCosineSchedule, CosineWDSchedule from src.utils.tensors import trunc_normal_ logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -34,43 +32,43 @@ def load_checkpoint( scaler, ): try: - checkpoint = torch.load(r_path, map_location=torch.device('cpu')) + checkpoint = torch.load(r_path, map_location=torch.device("cpu")) except Exception as e: - logger.info(f'Encountered exception when loading checkpoint {e}') + logger.info(f"Encountered exception when loading checkpoint {e}") epoch = 0 try: - epoch = checkpoint['epoch'] + epoch = checkpoint["epoch"] # -- loading encoder - pretrained_dict = checkpoint['encoder'] + pretrained_dict = checkpoint["encoder"] msg = encoder.load_state_dict(pretrained_dict) - logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}') + logger.info(f"loaded pretrained encoder from epoch {epoch} with msg: {msg}") # -- loading predictor - pretrained_dict = checkpoint['predictor'] + pretrained_dict = checkpoint["predictor"] msg = predictor.load_state_dict(pretrained_dict) - logger.info(f'loaded pretrained predictor from epoch {epoch} with msg: {msg}') + logger.info(f"loaded pretrained predictor from epoch {epoch} with msg: {msg}") # -- loading target_encoder if target_encoder is not None: print(list(checkpoint.keys())) - pretrained_dict = checkpoint['target_encoder'] + pretrained_dict = checkpoint["target_encoder"] msg = target_encoder.load_state_dict(pretrained_dict) logger.info( - f'loaded pretrained target encoder from epoch {epoch} with msg: {msg}' + f"loaded pretrained target encoder from epoch {epoch} with msg: {msg}" ) # -- loading optimizer - opt.load_state_dict(checkpoint['opt']) + opt.load_state_dict(checkpoint["opt"]) if scaler is not None: - scaler.load_state_dict(checkpoint['scaler']) - logger.info(f'loaded optimizers from epoch {epoch}') - logger.info(f'read-path: {r_path}') + scaler.load_state_dict(checkpoint["scaler"]) + logger.info(f"loaded optimizers from epoch {epoch}") + logger.info(f"read-path: {r_path}") del checkpoint except Exception as e: - logger.info(f'Encountered exception when loading checkpoint {e}') + logger.info(f"Encountered exception when loading checkpoint {e}") epoch = 0 return ( @@ -88,7 +86,7 @@ def init_video_model( patch_size=16, num_frames=16, tubelet_size=2, - model_name='vit_base', + model_name="vit_base", crop_size=224, pred_depth=6, pred_embed_dim=384, @@ -107,7 +105,7 @@ def init_video_model( use_sdpa=use_sdpa, ) encoder = MultiMaskWrapper(encoder) - predictor = vit_pred.__dict__['vit_predictor']( + predictor = vit_pred.__dict__["vit_predictor"]( img_size=crop_size, use_mask_tokens=use_mask_tokens, patch_size=patch_size, @@ -147,8 +145,8 @@ def init_weights(m): def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) - logger.info(f'Encoder number of parameters: {count_parameters(encoder)}') - logger.info(f'Predictor number of parameters: {count_parameters(predictor)}') + logger.info(f"Encoder number of parameters: {count_parameters(encoder)}") + logger.info(f"Predictor number of parameters: {count_parameters(predictor)}") return encoder, predictor @@ -172,25 +170,40 @@ def init_opt( ): param_groups = [ { - 'params': (p for n, p in encoder.named_parameters() - if ('bias' not in n) and (len(p.shape) != 1)) - }, { - 'params': (p for n, p in predictor.named_parameters() - if ('bias' not in n) and (len(p.shape) != 1)) - }, { - 'params': (p for n, p in encoder.named_parameters() - if ('bias' in n) or (len(p.shape) == 1)), - 'WD_exclude': zero_init_bias_wd, - 'weight_decay': 0, - }, { - 'params': (p for n, p in predictor.named_parameters() - if ('bias' in n) or (len(p.shape) == 1)), - 'WD_exclude': zero_init_bias_wd, - 'weight_decay': 0, + "params": ( + p + for n, p in encoder.named_parameters() + if ("bias" not in n) and (len(p.shape) != 1) + ) + }, + { + "params": ( + p + for n, p in predictor.named_parameters() + if ("bias" not in n) and (len(p.shape) != 1) + ) + }, + { + "params": ( + p + for n, p in encoder.named_parameters() + if ("bias" in n) or (len(p.shape) == 1) + ), + "WD_exclude": zero_init_bias_wd, + "weight_decay": 0, + }, + { + "params": ( + p + for n, p in predictor.named_parameters() + if ("bias" in n) or (len(p.shape) == 1) + ), + "WD_exclude": zero_init_bias_wd, + "weight_decay": 0, }, ] - logger.info('Using AdamW') + logger.info("Using AdamW") optimizer = torch.optim.AdamW(param_groups, betas=betas, eps=eps) scheduler = WarmupCosineSchedule( optimizer, diff --git a/evals/image_classification_frozen/eval.py b/evals/image_classification_frozen/eval.py index 56d2f28..9368c68 100644 --- a/evals/image_classification_frozen/eval.py +++ b/evals/image_classification_frozen/eval.py @@ -13,7 +13,7 @@ # -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE # -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE # -- TO EACH PROCESS - os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['SLURM_LOCALID'] + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["SLURM_LOCALID"] except Exception: pass @@ -35,18 +35,12 @@ from src.datasets.data_manager import ( init_data, ) -from src.utils.distributed import ( - init_distributed, - AllReduce -) +from src.utils.distributed import init_distributed, AllReduce from src.utils.schedulers import ( WarmupCosineSchedule, CosineWDSchedule, ) -from src.utils.logging import ( - AverageMeter, - CSVLogger -) +from src.utils.logging import AverageMeter, CSVLogger logging.basicConfig() logger = logging.getLogger() @@ -67,76 +61,75 @@ def main(args_eval, resume_preempt=False): # ----------------------------------------------------------------------- # # -- PRETRAIN - args_pretrain = args_eval.get('pretrain') - checkpoint_key = args_pretrain.get('checkpoint_key', 'target_encoder') - model_name = args_pretrain.get('model_name', None) - patch_size = args_pretrain.get('patch_size', None) - pretrain_folder = args_pretrain.get('folder', None) - ckp_fname = args_pretrain.get('checkpoint', None) - tag = args_pretrain.get('write_tag', None) - use_sdpa = args_pretrain.get('use_sdpa', True) - use_SiLU = args_pretrain.get('use_silu', False) - tight_SiLU = args_pretrain.get('tight_silu', True) - uniform_power = args_pretrain.get('uniform_power', False) + args_pretrain = args_eval.get("pretrain") + checkpoint_key = args_pretrain.get("checkpoint_key", "target_encoder") + model_name = args_pretrain.get("model_name", None) + patch_size = args_pretrain.get("patch_size", None) + pretrain_folder = args_pretrain.get("folder", None) + ckp_fname = args_pretrain.get("checkpoint", None) + tag = args_pretrain.get("write_tag", None) + use_sdpa = args_pretrain.get("use_sdpa", True) + use_SiLU = args_pretrain.get("use_silu", False) + tight_SiLU = args_pretrain.get("tight_silu", True) + uniform_power = args_pretrain.get("uniform_power", False) pretrained_path = os.path.join(pretrain_folder, ckp_fname) # Optional [for Video model]: - tubelet_size = args_pretrain.get('tubelet_size', 2) - frames_per_clip = args_pretrain.get('frames_per_clip', 1) + tubelet_size = args_pretrain.get("tubelet_size", 2) + frames_per_clip = args_pretrain.get("frames_per_clip", 1) # -- DATA - args_data = args_eval.get('data') - dataset_name = args_data.get('dataset_name') - num_classes = args_data.get('num_classes') - root_path = args_data.get('root_path', None) - image_folder = args_data.get('image_folder', None) - resolution = args_data.get('resolution', 224) + args_data = args_eval.get("data") + dataset_name = args_data.get("dataset_name") + num_classes = args_data.get("num_classes") + root_path = args_data.get("root_path", None) + image_folder = args_data.get("image_folder", None) + resolution = args_data.get("resolution", 224) # -- OPTIMIZATION - args_opt = args_eval.get('optimization') - batch_size = args_opt.get('batch_size') - num_epochs = args_opt.get('num_epochs') - wd = args_opt.get('weight_decay') - start_lr = args_opt.get('start_lr') - lr = args_opt.get('lr') - final_lr = args_opt.get('final_lr') - warmup = args_opt.get('warmup') - use_bfloat16 = args_opt.get('use_bfloat16') + args_opt = args_eval.get("optimization") + batch_size = args_opt.get("batch_size") + num_epochs = args_opt.get("num_epochs") + wd = args_opt.get("weight_decay") + start_lr = args_opt.get("start_lr") + lr = args_opt.get("lr") + final_lr = args_opt.get("final_lr") + warmup = args_opt.get("warmup") + use_bfloat16 = args_opt.get("use_bfloat16") # -- EXPERIMENT-ID/TAG (optional) - resume_checkpoint = args_eval.get('resume_checkpoint', False) or resume_preempt - eval_tag = args_eval.get('tag', None) + resume_checkpoint = args_eval.get("resume_checkpoint", False) or resume_preempt + eval_tag = args_eval.get("tag", None) # ----------------------------------------------------------------------- # try: - mp.set_start_method('spawn') + mp.set_start_method("spawn") except Exception: pass if not torch.cuda.is_available(): - device = torch.device('cpu') + device = torch.device("cpu") else: - device = torch.device('cuda:0') + device = torch.device("cuda:0") torch.cuda.set_device(device) world_size, rank = init_distributed() - logger.info(f'Initialized (rank/world-size) {rank}/{world_size}') + logger.info(f"Initialized (rank/world-size) {rank}/{world_size}") # -- log/checkpointing paths - folder = os.path.join(pretrain_folder, 'image_classification_frozen/') + folder = os.path.join(pretrain_folder, "image_classification_frozen/") if eval_tag is not None: folder = os.path.join(folder, eval_tag) if not os.path.exists(folder): os.makedirs(folder, exist_ok=True) - log_file = os.path.join(folder, f'{tag}_r{rank}.csv') - latest_path = os.path.join(folder, f'{tag}-latest.pth.tar') + log_file = os.path.join(folder, f"{tag}_r{rank}.csv") + latest_path = os.path.join(folder, f"{tag}-latest.pth.tar") # -- make csv_logger if rank == 0: - csv_logger = CSVLogger(log_file, - ('%d', 'epoch'), - ('%.5f', 'loss'), - ('%.5f', 'acc')) + csv_logger = CSVLogger( + log_file, ("%d", "epoch"), ("%.5f", "loss"), ("%.5f", "acc") + ) # Initialize model @@ -153,7 +146,8 @@ def main(args_eval, resume_preempt=False): checkpoint_key=checkpoint_key, use_SiLU=use_SiLU, tight_SiLU=tight_SiLU, - use_sdpa=use_sdpa) + use_sdpa=use_sdpa, + ) encoder.eval() for p in encoder.parameters(): p.requires_grad = False @@ -163,7 +157,7 @@ def main(args_eval, resume_preempt=False): embed_dim=encoder.embed_dim, num_heads=encoder.num_heads, depth=1, - num_classes=num_classes + num_classes=num_classes, ).to(device) train_loader = make_dataloader( @@ -174,7 +168,8 @@ def main(args_eval, resume_preempt=False): batch_size=batch_size, world_size=world_size, rank=rank, - training=True) + training=True, + ) val_loader = make_dataloader( dataset_name=dataset_name, root_path=root_path, @@ -183,9 +178,10 @@ def main(args_eval, resume_preempt=False): batch_size=batch_size, world_size=world_size, rank=rank, - training=False) + training=False, + ) ipe = len(train_loader) - logger.info(f'Dataloader created... iterations per epoch: {ipe}') + logger.info(f"Dataloader created... iterations per epoch: {ipe}") # -- optimizer and scheduler optimizer, scaler, scheduler, wd_scheduler = init_opt( @@ -197,7 +193,8 @@ def main(args_eval, resume_preempt=False): iterations_per_epoch=ipe, warmup=warmup, num_epochs=num_epochs, - use_bfloat16=use_bfloat16) + use_bfloat16=use_bfloat16, + ) classifier = DistributedDataParallel(classifier, static_graph=True) # -- load training checkpoint @@ -208,27 +205,28 @@ def main(args_eval, resume_preempt=False): r_path=latest_path, classifier=classifier, opt=optimizer, - scaler=scaler) - for _ in range(start_epoch*ipe): + scaler=scaler, + ) + for _ in range(start_epoch * ipe): scheduler.step() wd_scheduler.step() def save_checkpoint(epoch): save_dict = { - 'classifier': classifier.state_dict(), - 'opt': optimizer.state_dict(), - 'scaler': None if scaler is None else scaler.state_dict(), - 'epoch': epoch, - 'batch_size': batch_size, - 'world_size': world_size, - 'lr': lr + "classifier": classifier.state_dict(), + "opt": optimizer.state_dict(), + "scaler": None if scaler is None else scaler.state_dict(), + "epoch": epoch, + "batch_size": batch_size, + "world_size": world_size, + "lr": lr, } if rank == 0: torch.save(save_dict, latest_path) # TRAIN LOOP for epoch in range(start_epoch, num_epochs): - logger.info('Epoch %d' % (epoch + 1)) + logger.info("Epoch %d" % (epoch + 1)) train_acc = run_one_epoch( device=device, training=True, @@ -239,7 +237,8 @@ def save_checkpoint(epoch): scheduler=scheduler, wd_scheduler=wd_scheduler, data_loader=train_loader, - use_bfloat16=use_bfloat16) + use_bfloat16=use_bfloat16, + ) val_acc = run_one_epoch( device=device, @@ -251,9 +250,12 @@ def save_checkpoint(epoch): scheduler=scheduler, wd_scheduler=wd_scheduler, data_loader=val_loader, - use_bfloat16=use_bfloat16) + use_bfloat16=use_bfloat16, + ) - logger.info('[%5d] train: %.3f%% test: %.3f%%' % (epoch + 1, train_acc, val_acc)) + logger.info( + "[%5d] train: %.3f%% test: %.3f%%" % (epoch + 1, train_acc, val_acc) + ) if rank == 0: csv_logger.log(epoch + 1, train_acc, val_acc) save_checkpoint(epoch + 1) @@ -292,7 +294,7 @@ def run_one_epoch( outputs = classifier(outputs) loss = criterion(outputs, labels) - top1_acc = 100. * outputs.max(dim=1).indices.eq(labels).sum() / len(imgs) + top1_acc = 100.0 * outputs.max(dim=1).indices.eq(labels).sum() / len(imgs) top1_acc = float(AllReduce.apply(top1_acc)) top1_meter.update(top1_acc) @@ -310,68 +312,70 @@ def run_one_epoch( optimizer.zero_grad() if itr % 20 == 0: - logger.info('[%5d] %.3f%% (loss: %.3f) [mem: %.2e]' - % (itr, top1_meter.avg, loss, - torch.cuda.max_memory_allocated() / 1024.**2)) + logger.info( + "[%5d] %.3f%% (loss: %.3f) [mem: %.2e]" + % ( + itr, + top1_meter.avg, + loss, + torch.cuda.max_memory_allocated() / 1024.0**2, + ) + ) return top1_meter.avg -def load_checkpoint( - device, - r_path, - classifier, - opt, - scaler -): +def load_checkpoint(device, r_path, classifier, opt, scaler): try: - checkpoint = torch.load(r_path, map_location=torch.device('cpu')) - epoch = checkpoint['epoch'] + checkpoint = torch.load(r_path, map_location=torch.device("cpu")) + epoch = checkpoint["epoch"] # -- loading encoder - pretrained_dict = checkpoint['classifier'] + pretrained_dict = checkpoint["classifier"] msg = classifier.load_state_dict(pretrained_dict) - logger.info(f'loaded pretrained classifier from epoch {epoch} with msg: {msg}') + logger.info(f"loaded pretrained classifier from epoch {epoch} with msg: {msg}") # -- loading optimizer - opt.load_state_dict(checkpoint['opt']) + opt.load_state_dict(checkpoint["opt"]) if scaler is not None: - scaler.load_state_dict(checkpoint['scaler']) - logger.info(f'loaded optimizers from epoch {epoch}') - logger.info(f'read-path: {r_path}') + scaler.load_state_dict(checkpoint["scaler"]) + logger.info(f"loaded optimizers from epoch {epoch}") + logger.info(f"read-path: {r_path}") del checkpoint except Exception as e: - logger.info(f'Encountered exception when loading checkpoint {e}') + logger.info(f"Encountered exception when loading checkpoint {e}") epoch = 0 return classifier, opt, scaler, epoch -def load_pretrained( - encoder, - pretrained, - checkpoint_key='target_encoder' -): - logger.info(f'Loading pretrained model from {pretrained}') - checkpoint = torch.load(pretrained, map_location='cpu') +def load_pretrained(encoder, pretrained, checkpoint_key="target_encoder"): + logger.info(f"Loading pretrained model from {pretrained}") + checkpoint = torch.load(pretrained, map_location="cpu") try: pretrained_dict = checkpoint[checkpoint_key] except Exception: - pretrained_dict = checkpoint['encoder'] + pretrained_dict = checkpoint["encoder"] - pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()} - pretrained_dict = {k.replace('backbone.', ''): v for k, v in pretrained_dict.items()} + pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()} + pretrained_dict = { + k.replace("backbone.", ""): v for k, v in pretrained_dict.items() + } for k, v in encoder.state_dict().items(): if k not in pretrained_dict: logger.info(f'key "{k}" could not be found in loaded state dict') elif pretrained_dict[k].shape != v.shape: - logger.info(f'key "{k}" is of different shape in model and loaded state dict') + logger.info( + f'key "{k}" is of different shape in model and loaded state dict' + ) pretrained_dict[k] = v msg = encoder.load_state_dict(pretrained_dict, strict=False) print(encoder) - logger.info(f'loaded pretrained model with msg: {msg}') - logger.info(f'loaded pretrained encoder from epoch: {checkpoint["epoch"]}\n path: {pretrained}') + logger.info(f"loaded pretrained model with msg: {msg}") + logger.info( + f'loaded pretrained encoder from epoch: {checkpoint["epoch"]}\n path: {pretrained}' + ) del checkpoint return encoder @@ -385,28 +389,31 @@ def make_dataloader( rank, resolution=224, training=False, - subset_file=None + subset_file=None, ): - normalization = ((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) + normalization = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) if training: - logger.info('implementing auto-agument strategy') + logger.info("implementing auto-agument strategy") transform = timm_make_transforms( input_size=resolution, is_training=training, - auto_augment='original', - interpolation='bicubic', + auto_augment="original", + interpolation="bicubic", re_prob=0.25, - re_mode='pixel', + re_mode="pixel", re_count=1, mean=normalization[0], - std=normalization[1]) + std=normalization[1], + ) else: - transform = transforms.Compose([ - transforms.Resize(size=int(resolution * 256/224)), - transforms.CenterCrop(size=resolution), - transforms.ToTensor(), - transforms.Normalize(normalization[0], normalization[1])]) + transform = transforms.Compose( + [ + transforms.Resize(size=int(resolution * 256 / 224)), + transforms.CenterCrop(size=resolution), + transforms.ToTensor(), + transforms.Normalize(normalization[0], normalization[1]), + ] + ) data_loader, _ = init_data( data=dataset_name, @@ -419,7 +426,8 @@ def make_dataloader( training=training, copy_data=False, drop_last=False, - subset_file=subset_file) + subset_file=subset_file, + ) return data_loader @@ -436,7 +444,7 @@ def init_model( use_SiLU=False, tight_SiLU=True, uniform_power=False, - checkpoint_key='target_encoder' + checkpoint_key="target_encoder", ): encoder = vit.__dict__[model_name]( img_size=crop_size, @@ -449,15 +457,18 @@ def init_model( tight_SiLU=tight_SiLU, ) if frames_per_clip > 1: + def forward_prehook(module, input): input = input[0] # [B, C, H, W] input = input.unsqueeze(2).repeat(1, 1, frames_per_clip, 1, 1) - return (input) + return input encoder.register_forward_pre_hook(forward_prehook) encoder.to(device) - encoder = load_pretrained(encoder=encoder, pretrained=pretrained, checkpoint_key=checkpoint_key) + encoder = load_pretrained( + encoder=encoder, pretrained=pretrained, checkpoint_key=checkpoint_key + ) return encoder @@ -471,33 +482,42 @@ def init_opt( wd=1e-6, final_wd=1e-6, final_lr=0.0, - use_bfloat16=False + use_bfloat16=False, ): param_groups = [ { - 'params': (p for n, p in classifier.named_parameters() - if ('bias' not in n) and (len(p.shape) != 1)) - }, { - 'params': (p for n, p in classifier.named_parameters() - if ('bias' in n) or (len(p.shape) == 1)), - 'WD_exclude': True, - 'weight_decay': 0 - } + "params": ( + p + for n, p in classifier.named_parameters() + if ("bias" not in n) and (len(p.shape) != 1) + ) + }, + { + "params": ( + p + for n, p in classifier.named_parameters() + if ("bias" in n) or (len(p.shape) == 1) + ), + "WD_exclude": True, + "weight_decay": 0, + }, ] - logger.info('Using AdamW') + logger.info("Using AdamW") optimizer = torch.optim.AdamW(param_groups) scheduler = WarmupCosineSchedule( optimizer, - warmup_steps=int(warmup*iterations_per_epoch), + warmup_steps=int(warmup * iterations_per_epoch), start_lr=start_lr, ref_lr=ref_lr, final_lr=final_lr, - T_max=int(num_epochs*iterations_per_epoch)) + T_max=int(num_epochs * iterations_per_epoch), + ) wd_scheduler = CosineWDSchedule( optimizer, ref_wd=wd, final_wd=final_wd, - T_max=int(num_epochs*iterations_per_epoch)) + T_max=int(num_epochs * iterations_per_epoch), + ) scaler = torch.cuda.amp.GradScaler() if use_bfloat16 else None return optimizer, scaler, scheduler, wd_scheduler diff --git a/evals/main.py b/evals/main.py index c614edb..fb9130b 100644 --- a/evals/main.py +++ b/evals/main.py @@ -18,19 +18,24 @@ parser = argparse.ArgumentParser() parser.add_argument( - '--fname', type=str, - help='name of config file to load', - default='configs.yaml') + "--fname", type=str, help="name of config file to load", default="configs.yaml" +) parser.add_argument( - '--devices', type=str, nargs='+', default=['cuda:0'], - help='which devices to use on local machine') + "--devices", + type=str, + nargs="+", + default=["cuda:0"], + help="which devices to use on local machine", +) def process_main(rank, fname, world_size, devices): import os - os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1]) + + os.environ["CUDA_VISIBLE_DEVICES"] = str(devices[rank].split(":")[-1]) import logging + logging.basicConfig() logger = logging.getLogger() if rank == 0: @@ -38,30 +43,29 @@ def process_main(rank, fname, world_size, devices): else: logger.setLevel(logging.ERROR) - logger.info(f'called-params {fname}') + logger.info(f"called-params {fname}") # Load config params = None - with open(fname, 'r') as y_file: + with open(fname, "r") as y_file: params = yaml.load(y_file, Loader=yaml.FullLoader) - logger.info('loaded params...') + logger.info("loaded params...") pp = pprint.PrettyPrinter(indent=4) pp.pprint(params) # Init distributed (access to comm between GPUS on same machine) world_size, rank = init_distributed(rank_and_world_size=(rank, world_size)) - logger.info(f'Running... (rank: {rank}/{world_size})') + logger.info(f"Running... (rank: {rank}/{world_size})") # Launch the eval with loaded config - eval_main(params['eval_name'], args_eval=params) + eval_main(params["eval_name"], args_eval=params) -if __name__ == '__main__': +if __name__ == "__main__": args = parser.parse_args() num_gpus = len(args.devices) - mp.set_start_method('spawn') + mp.set_start_method("spawn") for rank in range(num_gpus): mp.Process( - target=process_main, - args=(rank, args.fname, num_gpus, args.devices) + target=process_main, args=(rank, args.fname, num_gpus, args.devices) ).start() diff --git a/evals/main_distributed.py b/evals/main_distributed.py index 1f332a0..d885d69 100644 --- a/evals/main_distributed.py +++ b/evals/main_distributed.py @@ -22,32 +22,33 @@ parser = argparse.ArgumentParser() parser.add_argument( - '--folder', type=str, - help='location to save submitit logs', - default='/fsx-jepa/massran/submitit/') + "--folder", + type=str, + help="location to save submitit logs", + default="/fsx-jepa/massran/submitit/", +) parser.add_argument( - '--exclude', type=str, - help='nodes to exclude from training', - default=None) + "--exclude", type=str, help="nodes to exclude from training", default=None +) parser.add_argument( - '--batch-launch', action='store_true', - help='whether fname points to a file to batch-lauch several config files') + "--batch-launch", + action="store_true", + help="whether fname points to a file to batch-lauch several config files", +) parser.add_argument( - '--fname', type=str, - help='yaml file containing config file names to launch', - default='configs.yaml') -parser.add_argument( - '--partition', type=str, - help='cluster partition to submit jobs on') -parser.add_argument( - '--time', type=int, default=4300, - help='time in minutes to run job') + "--fname", + type=str, + help="yaml file containing config file names to launch", + default="configs.yaml", +) +parser.add_argument("--partition", type=str, help="cluster partition to submit jobs on") +parser.add_argument("--time", type=int, default=4300, help="time in minutes to run job") class Trainer: def __init__(self, args_eval=None, resume_preempt=None): - self.eval_name = args_eval['eval_name'] + self.eval_name = args_eval["eval_name"] self.args_eval = args_eval self.resume_preempt = resume_preempt @@ -56,47 +57,47 @@ def __call__(self): args_eval = self.args_eval resume_preempt = self.resume_preempt - logger.info('loaded eval params...') + logger.info("loaded eval params...") pp = pprint.PrettyPrinter(indent=4) pp.pprint(args_eval) - eval_main( - eval_name, - args_eval=args_eval, - resume_preempt=resume_preempt) + eval_main(eval_name, args_eval=args_eval, resume_preempt=resume_preempt) def checkpoint(self): fb_trainer = Trainer(self.args_eval, True) - return submitit.helpers.DelayedSubmission(fb_trainer,) + return submitit.helpers.DelayedSubmission( + fb_trainer, + ) def launch_evals_with_parsed_args( args_for_evals, submitit_folder, - partition='learnlab,learnfair', + partition="learnlab,learnfair", timeout=4300, nodes=1, tasks_per_node=1, delay_seconds=10, - exclude_nodes=None + exclude_nodes=None, ): if not isinstance(args_for_evals, list): - logger.info(f'Passed in eval-args of type {type(args_for_evals)}') + logger.info(f"Passed in eval-args of type {type(args_for_evals)}") args_for_evals = [args_for_evals] time.sleep(delay_seconds) - logger.info('Launching evaluations in separate jobs...') + logger.info("Launching evaluations in separate jobs...") executor = submitit.AutoExecutor( - folder=os.path.join(submitit_folder, 'job_%j'), - slurm_max_num_timeout=20) + folder=os.path.join(submitit_folder, "job_%j"), slurm_max_num_timeout=20 + ) executor.update_parameters( slurm_partition=partition, - slurm_mem_per_gpu='55G', + slurm_mem_per_gpu="55G", timeout_min=timeout, nodes=nodes, tasks_per_node=tasks_per_node, cpus_per_task=12, - gpus_per_node=tasks_per_node) + gpus_per_node=tasks_per_node, + ) if exclude_nodes is not None: executor.update_parameters(slurm_exclude=exclude_nodes) @@ -105,12 +106,14 @@ def launch_evals_with_parsed_args( with executor.batch(): for ae in args_for_evals: fb_trainer = Trainer(ae) - job = executor.submit(fb_trainer,) + job = executor.submit( + fb_trainer, + ) trainers.append(fb_trainer) jobs.append(job) for job in jobs: - logger.info(f'Launched eval job with id {job.job_id}') + logger.info(f"Launched eval job with id {job.job_id}") def launch_evals(): @@ -124,7 +127,7 @@ def launch_evals(): # -- config, but actually specifies a list of other config files # -- to run in a slurm job array if args.batch_launch: - with open(args.fname, 'r') as y_file: + with open(args.fname, "r") as y_file: config_fnames = yaml.load(y_file, Loader=yaml.FullLoader) # ---------------------------------------------------------------------- # @@ -134,13 +137,13 @@ def launch_evals(): nodes, tasks_per_node = None, None configs = [] for f in config_fnames: - with open(f, 'r') as y_file: + with open(f, "r") as y_file: _params = yaml.load(y_file, Loader=yaml.FullLoader) - nodes = int(_params.get('nodes')) - tasks_per_node = int(_params.get('tasks_per_node')) + nodes = int(_params.get("nodes")) + tasks_per_node = int(_params.get("tasks_per_node")) configs += [_params] - logger.info(f'Loaded {len(configs)} config files') - logger.info(f'Running all jobs with {nodes=} / {tasks_per_node=}') + logger.info(f"Loaded {len(configs)} config files") + logger.info(f"Running all jobs with {nodes=} / {tasks_per_node=}") # ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- # @@ -153,10 +156,11 @@ def launch_evals(): timeout=args.time, nodes=nodes, tasks_per_node=tasks_per_node, - exclude_nodes=args.exclude) + exclude_nodes=args.exclude, + ) # ---------------------------------------------------------------------- # -if __name__ == '__main__': +if __name__ == "__main__": args = parser.parse_args() launch_evals() diff --git a/evals/scaffold.py b/evals/scaffold.py index c816b87..cef8d3d 100644 --- a/evals/scaffold.py +++ b/evals/scaffold.py @@ -13,12 +13,8 @@ logger = logging.getLogger() -def main( - eval_name, - args_eval, - resume_preempt=False -): - logger.info(f'Running evaluation: {eval_name}') - return importlib.import_module(f'evals.{eval_name}.eval').main( - args_eval=args_eval, - resume_preempt=resume_preempt) +def main(eval_name, args_eval, resume_preempt=False): + logger.info(f"Running evaluation: {eval_name}") + return importlib.import_module(f"evals.{eval_name}.eval").main( + args_eval=args_eval, resume_preempt=resume_preempt + ) diff --git a/evals/video_classification_frozen/eval.py b/evals/video_classification_frozen/eval.py index f81f526..a128d51 100644 --- a/evals/video_classification_frozen/eval.py +++ b/evals/video_classification_frozen/eval.py @@ -13,7 +13,7 @@ # -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE # -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE # -- TO EACH PROCESS - os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['SLURM_LOCALID'] + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["SLURM_LOCALID"] except Exception: pass @@ -33,23 +33,17 @@ from src.datasets.data_manager import ( init_data, ) -from src.utils.distributed import ( - init_distributed, - AllReduce -) +from src.utils.distributed import init_distributed, AllReduce from src.utils.schedulers import ( WarmupCosineSchedule, CosineWDSchedule, ) -from src.utils.logging import ( - AverageMeter, - CSVLogger -) +from src.utils.logging import AverageMeter, CSVLogger from evals.video_classification_frozen.utils import ( make_transforms, ClipAggregation, - FrameAggregation + FrameAggregation, ) logging.basicConfig() @@ -71,82 +65,81 @@ def main(args_eval, resume_preempt=False): # ----------------------------------------------------------------------- # # -- PRETRAIN - args_pretrain = args_eval.get('pretrain') - checkpoint_key = args_pretrain.get('checkpoint_key', 'target_encoder') - model_name = args_pretrain.get('model_name', None) - patch_size = args_pretrain.get('patch_size', None) - pretrain_folder = args_pretrain.get('folder', None) - ckp_fname = args_pretrain.get('checkpoint', None) - tag = args_pretrain.get('write_tag', None) - use_sdpa = args_pretrain.get('use_sdpa', True) - use_SiLU = args_pretrain.get('use_silu', False) - tight_SiLU = args_pretrain.get('tight_silu', True) - uniform_power = args_pretrain.get('uniform_power', False) + args_pretrain = args_eval.get("pretrain") + checkpoint_key = args_pretrain.get("checkpoint_key", "target_encoder") + model_name = args_pretrain.get("model_name", None) + patch_size = args_pretrain.get("patch_size", None) + pretrain_folder = args_pretrain.get("folder", None) + ckp_fname = args_pretrain.get("checkpoint", None) + tag = args_pretrain.get("write_tag", None) + use_sdpa = args_pretrain.get("use_sdpa", True) + use_SiLU = args_pretrain.get("use_silu", False) + tight_SiLU = args_pretrain.get("tight_silu", True) + uniform_power = args_pretrain.get("uniform_power", False) pretrained_path = os.path.join(pretrain_folder, ckp_fname) # Optional [for Video model]: - tubelet_size = args_pretrain.get('tubelet_size', 2) - pretrain_frames_per_clip = args_pretrain.get('frames_per_clip', 1) + tubelet_size = args_pretrain.get("tubelet_size", 2) + pretrain_frames_per_clip = args_pretrain.get("frames_per_clip", 1) # -- DATA - args_data = args_eval.get('data') - train_data_path = [args_data.get('dataset_train')] - val_data_path = [args_data.get('dataset_val')] - dataset_type = args_data.get('dataset_type', 'VideoDataset') - num_classes = args_data.get('num_classes') - eval_num_segments = args_data.get('num_segments', 1) - eval_frames_per_clip = args_data.get('frames_per_clip', 16) - eval_frame_step = args_pretrain.get('frame_step', 4) - eval_duration = args_pretrain.get('clip_duration', None) - eval_num_views_per_segment = args_data.get('num_views_per_segment', 1) + args_data = args_eval.get("data") + train_data_path = [args_data.get("dataset_train")] + val_data_path = [args_data.get("dataset_val")] + dataset_type = args_data.get("dataset_type", "VideoDataset") + num_classes = args_data.get("num_classes") + eval_num_segments = args_data.get("num_segments", 1) + eval_frames_per_clip = args_data.get("frames_per_clip", 16) + eval_frame_step = args_pretrain.get("frame_step", 4) + eval_duration = args_pretrain.get("clip_duration", None) + eval_num_views_per_segment = args_data.get("num_views_per_segment", 1) # -- OPTIMIZATION - args_opt = args_eval.get('optimization') - resolution = args_opt.get('resolution', 224) - batch_size = args_opt.get('batch_size') - attend_across_segments = args_opt.get('attend_across_segments', False) - num_epochs = args_opt.get('num_epochs') - wd = args_opt.get('weight_decay') - start_lr = args_opt.get('start_lr') - lr = args_opt.get('lr') - final_lr = args_opt.get('final_lr') - warmup = args_opt.get('warmup') - use_bfloat16 = args_opt.get('use_bfloat16') + args_opt = args_eval.get("optimization") + resolution = args_opt.get("resolution", 224) + batch_size = args_opt.get("batch_size") + attend_across_segments = args_opt.get("attend_across_segments", False) + num_epochs = args_opt.get("num_epochs") + wd = args_opt.get("weight_decay") + start_lr = args_opt.get("start_lr") + lr = args_opt.get("lr") + final_lr = args_opt.get("final_lr") + warmup = args_opt.get("warmup") + use_bfloat16 = args_opt.get("use_bfloat16") # -- EXPERIMENT-ID/TAG (optional) - resume_checkpoint = args_eval.get('resume_checkpoint', False) or resume_preempt - eval_tag = args_eval.get('tag', None) + resume_checkpoint = args_eval.get("resume_checkpoint", False) or resume_preempt + eval_tag = args_eval.get("tag", None) # ----------------------------------------------------------------------- # try: - mp.set_start_method('spawn') + mp.set_start_method("spawn") except Exception: pass if not torch.cuda.is_available(): - device = torch.device('cpu') + device = torch.device("cpu") else: - device = torch.device('cuda:0') + device = torch.device("cuda:0") torch.cuda.set_device(device) world_size, rank = init_distributed() - logger.info(f'Initialized (rank/world-size) {rank}/{world_size}') + logger.info(f"Initialized (rank/world-size) {rank}/{world_size}") # -- log/checkpointing paths - folder = os.path.join(pretrain_folder, 'video_classification_frozen/') + folder = os.path.join(pretrain_folder, "video_classification_frozen/") if eval_tag is not None: folder = os.path.join(folder, eval_tag) if not os.path.exists(folder): os.makedirs(folder, exist_ok=True) - log_file = os.path.join(folder, f'{tag}_r{rank}.csv') - latest_path = os.path.join(folder, f'{tag}-latest.pth.tar') + log_file = os.path.join(folder, f"{tag}_r{rank}.csv") + latest_path = os.path.join(folder, f"{tag}-latest.pth.tar") # -- make csv_logger if rank == 0: - csv_logger = CSVLogger(log_file, - ('%d', 'epoch'), - ('%.5f', 'loss'), - ('%.5f', 'acc')) + csv_logger = CSVLogger( + log_file, ("%d", "epoch"), ("%.5f", "loss"), ("%.5f", "acc") + ) # Initialize model @@ -163,7 +156,8 @@ def main(args_eval, resume_preempt=False): checkpoint_key=checkpoint_key, use_SiLU=use_SiLU, tight_SiLU=tight_SiLU, - use_sdpa=use_sdpa) + use_sdpa=use_sdpa, + ) if pretrain_frames_per_clip == 1: # Process each frame independently and aggregate encoder = FrameAggregation(encoder).to(device) @@ -172,7 +166,7 @@ def main(args_eval, resume_preempt=False): encoder = ClipAggregation( encoder, tubelet_size=tubelet_size, - attend_across_segments=attend_across_segments + attend_across_segments=attend_across_segments, ).to(device) encoder.eval() for p in encoder.parameters(): @@ -199,7 +193,8 @@ def main(args_eval, resume_preempt=False): batch_size=batch_size, world_size=world_size, rank=rank, - training=True) + training=True, + ) val_loader = make_dataloader( dataset_type=dataset_type, root_path=val_data_path, @@ -213,9 +208,10 @@ def main(args_eval, resume_preempt=False): batch_size=batch_size, world_size=world_size, rank=rank, - training=False) + training=False, + ) ipe = len(train_loader) - logger.info(f'Dataloader created... iterations per epoch: {ipe}') + logger.info(f"Dataloader created... iterations per epoch: {ipe}") # -- optimizer and scheduler optimizer, scaler, scheduler, wd_scheduler = init_opt( @@ -227,7 +223,8 @@ def main(args_eval, resume_preempt=False): iterations_per_epoch=ipe, warmup=warmup, num_epochs=num_epochs, - use_bfloat16=use_bfloat16) + use_bfloat16=use_bfloat16, + ) classifier = DistributedDataParallel(classifier, static_graph=True) # -- load training checkpoint @@ -238,27 +235,28 @@ def main(args_eval, resume_preempt=False): r_path=latest_path, classifier=classifier, opt=optimizer, - scaler=scaler) - for _ in range(start_epoch*ipe): + scaler=scaler, + ) + for _ in range(start_epoch * ipe): scheduler.step() wd_scheduler.step() def save_checkpoint(epoch): save_dict = { - 'classifier': classifier.state_dict(), - 'opt': optimizer.state_dict(), - 'scaler': None if scaler is None else scaler.state_dict(), - 'epoch': epoch, - 'batch_size': batch_size, - 'world_size': world_size, - 'lr': lr + "classifier": classifier.state_dict(), + "opt": optimizer.state_dict(), + "scaler": None if scaler is None else scaler.state_dict(), + "epoch": epoch, + "batch_size": batch_size, + "world_size": world_size, + "lr": lr, } if rank == 0: torch.save(save_dict, latest_path) # TRAIN LOOP for epoch in range(start_epoch, num_epochs): - logger.info('Epoch %d' % (epoch + 1)) + logger.info("Epoch %d" % (epoch + 1)) train_acc = run_one_epoch( device=device, training=True, @@ -272,7 +270,8 @@ def save_checkpoint(epoch): scheduler=scheduler, wd_scheduler=wd_scheduler, data_loader=train_loader, - use_bfloat16=use_bfloat16) + use_bfloat16=use_bfloat16, + ) val_acc = run_one_epoch( device=device, @@ -287,9 +286,12 @@ def save_checkpoint(epoch): scheduler=scheduler, wd_scheduler=wd_scheduler, data_loader=val_loader, - use_bfloat16=use_bfloat16) + use_bfloat16=use_bfloat16, + ) - logger.info('[%5d] train: %.3f%% test: %.3f%%' % (epoch + 1, train_acc, val_acc)) + logger.info( + "[%5d] train: %.3f%% test: %.3f%%" % (epoch + 1, train_acc, val_acc) + ) if rank == 0: csv_logger.log(epoch + 1, train_acc, val_acc) save_checkpoint(epoch + 1) @@ -324,7 +326,9 @@ def run_one_epoch( # Load data and put on GPU clips = [ - [dij.to(device, non_blocking=True) for dij in di] # iterate over spatial views of clip + [ + dij.to(device, non_blocking=True) for dij in di + ] # iterate over spatial views of clip for di in data[0] # iterate over temporal index of clip ] clip_indices = [d.to(device, non_blocking=True) for d in data[2]] @@ -349,13 +353,21 @@ def run_one_epoch( if attend_across_segments: loss = sum([criterion(o, labels) for o in outputs]) / len(outputs) else: - loss = sum([sum([criterion(ost, labels) for ost in os]) for os in outputs]) / len(outputs) / len(outputs[0]) + loss = ( + sum([sum([criterion(ost, labels) for ost in os]) for os in outputs]) + / len(outputs) + / len(outputs[0]) + ) with torch.no_grad(): if attend_across_segments: outputs = sum([F.softmax(o, dim=1) for o in outputs]) / len(outputs) else: - outputs = sum([sum([F.softmax(ost, dim=1) for ost in os]) for os in outputs]) / len(outputs) / len(outputs[0]) - top1_acc = 100. * outputs.max(dim=1).indices.eq(labels).sum() / batch_size + outputs = ( + sum([sum([F.softmax(ost, dim=1) for ost in os]) for os in outputs]) + / len(outputs) + / len(outputs[0]) + ) + top1_acc = 100.0 * outputs.max(dim=1).indices.eq(labels).sum() / batch_size top1_acc = float(AllReduce.apply(top1_acc)) top1_meter.update(top1_acc) @@ -373,68 +385,70 @@ def run_one_epoch( optimizer.zero_grad() if itr % 20 == 0: - logger.info('[%5d] %.3f%% (loss: %.3f) [mem: %.2e]' - % (itr, top1_meter.avg, loss, - torch.cuda.max_memory_allocated() / 1024.**2)) + logger.info( + "[%5d] %.3f%% (loss: %.3f) [mem: %.2e]" + % ( + itr, + top1_meter.avg, + loss, + torch.cuda.max_memory_allocated() / 1024.0**2, + ) + ) return top1_meter.avg -def load_checkpoint( - device, - r_path, - classifier, - opt, - scaler -): +def load_checkpoint(device, r_path, classifier, opt, scaler): try: - checkpoint = torch.load(r_path, map_location=torch.device('cpu')) - epoch = checkpoint['epoch'] + checkpoint = torch.load(r_path, map_location=torch.device("cpu")) + epoch = checkpoint["epoch"] # -- loading encoder - pretrained_dict = checkpoint['classifier'] + pretrained_dict = checkpoint["classifier"] msg = classifier.load_state_dict(pretrained_dict) - logger.info(f'loaded pretrained classifier from epoch {epoch} with msg: {msg}') + logger.info(f"loaded pretrained classifier from epoch {epoch} with msg: {msg}") # -- loading optimizer - opt.load_state_dict(checkpoint['opt']) + opt.load_state_dict(checkpoint["opt"]) if scaler is not None: - scaler.load_state_dict(checkpoint['scaler']) - logger.info(f'loaded optimizers from epoch {epoch}') - logger.info(f'read-path: {r_path}') + scaler.load_state_dict(checkpoint["scaler"]) + logger.info(f"loaded optimizers from epoch {epoch}") + logger.info(f"read-path: {r_path}") del checkpoint except Exception as e: - logger.info(f'Encountered exception when loading checkpoint {e}') + logger.info(f"Encountered exception when loading checkpoint {e}") epoch = 0 return classifier, opt, scaler, epoch -def load_pretrained( - encoder, - pretrained, - checkpoint_key='target_encoder' -): - logger.info(f'Loading pretrained model from {pretrained}') - checkpoint = torch.load(pretrained, map_location='cpu') +def load_pretrained(encoder, pretrained, checkpoint_key="target_encoder"): + logger.info(f"Loading pretrained model from {pretrained}") + checkpoint = torch.load(pretrained, map_location="cpu") try: pretrained_dict = checkpoint[checkpoint_key] except Exception: - pretrained_dict = checkpoint['encoder'] + pretrained_dict = checkpoint["encoder"] - pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()} - pretrained_dict = {k.replace('backbone.', ''): v for k, v in pretrained_dict.items()} + pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()} + pretrained_dict = { + k.replace("backbone.", ""): v for k, v in pretrained_dict.items() + } for k, v in encoder.state_dict().items(): if k not in pretrained_dict: logger.info(f'key "{k}" could not be found in loaded state dict') elif pretrained_dict[k].shape != v.shape: - logger.info(f'key "{k}" is of different shape in model and loaded state dict') + logger.info( + f'key "{k}" is of different shape in model and loaded state dict' + ) pretrained_dict[k] = v msg = encoder.load_state_dict(pretrained_dict, strict=False) print(encoder) - logger.info(f'loaded pretrained model with msg: {msg}') - logger.info(f'loaded pretrained encoder from epoch: {checkpoint["epoch"]}\n path: {pretrained}') + logger.info(f"loaded pretrained model with msg: {msg}") + logger.info( + f'loaded pretrained encoder from epoch: {checkpoint["epoch"]}\n path: {pretrained}' + ) del checkpoint return encoder @@ -444,7 +458,7 @@ def make_dataloader( batch_size, world_size, rank, - dataset_type='VideoDataset', + dataset_type="VideoDataset", resolution=224, frames_per_clip=16, frame_step=4, @@ -454,14 +468,14 @@ def make_dataloader( allow_segment_overlap=True, training=False, num_workers=12, - subset_file=None + subset_file=None, ): # Make Video Transforms transform = make_transforms( training=training, num_views_per_clip=num_views_per_segment, random_horizontal_flip=False, - random_resize_aspect_ratio=(0.75, 4/3), + random_resize_aspect_ratio=(0.75, 4 / 3), random_resize_scale=(0.08, 1.0), reprob=0.25, auto_augment=True, @@ -484,7 +498,8 @@ def make_dataloader( num_workers=num_workers, copy_data=False, drop_last=False, - subset_file=subset_file) + subset_file=subset_file, + ) return data_loader @@ -501,7 +516,7 @@ def init_model( use_SiLU=False, tight_SiLU=True, uniform_power=False, - checkpoint_key='target_encoder' + checkpoint_key="target_encoder", ): encoder = vit.__dict__[model_name]( img_size=crop_size, @@ -515,7 +530,9 @@ def init_model( ) encoder.to(device) - encoder = load_pretrained(encoder=encoder, pretrained=pretrained, checkpoint_key=checkpoint_key) + encoder = load_pretrained( + encoder=encoder, pretrained=pretrained, checkpoint_key=checkpoint_key + ) return encoder @@ -529,33 +546,42 @@ def init_opt( wd=1e-6, final_wd=1e-6, final_lr=0.0, - use_bfloat16=False + use_bfloat16=False, ): param_groups = [ { - 'params': (p for n, p in classifier.named_parameters() - if ('bias' not in n) and (len(p.shape) != 1)) - }, { - 'params': (p for n, p in classifier.named_parameters() - if ('bias' in n) or (len(p.shape) == 1)), - 'WD_exclude': True, - 'weight_decay': 0 - } + "params": ( + p + for n, p in classifier.named_parameters() + if ("bias" not in n) and (len(p.shape) != 1) + ) + }, + { + "params": ( + p + for n, p in classifier.named_parameters() + if ("bias" in n) or (len(p.shape) == 1) + ), + "WD_exclude": True, + "weight_decay": 0, + }, ] - logger.info('Using AdamW') + logger.info("Using AdamW") optimizer = torch.optim.AdamW(param_groups) scheduler = WarmupCosineSchedule( optimizer, - warmup_steps=int(warmup*iterations_per_epoch), + warmup_steps=int(warmup * iterations_per_epoch), start_lr=start_lr, ref_lr=ref_lr, final_lr=final_lr, - T_max=int(num_epochs*iterations_per_epoch)) + T_max=int(num_epochs * iterations_per_epoch), + ) wd_scheduler = CosineWDSchedule( optimizer, ref_wd=wd, final_wd=final_wd, - T_max=int(num_epochs*iterations_per_epoch)) + T_max=int(num_epochs * iterations_per_epoch), + ) scaler = torch.cuda.amp.GradScaler() if use_bfloat16 else None return optimizer, scaler, scheduler, wd_scheduler diff --git a/evals/video_classification_frozen/utils.py b/evals/video_classification_frozen/utils.py index 450f799..fe8c2c3 100644 --- a/evals/video_classification_frozen/utils.py +++ b/evals/video_classification_frozen/utils.py @@ -26,11 +26,7 @@ class FrameAggregation(nn.Module): """ def __init__( - self, - model, - max_frames=10000, - use_pos_embed=False, - attend_across_segments=False + self, model, max_frames=10000, use_pos_embed=False, attend_across_segments=False ): super().__init__() self.model = model @@ -41,8 +37,8 @@ def __init__( self.pos_embed = None if use_pos_embed: self.pos_embed = nn.Parameter( - torch.zeros(1, max_frames, embed_dim), - requires_grad=False) + torch.zeros(1, max_frames, embed_dim), requires_grad=False + ) sincos = get_1d_sincos_pos_embed(embed_dim, max_frames) self.pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) @@ -59,7 +55,7 @@ def forward(self, x, clip_indices=None): B, C, T, H, W = x.size() # Put each frame along the batch dimension - x = x.permute(0, 2, 1, 3, 4).reshape(B*T, C, H, W) + x = x.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W) outputs = self.model(x) _, N, D = outputs.size() @@ -69,13 +65,19 @@ def forward(self, x, clip_indices=None): B = B // num_views_per_clip all_outputs = [] for i in range(num_views_per_clip): - o = outputs[i*B:(i+1)*B] + o = outputs[i * B : (i + 1) * B] # Compute positional embedding if (self.pos_embed is not None) and (clip_indices is not None): pos_embed = self.pos_embed.repeat(B, 1, 1) # [B, F, D] - pos_embed = apply_masks(pos_embed, clip_indices, concat=False) # list(Tensor([B, T, D])) - pos_embed = torch.cat(pos_embed, dim=1) # concatenate along temporal dimension - pos_embed = pos_embed.unsqueeze(2).repeat(1, 1, N, 1) # [B, T*num_clips, N, D] + pos_embed = apply_masks( + pos_embed, clip_indices, concat=False + ) # list(Tensor([B, T, D])) + pos_embed = torch.cat( + pos_embed, dim=1 + ) # concatenate along temporal dimension + pos_embed = pos_embed.unsqueeze(2).repeat( + 1, 1, N, 1 + ) # [B, T*num_clips, N, D] pos_embed = pos_embed.flatten(1, 2) o += pos_embed all_outputs += [o] @@ -94,7 +96,7 @@ def __init__( tubelet_size=2, max_frames=10000, use_pos_embed=False, - attend_across_segments=False + attend_across_segments=False, ): super().__init__() self.model = model @@ -107,8 +109,8 @@ def __init__( if use_pos_embed: max_T = max_frames // tubelet_size self.pos_embed = nn.Parameter( - torch.zeros(1, max_T, embed_dim), - requires_grad=False) + torch.zeros(1, max_T, embed_dim), requires_grad=False + ) sincos = get_1d_sincos_pos_embed(embed_dim, max_T) self.pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) @@ -131,9 +133,9 @@ def forward(self, x, clip_indices=None): eff_B = B * num_views_per_clip all_outputs = [[] for _ in range(num_views_per_clip)] for i in range(num_clips): - o = outputs[i*eff_B:(i+1)*eff_B] + o = outputs[i * eff_B : (i + 1) * eff_B] for j in range(num_views_per_clip): - all_outputs[j].append(o[j*B:(j+1)*B]) + all_outputs[j].append(o[j * B : (j + 1) * B]) if not self.attend_across_segments: return all_outputs @@ -146,11 +148,17 @@ def forward(self, x, clip_indices=None): # Compute positional embedding if (self.pos_embed is not None) and (clip_indices is not None): - clip_indices = [c[:, ::self.tubelet_size] for c in clip_indices] + clip_indices = [c[:, :: self.tubelet_size] for c in clip_indices] pos_embed = self.pos_embed.repeat(B, 1, 1) # [B, F, D] - pos_embed = apply_masks(pos_embed, clip_indices, concat=False) # list(Tensor([B, T, D])) - pos_embed = torch.cat(pos_embed, dim=1) # concatenate along temporal dimension - pos_embed = pos_embed.unsqueeze(2).repeat(1, 1, N, 1) # [B, T*num_clips, N, D] + pos_embed = apply_masks( + pos_embed, clip_indices, concat=False + ) # list(Tensor([B, T, D])) + pos_embed = torch.cat( + pos_embed, dim=1 + ) # concatenate along temporal dimension + pos_embed = pos_embed.unsqueeze(2).repeat( + 1, 1, N, 1 + ) # [B, T*num_clips, N, D] pos_embed = pos_embed.flatten(1, 2) outputs += pos_embed @@ -162,19 +170,18 @@ def forward(self, x, clip_indices=None): def make_transforms( training=True, random_horizontal_flip=True, - random_resize_aspect_ratio=(3/4, 4/3), + random_resize_aspect_ratio=(3 / 4, 4 / 3), random_resize_scale=(0.3, 1.0), reprob=0.0, auto_augment=False, motion_shift=False, crop_size=224, num_views_per_clip=1, - normalize=((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) + normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ): if not training and num_views_per_clip > 1: - print('Making EvalVideoTransform, multi-view') + print("Making EvalVideoTransform, multi-view") _frames_augmentation = EvalVideoTransform( num_views_per_clip=num_views_per_clip, short_side_size=crop_size, @@ -202,25 +209,26 @@ def __init__( self, training=True, random_horizontal_flip=True, - random_resize_aspect_ratio=(3/4, 4/3), + random_resize_aspect_ratio=(3 / 4, 4 / 3), random_resize_scale=(0.3, 1.0), reprob=0.0, auto_augment=False, motion_shift=False, crop_size=224, - normalize=((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) + normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ): self.training = training short_side_size = int(crop_size * 256 / 224) - self.eval_transform = video_transforms.Compose([ - video_transforms.Resize(short_side_size, interpolation='bilinear'), - video_transforms.CenterCrop(size=(crop_size, crop_size)), - volume_transforms.ClipToTensor(), - video_transforms.Normalize(mean=normalize[0], std=normalize[1]) - ]) + self.eval_transform = video_transforms.Compose( + [ + video_transforms.Resize(short_side_size, interpolation="bilinear"), + video_transforms.CenterCrop(size=(crop_size, crop_size)), + volume_transforms.ClipToTensor(), + video_transforms.Normalize(mean=normalize[0], std=normalize[1]), + ] + ) self.random_horizontal_flip = random_horizontal_flip self.random_resize_aspect_ratio = random_resize_aspect_ratio @@ -232,20 +240,23 @@ def __init__( self.autoaug_transform = video_transforms.create_random_augment( input_size=(crop_size, crop_size), - auto_augment='rand-m7-n4-mstd0.5-inc1', - interpolation='bicubic', + auto_augment="rand-m7-n4-mstd0.5-inc1", + interpolation="bicubic", ) - self.spatial_transform = video_transforms.random_resized_crop_with_shift \ - if motion_shift else video_transforms.random_resized_crop + self.spatial_transform = ( + video_transforms.random_resized_crop_with_shift + if motion_shift + else video_transforms.random_resized_crop + ) self.reprob = reprob self.erase_transform = RandomErasing( reprob, - mode='pixel', + mode="pixel", max_count=1, num_splits=1, - device='cpu', + device="cpu", ) def __call__(self, buffer): @@ -289,16 +300,19 @@ def __init__( self, num_views_per_clip=1, short_side_size=224, - normalize=((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) + normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ): self.views_per_clip = num_views_per_clip self.short_side_size = short_side_size - self.spatial_resize = video_transforms.Resize(short_side_size, interpolation='bilinear') - self.to_tensor = video_transforms.Compose([ - volume_transforms.ClipToTensor(), - video_transforms.Normalize(mean=normalize[0], std=normalize[1]) - ]) + self.spatial_resize = video_transforms.Resize( + short_side_size, interpolation="bilinear" + ) + self.to_tensor = video_transforms.Compose( + [ + volume_transforms.ClipToTensor(), + video_transforms.Normalize(mean=normalize[0], std=normalize[1]), + ] + ) def __call__(self, buffer): @@ -312,11 +326,11 @@ def __call__(self, buffer): all_views = [] for i in range(num_views): - start = i*spatial_step + start = i * spatial_step if H > W: - view = buffer[:, start:start+side_len, :, :] + view = buffer[:, start : start + side_len, :, :] else: - view = buffer[:, :, start:start+side_len, :] + view = buffer[:, :, start : start + side_len, :] view = self.to_tensor(view) all_views.append(view) diff --git a/setup.py b/setup.py index 82de1e0..32c80f2 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,7 @@ VERSION = "0.0.1" + def get_requirements(): with open("./requirements.txt") as reqsf: reqs = reqsf.readlines() diff --git a/src/datasets/data_manager.py b/src/datasets/data_manager.py index cdb7ade..4ddc1b2 100644 --- a/src/datasets/data_manager.py +++ b/src/datasets/data_manager.py @@ -16,7 +16,7 @@ def init_data( batch_size, transform=None, shared_transform=None, - data='ImageNet', + data="ImageNet", collator=None, pin_mem=True, num_workers=8, @@ -45,10 +45,13 @@ def init_data( log_dir=None, ): - if (data.lower() == 'imagenet') \ - or (data.lower() == 'inat21') \ - or (data.lower() == 'places205'): + if ( + (data.lower() == "imagenet") + or (data.lower() == "inat21") + or (data.lower() == "places205") + ): from src.datasets.image_dataset import make_imagedataset + dataset, data_loader, dist_sampler = make_imagedataset( transform=transform, batch_size=batch_size, @@ -63,10 +66,12 @@ def init_data( persistent_workers=persistent_workers, copy_data=copy_data, drop_last=drop_last, - subset_file=subset_file) + subset_file=subset_file, + ) - elif data.lower() == 'videodataset': + elif data.lower() == "videodataset": from src.datasets.video_dataset import make_videodataset + dataset, data_loader, dist_sampler = make_videodataset( data_paths=root_path, batch_size=batch_size, @@ -86,6 +91,7 @@ def init_data( world_size=world_size, rank=rank, drop_last=drop_last, - log_dir=log_dir) + log_dir=log_dir, + ) return (data_loader, dist_sampler) diff --git a/src/datasets/image_dataset.py b/src/datasets/image_dataset.py index 84e9b08..ea34749 100644 --- a/src/datasets/image_dataset.py +++ b/src/datasets/image_dataset.py @@ -21,7 +21,7 @@ class ImageFolder(torchvision.datasets.ImageFolder): def __init__( self, root, - image_folder='imagenet_full_size/061417/', + image_folder="imagenet_full_size/061417/", transform=None, train=True, ): @@ -32,11 +32,11 @@ def __init__( :param train: whether to load train data (or validation) """ - suffix = 'train/' if train else 'val/' + suffix = "train/" if train else "val/" data_path = os.path.join(root, image_folder, suffix) - logger.info(f'data-path {data_path}') + logger.info(f"data-path {data_path}") super(ImageFolder, self).__init__(root=data_path, transform=transform) - logger.info('Initialized ImageFolder') + logger.info("Initialized ImageFolder") def make_imagedataset( @@ -53,18 +53,15 @@ def make_imagedataset( copy_data=False, drop_last=True, persistent_workers=False, - subset_file=None + subset_file=None, ): dataset = ImageFolder( - root=root_path, - image_folder=image_folder, - transform=transform, - train=training) - logger.info('ImageFolder dataset created') + root=root_path, image_folder=image_folder, transform=transform, train=training + ) + logger.info("ImageFolder dataset created") dist_sampler = torch.utils.data.distributed.DistributedSampler( - dataset=dataset, - num_replicas=world_size, - rank=rank) + dataset=dataset, num_replicas=world_size, rank=rank + ) data_loader = torch.utils.data.DataLoader( dataset, collate_fn=collator, @@ -73,7 +70,8 @@ def make_imagedataset( drop_last=drop_last, pin_memory=pin_mem, num_workers=num_workers, - persistent_workers=persistent_workers) - logger.info('ImageFolder unsupervised data loader created') + persistent_workers=persistent_workers, + ) + logger.info("ImageFolder unsupervised data loader created") return dataset, data_loader, dist_sampler diff --git a/src/datasets/utils/video/functional.py b/src/datasets/utils/video/functional.py index a91d15d..9e443f8 100644 --- a/src/datasets/utils/video/functional.py +++ b/src/datasets/utils/video/functional.py @@ -18,56 +18,54 @@ def _is_tensor_clip(clip): def crop_clip(clip, min_h, min_w, h, w): if isinstance(clip[0], np.ndarray): - cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] + cropped = [img[min_h : min_h + h, min_w : min_w + w, :] for img in clip] elif isinstance(clip[0], PIL.Image.Image): - cropped = [ - img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip - ] + cropped = [img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip] else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) + raise TypeError( + "Expected numpy.ndarray or PIL.Image" + + "but got list of {0}".format(type(clip[0])) + ) return cropped -def resize_clip(clip, size, interpolation='bilinear'): +def resize_clip(clip, size, interpolation="bilinear"): if isinstance(clip[0], np.ndarray): if isinstance(size, numbers.Number): im_h, im_w, im_c = clip[0].shape # Min spatial dim already matches minimal size - if (im_w <= im_h and im_w == size) or (im_h <= im_w - and im_h == size): + if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size): return clip new_h, new_w = get_resize_sizes(im_h, im_w, size) size = (new_w, new_h) else: size = size[0], size[1] - if interpolation == 'bilinear': + if interpolation == "bilinear": np_inter = cv2.INTER_LINEAR else: np_inter = cv2.INTER_NEAREST - scaled = [ - cv2.resize(img, size, interpolation=np_inter) for img in clip - ] + scaled = [cv2.resize(img, size, interpolation=np_inter) for img in clip] elif isinstance(clip[0], PIL.Image.Image): if isinstance(size, numbers.Number): im_w, im_h = clip[0].size # Min spatial dim already matches minimal size - if (im_w <= im_h and im_w == size) or (im_h <= im_w - and im_h == size): + if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size): return clip new_h, new_w = get_resize_sizes(im_h, im_w, size) size = (new_w, new_h) else: size = size[1], size[0] - if interpolation == 'bilinear': + if interpolation == "bilinear": pil_inter = PIL.Image.BILINEAR else: pil_inter = PIL.Image.NEAREST scaled = [img.resize(size, pil_inter) for img in clip] else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) + raise TypeError( + "Expected numpy.ndarray or PIL.Image" + + "but got list of {0}".format(type(clip[0])) + ) return scaled @@ -83,7 +81,7 @@ def get_resize_sizes(im_h, im_w, size): def normalize(clip, mean, std, inplace=False): if not _is_tensor_clip(clip): - raise TypeError('tensor is not a torch clip.') + raise TypeError("tensor is not a torch clip.") if not inplace: clip = clip.clone() diff --git a/src/datasets/utils/video/randaugment.py b/src/datasets/utils/video/randaugment.py index 4c80a99..a0f060e 100644 --- a/src/datasets/utils/video/randaugment.py +++ b/src/datasets/utils/video/randaugment.py @@ -50,46 +50,34 @@ def _check_args_tf(kwargs): def shear_x(img, factor, **kwargs): _check_args_tf(kwargs) - return img.transform( - img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs - ) + return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) def shear_y(img, factor, **kwargs): _check_args_tf(kwargs) - return img.transform( - img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs - ) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) def translate_x_rel(img, pct, **kwargs): pixels = pct * img.size[0] _check_args_tf(kwargs) - return img.transform( - img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs - ) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) def translate_y_rel(img, pct, **kwargs): pixels = pct * img.size[1] _check_args_tf(kwargs) - return img.transform( - img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs - ) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) def translate_x_abs(img, pixels, **kwargs): _check_args_tf(kwargs) - return img.transform( - img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs - ) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) def translate_y_abs(img, pixels, **kwargs): _check_args_tf(kwargs) - return img.transform( - img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs - ) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) def rotate(img, degrees, **kwargs): @@ -334,12 +322,12 @@ def __init__(self, name, prob=0.5, magnitude=10, hparams=None): self.magnitude = magnitude self.hparams = hparams.copy() self.kwargs = { - "fillcolor": hparams["img_mean"] - if "img_mean" in hparams - else _FILL, - "resample": hparams["interpolation"] - if "interpolation" in hparams - else _RANDOM_INTERPOLATION, + "fillcolor": hparams["img_mean"] if "img_mean" in hparams else _FILL, + "resample": ( + hparams["interpolation"] + if "interpolation" in hparams + else _RANDOM_INTERPOLATION + ), } # If magnitude_std is > 0, we introduce some randomness @@ -356,15 +344,11 @@ def __call__(self, img_list): magnitude = random.gauss(magnitude, self.magnitude_std) magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range level_args = ( - self.level_fn(magnitude, self.hparams) - if self.level_fn is not None - else () + self.level_fn(magnitude, self.hparams) if self.level_fn is not None else () ) if isinstance(img_list, list): - return [ - self.aug_fn(img, *level_args, **self.kwargs) for img in img_list - ] + return [self.aug_fn(img, *level_args, **self.kwargs) for img in img_list] else: return self.aug_fn(img_list, *level_args, **self.kwargs) @@ -512,7 +496,5 @@ def rand_augment_transform(config_str, hparams): ra_ops = rand_augment_ops( magnitude=magnitude, hparams=hparams, transforms=transforms ) - choice_weights = ( - None if weight_idx is None else _select_rand_weights(weight_idx) - ) + choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) diff --git a/src/datasets/utils/video/randerase.py b/src/datasets/utils/video/randerase.py index d1f185c..0136bfe 100644 --- a/src/datasets/utils/video/randerase.py +++ b/src/datasets/utils/video/randerase.py @@ -15,18 +15,14 @@ import torch -def _get_pixels( - per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" -): +def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"): # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() # paths, flip the order so normal is run on CPU if this becomes a problem # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 if per_pixel: return torch.empty(patch_size, dtype=dtype, device=device).normal_() elif rand_color: - return torch.empty( - (patch_size[0], 1, 1), dtype=dtype, device=device - ).normal_() + return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() else: return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) @@ -104,7 +100,7 @@ def _erase(self, img, chan, img_h, img_w, dtype): if w < img_w and h < img_h: top = random.randint(0, img_h - h) left = random.randint(0, img_w - w) - img[:, top:top + h, left:left + w] = _get_pixels( + img[:, top : top + h, left : left + w] = _get_pixels( self.per_pixel, self.rand_color, (chan, h, w), @@ -144,9 +140,7 @@ def _erase_cube( left = random.randint(0, img_w - w) for i in range(batch_start, batch_size): img_instance = img[i] - img_instance[ - :, top:top + h, left:left + w - ] = _get_pixels( + img_instance[:, top : top + h, left : left + w] = _get_pixels( self.per_pixel, self.rand_color, (chan, h, w), @@ -161,9 +155,7 @@ def __call__(self, input): else: batch_size, chan, img_h, img_w = input.size() # skip first slice of batch if num_splits is set (for clean portion of samples) - batch_start = ( - batch_size // self.num_splits if self.num_splits > 1 else 0 - ) + batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 if self.cube: self._erase_cube( input, diff --git a/src/datasets/utils/video/transforms.py b/src/datasets/utils/video/transforms.py index ffa8e61..c606c99 100644 --- a/src/datasets/utils/video/transforms.py +++ b/src/datasets/utils/video/transforms.py @@ -22,12 +22,12 @@ _pil_interpolation_to_str = { - Image.NEAREST: 'PIL.Image.NEAREST', - Image.BILINEAR: 'PIL.Image.BILINEAR', - Image.BICUBIC: 'PIL.Image.BICUBIC', - Image.LANCZOS: 'PIL.Image.LANCZOS', - Image.HAMMING: 'PIL.Image.HAMMING', - Image.BOX: 'PIL.Image.BOX', + Image.NEAREST: "PIL.Image.NEAREST", + Image.BILINEAR: "PIL.Image.BILINEAR", + Image.BICUBIC: "PIL.Image.BICUBIC", + Image.LANCZOS: "PIL.Image.LANCZOS", + Image.HAMMING: "PIL.Image.HAMMING", + Image.BOX: "PIL.Image.BOX", } @@ -35,11 +35,11 @@ def _pil_interp(method): - if method == 'bicubic': + if method == "bicubic": return Image.BICUBIC - elif method == 'lanczos': + elif method == "lanczos": return Image.LANCZOS - elif method == 'hamming': + elif method == "hamming": return Image.HAMMING else: return Image.BILINEAR @@ -68,17 +68,13 @@ def random_short_side_scale_jitter( `num boxes` x 4. """ if inverse_uniform_sampling: - size = int( - round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size)) - ) + size = int(round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size))) else: size = int(round(np.random.uniform(min_size, max_size))) height = images.shape[2] width = images.shape[3] - if (width <= height and width == size) or ( - height <= width and height == size - ): + if (width <= height and width == size) or (height <= width and height == size): return images, boxes new_width = size new_height = size @@ -95,7 +91,7 @@ def random_short_side_scale_jitter( torch.nn.functional.interpolate( images, size=(new_height, new_width), - mode='bilinear', + mode="bilinear", align_corners=False, ), boxes, @@ -146,13 +142,9 @@ def random_crop(images, size, boxes=None): x_offset = 0 if width > size: x_offset = int(np.random.randint(0, width - size)) - cropped = images[ - :, :, y_offset:y_offset + size, x_offset:x_offset + size - ] + cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] - cropped_boxes = ( - crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None - ) + cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None return cropped, cropped_boxes @@ -227,7 +219,7 @@ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): images = torch.nn.functional.interpolate( images, size=(height, width), - mode='bilinear', + mode="bilinear", align_corners=False, ) @@ -244,12 +236,8 @@ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): x_offset = 0 elif spatial_idx == 2: x_offset = width - size - cropped = images[ - :, :, y_offset:y_offset + size, x_offset:x_offset + size - ] - cropped_boxes = ( - crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None - ) + cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] + cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None if ndim == 3: cropped = cropped.squeeze(0) return cropped, cropped_boxes @@ -306,9 +294,7 @@ def grayscale(images): """ # R -> 0.299, G -> 0.587, B -> 0.114. img_gray = torch.tensor(images) - gray_channel = ( - 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0] - ) + gray_channel = 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0] img_gray[:, 0] = gray_channel img_gray[:, 1] = gray_channel img_gray[:, 2] = gray_channel @@ -332,20 +318,20 @@ def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0): jitter = [] if img_brightness != 0: - jitter.append('brightness') + jitter.append("brightness") if img_contrast != 0: - jitter.append('contrast') + jitter.append("contrast") if img_saturation != 0: - jitter.append('saturation') + jitter.append("saturation") if len(jitter) > 0: order = np.random.permutation(np.arange(len(jitter))) for idx in range(0, len(jitter)): - if jitter[order[idx]] == 'brightness': + if jitter[order[idx]] == "brightness": images = brightness_jitter(img_brightness, images) - elif jitter[order[idx]] == 'contrast': + elif jitter[order[idx]] == "contrast": images = contrast_jitter(img_contrast, images) - elif jitter[order[idx]] == 'saturation': + elif jitter[order[idx]] == "saturation": images = saturation_jitter(img_saturation, images) return images @@ -439,7 +425,7 @@ def lighting_jitter(images, alphastd, eigval, eigvec): # T C H W channel_dim = 1 else: - raise NotImplementedError(f'Unsupported dimension {len(images.shape)}') + raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") for idx in range(images.shape[channel_dim]): # C H W @@ -449,9 +435,7 @@ def lighting_jitter(images, alphastd, eigval, eigvec): elif len(images.shape) == 4: out_images[:, idx] = images[:, idx] + rgb[2 - idx] else: - raise NotImplementedError( - f'Unsupported dimension {len(images.shape)}' - ) + raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") return out_images @@ -470,21 +454,13 @@ def color_normalization(images, mean, stddev): `num frames` x `channel` x `height` x `width`. """ if len(images.shape) == 3: - assert ( - len(mean) == images.shape[0] - ), 'channel mean not computed properly' - assert ( - len(stddev) == images.shape[0] - ), 'channel stddev not computed properly' + assert len(mean) == images.shape[0], "channel mean not computed properly" + assert len(stddev) == images.shape[0], "channel stddev not computed properly" elif len(images.shape) == 4: - assert ( - len(mean) == images.shape[1] - ), 'channel mean not computed properly' - assert ( - len(stddev) == images.shape[1] - ), 'channel stddev not computed properly' + assert len(mean) == images.shape[1], "channel mean not computed properly" + assert len(stddev) == images.shape[1], "channel stddev not computed properly" else: - raise NotImplementedError(f'Unsupported dimension {len(images.shape)}') + raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") out_images = torch.zeros_like(images) for idx in range(len(mean)): @@ -494,9 +470,7 @@ def color_normalization(images, mean, stddev): elif len(images.shape) == 4: out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx] else: - raise NotImplementedError( - f'Unsupported dimension {len(images.shape)}' - ) + raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") return out_images @@ -568,11 +542,11 @@ def random_resized_crop( width = images.shape[3] i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) - cropped = images[:, :, i:i + h, j:j + w] + cropped = images[:, :, i : i + h, j : j + w] return torch.nn.functional.interpolate( cropped, size=(target_height, target_width), - mode='bilinear', + mode="bilinear", align_corners=False, ) @@ -608,15 +582,15 @@ def random_resized_crop_with_shift( w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()] out = torch.zeros((3, t, target_height, target_width)) for ind in range(t): - out[:, ind:ind + 1, :, :] = torch.nn.functional.interpolate( + out[:, ind : ind + 1, :, :] = torch.nn.functional.interpolate( images[ :, - ind:ind + 1, - i_s[ind]:i_s[ind] + h_s[ind], - j_s[ind]:j_s[ind] + w_s[ind], + ind : ind + 1, + i_s[ind] : i_s[ind] + h_s[ind], + j_s[ind] : j_s[ind] + w_s[ind], ], size=(target_height, target_width), - mode='bilinear', + mode="bilinear", align_corners=False, ) return out @@ -625,7 +599,7 @@ def random_resized_crop_with_shift( def create_random_augment( input_size, auto_augment=None, - interpolation='bilinear', + interpolation="bilinear", ): """ Get video randaug transform. @@ -648,13 +622,11 @@ def create_random_augment( img_size_min = min(img_size) else: img_size_min = img_size - aa_params = {'translate_const': int(img_size_min * 0.45)} - if interpolation and interpolation != 'random': - aa_params['interpolation'] = _pil_interp(interpolation) - if auto_augment.startswith('rand'): - return transforms.Compose( - [rand_augment_transform(auto_augment, aa_params)] - ) + aa_params = {"translate_const": int(img_size_min * 0.45)} + if interpolation and interpolation != "random": + aa_params["interpolation"] = _pil_interp(interpolation) + if auto_augment.startswith("rand"): + return transforms.Compose([rand_augment_transform(auto_augment, aa_params)]) raise NotImplementedError @@ -668,9 +640,7 @@ def random_sized_crop_img( """ Performs Inception-style cropping (used for training). """ - assert ( - len(im.shape) == 3 - ), 'Currently only support image for random_sized_crop' + assert len(im.shape) == 3, "Currently only support image for random_sized_crop" h, w = im.shape[1:3] i, j, h, w = _get_param_spatial_crop( scale=jitter_scale, @@ -681,11 +651,11 @@ def random_sized_crop_img( log_scale=False, switch_hw=True, ) - cropped = im[:, i:i + h, j:j + w] + cropped = im[:, i : i + h, j : j + w] return torch.nn.functional.interpolate( cropped.unsqueeze(0), size=(size, size), - mode='bilinear', + mode="bilinear", align_corners=False, ).squeeze(0) @@ -711,16 +681,16 @@ def __init__( size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), - interpolation='bilinear', + interpolation="bilinear", ): if isinstance(size, tuple): self.size = size else: self.size = (size, size) if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): - print('range should be of kind (min, max)') + print("range should be of kind (min, max)") - if interpolation == 'random': + if interpolation == "random": self.interpolation = _RANDOM_INTERPOLATION else: self.interpolation = _pil_interp(interpolation) @@ -784,19 +754,15 @@ def __call__(self, img): def __repr__(self): if isinstance(self.interpolation, (tuple, list)): - interpolate_str = ' '.join( + interpolate_str = " ".join( [_pil_interpolation_to_str[x] for x in self.interpolation] ) else: interpolate_str = _pil_interpolation_to_str[self.interpolation] - format_string = self.__class__.__name__ + '(size={0}'.format(self.size) - format_string += ', scale={0}'.format( - tuple(round(s, 4) for s in self.scale) - ) - format_string += ', ratio={0}'.format( - tuple(round(r, 4) for r in self.ratio) - ) - format_string += ', interpolation={0})'.format(interpolate_str) + format_string = self.__class__.__name__ + "(size={0}".format(self.size) + format_string += ", scale={0}".format(tuple(round(s, 4) for s in self.scale)) + format_string += ", ratio={0}".format(tuple(round(r, 4) for r in self.ratio)) + format_string += ", interpolation={0})".format(interpolate_str) return format_string @@ -833,12 +799,12 @@ def __call__(self, clip): if isinstance(clip[0], np.ndarray): return [np.fliplr(img) for img in clip] elif isinstance(clip[0], PIL.Image.Image): - return [ - img.transpose(PIL.Image.FLIP_LEFT_RIGHT) for img in clip - ] + return [img.transpose(PIL.Image.FLIP_LEFT_RIGHT) for img in clip] else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - ' but got list of {0}'.format(type(clip[0]))) + raise TypeError( + "Expected numpy.ndarray or PIL.Image" + + " but got list of {0}".format(type(clip[0])) + ) return clip @@ -852,7 +818,7 @@ class RandomResize(object): size (tuple): (widht, height) """ - def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): + def __init__(self, ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation="nearest"): self.ratio = ratio self.interpolation = interpolation @@ -867,8 +833,7 @@ def __call__(self, clip): new_w = int(im_w * scaling_factor) new_h = int(im_h * scaling_factor) new_size = (new_w, new_h) - resized = FF.resize_clip( - clip, new_size, interpolation=self.interpolation) + resized = FF.resize_clip(clip, new_size, interpolation=self.interpolation) return resized @@ -882,13 +847,12 @@ class Resize(object): size (tuple): (widht, height) """ - def __init__(self, size, interpolation='nearest'): + def __init__(self, size, interpolation="nearest"): self.size = size self.interpolation = interpolation def __call__(self, clip): - resized = FF.resize_clip( - clip, self.size, interpolation=self.interpolation) + resized = FF.resize_clip(clip, self.size, interpolation=self.interpolation) return resized @@ -919,14 +883,18 @@ def __call__(self, clip): elif isinstance(clip[0], PIL.Image.Image): im_w, im_h = clip[0].size else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) + raise TypeError( + "Expected numpy.ndarray or PIL.Image" + + "but got list of {0}".format(type(clip[0])) + ) if w > im_w or h > im_h: error_msg = ( - 'Initial image size should be larger then ' - 'cropped size but got cropped sizes : ({w}, {h}) while ' - 'initial image is ({im_w}, {im_h})'.format( - im_w=im_w, im_h=im_h, w=w, h=h)) + "Initial image size should be larger then " + "cropped size but got cropped sizes : ({w}, {h}) while " + "initial image is ({im_w}, {im_h})".format( + im_w=im_w, im_h=im_h, w=w, h=h + ) + ) raise ValueError(error_msg) x1 = random.randint(0, im_w - w) @@ -963,8 +931,10 @@ def __call__(self, clip): elif isinstance(clip[0], PIL.Image.Image): im_w, im_h = clip[0].size else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) + raise TypeError( + "Expected numpy.ndarray or PIL.Image" + + "but got list of {0}".format(type(clip[0])) + ) if w != im_w and h != im_h: clip = FF.resize_clip(clip, self.size, interpolation="bilinear") im_h, im_w, im_c = clip[0].shape @@ -972,7 +942,7 @@ def __call__(self, clip): step = np.max((np.max((im_w, im_h)) - self.size[0]) // 2, 0) cropped = [] for i in range(3): - if (im_h > self.size[0]): + if im_h > self.size[0]: x1 = 0 y1 = i * step cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) @@ -995,13 +965,11 @@ class RandomRotation(object): def __init__(self, degrees): if isinstance(degrees, numbers.Number): if degrees < 0: - raise ValueError('If degrees is a single number,' - 'must be positive') + raise ValueError("If degrees is a single number," "must be positive") degrees = (-degrees, degrees) else: if len(degrees) != 2: - raise ValueError('If degrees is a sequence,' - 'it must be of len 2.') + raise ValueError("If degrees is a sequence," "it must be of len 2.") self.degrees = degrees @@ -1014,14 +982,17 @@ def __call__(self, clip): PIL.Image or numpy.ndarray: Cropped list of images """ import skimage + angle = random.uniform(self.degrees[0], self.degrees[1]) if isinstance(clip[0], np.ndarray): rotated = [skimage.transform.rotate(img, angle) for img in clip] elif isinstance(clip[0], PIL.Image.Image): rotated = [img.rotate(angle) for img in clip] else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) + raise TypeError( + "Expected numpy.ndarray or PIL.Image" + + "but got list of {0}".format(type(clip[0])) + ) return rotated @@ -1053,18 +1024,22 @@ def __call__(self, clip): elif isinstance(clip[0], PIL.Image.Image): im_w, im_h = clip[0].size else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) + raise TypeError( + "Expected numpy.ndarray or PIL.Image" + + "but got list of {0}".format(type(clip[0])) + ) if w > im_w or h > im_h: error_msg = ( - 'Initial image size should be larger then ' - 'cropped size but got cropped sizes : ({w}, {h}) while ' - 'initial image is ({im_w}, {im_h})'.format( - im_w=im_w, im_h=im_h, w=w, h=h)) + "Initial image size should be larger then " + "cropped size but got cropped sizes : ({w}, {h}) while " + "initial image is ({im_w}, {im_h})".format( + im_w=im_w, im_h=im_h, w=w, h=h + ) + ) raise ValueError(error_msg) - x1 = int(round((im_w - w) / 2.)) - y1 = int(round((im_h - h) / 2.)) + x1 = int(round((im_w - w) / 2.0)) + y1 = int(round((im_h - h) / 2.0)) cropped = FF.crop_clip(clip, y1, x1, h, w) return cropped @@ -1093,20 +1068,17 @@ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): def get_params(self, brightness, contrast, saturation, hue): if brightness > 0: - brightness_factor = random.uniform( - max(0, 1 - brightness), 1 + brightness) + brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness) else: brightness_factor = None if contrast > 0: - contrast_factor = random.uniform( - max(0, 1 - contrast), 1 + contrast) + contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast) else: contrast_factor = None if saturation > 0: - saturation_factor = random.uniform( - max(0, 1 - saturation), 1 + saturation) + saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation) else: saturation_factor = None @@ -1124,22 +1096,36 @@ def __call__(self, clip): list PIL.Image : list of transformed PIL.Image """ if isinstance(clip[0], np.ndarray): - raise TypeError( - 'Color jitter not yet implemented for numpy arrays') + raise TypeError("Color jitter not yet implemented for numpy arrays") elif isinstance(clip[0], PIL.Image.Image): brightness, contrast, saturation, hue = self.get_params( - self.brightness, self.contrast, self.saturation, self.hue) + self.brightness, self.contrast, self.saturation, self.hue + ) # Create img transform function sequence img_transforms = [] if brightness is not None: - img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) + img_transforms.append( + lambda img: torchvision.transforms.functional.adjust_brightness( + img, brightness + ) + ) if saturation is not None: - img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) + img_transforms.append( + lambda img: torchvision.transforms.functional.adjust_saturation( + img, saturation + ) + ) if hue is not None: - img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) + img_transforms.append( + lambda img: torchvision.transforms.functional.adjust_hue(img, hue) + ) if contrast is not None: - img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) + img_transforms.append( + lambda img: torchvision.transforms.functional.adjust_contrast( + img, contrast + ) + ) random.shuffle(img_transforms) # Apply to all images @@ -1150,8 +1136,10 @@ def __call__(self, clip): jittered_clip.append(jittered_img) else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) + raise TypeError( + "Expected numpy.ndarray or PIL.Image" + + "but got list of {0}".format(type(clip[0])) + ) return jittered_clip @@ -1181,4 +1169,6 @@ def __call__(self, clip): return FF.normalize(clip, self.mean, self.std) def __repr__(self): - return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) + return self.__class__.__name__ + "(mean={0}, std={1})".format( + self.mean, self.std + ) diff --git a/src/datasets/utils/weighted_sampler.py b/src/datasets/utils/weighted_sampler.py index fd40825..c228f4b 100644 --- a/src/datasets/utils/weighted_sampler.py +++ b/src/datasets/utils/weighted_sampler.py @@ -10,12 +10,7 @@ import numpy as np import torch -from torch.utils.data import ( - Dataset, - Sampler, - DistributedSampler, - WeightedRandomSampler -) +from torch.utils.data import Dataset, Sampler, DistributedSampler, WeightedRandomSampler class DatasetFromSampler(Dataset): @@ -34,7 +29,7 @@ def __len__(self) -> int: class DistributedSamplerWrapper(DistributedSampler): - """ Convert any Pytorch Sampler to a DistributedSampler """ + """Convert any Pytorch Sampler to a DistributedSampler""" def __init__( self, @@ -59,7 +54,7 @@ def __iter__(self) -> Iterator[int]: class CustomWeightedRandomSampler(WeightedRandomSampler): - """ Generalized WeightedRandomSampler to allow for more than 2^24 samples """ + """Generalized WeightedRandomSampler to allow for more than 2^24 samples""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -69,7 +64,7 @@ def __iter__(self): range(0, len(self.weights)), size=self.num_samples, p=self.weights.numpy() / torch.sum(self.weights).numpy(), - replace=self.replacement + replace=self.replacement, ) rand_tensor = torch.from_numpy(rand_tensor) return iter(rand_tensor.tolist()) @@ -85,9 +80,8 @@ def __init__( shuffle: bool = True, ): weighted_sampler = CustomWeightedRandomSampler( - weights=weights, - num_samples=len(weights), - replacement=False) + weights=weights, num_samples=len(weights), replacement=False + ) super(DistributedWeightedSampler, self).__init__( sampler=weighted_sampler, diff --git a/src/datasets/video_dataset.py b/src/datasets/video_dataset.py index b05cc70..0fd153d 100644 --- a/src/datasets/video_dataset.py +++ b/src/datasets/video_dataset.py @@ -58,21 +58,18 @@ def make_videodataset( filter_long_videos=filter_long_videos, duration=duration, shared_transform=shared_transform, - transform=transform) + transform=transform, + ) - logger.info('VideoDataset dataset created') + logger.info("VideoDataset dataset created") if datasets_weights is not None: dist_sampler = DistributedWeightedSampler( - dataset.sample_weights, - num_replicas=world_size, - rank=rank, - shuffle=True) + dataset.sample_weights, num_replicas=world_size, rank=rank, shuffle=True + ) else: dist_sampler = torch.utils.data.distributed.DistributedSampler( - dataset, - num_replicas=world_size, - rank=rank, - shuffle=True) + dataset, num_replicas=world_size, rank=rank, shuffle=True + ) data_loader = torch.utils.data.DataLoader( dataset, @@ -82,14 +79,15 @@ def make_videodataset( drop_last=drop_last, pin_memory=pin_mem, num_workers=num_workers, - persistent_workers=num_workers > 0) - logger.info('VideoDataset unsupervised data loader created') + persistent_workers=num_workers > 0, + ) + logger.info("VideoDataset unsupervised data loader created") return dataset, data_loader, dist_sampler class VideoDataset(torch.utils.data.Dataset): - """ Video classification dataset. """ + """Video classification dataset.""" def __init__( self, @@ -120,21 +118,23 @@ def __init__( self.duration = duration if VideoReader is None: - raise ImportError('Unable to import "decord" which is required to read videos.') + raise ImportError( + 'Unable to import "decord" which is required to read videos.' + ) # Load video paths and labels samples, labels = [], [] self.num_samples_per_dataset = [] for data_path in self.data_paths: - if data_path[-4:] == '.csv': + if data_path[-4:] == ".csv": data = pd.read_csv(data_path, header=None, delimiter=" ") samples += list(data.values[:, 0]) labels += list(data.values[:, 1]) num_samples = len(data) self.num_samples_per_dataset.append(num_samples) - elif data_path[-4:] == '.npy': + elif data_path[-4:] == ".npy": data = np.load(data_path, allow_pickle=True) data = list(map(lambda x: repr(x)[1:-1], data)) samples += data @@ -169,10 +169,10 @@ def __getitem__(self, index): label = self.labels[index] def split_into_clips(video): - """ Split video into a list of clips """ + """Split video into a list of clips""" fpc = self.frames_per_clip nc = self.num_clips - return [video[i*fpc:(i+1)*fpc] for i in range(nc)] + return [video[i * fpc : (i + 1) * fpc] for i in range(nc)] # Parse video into frames & apply data augmentations if self.shared_transform is not None: @@ -184,19 +184,19 @@ def split_into_clips(video): return buffer, label, clip_indices def loadvideo_decord(self, sample): - """ Load video content using Decord """ + """Load video content using Decord""" fname = sample if not os.path.exists(fname): - warnings.warn(f'video path not found {fname=}') + warnings.warn(f"video path not found {fname=}") return [], None _fsize = os.path.getsize(fname) if _fsize < 1 * 1024: # avoid hanging issue - warnings.warn(f'video too short {fname=}') + warnings.warn(f"video too short {fname=}") return [], None if _fsize > self.filter_long_videos: - warnings.warn(f'skipping long video of size {_fsize=} (bytes)') + warnings.warn(f"skipping long video of size {_fsize=} (bytes)") return [], None try: @@ -215,7 +215,7 @@ def loadvideo_decord(self, sample): clip_len = int(fpc * fstp) if self.filter_short_videos and len(vr) < clip_len: - warnings.warn(f'skipping video of length {len(vr)}') + warnings.warn(f"skipping video of length {len(vr)}") return [], None vr.seek(0) # Go to start of video before sampling frames @@ -235,7 +235,7 @@ def loadvideo_decord(self, sample): end_indx = np.random.randint(clip_len, partition_len) start_indx = end_indx - clip_len indices = np.linspace(start_indx, end_indx, num=fpc) - indices = np.clip(indices, start_indx, end_indx-1).astype(np.int64) + indices = np.clip(indices, start_indx, end_indx - 1).astype(np.int64) # -- indices = indices + i * partition_len else: @@ -244,8 +244,13 @@ def loadvideo_decord(self, sample): # we reach the desired clip length if not self.allow_clip_overlap: indices = np.linspace(0, partition_len, num=partition_len // fstp) - indices = np.concatenate((indices, np.ones(fpc - partition_len // fstp) * partition_len,)) - indices = np.clip(indices, 0, partition_len-1).astype(np.int64) + indices = np.concatenate( + ( + indices, + np.ones(fpc - partition_len // fstp) * partition_len, + ) + ) + indices = np.clip(indices, 0, partition_len - 1).astype(np.int64) # -- indices = indices + i * partition_len @@ -254,8 +259,13 @@ def loadvideo_decord(self, sample): else: sample_len = min(clip_len, len(vr)) - 1 indices = np.linspace(0, sample_len, num=sample_len // fstp) - indices = np.concatenate((indices, np.ones(fpc - sample_len // fstp) * sample_len,)) - indices = np.clip(indices, 0, sample_len-1).astype(np.int64) + indices = np.concatenate( + ( + indices, + np.ones(fpc - sample_len // fstp) * sample_len, + ) + ) + indices = np.clip(indices, 0, sample_len - 1).astype(np.int64) # -- clip_step = 0 if len(vr) > clip_len: diff --git a/src/masks/multiblock3d.py b/src/masks/multiblock3d.py index a7bbc3e..5ab5793 100644 --- a/src/masks/multiblock3d.py +++ b/src/masks/multiblock3d.py @@ -36,12 +36,12 @@ def __init__( num_frames=num_frames, spatial_patch_size=patch_size, temporal_patch_size=tubelet_size, - spatial_pred_mask_scale=m.get('spatial_scale'), - temporal_pred_mask_scale=m.get('temporal_scale'), - aspect_ratio=m.get('aspect_ratio'), - npred=m.get('num_blocks'), - max_context_frames_ratio=m.get('max_temporal_keep', 1.0), - max_keep=m.get('max_keep', None), + spatial_pred_mask_scale=m.get("spatial_scale"), + temporal_pred_mask_scale=m.get("temporal_scale"), + aspect_ratio=m.get("aspect_ratio"), + npred=m.get("num_blocks"), + max_context_frames_ratio=m.get("max_temporal_keep", 1.0), + max_keep=m.get("max_keep", None), ) self.mask_generators.append(mask_generator) @@ -80,9 +80,12 @@ def __init__( ): super(_MaskGenerator, self).__init__() if not isinstance(crop_size, tuple): - crop_size = (crop_size, ) * 2 + crop_size = (crop_size,) * 2 self.crop_size = crop_size - self.height, self.width = crop_size[0] // spatial_patch_size, crop_size[1] // spatial_patch_size + self.height, self.width = ( + crop_size[0] // spatial_patch_size, + crop_size[1] // spatial_patch_size, + ) self.duration = num_frames // temporal_patch_size self.spatial_patch_size = spatial_patch_size @@ -92,9 +95,11 @@ def __init__( self.spatial_pred_mask_scale = spatial_pred_mask_scale self.temporal_pred_mask_scale = temporal_pred_mask_scale self.npred = npred - self.max_context_duration = max(1, int(self.duration * max_context_frames_ratio)) # maximum number of time-steps (frames) spanned by context mask + self.max_context_duration = max( + 1, int(self.duration * max_context_frames_ratio) + ) # maximum number of time-steps (frames) spanned by context mask self.max_keep = max_keep # maximum number of patches to keep in context - self._itr_counter = Value('i', -1) # collator is shared across worker processes + self._itr_counter = Value("i", -1) # collator is shared across worker processes def step(self): i = self._itr_counter @@ -104,11 +109,7 @@ def step(self): return v def _sample_block_size( - self, - generator, - temporal_scale, - spatial_scale, - aspect_ratio_scale + self, generator, temporal_scale, spatial_scale, aspect_ratio_scale ): # -- Sample temporal block mask scale _rand = torch.rand(1, generator=generator).item() @@ -142,12 +143,12 @@ def _sample_block_mask(self, b_size): start = torch.randint(0, self.duration - t + 1, (1,)) mask = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) - mask[start:start+t, top:top+h, left:left+w] = 0 + mask[start : start + t, top : top + h, left : left + w] = 0 # Context mask will only span the first X frames # (X=self.max_context_frames) if self.max_context_duration < self.duration: - mask[self.max_context_duration:, :, :] = 0 + mask[self.max_context_duration :, :, :] = 0 # -- return mask @@ -176,7 +177,9 @@ def __call__(self, batch_size): empty_context = True while empty_context: - mask_e = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) + mask_e = torch.ones( + (self.duration, self.height, self.width), dtype=torch.int32 + ) for _ in range(self.npred): mask_e *= self._sample_block_mask(p_size) mask_e = mask_e.flatten() diff --git a/src/masks/random_tube.py b/src/masks/random_tube.py index 84c0640..55fc65d 100644 --- a/src/masks/random_tube.py +++ b/src/masks/random_tube.py @@ -35,7 +35,7 @@ def __init__( num_frames=num_frames, spatial_patch_size=patch_size, temporal_patch_size=tubelet_size, - ratio=m.get('ratio'), + ratio=m.get("ratio"), ) self.mask_generators.append(mask_generator) @@ -69,21 +69,24 @@ def __init__( ): super(_MaskGenerator, self).__init__() if not isinstance(crop_size, tuple): - crop_size = (crop_size, ) * 2 + crop_size = (crop_size,) * 2 self.crop_size = crop_size - self.height, self.width = crop_size[0] // spatial_patch_size, crop_size[1] // spatial_patch_size + self.height, self.width = ( + crop_size[0] // spatial_patch_size, + crop_size[1] // spatial_patch_size, + ) self.duration = num_frames // temporal_patch_size self.spatial_patch_size = spatial_patch_size self.temporal_patch_size = temporal_patch_size - self.num_patches_spatial = self.height*self.width + self.num_patches_spatial = self.height * self.width self.ratio = ratio - self.num_keep_spatial = int(self.num_patches_spatial*(1.-self.ratio)) + self.num_keep_spatial = int(self.num_patches_spatial * (1.0 - self.ratio)) self.num_keep = self.num_keep_spatial * self.duration - self._itr_counter = Value('i', -1) # collator is shared across worker processes + self._itr_counter = Value("i", -1) # collator is shared across worker processes def step(self): i = self._itr_counter @@ -94,10 +97,12 @@ def step(self): def __call__(self, batch_size): def sample_mask(): - mask = np.hstack([ - np.zeros(self.num_patches_spatial - self.num_keep_spatial), - np.ones(self.num_keep_spatial), - ]) + mask = np.hstack( + [ + np.zeros(self.num_patches_spatial - self.num_keep_spatial), + np.ones(self.num_keep_spatial), + ] + ) np.random.shuffle(mask) mask = torch.tensor(np.tile(mask, (self.duration, 1))) mask = mask.flatten() diff --git a/src/models/attentive_pooler.py b/src/models/attentive_pooler.py index ecd9986..da0c528 100644 --- a/src/models/attentive_pooler.py +++ b/src/models/attentive_pooler.py @@ -10,16 +10,13 @@ import torch import torch.nn as nn -from src.models.utils.modules import ( - Block, - CrossAttention, - CrossAttentionBlock -) +from src.models.utils.modules import Block, CrossAttention, CrossAttentionBlock from src.utils.tensors import trunc_normal_ class AttentivePooler(nn.Module): - """ Attentive Pooler """ + """Attentive Pooler""" + def __init__( self, num_queries=1, @@ -30,7 +27,7 @@ def __init__( norm_layer=nn.LayerNorm, init_std=0.02, qkv_bias=True, - complete_block=True + complete_block=True, ): super().__init__() self.query_tokens = nn.Parameter(torch.zeros(1, num_queries, embed_dim)) @@ -42,24 +39,28 @@ def __init__( num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - norm_layer=norm_layer) + norm_layer=norm_layer, + ) else: self.cross_attention_block = CrossAttention( - dim=embed_dim, - num_heads=num_heads, - qkv_bias=qkv_bias) + dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias + ) self.blocks = None if depth > 1: - self.blocks = nn.ModuleList([ - Block( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=False, - norm_layer=norm_layer) - for i in range(depth-1)]) + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=False, + norm_layer=norm_layer, + ) + for i in range(depth - 1) + ] + ) self.init_std = init_std trunc_normal_(self.query_tokens, std=self.init_std) @@ -103,7 +104,8 @@ def forward(self, x): class AttentiveClassifier(nn.Module): - """ Attentive Classifier """ + """Attentive Classifier""" + def __init__( self, embed_dim=768, diff --git a/src/models/predictor.py b/src/models/predictor.py index 2dd9a38..967a23d 100644 --- a/src/models/predictor.py +++ b/src/models/predictor.py @@ -13,15 +13,13 @@ from src.models.utils.modules import Block from src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed -from src.utils.tensors import ( - trunc_normal_, - repeat_interleave_batch -) +from src.utils.tensors import trunc_normal_, repeat_interleave_batch from src.masks.utils import apply_masks class VisionTransformerPredictor(nn.Module): - """ Vision Transformer """ + """Vision Transformer""" + def __init__( self, img_size=224, @@ -54,10 +52,12 @@ def __init__( self.num_mask_tokens = 0 if use_mask_tokens: self.num_mask_tokens = num_mask_tokens - self.mask_tokens = nn.ParameterList([ - nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) - for i in range(num_mask_tokens) - ]) + self.mask_tokens = nn.ParameterList( + [ + nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) + for i in range(num_mask_tokens) + ] + ) # Determine positional embedding self.input_size = img_size @@ -77,32 +77,35 @@ def __init__( * (img_size // patch_size) ) else: - self.num_patches = num_patches = ( - (img_size // patch_size) - * (img_size // patch_size) + self.num_patches = num_patches = (img_size // patch_size) * ( + img_size // patch_size ) # Position embedding self.uniform_power = uniform_power self.predictor_pos_embed = None self.predictor_pos_embed = nn.Parameter( - torch.zeros(1, num_patches, predictor_embed_dim), - requires_grad=False) + torch.zeros(1, num_patches, predictor_embed_dim), requires_grad=False + ) # Attention Blocks - self.predictor_blocks = nn.ModuleList([ - Block( - dim=predictor_embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - act_layer=nn.GELU, - attn_drop=attn_drop_rate, - grid_size=grid_size, - grid_depth=grid_depth, - norm_layer=norm_layer) - for i in range(depth)]) + self.predictor_blocks = nn.ModuleList( + [ + Block( + dim=predictor_embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.GELU, + attn_drop=attn_drop_rate, + grid_size=grid_size, + grid_depth=grid_depth, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) # Normalize & project back to input dimension self.predictor_norm = norm_layer(predictor_embed_dim) @@ -128,7 +131,7 @@ def _init_pos_embed(self, pos_embed): grid_size, grid_depth, cls_token=False, - uniform_power=self.uniform_power + uniform_power=self.uniform_power, ) else: sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False) @@ -155,20 +158,26 @@ def diffusion(self, x, noise_beta=(0.5, 1.0), steps=1000): # Prepare diffusion noise schedule b1, b2 = noise_beta - beta_scheduler = (b1 + i*(b2-b1)/steps for i in range(steps)) + beta_scheduler = (b1 + i * (b2 - b1) / steps for i in range(steps)) alpha_scheduler = [] _alpha = 1.0 for _beta in beta_scheduler: - _alpha *= 1.-_beta + _alpha *= 1.0 - _beta alpha_scheduler += [_alpha] # Sample diffusion time step T = torch.randint(0, steps, (len(x),)) - alpha = torch.tensor(alpha_scheduler, device=x.device)[T].unsqueeze(-1).unsqueeze(-1) + alpha = ( + torch.tensor(alpha_scheduler, device=x.device)[T] + .unsqueeze(-1) + .unsqueeze(-1) + ) # Normalize features and apply noise x = torch.nn.functional.layer_norm(x, (x.size(-1),)) - x = alpha**0.5 * x + (1.-alpha)**0.5 * torch.randn(x.shape, device=x.device) + x = alpha**0.5 * x + (1.0 - alpha) ** 0.5 * torch.randn( + x.shape, device=x.device + ) return x def forward(self, ctxt, tgt, masks_ctxt, masks_tgt, mask_index=1): @@ -179,7 +188,9 @@ def forward(self, ctxt, tgt, masks_ctxt, masks_tgt, mask_index=1): :params masks_tgt: indices of target tokens in input """ - assert (masks_ctxt is not None) and (masks_tgt is not None), 'Cannot run predictor without mask indices' + assert (masks_ctxt is not None) and ( + masks_tgt is not None + ), "Cannot run predictor without mask indices" if not isinstance(masks_ctxt, list): masks_ctxt = [masks_ctxt] @@ -241,6 +252,6 @@ def forward(self, ctxt, tgt, masks_ctxt, masks_tgt, mask_index=1): def vit_predictor(**kwargs): model = VisionTransformerPredictor( - mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), - **kwargs) + mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs + ) return model diff --git a/src/models/utils/modules.py b/src/models/utils/modules.py index dc470d9..dd93b8f 100644 --- a/src/models/utils/modules.py +++ b/src/models/utils/modules.py @@ -17,7 +17,7 @@ def __init__( hidden_features=None, out_features=None, act_layer=nn.GELU, - drop=0. + drop=0.0, ): super().__init__() out_features = out_features or in_features @@ -43,14 +43,14 @@ def __init__( num_heads=8, qkv_bias=False, qk_scale=None, - attn_drop=0., - proj_drop=0., - use_sdpa=True + attn_drop=0.0, + proj_drop=0.0, + use_sdpa=True, ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = qk_scale or head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) @@ -60,18 +60,24 @@ def __init__( def forward(self, x, mask=None): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D] if self.use_sdpa: with torch.backends.cuda.sdp_kernel(): - x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.proj_drop_prob) + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.proj_drop_prob + ) attn = None else: attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D] attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (attn @ v) + x = attn @ v x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) @@ -83,11 +89,11 @@ def __init__( self, dim, num_heads, - mlp_ratio=4., + mlp_ratio=4.0, qkv_bias=False, qk_scale=None, - drop=0., - attn_drop=0., + drop=0.0, + attn_drop=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, grid_size=None, @@ -101,7 +107,8 @@ def __init__( qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, - proj_drop=drop) + proj_drop=drop, + ) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) @@ -109,7 +116,8 @@ def __init__( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, - drop=drop) + drop=drop, + ) def forward(self, x, return_attention=False, mask=None): y, attn = self.attn(self.norm1(x), mask=mask) @@ -121,28 +129,30 @@ def forward(self, x, return_attention=False, mask=None): class CrossAttention(nn.Module): - def __init__( - self, - dim, - num_heads=12, - qkv_bias=False, - use_sdpa=True - ): + def __init__(self, dim, num_heads=12, qkv_bias=False, use_sdpa=True): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = head_dim ** -0.5 + self.scale = head_dim**-0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) - self.kv = nn.Linear(dim, int(dim*2), bias=qkv_bias) + self.kv = nn.Linear(dim, int(dim * 2), bias=qkv_bias) self.proj = nn.Linear(dim, dim) self.use_sdpa = use_sdpa def forward(self, q, x): B, n, C = q.shape - q = self.q(q).reshape(B, n, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + q = ( + self.q(q) + .reshape(B, n, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) B, N, C = x.shape - kv = self.kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = ( + self.kv(x) + .reshape(B, N, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) k, v = kv[0], kv[1] # (batch_size, num_heads, seq_len, feature_dim_per_head) if self.use_sdpa: @@ -151,11 +161,11 @@ def forward(self, q, x): else: xattn = (q @ k.transpose(-2, -1)) * self.scale xattn = xattn.softmax(dim=-1) # (batch_size, num_heads, query_len, seq_len) - q = (xattn @ v) + q = xattn @ v q = q.transpose(1, 2).reshape(B, n, C) q = self.proj(q) - + return q @@ -164,17 +174,19 @@ def __init__( self, dim, num_heads, - mlp_ratio=4., + mlp_ratio=4.0, qkv_bias=False, act_layer=nn.GELU, - norm_layer=nn.LayerNorm + norm_layer=nn.LayerNorm, ): super().__init__() self.norm1 = norm_layer(dim) self.xattn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) + self.mlp = MLP( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer + ) def forward(self, q, x): y = self.xattn(q, self.norm1(x)) diff --git a/src/models/utils/patch_embed.py b/src/models/utils/patch_embed.py index 4ff4de5..42f3115 100644 --- a/src/models/utils/patch_embed.py +++ b/src/models/utils/patch_embed.py @@ -12,15 +12,13 @@ class PatchEmbed(nn.Module): """ Image to Patch Embedding """ - def __init__( - self, - patch_size=16, - in_chans=3, - embed_dim=768 - ): + + def __init__(self, patch_size=16, in_chans=3, embed_dim=768): super().__init__() self.patch_size = patch_size - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) def forward(self, x): B, C, H, W = x.shape diff --git a/src/models/utils/pos_embs.py b/src/models/utils/pos_embs.py index d1d82e2..f9792bd 100644 --- a/src/models/utils/pos_embs.py +++ b/src/models/utils/pos_embs.py @@ -9,11 +9,7 @@ def get_3d_sincos_pos_embed( - embed_dim, - grid_size, - grid_depth, - cls_token=False, - uniform_power=False + embed_dim, grid_size, grid_depth, cls_token=False, uniform_power=False ): """ grid_size: int of the grid height and width @@ -25,14 +21,16 @@ def get_3d_sincos_pos_embed( grid_d = np.arange(grid_depth, dtype=float) grid_h = np.arange(grid_size, dtype=float) grid_w = np.arange(grid_size, dtype=float) - grid_h, grid_d, grid_w = np.meshgrid(grid_h, grid_d, grid_w) # order of meshgrid is very important for indexing as [d,h,w] + grid_h, grid_d, grid_w = np.meshgrid( + grid_h, grid_d, grid_w + ) # order of meshgrid is very important for indexing as [d,h,w] if not uniform_power: h_embed_dim = embed_dim // 4 w_embed_dim = embed_dim // 4 d_embed_dim = embed_dim // 2 else: - h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim/6)*2) + h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim / 6) * 2) emb_h = get_1d_sincos_pos_embed_from_grid(h_embed_dim, grid_h) # (T*H*W, D1) emb_w = get_1d_sincos_pos_embed_from_grid(w_embed_dim, grid_w) # (T*H*W, D2) @@ -53,7 +51,9 @@ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ grid_h = np.arange(grid_size, dtype=float) grid_w = np.arange(grid_size, dtype=float) - grid_w, grid_h = np.meshgrid(grid_w, grid_h) # order of meshgrid is very important for indexing as [h, w] + grid_w, grid_h = np.meshgrid( + grid_w, grid_h + ) # order of meshgrid is very important for indexing as [h, w] emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w) # (H*W, D/2) @@ -86,11 +86,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=float) - omega /= embed_dim / 2. - omega = 1. / 10000**omega # (D/2,) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) - pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) diff --git a/src/models/vision_transformer.py b/src/models/vision_transformer.py index a8748df..7f3fbc9 100644 --- a/src/models/vision_transformer.py +++ b/src/models/vision_transformer.py @@ -19,7 +19,8 @@ class VisionTransformer(nn.Module): - """ Vision Transformer """ + """Vision Transformer""" + def __init__( self, img_size=224, @@ -62,7 +63,8 @@ def __init__( patch_size=patch_size, tubelet_size=tubelet_size, in_chans=in_chans, - embed_dim=embed_dim) + embed_dim=embed_dim, + ) self.num_patches = ( (num_frames // tubelet_size) * (img_size // patch_size) @@ -70,36 +72,36 @@ def __init__( ) else: self.patch_embed = PatchEmbed( - patch_size=patch_size, - in_chans=in_chans, - embed_dim=embed_dim) - self.num_patches = ( - (img_size // patch_size) - * (img_size // patch_size) + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim ) + self.num_patches = (img_size // patch_size) * (img_size // patch_size) # Position embedding self.uniform_power = uniform_power self.pos_embed = None self.pos_embed = nn.Parameter( - torch.zeros(1, self.num_patches, embed_dim), - requires_grad=False) + torch.zeros(1, self.num_patches, embed_dim), requires_grad=False + ) # Attention Blocks - self.blocks = nn.ModuleList([ - Block( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - act_layer=nn.GELU, - grid_size=grid_size, - grid_depth=grid_depth, - attn_drop=attn_drop_rate, - norm_layer=norm_layer) - for i in range(depth)]) + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.GELU, + grid_size=grid_size, + grid_depth=grid_depth, + attn_drop=attn_drop_rate, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) self.norm = norm_layer(embed_dim) # ------ initialize weights @@ -119,7 +121,7 @@ def _init_pos_embed(self, pos_embed): grid_size, grid_depth, cls_token=False, - uniform_power=self.uniform_power + uniform_power=self.uniform_power, ) else: sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False) @@ -215,15 +217,16 @@ def interpolate_pos_encoding(self, x, pos_embed): # in patches N_t = self.num_frames // self.tubelet_size N_h = N_w = self.input_size // self.patch_size - assert N_h * N_w * N_t == N, 'Positional embedding initialized incorrectly' + assert N_h * N_w * N_t == N, "Positional embedding initialized incorrectly" # Compute scale factor for spatio-temporal interpolation - scale_factor = (T/N_t, H/N_h, W/N_w) + scale_factor = (T / N_t, H / N_h, W / N_w) pos_embed = nn.functional.interpolate( pos_embed.reshape(1, N_t, N_h, N_w, dim).permute(0, 4, 1, 2, 3), scale_factor=scale_factor, - mode='trilinear') + mode="trilinear", + ) pos_embed = pos_embed.permute(0, 2, 3, 4, 1).view(1, -1, dim) return pos_embed @@ -239,69 +242,120 @@ def interpolate_pos_encoding(self, x, pos_embed): scale_factor = math.sqrt(npatch / N) pos_embed = nn.functional.interpolate( - pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( + 0, 3, 1, 2 + ), scale_factor=scale_factor, - mode='bicubic') + mode="bicubic", + ) pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return pos_embed def vit_tiny(patch_size=16, **kwargs): model = VisionTransformer( - patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + patch_size=patch_size, + embed_dim=192, + depth=12, + num_heads=3, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) return model def vit_small(patch_size=16, **kwargs): model = VisionTransformer( - patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) return model def vit_base(patch_size=16, **kwargs): model = VisionTransformer( - patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) return model def vit_large(patch_size=16, **kwargs): model = VisionTransformer( - patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) return model def vit_huge(patch_size=16, **kwargs): model = VisionTransformer( - patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + patch_size=patch_size, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) return model def vit_giant(patch_size=16, **kwargs): model = VisionTransformer( - patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + patch_size=patch_size, + embed_dim=1408, + depth=40, + num_heads=16, + mlp_ratio=48 / 11, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) return model def vit_gigantic(patch_size=14, **kwargs): model = VisionTransformer( - patch_size=patch_size, embed_dim=1664, depth=48, num_heads=16, mpl_ratio=64/13, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs + patch_size=patch_size, + embed_dim=1664, + depth=48, + num_heads=16, + mpl_ratio=64 / 13, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs ) return model VIT_EMBED_DIMS = { - 'vit_tiny': 192, - 'vit_small': 384, - 'vit_base': 768, - 'vit_large': 1024, - 'vit_huge': 1280, - 'vit_giant': 1408, - 'vit_gigantic': 1664, + "vit_tiny": 192, + "vit_small": 384, + "vit_base": 768, + "vit_large": 1024, + "vit_huge": 1280, + "vit_giant": 1408, + "vit_gigantic": 1664, } diff --git a/src/utils/distributed.py b/src/utils/distributed.py index cfba444..8b205c0 100644 --- a/src/utils/distributed.py +++ b/src/utils/distributed.py @@ -21,28 +21,26 @@ def init_distributed(port=37123, rank_and_world_size=(None, None)): return dist.get_world_size(), dist.get_rank() rank, world_size = rank_and_world_size - os.environ['MASTER_ADDR'] = 'localhost' + os.environ["MASTER_ADDR"] = "localhost" if (rank is None) or (world_size is None): try: - world_size = int(os.environ['SLURM_NTASKS']) - rank = int(os.environ['SLURM_PROCID']) - os.environ['MASTER_ADDR'] = os.environ['HOSTNAME'] + world_size = int(os.environ["SLURM_NTASKS"]) + rank = int(os.environ["SLURM_PROCID"]) + os.environ["MASTER_ADDR"] = os.environ["HOSTNAME"] except Exception: - logger.info('SLURM vars not set (distributed training not available)') + logger.info("SLURM vars not set (distributed training not available)") world_size, rank = 1, 0 return world_size, rank try: - os.environ['MASTER_PORT'] = str(port) + os.environ["MASTER_PORT"] = str(port) torch.distributed.init_process_group( - backend='nccl', - world_size=world_size, - rank=rank + backend="nccl", world_size=world_size, rank=rank ) except Exception as e: world_size, rank = 1, 0 - logger.info(f'Rank: {rank}. Distributed training not available {e}') + logger.info(f"Rank: {rank}. Distributed training not available {e}") return world_size, rank diff --git a/src/utils/logging.py b/src/utils/logging.py index fcdd3fa..80141bd 100644 --- a/src/utils/logging.py +++ b/src/utils/logging.py @@ -12,10 +12,10 @@ def gpu_timer(closure, log_timings=True): - """ Helper to time gpu-time to execute closure() """ + """Helper to time gpu-time to execute closure()""" log_timings = log_timings and torch.cuda.is_available() - elapsed_time = -1. + elapsed_time = -1.0 if log_timings: start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) @@ -36,8 +36,13 @@ def gpu_timer(closure, log_timings=True): def get_logger(name=None, force=False): - logging.basicConfig(stream=sys.stdout, level=logging.INFO, - format=LOG_FORMAT, datefmt=DATE_FORMAT, force=force) + logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format=LOG_FORMAT, + datefmt=DATE_FORMAT, + force=force, + ) return logging.getLogger(name=name) @@ -47,18 +52,18 @@ def __init__(self, fname, *argv): self.fname = fname self.types = [] # -- print headers - with open(self.fname, '+a') as f: + with open(self.fname, "+a") as f: for i, v in enumerate(argv, 1): self.types.append(v[0]) if i < len(argv): - print(v[1], end=',', file=f) + print(v[1], end=",", file=f) else: - print(v[1], end='\n', file=f) + print(v[1], end="\n", file=f) def log(self, *argv): - with open(self.fname, '+a') as f: + with open(self.fname, "+a") as f: for i, tv in enumerate(zip(self.types, argv), 1): - end = ',' if i < len(argv) else '\n' + end = "," if i < len(argv) else "\n" print(tv[0] % tv[1], end=end, file=f) @@ -71,8 +76,8 @@ def __init__(self): def reset(self): self.val = 0 self.avg = 0 - self.max = float('-inf') - self.min = float('inf') + self.max = float("-inf") + self.min = float("inf") self.sum = 0 self.count = 0 @@ -93,26 +98,26 @@ def grad_logger(named_params): stats.first_layer = None stats.last_layer = None for n, p in named_params: - if (p.grad is not None) and not (n.endswith('.bias') or len(p.shape) == 1): + if (p.grad is not None) and not (n.endswith(".bias") or len(p.shape) == 1): grad_norm = float(torch.norm(p.grad.data)) stats.update(grad_norm) - if 'qkv' in n: + if "qkv" in n: stats.last_layer = grad_norm if stats.first_layer is None: stats.first_layer = grad_norm if stats.first_layer is None or stats.last_layer is None: - stats.first_layer = stats.last_layer = 0. + stats.first_layer = stats.last_layer = 0.0 return stats def adamw_logger(optimizer): - """ logging magnitude of first and second momentum buffers in adamw """ + """logging magnitude of first and second momentum buffers in adamw""" # TODO: assert that optimizer is instance of torch.optim.AdamW - state = optimizer.state_dict().get('state') + state = optimizer.state_dict().get("state") exp_avg_stats = AverageMeter() exp_avg_sq_stats = AverageMeter() for key in state: s = state.get(key) - exp_avg_stats.update(float(s.get('exp_avg').abs().mean())) - exp_avg_sq_stats.update(float(s.get('exp_avg_sq').abs().mean())) - return {'exp_avg': exp_avg_stats, 'exp_avg_sq': exp_avg_sq_stats} + exp_avg_stats.update(float(s.get("exp_avg").abs().mean())) + exp_avg_sq_stats.update(float(s.get("exp_avg_sq").abs().mean())) + return {"exp_avg": exp_avg_stats, "exp_avg_sq": exp_avg_sq_stats} diff --git a/src/utils/monitoring.py b/src/utils/monitoring.py index 95a7845..bfd13a3 100644 --- a/src/utils/monitoring.py +++ b/src/utils/monitoring.py @@ -56,11 +56,12 @@ def __init__(self, pid=None, refresh_interval=None, stats_callback_fn=None): if stats_callback_fn is None: # Default callback def stats_callback_fn(resource_sample: ResourceStatsSample): - print( - f"PID {self.p.pid} Stats: {resource_sample.resource_stats}") + print(f"PID {self.p.pid} Stats: {resource_sample.resource_stats}") + elif not callable(stats_callback_fn): - raise ValueError("Callback needs to be callable, got {}".format( - type(stats_callback_fn))) + raise ValueError( + "Callback needs to be callable, got {}".format(type(stats_callback_fn)) + ) self.stats_callback_fn = stats_callback_fn def stop(self) -> None: @@ -121,8 +122,7 @@ def compress_cpu_affinity(cpu_affinity): if min_x == max_x: cpu_affinity_compressed.append("{}".format(min_x)) else: - cpu_affinity_compressed.append( - "{}-{}".format(min_x, max_x)) + cpu_affinity_compressed.append("{}-{}".format(min_x, max_x)) min_x = x max_x = x last_x = x @@ -131,8 +131,7 @@ def compress_cpu_affinity(cpu_affinity): if min_x == max_x: cpu_affinity_compressed.append("{}".format(min_x)) else: - cpu_affinity_compressed.append( - "{}-{}".format(min_x, max_x)) + cpu_affinity_compressed.append("{}-{}".format(min_x, max_x)) # Concat cpu_affinity_compressed = ",".join(cpu_affinity_compressed) @@ -167,6 +166,7 @@ def compress_cpu_affinity(cpu_affinity): if __name__ == "__main__": import multiprocessing import time + pid = multiprocessing.current_process().pid monitor_thread = ResourceMonitoringThread(pid, 1) monitor_thread.start() diff --git a/src/utils/schedulers.py b/src/utils/schedulers.py index df02e2b..ae29809 100644 --- a/src/utils/schedulers.py +++ b/src/utils/schedulers.py @@ -18,7 +18,7 @@ def __init__( ref_lr, T_max, last_epoch=-1, - final_lr=0. + final_lr=0.0, ): self.optimizer = optimizer self.start_lr = start_lr @@ -26,7 +26,7 @@ def __init__( self.final_lr = final_lr self.warmup_steps = warmup_steps self.T_max = T_max - warmup_steps - self._step = 0. + self._step = 0.0 def step(self): self._step += 1 @@ -36,34 +36,35 @@ def step(self): else: # -- progress after warmup progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max)) - new_lr = max(self.final_lr, - self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1. + math.cos(math.pi * progress))) + new_lr = max( + self.final_lr, + self.final_lr + + (self.ref_lr - self.final_lr) + * 0.5 + * (1.0 + math.cos(math.pi * progress)), + ) for group in self.optimizer.param_groups: - group['lr'] = new_lr + group["lr"] = new_lr return new_lr class CosineWDSchedule(object): - def __init__( - self, - optimizer, - ref_wd, - T_max, - final_wd=0. - ): + def __init__(self, optimizer, ref_wd, T_max, final_wd=0.0): self.optimizer = optimizer self.ref_wd = ref_wd self.final_wd = final_wd self.T_max = T_max - self._step = 0. + self._step = 0.0 def step(self): self._step += 1 progress = self._step / self.T_max - new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1. + math.cos(math.pi * progress)) + new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * ( + 1.0 + math.cos(math.pi * progress) + ) if self.final_wd <= self.ref_wd: new_wd = max(self.final_wd, new_wd) @@ -71,6 +72,6 @@ def step(self): new_wd = min(self.final_wd, new_wd) for group in self.optimizer.param_groups: - if ('WD_exclude' not in group) or not group['WD_exclude']: - group['weight_decay'] = new_wd + if ("WD_exclude" not in group) or not group["WD_exclude"]: + group["weight_decay"] = new_wd return new_wd diff --git a/src/utils/tensors.py b/src/utils/tensors.py index 6ae2850..ff33644 100644 --- a/src/utils/tensors.py +++ b/src/utils/tensors.py @@ -19,7 +19,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function - return (1. + math.erf(x / math.sqrt(2.))) / 2. + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 with torch.no_grad(): # Values are generated by using a truncated uniform distribution and @@ -37,7 +37,7 @@ def norm_cdf(x): tensor.erfinv_() # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.)) + tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range @@ -45,7 +45,7 @@ def norm_cdf(x): return tensor -def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): # type: (Tensor, float, float, float, float) -> Tensor return _no_grad_trunc_normal_(tensor, mean, std, a, b) @@ -64,8 +64,11 @@ def apply_masks(x, masks): def repeat_interleave_batch(x, B, repeat): N = len(x) // B - x = torch.cat([ - torch.cat([x[i*B:(i+1)*B] for _ in range(repeat)], dim=0) - for i in range(N) - ], dim=0) + x = torch.cat( + [ + torch.cat([x[i * B : (i + 1) * B] for _ in range(repeat)], dim=0) + for i in range(N) + ], + dim=0, + ) return x From 0ec248e929f1e2952f3e969c46c3d84036bacf85 Mon Sep 17 00:00:00 2001 From: Munir Jojo-Verge Date: Fri, 17 May 2024 14:52:55 -0400 Subject: [PATCH 3/8] fix: Adding actions to V-JEPA --- app/main.py | 2 +- app/main_distributed.py | 2 +- app/scaffold.py | 2 +- app/vjepa/train.py | 98 +++++++-- app/vjepa/transforms.py | 2 +- app/vjepa/utils.py | 17 +- configs/pretrain/vith16_384.yaml | 2 +- evals/image_classification_frozen/eval.py | 2 +- evals/main.py | 2 +- evals/main_distributed.py | 2 +- evals/scaffold.py | 2 +- evals/video_classification_frozen/eval.py | 2 +- evals/video_classification_frozen/utils.py | 2 +- setup.py | 2 +- src/datasets/data_manager.py | 16 +- src/datasets/image_dataset.py | 193 +++++++++++++++++- src/datasets/utils/video/functional.py | 2 +- src/datasets/utils/video/randaugment.py | 2 +- src/datasets/utils/video/randerase.py | 2 +- src/datasets/utils/video/transforms.py | 2 +- src/datasets/utils/video/volume_transforms.py | 2 +- src/datasets/utils/weighted_sampler.py | 2 +- src/datasets/video_dataset.py | 30 ++- src/masks/default.py | 2 +- src/masks/multiblock3d.py | 2 +- src/masks/random_tube.py | 2 +- src/masks/utils.py | 2 +- src/models/action_encoders.py | 35 ++++ src/models/attentive_pooler.py | 2 +- src/models/predictor.py | 2 +- src/models/utils/combine_encodings.py | 33 +++ src/models/utils/modules.py | 2 +- src/models/utils/multimask.py | 2 +- src/models/utils/patch_embed.py | 2 +- src/models/utils/pos_embs.py | 2 +- src/models/vision_transformer.py | 2 +- src/utils/distributed.py | 2 +- src/utils/logging.py | 2 +- src/utils/monitoring.py | 2 +- src/utils/schedulers.py | 2 +- src/utils/tensors.py | 2 +- 41 files changed, 432 insertions(+), 58 deletions(-) create mode 100644 src/models/action_encoders.py create mode 100644 src/models/utils/combine_encodings.py diff --git a/app/main.py b/app/main.py index 77d63e0..0858489 100644 --- a/app/main.py +++ b/app/main.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/app/main_distributed.py b/app/main_distributed.py index b36a646..defc098 100644 --- a/app/main_distributed.py +++ b/app/main_distributed.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/app/scaffold.py b/app/scaffold.py index 7946924..d2f3b3e 100644 --- a/app/scaffold.py +++ b/app/scaffold.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/app/vjepa/train.py b/app/vjepa/train.py index 0390974..34a5507 100644 --- a/app/vjepa/train.py +++ b/app/vjepa/train.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -6,6 +6,7 @@ # import os +import csv # -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS try: @@ -40,6 +41,12 @@ AverageMeter, ) from src.utils.tensors import repeat_interleave_batch +from src.models.utils.combine_encodings import ( + combine_encodings_concat, + combine_encodings_add, + AttentionFusion, +) + from app.vjepa.utils import ( load_checkpoint, @@ -64,7 +71,32 @@ logger = get_logger(__name__) +def generate_csv_file(data_dir, csv_filename="v-jepa-pretrain.csv"): + csv_filepath = os.path.join(data_dir, csv_filename) + logger.info(f"Generating CSV file: {csv_filepath}") + + valid_folders = [] + for folder_name in os.listdir(data_dir): + folder_path = os.path.join(data_dir, folder_name) + action_filepath = os.path.join(folder_path, "action_data.csv") + if os.path.isdir(folder_path) and os.path.isfile(action_filepath): + valid_folders.append(folder_path) + else: + logger.warning( + f"Skipping folder '{folder_name}' due to missing or invalid action_data.csv" + ) + + with open(csv_filepath, "w", newline="") as csvfile: + writer = csv.writer(csvfile, delimiter=" ") + for folder_path in valid_folders: + writer.writerow([folder_path, 0]) # Write folder path and dummy label (0) + + logger.info(f"CSV file generation complete. Found {len(valid_folders)} valid folders.") + + def main(args, resume_preempt=False): + # First let's go over the folders and generate the + # ----------------------------------------------------------------------- # # PASSED IN PARAMS FROM CONFIG FILE # ----------------------------------------------------------------------- # @@ -103,7 +135,7 @@ def main(args, resume_preempt=False): # -- DATA cfgs_data = args.get("data") - dataset_type = cfgs_data.get("dataset_type", "videodataset") + dataset_type = cfgs_data.get("dataset_type", "egovehicle_imagedataset") mask_type = cfgs_data.get("mask_type", "multiblock3d") dataset_paths = cfgs_data.get("datasets", []) datasets_weights = cfgs_data.get("datasets_weights", None) @@ -162,6 +194,12 @@ def main(args, resume_preempt=False): # ----------------------------------------------------------------------- # # ----------------------------------------------------------------------- # + # Generate CSV file (only if not already exists) + csv_filename = "v-jepa-pretrain.csv" + # if not os.path.exists(os.path.join(dataset_paths[0], csv_filename)): + generate_csv_file(dataset_paths[0]) + + np.random.seed(seed) torch.manual_seed(seed) torch.backends.cudnn.benchmark = True @@ -207,7 +245,7 @@ def main(args, resume_preempt=False): ) # -- init model - encoder, predictor = init_video_model( + encoder, predictor, action_encoder = init_video_model( uniform_power=uniform_power, use_mask_tokens=use_mask_tokens, num_mask_tokens=len(cfgs_mask), @@ -413,6 +451,11 @@ def load_clips(): [u.to(device, non_blocking=True) for u in udata[0]], dim=0 ) + # Load action data + actions = torch.cat( + [a.to(device, non_blocking=True) for a in actions], dim=0 + ) + # Put each mask-enc/mask-pred pair on the GPU and reuse the # same mask pair for each clip _masks_enc, _masks_pred = [], [] @@ -424,9 +467,9 @@ def load_clips(): _masks_enc.append(_me) _masks_pred.append(_mp) - return (clips, _masks_enc, _masks_pred) + return (clips, _masks_enc, _masks_pred, actions) - clips, masks_enc, masks_pred = load_clips() + clips, masks_enc, masks_pred, actions = load_clips() for _i, m in enumerate(mask_meters): m.update(masks_enc[_i][0].size(-1)) @@ -447,24 +490,43 @@ def forward_target(c): h, (h.size(-1),) ) # normalize over feature-dim [B, N, D] # -- create targets (masked regions of h) - h = apply_masks(h, masks_pred, concat=False) - return h - - def forward_context(c, h): + # h = apply_masks(h, masks_pred, concat=False) + # -- create targets (next frames of h) + h_next = h[ + :, 1:, : + ] # Assuming frames are ordered chronologically. These are ground truth next frames + return h_next + + def forward_context(c, h, a): """ Returns list of tensors of shape [B, N, D], one for each mask-pred. """ z = encoder(c, masks_enc) - z = predictor(z, h, masks_enc, masks_pred) + + # Encode the action representations + z_a = action_encoder(a) + + # Combine the encoded actions with the encoded video clips + z_combined = combine_encodings_concat(z, z_a) + + z = predictor(z_combined, h, masks_enc, masks_pred) return z - def loss_fn(z, h): + # def loss_fn(z, h): + # loss = 0.0 + # # Compute loss and accumulate for each mask-enc/mask-pred pair + # for zi, hi in zip(z, h): + # loss += torch.mean(torch.abs(zi - hi) ** loss_exp) / loss_exp + # # loss /= len(masks_pred) + # return loss + + def loss_fn(z_next, h_next): loss = 0.0 - # Compute loss and accumulate for each mask-enc/mask-pred pair - for zi, hi in zip(z, h): + # Compute loss between predicted next frames and ground truth next frames + for zi, hi in zip(z_next, h_next): loss += torch.mean(torch.abs(zi - hi) ** loss_exp) / loss_exp - loss /= len(masks_pred) + loss /= len(h_next) return loss def reg_fn(z): @@ -475,10 +537,10 @@ def reg_fn(z): # Step 1. Forward loss_jepa, loss_reg = 0.0, 0.0 with torch.cuda.amp.autocast(dtype=dtype, enabled=mixed_precision): - h = forward_target(clips) - z = forward_context(clips, h) - loss_jepa = loss_fn(z, h) # jepa prediction loss - pstd_z = reg_fn(z) # predictor variance across patches + h_next = forward_target(clips) + z_next = forward_context(clips, h_next, actions) + loss_jepa = loss_fn(z_next, h_next) # jepa prediction loss + pstd_z = reg_fn(z_next) # predictor variance across patches loss_reg += torch.mean(F.relu(1.0 - pstd_z)) loss = loss_jepa + reg_coeff * loss_reg diff --git a/app/vjepa/transforms.py b/app/vjepa/transforms.py index cc90645..d5e6d0c 100644 --- a/app/vjepa/transforms.py +++ b/app/vjepa/transforms.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/app/vjepa/utils.py b/app/vjepa/utils.py index 7bdecd5..58e0339 100644 --- a/app/vjepa/utils.py +++ b/app/vjepa/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -14,6 +14,7 @@ import torch import src.models.vision_transformer as video_vit +from src.models.action_encoders import ActionEncoderContinuous, ActionEncoderDiscrete import src.models.predictor as vit_pred from src.models.utils.multimask import MultiMaskWrapper, PredictorMultiMaskWrapper from src.utils.schedulers import WarmupCosineSchedule, CosineWDSchedule @@ -95,6 +96,10 @@ def init_video_model( num_mask_tokens=2, zero_init_mask_tokens=True, use_sdpa=False, + action_type: str = "disc", # "cont", + num_actions=19, + embed_dim=32, + hidden_dim=32, ): encoder = video_vit.__dict__[model_name]( img_size=crop_size, @@ -105,6 +110,7 @@ def init_video_model( use_sdpa=use_sdpa, ) encoder = MultiMaskWrapper(encoder) + predictor = vit_pred.__dict__["vit_predictor"]( img_size=crop_size, use_mask_tokens=use_mask_tokens, @@ -122,6 +128,13 @@ def init_video_model( ) predictor = PredictorMultiMaskWrapper(predictor) + if action_type == "disc": + action_encoder = ActionEncoderDiscrete( + num_actions=num_actions, embed_dim=embed_dim, hidden_dim=hidden_dim + ) + # else: + # action_encoder = ActionEncoderContinuous(input_dim=) + def init_weights(m): if isinstance(m, torch.nn.Linear): trunc_normal_(m.weight, std=0.02) @@ -148,7 +161,7 @@ def count_parameters(model): logger.info(f"Encoder number of parameters: {count_parameters(encoder)}") logger.info(f"Predictor number of parameters: {count_parameters(predictor)}") - return encoder, predictor + return encoder, predictor, action_encoder def init_opt( diff --git a/configs/pretrain/vith16_384.yaml b/configs/pretrain/vith16_384.yaml index af4bb5f..a4cf73b 100644 --- a/configs/pretrain/vith16_384.yaml +++ b/configs/pretrain/vith16_384.yaml @@ -2,7 +2,7 @@ app: vjepa nodes: 30 tasks_per_node: 8 data: - dataset_type: VideoDataset + dataset_type: egovehicle_imagedataset datasets: - /home/ncdev/Documents/darwin/data/raw/v-jepa-pretrain.csv # - /your_path_to_ssv2_csv_file_index.csv diff --git a/evals/image_classification_frozen/eval.py b/evals/image_classification_frozen/eval.py index 9368c68..91c57fb 100644 --- a/evals/image_classification_frozen/eval.py +++ b/evals/image_classification_frozen/eval.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/evals/main.py b/evals/main.py index fb9130b..7a4b078 100644 --- a/evals/main.py +++ b/evals/main.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/evals/main_distributed.py b/evals/main_distributed.py index d885d69..2945922 100644 --- a/evals/main_distributed.py +++ b/evals/main_distributed.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/evals/scaffold.py b/evals/scaffold.py index cef8d3d..c87b3ee 100644 --- a/evals/scaffold.py +++ b/evals/scaffold.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/evals/video_classification_frozen/eval.py b/evals/video_classification_frozen/eval.py index a128d51..3093f96 100644 --- a/evals/video_classification_frozen/eval.py +++ b/evals/video_classification_frozen/eval.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/evals/video_classification_frozen/utils.py b/evals/video_classification_frozen/utils.py index fe8c2c3..c1975dd 100644 --- a/evals/video_classification_frozen/utils.py +++ b/evals/video_classification_frozen/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/setup.py b/setup.py index 32c80f2..c987138 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. diff --git a/src/datasets/data_manager.py b/src/datasets/data_manager.py index 4ddc1b2..ac826f0 100644 --- a/src/datasets/data_manager.py +++ b/src/datasets/data_manager.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -93,5 +93,19 @@ def init_data( drop_last=drop_last, log_dir=log_dir, ) + elif data.lower() == "egovehicle_imagedataset": + from src.datasets.image_dataset import make_egovehicle_imagedataset + dataset, data_loader, dist_sampler = make_egovehicle_imagedataset( + data_paths=root_path, + batch_size=batch_size, + transform=transform, + shared_transform=shared_transform, + rank=rank, + world_size=world_size, + collator=collator, + drop_last=drop_last, + num_workers=num_workers, + pin_mem=pin_mem, + ) return (data_loader, dist_sampler) diff --git a/src/datasets/image_dataset.py b/src/datasets/image_dataset.py index ea34749..d655985 100644 --- a/src/datasets/image_dataset.py +++ b/src/datasets/image_dataset.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -11,6 +11,10 @@ import torch import torchvision +from datetime import datetime +import numpy as np +from torch.utils.data import DataLoader, DistributedSampler, Sampler + _GLOBAL_SEED = 0 logger = getLogger() @@ -75,3 +79,190 @@ def make_imagedataset( logger.info("ImageFolder unsupervised data loader created") return dataset, data_loader, dist_sampler + + +import os +import pandas as pd +import torch +from PIL import Image + + +class ImageDataset(torch.utils.data.Dataset): + def __init__( + self, + data_paths, # List of directories containing timestamped image folders + transform=None, + shared_transform=None, + ): + self.data_paths = data_paths + self.transform = transform + self.shared_transform = shared_transform + + # Load Image Paths and Labels + self.samples = [] + for data_path in self.data_paths: + timestamped_folders = [ + f + for f in os.listdir(data_path) + if os.path.isdir(os.path.join(data_path, f)) + ] + for folder in timestamped_folders: + folder_path = os.path.join(data_path, folder) + image_files = sorted( + os.listdir(folder_path) + ) # Sort for sequential order + self.samples.extend( + [(folder_path, image_file) for image_file in image_files] + ) # Store (folder_path, image_filename) tuples + + if len(self.samples) == 0: + raise RuntimeError(f"Found 0 image files in the data_paths: {data_paths}") + + def __getitem__(self, index): + folder_path, image_filename = self.samples[index] + image_path = os.path.join(folder_path, image_filename) + + # Load Image + image = Image.open(image_path) + + # Apply Transforms + if self.shared_transform is not None: + image = self.shared_transform(image) + if self.transform is not None: + image = self.transform(image) + + # Load Action Data + action_filepath = os.path.join(folder_path, "action_data.csv") + action_df = pd.read_csv(action_filepath) + + # Extract Timestamp from Image Filename + image_timestamp = self.extract_timestamp_from_filename(image_filename) + action_label = self.get_action_label_for_timestamp(action_df, image_timestamp) + + return ( + image, + action_label, + ) # Return the image and its corresponding action label + + def extract_timestamp_from_filename(self, filename): + timestamp_str = os.path.splitext(filename)[0].split("_")[ + 0 + ] # Get '20240516_175159' + timestamp = datetime.strptime(timestamp_str, "%Y%m%d_%H%M%S") + return timestamp + + def get_action_labels_for_clip(self, action_df, image_timestamps): + action_labels = [] + for timestamp in image_timestamps: + # Find closest action timestamps before and after the image timestamp + before_idx = action_df["timestamp"].searchsorted(timestamp) - 1 + after_idx = before_idx + 1 + + # Handle edge cases (first or last image) + before_idx = max(0, before_idx) + after_idx = min(len(action_df) - 1, after_idx) + + # Get action labels and timestamps + action_before = action_df.iloc[before_idx]["action_name"] + action_after = action_df.iloc[after_idx]["action_name"] + timestamp_before = action_df.iloc[before_idx]["timestamp"] + timestamp_after = action_df.iloc[after_idx]["timestamp"] + + # Linear Interpolation (if needed, can be removed for simple nearest neighbor) + weight_after = (timestamp - timestamp_before) / ( + timestamp_after - timestamp_before + ) + if weight_after < 0.5: # Closer to the previous action + action_label = action_before + else: # Closer to the next action + action_label = action_after + + action_labels.append(action_label) + + return action_labels + + + +class SequentialImageSampler(Sampler): + def __init__(self, image_dataset, num_replicas=None, rank=None): + super().__init__(image_dataset) + self.image_dataset = image_dataset + self.num_replicas = num_replicas + self.rank = rank + self.grouped_images = self.group_images_by_folder() + + def group_images_by_folder(self): + # Group image paths by folder, sorting by timestamp within each folder + grouped_images = {} + for folder_path, image_filename in self.image_dataset.samples: + grouped_images.setdefault(folder_path, []).append(image_filename) + for folder_path in grouped_images: + grouped_images[folder_path] = sorted(grouped_images[folder_path], key=self.image_dataset.extract_timestamp_from_filename) + return grouped_images + + def __iter__(self): + # Determine which folders this worker should handle + worker_folders = [ + folder + for i, folder in enumerate(sorted(self.grouped_images.keys())) + if i % self.num_replicas == self.rank + ] + + # Yield image indices in sequential order for each assigned folder + for folder_path in worker_folders: + for image_filename in self.grouped_images[folder_path]: + yield self.image_dataset.samples.index((folder_path, image_filename)) + + def __len__(self): + # Total number of samples across all workers + total_samples = sum(len(images) for images in self.grouped_images.values()) + # Number of samples for this worker + num_samples_per_worker = total_samples // self.num_replicas + # Add any remaining samples to the last worker + if self.rank == self.num_replicas - 1: + num_samples_per_worker += total_samples % self.num_replicas + return num_samples_per_worker + + +def make_egovehicle_imagedataset( + data_paths, + batch_size, + transform=None, + shared_transform=None, + rank=0, + world_size=1, + collator=None, + drop_last=True, + num_workers=10, + pin_mem=True, +): + dataset = ImageDataset( + data_paths=data_paths, + transform=transform, + shared_transform=shared_transform, + ) + + logger.info("ImageDataset created") + + # Ensure that each worker gets a subset of folders while maintaining sequential order + sampler = SequentialImageSampler(dataset, num_replicas=world_size, rank=rank) + + # Wrap the sampler with DistributedSampler for shuffling at the folder level + dist_sampler = DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=True + ) + + # DataLoader should use both samplers + data_loader = DataLoader( + dataset, + batch_sampler=dist_sampler, # Using batch_sampler instead of sampler + collate_fn=collator, + num_workers=num_workers, + pin_memory=pin_mem, + drop_last=drop_last, + persistent_workers=num_workers > 0, + ) + + logger.info("ImageDataset data loader created") + + return dataset, data_loader, dist_sampler diff --git a/src/datasets/utils/video/functional.py b/src/datasets/utils/video/functional.py index 9e443f8..3136cab 100644 --- a/src/datasets/utils/video/functional.py +++ b/src/datasets/utils/video/functional.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/datasets/utils/video/randaugment.py b/src/datasets/utils/video/randaugment.py index a0f060e..3837b89 100644 --- a/src/datasets/utils/video/randaugment.py +++ b/src/datasets/utils/video/randaugment.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/datasets/utils/video/randerase.py b/src/datasets/utils/video/randerase.py index 0136bfe..b38602e 100644 --- a/src/datasets/utils/video/randerase.py +++ b/src/datasets/utils/video/randerase.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/datasets/utils/video/transforms.py b/src/datasets/utils/video/transforms.py index c606c99..2af9c69 100644 --- a/src/datasets/utils/video/transforms.py +++ b/src/datasets/utils/video/transforms.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/datasets/utils/video/volume_transforms.py b/src/datasets/utils/video/volume_transforms.py index 0a01bb3..cf42d06 100644 --- a/src/datasets/utils/video/volume_transforms.py +++ b/src/datasets/utils/video/volume_transforms.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/datasets/utils/weighted_sampler.py b/src/datasets/utils/weighted_sampler.py index c228f4b..f411c8f 100644 --- a/src/datasets/utils/weighted_sampler.py +++ b/src/datasets/utils/weighted_sampler.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/datasets/video_dataset.py b/src/datasets/video_dataset.py index 0fd153d..5b6e42a 100644 --- a/src/datasets/video_dataset.py +++ b/src/datasets/video_dataset.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -117,6 +117,8 @@ def __init__( self.filter_long_videos = filter_long_videos self.duration = duration + self.frame_sample_rate = None + if VideoReader is None: raise ImportError( 'Unable to import "decord" which is required to read videos.' @@ -181,7 +183,27 @@ def split_into_clips(video): if self.transform is not None: buffer = [self.transform(clip) for clip in buffer] - return buffer, label, clip_indices + # Load Action Data + action_filepath = os.path.join(os.path.dirname(sample), "action_data.csv") + action_df = pd.read_csv(action_filepath) + action_labels = self.get_action_labels_for_clip(action_df, clip_indices) + + return buffer, label, clip_indices, action_labels + + def get_action_labels_for_clip(self, action_df, clip_indices): + # Convert video frame indices to timestamps in seconds + frame_timestamps = clip_indices / self.frame_sample_rate + + # Find the corresponding actions for each frame + action_labels = [] + for timestamp in frame_timestamps: + # Find the row with the nearest timestamp (modify this for interpolation if needed) + nearest_row = action_df.iloc[ + (action_df["timestamp"] - timestamp).abs().argmin() + ] + action_labels.append(nearest_row["action_name"]) + + return action_labels def loadvideo_decord(self, sample): """Load video content using Decord""" @@ -276,6 +298,10 @@ def loadvideo_decord(self, sample): all_indices.extend(list(indices)) buffer = vr.get_batch(all_indices).asnumpy() + + # Added the following line to extract the frame rate from video metadata + self.frame_sample_rate = vr.get_avg_fps() + return buffer, clip_indices def __len__(self): diff --git a/src/masks/default.py b/src/masks/default.py index 2810c0a..a95bbe0 100644 --- a/src/masks/default.py +++ b/src/masks/default.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/masks/multiblock3d.py b/src/masks/multiblock3d.py index 5ab5793..f306677 100644 --- a/src/masks/multiblock3d.py +++ b/src/masks/multiblock3d.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/masks/random_tube.py b/src/masks/random_tube.py index 55fc65d..00fb6a7 100644 --- a/src/masks/random_tube.py +++ b/src/masks/random_tube.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/masks/utils.py b/src/masks/utils.py index ca04af1..bb0ad76 100644 --- a/src/masks/utils.py +++ b/src/masks/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/models/action_encoders.py b/src/models/action_encoders.py new file mode 100644 index 0000000..d6c3732 --- /dev/null +++ b/src/models/action_encoders.py @@ -0,0 +1,35 @@ +import math +from functools import partial + +import torch +import torch.nn as nn + + +class ActionEncoderDiscrete(nn.Module): + def __init__(self, num_actions, embed_dim, hidden_dim): + super(ActionEncoderDiscrete, self).__init__() + self.embedding = nn.Embedding(num_actions, embed_dim) + self.linear = nn.Linear(embed_dim, hidden_dim) + + def forward(self, actions): + embedded_actions = self.embedding(actions) + encoded_actions = self.linear(embedded_actions) + return encoded_actions + + +class ActionEncoderContinuous(nn.Module): + def __init__(self, input_dim, hidden_dim, num_layers): + super(ActionEncoderContinuous, self).__init__() + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + *[ + nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU()) + for _ in range(num_layers - 1) + ], + nn.Linear(hidden_dim, hidden_dim) + ) + + def forward(self, actions): + encoded_actions = self.mlp(actions) + return encoded_actions diff --git a/src/models/attentive_pooler.py b/src/models/attentive_pooler.py index da0c528..122a92b 100644 --- a/src/models/attentive_pooler.py +++ b/src/models/attentive_pooler.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/models/predictor.py b/src/models/predictor.py index 967a23d..0f76f45 100644 --- a/src/models/predictor.py +++ b/src/models/predictor.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/models/utils/combine_encodings.py b/src/models/utils/combine_encodings.py new file mode 100644 index 0000000..c8f0c22 --- /dev/null +++ b/src/models/utils/combine_encodings.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def combine_encodings_concat(z, z_a): + """ + Concatenation: Concatenate the encoded video clips and actions along the feature dimension. + """ + z_combined = torch.cat([z, z_a], dim=-1) + return z_combined + + +def combine_encodings_add(z, z_a): + """ + Addition: Add the encoded video clips and actions element-wise. + """ + z_combined = z + z_a + return z_combined + + +class AttentionFusion(nn.Module): + """ + Attention-based fusion: Use an attention mechanism to weight the importance of video clips and actions based on their relevance. + """ + + def __init__(self, hidden_dim): + super(AttentionFusion, self).__init__() + self.attention = nn.MultiheadAttention(hidden_dim, num_heads=8) + + def forward(self, z, z_a): + z_combined, _ = self.attention(z, z_a, z_a) + return z_combined diff --git a/src/models/utils/modules.py b/src/models/utils/modules.py index dd93b8f..f95ea0a 100644 --- a/src/models/utils/modules.py +++ b/src/models/utils/modules.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/models/utils/multimask.py b/src/models/utils/multimask.py index d480086..db2f841 100644 --- a/src/models/utils/multimask.py +++ b/src/models/utils/multimask.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/models/utils/patch_embed.py b/src/models/utils/patch_embed.py index 42f3115..6488421 100644 --- a/src/models/utils/patch_embed.py +++ b/src/models/utils/patch_embed.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/models/utils/pos_embs.py b/src/models/utils/pos_embs.py index f9792bd..72be0bf 100644 --- a/src/models/utils/pos_embs.py +++ b/src/models/utils/pos_embs.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/models/vision_transformer.py b/src/models/vision_transformer.py index 7f3fbc9..c1fa5bb 100644 --- a/src/models/vision_transformer.py +++ b/src/models/vision_transformer.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/utils/distributed.py b/src/utils/distributed.py index 8b205c0..46cf5dd 100644 --- a/src/utils/distributed.py +++ b/src/utils/distributed.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/utils/logging.py b/src/utils/logging.py index 80141bd..f803910 100644 --- a/src/utils/logging.py +++ b/src/utils/logging.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/utils/monitoring.py b/src/utils/monitoring.py index bfd13a3..505a2a7 100644 --- a/src/utils/monitoring.py +++ b/src/utils/monitoring.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/utils/schedulers.py b/src/utils/schedulers.py index ae29809..b3496d5 100644 --- a/src/utils/schedulers.py +++ b/src/utils/schedulers.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/utils/tensors.py b/src/utils/tensors.py index ff33644..4ec8767 100644 --- a/src/utils/tensors.py +++ b/src/utils/tensors.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the From c263d9cb94fda8c4a1ccaa44fb1569a4e3d2b48a Mon Sep 17 00:00:00 2001 From: Munir Jojo-Verge Date: Fri, 17 May 2024 15:13:18 -0400 Subject: [PATCH 4/8] fix: adding __init__py and alaunch.json for full coverage and debugging --- .vscode/launch.json | 18 ++++++++++++++++++ app/__init__.py | 0 app/main.py | 2 +- app/vjepa/__init__.py | 0 configs/__init__.py | 0 configs/evals/__init__.py | 0 configs/pretrain/__init__.py | 0 evals/__init__.py | 0 src/__init__.py | 0 src/datasets/__init__.py | 0 src/datasets/utils/__init__.py | 0 src/datasets/utils/video/__init__.py | 0 src/masks/__init__.py | 0 src/models/__init__.py | 0 src/models/utils/__init__.py | 0 src/utils/__init__.py | 0 16 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 .vscode/launch.json create mode 100644 app/__init__.py create mode 100644 app/vjepa/__init__.py create mode 100644 configs/__init__.py create mode 100644 configs/evals/__init__.py create mode 100644 configs/pretrain/__init__.py create mode 100644 evals/__init__.py create mode 100644 src/__init__.py create mode 100644 src/datasets/__init__.py create mode 100644 src/datasets/utils/__init__.py create mode 100644 src/datasets/utils/video/__init__.py create mode 100644 src/masks/__init__.py create mode 100644 src/models/__init__.py create mode 100644 src/models/utils/__init__.py create mode 100644 src/utils/__init__.py diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..822f4e6 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,18 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "env": { + "PYTHONPATH": "${workspaceFolder}" + } + } + ] +} \ No newline at end of file diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/main.py b/app/main.py index 0858489..62af1c3 100644 --- a/app/main.py +++ b/app/main.py @@ -17,7 +17,7 @@ parser = argparse.ArgumentParser() parser.add_argument( - "--fname", type=str, help="name of config file to load", default="configs.yaml" + "--fname", type=str, help="name of config file to load", default="configs/pretrain/vith16_384.yaml" ) parser.add_argument( "--devices", diff --git a/app/vjepa/__init__.py b/app/vjepa/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/configs/__init__.py b/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/configs/evals/__init__.py b/configs/evals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/configs/pretrain/__init__.py b/configs/pretrain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/evals/__init__.py b/evals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/datasets/__init__.py b/src/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/datasets/utils/__init__.py b/src/datasets/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/datasets/utils/video/__init__.py b/src/datasets/utils/video/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/masks/__init__.py b/src/masks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/utils/__init__.py b/src/models/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 From 6eda27e966652dd88e9c74006a0afd04262d82a2 Mon Sep 17 00:00:00 2001 From: Munir Jojo-Verge Date: Mon, 20 May 2024 11:38:49 -0400 Subject: [PATCH 5/8] fix: adding __init__py and alaunch.json for full coverage and debugging --- .vscode/launch.json | 18 ++++++++++++++++++ app/__init__.py | 0 app/main.py | 2 +- app/vjepa/__init__.py | 0 configs/__init__.py | 0 configs/evals/__init__.py | 0 configs/pretrain/__init__.py | 0 evals/__init__.py | 0 src/__init__.py | 0 src/datasets/__init__.py | 0 src/datasets/utils/__init__.py | 0 src/datasets/utils/video/__init__.py | 0 src/masks/__init__.py | 0 src/models/__init__.py | 0 src/models/utils/__init__.py | 0 src/utils/__init__.py | 0 16 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 .vscode/launch.json create mode 100644 app/__init__.py create mode 100644 app/vjepa/__init__.py create mode 100644 configs/__init__.py create mode 100644 configs/evals/__init__.py create mode 100644 configs/pretrain/__init__.py create mode 100644 evals/__init__.py create mode 100644 src/__init__.py create mode 100644 src/datasets/__init__.py create mode 100644 src/datasets/utils/__init__.py create mode 100644 src/datasets/utils/video/__init__.py create mode 100644 src/masks/__init__.py create mode 100644 src/models/__init__.py create mode 100644 src/models/utils/__init__.py create mode 100644 src/utils/__init__.py diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..822f4e6 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,18 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "env": { + "PYTHONPATH": "${workspaceFolder}" + } + } + ] +} \ No newline at end of file diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/main.py b/app/main.py index 0858489..62af1c3 100644 --- a/app/main.py +++ b/app/main.py @@ -17,7 +17,7 @@ parser = argparse.ArgumentParser() parser.add_argument( - "--fname", type=str, help="name of config file to load", default="configs.yaml" + "--fname", type=str, help="name of config file to load", default="configs/pretrain/vith16_384.yaml" ) parser.add_argument( "--devices", diff --git a/app/vjepa/__init__.py b/app/vjepa/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/configs/__init__.py b/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/configs/evals/__init__.py b/configs/evals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/configs/pretrain/__init__.py b/configs/pretrain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/evals/__init__.py b/evals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/datasets/__init__.py b/src/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/datasets/utils/__init__.py b/src/datasets/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/datasets/utils/video/__init__.py b/src/datasets/utils/video/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/masks/__init__.py b/src/masks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/utils/__init__.py b/src/models/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 From 3d09a4c9b27e5ba40e36b68d365f4028738cb313 Mon Sep 17 00:00:00 2001 From: Munir Jojo-Verge Date: Mon, 20 May 2024 12:18:07 -0400 Subject: [PATCH 6/8] fix: WIP --- .vscode/launch.json | 8 +- app/main_with_actions.py | 83 ++++ app/vjepa/train.py | 548 ++++++++++-------------- app/vjepa/train_with_actions.py | 715 ++++++++++++++++++++++++++++++++ src/datasets/image_dataset.py | 102 ++--- src/utils/tensors.py | 11 + 6 files changed, 1053 insertions(+), 414 deletions(-) create mode 100644 app/main_with_actions.py create mode 100644 app/vjepa/train_with_actions.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 822f4e6..616ba9e 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -11,8 +11,12 @@ "program": "${file}", "console": "integratedTerminal", "env": { - "PYTHONPATH": "${workspaceFolder}" - } + "PYTHONPATH": "${workspaceFolder}", + "CUDA_VISIBLE_DEVICES": "0" + }, + "args": [ // Pass arguments here + "--fname", "configs/pretrain/vith16_384.yaml" + ], } ] } \ No newline at end of file diff --git a/app/main_with_actions.py b/app/main_with_actions.py new file mode 100644 index 0000000..b111d94 --- /dev/null +++ b/app/main_with_actions.py @@ -0,0 +1,83 @@ +# Copyright (c) NeoCybernetica, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import argparse +import importlib + +import multiprocessing as mp + +import pprint +import yaml + +from app.scaffold import main as app_main +from src.utils.distributed import init_distributed + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--fname", type=str, help="name of config file to load", default="configs/pretrain/vith16_384.yaml" +) +parser.add_argument( + "--devices", + type=str, + nargs="+", + default=["cuda:0"], + help="which devices to use on local machine", +) + + +def process_main(rank, fname, world_size, devices): + import os + + os.environ["CUDA_VISIBLE_DEVICES"] = str(devices[rank].split(":")[-1]) + + import logging + from src.utils.logging import get_logger + + logger = get_logger(force=True) + if rank == 0: + logger.setLevel(logging.INFO) + else: + logger.setLevel(logging.ERROR) + + logger.info(f"called-params {fname}") + + # Load config + params = None + with open(fname, "r") as y_file: + params = yaml.load(y_file, Loader=yaml.FullLoader) + logger.info("loaded params...") + + # Log config + if rank == 0: + pprint.PrettyPrinter(indent=4).pprint(params) + dump = os.path.join(params["logging"]["folder"], "params-pretrain.yaml") + with open(dump, "w") as f: + yaml.dump(params, f) + + # Init distributed (access to comm between GPUS on same machine) + world_size, rank = init_distributed(rank_and_world_size=(rank, world_size)) + logger.info(f"Running... (rank: {rank}/{world_size})") + + # Launch the app with loaded config + # app_main(params["app"], args=params) + + # Update this line to load your new train_with_actions module: + train_module = importlib.import_module(f"app.{params['app']}.train_with_actions") + + # Launch the app with loaded config (use the imported train_module): + train_module.main(args=params) + + +if __name__ == "__main__": + args = parser.parse_args() + num_gpus = len(args.devices) + mp.set_start_method("spawn") + for rank in range(num_gpus): + mp.Process( + target=process_main, args=(rank, args.fname, num_gpus, args.devices) + ).start() diff --git a/app/vjepa/train.py b/app/vjepa/train.py index 34a5507..568fdfe 100644 --- a/app/vjepa/train.py +++ b/app/vjepa/train.py @@ -6,7 +6,6 @@ # import os -import csv # -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS try: @@ -14,7 +13,7 @@ # -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE # -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE # -- TO EACH PROCESS - os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["SLURM_LOCALID"] + os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['SLURM_LOCALID'] except Exception: pass @@ -38,15 +37,8 @@ get_logger, grad_logger, adamw_logger, - AverageMeter, -) + AverageMeter) from src.utils.tensors import repeat_interleave_batch -from src.models.utils.combine_encodings import ( - combine_encodings_concat, - combine_encodings_add, - AttentionFusion, -) - from app.vjepa.utils import ( load_checkpoint, @@ -71,50 +63,25 @@ logger = get_logger(__name__) -def generate_csv_file(data_dir, csv_filename="v-jepa-pretrain.csv"): - csv_filepath = os.path.join(data_dir, csv_filename) - logger.info(f"Generating CSV file: {csv_filepath}") - - valid_folders = [] - for folder_name in os.listdir(data_dir): - folder_path = os.path.join(data_dir, folder_name) - action_filepath = os.path.join(folder_path, "action_data.csv") - if os.path.isdir(folder_path) and os.path.isfile(action_filepath): - valid_folders.append(folder_path) - else: - logger.warning( - f"Skipping folder '{folder_name}' due to missing or invalid action_data.csv" - ) - - with open(csv_filepath, "w", newline="") as csvfile: - writer = csv.writer(csvfile, delimiter=" ") - for folder_path in valid_folders: - writer.writerow([folder_path, 0]) # Write folder path and dummy label (0) - - logger.info(f"CSV file generation complete. Found {len(valid_folders)} valid folders.") - - def main(args, resume_preempt=False): - # First let's go over the folders and generate the - # ----------------------------------------------------------------------- # # PASSED IN PARAMS FROM CONFIG FILE # ----------------------------------------------------------------------- # # -- META - cfgs_meta = args.get("meta") - load_model = cfgs_meta.get("load_checkpoint") or resume_preempt - r_file = cfgs_meta.get("read_checkpoint", None) - seed = cfgs_meta.get("seed", _GLOBAL_SEED) - save_every_freq = cfgs_meta.get("save_every_freq", -1) - skip_batches = cfgs_meta.get("skip_batches", -1) - use_sdpa = cfgs_meta.get("use_sdpa", False) - which_dtype = cfgs_meta.get("dtype") - logger.info(f"{which_dtype=}") - if which_dtype.lower() == "bfloat16": + cfgs_meta = args.get('meta') + load_model = cfgs_meta.get('load_checkpoint') or resume_preempt + r_file = cfgs_meta.get('read_checkpoint', None) + seed = cfgs_meta.get('seed', _GLOBAL_SEED) + save_every_freq = cfgs_meta.get('save_every_freq', -1) + skip_batches = cfgs_meta.get('skip_batches', -1) + use_sdpa = cfgs_meta.get('use_sdpa', False) + which_dtype = cfgs_meta.get('dtype') + logger.info(f'{which_dtype=}') + if which_dtype.lower() == 'bfloat16': dtype = torch.bfloat16 mixed_precision = True - elif which_dtype.lower() == "float16": + elif which_dtype.lower() == 'float16': dtype = torch.float16 mixed_precision = True else: @@ -122,106 +89,98 @@ def main(args, resume_preempt=False): mixed_precision = False # -- MASK - cfgs_mask = args.get("mask") + cfgs_mask = args.get('mask') # -- MODEL - cfgs_model = args.get("model") - model_name = cfgs_model.get("model_name") - pred_depth = cfgs_model.get("pred_depth") - pred_embed_dim = cfgs_model.get("pred_embed_dim") - uniform_power = cfgs_model.get("uniform_power", True) - use_mask_tokens = cfgs_model.get("use_mask_tokens", True) - zero_init_mask_tokens = cfgs_model.get("zero_init_mask_tokens", True) + cfgs_model = args.get('model') + model_name = cfgs_model.get('model_name') + pred_depth = cfgs_model.get('pred_depth') + pred_embed_dim = cfgs_model.get('pred_embed_dim') + uniform_power = cfgs_model.get('uniform_power', True) + use_mask_tokens = cfgs_model.get('use_mask_tokens', True) + zero_init_mask_tokens = cfgs_model.get('zero_init_mask_tokens', True) # -- DATA - cfgs_data = args.get("data") - dataset_type = cfgs_data.get("dataset_type", "egovehicle_imagedataset") - mask_type = cfgs_data.get("mask_type", "multiblock3d") - dataset_paths = cfgs_data.get("datasets", []) - datasets_weights = cfgs_data.get("datasets_weights", None) + cfgs_data = args.get('data') + dataset_type = cfgs_data.get('dataset_type', 'videodataset') + mask_type = cfgs_data.get('mask_type', 'multiblock3d') + dataset_paths = cfgs_data.get('datasets', []) + datasets_weights = cfgs_data.get('datasets_weights', None) if datasets_weights is not None: - assert len(datasets_weights) == len( - dataset_paths - ), "Must have one sampling weight specified for each dataset" - batch_size = cfgs_data.get("batch_size") - num_clips = cfgs_data.get("num_clips") - num_frames = cfgs_data.get("num_frames") - tubelet_size = cfgs_data.get("tubelet_size") - sampling_rate = cfgs_data.get("sampling_rate") - duration = cfgs_data.get("clip_duration", None) - crop_size = cfgs_data.get("crop_size", 224) - patch_size = cfgs_data.get("patch_size") - pin_mem = cfgs_data.get("pin_mem", False) - num_workers = cfgs_data.get("num_workers", 1) - filter_short_videos = cfgs_data.get("filter_short_videos", False) - decode_one_clip = cfgs_data.get("decode_one_clip", True) - log_resource_util_data = cfgs_data.get("log_resource_utilization", False) + assert len(datasets_weights) == len(dataset_paths), 'Must have one sampling weight specified for each dataset' + batch_size = cfgs_data.get('batch_size') + num_clips = cfgs_data.get('num_clips') + num_frames = cfgs_data.get('num_frames') + tubelet_size = cfgs_data.get('tubelet_size') + sampling_rate = cfgs_data.get('sampling_rate') + duration = cfgs_data.get('clip_duration', None) + crop_size = cfgs_data.get('crop_size', 224) + patch_size = cfgs_data.get('patch_size') + pin_mem = cfgs_data.get('pin_mem', False) + num_workers = cfgs_data.get('num_workers', 1) + filter_short_videos = cfgs_data.get('filter_short_videos', False) + decode_one_clip = cfgs_data.get('decode_one_clip', True) + log_resource_util_data = cfgs_data.get('log_resource_utilization', False) # -- DATA AUGS - cfgs_data_aug = args.get("data_aug") - ar_range = cfgs_data_aug.get("random_resize_aspect_ratio", [3 / 4, 4 / 3]) - rr_scale = cfgs_data_aug.get("random_resize_scale", [0.3, 1.0]) - motion_shift = cfgs_data_aug.get("motion_shift", False) - reprob = cfgs_data_aug.get("reprob", 0.0) - use_aa = cfgs_data_aug.get("auto_augment", False) + cfgs_data_aug = args.get('data_aug') + ar_range = cfgs_data_aug.get('random_resize_aspect_ratio', [3/4, 4/3]) + rr_scale = cfgs_data_aug.get('random_resize_scale', [0.3, 1.0]) + motion_shift = cfgs_data_aug.get('motion_shift', False) + reprob = cfgs_data_aug.get('reprob', 0.) + use_aa = cfgs_data_aug.get('auto_augment', False) # -- LOSS - cfgs_loss = args.get("loss") - loss_exp = cfgs_loss.get("loss_exp") - reg_coeff = cfgs_loss.get("reg_coeff") + cfgs_loss = args.get('loss') + loss_exp = cfgs_loss.get('loss_exp') + reg_coeff = cfgs_loss.get('reg_coeff') # -- OPTIMIZATION - cfgs_opt = args.get("optimization") - ipe = cfgs_opt.get("ipe", None) - ipe_scale = cfgs_opt.get("ipe_scale", 1.0) - clip_grad = cfgs_opt.get("clip_grad", None) - wd = float(cfgs_opt.get("weight_decay")) - final_wd = float(cfgs_opt.get("final_weight_decay")) - num_epochs = cfgs_opt.get("epochs") - warmup = cfgs_opt.get("warmup") - start_lr = cfgs_opt.get("start_lr") - lr = cfgs_opt.get("lr") - final_lr = cfgs_opt.get("final_lr") - ema = cfgs_opt.get("ema") - betas = cfgs_opt.get("betas", (0.9, 0.999)) - eps = cfgs_opt.get("eps", 1.0e-8) + cfgs_opt = args.get('optimization') + ipe = cfgs_opt.get('ipe', None) + ipe_scale = cfgs_opt.get('ipe_scale', 1.0) + clip_grad = cfgs_opt.get('clip_grad', None) + wd = float(cfgs_opt.get('weight_decay')) + final_wd = float(cfgs_opt.get('final_weight_decay')) + num_epochs = cfgs_opt.get('epochs') + warmup = cfgs_opt.get('warmup') + start_lr = cfgs_opt.get('start_lr') + lr = cfgs_opt.get('lr') + final_lr = cfgs_opt.get('final_lr') + ema = cfgs_opt.get('ema') + betas = cfgs_opt.get('betas', (0.9, 0.999)) + eps = cfgs_opt.get('eps', 1.e-8) # -- LOGGING - cfgs_logging = args.get("logging") - folder = cfgs_logging.get("folder") - tag = cfgs_logging.get("write_tag") + cfgs_logging = args.get('logging') + folder = cfgs_logging.get('folder') + tag = cfgs_logging.get('write_tag') # ----------------------------------------------------------------------- # # ----------------------------------------------------------------------- # - # Generate CSV file (only if not already exists) - csv_filename = "v-jepa-pretrain.csv" - # if not os.path.exists(os.path.join(dataset_paths[0], csv_filename)): - generate_csv_file(dataset_paths[0]) - - np.random.seed(seed) torch.manual_seed(seed) torch.backends.cudnn.benchmark = True try: - mp.set_start_method("spawn") + mp.set_start_method('spawn') except Exception: pass # -- init torch distributed backend world_size, rank = init_distributed() - logger.info(f"Initialized (rank/world-size) {rank}/{world_size}") + logger.info(f'Initialized (rank/world-size) {rank}/{world_size}') # -- set device if not torch.cuda.is_available(): - device = torch.device("cpu") + device = torch.device('cpu') else: - device = torch.device("cuda:0") + device = torch.device('cuda:0') torch.cuda.set_device(device) # -- log/checkpointing paths - log_file = os.path.join(folder, f"{tag}_r{rank}.csv") - latest_file = f"{tag}-latest.pth.tar" + log_file = os.path.join(folder, f'{tag}_r{rank}.csv') + latest_file = f'{tag}-latest.pth.tar' latest_path = os.path.join(folder, latest_file) load_path = None if load_model: @@ -233,19 +192,19 @@ def main(args, resume_preempt=False): # -- make csv_logger csv_logger = CSVLogger( log_file, - ("%d", "epoch"), - ("%d", "itr"), - ("%.5f", "loss"), - ("%.5f", "loss-jepa"), - ("%.5f", "reg-loss"), - ("%.5f", "enc-grad-norm"), - ("%.5f", "pred-grad-norm"), - ("%d", "gpu-time(ms)"), - ("%d", "wall-time(ms)"), + ('%d', 'epoch'), + ('%d', 'itr'), + ('%.5f', 'loss'), + ('%.5f', 'loss-jepa'), + ('%.5f', 'reg-loss'), + ('%.5f', 'enc-grad-norm'), + ('%.5f', 'pred-grad-norm'), + ('%d', 'gpu-time(ms)'), + ('%d', 'wall-time(ms)'), ) # -- init model - encoder, predictor, action_encoder = init_video_model( + encoder, predictor = init_video_model( uniform_power=uniform_power, use_mask_tokens=use_mask_tokens, num_mask_tokens=len(cfgs_mask), @@ -263,24 +222,22 @@ def main(args, resume_preempt=False): target_encoder = copy.deepcopy(encoder) # -- make data transforms - if mask_type == "multiblock3d": - logger.info("Initializing basic multi-block mask") + if mask_type == 'multiblock3d': + logger.info('Initializing basic multi-block mask') mask_collator = MB3DMaskCollator( crop_size=crop_size, num_frames=num_frames, patch_size=patch_size, tubelet_size=tubelet_size, - cfgs_mask=cfgs_mask, - ) + cfgs_mask=cfgs_mask) else: - logger.info("Initializing random tube mask") + logger.info('Initializing random tube mask') mask_collator = TubeMaskCollator( crop_size=crop_size, num_frames=num_frames, patch_size=patch_size, tubelet_size=tubelet_size, - cfgs_mask=cfgs_mask, - ) + cfgs_mask=cfgs_mask) transform = make_transforms( random_horizontal_flip=True, random_resize_aspect_ratio=ar_range, @@ -288,37 +245,36 @@ def main(args, resume_preempt=False): reprob=reprob, auto_augment=use_aa, motion_shift=motion_shift, - crop_size=crop_size, - ) + crop_size=crop_size) # -- init data-loaders/samplers - (unsupervised_loader, unsupervised_sampler) = init_data( - data=dataset_type, - root_path=dataset_paths, - batch_size=batch_size, - training=True, - clip_len=num_frames, - frame_sample_rate=sampling_rate, - filter_short_videos=filter_short_videos, - decode_one_clip=decode_one_clip, - duration=duration, - num_clips=num_clips, - transform=transform, - datasets_weights=datasets_weights, - collator=mask_collator, - num_workers=num_workers, - world_size=world_size, - pin_mem=pin_mem, - rank=rank, - log_dir=folder if log_resource_util_data else None, - ) + (unsupervised_loader, + unsupervised_sampler) = init_data( + data=dataset_type, + root_path=dataset_paths, + batch_size=batch_size, + training=True, + clip_len=num_frames, + frame_sample_rate=sampling_rate, + filter_short_videos=filter_short_videos, + decode_one_clip=decode_one_clip, + duration=duration, + num_clips=num_clips, + transform=transform, + datasets_weights=datasets_weights, + collator=mask_collator, + num_workers=num_workers, + world_size=world_size, + pin_mem=pin_mem, + rank=rank, + log_dir=folder if log_resource_util_data else None) try: _dlen = len(unsupervised_loader) except Exception: # Different interface for webdataset _dlen = unsupervised_loader.num_batches if ipe is None: ipe = _dlen - logger.info(f"iterations per epoch/dataest length: {ipe}/{_dlen}") + logger.info(f'iterations per epoch/dataest length: {ipe}/{_dlen}') # -- init optimizer and scheduler optimizer, scaler, scheduler, wd_scheduler = init_opt( @@ -335,8 +291,7 @@ def main(args, resume_preempt=False): ipe_scale=ipe_scale, mixed_precision=mixed_precision, betas=betas, - eps=eps, - ) + eps=eps) encoder = DistributedDataParallel(encoder, static_graph=True) predictor = DistributedDataParallel(predictor, static_graph=True) target_encoder = DistributedDataParallel(target_encoder) @@ -344,10 +299,8 @@ def main(args, resume_preempt=False): p.requires_grad = False # -- momentum schedule - momentum_scheduler = ( - ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale) - for i in range(int(ipe * num_epochs * ipe_scale) + 1) - ) + momentum_scheduler = (ema[0] + i*(ema[1]-ema[0])/(ipe*num_epochs*ipe_scale) + for i in range(int(ipe*num_epochs*ipe_scale)+1)) start_epoch = 0 # -- load training checkpoint @@ -365,8 +318,7 @@ def main(args, resume_preempt=False): predictor=predictor, target_encoder=target_encoder, opt=optimizer, - scaler=scaler, - ) + scaler=scaler) for _ in range(start_epoch * ipe): scheduler.step() wd_scheduler.step() @@ -377,31 +329,31 @@ def save_checkpoint(epoch, path): if rank != 0: return save_dict = { - "encoder": encoder.state_dict(), - "predictor": predictor.state_dict(), - "opt": optimizer.state_dict(), - "scaler": None if scaler is None else scaler.state_dict(), - "target_encoder": target_encoder.state_dict(), - "epoch": epoch, - "loss": loss_meter.avg, - "batch_size": batch_size, - "world_size": world_size, - "lr": lr, + 'encoder': encoder.state_dict(), + 'predictor': predictor.state_dict(), + 'opt': optimizer.state_dict(), + 'scaler': None if scaler is None else scaler.state_dict(), + 'target_encoder': target_encoder.state_dict(), + 'epoch': epoch, + 'loss': loss_meter.avg, + 'batch_size': batch_size, + 'world_size': world_size, + 'lr': lr, } try: torch.save(save_dict, path) except Exception as e: - logger.info(f"Encountered exception when saving checkpoint: {e}") + logger.info(f'Encountered exception when saving checkpoint: {e}') - logger.info("Initializing loader...") + logger.info('Initializing loader...') loader = iter(unsupervised_loader) if skip_batches > 0: - logger.info(f"Skip {skip_batches} batches") + logger.info(f'Skip {skip_batches} batches') unsupervised_sampler.set_epoch(start_epoch) for itr in range(skip_batches): if itr % 10 == 0: - logger.info(f"Skip {itr}/{skip_batches} batches") + logger.info(f'Skip {itr}/{skip_batches} batches') try: udata = next(loader) except Exception: @@ -410,7 +362,7 @@ def save_checkpoint(epoch, path): # -- TRAINING LOOP for epoch in range(start_epoch, num_epochs): - logger.info("Epoch %d" % (epoch + 1)) + logger.info('Epoch %d' % (epoch + 1)) # -- update distributed-data-loader epoch unsupervised_sampler.set_epoch(epoch) @@ -429,32 +381,18 @@ def save_checkpoint(epoch, path): try: udata, masks_enc, masks_pred = next(loader) - - except StopIteration: - logger.info( - "Exhausted data loaders before completing all planned iterations. Ending epoch early..." - ) - break # Exit the current epoch loop if there are no more data points to process - # except Exception: - # logger.info('Exhausted data loaders. Refreshing...') - # loader = iter(unsupervised_loader) - # udata, masks_enc, masks_pred = next(loader) - assert len(masks_enc) == len( - masks_pred - ), "Currently require num encoder masks = num predictor masks" + except Exception: + logger.info('Exhausted data loaders. Refreshing...') + loader = iter(unsupervised_loader) + udata, masks_enc, masks_pred = next(loader) + assert len(masks_enc) == len(masks_pred), \ + 'Currently require num encoder masks = num predictor masks' def load_clips(): # -- unsupervised video clips # Put each clip on the GPU and concatenate along batch # dimension - clips = torch.cat( - [u.to(device, non_blocking=True) for u in udata[0]], dim=0 - ) - - # Load action data - actions = torch.cat( - [a.to(device, non_blocking=True) for a in actions], dim=0 - ) + clips = torch.cat([u.to(device, non_blocking=True) for u in udata[0]], dim=0) # Put each mask-enc/mask-pred pair on the GPU and reuse the # same mask pair for each clip @@ -467,9 +405,8 @@ def load_clips(): _masks_enc.append(_me) _masks_pred.append(_mp) - return (clips, _masks_enc, _masks_pred, actions) - - clips, masks_enc, masks_pred, actions = load_clips() + return (clips, _masks_enc, _masks_pred) + clips, masks_enc, masks_pred = load_clips() for _i, m in enumerate(mask_meters): m.update(masks_enc[_i][0].size(-1)) @@ -486,78 +423,51 @@ def forward_target(c): """ with torch.no_grad(): h = target_encoder(c) - h = F.layer_norm( - h, (h.size(-1),) - ) # normalize over feature-dim [B, N, D] + h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dim [B, N, D] # -- create targets (masked regions of h) - # h = apply_masks(h, masks_pred, concat=False) - # -- create targets (next frames of h) - h_next = h[ - :, 1:, : - ] # Assuming frames are ordered chronologically. These are ground truth next frames - return h_next - - def forward_context(c, h, a): + h = apply_masks(h, masks_pred, concat=False) + return h + + def forward_context(c, h): """ Returns list of tensors of shape [B, N, D], one for each mask-pred. """ z = encoder(c, masks_enc) - - # Encode the action representations - z_a = action_encoder(a) - - # Combine the encoded actions with the encoded video clips - z_combined = combine_encodings_concat(z, z_a) - - z = predictor(z_combined, h, masks_enc, masks_pred) + z = predictor(z, h, masks_enc, masks_pred) return z - # def loss_fn(z, h): - # loss = 0.0 - # # Compute loss and accumulate for each mask-enc/mask-pred pair - # for zi, hi in zip(z, h): - # loss += torch.mean(torch.abs(zi - hi) ** loss_exp) / loss_exp - # # loss /= len(masks_pred) - # return loss - - def loss_fn(z_next, h_next): - loss = 0.0 - # Compute loss between predicted next frames and ground truth next frames - for zi, hi in zip(z_next, h_next): - loss += torch.mean(torch.abs(zi - hi) ** loss_exp) / loss_exp - loss /= len(h_next) + def loss_fn(z, h): + loss = 0. + # Compute loss and accumulate for each mask-enc/mask-pred pair + for zi, hi in zip(z, h): + loss += torch.mean(torch.abs(zi - hi)**loss_exp) / loss_exp + loss /= len(masks_pred) return loss def reg_fn(z): - return sum([torch.sqrt(zi.var(dim=1) + 0.0001) for zi in z]) / len( - z - ) + return sum([torch.sqrt(zi.var(dim=1) + 0.0001) for zi in z]) / len(z) # Step 1. Forward - loss_jepa, loss_reg = 0.0, 0.0 + loss_jepa, loss_reg = 0., 0. with torch.cuda.amp.autocast(dtype=dtype, enabled=mixed_precision): - h_next = forward_target(clips) - z_next = forward_context(clips, h_next, actions) - loss_jepa = loss_fn(z_next, h_next) # jepa prediction loss - pstd_z = reg_fn(z_next) # predictor variance across patches - loss_reg += torch.mean(F.relu(1.0 - pstd_z)) + h = forward_target(clips) + z = forward_context(clips, h) + loss_jepa = loss_fn(z, h) # jepa prediction loss + pstd_z = reg_fn(z) # predictor variance across patches + loss_reg += torch.mean(F.relu(1.-pstd_z)) loss = loss_jepa + reg_coeff * loss_reg # Step 2. Backward & step - _enc_norm, _pred_norm = 0.0, 0.0 + _enc_norm, _pred_norm = 0., 0. if mixed_precision: scaler.scale(loss).backward() scaler.unscale_(optimizer) else: loss.backward() if (epoch > warmup) and (clip_grad is not None): - _enc_norm = torch.nn.utils.clip_grad_norm_( - encoder.parameters(), clip_grad - ) - _pred_norm = torch.nn.utils.clip_grad_norm_( - predictor.parameters(), clip_grad - ) + _enc_norm = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip_grad) + _pred_norm = torch.nn.utils.clip_grad_norm_(predictor.parameters(), clip_grad) if mixed_precision: scaler.step(optimizer) scaler.update() @@ -573,10 +483,8 @@ def reg_fn(z): # Step 3. momentum update of target encoder m = next(momentum_scheduler) with torch.no_grad(): - for param_q, param_k in zip( - encoder.parameters(), target_encoder.parameters() - ): - param_k.data.mul_(m).add_((1.0 - m) * param_q.detach().data) + for param_q, param_k in zip(encoder.parameters(), target_encoder.parameters()): + param_k.data.mul_(m).add_((1.-m) * param_q.detach().data) return ( float(loss), @@ -588,25 +496,11 @@ def reg_fn(z): grad_stats_pred, optim_stats, ) - - ( - loss, - loss_jepa, - loss_reg, - _new_lr, - _new_wd, - grad_stats, - grad_stats_pred, - optim_stats, - ), gpu_etime_ms = gpu_timer(train_step) - iter_elapsed_time_ms = (time.time() - itr_start_time) * 1000.0 + (loss, loss_jepa, loss_reg, _new_lr, _new_wd, grad_stats, grad_stats_pred, optim_stats,), gpu_etime_ms = gpu_timer(train_step) + iter_elapsed_time_ms = (time.time() - itr_start_time) * 1000. loss_meter.update(loss) - input_var = float( - AllReduce.apply(clips.view(clips.shape[0], -1).var(dim=1).mean(dim=0)) - ) - input_var_min = float( - AllReduce.apply(torch.min(clips.view(clips.shape[0], -1).var(dim=1))) - ) + input_var = float(AllReduce.apply(clips.view(clips.shape[0], -1).var(dim=1).mean(dim=0))) + input_var_min = float(AllReduce.apply(torch.min(clips.view(clips.shape[0], -1).var(dim=1)))) input_var_meter.update(input_var) input_var_min_meter.update(input_var_min) jepa_loss_meter.update(loss_jepa) @@ -625,88 +519,68 @@ def log_stats(): grad_stats.global_norm, grad_stats_pred.global_norm, gpu_etime_ms, - iter_elapsed_time_ms, - ) + iter_elapsed_time_ms) if (itr % log_freq == 0) or np.isnan(loss) or np.isinf(loss): logger.info( - "[%d, %5d] loss: %.3f | p%.3f r%.3f | " - "input_var: %.3f %.3f | " - "masks: %s " - "[wd: %.2e] [lr: %.2e] " - "[mem: %.2e] " - "[gpu: %.1f ms]" - "[wall: %.1f ms]" - % ( - epoch + 1, - itr, - loss_meter.avg, - jepa_loss_meter.avg, - reg_loss_meter.avg, - input_var_meter.avg, - input_var_min_meter.avg, - "[" - + ", ".join(["%.1f" % m.avg for m in mask_meters]) - + "]", - _new_wd, - _new_lr, - torch.cuda.max_memory_allocated() / 1024.0**2, - gpu_time_meter.avg, - wall_time_meter.avg, - ) - ) + '[%d, %5d] loss: %.3f | p%.3f r%.3f | ' + 'input_var: %.3f %.3f | ' + 'masks: %s ' + '[wd: %.2e] [lr: %.2e] ' + '[mem: %.2e] ' + '[gpu: %.1f ms]' + '[wall: %.1f ms]' + % (epoch + 1, itr, + loss_meter.avg, + jepa_loss_meter.avg, + reg_loss_meter.avg, + input_var_meter.avg, + input_var_min_meter.avg, + '[' + ', '.join(['%.1f' % m.avg for m in mask_meters]) + ']', + _new_wd, + _new_lr, + torch.cuda.max_memory_allocated() / 1024.0**2, + gpu_time_meter.avg, + wall_time_meter.avg)) if optim_stats is not None: logger.info( - "[%d, %5d] first moment: %.2e [%.2e %.2e] second moment: %.2e [%.2e %.2e]" - % ( - epoch + 1, - itr, - optim_stats.get("exp_avg").avg, - optim_stats.get("exp_avg").min, - optim_stats.get("exp_avg").max, - optim_stats.get("exp_avg_sq").avg, - optim_stats.get("exp_avg_sq").min, - optim_stats.get("exp_avg_sq").max, - ) - ) + '[%d, %5d] first moment: %.2e [%.2e %.2e] second moment: %.2e [%.2e %.2e]' + % (epoch + 1, itr, + optim_stats.get('exp_avg').avg, + optim_stats.get('exp_avg').min, + optim_stats.get('exp_avg').max, + optim_stats.get('exp_avg_sq').avg, + optim_stats.get('exp_avg_sq').min, + optim_stats.get('exp_avg_sq').max)) if grad_stats is not None: logger.info( - "[%d, %5d] enc_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e" - % ( - epoch + 1, - itr, - grad_stats.first_layer, - grad_stats.last_layer, - grad_stats.min, - grad_stats.max, - grad_stats.global_norm, - ) - ) + '[%d, %5d] enc_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e' + % (epoch + 1, itr, + grad_stats.first_layer, + grad_stats.last_layer, + grad_stats.min, + grad_stats.max, + grad_stats.global_norm)) if grad_stats_pred is not None: logger.info( - "[%d, %5d] pred_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e" - % ( - epoch + 1, - itr, - grad_stats_pred.first_layer, - grad_stats_pred.last_layer, - grad_stats_pred.min, - grad_stats_pred.max, - grad_stats_pred.global_norm, - ) - ) - + '[%d, %5d] pred_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e' + % (epoch + 1, itr, + grad_stats_pred.first_layer, + grad_stats_pred.last_layer, + grad_stats_pred.min, + grad_stats_pred.max, + grad_stats_pred.global_norm)) log_stats() - assert not np.isnan(loss), "loss is nan" + assert not np.isnan(loss), 'loss is nan' # -- Save Checkpoint - logger.info("avg. loss %.3f" % loss_meter.avg) + logger.info('avg. loss %.3f' % loss_meter.avg) # -- Save Last if epoch % checkpoint_freq == 0 or epoch == (num_epochs - 1): save_checkpoint(epoch + 1, latest_path) if save_every_freq > 0 and epoch % save_every_freq == 0: - save_every_file = f"{tag}-e{epoch}.pth.tar" + save_every_file = f'{tag}-e{epoch}.pth.tar' save_every_path = os.path.join(folder, save_every_file) - save_checkpoint(epoch + 1, save_every_path) + save_checkpoint(epoch + 1, save_every_path) \ No newline at end of file diff --git a/app/vjepa/train_with_actions.py b/app/vjepa/train_with_actions.py new file mode 100644 index 0000000..f949095 --- /dev/null +++ b/app/vjepa/train_with_actions.py @@ -0,0 +1,715 @@ +# Copyright (c) NeoCybernetica, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os +import csv + +# -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS +try: + # -- WARNING: IF DOING DISTRIBUTED TRAINING ON A NON-SLURM CLUSTER, MAKE + # -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE + # -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE + # -- TO EACH PROCESS + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["SLURM_LOCALID"] +except Exception: + pass + +import copy +import time +import numpy as np + +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel + +from src.datasets.data_manager import init_data +from src.masks.random_tube import MaskCollator as TubeMaskCollator +from src.masks.multiblock3d import MaskCollator as MB3DMaskCollator +from src.masks.utils import apply_masks +from src.utils.distributed import init_distributed, AllReduce +from src.utils.logging import ( + CSVLogger, + gpu_timer, + get_logger, + grad_logger, + adamw_logger, + AverageMeter, +) +from src.utils.tensors import repeat_interleave_batch, to_batch +from src.models.utils.combine_encodings import ( + combine_encodings_concat, + combine_encodings_add, + AttentionFusion, +) + + +from app.vjepa.utils import ( + load_checkpoint, + init_video_model, + init_opt, +) +from app.vjepa.transforms import make_transforms + + +# -- +log_timings = True +log_freq = 10 +checkpoint_freq = 1 +# -- + +_GLOBAL_SEED = 0 +np.random.seed(_GLOBAL_SEED) +torch.manual_seed(_GLOBAL_SEED) +torch.backends.cudnn.benchmark = True + + +logger = get_logger(__name__) + + +def generate_csv_file(data_dir, csv_filename="v-jepa-pretrain.csv"): + csv_filepath = os.path.join(data_dir, csv_filename) + logger.info(f"Generating CSV file: {csv_filepath}") + + valid_folders = [] + for folder_name in os.listdir(data_dir): + folder_path = os.path.join(data_dir, folder_name) + action_filepath = os.path.join(folder_path, "action_data.csv") + if os.path.isdir(folder_path) and os.path.isfile(action_filepath): + valid_folders.append(folder_path) + else: + logger.warning( + f"Skipping folder '{folder_name}' due to missing or invalid action_data.csv" + ) + + with open(csv_filepath, "w", newline="") as csvfile: + writer = csv.writer(csvfile, delimiter=" ") + for folder_path in valid_folders: + writer.writerow([folder_path, 0]) # Write folder path and dummy label (0) + + logger.info(f"CSV file generation complete. Found {len(valid_folders)} valid folders.") + + +def main(args, resume_preempt=False): + # First let's go over the folders and generate the + + # ----------------------------------------------------------------------- # + # PASSED IN PARAMS FROM CONFIG FILE + # ----------------------------------------------------------------------- # + + # -- META + cfgs_meta = args.get("meta") + load_model = cfgs_meta.get("load_checkpoint") or resume_preempt + r_file = cfgs_meta.get("read_checkpoint", None) + seed = cfgs_meta.get("seed", _GLOBAL_SEED) + save_every_freq = cfgs_meta.get("save_every_freq", -1) + skip_batches = cfgs_meta.get("skip_batches", -1) + use_sdpa = cfgs_meta.get("use_sdpa", False) + which_dtype = cfgs_meta.get("dtype") + logger.info(f"{which_dtype=}") + if which_dtype.lower() == "bfloat16": + dtype = torch.bfloat16 + mixed_precision = True + elif which_dtype.lower() == "float16": + dtype = torch.float16 + mixed_precision = True + else: + dtype = torch.float32 + mixed_precision = False + + # -- MASK + cfgs_mask = args.get("mask") + + # -- MODEL + cfgs_model = args.get("model") + model_name = cfgs_model.get("model_name") + pred_depth = cfgs_model.get("pred_depth") + pred_embed_dim = cfgs_model.get("pred_embed_dim") + uniform_power = cfgs_model.get("uniform_power", True) + use_mask_tokens = cfgs_model.get("use_mask_tokens", True) + zero_init_mask_tokens = cfgs_model.get("zero_init_mask_tokens", True) + + # -- DATA + cfgs_data = args.get("data") + dataset_type = cfgs_data.get("dataset_type", "egovehicle_imagedataset") + mask_type = cfgs_data.get("mask_type", "multiblock3d") + dataset_paths = cfgs_data.get("datasets", []) + datasets_weights = cfgs_data.get("datasets_weights", None) + if datasets_weights is not None: + assert len(datasets_weights) == len( + dataset_paths + ), "Must have one sampling weight specified for each dataset" + batch_size = cfgs_data.get("batch_size") + num_clips = cfgs_data.get("num_clips") + num_frames = cfgs_data.get("num_frames") + tubelet_size = cfgs_data.get("tubelet_size") + sampling_rate = cfgs_data.get("sampling_rate") + duration = cfgs_data.get("clip_duration", None) + crop_size = cfgs_data.get("crop_size", 224) + patch_size = cfgs_data.get("patch_size") + pin_mem = cfgs_data.get("pin_mem", False) + num_workers = cfgs_data.get("num_workers", 1) + filter_short_videos = cfgs_data.get("filter_short_videos", False) + decode_one_clip = cfgs_data.get("decode_one_clip", True) + log_resource_util_data = cfgs_data.get("log_resource_utilization", False) + + # -- DATA AUGS + cfgs_data_aug = args.get("data_aug") + ar_range = cfgs_data_aug.get("random_resize_aspect_ratio", [3 / 4, 4 / 3]) + rr_scale = cfgs_data_aug.get("random_resize_scale", [0.3, 1.0]) + motion_shift = cfgs_data_aug.get("motion_shift", False) + reprob = cfgs_data_aug.get("reprob", 0.0) + use_aa = cfgs_data_aug.get("auto_augment", False) + + # -- LOSS + cfgs_loss = args.get("loss") + loss_exp = cfgs_loss.get("loss_exp") + reg_coeff = cfgs_loss.get("reg_coeff") + + # -- OPTIMIZATION + cfgs_opt = args.get("optimization") + ipe = cfgs_opt.get("ipe", None) + ipe_scale = cfgs_opt.get("ipe_scale", 1.0) + clip_grad = cfgs_opt.get("clip_grad", None) + wd = float(cfgs_opt.get("weight_decay")) + final_wd = float(cfgs_opt.get("final_weight_decay")) + num_epochs = cfgs_opt.get("epochs") + warmup = cfgs_opt.get("warmup") + start_lr = cfgs_opt.get("start_lr") + lr = cfgs_opt.get("lr") + final_lr = cfgs_opt.get("final_lr") + ema = cfgs_opt.get("ema") + betas = cfgs_opt.get("betas", (0.9, 0.999)) + eps = cfgs_opt.get("eps", 1.0e-8) + + # -- LOGGING + cfgs_logging = args.get("logging") + folder = cfgs_logging.get("folder") + tag = cfgs_logging.get("write_tag") + + # ----------------------------------------------------------------------- # + # ----------------------------------------------------------------------- # + + # Generate CSV file (only if not already exists) + csv_filename = "v-jepa-pretrain.csv" + # if not os.path.exists(os.path.join(dataset_paths[0], csv_filename)): + generate_csv_file(dataset_paths[0]) + + + np.random.seed(seed) + torch.manual_seed(seed) + torch.backends.cudnn.benchmark = True + try: + mp.set_start_method("spawn") + except Exception: + pass + + # -- init torch distributed backend + world_size, rank = init_distributed() + logger.info(f"Initialized (rank/world-size) {rank}/{world_size}") + + # -- set device + if not torch.cuda.is_available(): + device = torch.device("cpu") + else: + device = torch.device("cuda:0") + torch.cuda.set_device(device) + + # -- log/checkpointing paths + log_file = os.path.join(folder, f"{tag}_r{rank}.csv") + latest_file = f"{tag}-latest.pth.tar" + latest_path = os.path.join(folder, latest_file) + load_path = None + if load_model: + load_path = os.path.join(folder, r_file) if r_file is not None else latest_path + if not os.path.exists(load_path): + load_path = None + load_model = False + + # -- make csv_logger + csv_logger = CSVLogger( + log_file, + ("%d", "epoch"), + ("%d", "itr"), + ("%.5f", "loss"), + ("%.5f", "loss-jepa"), + ("%.5f", "reg-loss"), + ("%.5f", "enc-grad-norm"), + ("%.5f", "pred-grad-norm"), + ("%d", "gpu-time(ms)"), + ("%d", "wall-time(ms)"), + ) + + # -- init model + encoder, predictor, action_encoder = init_video_model( + uniform_power=uniform_power, + use_mask_tokens=use_mask_tokens, + num_mask_tokens=len(cfgs_mask), + zero_init_mask_tokens=zero_init_mask_tokens, + device=device, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + model_name=model_name, + crop_size=crop_size, + pred_depth=pred_depth, + pred_embed_dim=pred_embed_dim, + use_sdpa=use_sdpa, + ) + target_encoder = copy.deepcopy(encoder) + + # -- make data transforms + if mask_type == "multiblock3d": + logger.info("Initializing basic multi-block mask") + mask_collator = MB3DMaskCollator( + crop_size=crop_size, + num_frames=num_frames, + patch_size=patch_size, + tubelet_size=tubelet_size, + cfgs_mask=cfgs_mask, + ) + else: + logger.info("Initializing random tube mask") + mask_collator = TubeMaskCollator( + crop_size=crop_size, + num_frames=num_frames, + patch_size=patch_size, + tubelet_size=tubelet_size, + cfgs_mask=cfgs_mask, + ) + transform = make_transforms( + random_horizontal_flip=True, + random_resize_aspect_ratio=ar_range, + random_resize_scale=rr_scale, + reprob=reprob, + auto_augment=use_aa, + motion_shift=motion_shift, + crop_size=crop_size, + ) + + # -- init data-loaders/samplers + (unsupervised_loader, unsupervised_sampler) = init_data( + data=dataset_type, + root_path=dataset_paths, + batch_size=batch_size, + training=True, + clip_len=num_frames, + frame_sample_rate=sampling_rate, + filter_short_videos=filter_short_videos, + decode_one_clip=decode_one_clip, + duration=duration, + num_clips=num_clips, + transform=transform, + datasets_weights=datasets_weights, + collator=mask_collator, + num_workers=num_workers, + world_size=world_size, + pin_mem=pin_mem, + rank=rank, + log_dir=folder if log_resource_util_data else None, + ) + try: + _dlen = len(unsupervised_loader) + except Exception: # Different interface for webdataset + _dlen = unsupervised_loader.num_batches + if ipe is None: + ipe = _dlen + logger.info(f"iterations per epoch/dataest length: {ipe}/{_dlen}") + + # -- init optimizer and scheduler + optimizer, scaler, scheduler, wd_scheduler = init_opt( + encoder=encoder, + predictor=predictor, + wd=wd, + final_wd=final_wd, + start_lr=start_lr, + ref_lr=lr, + final_lr=final_lr, + iterations_per_epoch=ipe, + warmup=warmup, + num_epochs=num_epochs, + ipe_scale=ipe_scale, + mixed_precision=mixed_precision, + betas=betas, + eps=eps, + ) + encoder = DistributedDataParallel(encoder, static_graph=True) + predictor = DistributedDataParallel(predictor, static_graph=True) + target_encoder = DistributedDataParallel(target_encoder) + for p in target_encoder.parameters(): + p.requires_grad = False + + # -- momentum schedule + momentum_scheduler = ( + ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale) + for i in range(int(ipe * num_epochs * ipe_scale) + 1) + ) + + start_epoch = 0 + # -- load training checkpoint + if load_model or os.path.exists(latest_path): + ( + encoder, + predictor, + target_encoder, + optimizer, + scaler, + start_epoch, + ) = load_checkpoint( + r_path=load_path, + encoder=encoder, + predictor=predictor, + target_encoder=target_encoder, + opt=optimizer, + scaler=scaler, + ) + for _ in range(start_epoch * ipe): + scheduler.step() + wd_scheduler.step() + next(momentum_scheduler) + mask_collator.step() + + def save_checkpoint(epoch, path): + if rank != 0: + return + save_dict = { + "encoder": encoder.state_dict(), + "predictor": predictor.state_dict(), + "opt": optimizer.state_dict(), + "scaler": None if scaler is None else scaler.state_dict(), + "target_encoder": target_encoder.state_dict(), + "epoch": epoch, + "loss": loss_meter.avg, + "batch_size": batch_size, + "world_size": world_size, + "lr": lr, + } + try: + torch.save(save_dict, path) + except Exception as e: + logger.info(f"Encountered exception when saving checkpoint: {e}") + + logger.info("Initializing loader...") + loader = iter(unsupervised_loader) + + if skip_batches > 0: + logger.info(f"Skip {skip_batches} batches") + unsupervised_sampler.set_epoch(start_epoch) + for itr in range(skip_batches): + if itr % 10 == 0: + logger.info(f"Skip {itr}/{skip_batches} batches") + try: + udata = next(loader) + except Exception: + loader = iter(unsupervised_loader) + udata = next(loader) + + # -- TRAINING LOOP + for epoch in range(start_epoch, num_epochs): + logger.info("Epoch %d" % (epoch + 1)) + + # -- update distributed-data-loader epoch + unsupervised_sampler.set_epoch(epoch) + + loss_meter = AverageMeter() + input_var_meter = AverageMeter() + input_var_min_meter = AverageMeter() + jepa_loss_meter = AverageMeter() + reg_loss_meter = AverageMeter() + mask_meters = [AverageMeter() for _ in range(len(cfgs_mask))] + gpu_time_meter = AverageMeter() + wall_time_meter = AverageMeter() + + for itr in range(ipe): + itr_start_time = time.time() + + try: + udata, masks_enc, masks_pred = next(loader) + + except StopIteration: + logger.info( + "Exhausted data loaders before completing all planned iterations. Ending epoch early..." + ) + break # Exit the current epoch loop if there are no more data points to process + # except Exception: + # logger.info('Exhausted data loaders. Refreshing...') + # loader = iter(unsupervised_loader) + # udata, masks_enc, masks_pred = next(loader) + assert len(masks_enc) == len( + masks_pred + ), "Currently require num encoder masks = num predictor masks" + + def load_images_and_actions(): + # -- images and action labels + images = to_batch([i.to(device, non_blocking=True) for i in udata[0]]) # List of images to batched tensor + action_labels = udata[1] # Extract actions from the second element + + # Convert to numerical format if actions are string labels + unique_actions = sorted(set(action_labels)) + action_to_idx = {action: idx for idx, action in enumerate(unique_actions)} + action_labels = torch.tensor([action_to_idx[a] for a in action_labels]).to(device) + + # -- Encode actions + encoded_actions = action_encoder(action_labels) # Encode the actions + + # ... (load masks as before, but adapt for images) + _masks_enc, _masks_pred = [], [] + for _me, _mp in zip(masks_enc, masks_pred): + _me = _me.to(device, non_blocking=True) + _mp = _mp.to(device, non_blocking=True) + _masks_enc.append(_me) + _masks_pred.append(_mp) + + return (images, _masks_enc, _masks_pred, encoded_actions) # Return encoded actions + + + images, masks_enc, masks_pred, encoded_actions = load_images_and_actions() + + for _i, m in enumerate(mask_meters): + m.update(masks_enc[_i][0].size(-1)) + + def train_step(): + _new_lr = scheduler.step() + _new_wd = wd_scheduler.step() + # -- + + def forward_target(images): + """ + Encodes the target images using the target encoder and returns the embeddings. + + Args: + images (torch.Tensor): A tensor of shape [B, T, C, H, W] representing a batch + of image sequences. + + Returns: + torch.Tensor: A tensor of shape [B, T, D] representing the encoded image embeddings, + where D is the embedding dimension. + """ + with torch.no_grad(): + image_embeddings = target_encoder(images) + # Normalize the embeddings across the feature dimension + normalized_embeddings = F.layer_norm(image_embeddings, (image_embeddings.size(-1),)) + + # Extract the embeddings for the next frames as targets + next_frame_embeddings = normalized_embeddings[:, 1:, :] # Assuming frames are sequential + return next_frame_embeddings + + def forward_context(images, encoded_actions, h): + """ + Encodes context images with the encoder, combines with encoded actions, + and predicts masked regions using the predictor. + + Args: + images (torch.Tensor): A tensor of shape [B, T, C, H, W] representing a batch + of image sequences. + encoded_actions (torch.Tensor): A tensor of shape [B, T, A] representing encoded actions, + where A is the action embedding dimension. + h (torch.Tensor): The hidden state from the target encoder (optional, might not be used in your case). + + Returns: + torch.Tensor: A list of tensors representing the predicted values for the masked regions. + """ + + image_embeddings = encoder(images, masks_enc) + + # Combine image and action embeddings + combined_embeddings = combine_encodings_concat(image_embeddings, encoded_actions) + + # Predict masked regions + predictions = predictor(combined_embeddings, h, masks_enc, masks_pred) + return predictions + + + def loss_fn(z_next, h_next): + loss = 0.0 + # Compute loss between predicted next frames and ground truth next frames + for zi, hi in zip(z_next, h_next): + loss += torch.mean(torch.abs(zi - hi) ** loss_exp) / loss_exp + loss /= len(h_next) + return loss + + def reg_fn(z): + return sum([torch.sqrt(zi.var(dim=1) + 0.0001) for zi in z]) / len( + z + ) + + # Step 1. Forward + loss_jepa, loss_reg = 0.0, 0.0 + with torch.cuda.amp.autocast(dtype=dtype, enabled=mixed_precision): + h_next = forward_target(images) + z_next = forward_context(images, h_next, encoded_actions) + loss_jepa = loss_fn(z_next, h_next) # jepa prediction loss + pstd_z = reg_fn(z_next) # predictor variance across patches + loss_reg += torch.mean(F.relu(1.0 - pstd_z)) + loss = loss_jepa + reg_coeff * loss_reg + + # Step 2. Backward & step + _enc_norm, _pred_norm = 0.0, 0.0 + if mixed_precision: + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + else: + loss.backward() + if (epoch > warmup) and (clip_grad is not None): + _enc_norm = torch.nn.utils.clip_grad_norm_( + encoder.parameters(), clip_grad + ) + _pred_norm = torch.nn.utils.clip_grad_norm_( + predictor.parameters(), clip_grad + ) + if mixed_precision: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + grad_stats = grad_logger(encoder.named_parameters()) + grad_stats.global_norm = float(_enc_norm) + grad_stats_pred = grad_logger(predictor.named_parameters()) + grad_stats_pred.global_norm = float(_pred_norm) + optimizer.zero_grad() + optim_stats = adamw_logger(optimizer) + + # Step 3. momentum update of target encoder + m = next(momentum_scheduler) + with torch.no_grad(): + for param_q, param_k in zip( + encoder.parameters(), target_encoder.parameters() + ): + param_k.data.mul_(m).add_((1.0 - m) * param_q.detach().data) + + return ( + float(loss), + float(loss_jepa), + float(loss_reg), + _new_lr, + _new_wd, + grad_stats, + grad_stats_pred, + optim_stats, + ) + + ( + loss, + loss_jepa, + loss_reg, + _new_lr, + _new_wd, + grad_stats, + grad_stats_pred, + optim_stats, + ), gpu_etime_ms = gpu_timer(train_step) + iter_elapsed_time_ms = (time.time() - itr_start_time) * 1000.0 + loss_meter.update(loss) + input_var = float( + AllReduce.apply(images.view(images.shape[0], -1).var(dim=1).mean(dim=0)) + ) + input_var_min = float( + AllReduce.apply(torch.min(images.view(images.shape[0], -1).var(dim=1))) + ) + input_var_meter.update(input_var) + input_var_min_meter.update(input_var_min) + jepa_loss_meter.update(loss_jepa) + reg_loss_meter.update(loss_reg) + gpu_time_meter.update(gpu_etime_ms) + wall_time_meter.update(iter_elapsed_time_ms) + + # -- Logging + def log_stats(): + csv_logger.log( + epoch + 1, + itr, + loss, + loss_jepa, + loss_reg, + grad_stats.global_norm, + grad_stats_pred.global_norm, + gpu_etime_ms, + iter_elapsed_time_ms, + ) + if (itr % log_freq == 0) or np.isnan(loss) or np.isinf(loss): + logger.info( + "[%d, %5d] loss: %.3f | p%.3f r%.3f | " + "input_var: %.3f %.3f | " + "masks: %s " + "[wd: %.2e] [lr: %.2e] " + "[mem: %.2e] " + "[gpu: %.1f ms]" + "[wall: %.1f ms]" + % ( + epoch + 1, + itr, + loss_meter.avg, + jepa_loss_meter.avg, + reg_loss_meter.avg, + input_var_meter.avg, + input_var_min_meter.avg, + "[" + + ", ".join(["%.1f" % m.avg for m in mask_meters]) + + "]", + _new_wd, + _new_lr, + torch.cuda.max_memory_allocated() / 1024.0**2, + gpu_time_meter.avg, + wall_time_meter.avg, + ) + ) + + if optim_stats is not None: + logger.info( + "[%d, %5d] first moment: %.2e [%.2e %.2e] second moment: %.2e [%.2e %.2e]" + % ( + epoch + 1, + itr, + optim_stats.get("exp_avg").avg, + optim_stats.get("exp_avg").min, + optim_stats.get("exp_avg").max, + optim_stats.get("exp_avg_sq").avg, + optim_stats.get("exp_avg_sq").min, + optim_stats.get("exp_avg_sq").max, + ) + ) + + if grad_stats is not None: + logger.info( + "[%d, %5d] enc_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e" + % ( + epoch + 1, + itr, + grad_stats.first_layer, + grad_stats.last_layer, + grad_stats.min, + grad_stats.max, + grad_stats.global_norm, + ) + ) + + if grad_stats_pred is not None: + logger.info( + "[%d, %5d] pred_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e" + % ( + epoch + 1, + itr, + grad_stats_pred.first_layer, + grad_stats_pred.last_layer, + grad_stats_pred.min, + grad_stats_pred.max, + grad_stats_pred.global_norm, + ) + ) + + log_stats() + assert not np.isnan(loss), "loss is nan" + + # -- Save Checkpoint + logger.info("avg. loss %.3f" % loss_meter.avg) + # -- Save Last + if epoch % checkpoint_freq == 0 or epoch == (num_epochs - 1): + save_checkpoint(epoch + 1, latest_path) + if save_every_freq > 0 and epoch % save_every_freq == 0: + save_every_file = f"{tag}-e{epoch}.pth.tar" + save_every_path = os.path.join(folder, save_every_file) + save_checkpoint(epoch + 1, save_every_path) diff --git a/src/datasets/image_dataset.py b/src/datasets/image_dataset.py index d655985..d20754c 100644 --- a/src/datasets/image_dataset.py +++ b/src/datasets/image_dataset.py @@ -90,37 +90,37 @@ def make_imagedataset( class ImageDataset(torch.utils.data.Dataset): def __init__( self, - data_paths, # List of directories containing timestamped image folders + csv_file_path, # List of directories containing timestamped image folders transform=None, shared_transform=None, - ): - self.data_paths = data_paths + ): self.transform = transform self.shared_transform = shared_transform - # Load Image Paths and Labels - self.samples = [] - for data_path in self.data_paths: - timestamped_folders = [ - f - for f in os.listdir(data_path) - if os.path.isdir(os.path.join(data_path, f)) - ] - for folder in timestamped_folders: - folder_path = os.path.join(data_path, folder) - image_files = sorted( - os.listdir(folder_path) - ) # Sort for sequential order - self.samples.extend( - [(folder_path, image_file) for image_file in image_files] - ) # Store (folder_path, image_filename) tuples - - if len(self.samples) == 0: - raise RuntimeError(f"Found 0 image files in the data_paths: {data_paths}") + # Load Image Paths and Labels from CSV + df = pd.read_csv(csv_file_path, header=None, delimiter=" ") + self.samples = [] # List to store (image_path, action_label) tuples + + for _, row in df.iterrows(): + folder_path = row[0] + action_filepath = os.path.join(folder_path, "action_data.csv") + if os.path.exists(action_filepath): + try: + action_df = pd.read_csv(action_filepath) + except pd.errors.EmptyDataError: + logger.warning( + f"Skipping folder '{folder_path}' due to empty action_data.csv" + ) + continue + self.samples.extend(list(action_df[["image_path", "maneuver"]].values)) # Store image paths and action labels + + if not self.samples: + raise RuntimeError( + f"Found 0 image files with corresponding action data in the CSV: {csv_file_path}" + ) def __getitem__(self, index): - folder_path, image_filename = self.samples[index] - image_path = os.path.join(folder_path, image_filename) + image_path, action_label = self.samples[index] # Load Image image = Image.open(image_path) @@ -131,55 +131,7 @@ def __getitem__(self, index): if self.transform is not None: image = self.transform(image) - # Load Action Data - action_filepath = os.path.join(folder_path, "action_data.csv") - action_df = pd.read_csv(action_filepath) - - # Extract Timestamp from Image Filename - image_timestamp = self.extract_timestamp_from_filename(image_filename) - action_label = self.get_action_label_for_timestamp(action_df, image_timestamp) - - return ( - image, - action_label, - ) # Return the image and its corresponding action label - - def extract_timestamp_from_filename(self, filename): - timestamp_str = os.path.splitext(filename)[0].split("_")[ - 0 - ] # Get '20240516_175159' - timestamp = datetime.strptime(timestamp_str, "%Y%m%d_%H%M%S") - return timestamp - - def get_action_labels_for_clip(self, action_df, image_timestamps): - action_labels = [] - for timestamp in image_timestamps: - # Find closest action timestamps before and after the image timestamp - before_idx = action_df["timestamp"].searchsorted(timestamp) - 1 - after_idx = before_idx + 1 - - # Handle edge cases (first or last image) - before_idx = max(0, before_idx) - after_idx = min(len(action_df) - 1, after_idx) - - # Get action labels and timestamps - action_before = action_df.iloc[before_idx]["action_name"] - action_after = action_df.iloc[after_idx]["action_name"] - timestamp_before = action_df.iloc[before_idx]["timestamp"] - timestamp_after = action_df.iloc[after_idx]["timestamp"] - - # Linear Interpolation (if needed, can be removed for simple nearest neighbor) - weight_after = (timestamp - timestamp_before) / ( - timestamp_after - timestamp_before - ) - if weight_after < 0.5: # Closer to the previous action - action_label = action_before - else: # Closer to the next action - action_label = action_after - - action_labels.append(action_label) - - return action_labels + return image, action_label # Return the image and its corresponding action label @@ -225,7 +177,7 @@ def __len__(self): def make_egovehicle_imagedataset( - data_paths, + csv_file_path, batch_size, transform=None, shared_transform=None, @@ -237,7 +189,7 @@ def make_egovehicle_imagedataset( pin_mem=True, ): dataset = ImageDataset( - data_paths=data_paths, + csv_file_path=csv_file_path, transform=transform, shared_transform=shared_transform, ) diff --git a/src/utils/tensors.py b/src/utils/tensors.py index 4ec8767..14369a5 100644 --- a/src/utils/tensors.py +++ b/src/utils/tensors.py @@ -14,6 +14,17 @@ logger = getLogger() +def to_batch(images): + """Converts a list of images into a batched tensor. + + Args: + images (list): A list of image tensors, each of shape [C, H, W]. + + Returns: + torch.Tensor: A batched tensor of shape [B, C, H, W], where B is the batch size. + """ + return torch.stack(images, dim=0) + def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf From 0bc9b12f8e7078d4debe095cea048ffdc51ecbaf Mon Sep 17 00:00:00 2001 From: Munir Jojo-Verge Date: Thu, 23 May 2024 14:20:15 +0000 Subject: [PATCH 7/8] fix: WIP main_with_actions and training_with_actions --- app/main_with_actions.py | 110 +++++++++++++++----------------- app/vjepa/train_with_actions.py | 20 +----- 2 files changed, 54 insertions(+), 76 deletions(-) diff --git a/app/main_with_actions.py b/app/main_with_actions.py index b111d94..769cc89 100644 --- a/app/main_with_actions.py +++ b/app/main_with_actions.py @@ -1,3 +1,5 @@ +# In app/main_with_actions.py + # Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # @@ -6,78 +8,68 @@ # import argparse -import importlib - -import multiprocessing as mp - import pprint import yaml +import os +import logging from app.scaffold import main as app_main from src.utils.distributed import init_distributed +from app.vjepa.train_with_actions import main as train # Import the main function from train_with_actions.py + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--fname", + type=str, + help="name of config file to load", + default="configs/pretrain/vith16_384.yaml", + ) + parser.add_argument( + "--devices", + type=str, + nargs="+", + default=["cuda:0"], + help="which devices to use on local machine", + ) + args = parser.parse_args() -parser = argparse.ArgumentParser() -parser.add_argument( - "--fname", type=str, help="name of config file to load", default="configs/pretrain/vith16_384.yaml" -) -parser.add_argument( - "--devices", - type=str, - nargs="+", - default=["cuda:0"], - help="which devices to use on local machine", -) - - -def process_main(rank, fname, world_size, devices): - import os - - os.environ["CUDA_VISIBLE_DEVICES"] = str(devices[rank].split(":")[-1]) - - import logging - from src.utils.logging import get_logger - - logger = get_logger(force=True) - if rank == 0: - logger.setLevel(logging.INFO) - else: - logger.setLevel(logging.ERROR) - - logger.info(f"called-params {fname}") + # Initialize logging + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + logger.info(f"Called parameters: {args.fname}") - # Load config - params = None - with open(fname, "r") as y_file: + # Load configuration from YAML file + with open(args.fname, "r") as y_file: params = yaml.load(y_file, Loader=yaml.FullLoader) - logger.info("loaded params...") + logger.info("Loaded configuration parameters.") + + # Pretty print the configuration parameters + pprint.PrettyPrinter(indent=4).pprint(params) - # Log config - if rank == 0: - pprint.PrettyPrinter(indent=4).pprint(params) - dump = os.path.join(params["logging"]["folder"], "params-pretrain.yaml") - with open(dump, "w") as f: - yaml.dump(params, f) + # Save the configuration parameters to a YAML file + dump_file = os.path.join(params["logging"]["folder"], "params-pretrain.yaml") + os.makedirs(os.path.dirname(dump_file), exist_ok=True) + with open(dump_file, "w") as f: + yaml.dump(params, f) - # Init distributed (access to comm between GPUS on same machine) - world_size, rank = init_distributed(rank_and_world_size=(rank, world_size)) + # Initialize distributed training (for single GPU, world_size and rank will be 1 and 0 respectively) + num_gpus = len(args.devices) + rank = 0 # Since you're on a single GPU + world_size, rank = init_distributed(rank_and_world_size=(rank, num_gpus)) # Update for single GPU logger.info(f"Running... (rank: {rank}/{world_size})") + + # Setup environment variables for GPU visibility + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.devices[rank].split(":")[-1]) # Launch the app with loaded config - # app_main(params["app"], args=params) - - # Update this line to load your new train_with_actions module: - train_module = importlib.import_module(f"app.{params['app']}.train_with_actions") - - # Launch the app with loaded config (use the imported train_module): - train_module.main(args=params) - + try: + train(args=params, world_size=world_size, rank=rank) + except Exception as e: + logger.error(f"An error occurred during training: {e}") + raise e if __name__ == "__main__": - args = parser.parse_args() - num_gpus = len(args.devices) - mp.set_start_method("spawn") - for rank in range(num_gpus): - mp.Process( - target=process_main, args=(rank, args.fname, num_gpus, args.devices) - ).start() + main() diff --git a/app/vjepa/train_with_actions.py b/app/vjepa/train_with_actions.py index f949095..bb439fc 100644 --- a/app/vjepa/train_with_actions.py +++ b/app/vjepa/train_with_actions.py @@ -94,7 +94,7 @@ def generate_csv_file(data_dir, csv_filename="v-jepa-pretrain.csv"): logger.info(f"CSV file generation complete. Found {len(valid_folders)} valid folders.") -def main(args, resume_preempt=False): +def main(args, world_size=1, rank=0, resume_preempt=False): # First let's go over the folders and generate the # ----------------------------------------------------------------------- # @@ -192,26 +192,12 @@ def main(args, resume_preempt=False): tag = cfgs_logging.get("write_tag") # ----------------------------------------------------------------------- # - # ----------------------------------------------------------------------- # - - # Generate CSV file (only if not already exists) - csv_filename = "v-jepa-pretrain.csv" - # if not os.path.exists(os.path.join(dataset_paths[0], csv_filename)): - generate_csv_file(dataset_paths[0]) - + # ----------------------------------------------------------------------- # np.random.seed(seed) torch.manual_seed(seed) torch.backends.cudnn.benchmark = True - try: - mp.set_start_method("spawn") - except Exception: - pass - - # -- init torch distributed backend - world_size, rank = init_distributed() - logger.info(f"Initialized (rank/world-size) {rank}/{world_size}") - + # -- set device if not torch.cuda.is_available(): device = torch.device("cpu") From 9e7ece1c2333e11adb936b36ede94d71f39b4b59 Mon Sep 17 00:00:00 2001 From: Munir Jojo-Verge Date: Fri, 24 May 2024 20:21:31 +0000 Subject: [PATCH 8/8] fix: Code Style. --- app/main_with_actions.py | 3 +- .../test_validate_data_loading_pipeline.py | 74 +++++++++ app/vjepa/train.py | 3 +- app/vjepa/train_with_actions.py | 142 +++++++--------- app/vjepa/transforms.py | 35 ++++ app/vjepa/utils.py | 5 +- configs/pretrain/vith16_384.yaml | 2 +- evals/image_classification_frozen/eval.py | 3 +- evals/video_classification_frozen/eval.py | 3 +- logs_and_checkpoints/jepa_r0.csv | 45 +++++ logs_and_checkpoints/params-pretrain.yaml | 88 ++++++++++ src/datasets/data_manager.py | 16 +- src/datasets/image_dataset.py | 154 ++++++++++++------ src/datasets/video_dataset.py | 2 +- src/masks/multiblock3d.py | 14 ++ src/masks/random_tube.py | 18 +- src/models/utils/patch_embed.py | 26 ++- src/models/vision_transformer.py | 65 +++----- src/utils/distributed.py | 4 +- 19 files changed, 503 insertions(+), 199 deletions(-) create mode 100644 app/vjepa/test_validate_data_loading_pipeline.py create mode 100644 logs_and_checkpoints/jepa_r0.csv create mode 100644 logs_and_checkpoints/params-pretrain.yaml diff --git a/app/main_with_actions.py b/app/main_with_actions.py index 769cc89..3a2b3b8 100644 --- a/app/main_with_actions.py +++ b/app/main_with_actions.py @@ -12,6 +12,7 @@ import yaml import os import logging +import traceback from app.scaffold import main as app_main from src.utils.distributed import init_distributed @@ -68,7 +69,7 @@ def main(): try: train(args=params, world_size=world_size, rank=rank) except Exception as e: - logger.error(f"An error occurred during training: {e}") + logger.error(f"An error occurred during training: {traceback.format_exc}") raise e if __name__ == "__main__": diff --git a/app/vjepa/test_validate_data_loading_pipeline.py b/app/vjepa/test_validate_data_loading_pipeline.py new file mode 100644 index 0000000..c6ffad7 --- /dev/null +++ b/app/vjepa/test_validate_data_loading_pipeline.py @@ -0,0 +1,74 @@ +from torch.utils.data import DataLoader +from torchvision import transforms +import yaml +from src.datasets.image_dataset import ImageDataset, SequentialDriveSampler +from src.masks.random_tube import MaskCollatorWithActions as TubeMaskCollator +from src.masks.multiblock3d import MaskCollatorWithActions as MB3DMaskCollator +from src.utils.logging import ( + get_logger, +) + +logger = get_logger(__name__) + +data_dir = "/home/ncdev/Documents/darwin/data/raw" +filename = "/home/ncdev/Documents/darwin/jepa/configs/pretrain/vith16_384.yaml" + +# Load configuration from YAML file +with open(filename, "r") as y_file: + params = yaml.load(y_file, Loader=yaml.FullLoader) +logger.info("Loaded configuration parameters.") + +def test_data_loader(data_dir, batch_size, mask_collator): + # Define the necessary transforms + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Lambda(lambda x: x[:3] if x.size(0) > 3 else x), # Convert to RGB if needed + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + dataset = ImageDataset(data_dir=data_dir, transform=transform) + sampler = SequentialDriveSampler(dataset) + data_loader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + collate_fn=mask_collator, + num_workers=4, + pin_memory=True, + drop_last=True, + ) + + for batch_idx, (images, maneuvers, masks_enc, masks_dec) in enumerate(data_loader): + print(f"Batch {batch_idx + 1}") + print(f"Images shape: {images.shape}") + print(f"Maneuvers shape: {maneuvers.shape}") + print(f"Encoder Masks shape: {masks_enc[0].shape}") + print(f"Decoder Masks shape: {masks_dec[0].shape}") + print("---") + + if batch_idx == 4: + break + +cfgs_mask = params.get("mask") + +# Test with TubeMaskCollator +print("Testing with TubeMaskCollator") +tube_mask_collator = TubeMaskCollator( + crop_size=224, + num_frames=16, + patch_size=16, + tubelet_size=2, + cfgs_mask=cfgs_mask, +) +test_data_loader(data_dir=data_dir, batch_size=32, mask_collator=tube_mask_collator) + +# Test with MB3DMaskCollator +print("Testing with MB3DMaskCollator") +mb3d_mask_collator = MB3DMaskCollator( + crop_size=224, + num_frames=16, + patch_size=16, + tubelet_size=2, + cfgs_mask=cfgs_mask, +) +test_data_loader(data_dir=data_dir, batch_size=32, mask_collator=mb3d_mask_collator) \ No newline at end of file diff --git a/app/vjepa/train.py b/app/vjepa/train.py index 568fdfe..4002f11 100644 --- a/app/vjepa/train.py +++ b/app/vjepa/train.py @@ -20,6 +20,7 @@ import copy import time import numpy as np +import traceback import torch import torch.multiprocessing as mp @@ -343,7 +344,7 @@ def save_checkpoint(epoch, path): try: torch.save(save_dict, path) except Exception as e: - logger.info(f'Encountered exception when saving checkpoint: {e}') + logger.info(f'Encountered exception when saving checkpoint: {traceback.format_exc}') logger.info('Initializing loader...') loader = iter(unsupervised_loader) diff --git a/app/vjepa/train_with_actions.py b/app/vjepa/train_with_actions.py index bb439fc..5190e49 100644 --- a/app/vjepa/train_with_actions.py +++ b/app/vjepa/train_with_actions.py @@ -6,7 +6,7 @@ # import os -import csv + # -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS try: @@ -21,15 +21,19 @@ import copy import time import numpy as np +import traceback import torch import torch.multiprocessing as mp import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel +from torchvision.transforms import ToPILImage + +from einops import rearrange from src.datasets.data_manager import init_data -from src.masks.random_tube import MaskCollator as TubeMaskCollator -from src.masks.multiblock3d import MaskCollator as MB3DMaskCollator +from src.masks.random_tube import MaskCollatorWithActions as TubeMaskCollatorWithActions +from src.masks.multiblock3d import MaskCollatorWithActions as MB3DMaskCollator from src.masks.utils import apply_masks from src.utils.distributed import init_distributed, AllReduce from src.utils.logging import ( @@ -53,7 +57,7 @@ init_video_model, init_opt, ) -from app.vjepa.transforms import make_transforms +from app.vjepa.transforms import make_image_transforms # -- @@ -70,30 +74,6 @@ logger = get_logger(__name__) - -def generate_csv_file(data_dir, csv_filename="v-jepa-pretrain.csv"): - csv_filepath = os.path.join(data_dir, csv_filename) - logger.info(f"Generating CSV file: {csv_filepath}") - - valid_folders = [] - for folder_name in os.listdir(data_dir): - folder_path = os.path.join(data_dir, folder_name) - action_filepath = os.path.join(folder_path, "action_data.csv") - if os.path.isdir(folder_path) and os.path.isfile(action_filepath): - valid_folders.append(folder_path) - else: - logger.warning( - f"Skipping folder '{folder_name}' due to missing or invalid action_data.csv" - ) - - with open(csv_filepath, "w", newline="") as csvfile: - writer = csv.writer(csvfile, delimiter=" ") - for folder_path in valid_folders: - writer.writerow([folder_path, 0]) # Write folder path and dummy label (0) - - logger.info(f"CSV file generation complete. Found {len(valid_folders)} valid folders.") - - def main(args, world_size=1, rank=0, resume_preempt=False): # First let's go over the folders and generate the @@ -260,20 +240,19 @@ def main(args, world_size=1, rank=0, resume_preempt=False): ) else: logger.info("Initializing random tube mask") - mask_collator = TubeMaskCollator( + mask_collator = TubeMaskCollatorWithActions( crop_size=crop_size, num_frames=num_frames, patch_size=patch_size, tubelet_size=tubelet_size, cfgs_mask=cfgs_mask, ) - transform = make_transforms( + transform = make_image_transforms( random_horizontal_flip=True, random_resize_aspect_ratio=ar_range, random_resize_scale=rr_scale, reprob=reprob, - auto_augment=use_aa, - motion_shift=motion_shift, + auto_augment=use_aa, crop_size=crop_size, ) @@ -377,22 +356,10 @@ def save_checkpoint(epoch, path): try: torch.save(save_dict, path) except Exception as e: - logger.info(f"Encountered exception when saving checkpoint: {e}") + logger.info(f"Encountered exception when saving checkpoint: {traceback.format_exc}") logger.info("Initializing loader...") - loader = iter(unsupervised_loader) - - if skip_batches > 0: - logger.info(f"Skip {skip_batches} batches") - unsupervised_sampler.set_epoch(start_epoch) - for itr in range(skip_batches): - if itr % 10 == 0: - logger.info(f"Skip {itr}/{skip_batches} batches") - try: - udata = next(loader) - except Exception: - loader = iter(unsupervised_loader) - udata = next(loader) + loader = iter(unsupervised_loader) # -- TRAINING LOOP for epoch in range(start_epoch, num_epochs): @@ -414,46 +381,50 @@ def save_checkpoint(epoch, path): itr_start_time = time.time() try: - udata, masks_enc, masks_pred = next(loader) + collated_images, collated_maneuvers, masks_enc, masks_pred = next(loader) except StopIteration: logger.info( "Exhausted data loaders before completing all planned iterations. Ending epoch early..." ) break # Exit the current epoch loop if there are no more data points to process - # except Exception: - # logger.info('Exhausted data loaders. Refreshing...') - # loader = iter(unsupervised_loader) - # udata, masks_enc, masks_pred = next(loader) + assert len(masks_enc) == len( masks_pred ), "Currently require num encoder masks = num predictor masks" def load_images_and_actions(): - # -- images and action labels - images = to_batch([i.to(device, non_blocking=True) for i in udata[0]]) # List of images to batched tensor - action_labels = udata[1] # Extract actions from the second element - - # Convert to numerical format if actions are string labels - unique_actions = sorted(set(action_labels)) - action_to_idx = {action: idx for idx, action in enumerate(unique_actions)} - action_labels = torch.tensor([action_to_idx[a] for a in action_labels]).to(device) - - # -- Encode actions - encoded_actions = action_encoder(action_labels) # Encode the actions - - # ... (load masks as before, but adapt for images) - _masks_enc, _masks_pred = [], [] - for _me, _mp in zip(masks_enc, masks_pred): - _me = _me.to(device, non_blocking=True) - _mp = _mp.to(device, non_blocking=True) - _masks_enc.append(_me) - _masks_pred.append(_mp) - - return (images, _masks_enc, _masks_pred, encoded_actions) # Return encoded actions + try: + images = [] + to_pil = ToPILImage() # Create an instance of ToPILImage + + for i in range(len(collated_images)): + image = collated_images[i] + image = to_pil(image) # Convert the PyTorch tensor to a PIL Image + image = transform(image) # Apply the transformation to the PIL image + images.append(image) + + # Stack the transformed images into a single batched tensor + images = torch.stack(images, dim=0).to(device, non_blocking=True) + + # -- Encode actions + encoded_actions = action_encoder(collated_maneuvers) + + # ... (load masks as before) + _masks_enc, _masks_pred = [], [] + for _me, _mp in zip(masks_enc, masks_pred): + _me = _me.to(device, non_blocking=True) + _mp = _mp.to(device, non_blocking=True) + _masks_enc.append(_me) + _masks_pred.append(_mp) + + return images, encoded_actions, _masks_enc, _masks_pred + except Exception as e: + logger.error(f"Error in load_images_and_actions: {str(e)}") + raise e - images, masks_enc, masks_pred, encoded_actions = load_images_and_actions() + images, encoded_actions, masks_enc, masks_pred = load_images_and_actions() for _i, m in enumerate(mask_meters): m.update(masks_enc[_i][0].size(-1)) @@ -486,7 +457,7 @@ def forward_target(images): def forward_context(images, encoded_actions, h): """ - Encodes context images with the encoder, combines with encoded actions, + Encodes context images with the encoder, combines with encoded actions, and predicts masked regions using the predictor. Args: @@ -494,20 +465,23 @@ def forward_context(images, encoded_actions, h): of image sequences. encoded_actions (torch.Tensor): A tensor of shape [B, T, A] representing encoded actions, where A is the action embedding dimension. - h (torch.Tensor): The hidden state from the target encoder (optional, might not be used in your case). + h (torch.Tensor): The hidden state from the target encoder. (Ground truth) Returns: torch.Tensor: A list of tensors representing the predicted values for the masked regions. """ - - image_embeddings = encoder(images, masks_enc) - - # Combine image and action embeddings - combined_embeddings = combine_encodings_concat(image_embeddings, encoded_actions) - - # Predict masked regions - predictions = predictor(combined_embeddings, h, masks_enc, masks_pred) - return predictions + try: + image_embeddings = encoder(images, masks_enc) + + # Combine image and action embeddings + combined_embeddings = combine_encodings_concat(image_embeddings, encoded_actions) + + # Predict masked regions + predictions = predictor(combined_embeddings, h, masks_enc, masks_pred) + return predictions + except Exception as e: + logger.error(f"Error in forward_context: {str(e)}") + raise e def loss_fn(z_next, h_next): diff --git a/app/vjepa/transforms.py b/app/vjepa/transforms.py index d5e6d0c..aa24ead 100644 --- a/app/vjepa/transforms.py +++ b/app/vjepa/transforms.py @@ -12,6 +12,41 @@ from src.datasets.utils.video.randerase import RandomErasing +def make_image_transforms( + random_horizontal_flip=True, + random_resize_aspect_ratio=(3 / 4, 4 / 3), + random_resize_scale=(0.3, 1.0), + reprob=0.0, + auto_augment=False, + crop_size=224, + normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), +): + + transform_list = [ + transforms.RandomResizedCrop( + crop_size, + scale=random_resize_scale, + ratio=random_resize_aspect_ratio, + ), + ] + + if random_horizontal_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + + if auto_augment: + transform_list.append(transforms.AutoAugment()) + + transform_list.extend([ + transforms.ToTensor(), + transforms.Normalize(mean=normalize[0], std=normalize[1]), + ]) + + if reprob > 0: + transform_list.append(transforms.RandomErasing(p=reprob)) + + return transforms.Compose(transform_list) + + def make_transforms( random_horizontal_flip=True, random_resize_aspect_ratio=(3 / 4, 4 / 3), diff --git a/app/vjepa/utils.py b/app/vjepa/utils.py index 58e0339..046018a 100644 --- a/app/vjepa/utils.py +++ b/app/vjepa/utils.py @@ -9,6 +9,7 @@ import sys import warnings import yaml +import traceback import torch @@ -35,7 +36,7 @@ def load_checkpoint( try: checkpoint = torch.load(r_path, map_location=torch.device("cpu")) except Exception as e: - logger.info(f"Encountered exception when loading checkpoint {e}") + logger.info(f"Encountered exception when loading checkpoint {traceback.format_exc}") epoch = 0 try: @@ -69,7 +70,7 @@ def load_checkpoint( del checkpoint except Exception as e: - logger.info(f"Encountered exception when loading checkpoint {e}") + logger.info(f"Encountered exception when loading checkpoint {traceback.format_exc}") epoch = 0 return ( diff --git a/configs/pretrain/vith16_384.yaml b/configs/pretrain/vith16_384.yaml index a4cf73b..9c9646a 100644 --- a/configs/pretrain/vith16_384.yaml +++ b/configs/pretrain/vith16_384.yaml @@ -4,7 +4,7 @@ tasks_per_node: 8 data: dataset_type: egovehicle_imagedataset datasets: - - /home/ncdev/Documents/darwin/data/raw/v-jepa-pretrain.csv + - /home/ncdev/Documents/darwin/data/raw/ # - /your_path_to_ssv2_csv_file_index.csv # - /your_path_to_howto100m_csv_file_index.csv decode_one_clip: true diff --git a/evals/image_classification_frozen/eval.py b/evals/image_classification_frozen/eval.py index 91c57fb..0be356b 100644 --- a/evals/image_classification_frozen/eval.py +++ b/evals/image_classification_frozen/eval.py @@ -19,6 +19,7 @@ import logging import pprint +import traceback import numpy as np @@ -344,7 +345,7 @@ def load_checkpoint(device, r_path, classifier, opt, scaler): del checkpoint except Exception as e: - logger.info(f"Encountered exception when loading checkpoint {e}") + logger.info(f"Encountered exception when loading checkpoint {traceback.format_exc}") epoch = 0 return classifier, opt, scaler, epoch diff --git a/evals/video_classification_frozen/eval.py b/evals/video_classification_frozen/eval.py index 3093f96..c1e4164 100644 --- a/evals/video_classification_frozen/eval.py +++ b/evals/video_classification_frozen/eval.py @@ -19,6 +19,7 @@ import logging import pprint +import traceback import numpy as np @@ -417,7 +418,7 @@ def load_checkpoint(device, r_path, classifier, opt, scaler): del checkpoint except Exception as e: - logger.info(f"Encountered exception when loading checkpoint {e}") + logger.info(f"Encountered exception when loading checkpoint {traceback.format_exc}") epoch = 0 return classifier, opt, scaler, epoch diff --git a/logs_and_checkpoints/jepa_r0.csv b/logs_and_checkpoints/jepa_r0.csv new file mode 100644 index 0000000..eb59245 --- /dev/null +++ b/logs_and_checkpoints/jepa_r0.csv @@ -0,0 +1,45 @@ +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) diff --git a/logs_and_checkpoints/params-pretrain.yaml b/logs_and_checkpoints/params-pretrain.yaml new file mode 100644 index 0000000..fdc80bd --- /dev/null +++ b/logs_and_checkpoints/params-pretrain.yaml @@ -0,0 +1,88 @@ +app: vjepa +data: + batch_size: 10 + clip_duration: null + crop_size: 384 + dataset_type: egovehicle_imagedataset + datasets: + - /home/ncdev/Documents/darwin/data/raw/ + decode_one_clip: true + filter_short_videos: false + num_clips: 1 + num_frames: 16 + num_workers: 12 + patch_size: 16 + pin_mem: true + sampling_rate: 4 + tubelet_size: 2 +data_aug: + auto_augment: false + motion_shift: false + random_resize_aspect_ratio: + - 0.75 + - 1.35 + random_resize_scale: + - 0.3 + - 1.0 + reprob: 0.0 +logging: + folder: /home/ncdev/Documents/darwin/jepa/logs_and_checkpoints + write_tag: jepa +loss: + loss_exp: 1.0 + reg_coeff: 0.0 +mask: +- aspect_ratio: + - 0.75 + - 1.5 + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 8 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 +- aspect_ratio: + - 0.75 + - 1.5 + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 2 + spatial_scale: + - 0.7 + - 0.7 + temporal_scale: + - 1.0 + - 1.0 +meta: + dtype: bfloat16 + eval_freq: 100 + load_checkpoint: false + read_checkpoint: null + seed: 234 + use_sdpa: true +model: + model_name: vit_huge + pred_depth: 12 + pred_embed_dim: 384 + uniform_power: true + use_mask_tokens: true + zero_init_mask_tokens: true +nodes: 30 +optimization: + clip_grad: 10.0 + ema: + - 0.998 + - 1.0 + epochs: 300 + final_lr: 1.0e-06 + final_weight_decay: 0.4 + ipe: 300 + ipe_scale: 1.25 + lr: 0.000625 + start_lr: 0.0002 + warmup: 40 + weight_decay: 0.04 +tasks_per_node: 8 diff --git a/src/datasets/data_manager.py b/src/datasets/data_manager.py index ac826f0..0adff49 100644 --- a/src/datasets/data_manager.py +++ b/src/datasets/data_manager.py @@ -95,17 +95,17 @@ def init_data( ) elif data.lower() == "egovehicle_imagedataset": from src.datasets.image_dataset import make_egovehicle_imagedataset - - dataset, data_loader, dist_sampler = make_egovehicle_imagedataset( - data_paths=root_path, + + dataset, data_loader, dist_sampler = make_egovehicle_imagedataset( + data_dir=root_path[0], batch_size=batch_size, transform=transform, - shared_transform=shared_transform, - rank=rank, - world_size=world_size, - collator=collator, - drop_last=drop_last, + shared_transform=shared_transform, + mask_collator=collator, num_workers=num_workers, + world_size=world_size, + rank=rank, pin_mem=pin_mem, + drop_last=drop_last, ) return (data_loader, dist_sampler) diff --git a/src/datasets/image_dataset.py b/src/datasets/image_dataset.py index d20754c..fe6c7b5 100644 --- a/src/datasets/image_dataset.py +++ b/src/datasets/image_dataset.py @@ -6,6 +6,9 @@ # import os +import PIL +from collections import defaultdict + from logging import getLogger @@ -88,52 +91,65 @@ def make_imagedataset( class ImageDataset(torch.utils.data.Dataset): - def __init__( - self, - csv_file_path, # List of directories containing timestamped image folders - transform=None, - shared_transform=None, - ): + def __init__(self, data_dir, transform=None, shared_transform=None): + self.data_dir = data_dir self.transform = transform self.shared_transform = shared_transform - # Load Image Paths and Labels from CSV - df = pd.read_csv(csv_file_path, header=None, delimiter=" ") - self.samples = [] # List to store (image_path, action_label) tuples + # Load data from drive folders + self.samples = [] + self.drive_data = {} - for _, row in df.iterrows(): - folder_path = row[0] - action_filepath = os.path.join(folder_path, "action_data.csv") - if os.path.exists(action_filepath): - try: - action_df = pd.read_csv(action_filepath) - except pd.errors.EmptyDataError: - logger.warning( - f"Skipping folder '{folder_path}' due to empty action_data.csv" - ) - continue - self.samples.extend(list(action_df[["image_path", "maneuver"]].values)) # Store image paths and action labels + try: + drive_folders = [f for f in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, f))] + for drive_folder in drive_folders: + drive_path = os.path.join(data_dir, drive_folder) + csv_file = os.path.join(drive_path, "drive_data.csv") - if not self.samples: - raise RuntimeError( - f"Found 0 image files with corresponding action data in the CSV: {csv_file_path}" - ) + if not os.path.exists(csv_file): + logger.warning(f"Skipping drive folder '{drive_folder}' due to missing drive_data.csv file.") + continue + try: + drive_df = pd.read_csv(csv_file) + self.drive_data[drive_folder] = drive_df + drive_samples = [(os.path.join(drive_path, row['path_to_image']), row['maneuverID']) for _, row in drive_df.iterrows()] + self.samples.extend(drive_samples) + except (pd.errors.EmptyDataError, KeyError) as e: + logger.warning(f"Skipping drive folder '{drive_folder}' due to error: {str(e)}") + + if len(self.samples) == 0: + raise RuntimeError(f"No valid drive folders found in the dataset directory: {data_dir}") + + except OSError as e: + raise RuntimeError(f"Error accessing dataset directory: {data_dir}. Exception: {str(e)}") + def __getitem__(self, index): - image_path, action_label = self.samples[index] + try: + image_path, maneuver_id = self.samples[index] - # Load Image - image = Image.open(image_path) + # Load image + try: + image = Image.open(image_path).convert("RGB") # Convert to RGB here + except (IOError, PIL.UnidentifiedImageError) as e: + logger.warning(f"Error loading image: {image_path}. Exception: {str(e)}") + raise e - # Apply Transforms - if self.shared_transform is not None: - image = self.shared_transform(image) - if self.transform is not None: - image = self.transform(image) + # Apply transforms + if self.shared_transform is not None: + image = self.shared_transform(image) + if self.transform is not None: + image = self.transform(image) - return image, action_label # Return the image and its corresponding action label + return image, maneuver_id + except IndexError as e: + raise IndexError(f"Index {index} is out of bounds for the dataset.") + def __len__(self): + if not self.samples: + raise RuntimeError("Dataset is empty. No valid samples found.") + return len(self.samples) class SequentialImageSampler(Sampler): def __init__(self, image_dataset, num_replicas=None, rank=None): @@ -175,46 +191,82 @@ def __len__(self): num_samples_per_worker += total_samples % self.num_replicas return num_samples_per_worker +def collate_fn(batch): + images, maneuvers = zip(*batch) + + # Stack images into a single tensor + images = torch.stack(images, dim=0) + + # Convert maneuvers to a tensor + maneuvers = torch.tensor([m for m in maneuvers]) + + return images, maneuvers + +class SequentialDriveSampler(Sampler): + def __init__(self, image_dataset): + self.image_dataset = image_dataset + self.drive_indices = self._get_drive_indices() + + def _get_drive_indices(self): + drive_indices = defaultdict(list) + for idx, (image_path, _) in enumerate(self.image_dataset.samples): + drive_folder = os.path.basename(os.path.dirname(image_path)) + drive_indices[drive_folder].append(idx) + return drive_indices + + def __iter__(self): + for drive_folder, indices in self.drive_indices.items(): + yield from indices + def __len__(self): + return len(self.image_dataset) + def make_egovehicle_imagedataset( - csv_file_path, + data_dir, batch_size, transform=None, shared_transform=None, + mask_collator=None, + num_workers=10, rank=0, world_size=1, - collator=None, - drop_last=True, - num_workers=10, pin_mem=True, + drop_last=True, ): dataset = ImageDataset( - csv_file_path=csv_file_path, + data_dir=data_dir, transform=transform, shared_transform=shared_transform, ) logger.info("ImageDataset created") - # Ensure that each worker gets a subset of folders while maintaining sequential order - sampler = SequentialImageSampler(dataset, num_replicas=world_size, rank=rank) + # sampler = SequentialDriveSampler(dataset) + dist_sampler = torch.utils.data.distributed.DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=True + ) - # Wrap the sampler with DistributedSampler for shuffling at the folder level - dist_sampler = DistributedSampler( - dataset, num_replicas=world_size, rank=rank, shuffle=True - ) + # data_loader = DataLoader( + # dataset, + # batch_size=batch_size, + # sampler=dist_sampler, + # collate_fn=mask_collator, + # num_workers=num_workers, + # pin_memory=pin_mem, + # drop_last=True, + # ) - # DataLoader should use both samplers data_loader = DataLoader( dataset, - batch_sampler=dist_sampler, # Using batch_sampler instead of sampler - collate_fn=collator, - num_workers=num_workers, - pin_memory=pin_mem, + collate_fn=mask_collator, + sampler=dist_sampler, + batch_size=batch_size, drop_last=drop_last, + pin_memory=pin_mem, + num_workers=num_workers, persistent_workers=num_workers > 0, ) logger.info("ImageDataset data loader created") - return dataset, data_loader, dist_sampler + return dataset, data_loader, dist_sampler diff --git a/src/datasets/video_dataset.py b/src/datasets/video_dataset.py index 5b6e42a..afdbe4c 100644 --- a/src/datasets/video_dataset.py +++ b/src/datasets/video_dataset.py @@ -184,7 +184,7 @@ def split_into_clips(video): buffer = [self.transform(clip) for clip in buffer] # Load Action Data - action_filepath = os.path.join(os.path.dirname(sample), "action_data.csv") + action_filepath = os.path.join(os.path.dirname(sample), "drive_data.csv") action_df = pd.read_csv(action_filepath) action_labels = self.get_action_labels_for_clip(action_df, clip_indices) diff --git a/src/masks/multiblock3d.py b/src/masks/multiblock3d.py index f306677..8b9f048 100644 --- a/src/masks/multiblock3d.py +++ b/src/masks/multiblock3d.py @@ -62,7 +62,21 @@ def __call__(self, batch): return collated_batch, collated_masks_enc, collated_masks_pred +class MaskCollatorWithActions(MaskCollator): + def __call__(self, batch): + images, maneuver_ids = zip(*batch) + # collated_images = torch.utils.data.default_collate(images) + collated_maneuvers = torch.tensor(maneuver_ids) + collated_images = list(images) # Keep images as a list of PIL images + + collated_masks_pred, collated_masks_enc = [], [] + for i, mask_generator in enumerate(self.mask_generators): + masks_enc, masks_pred = mask_generator(len(collated_images)) + collated_masks_enc.append(masks_enc) + collated_masks_pred.append(masks_pred) + return collated_images, collated_maneuvers, collated_masks_enc, collated_masks_pred + class _MaskGenerator(object): def __init__( diff --git a/src/masks/random_tube.py b/src/masks/random_tube.py index 00fb6a7..9d8a82e 100644 --- a/src/masks/random_tube.py +++ b/src/masks/random_tube.py @@ -35,7 +35,7 @@ def __init__( num_frames=num_frames, spatial_patch_size=patch_size, temporal_patch_size=tubelet_size, - ratio=m.get("ratio"), + ratio=m.get("ratio", 0.9), ) self.mask_generators.append(mask_generator) @@ -56,6 +56,20 @@ def __call__(self, batch): return collated_batch, collated_masks_enc, collated_masks_pred +class MaskCollatorWithActions(MaskCollator): + def __call__(self, batch): + images, maneuver_ids = zip(*batch) + # collated_images = torch.utils.data.default_collate(images) + collated_maneuvers = torch.tensor(maneuver_ids) + collated_images = list(images) # Keep images as a list of PIL images + + collated_masks_pred, collated_masks_enc = [], [] + for i, mask_generator in enumerate(self.mask_generators): + masks_enc, masks_pred = mask_generator(len(collated_images)) + collated_masks_enc.append(masks_enc) + collated_masks_pred.append(masks_pred) + + return collated_images, collated_maneuvers, collated_masks_enc, collated_masks_pred class _MaskGenerator(object): @@ -79,7 +93,7 @@ def __init__( self.spatial_patch_size = spatial_patch_size self.temporal_patch_size = temporal_patch_size - self.num_patches_spatial = self.height * self.width + self.num_patches_spatial = self.height * self.width self.ratio = ratio diff --git a/src/models/utils/patch_embed.py b/src/models/utils/patch_embed.py index 6488421..3620715 100644 --- a/src/models/utils/patch_embed.py +++ b/src/models/utils/patch_embed.py @@ -41,15 +41,33 @@ def __init__( super().__init__() self.patch_size = patch_size self.tubelet_size = tubelet_size - - self.proj = nn.Conv3d( + self.proj_video = nn.Conv3d( in_channels=in_chans, out_channels=embed_dim, kernel_size=(tubelet_size, patch_size, patch_size), stride=(tubelet_size, patch_size, patch_size), ) + self.proj_image = nn.Conv2d( + in_channels=in_chans, + out_channels=embed_dim, + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + ) def forward(self, x, **kwargs): - B, C, T, H, W = x.shape - x = self.proj(x).flatten(2).transpose(1, 2) + if x is None: + return None + + if x.ndim == 5: # Video input + B, C, T, H, W = x.shape + x = self.proj_video(x) + x = x.flatten(2).transpose(1, 2) + elif x.ndim == 4: # Image input + B, C, H, W = x.shape + x = self.proj_image(x) + x = x.flatten(2).transpose(1, 2) + + else: + raise ValueError(f"Unsupported input shape: {x.shape}") + return x diff --git a/src/models/vision_transformer.py b/src/models/vision_transformer.py index c1fa5bb..7cfd339 100644 --- a/src/models/vision_transformer.py +++ b/src/models/vision_transformer.py @@ -163,6 +163,8 @@ def forward(self, x, masks=None): :param x: input image/video :param masks: indices of patch tokens to mask (remove) """ + if x is None: + raise ValueError("Input tensor x cannot be None") if masks is not None and not isinstance(masks, list): masks = [masks] @@ -171,7 +173,9 @@ def forward(self, x, masks=None): pos_embed = self.pos_embed if pos_embed is not None: pos_embed = self.interpolate_pos_encoding(x, pos_embed) + x = self.patch_embed(x) + if pos_embed is not None: x += pos_embed B, N, D = x.shape @@ -195,62 +199,41 @@ def forward(self, x, masks=None): x = self.norm(x) return x - + def interpolate_pos_encoding(self, x, pos_embed): - _, N, dim = pos_embed.shape - if self.is_video: - - # If pos_embed already corret size, just return + if x.dim() == 5: # Video clip [B, C, T, H, W] _, _, T, H, W = x.shape - if H == self.input_size and W == self.input_size and T == self.num_frames: - return pos_embed - - # Convert depth, height, width of input to be measured in patches - # instead of pixels/frames - T = T // self.tubelet_size - H = H // self.patch_size - W = W // self.patch_size - - # Compute the initialized shape of the positional embedding measured - # in patches - N_t = self.num_frames // self.tubelet_size - N_h = N_w = self.input_size // self.patch_size - assert N_h * N_w * N_t == N, "Positional embedding initialized incorrectly" - - # Compute scale factor for spatio-temporal interpolation - scale_factor = (T / N_t, H / N_h, W / N_w) - - pos_embed = nn.functional.interpolate( - pos_embed.reshape(1, N_t, N_h, N_w, dim).permute(0, 4, 1, 2, 3), - scale_factor=scale_factor, - mode="trilinear", - ) - pos_embed = pos_embed.permute(0, 2, 3, 4, 1).view(1, -1, dim) - return pos_embed - - else: - - # If pos_embed already corret size, just return - _, _, H, W = x.shape + # ... (rest of the video interpolation logic remains the same) + else: # Image sequence [B, T, H, W] + _, T, H, W = x.shape + # If pos_embed already correct size, just return if H == self.input_size and W == self.input_size: + # Add a temporal dimension to the positional embedding + pos_embed = pos_embed.unsqueeze(1).repeat(1, T, 1, 1) return pos_embed # Compute scale factor for spatial interpolation npatch = (H // self.patch_size) * (W // self.patch_size) + # Assuming pos_embed was initialized with no temporal dimension, + # N should correspond to the number of patches in a single image + assert N == npatch, "Input image size doesn't match model's expected size" scale_factor = math.sqrt(npatch / N) + + # Repeat positional embedding to account for the temporal dimension + pos_embed = pos_embed.unsqueeze(1).repeat(1, T, 1, 1) + # 2D interpolation of positional embeddings (spatial dimensions) pos_embed = nn.functional.interpolate( - pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( - 0, 3, 1, 2 - ), - scale_factor=scale_factor, + pos_embed.reshape(1, T, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 4, 1, 2, 3), + scale_factor=(1.0, scale_factor, scale_factor), # Only interpolate spatial dimensions mode="bicubic", + align_corners=False, ) - pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return pos_embed + pos_embed = pos_embed.permute(0, 2, 3, 4, 1).view(1, -1, dim) + return pos_embed def vit_tiny(patch_size=16, **kwargs): model = VisionTransformer( diff --git a/src/utils/distributed.py b/src/utils/distributed.py index 46cf5dd..a3b7cca 100644 --- a/src/utils/distributed.py +++ b/src/utils/distributed.py @@ -6,10 +6,12 @@ # import os +import traceback import torch import torch.distributed as dist + from logging import getLogger logger = getLogger() @@ -40,7 +42,7 @@ def init_distributed(port=37123, rank_and_world_size=(None, None)): ) except Exception as e: world_size, rank = 1, 0 - logger.info(f"Rank: {rank}. Distributed training not available {e}") + logger.info(f"Rank: {rank}. Distributed training not available {traceback.format_exc}") return world_size, rank