We wrap data in dict
and pass it through the model.
Some key items include:
part_pcs
: point clouds sampled from each object part, usually of shape[batch_size, num_parts, num_points, 3]
. To enable batching, we pad all shape to a pre-defined number (usually 20) of parts with zeros.part_trans
: ground-truth translation of each part. Shape[batch_size, num_parts, 3]
and padded with zeros.part_quat
: ground-truth rotation (quaternion) of each part. Shape[batch_size, num_parts, 4]
and padded with zeros. Note that we load rotations using quaternion for ease of dataloading. SeeRotation Representation
section for more details.part_valids
: binary mask indicating padded parts. Shape[batch_size, num_parts]
. 1 means existed parts while 0 stands for padded parts.
For other data items, see comments in the dataset files for more details.
Shape assembly models usually consist of a point cloud feature extractor (e.g. PointNet), a relationship reasoning module (e.g. GNNs), and a pose predictor (usually implemented as MLPs). See model for details about the baselines supported in this codebase.
We implement a BaseModel
class as an instance of PyTorch-Lightning's LightningModule
, which support general methods such as training/validation/test_step/epoch_end()
.
It also implements general loss computation, metrics calculation, and visualization during training.
See base_model.py.
Below we detail some core methods we implement for all assembly models.
All the assembly models inherit from BaseModel
.
In general, you only need to implement three methods of a new model:
-
__init__()
: initialize all the model components such as feature extractors, pose predictors -
forward()
: the input to this function is thedata_dict
from the dataloader, which contains part point clouds and other items specified by you. The model needs to leverage these inputs to predict two items:rot
: rotation of each parts. Shape[batch_size, num_parts, 4]
if using quaternion or[batch_size, num_parts, 3, 3]
if using rotation matrixtrans
: translation of each parts. Shape[batch_size, num_parts, 3]
.
Once the output dictionary contains these two values, the loss, metrics and visualization code in
BaseModel
can run smoothly -
_loss_function()
: this function applies some pre-/post-processing of the model input-output and loss computation. For example, you can specify the inputs tomodel.forward()
by constructingforward_dict
fromdata_dict
. Or reuse some features calculated in previous samples
Common loss terms include:
- MSE between predicted and ground-truth translations
- Cosine loss for rotation, i.e.
|<q1, q2> - 1|_2
for quaternion or|R1^T @ R2 - I|_2
for rotation matrix - L2/Chamfer distance between point clouds transformed by predicted and ground-truth rotations and translations
Since there are multiple plausible assembly solutions for a set of parts, we adopt the MoN loss sampling mechanism from DGL. See Section 3.4 of their paper for more details.
Besides, since there are often geometrically equivalent parts in a shape (e.g. 4 legs of a chair), we perform a matching step to minimize the loss.
This is similar to the Bipartite Matching used in DETR.
See _match_parts()
method of BaseModel
class.
Usually, there is no geometrically equivalent parts in this setting. So we don't need to perform the matching GT step.
Remark: It is actually very hard to define a canonical pose for objects under the geometric assembly setting, due to e.g. symmetry of a bottle/vase.
See dev
branch for our experimental features in solving this issue.
For semantic assembly, we adopt Shape Chamfer Distance (SCD), Part Accuracy (PA) and Connectivity Accuracy (CA). Please refer to Section 4.3 of the paper for more details.
For geometric assembly, we adopt SCD and PA, as well as MSE/RMSE/MAE between translations and rotations. Please refer to Section 6.1 of the paper for more details.
Remark: As discussed above, these metrics are sometimes problematic due to the symmetry ambiguity.
See dev
branch for experimental metrics that are robust under this setting.
- We use real part first (w, x, y, z) quaternion in this codebase following PyTorch3D, while
scipy
use real part last format. Please be careful when using the code - For ease of data batching, we always represent rotations as quaternions from the dataloaders.
However, to build a compatible interface for util functions, model input-output, we wrap the predicted rotations in a
Rotation3D
class, which supports common format conversion and tensor operations. See rotation.py for detailed definitions - Rotation representations we support (change
_C.rot_type
undermodel
field to use different rotation representations):- Quaternion (
quat
), by default - 6D representation (rotation matrix,
rmat
): see CVPR'19 paper. The predicted6
-len tensor will be reshaped to(2, 3)
, and the third row is obtained via cross product. Then, the 3 vectors will be stacked along the-2
-th dim. In aRotation3D
object, the 6D representation will be converted to a 3x3 rotation matrix
- Quaternion (