Skip to content

Commit

Permalink
minor coreg fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
trchudley committed Jan 27, 2025
1 parent 9bf2390 commit 03fd1a8
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 235 deletions.
2 changes: 1 addition & 1 deletion batch/batch_download_and_coregister_is2.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@
# check whether coreg worked, and construct filename appropriately
if metadata["coreg_status"] == "failed":
out_fpath = out_fname + ".tif"
if metadata["coreg_status"] == "coregistered":
elif metadata["coreg_status"] == "coregistered":
out_fpath = out_fname + "_coreg.tif"
elif metadata["coreg_status"] == "dz_only":
out_fpath = out_fname + "_coreg_dz.tif"
Expand Down
32 changes: 12 additions & 20 deletions src/pdemtools/_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,19 @@ def coregister_is2(
# ADD ICESAT-2 SPECIFIC TO METADATA
if "request_date_dt" in points_df.columns:

dt_days_max = int(
points_df["request_date_dt"].dt.round("d").dt.days.abs().max()
)
metadata_dict["points_dt_days_max"] = dt_days_max
if metadata_dict["coreg_status"] == "failed":
metadata_dict["points_dt_days_max"] = None
metadata_dict["points_dt_days_count"] = None

dt_days = points_df["request_date_dt"].dt.round("d").dt.days
dt_days_counts_dict = dt_days.value_counts().to_dict()
metadata_dict["points_dt_days_count"] = dt_days_counts_dict
else:
dt_days_max = int(
points_df["request_date_dt"].dt.round("d").dt.days.abs().max()
)
metadata_dict["points_dt_days_max"] = dt_days_max

dt_days = points_df["request_date_dt"].dt.round("d").dt.days
dt_days_counts_dict = dt_days.value_counts().to_dict()
metadata_dict["points_dt_days_count"] = dt_days_counts_dict

metadata_dict["coregistration_type"] = "reference_icesat2"

Expand Down Expand Up @@ -384,19 +389,6 @@ def coregister_dems(

resolution = get_resolution(self._obj)

# new_dem_array, metadata_dict = coregisterdems(
# reference.values,
# self._obj.values,
# reference.x.values,
# reference.y.values,
# stable_mask.values,
# resolution,
# max_horiz_offset=max_horiz_offset,
# rmse_step_thresh=rmse_step_thresh,
# max_iterations=max_iterations,
# verbose=verbose,
# )

new_dem_array, metadata_dict = coregister(
self._obj.values,
reference.values,
Expand Down
215 changes: 1 addition & 214 deletions src/pdemtools/_coreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def coregister(
perr = np.full((3, 1), np.nan)
d0 = np.nan
status = "failed"
points_n = None

else:
status = "coregistered"
Expand Down Expand Up @@ -296,220 +297,6 @@ def coregister(
return dem2out, metadata_dict


# def coregisterdems(
# dem1, # Reference DEM
# dem2, # DEM to be coregistered
# x,
# y,
# mask,
# res,
# max_horiz_offset=50,
# rmse_step_thresh=-0.001,
# max_iterations=5,
# verbose=True,
# ):
# """
# Simplified version of Erik Husby's coregisterdems() Python function.

# INPUTS:
# dem_1, dem_2: 2D arrays (of same shape) of dems. dem2 is the dem to be coregistered
# mask: mask of regions to be used in coregistration process (1=VALID FOR COREGISTRATION)

# OUTPUTS:
# trans: the [dz,dx,dy] transformation parameter
# trans_err: 1-sigma errors of trans
# rms: root mean square of the transformation in the vertical from the residuals

# If the registration fails due to lack of overlap, NaNs are returned in p and perr.
# If the registration fails to converge or exceeds the maximum shift, the median
# vertical offset is applied.

# """

# # Verbose print lambda function
# print_verbose = lambda msg: print(msg) if verbose else None

# print_verbose("THIS IS THE OLD COREGISTERDEMS FUNCTION - WE SHOULD MOVE TO A SINGLE COREGISTER FUNCTION")

# # initial trans and RMSE settings
# p = np.zeros((3, 1)) # p is prior iteration trans var
# pn = p.copy() # pn is current iteration trans var
# perr = np.zeros((3, 1)) # perr is prior iteration regression errors
# pnerr = perr.copy() # pnerr is current iteration regression errors
# d0 = np.inf # initial RMSE

# # Edge case markers
# meddz = None
# return_meddz = False
# critical_failure = False

# it = 0
# while True:
# it += 1
# print_verbose(f"Planimetric Correction Iteration {it}")

# print_verbose(f"Offset (z,x,y): {pn[0, 0]:.3f}, {pn[1, 0]:.3f}, {pn[2, 0]:.3f}")
# # print(f"pn: {pn}")

# # Break loop if conditions reached
# if np.any(np.abs(pn[1:]) > max_horiz_offset):
# print(
# f"Maximum horizontal offset ({max_horiz_offset}) exceeded."
# "Consider raising the threshold if offsets are large."
# )
# return_meddz = True
# break

# # Apply offsets
# if pn[1] != 0 and pn[2] != 0:
# dem2n = shift_dem(dem2, pn.T[0], x, y, verbose=verbose).astype("float32")
# else:
# dem2n = dem2 - pn[0].astype("float32")

# # # Calculate slopes - original method from PGC
# # sy, sx = np.gradient(dem2n, res)
# # sx = -sx

# print(type(dem2n))

# # Calculate slope - using Florinsky slope method (p = sx, q = sy)
# sy = q_f(dem2n, res)
# sx = p_f(dem2n, res)
# sy = -sy
# sx = -sx

# # Difference grids.
# dz = dem2n - dem1

# # Mask (in full script, both m1 and m2 are applied)
# dz[mask == 0] = np.nan

# # If no overlap between scenes, break the loop
# if np.all(np.isnan(dz)):
# print("No overlapping data between DEMs")
# critical_failure = True
# break

# # Filter NaNs and outliers.
# n = (
# ~np.isnan(sx)
# & ~np.isnan(sy)
# & (np.abs(dz - np.nanmedian(dz)) <= 3 * np.nanstd(dz))
# )
# n_count = np.count_nonzero(n)

# if n_count < 10:
# print(f"Too few ({n_count}) registration points: 10 required")
# critical_failure = True
# break

# # Get RMSE
# d1 = np.sqrt(np.mean(np.power(dz[n], 2)))
# print_verbose(f"RMSE = {d1}")

# # Keep median dz if first iteration.
# if it == 1:
# meddz = np.median(dz[n])
# meddz_err = np.std(dz[n] / np.sqrt(n_count))
# d00 = np.sqrt(np.mean(np.power(dz[n] - meddz, 2)))

# # Get improvement in RMSE
# rmse_step = d1 - d0 # initial d0 == inf

# # break if rmse above threshold
# if rmse_step > rmse_step_thresh or np.isnan(d0):
# print_verbose(
# f"RMSE step in this iteration ({rmse_step:.5f}) is above threshold "
# f"({rmse_step_thresh}), stopping and returning values of prior iteration."
# )
# # If fails after first registration attempt,
# # set dx and dy to zero and subtract the median offset.
# if it == 2:
# print("Second iteration regression failure")
# return_meddz = True
# break
# elif it == max_iterations:
# print_verbose(f"Maximum number of iterations ({max_iterations}) reached")
# break

# # Keep this adjustment.
# dem2out = dem2n.copy()
# p = pn.copy()
# perr = pnerr.copy()
# d0 = d1

# # Build design matrix.
# X = np.column_stack((np.ones(n_count, dtype=np.float32), sx[n], sy[n]))
# sx, sy = None, None # release for data amangement

# # Solve for new adjustment.
# p1 = np.reshape(np.linalg.lstsq(X, dz[n], rcond=None)[0], (-1, 1))

# # Calculate p errors.
# _, R = np.linalg.qr(X)
# RI = np.linalg.lstsq(R, np.identity(3, dtype=np.float32), rcond=None)[0]
# nu = X.shape[0] - X.shape[1] # residual degrees of freedom
# yhat = np.matmul(X, p1) # predicted responses at each data point
# r = dz[n] - yhat.T[0] # residuals
# normr = np.linalg.norm(r)

# dz = None # release for memory managment

# rmse = normr / np.sqrt(nu)
# tval = stats.t.ppf((1 - 0.32 / 2), nu)

# se = rmse * np.sqrt(np.sum(np.square(np.abs(RI)), axis=1, keepdims=True))
# p1err = tval * se

# # Update shifts.
# pn = p + p1
# pnerr = np.sqrt(np.square(perr) + np.square(p1err))

# # END OF LOOP

# if return_meddz:
# print(f"Returning median vertical offset: {meddz:.3f}")
# dem2out = dem2 - meddz
# p = np.array([[meddz, 0, 0]]).T
# perr = np.array([[meddz_err, 0, 0]]).T
# d0 = d00
# status = "dz_only"

# elif critical_failure:
# print("Regression critical failure, returning original DEM, NaN trans, and RMSE")
# dem2out = dem2
# p = np.full((3, 1), np.nan)
# perr = np.full((3, 1), np.nan)
# d0 = np.nan
# status = "failed"

# else:
# status = "coregistered"

# print(f"Final offset (z,x,y): {p[0, 0]:.3f}, {p[1, 0]:.3f}, {p[2, 0]:.3f}")
# print(f"Final RMSE = {d0:.3f}")

# # Construct metadata:
# metadata_dict = {
# "coreg_status": status,
# "x_offset": p[1, 0],
# "y_offset": p[2, 0],
# "z_offset": p[0, 0],
# "x_offset_err": perr[1, 0],
# "y_offset_err": perr[2, 0],
# "z_offset_err": perr[0, 0],
# "rmse": d0,
# }
# # Convert all numerical values to regular Python floats
# metadata_dict = {
# key: float(value) if isinstance(value, (np.float64, np.float32)) else value
# for key, value in metadata_dict.items()
# }

# # Return
# return dem2out, metadata_dict # p.T[0], perr.T[0], d0


def shift_dem(dem, trans, x, y, verbose=True):
"""
Shifts DEM according to translation factors ascertained in coregisterdems function
Expand Down

0 comments on commit 03fd1a8

Please sign in to comment.