Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed training with PyTorch: PointPillars on Waymo #353

Merged
merged 27 commits into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions ml3d/configs/pointpillars_waymo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ dataset:
name: Waymo
dataset_path: # path/to/your/dataset
cache_dir: ./logs/cache
steps_per_epoch_train: 5000
steps_per_epoch_train: 4000

model:
name: PointPillars
Expand Down Expand Up @@ -31,7 +31,7 @@ model:
max_voxels: [32000, 32000]

voxel_encoder:
in_channels: 5
in_channels: 4
feat_channels: [64]
voxel_size: *vsize

Expand All @@ -43,7 +43,7 @@ model:
in_channels: 64
out_channels: [64, 128, 256]
layer_nums: [3, 5, 5]
layer_strides: [2, 2, 2]
layer_strides: [1, 2, 2]

neck:
in_channels: [64, 128, 256]
Expand All @@ -62,17 +62,18 @@ model:
[-74.88, -74.88, 0, 74.88, 74.88, 0],
]
sizes: [
[2.08, 4.73, 1.77], # car
[0.84, 1.81, 1.77], # cyclist
[0.84, 0.91, 1.74] # pedestrian
[2.08, 4.73, 1.77], # VEHICLE
[0.84, 1.81, 1.77], # CYCLIST
[0.84, 0.91, 1.74] # PEDESTRIAN
]
dir_offset: 0.7854
rotations: [0, 1.57]
iou_thr: [[0.4, 0.55], [0.3, 0.5], [0.3, 0.5]]

augment:
PointShuffle: True
ObjectRangeFilter: True
ObjectRangeFilter:
point_cloud_range: [-74.88, -74.88, -2, 74.88, 74.88, 4]
ObjectSample:
min_points_dict:
VEHICLE: 5
Expand All @@ -88,7 +89,7 @@ pipeline:
name: ObjectDetection
test_compute_metric: true
batch_size: 6
val_batch_size: 1
val_batch_size: 6
test_batch_size: 1
save_ckpt_freq: 5
max_epoch: 200
Expand All @@ -102,7 +103,7 @@ pipeline:
weight_decay: 0.01

# evaluation properties
overlaps: [0.5, 0.5, 0.7]
overlaps: [0.5, 0.5, 0.5]
difficulties: [0, 1, 2]
summary:
record_for: []
Expand Down
2 changes: 1 addition & 1 deletion ml3d/datasets/augment/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def ObjectSample(self, data, db_boxes_dict, sample_dict):
sampled_points = np.concatenate(
[box.points_inside_box for box in sampled], axis=0)
points = remove_points_in_boxes(points, sampled)
points = np.concatenate([sampled_points, points], axis=0)
points = np.concatenate([sampled_points[:, :4], points], axis=0)

