Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed training with PyTorch: PointPillars on Waymo #353

Merged
merged 27 commits into from
Sep 16, 2022

Conversation

sanskar107
Copy link
Collaborator

@sanskar107 sanskar107 commented Aug 12, 2021

Added support for distributed training for PyTorch ObjectDetection pipeine.

  • Modify launch script to spawn multiple process for multi-gpu training.
  • Update Object Detection Pipeline to support multiple gpus.
  • Train PointPillars on large scale datasets like Waymo.
  • Train PointRCNN on Waymo.
  • Update Semantic Segmentation pipeiine to support multi-gpu. (Update RandLANet and KPConv to avoid using samplers in dataloaders)
  • Train Semantic Segmentation models.

Some Results:
[Vehicle] green : ground truth
[Vehicle] orange: prediction
image
image


This change is Reviewable

@lgtm-com
Copy link

lgtm-com bot commented Aug 12, 2021

This pull request introduces 2 alerts when merging 306de5b into 2b86a83 - view on LGTM.com

new alerts:

  • 2 for Unused import

@lgtm-com
Copy link

lgtm-com bot commented Aug 13, 2021

This pull request introduces 1 alert when merging b953254 into bbac251 - view on LGTM.com

new alerts:

  • 1 for Unused import

@lgtm-com
Copy link

lgtm-com bot commented Aug 27, 2021

This pull request introduces 1 alert when merging 2779339 into bbac251 - view on LGTM.com

new alerts:

  • 1 for Unused import

@ssheorey ssheorey linked an issue Sep 8, 2021 that may be closed by this pull request
@lgtm-com
Copy link

lgtm-com bot commented Sep 24, 2021

This pull request introduces 1 alert when merging 53479c2 into a371cae - view on LGTM.com

new alerts:

  • 1 for Unused import

Copy link
Member

@ssheorey ssheorey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we planning to switch to DistributedDataParallel with PyTorch in the future? [Better performance and multi-node capability]?
In that case, it may b better to implement that directly, instead of going with DataParallel first.

Can you provide a checklist of the different steps needed to implement full multi-GPU training capability? I think it should be something like:

  • Customize batcher, other code changes
  • PyTorch DataParallel
  • PyTorch DistributedDataParallel
  • Tensorflow...
  • ...

Add documentation for which of the parallelization methods available in Torch / TF are supported.

Reviewable status: 0 of 11 files reviewed, 12 unresolved discussions (waiting on @sanskar107 and @ssheorey)


ml3d/torch/dataloaders/concat_batcher.py, line 439 at r4 (raw file):

    @staticmethod
    def scatter(batch, num_gpu):

Add docstring (what does this do / why / args description).


ml3d/torch/dataloaders/concat_batcher.py, line 452 at r4 (raw file):

len(b.point)

> 0 (explicit condition)


ml3d/torch/dataloaders/concat_batcher.py, line 510 at r4 (raw file):

    @staticmethod
    def scatter(batch, num_gpu):

Add docstring as above.


ml3d/torch/dataloaders/concat_batcher.py, line 524 at r4 (raw file):

len(b.point) > 0

explicit check


ml3d/torch/pipelines/base_pipeline.py, line 50 at r4 (raw file):

        if device == 'cpu' or not torch.cuda.is_available():
            self.device = torch.device('cpu')
            self.device_ids = [-1]

We use 0 as the device id for the CPU in Open3D.


ml3d/torch/pipelines/dataparallel.py, line 6 at r4 (raw file):

class CustomDataParallel(DataParallel):

We can keep the same name as the torch class, so that users can do:
from ml3d.torch.pipelines import DataParallel # or similar
instead of the default:
from torch.nn.parallel import DataParallel


ml3d/torch/pipelines/dataparallel.py, line 33 at r4 (raw file):

        return self.gather(outputs, self.output_device)

    def customscatter(self, inputs, kwargs, device_ids):

For overriding, keep the same name (scatter).


ml3d/torch/pipelines/dataparallel.py, line 40 at r4 (raw file):

        Agrs:
            inputs: Object of type custom batcher.
            kwargs: Optional keyword arguments.

