diff --git a/ml3d/configs/pointpillars_waymo.yml b/ml3d/configs/pointpillars_waymo.yml index 43534e834..23900ed15 100644 --- a/ml3d/configs/pointpillars_waymo.yml +++ b/ml3d/configs/pointpillars_waymo.yml @@ -2,7 +2,7 @@ dataset: name: Waymo dataset_path: # path/to/your/dataset cache_dir: ./logs/cache - steps_per_epoch_train: 5000 + steps_per_epoch_train: 4000 model: name: PointPillars @@ -31,7 +31,7 @@ model: max_voxels: [32000, 32000] voxel_encoder: - in_channels: 5 + in_channels: 4 feat_channels: [64] voxel_size: *vsize @@ -43,7 +43,7 @@ model: in_channels: 64 out_channels: [64, 128, 256] layer_nums: [3, 5, 5] - layer_strides: [2, 2, 2] + layer_strides: [1, 2, 2] neck: in_channels: [64, 128, 256] @@ -62,9 +62,9 @@ model: [-74.88, -74.88, 0, 74.88, 74.88, 0], ] sizes: [ - [2.08, 4.73, 1.77], # car - [0.84, 1.81, 1.77], # cyclist - [0.84, 0.91, 1.74] # pedestrian + [2.08, 4.73, 1.77], # VEHICLE + [0.84, 1.81, 1.77], # CYCLIST + [0.84, 0.91, 1.74] # PEDESTRIAN ] dir_offset: 0.7854 rotations: [0, 1.57] @@ -72,7 +72,8 @@ model: augment: PointShuffle: True - ObjectRangeFilter: True + ObjectRangeFilter: + point_cloud_range: [-74.88, -74.88, -2, 74.88, 74.88, 4] ObjectSample: min_points_dict: VEHICLE: 5 @@ -88,7 +89,7 @@ pipeline: name: ObjectDetection test_compute_metric: true batch_size: 6 - val_batch_size: 1 + val_batch_size: 6 test_batch_size: 1 save_ckpt_freq: 5 max_epoch: 200 @@ -102,7 +103,7 @@ pipeline: weight_decay: 0.01 # evaluation properties - overlaps: [0.5, 0.5, 0.7] + overlaps: [0.5, 0.5, 0.5] difficulties: [0, 1, 2] summary: record_for: [] diff --git a/ml3d/datasets/augment/augmentation.py b/ml3d/datasets/augment/augmentation.py index fb8cebca4..a7673c47b 100644 --- a/ml3d/datasets/augment/augmentation.py +++ b/ml3d/datasets/augment/augmentation.py @@ -493,7 +493,7 @@ def ObjectSample(self, data, db_boxes_dict, sample_dict): sampled_points = np.concatenate( [box.points_inside_box for box in sampled], axis=0) points = remove_points_in_boxes(points, sampled) - points = np.concatenate([sampled_points, points], axis=0) + points = np.concatenate([sampled_points[:, :4], points], axis=0) return { 'point': points, diff --git a/ml3d/datasets/utils/operations.py b/ml3d/datasets/utils/operations.py index dd211c13c..feb1ae804 100644 --- a/ml3d/datasets/utils/operations.py +++ b/ml3d/datasets/utils/operations.py @@ -4,7 +4,7 @@ import math from scipy.spatial import ConvexHull -from ...metrics import iou_bev +from open3d.ml.contrib import iou_bev_cpu as iou_bev def create_3D_rotations(axis, angle): diff --git a/ml3d/datasets/waymo.py b/ml3d/datasets/waymo.py index 047edfd3f..1dd2036ef 100644 --- a/ml3d/datasets/waymo.py +++ b/ml3d/datasets/waymo.py @@ -25,7 +25,6 @@ def __init__(self, name='Waymo', cache_dir='./logs/cache', use_cache=False, - val_split=3, **kwargs): """Initialize the function by passing the dataset and other details. @@ -34,7 +33,6 @@ def __init__(self, name: The name of the dataset (Waymo in this case). cache_dir: The directory where the cache is stored. use_cache: Indicates if the dataset should be cached. - val_split: The split value to get a set of images for training, validation, for testing. Returns: class: The corresponding class. @@ -43,7 +41,6 @@ def __init__(self, name=name, cache_dir=cache_dir, use_cache=use_cache, - val_split=val_split, **kwargs) cfg = self.cfg @@ -52,22 +49,27 @@ def __init__(self, self.dataset_path = cfg.dataset_path self.num_classes = 4 self.label_to_names = self.get_label_to_names() + self.shuffle = kwargs.get('shuffle', False) self.all_files = sorted( glob(join(cfg.dataset_path, 'velodyne', '*.bin'))) self.train_files = [] self.val_files = [] + self.test_files = [] for f in self.all_files: - idx = Path(f).name.replace('.bin', '')[:3] - idx = int(idx) - if idx < cfg.val_split: + if 'train' in f: self.train_files.append(f) - else: + elif 'val' in f: self.val_files.append(f) - - self.test_files = glob( - join(cfg.dataset_path, 'testing', 'velodyne', '*.bin')) + elif 'test' in f: + self.test_files.append(f) + else: + log.warning( + f"Skipping {f}, prefix must be one of train, test or val.") + if self.shuffle: + log.info("Shuffling training files...") + self.rng.shuffle(self.train_files) @staticmethod def get_label_to_names(): @@ -90,18 +92,21 @@ def read_lidar(path): """Reads lidar data from the path provided. Returns: - A data object with lidar information. + pc: pointcloud data with shape [N, 6], where + the format is xyzRGB. """ - assert Path(path).exists() - return np.fromfile(path, dtype=np.float32).reshape(-1, 6) @staticmethod def read_label(path, calib): - """Reads labels of bound boxes. + """Reads labels of bounding boxes. + + Args: + path: The path to the label file. + calib: Calibration as returned by read_calib(). Returns: - The data objects with bound boxes information. + The data objects with bounding boxes information. """ if not Path(path).exists(): return None @@ -131,24 +136,22 @@ def read_calib(path): Returns: The camera and the camera image used in calibration. """ - assert Path(path).exists() - with open(path, 'r') as f: lines = f.readlines() obj = lines[0].strip().split(' ')[1:] - P0 = np.array(obj, dtype=np.float32) + unused_P0 = np.array(obj, dtype=np.float32) obj = lines[1].strip().split(' ')[1:] - P1 = np.array(obj, dtype=np.float32) + unused_P1 = np.array(obj, dtype=np.float32) obj = lines[2].strip().split(' ')[1:] P2 = np.array(obj, dtype=np.float32) obj = lines[3].strip().split(' ')[1:] - P3 = np.array(obj, dtype=np.float32) + unused_P3 = np.array(obj, dtype=np.float32) obj = lines[4].strip().split(' ')[1:] - P4 = np.array(obj, dtype=np.float32) + unused_P4 = np.array(obj, dtype=np.float32) obj = lines[5].strip().split(' ')[1:] R0 = np.array(obj, dtype=np.float32).reshape(3, 3) @@ -162,7 +165,7 @@ def read_calib(path): Tr_velo_to_cam = Waymo._extend_matrix(Tr_velo_to_cam) world_cam = np.transpose(rect_4x4 @ Tr_velo_to_cam) - cam_img = np.transpose(P2) + cam_img = np.transpose(np.vstack((P2.reshape(3, 4), [0, 0, 0, 1]))) return {'world_cam': world_cam, 'cam_img': cam_img} @@ -209,7 +212,7 @@ def get_split_list(self, split): else: raise ValueError("Invalid split {}".format(split)) - def is_tested(): + def is_tested(attr): """Checks if a datum in the dataset has been tested. Args: @@ -219,16 +222,16 @@ def is_tested(): If the datum attribute is tested, then return the path where the attribute is stored; else, returns false. """ - pass + raise NotImplementedError() - def save_test_result(): + def save_test_result(results, attr): """Saves the output of a model. Args: results: The output of a model for the datum associated with the attribute passed. attr: The attributes that correspond to the outputs passed in results. """ - pass + raise NotImplementedError() class WaymoSplit(): @@ -273,11 +276,9 @@ def get_attr(self, idx): class Object3d(BEVBox3D): - """The class stores details that are object-specific, such as bounding box - coordinates, occlusion and so on. - """ def __init__(self, center, size, label, calib): + # ground truth files doesn't have confidence value. confidence = float(label[15]) if label.__len__() == 16 else -1.0 world_cam = calib['world_cam'] diff --git a/ml3d/torch/dataloaders/concat_batcher.py b/ml3d/torch/dataloaders/concat_batcher.py index 11a4b1cbb..a10684663 100644 --- a/ml3d/torch/dataloaders/concat_batcher.py +++ b/ml3d/torch/dataloaders/concat_batcher.py @@ -4,6 +4,7 @@ import pickle import torch import yaml +import math from os import listdir from os.path import exists, join, isdir @@ -434,6 +435,22 @@ def to(self, device): self.feat = [feat.to(device) for feat in self.feat] self.label = [label.to(device) for label in self.label] + @staticmethod + def scatter(batch, num_gpu): + batch_size = len(batch.batch_lengths) + + new_batch_size = math.ceil(batch_size / num_gpu) + batches = [SparseConvUnetBatch([]) for _ in range(num_gpu)] + for i in range(num_gpu): + start = new_batch_size * i + end = min(new_batch_size * (i + 1), batch_size) + batches[i].point = batch.point[start:end] + batches[i].feat = batch.feat[start:end] + batches[i].label = batch.label[start:end] + batches[i].batch_lengths = batch.batch_lengths[start:end] + + return [b for b in batches if len(b.point)] # filter empty batch + class PointTransformerBatch: @@ -486,7 +503,6 @@ def __init__(self, batches): self.attr = [] for batch in batches: - self.attr.append(batch['attr']) data = batch['data'] self.point.append(torch.tensor(data['point'], dtype=torch.float32)) self.labels.append( @@ -519,6 +535,23 @@ def to(self, device): if self.bboxes[i] is not None: self.bboxes[i] = self.bboxes[i].to(device) + @staticmethod + def scatter(batch, num_gpu): + batch_size = len(batch.point) + + new_batch_size = math.ceil(batch_size / num_gpu) + batches = [ObjectDetectBatch([]) for _ in range(num_gpu)] + for i in range(num_gpu): + start = new_batch_size * i + end = min(new_batch_size * (i + 1), batch_size) + batches[i].point = batch.point[start:end] + batches[i].labels = batch.labels[start:end] + batches[i].bboxes = batch.bboxes[start:end] + batches[i].bbox_objs = batch.bbox_objs[start:end] + batches[i].attr = batch.attr[start:end] + + return [b for b in batches if len(b.point)] # filter empty batch + class ConcatBatcher(object): """ConcatBatcher for KPConv.""" diff --git a/ml3d/torch/models/base_model_objdet.py b/ml3d/torch/models/base_model_objdet.py index 995ee3209..3f0717888 100644 --- a/ml3d/torch/models/base_model_objdet.py +++ b/ml3d/torch/models/base_model_objdet.py @@ -25,7 +25,7 @@ def __init__(self, **kwargs): self.rng = np.random.default_rng(kwargs.get('seed', None)) @abstractmethod - def loss(self, results, inputs): + def get_loss(self, results, inputs): """Computes the loss given the network input and outputs. Args: diff --git a/ml3d/torch/models/point_pillars.py b/ml3d/torch/models/point_pillars.py index ed08ec09e..a519acc75 100644 --- a/ml3d/torch/models/point_pillars.py +++ b/ml3d/torch/models/point_pillars.py @@ -137,7 +137,7 @@ def get_optimizer(self, cfg): optimizer = torch.optim.AdamW(self.parameters(), **cfg) return optimizer, None - def loss(self, results, inputs): + def get_loss(self, results, inputs): scores, bboxes, dirs = results gt_labels = inputs.labels gt_bboxes = inputs.bboxes diff --git a/ml3d/torch/models/point_rcnn.py b/ml3d/torch/models/point_rcnn.py index dc9cf4300..f367e53ff 100644 --- a/ml3d/torch/models/point_rcnn.py +++ b/ml3d/torch/models/point_rcnn.py @@ -183,7 +183,7 @@ def step(self): return optimizer, scheduler - def loss(self, results, inputs): + def get_loss(self, results, inputs): if self.mode == "RPN": return self.rpn.loss(results, inputs) else: diff --git a/ml3d/torch/pipelines/base_pipeline.py b/ml3d/torch/pipelines/base_pipeline.py index 9823d24ac..f466868b9 100644 --- a/ml3d/torch/pipelines/base_pipeline.py +++ b/ml3d/torch/pipelines/base_pipeline.py @@ -12,13 +12,19 @@ class BasePipeline(ABC): """Base pipeline class.""" - def __init__(self, model, dataset=None, device='gpu', **kwargs): + def __init__(self, + model, + dataset=None, + device='cuda', + distributed=False, + **kwargs): """Initialize. Args: model: A network model. dataset: A dataset, or None for inference model. - device: 'gpu' or 'cpu'. + device: 'cuda' or 'cpu'. + distributed: Whether to use multiple gpus. kwargs: Returns: @@ -34,18 +40,36 @@ def __init__(self, model, dataset=None, device='gpu', **kwargs): self.dataset = dataset self.rng = np.random.default_rng(kwargs.get('seed', None)) - make_dir(self.cfg.main_log_dir) + self.distributed = distributed + if self.distributed and self.name == "SemanticSegmentation": + raise NotImplementedError( + "Distributed training not implemented for SemanticSegmentation!" + ) + + self.rank = kwargs.get('rank', 0) + dataset_name = dataset.name if dataset is not None else '' self.cfg.logs_dir = join( self.cfg.main_log_dir, model.__class__.__name__ + '_' + dataset_name + '_torch') - make_dir(self.cfg.logs_dir) + + if self.rank == 0: + make_dir(self.cfg.main_log_dir) + make_dir(self.cfg.logs_dir) if device == 'cpu' or not torch.cuda.is_available(): + if distributed: + raise NotImplementedError( + "Distributed training for CPU is not supported yet.") self.device = torch.device('cpu') else: - self.device = torch.device('cuda' if len(device.split(':')) == - 1 else 'cuda:' + device.split(':')[1]) + if distributed: + self.device = torch.device(device) + print(f"Rank : {self.rank} using device : {self.device}") + torch.cuda.set_device(self.device) + else: + self.device = torch.device('cuda') + self.summary = {} self.cfg.setdefault('summary', {}) diff --git a/ml3d/torch/pipelines/object_detection.py b/ml3d/torch/pipelines/object_detection.py index 53ddb18b3..593494fea 100644 --- a/ml3d/torch/pipelines/object_detection.py +++ b/ml3d/torch/pipelines/object_detection.py @@ -1,12 +1,13 @@ import logging import re +import numpy as np +import torch +import torch.distributed as dist + from datetime import datetime from os.path import exists, join from pathlib import Path - from tqdm import tqdm -import numpy as np -import torch from torch.utils.data import DataLoader from .base_pipeline import BasePipeline @@ -148,8 +149,9 @@ def run_valid(self, epoch=0): log.info("DEVICE : {}".format(device)) log_file_path = join(cfg.logs_dir, 'log_valid_' + timestamp + '.txt') - log.info("Logging in file : {}".format(log_file_path)) - log.addHandler(logging.FileHandler(log_file_path)) + if self.rank == 0: + log.info("Logging in file : {}".format(log_file_path)) + log.addHandler(logging.FileHandler(log_file_path)) batcher = ConcatBatcher(device, model.cfg.name) @@ -161,16 +163,22 @@ def run_valid(self, epoch=0): shuffle=True, steps_per_epoch=dataset.cfg.get( 'steps_per_epoch_valid', None)) - valid_loader = DataLoader( - valid_split, - batch_size=cfg.val_batch_size, - num_workers=cfg.get('num_workers', 4), - pin_memory=cfg.get('pin_memory', False), - collate_fn=batcher.collate_fn, - worker_init_fn=lambda x: np.random.seed(x + np.uint32( - torch.utils.data.get_worker_info().seed))) - record_summary = 'valid' in cfg.get('summary').get('record_for', []) + if self.distributed: + valid_sampler = torch.utils.data.distributed.DistributedSampler( + valid_split) + else: + valid_sampler = None + + valid_loader = DataLoader(valid_split, + batch_size=cfg.val_batch_size, + num_workers=cfg.get('num_workers', 0), + pin_memory=cfg.get('pin_memory', False), + collate_fn=batcher.collate_fn, + sampler=valid_sampler) + + record_summary = self.rank == 0 and 'valid' in cfg.get('summary').get( + 'record_for', []) log.info("Started validation") self.valid_losses = {} @@ -181,7 +189,7 @@ def run_valid(self, epoch=0): for data in tqdm(valid_loader, desc='validation'): data.to(device) results = model(data) - loss = model.loss(results, data) + loss = model.get_loss(results, data) for l, v in loss.items(): if l not in self.valid_losses: self.valid_losses[l] = [] @@ -191,12 +199,12 @@ def run_valid(self, epoch=0): boxes = model.inference_end(results, data) pred.extend([BEVBox3D.to_dicts(b) for b in boxes]) gt.extend([BEVBox3D.to_dicts(b) for b in data.bbox_objs]) - # Save only for the first batch - if record_summary and 'valid' not in self.summary: + if record_summary: self.summary['valid'] = self.get_3d_summary(boxes, data, epoch, results=results) + record_summary = False # Save only for the first batch sum_loss = 0 desc = "validation - " @@ -211,6 +219,22 @@ def run_valid(self, epoch=0): similar_classes = cfg.get("similar_classes", {}) difficulties = cfg.get("difficulties", [0]) + if self.distributed: + gt_gather = [None for _ in range(dist.get_world_size())] + pred_gather = [None for _ in range(dist.get_world_size())] + + dist.gather_object(gt, gt_gather if self.rank == 0 else None, dst=0) + dist.gather_object(pred, + pred_gather if self.rank == 0 else None, + dst=0) + + if self.rank == 0: + gt = sum(gt_gather, []) + pred = sum(pred_gather, []) + + if self.rank != 0: + return + ap = mAP(pred, gt, model.classes, @@ -249,18 +273,21 @@ def run_train(self): """Run training with train data split.""" torch.manual_seed(self.rng.integers(np.iinfo( np.int32).max)) # Random reproducible seed for torch + rank = self.rank # Rank for distributed training model = self.model device = self.device dataset = self.dataset cfg = self.cfg - log.info("DEVICE : {}".format(device)) - timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') + if rank == 0: + log.info("DEVICE : {}".format(device)) + timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') - log_file_path = join(cfg.logs_dir, 'log_train_' + timestamp + '.txt') - log.info("Logging in file : {}".format(log_file_path)) - log.addHandler(logging.FileHandler(log_file_path)) + log_file_path = join(cfg.logs_dir, + 'log_train_' + timestamp + '.txt') + log.info("Logging in file : {}".format(log_file_path)) + log.addHandler(logging.FileHandler(log_file_path)) batcher = ConcatBatcher(device, model.cfg.name) @@ -271,12 +298,20 @@ def run_train(self): use_cache=dataset.cfg.use_cache, steps_per_epoch=dataset.cfg.get( 'steps_per_epoch_train', None)) + + if self.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_split) + else: + train_sampler = None + train_loader = DataLoader( train_split, batch_size=cfg.batch_size, - num_workers=cfg.get('num_workers', 4), + num_workers=cfg.get('num_workers', 0), pin_memory=cfg.get('pin_memory', False), collate_fn=batcher.collate_fn, + sampler=train_sampler, worker_init_fn=lambda x: np.random.seed(x + np.uint32( torch.utils.data.get_worker_info().seed)) ) # numpy expects np.uint32, whereas torch returns np.uint64. @@ -295,29 +330,52 @@ def run_train(self): runid + '_' + Path(tensorboard_dir).name) writer = SummaryWriter(self.tensorboard_dir) - self.save_config(writer) - log.info("Writing summary in {}.".format(self.tensorboard_dir)) - record_summary = 'train' in cfg.get('summary').get('record_for', []) + if rank == 0: + self.save_config(writer) + log.info("Writing summary in {}.".format(self.tensorboard_dir)) + + # wrap model for multiple GPU + if self.distributed: + model.cuda(self.device) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[self.device]) + model.get_loss = model.module.get_loss + model.cfg = model.module.cfg + model.inference_end = model.module.inference_end + + record_summary = self.rank == 0 and 'train' in cfg.get('summary').get( + 'record_for', []) + + if rank == 0: + log.info("Started training") - log.info("Started training") for epoch in range(start_ep, cfg.max_epoch + 1): log.info(f'=== EPOCH {epoch:d}/{cfg.max_epoch:d} ===') - model.train() + if self.distributed: + train_sampler.set_epoch(epoch) + model.train() self.losses = {} process_bar = tqdm(train_loader, desc='training') for data in process_bar: data.to(device) results = model(data) - loss = model.loss(results, data) + loss = model.get_loss(results, data) loss_sum = sum(loss.values()) self.optimizer.zero_grad() loss_sum.backward() - if model.cfg.get('grad_clip_norm', -1) > 0: - torch.nn.utils.clip_grad_value_(model.parameters(), - model.cfg.grad_clip_norm) + if self.distributed: + if model.module.cfg.get('grad_clip_norm', -1) > 0: + torch.nn.utils.clip_grad_value_( + model.module.parameters(), + model.module.cfg.grad_clip_norm) + else: + if model.cfg.get('grad_clip_norm', -1) > 0: + torch.nn.utils.clip_grad_value_( + model.parameters(), model.cfg.grad_clip_norm) + self.optimizer.step() # Record visualization for the last iteration @@ -337,17 +395,22 @@ def run_train(self): process_bar.set_description(desc) process_bar.refresh() + if self.distributed: + dist.barrier() + if self.scheduler is not None: self.scheduler.step() # --------------------- validation - if (epoch % cfg.get("validation_freq", 1)) == 0: + if epoch % cfg.get("validation_freq", 1) == 0: self.run_valid() + if self.distributed: + dist.barrier() - self.save_logs(writer, epoch) - - if epoch % cfg.save_ckpt_freq == 0 or epoch == cfg.max_epoch: - self.save_ckpt(epoch) + if rank == 0: + self.save_logs(writer, epoch) + if epoch % cfg.save_ckpt_freq == 0 or epoch == cfg.max_epoch: + self.save_ckpt(epoch) def get_3d_summary(self, infer_bboxes_batch, @@ -470,7 +533,10 @@ def save_logs(self, writer, epoch): def load_ckpt(self, ckpt_path=None, is_resume=True): train_ckpt_dir = join(self.cfg.logs_dir, 'checkpoint') - make_dir(train_ckpt_dir) + if self.rank == 0: + make_dir(train_ckpt_dir) + if self.distributed: + dist.barrier() epoch = 0 if ckpt_path is None: diff --git a/ml3d/torch/pipelines/semantic_segmentation.py b/ml3d/torch/pipelines/semantic_segmentation.py index 8863aedad..232ce5a30 100644 --- a/ml3d/torch/pipelines/semantic_segmentation.py +++ b/ml3d/torch/pipelines/semantic_segmentation.py @@ -95,7 +95,7 @@ def __init__( scheduler_gamma=0.95, momentum=0.98, main_log_dir='./logs/', - device='gpu', + device='cuda', split='train', train_sum_dir='train_log', **kwargs): diff --git a/ml3d/utils/builder.py b/ml3d/utils/builder.py index d68272c37..f087d501c 100644 --- a/ml3d/utils/builder.py +++ b/ml3d/utils/builder.py @@ -6,17 +6,30 @@ SAMPLER = Registry('sampler') -def convert_device_name(framework): +def build(cfg, registry, args=None): + return build_from_cfg(cfg, registry, args) + + +def build_network(cfg): + return build(cfg, NETWORK) + + +def convert_device_name(framework, device_ids): """Convert device to either cpu or cuda.""" gpu_names = ["gpu", "cuda"] cpu_names = ["cpu"] if framework not in cpu_names + gpu_names: raise KeyError("the device should either " "be cuda or cpu but got {}".format(framework)) + assert type(device_ids) is list + device_ids_new = [] + for device in device_ids: + device_ids_new.append(int(device)) + if framework in gpu_names: - return "cuda" + return "cuda", device_ids_new else: - return "cpu" + return "cpu", device_ids_new def convert_framework_name(framework): diff --git a/ml3d/utils/config.py b/ml3d/utils/config.py index 3b61d1d4e..41998a2de 100644 --- a/ml3d/utils/config.py +++ b/ml3d/utils/config.py @@ -245,3 +245,9 @@ def __getattr__(self, name): def __getitem__(self, name): return self._cfg_dict.__getitem__(name) + + def __getstate__(self): + return self.cfg_dict + + def __setstate__(self, state): + self.cfg_dict = state diff --git a/scripts/collect_bboxes.py b/scripts/collect_bboxes.py index 95dbfee01..f8c6a5f33 100644 --- a/scripts/collect_bboxes.py +++ b/scripts/collect_bboxes.py @@ -1,10 +1,13 @@ import logging -from os.path import join import argparse import pickle +import numpy as np +import multiprocessing + +from tqdm import tqdm +from os.path import join from open3d.ml.datasets import utils from open3d.ml import datasets -import multiprocessing def parse_args(): @@ -23,10 +26,17 @@ def parse_args(): default="KITTI", required=False) parser.add_argument('--num_cpus', - help='Name of dataset class', + help='Number of threads to use.', type=int, default=multiprocessing.cpu_count(), required=False) + parser.add_argument( + '--max_pc', + help= + 'Boxes from random N pointclouds will be saved. Default None(save from whole dataset).', + type=int, + default=None, + required=False) args = parser.parse_args() @@ -84,11 +94,19 @@ def process_boxes(i): classname = getattr(datasets, args.dataset_type) dataset = classname(args.dataset_path) train = dataset.get_split('train') + max_pc = len(train) if args.max_pc is None else args.max_pc + + rng = np.random.default_rng() + query_pc = range(len(train)) if max_pc >= len(train) else rng.choice( + range(len(train)), max_pc, replace=False) - print("Found", len(train), "traning samples") - print("This may take a few minutes...") + print(f"Found {len(train)} traning samples, Using {max_pc}") + print( + f"Using {args.num_cpus} number of cpus, This may take a few minutes...") with multiprocessing.Pool(args.num_cpus) as p: - bboxes = p.map(process_boxes, range(len(train))) + bboxes = list(tqdm(p.imap(process_boxes, query_pc), + total=len(query_pc))) bboxes = [e for l in bboxes for e in l] file = open(join(out_path, 'bboxes.pkl'), 'wb') pickle.dump(bboxes, file) + print(f"Saved {len(bboxes)} boxes.") diff --git a/scripts/preprocess_waymo.py b/scripts/preprocess_waymo.py index 6a32dde7d..6b44a61fd 100644 --- a/scripts/preprocess_waymo.py +++ b/scripts/preprocess_waymo.py @@ -8,13 +8,13 @@ import logging import numpy as np import os, sys, glob, pickle -from pathlib import Path -from os.path import join, exists, dirname, abspath -from os import makedirs -import random import argparse import tensorflow as tf import matplotlib.image as mpimg + +from pathlib import Path +from os.path import join, exists, dirname, abspath +from os import makedirs from multiprocessing import Pool from tqdm import tqdm from waymo_open_dataset.utils import range_image_utils, transform_utils @@ -38,10 +38,10 @@ def parse_args(): default=16, type=int) - parser.add_argument('--is_test', - help='True for processing test data (default False)', - default=False, - type=bool) + parser.add_argument('--split', + help='One of {train, val, test} (default train)', + default='train', + type=str) args = parser.parse_args() @@ -58,6 +58,25 @@ class Waymo2KITTI(): """Waymo to KITTI converter. This class converts tfrecord files from Waymo dataset to KITTI format. + KITTI format : (type, truncated, occluded, alpha, bbox, dimensions(3), location(3), + rotation_y(1), score(1, optional)) + type (string): Describes the type of object. + truncated (float): Ranges from 0(non-truncated) to 1(truncated). + occluded (int): Integer(0, 1, 2, 3) signifies state fully visible, partly + occluded, largely occluded, unknown. + alpha (float): Observation angle of object, ranging [-pi..pi]. + bbox (float): 2d bounding box of object in the image. + dimensions (float): 3D object dimensions: h, w, l in meters. + location (float): 3D object location: x,y,z in camera coordinates (in meters). + rotation_y (float): rotation around Y-axis in camera coordinates [-pi..pi]. + score (float): Only for predictions, indicating confidence in detection. + + Conversion writes following files: + pointcloud(np.float32) : pointcloud data with shape [N, 6]. Consists of + (x, y, z, intensity, elongation, timestamp). + images(np.uint8): camera images are saved if `write_image` is True. + calibrations(np.float32): Intinsic and Extrinsic matrix for all cameras. + label(np.float32): Bounding box information in KITTI format. Args: dataset_path (str): Directory to load waymo raw data. @@ -66,9 +85,9 @@ class Waymo2KITTI(): is_test (bool): Whether in the test_mode. Default: False. """ - def __init__(self, dataset_path, save_dir='', workers=8, is_test=False): + def __init__(self, dataset_path, save_dir='', workers=8, split='train'): - self.write_image = True + self.write_image = False self.filter_empty_3dboxes = True self.filter_no_label_zone_points = True @@ -86,8 +105,8 @@ def __init__(self, dataset_path, save_dir='', workers=8, is_test=False): self.dataset_path = dataset_path self.save_dir = save_dir self.workers = int(workers) - self.is_test = is_test - self.prefix = '' + self.is_test = split == 'test' + self.prefix = split + '_' self.save_track_id = False self.tfrecord_files = sorted( @@ -137,7 +156,6 @@ def process_one(self, file_idx): if (self.selected_waymo_locations is not None and frame.context.stats.location not in self.selected_waymo_locations): - print("continue") continue if self.write_image: @@ -153,8 +171,6 @@ def __len__(self): return len(self.tfrecord_files) def save_image(self, frame, file_idx, frame_idx): - self.prefix = '' - for img in frame.images: img_path = Path(self.image_save_dir + str(img.name - 1)) / ( self.prefix + str(file_idx).zfill(3) + str(frame_idx).zfill(3) + @@ -210,7 +226,6 @@ def save_calib(self, frame, file_idx, frame_idx): f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt', 'w+') as fp_calib: fp_calib.write(calib_context) - fp_calib.close() def save_pose(self, frame, file_idx, frame_idx): pose = np.array(frame.pose.transform).reshape(4, 4) @@ -228,7 +243,6 @@ def save_label(self, frame, file_idx, frame_idx): for labels in frame.projected_lidar_labels: name = labels.name for label in labels.labels: - # TODO: need a workaround as bbox may not belong to front cam bbox = [ label.box.center_x - label.box.length / 2, label.box.center_y - label.box.width / 2, @@ -257,9 +271,6 @@ def save_label(self, frame, file_idx, frame_idx): if my_type not in self.selected_waymo_classes: continue - # if self.filter_empty_3dboxes and obj.num_lidar_points_in_box < 1: - # continue - height = obj.box.height width = obj.box.width length = obj.box.length @@ -268,11 +279,6 @@ def save_label(self, frame, file_idx, frame_idx): y = obj.box.center_y z = obj.box.center_z - # # project bounding box to the virtual reference frame - # pt_ref = self.T_velo_to_front_cam @ \ - # np.array([x, y, z, 1]).reshape((4, 1)) - # x, y, z, _ = pt_ref.flatten().tolist() - rotation_y = -obj.box.heading - np.pi / 2 track_id = obj.id @@ -460,6 +466,8 @@ def cart_to_homo(mat): out_path = args.out_path if out_path is None: args.out_path = args.dataset_path + if args.split not in ['train', 'val', 'test']: + raise ValueError("split must be one of {train, val, test}") converter = Waymo2KITTI(args.dataset_path, args.out_path, args.workers, - args.is_test) + args.split) converter.convert() diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index 4fc55e471..68564e1b9 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -5,6 +5,8 @@ import yaml import pprint import os +import torch.distributed as dist +from torch import multiprocessing from pathlib import Path @@ -28,8 +30,12 @@ def parse_args(): parser.add_argument('--dataset_path', help='path to the dataset') parser.add_argument('--ckpt_path', help='path to the checkpoint') parser.add_argument('--device', - help='device to run the pipeline', - default='gpu') + help='devices to run the pipeline', + default='cuda') + parser.add_argument('--device_ids', + nargs='+', + help='cuda device list', + default=['0']) parser.add_argument('--split', help='train or test', default='train') parser.add_argument('--mode', help='additional mode', default=None) parser.add_argument('--max_epochs', help='number of epochs', default=None) @@ -37,6 +43,18 @@ def parse_args(): parser.add_argument('--main_log_dir', help='the dir to save logs and models') parser.add_argument('--seed', help='random seed', default=0) + parser.add_argument( + '--host', + help='Host for distributed training, default: localhost', + default='localhost') + parser.add_argument('--port', + help='port for distributed training, default: 12355', + default='12355') + parser.add_argument( + '--backend', + help= + 'backend for distributed training. One of (nccl, gloo)}, default: gloo', + default='gloo') args, unknown = parser.parse_known_args() @@ -63,10 +81,13 @@ def main(): args, extra_dict = parse_args() framework = _ml3d.utils.convert_framework_name(args.framework) - args.device = _ml3d.utils.convert_device_name(args.device) + args.device, args.device_ids = _ml3d.utils.convert_device_name( + args.device, args.device_ids) rng = np.random.default_rng(args.seed) if framework == 'torch': import open3d.ml.torch as ml3d + import torch.multiprocessing as mp + import torch.distributed as dist else: os.environ[ 'TF_CPP_MIN_LOG_LEVEL'] = '1' # Disable INFO messages from tf @@ -100,22 +121,20 @@ def main(): cfg_dict_dataset, cfg_dict_pipeline, cfg_dict_model = \ _ml3d.utils.Config.merge_cfg_file(cfg, args, extra_dict) - cfg_dict_dataset['seed'] = rng - cfg_dict_model['seed'] = rng - cfg_dict_pipeline['seed'] = rng - - dataset = Dataset(cfg_dict_dataset.pop('dataset_path', None), - **cfg_dict_dataset) - if args.mode is not None: cfg_dict_model["mode"] = args.mode - model = Model(**cfg_dict_model) - if args.max_epochs is not None: cfg_dict_pipeline["max_epochs"] = args.max_epochs if args.batch_size is not None: cfg_dict_pipeline["batch_size"] = args.batch_size - pipeline = Pipeline(model, dataset, **cfg_dict_pipeline) + + cfg_dict_dataset['seed'] = rng + cfg_dict_model['seed'] = rng + cfg_dict_pipeline['seed'] = rng + + cfg_dict_pipeline["device"] = args.device + cfg_dict_pipeline["device_ids"] = args.device_ids + else: if (args.pipeline and args.model and args.dataset) is None: raise ValueError("Please specify pipeline, model, and dataset " + @@ -133,31 +152,94 @@ def main(): cfg_dict_model['seed'] = rng cfg_dict_pipeline['seed'] = rng - dataset = Dataset(**cfg_dict_dataset) - model = Model(**cfg_dict_model, mode=args.mode) - pipeline = Pipeline(model, dataset, **cfg_dict_pipeline) - with open(Path(__file__).parent / 'README.md', 'r') as f: readme = f.read() - pipeline.cfg_tb = { + + cfg_tb = { 'readme': readme, 'cmd_line': cmd_line, 'dataset': pprint.pformat(cfg_dict_dataset, indent=2), 'model': pprint.pformat(cfg_dict_model, indent=2), 'pipeline': pprint.pformat(cfg_dict_pipeline, indent=2) } + args.cfg_tb = cfg_tb + args.distributed = framework == 'torch' and args.device != 'cpu' and len( + args.device_ids) > 1 + + if not args.distributed: + dataset = Dataset(**cfg_dict_dataset) + model = Model(**cfg_dict_model, mode=args.mode) + pipeline = Pipeline(model, dataset, **cfg_dict_pipeline) + + pipeline.cfg_tb = cfg_tb + + if args.split == 'test': + pipeline.run_test() + else: + pipeline.run_train() + + else: + mp.spawn(main_worker, + args=(Dataset, Model, Pipeline, cfg_dict_dataset, + cfg_dict_model, cfg_dict_pipeline, args), + nprocs=len(args.device_ids)) + + +def setup(rank, world_size, args): + os.environ['MASTER_ADDR'] = args.host + os.environ['MASTER_PORT'] = args.port + + # initialize the process group + dist.init_process_group(args.backend, rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +def main_worker(rank, Dataset, Model, Pipeline, cfg_dict_dataset, + cfg_dict_model, cfg_dict_pipeline, args): + world_size = len(args.device_ids) + setup(rank, world_size, args) + + cfg_dict_dataset['rank'] = rank + cfg_dict_model['rank'] = rank + cfg_dict_pipeline['rank'] = rank + + rng = np.random.default_rng(args.seed + rank) + cfg_dict_dataset['seed'] = rng + cfg_dict_model['seed'] = rng + cfg_dict_pipeline['seed'] = rng + + device = f"cuda:{args.device_ids[rank]}" + print(f"rank = {rank}, world_size = {world_size}, gpu = {device}") + + cfg_dict_model['device'] = device + cfg_dict_pipeline['device'] = device + + dataset = Dataset(**cfg_dict_dataset) + model = Model(**cfg_dict_model, mode=args.mode) + pipeline = Pipeline(model, + dataset, + distributed=args.distributed, + **cfg_dict_pipeline) + + pipeline.cfg_tb = args.cfg_tb if args.split == 'test': - pipeline.run_test() + if rank == 0: + pipeline.run_test() else: pipeline.run_train() + cleanup() -if __name__ == '__main__': +if __name__ == '__main__': logging.basicConfig( level=logging.INFO, format='%(levelname)s - %(asctime)s - %(module)s - %(message)s', ) - main() + multiprocessing.set_start_method('forkserver') + sys.exit(main())