return {
'point': points,
Expand Down
2 changes: 1 addition & 1 deletion ml3d/datasets/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import math
from scipy.spatial import ConvexHull

from ...metrics import iou_bev
from open3d.ml.contrib import iou_bev_cpu as iou_bev


def create_3D_rotations(axis, angle):
Expand Down
59 changes: 30 additions & 29 deletions ml3d/datasets/waymo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(self,
name='Waymo',
cache_dir='./logs/cache',
use_cache=False,
val_split=3,
**kwargs):
"""Initialize the function by passing the dataset and other details.

Expand All @@ -34,7 +33,6 @@ def __init__(self,
name: The name of the dataset (Waymo in this case).
cache_dir: The directory where the cache is stored.
use_cache: Indicates if the dataset should be cached.
val_split: The split value to get a set of images for training, validation, for testing.

Returns:
class: The corresponding class.
Expand All @@ -43,7 +41,6 @@ def __init__(self,
name=name,
cache_dir=cache_dir,
use_cache=use_cache,
val_split=val_split,
**kwargs)

cfg = self.cfg
Expand All @@ -52,22 +49,27 @@ def __init__(self,
self.dataset_path = cfg.dataset_path
self.num_classes = 4
self.label_to_names = self.get_label_to_names()
self.shuffle = kwargs.get('shuffle', False)

self.all_files = sorted(
glob(join(cfg.dataset_path, 'velodyne', '*.bin')))
self.train_files = []
self.val_files = []
self.test_files = []

for f in self.all_files:
idx = Path(f).name.replace('.bin', '')[:3]
idx = int(idx)
if idx < cfg.val_split:
if 'train' in f:
self.train_files.append(f)
else:
elif 'val' in f:
self.val_files.append(f)

self.test_files = glob(
join(cfg.dataset_path, 'testing', 'velodyne', '*.bin'))
elif 'test' in f:
self.test_files.append(f)
else:
log.warning(
f"Skipping {f}, prefix must be one of train, test or val.")
if self.shuffle:
log.info("Shuffling training files...")
self.rng.shuffle(self.train_files)

@staticmethod
def get_label_to_names():
Expand All @@ -90,18 +92,21 @@ def read_lidar(path):
"""Reads lidar data from the path provided.

Returns:
A data object with lidar information.
pc: pointcloud data with shape [N, 6], where
the format is xyzRGB.
"""
assert Path(path).exists()

return np.fromfile(path, dtype=np.float32).reshape(-1, 6)

@staticmethod
def read_label(path, calib):
"""Reads labels of bound boxes.
"""Reads labels of bounding boxes.

Args:
path: The path to the label file.
calib: Calibration as returned by read_calib().

Returns:
The data objects with bound boxes information.
The data objects with bounding boxes information.
"""
if not Path(path).exists():
return None
Expand Down Expand Up @@ -131,24 +136,22 @@ def read_calib(path):
Returns:
The camera and the camera image used in calibration.
"""
assert Path(path).exists()

with open(path, 'r') as f:
lines = f.readlines()
obj = lines[0].strip().split(' ')[1:]
P0 = np.array(obj, dtype=np.float32)
unused_P0 = np.array(obj, dtype=np.float32)

obj = lines[1].strip().split(' ')[1:]
P1 = np.array(obj, dtype=np.float32)
unused_P1 = np.array(obj, dtype=np.float32)

obj = lines[2].strip().split(' ')[1:]
P2 = np.array(obj, dtype=np.float32)

obj = lines[3].strip().split(' ')[1:]
P3 = np.array(obj, dtype=np.float32)
unused_P3 = np.array(obj, dtype=np.float32)

obj = lines[4].strip().split(' ')[1:]
P4 = np.array(obj, dtype=np.float32)
unused_P4 = np.array(obj, dtype=np.float32)

obj = lines[5].strip().split(' ')[1:]
R0 = np.array(obj, dtype=np.float32).reshape(3, 3)
Expand All @@ -162,7 +165,7 @@ def read_calib(path):
Tr_velo_to_cam = Waymo._extend_matrix(Tr_velo_to_cam)

world_cam = np.transpose(rect_4x4 @ Tr_velo_to_cam)
cam_img = np.transpose(P2)
cam_img = np.transpose(np.vstack((P2.reshape(3, 4), [0, 0, 0, 1])))

return {'world_cam': world_cam, 'cam_img': cam_img}

Expand Down Expand Up @@ -209,7 +212,7 @@ def get_split_list(self, split):
else:
raise ValueError("Invalid split {}".format(split))

def is_tested():
def is_tested(attr):
"""Checks if a datum in the dataset has been tested.

Args:
Expand All @@ -219,16 +222,16 @@ def is_tested():
If the datum attribute is tested, then return the path where the
attribute is stored; else, returns false.
"""
pass
raise NotImplementedError()

def save_test_result():
def save_test_result(results, attr):
"""Saves the output of a model.

Args:
results: The output of a model for the datum associated with the attribute passed.
attr: The attributes that correspond to the outputs passed in results.
"""
pass
raise NotImplementedError()


class WaymoSplit():
Expand Down Expand Up @@ -273,11 +276,9 @@ def get_attr(self, idx):


class Object3d(BEVBox3D):
"""The class stores details that are object-specific, such as bounding box
coordinates, occlusion and so on.
"""

def __init__(self, center, size, label, calib):
# ground truth files doesn't have confidence value.
confidence = float(label[15]) if label.__len__() == 16 else -1.0

world_cam = calib['world_cam']
Expand Down
35 changes: 34 additions & 1 deletion ml3d/torch/dataloaders/concat_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pickle
import torch
import yaml
import math
from os import listdir
from os.path import exists, join, isdir

Expand Down Expand Up @@ -434,6 +435,22 @@ def to(self, device):
self.feat = [feat.to(device) for feat in self.feat]
self.label = [label.to(device) for label in self.label]

@staticmethod
def scatter(batch, num_gpu):
batch_size = len(batch.batch_lengths)

new_batch_size = math.ceil(batch_size / num_gpu)
batches = [SparseConvUnetBatch([]) for _ in range(num_gpu)]
for i in range(num_gpu):
start = new_batch_size * i
end = min(new_batch_size * (i + 1), batch_size)
batches[i].point = batch.point[start:end]
batches[i].feat = batch.feat[start:end]
batches[i].label = batch.label[start:end]
batches[i].batch_lengths = batch.batch_lengths[start:end]

return [b for b in batches if len(b.point)] # filter empty batch


class PointTransformerBatch:

Expand Down Expand Up @@ -486,7 +503,6 @@ def __init__(self, batches):
self.attr = []

for batch in batches:
self.attr.append(batch['attr'])
data = batch['data']
self.point.append(torch.tensor(data['point'], dtype=torch.float32))
self.labels.append(
Expand Down Expand Up @@ -519,6 +535,23 @@ def to(self, device):
if self.bboxes[i] is not None:
self.bboxes[i] = self.bboxes[i].to(device)

@staticmethod
def scatter(batch, num_gpu):
batch_size = len(batch.point)

new_batch_size = math.ceil(batch_size / num_gpu)
batches = [ObjectDetectBatch([]) for _ in range(num_gpu)]
for i in range(num_gpu):
start = new_batch_size * i
end = min(new_batch_size * (i + 1), batch_size)
batches[i].point = batch.point[start:end]
batches[i].labels = batch.labels[start:end]
batches[i].bboxes = batch.bboxes[start:end]
batches[i].bbox_objs = batch.bbox_objs[start:end]
batches[i].attr = batch.attr[start:end]

return [b for b in batches if len(b.point)] # filter empty batch


class ConcatBatcher(object):
"""ConcatBatcher for KPConv."""
Expand Down
2 changes: 1 addition & 1 deletion ml3d/torch/models/base_model_objdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, **kwargs):
self.rng = np.random.default_rng(kwargs.get('seed', None))

@abstractmethod
def loss(self, results, inputs):
def get_loss(self, results, inputs):
"""Computes the loss given the network input and outputs.

Args:
Expand Down
2 changes: 1 addition & 1 deletion ml3d/torch/models/point_pillars.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def get_optimizer(self, cfg):
optimizer = torch.optim.AdamW(self.parameters(), **cfg)
return optimizer, None

def loss(self, results, inputs):
def get_loss(self, results, inputs):
scores, bboxes, dirs = results
gt_labels = inputs.labels
gt_bboxes = inputs.bboxes
Expand Down
2 changes: 1 addition & 1 deletion ml3d/torch/models/point_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def step(self):

return optimizer, scheduler

def loss(self, results, inputs):
def get_loss(self, results, inputs):
if self.mode == "RPN":
return self.rpn.loss(results, inputs)
else:
Expand Down
Loading