Skip to content

Commit

Permalink
fix formatting and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
senselessdev1 committed Aug 12, 2023
1 parent 2a3670e commit dbeebfe
Showing 1 changed file with 17 additions and 22 deletions.
39 changes: 17 additions & 22 deletions gtsfm/bundle/bundle_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,15 +289,15 @@ def __cameras_to_model(

return sorted(list(cameras))

def run_ba_step(
def run_ba_at_threshold(
self,
initial_data: GtsfmData,
absolute_pose_priors: List[Optional[PosePrior]],
relative_pose_priors: Dict[Tuple[int, int], PosePrior],
reproj_error_thresh: Optional[float],
verbose: bool = True,
) -> Tuple[GtsfmData, GtsfmData, List[bool], float]:
"""TODO"""
"""Run bundle adjustment and filter the resulting tracks by reprojection error."""
cameras_to_model = self.__cameras_to_model(initial_data, absolute_pose_priors, relative_pose_priors)
graph = self.__construct_factor_graph(
cameras_to_model=cameras_to_model,
Expand All @@ -311,31 +311,24 @@ def run_ba_step(
# Print error.
final_error = graph.error(result_values)
if verbose:
logger.info(f"initial error: {graph.error(initial_values):.2f}")
logger.info(f"final error: {final_error:.2f}")
logger.info("initial error: %.2f", graph.error(initial_values))
logger.info("final error: %.2f", final_error)

# Construct the results.
# Convert the `Values` results to a `GtsfmData` instance.
optimized_data = values_to_gtsfm_data(result_values, initial_data, self._shared_calib)

# Filter landmarks by reprojection error.
if reproj_error_thresh is not None:
if verbose:
logger.info(
"[Result] Number of tracks before filtering: %d",
optimized_data.number_tracks(),
)
logger.info("[Result] Number of tracks before filtering: %d", optimized_data.number_tracks())
filtered_result, valid_mask = optimized_data.filter_landmarks(reproj_error_thresh)
if verbose:
logger.info(
"[Result] Number of tracks after filtering: %d",
filtered_result.number_tracks(),
)
logger.info("[Result] Number of tracks after filtering: %d", filtered_result.number_tracks())

else:
valid_mask = [True] * optimized_data.number_tracks()
filtered_result = optimized_data

# Set intermediate result as initial condition for next step.
return optimized_data, filtered_result, valid_mask, final_error

def run_ba(
Expand All @@ -361,7 +354,8 @@ def run_ba(
"""
num_ba_steps = len(self._reproj_error_thresholds)
for step, reproj_error_thresh in enumerate(self._reproj_error_thresholds):
(optimized_data, filtered_result, valid_mask, final_error,) = self.run_ba_step(
# Use intermediate result as initial condition for next step.
(optimized_data, filtered_result, valid_mask, final_error) = self.run_ba_at_threshold(
initial_data,
absolute_pose_priors,
relative_pose_priors,
Expand Down Expand Up @@ -460,20 +454,21 @@ def _run_ba_instrumented(
num_ba_steps = len(self._reproj_error_thresholds)
for step, reproj_error_thresh in enumerate(self._reproj_error_thresholds):
step_start_time = time.time()
(optimized_data, filtered_result, valid_mask, final_error,) = self.run_ba_step(
initial_data,
absolute_pose_priors,
relative_pose_priors,
reproj_error_thresh,
verbose,
(optimized_data, filtered_result, valid_mask, final_error) = self.run_ba_at_threshold(
initial_data=initial_data,
absolute_pose_priors=absolute_pose_priors,
relative_pose_priors=relative_pose_priors,
reproj_error_thresh=reproj_error_thresh,
verbose=verbose,
)
step_times.append(time.time() - step_start_time)

# Print intermediate results.
if num_ba_steps > 1:
logger.info(
"[BA Step %d/%d] Error: %.2f, Number of tracks: %d"
"[BA Stage @ thresh=%.2f px %d/%d] Error: %.2f, Number of tracks: %d"
% (
reproj_error_thresh,
step + 1,
num_ba_steps,
final_error,
Expand Down

0 comments on commit dbeebfe

Please sign in to comment.