diff --git a/gtsfm/configs/unified.yaml b/gtsfm/configs/unified.yaml index 8fef52586..28bc1d840 100644 --- a/gtsfm/configs/unified.yaml +++ b/gtsfm/configs/unified.yaml @@ -65,7 +65,7 @@ SceneOptimizer: # comment out to not run view_graph_estimator: _target_: gtsfm.view_graph_estimator.cycle_consistent_rotation_estimator.CycleConsistentRotationViewGraphEstimator - edge_error_aggregation_criterion: MEDIAN_EDGE_ERROR + edge_error_aggregation_criterion: MIN_EDGE_ERROR rot_avg_module: _target_: gtsfm.averaging.rotation.shonan.ShonanRotationAveraging diff --git a/gtsfm/multi_view_optimizer.py b/gtsfm/multi_view_optimizer.py index e6c557bac..75ccfcfa6 100644 --- a/gtsfm/multi_view_optimizer.py +++ b/gtsfm/multi_view_optimizer.py @@ -21,11 +21,15 @@ from gtsfm.common.keypoints import Keypoints from gtsfm.common.pose_prior import PosePrior from gtsfm.common.sfm_track import SfmTrack2d +from gtsfm.common.two_view_estimation_report import TwoViewEstimationReport +from gtsfm.data_association.cpp_dsf_tracks_estimator import CppDsfTracksEstimator from gtsfm.data_association.data_assoc import DataAssociation from gtsfm.evaluation.metrics import GtsfmMetricsGroup +from gtsfm.view_graph_estimator.cycle_consistent_rotation_estimator import ( + CycleConsistentRotationViewGraphEstimator, + EdgeErrorAggregationCriterion, +) from gtsfm.view_graph_estimator.view_graph_estimator_base import ViewGraphEstimatorBase -from gtsfm.data_association.cpp_dsf_tracks_estimator import CppDsfTracksEstimator -from gtsfm.common.two_view_estimation_report import TwoViewEstimationReport class MultiViewOptimizer: @@ -44,6 +48,10 @@ def __init__( self.ba_optimizer = bundle_adjustment_module self._run_view_graph_estimator: bool = self.view_graph_estimator is not None + self.view_graph_estimator_v2 = CycleConsistentRotationViewGraphEstimator( + edge_error_aggregation_criterion=EdgeErrorAggregationCriterion.MEDIAN_EDGE_ERROR + ) + def __repr__(self) -> str: return f""" MultiviewOptimizer: @@ -114,6 +122,21 @@ def create_computation_graph( two_view_reports_dict, debug_output_dir, ) + ( + viewgraph_i2Ri1_graph, + viewgraph_i2Ui1_graph, + viewgraph_v_corr_idxs_graph, + viewgraph_two_view_reports_graph, + viewgraph_estimation_metrics, + ) = self.view_graph_estimator_v2.create_computation_graph( + viewgraph_i2Ri1_graph, + viewgraph_i2Ui1_graph, + all_intrinsics, + viewgraph_v_corr_idxs_graph, + keypoints_list, + viewgraph_two_view_reports_graph, + debug_output_dir / "2", + ) else: viewgraph_i2Ri1_graph = dask.delayed(i2Ri1_dict) viewgraph_i2Ui1_graph = dask.delayed(i2Ui1_dict)