From dbeebfec267d26b1af5e5f81dc6d4b6e13ec567c Mon Sep 17 00:00:00 2001 From: senselessdev1 Date: Sat, 12 Aug 2023 16:25:28 -0400 Subject: [PATCH] fix formatting and docstrings --- gtsfm/bundle/bundle_adjustment.py | 39 ++++++++++++++----------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/gtsfm/bundle/bundle_adjustment.py b/gtsfm/bundle/bundle_adjustment.py index 57f0794be..ebbaafb92 100644 --- a/gtsfm/bundle/bundle_adjustment.py +++ b/gtsfm/bundle/bundle_adjustment.py @@ -289,7 +289,7 @@ def __cameras_to_model( return sorted(list(cameras)) - def run_ba_step( + def run_ba_at_threshold( self, initial_data: GtsfmData, absolute_pose_priors: List[Optional[PosePrior]], @@ -297,7 +297,7 @@ def run_ba_step( reproj_error_thresh: Optional[float], verbose: bool = True, ) -> Tuple[GtsfmData, GtsfmData, List[bool], float]: - """TODO""" + """Run bundle adjustment and filter the resulting tracks by reprojection error.""" cameras_to_model = self.__cameras_to_model(initial_data, absolute_pose_priors, relative_pose_priors) graph = self.__construct_factor_graph( cameras_to_model=cameras_to_model, @@ -311,31 +311,24 @@ def run_ba_step( # Print error. final_error = graph.error(result_values) if verbose: - logger.info(f"initial error: {graph.error(initial_values):.2f}") - logger.info(f"final error: {final_error:.2f}") + logger.info("initial error: %.2f", graph.error(initial_values)) + logger.info("final error: %.2f", final_error) - # Construct the results. + # Convert the `Values` results to a `GtsfmData` instance. optimized_data = values_to_gtsfm_data(result_values, initial_data, self._shared_calib) # Filter landmarks by reprojection error. if reproj_error_thresh is not None: if verbose: - logger.info( - "[Result] Number of tracks before filtering: %d", - optimized_data.number_tracks(), - ) + logger.info("[Result] Number of tracks before filtering: %d", optimized_data.number_tracks()) filtered_result, valid_mask = optimized_data.filter_landmarks(reproj_error_thresh) if verbose: - logger.info( - "[Result] Number of tracks after filtering: %d", - filtered_result.number_tracks(), - ) + logger.info("[Result] Number of tracks after filtering: %d", filtered_result.number_tracks()) else: valid_mask = [True] * optimized_data.number_tracks() filtered_result = optimized_data - # Set intermediate result as initial condition for next step. return optimized_data, filtered_result, valid_mask, final_error def run_ba( @@ -361,7 +354,8 @@ def run_ba( """ num_ba_steps = len(self._reproj_error_thresholds) for step, reproj_error_thresh in enumerate(self._reproj_error_thresholds): - (optimized_data, filtered_result, valid_mask, final_error,) = self.run_ba_step( + # Use intermediate result as initial condition for next step. + (optimized_data, filtered_result, valid_mask, final_error) = self.run_ba_at_threshold( initial_data, absolute_pose_priors, relative_pose_priors, @@ -460,20 +454,21 @@ def _run_ba_instrumented( num_ba_steps = len(self._reproj_error_thresholds) for step, reproj_error_thresh in enumerate(self._reproj_error_thresholds): step_start_time = time.time() - (optimized_data, filtered_result, valid_mask, final_error,) = self.run_ba_step( - initial_data, - absolute_pose_priors, - relative_pose_priors, - reproj_error_thresh, - verbose, + (optimized_data, filtered_result, valid_mask, final_error) = self.run_ba_at_threshold( + initial_data=initial_data, + absolute_pose_priors=absolute_pose_priors, + relative_pose_priors=relative_pose_priors, + reproj_error_thresh=reproj_error_thresh, + verbose=verbose, ) step_times.append(time.time() - step_start_time) # Print intermediate results. if num_ba_steps > 1: logger.info( - "[BA Step %d/%d] Error: %.2f, Number of tracks: %d" + "[BA Stage @ thresh=%.2f px %d/%d] Error: %.2f, Number of tracks: %d" % ( + reproj_error_thresh, step + 1, num_ba_steps, final_error,