Add pointer explaining kwargs, like:

Passed to `torch.n.DataParallel.scatter`

ml3d/torch/pipelines/dataparallel.py, line 47 at r4 (raw file):

inputs[0]

Is inputs an object or sequence (or either)?


ml3d/utils/builder.py, line 19 at r4 (raw file):

def convert_device_name(framework, device_ids):
    """Convert device to either cpu or cuda."""
    gpu_names = ["gpu", "cuda"]

At some point we should phase out gpu as a synonym for cuda, to prepare for Intel GPUs.


ml3d/utils/builder.py, line 24 at r4 (raw file):

        raise KeyError("the device shoule either "
                       "be cuda or cpu but got {}".format(framework))
    assert type(device_ids) is list

assert -> if + raise TypeError
Actually we can skip this check. If device_ids is not iterable, the for loop below will throw a TypeError. catch and re-raise if you want to provide a better error message.


scripts/run_pipeline.py, line 31 at r4 (raw file):

    parser.add_argument('--ckpt_path', help='path to the checkpoint')
    parser.add_argument('--device',
                        help='devices to run the pipeline',

choices=('cpu', 'cuda')
devices -> device type

@lgtm-com
Copy link

lgtm-com bot commented Jan 10, 2022

This pull request introduces 1 alert when merging 6139e66 into 203b8c6 - view on LGTM.com

new alerts:

  • 1 for Unused import

@lgtm-com
Copy link

lgtm-com bot commented Jan 17, 2022

This pull request introduces 4 alerts when merging d9e1564 into 5baf722 - view on LGTM.com

new alerts:

  • 3 for Unused import
  • 1 for Unused local variable

@lgtm-com
Copy link

lgtm-com bot commented Jan 18, 2022

This pull request introduces 3 alerts when merging ef0e440 into 5baf722 - view on LGTM.com

new alerts:

  • 2 for Unused import
  • 1 for Unused local variable

@sanskar107 sanskar107 marked this pull request as draft January 18, 2022 20:29
@sanskar107 sanskar107 marked this pull request as ready for review February 18, 2022 12:24
@sanskar107 sanskar107 requested a review from ssheorey February 18, 2022 12:24
@sanskar107 sanskar107 changed the title Support for Multiple CUDA device Distributed trainings for PyTorch Feb 18, 2022
@lgtm-com
Copy link

lgtm-com bot commented Feb 18, 2022

This pull request introduces 3 alerts when merging 10164f9 into 8ddb672 - view on LGTM.com

new alerts:

  • 1 for Use of the return value of a procedure
  • 1 for Unused local variable
  • 1 for Unused import

@lgtm-com
Copy link

lgtm-com bot commented Feb 22, 2022

This pull request introduces 3 alerts when merging 6d58cd1 into 8ddb672 - view on LGTM.com

new alerts:

  • 1 for Use of the return value of a procedure
  • 1 for Unused local variable
  • 1 for Unused import

@lgtm-com
Copy link

lgtm-com bot commented Feb 22, 2022

This pull request introduces 1 alert when merging 51a16c3 into 8ddb672 - view on LGTM.com

new alerts:

  • 1 for Use of the return value of a procedure

Copy link
Member

@ssheorey ssheorey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rename this PR to something like "Distributed training with PyTorch: PointPillars on Waymo". For PointRCNN, we will have a separate PR, right?

Reviewed 1 of 10 files at r8.
Reviewable status: 1 of 17 files reviewed, 23 unresolved discussions (waiting on @sanskar107 and @ssheorey)


ml3d/configs/pointpillars_waymo.yml, line 109 at r10 (raw file):

  difficulties: [0, 1, 2]
  summary:
    record_for: [train, valid]

Set to empty before merge.


ml3d/datasets/waymo.py, line 3 at r10 (raw file):

import numpy as np
import os, argparse, pickle, sys
from os.path import exists, join, isfile, dirname, abspath, split

nit: avoid unused imports, multiple imports on one lines (see lgtm messages).


ml3d/datasets/waymo.py, line 94 at r10 (raw file):

        Returns:
            A data object with lidar information.

Mention return data format. e.g. each row is XYZRGB?


ml3d/datasets/waymo.py, line 96 at r10 (raw file):

            A data object with lidar information.
        """
        assert Path(path).exists()

Not necessary. np.fromfile() will complain if path is incorrect.


ml3d/datasets/waymo.py, line 105 at r10 (raw file):

        Returns:
            The data objects with bound boxes information.

bounding boxes


ml3d/datasets/waymo.py, line 135 at r10 (raw file):

            The camera and the camera image used in calibration.
        """
        assert Path(path).exists()

Not necessary. open() will complain if path is incorrect.


ml3d/datasets/waymo.py, line 140 at r10 (raw file):

            lines = f.readlines()
        obj = lines[0].strip().split(' ')[1:]
        P0 = np.array(obj, dtype=np.float32)

Name unused variables as unused_P0, etc. Also P1, P3, P4.


ml3d/datasets/waymo.py, line 223 at r10 (raw file):

                attribute is stored; else, returns false.
        """
        pass

Is this unimplemented? If we want to implement this later, return false would be better since a return value is expected. Also print warning / NotImplementedError() as appropriate.


ml3d/datasets/waymo.py, line 232 at r10 (raw file):

            attr: The attributes that correspond to the outputs passed in results.
        """
        pass

Print warning / NotImplementedError()


ml3d/datasets/waymo.py, line 282 at r10 (raw file):

    def __init__(self, center, size, label, calib):
        confidence = float(label[15]) if label.__len__() == 16 else -1.0

Are there 2 different formats? Add comment explaining this.
nit: Add "extra" data (confidence, occlusion, etc.) explanation to docstring.


scripts/run_pipeline.py, line 232 at r10 (raw file):

    )

    multiprocessing.set_start_method('spawn')

