diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d578018 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +/saved +/logs +/data +/.vscode +.python-version +__pycache__/ +*.out \ No newline at end of file diff --git a/conf/augmentation/volumentations_aug.yaml b/conf/augmentation/volumentations_aug.yaml new file mode 100644 index 0000000..5a79185 --- /dev/null +++ b/conf/augmentation/volumentations_aug.yaml @@ -0,0 +1,53 @@ +# pi = 3.14159265358979 +# pi/2 = 1.57079632679489 +# pi/3 = 1.04719755119659 +# pi/6 = 0.52359877559829 +# pi/12 = 0.26179938779914 +# pi/24 = 0.13089969389957 + +__version__: 0.1.6 +transform: + __class_fullname__: volumentations.core.composition.Compose + additional_targets: {} + p: 1.0 + transforms: + - __class_fullname__: volumentations.augmentations.transforms.Scale3d + always_apply: true + p: 0.5 + scale_limit: + - - -0.1 + - 0.1 + - - -0.1 + - 0.1 + - - -0.1 + - 0.1 + - __class_fullname__: volumentations.augmentations.transforms.RotateAroundAxis3d + always_apply: true + axis: + - 0 + - 0 + - 1 + p: 0.5 + rotation_limit: + - -3.141592653589793 + - 3.141592653589793 + - __class_fullname__: volumentations.augmentations.transforms.RotateAroundAxis3d + always_apply: true + axis: + - 0 + - 1 + - 0 + p: 0.5 + rotation_limit: + - -0.13089969389957 + - 0.13089969389957 + - __class_fullname__: volumentations.augmentations.transforms.RotateAroundAxis3d + always_apply: true + axis: + - 1 + - 0 + - 0 + p: 0.5 + rotation_limit: + - -0.13089969389957 + - 0.13089969389957 diff --git a/conf/callbacks/callbacks_panoptic.yaml b/conf/callbacks/callbacks_panoptic.yaml new file mode 100644 index 0000000..c5011ec --- /dev/null +++ b/conf/callbacks/callbacks_panoptic.yaml @@ -0,0 +1,11 @@ +# @package _group_ +- _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: val_mean_lstq + save_last: true + save_top_k: 1 + mode: max + dirpath: ${general.save_dir} + filename: "{epoch}-{val_mean_lstq:.3f}" + every_n_epochs: 1 + +- _target_: pytorch_lightning.callbacks.LearningRateMonitor diff --git a/conf/config_panoptic_4d.yaml b/conf/config_panoptic_4d.yaml new file mode 100644 index 0000000..382aed3 --- /dev/null +++ b/conf/config_panoptic_4d.yaml @@ -0,0 +1,35 @@ +general: + mode: "train" + seed: null + ckpt_path: null + project_name: mask4d + workspace: kadiryilmaz + instance_population: 10 + dbscan_eps: null + data_percent: 1.0 + experiment_name: ${now:%Y-%m-%d_%H%M%S} + save_dir: saved/${general.experiment_name} + gpus: 1 + +defaults: + - data: kitti + - data/data_loaders: simple_loader + - data/datasets: semantic_kitti + - data/collation_functions: voxelize_collate + - logging: full + - model: mask4d + - optimizer: adamw + - scheduler: onecyclelr + - trainer: trainer30 + - callbacks: callbacks_panoptic + - matcher: hungarian_matcher + - loss: set_criterion + - metric: lstq + +hydra: + run: + dir: saved/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: saved/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} + # dir: ${general.save_dir} + subdir: ${hydra.job.num}_${hydra.job.id} diff --git a/conf/data/collation_functions/voxelize_collate.yaml b/conf/data/collation_functions/voxelize_collate.yaml new file mode 100644 index 0000000..7302146 --- /dev/null +++ b/conf/data/collation_functions/voxelize_collate.yaml @@ -0,0 +1,16 @@ +# @package data + +train_collation: + _target_: datasets.utils.VoxelizeCollate + ignore_label: ${data.ignore_label} + voxel_size: ${data.voxel_size} + +validation_collation: + _target_: datasets.utils.VoxelizeCollate + ignore_label: ${data.ignore_label} + voxel_size: ${data.voxel_size} + +test_collation: + _target_: datasets.utils.VoxelizeCollate + ignore_label: ${data.ignore_label} + voxel_size: ${data.voxel_size} diff --git a/conf/data/data_loaders/simple_loader.yaml b/conf/data/data_loaders/simple_loader.yaml new file mode 100644 index 0000000..39996e1 --- /dev/null +++ b/conf/data/data_loaders/simple_loader.yaml @@ -0,0 +1,22 @@ +# @package data + +train_dataloader: + _target_: torch.utils.data.DataLoader + shuffle: true + pin_memory: ${data.pin_memory} + num_workers: ${data.num_workers} + batch_size: ${data.batch_size} + +validation_dataloader: + _target_: torch.utils.data.DataLoader + shuffle: false + pin_memory: ${data.pin_memory} + num_workers: ${data.num_workers} + batch_size: ${data.test_batch_size} + +test_dataloader: + _target_: torch.utils.data.DataLoader + shuffle: false + pin_memory: ${data.pin_memory} + num_workers: ${data.num_workers} + batch_size: ${data.test_batch_size} diff --git a/conf/data/datasets/semantic_kitti.yaml b/conf/data/datasets/semantic_kitti.yaml new file mode 100644 index 0000000..4a94622 --- /dev/null +++ b/conf/data/datasets/semantic_kitti.yaml @@ -0,0 +1,33 @@ +# @package data +train_dataset: + _target_: datasets.lidar.LidarDataset + data_dir: data/semantic_kitti + mode: ${data.train_mode} + add_distance: ${data.add_distance} + sweep: ${data.sweep} + instance_population: ${data.instance_population} + data_percent: ${data.data_percent} + ignore_label: ${data.ignore_label} + volume_augmentations_path: conf/augmentation/volumentations_aug.yaml + +validation_dataset: + _target_: datasets.lidar.LidarDataset + data_dir: data/semantic_kitti + mode: ${data.validation_mode} + add_distance: ${data.add_distance} + sweep: ${data.sweep} + instance_population: 0 + data_percent: 1.0 + ignore_label: ${data.ignore_label} + volume_augmentations_path: null + +test_dataset: + _target_: datasets.lidar.LidarDataset + data_dir: data/semantic_kitti + mode: ${data.test_mode} + add_distance: ${data.add_distance} + sweep: ${data.sweep} + instance_population: 0 + data_percent: 1.0 + ignore_label: ${data.ignore_label} + volume_augmentations_path: null diff --git a/conf/data/kitti.yaml b/conf/data/kitti.yaml new file mode 100644 index 0000000..d3d79f6 --- /dev/null +++ b/conf/data/kitti.yaml @@ -0,0 +1,32 @@ +# @package _group_ + +# these parameters are inherited by datasets, data_loaders and collators +# but they might be overwritten + +# splits +train_mode: train +validation_mode: validation +test_mode: test + +# dataset +ignore_label: 255 +add_distance: true +in_channels: 2 +num_labels: 19 +instance_population: ${general.instance_population} +data_percent: ${general.data_percent} +sweep: 2 +min_stuff_cls_id: 9 +min_points: 50 +class_names: ['car', 'bicycle', 'motorcycle', 'truck', 'other-vehicle', 'person', 'bicyclist', +'motorcyclist', 'road', 'parking', 'sidewalk', 'other-ground', 'building', 'fence', 'vegetation', +'trunk', 'terrain', 'pole', 'traffic-sign'] + +# data loader +pin_memory: true +num_workers: 4 +batch_size: 4 +test_batch_size: 2 + +# collation +voxel_size: 0.05 diff --git a/conf/logging/full.yaml b/conf/logging/full.yaml new file mode 100644 index 0000000..970b2da --- /dev/null +++ b/conf/logging/full.yaml @@ -0,0 +1,7 @@ +# @package _group_ +- _target_: pytorch_lightning.loggers.WandbLogger + project: ${general.project_name} + name: ${general.experiment_name} + save_dir: ${general.save_dir} + entity: "kadiryilmaz193" + id: ${general.experiment_name} diff --git a/conf/loss/set_criterion.yaml b/conf/loss/set_criterion.yaml new file mode 100644 index 0000000..61f4fae --- /dev/null +++ b/conf/loss/set_criterion.yaml @@ -0,0 +1,8 @@ +# @package _group_ +_target_: models.criterion.SetCriterion +num_classes: ${data.num_labels} +eos_coef: 0.1 +losses: + - "labels" + - "masks" + - "bboxs" diff --git a/conf/matcher/hungarian_matcher.yaml b/conf/matcher/hungarian_matcher.yaml new file mode 100644 index 0000000..e8e17a1 --- /dev/null +++ b/conf/matcher/hungarian_matcher.yaml @@ -0,0 +1,6 @@ +# @package _group_ +_target_: models.matcher.HungarianMatcher +cost_class: 2. +cost_mask: 5. +cost_dice: 2. +cost_box: 5. diff --git a/conf/metric/lstq.yaml b/conf/metric/lstq.yaml new file mode 100644 index 0000000..8b8456c --- /dev/null +++ b/conf/metric/lstq.yaml @@ -0,0 +1,5 @@ +# @package _group_ +_target_: models.metrics.Panoptic4DEval +n_classes: ${data.num_labels} +min_stuff_cls_id: ${data.min_stuff_cls_id} +min_points: ${data.min_points} \ No newline at end of file diff --git a/conf/model/mask4d.yaml b/conf/model/mask4d.yaml new file mode 100644 index 0000000..a1836ae --- /dev/null +++ b/conf/model/mask4d.yaml @@ -0,0 +1,22 @@ +# @package _group_ +_target_: models.Mask4D + +# backbone +backbone: + _target_: models.Res16UNet34C + config: + dialations: [ 1, 1, 1, 1 ] + conv1_kernel_size: 5 + bn_momentum: 0.02 + in_channels: ${data.in_channels} + out_channels: ${data.num_labels} + +# transformer parameters +num_queries: 100 +num_heads: 8 +num_decoders: 3 +num_levels: 4 +sample_sizes: [4000, 8000, 16000, 32000] +mask_dim: 128 +dim_feedforward: 1024 +num_labels: ${data.num_labels} diff --git a/conf/optimizer/adamw.yaml b/conf/optimizer/adamw.yaml new file mode 100644 index 0000000..e2e8c3e --- /dev/null +++ b/conf/optimizer/adamw.yaml @@ -0,0 +1,3 @@ +# @package _group_ +_target_: torch.optim.AdamW +lr: 0.0002 \ No newline at end of file diff --git a/conf/scheduler/onecyclelr.yaml b/conf/scheduler/onecyclelr.yaml new file mode 100644 index 0000000..0d992cb --- /dev/null +++ b/conf/scheduler/onecyclelr.yaml @@ -0,0 +1,10 @@ +# @package _group_ +scheduler: + _target_: torch.optim.lr_scheduler.OneCycleLR + max_lr: ${optimizer.lr} + epochs: ${trainer.max_epochs} + # need to set to number because of tensorboard logger + steps_per_epoch: -1 + +pytorch_lightning_params: + interval: step diff --git a/conf/semantic-kitti.yaml b/conf/semantic-kitti.yaml new file mode 100644 index 0000000..6281065 --- /dev/null +++ b/conf/semantic-kitti.yaml @@ -0,0 +1,211 @@ +# This file is covered by the LICENSE file in the root of this project. +labels: + 0 : "unlabeled" + 1 : "outlier" + 10: "car" + 11: "bicycle" + 13: "bus" + 15: "motorcycle" + 16: "on-rails" + 18: "truck" + 20: "other-vehicle" + 30: "person" + 31: "bicyclist" + 32: "motorcyclist" + 40: "road" + 44: "parking" + 48: "sidewalk" + 49: "other-ground" + 50: "building" + 51: "fence" + 52: "other-structure" + 60: "lane-marking" + 70: "vegetation" + 71: "trunk" + 72: "terrain" + 80: "pole" + 81: "traffic-sign" + 99: "other-object" + 252: "moving-car" + 253: "moving-bicyclist" + 254: "moving-person" + 255: "moving-motorcyclist" + 256: "moving-on-rails" + 257: "moving-bus" + 258: "moving-truck" + 259: "moving-other-vehicle" +color_map: # bgr + 0 : [0, 0, 0] + 1 : [0, 0, 255] + 10: [245, 150, 100] + 11: [245, 230, 100] + 13: [250, 80, 100] + 15: [150, 60, 30] + 16: [255, 0, 0] + 18: [180, 30, 80] + 20: [255, 0, 0] + 30: [30, 30, 255] + 31: [200, 40, 255] + 32: [90, 30, 150] + 40: [255, 0, 255] + 44: [255, 150, 255] + 48: [75, 0, 75] + 49: [75, 0, 175] + 50: [0, 200, 255] + 51: [50, 120, 255] + 52: [0, 150, 255] + 60: [170, 255, 150] + 70: [0, 175, 0] + 71: [0, 60, 135] + 72: [80, 240, 150] + 80: [150, 240, 255] + 81: [0, 0, 255] + 99: [255, 255, 50] + 252: [245, 150, 100] + 256: [255, 0, 0] + 253: [200, 40, 255] + 254: [30, 30, 255] + 255: [90, 30, 150] + 257: [250, 80, 100] + 258: [180, 30, 80] + 259: [255, 0, 0] +content: # as a ratio with the total number of points + 0: 0.018889854628292943 + 1: 0.0002937197336781505 + 10: 0.040818519255974316 + 11: 0.00016609538710764618 + 13: 2.7879693665067774e-05 + 15: 0.00039838616015114444 + 16: 0.0 + 18: 0.0020633612104619787 + 20: 0.0016218197275284021 + 30: 0.00017698551338515307 + 31: 1.1065903904919655e-08 + 32: 5.532951952459828e-09 + 40: 0.1987493871255525 + 44: 0.014717169549888214 + 48: 0.14392298360372 + 49: 0.0039048553037472045 + 50: 0.1326861944777486 + 51: 0.0723592229456223 + 52: 0.002395131480328884 + 60: 4.7084144280367186e-05 + 70: 0.26681502148037506 + 71: 0.006035012012626033 + 72: 0.07814222006271769 + 80: 0.002855498193863172 + 81: 0.0006155958086189918 + 99: 0.009923127583046915 + 252: 0.001789309418528068 + 253: 0.00012709999297008662 + 254: 0.00016059776092534436 + 255: 3.745553104802113e-05 + 256: 0.0 + 257: 0.00011351574470342043 + 258: 0.00010157861367183268 + 259: 4.3840131989471124e-05 +# classes that are indistinguishable from single scan or inconsistent in +# ground truth are mapped to their closest equivalent +learning_map: + 0 : 0 # "unlabeled" + 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped + 10: 1 # "car" + 11: 2 # "bicycle" + 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped + 15: 3 # "motorcycle" + 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped + 18: 4 # "truck" + 20: 5 # "other-vehicle" + 30: 6 # "person" + 31: 7 # "bicyclist" + 32: 8 # "motorcyclist" + 40: 9 # "road" + 44: 10 # "parking" + 48: 11 # "sidewalk" + 49: 12 # "other-ground" + 50: 13 # "building" + 51: 14 # "fence" + 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped + 60: 9 # "lane-marking" to "road" ---------------------------------mapped + 70: 15 # "vegetation" + 71: 16 # "trunk" + 72: 17 # "terrain" + 80: 18 # "pole" + 81: 19 # "traffic-sign" + 99: 0 # "other-object" to "unlabeled" ----------------------------mapped + 252: 1 # "moving-car" to "car" ------------------------------------mapped + 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped + 254: 6 # "moving-person" to "person" ------------------------------mapped + 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped + 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped + 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped + 258: 4 # "moving-truck" to "truck" --------------------------------mapped + 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped +learning_map_inv: # inverse of previous map + 0: 0 # "unlabeled", and others ignored + 1: 10 # "car" + 2: 11 # "bicycle" + 3: 15 # "motorcycle" + 4: 18 # "truck" + 5: 20 # "other-vehicle" + 6: 30 # "person" + 7: 31 # "bicyclist" + 8: 32 # "motorcyclist" + 9: 40 # "road" + 10: 44 # "parking" + 11: 48 # "sidewalk" + 12: 49 # "other-ground" + 13: 50 # "building" + 14: 51 # "fence" + 15: 70 # "vegetation" + 16: 71 # "trunk" + 17: 72 # "terrain" + 18: 80 # "pole" + 19: 81 # "traffic-sign" +learning_ignore: # Ignore classes + 0: True # "unlabeled", and others ignored + 1: False # "car" + 2: False # "bicycle" + 3: False # "motorcycle" + 4: False # "truck" + 5: False # "other-vehicle" + 6: False # "person" + 7: False # "bicyclist" + 8: False # "motorcyclist" + 9: False # "road" + 10: False # "parking" + 11: False # "sidewalk" + 12: False # "other-ground" + 13: False # "building" + 14: False # "fence" + 15: False # "vegetation" + 16: False # "trunk" + 17: False # "terrain" + 18: False # "pole" + 19: False # "traffic-sign" +split: # sequence numbers + train: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 9 + - 10 + valid: + - 8 + test: + - 11 + - 12 + - 13 + - 14 + - 15 + - 16 + - 17 + - 18 + - 19 + - 20 + - 21 diff --git a/conf/trainer/trainer30.yaml b/conf/trainer/trainer30.yaml new file mode 100644 index 0000000..80c10be --- /dev/null +++ b/conf/trainer/trainer30.yaml @@ -0,0 +1,4 @@ +# @package _group_ +max_epochs: 30 +check_val_every_n_epoch: 5 +num_sanity_val_steps: 2 diff --git a/datasets/lidar.py b/datasets/lidar.py new file mode 100644 index 0000000..c09ddaa --- /dev/null +++ b/datasets/lidar.py @@ -0,0 +1,201 @@ +import numpy as np +import volumentations as V +import yaml +from torch.utils.data import Dataset +from pathlib import Path +from typing import List, Optional, Union +from random import random, choice, uniform + + +class LidarDataset(Dataset): + def __init__( + self, + data_dir: Optional[str] = "data/processed/semantic_kitti", + mode: Optional[str] = "train", + add_distance: Optional[bool] = False, + data_percent: Optional[float] = 1.0, + ignore_label: Optional[Union[int, List[int]]] = 255, + volume_augmentations_path: Optional[str] = None, + instance_population: Optional[int] = 0, + sweep: Optional[int] = 1, + ): + self.mode = mode + self.data_dir = data_dir + self.ignore_label = ignore_label + self.add_distance = add_distance + self.instance_population = instance_population + self.sweep = sweep + self.config = self._load_yaml("conf/semantic-kitti.yaml") + + # loading database file + database_path = Path(self.data_dir) + if not (database_path / f"{mode}_database.yaml").exists(): + print(f"generate {database_path}/{mode}_database.yaml first") + exit() + self.data = self._load_yaml(database_path / f"{mode}_database.yaml") + + self.label_info = self._select_correct_labels(self.config["learning_ignore"]) + # augmentations + self.volume_augmentations = V.NoOp() + if volume_augmentations_path is not None: + self.volume_augmentations = V.load( + volume_augmentations_path, data_format="yaml" + ) + # reformulating in sweeps + data = [[]] + last_scene = self.data[0]["scene"] + for x in self.data: + if x["scene"] == last_scene: + data[-1].append(x) + else: + last_scene = x["scene"] + data.append([x]) + for i in range(len(data)): + data[i] = list(self.chunks(data[i], sweep)) + self.data = [val for sublist in data for val in sublist] + + if instance_population > 0: + self.instance_data = self._load_yaml(database_path / f"{mode}_instances_database.yaml") + + if data_percent < 1.0: + self.data = self.data[: int(len(self.data) * data_percent)] + + def chunks(self, lst, n): + if "train" in self.mode or n==1: + for i in range(len(lst) - n + 1): + yield lst[i : i + n] + else: + for i in range(0, len(lst) - n + 1, n-1): + yield lst[i : i + n] + if i != len(lst) - n: + yield lst[i + n - 1: ] + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx: int): + coordinates_list = [] + features_list = [] + labels_list = [] + acc_num_points = [0] + for time, scan in enumerate(self.data[idx]): + points = np.fromfile(scan["filepath"], dtype=np.float32).reshape(-1, 4) + coordinates = points[:, :3] + # rotate and translate + pose = np.array(scan["pose"]).T + coordinates = coordinates @ pose[:3, :3] + pose[3, :3] + coordinates_list.append(coordinates) + acc_num_points.append(acc_num_points[-1] + len(coordinates)) + features = points[:, 3:4] + time_array = np.ones((features.shape[0], 1)) * time + features = np.hstack((time_array, features)) + features_list.append(features) + if "test" in self.mode: + labels = np.zeros_like(features).astype(np.int64) + labels_list.append(labels) + else: + panoptic_label = np.fromfile(scan["label_filepath"], dtype=np.uint32) + semantic_label = panoptic_label & 0xFFFF + semantic_label = np.vectorize(self.config["learning_map"].__getitem__)(semantic_label) + labels = np.hstack((semantic_label[:, None], panoptic_label[:, None])) + labels_list.append(labels) + + coordinates = np.vstack(coordinates_list) + features = np.vstack(features_list) + labels = np.vstack(labels_list) + + if "train" in self.mode and self.instance_population > 0: + max_instance_id = np.amax(labels[:, 1]) + pc_center = coordinates.mean(axis=0) + instance_c, instance_f, instance_l = self.populate_instances( + max_instance_id, pc_center, self.instance_population) + coordinates = np.vstack((coordinates, instance_c)) + features = np.vstack((features, instance_f)) + labels = np.vstack((labels, instance_l)) + + if self.add_distance: + center_coordinate = coordinates.mean(0) + features = np.hstack( + ( + features, + np.linalg.norm(coordinates - center_coordinate, axis=1)[ + :, np.newaxis + ], + ) + ) + + # volume and image augmentations for train + if "train" in self.mode: + coordinates -= coordinates.mean(0) + if 0.5 > random(): + coordinates += ( + np.random.uniform(coordinates.min(0), coordinates.max(0)) / 2 + ) + aug = self.volume_augmentations(points=coordinates) + coordinates = aug["points"] + + features = np.hstack((coordinates, features)) + + labels[:, 0] = np.vectorize(self.label_info.__getitem__)(labels[:, 0]) + + return { + "num_points": acc_num_points, + "coordinates": coordinates, + "features": features, + "labels": labels, + "sequence": scan["scene"] + } + + @staticmethod + def _load_yaml(filepath): + with open(filepath) as f: + file = yaml.safe_load(f) + return file + + def _select_correct_labels(self, learning_ignore): + count = 0 + label_info = dict() + for k, v, in learning_ignore.items(): + if v: + label_info[k] = self.ignore_label + else: + label_info[k] = count + count += 1 + return label_info + + def _remap_model_output(self, output): + inv_map = {v: k for k, v in self.label_info.items()} + output = np.vectorize(inv_map.__getitem__)(output) + return output + + def populate_instances(self, max_instance_id, pc_center, instance_population): + coordinates_list = [] + features_list = [] + labels_list = [] + for _ in range(instance_population): + instance_dict = choice(self.instance_data) + idx = np.random.randint(len(instance_dict["filepaths"])) + instance_list = [] + for time in range(self.sweep): + if idx < len(instance_dict["filepaths"]): + filepath = instance_dict["filepaths"][idx] + instance = np.load(filepath) + time_array = np.ones((instance.shape[0], 1)) * time + instance = np.hstack((instance[:, :3], time_array , instance[:, 3:4])) + instance_list.append(instance) + idx = idx + 1 + instances = np.vstack(instance_list) + coordinates = instances[:, :3] - instances[:, :3].mean(0) + coordinates += pc_center + np.array([uniform(-10, 10), uniform(-10, 10), uniform(-1, 1)]) + features = instances[:, 3:] + semantic_label = instance_dict["semantic_label"] + labels = np.zeros_like(features, dtype=np.int64) + labels[:, 0] = semantic_label + max_instance_id = max_instance_id + 1 + labels[:, 1] = max_instance_id + aug = self.volume_augmentations(points=coordinates) + coordinates = aug["points"] + coordinates_list.append(coordinates) + features_list.append(features) + labels_list.append(labels) + return np.vstack(coordinates_list), np.vstack(features_list), np.vstack(labels_list) diff --git a/datasets/preprocessing/semantic_kitti_preprocessing.py b/datasets/preprocessing/semantic_kitti_preprocessing.py new file mode 100644 index 0000000..bd4258a --- /dev/null +++ b/datasets/preprocessing/semantic_kitti_preprocessing.py @@ -0,0 +1,253 @@ +import re +import numpy as np +import yaml +from pathlib import Path +from natsort import natsorted +from loguru import logger +from tqdm import tqdm +from fire import Fire + + +class SemanticKittiPreprocessing: + def __init__( + self, + data_dir: str = "/globalwork/data/SemanticKITTI/dataset", + save_dir: str = "/globalwork/yilmaz/data/processed/semantic_kitti", + modes: tuple = ("train", "validation", "test"), + git_repo: str = "./third_party/semantic-kitti-api", + ): + self.data_dir = Path(data_dir) + self.save_dir = Path(save_dir) + self.modes = modes + git_repo = Path(git_repo) + + if not self.data_dir.exists(): + logger.error("Data folder doesn't exist") + raise FileNotFoundError + if self.save_dir.exists() is False: + self.save_dir.mkdir(parents=True, exist_ok=True) + + self.files = {} + for data_type in self.modes: + self.files.update({data_type: []}) + + self.config = self._load_yaml(git_repo / "config" / "semantic-kitti.yaml") + self.create_label_database(git_repo / "config" / "semantic-kitti.yaml") + self.pose = dict() + + for mode in self.modes: + scene_mode = "valid" if mode == "validation" else mode + self.pose[mode] = dict() + for scene in sorted(self.config["split"][scene_mode]): + filepaths = list(self.data_dir.glob(f"*/{scene:02}/velodyne/*bin")) + filepaths = [str(file) for file in filepaths] + self.files[mode].extend(natsorted(filepaths)) + calibration = parse_calibration( + Path(filepaths[0]).parent.parent / "calib.txt" + ) + self.pose[mode].update( + { + scene: parse_poses( + Path(filepaths[0]).parent.parent / "poses.txt", calibration, + ), + } + ) + + def preprocess(self): + for mode in self.modes: + database = [] + for filepath in tqdm(self.files[mode], unit="file"): + filebase = self.process_file(filepath, mode) + database.append(filebase) + self.save_database(database, mode) + self.joint_database() + + def make_instance_database(self): + train_database = self._load_yaml(self.save_dir / "train_database.yaml") + instance_database = {} + for sample in tqdm(train_database): + instances = self.extract_instance_from_file(sample) + for instance in instances: + scene = instance["scene"] + panoptic_label = instance["panoptic_label"] + unique_identifier = f"{scene}_{panoptic_label}" + if unique_identifier in instance_database: + instance_database[unique_identifier]["filepaths"].append(instance["instance_filepath"]) + else: + instance_database[unique_identifier] = { + "semantic_label": instance["semantic_label"], + "filepaths": [instance["instance_filepath"]] + } + self.save_database(list(instance_database.values()), "train_instances") + + validation_database = self._load_yaml(self.save_dir / "validation_database.yaml") + for sample in tqdm(validation_database): + instances = self.extract_instance_from_file(sample) + for instance in instances: + scene = instance["scene"] + panoptic_label = instance["panoptic_label"] + unique_identifier = f"{scene}_{panoptic_label}" + if unique_identifier in instance_database: + instance_database[unique_identifier]["filepaths"].append(instance["instance_filepath"]) + else: + instance_database[unique_identifier] = { + "semantic_label": instance["semantic_label"], + "filepaths": [instance["instance_filepath"]] + } + self.save_database(list(instance_database.values()), "trainval_instances") + + def extract_instance_from_file(self, sample): + points = np.fromfile(sample["filepath"], dtype=np.float32).reshape(-1, 4) + pose = np.array(sample["pose"]).T + points[:, :3] = points[:, :3] @ pose[:3, :3] + pose[3, :3] + label = np.fromfile(sample["label_filepath"], dtype=np.uint32) + scene, sub_scene = re.search(r"(\d{2}).*(\d{6})", sample["filepath"]).group(1, 2) + file_instances = [] + for panoptic_label in np.unique(label): + semantic_label = panoptic_label & 0xFFFF + semantic_label = np.vectorize(self.config["learning_map"].__getitem__)(semantic_label) + if np.isin(semantic_label, range(1,9)): + instance_mask = label == panoptic_label + instance_points = points[instance_mask, :] + filename = f"{scene}_{panoptic_label:010d}_{sub_scene}.npy" + instance_filepath = self.save_dir / "instances" / filename + instance = { + "scene": scene, + "sub_scene": sub_scene, + "panoptic_label": f"{panoptic_label:010d}", + "instance_filepath": str(instance_filepath), + "semantic_label": semantic_label.item(), + } + if not instance_filepath.parent.exists(): + instance_filepath.parent.mkdir(parents=True, exist_ok=True) + np.save(instance_filepath, instance_points.astype(np.float32)) + file_instances.append(instance) + return file_instances + + def save_database(self, database, mode): + for element in database: + self._dict_to_yaml(element) + self._save_yaml(self.save_dir / (mode + "_database.yaml"), database) + + def joint_database(self, train_modes=["train", "validation"]): + joint_db = [] + for mode in train_modes: + joint_db.extend(self._load_yaml(self.save_dir / (mode + "_database.yaml"))) + self._save_yaml(self.save_dir / "trainval_database.yaml", joint_db) + + @classmethod + def _save_yaml(cls, path, file): + with open(path, "w") as f: + yaml.safe_dump(file, f, default_style=None, default_flow_style=False) + + @classmethod + def _dict_to_yaml(cls, dictionary): + if not isinstance(dictionary, dict): + return + for k, v in dictionary.items(): + if isinstance(v, dict): + cls._dict_to_yaml(v) + if isinstance(v, np.ndarray): + dictionary[k] = v.tolist() + if isinstance(v, Path): + dictionary[k] = str(v) + + @classmethod + def _load_yaml(cls, filepath): + with open(filepath) as f: + file = yaml.safe_load(f) + return file + + def create_label_database(self, config_file): + if (self.save_dir / "label_database.yaml").exists(): + return self._load_yaml(self.save_dir / "label_database.yaml") + config = self._load_yaml(config_file) + label_database = {} + for key, old_key in config["learning_map_inv"].items(): + label_database.update( + { + key: { + "name": config["labels"][old_key], + "color": config["color_map"][old_key][::-1], + "validation": not config["learning_ignore"][key], + } + } + ) + + self._save_yaml(self.save_dir / "label_database.yaml", label_database) + return label_database + + def process_file(self, filepath, mode): + scene, sub_scene = re.search(r"(\d{2}).*(\d{6})", filepath).group(1, 2) + sample = { + "filepath": filepath, + "scene": int(scene), + "pose": self.pose[mode][int(scene)][int(sub_scene)].tolist(), + } + + if mode in ["train", "validation"]: + # getting label info + label_filepath = filepath.replace("velodyne", "labels").replace( + "bin", "label" + ) + sample["label_filepath"] = label_filepath + return sample + +def parse_calibration(filename): + """ read calibration file with given filename + Returns + ------- + dict + Calibration matrices as 4x4 numpy arrays. + """ + calib = {} + + with open(filename) as calib_file: + for line in calib_file: + key, content = line.strip().split(":") + values = [float(v) for v in content.strip().split()] + + pose = np.zeros((4, 4)) + pose[0, 0:4] = values[0:4] + pose[1, 0:4] = values[4:8] + pose[2, 0:4] = values[8:12] + pose[3, 3] = 1.0 + + calib[key] = pose + return calib + +def parse_poses(filename, calibration): + """ read poses file with per-scan poses from given filename + Returns + ------- + list + list of poses as 4x4 numpy arrays. + """ + + poses = [] + + Tr = calibration["Tr"] + Tr_inv = np.linalg.inv(Tr) + + with open(filename) as file: + for line in file: + values = [float(v) for v in line.strip().split()] + + pose = np.zeros((4, 4)) + pose[0, 0:4] = values[0:4] + pose[1, 0:4] = values[4:8] + pose[2, 0:4] = values[8:12] + pose[3, 3] = 1.0 + + poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr))) + + return poses + +if __name__ == "__main__": + Fire(SemanticKittiPreprocessing) + +# if __name__ == "__main__": + # preprocess_cls = SemanticKittiPreprocessing() + # preprocess_cls.preprocess_sequential() + # preprocess_cls.make_instance_database_sequential() + \ No newline at end of file diff --git a/datasets/utils.py b/datasets/utils.py new file mode 100644 index 0000000..4bc67a1 --- /dev/null +++ b/datasets/utils.py @@ -0,0 +1,110 @@ +import MinkowskiEngine as ME +import numpy as np +import torch + + +class VoxelizeCollate: + def __init__( + self, + ignore_label=255, + voxel_size=1, + ): + self.voxel_size = voxel_size + self.ignore_label = ignore_label + + def __call__(self, batch): + (coordinates, features, labels, original_labels, inverse_maps, num_points, sequences) = ( + [], + [], + [], + [], + [], + [], + [] + ) + + for sample in batch: + original_labels.append(sample["labels"]) + num_points.append(sample["num_points"]) + sequences.append(sample["sequence"]) + sample_c, sample_f, sample_l, inverse_map = voxelize(sample["coordinates"], + sample["features"], + sample["labels"], + self.voxel_size) + inverse_maps.append(inverse_map) + coordinates.append(sample_c) + features.append(sample_f) + labels.append(sample_l) + + # Concatenate all lists + target = generate_target(features, labels, self.ignore_label) + coordinates, features = ME.utils.sparse_collate(coordinates, features) + raw_coordinates = features[:, :4] + features = features[:, 4:] + + return (NoGpu(coordinates, features, raw_coordinates, original_labels, inverse_maps, num_points, sequences), target) + +def voxelize(coordinates, features, labels, voxel_size): + if coordinates.shape[1] == 4: + voxel_size = np.array([voxel_size, voxel_size, voxel_size, 1]) + sample_c, sample_f, unique_map, inverse_map = ME.utils.sparse_quantize(coordinates=coordinates, + features=features, + return_index=True, + return_inverse=True, + quantization_size=voxel_size) + sample_c = sample_c + sample_f = torch.from_numpy(sample_f).float() + sample_l = torch.from_numpy(labels[unique_map]) + return sample_c, sample_f, sample_l, inverse_map + +def generate_target(features, labels, ignore_label): + target = [] + + for feat, lb in zip(features, labels): + raw_coords = feat[:, :3] + raw_coords = (raw_coords - raw_coords.min(0)[0]) / (raw_coords.max(0)[0]-raw_coords.min(0)[0]) + mask_labels = [] + binary_masks = [] + bboxs = [] + + panoptic_labels = lb[:, 1].unique() + for panoptic_label in panoptic_labels: + mask = lb[:, 1] == panoptic_label + + if panoptic_label == 0: + continue + + sem_labels = lb[mask, 0] + if sem_labels[0] != ignore_label: + mask_labels.append(sem_labels[0]) + binary_masks.append(mask) + mask_coords = raw_coords[mask, :] + bboxs.append(torch.hstack(( + mask_coords.mean(0), + mask_coords.max(0)[0]-mask_coords.min(0)[0], + ))) + + if len(mask_labels) != 0: + mask_labels = torch.stack(mask_labels) + binary_masks = torch.stack(binary_masks) + bboxs = torch.stack(bboxs) + target.append({ + 'labels': mask_labels, + 'masks': binary_masks, + 'bboxs': bboxs + }) + + return target + + +class NoGpu: + def __init__(self, coordinates, features, raw_coordinates, original_labels=None, inverse_maps=None, num_points=None, + sequences=None): + """ helper class to prevent gpu loading on lightning """ + self.coordinates = coordinates + self.features = features + self.raw_coordinates = raw_coordinates + self.original_labels = original_labels + self.inverse_maps = inverse_maps + self.num_points = num_points + self.sequences = sequences \ No newline at end of file diff --git a/main_panoptic.py b/main_panoptic.py new file mode 100644 index 0000000..fd63c53 --- /dev/null +++ b/main_panoptic.py @@ -0,0 +1,101 @@ +import logging +import os +import hydra +import torch +from dotenv import load_dotenv +from omegaconf import DictConfig, OmegaConf +from trainer.trainer import PanopticSegmentation +from utils.utils import flatten_dict, RegularCheckpointing +from pytorch_lightning import Trainer, seed_everything + + +def get_parameters(cfg: DictConfig): + logger = logging.getLogger(__name__) + load_dotenv(".env") + + # parsing input parameters + seed_everything(cfg.general.seed) + + # getting basic configuration + if cfg.general.get("gpus", None) is None: + cfg.general.gpus = os.environ.get("CUDA_VISIBLE_DEVICES", None) + loggers = [] + + if not os.path.exists(cfg.general.save_dir): + os.makedirs(cfg.general.save_dir) + else: + print("EXPERIMENT ALREADY EXIST") + cfg.general.ckpt_path = f"{cfg.general.save_dir}/last-epoch.ckpt" + + for log in cfg.logging: + print(log) + loggers.append(hydra.utils.instantiate(log)) + loggers[-1].log_hyperparams( + flatten_dict(OmegaConf.to_container(cfg, resolve=True)) + ) + + model = PanopticSegmentation(cfg) + + logger.info(flatten_dict(OmegaConf.to_container(cfg, resolve=True))) + return cfg, model, loggers + + +@hydra.main(config_path="conf", config_name="config_panoptic_4d.yaml") +def train(cfg: DictConfig): + os.chdir(hydra.utils.get_original_cwd()) + cfg, model, loggers = get_parameters(cfg) + callbacks = [] + for cb in cfg.callbacks: + callbacks.append(hydra.utils.instantiate(cb)) + + callbacks.append(RegularCheckpointing()) + # torch.use_deterministic_algorithms(True) + runner = Trainer( + logger=loggers, + accelerator='gpu', + devices=1, + callbacks=callbacks, + default_root_dir=str(cfg.general.save_dir), + **cfg.trainer, + ) + runner.fit(model, ckpt_path=cfg.general.ckpt_path) + + +@hydra.main(config_path="conf", config_name="config_panoptic_4d.yaml") +def validate(cfg: DictConfig): + # because hydra wants to change dir for some reason + os.chdir(hydra.utils.get_original_cwd()) + cfg, model, loggers = get_parameters(cfg) + runner = Trainer( + logger=loggers, + accelerator='gpu', + devices=1, + default_root_dir=str(cfg.general.save_dir), + ) + runner.validate(model=model, ckpt_path=cfg.general.ckpt_path) + + +@hydra.main(config_path="conf", config_name="config_panoptic_4d.yaml") +def test(cfg: DictConfig): + # because hydra wants to change dir for some reason + os.chdir(hydra.utils.get_original_cwd()) + cfg, model, loggers = get_parameters(cfg) + runner = Trainer( + logger=loggers, + accelerator='gpu', + devices=1, + default_root_dir=str(cfg.general.save_dir), + ) + runner.test(model=model, ckpt_path=cfg.general.ckpt_path) + +@hydra.main(config_path="conf", config_name="config_panoptic_4d.yaml") +def main(cfg: DictConfig): + if cfg['general']['mode'] == "train": + train(cfg) + elif cfg['general']['mode'] == "validate": + validate(cfg) + else: + test(cfg) + +if __name__ == "__main__": + main() diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..369c63f --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,35 @@ +import models.resunet as resunet +import models.res16unet as res16unet +from models.res16unet import Res16UNet34C, STRes16UNet34C +from models.mask4d import Mask4D + +MODELS = [] + + +def add_models(module): + MODELS.extend([getattr(module, a) for a in dir(module) if "Net" in a]) + + +add_models(resunet) +add_models(res16unet) + + +def get_models(): + """Returns a tuple of sample models.""" + return MODELS + + +def load_model(name): + """Creates and returns an instance of the model given its class name.""" + # Find the model class from its name + all_models = get_models() + mdict = {model.__name__: model for model in all_models} + if name not in mdict: + print("Invalid model index. Options are:") + # Display a list of valid model names + for model in all_models: + print(f"\t* {model.__name__}") + return None + NetClass = mdict[name] + + return NetClass diff --git a/models/criterion.py b/models/criterion.py new file mode 100644 index 0000000..26c9625 --- /dev/null +++ b/models/criterion.py @@ -0,0 +1,158 @@ +import torch +import torch.nn.functional as F +from torch import nn + +def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, num_masks: float): + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(-1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_masks + + +dice_loss_jit = torch.jit.script( + dice_loss +) # type: torch.jit.ScriptModule + + +def sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor, num_masks: float): + loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + return loss.mean(1).sum() / num_masks + + +sigmoid_ce_loss_jit = torch.jit.script( + sigmoid_ce_loss +) # type: torch.jit.ScriptModule + + +def box_loss(inputs: torch.Tensor, targets: torch.Tensor, num_bboxs: float): + loss = F.l1_loss(inputs, targets, reduction="none") + return loss.mean(1).sum() / num_bboxs + + +box_loss_jit = torch.jit.script( + box_loss +) # type: torch.jit.ScriptModule + + +class SetCriterion(nn.Module): + """This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses): + """Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + self.losses = losses + empty_weight = torch.ones(num_classes + 1) + empty_weight[-1] = self.eos_coef + + self.register_buffer("empty_weight", empty_weight) + + def loss_labels(self, outputs, targets, indices): + src_logits = outputs["pred_logits"].float() + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full( + src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device + ) + target_classes[idx] = target_classes_o + + loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight, ignore_index=255) + losses = {"loss_ce": loss_ce} + return losses + + def loss_masks(self, outputs, targets, indices): + loss_masks = [] + loss_dices = [] + + for batch_id, (map_id, target_id) in enumerate(indices): + map = outputs["pred_masks"][batch_id][:, map_id].T + target_mask = targets[batch_id]["masks"][target_id].float() + num_masks = target_mask.shape[0] + + loss_masks.append(sigmoid_ce_loss_jit(map, target_mask, num_masks)) + loss_dices.append(dice_loss_jit(map, target_mask, num_masks)) + return { + "loss_mask": torch.sum(torch.stack(loss_masks)), + "loss_dice": torch.sum(torch.stack(loss_dices)) + } + + def loss_bboxs(self, outputs, targets, indices): + loss_box = torch.zeros(1, device=outputs["pred_bboxs"].device) + for batch_id, (map_id, target_id) in enumerate(indices): + pred_bboxs = outputs["pred_bboxs"][batch_id, map_id, :] + target_bboxs = targets[batch_id]["bboxs"][target_id] + target_classes = targets[batch_id]["labels"][target_id] + keep_things = target_classes < 8 + if torch.any(keep_things): + target_bboxs = target_bboxs[keep_things] + pred_bboxs = pred_bboxs[keep_things] + num_bboxs = target_bboxs.shape[0] + loss_box += box_loss_jit(pred_bboxs, target_bboxs, num_bboxs) + return { + "loss_box": loss_box, + } + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices): + loss_map = { + 'labels': self.loss_labels, + 'masks': self.loss_masks, + 'bboxs': self.loss_bboxs + } + return loss_map[loss](outputs, targets, indices) + + def forward(self, outputs, targets): + """This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if "aux_outputs" in outputs: + for i, aux_outputs in enumerate(outputs["aux_outputs"]): + indices = self.matcher(aux_outputs, targets) + for loss in self.losses: + l_dict = self.get_loss(loss, aux_outputs, targets, indices) + l_dict = {k + f"_{i}": v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses diff --git a/models/mask4d.py b/models/mask4d.py new file mode 100644 index 0000000..32debbc --- /dev/null +++ b/models/mask4d.py @@ -0,0 +1,273 @@ +import torch +import hydra +import torch.nn as nn +import MinkowskiEngine.MinkowskiOps as me +from MinkowskiEngine.MinkowskiPooling import MinkowskiAvgPooling +from models.modules.common import conv +from models.position_embedding import PositionEmbeddingCoordsSine +from third_party.pointnet2.pointnet2_utils import furthest_point_sample +from models.modules.helpers_3detr import GenericMLP +from torch.cuda.amp import autocast +from models.modules.attention import CrossAttentionLayer, SelfAttentionLayer, FFNLayer + + +class Mask4D(nn.Module): + def __init__(self, backbone, num_queries, num_heads, num_decoders, num_levels, sample_sizes, + mask_dim, dim_feedforward, num_labels): + super().__init__() + self.backbone = hydra.utils.instantiate(backbone) + self.num_queries = num_queries + self.num_heads = num_heads + self.num_decoders = num_decoders + self.num_levels = num_levels + self.sample_sizes = sample_sizes + sizes = self.backbone.PLANES[-5:] + + self.point_features_head = conv( + self.backbone.PLANES[7], mask_dim, kernel_size=1, stride=1, bias=True, D=3 + ) + + self.query_projection = GenericMLP( + input_dim=mask_dim, + hidden_dims=[mask_dim], + output_dim=mask_dim, + use_conv=True, + output_use_activation=True, + hidden_use_bias=True, + ) + + self.mask_embed_head = nn.Sequential( + nn.Linear(mask_dim, mask_dim), + nn.ReLU(), + nn.Linear(mask_dim, mask_dim) + ) + + self.bbox_embed_head = nn.Sequential( + nn.Linear(mask_dim, mask_dim), + nn.ReLU(), + nn.Linear(mask_dim, mask_dim), + nn.ReLU(), + nn.Linear(mask_dim, 6), + nn.Sigmoid() + ) + + self.class_embed_head = nn.Linear(mask_dim, num_labels + 1) + self.pos_enc = PositionEmbeddingCoordsSine(d_pos=mask_dim) + self.temporal_pos_enc = PositionEmbeddingCoordsSine(d_in=1, d_pos=mask_dim) + self.pooling = MinkowskiAvgPooling(kernel_size=2, stride=2, dimension=3) + + self.cross_attention = nn.ModuleList() + self.self_attention = nn.ModuleList() + self.ffn_attention = nn.ModuleList() + self.lin_squeeze = nn.ModuleList() + + for hlevel in range(self.num_levels): + self.cross_attention.append( + CrossAttentionLayer( + d_model=mask_dim, + nhead=self.num_heads, + ) + ) + self.lin_squeeze.append(nn.Linear(sizes[hlevel], mask_dim)) + self.self_attention.append( + SelfAttentionLayer( + d_model=mask_dim, + nhead=self.num_heads, + ) + ) + self.ffn_attention.append( + FFNLayer( + d_model=mask_dim, + dim_feedforward=dim_feedforward, + ) + ) + + self.decoder_norm = nn.LayerNorm(mask_dim) + + def forward(self, x, raw_coordinates=None, is_eval=False): + device = x.device + all_features = self.backbone(x) + point_features = self.point_features_head(all_features[-1]) + + with torch.no_grad(): + coordinates = me.SparseTensor(features=raw_coordinates, coordinates=x.C, device=device) + pos_encodings_pcd = self.get_pos_encs(coordinates) + + sampled_coords = [] + mins = [] + maxs = [] + for coords, feats in zip(x.decomposed_coordinates, coordinates.decomposed_features): + fps_idx = furthest_point_sample(coords[None, ...].float(), self.num_queries).squeeze(0).long() + sampled_coords.append(feats[fps_idx, :3]) + mins.append(feats[:, :3].min(dim=0)[0]) + maxs.append(feats[:, :3].max(dim=0)[0]) + + sampled_coords = torch.stack(sampled_coords) + mins = torch.stack(mins) + maxs = torch.stack(maxs) + + query_pos = self.pos_enc(sampled_coords.float(),input_range=[mins, maxs]) # Batch, Dim, queries + query_pos = self.query_projection(query_pos) + + queries = torch.zeros_like(query_pos).permute((0, 2, 1)) + query_pos = query_pos.permute((2, 0, 1)) + + predictions_class = [] + predictions_bbox = [] + predictions_mask = [] + + for _ in range(self.num_decoders): + for hlevel in range(self.num_levels): + output_class, outputs_bbox, outputs_mask, attn_mask = self.mask_module(queries, + point_features, + self.num_levels - hlevel) + + decomposed_feat = all_features[hlevel].decomposed_features + decomposed_attn = attn_mask.decomposed_features + + pcd_sizes = [pcd.shape[0] for pcd in decomposed_feat] + curr_sample_size = max(pcd_sizes) + + if not is_eval: + curr_sample_size = min(curr_sample_size, self.sample_sizes[hlevel]) + + rand_idx, mask_idx = self.get_random_samples(pcd_sizes, curr_sample_size, device) + + batched_feat = torch.stack([ + feat[idx, :] for feat, idx in zip(decomposed_feat, rand_idx) + ]) + + batched_attn = torch.stack([ + attn[idx, :] for attn, idx in zip(decomposed_attn, rand_idx) + ]) + + batched_pos_enc = torch.stack([ + pos_enc[idx, :] for pos_enc, idx in zip(pos_encodings_pcd[hlevel], rand_idx) + ]) + + batched_attn.permute((0, 2, 1))[batched_attn.sum(1) == curr_sample_size] = False + + m = torch.stack(mask_idx) + batched_attn = torch.logical_or(batched_attn, m[..., None]) + + src_pcd = self.lin_squeeze[hlevel](batched_feat.permute((1, 0, 2))) + + output = self.cross_attention[hlevel]( + queries.permute((1, 0, 2)), + src_pcd, + memory_mask=batched_attn.repeat_interleave(self.num_heads, dim=0).permute((0, 2, 1)), + memory_key_padding_mask=None, # here we do not apply masking on padded region + pos=batched_pos_enc.permute((1, 0, 2)), + query_pos=query_pos + ) + + output = self.self_attention[hlevel]( + output, tgt_mask=None, + tgt_key_padding_mask=None, + query_pos=query_pos + ) + + # FFN + queries = self.ffn_attention[hlevel](output).permute((1, 0, 2)) + + predictions_class.append(output_class) + predictions_bbox.append(outputs_bbox) + predictions_mask.append(outputs_mask) + + output_class, outputs_bbox, outputs_mask = self.mask_module(queries, point_features) + predictions_class.append(output_class) + predictions_bbox.append(outputs_bbox) + predictions_mask.append(outputs_mask) + + return { + 'pred_logits': predictions_class[-1], + 'pred_bboxs': predictions_bbox[-1], + 'pred_masks': predictions_mask[-1], + 'aux_outputs': self._set_aux_loss( + predictions_class, predictions_bbox, predictions_mask + ) + } + + def mask_module(self, query_feat, point_features, num_pooling_steps=0): + query_feat = self.decoder_norm(query_feat) + mask_embed = self.mask_embed_head(query_feat) + outputs_class = self.class_embed_head(query_feat) + outputs_bbox = self.bbox_embed_head(query_feat) + + output_masks = [] + + for feat, embed in zip(point_features.decomposed_features, mask_embed): + output_masks.append(feat @ embed.T) + + output_masks = torch.cat(output_masks) + outputs_mask = me.SparseTensor(features=output_masks, + coordinate_manager=point_features.coordinate_manager, + coordinate_map_key=point_features.coordinate_map_key) + + if num_pooling_steps != 0: + attn_mask = outputs_mask + for _ in range(num_pooling_steps): + attn_mask = self.pooling(attn_mask.float()) + + attn_mask = me.SparseTensor(features=(attn_mask.F.detach().sigmoid() < 0.5), + coordinate_manager=attn_mask.coordinate_manager, + coordinate_map_key=attn_mask.coordinate_map_key) + + return outputs_class, outputs_bbox, outputs_mask.decomposed_features, attn_mask + + return outputs_class, outputs_bbox, outputs_mask.decomposed_features + + def get_pos_encs(self, coordinates): + pos_encodings_pcd = [] + + for _ in range(self.num_levels + 1): + pos_encodings_pcd.append([]) + + for coords_batch in coordinates.decomposed_features: + scene_min = coords_batch.min(dim=0)[0][None, ...] + scene_max = coords_batch.max(dim=0)[0][None, ...] + + with autocast(enabled=False): + tmp = self.pos_enc(coords_batch[None, :, :3].float(), + input_range=[scene_min[:, :3], scene_max[:, :3]]) + tmp += self.temporal_pos_enc(coords_batch[None, :, 3].float(), + input_range=[scene_min[:, 3:4], scene_max[:, 3:4]]) + + pos_encodings_pcd[-1].append(tmp.squeeze(0).permute((1, 0))) + + coordinates = self.pooling(coordinates) + + pos_encodings_pcd.reverse() + + return pos_encodings_pcd + + def get_random_samples(self, pcd_sizes, curr_sample_size, device): + rand_idx = [] + mask_idx = [] + for pcd_size in pcd_sizes: + if pcd_size <= curr_sample_size: + # we do not need to sample + # take all points and pad the rest with zeroes and mask it + idx = torch.zeros(curr_sample_size, dtype=torch.long, device=device) + midx = torch.ones(curr_sample_size, dtype=torch.bool, device=device) + idx[:pcd_size] = torch.arange(pcd_size, device=device) + midx[:pcd_size] = False # attend to first points + else: + # we have more points in pcd as we like to sample + # take a subset (no padding or masking needed) + idx = torch.randperm(pcd_size, device=device)[:curr_sample_size] + midx = torch.zeros(curr_sample_size, dtype=torch.bool, device=device) + + rand_idx.append(idx) + mask_idx.append(midx) + return rand_idx, mask_idx + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_bbox, outputs_seg_masks): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [ + {"pred_logits": a, "pred_bboxs": b, "pred_masks": c} + for a, b, c in zip(outputs_class[:-1], outputs_bbox[:-1], outputs_seg_masks[:-1]) + ] diff --git a/models/matcher.py b/models/matcher.py new file mode 100644 index 0000000..dc90b33 --- /dev/null +++ b/models/matcher.py @@ -0,0 +1,144 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py +""" +Modules to compute the matching cost and solve the corresponding LSAP. +""" +import torch +import torch.nn.functional as F +from scipy.optimize import linear_sum_assignment +from torch import nn +from torch.cuda.amp import autocast + + +def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor): + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) + denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + +batch_dice_loss_jit = torch.jit.script( + batch_dice_loss +) # type: torch.jit.ScriptModule + + +def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor): + """ + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + Returns: + Loss tensor + """ + hw = inputs.shape[1] + + pos = F.binary_cross_entropy_with_logits( + inputs, torch.ones_like(inputs), reduction="none" + ) + neg = F.binary_cross_entropy_with_logits( + inputs, torch.zeros_like(inputs), reduction="none" + ) + + loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum( + "nc,mc->nm", neg, (1 - targets) + ) + + return loss / hw + + +batch_sigmoid_ce_loss_jit = torch.jit.script( + batch_sigmoid_ce_loss +) # type: torch.jit.ScriptModule + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, + cost_box: float = 1): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost + cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_mask = cost_mask + self.cost_dice = cost_dice + self.cost_box = cost_box + + @torch.no_grad() + def forward(self, outputs, targets): + """Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + bs, num_queries = outputs["pred_logits"].shape[:2] + + indices = [] + + # Iterate through batch size + for b in range(bs): + + out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes] + tgt_ids = targets[b]["labels"] + + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be ommitted. + cost_class = -out_prob[:, tgt_ids] + + out_mask = outputs['pred_masks'][b].T # [num_queries, H_pred, W_pred] + # gt masks are already padded when preparing target + tgt_mask = targets[b]["masks"].to(out_mask) + + with autocast(enabled=False): + out_mask = out_mask.float() + tgt_mask = tgt_mask.float() + # Compute the focal loss between masks + cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask) + + # Compute the dice loss betwen masks + cost_dice = batch_dice_loss_jit(out_mask, tgt_mask) + + # Final cost matrix + C = ( + self.cost_mask * cost_mask + + self.cost_class * cost_class + + self.cost_dice * cost_dice + ) + C = C.reshape(num_queries, -1).cpu() + + indices.append(linear_sum_assignment(C)) + + return [ + (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) + for i, j in indices + ] diff --git a/models/metrics/__init__.py b/models/metrics/__init__.py new file mode 100644 index 0000000..94f2156 --- /dev/null +++ b/models/metrics/__init__.py @@ -0,0 +1 @@ +from .panoptic_quality import Panoptic4DEval \ No newline at end of file diff --git a/models/metrics/panoptic_quality.py b/models/metrics/panoptic_quality.py new file mode 100644 index 0000000..6e09ca6 --- /dev/null +++ b/models/metrics/panoptic_quality.py @@ -0,0 +1,146 @@ +import numpy as np +import math + + +class Panoptic4DEval: + def __init__(self, n_classes, min_stuff_cls_id, ignore=0, offset=2**32, min_points=50): + self.n_classes = n_classes + 1 + self.ignore = ignore + self.include = np.array([n for n in range(self.n_classes) if n != self.ignore], dtype=np.int64) + self.min_stuff_cls_id = min_stuff_cls_id + self.reset() + self.offset = offset # largest number of instances in a given scan + self.min_points = min_points # smallest number of points to consider instances in gt + self.eps = 1e-15 + + def reset(self): + # iou stuff + self.px_iou_conf_matrix = np.zeros((self.n_classes, self.n_classes), dtype=np.int64) + self.sequences = [] + self.preds = {} + self.gts = {} + self.intersects = {} + + def addBatchSemIoU(self, x_sem, y_sem): + # idxs are labels and predictions + idxs = np.stack([x_sem, y_sem], axis=0) + + # make confusion matrix (cols = gt, rows = pred) + np.add.at(self.px_iou_conf_matrix, tuple(idxs), 1) + + def getSemIoUStats(self): + conf = self.px_iou_conf_matrix.copy().astype(np.double) + conf[:, self.ignore] = 0 + + # get the clean stats + tp = conf.diagonal() + fp = conf.sum(axis=1) - tp + fn = conf.sum(axis=0) - tp + return tp, fp, fn + + def getSemIoU(self): + tp, fp, fn = self.getSemIoUStats() + intersection = tp + union = tp + fp + fn + union = np.maximum(union, self.eps) + iou = intersection[self.include].astype(np.double) / union[self.include].astype(np.double) + iou_mean = iou.mean() + + return iou_mean, iou + + def update_dict_stat(self, stat_dict, unique_ids, unique_cnts): + for uniqueid, counts in zip(unique_ids, unique_cnts): + if uniqueid == 1: continue # 1 -- no instance + if uniqueid in stat_dict: + stat_dict[uniqueid] += counts + else: + stat_dict[uniqueid] = counts + + def addBatchPanoptic4D(self, seq, x_sem_row, x_inst_row, y_sem_row, y_inst_row): + if seq not in self.sequences: + self.sequences.append(seq) + self.preds[seq] = {} + self.gts[seq] = [{} for i in range(self.n_classes)] + self.intersects[seq] = [{} for i in range(self.n_classes)] + + # make sure instances are not zeros (it messes with my approach) + x_inst_row = x_inst_row + 1 + y_inst_row = y_inst_row + 1 + + preds = self.preds[seq] + # generate the areas for each unique instance prediction (i.e., set1) + unique_pred, counts_pred = np.unique(x_inst_row, return_counts=True) + self.update_dict_stat(preds, unique_pred, counts_pred) + + for cl in self.include: + # Per-class accumulated stats + cl_gts = self.gts[seq][cl] + cl_intersects = self.intersects[seq][cl] + + # get a binary class mask (filter acc. to semantic class!) + y_inst_in_cl_mask = y_sem_row == cl + + # get instance points in class (mask-out everything but _this_ class) + y_inst_in_cl = y_inst_row * y_inst_in_cl_mask.astype(np.int64) + + # generate the areas for each unique instance gt_np (i.e., set2) + unique_gt, counts_gt = np.unique(y_inst_in_cl[y_inst_in_cl > 0], return_counts=True) + self.update_dict_stat(cl_gts, unique_gt[counts_gt>self.min_points], counts_gt[counts_gt>self.min_points]) + y_inst_in_cl[np.isin(y_inst_in_cl, unique_gt[counts_gt<=self.min_points])] = 0 + + # generate intersection using offset + offset_combo = x_inst_row[y_inst_in_cl > 0] + self.offset * y_inst_in_cl[y_inst_in_cl > 0] + unique_combo, counts_combo = np.unique(offset_combo, return_counts=True) + + self.update_dict_stat(cl_intersects, unique_combo, counts_combo) + + def getPQ4D(self): + pan_aq = np.zeros(self.n_classes, dtype=np.double) + pan_aq_ovr = 0.0 + num_tubes = [0] * self.n_classes + + for seq in self.sequences: + preds = self.preds[seq] + for cl in range(self.n_classes): + cl_gts = self.gts[seq][cl] + cl_intersects = self.intersects[seq][cl] + outer_sum_iou = 0.0 + for gt_id, gt_size in cl_gts.items(): + num_tubes[cl] += 1 + inner_sum_iou = 0.0 + for pr_id, pr_size in preds.items(): + TPA_key = pr_id + self.offset * gt_id + if TPA_key in cl_intersects: + TPA_ovr = cl_intersects[TPA_key] + inner_sum_iou += TPA_ovr * (TPA_ovr / (gt_size + pr_size - TPA_ovr)) + outer_sum_iou += float(inner_sum_iou) / float(gt_size) + pan_aq[cl] += outer_sum_iou + pan_aq_ovr += outer_sum_iou + + AQ_overall = np.sum(pan_aq_ovr)/ np.sum(num_tubes[1:self.min_stuff_cls_id]) + AQ = pan_aq / np.maximum(num_tubes, self.eps) + + iou_mean, iou = self.getSemIoU() + + PQ4D = math.sqrt(AQ_overall*iou_mean) + return PQ4D, AQ_overall, AQ[self.include], iou_mean, iou + + def addBatch(self, x_sem, x_inst, y_sem, y_inst, indices, seq): # x=preds, y=targets + x_sem = x_sem[indices] + x_inst = x_inst[indices] + y_sem = y_sem[indices] + y_inst = y_inst[indices] + + # only interested in points that are outside the void area (not in excluded classes) + gt_not_in_excl_mask = y_sem != self.ignore + # remove all other points + x_sem = x_sem[gt_not_in_excl_mask] + y_sem = y_sem[gt_not_in_excl_mask] + x_inst = x_inst[gt_not_in_excl_mask] + y_inst = y_inst[gt_not_in_excl_mask] + + # add to IoU calculation (for checking purposes) + self.addBatchSemIoU(x_sem, y_sem) + + # now do the panoptic stuff + self.addBatchPanoptic4D(seq, x_sem, x_inst, y_sem, y_inst) \ No newline at end of file diff --git a/models/model.py b/models/model.py new file mode 100644 index 0000000..9d2dfa2 --- /dev/null +++ b/models/model.py @@ -0,0 +1,27 @@ +from MinkowskiEngine import MinkowskiNetwork + + +class Model(MinkowskiNetwork): + """ + Base network for all sparse convnet + + By default, all networks are segmentation networks. + """ + + OUT_PIXEL_DIST = -1 + + def __init__(self, in_channels, out_channels, config, D, **kwargs): + super().__init__(D) + self.in_channels = in_channels + self.out_channels = out_channels + self.config = config + + +class HighDimensionalModel(Model): + """ + Base network for all spatio (temporal) chromatic sparse convnet + """ + + def __init__(self, in_channels, out_channels, config, D, **kwargs): + assert D > 4, "Num dimension smaller than 5" + super().__init__(in_channels, out_channels, config, D, **kwargs) diff --git a/models/modules/__init__.py b/models/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/modules/attention.py b/models/modules/attention.py new file mode 100644 index 0000000..b3b7382 --- /dev/null +++ b/models/modules/attention.py @@ -0,0 +1,115 @@ +import torch.nn as nn +from torch.nn import functional as F + + +class SelfAttentionLayer(nn.Module): + + def __init__(self, d_model, nhead, dropout=0.0, activation="relu"): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def forward(self, tgt, + tgt_mask = None, + tgt_key_padding_mask= None, + query_pos = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + + return tgt + + +class CrossAttentionLayer(nn.Module): + + def __init__(self, d_model, nhead, dropout=0.0, activation="relu"): + super().__init__() + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def forward(self, tgt, memory, + memory_mask = None, + memory_key_padding_mask = None, + pos = None, + query_pos = None): + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + + return tgt + + +class FFNLayer(nn.Module): + + def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, activation="relu"): + super().__init__() + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm = nn.LayerNorm(d_model) + + self.activation = _get_activation_fn(activation) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def forward(self, tgt): + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + diff --git a/models/modules/common.py b/models/modules/common.py new file mode 100644 index 0000000..1678209 --- /dev/null +++ b/models/modules/common.py @@ -0,0 +1,258 @@ +import sys + +if sys.version_info[:2] >= (3, 8): + from collections.abc import Sequence +else: + from collections import Sequence + +from enum import Enum + +import torch.nn as nn +import MinkowskiEngine as ME + + +class NormType(Enum): + BATCH_NORM = 0 + INSTANCE_NORM = 1 + INSTANCE_BATCH_NORM = 2 + + +def get_norm(norm_type, n_channels, D, bn_momentum=0.1): + if norm_type == NormType.BATCH_NORM: + return ME.MinkowskiBatchNorm(n_channels, momentum=bn_momentum) + elif norm_type == NormType.INSTANCE_NORM: + return ME.MinkowskiInstanceNorm(n_channels) + elif norm_type == NormType.INSTANCE_BATCH_NORM: + return nn.Sequential( + ME.MinkowskiInstanceNorm(n_channels), + ME.MinkowskiBatchNorm(n_channels, momentum=bn_momentum), + ) + else: + raise ValueError(f"Norm type: {norm_type} not supported") + + +class ConvType(Enum): + """ + Define the kernel region type + """ + + HYPERCUBE = 0, "HYPERCUBE" + SPATIAL_HYPERCUBE = 1, "SPATIAL_HYPERCUBE" + SPATIO_TEMPORAL_HYPERCUBE = 2, "SPATIO_TEMPORAL_HYPERCUBE" + HYPERCROSS = 3, "HYPERCROSS" + SPATIAL_HYPERCROSS = 4, "SPATIAL_HYPERCROSS" + SPATIO_TEMPORAL_HYPERCROSS = 5, "SPATIO_TEMPORAL_HYPERCROSS" + SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS = 6, "SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS " + + def __new__(cls, value, name): + member = object.__new__(cls) + member._value_ = value + member.fullname = name + return member + + def __int__(self): + return self.value + + +# Covert the ConvType var to a RegionType var +conv_to_region_type = { + # kernel_size = [k, k, k, 1] + ConvType.HYPERCUBE: ME.RegionType.HYPER_CUBE, + ConvType.SPATIAL_HYPERCUBE: ME.RegionType.HYPER_CUBE, + ConvType.SPATIO_TEMPORAL_HYPERCUBE: ME.RegionType.HYPER_CUBE, + ConvType.HYPERCROSS: ME.RegionType.HYPER_CROSS, + ConvType.SPATIAL_HYPERCROSS: ME.RegionType.HYPER_CROSS, + ConvType.SPATIO_TEMPORAL_HYPERCROSS: ME.RegionType.HYPER_CROSS, + ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS: ME.RegionType.HYPER_CUBE, # JONAS CHANGE from HYBRID +} + +# int_to_region_type = {m.value: m for m in ME.RegionType} +int_to_region_type = {m: ME.RegionType(m) for m in range(3)} + + +def convert_region_type(region_type): + """ + Convert the integer region_type to the corresponding RegionType enum object. + """ + return int_to_region_type[region_type] + + +def convert_conv_type(conv_type, kernel_size, D): + assert isinstance(conv_type, ConvType), "conv_type must be of ConvType" + region_type = conv_to_region_type[conv_type] + axis_types = None + if conv_type == ConvType.SPATIAL_HYPERCUBE: + # No temporal convolution + if isinstance(kernel_size, Sequence): + kernel_size = kernel_size[:3] + else: + kernel_size = [ + kernel_size, + ] * 3 + if D == 4: + kernel_size.append(1) + elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCUBE: + # conv_type conversion already handled + assert D == 4 + elif conv_type == ConvType.HYPERCUBE: + # conv_type conversion already handled + pass + elif conv_type == ConvType.SPATIAL_HYPERCROSS: + if isinstance(kernel_size, Sequence): + kernel_size = kernel_size[:3] + else: + kernel_size = [ + kernel_size, + ] * 3 + if D == 4: + kernel_size.append(1) + elif conv_type == ConvType.HYPERCROSS: + # conv_type conversion already handled + pass + elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCROSS: + # conv_type conversion already handled + assert D == 4 + elif conv_type == ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS: + # Define the CUBIC conv kernel for spatial dims and CROSS conv for temp dim + axis_types = [ + ME.RegionType.HYPER_CUBE, + ] * 3 + if D == 4: + axis_types.append(ME.RegionType.HYPER_CROSS) + return region_type, axis_types, kernel_size + + +def conv( + in_planes, + out_planes, + kernel_size, + stride=1, + dilation=1, + bias=False, + conv_type=ConvType.HYPERCUBE, + D=-1, +): + assert D > 0, "Dimension must be a positive integer" + region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) + kernel_generator = ME.KernelGenerator( + kernel_size, + stride, + dilation, + region_type=region_type, + axis_types=None, # axis_types JONAS + dimension=D, + ) + + return ME.MinkowskiConvolution( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + bias=bias, + kernel_generator=kernel_generator, + dimension=D, + ) + + +def conv_tr( + in_planes, + out_planes, + kernel_size, + upsample_stride=1, + dilation=1, + bias=False, + conv_type=ConvType.HYPERCUBE, + D=-1, +): + assert D > 0, "Dimension must be a positive integer" + region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) + kernel_generator = ME.KernelGenerator( + kernel_size, + upsample_stride, + dilation, + region_type=region_type, + axis_types=axis_types, + dimension=D, + ) + + return ME.MinkowskiConvolutionTranspose( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=kernel_size, + stride=upsample_stride, + dilation=dilation, + bias=bias, + kernel_generator=kernel_generator, + dimension=D, + ) + + +def avg_pool( + kernel_size, + stride=1, + dilation=1, + conv_type=ConvType.HYPERCUBE, + in_coords_key=None, + D=-1, +): + assert D > 0, "Dimension must be a positive integer" + region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) + kernel_generator = ME.KernelGenerator( + kernel_size, + stride, + dilation, + region_type=region_type, + axis_types=axis_types, + dimension=D, + ) + + return ME.MinkowskiAvgPooling( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + kernel_generator=kernel_generator, + dimension=D, + ) + + +def avg_unpool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, D=-1): + assert D > 0, "Dimension must be a positive integer" + region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) + kernel_generator = ME.KernelGenerator( + kernel_size, + stride, + dilation, + region_type=region_type, + axis_types=axis_types, + dimension=D, + ) + + return ME.MinkowskiAvgUnpooling( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + kernel_generator=kernel_generator, + dimension=D, + ) + + +def sum_pool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, D=-1): + assert D > 0, "Dimension must be a positive integer" + region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) + kernel_generator = ME.KernelGenerator( + kernel_size, + stride, + dilation, + region_type=region_type, + axis_types=axis_types, + dimension=D, + ) + + return ME.MinkowskiSumPooling( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + kernel_generator=kernel_generator, + dimension=D, + ) diff --git a/models/modules/helpers_3detr.py b/models/modules/helpers_3detr.py new file mode 100644 index 0000000..fe73bf7 --- /dev/null +++ b/models/modules/helpers_3detr.py @@ -0,0 +1,116 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch.nn as nn +from functools import partial +import copy + + +class BatchNormDim1Swap(nn.BatchNorm1d): + """ + Used for nn.Transformer that uses a HW x N x C rep + """ + + def forward(self, x): + """ + x: HW x N x C + permute to N x C x HW + Apply BN on C + permute back + """ + hw, n, c = x.shape + x = x.permute(1, 2, 0) + x = super(BatchNormDim1Swap, self).forward(x) + # x: n x c x hw -> hw x n x c + x = x.permute(2, 0, 1) + return x + + +NORM_DICT = { + "bn": BatchNormDim1Swap, + "bn1d": nn.BatchNorm1d, + "id": nn.Identity, + "ln": nn.LayerNorm, +} + +ACTIVATION_DICT = { + "relu": nn.ReLU, + "gelu": nn.GELU, + "leakyrelu": partial(nn.LeakyReLU, negative_slope=0.1), +} + +WEIGHT_INIT_DICT = { + "xavier_uniform": nn.init.xavier_uniform_, +} + + +class GenericMLP(nn.Module): + def __init__( + self, + input_dim, + hidden_dims, + output_dim, + norm_fn_name=None, + activation="relu", + use_conv=False, + dropout=None, + hidden_use_bias=False, + output_use_bias=True, + output_use_activation=False, + output_use_norm=False, + weight_init_name=None, + ): + super().__init__() + activation = ACTIVATION_DICT[activation] + norm = None + if norm_fn_name is not None: + norm = NORM_DICT[norm_fn_name] + if norm_fn_name == "ln" and use_conv: + norm = lambda x: nn.GroupNorm(1, x) # easier way to use LayerNorm + + if dropout is not None: + if not isinstance(dropout, list): + dropout = [dropout for _ in range(len(hidden_dims))] + + layers = [] + prev_dim = input_dim + for idx, x in enumerate(hidden_dims): + if use_conv: + layer = nn.Conv1d(prev_dim, x, 1, bias=hidden_use_bias) + else: + layer = nn.Linear(prev_dim, x, bias=hidden_use_bias) + layers.append(layer) + if norm: + layers.append(norm(x)) + layers.append(activation()) + if dropout is not None: + layers.append(nn.Dropout(p=dropout[idx])) + prev_dim = x + if use_conv: + layer = nn.Conv1d(prev_dim, output_dim, 1, bias=output_use_bias) + else: + layer = nn.Linear(prev_dim, output_dim, bias=output_use_bias) + layers.append(layer) + + if output_use_norm: + layers.append(norm(output_dim)) + + if output_use_activation: + layers.append(activation()) + + self.layers = nn.Sequential(*layers) + + if weight_init_name is not None: + self.do_weight_init(weight_init_name) + + def do_weight_init(self, weight_init_name): + func = WEIGHT_INIT_DICT[weight_init_name] + for (_, param) in self.named_parameters(): + if param.dim() > 1: # skips batchnorm/layernorm + func(param) + + def forward(self, x): + output = self.layers(x) + return output + + +def get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) \ No newline at end of file diff --git a/models/modules/resnet_block.py b/models/modules/resnet_block.py new file mode 100644 index 0000000..0ffafaa --- /dev/null +++ b/models/modules/resnet_block.py @@ -0,0 +1,149 @@ +import torch.nn as nn +from MinkowskiEngine import MinkowskiReLU + +from models.modules.common import ConvType, NormType, conv, get_norm + + +class BasicBlockBase(nn.Module): + expansion = 1 + NORM_TYPE = NormType.BATCH_NORM + + def __init__( + self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + conv_type=ConvType.HYPERCUBE, + bn_momentum=0.1, + D=3, + ): + super().__init__() + + self.conv1 = conv( + inplanes, + planes, + kernel_size=3, + stride=stride, + dilation=dilation, + conv_type=conv_type, + D=D, + ) + self.norm1 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) + self.conv2 = conv( + planes, + planes, + kernel_size=3, + stride=1, + dilation=dilation, + bias=False, + conv_type=conv_type, + D=D, + ) + self.norm2 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) + self.relu = MinkowskiReLU(inplace=True) + self.downsample = downsample + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class BasicBlock(BasicBlockBase): + NORM_TYPE = NormType.BATCH_NORM + + +class BasicBlockIN(BasicBlockBase): + NORM_TYPE = NormType.INSTANCE_NORM + + +class BasicBlockINBN(BasicBlockBase): + NORM_TYPE = NormType.INSTANCE_BATCH_NORM + + +class BottleneckBase(nn.Module): + expansion = 4 + NORM_TYPE = NormType.BATCH_NORM + + def __init__( + self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + conv_type=ConvType.HYPERCUBE, + bn_momentum=0.1, + D=3, + ): + super().__init__() + self.conv1 = conv(inplanes, planes, kernel_size=1, D=D) + self.norm1 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) + + self.conv2 = conv( + planes, + planes, + kernel_size=3, + stride=stride, + dilation=dilation, + conv_type=conv_type, + D=D, + ) + self.norm2 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) + + self.conv3 = conv(planes, planes * self.expansion, kernel_size=1, D=D) + self.norm3 = get_norm( + self.NORM_TYPE, planes * self.expansion, D, bn_momentum=bn_momentum + ) + + self.relu = MinkowskiReLU(inplace=True) + self.downsample = downsample + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(BottleneckBase): + NORM_TYPE = NormType.BATCH_NORM + + +class BottleneckIN(BottleneckBase): + NORM_TYPE = NormType.INSTANCE_NORM + + +class BottleneckINBN(BottleneckBase): + NORM_TYPE = NormType.INSTANCE_BATCH_NORM diff --git a/models/position_embedding.py b/models/position_embedding.py new file mode 100644 index 0000000..8ad0417 --- /dev/null +++ b/models/position_embedding.py @@ -0,0 +1,65 @@ +import torch +from torch import nn +import numpy as np + + +def shift_scale_points(pred_xyz, input_range): + """ + pred_xyz: B x N x 3 + input_range: [[B x 3], [B x 3]] - min and max XYZ coords + dst_range: [[B x 3], [B x 3]] - min and max XYZ coords + """ + dst_range = [ + torch.zeros_like(input_range[0], device=input_range[0].device), + torch.ones_like(input_range[0], device=input_range[0].device), + ] + + src_diff = input_range[1][:, None, :] - input_range[0][:, None, :] + dst_diff = dst_range[1][:, None, :] - dst_range[0][:, None, :] + prop_xyz = ( + ((pred_xyz - input_range[0][:, None, :]) * dst_diff) / src_diff + ) + dst_range[0][:, None, :] + return prop_xyz + + +class PositionEmbeddingCoordsSine(nn.Module): + def __init__( + self, + d_in=3, + d_pos=None, + normalize=True, + ): + super().__init__() + self.d_in = d_in + self.d_pos = d_pos + self.normalize = normalize + + # define a gaussian matrix input_ch -> output_ch + B = torch.empty((d_in, d_pos // 2)).normal_() + self.register_buffer("gauss_B", B) + + @torch.no_grad() + def forward(self, xyz, num_channels=None, input_range=None): + # xyz is batch x npoints x 3 + if num_channels is None: + num_channels = self.gauss_B.shape[1] * 2 + + bsize, npoints = xyz.shape[0], xyz.shape[1] + d_out = num_channels // 2 + + # clone coords so that shift/scale operations do not affect original tensor + orig_xyz = xyz + xyz = orig_xyz.clone() + + if self.normalize: + xyz = shift_scale_points(xyz, input_range=input_range) + + xyz *= 2 * np.pi + xyz_proj = torch.mm(xyz.view(-1, self.d_in), self.gauss_B[:, :d_out]).view( + bsize, npoints, d_out + ) + final_embeds = [xyz_proj.sin(), xyz_proj.cos()] + + # return batch x d_pos x npoints embedding + final_embeds = torch.cat(final_embeds, dim=2).permute(0, 2, 1) + return final_embeds diff --git a/models/res16unet.py b/models/res16unet.py new file mode 100644 index 0000000..8651702 --- /dev/null +++ b/models/res16unet.py @@ -0,0 +1,425 @@ +import MinkowskiEngine.MinkowskiOps as me +from MinkowskiEngine import MinkowskiReLU + +from models.resnet import ResNetBase, get_norm +from models.modules.common import ConvType, NormType, conv, conv_tr +from models.modules.resnet_block import BasicBlock, Bottleneck + + +class Res16UNetBase(ResNetBase): + BLOCK = None + PLANES = (32, 64, 128, 256, 256, 256, 256, 256) + DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1) + LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) + INIT_DIM = 32 + OUT_PIXEL_DIST = 1 + NORM_TYPE = NormType.BATCH_NORM + NON_BLOCK_CONV_TYPE = ConvType.SPATIAL_HYPERCUBE + CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS + + # To use the model, must call initialize_coords before forward pass. + # Once data is processed, call clear to reset the model before calling initialize_coords + def __init__(self, in_channels, out_channels, config, D=3, **kwargs): + super().__init__(in_channels, out_channels, config, D) + + def network_initialization(self, in_channels, out_channels, config, D): + # Setup net_metadata + dilations = self.DILATIONS + bn_momentum = config.bn_momentum + + def space_n_time_m(n, m): + return n if D == 3 else [n, n, n, m] + + if D == 4: + self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1) + + # Output of the first conv concated to conv6 + self.inplanes = self.INIT_DIM + self.conv0p1s1 = conv( + in_channels, + self.inplanes, + kernel_size=space_n_time_m(config.conv1_kernel_size, 1), + stride=1, + dilation=1, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + + self.bn0 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) + + self.conv1p1s2 = conv( + self.inplanes, + self.inplanes, + kernel_size=space_n_time_m(2, 1), + stride=space_n_time_m(2, 1), + dilation=1, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bn1 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) + self.block1 = self._make_layer( + self.BLOCK, + self.PLANES[0], + self.LAYERS[0], + dilation=dilations[0], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + + self.conv2p2s2 = conv( + self.inplanes, + self.inplanes, + kernel_size=space_n_time_m(2, 1), + stride=space_n_time_m(2, 1), + dilation=1, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bn2 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) + self.block2 = self._make_layer( + self.BLOCK, + self.PLANES[1], + self.LAYERS[1], + dilation=dilations[1], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + + self.conv3p4s2 = conv( + self.inplanes, + self.inplanes, + kernel_size=space_n_time_m(2, 1), + stride=space_n_time_m(2, 1), + dilation=1, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bn3 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) + self.block3 = self._make_layer( + self.BLOCK, + self.PLANES[2], + self.LAYERS[2], + dilation=dilations[2], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + + self.conv4p8s2 = conv( + self.inplanes, + self.inplanes, + kernel_size=space_n_time_m(2, 1), + stride=space_n_time_m(2, 1), + dilation=1, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bn4 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) + self.block4 = self._make_layer( + self.BLOCK, + self.PLANES[3], + self.LAYERS[3], + dilation=dilations[3], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + self.convtr4p16s2 = conv_tr( + self.inplanes, + self.PLANES[4], + kernel_size=space_n_time_m(2, 1), + upsample_stride=space_n_time_m(2, 1), + dilation=1, + bias=False, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bntr4 = get_norm( + self.NORM_TYPE, self.PLANES[4], D, bn_momentum=bn_momentum + ) + + self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion + self.block5 = self._make_layer( + self.BLOCK, + self.PLANES[4], + self.LAYERS[4], + dilation=dilations[4], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + self.convtr5p8s2 = conv_tr( + self.inplanes, + self.PLANES[5], + kernel_size=space_n_time_m(2, 1), + upsample_stride=space_n_time_m(2, 1), + dilation=1, + bias=False, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bntr5 = get_norm( + self.NORM_TYPE, self.PLANES[5], D, bn_momentum=bn_momentum + ) + + self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion + self.block6 = self._make_layer( + self.BLOCK, + self.PLANES[5], + self.LAYERS[5], + dilation=dilations[5], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + self.convtr6p4s2 = conv_tr( + self.inplanes, + self.PLANES[6], + kernel_size=space_n_time_m(2, 1), + upsample_stride=space_n_time_m(2, 1), + dilation=1, + bias=False, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bntr6 = get_norm( + self.NORM_TYPE, self.PLANES[6], D, bn_momentum=bn_momentum + ) + + self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion + self.block7 = self._make_layer( + self.BLOCK, + self.PLANES[6], + self.LAYERS[6], + dilation=dilations[6], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + self.convtr7p2s2 = conv_tr( + self.inplanes, + self.PLANES[7], + kernel_size=space_n_time_m(2, 1), + upsample_stride=space_n_time_m(2, 1), + dilation=1, + bias=False, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bntr7 = get_norm( + self.NORM_TYPE, self.PLANES[7], D, bn_momentum=bn_momentum + ) + + self.inplanes = self.PLANES[7] + self.INIT_DIM + self.block8 = self._make_layer( + self.BLOCK, + self.PLANES[7], + self.LAYERS[7], + dilation=dilations[7], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + + # self.final = conv( + # self.PLANES[7], out_channels, kernel_size=1, stride=1, bias=True, D=D + # ) + self.relu = MinkowskiReLU(inplace=True) + + def forward(self, x): + feature_maps = [] + + out = self.conv0p1s1(x) + out = self.bn0(out) + out_p1 = self.relu(out) + + out = self.conv1p1s2(out_p1) + out = self.bn1(out) + out = self.relu(out) + out_b1p2 = self.block1(out) + + out = self.conv2p2s2(out_b1p2) + out = self.bn2(out) + out = self.relu(out) + out_b2p4 = self.block2(out) + + out = self.conv3p4s2(out_b2p4) + out = self.bn3(out) + out = self.relu(out) + out_b3p8 = self.block3(out) + + # pixel_dist=16 + out = self.conv4p8s2(out_b3p8) + out = self.bn4(out) + out = self.relu(out) + out = self.block4(out) + + feature_maps.append(out) + + # pixel_dist=8 + out = self.convtr4p16s2(out) + out = self.bntr4(out) + out = self.relu(out) + + out = me.cat(out, out_b3p8) + out = self.block5(out) + + feature_maps.append(out) + + # pixel_dist=4 + out = self.convtr5p8s2(out) + out = self.bntr5(out) + out = self.relu(out) + + out = me.cat(out, out_b2p4) + out = self.block6(out) + + feature_maps.append(out) + + # pixel_dist=2 + out = self.convtr6p4s2(out) + out = self.bntr6(out) + out = self.relu(out) + + out = me.cat(out, out_b1p2) + out = self.block7(out) + + feature_maps.append(out) + + # pixel_dist=1 + out = self.convtr7p2s2(out) + out = self.bntr7(out) + out = self.relu(out) + + out = me.cat(out, out_p1) + out = self.block8(out) + + feature_maps.append(out) + + return feature_maps + + +class Res16UNet14(Res16UNetBase): + BLOCK = BasicBlock + LAYERS = (1, 1, 1, 1, 1, 1, 1, 1) + + +class Res16UNet18(Res16UNetBase): + BLOCK = BasicBlock + LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) + + +class Res16UNet34(Res16UNetBase): + BLOCK = BasicBlock + LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) + + +class Res16UNet50(Res16UNetBase): + BLOCK = Bottleneck + LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) + + +class Res16UNet101(Res16UNetBase): + BLOCK = Bottleneck + LAYERS = (2, 3, 4, 23, 2, 2, 2, 2) + + +class Res16UNet14A(Res16UNet14): + PLANES = (32, 64, 128, 256, 128, 128, 96, 96) + + +class Res16UNet14A2(Res16UNet14A): + LAYERS = (1, 1, 1, 1, 2, 2, 2, 2) + + +class Res16UNet14B(Res16UNet14): + PLANES = (32, 64, 128, 256, 128, 128, 128, 128) + + +class Res16UNet14B2(Res16UNet14B): + LAYERS = (1, 1, 1, 1, 2, 2, 2, 2) + + +class Res16UNet14B3(Res16UNet14B): + LAYERS = (2, 2, 2, 2, 1, 1, 1, 1) + + +class Res16UNet14C(Res16UNet14): + PLANES = (32, 64, 128, 256, 192, 192, 128, 128) + + +class Res16UNet14D(Res16UNet14): + PLANES = (32, 64, 128, 256, 384, 384, 384, 384) + + +class Res16UNet18A(Res16UNet18): + PLANES = (32, 64, 128, 256, 128, 128, 96, 96) + + +class Res16UNet18B(Res16UNet18): + PLANES = (32, 64, 128, 256, 128, 128, 128, 128) + + +class Res16UNet18D(Res16UNet18): + PLANES = (32, 64, 128, 256, 384, 384, 384, 384) + + +class Res16UNet34A(Res16UNet34): + PLANES = (32, 64, 128, 256, 256, 128, 64, 64) + + +class Res16UNet34B(Res16UNet34): + PLANES = (32, 64, 128, 256, 256, 128, 64, 32) + + +class Res16UNet34C(Res16UNet34): + PLANES = (32, 64, 128, 256, 256, 128, 96, 96) + +class Custom30M(Res16UNet34): + PLANES = (32, 64, 128, 256, 128, 64, 64, 32) + +class Res16UNet34D(Res16UNet34): + PLANES = (32, 64, 128, 256, 256, 128, 96, 128) + + +class STRes16UNetBase(Res16UNetBase): + + CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS + + def __init__(self, in_channels, out_channels, config, D=4, **kwargs): + super().__init__(in_channels, out_channels, config, D, **kwargs) + + +class STRes16UNet14(STRes16UNetBase, Res16UNet14): + pass + + +class STRes16UNet14A(STRes16UNetBase, Res16UNet14A): + pass + + +class STRes16UNet18(STRes16UNetBase, Res16UNet18): + pass + + +class STRes16UNet34(STRes16UNetBase, Res16UNet34): + pass + + +class STRes16UNet34C(STRes16UNetBase, Res16UNet34C): + pass + + +class STRes16UNet50(STRes16UNetBase, Res16UNet50): + pass + + +class STRes16UNet101(STRes16UNetBase, Res16UNet101): + pass + + +class STRes16UNet18A(STRes16UNet18): + PLANES = (32, 64, 128, 256, 128, 128, 96, 96) + + +class STResTesseract16UNetBase(STRes16UNetBase): + pass + #CONV_TYPE = ConvType.HYPERCUBE + + +class STResTesseract16UNet18A(STRes16UNet18A, STResTesseract16UNetBase): + pass diff --git a/models/resnet.py b/models/resnet.py new file mode 100644 index 0000000..5208c1f --- /dev/null +++ b/models/resnet.py @@ -0,0 +1,240 @@ +import torch.nn as nn +import MinkowskiEngine as ME + +from models.model import Model +from models.modules.common import ConvType, NormType, conv, get_norm, sum_pool +from models.modules.resnet_block import BasicBlock, Bottleneck + + +class ResNetBase(Model): + BLOCK = None + LAYERS = () + INIT_DIM = 64 + PLANES = (64, 128, 256, 512) + OUT_PIXEL_DIST = 32 + HAS_LAST_BLOCK = False + CONV_TYPE = ConvType.HYPERCUBE + + def __init__(self, in_channels, out_channels, config, D=3, **kwargs): + assert self.BLOCK is not None + assert self.OUT_PIXEL_DIST > 0 + + super().__init__(in_channels, out_channels, config, D, **kwargs) + + self.network_initialization(in_channels, out_channels, config, D) + self.weight_initialization() + + def network_initialization(self, in_channels, out_channels, config, D): + def space_n_time_m(n, m): + return n if D == 3 else [n, n, n, m] + + if D == 4: + self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1) + + dilations = config.dilations + bn_momentum = config.bn_momentum + self.inplanes = self.INIT_DIM + self.conv1 = conv( + in_channels, + self.inplanes, + kernel_size=space_n_time_m(config.conv1_kernel_size, 1), + stride=1, + D=D, + ) + + self.bn1 = get_norm( + NormType.BATCH_NORM, self.inplanes, D=self.D, bn_momentum=bn_momentum + ) + self.relu = ME.MinkowskiReLU(inplace=True) + self.pool = sum_pool( + kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), D=D + ) + + self.layer1 = self._make_layer( + self.BLOCK, + self.PLANES[0], + self.LAYERS[0], + stride=space_n_time_m(2, 1), + dilation=space_n_time_m(dilations[0], 1), + ) + self.layer2 = self._make_layer( + self.BLOCK, + self.PLANES[1], + self.LAYERS[1], + stride=space_n_time_m(2, 1), + dilation=space_n_time_m(dilations[1], 1), + ) + self.layer3 = self._make_layer( + self.BLOCK, + self.PLANES[2], + self.LAYERS[2], + stride=space_n_time_m(2, 1), + dilation=space_n_time_m(dilations[2], 1), + ) + self.layer4 = self._make_layer( + self.BLOCK, + self.PLANES[3], + self.LAYERS[3], + stride=space_n_time_m(2, 1), + dilation=space_n_time_m(dilations[3], 1), + ) + + self.final = conv( + self.PLANES[3] * self.BLOCK.expansion, + out_channels, + kernel_size=1, + bias=True, + D=D, + ) + + def weight_initialization(self): + for m in self.modules(): + if isinstance(m, ME.MinkowskiBatchNorm): + nn.init.constant_(m.bn.weight, 1) + nn.init.constant_(m.bn.bias, 0) + + def _make_layer( + self, + block, + planes, + blocks, + stride=1, + dilation=1, + norm_type=NormType.BATCH_NORM, + bn_momentum=0.1, + ): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + D=self.D, + ), + get_norm( + norm_type, + planes * block.expansion, + D=self.D, + bn_momentum=bn_momentum, + ), + ) + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride=stride, + dilation=dilation, + downsample=downsample, + conv_type=self.CONV_TYPE, + D=self.D, + ) + ) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + stride=1, + dilation=dilation, + conv_type=self.CONV_TYPE, + D=self.D, + ) + ) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.pool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.final(x) + return x + + +class ResNet14(ResNetBase): + BLOCK = BasicBlock + LAYERS = (1, 1, 1, 1) + + +class ResNet18(ResNetBase): + BLOCK = BasicBlock + LAYERS = (2, 2, 2, 2) + + +class ResNet34(ResNetBase): + BLOCK = BasicBlock + LAYERS = (3, 4, 6, 3) + + +class ResNet50(ResNetBase): + BLOCK = Bottleneck + LAYERS = (3, 4, 6, 3) + + +class ResNet101(ResNetBase): + BLOCK = Bottleneck + LAYERS = (3, 4, 23, 3) + + +class STResNetBase(ResNetBase): + + CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS + + def __init__(self, in_channels, out_channels, config, D=4, **kwargs): + super().__init__(in_channels, out_channels, config, D, **kwargs) + + +class STResNet14(STResNetBase, ResNet14): + pass + + +class STResNet18(STResNetBase, ResNet18): + pass + + +class STResNet34(STResNetBase, ResNet34): + pass + + +class STResNet50(STResNetBase, ResNet50): + pass + + +class STResNet101(STResNetBase, ResNet101): + pass + + +class STResTesseractNetBase(STResNetBase): + CONV_TYPE = ConvType.HYPERCUBE + + +class STResTesseractNet14(STResTesseractNetBase, STResNet14): + pass + + +class STResTesseractNet18(STResTesseractNetBase, STResNet18): + pass + + +class STResTesseractNet34(STResTesseractNetBase, STResNet34): + pass + + +class STResTesseractNet50(STResTesseractNetBase, STResNet50): + pass + + +class STResTesseractNet101(STResTesseractNetBase, STResNet101): + pass diff --git a/models/resunet.py b/models/resunet.py new file mode 100644 index 0000000..a4b21d9 --- /dev/null +++ b/models/resunet.py @@ -0,0 +1,593 @@ +import torch.nn as nn +import MinkowskiEngine as ME +import MinkowskiEngine.MinkowskiOps as me +from MinkowskiEngine import MinkowskiReLU + +from models.resnet import ResNetBase, get_norm +from models.modules.common import ConvType, NormType, conv, conv_tr +from models.modules.resnet_block import BasicBlock, Bottleneck, BasicBlockINBN + + +class MinkUNetBase(ResNetBase): + BLOCK = None + PLANES = (64, 128, 256, 512, 256, 128, 128) + DILATIONS = (1, 1, 1, 1, 1, 1) + LAYERS = (2, 2, 2, 2, 2, 2) + INIT_DIM = 64 + OUT_PIXEL_DIST = 1 + NORM_TYPE = NormType.BATCH_NORM + NON_BLOCK_CONV_TYPE = ConvType.SPATIAL_HYPERCUBE + CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS + + # To use the model, must call initialize_coords before forward pass. + # Once data is processed, call clear to reset the model before calling initialize_coords + def __init__(self, in_channels, out_channels, config, D=3, **kwargs): + super().__init__(in_channels, out_channels, config, D) + + def network_initialization(self, in_channels, out_channels, config, D): + # Setup net_metadata + dilations = self.DILATIONS + bn_momentum = config.bn_momentum + + def space_n_time_m(n, m): + return n if D == 3 else [n, n, n, m] + + if D == 4: + self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1) + + # Output of the first conv concated to conv6 + self.inplanes = self.INIT_DIM + self.conv1p1s1 = conv( + in_channels, + self.inplanes, + kernel_size=space_n_time_m(config.conv1_kernel_size, 1), + stride=1, + dilation=1, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + + self.bn1 = get_norm(self.NORM_TYPE, self.PLANES[0], D, bn_momentum=bn_momentum) + self.block1 = self._make_layer( + self.BLOCK, + self.PLANES[0], + self.LAYERS[0], + dilation=dilations[0], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + + self.conv2p1s2 = conv( + self.inplanes, + self.inplanes, + kernel_size=space_n_time_m(2, 1), + stride=space_n_time_m(2, 1), + dilation=1, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bn2 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) + self.block2 = self._make_layer( + self.BLOCK, + self.PLANES[1], + self.LAYERS[1], + dilation=dilations[1], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + + self.conv3p2s2 = conv( + self.inplanes, + self.inplanes, + kernel_size=space_n_time_m(2, 1), + stride=space_n_time_m(2, 1), + dilation=1, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bn3 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) + self.block3 = self._make_layer( + self.BLOCK, + self.PLANES[2], + self.LAYERS[2], + dilation=dilations[2], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + + self.conv4p4s2 = conv( + self.inplanes, + self.inplanes, + kernel_size=space_n_time_m(2, 1), + stride=space_n_time_m(2, 1), + dilation=1, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bn4 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) + self.block4 = self._make_layer( + self.BLOCK, + self.PLANES[3], + self.LAYERS[3], + dilation=dilations[3], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + self.convtr4p8s2 = conv_tr( + self.inplanes, + self.PLANES[4], + kernel_size=space_n_time_m(2, 1), + upsample_stride=space_n_time_m(2, 1), + dilation=1, + bias=False, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bntr4 = get_norm( + self.NORM_TYPE, self.PLANES[4], D, bn_momentum=bn_momentum + ) + + self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion + self.block5 = self._make_layer( + self.BLOCK, + self.PLANES[4], + self.LAYERS[4], + dilation=dilations[4], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + self.convtr5p4s2 = conv_tr( + self.inplanes, + self.PLANES[5], + kernel_size=space_n_time_m(2, 1), + upsample_stride=space_n_time_m(2, 1), + dilation=1, + bias=False, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bntr5 = get_norm( + self.NORM_TYPE, self.PLANES[5], D, bn_momentum=bn_momentum + ) + + self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion + self.block6 = self._make_layer( + self.BLOCK, + self.PLANES[5], + self.LAYERS[5], + dilation=dilations[5], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + self.convtr6p2s2 = conv_tr( + self.inplanes, + self.PLANES[6], + kernel_size=space_n_time_m(2, 1), + upsample_stride=space_n_time_m(2, 1), + dilation=1, + bias=False, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bntr6 = get_norm( + self.NORM_TYPE, self.PLANES[6], D, bn_momentum=bn_momentum + ) + self.relu = MinkowskiReLU(inplace=True) + + self.final = nn.Sequential( + conv( + self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion, + 512, + kernel_size=1, + stride=1, + dilation=1, + bias=False, + D=D, + ), + ME.MinkowskiBatchNorm(512), + ME.MinkowskiReLU(), + conv( + 512, out_channels, kernel_size=1, stride=1, dilation=1, bias=True, D=D + ), + ) + + def forward(self, x): + out = self.conv1p1s1(x) + out = self.bn1(out) + out = self.relu(out) + + out_b1p1 = self.block1(out) + + out = self.conv2p1s2(out_b1p1) + out = self.bn2(out) + out = self.relu(out) + + out_b2p2 = self.block2(out) + + out = self.conv3p2s2(out_b2p2) + out = self.bn3(out) + out = self.relu(out) + + out_b3p4 = self.block3(out) + + out = self.conv4p4s2(out_b3p4) + out = self.bn4(out) + out = self.relu(out) + + # pixel_dist=8 + out = self.block4(out) + + out = self.convtr4p8s2(out) + out = self.bntr4(out) + out = self.relu(out) + + out = me.cat(out, out_b3p4) + out = self.block5(out) + + out = self.convtr5p4s2(out) + out = self.bntr5(out) + out = self.relu(out) + + out = me.cat(out, out_b2p2) + out = self.block6(out) + + out = self.convtr6p2s2(out) + out = self.bntr6(out) + out = self.relu(out) + + out = me.cat(out, out_b1p1) + return self.final(out) + + +class ResUNet14(MinkUNetBase): + BLOCK = BasicBlock + LAYERS = (1, 1, 1, 1, 1, 1) + + +class ResUNet18(MinkUNetBase): + BLOCK = BasicBlock + LAYERS = (2, 2, 2, 2, 2, 2) + + +class ResUNet18INBN(ResUNet18): + NORM_TYPE = NormType.INSTANCE_BATCH_NORM + BLOCK = BasicBlockINBN + + +class ResUNet34(MinkUNetBase): + BLOCK = BasicBlock + LAYERS = (3, 4, 6, 3, 2, 2) + + +class ResUNet50(MinkUNetBase): + BLOCK = Bottleneck + LAYERS = (3, 4, 6, 3, 2, 2) + + +class ResUNet101(MinkUNetBase): + BLOCK = Bottleneck + LAYERS = (3, 4, 23, 3, 2, 2) + + +class ResUNet14D(ResUNet14): + PLANES = (64, 128, 256, 512, 512, 512, 512) + + +class ResUNet18D(ResUNet18): + PLANES = (64, 128, 256, 512, 512, 512, 512) + + +class ResUNet34D(ResUNet34): + PLANES = (64, 128, 256, 512, 512, 512, 512) + + +class ResUNet34E(ResUNet34): + INIT_DIM = 32 + PLANES = (32, 64, 128, 256, 128, 64, 64) + + +class ResUNet34F(ResUNet34): + INIT_DIM = 32 + PLANES = (32, 64, 128, 256, 128, 64, 32) + + +class MinkUNetHyper(MinkUNetBase): + BLOCK = None + PLANES = (64, 128, 256, 512, 256, 128, 128) + DILATIONS = (1, 1, 1, 1, 1, 1) + LAYERS = (2, 2, 2, 2, 2, 2) + INIT_DIM = 64 + OUT_PIXEL_DIST = 1 + NORM_TYPE = NormType.BATCH_NORM + NON_BLOCK_CONV_TYPE = ConvType.SPATIAL_HYPERCUBE + CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS + + # To use the model, must call initialize_coords before forward pass. + # Once data is processed, call clear to reset the model before calling initialize_coords + def __init__(self, in_channels, out_channels, config, D=3, **kwargs): + super(MinkUNetBase, self).__init__(in_channels, out_channels, config, D) + + def network_initialization(self, in_channels, out_channels, config, D): + # Setup net_metadata + dilations = self.DILATIONS + bn_momentum = config.bn_momentum + + def space_n_time_m(n, m): + return n if D == 3 else [n, n, n, m] + + if D == 4: + self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1) + + # Output of the first conv concated to conv6 + self.inplanes = self.INIT_DIM + self.conv1p1s1 = conv( + in_channels, + self.inplanes, + kernel_size=space_n_time_m(config.conv1_kernel_size, 1), + stride=1, + dilation=1, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + + self.bn1 = get_norm(self.NORM_TYPE, self.PLANES[0], D, bn_momentum=bn_momentum) + self.block1 = self._make_layer( + self.BLOCK, + self.PLANES[0], + self.LAYERS[0], + dilation=dilations[0], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + + self.conv2p1s2 = conv( + self.inplanes, + self.inplanes, + kernel_size=space_n_time_m(2, 1), + stride=space_n_time_m(2, 1), + dilation=1, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bn2 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) + self.block2 = self._make_layer( + self.BLOCK, + self.PLANES[1], + self.LAYERS[1], + dilation=dilations[1], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + + self.conv3p2s2 = conv( + self.inplanes, + self.inplanes, + kernel_size=space_n_time_m(2, 1), + stride=space_n_time_m(2, 1), + dilation=1, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bn3 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) + self.block3 = self._make_layer( + self.BLOCK, + self.PLANES[2], + self.LAYERS[2], + dilation=dilations[2], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + + self.conv4p4s2 = conv( + self.inplanes, + self.inplanes, + kernel_size=space_n_time_m(2, 1), + stride=space_n_time_m(2, 1), + dilation=1, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bn4 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) + self.block4 = self._make_layer( + self.BLOCK, + self.PLANES[3], + self.LAYERS[3], + dilation=dilations[3], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + self.pool_tr4 = ME.MinkowskiPoolingTranspose( + kernel_size=8, stride=8, dimension=D + ) + _ = self.inplanes + self.convtr4p8s2 = conv_tr( + self.inplanes, + self.PLANES[4], + kernel_size=space_n_time_m(2, 1), + upsample_stride=space_n_time_m(2, 1), + dilation=1, + bias=False, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bntr4 = get_norm( + self.NORM_TYPE, self.PLANES[4], D, bn_momentum=bn_momentum + ) + + self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion + self.block5 = self._make_layer( + self.BLOCK, + self.PLANES[4], + self.LAYERS[4], + dilation=dilations[4], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + self.pool_tr5 = ME.MinkowskiPoolingTranspose( + kernel_size=4, stride=4, dimension=D + ) + out_pool5 = self.inplanes + self.convtr5p4s2 = conv_tr( + self.inplanes, + self.PLANES[5], + kernel_size=space_n_time_m(2, 1), + upsample_stride=space_n_time_m(2, 1), + dilation=1, + bias=False, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bntr5 = get_norm( + self.NORM_TYPE, self.PLANES[5], D, bn_momentum=bn_momentum + ) + + self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion + self.block6 = self._make_layer( + self.BLOCK, + self.PLANES[5], + self.LAYERS[5], + dilation=dilations[5], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum, + ) + self.pool_tr6 = ME.MinkowskiPoolingTranspose( + kernel_size=2, stride=2, dimension=D + ) + out_pool6 = self.inplanes + self.convtr6p2s2 = conv_tr( + self.inplanes, + self.PLANES[6], + kernel_size=space_n_time_m(2, 1), + upsample_stride=space_n_time_m(2, 1), + dilation=1, + bias=False, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D, + ) + self.bntr6 = get_norm( + self.NORM_TYPE, self.PLANES[6], D, bn_momentum=bn_momentum + ) + + self.relu = MinkowskiReLU(inplace=True) + + self.final = nn.Sequential( + conv( + out_pool5 + + out_pool6 + + self.PLANES[6] + + self.PLANES[0] * self.BLOCK.expansion, + 512, + kernel_size=1, + bias=False, + D=D, + ), + ME.MinkowskiBatchNorm(512), + ME.MinkowskiReLU(), + conv(512, out_channels, kernel_size=1, bias=True, D=D), + ) + + def forward(self, x): + out = self.conv1p1s1(x) + out = self.bn1(out) + out = self.relu(out) + + out_b1p1 = self.block1(out) + + out = self.conv2p1s2(out_b1p1) + out = self.bn2(out) + out = self.relu(out) + + out_b2p2 = self.block2(out) + + out = self.conv3p2s2(out_b2p2) + out = self.bn3(out) + out = self.relu(out) + + out_b3p4 = self.block3(out) + + out = self.conv4p4s2(out_b3p4) + out = self.bn4(out) + out = self.relu(out) + + # pixel_dist=8 + out = self.block4(out) + + out = self.convtr4p8s2(out) + out = self.bntr4(out) + out = self.relu(out) + + out = me.cat(out, out_b3p4) + out = self.block5(out) + out_5 = self.pool_tr5(out) + + out = self.convtr5p4s2(out) + out = self.bntr5(out) + out = self.relu(out) + + out = me.cat(out, out_b2p2) + out = self.block6(out) + out_6 = self.pool_tr6(out) + + out = self.convtr6p2s2(out) + out = self.bntr6(out) + out = self.relu(out) + + out = me.cat(out, out_b1p1, out_6, out_5) + return self.final(out) + + +class MinkUNetHyper14INBN(MinkUNetHyper): + NORM_TYPE = NormType.INSTANCE_BATCH_NORM + BLOCK = BasicBlockINBN + + +class STMinkUNetBase(MinkUNetBase): + + CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS + + def __init__(self, in_channels, out_channels, config, D=4, **kwargs): + super().__init__(in_channels, out_channels, config, D, **kwargs) + + +class STResUNet14(STMinkUNetBase, ResUNet14): + pass + + +class STResUNet18(STMinkUNetBase, ResUNet18): + pass + + +class STResUNet34(STMinkUNetBase, ResUNet34): + pass + + +class STResUNet50(STMinkUNetBase, ResUNet50): + pass + + +class STResUNet101(STMinkUNetBase, ResUNet101): + pass + + +class STResTesseractUNetBase(STMinkUNetBase): + CONV_TYPE = ConvType.HYPERCUBE + + +class STResTesseractUNet14(STResTesseractUNetBase, ResUNet14): + pass + + +class STResTesseractUNet18(STResTesseractUNetBase, ResUNet18): + pass + + +class STResTesseractUNet34(STResTesseractUNetBase, ResUNet34): + pass + + +class STResTesseractUNet50(STResTesseractUNetBase, ResUNet50): + pass + + +class STResTesseractUNet101(STResTesseractUNetBase, ResUNet101): + pass diff --git a/scripts/preprocess_kitti.sh b/scripts/preprocess_kitti.sh new file mode 100755 index 0000000..97927b9 --- /dev/null +++ b/scripts/preprocess_kitti.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +srun python -m datasets.preprocessing.semantic_kitti_preprocessing preprocess \ +--data_dir="/globalwork/data/SemanticKITTI/dataset" \ +--save_dir="/globalwork/yilmaz/data/processed/semantic_kitti" + +srun python -m datasets.preprocessing.semantic_kitti_preprocessing make_instance_database \ +--data_dir="/globalwork/data/SemanticKITTI/dataset" \ +--save_dir="/globalwork/yilmaz/data/processed/semantic_kitti" \ No newline at end of file diff --git a/scripts/test.sh b/scripts/test.sh new file mode 100755 index 0000000..394f745 --- /dev/null +++ b/scripts/test.sh @@ -0,0 +1,10 @@ +#!/bin/bash +export OMP_NUM_THREADS=12 # speeds up MinkowskiEngine +export CUDA_LAUNCH_BLOCKING=1 +export HYDRA_FULL_ERROR=1 + +python main_panoptic.py \ +general.project_name="kitti_semantic" \ +general.mode="test" \ +general.instance_population=0 \ +general.ckpt_path='/home/yilmaz/fix/Mask4D/saved/2023-09-30_153832/last-epoch.ckpt' \ \ No newline at end of file diff --git a/scripts/train.sh b/scripts/train.sh new file mode 100755 index 0000000..5d18fac --- /dev/null +++ b/scripts/train.sh @@ -0,0 +1,8 @@ +#!/bin/bash +export OMP_NUM_THREADS=12 # speeds up MinkowskiEngine +export CUDA_LAUNCH_BLOCKING=1 +export HYDRA_FULL_ERROR=1 + +# TRAIN +python main_panoptic.py \ +general.project_name="kitti_semantic" \ \ No newline at end of file diff --git a/scripts/val.sh b/scripts/val.sh new file mode 100755 index 0000000..7bbdd0d --- /dev/null +++ b/scripts/val.sh @@ -0,0 +1,10 @@ +#!/bin/bash +export OMP_NUM_THREADS=12 # speeds up MinkowskiEngine +export CUDA_LAUNCH_BLOCKING=1 +export HYDRA_FULL_ERROR=1 + +python main_panoptic.py \ +general.project_name="kitti_semantic" \ +general.mode="validate" \ +general.instance_population=0 \ +general.ckpt_path='/home/yilmaz/fix/Mask4D/saved/2023-09-30_153832/last-epoch.ckpt' \ \ No newline at end of file diff --git a/third_party/pointnet2/_ext_src/include/ball_query.h b/third_party/pointnet2/_ext_src/include/ball_query.h new file mode 100644 index 0000000..b4feff8 --- /dev/null +++ b/third_party/pointnet2/_ext_src/include/ball_query.h @@ -0,0 +1,7 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + +#pragma once +#include + +at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, + const int nsample); diff --git a/third_party/pointnet2/_ext_src/include/cuda_utils.h b/third_party/pointnet2/_ext_src/include/cuda_utils.h new file mode 100644 index 0000000..f746526 --- /dev/null +++ b/third_party/pointnet2/_ext_src/include/cuda_utils.h @@ -0,0 +1,43 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + +#ifndef _CUDA_UTILS_H +#define _CUDA_UTILS_H + +#include +#include +#include + +#include +#include + +#include + +#define TOTAL_THREADS 512 + +inline int opt_n_threads(int work_size) { + const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); + + return max(min(1 << pow_2, TOTAL_THREADS), 1); +} + +inline dim3 opt_block_config(int x, int y) { + const int x_threads = opt_n_threads(x); + const int y_threads = + max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); + dim3 block_config(x_threads, y_threads, 1); + + return block_config; +} + +#define CUDA_CHECK_ERRORS() \ + do { \ + cudaError_t err = cudaGetLastError(); \ + if (cudaSuccess != err) { \ + fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ + cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ + __FILE__); \ + exit(-1); \ + } \ + } while (0) + +#endif diff --git a/third_party/pointnet2/_ext_src/include/group_points.h b/third_party/pointnet2/_ext_src/include/group_points.h new file mode 100644 index 0000000..97be802 --- /dev/null +++ b/third_party/pointnet2/_ext_src/include/group_points.h @@ -0,0 +1,8 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + + +#pragma once +#include + +at::Tensor group_points(at::Tensor points, at::Tensor idx); +at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); diff --git a/third_party/pointnet2/_ext_src/include/interpolate.h b/third_party/pointnet2/_ext_src/include/interpolate.h new file mode 100644 index 0000000..e7fb792 --- /dev/null +++ b/third_party/pointnet2/_ext_src/include/interpolate.h @@ -0,0 +1,12 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + +#pragma once + +#include +#include + +std::vector three_nn(at::Tensor unknowns, at::Tensor knows); +at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, + at::Tensor weight); +at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, + at::Tensor weight, const int m); diff --git a/third_party/pointnet2/_ext_src/include/sampling.h b/third_party/pointnet2/_ext_src/include/sampling.h new file mode 100644 index 0000000..7de473e --- /dev/null +++ b/third_party/pointnet2/_ext_src/include/sampling.h @@ -0,0 +1,9 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + + +#pragma once +#include + +at::Tensor gather_points(at::Tensor points, at::Tensor idx); +at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); +at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); diff --git a/third_party/pointnet2/_ext_src/include/utils.h b/third_party/pointnet2/_ext_src/include/utils.h new file mode 100644 index 0000000..815dabb --- /dev/null +++ b/third_party/pointnet2/_ext_src/include/utils.h @@ -0,0 +1,28 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + + +#pragma once +#include +#include + +#define CHECK_CUDA(x) \ + do { \ + AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \ + } while (0) + +#define CHECK_CONTIGUOUS(x) \ + do { \ + AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \ + } while (0) + +#define CHECK_IS_INT(x) \ + do { \ + AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \ + #x " must be an int tensor"); \ + } while (0) + +#define CHECK_IS_FLOAT(x) \ + do { \ + AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \ + #x " must be a float tensor"); \ + } while (0) diff --git a/third_party/pointnet2/_ext_src/src/ball_query.cpp b/third_party/pointnet2/_ext_src/src/ball_query.cpp new file mode 100644 index 0000000..7dd77d5 --- /dev/null +++ b/third_party/pointnet2/_ext_src/src/ball_query.cpp @@ -0,0 +1,35 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + + +#include "ball_query.h" +#include "utils.h" + +void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, + int nsample, const float *new_xyz, + const float *xyz, int *idx); + +at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, + const int nsample) { + CHECK_CONTIGUOUS(new_xyz); + CHECK_CONTIGUOUS(xyz); + CHECK_IS_FLOAT(new_xyz); + CHECK_IS_FLOAT(xyz); + + if (new_xyz.is_cuda()) { + CHECK_CUDA(xyz); + } + + at::Tensor idx = + torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, + at::device(new_xyz.device()).dtype(at::ScalarType::Int)); + + if (new_xyz.is_cuda()) { + query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), + radius, nsample, new_xyz.data(), + xyz.data(), idx.data()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return idx; +} diff --git a/third_party/pointnet2/_ext_src/src/ball_query_gpu.cu b/third_party/pointnet2/_ext_src/src/ball_query_gpu.cu new file mode 100644 index 0000000..cee88cb --- /dev/null +++ b/third_party/pointnet2/_ext_src/src/ball_query_gpu.cu @@ -0,0 +1,57 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + + +#include +#include +#include + +#include "cuda_utils.h" + +// input: new_xyz(b, m, 3) xyz(b, n, 3) +// output: idx(b, m, nsample) +__global__ void query_ball_point_kernel(int b, int n, int m, float radius, + int nsample, + const float *__restrict__ new_xyz, + const float *__restrict__ xyz, + int *__restrict__ idx) { + int batch_index = blockIdx.x; + xyz += batch_index * n * 3; + new_xyz += batch_index * m * 3; + idx += m * nsample * batch_index; + + int index = threadIdx.x; + int stride = blockDim.x; + + float radius2 = radius * radius; + for (int j = index; j < m; j += stride) { + float new_x = new_xyz[j * 3 + 0]; + float new_y = new_xyz[j * 3 + 1]; + float new_z = new_xyz[j * 3 + 2]; + for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { + float x = xyz[k * 3 + 0]; + float y = xyz[k * 3 + 1]; + float z = xyz[k * 3 + 2]; + float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + + (new_z - z) * (new_z - z); + if (d2 < radius2) { + if (cnt == 0) { + for (int l = 0; l < nsample; ++l) { + idx[j * nsample + l] = k; + } + } + idx[j * nsample + cnt] = k; + ++cnt; + } + } + } +} + +void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, + int nsample, const float *new_xyz, + const float *xyz, int *idx) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + query_ball_point_kernel<<>>( + b, n, m, radius, nsample, new_xyz, xyz, idx); + + CUDA_CHECK_ERRORS(); +} diff --git a/third_party/pointnet2/_ext_src/src/bindings.cpp b/third_party/pointnet2/_ext_src/src/bindings.cpp new file mode 100644 index 0000000..58d6c2d --- /dev/null +++ b/third_party/pointnet2/_ext_src/src/bindings.cpp @@ -0,0 +1,22 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + + +#include "ball_query.h" +#include "group_points.h" +#include "interpolate.h" +#include "sampling.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("gather_points", &gather_points); + m.def("gather_points_grad", &gather_points_grad); + m.def("furthest_point_sampling", &furthest_point_sampling); + + m.def("three_nn", &three_nn); + m.def("three_interpolate", &three_interpolate); + m.def("three_interpolate_grad", &three_interpolate_grad); + + m.def("ball_query", &ball_query); + + m.def("group_points", &group_points); + m.def("group_points_grad", &group_points_grad); +} diff --git a/third_party/pointnet2/_ext_src/src/group_points.cpp b/third_party/pointnet2/_ext_src/src/group_points.cpp new file mode 100644 index 0000000..22998dd --- /dev/null +++ b/third_party/pointnet2/_ext_src/src/group_points.cpp @@ -0,0 +1,63 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + + +#include "group_points.h" +#include "utils.h" + +void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, + const float *points, const int *idx, + float *out); + +void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, + int nsample, const float *grad_out, + const int *idx, float *grad_points); + +at::Tensor group_points(at::Tensor points, at::Tensor idx) { + CHECK_CONTIGUOUS(points); + CHECK_CONTIGUOUS(idx); + CHECK_IS_FLOAT(points); + CHECK_IS_INT(idx); + + if (points.is_cuda()) { + CHECK_CUDA(idx); + } + + at::Tensor output = + torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, + at::device(points.device()).dtype(at::ScalarType::Float)); + + if (points.is_cuda()) { + group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), + idx.size(1), idx.size(2), points.data(), + idx.data(), output.data()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} + +at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { + CHECK_CONTIGUOUS(grad_out); + CHECK_CONTIGUOUS(idx); + CHECK_IS_FLOAT(grad_out); + CHECK_IS_INT(idx); + + if (grad_out.is_cuda()) { + CHECK_CUDA(idx); + } + + at::Tensor output = + torch::zeros({grad_out.size(0), grad_out.size(1), n}, + at::device(grad_out.device()).dtype(at::ScalarType::Float)); + + if (grad_out.is_cuda()) { + group_points_grad_kernel_wrapper( + grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), + grad_out.data(), idx.data(), output.data()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} diff --git a/third_party/pointnet2/_ext_src/src/group_points_gpu.cu b/third_party/pointnet2/_ext_src/src/group_points_gpu.cu new file mode 100644 index 0000000..e36672e --- /dev/null +++ b/third_party/pointnet2/_ext_src/src/group_points_gpu.cu @@ -0,0 +1,78 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + + +#include +#include + +#include "cuda_utils.h" + +// input: points(b, c, n) idx(b, npoints, nsample) +// output: out(b, c, npoints, nsample) +__global__ void group_points_kernel(int b, int c, int n, int npoints, + int nsample, + const float *__restrict__ points, + const int *__restrict__ idx, + float *__restrict__ out) { + int batch_index = blockIdx.x; + points += batch_index * n * c; + idx += batch_index * npoints * nsample; + out += batch_index * npoints * nsample * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * npoints; i += stride) { + const int l = i / npoints; + const int j = i % npoints; + for (int k = 0; k < nsample; ++k) { + int ii = idx[j * nsample + k]; + out[(l * npoints + j) * nsample + k] = points[l * n + ii]; + } + } +} + +void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, + const float *points, const int *idx, + float *out) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + group_points_kernel<<>>( + b, c, n, npoints, nsample, points, idx, out); + + CUDA_CHECK_ERRORS(); +} + +// input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) +// output: grad_points(b, c, n) +__global__ void group_points_grad_kernel(int b, int c, int n, int npoints, + int nsample, + const float *__restrict__ grad_out, + const int *__restrict__ idx, + float *__restrict__ grad_points) { + int batch_index = blockIdx.x; + grad_out += batch_index * npoints * nsample * c; + idx += batch_index * npoints * nsample; + grad_points += batch_index * n * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * npoints; i += stride) { + const int l = i / npoints; + const int j = i % npoints; + for (int k = 0; k < nsample; ++k) { + int ii = idx[j * nsample + k]; + atomicAdd(grad_points + l * n + ii, + grad_out[(l * npoints + j) * nsample + k]); + } + } +} + +void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, + int nsample, const float *grad_out, + const int *idx, float *grad_points) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + group_points_grad_kernel<<>>( + b, c, n, npoints, nsample, grad_out, idx, grad_points); + + CUDA_CHECK_ERRORS(); +} diff --git a/third_party/pointnet2/_ext_src/src/interpolate.cpp b/third_party/pointnet2/_ext_src/src/interpolate.cpp new file mode 100644 index 0000000..4b680c5 --- /dev/null +++ b/third_party/pointnet2/_ext_src/src/interpolate.cpp @@ -0,0 +1,101 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + +#include "interpolate.h" +#include "utils.h" + +void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, + const float *known, float *dist2, int *idx); +void three_interpolate_kernel_wrapper(int b, int c, int m, int n, + const float *points, const int *idx, + const float *weight, float *out); +void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, + const float *grad_out, + const int *idx, const float *weight, + float *grad_points); + +std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { + CHECK_CONTIGUOUS(unknowns); + CHECK_CONTIGUOUS(knows); + CHECK_IS_FLOAT(unknowns); + CHECK_IS_FLOAT(knows); + + if (unknowns.is_cuda()) { + CHECK_CUDA(knows); + } + + at::Tensor idx = + torch::zeros({unknowns.size(0), unknowns.size(1), 3}, + at::device(unknowns.device()).dtype(at::ScalarType::Int)); + at::Tensor dist2 = + torch::zeros({unknowns.size(0), unknowns.size(1), 3}, + at::device(unknowns.device()).dtype(at::ScalarType::Float)); + + if (unknowns.is_cuda()) { + three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), + unknowns.data(), knows.data(), + dist2.data(), idx.data()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return {dist2, idx}; +} + +at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, + at::Tensor weight) { + CHECK_CONTIGUOUS(points); + CHECK_CONTIGUOUS(idx); + CHECK_CONTIGUOUS(weight); + CHECK_IS_FLOAT(points); + CHECK_IS_INT(idx); + CHECK_IS_FLOAT(weight); + + if (points.is_cuda()) { + CHECK_CUDA(idx); + CHECK_CUDA(weight); + } + + at::Tensor output = + torch::zeros({points.size(0), points.size(1), idx.size(1)}, + at::device(points.device()).dtype(at::ScalarType::Float)); + + if (points.is_cuda()) { + three_interpolate_kernel_wrapper( + points.size(0), points.size(1), points.size(2), idx.size(1), + points.data(), idx.data(), weight.data(), + output.data()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} +at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, + at::Tensor weight, const int m) { + CHECK_CONTIGUOUS(grad_out); + CHECK_CONTIGUOUS(idx); + CHECK_CONTIGUOUS(weight); + CHECK_IS_FLOAT(grad_out); + CHECK_IS_INT(idx); + CHECK_IS_FLOAT(weight); + + if (grad_out.is_cuda()) { + CHECK_CUDA(idx); + CHECK_CUDA(weight); + } + + at::Tensor output = + torch::zeros({grad_out.size(0), grad_out.size(1), m}, + at::device(grad_out.device()).dtype(at::ScalarType::Float)); + + if (grad_out.is_cuda()) { + three_interpolate_grad_kernel_wrapper( + grad_out.size(0), grad_out.size(1), grad_out.size(2), m, + grad_out.data(), idx.data(), weight.data(), + output.data()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} diff --git a/third_party/pointnet2/_ext_src/src/interpolate_gpu.cu b/third_party/pointnet2/_ext_src/src/interpolate_gpu.cu new file mode 100644 index 0000000..b4c5644 --- /dev/null +++ b/third_party/pointnet2/_ext_src/src/interpolate_gpu.cu @@ -0,0 +1,157 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + + +#include +#include +#include + +#include "cuda_utils.h" + +// input: unknown(b, n, 3) known(b, m, 3) +// output: dist2(b, n, 3), idx(b, n, 3) +__global__ void three_nn_kernel(int b, int n, int m, + const float *__restrict__ unknown, + const float *__restrict__ known, + float *__restrict__ dist2, + int *__restrict__ idx) { + int batch_index = blockIdx.x; + unknown += batch_index * n * 3; + known += batch_index * m * 3; + dist2 += batch_index * n * 3; + idx += batch_index * n * 3; + + int index = threadIdx.x; + int stride = blockDim.x; + for (int j = index; j < n; j += stride) { + float ux = unknown[j * 3 + 0]; + float uy = unknown[j * 3 + 1]; + float uz = unknown[j * 3 + 2]; + + double best1 = 1e40, best2 = 1e40, best3 = 1e40; + int besti1 = 0, besti2 = 0, besti3 = 0; + for (int k = 0; k < m; ++k) { + float x = known[k * 3 + 0]; + float y = known[k * 3 + 1]; + float z = known[k * 3 + 2]; + float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); + if (d < best1) { + best3 = best2; + besti3 = besti2; + best2 = best1; + besti2 = besti1; + best1 = d; + besti1 = k; + } else if (d < best2) { + best3 = best2; + besti3 = besti2; + best2 = d; + besti2 = k; + } else if (d < best3) { + best3 = d; + besti3 = k; + } + } + dist2[j * 3 + 0] = best1; + dist2[j * 3 + 1] = best2; + dist2[j * 3 + 2] = best3; + + idx[j * 3 + 0] = besti1; + idx[j * 3 + 1] = besti2; + idx[j * 3 + 2] = besti3; + } +} + +void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, + const float *known, float *dist2, int *idx) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + three_nn_kernel<<>>(b, n, m, unknown, known, + dist2, idx); + + CUDA_CHECK_ERRORS(); +} + +// input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) +// output: out(b, c, n) +__global__ void three_interpolate_kernel(int b, int c, int m, int n, + const float *__restrict__ points, + const int *__restrict__ idx, + const float *__restrict__ weight, + float *__restrict__ out) { + int batch_index = blockIdx.x; + points += batch_index * m * c; + + idx += batch_index * n * 3; + weight += batch_index * n * 3; + + out += batch_index * n * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * n; i += stride) { + const int l = i / n; + const int j = i % n; + float w1 = weight[j * 3 + 0]; + float w2 = weight[j * 3 + 1]; + float w3 = weight[j * 3 + 2]; + + int i1 = idx[j * 3 + 0]; + int i2 = idx[j * 3 + 1]; + int i3 = idx[j * 3 + 2]; + + out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + + points[l * m + i3] * w3; + } +} + +void three_interpolate_kernel_wrapper(int b, int c, int m, int n, + const float *points, const int *idx, + const float *weight, float *out) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + three_interpolate_kernel<<>>( + b, c, m, n, points, idx, weight, out); + + CUDA_CHECK_ERRORS(); +} + +// input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) +// output: grad_points(b, c, m) + +__global__ void three_interpolate_grad_kernel( + int b, int c, int n, int m, const float *__restrict__ grad_out, + const int *__restrict__ idx, const float *__restrict__ weight, + float *__restrict__ grad_points) { + int batch_index = blockIdx.x; + grad_out += batch_index * n * c; + idx += batch_index * n * 3; + weight += batch_index * n * 3; + grad_points += batch_index * m * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * n; i += stride) { + const int l = i / n; + const int j = i % n; + float w1 = weight[j * 3 + 0]; + float w2 = weight[j * 3 + 1]; + float w3 = weight[j * 3 + 2]; + + int i1 = idx[j * 3 + 0]; + int i2 = idx[j * 3 + 1]; + int i3 = idx[j * 3 + 2]; + + atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); + atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); + atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); + } +} + +void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, + const float *grad_out, + const int *idx, const float *weight, + float *grad_points) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + three_interpolate_grad_kernel<<>>( + b, c, n, m, grad_out, idx, weight, grad_points); + + CUDA_CHECK_ERRORS(); +} diff --git a/third_party/pointnet2/_ext_src/src/sampling.cpp b/third_party/pointnet2/_ext_src/src/sampling.cpp new file mode 100644 index 0000000..de55822 --- /dev/null +++ b/third_party/pointnet2/_ext_src/src/sampling.cpp @@ -0,0 +1,88 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + +#include "sampling.h" +#include "utils.h" + +void gather_points_kernel_wrapper(int b, int c, int n, int npoints, + const float *points, const int *idx, + float *out); +void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, + const float *grad_out, const int *idx, + float *grad_points); + +void furthest_point_sampling_kernel_wrapper(int b, int n, int m, + const float *dataset, float *temp, + int *idxs); + +at::Tensor gather_points(at::Tensor points, at::Tensor idx) { + CHECK_CONTIGUOUS(points); + CHECK_CONTIGUOUS(idx); + CHECK_IS_FLOAT(points); + CHECK_IS_INT(idx); + + if (points.is_cuda()) { + CHECK_CUDA(idx); + } + + at::Tensor output = + torch::zeros({points.size(0), points.size(1), idx.size(1)}, + at::device(points.device()).dtype(at::ScalarType::Float)); + + if (points.is_cuda()) { + gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), + idx.size(1), points.data(), + idx.data(), output.data()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} + +at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, + const int n) { + CHECK_CONTIGUOUS(grad_out); + CHECK_CONTIGUOUS(idx); + CHECK_IS_FLOAT(grad_out); + CHECK_IS_INT(idx); + + if (grad_out.is_cuda()) { + CHECK_CUDA(idx); + } + + at::Tensor output = + torch::zeros({grad_out.size(0), grad_out.size(1), n}, + at::device(grad_out.device()).dtype(at::ScalarType::Float)); + + if (grad_out.is_cuda()) { + gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, + idx.size(1), grad_out.data(), + idx.data(), output.data()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} +at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { + CHECK_CONTIGUOUS(points); + CHECK_IS_FLOAT(points); + + at::Tensor output = + torch::zeros({points.size(0), nsamples}, + at::device(points.device()).dtype(at::ScalarType::Int)); + + at::Tensor tmp = + torch::full({points.size(0), points.size(1)}, 1e10, + at::device(points.device()).dtype(at::ScalarType::Float)); + + if (points.is_cuda()) { + furthest_point_sampling_kernel_wrapper( + points.size(0), points.size(1), nsamples, points.data(), + tmp.data(), output.data()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} diff --git a/third_party/pointnet2/_ext_src/src/sampling_gpu.cu b/third_party/pointnet2/_ext_src/src/sampling_gpu.cu new file mode 100644 index 0000000..d2b3707 --- /dev/null +++ b/third_party/pointnet2/_ext_src/src/sampling_gpu.cu @@ -0,0 +1,232 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + + +#include +#include + +#include "cuda_utils.h" + +// input: points(b, c, n) idx(b, m) +// output: out(b, c, m) +__global__ void gather_points_kernel(int b, int c, int n, int m, + const float *__restrict__ points, + const int *__restrict__ idx, + float *__restrict__ out) { + for (int i = blockIdx.x; i < b; i += gridDim.x) { + for (int l = blockIdx.y; l < c; l += gridDim.y) { + for (int j = threadIdx.x; j < m; j += blockDim.x) { + int a = idx[i * m + j]; + out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; + } + } + } +} + +void gather_points_kernel_wrapper(int b, int c, int n, int npoints, + const float *points, const int *idx, + float *out) { + gather_points_kernel<<>>(b, c, n, npoints, + points, idx, out); + + CUDA_CHECK_ERRORS(); +} + +// input: grad_out(b, c, m) idx(b, m) +// output: grad_points(b, c, n) +__global__ void gather_points_grad_kernel(int b, int c, int n, int m, + const float *__restrict__ grad_out, + const int *__restrict__ idx, + float *__restrict__ grad_points) { + for (int i = blockIdx.x; i < b; i += gridDim.x) { + for (int l = blockIdx.y; l < c; l += gridDim.y) { + for (int j = threadIdx.x; j < m; j += blockDim.x) { + int a = idx[i * m + j]; + atomicAdd(grad_points + (i * c + l) * n + a, + grad_out[(i * c + l) * m + j]); + } + } + } +} + +void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, + const float *grad_out, const int *idx, + float *grad_points) { + gather_points_grad_kernel<<>>( + b, c, n, npoints, grad_out, idx, grad_points); + + CUDA_CHECK_ERRORS(); +} + +__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, + int idx1, int idx2) { + const float v1 = dists[idx1], v2 = dists[idx2]; + const int i1 = dists_i[idx1], i2 = dists_i[idx2]; + dists[idx1] = max(v1, v2); + dists_i[idx1] = v2 > v1 ? i2 : i1; +} + +// Input dataset: (b, n, 3), tmp: (b, n) +// Ouput idxs (b, m) +template +__global__ void furthest_point_sampling_kernel( + int b, int n, int m, const float *__restrict__ dataset, + float *__restrict__ temp, int *__restrict__ idxs) { + if (m <= 0) return; + __shared__ float dists[block_size]; + __shared__ int dists_i[block_size]; + + int batch_index = blockIdx.x; + dataset += batch_index * n * 3; + temp += batch_index * n; + idxs += batch_index * m; + + int tid = threadIdx.x; + const int stride = block_size; + + int old = 0; + if (threadIdx.x == 0) idxs[0] = old; + + __syncthreads(); + for (int j = 1; j < m; j++) { + int besti = 0; + float best = -1; + float x1 = dataset[old * 3 + 0]; + float y1 = dataset[old * 3 + 1]; + float z1 = dataset[old * 3 + 2]; + for (int k = tid; k < n; k += stride) { + float x2, y2, z2; + x2 = dataset[k * 3 + 0]; + y2 = dataset[k * 3 + 1]; + z2 = dataset[k * 3 + 2]; + float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); + if (mag <= 1e-3) continue; + + float d = + (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); + + float d2 = min(d, temp[k]); + temp[k] = d2; + besti = d2 > best ? k : besti; + best = d2 > best ? d2 : best; + } + dists[tid] = best; + dists_i[tid] = besti; + __syncthreads(); + + if (block_size >= 512) { + if (tid < 256) { + __update(dists, dists_i, tid, tid + 256); + } + __syncthreads(); + } + if (block_size >= 256) { + if (tid < 128) { + __update(dists, dists_i, tid, tid + 128); + } + __syncthreads(); + } + if (block_size >= 128) { + if (tid < 64) { + __update(dists, dists_i, tid, tid + 64); + } + __syncthreads(); + } + if (block_size >= 64) { + if (tid < 32) { + __update(dists, dists_i, tid, tid + 32); + } + __syncthreads(); + } + if (block_size >= 32) { + if (tid < 16) { + __update(dists, dists_i, tid, tid + 16); + } + __syncthreads(); + } + if (block_size >= 16) { + if (tid < 8) { + __update(dists, dists_i, tid, tid + 8); + } + __syncthreads(); + } + if (block_size >= 8) { + if (tid < 4) { + __update(dists, dists_i, tid, tid + 4); + } + __syncthreads(); + } + if (block_size >= 4) { + if (tid < 2) { + __update(dists, dists_i, tid, tid + 2); + } + __syncthreads(); + } + if (block_size >= 2) { + if (tid < 1) { + __update(dists, dists_i, tid, tid + 1); + } + __syncthreads(); + } + + old = dists_i[0]; + if (tid == 0) idxs[j] = old; + } +} + +void furthest_point_sampling_kernel_wrapper(int b, int n, int m, + const float *dataset, float *temp, + int *idxs) { + unsigned int n_threads = opt_n_threads(n); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (n_threads) { + case 512: + furthest_point_sampling_kernel<512> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 256: + furthest_point_sampling_kernel<256> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 128: + furthest_point_sampling_kernel<128> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 64: + furthest_point_sampling_kernel<64> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 32: + furthest_point_sampling_kernel<32> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 16: + furthest_point_sampling_kernel<16> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 8: + furthest_point_sampling_kernel<8> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 4: + furthest_point_sampling_kernel<4> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 2: + furthest_point_sampling_kernel<2> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 1: + furthest_point_sampling_kernel<1> + <<>>(b, n, m, dataset, temp, idxs); + break; + default: + furthest_point_sampling_kernel<512> + <<>>(b, n, m, dataset, temp, idxs); + } + + CUDA_CHECK_ERRORS(); +} diff --git a/third_party/pointnet2/pointnet2_modules.py b/third_party/pointnet2/pointnet2_modules.py new file mode 100644 index 0000000..c525504 --- /dev/null +++ b/third_party/pointnet2/pointnet2_modules.py @@ -0,0 +1,514 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +''' Pointnet2 layers. +Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch +Extended with the following: +1. Uniform sampling in each local region (sample_uniformly) +2. Return sampled points indices to support votenet. +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + +import os +import sys +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(BASE_DIR) + +import pointnet2_utils +import pytorch_utils as pt_utils +from typing import List + + +class _PointnetSAModuleBase(nn.Module): + + def __init__(self): + super().__init__() + self.npoint = None + self.groupers = None + self.mlps = None + + def forward(self, xyz: torch.Tensor, + features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): + r""" + Parameters + ---------- + xyz : torch.Tensor + (B, N, 3) tensor of the xyz coordinates of the features + features : torch.Tensor + (B, N, C) tensor of the descriptors of the the features + + Returns + ------- + new_xyz : torch.Tensor + (B, npoint, 3) tensor of the new features' xyz + new_features : torch.Tensor + (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors + """ + + new_features_list = [] + + xyz_flipped = xyz.transpose(1, 2).contiguous() + new_xyz = pointnet2_utils.gather_operation( + xyz_flipped, + pointnet2_utils.furthest_point_sample(xyz, self.npoint) + ).transpose(1, 2).contiguous() if self.npoint is not None else None + + for i in range(len(self.groupers)): + new_features = self.groupers[i]( + xyz, new_xyz, features + ) # (B, C, npoint, nsample) + + new_features = self.mlps[i]( + new_features + ) # (B, mlp[-1], npoint, nsample) + new_features = F.max_pool2d( + new_features, kernel_size=[1, new_features.size(3)] + ) # (B, mlp[-1], npoint, 1) + new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) + + new_features_list.append(new_features) + + return new_xyz, torch.cat(new_features_list, dim=1) + + +class PointnetSAModuleMSG(_PointnetSAModuleBase): + r"""Pointnet set abstrction layer with multiscale grouping + + Parameters + ---------- + npoint : int + Number of features + radii : list of float32 + list of radii to group with + nsamples : list of int32 + Number of samples in each ball query + mlps : list of list of int32 + Spec of the pointnet before the global max_pool for each scale + bn : bool + Use batchnorm + """ + + def __init__( + self, + *, + npoint: int, + radii: List[float], + nsamples: List[int], + mlps: List[List[int]], + bn: bool = True, + use_xyz: bool = True, + sample_uniformly: bool = False + ): + super().__init__() + + assert len(radii) == len(nsamples) == len(mlps) + + self.npoint = npoint + self.groupers = nn.ModuleList() + self.mlps = nn.ModuleList() + for i in range(len(radii)): + radius = radii[i] + nsample = nsamples[i] + self.groupers.append( + pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz, sample_uniformly=sample_uniformly) + if npoint is not None else pointnet2_utils.GroupAll(use_xyz) + ) + mlp_spec = mlps[i] + if use_xyz: + mlp_spec[0] += 3 + + self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn)) + + +class PointnetSAModule(PointnetSAModuleMSG): + r"""Pointnet set abstrction layer + + Parameters + ---------- + npoint : int + Number of features + radius : float + Radius of ball + nsample : int + Number of samples in the ball query + mlp : list + Spec of the pointnet before the global max_pool + bn : bool + Use batchnorm + """ + + def __init__( + self, + *, + mlp: List[int], + npoint: int = None, + radius: float = None, + nsample: int = None, + bn: bool = True, + use_xyz: bool = True + ): + super().__init__( + mlps=[mlp], + npoint=npoint, + radii=[radius], + nsamples=[nsample], + bn=bn, + use_xyz=use_xyz + ) + + +class PointnetSAModuleVotes(nn.Module): + ''' Modified based on _PointnetSAModuleBase and PointnetSAModuleMSG + with extra support for returning point indices for getting their GT votes ''' + + def __init__( + self, + *, + mlp: List[int], + npoint: int = None, + radius: float = None, + nsample: int = None, + bn: bool = True, + use_xyz: bool = True, + pooling: str = 'max', + sigma: float = None, # for RBF pooling + normalize_xyz: bool = False, # noramlize local XYZ with radius + sample_uniformly: bool = False, + ret_unique_cnt: bool = False + ): + super().__init__() + self.npoint = npoint + self.radius = radius + self.nsample = nsample + self.pooling = pooling + self.mlp_module = None + self.use_xyz = use_xyz + self.sigma = sigma + if self.sigma is None: + self.sigma = self.radius/2 + self.normalize_xyz = normalize_xyz + self.ret_unique_cnt = ret_unique_cnt + + if npoint is not None: + self.grouper = pointnet2_utils.QueryAndGroup(radius, nsample, + use_xyz=use_xyz, ret_grouped_xyz=True, normalize_xyz=normalize_xyz, + sample_uniformly=sample_uniformly, ret_unique_cnt=ret_unique_cnt) + else: + self.grouper = pointnet2_utils.GroupAll(use_xyz, ret_grouped_xyz=True) + + mlp_spec = mlp + if use_xyz and len(mlp_spec)>0: + mlp_spec[0] += 3 + self.mlp_module = pt_utils.SharedMLP(mlp_spec, bn=bn) + + + def forward(self, xyz: torch.Tensor, + features: torch.Tensor = None, + inds: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): + r""" + Parameters + ---------- + xyz : torch.Tensor + (B, N, 3) tensor of the xyz coordinates of the features + features : torch.Tensor + (B, C, N) tensor of the descriptors of the the features + inds : torch.Tensor + (B, npoint) tensor that stores index to the xyz points (values in 0-N-1) + + Returns + ------- + new_xyz : torch.Tensor + (B, npoint, 3) tensor of the new features' xyz + new_features : torch.Tensor + (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors + inds: torch.Tensor + (B, npoint) tensor of the inds + """ + + xyz_flipped = xyz.transpose(1, 2).contiguous() + if inds is None: + inds = pointnet2_utils.furthest_point_sample(xyz, self.npoint) + else: + assert(inds.shape[1] == self.npoint) + new_xyz = pointnet2_utils.gather_operation( + xyz_flipped, inds + ).transpose(1, 2).contiguous() if self.npoint is not None else None + + if not self.ret_unique_cnt: + grouped_features, grouped_xyz = self.grouper( + xyz, new_xyz, features + ) # (B, C, npoint, nsample) + else: + grouped_features, grouped_xyz, unique_cnt = self.grouper( + xyz, new_xyz, features + ) # (B, C, npoint, nsample), (B,3,npoint,nsample), (B,npoint) + + new_features = self.mlp_module( + grouped_features + ) # (B, mlp[-1], npoint, nsample) + if self.pooling == 'max': + new_features = F.max_pool2d( + new_features, kernel_size=[1, new_features.size(3)] + ) # (B, mlp[-1], npoint, 1) + elif self.pooling == 'avg': + new_features = F.avg_pool2d( + new_features, kernel_size=[1, new_features.size(3)] + ) # (B, mlp[-1], npoint, 1) + elif self.pooling == 'rbf': + # Use radial basis function kernel for weighted sum of features (normalized by nsample and sigma) + # Ref: https://en.wikipedia.org/wiki/Radial_basis_function_kernel + rbf = torch.exp(-1 * grouped_xyz.pow(2).sum(1,keepdim=False) / (self.sigma**2) / 2) # (B, npoint, nsample) + new_features = torch.sum(new_features * rbf.unsqueeze(1), -1, keepdim=True) / float(self.nsample) # (B, mlp[-1], npoint, 1) + new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) + + if not self.ret_unique_cnt: + return new_xyz, new_features, inds + else: + return new_xyz, new_features, inds, unique_cnt + +class PointnetSAModuleMSGVotes(nn.Module): + ''' Modified based on _PointnetSAModuleBase and PointnetSAModuleMSG + with extra support for returning point indices for getting their GT votes ''' + + def __init__( + self, + *, + mlps: List[List[int]], + npoint: int, + radii: List[float], + nsamples: List[int], + bn: bool = True, + use_xyz: bool = True, + sample_uniformly: bool = False + ): + super().__init__() + + assert(len(mlps) == len(nsamples) == len(radii)) + + self.npoint = npoint + self.groupers = nn.ModuleList() + self.mlps = nn.ModuleList() + for i in range(len(radii)): + radius = radii[i] + nsample = nsamples[i] + self.groupers.append( + pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz, sample_uniformly=sample_uniformly) + if npoint is not None else pointnet2_utils.GroupAll(use_xyz) + ) + mlp_spec = mlps[i] + if use_xyz: + mlp_spec[0] += 3 + + self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn)) + + def forward(self, xyz: torch.Tensor, + features: torch.Tensor = None, inds: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): + r""" + Parameters + ---------- + xyz : torch.Tensor + (B, N, 3) tensor of the xyz coordinates of the features + features : torch.Tensor + (B, C, C) tensor of the descriptors of the the features + inds : torch.Tensor + (B, npoint) tensor that stores index to the xyz points (values in 0-N-1) + + Returns + ------- + new_xyz : torch.Tensor + (B, npoint, 3) tensor of the new features' xyz + new_features : torch.Tensor + (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors + inds: torch.Tensor + (B, npoint) tensor of the inds + """ + new_features_list = [] + + xyz_flipped = xyz.transpose(1, 2).contiguous() + if inds is None: + inds = pointnet2_utils.furthest_point_sample(xyz, self.npoint) + new_xyz = pointnet2_utils.gather_operation( + xyz_flipped, inds + ).transpose(1, 2).contiguous() if self.npoint is not None else None + + for i in range(len(self.groupers)): + new_features = self.groupers[i]( + xyz, new_xyz, features + ) # (B, C, npoint, nsample) + new_features = self.mlps[i]( + new_features + ) # (B, mlp[-1], npoint, nsample) + new_features = F.max_pool2d( + new_features, kernel_size=[1, new_features.size(3)] + ) # (B, mlp[-1], npoint, 1) + new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) + + new_features_list.append(new_features) + + return new_xyz, torch.cat(new_features_list, dim=1), inds + + +class PointnetFPModule(nn.Module): + r"""Propigates the features of one set to another + + Parameters + ---------- + mlp : list + Pointnet module parameters + bn : bool + Use batchnorm + """ + + def __init__(self, *, mlp: List[int], bn: bool = True): + super().__init__() + self.mlp = pt_utils.SharedMLP(mlp, bn=bn) + + def forward( + self, unknown: torch.Tensor, known: torch.Tensor, + unknow_feats: torch.Tensor, known_feats: torch.Tensor + ) -> torch.Tensor: + r""" + Parameters + ---------- + unknown : torch.Tensor + (B, n, 3) tensor of the xyz positions of the unknown features + known : torch.Tensor + (B, m, 3) tensor of the xyz positions of the known features + unknow_feats : torch.Tensor + (B, C1, n) tensor of the features to be propigated to + known_feats : torch.Tensor + (B, C2, m) tensor of features to be propigated + + Returns + ------- + new_features : torch.Tensor + (B, mlp[-1], n) tensor of the features of the unknown features + """ + + if known is not None: + dist, idx = pointnet2_utils.three_nn(unknown, known) + dist_recip = 1.0 / (dist + 1e-8) + norm = torch.sum(dist_recip, dim=2, keepdim=True) + weight = dist_recip / norm + + interpolated_feats = pointnet2_utils.three_interpolate( + known_feats, idx, weight + ) + else: + interpolated_feats = known_feats.expand( + *known_feats.size()[0:2], unknown.size(1) + ) + + if unknow_feats is not None: + new_features = torch.cat([interpolated_feats, unknow_feats], + dim=1) #(B, C2 + C1, n) + else: + new_features = interpolated_feats + + new_features = new_features.unsqueeze(-1) + new_features = self.mlp(new_features) + + return new_features.squeeze(-1) + +class PointnetLFPModuleMSG(nn.Module): + ''' Modified based on _PointnetSAModuleBase and PointnetSAModuleMSG + learnable feature propagation layer.''' + + def __init__( + self, + *, + mlps: List[List[int]], + radii: List[float], + nsamples: List[int], + post_mlp: List[int], + bn: bool = True, + use_xyz: bool = True, + sample_uniformly: bool = False + ): + super().__init__() + + assert(len(mlps) == len(nsamples) == len(radii)) + + self.post_mlp = pt_utils.SharedMLP(post_mlp, bn=bn) + + self.groupers = nn.ModuleList() + self.mlps = nn.ModuleList() + for i in range(len(radii)): + radius = radii[i] + nsample = nsamples[i] + self.groupers.append( + pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz, + sample_uniformly=sample_uniformly) + ) + mlp_spec = mlps[i] + if use_xyz: + mlp_spec[0] += 3 + + self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn)) + + def forward(self, xyz2: torch.Tensor, xyz1: torch.Tensor, + features2: torch.Tensor, features1: torch.Tensor) -> torch.Tensor: + r""" Propagate features from xyz1 to xyz2. + Parameters + ---------- + xyz2 : torch.Tensor + (B, N2, 3) tensor of the xyz coordinates of the features + xyz1 : torch.Tensor + (B, N1, 3) tensor of the xyz coordinates of the features + features2 : torch.Tensor + (B, C2, N2) tensor of the descriptors of the the features + features1 : torch.Tensor + (B, C1, N1) tensor of the descriptors of the the features + + Returns + ------- + new_features1 : torch.Tensor + (B, \sum_k(mlps[k][-1]), N1) tensor of the new_features descriptors + """ + new_features_list = [] + + for i in range(len(self.groupers)): + new_features = self.groupers[i]( + xyz1, xyz2, features1 + ) # (B, C1, N2, nsample) + new_features = self.mlps[i]( + new_features + ) # (B, mlp[-1], N2, nsample) + new_features = F.max_pool2d( + new_features, kernel_size=[1, new_features.size(3)] + ) # (B, mlp[-1], N2, 1) + new_features = new_features.squeeze(-1) # (B, mlp[-1], N2) + + if features2 is not None: + new_features = torch.cat([new_features, features2], + dim=1) #(B, mlp[-1] + C2, N2) + + new_features = new_features.unsqueeze(-1) + new_features = self.post_mlp(new_features) + + new_features_list.append(new_features) + + return torch.cat(new_features_list, dim=1).squeeze(-1) + + +if __name__ == "__main__": + from torch.autograd import Variable + torch.manual_seed(1) + torch.cuda.manual_seed_all(1) + xyz = Variable(torch.randn(2, 9, 3).cuda(), requires_grad=True) + xyz_feats = Variable(torch.randn(2, 9, 6).cuda(), requires_grad=True) + + test_module = PointnetSAModuleMSG( + npoint=2, radii=[5.0, 10.0], nsamples=[6, 3], mlps=[[9, 3], [9, 6]] + ) + test_module.cuda() + print(test_module(xyz, xyz_feats)) + + for _ in range(1): + _, new_features = test_module(xyz, xyz_feats) + new_features.backward( + torch.cuda.FloatTensor(*new_features.size()).fill_(1) + ) + print(new_features) + print(xyz.grad) diff --git a/third_party/pointnet2/pointnet2_test.py b/third_party/pointnet2/pointnet2_test.py new file mode 100644 index 0000000..0ab26ce --- /dev/null +++ b/third_party/pointnet2/pointnet2_test.py @@ -0,0 +1,30 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +''' Testing customized ops. ''' + +import torch +from torch.autograd import gradcheck +import numpy as np + +import os +import sys +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(BASE_DIR) +import pointnet2_utils + +def test_interpolation_grad(): + batch_size = 1 + feat_dim = 2 + m = 4 + feats = torch.randn(batch_size, feat_dim, m, requires_grad=True).float().cuda() + + def interpolate_func(inputs): + idx = torch.from_numpy(np.array([[[0,1,2],[1,2,3]]])).int().cuda() + weight = torch.from_numpy(np.array([[[1,1,1],[2,2,2]]])).float().cuda() + interpolated_feats = pointnet2_utils.three_interpolate(inputs, idx, weight) + return interpolated_feats + + assert (gradcheck(interpolate_func, feats, atol=1e-1, rtol=1e-1)) + +if __name__=='__main__': + test_interpolation_grad() diff --git a/third_party/pointnet2/pointnet2_utils.py b/third_party/pointnet2/pointnet2_utils.py new file mode 100644 index 0000000..3227f21 --- /dev/null +++ b/third_party/pointnet2/pointnet2_utils.py @@ -0,0 +1,422 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +''' Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch ''' +from __future__ import ( + division, + absolute_import, + with_statement, + print_function, + unicode_literals, +) +import torch +from torch.autograd import Function +import torch.nn as nn +import third_party.pointnet2.pytorch_utils as pt_utils +import sys + +try: + import builtins +except: + import __builtin__ as builtins + +try: + import pointnet2._ext as _ext +except ImportError: + if not getattr(builtins, "__POINTNET2_SETUP__", False): + raise ImportError( + "Could not import _ext module.\n" + "Please see the setup instructions in the README: " + "https://github.com/erikwijmans/Pointnet2_PyTorch/blob/master/README.rst" + ) + +if False: + # Workaround for type hints without depending on the `typing` module + from typing import * + + +class RandomDropout(nn.Module): + def __init__(self, p=0.5, inplace=False): + super(RandomDropout, self).__init__() + self.p = p + self.inplace = inplace + + def forward(self, X): + theta = torch.Tensor(1).uniform_(0, self.p)[0] + return pt_utils.feature_dropout_no_scaling(X, theta, self.train, self.inplace) + + +class FurthestPointSampling(Function): + @staticmethod + def forward(ctx, xyz, npoint): + # type: (Any, torch.Tensor, int) -> torch.Tensor + r""" + Uses iterative furthest point sampling to select a set of npoint features that have the largest + minimum distance + + Parameters + ---------- + xyz : torch.Tensor + (B, N, 3) tensor where N > npoint + npoint : int32 + number of features in the sampled set + + Returns + ------- + torch.Tensor + (B, npoint) tensor containing the set + """ + fps_inds = _ext.furthest_point_sampling(xyz, npoint) + ctx.mark_non_differentiable(fps_inds) + return fps_inds + + @staticmethod + def backward(xyz, a=None): + return None, None + + +furthest_point_sample = FurthestPointSampling.apply + + +class GatherOperation(Function): + @staticmethod + def forward(ctx, features, idx): + # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor + r""" + + Parameters + ---------- + features : torch.Tensor + (B, C, N) tensor + + idx : torch.Tensor + (B, npoint) tensor of the features to gather + + Returns + ------- + torch.Tensor + (B, C, npoint) tensor + """ + + _, C, N = features.size() + + ctx.for_backwards = (idx, C, N) + + return _ext.gather_points(features, idx) + + @staticmethod + def backward(ctx, grad_out): + idx, C, N = ctx.for_backwards + + grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) + return grad_features, None + + +gather_operation = GatherOperation.apply + + +class ThreeNN(Function): + @staticmethod + def forward(ctx, unknown, known): + # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] + r""" + Find the three nearest neighbors of unknown in known + Parameters + ---------- + unknown : torch.Tensor + (B, n, 3) tensor of known features + known : torch.Tensor + (B, m, 3) tensor of unknown features + + Returns + ------- + dist : torch.Tensor + (B, n, 3) l2 distance to the three nearest neighbors + idx : torch.Tensor + (B, n, 3) index of 3 nearest neighbors + """ + dist2, idx = _ext.three_nn(unknown, known) + + return torch.sqrt(dist2), idx + + @staticmethod + def backward(ctx, a=None, b=None): + return None, None + + +three_nn = ThreeNN.apply + + +class ThreeInterpolate(Function): + @staticmethod + def forward(ctx, features, idx, weight): + # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor + r""" + Performs weight linear interpolation on 3 features + Parameters + ---------- + features : torch.Tensor + (B, c, m) Features descriptors to be interpolated from + idx : torch.Tensor + (B, n, 3) three nearest neighbors of the target features in features + weight : torch.Tensor + (B, n, 3) weights + + Returns + ------- + torch.Tensor + (B, c, n) tensor of the interpolated features + """ + B, c, m = features.size() + n = idx.size(1) + + ctx.three_interpolate_for_backward = (idx, weight, m) + + return _ext.three_interpolate(features, idx, weight) + + @staticmethod + def backward(ctx, grad_out): + # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + r""" + Parameters + ---------- + grad_out : torch.Tensor + (B, c, n) tensor with gradients of ouputs + + Returns + ------- + grad_features : torch.Tensor + (B, c, m) tensor with gradients of features + + None + + None + """ + idx, weight, m = ctx.three_interpolate_for_backward + + grad_features = _ext.three_interpolate_grad( + grad_out.contiguous(), idx, weight, m + ) + + return grad_features, None, None + + +three_interpolate = ThreeInterpolate.apply + + +class GroupingOperation(Function): + @staticmethod + def forward(ctx, features, idx): + # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor + r""" + + Parameters + ---------- + features : torch.Tensor + (B, C, N) tensor of features to group + idx : torch.Tensor + (B, npoint, nsample) tensor containing the indicies of features to group with + + Returns + ------- + torch.Tensor + (B, C, npoint, nsample) tensor + """ + B, nfeatures, nsample = idx.size() + _, C, N = features.size() + + ctx.for_backwards = (idx, N) + + return _ext.group_points(features, idx) + + @staticmethod + def backward(ctx, grad_out): + # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] + r""" + + Parameters + ---------- + grad_out : torch.Tensor + (B, C, npoint, nsample) tensor of the gradients of the output from forward + + Returns + ------- + torch.Tensor + (B, C, N) gradient of the features + None + """ + idx, N = ctx.for_backwards + + grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) + + return grad_features, None + + +grouping_operation = GroupingOperation.apply + + +class BallQuery(Function): + @staticmethod + def forward(ctx, radius, nsample, xyz, new_xyz): + # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor + r""" + + Parameters + ---------- + radius : float + radius of the balls + nsample : int + maximum number of features in the balls + xyz : torch.Tensor + (B, N, 3) xyz coordinates of the features + new_xyz : torch.Tensor + (B, npoint, 3) centers of the ball query + + Returns + ------- + torch.Tensor + (B, npoint, nsample) tensor with the indicies of the features that form the query balls + """ + inds = _ext.ball_query(new_xyz, xyz, radius, nsample) + ctx.mark_non_differentiable(inds) + return inds + + @staticmethod + def backward(ctx, a=None): + return None, None, None, None + + +ball_query = BallQuery.apply + + +class QueryAndGroup(nn.Module): + r""" + Groups with a ball query of radius + + Parameters + --------- + radius : float32 + Radius of ball + nsample : int32 + Maximum number of features to gather in the ball + """ + + def __init__(self, radius, nsample, use_xyz=True, ret_grouped_xyz=False, normalize_xyz=False, sample_uniformly=False, ret_unique_cnt=False): + # type: (QueryAndGroup, float, int, bool) -> None + super(QueryAndGroup, self).__init__() + self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz + self.ret_grouped_xyz = ret_grouped_xyz + self.normalize_xyz = normalize_xyz + self.sample_uniformly = sample_uniformly + self.ret_unique_cnt = ret_unique_cnt + if self.ret_unique_cnt: + assert(self.sample_uniformly) + + def forward(self, xyz, new_xyz, features=None): + # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] + r""" + Parameters + ---------- + xyz : torch.Tensor + xyz coordinates of the features (B, N, 3) + new_xyz : torch.Tensor + centriods (B, npoint, 3) + features : torch.Tensor + Descriptors of the features (B, C, N) + + Returns + ------- + new_features : torch.Tensor + (B, 3 + C, npoint, nsample) tensor + """ + idx = ball_query(self.radius, self.nsample, xyz, new_xyz) + + if self.sample_uniformly: + unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) + for i_batch in range(idx.shape[0]): + for i_region in range(idx.shape[1]): + unique_ind = torch.unique(idx[i_batch, i_region, :]) + num_unique = unique_ind.shape[0] + unique_cnt[i_batch, i_region] = num_unique + sample_ind = torch.randint(0, num_unique, (self.nsample - num_unique,), dtype=torch.long) + all_ind = torch.cat((unique_ind, unique_ind[sample_ind])) + idx[i_batch, i_region, :] = all_ind + + + xyz_trans = xyz.transpose(1, 2).contiguous() + grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) + grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) + if self.normalize_xyz: + grouped_xyz /= self.radius + + if features is not None: + grouped_features = grouping_operation(features, idx) + if self.use_xyz: + new_features = torch.cat( + [grouped_xyz, grouped_features], dim=1 + ) # (B, C + 3, npoint, nsample) + else: + new_features = grouped_features + else: + assert ( + self.use_xyz + ), "Cannot have not features and not use xyz as a feature!" + new_features = grouped_xyz + + ret = [new_features] + if self.ret_grouped_xyz: + ret.append(grouped_xyz) + if self.ret_unique_cnt: + ret.append(unique_cnt) + if len(ret) == 1: + return ret[0] + else: + return tuple(ret) + + +class GroupAll(nn.Module): + r""" + Groups all features + + Parameters + --------- + """ + + def __init__(self, use_xyz=True, ret_grouped_xyz=False): + # type: (GroupAll, bool) -> None + super(GroupAll, self).__init__() + self.use_xyz = use_xyz + + def forward(self, xyz, new_xyz, features=None): + # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] + r""" + Parameters + ---------- + xyz : torch.Tensor + xyz coordinates of the features (B, N, 3) + new_xyz : torch.Tensor + Ignored + features : torch.Tensor + Descriptors of the features (B, C, N) + + Returns + ------- + new_features : torch.Tensor + (B, C + 3, 1, N) tensor + """ + + grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) + if features is not None: + grouped_features = features.unsqueeze(2) + if self.use_xyz: + new_features = torch.cat( + [grouped_xyz, grouped_features], dim=1 + ) # (B, 3 + C, 1, N) + else: + new_features = grouped_features + else: + new_features = grouped_xyz + + if self.ret_grouped_xyz: + return new_features, grouped_xyz + else: + return new_features diff --git a/third_party/pointnet2/pytorch_utils.py b/third_party/pointnet2/pytorch_utils.py new file mode 100644 index 0000000..4eb8572 --- /dev/null +++ b/third_party/pointnet2/pytorch_utils.py @@ -0,0 +1,295 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +''' Modified based on Ref: https://github.com/erikwijmans/Pointnet2_PyTorch ''' +import torch +import torch.nn as nn +from typing import List, Tuple + +class SharedMLP(nn.Sequential): + + def __init__( + self, + args: List[int], + *, + bn: bool = False, + activation=nn.ReLU(inplace=True), + preact: bool = False, + first: bool = False, + name: str = "" + ): + super().__init__() + + for i in range(len(args) - 1): + self.add_module( + name + 'layer{}'.format(i), + Conv2d( + args[i], + args[i + 1], + bn=(not first or not preact or (i != 0)) and bn, + activation=activation + if (not first or not preact or (i != 0)) else None, + preact=preact + ) + ) + + +class _BNBase(nn.Sequential): + + def __init__(self, in_size, batch_norm=None, name=""): + super().__init__() + self.add_module(name + "bn", batch_norm(in_size)) + + nn.init.constant_(self[0].weight, 1.0) + nn.init.constant_(self[0].bias, 0) + + +class BatchNorm1d(_BNBase): + + def __init__(self, in_size: int, *, name: str = ""): + super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) + + +class BatchNorm2d(_BNBase): + + def __init__(self, in_size: int, name: str = ""): + super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) + + +class BatchNorm3d(_BNBase): + + def __init__(self, in_size: int, name: str = ""): + super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name) + + +class _ConvBase(nn.Sequential): + + def __init__( + self, + in_size, + out_size, + kernel_size, + stride, + padding, + activation, + bn, + init, + conv=None, + batch_norm=None, + bias=True, + preact=False, + name="" + ): + super().__init__() + + bias = bias and (not bn) + conv_unit = conv( + in_size, + out_size, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias + ) + init(conv_unit.weight) + if bias: + nn.init.constant_(conv_unit.bias, 0) + + if bn: + if not preact: + bn_unit = batch_norm(out_size) + else: + bn_unit = batch_norm(in_size) + + if preact: + if bn: + self.add_module(name + 'bn', bn_unit) + + if activation is not None: + self.add_module(name + 'activation', activation) + + self.add_module(name + 'conv', conv_unit) + + if not preact: + if bn: + self.add_module(name + 'bn', bn_unit) + + if activation is not None: + self.add_module(name + 'activation', activation) + + +class Conv1d(_ConvBase): + + def __init__( + self, + in_size: int, + out_size: int, + *, + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + activation=nn.ReLU(inplace=True), + bn: bool = False, + init=nn.init.kaiming_normal_, + bias: bool = True, + preact: bool = False, + name: str = "" + ): + super().__init__( + in_size, + out_size, + kernel_size, + stride, + padding, + activation, + bn, + init, + conv=nn.Conv1d, + batch_norm=BatchNorm1d, + bias=bias, + preact=preact, + name=name + ) + + +class Conv2d(_ConvBase): + + def __init__( + self, + in_size: int, + out_size: int, + *, + kernel_size: Tuple[int, int] = (1, 1), + stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), + activation=nn.ReLU(inplace=True), + bn: bool = False, + init=nn.init.kaiming_normal_, + bias: bool = True, + preact: bool = False, + name: str = "" + ): + super().__init__( + in_size, + out_size, + kernel_size, + stride, + padding, + activation, + bn, + init, + conv=nn.Conv2d, + batch_norm=BatchNorm2d, + bias=bias, + preact=preact, + name=name + ) + + +class Conv3d(_ConvBase): + + def __init__( + self, + in_size: int, + out_size: int, + *, + kernel_size: Tuple[int, int, int] = (1, 1, 1), + stride: Tuple[int, int, int] = (1, 1, 1), + padding: Tuple[int, int, int] = (0, 0, 0), + activation=nn.ReLU(inplace=True), + bn: bool = False, + init=nn.init.kaiming_normal_, + bias: bool = True, + preact: bool = False, + name: str = "" + ): + super().__init__( + in_size, + out_size, + kernel_size, + stride, + padding, + activation, + bn, + init, + conv=nn.Conv3d, + batch_norm=BatchNorm3d, + bias=bias, + preact=preact, + name=name + ) + + +class FC(nn.Sequential): + + def __init__( + self, + in_size: int, + out_size: int, + *, + activation=nn.ReLU(inplace=True), + bn: bool = False, + init=None, + preact: bool = False, + name: str = "" + ): + super().__init__() + + fc = nn.Linear(in_size, out_size, bias=not bn) + if init is not None: + init(fc.weight) + if not bn: + nn.init.constant_(fc.bias, 0) + + if preact: + if bn: + self.add_module(name + 'bn', BatchNorm1d(in_size)) + + if activation is not None: + self.add_module(name + 'activation', activation) + + self.add_module(name + 'fc', fc) + + if not preact: + if bn: + self.add_module(name + 'bn', BatchNorm1d(out_size)) + + if activation is not None: + self.add_module(name + 'activation', activation) + +def set_bn_momentum_default(bn_momentum): + + def fn(m): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): + m.momentum = bn_momentum + + return fn + + +class BNMomentumScheduler(object): + + def __init__( + self, model, bn_lambda, last_epoch=-1, + setter=set_bn_momentum_default + ): + if not isinstance(model, nn.Module): + raise RuntimeError( + "Class '{}' is not a PyTorch nn Module".format( + type(model).__name__ + ) + ) + + self.model = model + self.setter = setter + self.lmbd = bn_lambda + + self.step(last_epoch + 1) + self.last_epoch = last_epoch + + def step(self, epoch=None): + if epoch is None: + epoch = self.last_epoch + 1 + + self.last_epoch = epoch + self.model.apply(self.setter(self.lmbd(epoch))) + + diff --git a/third_party/pointnet2/setup.py b/third_party/pointnet2/setup.py new file mode 100644 index 0000000..ae3b5bf --- /dev/null +++ b/third_party/pointnet2/setup.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import glob +import os.path as osp + +this_dir = osp.dirname(osp.abspath(__file__)) + +_ext_src_root = "_ext_src" +_ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob( + "{}/src/*.cu".format(_ext_src_root) +) +_ext_headers = glob.glob("{}/include/*".format(_ext_src_root)) + +setup( + name='pointnet2', + ext_modules=[ + CUDAExtension( + name='pointnet2._ext', + sources=_ext_sources, + extra_compile_args={ + "cxx": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], + "nvcc": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], + }, + include_dirs=[osp.join(this_dir, _ext_src_root, "include")], + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/trainer/trainer.py b/trainer/trainer.py new file mode 100644 index 0000000..87baabf --- /dev/null +++ b/trainer/trainer.py @@ -0,0 +1,305 @@ +import statistics +import hydra +import MinkowskiEngine as ME +import numpy as np +import pytorch_lightning as pl +import torch +from sklearn.cluster import DBSCAN +from contextlib import nullcontext +from collections import defaultdict +from utils.utils import associate_instances, save_predictions + + +class PanopticSegmentation(pl.LightningModule): + def __init__(self, config): + super().__init__() + + self.config = config + self.save_hyperparameters() + # model + self.model = hydra.utils.instantiate(config.model) + self.optional_freeze = nullcontext + + matcher = hydra.utils.instantiate(config.matcher) + weight_dict = {"loss_ce": matcher.cost_class, + "loss_mask": matcher.cost_mask, + "loss_dice": matcher.cost_dice, + "loss_box": matcher.cost_box} + + aux_weight_dict = {} + for i in range(self.model.num_levels * self.model.num_decoders): + aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + self.criterion = hydra.utils.instantiate(config.loss, matcher=matcher, weight_dict=weight_dict) + # metrics + self.class_evaluator = hydra.utils.instantiate(config.metric) + self.last_seq = None + + def forward(self, x, raw_coordinates=None, is_eval=False): + with self.optional_freeze(): + x = self.model(x, raw_coordinates=raw_coordinates, is_eval=is_eval) + return x + + def training_step(self, batch, batch_idx): + data, target = batch + raw_coordinates = data.raw_coordinates + data = ME.SparseTensor(coordinates=data.coordinates, features=data.features, device=self.device) + + output = self.forward(data, raw_coordinates=raw_coordinates) + losses = self.criterion(output, target) + + for k in list(losses.keys()): + if k in self.criterion.weight_dict: + losses[k] *= self.criterion.weight_dict[k] + else: + # remove this loss if not specified in `weight_dict` + losses.pop(k) + + logs = {f"train_{k}": v.detach().cpu().item() for k,v in losses.items()} + + logs['train_mean_loss_ce'] = statistics.mean( + [item for item in [v for k, v in logs.items() if "loss_ce" in k]]) + + logs['train_mean_loss_mask'] = statistics.mean( + [item for item in [v for k, v in logs.items() if "loss_mask" in k]]) + + logs['train_mean_loss_dice'] = statistics.mean( + [item for item in [v for k, v in logs.items() if "loss_dice" in k]]) + + logs['train_mean_loss_box'] = statistics.mean( + [item for item in [v for k, v in logs.items() if "loss_box" in k]]) + + self.log_dict(logs) + return sum(losses.values()) + + def validation_step(self, batch, batch_idx): + data, target = batch + inverse_maps = data.inverse_maps + original_labels = data.original_labels + raw_coordinates = data.raw_coordinates + num_points = data.num_points + sequences = data.sequences + + data = ME.SparseTensor(coordinates=data.coordinates, features=data.features, device=self.device) + output = self.forward(data, raw_coordinates=raw_coordinates, is_eval=True) + losses = self.criterion(output, target) + + for k in list(losses.keys()): + if k in self.criterion.weight_dict: + losses[k] *= self.criterion.weight_dict[k] + else: + # remove this loss if not specified in `weight_dict` + losses.pop(k) + + pred_logits = output['pred_logits'] + pred_logits = torch.functional.F.softmax(pred_logits, dim=-1)[..., :-1] + pred_masks= output['pred_masks'] + offset_coords_idx = 0 + + for logit, mask, map, label, n_point, seq in zip(pred_logits, pred_masks, inverse_maps, original_labels, num_points, sequences): + if seq != self.last_seq: + self.last_seq = seq + self.previous_instances = None + self.max_instance_id = self.config.model.num_queries + self.scene = 0 + + class_confidence, classes = torch.max(logit.detach().cpu(), dim=1) + foreground_confidence = mask.detach().cpu().float().sigmoid() + confidence = class_confidence[None, ...] * foreground_confidence + confidence = confidence[map].numpy() + + ins_preds = np.argmax(confidence, axis=1) + sem_preds = classes[ins_preds].numpy() + 1 + ins_preds += 1 + ins_preds[np.isin(sem_preds, range(1, self.config.data.min_stuff_cls_id), invert=True)] = 0 + sem_labels = self.validation_dataset._remap_model_output(label[:, 0]) + ins_labels = label[:, 1] >> 16 + + db_max_instance_id = self.config.model.num_queries + if self.config.general.dbscan_eps is not None: + curr_coords_idx = mask.shape[0] + curr_coords = raw_coordinates[offset_coords_idx:curr_coords_idx + offset_coords_idx, :3] + curr_coords = curr_coords[map].detach().cpu().numpy() + offset_coords_idx += curr_coords_idx + + ins_ids = np.unique(ins_preds) + for ins_id in ins_ids: + if ins_id != 0: + instance_mask = ins_preds == ins_id + clusters = DBSCAN(eps=self.config.general.dbscan_eps, min_samples=1, n_jobs=-1).fit(curr_coords[instance_mask]).labels_ + new_mask = np.zeros(ins_preds.shape, dtype=np.int64) + new_mask[instance_mask] = clusters + 1 + for cluster_id in np.unique(new_mask): + if cluster_id != 0: + db_max_instance_id += 1 + ins_preds[new_mask == cluster_id] = db_max_instance_id + + self.max_instance_id = max(db_max_instance_id, self.max_instance_id) + for i in range(len(n_point) - 1): + indices = range(n_point[i], n_point[i+1]) + if i == 0 and self.previous_instances is not None: + current_instances = ins_preds[indices] + associations = associate_instances(self.previous_instances, current_instances) + for id in np.unique(ins_preds): + if associations.get(id) is None: + self.max_instance_id += 1 + associations[id] = self.max_instance_id + ins_preds = np.vectorize(associations.__getitem__)(ins_preds) + else: + self.class_evaluator.addBatch(sem_preds, ins_preds, sem_labels, ins_labels, indices, seq) + if i > 0: + self.previous_instances = ins_preds[indices] + + return {f"val_{k}": v.detach().cpu().item() for k, v in losses.items()} + + def test_step(self, batch, batch_idx): + data, _ = batch + inverse_maps = data.inverse_maps + raw_coordinates = data.raw_coordinates + num_points = data.num_points + sequences = data.sequences + + data = ME.SparseTensor(coordinates=data.coordinates, features=data.features, device=self.device) + output = self.forward(data, raw_coordinates=raw_coordinates, is_eval=True) + + pred_logits = output['pred_logits'] + pred_logits = torch.functional.F.softmax(pred_logits, dim=-1)[..., :-1] + pred_masks= output['pred_masks'] + + offset_coords_idx = 0 + + for logit, mask, map, n_point, seq in zip(pred_logits, pred_masks, inverse_maps, num_points, sequences): + if seq != self.last_seq: + self.last_seq = seq + self.previous_instances = None + self.max_instance_id = self.config.model.num_queries + self.scene = 0 + class_confidence, classes = torch.max(logit.detach().cpu(), dim=1) + foreground_confidence = mask.detach().cpu().float().sigmoid() + confidence = class_confidence[None, ...] * foreground_confidence + confidence = confidence[map].numpy() + + ins_preds = np.argmax(confidence, axis=1) + sem_preds = classes[ins_preds].numpy() + 1 + ins_preds += 1 + ins_preds[np.isin(sem_preds, range(1, self.config.data.min_stuff_cls_id), invert=True)] = 0 + + db_max_instance_id = self.config.model.num_queries + if self.config.general.dbscan_eps is not None: + curr_coords_idx = mask.shape[0] + curr_coords = raw_coordinates[offset_coords_idx:curr_coords_idx + offset_coords_idx, :3] + curr_coords = curr_coords[map].detach().cpu().numpy() + offset_coords_idx += curr_coords_idx + + ins_ids = np.unique(ins_preds) + for ins_id in ins_ids: + if ins_id != 0: + instance_mask = ins_preds == ins_id + clusters = DBSCAN(eps=self.config.general.dbscan_eps, min_samples=1, n_jobs=-1).fit(curr_coords[instance_mask]).labels_ + new_mask = np.zeros(ins_preds.shape, dtype=np.int64) + new_mask[instance_mask] = clusters + 1 + for cluster_id in np.unique(new_mask): + if cluster_id != 0: + db_max_instance_id += 1 + ins_preds[new_mask == cluster_id] = db_max_instance_id + + self.max_instance_id = max(db_max_instance_id, self.max_instance_id) + for i in range(len(n_point) - 1): + indices = range(n_point[i], n_point[i+1]) + if i == 0 and self.previous_instances is not None: + current_instances = ins_preds[indices] + associations = associate_instances(self.previous_instances, current_instances) + for id in np.unique(ins_preds): + if associations.get(id) is None: + self.max_instance_id += 1 + associations[id] = self.max_instance_id + ins_preds = np.vectorize(associations.__getitem__)(ins_preds) + else: + save_predictions(sem_preds[indices], ins_preds[indices], f"{seq:02}", f"{self.scene:06}") + self.scene += 1 + if i > 0: + self.previous_instances = ins_preds[indices] + + return {} + + def training_epoch_end(self, outputs): + train_loss = sum([out["loss"].cpu().item() for out in outputs]) / len(outputs) + results = {"train_loss_mean": train_loss} + self.log_dict(results) + + def validation_epoch_end(self, outputs): + self.last_seq = None + class_names = self.config.data.class_names + lstq, aq, all_aq, iou, all_iou = self.class_evaluator.getPQ4D() + self.class_evaluator.reset() + results = {} + results["val_mean_aq"] = aq + results["val_mean_iou"] = iou + results["val_mean_lstq"] = lstq + for i, (aq, iou) in enumerate(zip(all_aq, all_iou)): + results[f"val_{class_names[i]}_aq"] = aq.item() + results[f"val_{class_names[i]}_iou"] = iou.item() + self.log_dict(results) + + dd = defaultdict(list) + for output in outputs: + for key, val in output.items(): + dd[key].append(val) + + dd = {k: statistics.mean(v) for k, v in dd.items()} + + dd['val_mean_loss_ce'] = statistics.mean([item for item in [v for k,v in dd.items() if "loss_ce" in k]]) + dd['val_mean_loss_mask'] = statistics.mean([item for item in [v for k,v in dd.items() if "loss_mask" in k]]) + dd['val_mean_loss_dice'] = statistics.mean([item for item in [v for k,v in dd.items() if "loss_dice" in k]]) + + self.log_dict(dd) + + def test_epoch_end(self, outputs): + return {} + + def configure_optimizers(self): + optimizer = hydra.utils.instantiate( + self.config.optimizer, params=self.parameters() + ) + if "steps_per_epoch" in self.config.scheduler.scheduler.keys(): + self.config.scheduler.scheduler.steps_per_epoch = len( + self.train_dataloader() + ) + lr_scheduler = hydra.utils.instantiate( + self.config.scheduler.scheduler, optimizer=optimizer + ) + scheduler_config = {"scheduler": lr_scheduler} + scheduler_config.update(self.config.scheduler.pytorch_lightning_params) + return [optimizer], [scheduler_config] + + def prepare_data(self): + self.train_dataset = hydra.utils.instantiate(self.config.data.train_dataset) + self.validation_dataset = hydra.utils.instantiate( + self.config.data.validation_dataset + ) + self.test_dataset = hydra.utils.instantiate(self.config.data.test_dataset) + + def train_dataloader(self): + c_fn = hydra.utils.instantiate(self.config.data.train_collation) + return hydra.utils.instantiate( + self.config.data.train_dataloader, + self.train_dataset, + collate_fn=c_fn, + ) + + def val_dataloader(self): + c_fn = hydra.utils.instantiate(self.config.data.validation_collation) + return hydra.utils.instantiate( + self.config.data.validation_dataloader, + self.validation_dataset, + collate_fn=c_fn, + ) + + def test_dataloader(self): + c_fn = hydra.utils.instantiate(self.config.data.test_collation) + return hydra.utils.instantiate( + self.config.data.test_dataloader, + self.test_dataset, + collate_fn=c_fn, + ) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/pointops2/__init__.py b/utils/pointops2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/pointops2/functions/__init__.py b/utils/pointops2/functions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/pointops2/functions/pointops.py b/utils/pointops2/functions/pointops.py new file mode 100644 index 0000000..52aaa2c --- /dev/null +++ b/utils/pointops2/functions/pointops.py @@ -0,0 +1,829 @@ +from typing import Tuple + +import torch +from torch.autograd import Function +import torch.nn as nn + +import pointops2_cuda as pointops_cuda +import time + +class FurthestSampling(Function): + @staticmethod + def forward(ctx, xyz, offset, new_offset): + """ + input: xyz: (n, 3), offset: (b), new_offset: (b) + output: idx: (m) + """ + assert xyz.is_contiguous() + n, b, n_max = xyz.shape[0], offset.shape[0], offset[0] + for i in range(1, b): + n_max = max(offset[i] - offset[i-1], n_max) + idx = torch.cuda.IntTensor(new_offset[b-1].item()).zero_() + tmp = torch.cuda.FloatTensor(n).fill_(1e10) + pointops_cuda.furthestsampling_cuda(b, n_max, xyz, offset, new_offset, tmp, idx) + del tmp + return idx + +furthestsampling = FurthestSampling.apply + + +class KNNQuery(Function): + @staticmethod + def forward(ctx, nsample, xyz, new_xyz, offset, new_offset): + """ + input: xyz: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b) + output: idx: (m, nsample), dist2: (m, nsample) + """ + if new_xyz is None: new_xyz = xyz + assert xyz.is_contiguous() and new_xyz.is_contiguous() + m = new_xyz.shape[0] + idx = torch.cuda.IntTensor(m, nsample).zero_() + dist2 = torch.cuda.FloatTensor(m, nsample).zero_() + pointops_cuda.knnquery_cuda(m, nsample, xyz, new_xyz, offset, new_offset, idx, dist2) + return idx, torch.sqrt(dist2) + +knnquery = KNNQuery.apply + + +class Grouping(Function): + @staticmethod + def forward(ctx, input, idx): + """ + input: input: (n, c), idx : (m, nsample) + output: (m, nsample, c) + """ + assert input.is_contiguous() and idx.is_contiguous() + m, nsample, n, c = idx.shape[0], idx.shape[1], input.shape[0], input.shape[1] + output = torch.cuda.FloatTensor(m, nsample, c) + pointops_cuda.grouping_forward_cuda(m, nsample, c, input, idx, output) + ctx.n = n + ctx.save_for_backward(idx) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_out: (m, c, nsample) + output: (n, c), None + """ + n = ctx.n + idx, = ctx.saved_tensors + m, nsample, c = grad_output.shape + grad_input = torch.cuda.FloatTensor(n, c).zero_() + pointops_cuda.grouping_backward_cuda(m, nsample, c, grad_output, idx, grad_input) + return grad_input, None + +grouping = Grouping.apply + +class AttentionStep1(Function): + @staticmethod + def forward(ctx, q, k, index0, index1): + """ + input: q: (N, h, C//h), k: (N, h, C//h), index0: (M), index1: (M) + output: output: [N, h, C//h] + """ + assert q.is_contiguous() and k.is_contiguous() and index0.is_contiguous() and index1.is_contiguous() + + N_q, h, C_div_h = q.shape + N_k = k.shape[0] + M = index0.shape[0] + C = int(C_div_h * h) + + output = torch.cuda.FloatTensor(M, h).zero_() + pointops_cuda.attention_step1_forward_cuda(N_k, M, h, C, q, k, index0, index1, output) + ctx.N_q = N_q + ctx.N_k = N_k + ctx.C = C + ctx.save_for_backward(q, k, index0, index1) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_output: (N, h, C//h) + output: (M, h), (N, h, C//h), None, None + """ + + N_q = ctx.N_q + N_k = ctx.N_k + C = ctx.C + q, k, index0, index1 = ctx.saved_tensors + M, h = grad_output.shape + + grad_output = grad_output.contiguous() + # print("grad_output.is_contiguous(): ", grad_output.is_contiguous()) + assert q.is_contiguous() and k.is_contiguous() and index0.is_contiguous() and index1.is_contiguous() and grad_output.is_contiguous() + + # print("back: attn[:5,:5]: ", attn[:5, :5]) + + # print("attn.shape: {} v.shape: {}, index0.shape: {}, index1.shape: {}".format(attn.shape, v.shape, index0.shape, index1.shape)) + + grad_q = torch.cuda.FloatTensor(N_q, h, C//h).zero_() + grad_k = torch.cuda.FloatTensor(N_k, h, C//h).zero_() + + # torch.cuda.synchronize() + # start = time.time() + + pointops_cuda.attention_step1_backward_cuda(N_q, M, h, C, grad_output, index0, index1, q, k, grad_q, grad_k) + + # torch.cuda.synchronize() + # end = time.time() + # print("time v7: {}".format(end - start)) + # # input() + + return grad_q, grad_k, None, None + +attention_step1 = AttentionStep1.apply + +class AttentionStep1_v2(Function): + @staticmethod + def forward(ctx, q, k, index1, index0_offsets, n_max): + """ + input: q: (N, h, C//h), k: (N, h, C//h), index0: (M), index1: (M) + output: output: [N, h, C//h] + """ + assert q.is_contiguous() and k.is_contiguous() and index0_offsets.is_contiguous() and index1.is_contiguous() + assert n_max <= 1024 + + N_q, h, C_div_h = q.shape + N_k = k.shape[0] + M = index1.shape[0] + C = int(C_div_h * h) + + output = torch.cuda.FloatTensor(M, h).zero_() + pointops_cuda.attention_step1_forward_cuda_v2(N_k, M, h, C, n_max, q, k, index0_offsets, index1, output) + ctx.N_q = N_q + ctx.N_k = N_k + ctx.C = C + ctx.n_max = n_max + ctx.save_for_backward(q, k, index0_offsets, index1) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_output: (N, h, C//h) + output: (M, h), (N, h, C//h), None, None + """ + + N_q = ctx.N_q + N_k = ctx.N_k + C = ctx.C + n_max = ctx.n_max + q, k, index0_offsets, index1 = ctx.saved_tensors + M, h = grad_output.shape + + grad_output = grad_output.contiguous() + # print("grad_output.is_contiguous(): ", grad_output.is_contiguous()) + assert q.is_contiguous() and k.is_contiguous() and index0_offsets.is_contiguous() and index1.is_contiguous() and grad_output.is_contiguous() + + # print("back: attn[:5,:5]: ", attn[:5, :5]) + + # print("attn.shape: {} v.shape: {}, index0.shape: {}, index1.shape: {}".format(attn.shape, v.shape, index0.shape, index1.shape)) + + grad_q = torch.cuda.FloatTensor(N_q, h, C//h).zero_() + grad_k = torch.cuda.FloatTensor(N_k, h, C//h).zero_() + + # torch.cuda.synchronize() + # start = time.time() + + pointops_cuda.attention_step1_backward_cuda_v2(N_q, M, h, C, n_max, grad_output, index0_offsets, index1, q, k, grad_q, grad_k) + + # torch.cuda.synchronize() + # end = time.time() + # print("time v7: {}".format(end - start)) + # # input() + + return grad_q, grad_k, None, None, None + +attention_step1_v2 = AttentionStep1_v2.apply + + + +class AttentionStep2(Function): + @staticmethod + def forward(ctx, attn, v, index0, index1): + """ + input: attn: (M, h), v: (N, h, C//h), index0: (M), index1: (M) + output: output: [N, h, C//h] + """ + assert attn.is_contiguous() and v.is_contiguous() and index0.is_contiguous() and index1.is_contiguous() + + M, h = attn.shape + N_q = index0.max().item() + 1 + N_v, h, C_div_h = v.shape + C = int(C_div_h * h) + + output = torch.cuda.FloatTensor(N_q, h, C//h).zero_() + pointops_cuda.attention_step2_forward_cuda(N_q, M, h, C, attn, v, index0, index1, output) + ctx.M = M + + # print("attn[:5,:5]: ", attn[:5, :5]) + + ctx.save_for_backward(attn, v, index0, index1) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_output: (N, h, C//h) + output: (M, h), (N, h, C//h), None, None + """ + M = ctx.M + attn, v, index0, index1 = ctx.saved_tensors + N_v = v.shape[0] + N_q, h, C_div_h = grad_output.shape + C = h * C_div_h + + grad_output = grad_output.contiguous() + # print("grad_output.is_contiguous(): ", grad_output.is_contiguous()) + assert attn.is_contiguous() and v.is_contiguous() and index0.is_contiguous() and index1.is_contiguous() and grad_output.is_contiguous() + + # print("back: attn[:5,:5]: ", attn[:5, :5]) + + # print("attn.shape: {} v.shape: {}, index0.shape: {}, index1.shape: {}".format(attn.shape, v.shape, index0.shape, index1.shape)) + + grad_attn = torch.cuda.FloatTensor(M, h).zero_() + grad_v = torch.cuda.FloatTensor(N_v, h, C//h).zero_() + + # torch.cuda.synchronize() + # start = time.time() + + pointops_cuda.attention_step2_backward_cuda(N_q, M, h, C, grad_output, index0, index1, attn, v, grad_attn, grad_v) + + # torch.cuda.synchronize() + # end = time.time() + # print("time v8: {}".format(end - start)) + # # input() + + return grad_attn, grad_v, None, None + +attention_step2 = AttentionStep2.apply + + +class AttentionStep2_v2(Function): + @staticmethod + def forward(ctx, attn, v, index0, index1): + """ + input: attn: (M, h), v: (N, h, C//h), index0: (M), index1: (M) + output: output: [L, h, C//h] + """ + assert attn.is_contiguous() and v.is_contiguous() and index0.is_contiguous() and index1.is_contiguous() + + L = int(index0.max().item()) + 1 + + M, h = attn.shape + N, h, C_div_h = v.shape + C = int(C_div_h * h) + + output = torch.cuda.FloatTensor(L, h, C//h).zero_() + pointops_cuda.attention_step2_forward_cuda(N, M, h, C, attn, v, index0, index1, output) + ctx.M = M + + # print("attn[:5,:5]: ", attn[:5, :5]) + + ctx.save_for_backward(attn, v, index0, index1) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_output: (L, h, C//h) + output: (M, h), (N, h, C//h), None, None + """ + M = ctx.M + attn, v, index0, index1 = ctx.saved_tensors + L, h, C_div_h = grad_output.shape + N = v.shape[0] + C = h * C_div_h + + grad_output = grad_output.contiguous() + # print("grad_output.is_contiguous(): ", grad_output.is_contiguous()) + assert attn.is_contiguous() and v.is_contiguous() and index0.is_contiguous() and index1.is_contiguous() and grad_output.is_contiguous() + + # print("back: attn[:5,:5]: ", attn[:5, :5]) + + # print("attn.shape: {} v.shape: {}, index0.shape: {}, index1.shape: {}".format(attn.shape, v.shape, index0.shape, index1.shape)) + + grad_attn = torch.cuda.FloatTensor(M, h).zero_() + grad_v = torch.cuda.FloatTensor(N, h, C//h).zero_() + + pointops_cuda.attention_step2_backward_cuda(N, M, h, C, grad_output, index0, index1, attn, v, grad_attn, grad_v) + return grad_attn, grad_v, None, None + +attention_step2_v2 = AttentionStep2_v2.apply + +class DotProdWithIdx(Function): + @staticmethod + def forward(ctx, q, index, table, rel_idx): + """ + input: q: (N, h, hdim), index: (M), table: (L, h, hdim, 3), rel_idx: (M, 3) + output: output: [M, h] + """ + assert q.is_contiguous() and index.is_contiguous() and table.is_contiguous() and rel_idx.is_contiguous() + + N, h, hdim = q.shape + M = index.shape[0] + + output = torch.cuda.FloatTensor(M, h).zero_() + pointops_cuda.dot_prod_with_idx_forward_cuda(N, M, h, hdim, q, index, table, rel_idx, output) + ctx.save_for_backward(q, index, table, rel_idx) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_output: [M, h] + output: (N, h, hdim), None, (L, h, hdim, 3), None + """ + q, index, table, rel_idx = ctx.saved_tensors + M, h = grad_output.shape + N, _, hdim = q.shape + L = table.shape[0] + + grad_output = grad_output.contiguous() + assert q.is_contiguous() and index.is_contiguous() and table.is_contiguous() and rel_idx.is_contiguous() and grad_output.is_contiguous() + + # print("back: attn[:5,:5]: ", attn[:5, :5]) + + # print("attn.shape: {} v.shape: {}, index0.shape: {}, index1.shape: {}".format(attn.shape, v.shape, index0.shape, index1.shape)) + + grad_q = torch.cuda.FloatTensor(N, h, hdim).zero_() + grad_table = torch.cuda.FloatTensor(L, h, hdim, 3).zero_() + + # torch.cuda.synchronize() + # start = time.time() + + pointops_cuda.dot_prod_with_idx_backward_cuda(N, M, h, hdim, grad_output, q, index, table, rel_idx, grad_q, grad_table) + + # torch.cuda.synchronize() + # end = time.time() + # print("time v9: {}".format(end - start)) + # # input() + + return grad_q, None, grad_table, None + +dot_prod_with_idx = DotProdWithIdx.apply + +class DotProdWithIdx_v2(Function): + @staticmethod + def forward(ctx, q, index_q, k, index_k, table_q, table_k, rel_idx): + """ + input: q: (N, h, hdim), index_q: (M), k: (N, h, hdim), index_k: (M), table_q: (L, h, hdim, 3), table_k: (L, h, hdim, 3), rel_idx: (M, 3) + output: output: [M, h] + """ + assert q.is_contiguous() and index_q.is_contiguous() and k.is_contiguous() and index_k.is_contiguous() and table_q.is_contiguous() and table_k.is_contiguous() and rel_idx.is_contiguous() + + N, h, hdim = q.shape + M = index_q.shape[0] + L = table_q.shape[0] + assert table_k.shape[0] == L and index_k.shape[0] == M + + # obtain the mapping from block_idx to m_idx + rel_idx_merge = rel_idx[:, 0] + rel_idx[:, 1] * L + rel_idx[:, 2] * (L ** 2) #[M, ] + sorted_values, sort_indices = torch.sort(rel_idx_merge) + _, counts = torch.unique_consecutive(sorted_values, return_counts=True) + rel_idx_offsets = torch.cumsum(counts, dim=-1) #[T,] + rel_idx_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), rel_idx_offsets], 0) #[T+1,] + n_max = counts.max() + T = counts.shape[0] + + # print("M: {}, L: {}, n_max: {}, T: {}".format(M, L, n_max, T)) + # print("rel_idx_merge.shape: {}, sorted_values.shape: {}".format(rel_idx_merge.shape, sorted_values.shape)) + # print("counts.shape: {}".format(counts.shape)) + + output = torch.cuda.FloatTensor(M, h).zero_() + # pointops_cuda.dot_prod_with_idx_forward_cuda(N, M, h, hdim, q, index, table, rel_idx, output) + pointops_cuda.dot_prod_with_idx_forward_cuda_v2(N, M, h, hdim, n_max, T, q, index_q, k, index_k, table_q, table_k, rel_idx, rel_idx_offsets.int(), sort_indices.int(), output) + + ctx.n_max = n_max + ctx.T = T + ctx.save_for_backward(q, index_q, k, index_k, table_q, table_k, rel_idx, rel_idx_offsets, sort_indices) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_output: [M, h] + output: (N, h, hdim), None, (L, h, hdim, 3), None + """ + q, index_q, k, index_k, table_q, table_k, rel_idx, rel_idx_offsets, sort_indices = ctx.saved_tensors + M, h = grad_output.shape + N, _, hdim = q.shape + L = table_q.shape[0] + T, n_max = ctx.T, ctx.n_max + + grad_output = grad_output.contiguous() + assert q.is_contiguous() and index_q.is_contiguous() and k.is_contiguous() and index_k.is_contiguous() and table_q.is_contiguous() and table_k.is_contiguous() and rel_idx.is_contiguous() and rel_idx_offsets.is_contiguous() and sort_indices.is_contiguous() and grad_output.is_contiguous() + + # print("back: attn[:5,:5]: ", attn[:5, :5]) + + # print("attn.shape: {} v.shape: {}, index0.shape: {}, index1.shape: {}".format(attn.shape, v.shape, index0.shape, index1.shape)) + + grad_q = torch.cuda.FloatTensor(N, h, hdim).zero_() + grad_table_q = torch.cuda.FloatTensor(L, h, hdim, 3).zero_() + grad_k = torch.cuda.FloatTensor(N, h, hdim).zero_() + grad_table_k = torch.cuda.FloatTensor(L, h, hdim, 3).zero_() + + # torch.cuda.synchronize() + # start = time.time() + + pointops_cuda.dot_prod_with_idx_backward_cuda_v2(N, M, h, hdim, n_max, T, grad_output, q, index_q, k, index_k, table_q, table_k, rel_idx, rel_idx_offsets.int(), sort_indices.int(), grad_q, grad_k, grad_table_q, grad_table_k) + + # torch.cuda.synchronize() + # end = time.time() + # print("time v9: {}".format(end - start)) + # # input() + return grad_q, None, grad_k, None, grad_table_q, grad_table_k, None + +dot_prod_with_idx_v2 = DotProdWithIdx_v2.apply + + +class DotProdWithIdx_v3(Function): + @staticmethod + def forward(ctx, q, index_q_offsets, n_max, k, index_k, table_q, table_k, rel_idx): + """ + input: q: (N, h, hdim), index_q: (M), k: (N, h, hdim), index_k: (M), table_q: (L, h, hdim, 3), table_k: (L, h, hdim, 3), rel_idx: (M, 3) + output: output: [M, h] + """ + assert q.is_contiguous() and index_q_offsets.is_contiguous() and k.is_contiguous() and index_k.is_contiguous() and table_q.is_contiguous() and table_k.is_contiguous() and rel_idx.is_contiguous() + + N, h, hdim = q.shape + M = index_k.shape[0] + L = table_q.shape[0] + assert table_k.shape[0] == L + + # # obtain the mapping from block_idx to m_idx + # rel_idx_merge = rel_idx[:, 0] + rel_idx[:, 1] * L + rel_idx[:, 2] * (L ** 2) #[M, ] + # sorted_values, sort_indices = torch.sort(rel_idx_merge) + # _, counts = torch.unique_consecutive(sorted_values, return_counts=True) + # rel_idx_offsets = torch.cumsum(counts, dim=-1) #[T,] + # rel_idx_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), rel_idx_offsets], 0) #[T+1,] + # n_max = counts.max() + # T = counts.shape[0] + + # print("M: {}, L: {}, n_max: {}, T: {}".format(M, L, n_max, T)) + # print("rel_idx_merge.shape: {}, sorted_values.shape: {}".format(rel_idx_merge.shape, sorted_values.shape)) + # print("counts.shape: {}".format(counts.shape)) + + # print("M: {}, L: {}, n_max: {}".format(M, L, n_max)) + + output = torch.cuda.FloatTensor(M, h).zero_() + # pointops_cuda.dot_prod_with_idx_forward_cuda(N, M, h, hdim, q, index, table, rel_idx, output) + pointops_cuda.dot_prod_with_idx_forward_cuda_v3(N, M, h, hdim, n_max, q, index_q_offsets, k, index_k, table_q, table_k, rel_idx, output) + + ctx.n_max = n_max + # ctx.T = T + ctx.save_for_backward(q, index_q_offsets, k, index_k, table_q, table_k, rel_idx) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_output: [M, h] + output: (N, h, hdim), None, (L, h, hdim, 3), None + """ + q, index_q_offsets, k, index_k, table_q, table_k, rel_idx = ctx.saved_tensors + M, h = grad_output.shape + N, _, hdim = q.shape + L = table_q.shape[0] + n_max = ctx.n_max + + grad_output = grad_output.contiguous() + assert q.is_contiguous() and index_q_offsets.is_contiguous() and k.is_contiguous() and index_k.is_contiguous() and table_q.is_contiguous() and table_k.is_contiguous() and rel_idx.is_contiguous() and grad_output.is_contiguous() + + # print("back: attn[:5,:5]: ", attn[:5, :5]) + + # print("attn.shape: {} v.shape: {}, index0.shape: {}, index1.shape: {}".format(attn.shape, v.shape, index0.shape, index1.shape)) + + grad_q = torch.cuda.FloatTensor(N, h, hdim).zero_() + grad_table_q = torch.cuda.FloatTensor(L, h, hdim, 3).zero_() + grad_k = torch.cuda.FloatTensor(N, h, hdim).zero_() + grad_table_k = torch.cuda.FloatTensor(L, h, hdim, 3).zero_() + + # torch.cuda.synchronize() + # start = time.time() + + pointops_cuda.dot_prod_with_idx_backward_cuda_v3(N, M, h, hdim, n_max, grad_output, q, index_q_offsets, k, index_k, table_q, table_k, rel_idx, grad_q, grad_k, grad_table_q, grad_table_k) + + # torch.cuda.synchronize() + # end = time.time() + # print("time v9: {}".format(end - start)) + # # input() + return grad_q, None, None, grad_k, None, grad_table_q, grad_table_k, None + +dot_prod_with_idx_v3 = DotProdWithIdx_v3.apply + +class AttentionStep2WithRelPosValue(Function): + @staticmethod + def forward(ctx, attn, v, index0, index1, table, rel_idx): + """ + input: attn: (M, h), v: (N, h, hdim), index0: (M), index1: (M), table: (L, h, hdim, 3), rel_idx: (M, 3) + output: output: [N, h, hdim] + """ + assert attn.is_contiguous() and v.is_contiguous() and index0.is_contiguous() and index1.is_contiguous() and table.is_contiguous() and rel_idx.is_contiguous() + + M, h = attn.shape + N_v, h, hdim = v.shape + N_q = index0.max().item() + 1 + + output = torch.cuda.FloatTensor(N_q, h, hdim).zero_() + pointops_cuda.attention_step2_with_rel_pos_value_forward_cuda(N_q, M, h, hdim, attn, v, index0, index1, table, rel_idx, output) + + # print("attn[:5,:5]: ", attn[:5, :5]) + + ctx.save_for_backward(attn, v, index0, index1, table, rel_idx) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_output: (N, h, C//h) + output: (M, h), (N, h, C//h), None, None, (L, h, hdim, 3), None + """ + attn, v, index0, index1, table, rel_idx = ctx.saved_tensors + N_q, h, hdim = grad_output.shape + N_v = v.shape[0] + M = attn.shape[0] + L = table.shape[0] + + grad_output = grad_output.contiguous() + # print("grad_output.is_contiguous(): ", grad_output.is_contiguous()) + assert attn.is_contiguous() and v.is_contiguous() and index0.is_contiguous() and index1.is_contiguous() and grad_output.is_contiguous() and table.is_contiguous() and rel_idx.is_contiguous() + + # print("back: attn[:5,:5]: ", attn[:5, :5]) + + # print("attn.shape: {} v.shape: {}, index0.shape: {}, index1.shape: {}".format(attn.shape, v.shape, index0.shape, index1.shape)) + + grad_attn = torch.cuda.FloatTensor(M, h).zero_() + grad_v = torch.cuda.FloatTensor(N_v, h, hdim).zero_() + grad_table = torch.cuda.FloatTensor(L, h, hdim, 3).zero_() + + # print("attn.shape: {}, grad_attn.shape: {}".format(attn.shape, grad_attn.shape)) + # print("v.shape: {}, grad_v.shape: {}".format(v.shape, grad_v.shape)) + # print("table.shape: {}, grad_table.shape: {}".format(table.shape, grad_table.shape)) + + # torch.cuda.synchronize() + # start = time.time() + + pointops_cuda.attention_step2_with_rel_pos_value_backward_cuda(N_q, M, h, hdim, grad_output, index0, index1, attn, v, table, rel_idx, grad_attn, grad_v, grad_table) + + # torch.cuda.synchronize() + # end = time.time() + # print("time v10: {}".format(end - start)) + # # input() + return grad_attn, grad_v, None, None, grad_table, None + +attention_step2_with_rel_pos_value = AttentionStep2WithRelPosValue.apply + + +class AttentionStep2WithRelPosValue_v2(Function): + @staticmethod + def forward(ctx, attn, v, index0_offsets, n_max, index1, table, rel_idx): + """ + input: attn: (M, h), v: (N, h, hdim), index0_offsets: (M), index1: (M), table: (L, h, hdim, 3), rel_idx: (M, 3) + output: output: [N, h, hdim] + """ + assert attn.is_contiguous() and v.is_contiguous() and index0_offsets.is_contiguous() and index1.is_contiguous() and table.is_contiguous() and rel_idx.is_contiguous() + + M, h = attn.shape + N, h, hdim = v.shape + # N_q = int(index0_offsets.max().item()) + 1 + + output = torch.cuda.FloatTensor(N, h, hdim).zero_() + pointops_cuda.attention_step2_with_rel_pos_value_forward_cuda_v2(N, M, h, hdim, n_max, attn, v, index0_offsets, index1, table, rel_idx, output) + + # print("attn[:5,:5]: ", attn[:5, :5]) + + ctx.n_max = n_max + ctx.save_for_backward(attn, v, index0_offsets, index1, table, rel_idx) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_output: (N, h, C//h) + output: (M, h), (N, h, C//h), None, None, (L, h, hdim, 3), None + """ + n_max = ctx.n_max + attn, v, index0_offsets, index1, table, rel_idx = ctx.saved_tensors + N, h, hdim = grad_output.shape + N = v.shape[0] + M = attn.shape[0] + L = table.shape[0] + + # grad_output = grad_output.contiguous() + # print("grad_output.is_contiguous(): ", grad_output.is_contiguous()) + assert attn.is_contiguous() and v.is_contiguous() and index0_offsets.is_contiguous() and index1.is_contiguous() and grad_output.is_contiguous() and table.is_contiguous() and rel_idx.is_contiguous() + + # print("back: attn[:5,:5]: ", attn[:5, :5]) + + # print("attn.shape: {} v.shape: {}, index0_offsets.shape: {}, index1.shape: {}".format(attn.shape, v.shape, index0_offsets.shape, index1.shape)) + + grad_attn = torch.cuda.FloatTensor(M, h).zero_() + grad_v = torch.cuda.FloatTensor(N, h, hdim).zero_() + grad_table = torch.cuda.FloatTensor(L, h, hdim, 3).zero_() + + # print("attn.shape: {}, grad_attn.shape: {}".format(attn.shape, grad_attn.shape)) + # print("v.shape: {}, grad_v.shape: {}".format(v.shape, grad_v.shape)) + # print("table.shape: {}, grad_table.shape: {}".format(table.shape, grad_table.shape)) + + # torch.cuda.synchronize() + # start = time.time() + + pointops_cuda.attention_step2_with_rel_pos_value_backward_cuda_v2(N, M, h, hdim, n_max, grad_output, index0_offsets, index1, attn, v, table, rel_idx, grad_attn, grad_v, grad_table) + + # torch.cuda.synchronize() + # end = time.time() + # print("time v10: {}".format(end - start)) + + return grad_attn, grad_v, None, None, None, grad_table, None + +attention_step2_with_rel_pos_value_v2 = AttentionStep2WithRelPosValue_v2.apply + +def queryandgroup(nsample, xyz, new_xyz, feat, idx, offset, new_offset, use_xyz=True, return_indx=False): + """ + input: xyz: (n, 3), new_xyz: (m, 3), feat: (n, c), idx: (m, nsample), offset: (b), new_offset: (b) + output: new_feat: (m, c+3, nsample), grouped_idx: (m, nsample) + """ + assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous() + if new_xyz is None: + new_xyz = xyz + if idx is None: + idx, _ = knnquery(nsample, xyz, new_xyz, offset, new_offset) # (m, nsample) + + n, m, c = xyz.shape[0], new_xyz.shape[0], feat.shape[1] + grouped_xyz = xyz[idx.view(-1).long(), :].view(m, nsample, 3) # (m, nsample, 3) + #grouped_xyz = grouping(xyz, idx) # (m, nsample, 3) + # 相对位置 + grouped_xyz -= new_xyz.unsqueeze(1) # (m, nsample, 3) + grouped_feat = feat[idx.view(-1).long(), :].view(m, nsample, c) # (m, nsample, c) + #grouped_feat = grouping(feat, idx) # (m, nsample, c) + if use_xyz: + if return_indx: + return torch.cat((grouped_xyz, grouped_feat), -1), idx # (m, nsample, 3+c) + else: + return torch.cat((grouped_xyz, grouped_feat), -1) + else: + if return_indx: + return grouped_feat, idx + else: + return grouped_feat + + +def Divide2Patch(nsample, xyz, offset, return_offset=False, anchor_scale=None): + # nsample: 16 xyz: (n, 3) offset: (b) + downsample_scale = anchor_scale or nsample + new_offset, count = [offset[0].item() // downsample_scale], offset[0].item() // downsample_scale + for i in range(1, offset.shape[0]): + count += (offset[i].item() - offset[i-1].item()) // downsample_scale + new_offset.append(count) + # print("donw sample scale:", downsample_scale,"offset:", offset, "newoffset:", new_offset) + new_offset = torch.cuda.IntTensor(new_offset) + idx = furthestsampling(xyz, offset, new_offset) # (m) + new_xyz = xyz[idx.long()] + p_idx, _ = knnquery(nsample, xyz, new_xyz, offset, new_offset) # (m, nsample) + if return_offset: + return p_idx, new_offset + else: + return p_idx + +class Subtraction(Function): + @staticmethod + def forward(ctx, input1, input2, idx): + """ + input: input1: (n, c), input2: (n, c), idx: (n, nsample) + output: (n, nsample, c) + """ + assert input1.is_contiguous() and input2.is_contiguous() + n, c = input1.shape; nsample = idx.shape[-1] + output = torch.cuda.FloatTensor(n, nsample, c).zero_() + pointops_cuda.subtraction_forward_cuda(n, nsample, c, input1, input2, idx, output) + ctx.save_for_backward(idx) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_out: (n, nsample, c) + output: grad_input1: (n, c), grad_input2: (n, c) + """ + idx, = ctx.saved_tensors + n, nsample, c = grad_output.shape + grad_input1 = torch.cuda.FloatTensor(n, c).zero_() + grad_input2 = torch.cuda.FloatTensor(n, c).zero_() + pointops_cuda.subtraction_backward_cuda(n, nsample, c, idx, grad_output, grad_input1, grad_input2) + return grad_input1, grad_input2, None + +subtraction = Subtraction.apply + + +class Aggregation(Function): + @staticmethod + def forward(ctx, input, position, weight, idx): + """ + input: input: (n, c), position: (n, nsample, c), weight : (n, nsample, c'), idx: (n, nsample) + output: (n, c) + """ + assert input.is_contiguous() and position.is_contiguous() and weight.is_contiguous() + n, nsample, c = position.shape; w_c = weight.shape[-1] + output = torch.cuda.FloatTensor(n, c).zero_() + pointops_cuda.aggregation_forward_cuda(n, nsample, c, w_c, input, position, weight, idx, output) + ctx.save_for_backward(input, position, weight, idx) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_out: (n, c) + output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight : (n, nsample, c') + """ + input, position, weight, idx = ctx.saved_tensors + n, nsample, c = position.shape; w_c = weight.shape[-1] + grad_input = torch.cuda.FloatTensor(n, c).zero_() + grad_position = torch.cuda.FloatTensor(n, nsample, c).zero_() + grad_weight = torch.cuda.FloatTensor(n, nsample, w_c).zero_() + pointops_cuda.aggregation_backward_cuda(n, nsample, c, w_c, input, position, weight, idx, grad_output, grad_input, grad_position, grad_weight) + return grad_input, grad_position, grad_weight, None + +aggregation = Aggregation.apply + + +def interpolation(xyz, new_xyz, feat, offset, new_offset, k=3): + """ + input: xyz: (m, 3), new_xyz: (n, 3), feat: (m, c), offset: (b), new_offset: (b) + output: (n, c) + """ + assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous() + idx, dist = knnquery(k, xyz, new_xyz, offset, new_offset) # (n, 3), (n, 3) + dist_recip = 1.0 / (dist + 1e-8) # (n, 3) + norm = torch.sum(dist_recip, dim=1, keepdim=True) + weight = dist_recip / norm # (n, 3) + + new_feat = torch.cuda.FloatTensor(new_xyz.shape[0], feat.shape[1]).zero_() + for i in range(k): + new_feat += feat[idx[:, i].long(), :] * weight[:, i].unsqueeze(-1) + return new_feat + + +def interpolation_v2(xyz, new_xyz, feat, offset, new_offset, k=3): + """ + input: xyz: (m, 3), new_xyz: (n, 3), feat: (m, c), offset: (b), new_offset: (b) + output: (n, c) + """ + assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous() + + idx, _ = knnquery(k, xyz, new_xyz, offset, new_offset) # (n, 3), (n, 3) + + # print("e3: idx.shape: {}, idx[:5]: {}".format(idx.shape, idx[:5])) + + dist = torch.sqrt(((new_xyz.unsqueeze(1) - xyz[idx.long()]) ** 2).sum(-1) + 1e-8) + + # print("e4: dist.shape: {}, dist[:5]: {}".format(dist.shape, dist[:5])) + # print("((_-dist)**2).max(): ", ((_-dist)**2).max()) + # input() + + dist_recip = 1.0 / (dist + 1e-8) # (n, 3) + norm = torch.sum(dist_recip, dim=1, keepdim=True) + weight = dist_recip / norm # (n, 3) + + new_feat = torch.cuda.FloatTensor(new_xyz.shape[0], feat.shape[1]).zero_() + for i in range(k): + new_feat += feat[idx[:, i].long(), :] * weight[:, i].unsqueeze(-1) + return new_feat + + +class Interpolation(Function): + @staticmethod + def forward(ctx, xyz, new_xyz, input, offset, new_offset, k=3): + """ + input: xyz: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b) + output: (n, c) + """ + assert xyz.is_contiguous() and new_xyz.is_contiguous() and input.is_contiguous() + idx, dist = knnquery(k, xyz, new_xyz, offset, new_offset) # (n, k), (n, k) + dist_recip = 1.0 / (dist + 1e-8) # (n, k) + norm = torch.sum(dist_recip, dim=1, keepdim=True) + weight = dist_recip / norm # (n, k) + + n, c, m = new_xyz.shape[0], input.shape[1], input.shape[0] + output = torch.cuda.FloatTensor(n, c).zero_() + pointops_cuda.interpolation_forward_cuda(n, c, k, input, idx, weight, output) + ctx.m, ctx.k = m, k + ctx.save_for_backward(idx, weight) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: xyz: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b) + output: (n, c) + """ + m, k = ctx.m, ctx.k + idx, weight = ctx.saved_tensors + n, c = grad_output.shape + grad_input = torch.cuda.FloatTensor(m, c).zero_() + pointops_cuda.interpolation_backward_cuda(n, c, k, grad_output, idx, weight, grad_input) + return None, None, grad_input, None, None, None + +interpolation2 = Interpolation.apply diff --git a/utils/pointops2/functions/pointops2.py b/utils/pointops2/functions/pointops2.py new file mode 100644 index 0000000..33af8e7 --- /dev/null +++ b/utils/pointops2/functions/pointops2.py @@ -0,0 +1,214 @@ +from typing import Tuple + +import torch +from torch.autograd import Function +import torch.nn as nn + +import pointops2_cuda as pointops_cuda + + +class FurthestSampling(Function): + @staticmethod + def forward(ctx, xyz, offset, new_offset): + """ + input: xyz: (n, 3), offset: (b), new_offset: (b) + output: idx: (m) + """ + assert xyz.is_contiguous() + n, b, n_max = xyz.shape[0], offset.shape[0], offset[0] + for i in range(1, b): + n_max = max(offset[i] - offset[i-1], n_max) + idx = torch.cuda.IntTensor(new_offset[b-1].item()).zero_() + tmp = torch.cuda.FloatTensor(n).fill_(1e10) + pointops_cuda.furthestsampling_cuda(b, n_max, xyz, offset, new_offset, tmp, idx) + del tmp + return idx + +furthestsampling = FurthestSampling.apply + + +class KNNQuery(Function): + @staticmethod + def forward(ctx, nsample, xyz, new_xyz, offset, new_offset): + """ + input: xyz: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b) + output: idx: (m, nsample), dist2: (m, nsample) + """ + if new_xyz is None: new_xyz = xyz + assert xyz.is_contiguous() and new_xyz.is_contiguous() + m = new_xyz.shape[0] + idx = torch.cuda.IntTensor(m, nsample).zero_() + dist2 = torch.cuda.FloatTensor(m, nsample).zero_() + pointops_cuda.knnquery_cuda(m, nsample, xyz, new_xyz, offset, new_offset, idx, dist2) + return idx, torch.sqrt(dist2) + +knnquery = KNNQuery.apply + + +class Grouping(Function): + @staticmethod + def forward(ctx, input, idx): + """ + input: input: (n, c), idx : (m, nsample) + output: (m, nsample, c) + """ + assert input.is_contiguous() and idx.is_contiguous() + m, nsample, n, c = idx.shape[0], idx.shape[1], input.shape[0], input.shape[1] + output = torch.cuda.FloatTensor(m, nsample, c) + pointops_cuda.grouping_forward_cuda(m, nsample, c, input, idx, output) + ctx.n = n + ctx.save_for_backward(idx) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_out: (m, c, nsample) + output: (n, c), None + """ + n = ctx.n + idx, = ctx.saved_tensors + m, nsample, c = grad_output.shape + grad_input = torch.cuda.FloatTensor(n, c).zero_() + pointops_cuda.grouping_backward_cuda(m, nsample, c, grad_output, idx, grad_input) + return grad_input, None + +grouping = Grouping.apply + + +def queryandgroup(nsample, xyz, new_xyz, feat, idx, offset, new_offset, use_xyz=True): + """ + input: xyz: (n, 3), new_xyz: (m, 3), feat: (n, c), idx: (m, nsample), offset: (b), new_offset: (b) + output: new_feat: (m, c+3, nsample), grouped_idx: (m, nsample) + """ + assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous() + if new_xyz is None: + new_xyz = xyz + if idx is None: + idx, _ = knnquery(nsample, xyz, new_xyz, offset, new_offset) # (m, nsample) + + n, m, c = xyz.shape[0], new_xyz.shape[0], feat.shape[1] + grouped_xyz = xyz[idx.view(-1).long(), :].view(m, nsample, 3) # (m, nsample, 3) + #grouped_xyz = grouping(xyz, idx) # (m, nsample, 3) + grouped_xyz -= new_xyz.unsqueeze(1) # (m, nsample, 3) + grouped_feat = feat[idx.view(-1).long(), :].view(m, nsample, c) # (m, nsample, c) + #grouped_feat = grouping(feat, idx) # (m, nsample, c) + + if use_xyz: + return torch.cat((grouped_xyz, grouped_feat), -1) # (m, nsample, 3+c) + else: + return grouped_feat + + +class Subtraction(Function): + @staticmethod + def forward(ctx, input1, input2, idx): + """ + input: input1: (n, c), input2: (n, c), idx: (n, nsample) + output: (n, nsample, c) + """ + assert input1.is_contiguous() and input2.is_contiguous() + n, c = input1.shape; nsample = idx.shape[-1] + output = torch.cuda.FloatTensor(n, nsample, c).zero_() + pointops_cuda.subtraction_forward_cuda(n, nsample, c, input1, input2, idx, output) + ctx.save_for_backward(idx) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_out: (n, nsample, c) + output: grad_input1: (n, c), grad_input2: (n, c) + """ + idx, = ctx.saved_tensors + n, nsample, c = grad_output.shape + grad_input1 = torch.cuda.FloatTensor(n, c).zero_() + grad_input2 = torch.cuda.FloatTensor(n, c).zero_() + pointops_cuda.subtraction_backward_cuda(n, nsample, c, idx, grad_output, grad_input1, grad_input2) + return grad_input1, grad_input2, None + +subtraction = Subtraction.apply + + +class Aggregation(Function): + @staticmethod + def forward(ctx, input, position, weight, idx): + """ + input: input: (n, c), position: (n, nsample, c), weight : (n, nsample, c'), idx: (n, nsample) + output: (n, c) + """ + assert input.is_contiguous() and position.is_contiguous() and weight.is_contiguous() + n, nsample, c = position.shape; w_c = weight.shape[-1] + output = torch.cuda.FloatTensor(n, c).zero_() + pointops_cuda.aggregation_forward_cuda(n, nsample, c, w_c, input, position, weight, idx, output) + ctx.save_for_backward(input, position, weight, idx) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_out: (n, c) + output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight : (n, nsample, c') + """ + input, position, weight, idx = ctx.saved_tensors + n, nsample, c = position.shape; w_c = weight.shape[-1] + grad_input = torch.cuda.FloatTensor(n, c).zero_() + grad_position = torch.cuda.FloatTensor(n, nsample, c).zero_() + grad_weight = torch.cuda.FloatTensor(n, nsample, w_c).zero_() + pointops_cuda.aggregation_backward_cuda(n, nsample, c, w_c, input, position, weight, idx, grad_output, grad_input, grad_position, grad_weight) + return grad_input, grad_position, grad_weight, None + +aggregation = Aggregation.apply + + +def interpolation(xyz, new_xyz, feat, offset, new_offset, k=3): + """ + input: xyz: (m, 3), new_xyz: (n, 3), feat: (m, c), offset: (b), new_offset: (b) + output: (n, c) + """ + assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous() + idx, dist = knnquery(k, xyz, new_xyz, offset, new_offset) # (n, 3), (n, 3) + dist_recip = 1.0 / (dist + 1e-8) # (n, 3) + norm = torch.sum(dist_recip, dim=1, keepdim=True) + weight = dist_recip / norm # (n, 3) + + new_feat = torch.cuda.FloatTensor(new_xyz.shape[0], feat.shape[1]).zero_() + for i in range(k): + new_feat += feat[idx[:, i].long(), :] * weight[:, i].unsqueeze(-1) + return new_feat + + +class Interpolation(Function): + @staticmethod + def forward(ctx, xyz, new_xyz, input, offset, new_offset, k=3): + """ + input: xyz: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b) + output: (n, c) + """ + assert xyz.is_contiguous() and new_xyz.is_contiguous() and input.is_contiguous() + idx, dist = knnquery(k, xyz, new_xyz, offset, new_offset) # (n, k), (n, k) + dist_recip = 1.0 / (dist + 1e-8) # (n, k) + norm = torch.sum(dist_recip, dim=1, keepdim=True) + weight = dist_recip / norm # (n, k) + + n, c, m = new_xyz.shape[0], input.shape[1], input.shape[0] + output = torch.cuda.FloatTensor(n, c).zero_() + pointops_cuda.interpolation_forward_cuda(n, c, k, input, idx, weight, output) + ctx.m, ctx.k = m, k + ctx.save_for_backward(idx, weight) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: xyz: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b) + output: (n, c) + """ + m, k = ctx.m, ctx.k + idx, weight = ctx.saved_tensors + n, c = grad_output.shape + grad_input = torch.cuda.FloatTensor(m, c).zero_() + pointops_cuda.interpolation_backward_cuda(n, c, k, grad_output, idx, weight, grad_input) + return None, None, grad_input, None, None, None + +interpolation2 = Interpolation.apply diff --git a/utils/pointops2/functions/pointops_ablation.py b/utils/pointops2/functions/pointops_ablation.py new file mode 100644 index 0000000..ac01326 --- /dev/null +++ b/utils/pointops2/functions/pointops_ablation.py @@ -0,0 +1,215 @@ +from typing import Tuple + +import torch +from torch.autograd import Function +import torch.nn as nn + +import pointops2_cuda as pointops_cuda + + +class FurthestSampling(Function): + @staticmethod + def forward(ctx, xyz, offset, new_offset): + """ + input: xyz: (n, 3), offset: (b), new_offset: (b) + output: idx: (m) + """ + assert xyz.is_contiguous() + n, b, n_max = xyz.shape[0], offset.shape[0], offset[0] + for i in range(1, b): + n_max = max(offset[i] - offset[i-1], n_max) + idx = torch.cuda.IntTensor(new_offset[b-1].item()).zero_() + tmp = torch.cuda.FloatTensor(n).fill_(1e10) + pointops_cuda.furthestsampling_cuda(b, n_max, xyz, offset, new_offset, tmp, idx) + del tmp + return idx + +furthestsampling = FurthestSampling.apply + + +class KNNQuery(Function): + @staticmethod + def forward(ctx, nsample, xyz, new_xyz, offset, new_offset): + """ + input: xyz: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b) + output: idx: (m, nsample), dist2: (m, nsample) + """ + if new_xyz is None: new_xyz = xyz + assert xyz.is_contiguous() and new_xyz.is_contiguous() + m = new_xyz.shape[0] + idx = torch.cuda.IntTensor(m, nsample).zero_() + dist2 = torch.cuda.FloatTensor(m, nsample).zero_() + pointops_cuda.knnquery_cuda(m, nsample, xyz, new_xyz, offset, new_offset, idx, dist2) + return idx, torch.sqrt(dist2) + +knnquery = KNNQuery.apply + + +class Grouping(Function): + @staticmethod + def forward(ctx, input, idx): + """ + input: input: (n, c), idx : (m, nsample) + output: (m, nsample, c) + """ + assert input.is_contiguous() and idx.is_contiguous() + m, nsample, n, c = idx.shape[0], idx.shape[1], input.shape[0], input.shape[1] + output = torch.cuda.FloatTensor(m, nsample, c) + pointops_cuda.grouping_forward_cuda(m, nsample, c, input, idx, output) + ctx.n = n + ctx.save_for_backward(idx) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_out: (m, c, nsample) + output: (n, c), None + """ + n = ctx.n + idx, = ctx.saved_tensors + m, nsample, c = grad_output.shape + grad_input = torch.cuda.FloatTensor(n, c).zero_() + pointops_cuda.grouping_backward_cuda(m, nsample, c, grad_output, idx, grad_input) + return grad_input, None + +grouping = Grouping.apply + + +def queryandgroup(nsample, xyz, new_xyz, feat, idx, offset, new_offset, use_xyz=True, relative=True): + """ + input: xyz: (n, 3), new_xyz: (m, 3), feat: (n, c), idx: (m, nsample), offset: (b), new_offset: (b) + output: new_feat: (m, c+3, nsample), grouped_idx: (m, nsample) + """ + assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous() + if new_xyz is None: + new_xyz = xyz + if idx is None: + idx, _ = knnquery(nsample, xyz, new_xyz, offset, new_offset) # (m, nsample) + + n, m, c = xyz.shape[0], new_xyz.shape[0], feat.shape[1] + grouped_xyz = xyz[idx.view(-1).long(), :].view(m, nsample, 3) # (m, nsample, 3) + #grouped_xyz = grouping(xyz, idx) # (m, nsample, 3) + if relative: + grouped_xyz -= new_xyz.unsqueeze(1) # (m, nsample, 3) + grouped_feat = feat[idx.view(-1).long(), :].view(m, nsample, c) # (m, nsample, c) + #grouped_feat = grouping(feat, idx) # (m, nsample, c) + + if use_xyz: + return torch.cat((grouped_xyz, grouped_feat), -1) # (m, nsample, 3+c) + else: + return grouped_feat + + +class Subtraction(Function): + @staticmethod + def forward(ctx, input1, input2, idx): + """ + input: input1: (n, c), input2: (n, c), idx: (n, nsample) + output: (n, nsample, c) + """ + assert input1.is_contiguous() and input2.is_contiguous() + n, c = input1.shape; nsample = idx.shape[-1] + output = torch.cuda.FloatTensor(n, nsample, c).zero_() + pointops_cuda.subtraction_forward_cuda(n, nsample, c, input1, input2, idx, output) + ctx.save_for_backward(idx) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_out: (n, nsample, c) + output: grad_input1: (n, c), grad_input2: (n, c) + """ + idx, = ctx.saved_tensors + n, nsample, c = grad_output.shape + grad_input1 = torch.cuda.FloatTensor(n, c).zero_() + grad_input2 = torch.cuda.FloatTensor(n, c).zero_() + pointops_cuda.subtraction_backward_cuda(n, nsample, c, idx, grad_output, grad_input1, grad_input2) + return grad_input1, grad_input2, None + +subtraction = Subtraction.apply + + +class Aggregation(Function): + @staticmethod + def forward(ctx, input, position, weight, idx): + """ + input: input: (n, c), position: (n, nsample, c), weight : (n, nsample, c'), idx: (n, nsample) + output: (n, c) + """ + assert input.is_contiguous() and position.is_contiguous() and weight.is_contiguous() + n, nsample, c = position.shape; w_c = weight.shape[-1] + output = torch.cuda.FloatTensor(n, c).zero_() + pointops_cuda.aggregation_forward_cuda(n, nsample, c, w_c, input, position, weight, idx, output) + ctx.save_for_backward(input, position, weight, idx) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_out: (n, c) + output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight : (n, nsample, c') + """ + input, position, weight, idx = ctx.saved_tensors + n, nsample, c = position.shape; w_c = weight.shape[-1] + grad_input = torch.cuda.FloatTensor(n, c).zero_() + grad_position = torch.cuda.FloatTensor(n, nsample, c).zero_() + grad_weight = torch.cuda.FloatTensor(n, nsample, w_c).zero_() + pointops_cuda.aggregation_backward_cuda(n, nsample, c, w_c, input, position, weight, idx, grad_output, grad_input, grad_position, grad_weight) + return grad_input, grad_position, grad_weight, None + +aggregation = Aggregation.apply + + +def interpolation(xyz, new_xyz, feat, offset, new_offset, k=3): + """ + input: xyz: (m, 3), new_xyz: (n, 3), feat: (m, c), offset: (b), new_offset: (b) + output: (n, c) + """ + assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous() + idx, dist = knnquery(k, xyz, new_xyz, offset, new_offset) # (n, 3), (n, 3) + dist_recip = 1.0 / (dist + 1e-8) # (n, 3) + norm = torch.sum(dist_recip, dim=1, keepdim=True) + weight = dist_recip / norm # (n, 3) + + new_feat = torch.cuda.FloatTensor(new_xyz.shape[0], feat.shape[1]).zero_() + for i in range(k): + new_feat += feat[idx[:, i].long(), :] * weight[:, i].unsqueeze(-1) + return new_feat + + +class Interpolation(Function): + @staticmethod + def forward(ctx, xyz, new_xyz, input, offset, new_offset, k=3): + """ + input: xyz: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b) + output: (n, c) + """ + assert xyz.is_contiguous() and new_xyz.is_contiguous() and input.is_contiguous() + idx, dist = knnquery(k, xyz, new_xyz, offset, new_offset) # (n, k), (n, k) + dist_recip = 1.0 / (dist + 1e-8) # (n, k) + norm = torch.sum(dist_recip, dim=1, keepdim=True) + weight = dist_recip / norm # (n, k) + + n, c, m = new_xyz.shape[0], input.shape[1], input.shape[0] + output = torch.cuda.FloatTensor(n, c).zero_() + pointops_cuda.interpolation_forward_cuda(n, c, k, input, idx, weight, output) + ctx.m, ctx.k = m, k + ctx.save_for_backward(idx, weight) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: xyz: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b) + output: (n, c) + """ + m, k = ctx.m, ctx.k + idx, weight = ctx.saved_tensors + n, c = grad_output.shape + grad_input = torch.cuda.FloatTensor(m, c).zero_() + pointops_cuda.interpolation_backward_cuda(n, c, k, grad_output, idx, weight, grad_input) + return None, None, grad_input, None, None, None + +interpolation2 = Interpolation.apply diff --git a/utils/pointops2/functions/test_attention_op_step1.py b/utils/pointops2/functions/test_attention_op_step1.py new file mode 100644 index 0000000..6515fdc --- /dev/null +++ b/utils/pointops2/functions/test_attention_op_step1.py @@ -0,0 +1,78 @@ +import torch +import pointops +from torch_scatter import scatter_max, scatter_mean, scatter_add, scatter_min, scatter_sum + +torch.manual_seed(1) + +M = 800000 +N = 35000 +C = 96 +h = 6 +query = torch.rand(N, h, C//h).cuda() +key = torch.rand(N, h, C//h).cuda() + +index_0 = torch.rand(M) +index_0[index_0 < 0] = 0 +index_0 = (index_0*N).long().cuda() + +index_1 = torch.rand(M) +index_1[index_1 < 0] = 0 +index_1 = (index_1*N).long().cuda() + +query.requires_grad = True +key.requires_grad = True + +# rearrange index for acceleration +index_0, indices = torch.sort(index_0) #[M,] +index_1 = index_1[indices] #[M,] +index_0_counts = index_0.bincount() + +print("index_0_counts.shape: ", index_0_counts.shape) + +n_max = index_0_counts.max() +index_0_offsets = index_0_counts.cumsum(dim=-1) #[N] + +print("v1 index_0_offsets.shape: ", index_0_offsets.shape) + +index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1] + +# print("index_0[:100]: ", index_0[:100]) +print("n_max: ", n_max) +print("index_0_offsets.shape: ", index_0_offsets.shape) +# input() + +print("index_0_offsets[:100]: ", index_0_offsets[:100]) +print("index_1[300:320]: ", index_1[300:320]) + + +attn_flat = pointops.attention_step1(query.float(), key.float(), index_0.int(), index_1.int()) +# loss = attn_flat.sum() +# loss.backward() +print("attn_flat.shape: {}, attn_flat[300:320,:10]: {}".format(attn_flat.shape, attn_flat[300:320,:10])) +# print("query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5]) +# print("key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5]) +# input() + +print("query.is_contiguous(): ", query.is_contiguous()) +print("key.is_contiguous(): ", key.is_contiguous()) +print("index_0.is_contiguous(): ", index_0.is_contiguous()) +print("index_1.is_contiguous(): ", index_1.is_contiguous()) + +attn_flat_v2 = pointops.attention_step1_v2(query.float(), key.float(), index_1.int(), index_0_offsets.int(), n_max) +# loss = attn_flat_v2.sum() +# loss.backward() +print("attn_flat_v2.shape: {}, attn_flat_v2[300:320,:10]: {}".format(attn_flat_v2.shape, attn_flat_v2[300:320,:10])) +# print("query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5]) +# print("key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5]) +# input() + +mask = attn_flat_v2.sum(-1) != 0 +print("mask.sum(): ", mask.sum()) +print("attn_flat_v2[mask] - attn_flat[mask]: ", ((attn_flat_v2[mask] - attn_flat[mask])**2).max()) + + +print("((attn_flat-attn_flat_v2)**2 < 1e-8).all(): ", ((attn_flat-attn_flat_v2)**2 < 1e-8).all()) + +selected = 10000 +print("torch.max((attn_flat[:selected]-attn_flat_v2[:selected])**2, 0): ", torch.max((attn_flat[:selected]-attn_flat_v2[:selected])**2, 0)) + diff --git a/utils/pointops2/functions/test_attention_op_step1_v2.py b/utils/pointops2/functions/test_attention_op_step1_v2.py new file mode 100644 index 0000000..b8bd582 --- /dev/null +++ b/utils/pointops2/functions/test_attention_op_step1_v2.py @@ -0,0 +1,97 @@ +import torch +import pointops +from torch_scatter import scatter_max, scatter_mean, scatter_add, scatter_min, scatter_sum + +torch.manual_seed(1) + +M = 800000 +N = 35000 +C = 96 +h = 6 +query = torch.rand(N, h, C//h).cuda() +key = torch.rand(N, h, C//h).cuda() + +index_0 = torch.rand(M) +index_0[index_0 < 0] = 0 +index_0 = (index_0*N).long().cuda() + +index_1 = torch.rand(M) +index_1[index_1 < 0] = 0 +index_1 = (index_1*N).long().cuda() + +query.requires_grad = True +key.requires_grad = True + + +attn_flat = pointops.attention_step1(query.float(), key.float(), index_0.int(), index_1.int()) +loss = attn_flat.sum() +loss.backward() +print("attn_flat.shape: {}, attn_flat[:20,:10]: {}".format(attn_flat.shape, attn_flat[:20,:10])) +print("query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5]) +print("key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5]) +input() + + + +# rearrange index for acceleration +index_0, indices = torch.sort(index_0) #[M,] +index_1 = index_1[indices] #[M,] +index_0_counts = index_0.bincount() + +print("index_0_counts.shape: ", index_0_counts.shape) + +n_max = index_0_counts.max() +index_0_offsets = index_0_counts.cumsum(dim=-1) #[N] + +print("v1 index_0_offsets.shape: ", index_0_offsets.shape) + +index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1] + +# print("index_0[:100]: ", index_0[:100]) +print("n_max: ", n_max) +print("index_0_offsets.shape: ", index_0_offsets.shape) +# input() + +print("index_0_offsets[:100]: ", index_0_offsets[:100]) +print("index_1[:20]: ", index_1[:20]) + + +attn_flat = pointops.attention_step1(query.float(), key.float(), index_0.int(), index_1.int()) +# loss = attn_flat.sum() +# loss.backward() +# # attn_flat = pointops.attention_step1(query.float(), key.float(), index_0.int(), index_1.int()) +# # loss = attn_flat.sum() +# # loss.backward() +# print("attn_flat.shape: {}, attn_flat[:20,:10]: {}".format(attn_flat.shape, attn_flat[:20,:10])) +# print("query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5]) +# print("key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5]) +# input() + +print("query.is_contiguous(): ", query.is_contiguous()) +print("key.is_contiguous(): ", key.is_contiguous()) +print("index_0.is_contiguous(): ", index_0.is_contiguous()) +print("index_1.is_contiguous(): ", index_1.is_contiguous()) + +attn_flat_v2 = pointops.attention_step1_v2(query.float(), key.float(), index_1.int(), index_0_offsets.int(), n_max) +loss = attn_flat_v2.sum() +loss.backward() + +# attn_flat_v2 = pointops.attention_step1_v2(query.float(), key.float(), index_1.int(), index_0_offsets.int(), n_max) +# loss = attn_flat_v2.sum() +# loss.backward() + +print("attn_flat_v2.shape: {}, attn_flat_v2[:20,:10]: {}".format(attn_flat_v2.shape, attn_flat_v2[:20,:10])) +print("query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5]) +print("key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5]) +# input() + +# mask = attn_flat_v2.sum(-1) != 0 +# print("mask.sum(): ", mask.sum()) +# print("attn_flat_v2[mask] - attn_flat[mask]: ", ((attn_flat_v2[mask] - attn_flat[mask])**2).max()) + + +print("((attn_flat-attn_flat_v2)**2 < 1e-8).all(): ", ((attn_flat-attn_flat_v2)**2 < 1e-8).all()) + +selected = 10000 +print("torch.max((attn_flat[:selected]-attn_flat_v2[:selected])**2, 0): ", torch.max((attn_flat[:selected]-attn_flat_v2[:selected])**2, 0)) + diff --git a/utils/pointops2/functions/test_attention_op_step2.py b/utils/pointops2/functions/test_attention_op_step2.py new file mode 100644 index 0000000..2146d5a --- /dev/null +++ b/utils/pointops2/functions/test_attention_op_step2.py @@ -0,0 +1,55 @@ +import torch +import pointops +from torch_scatter import scatter_max, scatter_mean, scatter_add, scatter_min, scatter_sum + +torch.manual_seed(1) + +M = 800000 +N = 35000 +C = 96 +h = 6 +softmax_attn_flat = torch.rand(M, h).cuda() +value = torch.rand(N, h, C//h).cuda() + +index_0 = torch.rand(M) +index_0[index_0 < 0] = 0 +index_0 = (index_0*N).long().cuda() + +index_1 = torch.rand(M) +index_1[index_1 < 0] = 0 +index_1 = (index_1*N).long().cuda() + +softmax_attn_flat.requires_grad = True +value.requires_grad = True + +# value_flat = value[index_1] #[M, num_heads, C // num_heads] +# x = (softmax_attn_flat.unsqueeze(-1) * value_flat).reshape(M, C) +# x = scatter_sum(src=x, index=index_0, dim=0, dim_size=N) #[N, C] +# loss = x.sum() +# loss.backward() + +# print("x.shape: {}, x[:5,:10]: {}".format(x.shape, x[:5,:10])) +# print("softmax_attn_flat.grad[:5, :10]: ", softmax_attn_flat.grad[:5, :10]) +# print("value.grad[:5, :3, :5]: ", value.grad[:5, :3, :5]) +# input() + +print("softmax_attn_flat.is_contiguous(): ", softmax_attn_flat.is_contiguous()) +print("value.is_contiguous(): ", value.is_contiguous()) +print("index_0.is_contiguous(): ", index_0.is_contiguous()) +print("index_1.is_contiguous(): ", index_1.is_contiguous()) + +x_v2 = pointops.attention_step2(softmax_attn_flat.float(), value.float(), index_0.int(), index_1.int()) +x_v2 = x_v2.view(N, C) +loss = x_v2.sum() +loss.backward() + +print("x_v2.shape: {}, x_v2[:5,:10]: {}".format(x_v2.shape, x_v2[:5,:10])) + +print("softmax_attn_flat.grad[:5, :10]: ", softmax_attn_flat.grad[:5, :10]) +print("value.grad[:5, :3, :5]: ", value.grad[:5, :3, :5]) +input() + +print("((x-x_v2)**2 < 1e-8).all(): ", ((x-x_v2)**2 < 1e-8).all()) + +print("torch.max((x-x_v2)**2): ", torch.max((x-x_v2)**2)) + diff --git a/utils/pointops2/functions/test_relative_pos_encoding_op_step1.py b/utils/pointops2/functions/test_relative_pos_encoding_op_step1.py new file mode 100644 index 0000000..2ff8b34 --- /dev/null +++ b/utils/pointops2/functions/test_relative_pos_encoding_op_step1.py @@ -0,0 +1,56 @@ +import torch +import pointops +from torch_scatter import scatter_max, scatter_mean, scatter_add, scatter_min, scatter_sum + +torch.manual_seed(1) + +M = 80000 +N = 3500 +hdim = 16 +h = 6 +L = 31 +query = torch.rand(N, h, hdim).cuda() +table = torch.rand(L, h, hdim, 3).cuda() + +index = torch.rand(M) +index[index < 0] = 0 +index = (index*N).long().cuda() + +rel_index = torch.rand(M, 3) +rel_index[rel_index < 0] = 0 +rel_index = (rel_index*L).long().cuda() + +query.requires_grad = True +table.requires_grad = True + +# query_flat = query[index] #[M, h, hdim] +# table_x, table_y, table_z = table[:,:,:,0], table[:,:,:,1], table[:,:,:,2] #[L, h, hdim] +# rel_index_x, rel_index_y, rel_index_z = rel_index[:,0], rel_index[:,1], rel_index[:,2] #[M] +# rel_pos_encoding = table_x[rel_index_x] + table_y[rel_index_y] + table_z[rel_index_z] #[M, h, hdim] +# output = (query_flat * rel_pos_encoding).sum(-1) #[M, h] +# loss = output.mean() +# loss.backward() + +# print("output.shape: {}, output[:5,:10]: {}".format(output.shape, output[:5,:10])) +# print("query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5]) +# print("table.grad[:5, :3, :5, :2]: ", table.grad[:5, :3, :5, :2]) +# input() + +# print("query.is_contiguous(): ", query.is_contiguous()) +# print("key.is_contiguous(): ", key.is_contiguous()) +# print("index_0.is_contiguous(): ", index_0.is_contiguous()) +# print("index_1.is_contiguous(): ", index_1.is_contiguous()) + +output_v2 = pointops.dot_prod_with_idx(query, index.int(), table, rel_index.int()) +loss = output_v2.mean() +loss.backward() + +print("output_v2.shape: {}, output_v2[:5,:10]: {}".format(output_v2.shape, output_v2[:5,:10])) +print("v2: query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5]) +print("v2: table.grad[:5, :3, :5, :2]: ", table.grad[:5, :3, :5, :2]) +input() + +# print("((output-output_v2)**2).max(): ", ((output-output_v2)**2).max()) + +# print("torch.max((attn_flat-attn_flat_v2)**2): ", torch.max((attn_flat-attn_flat_v2)**2)) + diff --git a/utils/pointops2/functions/test_relative_pos_encoding_op_step1_v2.py b/utils/pointops2/functions/test_relative_pos_encoding_op_step1_v2.py new file mode 100644 index 0000000..3117274 --- /dev/null +++ b/utils/pointops2/functions/test_relative_pos_encoding_op_step1_v2.py @@ -0,0 +1,64 @@ +import torch +import pointops +from torch_scatter import scatter_max, scatter_mean, scatter_add, scatter_min, scatter_sum + +torch.manual_seed(1) + +M = 80000 +N = 3500 +hdim = 16 +h = 6 +L = 31 +query = torch.rand(N, h, hdim).cuda() +table_q = torch.rand(L, h, hdim, 3).cuda() +key = torch.rand(N, h, hdim).cuda() +table_k = torch.rand(L, h, hdim, 3).cuda() + +index_q = torch.rand(M) +index_q[index_q < 0] = 0 +index_q = (index_q*N).long().cuda() + +index_k = torch.rand(M) +index_k[index_k < 0] = 0 +index_k = (index_k*N).long().cuda() + +rel_index = torch.rand(M, 3) +rel_index[rel_index < 0] = 0 +rel_index = (rel_index*L).long().cuda() + +query.requires_grad = True +table_q.requires_grad = True +key.requires_grad = True +table_k.requires_grad = True + +output1 = pointops.dot_prod_with_idx(query, index_q.int(), table_q, rel_index.int()) +output2 = pointops.dot_prod_with_idx(key, index_k.int(), table_k, rel_index.int()) +output = output1 + output2 +# loss = output.mean() +# loss.backward() + +# print("output.shape: {}, output[:5,:10]: {}".format(output.shape, output[:5,:10])) +# print("query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5]) +# print("table_q.grad[:5, :3, :5, :2]: ", table_q.grad[:5, :3, :5, :2]) +# print("key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5]) +# print("table_k.grad[:5, :3, :5, :2]: ", table_k.grad[:5, :3, :5, :2]) +# input() + +# print("query.is_contiguous(): ", query.is_contiguous()) +# print("key.is_contiguous(): ", key.is_contiguous()) +# print("index_0.is_contiguous(): ", index_0.is_contiguous()) +# print("index_1.is_contiguous(): ", index_1.is_contiguous()) + +output_v2 = pointops.dot_prod_with_idx_v2(query, index_q.int(), key, index_k.int(), table_q, table_k, rel_index.int()) +loss = output_v2.mean() +loss.backward() + +print("output_v2.shape: {}, output_v2[:5,:10]: {}".format(output_v2.shape, output_v2[:5,:10])) +print("v2 query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5]) +print("v2 table_q.grad[:5, :3, :5, :2]: ", table_q.grad[:5, :3, :5, :2]) +print("v2 key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5]) +print("v2 table_k.grad[:5, :3, :5, :2]: ", table_k.grad[:5, :3, :5, :2]) +# input() + +print("((output-output_v2)**2).max(): ", ((output-output_v2)**2).max()) + diff --git a/utils/pointops2/functions/test_relative_pos_encoding_op_step1_v3.py b/utils/pointops2/functions/test_relative_pos_encoding_op_step1_v3.py new file mode 100644 index 0000000..ea68061 --- /dev/null +++ b/utils/pointops2/functions/test_relative_pos_encoding_op_step1_v3.py @@ -0,0 +1,90 @@ +import torch +import pointops +from torch_scatter import scatter_max, scatter_mean, scatter_add, scatter_min, scatter_sum + +torch.manual_seed(1) + +M = 80000 +N = 3500 +# M = 80 +# N = 5 +hdim = 16 +h = 6 +L = 31 +query = torch.rand(N, h, hdim).cuda() +table_q = torch.rand(L, h, hdim, 3).cuda() +key = torch.rand(N, h, hdim).cuda() +table_k = torch.rand(L, h, hdim, 3).cuda() + +index_q = torch.rand(M) +index_q[index_q < 0] = 0 +index_q = (index_q*N).long().cuda() + +index_k = torch.rand(M) +index_k[index_k < 0] = 0 +index_k = (index_k*N).long().cuda() + +rel_index = torch.rand(M, 3) +rel_index[rel_index < 0] = 0 +rel_index = (rel_index*L).long().cuda() + + +# rearrange index for acceleration +index_q, indices = torch.sort(index_q) #[M,] +index_k = index_k[indices] #[M,] +rel_index = rel_index[indices] +index_q_counts = index_q.bincount() + +print("index_q_counts.shape: ", index_q_counts.shape) + +n_max = index_q_counts.max() +index_q_offsets = index_q_counts.cumsum(dim=-1) #[N] + +print("v1 index_q_offsets.shape: ", index_q_offsets.shape) + +index_q_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_q_offsets], 0) #[N+1] + +# print("index_q[:100]: ", index_q[:100]) +print("n_max: ", n_max) +print("index_q_offsets.shape: ", index_q_offsets.shape) +# input() + +print("index_q_offsets[:100]: ", index_q_offsets[:100]) +print("index_k[:20]: ", index_k[:20]) + +query.requires_grad = True +table_q.requires_grad = True +key.requires_grad = True +table_k.requires_grad = True + +output1 = pointops.dot_prod_with_idx(query, index_q.int(), table_q, rel_index.int()) +output2 = pointops.dot_prod_with_idx(key, index_k.int(), table_k, rel_index.int()) +output = output1 + output2 +loss = output.mean() +loss.backward() + +# print("output.shape: {}, output[:5,:10]: {}".format(output.shape, output[:5,:10])) +# print("query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5]) +# print("table_q.grad[:5, :3, :5, :2]: ", table_q.grad[:5, :3, :5, :2]) +# print("key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5]) +# print("table_k.grad[:5, :3, :5, :2]: ", table_k.grad[:5, :3, :5, :2]) +# input() + +# print("query.is_contiguous(): ", query.is_contiguous()) +# print("key.is_contiguous(): ", key.is_contiguous()) +# print("index_q.is_contiguous(): ", index_q.is_contiguous()) +# print("index_k.is_contiguous(): ", index_k.is_contiguous()) + +output_v2 = pointops.dot_prod_with_idx_v3(query, index_q_offsets.int(), n_max, key, index_k.int(), table_q, table_k, rel_index.int()) +# loss = output_v2.mean() +# loss.backward() + +# print("output_v2.shape: {}, output_v2[:5,:10]: {}".format(output_v2.shape, output_v2[:5,:10])) +# print("v2 query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5]) +# print("v2 table_q.grad[:5, :3, :5, :2]: ", table_q.grad[:5, :3, :5, :2]) +# print("v2 key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5]) +# print("v2 table_k.grad[:5, :3, :5, :2]: ", table_k.grad[:5, :3, :5, :2]) +# input() + +print("((output-output_v2)**2).max(): ", ((output-output_v2)**2).max()) + diff --git a/utils/pointops2/functions/test_relative_pos_encoding_op_step2.py b/utils/pointops2/functions/test_relative_pos_encoding_op_step2.py new file mode 100644 index 0000000..eae7d9f --- /dev/null +++ b/utils/pointops2/functions/test_relative_pos_encoding_op_step2.py @@ -0,0 +1,66 @@ +import torch +import pointops +from torch_scatter import scatter_max, scatter_mean, scatter_add, scatter_min, scatter_sum + +torch.manual_seed(1) + +M = 80000 +N = 3500 +hdim = 16 +h = 6 +L = 31 +attn = torch.rand(M, h).cuda() +v = torch.rand(N, h, hdim).cuda() +table = torch.rand(L, h, hdim, 3).cuda() + +index_0 = torch.rand(M) +index_0[index_0 < 0] = 0 +index_0 = (index_0*N).long().cuda() + +index_1 = torch.rand(M) +index_1[index_1 < 0] = 0 +index_1 = (index_1*N).long().cuda() + +rel_index = torch.rand(M, 3) +rel_index[rel_index < 0] = 0 +rel_index = (rel_index*L).long().cuda() + +attn.requires_grad = True +v.requires_grad = True +table.requires_grad = True + +v_flat = v[index_1] #[M, h, hdim] +table_x, table_y, table_z = table[:,:,:,0], table[:,:,:,1], table[:,:,:,2] #[L, h, hdim] +rel_index_x, rel_index_y, rel_index_z = rel_index[:,0], rel_index[:,1], rel_index[:,2] #[M] +rel_pos_encoding = table_x[rel_index_x] + table_y[rel_index_y] + table_z[rel_index_z] #[M, h, hdim] +v_flat_new = v_flat + rel_pos_encoding #[M, h, hdim] +output = attn.unsqueeze(-1) * v_flat_new #[M, h, hdim] +output = scatter_sum(src=output, index=index_0, dim=0, dim_size=N) #[N, h, hdim] +loss = output.mean() +loss.backward() + +print("output.shape: {}, output[:5,:10,:5]: {}".format(output.shape, output[:5,:10, :5])) +print("attn.grad[:5, :3]: ", attn.grad[:5, :3]) +print("v.grad[:5, :3, :5]: ", v.grad[:5, :3, :5]) +print("table.grad[:5, :3, :5, :2]: ", table.grad[:5, :3, :5, :2]) +input() + +# print("query.is_contiguous(): ", query.is_contiguous()) +# print("key.is_contiguous(): ", key.is_contiguous()) +# print("index_0.is_contiguous(): ", index_0.is_contiguous()) +# print("index_1.is_contiguous(): ", index_1.is_contiguous()) + +# output_v2 = pointops.attention_step2_with_rel_pos_value(attn, v, index_0.int(), index_1.int(), table, rel_index.int()) +# loss = output_v2.mean() +# loss.backward() + +# print("output_v2.shape: {}, output_v2[:5,:10,:5]: {}".format(output_v2.shape, output_v2[:5,:10,:5])) +# print("v2 attn.grad[:5, :3]: ", attn.grad[:5, :3]) +# print("v2 v.grad[:5, :3, :5]: ", v.grad[:5, :3, :5]) +# print("v2 table.grad[:5, :3, :5, :2]: ", table.grad[:5, :3, :5, :2]) +# input() + +# print("((output-output_v2)**2).max(): ", ((output-output_v2)**2).max()) + +# print("torch.max((attn_flat-attn_flat_v2)**2): ", torch.max((attn_flat-attn_flat_v2)**2)) + diff --git a/utils/pointops2/functions/test_relative_pos_encoding_op_step2_v2.py b/utils/pointops2/functions/test_relative_pos_encoding_op_step2_v2.py new file mode 100644 index 0000000..9c2ee11 --- /dev/null +++ b/utils/pointops2/functions/test_relative_pos_encoding_op_step2_v2.py @@ -0,0 +1,92 @@ +import torch +import pointops +from torch_scatter import scatter_max, scatter_mean, scatter_add, scatter_min, scatter_sum + +torch.manual_seed(1) + +M = 80000 +N = 3500 +hdim = 16 +h = 6 +L = 31 +attn = torch.rand(M, h).cuda() +v = torch.rand(N, h, hdim).cuda() +table = torch.rand(L, h, hdim, 3).cuda() + +index_0 = torch.rand(M) +index_0[index_0 < 0] = 0 +index_0 = (index_0*N).long().cuda() + +index_1 = torch.rand(M) +index_1[index_1 < 0] = 0 +index_1 = (index_1*N).long().cuda() + +rel_index = torch.rand(M, 3) +rel_index[rel_index < 0] = 0 +rel_index = (rel_index*L).long().cuda() + + +# rearrange index for acceleration +index_0, indices = torch.sort(index_0) #[M,] +index_1 = index_1[indices] #[M,] +rel_index = rel_index[indices] +index_0_counts = index_0.bincount() + +print("index_0_counts.shape: ", index_0_counts.shape) + +n_max = index_0_counts.max() +index_0_offsets = index_0_counts.cumsum(dim=-1) #[N] + +print("v1 index_0_offsets.shape: ", index_0_offsets.shape) + +index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1] + + +attn.requires_grad = True +v.requires_grad = True +table.requires_grad = True + + +output = pointops.attention_step2_with_rel_pos_value(attn, v, index_0.int(), index_1.int(), table, rel_index.int()) +loss = output.mean() +loss.backward() + +print("output.shape: {}, output[:5,:10,:5]: {}".format(output.shape, output[:5,:10, :5])) +print("attn.grad[:5, :3]: ", attn.grad[:5, :3]) +print("v.grad[:5, :3, :5]: ", v.grad[:5, :3, :5]) +print("table.grad[:5, :3, :5, :2]: ", table.grad[:5, :3, :5, :2]) +# input() + +attn_grad = attn.grad.clone() +v_grad = v.grad.clone() +table_grad = table.grad.clone() + +attn.grad.zero_() +v.grad.zero_() +table.grad.zero_() + +# print("query.is_contiguous(): ", query.is_contiguous()) +# print("key.is_contiguous(): ", key.is_contiguous()) +# print("index_0.is_contiguous(): ", index_0.is_contiguous()) +# print("index_1.is_contiguous(): ", index_1.is_contiguous()) + +output_v2 = pointops.attention_step2_with_rel_pos_value_v2(attn, v, index_0_offsets.int(), n_max, index_1.int(), table, rel_index.int()) +loss = output_v2.mean() +loss.backward() + +print("output_v2.shape: {}, output_v2[:5,:10,:5]: {}".format(output_v2.shape, output_v2[:5,:10,:5])) +print("v2 attn.grad[:5, :3]: ", attn.grad[:5, :3]) +print("v2 v.grad[:5, :3, :5]: ", v.grad[:5, :3, :5]) +print("v2 table.grad[:5, :3, :5, :2]: ", table.grad[:5, :3, :5, :2]) +# input() + +print("((output-output_v2)**2).max(): ", ((output-output_v2)**2).max()) + +print("((attn_grad-attn.grad)**2).max(): ", ((attn_grad-attn.grad)**2).max()) + +print("((v_grad-v.grad)**2).max(): ", ((v_grad-v.grad)**2).max()) + +print("((table_grad-table.grad)**2).max(): ", ((table_grad-table.grad)**2).max()) + +# print("torch.max((attn_flat-attn_flat_v2)**2): ", torch.max((attn_flat-attn_flat_v2)**2)) + diff --git a/utils/pointops2/setup.py b/utils/pointops2/setup.py new file mode 100644 index 0000000..03f5d6e --- /dev/null +++ b/utils/pointops2/setup.py @@ -0,0 +1,42 @@ +#python3 setup.py install +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import os +from distutils.sysconfig import get_config_vars + +(opt,) = get_config_vars('OPT') +os.environ['OPT'] = " ".join( + flag for flag in opt.split() if flag != '-Wstrict-prototypes' +) + +setup( + name='pointops2', + ext_modules=[ + CUDAExtension('pointops2_cuda', [ + 'src/pointops_api.cpp', + 'src/knnquery/knnquery_cuda.cpp', + 'src/knnquery/knnquery_cuda_kernel.cu', + 'src/sampling/sampling_cuda.cpp', + 'src/sampling/sampling_cuda_kernel.cu', + 'src/grouping/grouping_cuda.cpp', + 'src/grouping/grouping_cuda_kernel.cu', + 'src/interpolation/interpolation_cuda.cpp', + 'src/interpolation/interpolation_cuda_kernel.cu', + 'src/subtraction/subtraction_cuda.cpp', + 'src/subtraction/subtraction_cuda_kernel.cu', + 'src/aggregation/aggregation_cuda.cpp', + 'src/aggregation/aggregation_cuda_kernel.cu', + 'src/attention/attention_cuda.cpp', + 'src/attention/attention_cuda_kernel.cu', + 'src/rpe/relative_pos_encoding_cuda.cpp', + 'src/rpe/relative_pos_encoding_cuda_kernel.cu', + 'src/attention_v2/attention_cuda_v2.cpp', + 'src/attention_v2/attention_cuda_kernel_v2.cu', + 'src/rpe_v2/relative_pos_encoding_cuda_v2.cpp', + 'src/rpe_v2/relative_pos_encoding_cuda_kernel_v2.cu', + ], + extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']} + ) + ], + cmdclass={'build_ext': BuildExtension} +) diff --git a/utils/pointops2/src/__init__.py b/utils/pointops2/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/pointops2/src/aggregation/aggregation_cuda.cpp b/utils/pointops2/src/aggregation/aggregation_cuda.cpp new file mode 100644 index 0000000..d5ad9cd --- /dev/null +++ b/utils/pointops2/src/aggregation/aggregation_cuda.cpp @@ -0,0 +1,29 @@ +#include +#include +#include +#include +#include "aggregation_cuda_kernel.h" + + +void aggregation_forward_cuda(int n, int nsample, int c, int w_c, at::Tensor input_tensor, at::Tensor position_tensor, at::Tensor weight_tensor, at::Tensor idx_tensor, at::Tensor output_tensor) +{ + const float *input = input_tensor.data_ptr(); + const float *position = position_tensor.data_ptr(); + const float *weight = weight_tensor.data_ptr(); + const int *idx = idx_tensor.data_ptr(); + float *output = output_tensor.data_ptr(); + aggregation_forward_cuda_launcher(n, nsample, c, w_c, input, position, weight, idx, output); +} + +void aggregation_backward_cuda(int n, int nsample, int c, int w_c, at::Tensor input_tensor, at::Tensor position_tensor, at::Tensor weight_tensor, at::Tensor idx_tensor, at::Tensor grad_output_tensor, at::Tensor grad_input_tensor, at::Tensor grad_position_tensor, at::Tensor grad_weight_tensor) +{ + const float *input = input_tensor.data_ptr(); + const float *position = position_tensor.data_ptr(); + const float *weight = weight_tensor.data_ptr(); + const int *idx = idx_tensor.data_ptr(); + const float *grad_output = grad_output_tensor.data_ptr(); + float *grad_input = grad_input_tensor.data_ptr(); + float *grad_position = grad_position_tensor.data_ptr(); + float *grad_weight = grad_weight_tensor.data_ptr(); + aggregation_backward_cuda_launcher(n, nsample, c, w_c, input, position, weight, idx, grad_output, grad_input, grad_position, grad_weight); +} diff --git a/utils/pointops2/src/aggregation/aggregation_cuda_kernel.cu b/utils/pointops2/src/aggregation/aggregation_cuda_kernel.cu new file mode 100644 index 0000000..8339bb7 --- /dev/null +++ b/utils/pointops2/src/aggregation/aggregation_cuda_kernel.cu @@ -0,0 +1,53 @@ +#include "../cuda_utils.h" +#include "aggregation_cuda_kernel.h" + + +__global__ void aggregation_forward_cuda_kernel(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, float *output) { + // input: input: (n, c), position: (n, nsample, c), weight: (n, nsample, w_c), idx: (n, nsample), output: (n, c) + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= n * c) return; + const int c_idx = index % c; + const int n_idx = index / c; + const int w_c_idx = c_idx % w_c; + for (int nsample_idx = 0; nsample_idx < nsample; nsample_idx++) + { + int idx_idx = n_idx * nsample + nsample_idx; + int input_idx = idx[idx_idx] * c + c_idx; + int position_idx = n_idx * nsample * c + nsample_idx * c + c_idx; + int weight_idx = n_idx * nsample * w_c + nsample_idx * w_c + w_c_idx; + output[index] += (input[input_idx] + position[position_idx]) * weight[weight_idx]; + } +} + +__global__ void aggregation_backward_cuda_kernel(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, const float *grad_output, float *grad_input, float *grad_position, float *grad_weight) { + // input: grad_output: (n, c), output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight: (n, nsample, w_c) + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= n * c) return; + const int c_idx = index % c; + const int n_idx = index / c; + const int w_c_idx = c_idx % w_c; + for (int nsample_idx = 0; nsample_idx < nsample; nsample_idx++) + { + int idx_idx = n_idx * nsample + nsample_idx; + int input_idx = idx[idx_idx] * c + c_idx; + int position_idx = n_idx * nsample * c + nsample_idx * c + c_idx; + int weight_idx = n_idx * nsample * w_c + nsample_idx * w_c + w_c_idx; + atomicAdd(grad_input + input_idx, grad_output[index] * weight[weight_idx]); + grad_position[position_idx] = grad_output[index] * weight[weight_idx]; + atomicAdd(grad_weight + weight_idx, grad_output[index] * (input[input_idx] + position[position_idx])); + } +} + +void aggregation_forward_cuda_launcher(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, float *output) { + // input: input: (n, c), position: (n, nsample, c), weight: (n, nsample, w_c), idx: (n, nsample), output: (n, c) + dim3 blocks(DIVUP(n * c, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + aggregation_forward_cuda_kernel<<>>(n, nsample, c, w_c, input, position, weight, idx, output); +} + +void aggregation_backward_cuda_launcher(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, const float *grad_output, float *grad_input, float *grad_position, float *grad_weight) { + // input: grad_output: (n, c), output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight: (n, nsample, w_c) + dim3 blocks(DIVUP(n * c, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + aggregation_backward_cuda_kernel<<>>(n, nsample, c, w_c, input, position, weight, idx, grad_output, grad_input, grad_position, grad_weight); +} diff --git a/utils/pointops2/src/aggregation/aggregation_cuda_kernel.h b/utils/pointops2/src/aggregation/aggregation_cuda_kernel.h new file mode 100644 index 0000000..5211a96 --- /dev/null +++ b/utils/pointops2/src/aggregation/aggregation_cuda_kernel.h @@ -0,0 +1,20 @@ +#ifndef _AGGREGATION_CUDA_KERNEL +#define _AGGREGATION_CUDA_KERNEL +#include +#include +#include + +void aggregation_forward_cuda(int n, int nsample, int c, int w_c, at::Tensor input_tensor, at::Tensor position_tensor, at::Tensor weight_tensor, at::Tensor idx_tensor, at::Tensor output_tensor); +void aggregation_backward_cuda(int n, int nsample, int c, int w_c, at::Tensor input_tensor, at::Tensor position_tensor, at::Tensor weight_tensor, at::Tensor idx_tensor, at::Tensor grad_output_tensor, at::Tensor grad_input_tensor, at::Tensor grad_position_tensor, at::Tensor grad_weight_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void aggregation_forward_cuda_launcher(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, float *output); +void aggregation_backward_cuda_launcher(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, const float *grad_output, float *grad_input, float *grad_position, float *grad_weight); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/utils/pointops2/src/attention/attention_cuda.cpp b/utils/pointops2/src/attention/attention_cuda.cpp new file mode 100644 index 0000000..8d2c725 --- /dev/null +++ b/utils/pointops2/src/attention/attention_cuda.cpp @@ -0,0 +1,56 @@ +#include +#include +#include +#include +#include "attention_cuda_kernel.h" + +void attention_step1_forward_cuda(int N, int M, int h, int C, at::Tensor q_tensor, at::Tensor k_tensor, + at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor attn_tensor) +{ + const float *q = q_tensor.data_ptr(); + const float *k = k_tensor.data_ptr(); + const int *index0 = index0_tensor.data_ptr(); + const int *index1 = index1_tensor.data_ptr(); + float *attn = attn_tensor.data_ptr(); + attention_step1_forward_cuda_launcher(N, M, h, C, q, k, index0, index1, attn); +} + +void attention_step1_backward_cuda(int N, int M, int h, int C, at::Tensor grad_out_tensor, + at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor q_tensor, at::Tensor k_tensor, + at::Tensor grad_q_tensor, at::Tensor grad_k_tensor) +{ + const float *grad_out = grad_out_tensor.data_ptr(); + const int *index0 = index0_tensor.data_ptr(); + const int *index1 = index1_tensor.data_ptr(); + const float *q = q_tensor.data_ptr(); + const float *k = k_tensor.data_ptr(); + float *grad_q = grad_q_tensor.data_ptr(); + float *grad_k = grad_k_tensor.data_ptr(); + attention_step1_backward_cuda_launcher(N, M, h, C, grad_out, index0, index1, q, k, grad_q, grad_k); +} + +void attention_step2_forward_cuda(int N, int M, int h, int C, at::Tensor attn_tensor, at::Tensor v_tensor, + at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor output_tensor) +{ + const float *attn = attn_tensor.data_ptr(); + const float *v = v_tensor.data_ptr(); + const int *index0 = index0_tensor.data_ptr(); + const int *index1 = index1_tensor.data_ptr(); + float *output = output_tensor.data_ptr(); + attention_step2_forward_cuda_launcher(N, M, h, C, attn, v, index0, index1, output); +} + + +void attention_step2_backward_cuda(int N, int M, int h, int C, at::Tensor grad_out_tensor, + at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor attn_tensor, at::Tensor v_tensor, + at::Tensor grad_attn_tensor, at::Tensor grad_v_tensor) +{ + const float *grad_out = grad_out_tensor.data_ptr(); + const int *index0 = index0_tensor.data_ptr(); + const int *index1 = index1_tensor.data_ptr(); + const float *attn = attn_tensor.data_ptr(); + const float *v = v_tensor.data_ptr(); + float *grad_attn = grad_attn_tensor.data_ptr(); + float *grad_v = grad_v_tensor.data_ptr(); + attention_step2_backward_cuda_launcher(N, M, h, C, grad_out, index0, index1, attn, v, grad_attn, grad_v); +} diff --git a/utils/pointops2/src/attention/attention_cuda_kernel.cu b/utils/pointops2/src/attention/attention_cuda_kernel.cu new file mode 100644 index 0000000..f71ad62 --- /dev/null +++ b/utils/pointops2/src/attention/attention_cuda_kernel.cu @@ -0,0 +1,103 @@ +#include "../cuda_utils.h" +#include "attention_cuda_kernel.h" + + +__global__ void attention_step1_forward_cuda_kernel( // M, h, C//h + int N, int M, int h, int C, const float *q, const float *k, + const int *index0, const int *index1, float *attn) { + + int c_idx = blockIdx.z; + int h_idx = blockIdx.y; + int m_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (m_idx >= M || h_idx >= h || c_idx >= C / h) return; + + int idx0 = index0[m_idx]; + int idx1 = index1[m_idx]; + float val = q[idx0*C+h_idx*C/h+c_idx] * k[idx1*C+h_idx*C/h+c_idx]; + atomicAdd(attn+m_idx*h+h_idx, val); +} + +__global__ void attention_step1_backward_cuda_kernel( // M, h, C//h + int N, int M, int h, int C, const float *grad_out, const int *index0, const int *index1, const float *q, const float *k, + float *grad_q, float *grad_k) { + + int c_idx = blockIdx.z; + int h_idx = blockIdx.y; + int m_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (m_idx >= M || h_idx >= h || c_idx >= C / h) return; + + int idx0 = index0[m_idx]; + int idx1 = index1[m_idx]; + int grad_out_idx = m_idx*h+h_idx; + int q_idx = idx0*C+h_idx*C/h+c_idx; + int k_idx = idx1*C+h_idx*C/h+c_idx; + atomicAdd(grad_q+q_idx, grad_out[grad_out_idx] * k[k_idx]); + atomicAdd(grad_k+k_idx, grad_out[grad_out_idx] * q[q_idx]); +} + +void attention_step1_forward_cuda_launcher(int N, int M, int h, int C, const float *q, const float *k, + const int *index0, const int *index1, float *attn) { + // input: attn: (M, h), v: (N, h, C/h), index0: (M, ), index1: (M, ) + //dim3 blocks(DIVUP(C/h, THREADS_PER_BLOCK), h, M); + dim3 blocks(DIVUP(M, THREADS_PER_BLOCK), h, C/h); + dim3 threads(THREADS_PER_BLOCK); + attention_step1_forward_cuda_kernel<<>>(N, M, h, C, q, k, index0, index1, attn); +} + +void attention_step1_backward_cuda_launcher(int N, int M, int h, int C, const float *grad_out, const int *index0, const int *index1, + const float *q, const float *k, float *grad_q, float *grad_k) { + // input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c) + //dim3 blocks(DIVUP(C/h, THREADS_PER_BLOCK), h, M); + dim3 blocks(DIVUP(M, THREADS_PER_BLOCK), h, C/h); + dim3 threads(THREADS_PER_BLOCK); + attention_step1_backward_cuda_kernel<<>>(N, M, h, C, grad_out, index0, index1, q, k, grad_q, grad_k); +} + +__global__ void attention_step2_forward_cuda_kernel( // M, h, C//h + int N, int M, int h, int C, const float *attn, const float *v, + const int *index0, const int *index1, float *output) { + + int c_idx = blockIdx.z; + int h_idx = blockIdx.y; + int m_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (m_idx >= M || h_idx >= h || c_idx >= C / h) return; + + int idx1 = index1[m_idx]; + float val = attn[m_idx*h+h_idx] * v[idx1*C+h_idx*C/h+c_idx]; + int idx0 = index0[m_idx]; + atomicAdd(output+idx0*C+h_idx*C/h+c_idx, val); +} + +__global__ void attention_step2_backward_cuda_kernel( // M, h, C//h + int N, int M, int h, int C, const float *grad_out, const int *index0, const int *index1, const float *attn, const float *v, + float *grad_attn, float *grad_v) { + + int c_idx = blockIdx.z; + int h_idx = blockIdx.y; + int m_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (m_idx >= M || h_idx >= h || c_idx >= C / h) return; + + int idx0 = index0[m_idx]; + int idx1 = index1[m_idx]; + int grad_out_idx = idx0*C+h_idx*C/h+c_idx; + atomicAdd(grad_attn+m_idx*h+h_idx, grad_out[grad_out_idx] * v[idx1*C+h_idx*C/h+c_idx]); + atomicAdd(grad_v+idx1*C+h_idx*C/h+c_idx, grad_out[grad_out_idx] * attn[m_idx*h+h_idx]); +} + +void attention_step2_forward_cuda_launcher(int N, int M, int h, int C, const float *attn, const float *v, + const int *index0, const int *index1, float *output) { + // input: attn: (M, h), v: (N, h, C/h), index0: (M, ), index1: (M, ) + //dim3 blocks(DIVUP(C/h, THREADS_PER_BLOCK), h, M); + dim3 blocks(DIVUP(M, THREADS_PER_BLOCK), h, C/h); + dim3 threads(THREADS_PER_BLOCK); + attention_step2_forward_cuda_kernel<<>>(N, M, h, C, attn, v, index0, index1, output); +} + +void attention_step2_backward_cuda_launcher(int N, int M, int h, int C, const float *grad_out, const int *index0, const int *index1, + const float *attn, const float *v, float *grad_attn, float *grad_v) { + // input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c) + //dim3 blocks(DIVUP(C/h, THREADS_PER_BLOCK), h, M); + dim3 blocks(DIVUP(M, THREADS_PER_BLOCK), h, C/h); + dim3 threads(THREADS_PER_BLOCK); + attention_step2_backward_cuda_kernel<<>>(N, M, h, C, grad_out, index0, index1, attn, v, grad_attn, grad_v); +} diff --git a/utils/pointops2/src/attention/attention_cuda_kernel.h b/utils/pointops2/src/attention/attention_cuda_kernel.h new file mode 100644 index 0000000..cbd99b9 --- /dev/null +++ b/utils/pointops2/src/attention/attention_cuda_kernel.h @@ -0,0 +1,26 @@ +#ifndef _ATTENTION_CUDA_KERNEL +#define _ATTENTION_CUDA_KERNEL +#include +#include +#include + +void attention_step1_forward_cuda(int N, int M, int h, int C, at::Tensor q_tensor, at::Tensor k_tensor, at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor attn_tensor); +void attention_step1_backward_cuda(int N, int M, int h, int C, at::Tensor grad_out_tensor, at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor q_tensor, at::Tensor k_tensor, at::Tensor grad_q_tensor, at::Tensor grad_k_tensor); + +void attention_step2_forward_cuda(int N, int M, int h, int C, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor output_tensor); +void attention_step2_backward_cuda(int N, int M, int h, int C, at::Tensor grad_out_tensor, at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor grad_attn_tensor, at::Tensor grad_v_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void attention_step1_forward_cuda_launcher(int N, int M, int h, int C, const float *q, const float *k, const int *index0, const int *index1, float *attn); +void attention_step1_backward_cuda_launcher(int N, int M, int h, int C, const float *grad_out, const int *index0, const int *index1, const float *q, const float *k, float *grad_q, float *grad_k); + +void attention_step2_forward_cuda_launcher(int N, int M, int h, int C, const float *attn, const float *v, const int *index0, const int *index1, float *output); +void attention_step2_backward_cuda_launcher(int N, int M, int h, int C, const float *grad_out, const int *index0, const int *index1, const float *attn, const float *v, float *grad_attn, float *grad_v); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/utils/pointops2/src/attention_v2/attention_cuda_kernel_v2.cu b/utils/pointops2/src/attention_v2/attention_cuda_kernel_v2.cu new file mode 100644 index 0000000..2e5343f --- /dev/null +++ b/utils/pointops2/src/attention_v2/attention_cuda_kernel_v2.cu @@ -0,0 +1,193 @@ +#include "../cuda_utils.h" +#include "attention_cuda_kernel_v2.h" + + +template +__global__ void attention_step1_forward_cuda_kernel_v2( // M, h, C//h + int N, int M, int h, const float *q, const float *k, + const int *index0_offsets, const int *index1, float *attn) { + + int h_idx = blockIdx.y; + int q_idx = blockIdx.x; + int n_idx = threadIdx.x; + int C = h * d; + // if (m_idx >= M || h_idx >= h || c_idx >= C / h) return; + + __shared__ float query_vec[d]; + __shared__ int start, end; + + // if(n_idx == 0){ + // printf("blockDim.x: %d\n", blockDim.x); + // } + + if (n_idx == 0){ + start = index0_offsets[q_idx]; + end = index0_offsets[q_idx+1]; + // printf("start: %d, end: %d, blockDim.x: %d\n", start, end, blockDim.x); + } + for(int i = n_idx; i < d; i += blockDim.x) + query_vec[i] = q[q_idx*C + h_idx*d + i]; + + __syncthreads(); + + int m_idx = start + n_idx; + if(m_idx >= end) + return; + + float sum = 0; + for(int i = 0; i < d; i++){ + int k_idx = index1[m_idx]; + float key = k[k_idx * C + h_idx * d + i]; + sum += query_vec[i] * key; + } + attn[m_idx*h + h_idx] = sum; + // int idx0 = index0[m_idx]; + // int idx1 = index1[m_idx]; + // float val = q[idx0*C+h_idx*C/h+c_idx] * k[idx1*C+h_idx*C/h+c_idx]; + // atomicAdd(attn+m_idx*h+h_idx, val); +} + +template +__global__ void attention_step1_backward_cuda_kernel_v2( // M, h, C//h + int N, int M, int h, const float *grad_out, const int *index0_offsets, const int *index1, const float *q, const float *k, + float *grad_q, float *grad_k) { + + int h_idx = blockIdx.y; + int q_idx = blockIdx.x; + int n_idx = threadIdx.x; + int C = d * h; + + __shared__ float query_vec[d]; + __shared__ int start, end; + + if (n_idx == 0){ + start = index0_offsets[q_idx]; + end = index0_offsets[q_idx+1]; + } + for(int i = n_idx; i < d; i += blockDim.x) + query_vec[i] = q[q_idx*C + h_idx*d + i]; + + __shared__ float gradient_new[d]; + for(int i = n_idx; i < d; i += blockDim.x) + gradient_new[i] = 0; + + __syncthreads(); + + int m_idx = start + n_idx; + if(m_idx < end){ + float gradient = grad_out[m_idx*h + h_idx]; + for(int i = 0; i < d; i++){ + int k_idx = index1[m_idx]; + atomicAdd(&gradient_new[i], gradient * k[k_idx*C + h_idx*d + i]); + atomicAdd(grad_k + k_idx*C + h_idx*d + i, gradient * query_vec[i]); + } + } + __syncthreads(); + + for(int i = n_idx; i < d; i += blockDim.x) + grad_q[q_idx*C + h_idx*d + i] = gradient_new[i]; +} + +void attention_step1_forward_cuda_launcher_v2(int N, int M, int h, int C, const unsigned int n_max, + const float *q, const float *k, const int *index0_offsets, const int *index1, float *attn) { + // input: attn: (M, h), v: (N, h, C/h), index0: (M, ), index1: (M, ) + //dim3 blocks(DIVUP(C/h, THREADS_PER_BLOCK), h, M); + dim3 blocks(N, h); + unsigned int n_threads = opt_n_threads(n_max); + + n_threads = n_threads == n_max ? n_threads : n_threads * 2; + // n_threads = n_threads > 1024 ? 512 : n_threads; + + // printf("n_max: %d, n_threads: %d\n", n_max, n_threads); + + // dim3 threads(THREADS_PER_BLOCK); + // attention_step1_forward_cuda_kernel_v2<<>>(N, M, h, C, q, k, index0, index1, attn); + + switch (C / h) { + case 16: + attention_step1_forward_cuda_kernel_v2<16><<>>(N, M, h, q, k, index0_offsets, index1, attn); + break; + case 32: + attention_step1_forward_cuda_kernel_v2<32><<>>(N, M, h, q, k, index0_offsets, index1, attn); + break; + default: + throw "d != 16 and d != 32"; + } +} + +void attention_step1_backward_cuda_launcher_v2(int N, int M, int h, int C, const unsigned int n_max, + const float *grad_out, const int *index0_offsets, const int *index1, const float *q, const float *k, float *grad_q, float *grad_k) { + // input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c) + //dim3 blocks(DIVUP(C/h, THREADS_PER_BLOCK), h, M); + // dim3 blocks(DIVUP(M, THREADS_PER_BLOCK), h, C/h); + // dim3 threads(THREADS_PER_BLOCK); + dim3 blocks(N, h); + unsigned int n_threads = opt_n_threads(n_max); + // attention_step1_backward_cuda_kernel_v2<<>>(N, M, h, C/h, grad_out, index0_offsets, index1, q, k, grad_q, grad_k); + + n_threads = n_threads == n_max ? n_threads : n_threads * 2; + // n_threads = n_threads > 1024 ? 512 : n_threads; + + // printf("n_max: %d, n_threads: %d\n", n_max, n_threads); + + switch (C / h) { + case 16: + attention_step1_backward_cuda_kernel_v2<16><<>>(N, M, h, grad_out, index0_offsets, index1, q, k, grad_q, grad_k); + break; + case 32: + attention_step1_backward_cuda_kernel_v2<32><<>>(N, M, h, grad_out, index0_offsets, index1, q, k, grad_q, grad_k); + break; + default: + throw "d != 16 and d != 32"; + } + +} + +__global__ void attention_step2_forward_cuda_kernel_v2( // M, h, C//h + int N, int M, int h, int C, const float *attn, const float *v, + const int *index0, const int *index1, float *output) { + + int c_idx = blockIdx.z; + int h_idx = blockIdx.y; + int m_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (m_idx >= M || h_idx >= h || c_idx >= C / h) return; + + int idx1 = index1[m_idx]; + float val = attn[m_idx*h+h_idx] * v[idx1*C+h_idx*C/h+c_idx]; + int idx0 = index0[m_idx]; + atomicAdd(output+idx0*C+h_idx*C/h+c_idx, val); +} + +__global__ void attention_step2_backward_cuda_kernel_v2( // M, h, C//h + int N, int M, int h, int C, const float *grad_out, const int *index0, const int *index1, const float *attn, const float *v, + float *grad_attn, float *grad_v) { + + int c_idx = blockIdx.z; + int h_idx = blockIdx.y; + int m_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (m_idx >= M || h_idx >= h || c_idx >= C / h) return; + + int idx0 = index0[m_idx]; + int idx1 = index1[m_idx]; + int grad_out_idx = idx0*C+h_idx*C/h+c_idx; + atomicAdd(grad_attn+m_idx*h+h_idx, grad_out[grad_out_idx] * v[idx1*C+h_idx*C/h+c_idx]); + atomicAdd(grad_v+idx1*C+h_idx*C/h+c_idx, grad_out[grad_out_idx] * attn[m_idx*h+h_idx]); +} + +void attention_step2_forward_cuda_launcher_v2(int N, int M, int h, int C, const float *attn, const float *v, + const int *index0, const int *index1, float *output) { + // input: attn: (M, h), v: (N, h, C/h), index0: (M, ), index1: (M, ) + //dim3 blocks(DIVUP(C/h, THREADS_PER_BLOCK), h, M); + dim3 blocks(DIVUP(M, THREADS_PER_BLOCK), h, C/h); + dim3 threads(THREADS_PER_BLOCK); + attention_step2_forward_cuda_kernel_v2<<>>(N, M, h, C, attn, v, index0, index1, output); +} + +void attention_step2_backward_cuda_launcher_v2(int N, int M, int h, int C, const float *grad_out, const int *index0, const int *index1, + const float *attn, const float *v, float *grad_attn, float *grad_v) { + // input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c) + //dim3 blocks(DIVUP(C/h, THREADS_PER_BLOCK), h, M); + dim3 blocks(DIVUP(M, THREADS_PER_BLOCK), h, C/h); + dim3 threads(THREADS_PER_BLOCK); + attention_step2_backward_cuda_kernel_v2<<>>(N, M, h, C, grad_out, index0, index1, attn, v, grad_attn, grad_v); +} diff --git a/utils/pointops2/src/attention_v2/attention_cuda_kernel_v2.h b/utils/pointops2/src/attention_v2/attention_cuda_kernel_v2.h new file mode 100644 index 0000000..d7e7f04 --- /dev/null +++ b/utils/pointops2/src/attention_v2/attention_cuda_kernel_v2.h @@ -0,0 +1,26 @@ +#ifndef _ATTENTION_V2_CUDA_KERNEL +#define _ATTENTION_V2_CUDA_KERNEL +#include +#include +#include + +void attention_step1_forward_cuda_v2(int N, int M, int h, int C, const unsigned int n_max, at::Tensor q_tensor, at::Tensor k_tensor, at::Tensor index0_tensor_offsets, at::Tensor index1_tensor, at::Tensor attn_tensor); +void attention_step1_backward_cuda_v2(int N, int M, int h, int C, const unsigned int n_max, at::Tensor grad_out_tensor, at::Tensor index0_tensor_offsets, at::Tensor index1_tensor, at::Tensor q_tensor, at::Tensor k_tensor, at::Tensor grad_q_tensor, at::Tensor grad_k_tensor); + +void attention_step2_forward_cuda_v2(int N, int M, int h, int C, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor output_tensor); +void attention_step2_backward_cuda_v2(int N, int M, int h, int C, at::Tensor grad_out_tensor, at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor grad_attn_tensor, at::Tensor grad_v_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void attention_step1_forward_cuda_launcher_v2(int N, int M, int h, int C, const unsigned int n_max, const float *q, const float *k, const int *index0_offsets, const int *index1, float *attn); +void attention_step1_backward_cuda_launcher_v2(int N, int M, int h, int C, const unsigned int n_max, const float *grad_out, const int *index0_offsets, const int *index1, const float *q, const float *k, float *grad_q, float *grad_k); + +void attention_step2_forward_cuda_launcher_v2(int N, int M, int h, int C, const float *attn, const float *v, const int *index0, const int *index1, float *output); +void attention_step2_backward_cuda_launcher_v2(int N, int M, int h, int C, const float *grad_out, const int *index0, const int *index1, const float *attn, const float *v, float *grad_attn, float *grad_v); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/utils/pointops2/src/attention_v2/attention_cuda_v2.cpp b/utils/pointops2/src/attention_v2/attention_cuda_v2.cpp new file mode 100644 index 0000000..311adaf --- /dev/null +++ b/utils/pointops2/src/attention_v2/attention_cuda_v2.cpp @@ -0,0 +1,56 @@ +#include +#include +#include +#include +#include "attention_cuda_kernel_v2.h" + +void attention_step1_forward_cuda_v2(int N, int M, int h, int C, const unsigned int n_max, at::Tensor q_tensor, at::Tensor k_tensor, + at::Tensor index0_tensor_offsets, at::Tensor index1_tensor, at::Tensor attn_tensor) +{ + const float *q = q_tensor.data_ptr(); + const float *k = k_tensor.data_ptr(); + const int *index0_offsets = index0_tensor_offsets.data_ptr(); + const int *index1 = index1_tensor.data_ptr(); + float *attn = attn_tensor.data_ptr(); + attention_step1_forward_cuda_launcher_v2(N, M, h, C, n_max, q, k, index0_offsets, index1, attn); +} + +void attention_step1_backward_cuda_v2(int N, int M, int h, int C, const unsigned int n_max, at::Tensor grad_out_tensor, + at::Tensor index0_tensor_offsets, at::Tensor index1_tensor, at::Tensor q_tensor, at::Tensor k_tensor, + at::Tensor grad_q_tensor, at::Tensor grad_k_tensor) +{ + const float *grad_out = grad_out_tensor.data_ptr(); + const int *index0_offsets = index0_tensor_offsets.data_ptr(); + const int *index1 = index1_tensor.data_ptr(); + const float *q = q_tensor.data_ptr(); + const float *k = k_tensor.data_ptr(); + float *grad_q = grad_q_tensor.data_ptr(); + float *grad_k = grad_k_tensor.data_ptr(); + attention_step1_backward_cuda_launcher_v2(N, M, h, C, n_max, grad_out, index0_offsets, index1, q, k, grad_q, grad_k); +} + +void attention_step2_forward_cuda_v2(int N, int M, int h, int C, at::Tensor attn_tensor, at::Tensor v_tensor, + at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor output_tensor) +{ + const float *attn = attn_tensor.data_ptr(); + const float *v = v_tensor.data_ptr(); + const int *index0 = index0_tensor.data_ptr(); + const int *index1 = index1_tensor.data_ptr(); + float *output = output_tensor.data_ptr(); + attention_step2_forward_cuda_launcher_v2(N, M, h, C, attn, v, index0, index1, output); +} + + +void attention_step2_backward_cuda_v2(int N, int M, int h, int C, at::Tensor grad_out_tensor, + at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor attn_tensor, at::Tensor v_tensor, + at::Tensor grad_attn_tensor, at::Tensor grad_v_tensor) +{ + const float *grad_out = grad_out_tensor.data_ptr(); + const int *index0 = index0_tensor.data_ptr(); + const int *index1 = index1_tensor.data_ptr(); + const float *attn = attn_tensor.data_ptr(); + const float *v = v_tensor.data_ptr(); + float *grad_attn = grad_attn_tensor.data_ptr(); + float *grad_v = grad_v_tensor.data_ptr(); + attention_step2_backward_cuda_launcher_v2(N, M, h, C, grad_out, index0, index1, attn, v, grad_attn, grad_v); +} diff --git a/utils/pointops2/src/cuda_utils.h b/utils/pointops2/src/cuda_utils.h new file mode 100644 index 0000000..e67749c --- /dev/null +++ b/utils/pointops2/src/cuda_utils.h @@ -0,0 +1,23 @@ +#ifndef _CUDA_UTILS_H +#define _CUDA_UTILS_H + +#include +#include + +#define TOTAL_THREADS 1024 +#define THREADS_PER_BLOCK 256 +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + +inline int opt_n_threads(int work_size) { + const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); + return std::max(std::min(1 << pow_2, TOTAL_THREADS), 1); +} + +inline dim3 opt_block_config(int x, int y) { + const int x_threads = opt_n_threads(x); + const int y_threads = std::max(std::min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); + dim3 block_config(x_threads, y_threads, 1); + return block_config; +} + +#endif diff --git a/utils/pointops2/src/grouping/grouping_cuda.cpp b/utils/pointops2/src/grouping/grouping_cuda.cpp new file mode 100644 index 0000000..a00d313 --- /dev/null +++ b/utils/pointops2/src/grouping/grouping_cuda.cpp @@ -0,0 +1,22 @@ +#include +#include +#include +#include +#include "grouping_cuda_kernel.h" + + +void grouping_forward_cuda(int m, int nsample, int c, at::Tensor input_tensor, at::Tensor idx_tensor, at::Tensor output_tensor) +{ + const float *input = input_tensor.data_ptr(); + const int *idx = idx_tensor.data_ptr(); + float *output = output_tensor.data_ptr(); + grouping_forward_cuda_launcher(m, nsample, c, input, idx, output); +} + +void grouping_backward_cuda(int m, int nsample, int c, at::Tensor grad_output_tensor, at::Tensor idx_tensor, at::Tensor grad_input_tensor) +{ + const float *grad_output = grad_output_tensor.data_ptr(); + const int *idx = idx_tensor.data_ptr(); + float *grad_input = grad_input_tensor.data_ptr(); + grouping_backward_cuda_launcher(m, nsample, c, grad_output, idx, grad_input); +} diff --git a/utils/pointops2/src/grouping/grouping_cuda_kernel.cu b/utils/pointops2/src/grouping/grouping_cuda_kernel.cu new file mode 100644 index 0000000..58ec0a2 --- /dev/null +++ b/utils/pointops2/src/grouping/grouping_cuda_kernel.cu @@ -0,0 +1,40 @@ +#include "../cuda_utils.h" +#include "grouping_cuda_kernel.h" + + +__global__ void grouping_forward_cuda_kernel(int m, int nsample, int c, const float *__restrict__ input, const int *__restrict__ idx, float *__restrict__ output) { + // input: input: (n, c), idx: (m, nsample), output: (m, nsample, c) + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= m * nsample * c) return; + const int c_idx = index % c; + const int nsample_idx = (index / c) % nsample; + const int m_idx = index / nsample / c; + const int input_idx = idx[m_idx * nsample + nsample_idx] * c + c_idx; + output[index] = input[input_idx]; +} + +__global__ void grouping_backward_cuda_kernel(int m, int nsample, int c, const float *__restrict__ grad_output, const int *__restrict__ idx, float *__restrict__ grad_input) { + // input: grad_output: (m, nsample, c), idx: (m, nsample), output: grad_input: (n, c) + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= m * nsample * c) return; + const int c_idx = index % c; + const int nsample_idx = (index / c) % nsample; + const int m_idx = index / nsample / c; + const int input_idx = idx[m_idx * nsample + nsample_idx] * c + c_idx; + atomicAdd(grad_input + input_idx, grad_output[index]); +} + +void grouping_forward_cuda_launcher(int m, int nsample, int c, const float *input, const int *idx, float *output) { + // input: input: (n, c), idx: (m, nsample), output: (m, nsample, c) + dim3 blocks(DIVUP(m * nsample * c, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + grouping_forward_cuda_kernel<<>>(m, nsample, c, input, idx, output); +} + +void grouping_backward_cuda_launcher(int m, int nsample, int c, const float *grad_output, const int *idx, float *grad_input) +{ + // input: grad_output: (m, nsample, c), idx: (m, nsample), output: grad_input: (n, c) + dim3 blocks(DIVUP(m * nsample * c, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + grouping_backward_cuda_kernel<<>>(m, nsample, c, grad_output, idx, grad_input); +} diff --git a/utils/pointops2/src/grouping/grouping_cuda_kernel.h b/utils/pointops2/src/grouping/grouping_cuda_kernel.h new file mode 100644 index 0000000..3db4aaa --- /dev/null +++ b/utils/pointops2/src/grouping/grouping_cuda_kernel.h @@ -0,0 +1,20 @@ +#ifndef _GROUPING_CUDA_KERNEL +#define _GROUPING_CUDA_KERNEL +#include +#include +#include + +void grouping_forward_cuda(int m, int nsample, int c, at::Tensor input_tensor, at::Tensor idx_tensor, at::Tensor output_tensor); +void grouping_backward_cuda(int m, int nsample, int c, at::Tensor grad_output_tensor, at::Tensor idx_tensor, at::Tensor grad_input_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void grouping_forward_cuda_launcher(int m, int nsample, int c, const float *input, const int *idx, float *output); +void grouping_backward_cuda_launcher(int m, int nsample, int c, const float *grad_output, const int *idx, float *grad_input); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/utils/pointops2/src/interpolation/interpolation_cuda.cpp b/utils/pointops2/src/interpolation/interpolation_cuda.cpp new file mode 100644 index 0000000..a73c02b --- /dev/null +++ b/utils/pointops2/src/interpolation/interpolation_cuda.cpp @@ -0,0 +1,24 @@ +#include +#include +#include +#include +#include "interpolation_cuda_kernel.h" + + +void interpolation_forward_cuda(int n, int c, int k, at::Tensor input_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor output_tensor) +{ + const float *input = input_tensor.data_ptr(); + const int *idx = idx_tensor.data_ptr(); + const float *weight = weight_tensor.data_ptr(); + float *output = output_tensor.data_ptr(); + interpolation_forward_cuda_launcher(n, c, k, input, idx, weight, output); +} + +void interpolation_backward_cuda(int n, int c, int k, at::Tensor grad_output_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_input_tensor) +{ + const float *grad_output = grad_output_tensor.data_ptr(); + const int *idx = idx_tensor.data_ptr(); + const float *weight = weight_tensor.data_ptr(); + float *grad_input = grad_input_tensor.data_ptr(); + interpolation_backward_cuda_launcher(n, c, k, grad_output, idx, weight, grad_input); +} diff --git a/utils/pointops2/src/interpolation/interpolation_cuda_kernel.cu b/utils/pointops2/src/interpolation/interpolation_cuda_kernel.cu new file mode 100644 index 0000000..f560d8c --- /dev/null +++ b/utils/pointops2/src/interpolation/interpolation_cuda_kernel.cu @@ -0,0 +1,47 @@ +#include "../cuda_utils.h" +#include "interpolation_cuda_kernel.h" + + +__global__ void interpolation_forward_cuda_kernel(int n, int c, int k, const float *input, const int *idx, const float *weight, float *output) +{ + // input: input: (m, c), idx: (n, k), weight: (n, k), output: output (n, c) + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= n * c) return; + int c_idx = index % c; + int n_idx = index / c; + for (int i = 0; i < k; i++) + { + int idx_idx = n_idx * k + i; + int input_idx = idx[idx_idx] * c + c_idx; + output[index] += input[input_idx] * weight[idx_idx]; + } +} + +__global__ void interpolation_backward_cuda_kernel(int n, int c, int k, const float *grad_output, const int *idx, const float *weight, float *grad_input) +{ + // input: grad_output: (n, c), idx: (n, k), weight: (n, k), output: grad_input (m, c) + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= n * c) return; + int c_idx = index % c; + int n_idx = index / c; + for (int i = 0; i < k; i++) + { + int idx_idx = n_idx * k + i; + int input_idx = idx[idx_idx] * c + c_idx; + atomicAdd(grad_input + input_idx, grad_output[index] * weight[idx_idx]); + } +} + +void interpolation_forward_cuda_launcher(int n, int c, int k, const float *input, const int *idx, const float *weight, float *output) { + // input: input: (m, c), idx: (n, k), weight: (n, k), output: output (n, c) + dim3 blocks(DIVUP(n * c, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + interpolation_forward_cuda_kernel<<>>(n, c, k, input, idx, weight, output); +} + +void interpolation_backward_cuda_launcher(int n, int c, int k, const float *grad_output, const int *idx, const float *weight, float *grad_input) { + // input: grad_output: (n, c), idx: (n, k), weight: (n, k), output: grad_input (m, c) + dim3 blocks(DIVUP(n * c, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + interpolation_backward_cuda_kernel<<>>(n, c, k, grad_output, idx, weight, grad_input); +} diff --git a/utils/pointops2/src/interpolation/interpolation_cuda_kernel.h b/utils/pointops2/src/interpolation/interpolation_cuda_kernel.h new file mode 100644 index 0000000..309e5dd --- /dev/null +++ b/utils/pointops2/src/interpolation/interpolation_cuda_kernel.h @@ -0,0 +1,20 @@ +#ifndef _INTERPOLATION_CUDA_KERNEL +#define _INTERPOLATION_CUDA_KERNEL +#include +#include +#include + +void interpolation_forward_cuda(int n, int c, int k, at::Tensor input_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor output_tensor); +void interpolation_backward_cuda(int n, int c, int k, at::Tensor grad_output_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_input_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void interpolation_forward_cuda_launcher(int n, int c, int k, const float *input, const int *idx, const float *weight, float *output); +void interpolation_backward_cuda_launcher(int n, int c, int k, const float *grad_output, const int *idx, const float *weight, float *grad_input); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/utils/pointops2/src/knnquery/knnquery_cuda.cpp b/utils/pointops2/src/knnquery/knnquery_cuda.cpp new file mode 100644 index 0000000..568f136 --- /dev/null +++ b/utils/pointops2/src/knnquery/knnquery_cuda.cpp @@ -0,0 +1,17 @@ +#include +#include +#include +#include +#include "knnquery_cuda_kernel.h" + + +void knnquery_cuda(int m, int nsample, at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, at::Tensor offset_tensor, at::Tensor new_offset_tensor, at::Tensor idx_tensor, at::Tensor dist2_tensor) +{ + const float *xyz = xyz_tensor.data_ptr(); + const float *new_xyz = new_xyz_tensor.data_ptr(); + const int *offset = offset_tensor.data_ptr(); + const int *new_offset = new_offset_tensor.data_ptr(); + int *idx = idx_tensor.data_ptr(); + float *dist2 = dist2_tensor.data_ptr(); + knnquery_cuda_launcher(m, nsample, xyz, new_xyz, offset, new_offset, idx, dist2); +} diff --git a/utils/pointops2/src/knnquery/knnquery_cuda_kernel.cu b/utils/pointops2/src/knnquery/knnquery_cuda_kernel.cu new file mode 100644 index 0000000..83762bc --- /dev/null +++ b/utils/pointops2/src/knnquery/knnquery_cuda_kernel.cu @@ -0,0 +1,116 @@ +#include "../cuda_utils.h" +#include "knnquery_cuda_kernel.h" + + +__device__ void swap_float(float *x, float *y) +{ + float tmp = *x; + *x = *y; + *y = tmp; +} + + +__device__ void swap_int(int *x, int *y) +{ + int tmp = *x; + *x = *y; + *y = tmp; +} + + +__device__ void reheap(float *dist, int *idx, int k) +{ + int root = 0; + int child = root * 2 + 1; + while (child < k) + { + if(child + 1 < k && dist[child+1] > dist[child]) + child++; + if(dist[root] > dist[child]) + return; + swap_float(&dist[root], &dist[child]); + swap_int(&idx[root], &idx[child]); + root = child; + child = root * 2 + 1; + } +} + + +__device__ void heap_sort(float *dist, int *idx, int k) +{ + int i; + for (i = k - 1; i > 0; i--) + { + swap_float(&dist[0], &dist[i]); + swap_int(&idx[0], &idx[i]); + reheap(dist, idx, i); + } +} + + +__device__ int get_bt_idx(int idx, const int *offset) +{ + int i = 0; + while (1) + { + if (idx < offset[i]) + break; + else + i++; + } + return i; +} + + +__global__ void knnquery_cuda_kernel(int m, int nsample, const float *__restrict__ xyz, const float *__restrict__ new_xyz, const int *__restrict__ offset, const int *__restrict__ new_offset, int *__restrict__ idx, float *__restrict__ dist2) { + // input: xyz (n, 3) new_xyz (m, 3) + // output: idx (m, nsample) dist2 (m, nsample) + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (pt_idx >= m) return; + + new_xyz += pt_idx * 3; + idx += pt_idx * nsample; + dist2 += pt_idx * nsample; + int bt_idx = get_bt_idx(pt_idx, new_offset); + int start; + if (bt_idx == 0) + start = 0; + else + start = offset[bt_idx - 1]; + int end = offset[bt_idx]; + + float new_x = new_xyz[0]; + float new_y = new_xyz[1]; + float new_z = new_xyz[2]; + + float best_dist[100]; + int best_idx[100]; + for(int i = 0; i < nsample; i++){ + best_dist[i] = 1e10; + best_idx[i] = start; + } + for(int i = start; i < end; i++){ + float x = xyz[i * 3 + 0]; + float y = xyz[i * 3 + 1]; + float z = xyz[i * 3 + 2]; + float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); + if (d2 < best_dist[0]){ + best_dist[0] = d2; + best_idx[0] = i; + reheap(best_dist, best_idx, nsample); + } + } + heap_sort(best_dist, best_idx, nsample); + for(int i = 0; i < nsample; i++){ + idx[i] = best_idx[i]; + dist2[i] = best_dist[i]; + } +} + + +void knnquery_cuda_launcher(int m, int nsample, const float *xyz, const float *new_xyz, const int *offset, const int *new_offset, int *idx, float *dist2) { + // input: new_xyz: (m, 3), xyz: (n, 3), idx: (m, nsample) + dim3 blocks(DIVUP(m, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + knnquery_cuda_kernel<<>>(m, nsample, xyz, new_xyz, offset, new_offset, idx, dist2); +} diff --git a/utils/pointops2/src/knnquery/knnquery_cuda_kernel.h b/utils/pointops2/src/knnquery/knnquery_cuda_kernel.h new file mode 100644 index 0000000..3c0aedf --- /dev/null +++ b/utils/pointops2/src/knnquery/knnquery_cuda_kernel.h @@ -0,0 +1,18 @@ +#ifndef _KNNQUERY_CUDA_KERNEL +#define _KNNQUERY_CUDA_KERNEL +#include +#include +#include + +void knnquery_cuda(int m, int nsample, at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, at::Tensor offset_tensor, at::Tensor new_offset_tensor, at::Tensor idx_tensor, at::Tensor dist2_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void knnquery_cuda_launcher(int m, int nsample, const float *xyz, const float *new_xyz, const int *offset, const int *new_offset, int *idx, float *dist2); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/utils/pointops2/src/pointops_api.cpp b/utils/pointops2/src/pointops_api.cpp new file mode 100644 index 0000000..812789f --- /dev/null +++ b/utils/pointops2/src/pointops_api.cpp @@ -0,0 +1,45 @@ +#include +#include + +#include "knnquery/knnquery_cuda_kernel.h" +#include "sampling/sampling_cuda_kernel.h" +#include "grouping/grouping_cuda_kernel.h" +#include "interpolation/interpolation_cuda_kernel.h" +#include "aggregation/aggregation_cuda_kernel.h" +#include "subtraction/subtraction_cuda_kernel.h" +#include "attention/attention_cuda_kernel.h" +#include "rpe/relative_pos_encoding_cuda_kernel.h" +#include "attention_v2/attention_cuda_kernel_v2.h" +#include "rpe_v2/relative_pos_encoding_cuda_kernel_v2.h" + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("knnquery_cuda", &knnquery_cuda, "knnquery_cuda"); + m.def("furthestsampling_cuda", &furthestsampling_cuda, "furthestsampling_cuda"); + m.def("grouping_forward_cuda", &grouping_forward_cuda, "grouping_forward_cuda"); + m.def("grouping_backward_cuda", &grouping_backward_cuda, "grouping_backward_cuda"); + m.def("interpolation_forward_cuda", &interpolation_forward_cuda, "interpolation_forward_cuda"); + m.def("interpolation_backward_cuda", &interpolation_backward_cuda, "interpolation_backward_cuda"); + m.def("subtraction_forward_cuda", &subtraction_forward_cuda, "subtraction_forward_cuda"); + m.def("subtraction_backward_cuda", &subtraction_backward_cuda, "subtraction_backward_cuda"); + m.def("aggregation_forward_cuda", &aggregation_forward_cuda, "aggregation_forward_cuda"); + m.def("aggregation_backward_cuda", &aggregation_backward_cuda, "aggregation_backward_cuda"); + m.def("attention_step1_forward_cuda", &attention_step1_forward_cuda, "attention_step1_forward_cuda"); + m.def("attention_step1_backward_cuda", &attention_step1_backward_cuda, "attention_step1_backward_cuda"); + m.def("attention_step2_forward_cuda", &attention_step2_forward_cuda, "attention_step2_forward_cuda"); + m.def("attention_step2_backward_cuda", &attention_step2_backward_cuda, "attention_step2_backward_cuda"); + m.def("dot_prod_with_idx_forward_cuda", &dot_prod_with_idx_forward_cuda, "dot_prod_with_idx_forward_cuda"); + m.def("dot_prod_with_idx_backward_cuda", &dot_prod_with_idx_backward_cuda, "dot_prod_with_idx_backward_cuda"); + m.def("attention_step2_with_rel_pos_value_forward_cuda", &attention_step2_with_rel_pos_value_forward_cuda, "attention_step2_with_rel_pos_value_forward_cuda"); + m.def("attention_step2_with_rel_pos_value_backward_cuda", &attention_step2_with_rel_pos_value_backward_cuda, "attention_step2_with_rel_pos_value_backward_cuda"); + m.def("attention_step1_forward_cuda_v2", &attention_step1_forward_cuda_v2, "attention_step1_forward_cuda_v2"); + m.def("attention_step1_backward_cuda_v2", &attention_step1_backward_cuda_v2, "attention_step1_backward_cuda_v2"); + m.def("attention_step2_forward_cuda_v2", &attention_step2_forward_cuda_v2, "attention_step2_forward_cuda_v2"); + m.def("attention_step2_backward_cuda_v2", &attention_step2_backward_cuda_v2, "attention_step2_backward_cuda_v2"); + m.def("dot_prod_with_idx_forward_cuda_v2", &dot_prod_with_idx_forward_cuda_v2, "dot_prod_with_idx_forward_cuda_v2"); + m.def("dot_prod_with_idx_backward_cuda_v2", &dot_prod_with_idx_backward_cuda_v2, "dot_prod_with_idx_backward_cuda_v2"); + m.def("attention_step2_with_rel_pos_value_forward_cuda_v2", &attention_step2_with_rel_pos_value_forward_cuda_v2, "attention_step2_with_rel_pos_value_forward_cuda_v2"); + m.def("attention_step2_with_rel_pos_value_backward_cuda_v2", &attention_step2_with_rel_pos_value_backward_cuda_v2, "attention_step2_with_rel_pos_value_backward_cuda_v2"); + m.def("dot_prod_with_idx_forward_cuda_v3", &dot_prod_with_idx_forward_cuda_v3, "dot_prod_with_idx_forward_cuda_v3"); + m.def("dot_prod_with_idx_backward_cuda_v3", &dot_prod_with_idx_backward_cuda_v3, "dot_prod_with_idx_backward_cuda_v3"); + } diff --git a/utils/pointops2/src/rpe/relative_pos_encoding_cuda.cpp b/utils/pointops2/src/rpe/relative_pos_encoding_cuda.cpp new file mode 100644 index 0000000..634ebb0 --- /dev/null +++ b/utils/pointops2/src/rpe/relative_pos_encoding_cuda.cpp @@ -0,0 +1,60 @@ +#include +#include +#include +#include +#include "relative_pos_encoding_cuda_kernel.h" + +void dot_prod_with_idx_forward_cuda(int N, int M, int h, int hdim, at::Tensor q_tensor, at::Tensor index_tensor, + at::Tensor table_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor) +{ + const float *q = q_tensor.data_ptr(); + const float *table = table_tensor.data_ptr(); + const int *index = index_tensor.data_ptr(); + const int *rel_idx = rel_idx_tensor.data_ptr(); + float *output = output_tensor.data_ptr(); + dot_prod_with_idx_forward_cuda_launcher(N, M, h, hdim, q, index, table, rel_idx, output); +} + +void dot_prod_with_idx_backward_cuda(int N, int M, int h, int hdim, at::Tensor grad_out_tensor, + at::Tensor q_tensor, at::Tensor index_tensor, at::Tensor table_tensor, at::Tensor rel_idx_tensor, + at::Tensor grad_q_tensor, at::Tensor grad_table_tensor) +{ + const float *grad_out = grad_out_tensor.data_ptr(); + const float *q = q_tensor.data_ptr(); + const int *index = index_tensor.data_ptr(); + const float *table = table_tensor.data_ptr(); + const int *rel_idx = rel_idx_tensor.data_ptr(); + float *grad_q = grad_q_tensor.data_ptr(); + float *grad_table = grad_table_tensor.data_ptr(); + dot_prod_with_idx_backward_cuda_launcher(N, M, h, hdim, grad_out, q, index, table, rel_idx, grad_q, grad_table); +} + +void attention_step2_with_rel_pos_value_forward_cuda(int N, int M, int h, int hdim, at::Tensor attn_tensor, at::Tensor v_tensor, + at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor table_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor) +{ + const float *attn = attn_tensor.data_ptr(); + const float *v = v_tensor.data_ptr(); + const int *index0 = index0_tensor.data_ptr(); + const int *index1 = index1_tensor.data_ptr(); + const float *table = table_tensor.data_ptr(); + const int *rel_idx = rel_idx_tensor.data_ptr(); + float *output = output_tensor.data_ptr(); + attention_step2_with_rel_pos_value_forward_cuda_launcher(N, M, h, hdim, attn, v, index0, index1, table, rel_idx, output); +} + +void attention_step2_with_rel_pos_value_backward_cuda(int N, int M, int h, int hdim, at::Tensor grad_out_tensor, + at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor table_tensor, + at::Tensor rel_idx_tensor, at::Tensor grad_attn_tensor, at::Tensor grad_v_tensor, at::Tensor grad_table_tensor) +{ + const float *grad_out = grad_out_tensor.data_ptr(); + const int *index0 = index0_tensor.data_ptr(); + const int *index1 = index1_tensor.data_ptr(); + const float *attn = attn_tensor.data_ptr(); + const float *v = v_tensor.data_ptr(); + const float *table = table_tensor.data_ptr(); + const int *rel_idx = rel_idx_tensor.data_ptr(); + float *grad_attn = grad_attn_tensor.data_ptr(); + float *grad_v = grad_v_tensor.data_ptr(); + float *grad_table = grad_table_tensor.data_ptr(); + attention_step2_with_rel_pos_value_backward_cuda_launcher(N, M, h, hdim, grad_out, index0, index1, attn, v, table, rel_idx, grad_attn, grad_v, grad_table); +} diff --git a/utils/pointops2/src/rpe/relative_pos_encoding_cuda_kernel.cu b/utils/pointops2/src/rpe/relative_pos_encoding_cuda_kernel.cu new file mode 100644 index 0000000..b8fd8f4 --- /dev/null +++ b/utils/pointops2/src/rpe/relative_pos_encoding_cuda_kernel.cu @@ -0,0 +1,134 @@ +#include "../cuda_utils.h" +#include "relative_pos_encoding_cuda_kernel.h" + + +__global__ void dot_prod_with_idx_forward_cuda_kernel( // M, h, hdim + int N, int M, int h, int hdim, const float *q, const int *index, + const float *table, const int *rel_idx, float *output) { + // input: q: (N, h, hdim), index: (M), table: (L, h, hdim, 3), rel_idx: (M, 3), output: (M, h) + + int c_idx = blockIdx.z; + int h_idx = blockIdx.y; + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (thread_idx >= M*3 || h_idx >= h || c_idx >= hdim) return; + + int dim = thread_idx % 3; + int m_idx = thread_idx / 3; + + int q_idx = index[m_idx]; + int rel_idx_dim = rel_idx[thread_idx]; + float rel_table_val = table[rel_idx_dim*h*hdim*3+h_idx*hdim*3+c_idx*3+dim]; + float val = q[q_idx*h*hdim+h_idx*hdim+c_idx] * rel_table_val; + atomicAdd(output+m_idx*h+h_idx, val); +} + +__global__ void dot_prod_with_idx_backward_cuda_kernel( // M, h, hdim + int N, int M, int h, int hdim, const float *grad_out, const float *q, const int *index, + const float *table, const int *rel_idx, float *grad_q, float *grad_table) { + + int c_idx = blockIdx.z; + int h_idx = blockIdx.y; + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (thread_idx >= M*3 || h_idx >= h || c_idx >= hdim) return; + + int dim = thread_idx % 3; + int m_idx = thread_idx / 3; + + int q_idx = index[m_idx]; + int rel_idx_dim = rel_idx[thread_idx]; + int grad_out_idx = m_idx*h+h_idx; + float grad_out_value = grad_out[grad_out_idx]; + + float rel_table_val = table[rel_idx_dim*h*hdim*3+h_idx*hdim*3+c_idx*3+dim]; + atomicAdd(grad_q+q_idx*h*hdim+h_idx*hdim+c_idx, grad_out_value * rel_table_val); + + float q_value = q[q_idx*h*hdim+h_idx*hdim+c_idx]; + atomicAdd(grad_table+rel_idx_dim*h*hdim*3+h_idx*hdim*3+c_idx*3+dim, grad_out_value * q_value); +} + +void dot_prod_with_idx_forward_cuda_launcher(int N, int M, int h, int hdim, const float *q, const int *index, + const float *table, const int *rel_idx, float *output) { + // input: q: (N, h, hdim), index: (M), table: (L, h, hdim, 3), rel_idx: (M, 3) + //dim3 blocks(DIVUP(hdim, THREADS_PER_BLOCK), h, M); + dim3 blocks(DIVUP(M*3, THREADS_PER_BLOCK), h, hdim); + dim3 threads(THREADS_PER_BLOCK); + dot_prod_with_idx_forward_cuda_kernel<<>>(N, M, h, hdim, q, index, table, rel_idx, output); +} + +void dot_prod_with_idx_backward_cuda_launcher(int N, int M, int h, int hdim, const float *grad_out, + const float *q, const int *index, const float *table, const int *rel_idx, float *grad_q, float *grad_table) { + // input: grad_out: (M, h), output: grad_q: (N, h, hdim), grad_table: (L, h, hdim, 3) + //dim3 blocks(DIVUP(hdim, THREADS_PER_BLOCK), h, M); + dim3 blocks(DIVUP(M*3, THREADS_PER_BLOCK), h, hdim); + dim3 threads(THREADS_PER_BLOCK); + dot_prod_with_idx_backward_cuda_kernel<<>>(N, M, h, hdim, grad_out, q, index, table, rel_idx, grad_q, grad_table); +} + +__global__ void attention_step2_with_rel_pos_value_forward_cuda_kernel( // M, h, hdim + int N, int M, int h, int hdim, const float *attn, const float *v, + const int *index0, const int *index1, const float *table, const int *rel_idx, float *output) { + // input: attn: (M, h), v: (N, h, hdim), index0: (M, ), index1: (M, ), table: (L, h, hdim, 3), rel_idx: (M, 3) + + int c_idx = blockIdx.z; + int h_idx = blockIdx.y; + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (thread_idx >= M*3 || h_idx >= h || c_idx >= hdim) return; + + int dim = thread_idx % 3; + int m_idx = thread_idx / 3; + + int idx1 = index1[m_idx]; + + int rel_idx_dim = rel_idx[thread_idx]; + float table_val = table[rel_idx_dim*h*hdim*3+h_idx*hdim*3+c_idx*3+dim]; + + float val = attn[m_idx*h+h_idx] * (v[idx1*h*hdim+h_idx*hdim+c_idx] / 3.0 + table_val); + + int idx0 = index0[m_idx]; + atomicAdd(output+idx0*h*hdim+h_idx*hdim+c_idx, val); +} + + +__global__ void attention_step2_with_rel_pos_value_backward_cuda_kernel( // M, h, hdim + int N, int M, int h, int hdim, const float *grad_out, const int *index0, const int *index1, const float *attn, const float *v, const float *table, + const int *rel_idx, float *grad_attn, float *grad_v, float *grad_table) { + // input: attn: (M, h), v: (N, h, hdim), index0: (M, ), index1: (M, ), table: (L, h, hdim, 3), rel_idx: (M, 3) + + int c_idx = blockIdx.z; + int h_idx = blockIdx.y; + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (thread_idx >= M*3 || h_idx >= h || c_idx >= hdim) return; + + int dim = thread_idx % 3; + int m_idx = thread_idx / 3; + + int idx0 = index0[m_idx]; + int idx1 = index1[m_idx]; + int grad_out_idx = idx0*h*hdim+h_idx*hdim+c_idx; + + int rel_idx_dim = rel_idx[thread_idx]; + float table_val = table[rel_idx_dim*h*hdim*3+h_idx*hdim*3+c_idx*3+dim]; + float grad_out_value = grad_out[grad_out_idx]; + + atomicAdd(grad_attn+m_idx*h+h_idx, grad_out_value * (v[idx1*h*hdim+h_idx*hdim+c_idx]/3 + table_val)); + atomicAdd(grad_v+idx1*h*hdim+h_idx*hdim+c_idx, grad_out_value * attn[m_idx*h+h_idx]/3); + atomicAdd(grad_table+rel_idx_dim*h*hdim*3+h_idx*hdim*3+c_idx*3+dim, grad_out_value * attn[m_idx*h+h_idx]); +} + +void attention_step2_with_rel_pos_value_forward_cuda_launcher(int N, int M, int h, int hdim, const float *attn, const float *v, const int *index0, + const int *index1, const float *table, const int *rel_idx, float *output) { + // input: attn: (M, h), v: (N, h, hdim), index0: (M, ), index1: (M, ), table: (L, h, hdim, 3), rel_idx: (M, 3) + //dim3 blocks(DIVUP(hdim, THREADS_PER_BLOCK), h, M); + dim3 blocks(DIVUP(M*3, THREADS_PER_BLOCK), h, hdim); + dim3 threads(THREADS_PER_BLOCK); + attention_step2_with_rel_pos_value_forward_cuda_kernel<<>>(N, M, h, hdim, attn, v, index0, index1, table, rel_idx, output); +} + +void attention_step2_with_rel_pos_value_backward_cuda_launcher(int N, int M, int h, int hdim, const float *grad_out, const int *index0, + const int *index1, const float *attn, const float *v, const float *table, const int *rel_idx, float *grad_attn, float *grad_v, float *grad_table) { + // input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c) + //dim3 blocks(DIVUP(hdim, THREADS_PER_BLOCK), h, M); + dim3 blocks(DIVUP(M*3, THREADS_PER_BLOCK), h, hdim); + dim3 threads(THREADS_PER_BLOCK); + attention_step2_with_rel_pos_value_backward_cuda_kernel<<>>(N, M, h, hdim, grad_out, index0, index1, attn, v, table, rel_idx, grad_attn, grad_v, grad_table); +} diff --git a/utils/pointops2/src/rpe/relative_pos_encoding_cuda_kernel.h b/utils/pointops2/src/rpe/relative_pos_encoding_cuda_kernel.h new file mode 100644 index 0000000..cafc7b6 --- /dev/null +++ b/utils/pointops2/src/rpe/relative_pos_encoding_cuda_kernel.h @@ -0,0 +1,26 @@ +#ifndef _RPE_CUDA_KERNEL +#define _RPE_CUDA_KERNEL +#include +#include +#include + +void dot_prod_with_idx_forward_cuda(int N, int M, int h, int hdim, at::Tensor q_tensor, at::Tensor index_tensor, at::Tensor table_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor); +void dot_prod_with_idx_backward_cuda(int N, int M, int h, int hdim, at::Tensor grad_out_tensor, at::Tensor q_tensor, at::Tensor index_tensor, at::Tensor table_tensor, at::Tensor rel_idx_tensor, at::Tensor grad_q_tensor, at::Tensor grad_table_tensor); + +void attention_step2_with_rel_pos_value_forward_cuda(int N, int M, int h, int hdim, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor table_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor); +void attention_step2_with_rel_pos_value_backward_cuda(int N, int M, int h, int hdim, at::Tensor grad_out_tensor, at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor table_tensor, at::Tensor rel_idx_tensor, at::Tensor grad_attn_tensor, at::Tensor grad_v_tensor, at::Tensor grad_table_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void dot_prod_with_idx_forward_cuda_launcher(int N, int M, int h, int hdim, const float *q, const int *index, const float *table, const int *rel_idx, float *output); +void dot_prod_with_idx_backward_cuda_launcher(int N, int M, int h, int hdim, const float *grad_out, const float *q, const int *index, const float *table, const int *rel_idx, float *grad_q, float *grad_table); + +void attention_step2_with_rel_pos_value_forward_cuda_launcher(int N, int M, int h, int hdim, const float *attn, const float *v, const int *index0, const int *index1, const float *table, const int *rel_idx, float *output); +void attention_step2_with_rel_pos_value_backward_cuda_launcher(int N, int M, int h, int hdim, const float *grad_out, const int *index0, const int *index1, const float *attn, const float *v, const float *table, const int *rel_idx, float *grad_attn, float *grad_v, float *grad_table); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/utils/pointops2/src/rpe_v2/relative_pos_encoding_cuda_kernel_v2.cu b/utils/pointops2/src/rpe_v2/relative_pos_encoding_cuda_kernel_v2.cu new file mode 100644 index 0000000..628d8e3 --- /dev/null +++ b/utils/pointops2/src/rpe_v2/relative_pos_encoding_cuda_kernel_v2.cu @@ -0,0 +1,525 @@ +#include "../cuda_utils.h" +#include "relative_pos_encoding_cuda_kernel_v2.h" + + +// N, M, h, q, index_q, k, index_k, table_q, table_k, rel_idx, rel_idx_offsets, output + +template +__global__ void dot_prod_with_idx_forward_cuda_kernel_v2( // M, h, hdim + int N, int M, int h, const float *q, const int *index_q, const float *k, const int *index_k, + const float *table_q, const float *table_k, const int *rel_idx, const int *rel_idx_offsets, + const int *sort_indices, float *output) { + // input: q: (N, h, hdim), index: (M), table: (L, h, hdim, 3), rel_idx: (M, 3), output: (M, h) + + int h_idx = blockIdx.y; + int t_idx = blockIdx.x; + int n_idx = threadIdx.x; + int C = h*d; + + __shared__ int start, end; + if(n_idx == 0){ + start = rel_idx_offsets[t_idx]; + end = rel_idx_offsets[t_idx+1]; + // printf("e2: start: %d, end: %d\n", start, end); + } + + __syncthreads(); + + int m_idx_prev = start + n_idx; + // if(m_idx_prev >= end) + // return; + + __shared__ int m_idx; + if(n_idx == 0) + m_idx = sort_indices[m_idx_prev]; + + __syncthreads(); + + __shared__ int rel_idx_vec[3]; + if(n_idx < 3) + rel_idx_vec[n_idx] = rel_idx[m_idx*3 + n_idx]; + + __syncthreads(); + + __shared__ float table_q_vec[d]; + __shared__ float table_k_vec[d]; + + for(int i = n_idx; i < 2*d; i += blockDim.x){ + if (i < d){ + int ind0 = rel_idx_vec[0] * C * 3 + h_idx * d * 3 + i * 3 + 0; + int ind1 = rel_idx_vec[1] * C * 3 + h_idx * d * 3 + i * 3 + 1; + int ind2 = rel_idx_vec[2] * C * 3 + h_idx * d * 3 + i * 3 + 2; + table_q_vec[i] = table_q[ind0] + table_q[ind1] + table_q[ind2]; + } else{ + int ind0 = rel_idx_vec[0] * C * 3 + h_idx * d * 3 + (i-d) * 3 + 0; + int ind1 = rel_idx_vec[1] * C * 3 + h_idx * d * 3 + (i-d) * 3 + 1; + int ind2 = rel_idx_vec[2] * C * 3 + h_idx * d * 3 + (i-d) * 3 + 2; + table_k_vec[i-d] = table_k[ind0] + table_k[ind1] + table_k[ind2]; + } + } + + __syncthreads(); + + for(int i = m_idx_prev; i < end; i += blockDim.x){ + float sum = 0; + int m_idx_i = sort_indices[i]; + int q_idx = index_q[m_idx_i]; + int k_idx = index_k[m_idx_i]; + for(int j = 0; j < d; j++){ + sum += q[q_idx*C + h_idx*d + j] * table_q_vec[j]; + sum += k[k_idx*C + h_idx*d + j] * table_k_vec[j]; + } + output[m_idx_i*h + h_idx] = sum; + } +} + +// N, M, h, hdim, grad_out, q, index_q, k, index_k, table_q, table_k, rel_idx, rel_idx_offsets, sort_indices, grad_q, grad_k, grad_table_q, grad_table_k + +template +__global__ void dot_prod_with_idx_backward_cuda_kernel_v2( // M, h, hdim + int N, int M, int h, const float *grad_out, const float *q, const int *index_q, + const float *k, const int *index_k, const float *table_q, const float *table_k, + const int *rel_idx, const int *rel_idx_offsets, const int *sort_indices, float *grad_q, + float *grad_k, float *grad_table_q, float *grad_table_k) { + + int h_idx = blockIdx.y; + int t_idx = blockIdx.x; + int n_idx = threadIdx.x; + int C = h*d; + + __shared__ int start, end; + if(n_idx == 0){ + start = rel_idx_offsets[t_idx]; + end = rel_idx_offsets[t_idx+1]; + } + + __syncthreads(); + + int m_idx_prev = start + n_idx; + // if(m_idx_prev >= end) + // return; + + __shared__ int m_idx; + if(n_idx == 0) + m_idx = sort_indices[m_idx_prev]; + + __syncthreads(); + + __shared__ int rel_idx_vec[3]; + if(n_idx < 3) + rel_idx_vec[n_idx] = rel_idx[m_idx*3 + n_idx]; + + __syncthreads(); + + __shared__ float table_q_vec[d]; + __shared__ float table_k_vec[d]; + + for(int i = n_idx; i < 2*d; i += blockDim.x){ + if (i < d){ + int ind0 = rel_idx_vec[0] * C * 3 + h_idx * d * 3 + i * 3 + 0; + int ind1 = rel_idx_vec[1] * C * 3 + h_idx * d * 3 + i * 3 + 1; + int ind2 = rel_idx_vec[2] * C * 3 + h_idx * d * 3 + i * 3 + 2; + table_q_vec[i] = table_q[ind0] + table_q[ind1] + table_q[ind2]; + } else{ + int ind0 = rel_idx_vec[0] * C * 3 + h_idx * d * 3 + (i-d) * 3 + 0; + int ind1 = rel_idx_vec[1] * C * 3 + h_idx * d * 3 + (i-d) * 3 + 1; + int ind2 = rel_idx_vec[2] * C * 3 + h_idx * d * 3 + (i-d) * 3 + 2; + table_k_vec[i-d] = table_k[ind0] + table_k[ind1] + table_k[ind2]; + } + } + + __shared__ float gradient_q[d]; + __shared__ float gradient_k[d]; + for(int i = n_idx; i < d; i += blockDim.x){ + gradient_q[i] = 0; + gradient_k[i] = 0; + } + + __syncthreads(); + + for(int i = m_idx_prev; i < end; i += blockDim.x){ + int m_idx_i = sort_indices[i]; + int q_idx = index_q[m_idx_i]; + int k_idx = index_k[m_idx_i]; + float grad_out_i = grad_out[m_idx_i*h+h_idx]; + for(int j = 0; j < d; j++){ + atomicAdd(&gradient_q[j], q[q_idx*C + h_idx*d + j] * grad_out_i); + atomicAdd(&gradient_k[j], k[k_idx*C + h_idx*d + j] * grad_out_i); + atomicAdd(grad_q + q_idx*C + h_idx*d + j, table_q_vec[j] * grad_out_i); + atomicAdd(grad_k + k_idx*C + h_idx*d + j, table_k_vec[j] * grad_out_i); + } + } + + __syncthreads(); + + for(int i = n_idx; i < d*2; i += blockDim.x){ + if(i < d){ + atomicAdd(grad_table_q + rel_idx_vec[0] * C * 3 + h_idx * d * 3 + i * 3, gradient_q[i]); + atomicAdd(grad_table_q + rel_idx_vec[1] * C * 3 + h_idx * d * 3 + i * 3 + 1, gradient_q[i]); + atomicAdd(grad_table_q + rel_idx_vec[2] * C * 3 + h_idx * d * 3 + i * 3 + 2, gradient_q[i]); + }else{ + atomicAdd(grad_table_k + rel_idx_vec[0] * C * 3 + h_idx * d * 3 + (i-d) * 3, gradient_k[i-d]); + atomicAdd(grad_table_k + rel_idx_vec[1] * C * 3 + h_idx * d * 3 + (i-d) * 3 + 1, gradient_k[i-d]); + atomicAdd(grad_table_k + rel_idx_vec[2] * C * 3 + h_idx * d * 3 + (i-d) * 3 + 2, gradient_k[i-d]); + } + } + + // int c_idx = blockIdx.z; + // int h_idx = blockIdx.y; + // int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + // if (thread_idx >= M*3 || h_idx >= h || c_idx >= hdim) return; + + // int dim = thread_idx % 3; + // int m_idx = thread_idx / 3; + + // int q_idx = index[m_idx]; + // int rel_idx_dim = rel_idx[thread_idx]; + // int grad_out_idx = m_idx*h+h_idx; + // float grad_out_value = grad_out[grad_out_idx]; + + // float rel_table_val = table[rel_idx_dim*h*hdim*3+h_idx*hdim*3+c_idx*3+dim]; + // atomicAdd(grad_q+q_idx*h*hdim+h_idx*hdim+c_idx, grad_out_value * rel_table_val); + + // float q_value = q[q_idx*h*hdim+h_idx*hdim+c_idx]; + // atomicAdd(grad_table+rel_idx_dim*h*hdim*3+h_idx*hdim*3+c_idx*3+dim, grad_out_value * q_value); +} + +void dot_prod_with_idx_forward_cuda_launcher_v2(int N, int M, int h, int hdim, int n_max, int T, const float *q, + const int *index_q, const float *k, const int *index_k, const float *table_q, const float *table_k, + const int *rel_idx, const int *rel_idx_offsets, const int *sort_indices, float *output) +{ + // input: q: (N, h, hdim), index: (M), table: (L, h, hdim, 3), rel_idx: (M, 3) + //dim3 blocks(DIVUP(hdim, THREADS_PER_BLOCK), h, M); + dim3 blocks(T, h); + // dim3 threads(THREADS_PER_BLOCK); + + unsigned int n_threads = opt_n_threads(n_max); + n_threads = n_threads == n_max ? n_threads : n_threads * 2; + n_threads = n_threads > 1024 ? 512 : n_threads; + + // printf("e1: T: %d, h: %d, n_threads: %d\n", T, h, n_threads); + + switch (hdim) { + case 16: + dot_prod_with_idx_forward_cuda_kernel_v2<16><<>>(N, M, h, q, index_q, k, index_k, table_q, table_k, rel_idx, rel_idx_offsets, sort_indices, output); + break; + case 32: + dot_prod_with_idx_forward_cuda_kernel_v2<32><<>>(N, M, h, q, index_q, k, index_k, table_q, table_k, rel_idx, rel_idx_offsets, sort_indices, output); + break; + default: + throw "d != 16 and d != 32"; + } +} + +void dot_prod_with_idx_backward_cuda_launcher_v2(int N, int M, int h, int hdim, int n_max, int T, + const float *grad_out, const float *q, const int *index_q, const float *k, const int *index_k, + const float *table_q, const float *table_k, const int *rel_idx, const int *rel_idx_offsets, const int *sort_indices, + float *grad_q, float *grad_k, float *grad_table_q, float *grad_table_k) +{ + // input: grad_out: (M, h), output: grad_q: (N, h, hdim), grad_table: (L, h, hdim, 3) + //dim3 blocks(DIVUP(hdim, THREADS_PER_BLOCK), h, M); + // dim3 blocks(DIVUP(M*3, THREADS_PER_BLOCK), h, hdim); + // dim3 threads(THREADS_PER_BLOCK); + + dim3 blocks(T, h); + // dim3 threads(THREADS_PER_BLOCK); + + unsigned int n_threads = opt_n_threads(n_max); + n_threads = n_threads == n_max ? n_threads : n_threads * 2; + n_threads = n_threads > 1024 ? 512 : n_threads; + + switch (hdim) { + case 16: + dot_prod_with_idx_backward_cuda_kernel_v2<16><<>>(N, M, h, grad_out, q, index_q, k, index_k, table_q, table_k, rel_idx, rel_idx_offsets, sort_indices, grad_q, grad_k, grad_table_q, grad_table_k); + break; + case 32: + dot_prod_with_idx_backward_cuda_kernel_v2<32><<>>(N, M, h, grad_out, q, index_q, k, index_k, table_q, table_k, rel_idx, rel_idx_offsets, sort_indices, grad_q, grad_k, grad_table_q, grad_table_k); + break; + default: + throw "d != 16 and d != 32"; + } +} + + + +template +__global__ void dot_prod_with_idx_forward_cuda_kernel_v3( // M, h, hdim + int N, int M, int h, const float *q, const int *index_q_offsets, const float *k, const int *index_k, + const float *table_q, const float *table_k, const int *rel_idx, float *output) { + // input: q: (N, h, hdim), index: (M), table: (L, h, hdim, 3), rel_idx: (M, 3), output: (M, h) + int q_idx = blockIdx.x; + int h_idx = blockIdx.y; + int n_idx = threadIdx.x; + int C = h*d; + + __shared__ float query_vec[d]; + __shared__ int start, end; + if (n_idx == 0){ + start = index_q_offsets[q_idx]; + end = index_q_offsets[q_idx+1]; + } + for(int i = n_idx; i < d; i += blockDim.x) + query_vec[i] = q[q_idx*C + h_idx*d + i]; + + __syncthreads(); + + int m_idx = start + n_idx; + if(m_idx >= end) + return; + + int k_idx = index_k[m_idx]; + int r_idx1 = rel_idx[m_idx*3], r_idx2 = rel_idx[m_idx*3+1], r_idx3 = rel_idx[m_idx*3+2]; + float sum = 0; + for(int i = 0; i < d; i++){ + float table_q_scalar_i = table_q[r_idx1*C*3+h_idx*d*3+i*3] + table_q[r_idx2*C*3+h_idx*d*3+i*3+1] + table_q[r_idx3*C*3+h_idx*d*3+i*3+2]; + sum += query_vec[i] * table_q_scalar_i; + float table_k_scalar_i = table_k[r_idx1*C*3+h_idx*d*3+i*3] + table_k[r_idx2*C*3+h_idx*d*3+i*3+1] + table_k[r_idx3*C*3+h_idx*d*3+i*3+2]; + sum += k[k_idx*C+h_idx*d+i] * table_k_scalar_i; + } + output[m_idx*h + h_idx] = sum; + +} + +// N, M, h, hdim, grad_out, q, index_q, k, index_k, table_q, table_k, rel_idx, rel_idx_offsets, sort_indices, grad_q, grad_k, grad_table_q, grad_table_k + +template +__global__ void dot_prod_with_idx_backward_cuda_kernel_v3( // M, h, hdim + int N, int M, int h, const float *grad_out, const float *q, const int *index_q_offsets, + const float *k, const int *index_k, const float *table_q, const float *table_k, + const int *rel_idx, float *grad_q, float *grad_k, float *grad_table_q, float *grad_table_k) { + + int q_idx = blockIdx.x; + int h_idx = blockIdx.y; + int n_idx = threadIdx.x; + int C = h*d; + + __shared__ float query_vec[d]; + __shared__ int start, end; + if (n_idx == 0){ + start = index_q_offsets[q_idx]; + end = index_q_offsets[q_idx+1]; + } + for(int i = n_idx; i < d; i += blockDim.x) + query_vec[i] = q[q_idx*C + h_idx*d + i]; + + __shared__ float gradients_q[d]; + for(int i = n_idx; i < d; i += blockDim.x){ + gradients_q[i] = 0; + } + + __syncthreads(); + + int m_idx = start + n_idx; + + if(m_idx < end){ + int k_idx = index_k[m_idx]; + int r_idx1 = rel_idx[m_idx*3], r_idx2 = rel_idx[m_idx*3+1], r_idx3 = rel_idx[m_idx*3+2]; + float gradient = grad_out[m_idx*h + h_idx]; + for(int i = 0; i < d; i++){ + float table_q_scalar_i = table_q[r_idx1*C*3+h_idx*d*3+i*3] + table_q[r_idx2*C*3+h_idx*d*3+i*3+1] + table_q[r_idx3*C*3+h_idx*d*3+i*3+2]; + float table_k_scalar_i = table_k[r_idx1*C*3+h_idx*d*3+i*3] + table_k[r_idx2*C*3+h_idx*d*3+i*3+1] + table_k[r_idx3*C*3+h_idx*d*3+i*3+2]; + float q_scalar_i = query_vec[i]; + float k_scalar_i = k[k_idx*C+h_idx*d+i]; + atomicAdd(&gradients_q[i], table_q_scalar_i * gradient); + atomicAdd(grad_k+k_idx*C+h_idx*d+i, table_k_scalar_i * gradient); + atomicAdd(grad_table_q+r_idx1*C*3+h_idx*d*3+i*3, q_scalar_i * gradient); + atomicAdd(grad_table_q+r_idx2*C*3+h_idx*d*3+i*3+1, q_scalar_i * gradient); + atomicAdd(grad_table_q+r_idx3*C*3+h_idx*d*3+i*3+2, q_scalar_i * gradient); + atomicAdd(grad_table_k+r_idx1*C*3+h_idx*d*3+i*3, k_scalar_i * gradient); + atomicAdd(grad_table_k+r_idx2*C*3+h_idx*d*3+i*3+1, k_scalar_i * gradient); + atomicAdd(grad_table_k+r_idx3*C*3+h_idx*d*3+i*3+2, k_scalar_i * gradient); + } + } + __syncthreads(); + + for(int i = n_idx; i < d; i += blockDim.x){ + grad_q[q_idx*C+h_idx*d+i] = gradients_q[i]; + } +} + +void dot_prod_with_idx_forward_cuda_launcher_v3(int N, int M, int h, int hdim, int n_max, const float *q, + const int *index_q_offsets, const float *k, const int *index_k, const float *table_q, const float *table_k, + const int *rel_idx, float *output) +{ + // input: q: (N, h, hdim), index: (M), table: (L, h, hdim, 3), rel_idx: (M, 3) + //dim3 blocks(DIVUP(hdim, THREADS_PER_BLOCK), h, M); + dim3 blocks(N, h); + // dim3 threads(THREADS_PER_BLOCK); + + unsigned int n_threads = opt_n_threads(n_max); + n_threads = n_threads == n_max ? n_threads : n_threads * 2; + + // printf("e1: h: %d, n_max: %d, n_threads: %d\n", h, n_max, n_threads); + + switch (hdim) { + case 16: + dot_prod_with_idx_forward_cuda_kernel_v3<16><<>>(N, M, h, q, index_q_offsets, k, index_k, table_q, table_k, rel_idx, output); + break; + case 32: + dot_prod_with_idx_forward_cuda_kernel_v3<32><<>>(N, M, h, q, index_q_offsets, k, index_k, table_q, table_k, rel_idx, output); + break; + default: + throw "d != 16 and d != 32"; + } +} + +void dot_prod_with_idx_backward_cuda_launcher_v3(int N, int M, int h, int hdim, int n_max, + const float *grad_out, const float *q, const int *index_q_offsets, const float *k, const int *index_k, + const float *table_q, const float *table_k, const int *rel_idx, + float *grad_q, float *grad_k, float *grad_table_q, float *grad_table_k) +{ + // input: grad_out: (M, h), output: grad_q: (N, h, hdim), grad_table: (L, h, hdim, 3) + //dim3 blocks(DIVUP(hdim, THREADS_PER_BLOCK), h, M); + // dim3 blocks(DIVUP(M*3, THREADS_PER_BLOCK), h, hdim); + // dim3 threads(THREADS_PER_BLOCK); + + dim3 blocks(N, h); + // dim3 threads(THREADS_PER_BLOCK); + + unsigned int n_threads = opt_n_threads(n_max); + n_threads = n_threads == n_max ? n_threads : n_threads * 2; + + switch (hdim) { + case 16: + dot_prod_with_idx_backward_cuda_kernel_v3<16><<>>(N, M, h, grad_out, q, index_q_offsets, k, index_k, table_q, table_k, rel_idx, grad_q, grad_k, grad_table_q, grad_table_k); + break; + case 32: + dot_prod_with_idx_backward_cuda_kernel_v3<32><<>>(N, M, h, grad_out, q, index_q_offsets, k, index_k, table_q, table_k, rel_idx, grad_q, grad_k, grad_table_q, grad_table_k); + break; + default: + throw "d != 16 and d != 32"; + } +} + + +template +__global__ void attention_step2_with_rel_pos_value_forward_cuda_kernel_v2( // M, h, hdim + int N, int M, int h, const float *attn, const float *v, + const int *index0_offsets, const int *index1, const float *table, const int *rel_idx, float *output) { + // input: attn: (M, h), v: (N, h, hdim), index0: (M, ), index1: (M, ), table: (L, h, hdim, 3), rel_idx: (M, 3) + + int q_idx = blockIdx.x; + int h_idx = blockIdx.y; + int n_idx = threadIdx.x; + + int C = h*d; + + __shared__ int start, end; + __shared__ float result[d]; + + if (n_idx == 0){ + start = index0_offsets[q_idx]; + end = index0_offsets[q_idx+1]; + } + for (int i = n_idx; i < d; i += blockDim.x){ + result[i] = 0; + } + + __syncthreads(); + + int m_idx = start + n_idx; + if (m_idx < end){ + float attn_scalar = attn[m_idx*h + h_idx]; + int r_idx1 = rel_idx[m_idx*3], r_idx2 = rel_idx[m_idx*3+1], r_idx3 = rel_idx[m_idx*3+2]; + for(int i = 0; i < d; i ++){ + int v_idx = index1[m_idx]; + float table_scaler_i = table[r_idx1*C*3+h_idx*d*3+i*3] + table[r_idx2*C*3+h_idx*d*3+i*3+1] + table[r_idx3*C*3+h_idx*d*3+i*3+2]; + float value_scaler_i = v[v_idx*C + h_idx*d + i]; + atomicAdd(&result[i], (table_scaler_i + value_scaler_i) * attn_scalar); + } + } + + __syncthreads(); + + for (int i = n_idx; i < d; i += blockDim.x) + output[q_idx*C + h_idx*d + i] = result[i]; +} + + +template +__global__ void attention_step2_with_rel_pos_value_backward_cuda_kernel_v2( // M, h, hdim + int N, int M, int h, const float *grad_out, const int *index0_offsets, const int *index1, const float *attn, const float *v, const float *table, + const int *rel_idx, float *grad_attn, float *grad_v, float *grad_table) { + // input: attn: (M, h), v: (N, h, hdim), index0: (M, ), index1: (M, ), table: (L, h, hdim, 3), rel_idx: (M, 3) + + int q_idx = blockIdx.x; + int h_idx = blockIdx.y; + int n_idx = threadIdx.x; + + int C = h*d; + + __shared__ int start, end; + __shared__ float gradients[d]; + + if (n_idx == 0){ + start = index0_offsets[q_idx]; + end = index0_offsets[q_idx+1]; + } + for (int i = n_idx; i < d; i += blockDim.x){ + gradients[i] = grad_out[q_idx*C + h_idx*d + i]; + } + + __syncthreads(); + + int m_idx = start + n_idx; + if (m_idx < end){ + int v_idx = index1[m_idx]; + int r_idx1 = rel_idx[m_idx*3], r_idx2 = rel_idx[m_idx*3+1], r_idx3 = rel_idx[m_idx*3+2]; + float attn_scalar = attn[m_idx*h + h_idx]; + float grad_attn_sum = 0; + for (int i = 0; i < d; i++){ + float grad_out_scaler_i = gradients[i]; + float table_scaler_i = table[r_idx1*C*3+h_idx*d*3+i*3] + table[r_idx2*C*3+h_idx*d*3+i*3+1] + table[r_idx3*C*3+h_idx*d*3+i*3+2]; + float value_scaler_i = v[v_idx*C + h_idx*d + i]; + grad_attn_sum += (table_scaler_i + value_scaler_i) * grad_out_scaler_i; + atomicAdd(grad_v + v_idx*C + h_idx*d + i, attn_scalar * grad_out_scaler_i); + atomicAdd(grad_table + r_idx1*C*3 + h_idx*d*3 + i*3, attn_scalar * grad_out_scaler_i); + atomicAdd(grad_table + r_idx2*C*3 + h_idx*d*3 + i*3 + 1, attn_scalar * grad_out_scaler_i); + atomicAdd(grad_table + r_idx3*C*3 + h_idx*d*3 + i*3 + 2, attn_scalar * grad_out_scaler_i); + } + grad_attn[m_idx*h + h_idx] = grad_attn_sum; + } +} + +void attention_step2_with_rel_pos_value_forward_cuda_launcher_v2(int N, int M, int h, int hdim, int n_max, const float *attn, const float *v, const int *index0_offsets, + const int *index1, const float *table, const int *rel_idx, float *output) { + // input: attn: (M, h), v: (N, h, hdim), index0: (M, ), index1: (M, ), table: (L, h, hdim, 3), rel_idx: (M, 3) + //dim3 blocks(DIVUP(hdim, THREADS_PER_BLOCK), h, M); + // dim3 blocks(DIVUP(M*3, THREADS_PER_BLOCK), h, hdim); + // dim3 threads(THREADS_PER_BLOCK); + dim3 blocks(N, h); + unsigned int n_threads = opt_n_threads(n_max); + n_threads = n_threads == n_max ? n_threads : n_threads * 2; + + switch (hdim) { + case 16: + attention_step2_with_rel_pos_value_forward_cuda_kernel_v2<16><<>>(N, M, h, attn, v, index0_offsets, index1, table, rel_idx, output); + break; + case 32: + attention_step2_with_rel_pos_value_forward_cuda_kernel_v2<32><<>>(N, M, h, attn, v, index0_offsets, index1, table, rel_idx, output); + break; + default: + throw "d != 16 and d != 32"; + } +} + +void attention_step2_with_rel_pos_value_backward_cuda_launcher_v2(int N, int M, int h, int hdim, int n_max, const float *grad_out, const int *index0_offsets, + const int *index1, const float *attn, const float *v, const float *table, const int *rel_idx, float *grad_attn, float *grad_v, float *grad_table) { + // input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c) + //dim3 blocks(DIVUP(hdim, THREADS_PER_BLOCK), h, M); + + dim3 blocks(N, h); + unsigned int n_threads = opt_n_threads(n_max); + n_threads = n_threads == n_max ? n_threads : n_threads * 2; + + switch (hdim) { + case 16: + attention_step2_with_rel_pos_value_backward_cuda_kernel_v2<16><<>>(N, M, h, grad_out, index0_offsets, index1, attn, v, table, rel_idx, grad_attn, grad_v, grad_table); + break; + case 32: + attention_step2_with_rel_pos_value_backward_cuda_kernel_v2<32><<>>(N, M, h, grad_out, index0_offsets, index1, attn, v, table, rel_idx, grad_attn, grad_v, grad_table); + break; + default: + throw "d != 16 and d != 32"; + } +} diff --git a/utils/pointops2/src/rpe_v2/relative_pos_encoding_cuda_kernel_v2.h b/utils/pointops2/src/rpe_v2/relative_pos_encoding_cuda_kernel_v2.h new file mode 100644 index 0000000..648b152 --- /dev/null +++ b/utils/pointops2/src/rpe_v2/relative_pos_encoding_cuda_kernel_v2.h @@ -0,0 +1,32 @@ +#ifndef _RPE_V2_CUDA_KERNEL +#define _RPE_V2_CUDA_KERNEL +#include +#include +#include + +void dot_prod_with_idx_forward_cuda_v2(int N, int M, int h, int hdim, int n_max, int T, at::Tensor q_tensor, at::Tensor index_q_tensor, at::Tensor k_tensor, at::Tensor index_k_tensor, at::Tensor table_q_tensor, at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor rel_idx_offsets_tensor, at::Tensor sort_indices_tensor, at::Tensor output_tensor); +void dot_prod_with_idx_backward_cuda_v2(int N, int M, int h, int hdim, int n_max, int T, at::Tensor grad_out_tensor, at::Tensor q_tensor, at::Tensor index_q_tensor, at::Tensor k_tensor, at::Tensor index_k_tensor, at::Tensor table_q_tensor, at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor rel_idx_offsets_tensor, at::Tensor sort_indices_tensor, at::Tensor grad_q_tensor, at::Tensor grad_k_tensor, at::Tensor grad_table_q_tensor, at::Tensor grad_table_k_tensor); + +void dot_prod_with_idx_forward_cuda_v3(int N, int M, int h, int hdim, int n_max, at::Tensor q_tensor, at::Tensor index_q_offsets_tensor, at::Tensor k_tensor, at::Tensor index_k_tensor, at::Tensor table_q_tensor, at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor); +void dot_prod_with_idx_backward_cuda_v3(int N, int M, int h, int hdim, int n_max, at::Tensor grad_out_tensor, at::Tensor q_tensor, at::Tensor index_q_offsets_tensor, at::Tensor k_tensor, at::Tensor index_k_tensor, at::Tensor table_q_tensor, at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor grad_q_tensor, at::Tensor grad_k_tensor, at::Tensor grad_table_q_tensor, at::Tensor grad_table_k_tensor); + +void attention_step2_with_rel_pos_value_forward_cuda_v2(int N, int M, int h, int hdim, int n_max, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor table_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor); +void attention_step2_with_rel_pos_value_backward_cuda_v2(int N, int M, int h, int hdim, int n_max, at::Tensor grad_out_tensor, at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor table_tensor, at::Tensor rel_idx_tensor, at::Tensor grad_attn_tensor, at::Tensor grad_v_tensor, at::Tensor grad_table_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void dot_prod_with_idx_forward_cuda_launcher_v2(int N, int M, int h, int hdim, int n_max, int T, const float *q, const int *index_q, const float *k, const int *index_k, const float *table_q, const float *table_k, const int *rel_idx, const int *rel_idx_offsets, const int *sort_indices, float *output); +void dot_prod_with_idx_backward_cuda_launcher_v2(int N, int M, int h, int hdim, int n_max, int T, const float *grad_out, const float *q, const int *index_q, const float *k, const int *index_k, const float *table_q, const float *table_k, const int *rel_idx, const int *rel_idx_offsets, const int *sort_indices, float *grad_q, float *grad_k, float *grad_table_q, float *grad_table_k); + +void dot_prod_with_idx_forward_cuda_launcher_v3(int N, int M, int h, int hdim, int n_max, const float *q, const int *index_q_offsets, const float *k, const int *index_k, const float *table_q, const float *table_k, const int *rel_idx, float *output); +void dot_prod_with_idx_backward_cuda_launcher_v3(int N, int M, int h, int hdim, int n_max, const float *grad_out, const float *q, const int *index_q_offsets, const float *k, const int *index_k, const float *table_q, const float *table_k, const int *rel_idx, float *grad_q, float *grad_k, float *grad_table_q, float *grad_table_k); + +void attention_step2_with_rel_pos_value_forward_cuda_launcher_v2(int N, int M, int h, int hdim, int n_max, const float *attn, const float *v, const int *index0_offsets, const int *index1, const float *table, const int *rel_idx, float *output); +void attention_step2_with_rel_pos_value_backward_cuda_launcher_v2(int N, int M, int h, int hdim, int n_max, const float *grad_out, const int *index0_offsets, const int *index1, const float *attn, const float *v, const float *table, const int *rel_idx, float *grad_attn, float *grad_v, float *grad_table); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/utils/pointops2/src/rpe_v2/relative_pos_encoding_cuda_v2.cpp b/utils/pointops2/src/rpe_v2/relative_pos_encoding_cuda_v2.cpp new file mode 100644 index 0000000..0a4c96a --- /dev/null +++ b/utils/pointops2/src/rpe_v2/relative_pos_encoding_cuda_v2.cpp @@ -0,0 +1,111 @@ +#include +#include +#include +#include +#include "relative_pos_encoding_cuda_kernel_v2.h" + +void dot_prod_with_idx_forward_cuda_v2(int N, int M, int h, int hdim, int n_max, int T, at::Tensor q_tensor, + at::Tensor index_q_tensor, at::Tensor k_tensor, at::Tensor index_k_tensor, at::Tensor table_q_tensor, + at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor rel_idx_offsets_tensor, at::Tensor sort_indices_tensor, at::Tensor output_tensor) +{ + const float *q = q_tensor.data_ptr(); + const int *index_q = index_q_tensor.data_ptr(); + const float *k = k_tensor.data_ptr(); + const int *index_k = index_k_tensor.data_ptr(); + const float *table_q = table_q_tensor.data_ptr(); + const float *table_k = table_k_tensor.data_ptr(); + const int *rel_idx = rel_idx_tensor.data_ptr(); + const int *rel_idx_offsets = rel_idx_offsets_tensor.data_ptr(); + const int *sort_indices = sort_indices_tensor.data_ptr(); + float *output = output_tensor.data_ptr(); + dot_prod_with_idx_forward_cuda_launcher_v2(N, M, h, hdim, n_max, T, q, index_q, k, index_k, table_q, table_k, rel_idx, rel_idx_offsets, sort_indices, output); +} + +void dot_prod_with_idx_backward_cuda_v2(int N, int M, int h, int hdim, int n_max, int T, at::Tensor grad_out_tensor, + at::Tensor q_tensor, at::Tensor index_q_tensor, at::Tensor k_tensor, at::Tensor index_k_tensor, + at::Tensor table_q_tensor, at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor rel_idx_offsets_tensor, + at::Tensor sort_indices_tensor, at::Tensor grad_q_tensor, at::Tensor grad_k_tensor, at::Tensor grad_table_q_tensor, at::Tensor grad_table_k_tensor) +{ + const float *grad_out = grad_out_tensor.data_ptr(); + const float *q = q_tensor.data_ptr(); + const int *index_q = index_q_tensor.data_ptr(); + const float *k = k_tensor.data_ptr(); + const int *index_k = index_k_tensor.data_ptr(); + const float *table_q = table_q_tensor.data_ptr(); + const float *table_k = table_k_tensor.data_ptr(); + const int *rel_idx = rel_idx_tensor.data_ptr(); + const int *rel_idx_offsets = rel_idx_offsets_tensor.data_ptr(); + const int *sort_indices = sort_indices_tensor.data_ptr(); + float *grad_q = grad_q_tensor.data_ptr(); + float *grad_k = grad_k_tensor.data_ptr(); + float *grad_table_q = grad_table_q_tensor.data_ptr(); + float *grad_table_k = grad_table_k_tensor.data_ptr(); + dot_prod_with_idx_backward_cuda_launcher_v2(N, M, h, hdim, n_max, T, grad_out, q, index_q, k, index_k, table_q, table_k, rel_idx, rel_idx_offsets, sort_indices, grad_q, grad_k, grad_table_q, grad_table_k); +} + + +void dot_prod_with_idx_forward_cuda_v3(int N, int M, int h, int hdim, int n_max, at::Tensor q_tensor, + at::Tensor index_q_offsets_tensor, at::Tensor k_tensor, at::Tensor index_k_tensor, at::Tensor table_q_tensor, + at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor) +{ + const float *q = q_tensor.data_ptr(); + const int *index_q_offsets = index_q_offsets_tensor.data_ptr(); + const float *k = k_tensor.data_ptr(); + const int *index_k = index_k_tensor.data_ptr(); + const float *table_q = table_q_tensor.data_ptr(); + const float *table_k = table_k_tensor.data_ptr(); + const int *rel_idx = rel_idx_tensor.data_ptr(); + float *output = output_tensor.data_ptr(); + dot_prod_with_idx_forward_cuda_launcher_v3(N, M, h, hdim, n_max, q, index_q_offsets, k, index_k, table_q, table_k, rel_idx, output); +} + +void dot_prod_with_idx_backward_cuda_v3(int N, int M, int h, int hdim, int n_max, at::Tensor grad_out_tensor, + at::Tensor q_tensor, at::Tensor index_q_offsets_tensor, at::Tensor k_tensor, at::Tensor index_k_tensor, + at::Tensor table_q_tensor, at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor grad_q_tensor, + at::Tensor grad_k_tensor, at::Tensor grad_table_q_tensor, at::Tensor grad_table_k_tensor) +{ + const float *grad_out = grad_out_tensor.data_ptr(); + const float *q = q_tensor.data_ptr(); + const int *index_q_offsets = index_q_offsets_tensor.data_ptr(); + const float *k = k_tensor.data_ptr(); + const int *index_k = index_k_tensor.data_ptr(); + const float *table_q = table_q_tensor.data_ptr(); + const float *table_k = table_k_tensor.data_ptr(); + const int *rel_idx = rel_idx_tensor.data_ptr(); + float *grad_q = grad_q_tensor.data_ptr(); + float *grad_k = grad_k_tensor.data_ptr(); + float *grad_table_q = grad_table_q_tensor.data_ptr(); + float *grad_table_k = grad_table_k_tensor.data_ptr(); + dot_prod_with_idx_backward_cuda_launcher_v3(N, M, h, hdim, n_max, grad_out, q, index_q_offsets, k, index_k, table_q, table_k, rel_idx, grad_q, grad_k, grad_table_q, grad_table_k); +} + + +void attention_step2_with_rel_pos_value_forward_cuda_v2(int N, int M, int h, int hdim, int n_max, at::Tensor attn_tensor, at::Tensor v_tensor, + at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor table_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor) +{ + const float *attn = attn_tensor.data_ptr(); + const float *v = v_tensor.data_ptr(); + const int *index0_offsets = index0_offsets_tensor.data_ptr(); + const int *index1 = index1_tensor.data_ptr(); + const float *table = table_tensor.data_ptr(); + const int *rel_idx = rel_idx_tensor.data_ptr(); + float *output = output_tensor.data_ptr(); + attention_step2_with_rel_pos_value_forward_cuda_launcher_v2(N, M, h, hdim, n_max, attn, v, index0_offsets, index1, table, rel_idx, output); +} + +void attention_step2_with_rel_pos_value_backward_cuda_v2(int N, int M, int h, int hdim, int n_max, at::Tensor grad_out_tensor, + at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor table_tensor, + at::Tensor rel_idx_tensor, at::Tensor grad_attn_tensor, at::Tensor grad_v_tensor, at::Tensor grad_table_tensor) +{ + const float *grad_out = grad_out_tensor.data_ptr(); + const int *index0_offsets = index0_offsets_tensor.data_ptr(); + const int *index1 = index1_tensor.data_ptr(); + const float *attn = attn_tensor.data_ptr(); + const float *v = v_tensor.data_ptr(); + const float *table = table_tensor.data_ptr(); + const int *rel_idx = rel_idx_tensor.data_ptr(); + float *grad_attn = grad_attn_tensor.data_ptr(); + float *grad_v = grad_v_tensor.data_ptr(); + float *grad_table = grad_table_tensor.data_ptr(); + attention_step2_with_rel_pos_value_backward_cuda_launcher_v2(N, M, h, hdim, n_max, grad_out, index0_offsets, index1, attn, v, table, rel_idx, grad_attn, grad_v, grad_table); +} diff --git a/utils/pointops2/src/sampling/sampling_cuda.cpp b/utils/pointops2/src/sampling/sampling_cuda.cpp new file mode 100644 index 0000000..7b2622a --- /dev/null +++ b/utils/pointops2/src/sampling/sampling_cuda.cpp @@ -0,0 +1,16 @@ +#include +#include +#include +#include +#include "sampling_cuda_kernel.h" + + +void furthestsampling_cuda(int b, int n, at::Tensor xyz_tensor, at::Tensor offset_tensor, at::Tensor new_offset_tensor, at::Tensor tmp_tensor, at::Tensor idx_tensor) +{ + const float *xyz = xyz_tensor.data_ptr(); + const int *offset = offset_tensor.data_ptr(); + const int *new_offset = new_offset_tensor.data_ptr(); + float *tmp = tmp_tensor.data_ptr(); + int *idx = idx_tensor.data_ptr(); + furthestsampling_cuda_launcher(b, n, xyz, offset, new_offset, tmp, idx); +} diff --git a/utils/pointops2/src/sampling/sampling_cuda_kernel.cu b/utils/pointops2/src/sampling/sampling_cuda_kernel.cu new file mode 100644 index 0000000..d2c70b5 --- /dev/null +++ b/utils/pointops2/src/sampling/sampling_cuda_kernel.cu @@ -0,0 +1,171 @@ +#include "../cuda_utils.h" +#include "sampling_cuda_kernel.h" + + +__device__ void __update(float *dists, int *dists_i, int idx1, int idx2) { + const float v1 = dists[idx1], v2 = dists[idx2]; + const int i1 = dists_i[idx1], i2 = dists_i[idx2]; + dists[idx1] = max(v1, v2); + dists_i[idx1] = v2 > v1 ? i2 : i1; +} + +// input xyz: (n, 3), tmp: (b, n_max) +// ouput idx (m) +template +__global__ void furthestsampling_cuda_kernel(const float *xyz, const int *offset, const int *new_offset, float *tmp, int *idx) +{ + __shared__ float dists[block_size]; + __shared__ int dists_i[block_size]; + + int bid = blockIdx.x; + int start_n, end_n, start_m, end_m, old; + if (bid == 0) { + start_n = 0; + end_n = offset[0]; + start_m = 0; + end_m = new_offset[0]; + old = 0; + } + else { + start_n = offset[bid - 1]; + end_n = offset[bid]; + start_m = new_offset[bid - 1]; + end_m = new_offset[bid]; + old = offset[bid - 1]; + } + + const int stride = block_size; + int tid = threadIdx.x; + if (tid == 0) idx[start_m] = start_n; + + __syncthreads(); + for (int j = start_m + 1; j < end_m; j++) + { + int besti = start_n; + float best = -1; + float x1 = xyz[old * 3 + 0]; + float y1 = xyz[old * 3 + 1]; + float z1 = xyz[old * 3 + 2]; + for (int k = start_n + tid; k < end_n; k += stride) + { + float x2 = xyz[k * 3 + 0]; + float y2 = xyz[k * 3 + 1]; + float z2 = xyz[k * 3 + 2]; + float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); + float d2 = min(d, tmp[k]); + tmp[k] = d2; + besti = d2 > best ? k : besti; + best = d2 > best ? d2 : best; + } + dists[tid] = best; + dists_i[tid] = besti; + __syncthreads(); + + if (block_size >= 1024) { + if (tid < 512) { + __update(dists, dists_i, tid, tid + 512); + } + __syncthreads(); + } + if (block_size >= 512) { + if (tid < 256) { + __update(dists, dists_i, tid, tid + 256); + } + __syncthreads(); + } + if (block_size >= 256) { + if (tid < 128) { + __update(dists, dists_i, tid, tid + 128); + } + __syncthreads(); + } + if (block_size >= 128) { + if (tid < 64) { + __update(dists, dists_i, tid, tid + 64); + } + __syncthreads(); + } + if (block_size >= 64) { + if (tid < 32) { + __update(dists, dists_i, tid, tid + 32); + } + __syncthreads(); + } + if (block_size >= 32) { + if (tid < 16) { + __update(dists, dists_i, tid, tid + 16); + } + __syncthreads(); + } + if (block_size >= 16) { + if (tid < 8) { + __update(dists, dists_i, tid, tid + 8); + } + __syncthreads(); + } + if (block_size >= 8) { + if (tid < 4) { + __update(dists, dists_i, tid, tid + 4); + } + __syncthreads(); + } + if (block_size >= 4) { + if (tid < 2) { + __update(dists, dists_i, tid, tid + 2); + } + __syncthreads(); + } + if (block_size >= 2) { + if (tid < 1) { + __update(dists, dists_i, tid, tid + 1); + } + __syncthreads(); + } + + old = dists_i[0]; + if (tid == 0) + idx[j] = old; + } +} + +void furthestsampling_cuda_launcher(int b, int n, const float *xyz, const int *offset, const int *new_offset, float *tmp, int *idx) +{ + unsigned int n_threads = opt_n_threads(n); + switch (n_threads) { + case 1024: + furthestsampling_cuda_kernel<1024><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 512: + furthestsampling_cuda_kernel<512><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 256: + furthestsampling_cuda_kernel<256><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 128: + furthestsampling_cuda_kernel<128><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 64: + furthestsampling_cuda_kernel<64><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 32: + furthestsampling_cuda_kernel<32><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 16: + furthestsampling_cuda_kernel<16><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 8: + furthestsampling_cuda_kernel<8><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 4: + furthestsampling_cuda_kernel<4><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 2: + furthestsampling_cuda_kernel<2><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 1: + furthestsampling_cuda_kernel<1><<>>(xyz, offset, new_offset, tmp, idx); + break; + default: + furthestsampling_cuda_kernel<512><<>>(xyz, offset, new_offset, tmp, idx); + } +} diff --git a/utils/pointops2/src/sampling/sampling_cuda_kernel.h b/utils/pointops2/src/sampling/sampling_cuda_kernel.h new file mode 100644 index 0000000..c903f63 --- /dev/null +++ b/utils/pointops2/src/sampling/sampling_cuda_kernel.h @@ -0,0 +1,18 @@ +#ifndef _SAMPLING_CUDA_KERNEL +#define _SAMPLING_CUDA_KERNEL +#include +#include +#include + +void furthestsampling_cuda(int b, int n, at::Tensor xyz_tensor, at::Tensor offset_tensor, at::Tensor new_offset_tensor, at::Tensor tmp_tensor, at::Tensor idx_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void furthestsampling_cuda_launcher(int b, int n, const float *xyz, const int *offset, const int *new_offset, float *tmp, int *idx); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/utils/pointops2/src/subtraction/subtraction_cuda.cpp b/utils/pointops2/src/subtraction/subtraction_cuda.cpp new file mode 100644 index 0000000..fa38dc5 --- /dev/null +++ b/utils/pointops2/src/subtraction/subtraction_cuda.cpp @@ -0,0 +1,24 @@ +#include +#include +#include +#include +#include "subtraction_cuda_kernel.h" + + +void subtraction_forward_cuda(int n, int nsample, int c, at::Tensor input1_tensor, at::Tensor input2_tensor, at::Tensor idx_tensor, at::Tensor output_tensor) +{ + const float *input1 = input1_tensor.data_ptr(); + const float *input2 = input2_tensor.data_ptr(); + const int *idx = idx_tensor.data_ptr(); + float *output = output_tensor.data_ptr(); + subtraction_forward_cuda_launcher(n, nsample, c, input1, input2, idx, output); +} + +void subtraction_backward_cuda(int n, int nsample, int c, at::Tensor idx_tensor, at::Tensor grad_output_tensor, at::Tensor grad_input1_tensor, at::Tensor grad_input2_tensor) +{ + const int *idx = idx_tensor.data_ptr(); + const float *grad_output = grad_output_tensor.data_ptr(); + float *grad_input1 = grad_input1_tensor.data_ptr(); + float *grad_input2 = grad_input2_tensor.data_ptr(); + subtraction_backward_cuda_launcher(n, nsample, c, idx, grad_output, grad_input1, grad_input2); +} diff --git a/utils/pointops2/src/subtraction/subtraction_cuda_kernel.cu b/utils/pointops2/src/subtraction/subtraction_cuda_kernel.cu new file mode 100644 index 0000000..9b8d4f7 --- /dev/null +++ b/utils/pointops2/src/subtraction/subtraction_cuda_kernel.cu @@ -0,0 +1,44 @@ +#include "../cuda_utils.h" +#include "subtraction_cuda_kernel.h" + + +__global__ void subtraction_forward_cuda_kernel(int n, int nsample, int c, const float *input1, const float *input2, const int *idx, float *output) { + // input: input1: (n, c), input2: (n, c), idx: (n, nsample), output: (n, nsample, c) + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= n * nsample * c) return; + const int c_idx = index % c; + const int nsample_idx = (index / c) % nsample; + const int n_idx = index / nsample / c; + const int idx_idx = n_idx * nsample + nsample_idx; + const int input1_idx = n_idx * c + c_idx; + const int input2_idx = idx[idx_idx] * c + c_idx; + output[index] = input1[input1_idx] - input2[input2_idx]; +} + +__global__ void subtraction_backward_cuda_kernel(int n, int nsample, int c, const int *idx, const float *grad_output, float *grad_input1, float *grad_input2) { + // input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c) + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= n * nsample * c) return; + const int c_idx = index % c; + const int nsample_idx = (index / c) % nsample; + const int n_idx = index / nsample / c; + const int idx_idx = n_idx * nsample + nsample_idx; + const int input1_idx = n_idx * c + c_idx; + const int input2_idx = idx[idx_idx] * c + c_idx; + atomicAdd(grad_input1 + input1_idx, grad_output[index]); + atomicAdd(grad_input2 + input2_idx, -grad_output[index]); +} + +void subtraction_forward_cuda_launcher(int n, int nsample, int c, const float *input1, const float *input2, const int *idx, float *output) { + // input: input1: (n, c), input2: (n, c), idx: (n, nsample), output: (n, nsample, c) + dim3 blocks(DIVUP(n * nsample * c, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + subtraction_forward_cuda_kernel<<>>(n, nsample, c, input1, input2, idx, output); +} + +void subtraction_backward_cuda_launcher(int n, int nsample, int c, const int *idx, const float *grad_output, float *grad_input1, float *grad_input2) { + // input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c) + dim3 blocks(DIVUP(n * nsample * c, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + subtraction_backward_cuda_kernel<<>>(n, nsample, c, idx, grad_output, grad_input1, grad_input2); +} diff --git a/utils/pointops2/src/subtraction/subtraction_cuda_kernel.h b/utils/pointops2/src/subtraction/subtraction_cuda_kernel.h new file mode 100644 index 0000000..856133d --- /dev/null +++ b/utils/pointops2/src/subtraction/subtraction_cuda_kernel.h @@ -0,0 +1,20 @@ +#ifndef _SUBTRACTION_CUDA_KERNEL +#define _SUBTRACTION_CUDA_KERNEL +#include +#include +#include + +void subtraction_forward_cuda(int n, int nsample, int c, at::Tensor input1_tensor, at::Tensor input2_tensor, at::Tensor idx_tensor, at::Tensor output_tensor); +void subtraction_backward_cuda(int n, int nsample, int c, at::Tensor idx_tensor, at::Tensor grad_output_tensor, at::Tensor grad_input1_tensor, at::Tensor grad_input2_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void subtraction_forward_cuda_launcher(int n, int nsample, int c, const float *input1, const float *input2, const int *idx, float *output); +void subtraction_backward_cuda_launcher(int n, int nsample, int c, const int *idx, const float *grad_output, float *grad_input1, float *grad_input2); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..cd13b77 --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,101 @@ +import numpy as np +import torch +from scipy.optimize import linear_sum_assignment +import sys +import pytorch_lightning as pl +from pathlib import Path +import os +if sys.version_info[:2] >= (3, 8): + from collections.abc import MutableMapping +else: + from collections import MutableMapping + +def flatten_dict(d, parent_key="", sep="_"): + """ + https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys + """ + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, MutableMapping): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +class RegularCheckpointing(pl.Callback): + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): + general = pl_module.config.general + trainer.save_checkpoint(f"{general.save_dir}/last-epoch.ckpt") + print("Checkpoint created") + + +def associate_instances(previous_ins_label, current_ins_label): + previous_instance_ids, c_p = np.unique(previous_ins_label[previous_ins_label != 0], return_counts=True) + current_instance_ids, c_c = np.unique(current_ins_label[current_ins_label!=0], return_counts=True) + + associations = { + 0: 0 + } + + large_previous_instance_ids = [] + large_current_instance_ids = [] + for id, count in zip(previous_instance_ids, c_p): + if count > 25: + large_previous_instance_ids.append(id) + for id, count in zip(current_instance_ids, c_c): + if count > 50: + large_current_instance_ids.append(id) + + p_n = len(large_previous_instance_ids) + c_n = len(large_current_instance_ids) + + association_costs = torch.zeros(p_n, c_n) + for i, p_id in enumerate(large_previous_instance_ids): + for j, c_id in enumerate(large_current_instance_ids): + intersection = np.sum( (previous_ins_label==p_id) & (current_ins_label == c_id) ) + union = np.sum(previous_ins_label==p_id) + np.sum(current_ins_label == c_id) - intersection + iou = intersection/union + cost = 1 - iou + association_costs[i, j] = cost + + idxes_1, idxes_2 = linear_sum_assignment(association_costs) + + for i1, i2 in zip(idxes_1, idxes_2): + if association_costs[i1][i2] < 1.0: + associations[large_current_instance_ids[i2]] = large_previous_instance_ids[i1] + return associations + + +def save_predictions(sem_preds, ins_preds, seq_name, sweep_name): + filename = Path("/globalwork/yilmaz/submission/sequences") / seq_name / "predictions" + # assert not filename.exists(), "Path exists" + filename.mkdir(parents=True, exist_ok=True) + learning_map_inv = { + 1: 10, # "car" + 2: 11, # "bicycle" + 3: 15, # "motorcycle" + 4: 18, # "truck" + 5: 20, # "other-vehicle" + 6: 30, # "person" + 7: 31, # "bicyclist" + 8: 32, # "motorcyclist" + 9: 40, # "road" + 10: 44, # "parking" + 11: 48, # "sidewalk" + 12: 49, # "other-ground" + 13: 50, # "building" + 14: 51, # "fence" + 15: 70, # "vegetation" + 16: 71, # "trunk" + 17: 72, # "terrain" + 18: 80, # "pole" + 19: 81, # "traffic-sign" + } + sem_preds = np.vectorize(learning_map_inv.__getitem__)(sem_preds) + panoptic_preds = (ins_preds << 16) + sem_preds + file_path = str(filename / sweep_name) + ".label" + if not os.path.exists(file_path): + with open(file_path, "wb") as f: + f.write(panoptic_preds.astype(np.uint32).tobytes())