forkserver is faster than spawn and should be preferred in Linux. Does that cause any problems? Also, we probably want to change the default method only in Linux so check OS.

Copy link
Member

@ssheorey ssheorey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add DDP on CPUs in a future PR.

Reviewed 6 of 11 files at r5, 4 of 10 files at r8, 1 of 1 files at r9, 1 of 2 files at r10.
Reviewable status: 13 of 17 files reviewed, 45 unresolved discussions (waiting on @sanskar107 and @ssheorey)


ml3d/torch/pipelines/base_pipeline.py, line 26 at r10 (raw file):

            model: A network model.
            dataset: A dataset, or None for inference model.
            device: 'gpu' or 'cpu'.

'cuda' or 'cpu'


ml3d/torch/pipelines/base_pipeline.py, line 63 at r10 (raw file):

            if distributed:
                raise ValueError(
                    "Distributed training is ON, but CUDA not available.")

Distributed training on CPUs is supported by PyTorch. If there is a blocking issue, we can raise NotImplementedError() here and address it in a future PR.


ml3d/torch/pipelines/dataparallel.py, line 1 at r10 (raw file):

import torch

Is this file still being used?

Code quote:

CustomDataParall

ml3d/torch/pipelines/object_detection.py, line 180 at r10 (raw file):

                                  sampler=valid_sampler)
        # worker_init_fn=lambda x: np.random.seed(x + np.uint32(
        #     torch.utils.data.get_worker_info().seed)))

Remove comment if not needed.


ml3d/torch/pipelines/object_detection.py, line 366 at r10 (raw file):

                    loss = model.module.get_loss(results, data)
                else:
                    loss = model.get_loss(results, data)

Ideally, we want to avoid this if (distributed): self.module.* else: self.* code pattern. Is it possible to use wrapping or inheritance to use model.get_loss() even in distributed mode?

Code quote:

                if self.distributed:
                    loss = model.module.get_loss(results, data)
                else:
                    loss = model.get_loss(results, data)

ml3d/torch/pipelines/object_detection.py, line 374 at r10 (raw file):

                    if model.module.cfg.get('grad_clip_norm', -1) > 0:
                        torch.nn.utils.clip_grad_value_(
                            model.module.parameters(),

See above. Avoid model.module.*


ml3d/torch/pipelines/object_detection.py, line 387 at r10 (raw file):

                    if self.distributed:
                        boxes = model.module.inference_end(results, data)
                    else:

See above. Avoid model.module.*


scripts/collect_bboxes.py, line 5 at r10 (raw file):

import argparse
import pickle
import random

Avoid random. Stick to one of numpy.rng or torch.random for reproducibility.


scripts/collect_bboxes.py, line 99 at r10 (raw file):

    query_pc = range(len(train)) if max_pc >= len(train) else random.sample(
        range(len(train)), max_pc)

Use the NumPy RNG.


scripts/preprocess_waymo.py, line 14 at r10 (raw file):

from os.path import join, exists, dirname, abspath
from os import makedirs
import random

Use NumPy RNG everywhere.


scripts/preprocess_waymo.py, line 60 at r10 (raw file):

    """Waymo to KITTI converter.

    This class converts tfrecord files from Waymo dataset to KITTI format.

Can you describe KITTI format here as well, so users don't have to refer to the KITTI code.
Also, add reference to Waymo document / example code explaining the pre-processing (if available).
Add information on tensor shapes and formats, as appropriate (e.g. lidar is (N,6) with each row X, Y, Z, intensity, elongation, timestamp)


scripts/preprocess_waymo.py, line 140 at r10 (raw file):

                    frame.context.stats.location
                    not in self.selected_waymo_locations):
                print("continue")

Remove if debugging code.

Code quote:

print("continue")

scripts/preprocess_waymo.py, line 197 at r10 (raw file):

        # all camera ids are saved as id-1 in the result because
        # camera 0 is unknown in the proto

Document this (and any other) data modification in the class docstring.

Code quote:

        # all camera ids are saved as id-1 in the result because
        # camera 0 is unknown in the proto

scripts/preprocess_waymo.py, line 211 at r10 (raw file):

                'w+') as fp_calib:
            fp_calib.write(calib_context)
            fp_calib.close()

close() is unnecessary with context manager.


scripts/preprocess_waymo.py, line 229 at r10 (raw file):

            name = labels.name
            for label in labels.labels:
                # TODO: need a workaround as bbox may not belong to front cam

Is this resolved now?


scripts/preprocess_waymo.py, line 259 at r10 (raw file):

            # if self.filter_empty_3dboxes and obj.num_lidar_points_in_box < 1:
            #     continue

Remove commented code.

Code quote:

            # if self.filter_empty_3dboxes and obj.num_lidar_points_in_box < 1:
            #     continue

scripts/preprocess_waymo.py, line 272 at r10 (raw file):

            # pt_ref = self.T_velo_to_front_cam @ \
            #     np.array([x, y, z, 1]).reshape((4, 1))
            # x, y, z, _ = pt_ref.flatten().tolist()

Remove commented code if not required any more.

Code quote:

            # # project bounding box to the virtual reference frame
            # pt_ref = self.T_velo_to_front_cam @ \
            #     np.array([x, y, z, 1]).reshape((4, 1))
            # x, y, z, _ = pt_ref.flatten().tolist()

scripts/preprocess_waymo.py, line 341 at r10 (raw file):

        timestamp = frame.timestamp_micros * np.ones_like(intensity)

        # concatenate x,y,z, intensity, elongation, timestamp (6-dim)

Add this to docstring.


scripts/preprocess_waymo.py, line 364 at r10 (raw file):

        frame_pose = tf.convert_to_tensor(
            value=np.reshape(np.array(frame.pose.transform), [4, 4]))
        # [H, W, 6]

Add input / output shape info to docstring


scripts/preprocess_waymo.py, line 368 at r10 (raw file):

            tf.convert_to_tensor(value=range_image_top_pose.data),
            range_image_top_pose.shape.dims)
        # [H, W, 3, 3]

Add input / output shape info to docstring


scripts/run_pipeline.py, line 155 at r10 (raw file):

    args.cfg_tb = cfg_tb
    args.distributed = framework == 'torch' and args.device != 'cpu' and len(
        args.device_ids) > 1

future PR: allow distributed with CPUs.


scripts/run_pipeline.py, line 181 at r10 (raw file):

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

Allow configuring host, port and backend in the pipeline config file. Check if these are already set by the user and respect prior settings in that case.

@lgtm-com
Copy link

lgtm-com bot commented Mar 25, 2022

This pull request introduces 2 alerts and fixes 6 when merging 1ed11c8 into 8ddb672 - view on LGTM.com

new alerts:

  • 1 for Use of the return value of a procedure
  • 1 for Unused import

fixed alerts:

  • 4 for Unused local variable
  • 2 for Unused import

@sanskar107 sanskar107 changed the title Distributed trainings for PyTorch Distributed training with PyTorch: PointPillars on Waymo Apr 5, 2022
Copy link
Collaborator Author

@sanskar107 sanskar107 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 9 of 17 files reviewed, 45 unresolved discussions (waiting on @sanskar107 and @ssheorey)


ml3d/configs/pointpillars_waymo.yml, line 109 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Set to empty before merge.

Done.


ml3d/datasets/waymo.py, line 94 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Mention return data format. e.g. each row is XYZRGB?

Done.


ml3d/datasets/waymo.py, line 96 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Not necessary. np.fromfile() will complain if path is incorrect.

Done.


ml3d/datasets/waymo.py, line 105 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

bounding boxes

Done.


ml3d/datasets/waymo.py, line 135 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Not necessary. open() will complain if path is incorrect.

Done.


ml3d/datasets/waymo.py, line 140 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Name unused variables as unused_P0, etc. Also P1, P3, P4.

Done.


ml3d/datasets/waymo.py, line 223 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Is this unimplemented? If we want to implement this later, return false would be better since a return value is expected. Also print warning / NotImplementedError() as appropriate.

Done.


ml3d/datasets/waymo.py, line 232 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Print warning / NotImplementedError()

Done.


ml3d/datasets/waymo.py, line 282 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Are there 2 different formats? Add comment explaining this.
nit: Add "extra" data (confidence, occlusion, etc.) explanation to docstring.

Done.


ml3d/torch/pipelines/base_pipeline.py, line 50 at r4 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

We use 0 as the device id for the CPU in Open3D.

Pytorch and tensorflow have gpu index starting from 0, we should keep the same convention to avoid confusion.


ml3d/torch/pipelines/base_pipeline.py, line 26 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

'cuda' or 'cpu'

Done.


ml3d/torch/pipelines/base_pipeline.py, line 63 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Distributed training on CPUs is supported by PyTorch. If there is a blocking issue, we can raise NotImplementedError() here and address it in a future PR.

Done.


ml3d/torch/pipelines/object_detection.py, line 180 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Remove comment if not needed.

Done.


ml3d/torch/pipelines/object_detection.py, line 366 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Ideally, we want to avoid this if (distributed): self.module.* else: self.* code pattern. Is it possible to use wrapping or inheritance to use model.get_loss() even in distributed mode?

Done.


ml3d/torch/pipelines/object_detection.py, line 374 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

See above. Avoid model.module.*

Replaced all model.module.* with model.* except for this model.module.parameters(). This is because parameters() contains the trainable weights, and the distributed wrapper DistributedDataParallel may contain some extra parameter, which might be a source of an error.


ml3d/torch/pipelines/object_detection.py, line 387 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

See above. Avoid model.module.*

Done.


scripts/collect_bboxes.py, line 5 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Avoid random. Stick to one of numpy.rng or torch.random for reproducibility.

Done.


scripts/collect_bboxes.py, line 99 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Use the NumPy RNG.

Done.


scripts/preprocess_waymo.py, line 14 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Use NumPy RNG everywhere.

Done.


scripts/preprocess_waymo.py, line 60 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Can you describe KITTI format here as well, so users don't have to refer to the KITTI code.
Also, add reference to Waymo document / example code explaining the pre-processing (if available).
Add information on tensor shapes and formats, as appropriate (e.g. lidar is (N,6) with each row X, Y, Z, intensity, elongation, timestamp)

Done.


scripts/preprocess_waymo.py, line 140 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Remove if debugging code.

Done.


scripts/preprocess_waymo.py, line 197 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Document this (and any other) data modification in the class docstring.

Done.


scripts/preprocess_waymo.py, line 211 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

close() is unnecessary with context manager.

Done.


scripts/preprocess_waymo.py, line 229 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Is this resolved now?

Done.


scripts/preprocess_waymo.py, line 259 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Remove commented code.

Done.


scripts/preprocess_waymo.py, line 272 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Remove commented code if not required any more.

Done.


scripts/preprocess_waymo.py, line 341 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Add this to docstring.

Done.


scripts/preprocess_waymo.py, line 364 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Add input / output shape info to docstring

Done.


scripts/preprocess_waymo.py, line 368 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Add input / output shape info to docstring

Done.


scripts/run_pipeline.py, line 31 at r4 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

choices=('cpu', 'cuda')
devices -> device type

Done.


scripts/run_pipeline.py, line 181 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

Allow configuring host, port and backend in the pipeline config file. Check if these are already set by the user and respect prior settings in that case.

Done.


scripts/run_pipeline.py, line 232 at r10 (raw file):

Previously, ssheorey (Sameer Sheorey) wrote…

forkserver is faster than spawn and should be preferred in Linux. Does that cause any problems? Also, we probably want to change the default method only in Linux so check OS.

changed to forkserver. This is also needed for MacOS.

@lgtm-com
Copy link

lgtm-com bot commented Apr 5, 2022

This pull request introduces 2 alerts and fixes 6 when merging eb3b551 into 8ddb672 - view on LGTM.com

new alerts:

  • 1 for Use of the return value of a procedure
  • 1 for Unused import

fixed alerts:

  • 4 for Unused local variable
  • 2 for Unused import

@ssheorey ssheorey self-requested a review June 13, 2022 15:18
Copy link
Collaborator

@benjaminum benjaminum left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 1 of 10 files at r8, 5 of 7 files at r11, 1 of 1 files at r13.
Reviewable status: 15 of 17 files reviewed, 48 unresolved discussions (waiting on @sanskar107 and @ssheorey)


ml3d/datasets/waymo.py line 101 at r13 (raw file):

    @staticmethod
    def read_label(path, calib):

Please add a short description of path and calib or useful information like where to get the inputs, e.g., calib: Calibration as returned by read_calib().


ml3d/datasets/waymo.py line 279 at r13 (raw file):

    """

    def __init__(self, center, size, label, calib):

Do we need documentation here or is the base class docstring sufficient?


ml3d/torch/pipelines/object_detection.py line 405 at r13 (raw file):

            # --------------------- validation
            # if rank == 0 and (epoch % cfg.get("validation_freq", 1)) == 0:

remove old if

@lgtm-com
Copy link

lgtm-com bot commented Sep 13, 2022

This pull request introduces 2 alerts and fixes 6 when merging e49f4ec into c581efe - view on LGTM.com

new alerts:

  • 1 for Use of the return value of a procedure
  • 1 for Unused import

fixed alerts:

  • 4 for Unused local variable
  • 2 for Unused import

Copy link
Collaborator

@benjaminum benjaminum left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:lgtm:

Reviewable status: 11 of 17 files reviewed, 48 unresolved discussions (waiting on @benjaminum, @sanskar107, and @ssheorey)

Copy link
Member

@ssheorey ssheorey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 11 of 17 files reviewed, 16 unresolved discussions (waiting on @benjaminum, @sanskar107, and @ssheorey)

@ssheorey ssheorey merged commit 1c45bfe into dev Sep 16, 2022
@ssheorey ssheorey deleted the sanskar/multi_gpu branch September 16, 2022 18:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Train with multi GPU
4 participants