diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0235eab1..b2a496f6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,11 +32,14 @@ jobs: - name: Test with pytest run: python -m pytest - # Add Black formatter - - name: Install Black formatter - run: python -m pip install black - - - name: Check code formatting with Black - run: black . --check --diff - + linter_name: + name: Run Black formatter + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + - name: Run Black formatter + uses: rickstaa/action-black@v1 + with: + black_args: ". --check" diff --git a/pyproject.toml b/pyproject.toml index 58fb232c..542d4272 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ release = [ test = [ "pytest>=7.0.0,<8.1", + "pytest-mock", "pytest-cov", "pytest-emoji", "pytest-raises", diff --git a/src/wf_psf/data/training_preprocessing.py b/src/wf_psf/data/training_preprocessing.py index 61650f1d..5648bff9 100644 --- a/src/wf_psf/data/training_preprocessing.py +++ b/src/wf_psf/data/training_preprocessing.py @@ -9,90 +9,170 @@ import numpy as np import wf_psf.utils.utils as utils import tensorflow as tf -import tensorflow_addons as tfa import os -class TrainingDataHandler: - """Training Data Handler. +class DataHandler: + """Data Handler. - A class to manage training data. + This class manages loading and processing of training and testing data for use in machine learning models. + It provides methods to access and preprocess the data. Parameters ---------- - training_data_params: Recursive Namespace object + data_type: str + A string indicating type of data ("train" or "test"). + data_params: Recursive Namespace object Recursive Namespace object containing training data parameters + simPSF: PSFSimulator + An instance of the PSFSimulator class for simulating a PSF. + n_bins_lambda: int + The number of bins in wavelength. + init_flag: bool, optional + A flag indicating whether to perform initialization steps upon object creation. + If True (default), the dataset is loaded and processed. If False, initialization + steps are skipped, and manual initialization is required. + + Attributes + ---------- + dataset: dict + A dictionary containing the loaded dataset, including positions and stars/noisy_stars. simPSF: object - PSFSimulator instance + An instance of the SimPSFToolkit class for simulating PSF. n_bins_lambda: int - Number of bins in wavelength + The number of bins in wavelength. + sed_data: tf.Tensor + A TensorFlow tensor containing the SED data for training/testing. + init_flag: bool, optional + A flag used to control initialization steps. If True, initialization is performed + upon object creation. + """ - def __init__(self, training_data_params, simPSF, n_bins_lambda): - self.training_data_params = training_data_params - self.train_dataset = np.load( - os.path.join( - self.training_data_params.data_dir, self.training_data_params.file - ), + def __init__(self, data_type, data_params, simPSF, n_bins_lambda, init_flag=True): + self.data_params = data_params.__dict__[data_type] + self.dataset = None + self.simPSF = simPSF + self.n_bins_lambda = n_bins_lambda + self.sed_data = None + self.initialize(init_flag) + + def load_dataset(self): + """Load dataset. + + Load the dataset based on the specified data type. + + """ + self.dataset = np.load( + os.path.join(self.data_params.data_dir, self.data_params.file), allow_pickle=True, )[()] - self.train_dataset["positions"] = tf.convert_to_tensor( - self.train_dataset["positions"], dtype=tf.float32 + self.dataset["positions"] = tf.convert_to_tensor( + self.dataset["positions"], dtype=tf.float32 ) - self.train_dataset["noisy_stars"] = tf.convert_to_tensor( - self.train_dataset["noisy_stars"], dtype=tf.float32 - ) - self.simPSF = simPSF - self.n_bins_lambda = n_bins_lambda + if "train" in self.data_params.file: + self.dataset["noisy_stars"] = tf.convert_to_tensor( + self.dataset["noisy_stars"], dtype=tf.float32 + ) + elif "test" in self.data_params.file: + self.dataset["stars"] = tf.convert_to_tensor( + self.dataset["stars"], dtype=tf.float32 + ) + + def process_sed_data(self): + """Process SED Data. + + A method to generate and process SED data. + + """ self.sed_data = [ utils.generate_SED_elems_in_tensorflow( _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 ) - for _sed in self.train_dataset["SEDs"] + for _sed in self.dataset["SEDs"] ] self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32) self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) + def initialize(self, init_flag): + """Initialize. + + Initialize the DataHandler instance by loading and processing the dataset, + if the init_flag is True. -class TestDataHandler: - """Test Data Handler. + Parameters + ---------- + init_flag : bool + A flag indicating whether to perform initialization steps. If True, + the dataset is loaded and processed. If False, initialization steps + are skipped. - A class to handle test data for model validation. + """ + if init_flag: + self.load_dataset() + self.process_sed_data() + + +def get_obs_positions(data): + """Get observed positions from the provided dataset. + + This method concatenates the positions of the stars from both the training + and test datasets to obtain the observed positions. Parameters ---------- - test_data_params: Recursive Namespace object - Recursive Namespace object containing test data parameters - simPSF: object - PSFSimulator instance - n_bins_lambda: int - Number of bins in wavelength + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + tf.Tensor + Tensor containing the observed positions of the stars. + + Notes + ----- + The observed positions are obtained by concatenating the positions of stars + from both the training and test datasets along the 0th axis. """ + obs_positions = np.concatenate( + ( + data.training_data.dataset["positions"], + data.test_data.dataset["positions"], + ), + axis=0, + ) + return tf.convert_to_tensor(obs_positions, dtype=tf.float32) - def __init__(self, test_data_params, simPSF, n_bins_lambda): - self.test_data_params = test_data_params - self.test_dataset = np.load( - os.path.join(self.test_data_params.data_dir, self.test_data_params.file), - allow_pickle=True, - )[()] - self.test_dataset["stars"] = tf.convert_to_tensor( - self.test_dataset["stars"], dtype=tf.float32 - ) - self.test_dataset["positions"] = tf.convert_to_tensor( - self.test_dataset["positions"], dtype=tf.float32 - ) - # Prepare validation data inputs - self.simPSF = simPSF - self.n_bins_lambda = n_bins_lambda +def get_zernike_prior(data): + """Get Zernike priors from the provided dataset. - self.sed_data = [ - utils.generate_SED_elems_in_tensorflow( - _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 - ) - for _sed in self.test_dataset["SEDs"] - ] - self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32) - self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) + This method concatenates the Zernike priors from both the training + and test datasets. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + tf.Tensor + Tensor containing the observed positions of the stars. + + Notes + ----- + The Zernike prior are obtained by concatenating the Zernike priors + from both the training and test datasets along the 0th axis. + + """ + zernike_prior = np.concatenate( + ( + data.training_data.dataset["zernike_prior"], + data.test_data.dataset["zernike_prior"], + ), + axis=0, + ) + return tf.convert_to_tensor(zernike_prior, dtype=tf.float32) diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index 83e7ee8b..9f5a1039 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt import galsim as gs import wf_psf.utils.utils as utils -from wf_psf.psf_models.tf_psf_field import build_PSF_model +from wf_psf.psf_models.psf_models import build_PSF_model from wf_psf.psf_models import tf_psf_field as psf_field from wf_psf.sims import psf_simulator as psf_simulator import logging @@ -13,7 +13,7 @@ def compute_poly_metric( tf_semiparam_field, - GT_tf_semiparam_field, + gt_tf_semiparam_field, simPSF_np, tf_pos, tf_SEDs, @@ -25,7 +25,7 @@ def compute_poly_metric( """Calculate metrics for polychromatic reconstructions. The ``tf_semiparam_field`` should be the model to evaluate, and the - ``GT_tf_semiparam_field`` should be loaded with the ground truth PSF field. + ``gt_tf_semiparam_field`` should be loaded with the ground truth PSF field. Relative values returned in [%] (so multiplied by 100). @@ -33,8 +33,8 @@ def compute_poly_metric( ---------- tf_semiparam_field: PSF field object Trained model to evaluate. - GT_tf_semiparam_field: PSF field object - Ground truth model to produce GT observations at any position + gt_tf_semiparam_field: PSF field object + Ground truth model to produce gt observations at any position and wavelength. simPSF_np: PSF simulator object Simulation object to be used by ``generate_packed_elems`` function. @@ -51,7 +51,7 @@ def compute_poly_metric( dataset_dict: dict Dictionary containing the dataset information. If provided, and if the `'stars'` key is present, the noiseless stars from the dataset are used to compute the metrics. - Otherwise, the stars are generated from the GT model. + Otherwise, the stars are generated from the gt model. Default is `None`. Returns @@ -78,15 +78,15 @@ def compute_poly_metric( # Model prediction preds = tf_semiparam_field.predict(x=pred_inputs, batch_size=batch_size) - # GT data preparation + # gt data preparation if dataset_dict is None or "stars" not in dataset_dict: - logger.info("Regenerating GT stars from model.") - # Change interpolation parameters for the GT simPSF + logger.info("Regenerating gt stars from model.") + # Change interpolation parameters for the gt simPSF interp_pts_per_bin = simPSF_np.SED_interp_pts_per_bin simPSF_np.SED_interp_pts_per_bin = 0 SED_sigma = simPSF_np.SED_sigma simPSF_np.SED_sigma = 0 - # Generate SED data list for GT model + # Generate SED data list for gt model packed_SED_data = [ utils.generate_packed_elems(_sed, simPSF_np, n_bins=n_bins_gt) for _sed in tf_SEDs @@ -95,24 +95,24 @@ def compute_poly_metric( tf_packed_SED_data = tf.transpose(tf_packed_SED_data, perm=[0, 2, 1]) pred_inputs = [tf_pos, tf_packed_SED_data] - # GT model prediction - GT_preds = GT_tf_semiparam_field.predict(x=pred_inputs, batch_size=batch_size) + # Ground Truth model prediction + gt_preds = gt_tf_semiparam_field.predict(x=pred_inputs, batch_size=batch_size) else: - logger.info("Using GT stars from dataset.") - GT_preds = dataset_dict["stars"] + logger.info("Using Ground Truth stars from dataset.") + gt_preds = dataset_dict["stars"] # Calculate residuals - residuals = np.sqrt(np.mean((GT_preds - preds) ** 2, axis=(1, 2))) - GT_star_mean = np.sqrt(np.mean((GT_preds) ** 2, axis=(1, 2))) + residuals = np.sqrt(np.mean((gt_preds - preds) ** 2, axis=(1, 2))) + gt_star_mean = np.sqrt(np.mean((gt_preds) ** 2, axis=(1, 2))) # RMSE calculations rmse = np.mean(residuals) - rel_rmse = 100.0 * np.mean(residuals / GT_star_mean) + rel_rmse = 100.0 * np.mean(residuals / gt_star_mean) # STD calculations std_rmse = np.std(residuals) - std_rel_rmse = 100.0 * np.std(residuals / GT_star_mean) + std_rel_rmse = 100.0 * np.std(residuals / gt_star_mean) # Print RMSE values logger.info("Absolute RMSE:\t %.4e \t +/- %.4e" % (rmse, std_rmse)) @@ -123,7 +123,7 @@ def compute_poly_metric( def compute_mono_metric( tf_semiparam_field, - GT_tf_semiparam_field, + gt_tf_semiparam_field, simPSF_np, tf_pos, lambda_list, @@ -132,7 +132,7 @@ def compute_mono_metric( """Calculate metrics for monochromatic reconstructions. The ``tf_semiparam_field`` should be the model to evaluate, and the - ``GT_tf_semiparam_field`` should be loaded with the ground truth PSF field. + ``gt_tf_semiparam_field`` should be loaded with the ground truth PSF field. Relative values returned in [%] (so multiplied by 100). @@ -140,8 +140,8 @@ def compute_mono_metric( ---------- tf_semiparam_field: PSF field object Trained model to evaluate. - GT_tf_semiparam_field: PSF field object - Ground truth model to produce GT observations at any position + gt_tf_semiparam_field: PSF field object + Ground truth model to produce gt observations at any position and wavelength. simPSF_np: PSF simulator object Simulation object capable of calculating ``phase_N`` values from @@ -181,7 +181,7 @@ def compute_mono_metric( phase_N = simPSF_np.feasible_N(lambda_obs) residuals = np.zeros((total_samples)) - GT_star_mean = np.zeros((total_samples)) + gt_star_mean = np.zeros((total_samples)) # Total number of epochs n_epochs = int(np.ceil(total_samples / batch_size)) @@ -196,7 +196,7 @@ def compute_mono_metric( batch_pos = tf_pos[ep_low_lim:ep_up_lim, :] # Estimate the monochromatic PSFs - GT_mono_psf = GT_tf_semiparam_field.predict_mono_psfs( + gt_mono_psf = gt_tf_semiparam_field.predict_mono_psfs( input_positions=batch_pos, lambda_obs=lambda_obs, phase_N=phase_N ) @@ -204,13 +204,13 @@ def compute_mono_metric( input_positions=batch_pos, lambda_obs=lambda_obs, phase_N=phase_N ) - num_pixels = GT_mono_psf.shape[1] * GT_mono_psf.shape[2] + num_pixels = gt_mono_psf.shape[1] * gt_mono_psf.shape[2] residuals[ep_low_lim:ep_up_lim] = ( - np.sum((GT_mono_psf - model_mono_psf) ** 2, axis=(1, 2)) / num_pixels + np.sum((gt_mono_psf - model_mono_psf) ** 2, axis=(1, 2)) / num_pixels ) - GT_star_mean[ep_low_lim:ep_up_lim] = ( - np.sum((GT_mono_psf) ** 2, axis=(1, 2)) / num_pixels + gt_star_mean[ep_low_lim:ep_up_lim] = ( + np.sum((gt_mono_psf) ** 2, axis=(1, 2)) / num_pixels ) # Increase lower limit @@ -218,20 +218,20 @@ def compute_mono_metric( # Calculate residuals residuals = np.sqrt(residuals) - GT_star_mean = np.sqrt(GT_star_mean) + gt_star_mean = np.sqrt(gt_star_mean) # RMSE calculations rmse_lda.append(np.mean(residuals)) - rel_rmse_lda.append(100.0 * np.mean(residuals / GT_star_mean)) + rel_rmse_lda.append(100.0 * np.mean(residuals / gt_star_mean)) # STD calculations std_rmse_lda.append(np.std(residuals)) - std_rel_rmse_lda.append(100.0 * np.std(residuals / GT_star_mean)) + std_rel_rmse_lda.append(100.0 * np.std(residuals / gt_star_mean)) return rmse_lda, rel_rmse_lda, std_rmse_lda, std_rel_rmse_lda -def compute_opd_metrics(tf_semiparam_field, GT_tf_semiparam_field, pos, batch_size=16): +def compute_opd_metrics(tf_semiparam_field, gt_tf_semiparam_field, pos, batch_size=16): """Compute the OPD metrics. Need to handle a batch size to avoid Out-Of-Memory errors with @@ -246,8 +246,8 @@ def compute_opd_metrics(tf_semiparam_field, GT_tf_semiparam_field, pos, batch_si ---------- tf_semiparam_field: PSF field object Trained model to evaluate. - GT_tf_semiparam_field: PSF field object - Ground truth model to produce GT observations at any position + gt_tf_semiparam_field: PSF field object + Ground truth model to produce gt observations at any position and wavelength. pos: numpy.ndarray [batch x 2] Positions at where to predict the OPD maps. @@ -267,13 +267,13 @@ def compute_opd_metrics(tf_semiparam_field, GT_tf_semiparam_field, pos, batch_si """ # Get OPD obscurations - np_obscurations = np.real(GT_tf_semiparam_field.obscurations.numpy()) + np_obscurations = np.real(gt_tf_semiparam_field.obscurations.numpy()) # Define total number of samples n_samples = pos.shape[0] # Initialise batch variables opd_batch = None - GT_opd_batch = None + gt_opd_batch = None counter = 0 # Initialise result lists rmse_vals = np.zeros(n_samples) @@ -290,13 +290,13 @@ def compute_opd_metrics(tf_semiparam_field, GT_tf_semiparam_field, pos, batch_si batch_pos = pos[counter:end_sample, :] # We calculate a batch of OPDs opd_batch = tf_semiparam_field.predict_opd(batch_pos).numpy() - GT_opd_batch = GT_tf_semiparam_field.predict_opd(batch_pos).numpy() + gt_opd_batch = gt_tf_semiparam_field.predict_opd(batch_pos).numpy() # Remove the mean of the OPD opd_batch -= np.mean(opd_batch, axis=(1, 2)).reshape(-1, 1, 1) - GT_opd_batch -= np.mean(GT_opd_batch, axis=(1, 2)).reshape(-1, 1, 1) + gt_opd_batch -= np.mean(gt_opd_batch, axis=(1, 2)).reshape(-1, 1, 1) # Obscure the OPDs opd_batch *= np_obscurations - GT_opd_batch *= np_obscurations + gt_opd_batch *= np_obscurations # Generate obscuration mask obsc_mask = np_obscurations > 0 nb_mask_elems = np.sum(obsc_mask) @@ -305,18 +305,18 @@ def compute_opd_metrics(tf_semiparam_field, GT_tf_semiparam_field, pos, batch_si np.array( [ np.sum((im1[obsc_mask] - im2[obsc_mask]) ** 2) / nb_mask_elems - for im1, im2 in zip(opd_batch, GT_opd_batch) + for im1, im2 in zip(opd_batch, gt_opd_batch) ] ) ) - GT_opd_mean = np.sqrt( + gt_opd_mean = np.sqrt( np.array( - [np.sum(im2[obsc_mask] ** 2) / nb_mask_elems for im2 in GT_opd_batch] + [np.sum(im2[obsc_mask] ** 2) / nb_mask_elems for im2 in gt_opd_batch] ) ) # RMSE calculations rmse_vals[counter:end_sample] = res_opd - rel_rmse_vals[counter:end_sample] = 100.0 * (res_opd / GT_opd_mean) + rel_rmse_vals[counter:end_sample] = 100.0 * (res_opd / gt_opd_mean) # Add the results to the lists counter += batch_size @@ -336,7 +336,7 @@ def compute_opd_metrics(tf_semiparam_field, GT_tf_semiparam_field, pos, batch_si def compute_shape_metrics( tf_semiparam_field, - GT_tf_semiparam_field, + gt_tf_semiparam_field, simPSF_np, SEDs, tf_pos, @@ -357,8 +357,8 @@ def compute_shape_metrics( ---------- tf_semiparam_field: PSF field object Trained model to evaluate. - GT_tf_semiparam_field: PSF field object - Ground truth model to produce GT observations at any position + gt_tf_semiparam_field: PSF field object + Ground truth model to produce gt observations at any position and wavelength. simPSF_np: SEDs: numpy.ndarray [batch x SED_samples x 2] @@ -389,7 +389,7 @@ def compute_shape_metrics( dataset_dict: dict Dictionary containing the dataset information. If provided, and if the `'super_res_stars'` key is present, the noiseless super resolved stars from the dataset are used to compute - the metrics. Otherwise, the stars are generated from the GT model. + the metrics. Otherwise, the stars are generated from the gt model. Default is `None`. Returns @@ -401,16 +401,16 @@ def compute_shape_metrics( # Save original output_Q and output_dim original_out_Q = tf_semiparam_field.output_Q original_out_dim = tf_semiparam_field.output_dim - GT_original_out_Q = GT_tf_semiparam_field.output_Q - GT_original_out_dim = GT_tf_semiparam_field.output_dim + gt_original_out_Q = gt_tf_semiparam_field.output_Q + gt_original_out_dim = gt_tf_semiparam_field.output_dim # Set the required output_Q and output_dim parameters in the models tf_semiparam_field.set_output_Q(output_Q=output_Q, output_dim=output_dim) - GT_tf_semiparam_field.set_output_Q(output_Q=output_Q, output_dim=output_dim) + gt_tf_semiparam_field.set_output_Q(output_Q=output_Q, output_dim=output_dim) # Need to compile the models again tf_semiparam_field = build_PSF_model(tf_semiparam_field) - GT_tf_semiparam_field = build_PSF_model(GT_tf_semiparam_field) + gt_tf_semiparam_field = build_PSF_model(gt_tf_semiparam_field) # Generate SED data list packed_SED_data = [ @@ -425,19 +425,19 @@ def compute_shape_metrics( # PSF model predictions = tf_semiparam_field.predict(x=pred_inputs, batch_size=batch_size) - # GT data preparation + # gt data preparation if ( dataset_dict is None or "super_res_stars" not in dataset_dict or "SR_stars" not in dataset_dict ): - logger.info("Generating GT super resolved stars from the GT model.") - # Change interpolation parameters for the GT simPSF + logger.info("Generating gt super resolved stars from the gt model.") + # Change interpolation parameters for the gt simPSF interp_pts_per_bin = simPSF_np.SED_interp_pts_per_bin simPSF_np.SED_interp_pts_per_bin = 0 SED_sigma = simPSF_np.SED_sigma simPSF_np.SED_sigma = 0 - # Generate SED data list for GT model + # Generate SED data list for gt model packed_SED_data = [ utils.generate_packed_elems(_sed, simPSF_np, n_bins=n_bins_gt) for _sed in SEDs @@ -449,32 +449,32 @@ def compute_shape_metrics( pred_inputs = [tf_pos, tf_packed_SED_data] # Ground Truth model - GT_predictions = GT_tf_semiparam_field.predict( + gt_predictions = gt_tf_semiparam_field.predict( x=pred_inputs, batch_size=batch_size ) else: logger.info("Using super resolved stars from dataset.") if "super_res_stars" in dataset_dict: - GT_predictions = dataset_dict["super_res_stars"] + gt_predictions = dataset_dict["super_res_stars"] elif "SR_stars" in dataset_dict: - GT_predictions = dataset_dict["SR_stars"] + gt_predictions = dataset_dict["SR_stars"] # Calculate residuals - residuals = np.sqrt(np.mean((GT_predictions - predictions) ** 2, axis=(1, 2))) - GT_star_mean = np.sqrt(np.mean((GT_predictions) ** 2, axis=(1, 2))) + residuals = np.sqrt(np.mean((gt_predictions - predictions) ** 2, axis=(1, 2))) + gt_star_mean = np.sqrt(np.mean((gt_predictions) ** 2, axis=(1, 2))) # Pixel RMSE for each star if opt_stars_rel_pix_rmse: - stars_rel_pix_rmse = 100.0 * residuals / GT_star_mean + stars_rel_pix_rmse = 100.0 * residuals / gt_star_mean # RMSE calculations pix_rmse = np.mean(residuals) - rel_pix_rmse = 100.0 * np.mean(residuals / GT_star_mean) + rel_pix_rmse = 100.0 * np.mean(residuals / gt_star_mean) # STD calculations pix_rmse_std = np.std(residuals) - rel_pix_rmse_std = 100.0 * np.std(residuals / GT_star_mean) + rel_pix_rmse_std = 100.0 * np.std(residuals / gt_star_mean) # Print pixel RMSE values logger.info( @@ -491,40 +491,40 @@ def compute_shape_metrics( ] # Measure shapes of the reconstructions - GT_pred_moments = [ + gt_pred_moments = [ gs.hsm.FindAdaptiveMom(gs.Image(_pred), strict=False) - for _pred in GT_predictions + for _pred in gt_predictions ] pred_e1_HSM, pred_e2_HSM, pred_R2_HSM = [], [], [] - GT_pred_e1_HSM, GT_pred_e2_HSM, GT_pred_R2_HSM = [], [], [] + gt_pred_e1_HSM, gt_pred_e2_HSM, gt_pred_R2_HSM = [], [], [] - for it in range(len(GT_pred_moments)): + for it in range(len(gt_pred_moments)): if ( pred_moments[it].moments_status == 0 - and GT_pred_moments[it].moments_status == 0 + and gt_pred_moments[it].moments_status == 0 ): pred_e1_HSM.append(pred_moments[it].observed_shape.g1) pred_e2_HSM.append(pred_moments[it].observed_shape.g2) pred_R2_HSM.append(2 * (pred_moments[it].moments_sigma ** 2)) - GT_pred_e1_HSM.append(GT_pred_moments[it].observed_shape.g1) - GT_pred_e2_HSM.append(GT_pred_moments[it].observed_shape.g2) - GT_pred_R2_HSM.append(2 * (GT_pred_moments[it].moments_sigma ** 2)) + gt_pred_e1_HSM.append(gt_pred_moments[it].observed_shape.g1) + gt_pred_e2_HSM.append(gt_pred_moments[it].observed_shape.g2) + gt_pred_R2_HSM.append(2 * (gt_pred_moments[it].moments_sigma ** 2)) pred_e1_HSM = np.array(pred_e1_HSM) pred_e2_HSM = np.array(pred_e2_HSM) pred_R2_HSM = np.array(pred_R2_HSM) - GT_pred_e1_HSM = np.array(GT_pred_e1_HSM) - GT_pred_e2_HSM = np.array(GT_pred_e2_HSM) - GT_pred_R2_HSM = np.array(GT_pred_R2_HSM) + gt_pred_e1_HSM = np.array(gt_pred_e1_HSM) + gt_pred_e2_HSM = np.array(gt_pred_e2_HSM) + gt_pred_R2_HSM = np.array(gt_pred_R2_HSM) # Calculate metrics # e1 - e1_res = GT_pred_e1_HSM - pred_e1_HSM - e1_res_rel = (GT_pred_e1_HSM - pred_e1_HSM) / GT_pred_e1_HSM + e1_res = gt_pred_e1_HSM - pred_e1_HSM + e1_res_rel = (gt_pred_e1_HSM - pred_e1_HSM) / gt_pred_e1_HSM rmse_e1 = np.sqrt(np.mean(e1_res**2)) rel_rmse_e1 = 100.0 * np.sqrt(np.mean(e1_res_rel**2)) @@ -532,8 +532,8 @@ def compute_shape_metrics( std_rel_rmse_e1 = 100.0 * np.std(e1_res_rel) # e2 - e2_res = GT_pred_e2_HSM - pred_e2_HSM - e2_res_rel = (GT_pred_e2_HSM - pred_e2_HSM) / GT_pred_e2_HSM + e2_res = gt_pred_e2_HSM - pred_e2_HSM + e2_res_rel = (gt_pred_e2_HSM - pred_e2_HSM) / gt_pred_e2_HSM rmse_e2 = np.sqrt(np.mean(e2_res**2)) rel_rmse_e2 = 100.0 * np.sqrt(np.mean(e2_res_rel**2)) @@ -541,10 +541,10 @@ def compute_shape_metrics( std_rel_rmse_e2 = 100.0 * np.std(e2_res_rel) # R2 - R2_res = GT_pred_R2_HSM - pred_R2_HSM + R2_res = gt_pred_R2_HSM - pred_R2_HSM - rmse_R2_meanR2 = np.sqrt(np.mean(R2_res**2)) / np.mean(GT_pred_R2_HSM) - std_rmse_R2_meanR2 = np.std(R2_res / GT_pred_R2_HSM) + rmse_R2_meanR2 = np.sqrt(np.mean(R2_res**2)) / np.mean(gt_pred_R2_HSM) + std_rmse_R2_meanR2 = np.std(R2_res / gt_pred_R2_HSM) # Print shape/size errors logger.info("\nsigma(e1) RMSE =\t\t %.4e \t +/- %.4e " % (rmse_e1, std_rmse_e1)) @@ -564,32 +564,32 @@ def compute_shape_metrics( ) # Print number of stars - logger.info("\nTotal number of stars: \t\t%d" % (len(GT_pred_moments))) + logger.info("\nTotal number of stars: \t\t%d" % (len(gt_pred_moments))) logger.info( "Problematic number of stars: \t%d" - % (len(GT_pred_moments) - GT_pred_e1_HSM.shape[0]) + % (len(gt_pred_moments) - gt_pred_e1_HSM.shape[0]) ) # Re-et the original output_Q and output_dim parameters in the models tf_semiparam_field.set_output_Q( output_Q=original_out_Q, output_dim=original_out_dim ) - GT_tf_semiparam_field.set_output_Q( - output_Q=GT_original_out_Q, output_dim=GT_original_out_dim + gt_tf_semiparam_field.set_output_Q( + output_Q=gt_original_out_Q, output_dim=gt_original_out_dim ) # Need to compile the models again tf_semiparam_field = build_PSF_model(tf_semiparam_field) - GT_tf_semiparam_field = build_PSF_model(GT_tf_semiparam_field) + gt_tf_semiparam_field = build_PSF_model(gt_tf_semiparam_field) # Moment results result_dict = { "pred_e1_HSM": pred_e1_HSM, "pred_e2_HSM": pred_e2_HSM, "pred_R2_HSM": pred_R2_HSM, - "GT_pred_e1_HSM": GT_pred_e1_HSM, - "GT_ped_e2_HSM": GT_pred_e2_HSM, - "GT_pred_R2_HSM": GT_pred_R2_HSM, + "gt_pred_e1_HSM": gt_pred_e1_HSM, + "gt_ped_e2_HSM": gt_pred_e2_HSM, + "gt_pred_R2_HSM": gt_pred_R2_HSM, "rmse_e1": rmse_e1, "std_rmse_e1": std_rmse_e1, "rel_rmse_e1": rel_rmse_e1, @@ -615,7 +615,7 @@ def compute_shape_metrics( return result_dict -def gen_GT_wf_model(test_wf_file_path, pred_output_Q=1, pred_output_dim=64): +def gen_gt_wf_model(test_wf_file_path, pred_output_Q=1, pred_output_dim=64): r"""Generate the ground truth model and output test PSF ar required resolution. If `pred_output_Q=1` the resolution will be 3 times the one of Euclid. @@ -630,7 +630,7 @@ def gen_GT_wf_model(test_wf_file_path, pred_output_Q=1, pred_output_dim=64): tf_test_pos = tf.convert_to_tensor(wf_test_pos, dtype=tf.float32) wf_test_SEDs = wf_test_dataset["SEDs"] - # Generate GT model + # Generate gt model batch_size = 16 # Generate Zernike maps @@ -673,7 +673,7 @@ def gen_GT_wf_model(test_wf_file_path, pred_output_Q=1, pred_output_dim=64): tf_zernike_cube = tf.convert_to_tensor(np_zernike_cube, dtype=tf.float32) # Initialize the model - GT_tf_semiparam_field = psf_field.TF_SemiParam_field( + gt_tf_semiparam_field = psf_field.TF_SemiParam_field( zernike_maps=tf_zernike_cube, obscurations=tf_obscurations, batch_size=batch_size, @@ -687,18 +687,18 @@ def gen_GT_wf_model(test_wf_file_path, pred_output_Q=1, pred_output_dim=64): ) # For the Ground truth model - GT_tf_semiparam_field.tf_poly_Z_field.assign_coeff_matrix(wf_test_C_poly) - _ = GT_tf_semiparam_field.tf_np_poly_opd.alpha_mat.assign( - tf.zeros_like(GT_tf_semiparam_field.tf_np_poly_opd.alpha_mat) + gt_tf_semiparam_field.tf_poly_Z_field.assign_coeff_matrix(wf_test_C_poly) + _ = gt_tf_semiparam_field.tf_np_poly_opd.alpha_mat.assign( + tf.zeros_like(gt_tf_semiparam_field.tf_np_poly_opd.alpha_mat) ) # Set required output_Q - GT_tf_semiparam_field.set_output_Q( + gt_tf_semiparam_field.set_output_Q( output_Q=pred_output_Q, output_dim=pred_output_dim ) - GT_tf_semiparam_field = psf_field.build_PSF_model(GT_tf_semiparam_field) + gt_tf_semiparam_field = psf_field.build_PSF_model(gt_tf_semiparam_field) packed_SED_data = [ utils.generate_packed_elems(_sed, simPSF_np, n_bins=wf_test_params["n_bins"]) @@ -711,9 +711,9 @@ def gen_GT_wf_model(test_wf_file_path, pred_output_Q=1, pred_output_dim=64): pred_inputs = [tf_test_pos, tf_packed_SED_data] # Ground Truth model - GT_predictions = GT_tf_semiparam_field.predict(x=pred_inputs, batch_size=batch_size) + gt_predictions = gt_tf_semiparam_field.predict(x=pred_inputs, batch_size=batch_size) - return GT_predictions, wf_test_pos, GT_tf_semiparam_field + return gt_predictions, wf_test_pos, gt_tf_semiparam_field ## Below this line there are DEPRECATED functions @@ -779,7 +779,7 @@ def compute_metrics( def compute_opd_metrics_mccd( - tf_semiparam_field, GT_tf_semiparam_field, test_pos, train_pos + tf_semiparam_field, gt_tf_semiparam_field, test_pos, train_pos ): """Compute the OPD metrics.""" @@ -794,19 +794,19 @@ def compute_opd_metrics_mccd( # OPD prediction opd_pred = tf.math.add(P_opd_pred, NP_opd_pred) - # GT model - GT_zernike_coeffs = GT_tf_semiparam_field.tf_poly_Z_field(test_pos) - GT_opd_maps = GT_tf_semiparam_field.tf_zernike_OPD(GT_zernike_coeffs) + # gt model + gt_zernike_coeffs = gt_tf_semiparam_field.tf_poly_Z_field(test_pos) + gt_opd_maps = gt_tf_semiparam_field.tf_zernike_OPD(gt_zernike_coeffs) # Compute residual and obscure the OPD - res_opd = (GT_opd_maps.numpy() - opd_pred.numpy()) * np_obscurations + res_opd = (gt_opd_maps.numpy() - opd_pred.numpy()) * np_obscurations # Calculate absolute RMSE values test_opd_rmse = np.sqrt(np.mean(res_opd**2)) # Calculate relative RMSE values relative_test_opd_rmse = test_opd_rmse / np.sqrt( - np.mean((GT_opd_maps.numpy() * np_obscurations) ** 2) + np.mean((gt_opd_maps.numpy() * np_obscurations) ** 2) ) # Print RMSE values @@ -824,19 +824,19 @@ def compute_opd_metrics_mccd( # OPD prediction opd_pred = tf.math.add(P_opd_pred, NP_opd_pred) - # GT model - GT_zernike_coeffs = GT_tf_semiparam_field.tf_poly_Z_field(train_pos) - GT_opd_maps = GT_tf_semiparam_field.tf_zernike_OPD(GT_zernike_coeffs) + # gt model + gt_zernike_coeffs = gt_tf_semiparam_field.tf_poly_Z_field(train_pos) + gt_opd_maps = gt_tf_semiparam_field.tf_zernike_OPD(gt_zernike_coeffs) # Compute residual and obscure the OPD - res_opd = (GT_opd_maps.numpy() - opd_pred.numpy()) * np_obscurations + res_opd = (gt_opd_maps.numpy() - opd_pred.numpy()) * np_obscurations # Calculate RMSE values train_opd_rmse = np.sqrt(np.mean(res_opd**2)) # Calculate relative RMSE values relative_train_opd_rmse = train_opd_rmse / np.sqrt( - np.mean((GT_opd_maps.numpy() * np_obscurations) ** 2) + np.mean((gt_opd_maps.numpy() * np_obscurations) ** 2) ) # Print RMSE values @@ -849,7 +849,7 @@ def compute_opd_metrics_mccd( def compute_opd_metrics_polymodel( - tf_semiparam_field, GT_tf_semiparam_field, test_pos, train_pos + tf_semiparam_field, gt_tf_semiparam_field, test_pos, train_pos ): """Compute the OPD metrics.""" @@ -864,19 +864,19 @@ def compute_opd_metrics_polymodel( # OPD prediction opd_pred = tf.math.add(P_opd_pred, NP_opd_pred) - # GT model - GT_zernike_coeffs = GT_tf_semiparam_field.tf_poly_Z_field(test_pos) - GT_opd_maps = GT_tf_semiparam_field.tf_zernike_OPD(GT_zernike_coeffs) + # gt model + gt_zernike_coeffs = gt_tf_semiparam_field.tf_poly_Z_field(test_pos) + gt_opd_maps = gt_tf_semiparam_field.tf_zernike_OPD(gt_zernike_coeffs) # Compute residual and obscure the OPD - res_opd = (GT_opd_maps.numpy() - opd_pred.numpy()) * np_obscurations + res_opd = (gt_opd_maps.numpy() - opd_pred.numpy()) * np_obscurations # Calculate RMSE values test_opd_rmse = np.sqrt(np.mean(res_opd**2)) # Calculate relative RMSE values relative_test_opd_rmse = test_opd_rmse / np.sqrt( - np.mean((GT_opd_maps.numpy() * np_obscurations) ** 2) + np.mean((gt_opd_maps.numpy() * np_obscurations) ** 2) ) # Print RMSE values @@ -894,19 +894,19 @@ def compute_opd_metrics_polymodel( # OPD prediction opd_pred = tf.math.add(P_opd_pred, NP_opd_pred) - # GT model - GT_zernike_coeffs = GT_tf_semiparam_field.tf_poly_Z_field(train_pos) - GT_opd_maps = GT_tf_semiparam_field.tf_zernike_OPD(GT_zernike_coeffs) + # gt model + gt_zernike_coeffs = gt_tf_semiparam_field.tf_poly_Z_field(train_pos) + gt_opd_maps = gt_tf_semiparam_field.tf_zernike_OPD(gt_zernike_coeffs) # Compute residual and obscure the OPD - res_opd = (GT_opd_maps.numpy() - opd_pred.numpy()) * np_obscurations + res_opd = (gt_opd_maps.numpy() - opd_pred.numpy()) * np_obscurations # Calculate RMSE values train_opd_rmse = np.sqrt(np.mean(res_opd**2)) # Calculate relative RMSE values relative_train_opd_rmse = train_opd_rmse / np.sqrt( - np.mean((GT_opd_maps.numpy() * np_obscurations) ** 2) + np.mean((gt_opd_maps.numpy() * np_obscurations) ** 2) ) # Pritn RMSE values @@ -919,7 +919,7 @@ def compute_opd_metrics_polymodel( def compute_opd_metrics_param_model( - tf_semiparam_field, GT_tf_semiparam_field, test_pos, train_pos + tf_semiparam_field, gt_tf_semiparam_field, test_pos, train_pos ): """Compute the OPD metrics.""" @@ -930,19 +930,19 @@ def compute_opd_metrics_param_model( zernike_coeffs = tf_semiparam_field.tf_poly_Z_field(test_pos) opd_pred = tf_semiparam_field.tf_zernike_OPD(zernike_coeffs) - # GT model - GT_zernike_coeffs = GT_tf_semiparam_field.tf_poly_Z_field(test_pos) - GT_opd_maps = GT_tf_semiparam_field.tf_zernike_OPD(GT_zernike_coeffs) + # gt model + gt_zernike_coeffs = gt_tf_semiparam_field.tf_poly_Z_field(test_pos) + gt_opd_maps = gt_tf_semiparam_field.tf_zernike_OPD(gt_zernike_coeffs) # Compute residual and obscure the OPD - res_opd = (GT_opd_maps.numpy() - opd_pred.numpy()) * np_obscurations + res_opd = (gt_opd_maps.numpy() - opd_pred.numpy()) * np_obscurations # Calculate absolute RMSE values test_opd_rmse = np.sqrt(np.mean(res_opd**2)) # Calculate relative RMSE values relative_test_opd_rmse = test_opd_rmse / np.sqrt( - np.mean((GT_opd_maps.numpy() * np_obscurations) ** 2) + np.mean((gt_opd_maps.numpy() * np_obscurations) ** 2) ) # Print RMSE values @@ -956,19 +956,19 @@ def compute_opd_metrics_param_model( zernike_coeffs = tf_semiparam_field.tf_poly_Z_field(train_pos) opd_pred = tf_semiparam_field.tf_zernike_OPD(zernike_coeffs) - # GT model - GT_zernike_coeffs = GT_tf_semiparam_field.tf_poly_Z_field(train_pos) - GT_opd_maps = GT_tf_semiparam_field.tf_zernike_OPD(GT_zernike_coeffs) + # gt model + gt_zernike_coeffs = gt_tf_semiparam_field.tf_poly_Z_field(train_pos) + gt_opd_maps = gt_tf_semiparam_field.tf_zernike_OPD(gt_zernike_coeffs) # Compute residual and obscure the OPD - res_opd = (GT_opd_maps.numpy() - opd_pred.numpy()) * np_obscurations + res_opd = (gt_opd_maps.numpy() - opd_pred.numpy()) * np_obscurations # Calculate RMSE values train_opd_rmse = np.sqrt(np.mean(res_opd**2)) # Calculate relative RMSE values relative_train_opd_rmse = train_opd_rmse / np.sqrt( - np.mean((GT_opd_maps.numpy() * np_obscurations) ** 2) + np.mean((gt_opd_maps.numpy() * np_obscurations) ** 2) ) # Print RMSE values @@ -980,7 +980,7 @@ def compute_opd_metrics_param_model( return test_opd_rmse, train_opd_rmse -def compute_one_opd_rmse(GT_tf_semiparam_field, tf_semiparam_field, pos, is_poly=False): +def compute_one_opd_rmse(gt_tf_semiparam_field, tf_semiparam_field, pos, is_poly=False): """Compute the OPD map for one position.""" np_obscurations = np.real(tf_semiparam_field.obscurations.numpy()) @@ -999,12 +999,12 @@ def compute_one_opd_rmse(GT_tf_semiparam_field, tf_semiparam_field, pos, is_poly # OPD prediction opd_pred = tf.math.add(P_opd_pred, NP_opd_pred) - # GT model - GT_zernike_coeffs = GT_tf_semiparam_field.tf_poly_Z_field(tf_pos) - GT_opd_maps = GT_tf_semiparam_field.tf_zernike_OPD(GT_zernike_coeffs) + # gt model + gt_zernike_coeffs = gt_tf_semiparam_field.tf_poly_Z_field(tf_pos) + gt_opd_maps = gt_tf_semiparam_field.tf_zernike_OPD(gt_zernike_coeffs) # Compute residual and obscure the OPD - res_opd = (GT_opd_maps.numpy() - opd_pred.numpy()) * np_obscurations + res_opd = (gt_opd_maps.numpy() - opd_pred.numpy()) * np_obscurations # Calculate RMSE values opd_rmse = np.sqrt(np.mean(res_opd**2)) @@ -1051,7 +1051,7 @@ def plot_function(mesh_pos, residual, tf_train_pos, tf_test_pos, title="Error"): def plot_residual_maps( - GT_tf_semiparam_field, + gt_tf_semiparam_field, tf_semiparam_field, simPSF_np, train_SEDs, @@ -1101,21 +1101,21 @@ def plot_residual_maps( # Predict mesh stars model_mesh_preds = tf_semiparam_field.predict(x=mesh_pred_inputs, batch_size=16) - GT_mesh_preds = GT_tf_semiparam_field.predict(x=mesh_pred_inputs, batch_size=16) + gt_mesh_preds = gt_tf_semiparam_field.predict(x=mesh_pred_inputs, batch_size=16) # Calculate pixel RMSE for each star pix_rmse = np.array( [ - np.sqrt(np.mean((_GT_pred - _model_pred) ** 2)) - for _GT_pred, _model_pred in zip(GT_mesh_preds, model_mesh_preds) + np.sqrt(np.mean((_gt_pred - _model_pred) ** 2)) + for _gt_pred, _model_pred in zip(gt_mesh_preds, model_mesh_preds) ] ) relative_pix_rmse = np.array( [ - np.sqrt(np.mean((_GT_pred - _model_pred) ** 2)) - / np.sqrt(np.mean((_GT_pred) ** 2)) - for _GT_pred, _model_pred in zip(GT_mesh_preds, model_mesh_preds) + np.sqrt(np.mean((_gt_pred - _model_pred) ** 2)) + / np.sqrt(np.mean((_gt_pred) ** 2)) + for _gt_pred, _model_pred in zip(gt_mesh_preds, model_mesh_preds) ] ) @@ -1136,7 +1136,7 @@ def plot_residual_maps( opd_rmse = np.array( [ compute_one_opd_rmse( - GT_tf_semiparam_field, tf_semiparam_field, _pos.reshape(1, -1), is_poly + gt_tf_semiparam_field, tf_semiparam_field, _pos.reshape(1, -1), is_poly ) for _pos in mesh_pos ] diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index 49b20a8a..4623c4f4 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -13,7 +13,7 @@ import tensorflow as tf import tensorflow_addons as tfa import wf_psf.data.training_preprocessing as training_preprocessing -from wf_psf.data.training_preprocessing import TrainingDataHandler, TestDataHandler +from wf_psf.data.training_preprocessing import DataHandler from wf_psf.psf_models import psf_models from wf_psf.metrics import metrics as wf_metrics import os @@ -23,7 +23,42 @@ logger = logging.getLogger(__name__) -def ground_truth_psf_model(metrics_params, coeff_matrix): +def create_ground_truth_psf_model(metrics_params, coeff_matrix): + """Create a Ground Truth PSF Model for metrics evaluation. + + This function creates a Ground Truth PSF Model instance specifically designed + for metrics evaluation purposes. It uses the provided configuration parameters + and coefficient matrix to initialize the model. + + Parameters + ---------- + metrics_params : RecursiveNamespace + Object storing the metric configuration parameters, including the model + parameters for the Ground Truth PSF Model. + coeff_matrix : Tensor or None + Coefficient matrix defining the parametric PSF field model. This matrix + is used to initialize the polynomial Zernike field of the PSF model. + + Returns + ------- + psf_model : Object + Class instance of the Ground Truth SemiParametric PSF model. + + Notes + ----- + The provided coefficient matrix initializes the polynomial Zernike field of + the Ground Truth PSF Model. The function also resets the alpha matrix of the + non-parametric polychromatic OPD to zeros. + + Example + ------- + metrics_params = load_metrics_params() + coeff_matrix = load_coeff_matrix() + ground_truth_model = create_ground_truth_psf_model(metrics_params, coeff_matrix) + + # Use the ground_truth_model for metrics evaluation + metrics_results = evaluate_metrics(ground_truth_model, ...) + """ psf_model = psf_models.get_psf_model( metrics_params.ground_truth_model.model_params, metrics_params.metrics_hparams, @@ -81,7 +116,7 @@ def evaluate_metrics_polychromatic_lowres(self, psf_model, simPSF, dataset): rmse, rel_rmse, std_rmse, std_rel_rmse = wf_metrics.compute_poly_metric( tf_semiparam_field=psf_model, - GT_tf_semiparam_field=ground_truth_psf_model( + gt_tf_semiparam_field=create_ground_truth_psf_model( self.metrics_params, dataset["C_poly"] ), simPSF_np=simPSF, @@ -131,7 +166,7 @@ def evaluate_metrics_mono_rmse(self, psf_model, simPSF, dataset): std_rel_rmse_lda, ) = wf_metrics.compute_mono_metric( tf_semiparam_field=psf_model, - GT_tf_semiparam_field=ground_truth_psf_model( + gt_tf_semiparam_field=create_ground_truth_psf_model( self.metrics_params, dataset["C_poly"] ), simPSF_np=simPSF, @@ -176,7 +211,7 @@ def evaluate_metrics_opd(self, psf_model, simPSF, dataset): rel_rmse_std_opd, ) = wf_metrics.compute_opd_metrics( tf_semiparam_field=psf_model, - GT_tf_semiparam_field=ground_truth_psf_model( + gt_tf_semiparam_field=create_ground_truth_psf_model( self.metrics_params, dataset["C_poly"] ), pos=dataset["positions"], @@ -215,7 +250,7 @@ def evaluate_metrics_shape(self, psf_model, simPSF, dataset): shape_results = wf_metrics.compute_shape_metrics( tf_semiparam_field=psf_model, - GT_tf_semiparam_field=ground_truth_psf_model( + gt_tf_semiparam_field=create_ground_truth_psf_model( self.metrics_params, dataset["C_poly"] ), simPSF_np=simPSF, @@ -235,8 +270,7 @@ def evaluate_metrics_shape(self, psf_model, simPSF, dataset): def evaluate_model( metrics_params, trained_model_params, - training_data, - test_data, + data, psf_model, weights_path, metrics_output, @@ -251,10 +285,8 @@ def evaluate_model( Recursive Namespace object containing metrics input parameters trained_model_params: Recursive Namespace object Recursive Namespace object containing trained model input parameters - training_data: object - TrainingDataHandler object - test_data: object - TestDataHandler object + data: DataHandler object + DataHandler object containing training and test data psf_model: object PSF model object weights_path: str @@ -277,7 +309,7 @@ def evaluate_model( ## Prepare models # Prepare np input - simPSF_np = training_data.simPSF + simPSF_np = data.training_data.simPSF ## Load the model's weights try: @@ -292,13 +324,13 @@ def evaluate_model( # Polychromatic star reconstructions poly_metric = metrics_handler.evaluate_metrics_polychromatic_lowres( - psf_model, simPSF_np, test_data.test_dataset + psf_model, simPSF_np, data.test_data.dataset ) # Monochromatic star reconstructions if metrics_params.eval_mono_metric_rmse: mono_metric = metrics_handler.evaluate_metrics_mono_rmse( - psf_model, simPSF_np, test_data.test_dataset + psf_model, simPSF_np, data.test_data.dataset ) else: mono_metric = None @@ -306,7 +338,7 @@ def evaluate_model( # OPD metrics if metrics_params.eval_opd_metric_rmse: opd_metric = metrics_handler.evaluate_metrics_opd( - psf_model, simPSF_np, test_data.test_dataset + psf_model, simPSF_np, data.test_data.dataset ) else: opd_metric = None @@ -316,7 +348,7 @@ def evaluate_model( "Computing polychromatic high-resolution metrics and shape metrics." ) shape_results_dict = metrics_handler.evaluate_metrics_shape( - psf_model, simPSF_np, test_data.test_dataset + psf_model, simPSF_np, data.test_data.dataset ) # Save metrics test_metrics = { @@ -333,13 +365,13 @@ def evaluate_model( logger.info("Computing polychromatic metrics at low resolution.") train_poly_metric = metrics_handler.evaluate_metrics_polychromatic_lowres( - psf_model, simPSF_np, training_data.train_dataset + psf_model, simPSF_np, data.training_data.dataset ) # Monochromatic star reconstructions turn into a class if metrics_params.eval_mono_metric_rmse: train_mono_metric = metrics_handler.evaluate_metrics_mono_rmse( - psf_model, simPSF_np, training_data.train_dataset + psf_model, simPSF_np, data.training_data.dataset ) else: train_mono_metric = None @@ -347,7 +379,7 @@ def evaluate_model( # OPD metrics turn into a class if metrics_params.eval_opd_metric_rmse: train_opd_metric = metrics_handler.evaluate_metrics_opd( - psf_model, simPSF_np, training_data.train_dataset + psf_model, simPSF_np, data.training_data.dataset ) else: train_opd_metric = None @@ -355,7 +387,7 @@ def evaluate_model( # Shape metrics turn into a class if metrics_params.eval_train_shape_sr_metric_rmse: train_shape_results_dict = metrics_handler.evaluate_metrics_shape( - psf_model, simPSF_np, training_data.train_dataset + psf_model, simPSF_np, data.training_data.dataset ) else: train_shape_results_dict = None diff --git a/src/wf_psf/psf_models/psf_model_parametric.py b/src/wf_psf/psf_models/psf_model_parametric.py index 3d7a0310..0816765d 100644 --- a/src/wf_psf/psf_models/psf_model_parametric.py +++ b/src/wf_psf/psf_models/psf_model_parametric.py @@ -12,20 +12,17 @@ from tensorflow.python.keras.engine import data_adapter from wf_psf.psf_models.psf_models import register_psfclass from wf_psf.psf_models.tf_layers import ( - TF_poly_Z_field, - TF_zernike_OPD, - TF_batch_poly_PSF, + TFPolynomialZernikeField, + TFZernikeOPD, + TFBatchPolychromaticPSF, + TFBatchMonochromaticPSF, + TFNonParametricPolynomialVariationsOPD, + TFPhysicalLayer, ) -from wf_psf.psf_models.tf_layers import ( - TF_NP_poly_OPD, - TF_batch_mono_PSF, - TF_physical_layer, -) -from wf_psf.utils.utils import PI_zernikes @register_psfclass -class TF_PSF_field_model(tf.keras.Model): +class TFParametricPSFFieldModel(tf.keras.Model): """Parametric PSF field model! Fully parametric model based on the Zernike polynomial basis. @@ -109,7 +106,7 @@ def __init__( self.l2_param = l2_param # Initialize the first layer - self.tf_poly_Z_field = TF_poly_Z_field( + self.tf_poly_Z_field = TFPolynomialZernikeField( x_lims=self.x_lims, y_lims=self.y_lims, n_zernikes=self.n_zernikes, @@ -117,10 +114,10 @@ def __init__( ) # Initialize the zernike to OPD layer - self.tf_zernike_OPD = TF_zernike_OPD(zernike_maps=zernike_maps) + self.tf_zernike_OPD = TFZernikeOPD(zernike_maps=zernike_maps) # Initialize the batch opd to batch polychromatic PSF layer - self.tf_batch_poly_PSF = TF_batch_poly_PSF( + self.tf_batch_poly_PSF = TFBatchPolychromaticPSF( obscurations=self.obscurations, output_Q=self.output_Q, output_dim=self.output_dim, @@ -154,7 +151,7 @@ def set_output_Q(self, output_Q, output_dim=None): if output_dim is not None: self.output_dim = output_dim # Reinitialize the PSF batch poly generator - self.tf_batch_poly_PSF = TF_batch_poly_PSF( + self.tf_batch_poly_PSF = TFBatchPolychromaticPSF( obscurations=self.obscurations, output_Q=self.output_Q, output_dim=self.output_dim, @@ -175,7 +172,7 @@ def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): """ # Initialise the monochromatic PSF batch calculator - tf_batch_mono_psf = TF_batch_mono_PSF( + tf_batch_mono_psf = TFBatchMonochromaticPSF( obscurations=self.obscurations, output_Q=self.output_Q, output_dim=self.output_dim, diff --git a/src/wf_psf/psf_models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/psf_model_physical_polychromatic.py new file mode 100644 index 00000000..4b5778b0 --- /dev/null +++ b/src/wf_psf/psf_models/psf_model_physical_polychromatic.py @@ -0,0 +1,641 @@ +"""PSF Model Physical Semi-Parametric Polychromatic. + +A module which defines the classes and methods +to manage the parameters of the psf physical polychromatic model. + +:Authors: Tobias Liaudat and Jennifer Pollack + +""" + +import numpy as np +import tensorflow as tf +from tensorflow.python.keras.engine import data_adapter +from wf_psf.psf_models import psf_models as psfm +from wf_psf.psf_models import tf_layers as tfl +from wf_psf.utils.utils import zernike_generator +from wf_psf.data.training_preprocessing import get_obs_positions, get_zernike_prior +from wf_psf.psf_models.tf_layers import ( + TFPolynomialZernikeField, + TFZernikeOPD, + TFBatchPolychromaticPSF, + TFBatchMonochromaticPSF, + TFNonParametricPolynomialVariationsOPD, + TFPhysicalLayer, +) +import logging + + +logger = logging.getLogger(__name__) + + +@psfm.register_psfclass +class TFPhysicalPolychromaticFieldFactory(psfm.PSFModelBaseFactory): + """Factory class for the Tensor Flow Physical Polychromatic PSF Field Model. + + This factory class is responsible for instantiating instances of the Tensor Flow Physical Polychromatic PSF Field Model. + It is registered with the PSF model factory registry. + + Parameters + ---------- + ids : tuple + A tuple containing identifiers for the factory class. + + Methods + ------- + get_model_instance(model_params, training_params, data=None, coeff_mat=None) + Instantiates an instance of the Tensor Flow Physical Polychromatic Field class with the provided parameters. + """ + + ids = ("physical_poly",) + + def get_model_instance(self, model_params, training_params, data, coeff_mat=None): + return TF_physical_polychromatic_field( + model_params, training_params, data, coeff_mat + ) + + +class TFPhysicalPolychromaticField(tf.keras.Model): + """Tensor Flow Physical Polychromatic PSF Field class. + + This class represents a polychromatic PSF field model with a physical layer, + which is part of a larger PSF modeling framework. + + Parameters + ---------- + ids : tuple + A tuple storing the string attribute of the PSF model class + model_params : Recursive Namespace + A Recursive Namespace object containing parameters for this PSF model class + training_params : Recursive Namespace + A Recursive Namespace object containing training hyperparameters for this PSF model class + data : DataConfigHandler object + A DataConfigHandler object containing training and tests datasets + coeff_mat : Tensor or None + Initialization of the coefficient matrix defining the parametric psf field model + + Returns + ------- + PSF model instance + An instance of the Physical Polychromatic PSF Field Model. + + """ + + ids = ("physical_poly",) + + def __init__(self, model_params, training_params, data, coeff_mat=None): + """Initialize the TFPhysicalPolychromaticField instance. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + training_params : RecursiveNamespace + Object containing training hyperparameters for this PSF model class. + data : DataConfigHandler + Object containing training and test datasets and zernike prior. + + coeff_mat : Tensor or None + Coefficient matrix defining the parametric PSF field model. + + Returns + ------- + TFPhysicalPolychromaticField + Initialized instance of the TFPhysicalPolychromaticField class. + """ + super().__init__(model_params, training_params, coeff_mat) + self._initialize_parameters_and_layers( + model_params, training_params, data, coeff_mat + ) + + def _initialize_parameters_and_layers( + self, model_params, training_params, data, coeff_mat=None + ): + """Initialize Parameters of the PSF model. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + training_params : RecursiveNamespace + Object containing training hyperparameters for this PSF model class. + data : DataConfigHandler + Object containing training and test datasets. + coeff_mat : Tensor or None + Coefficient matrix defining the parametric PSF field model. + """ + self.output_Q = model_params.output_Q + self.obs_pos = get_obs_positions(data) + self.l2_param = model_params.param_hparams.l2_param + + self._initialize_zernike_parameters(model_params, data) + self._initialize_layers(model_params, training_params) + + # Initialize the model parameters with non-default value + if coeff_mat is not None: + self.assign_coeff_matrix(coeff_mat) + + def _initialize_zernike_parameters(self, model_params, data): + """Initialize the Zernike parameters. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parametrs for this PSF model class. + + """ + self.zks_prior = get_zernike_prior(data) + self.n_zks_total = max( + model_params.param_hparams.n_zernikes, + tf.cast(tf.shape(self.zks_prior)[1], tf.int32), + ) + self.zernike_maps = psfm.generate_zernike_maps_3d( + self.n_zks_total, model_params.pupil_diameter + ) + + def _initialize_layers(self, model_params, training_params): + """Initialize the layers of the PSF model. + + This method initializes the layers of the PSF model, including the physical + layer, polynomial Zernike field, batch polychromatic layer, and non-parametric + OPD layer. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + training_params : RecursiveNamespace + Object containing training hyperparameters for this PSF model class. + coeff_mat : Tensor or None + Initialization of the coefficient matrix defining the parametric PSF field model. + + """ + self._initialize_physical_layer(model_params) + self._initialize_polynomial_Z_field(model_params) + self._initialize_batch_polychromatic_layer(model_params, training_params) + self._initialize_nonparametric_opd_layer(model_params, training_params) + + def _initialize_physical_layer(self, model_params): + """Initialize the physical layer of the PSF model. + + This method initializes the physical layer of the PSF model using parameters + specified in the `model_params` object. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + + """ + self.tf_physical_layer = TFPhysicalLayer( + self.obs_pos, + self.zks_prior, + interpolation_type=model_params.interpolation_type, + interpolation_args=model_params.interpolation_args, + ) + + def _initialize_polynomial_Z_field(self, model_params): + """Initialize the polynomial Zernike field of the PSF model. + + This method initializes the polynomial Zernike field of the PSF model using + parameters specified in the `model_params` object. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + + """ + self.tf_poly_Z_field = TFPolynomialZernikeField( + x_lims=model_params.x_lims, + y_lims=model_params.y_lims, + random_seed=model_params.param_hparams.random_seed, + n_zernikes=model_params.param_hparams.n_zernikes, + d_max=model_params.param_hparams.d_max, + ) + + def _initialize_batch_polychromatic_layer(self, model_params, training_params): + """Initialize the batch polychromatic PSF layer. + + This method initializes the batch opd to batch polychromatic PSF layer + using the provided `model_params` and `training_params`. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + training_params : RecursiveNamespace + Object containing training hyperparameters for this PSF model class. + + + """ + self.batch_size = training_params.batch_size + self.obscurations = psfm.tf_obscurations(model_params.pupil_diameter) + self.output_dim = model_params.output_dim + + self.tf_batch_poly_PSF = TFBatchPolychromaticPSF( + obscurations=self.obscurations, + output_Q=self.output_Q, + output_dim=self.output_dim, + ) + + def _initialize_nonparametric_opd_layer(self, model_params, training_params): + """Initialize the non-parametric OPD layer. + + This method initializes the non-parametric OPD layer using the provided + `model_params` and `training_params`. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + training_params : RecursiveNamespace + Object containing training hyperparameters for this PSF model class. + + """ + # self.d_max_nonparam = model_params.nonparam_hparams.d_max_nonparam + # self.opd_dim = tf.shape(self.zernike_maps)[1].numpy() + + self.tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( + x_lims=model_params.x_lims, + y_lims=model_params.y_lims, + d_max=model_params.nonparam_hparams.d_max_nonparam, + opd_dim=tf.shape(self.zernike_maps)[1].numpy(), + ) + + def get_coeff_matrix(self): + """Get coefficient matrix.""" + return self.tf_poly_Z_field.get_coeff_matrix() + + def assign_coeff_matrix(self, coeff_mat): + """Assigns the coefficient matrix defining the parametric PSF field model. + + This method assigns the coefficient matrix to the parametric PSF field model, + allowing for customization and modification of the PSF field. + + Parameters + ---------- + coeff_mat : Tensor or None + The coefficient matrix defining the parametric PSF field model. + If None, the default coefficient matrix will be used. + + + """ + self.tf_poly_Z_field.assign_coeff_matrix(coeff_mat) + + def set_zero_nonparam(self): + """Set the non-parametric part of the OPD (Optical Path Difference) to zero. + + This method sets the non-parametric component of the Optical Path Difference (OPD) + to zero, effectively removing its contribution from the overall PSF (Point Spread Function). + + """ + self.tf_np_poly_opd.set_alpha_zero() + + def set_nonzero_nonparam(self): + """Set the non-parametric part to non-zero values. + + This method sets the non-parametric component of the Optical Path Difference (OPD) + to non-zero values, allowing it to contribute to the overall PSF (Point Spread Function). + + """ + self.tf_np_poly_opd.set_alpha_identity() + + def set_trainable_layers(self, param_bool=True, nonparam_bool=True): + """Set the layers to be trainable. + + A method to set layers to be trainable. + + Parameters + ---------- + param_bool: bool + Boolean flag for parametric model layers + + nonparam_bool: bool + Boolean flag for non-parametric model layers + + """ + self.tf_np_poly_opd.trainable = nonparam_bool + self.tf_poly_Z_field.trainable = param_bool + + def pad_zernikes(self, zk_param, zk_prior): + """Pad the Zernike coefficients to match the maximum length. + + Pad the input Zernike coefficient tensors to match the length of the + maximum number of Zernike coefficients among the parametric and prior parts. + + Parameters + ---------- + zk_param: tf.Tensor + Zernike coefficients for the parametric part. Shape [batch, n_zks_param, 1, 1]. + zk_prior: tf.Tensor + Zernike coefficients for the prior part. Shape [batch, n_zks_prior, 1, 1]. + + Returns + ------- + padded_zk_param: tf.Tensor + Padded Zernike coefficients for the parametric part. Shape [batch, n_zks_total, 1, 1]. + padded_zk_prior: tf.Tensor + Padded Zernike coefficients for the prior part. Shape [batch, n_zks_total, 1, 1]. + """ + # Calculate the number of Zernikes to pad for parametric and prior parts + pad_num_param = tf.cast( + self.n_zks_total - tf.shape(zk_param)[1].numpy(), dtype=tf.int32 + ) + pad_num_prior = tf.cast( + self.n_zks_total - tf.shape(zk_prior)[1].numpy(), dtype=tf.int32 + ) + + if pad_num_param != 0: + # Pad the Zernike coefficients for parametric and prior parts + padding_param = [(0, 0), (0, pad_num_param), (0, 0), (0, 0)] + padded_zk_param = tf.pad(zk_param, padding_param) + else: + padded_zk_param = zk_param + + if pad_num_prior != 0: + padding_prior = [(0, 0), (0, pad_num_prior), (0, 0), (0, 0)] + padded_zk_prior = tf.pad(zk_prior, padding_prior) + else: + padded_zk_prior = zk_prior + + # Assert that the shapes are correct + if padded_zk_param.shape != padded_zk_prior.shape: + raise ValueError( + "Shapes of padded tensors {zk_param.shape} and {zk_prior.shape} do not match." + ) + + return padded_zk_param, padded_zk_prior + + def predict_step(self, data, evaluate_step=False): + """Predict (inference) step. + + A method to enable a special type of + interpolation (different from training) for + the physical layer. + + Parameters + ---------- + data : NOT SURE + + evaluate_step : bool + Boolean flag to evaluate step + + Returns + ------- + poly_psfs TFBatchPolychromaticPSF + Instance of TFBatchPolychromaticPSF class containing computed polychromatic PSFs. + + """ + if evaluate_step: + input_data = data + else: + # Format input data + data = data_adapter.expand_1d(data) + input_data, _, _ = data_adapter.unpack_x_y_sample_weight(data) + + # Unpack inputs + input_positions = input_data[0] + packed_SEDs = input_data[1] + + # Compute zernikes from parametric model and physical layer + zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD + param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part + nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations + opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) + # Compute the polychromatic PSFs + poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) + + return poly_psfs + + def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): + """Predict a set of monochromatic Point Spread Functions (PSFs) at desired positions. + + This method calculates monochromatic PSFs based on the provided input positions, + observed wavelength, and required wavefront dimension. + + Parameters + ---------- + input_positions : Tensor [batch_dim, 2] + Positions at which to compute the PSFs. + lambda_obs : float + Observed wavelength in micrometers (um). + phase_N : int + Required wavefront dimension. This should be calculated using a SimPSFToolkit + instance. Example: + ``` + simPSF_np = wf.SimPSFToolkit(...) + phase_N = simPSF_np.feasible_N(lambda_obs) + ``` + + Returns + ------- + mono_psf_batch : Tensor + Batch of monochromatic PSFs. + + """ + + # Initialise the monochromatic PSF batch calculator + tf_batch_mono_psf = TF_batch_mono_PSF( + obscurations=self.obscurations, + output_Q=self.output_Q, + output_dim=self.output_dim, + ) + # Set the lambda_obs and the phase_N parameters + tf_batch_mono_psf.set_lambda_phaseN(phase_N, lambda_obs) + + # Predict zernikes from parametric model and physical layer + zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD + param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part + nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations + opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) + + # Compute the monochromatic PSFs + mono_psf_batch = tf_batch_mono_psf(opd_maps) + + return mono_psf_batch + + def predict_opd(self, input_positions): + """Predict the OPD at some positions. + + Parameters + ---------- + input_positions: Tensor [batch_dim, 2] + Positions to predict the OPD. + + Returns + ------- + opd_maps : Tensor [batch, opd_dim, opd_dim] + OPD at requested positions. + + """ + # Predict zernikes from parametric model and physical layer + zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD + param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part + nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations + opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) + + return opd_maps + + def compute_zernikes(self, input_positions): + """Compute Zernike coefficients at a batch of positions. + + This method computes the Zernike coefficients for a batch of input positions + using both the parametric model and the physical layer. + + Parameters + ---------- + input_positions: Tensor [batch_dim, 2] + Positions for which to compute the Zernike coefficients. + + Returns + ------- + zernike_coefficients : Tensor [batch, n_zks_total, 1, 1] + Computed Zernike coefficients for the input positions. + + Notes + ----- + This method combines the predictions from both the parametric model and + the physical layer to obtain the final Zernike coefficients. + + """ + # Calculate parametric part + zernike_params = self.tf_poly_Z_field(input_positions) + + # Calculate the physical layer + zernike_prior = self.tf_physical_layer.call(input_positions) + + # Pad and sum the zernike coefficients + padded_zernike_params, padded_zernike_prior = self.pad_zernikes( + zernike_params, zernike_prior + ) + zernike_coeffs = tf.math.add(padded_zernike_params, padded_zernike_prior) + + return zernike_coeffs + + def predict_zernikes(self, input_positions): + """Predict Zernike coefficients at a batch of positions. + + This method predicts the Zernike coefficients for a batch of input positions + using both the parametric model and the physical layer. During training, + the prediction from the physical layer is typically not used. + + Parameters + ---------- + input_positions: Tensor [batch_dim, 2] + Positions for which to predict the Zernike coefficients. + + Returns + ------- + zernike_coefficients : Tensor [batch, n_zks_total, 1, 1] + Predicted Zernike coefficients for the input positions. + + Notes + ----- + At training time, the prediction from the physical layer may not be utilized, + as the model might be trained to rely solely on the parametric part. + + """ + # Calculate parametric part + zernike_params = self.tf_poly_Z_field(input_positions) + + # Calculate the prediction from the physical layer + physical_layer_prediction = self.tf_physical_layer.predict(input_positions) + + # Pad and sum the Zernike coefficients + padded_zernike_params, padded_physical_layer_prediction = self.pad_zernikes( + zernike_params, physical_layer_prediction + ) + zernike_coefficients = tf.math.add( + padded_zernike_params, padded_physical_layer_prediction + ) + + return zernike_coeffs + + def call(self, inputs, training=True): + """Define the PSF (Point Spread Function) field forward model. + + This method defines the forward model of the PSF field, which involves several steps: + 1. Transforming input positions into Zernike coefficients. + 2. Converting Zernike coefficients into Optical Path Difference (OPD) maps. + 3. Combining OPD maps with Spectral Energy Distribution (SED) information to generate + polychromatic PSFs. + + Parameters + ---------- + inputs : list + List containing input data required for PSF computation. It should contain two + elements: + - input_positions: Tensor [batch_dim, 2] + Positions at which to compute the PSFs. + - packed_SEDs: Tensor [batch_dim, ...] + Packed Spectral Energy Distributions (SEDs) for the corresponding positions. + training : bool, optional + Indicates whether the model is being trained or used for inference. Defaults to True. + + Returns + ------- + poly_psfs : Tensor + Polychromatic PSFs generated by the forward model. + + Notes + ----- + - The `input_positions` tensor should have a shape of [batch_dim, 2], where each row + represents the x and y coordinates of a position. + - The `packed_SEDs` tensor should have a shape of [batch_dim, ...], containing the SED + information for each position. + - During training, this method computes the Zernike coefficients from the input positions + and calculates the corresponding OPD maps. Additionally, it adds an L2 loss term based on + the parametric OPD maps. + - During inference, this method generates predictions using precomputed OPD maps or by + propagating through the forward model. + + Examples + -------- + # Usage during training + inputs = [input_positions, packed_SEDs] + poly_psfs = psf_model(inputs) + + # Usage during inference + inputs = [input_positions, packed_SEDs] + poly_psfs = psf_model(inputs, training=False) + """ + # Unpack inputs + input_positions = inputs[0] + packed_SEDs = inputs[1] + + # For the training + if training: + # Compute zernikes from parametric model and physical layer + zks_coeffs = self.compute_zernikes(input_positions) + + # Propagate to obtain the OPD + param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + + # Add l2 loss on the parametric OPD + self.add_loss( + self.l2_param * tf.math.reduce_sum(tf.math.square(param_opd_maps)) + ) + + # Calculate the non parametric part + nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + + # Add the estimations + opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) + + # Compute the polychromatic PSFs + poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) + # For the inference + else: + # Compute predictions + poly_psfs = self.predict_step(inputs, evaluate_step=True) + + return poly_psfs diff --git a/src/wf_psf/psf_models/psf_model_semiparametric.py b/src/wf_psf/psf_models/psf_model_semiparametric.py index 017acfdf..b68f7e98 100644 --- a/src/wf_psf/psf_models/psf_model_semiparametric.py +++ b/src/wf_psf/psf_models/psf_model_semiparametric.py @@ -12,10 +12,10 @@ from tensorflow.python.keras.engine import data_adapter from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models import tf_layers as tfl -from wf_psf.utils.utils import PI_zernikes, zernike_generator +from wf_psf.utils.utils import decompose_tf_obscured_opd_basis from wf_psf.psf_models.tf_layers import ( - TF_batch_poly_PSF, - TF_batch_mono_PSF, + TFBatchPolychromaticPSF, + TFBatchMonochromaticPSF, ) import logging @@ -24,15 +24,58 @@ @psfm.register_psfclass -class TF_SemiParam_field(tf.keras.Model): +class SemiParamFieldFactory(psfm.PSFModelBaseFactory): + """Factory class for the SemiParametric PSF Field Model. + + This factory class is responsible for instantiating instances of the SemiParametric PSF Field Model. + It is registered with the PSF model factory registry. + + Parameters + ---------- + ids: tuple + A tuple containing identifiers for the factory class. + + Methods + ------- + get_model_instance(model_params, training_params, data=None, coeff_mat=None) + Instantiates an instance of the SemiParametric PSF Field Model with the provided parameters. + """ + + ids = ("poly",) + + def get_model_instance( + self, model_params, training_params, data=None, coeff_mat=None + ): + """Get Model Instance. + + This method creates an instance of the SemiParametric PSF Field Model using the provided parameters. + + Parameters + ---------- + model_params : object + Parameters for configuring the PSF model. + training_params : object + Parameters for training the PSF model. + data : object or None, optional + Data used for training the PSF model. + coeff_mat : object or None, optional + Coefficient matrix defining the parametric PSF field model. + + Returns + ------- + PSF model instance + An instance of the SemiParametric PSF Field Model. + """ + return TFSemiParametricField(model_params, training_params, coeff_mat) + + +class TFSemiParametricField(tf.keras.Model): """PSF field forward model. Semi parametric model based on the Zernike polynomial basis. Parameters ---------- - ids: tuple - A tuple storing the string attribute of the PSF model class model_params: Recursive Namespace Recursive Namespace object containing parameters for this PSF model class training_params: Recursive Namespace @@ -42,8 +85,6 @@ class TF_SemiParam_field(tf.keras.Model): """ - ids = ("poly",) - def __init__(self, model_params, training_params, coeff_mat=None): super().__init__() @@ -61,11 +102,12 @@ def __init__(self, model_params, training_params, coeff_mat=None): self.d_max = model_params.param_hparams.d_max self.x_lims = model_params.x_lims self.y_lims = model_params.y_lims + self.zernike_maps = psfm.generate_zernike_maps_3d( + self.n_zernikes, self.pupil_diam + ) # Inputs: TF_NP_poly_OPD self.d_max_nonparam = model_params.nonparam_hparams.d_max_nonparam - self.zernike_maps = psfm.tf_zernike_cube(self.n_zernikes, self.pupil_diam) - self.opd_dim = tf.shape(self.zernike_maps)[1].numpy() self.opd_dim = tf.shape(self.zernike_maps)[1].numpy() # Inputs: TF_batch_poly_PSF @@ -93,7 +135,7 @@ def __init__(self, model_params, training_params, coeff_mat=None): ) # Initialize the first layer - self.tf_poly_Z_field = tfl.TF_poly_Z_field( + self.tf_poly_Z_field = tfl.TFPolynomialZernikeField( x_lims=self.x_lims, y_lims=self.y_lims, random_seed=self.random_seed, @@ -102,10 +144,10 @@ def __init__(self, model_params, training_params, coeff_mat=None): ) # Initialize the zernike to OPD layer - self.tf_zernike_OPD = tfl.TF_zernike_OPD(zernike_maps=self.zernike_maps) + self.tf_zernike_OPD = tfl.TFZernikeOPD(zernike_maps=self.zernike_maps) # Initialize the non-parametric (np) layer - self.tf_np_poly_opd = tfl.TF_NP_poly_OPD( + self.tf_np_poly_opd = tfl.TFNonParametricPolynomialVariationsOPD( x_lims=self.x_lims, y_lims=self.y_lims, random_seed=self.random_seed, @@ -114,7 +156,7 @@ def __init__(self, model_params, training_params, coeff_mat=None): ) # Initialize the batch opd to batch polychromatic PSF layer - self.tf_batch_poly_PSF = tfl.TF_batch_poly_PSF( + self.tf_batch_poly_PSF = tfl.TFBatchPolychromaticPSF( obscurations=self.obscurations, output_Q=self.output_Q, output_dim=self.output_dim, @@ -204,8 +246,8 @@ def set_output_Q(self, output_Q, output_dim=None): if output_dim is not None: self.output_dim = output_dim - # Reinitialize the PSF batch poly generator - self.tf_batch_poly_PSF = TF_batch_poly_PSF( + # Reinitialize the PSF batch polychromatic generator + self.tf_batch_poly_PSF = TFBatchPolychromaticPSF( obscurations=self.obscurations, output_Q=self.output_Q, output_dim=self.output_dim, @@ -228,7 +270,7 @@ def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): """ # Initialise the monochromatic PSF batch calculator - tf_batch_mono_psf = TF_batch_mono_PSF( + tf_batch_mono_psf = TFBatchMonochromaticPSF( obscurations=self.obscurations, output_Q=self.output_Q, output_dim=self.output_dim, @@ -256,13 +298,13 @@ def predict_opd(self, input_positions): Parameters ---------- - input_positions: Tensor(batch_dim x 2) - Positions to predict the OPD. + input_positions : tf.Tensor + Positions to predict the OPD. Tensor dimensions are (batch_dim, 2) Returns ------- - opd_maps : Tensor [batch x opd_dim x opd_dim] - OPD at requested positions. + opd_maps : tf.Tensor + OPD at requested positions. Tensor dimensions are (batch, opd_dim, opd_dim) """ # Calculate parametric part @@ -275,70 +317,91 @@ def predict_opd(self, input_positions): return opd_maps - def assign_S_mat(self, S_mat): + def assign_S_mat(self, s_mat): """Assign DD features matrix.""" - self.tf_np_poly_opd.assign_S_mat(S_mat) + self.tf_np_poly_opd.assign_S_mat(s_mat) + + def project_DD_features(self, tf_zernike_cube=None): + """Project data-driven features. - def project_DD_features(self, tf_zernike_cube): - """ Project non-parametric wavefront onto first n_z Zernikes and transfer - their parameters to the parametric model. + their parameters to the parametric model. This method updates the value + of the S matrix in the non-parametric layer `self.tf_np_poly_opd`. + + Parameters + ---------- + tf_zernike_cube : tf.Tensor + Zernike maps used for the projection. """ - # Compute Zernike norm for projections - n_pix_zernike = PI_zernikes(tf_zernike_cube[0, :, :], tf_zernike_cube[0, :, :]) + # If no Zernike maps are provided, use the ones from the + # Zernike to OPD layer + if tf_zernike_cube is None: + tf_zernike_cube = self.tf_zernike_OPD.zernike_maps + + # If no Zernike maps are provided, use the ones from the + # Zernike to OPD layer + if tf_zernike_cube is None: + tf_zernike_cube = self.tf_zernike_OPD.zernike_maps + + # Number of monomials in the parametric part -> n_poly(d_max) + n_poly_param = self.tf_poly_Z_field.coeff_mat.shape[1] + # Multiply Alpha matrix with DD features matrix S inter_res_v2 = tf.tensordot( - self.tf_np_poly_opd.alpha_mat[: self.tf_poly_Z_field.coeff_mat.shape[1], :], + self.tf_np_poly_opd.alpha_mat[:n_poly_param, :], self.tf_np_poly_opd.S_mat, axes=1, ) # Project over first n_z Zernikes - # TO DO: Clean up delta_C_poly = tf.constant( - np.array( - [ + np.transpose( + np.array( [ - PI_zernikes( - tf_zernike_cube[i, :, :], - inter_res_v2[j, :, :], - n_pix_zernike, + decompose_tf_obscured_opd_basis( + tf_opd=inter_res_v2[j, :, :], + tf_obscurations=self.obscurations, + tf_zk_basis=tf_zernike_cube, + n_zernike=self.n_zernikes, + iters=40, ) - for j in range(self.tf_poly_Z_field.coeff_mat.shape[1]) + for j in range(n_poly_param) ] - for i in range(self.n_zernikes) - ] + ) ), dtype=tf.float32, ) old_C_poly = self.tf_poly_Z_field.coeff_mat + # Corrected parametric coeff matrix new_C_poly = old_C_poly + delta_C_poly self.assign_coeff_matrix(new_C_poly) # Remove extracted features from non-parametric model # Mix DD features with matrix alpha - S_tilde = tf.tensordot( + s_tilde = tf.tensordot( self.tf_np_poly_opd.alpha_mat, self.tf_np_poly_opd.S_mat, axes=1 ) - # TO DO: Clean Up - # Get beta tilde as the protection of the first n_param_poly_terms (6 for d_max=2) onto the first n_zernikes. + + # Get beta tilde as the proyection of the first n_param_poly_terms (6 for d_max=2) onto the first n_zernikes. beta_tilde_inner = np.array( [ - [ - PI_zernikes(tf_zernike_cube[j, :, :], S_tilde_slice, n_pix_zernike) - for j in range(self.n_zernikes) - ] - for S_tilde_slice in S_tilde[ - : self.tf_poly_Z_field.coeff_mat.shape[1], :, : - ] + decompose_tf_obscured_opd_basis( + tf_opd=s_tilde_slice, + tf_obscurations=self.obscurations, + tf_zk_basis=tf_zernike_cube, + n_zernike=self.n_zernikes, + iters=40, + ) + for s_tilde_slice in s_tilde[:n_poly_param, :, :] ] ) - # Only pad in the first dimension so we get a matrix of size (d_max_nonparam_terms)x(n_zernikes) --> 21x15 or 21x45. + # Only pad in the first dimention so we get a + # matrix of size (d_max_nonparam_terms)x(n_zernikes) --> 21x15 or 21x45. beta_tilde = np.pad( beta_tilde_inner, - [(0, S_tilde.shape[0] - beta_tilde_inner.shape[0]), (0, 0)], + [(0, s_tilde.shape[0] - beta_tilde_inner.shape[0]), (0, 0)], mode="constant", ) @@ -346,15 +409,14 @@ def project_DD_features(self, tf_zernike_cube): beta = tf.constant( np.linalg.inv(self.tf_np_poly_opd.alpha_mat) @ beta_tilde, dtype=tf.float32 ) - # To do: Clarify comment or delete. # Get the projection for the unmixed features # Now since beta.shape[1]=n_zernikes we can take the whole beta matrix. - S_mat_projected = tf.tensordot(beta, tf_zernike_cube, axes=[1, 0]) + s_mat_projected = tf.tensordot(beta, tf_zernike_cube, axes=[1, 0]) # Subtract the projection from the DD features - S_new = self.tf_np_poly_opd.S_mat - S_mat_projected - self.assign_S_mat(S_new) + s_new = self.tf_np_poly_opd.S_mat - s_mat_projected + self.assign_S_mat(s_new) def call(self, inputs): """Define the PSF field forward model. diff --git a/src/wf_psf/psf_models/psf_models.py b/src/wf_psf/psf_models/psf_models.py index 4fcbc8cf..262faaaa 100644 --- a/src/wf_psf/psf_models/psf_models.py +++ b/src/wf_psf/psf_models/psf_models.py @@ -10,19 +10,27 @@ import numpy as np import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.utils.utils import PI_zernikes, zernike_generator from wf_psf.sims.psf_simulator import PSFSimulator +from wf_psf.utils.utils import zernike_generator import glob from sys import exit import logging logger = logging.getLogger(__name__) -PSF_CLASS = {} +PSF_FACTORY = {} -class PsfModelError(Exception): - """PSF Model Parameter Error exception class for specific error scenarios.""" +class PSFModelError(Exception): + """PSF Model Parameter Error exception class. + + This exception class is used to handle errors related to PSF (Point Spread Function) model parameters. + + Parameters + ---------- + message : str, optional + Error message to be raised. Defaults to "An error with your PSF model parameter settings occurred." + """ def __init__( self, message="An error with your PSF model parameter settings occurred." @@ -31,27 +39,64 @@ def __init__( super().__init__(self.message) -def register_psfclass(psf_class): - """Register PSF Class. +class PSFModelBaseFactory: + """Base factory class for PSF models. - A wrapper function to register all PSF model classes - in a dictionary. + This class serves as the base factory for instantiating PSF (Point Spread Function) models. + Subclasses should implement the `get_model_instance` method to provide specific PSF model instances. - Parameters + Attributes ---------- - psf_class: type - PSF Class + None - Returns + Methods ------- - psf_class: type - PSF class + get_model_instance(model_params, training_params, data=None, coeff_matrix=None) + Instantiates a PSF model with the provided parameters. + Notes + ----- + Subclasses of `PSFModelBaseFactory` should override the `get_model_instance` method to provide + implementation-specific logic for instantiating PSF model instances. """ - for id in psf_class.ids: - PSF_CLASS[id] = psf_class - return psf_class + def get_model_instance( + self, model_params, training_params, data=None, coeff_matrix=None + ): + """Instantiate a PSF model instance. + + Parameters + ---------- + model_params: object + Parameters for configuring the PSF model. + training_params: object + Parameters for training the PSF model. + data: object or None, optional + Data used for training the PSF model. + coeff_matrix: object or None, optional + Coefficient matrix defining the PSF model. + + Returns + ------- + PSF model instance + An instance of the PSF model. + """ + pass + + +def register_psfclass(psf_factory_class): + """Register PSF Factory Class. + + A function to register a PSF factory class in a dictionary. + + Parameters + ---------- + factory_class: type + PSF Factory Class + + """ + for id in psf_factory_class.ids: + PSF_FACTORY[id] = psf_factory_class def set_psf_model(model_name): @@ -73,14 +118,14 @@ def set_psf_model(model_name): """ try: - psf_class = PSF_CLASS[model_name] + psf_factory_class = PSF_FACTORY[model_name] except KeyError as e: logger.exception(e) - raise PsfModelError("PSF model entered is invalid. Check your config settings.") - return psf_class + raise PSFModelError("PSF model entered is invalid. Check your config settings.") + return psf_factory_class -def get_psf_model(model_params, training_hparams, *coeff_matrix): +def get_psf_model(*psf_model_params): """Get PSF Model Class Instance. A function to instantiate a @@ -88,24 +133,56 @@ def get_psf_model(model_params, training_hparams, *coeff_matrix): Parameters ---------- - model_name: str - Short name of PSF model - model_params: type - Recursive Namespace object - training_hparams: type - Recursive Namespace object - coeff_matrix: Tensor or None, optional - Initialization of the coefficient matrix defining the parametric psf field model + *psf_model_params : tuple + Positional arguments representing the parameters required to instantiate the PSF model. Returns ------- - psf_class: class instance - PSF model class instance + PSF model class instance + An instance of the PSF model class based on the provided parameters. + """ - psf_class = set_psf_model(model_params.model_name) - return psf_class(model_params, training_hparams, *coeff_matrix) + model_name = psf_model_params[0].model_name + psf_class = set_psf_model(model_name) + psf_factory_class = PSF_FACTORY.get(model_name) + if psf_factory_class is None: + raise PSFModelError("PSF model entered is invalid. Check your config settings.") + + return psf_factory_class().get_model_instance(*psf_model_params) + + +def build_PSF_model(model_inst, optimizer=None, loss=None, metrics=None): + """Define the model-compilation parameters. + + Specially the loss function, the optimizer and the metrics. + """ + # Define model loss function + if loss is None: + loss = tf.keras.losses.MeanSquaredError() + + # Define optimizer function + if optimizer is None: + optimizer = tf.keras.optimizers.Adam( + learning_rate=1e-2, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=False + ) + + # Define metric functions + if metrics is None: + metrics = [tf.keras.metrics.MeanSquaredError()] + + # Compile the model + model_inst.compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + loss_weights=None, + weighted_metrics=None, + run_eagerly=False, + ) + + return model_inst def get_psf_model_weights_filepath(weights_filepath): @@ -130,27 +207,31 @@ def get_psf_model_weights_filepath(weights_filepath): logger.exception( "PSF weights file not found. Check that you've specified the correct weights file in the metrics config file." ) - raise PsfModelError("PSF model weights error.") + raise PSFModelError("PSF model weights error.") -def tf_zernike_cube(n_zernikes, pupil_diam): - """Tensor Flow Zernike Cube. +def generate_zernike_maps_3d(n_zernikes, pupil_diam): + """Generate 3D Zernike Maps. - A function to generate Zernike maps on - a three-dimensional tensor. + This function generates Zernike maps on a three-dimensional tensor. Parameters ---------- - n_zernikes: int - Number of Zernike polynomials - pupil_diam: float - Size of the pupil diameter + n_zernikes : int + The number of Zernike polynomials. + pupil_diam : float + The diameter of the pupil. Returns ------- - Zernike map tensor - TensorFlow EagerTensor type - + tf.Tensor + A TensorFlow EagerTensor containing the Zernike map tensor. + + Notes + ----- + The Zernike maps are generated using the specified number of Zernike + polynomials and the size of the pupil diameter. The resulting tensor + contains the Zernike maps in a three-dimensional format. """ # Prepare the inputs # Generate Zernike maps @@ -191,9 +272,6 @@ def tf_obscurations(pupil_diam, N_filter=2): ) return tf.convert_to_tensor(obscurations, dtype=tf.complex64) - ## Generate initializations -- This looks like it could be moved to PSF model package - # Prepare np input - def simPSF(model_params): """Simulated PSF model. diff --git a/src/wf_psf/psf_models/tf_layers.py b/src/wf_psf/psf_models/tf_layers.py index 225edb33..09082af4 100644 --- a/src/wf_psf/psf_models/tf_layers.py +++ b/src/wf_psf/psf_models/tf_layers.py @@ -1,7 +1,7 @@ import numpy as np import tensorflow as tf import tensorflow_addons as tfa -from wf_psf.psf_models.tf_modules import TF_mono_PSF +from wf_psf.psf_models.tf_modules import TFMonochromaticPSF from wf_psf.utils.utils import calc_poly_position_mat import wf_psf.utils.utils as utils import logging @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) -class TF_poly_Z_field(tf.keras.layers.Layer): +class TFPolynomialZernikeField(tf.keras.layers.Layer): """Calculate the zernike coefficients for a given position. This module implements a polynomial model of Zernike @@ -90,7 +90,7 @@ def call(self, positions): return zernikes_coeffs[:, :, tf.newaxis, tf.newaxis] -class TF_zernike_OPD(tf.keras.layers.Layer): +class TFZernikeOPD(tf.keras.layers.Layer): """Turn zernike coefficients into an OPD. Will use all of the Zernike maps provided. @@ -122,7 +122,7 @@ def call(self, z_coeffs): return tf.math.reduce_sum(tf.math.multiply(self.zernike_maps, z_coeffs), axis=1) -class TF_batch_poly_PSF(tf.keras.layers.Layer): +class TFBatchPolychromaticPSF(tf.keras.layers.Layer): """Calculate a polychromatic PSF from an OPD and stored SED values. The calculation of the packed values with the respective SED is done @@ -159,7 +159,7 @@ def __init__(self, obscurations, output_Q, output_dim=64, name="TF_batch_poly_PS self.current_opd = None - def calculate_mono_PSF(self, packed_elems): + def calculate_monochromatic_PSF(self, packed_elems): """Calculate monochromatic PSF from packed elements. packed_elems[0]: phase_N @@ -172,7 +172,7 @@ def calculate_mono_PSF(self, packed_elems): SED_norm_val = packed_elems[2] # Build the monochromatic PSF generator - tf_mono_psf_gen = TF_mono_PSF( + tf_monochromatic_psf_gen = TFMonochromaticPSF( phase_N, lambda_obs, self.obscurations, @@ -181,22 +181,21 @@ def calculate_mono_PSF(self, packed_elems): ) # Calculate the PSF - mono_psf = tf_mono_psf_gen.__call__(self.current_opd) - mono_psf = tf.squeeze(mono_psf, axis=0) - # mono_psf = tf.reshape(mono_psf, shape=(mono_psf.shape[1],mono_psf.shape[2])) + monochromatic_psf = tf_monochromatic_psf_gen.__call__(self.current_opd) + monochromatic_psf = tf.squeeze(monochromatic_psf, axis=0) # Multiply with the respective normalized SED and return - return tf.math.scalar_mul(SED_norm_val, mono_psf) + return tf.math.scalar_mul(SED_norm_val, monochromatic_psf) - def calculate_poly_PSF(self, packed_elems): + def calculate_polychromatic_PSF(self, packed_elems): """Calculate a polychromatic PSF.""" self.current_opd = packed_elems[0][tf.newaxis, :, :] SED_pack_data = packed_elems[1] - def _calculate_poly_PSF(elems_to_unpack): + def _calculate_polychromatic_PSF(elems_to_unpack): return tf.map_fn( - self.calculate_mono_PSF, + self.calculate_monochromatic_PSF, elems_to_unpack, parallel_iterations=10, fn_output_signature=tf.float32, @@ -208,13 +207,13 @@ def _calculate_poly_PSF(elems_to_unpack): # poly_psf = tf.math.reduce_sum(stacked_psfs, axis=0) # return poly_psf - stack_psf = _calculate_poly_PSF(SED_pack_data) - poly_psf = tf.math.reduce_sum(stack_psf, axis=0) + stack_psf = _calculate_polychromatic_PSF(SED_pack_data) + polychromatic_psf = tf.math.reduce_sum(stack_psf, axis=0) - return poly_psf + return polychromatic_psf def call(self, inputs): - """Calculate the batch poly PSFs.""" + """Calculate the batch polychromatic PSFs.""" # Unpack Inputs opd_batch = inputs[0] @@ -222,19 +221,19 @@ def call(self, inputs): def _calculate_PSF_batch(elems_to_unpack): return tf.map_fn( - self.calculate_poly_PSF, + self.calculate_polychromatic_PSF, elems_to_unpack, parallel_iterations=10, fn_output_signature=tf.float32, swap_memory=True, ) - psf_poly_batch = _calculate_PSF_batch((opd_batch, packed_SED_data)) + psf_polychromatic_batch = _calculate_PSF_batch((opd_batch, packed_SED_data)) - return psf_poly_batch + return psf_polychromatic_batch -class TF_batch_mono_PSF(tf.keras.layers.Layer): +class TFBatchMonochromaticPSF(tf.keras.layers.Layer): """Calculate a monochromatic PSF from a batch of OPDs. The calculation of the ``phase_N`` variable is done @@ -255,7 +254,7 @@ class TF_batch_mono_PSF(tf.keras.layers.Layer): """ - def __init__(self, obscurations, output_Q, output_dim=64, name="TF_batch_mono_PSF"): + def __init__(self, obscurations, output_Q, output_dim=64, name="Pbatch_mono_PSF"): super().__init__(name=name) self.output_Q = output_Q @@ -268,7 +267,7 @@ def __init__(self, obscurations, output_Q, output_dim=64, name="TF_batch_mono_PS self.current_opd = None - def calculate_mono_PSF(self, current_opd): + def calculate_monochromatic_PSF(self, current_opd): """Calculate monochromatic PSF from OPD info.""" # Calculate the PSF mono_psf = self.tf_mono_psf_gen.__call__(current_opd[tf.newaxis, :, :]) @@ -278,7 +277,7 @@ def calculate_mono_PSF(self, current_opd): def init_mono_PSF(self): """Initialise or restart the PSF generator.""" - self.tf_mono_psf_gen = TF_mono_PSF( + self.tf_mono_psf_gen = TFMonochromaticPSF( self.phase_N, self.lambda_obs, self.obscurations, @@ -306,7 +305,7 @@ def call(self, opd_batch): def _calculate_PSF_batch(elems_to_unpack): return tf.map_fn( - self.calculate_mono_PSF, + self.calculate_monochromatic_PSF, elems_to_unpack, parallel_iterations=10, fn_output_signature=tf.float32, @@ -318,7 +317,7 @@ def _calculate_PSF_batch(elems_to_unpack): return mono_psf_batch -class TF_NP_poly_OPD(tf.keras.layers.Layer): +class TFNonParametricPolynomialVariationsOPD(tf.keras.layers.Layer): """Non-parametric OPD generation with polynomial variations. @@ -380,6 +379,10 @@ def init_vars(self): initial_value=tf.eye(self.n_poly), trainable=True, dtype=tf.float32 ) + # Update random seed for next call + if self.random_seed is not None: + self.random_seed += 1 + def set_alpha_zero(self): """Set alpha matrix to zero.""" self.alpha_mat.assign(tf.zeros_like(self.alpha_mat, dtype=tf.float32)) @@ -418,7 +421,7 @@ def call(self, positions): return tf.tensordot(inter_res, self.S_mat, axes=1) -class TF_NP_MCCD_OPD_v2(tf.keras.layers.Layer): +class TFNonParametricMCCDOPDv2(tf.keras.layers.Layer): """Non-parametric OPD generation with hybrid-MCCD variations. @@ -528,6 +531,10 @@ def init_vars(self): dtype=tf.float32, ) + # Update random seed for next call + if self.random_seed is not None: + self.random_seed += 1 + def set_alpha_zero(self): """Set alpha matrix to zero.""" self.alpha_poly.assign(tf.zeros_like(self.alpha_poly, dtype=tf.float32)) @@ -632,7 +639,7 @@ def calc_index(idx_pos): return tf.math.add(contribution_poly, contribution_graph) -class TF_NP_GRAPH_OPD(tf.keras.layers.Layer): +class TFNonParametricGraphOPD(tf.keras.layers.Layer): """Non-parametric OPD generation with only graph-cosntraint variations. @@ -723,6 +730,10 @@ def init_vars(self): dtype=tf.float32, ) + # Update random seed for next call + if self.random_seed is not None: + self.random_seed += 1 + def set_alpha_zero(self): """Set alpha matrix to zero.""" self.alpha_graph.assign(tf.zeros_like(self.alpha_graph, dtype=tf.float32)) @@ -807,8 +818,8 @@ def calc_index(idx_pos): return contribution_graph -class TF_physical_layer(tf.keras.layers.Layer): - """Store and calculate the zernike coefficients for a given position. +class TFPhysicalLayer(tf.keras.layers.Layer): + """Store and calculate the zernike coefficients for a given position This layer gives the Zernike contribution of the physical layer. It is fixed and not trainable. @@ -913,19 +924,38 @@ def interpolate_independent_Zk(self, positions): return interp_zks[:, :, tf.newaxis, tf.newaxis] def call(self, positions): - """Calculate the prior zernike coefficients for a given position. + """Calculate the prior Zernike coefficients for a batch of positions. - The position polynomial matrix and the coefficients should be - set before calling this function. + This method calculates the Zernike coefficients for a batch of input positions + based on the pre-computed Zernike coefficients for observed positions. Parameters ---------- - positions: Tensor(batch, 2) - First element is x-axis, second is y-axis. + positions : tf.Tensor + Tensor of shape (batch_size, 2) representing the positions. + The first element represents the x-axis, and the second element represents the y-axis. Returns ------- - zernikes_coeffs: Tensor(batch, n_zernikes, 1, 1) + zernike_coeffs : tf.Tensor + Tensor of shape (batch_size, n_zernikes, 1, 1) containing the prior Zernike coefficients + corresponding to the input positions. + + Notes + ----- + The method retrieves the Zernike coefficients for each input position from the pre-computed + Zernike coefficients stored for observed positions. It matches each input position with + the closest observed position and retrieves the corresponding Zernike coefficients. + + Before calling this method, ensure that the position polynomial matrix and the + corresponding Zernike coefficients have been precomputed and set for the layer. + + + Raises + ------ + ValueError + If the shape of the input `positions` tensor is not compatible. + """ def calc_index(idx_pos): @@ -937,133 +967,3 @@ def calc_index(idx_pos): batch_zks = tf.gather(self.zks_prior, indices=indices, axis=0, batch_dims=0) return batch_zks[:, :, tf.newaxis, tf.newaxis] - - -# --- # -# Deprecated # -class OLD_TF_batch_poly_PSF(tf.keras.layers.Layer): - """Calculate a polychromatic PSF from an OPD and stored SED values. - - The calculation of the packed values with the respective SED is done - with the PSFSimulator class but outside the TF class. - - - - obscurations: Tensor(pupil_len, pupil_len) - Obscurations to apply to the wavefront. - - packed_SED_data: Tensor(batch_size, 3, n_bins_lda) - - Comes from: tf.convert_to_tensor(list(list(Tensor,Tensor,Tensor))) - Where each inner list consist of a packed_elem: - - packed_elems: Tuple of tensors - Contains three 1D tensors with the parameters needed for - the calculation of one monochromatic PSF. - - packed_elems[0]: phase_N - packed_elems[1]: lambda_obs - packed_elems[2]: SED_norm_val - The SED data is constant in a FoV. - - psf_batch: Tensor(batch_size, output_dim, output_dim) - Tensor containing the psfs that will be updated each - time a calculation is required. - - """ - - def __init__( - self, obscurations, psf_batch, output_dim=64, name="TF_batch_poly_PSF" - ): - super().__init__(name=name) - - self.obscurations = obscurations - self.output_dim = output_dim - self.psf_batch = psf_batch - - self.current_opd = None - - def set_psf_batch(self, psf_batch): - """Set poly PSF batch.""" - self.psf_batch = psf_batch - - def calculate_mono_PSF(self, packed_elems): - """Calculate monochromatic PSF from packed elements. - - packed_elems[0]: phase_N - packed_elems[1]: lambda_obs - packed_elems[2]: SED_norm_val - """ - # Unpack elements - phase_N = packed_elems[0] - lambda_obs = packed_elems[1] - SED_norm_val = packed_elems[2] - - # Build the monochromatic PSF generator - tf_mono_psf_gen = TF_mono_PSF( - phase_N, lambda_obs, self.obscurations, output_dim=self.output_dim - ) - - # Calculate the PSF - mono_psf = tf_mono_psf_gen.__call__(self.current_opd) - - # Multiply with the respective normalized SED and return - return tf.math.scalar_mul(SED_norm_val, mono_psf) - - def calculate_poly_PSF(self, packed_elems): - """Calculate a polychromatic PSF.""" - - logger.info("TF_batch_poly_PSF: calculate_poly_PSF: packed_elems.type") - logger.info(packed_elems.dtype) - - def _calculate_poly_PSF(elems_to_unpack): - return tf.map_fn( - self.calculate_mono_PSF, - elems_to_unpack, - parallel_iterations=10, - fn_output_signature=tf.float32, - swap_memory=True, - ) - - # Readability - # stacked_psfs = _calculate_poly_PSF(packed_elems) - # poly_psf = tf.math.reduce_sum(stacked_psfs, axis=0) - # return poly_psf - - return tf.math.reduce_sum(_calculate_poly_PSF(packed_elems), axis=0) - - def call(self, inputs): - """Calculate the batch poly PSFs.""" - - # Unpack Inputs - opd_batch = inputs[0] - packed_SED_data = inputs[1] - - batch_num = opd_batch.shape[0] - - it = tf.constant(0) - while_condition = lambda it: tf.less(it, batch_num) - - def while_body(it): - # Extract the required data of _it_ - packed_elems = packed_SED_data[it] - self.current_opd = opd_batch[it][tf.newaxis, :, :] - - # Calculate the _it_ poly PSF - poly_psf = self.calculate_poly_PSF(packed_elems) - - # Update the poly PSF tensor with the result - # Slice update of a tensor - # See tf doc of _tensor_scatter_nd_update_ to understand - indices = tf.reshape(it, shape=(1, 1)) - # self.psf_batch = tf.tensor_scatter_nd_update(self.psf_batch, indices, poly_psf) - - # increment i - return [tf.add(it, 1)] - - # Loop over the PSF batches - r = tf.while_loop( - while_condition, while_body, [it], swap_memory=True, parallel_iterations=10 - ) - - return self.psf_batch diff --git a/src/wf_psf/psf_models/tf_modules.py b/src/wf_psf/psf_models/tf_modules.py index 2598fd13..7d678554 100644 --- a/src/wf_psf/psf_models/tf_modules.py +++ b/src/wf_psf/psf_models/tf_modules.py @@ -2,7 +2,7 @@ import tensorflow as tf -class TF_fft_diffract(tf.Module): +class TFFftDiffract(tf.Module): """Diffract the wavefront into a monochromatic PSF. Parameters @@ -101,7 +101,7 @@ def __call__(self, input_phase): return norm_psf -class TF_build_phase(tf.Module): +class TFBuildPhase(tf.Module): """Build complex phase map from OPD map.""" def __init__(self, phase_N, lambda_obs, obscurations, name=None): @@ -157,7 +157,7 @@ def __call__(self, opd): return padded_phase -class TF_zernike_OPD(tf.Module): +class TFZernikeOPD(tf.Module): """Turn zernike coefficients into an OPD. Will use all of the Zernike maps provided. @@ -185,7 +185,7 @@ def __call__(self, z_coeffs): return opd -class TF_Zernike_mono_PSF(tf.Module): +class TFZernikeMonochromaticPSF(tf.Module): """Build a monochromatic PSF from zernike coefficients. Following a Zernike model. @@ -196,9 +196,9 @@ def __init__( ): super().__init__(name=name) - self.tf_build_opd_zernike = TF_zernike_OPD(zernike_maps) - self.tf_build_phase = TF_build_phase(phase_N, lambda_obs, obscurations) - self.tf_fft_diffract = TF_fft_diffract(output_dim) + self.tf_build_opd_zernike = TFZernikesOPD(zernike_maps) + self.tf_build_phase = TFBuildPhase(phase_N, lambda_obs, obscurations) + self.tf_fft_diffract = TFFftDiffract(output_dim) def __call__(self, z_coeffs): opd = self.tf_build_opd_zernike.__call__(z_coeffs) @@ -208,7 +208,7 @@ def __call__(self, z_coeffs): return psf -class TF_mono_PSF(tf.Module): +class TFMonochromaticPSF(tf.Module): """Calculate a monochromatic PSF from an OPD map.""" def __init__( @@ -217,98 +217,11 @@ def __init__( super().__init__(name=name) self.output_Q = output_Q - self.tf_build_phase = TF_build_phase(phase_N, lambda_obs, obscurations) - self.tf_fft_diffract = TF_fft_diffract(output_dim, output_Q=self.output_Q) + self.tf_build_phase = TFBuildPhase(phase_N, lambda_obs, obscurations) + self.tf_fft_diffract = TFFftDiffract(output_dim, output_Q=self.output_Q) def __call__(self, opd): phase = self.tf_build_phase.__call__(opd) psf = self.tf_fft_diffract.__call__(phase) return tf.cast(psf, dtype=opd.dtype) - - -# class TF_poly_PSF(tf.Module): -# """Calculate a polychromatic PSF from an OPD and stored SED values. - -# The calculation of the packed values with the respective SED is done -# with the PSFSimulator class but outside the TF class. - -# packed_elems: Tuple of tensors -# Contains three 1D tensors with the parameters needed for -# the calculation of each monochromatic PSF. - -# packed_elems[0]: phase_N -# packed_elems[1]: lambda_obs -# packed_elems[2]: SED_norm_val -# """ -# def __init__(self, obscurations, packed_elems, output_dim=64, zernike_maps=None, name=None): -# super().__init__(name=name) - -# self.obscurations = obscurations -# self.output_dim = output_dim -# self.packed_elems = packed_elems -# self.zernike_maps = zernike_maps - -# self.opd = None - -# def set_packed_elems(self, new_packed_elems): -# """Set packed elements.""" -# self.packed_elems = new_packed_elems - -# def set_zernike_maps(self, zernike_maps): -# """Set Zernike maps.""" -# self.zernike_maps = zernike_maps - -# def calculate_from_zernikes(self, z_coeffs): -# """Calculate polychromatic PSFs from zernike coefficients. - -# Zernike maps required. -# """ -# tf_zernike_opd_gen = TF_zernike_OPD(self.zernike_maps) -# # For readability -# # opd = tf_zernike_opd_gen.__call__(z_coeffs) -# # poly_psf = self.__call__(opd) -# # return poly_psf - -# return self.__call__(tf_zernike_opd_gen.__call__(z_coeffs)) - -# def calculate_mono_PSF(self, packed_elems): -# """Calculate monochromatic PSF from packed elements. - -# packed_elems[0]: phase_N -# packed_elems[1]: lambda_obs -# packed_elems[2]: SED_norm_val -# """ -# # Unpack elements -# phase_N = packed_elems[0] -# lambda_obs = packed_elems[1] -# SED_norm_val = packed_elems[2] - -# # Build the monochromatic PSF generator -# tf_mono_psf_gen = TF_mono_PSF(phase_N, lambda_obs, self.obscurations, output_dim=self.output_dim) - -# # Calculate the PSF -# mono_psf = tf_mono_psf_gen.__call__(self.opd) - -# # Multiply with the respective normalized SED and return -# return tf.math.scalar_mul(SED_norm_val, mono_psf) - -# def __call__(self, opd): - -# # Save the OPD that will be shared by all the monochromatic PSFs -# self.opd = opd - -# # Use tf.function for parallelization over GPU -# # Not allowed since the dynamic padding for the diffraction does not -# # work in the @tf.function context -# # @tf.function -# def calculate_poly_PSF(elems_to_unpack): -# return tf.map_fn(self.calculate_mono_PSF, -# elems_to_unpack, -# parallel_iterations=10, -# fn_output_signature=tf.float32) - -# stacked_psfs = calculate_poly_PSF(packed_elems) -# poly_psf = tf.math.reduce_sum(stacked_psfs, axis=0) - -# return poly_psf diff --git a/src/wf_psf/psf_models/tf_psf_field.py b/src/wf_psf/psf_models/tf_psf_field.py index 37e96df8..e6a63f64 100644 --- a/src/wf_psf/psf_models/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_psf_field.py @@ -2,971 +2,15 @@ import tensorflow as tf from tensorflow.python.keras.engine import data_adapter from wf_psf.psf_models.tf_layers import ( - TF_poly_Z_field, - TF_zernike_OPD, - TF_batch_poly_PSF, + TFZernikeOPD, + TFBatchPolychromaticPSF, + TFBatchMonochromaticPSF, + TFPhysicalLayer, ) -from wf_psf.psf_models.tf_layers import ( - TF_NP_poly_OPD, - TF_batch_mono_PSF, - TF_physical_layer, -) -from wf_psf.utils.utils import PI_zernikes - - -class TF_PSF_field_model(tf.keras.Model): - """Parametric PSF field model! - - Fully parametric model based on the Zernike polynomial basis. - - Parameters - ---------- - zernike_maps: Tensor(n_batch, opd_dim, opd_dim) - Zernike polynomial maps. - obscurations: Tensor(opd_dim, opd_dim) - Predefined obscurations of the phase. - batch_size: int - Batch size. - output_Q: float - Oversampling used. This should match the oversampling Q used to generate - the diffraction zero padding that is found in the input `packed_SEDs`. - We call this other Q the `input_Q`. - In that case, we replicate the original sampling of the model used to - calculate the input `packed_SEDs`. - The final oversampling of the generated PSFs with respect to the - original instrument sampling depend on the division `input_Q/output_Q`. - It is not recommended to use `output_Q < 1`. - Although it works with float values it is better to use integer values. - l2_param: float - Parameter going with the l2 loss on the opd. If it is `0.` the loss - is not added. Default is `0.`. - output_dim: int - Output dimension of the PSF stamps. - n_zernikes: int - Order of the Zernike polynomial for the parametric model. - d_max: int - Maximum degree of the polynomial for the Zernike coefficient variations. - x_lims: [float, float] - Limits for the x coordinate of the PSF field. - y_lims: [float, float] - Limits for the x coordinate of the PSF field. - coeff_mat: Tensor or None - Initialization of the coefficient matrix defining the parametric psf - field model. - - """ - - def __init__( - self, - zernike_maps, - obscurations, - batch_size, - output_Q, - l2_param=0.0, - output_dim=64, - n_zernikes=45, - d_max=2, - x_lims=[0, 1e3], - y_lims=[0, 1e3], - coeff_mat=None, - name="TF_PSF_field_model", - ): - super(TF_PSF_field_model, self).__init__() - - self.output_Q = output_Q - - # Inputs: TF_poly_Z_field - self.n_zernikes = n_zernikes - self.d_max = d_max - self.x_lims = x_lims - self.y_lims = y_lims - - # Inputs: TF_zernike_OPD - # They are not stored as they are memory-heavy - # zernike_maps =[] - - # Inputs: TF_batch_poly_PSF - self.batch_size = batch_size - self.obscurations = obscurations - self.output_dim = output_dim - - # Inputs: Loss - self.l2_param = l2_param - - # Initialize the first layer - self.tf_poly_Z_field = TF_poly_Z_field( - x_lims=self.x_lims, - y_lims=self.y_lims, - n_zernikes=self.n_zernikes, - d_max=self.d_max, - ) - - # Initialize the zernike to OPD layer - self.tf_zernike_OPD = TF_zernike_OPD(zernike_maps=zernike_maps) - - # Initialize the batch opd to batch polychromatic PSF layer - self.tf_batch_poly_PSF = TF_batch_poly_PSF( - obscurations=self.obscurations, - output_Q=self.output_Q, - output_dim=self.output_dim, - ) - - # Initialize the model parameters with non-default value - if coeff_mat is not None: - self.assign_coeff_matrix(coeff_mat) - - # # Depending on the parameter we define the forward model - # # This is, we add or not the L2 loss to the OPD. - # if self.l2_param == 0.: - # self.call = self.call_basic - # else: - # self.call = self.call_l2_opd_loss - - def get_coeff_matrix(self): - """Get coefficient matrix.""" - return self.tf_poly_Z_field.get_coeff_matrix() - - def assign_coeff_matrix(self, coeff_mat): - """Assign coefficient matrix.""" - self.tf_poly_Z_field.assign_coeff_matrix(coeff_mat) - - def set_output_Q(self, output_Q, output_dim=None): - """Set the value of the output_Q parameter. - Useful for generating/predicting PSFs at a different sampling wrt the - observation sampling. - """ - self.output_Q = output_Q - if output_dim is not None: - self.output_dim = output_dim - # Reinitialize the PSF batch poly generator - self.tf_batch_poly_PSF = TF_batch_poly_PSF( - obscurations=self.obscurations, - output_Q=self.output_Q, - output_dim=self.output_dim, - ) - - def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): - """Predict a set of monochromatic PSF at desired positions. - - input_positions: Tensor(batch_dim x 2) - - lambda_obs: float - Observed wavelength in um. - - phase_N: int - Required wavefront dimension. Should be calculated with as: - ``simPSF_np = wf_psf.sims.psf_simulator.PSFSimulator(...)`` - ``phase_N = simPSF_np.feasible_N(lambda_obs)`` - """ - - # Initialise the monochromatic PSF batch calculator - tf_batch_mono_psf = TF_batch_mono_PSF( - obscurations=self.obscurations, - output_Q=self.output_Q, - output_dim=self.output_dim, - ) - # Set the lambda_obs and the phase_N parameters - tf_batch_mono_psf.set_lambda_phaseN(phase_N, lambda_obs) - - # Continue the OPD maps - zernike_coeffs = self.tf_poly_Z_field(input_positions) - opd_maps = self.tf_zernike_OPD(zernike_coeffs) - - # Compute the monochromatic PSFs - mono_psf_batch = tf_batch_mono_psf(opd_maps) - - return mono_psf_batch - - def predict_opd(self, input_positions): - """Predict the OPD at some positions. - - Parameters - ---------- - input_positions: Tensor(batch_dim x 2) - Positions to predict the OPD. - - Returns - ------- - opd_maps : Tensor [batch x opd_dim x opd_dim] - OPD at requested positions. - - """ - # Continue the OPD maps - zernike_coeffs = self.tf_poly_Z_field(input_positions) - opd_maps = self.tf_zernike_OPD(zernike_coeffs) - - return opd_maps - - def call(self, inputs): - """Define the PSF field forward model. - - [1] From positions to Zernike coefficients - [2] From Zernike coefficients to OPD maps - [3] From OPD maps and SED info to polychromatic PSFs - - OPD: Optical Path Differences - """ - # Unpack inputs - input_positions = inputs[0] - packed_SEDs = inputs[1] - - # Continue the forward model - zernike_coeffs = self.tf_poly_Z_field(input_positions) - opd_maps = self.tf_zernike_OPD(zernike_coeffs) - # Add l2 loss on the OPD - self.add_loss(self.l2_param * tf.math.reduce_sum(tf.math.square(opd_maps))) - poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) - - return poly_psfs - - -class TF_SemiParam_field(tf.keras.Model): - """PSF field forward model! - - Semi parametric model based on the Zernike polynomial basis. The - - Parameters - ---------- - zernike_maps: Tensor(n_batch, opd_dim, opd_dim) - Zernike polynomial maps. - obscurations: Tensor(opd_dim, opd_dim) - Predefined obscurations of the phase. - batch_size: int - Batch sizet - output_Q: float - Oversampling used. This should match the oversampling Q used to generate - the diffraction zero padding that is found in the input `packed_SEDs`. - We call this other Q the `input_Q`. - In that case, we replicate the original sampling of the model used to - calculate the input `packed_SEDs`. - The final oversampling of the generated PSFs with respect to the - original instrument sampling depend on the division `input_Q/output_Q`. - It is not recommended to use `output_Q < 1`. - Although it works with float values it is better to use integer values. - d_max_nonparam: int - Maximum degree of the polynomial for the non-parametric variations. - l2_param: float - Parameter going with the l2 loss on the opd. If it is `0.` the loss - is not added. Default is `0.`. - output_dim: int - Output dimension of the PSF stamps. - n_zernikes: int - Order of the Zernike polynomial for the parametric model. - d_max: int - Maximum degree of the polynomial for the Zernike coefficient variations. - x_lims: [float, float] - Limits for the x coordinate of the PSF field. - y_lims: [float, float] - Limits for the x coordinate of the PSF field. - coeff_mat: Tensor or None - Initialization of the coefficient matrix defining the parametric psf - field model. - - """ - - def __init__( - self, - zernike_maps, - obscurations, - batch_size, - output_Q, - d_max_nonparam=3, - l2_param=0.0, - output_dim=64, - n_zernikes=45, - d_max=2, - x_lims=[0, 1e3], - y_lims=[0, 1e3], - coeff_mat=None, - name="TF_SemiParam_field", - ): - super(TF_SemiParam_field, self).__init__() - - # Inputs: oversampling used - self.output_Q = output_Q - - # Inputs: TF_poly_Z_field - self.n_zernikes = n_zernikes - self.d_max = d_max - self.x_lims = x_lims - self.y_lims = y_lims - - # Inputs: TF_NP_poly_OPD - self.d_max_nonparam = d_max_nonparam - self.opd_dim = tf.shape(zernike_maps)[1].numpy() - - # Inputs: TF_zernike_OPD - # They are not stored as they are memory-heavy - # zernike_maps =[] - - # Inputs: TF_batch_poly_PSF - self.batch_size = batch_size - self.obscurations = obscurations - self.output_dim = output_dim - - # Inputs: Loss - self.l2_param = l2_param - - # Initialize the first layer - self.tf_poly_Z_field = TF_poly_Z_field( - x_lims=self.x_lims, - y_lims=self.y_lims, - n_zernikes=self.n_zernikes, - d_max=self.d_max, - ) - - # Initialize the zernike to OPD layer - self.tf_zernike_OPD = TF_zernike_OPD(zernike_maps=zernike_maps) - - # Initialize the non-parametric layer - self.tf_np_poly_opd = TF_NP_poly_OPD( - x_lims=self.x_lims, - y_lims=self.y_lims, - d_max=self.d_max_nonparam, - opd_dim=self.opd_dim, - ) - - # Initialize the batch opd to batch polychromatic PSF layer - self.tf_batch_poly_PSF = TF_batch_poly_PSF( - obscurations=self.obscurations, - output_Q=self.output_Q, - output_dim=self.output_dim, - ) - - # Initialize the model parameters with non-default value - if coeff_mat is not None: - self.assign_coeff_matrix(coeff_mat) - - # # Depending on the parameter we define the forward model - # # This is, we add or not the L2 loss to the OPD. - # if self.l2_param == 0.: - # self.call = self.call_basic - # else: - # self.call = self.call_l2_opd_loss - - def get_coeff_matrix(self): - """Get coefficient matrix.""" - return self.tf_poly_Z_field.get_coeff_matrix() - - def assign_coeff_matrix(self, coeff_mat): - """Assign coefficient matrix.""" - self.tf_poly_Z_field.assign_coeff_matrix(coeff_mat) - - def set_zero_nonparam(self): - """Set to zero the non-parametric part.""" - self.tf_np_poly_opd.set_alpha_zero() - - def set_nonzero_nonparam(self): - """Set to non-zero the non-parametric part.""" - self.tf_np_poly_opd.set_alpha_identity() - - def set_trainable_layers(self, param_bool=True, nonparam_bool=True): - """Set the layers to be trainable or not.""" - self.tf_np_poly_opd.trainable = nonparam_bool - self.tf_poly_Z_field.trainable = param_bool - - def set_output_Q(self, output_Q, output_dim=None): - """Set the value of the output_Q parameter. - Useful for generating/predicting PSFs at a different sampling wrt the - observation sampling. - """ - self.output_Q = output_Q - if output_dim is not None: - self.output_dim = output_dim - - # Reinitialize the PSF batch poly generator - self.tf_batch_poly_PSF = TF_batch_poly_PSF( - obscurations=self.obscurations, - output_Q=self.output_Q, - output_dim=self.output_dim, - ) - - def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): - """Predict a set of monochromatic PSF at desired positions. - - input_positions: Tensor(batch_dim x 2) - - lambda_obs: float - Observed wavelength in um. - - phase_N: int - Required wavefront dimension. Should be calculated with as: - ``simPSF_np =wf_psf.sims.psf_simulator.PSFSimulator(...)`` - ``phase_N = simPSF_np.feasible_N(lambda_obs)`` - """ - - # Initialise the monochromatic PSF batch calculator - tf_batch_mono_psf = TF_batch_mono_PSF( - obscurations=self.obscurations, - output_Q=self.output_Q, - output_dim=self.output_dim, - ) - # Set the lambda_obs and the phase_N parameters - tf_batch_mono_psf.set_lambda_phaseN(phase_N, lambda_obs) - - # Calculate parametric part - zernike_coeffs = self.tf_poly_Z_field(input_positions) - param_opd_maps = self.tf_zernike_OPD(zernike_coeffs) - # Calculate the non parametric part - nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - # Add the estimations - opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) - - # Compute the monochromatic PSFs - mono_psf_batch = tf_batch_mono_psf(opd_maps) - - return mono_psf_batch - - def predict_opd(self, input_positions): - """Predict the OPD at some positions. - - Parameters - ---------- - input_positions: Tensor(batch_dim x 2) - Positions to predict the OPD. - - Returns - ------- - opd_maps : Tensor [batch x opd_dim x opd_dim] - OPD at requested positions. - - """ - # Calculate parametric part - zernike_coeffs = self.tf_poly_Z_field(input_positions) - param_opd_maps = self.tf_zernike_OPD(zernike_coeffs) - # Calculate the non parametric part - nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - # Add the estimations - opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) - - return opd_maps - - def assign_S_mat(self, S_mat): - """Assign DD features matrix.""" - self.tf_np_poly_opd.assign_S_mat(S_mat) - - def project_DD_features(self, tf_zernike_cube): - """ - Project non-parametric wavefront onto first n_z Zernikes and transfer - their parameters to the parametric model. - - """ - # Compute Zernike norm for projections - n_pix_zernike = PI_zernikes(tf_zernike_cube[0, :, :], tf_zernike_cube[0, :, :]) - # Multiply Alpha matrix with DD features matrix S - inter_res_v2 = tf.tensordot( - self.tf_np_poly_opd.alpha_mat[: self.tf_poly_Z_field.coeff_mat.shape[1], :], - self.tf_np_poly_opd.S_mat, - axes=1, - ) - # Project over first n_z Zernikes - delta_C_poly = tf.constant( - np.array( - [ - [ - PI_zernikes( - tf_zernike_cube[i, :, :], - inter_res_v2[j, :, :], - n_pix_zernike, - ) - for j in range(self.tf_poly_Z_field.coeff_mat.shape[1]) - ] - for i in range(self.n_zernikes) - ] - ), - dtype=tf.float32, - ) - old_C_poly = self.tf_poly_Z_field.coeff_mat - # Corrected parametric coeff matrix - new_C_poly = old_C_poly + delta_C_poly - self.assign_coeff_matrix(new_C_poly) - - # Remove extracted features from non-parametric model - # Mix DD features with matrix alpha - S_tilde = tf.tensordot( - self.tf_np_poly_opd.alpha_mat, self.tf_np_poly_opd.S_mat, axes=1 - ) - # Get beta tilde as the proyection of the first n_param_poly_terms (6 for d_max=2) onto the first n_zernikes. - beta_tilde_inner = np.array( - [ - [ - PI_zernikes(tf_zernike_cube[j, :, :], S_tilde_slice, n_pix_zernike) - for j in range(self.n_zernikes) - ] - for S_tilde_slice in S_tilde[ - : self.tf_poly_Z_field.coeff_mat.shape[1], :, : - ] - ] - ) - - # Only pad in the firs dimention so we get a matrix of size (d_max_nonparam_terms)x(n_zernikes) --> 21x15 or 21x45. - beta_tilde = np.pad( - beta_tilde_inner, - [(0, S_tilde.shape[0] - beta_tilde_inner.shape[0]), (0, 0)], - mode="constant", - ) - - # Unmix beta tilde with the inverse of alpha - beta = tf.constant( - np.linalg.inv(self.tf_np_poly_opd.alpha_mat) @ beta_tilde, dtype=tf.float32 - ) - # Get the projection for the unmixed features - - # Now since beta.shape[1]=n_zernikes we can take the whole beta matrix. - S_mat_projected = tf.tensordot(beta, tf_zernike_cube, axes=[1, 0]) - - # Subtract the projection from the DD features - S_new = self.tf_np_poly_opd.S_mat - S_mat_projected - self.assign_S_mat(S_new) - - def call(self, inputs): - """Define the PSF field forward model. - - [1] From positions to Zernike coefficients - [2] From Zernike coefficients to OPD maps - [3] From OPD maps and SED info to polychromatic PSFs - - OPD: Optical Path Differences - """ - # Unpack inputs - input_positions = inputs[0] - packed_SEDs = inputs[1] - - # Forward model - # Calculate parametric part - zernike_coeffs = self.tf_poly_Z_field(input_positions) - param_opd_maps = self.tf_zernike_OPD(zernike_coeffs) - # Add l2 loss on the parametric OPD - self.add_loss( - self.l2_param * tf.math.reduce_sum(tf.math.square(param_opd_maps)) - ) - # Calculate the non parametric part - nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - # Add the estimations - opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) - # Compute the polychromatic PSFs - poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) - - return poly_psfs - - -class TF_physical_poly_field(tf.keras.Model): - """PSF field forward model with a physical layer - - WaveDiff-original with a physical layer - - Parameters - ---------- - zernike_maps: Tensor(n_batch, opd_dim, opd_dim) - Zernike polynomial maps. - obscurations: Tensor(opd_dim, opd_dim) - Predefined obscurations of the phase. - batch_size: int - Batch size - obs_pos: Tensor(n_stars, 2) - The positions of all the stars - zks_prior: Tensor(n_stars, n_zks) - The Zernike coeffients of the prior for all the stars - output_Q: float - Oversampling used. This should match the oversampling Q used to generate - the diffraction zero padding that is found in the input `packed_SEDs`. - We call this other Q the `input_Q`. - In that case, we replicate the original sampling of the model used to - calculate the input `packed_SEDs`. - The final oversampling of the generated PSFs with respect to the - original instrument sampling depend on the division `input_Q/output_Q`. - It is not recommended to use `output_Q < 1`. - Although it works with float values it is better to use integer values. - d_max_nonparam: int - Maximum degree of the polynomial for the non-parametric variations. - l2_param: float - Parameter going with the l2 loss on the opd. If it is `0.` the loss - is not added. Default is `0.`. - output_dim: int - Output dimension of the PSF stamps. - n_zks_param: int - Order of the Zernike polynomial for the parametric model. - d_max: int - Maximum degree of the polynomial for the Zernike coefficient variations. - x_lims: [float, float] - Limits for the x coordinate of the PSF field. - y_lims: [float, float] - Limits for the x coordinate of the PSF field. - coeff_mat: Tensor or None - Initialization of the coefficient matrix defining the parametric psf - field model. - interpolation_type: str - Option for the interpolation type of the physical layer. - Default is no interpolation. - interpolation_args: dict - Additional arguments for the interpolation. - - """ - - def __init__( - self, - zernike_maps, - obscurations, - batch_size, - obs_pos, - zks_prior, - output_Q, - d_max_nonparam=3, - l2_param=0.0, - output_dim=64, - n_zks_param=45, - d_max=2, - x_lims=[0, 1e3], - y_lims=[0, 1e3], - coeff_mat=None, - interpolation_type="none", - interpolation_args=None, - name="TF_physical_poly_field", - ): - super(TF_physical_poly_field, self).__init__(name=name) - - # Inputs: oversampling used - self.output_Q = output_Q - self.n_zks_total = tf.shape(zernike_maps)[0].numpy() - - # Inputs: TF_poly_Z_field - self.n_zks_param = n_zks_param - self.d_max = d_max - self.x_lims = x_lims - self.y_lims = y_lims - - # Inputs: TF_physical_layer - self.obs_pos = obs_pos - self.zks_prior = zks_prior - self.n_zks_prior = tf.shape(zks_prior)[1].numpy() - self.interpolation_type = interpolation_type - self.interpolation_args = interpolation_args - - # Inputs: TF_NP_poly_OPD - self.d_max_nonparam = d_max_nonparam - self.opd_dim = tf.shape(zernike_maps)[1].numpy() - - # Check if the Zernike maps are enough - if (self.n_zks_prior > self.n_zks_total) or ( - self.n_zks_param > self.n_zks_total - ): - raise ValueError("The number of Zernike maps is not enough.") - - # Inputs: TF_zernike_OPD - # They are not stored as they are memory-intensive - # zernike_maps =[] - - # Inputs: TF_batch_poly_PSF - self.batch_size = batch_size - self.obscurations = obscurations - self.output_dim = output_dim - - # Inputs: Loss - self.l2_param = l2_param - - # Initialize the first layer - self.tf_poly_Z_field = TF_poly_Z_field( - x_lims=self.x_lims, - y_lims=self.y_lims, - n_zernikes=self.n_zks_param, - d_max=self.d_max, - ) - # Initialize the physical layer - self.tf_physical_layer = TF_physical_layer( - self.obs_pos, - self.zks_prior, - interpolation_type=self.interpolation_type, - interpolation_args=self.interpolation_args, - ) - # Initialize the zernike to OPD layer - self.tf_zernike_OPD = TF_zernike_OPD(zernike_maps=zernike_maps) - - # Initialize the non-parametric layer - self.tf_np_poly_opd = TF_NP_poly_OPD( - x_lims=self.x_lims, - y_lims=self.y_lims, - d_max=self.d_max_nonparam, - opd_dim=self.opd_dim, - ) - # Initialize the batch opd to batch polychromatic PSF layer - self.tf_batch_poly_PSF = TF_batch_poly_PSF( - obscurations=self.obscurations, - output_Q=self.output_Q, - output_dim=self.output_dim, - ) - # Initialize the model parameters with non-default value - if coeff_mat is not None: - self.assign_coeff_matrix(coeff_mat) - - def get_coeff_matrix(self): - """Get coefficient matrix.""" - return self.tf_poly_Z_field.get_coeff_matrix() - - def assign_coeff_matrix(self, coeff_mat): - """Assign coefficient matrix.""" - self.tf_poly_Z_field.assign_coeff_matrix(coeff_mat) - - def set_zero_nonparam(self): - """Set to zero the non-parametric part.""" - self.tf_np_poly_opd.set_alpha_zero() - - def set_nonzero_nonparam(self): - """Set to non-zero the non-parametric part.""" - self.tf_np_poly_opd.set_alpha_identity() - - def set_trainable_layers(self, param_bool=True, nonparam_bool=True): - """Set the layers to be trainable or not.""" - self.tf_np_poly_opd.trainable = nonparam_bool - self.tf_poly_Z_field.trainable = param_bool - - def set_output_Q(self, output_Q, output_dim=None): - """Set the value of the output_Q parameter. - Useful for generating/predicting PSFs at a different sampling wrt the - observation sampling. - """ - self.output_Q = output_Q - if output_dim is not None: - self.output_dim = output_dim - - # Reinitialize the PSF batch poly generator - self.tf_batch_poly_PSF = TF_batch_poly_PSF( - obscurations=self.obscurations, - output_Q=self.output_Q, - output_dim=self.output_dim, - ) - - def zks_pad(self, zk_param, zk_prior): - """Pad the zernike coefficients with zeros to have the same length. - - Pad them to have `n_zks_total` length. - - Parameters - ---------- - zk_param: Tensor [batch, n_zks_param, 1, 1] - Zernike coefficients for the parametric part - zk_prior: Tensor [batch, n_zks_prior, 1, 1] - Zernike coefficients for the prior part - - Returns - ------- - zk_param: Tensor [batch, n_zks_total, 1, 1] - Zernike coefficients for the parametric part - zk_prior: Tensor [batch, n_zks_total, 1, 1] - Zernike coefficients for the prior part - - """ - # Calculate the number of zernikes to pad - pad_num = tf.cast(self.n_zks_total - self.n_zks_param, dtype=tf.int32) - # Pad the zernike coefficients - padding = [ - (0, 0), - (0, pad_num), - (0, 0), - (0, 0), - ] - padded_zk_param = tf.pad(zk_param, padding) - - # Calculate the number of zernikes to pad - pad_num = tf.cast(self.n_zks_total - self.n_zks_prior, dtype=tf.int32) - # Pad the zernike coefficients - padding = [ - (0, 0), - (0, pad_num), - (0, 0), - (0, 0), - ] - padded_zk_prior = tf.pad(zk_prior, padding) - - return padded_zk_param, padded_zk_prior - - def predict_step(self, data, evaluate_step=False): - r"""Custom predict (inference) step. - - It is needed as the physical layer requires a special - interpolation (different from training). - - """ - if evaluate_step: - input_data = data - else: - # Format input data - data = data_adapter.expand_1d(data) - input_data, _, _ = data_adapter.unpack_x_y_sample_weight(data) - - # Unpack inputs - input_positions = input_data[0] - packed_SEDs = input_data[1] - - # Compute zernikes from parametric model and physical layer - zks_coeffs = self.predict_zernikes(input_positions) - # Propagate to obtain the OPD - param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - # Calculate the non parametric part - nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - # Add the estimations - opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) - # Compute the polychromatic PSFs - poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) - - return poly_psfs - - def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): - """Predict a set of monochromatic PSF at desired positions. - - Parameters - ---------- - input_positions: Tensor [batch_dim, 2] - Positions at which to compute the PSF - lambda_obs: float - Observed wavelength in um. - phase_N: int - Required wavefront dimension. Should be calculated with as: - ``simPSF_np = wf_psf.sims.psf_simulator.PSFSimulator(...)`` - ``phase_N = simPSF_np.feasible_N(lambda_obs)`` - - """ - - # Initialise the monochromatic PSF batch calculator - tf_batch_mono_psf = TF_batch_mono_PSF( - obscurations=self.obscurations, - output_Q=self.output_Q, - output_dim=self.output_dim, - ) - # Set the lambda_obs and the phase_N parameters - tf_batch_mono_psf.set_lambda_phaseN(phase_N, lambda_obs) - - # Compute zernikes from parametric model and physical layer - zks_coeffs = self.predict_zernikes(input_positions) - # Propagate to obtain the OPD - param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - # Calculate the non parametric part - nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - # Add the estimations - opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) - - # Compute the monochromatic PSFs - mono_psf_batch = tf_batch_mono_psf(opd_maps) - - return mono_psf_batch - - def predict_opd(self, input_positions): - """Predict the OPD at some positions. - - Parameters - ---------- - input_positions: Tensor [batch_dim, 2] - Positions to predict the OPD. - - Returns - ------- - opd_maps : Tensor [batch, opd_dim, opd_dim] - OPD at requested positions. - - """ - # Compute zernikes from parametric model and physical layer - zks_coeffs = self.predict_zernikes(input_positions) - # Propagate to obtain the OPD - param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - # Calculate the non parametric part - nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - # Add the estimations - opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) - return opd_maps - - def compute_zernikes(self, input_positions): - """Compute Zernike coefficients at a batch of positions - - This includes the parametric model and the physical layer - - Parameters - ---------- - input_positions: Tensor [batch_dim, 2] - Positions to compute the Zernikes. - - Returns - ------- - zks_coeffs : Tensor [batch, n_zks_total, 1, 1] - Zernikes at requested positions - - """ - # Calculate parametric part - zks_params = self.tf_poly_Z_field(input_positions) - # Calculate the physical layer - zks_prior = self.tf_physical_layer.call(input_positions) - # Pad and sum the zernike coefficients - padded_zk_param, padded_zk_prior = self.zks_pad(zks_params, zks_prior) - zks_coeffs = tf.math.add(padded_zk_param, padded_zk_prior) - - return zks_coeffs - - def predict_zernikes(self, input_positions): - """Predict Zernike coefficients at a batch of positions - - This includes the parametric model and the physical layer. - The prediction of the physical layer to positions is not used - at training time. - - Parameters - ---------- - input_positions: Tensor [batch_dim, 2] - Positions to compute the Zernikes. - - Returns - ------- - zks_coeffs : Tensor [batch, n_zks_total, 1, 1] - Zernikes at requested positions - - """ - # Calculate parametric part - zks_params = self.tf_poly_Z_field(input_positions) - # Calculate the physical layer - zks_prior = self.tf_physical_layer.predict(input_positions) - # Pad and sum the zernike coefficients - padded_zk_param, padded_zk_prior = self.zks_pad(zks_params, zks_prior) - zks_coeffs = tf.math.add(padded_zk_param, padded_zk_prior) - - return zks_coeffs - - def call(self, inputs, training=True): - """Define the PSF field forward model. - - [1] From positions to Zernike coefficients - [2] From Zernike coefficients to OPD maps - [3] From OPD maps and SED info to polychromatic PSFs - - OPD: Optical Path Differences - """ - # Unpack inputs - input_positions = inputs[0] - packed_SEDs = inputs[1] - - # For the training - if training: - # Compute zernikes from parametric model and physical layer - zks_coeffs = self.compute_zernikes(input_positions) - # Propagate to obtain the OPD - param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - # Add l2 loss on the parametric OPD - self.add_loss( - self.l2_param * tf.math.reduce_sum(tf.math.square(param_opd_maps)) - ) - # Calculate the non parametric part - nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - # Add the estimations - opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) - # Compute the polychromatic PSFs - poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) - # For the inference - else: - # Compute predictions - poly_psfs = self.predict_step(inputs, evaluate_step=True) - - return poly_psfs - -class TF_GT_physical_field(tf.keras.Model): - """Ground truth PSF field forward model with a physical layer +class TFGroundTruthPhysicalField(tf.keras.Model): + """Ground Truth PSF field forward model with a physical layer Ground truth PSF field used for evaluation purposes. @@ -1009,7 +53,7 @@ def __init__( output_dim=64, name="TF_GT_physical_field", ): - super(TF_GT_physical_field, self).__init__() + super(TFGroundTruthPhysicalField, self).__init__() # Inputs: oversampling used self.output_Q = output_Q @@ -1034,16 +78,16 @@ def __init__( self.output_dim = output_dim # Initialize the physical layer - self.tf_physical_layer = TF_physical_layer( + self.tf_physical_layer = TFPhysicalLayer( self.obs_pos, self.zks_prior, interpolation_type="none", ) # Initialize the zernike to OPD layer - self.tf_zernike_OPD = TF_zernike_OPD(zernike_maps=zernike_maps) + self.tf_zernike_OPD = TFZernikeOPD(zernike_maps=zernike_maps) # Initialize the batch opd to batch polychromatic PSF layer - self.tf_batch_poly_PSF = TF_batch_poly_PSF( + self.tf_batch_poly_PSF = TFBatchPolychromaticPSF( obscurations=self.obscurations, output_Q=self.output_Q, output_dim=self.output_dim, @@ -1059,7 +103,7 @@ def set_output_Q(self, output_Q, output_dim=None): self.output_dim = output_dim # Reinitialize the PSF batch poly generator - self.tf_batch_poly_PSF = TF_batch_poly_PSF( + self.tf_batch_poly_PSF = TFBatchPolychromaticPSF( obscurations=self.obscurations, output_Q=self.output_Q, output_dim=self.output_dim, @@ -1109,7 +153,7 @@ def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): """ # Initialise the monochromatic PSF batch calculator - tf_batch_mono_psf = TF_batch_mono_PSF( + tf_batch_mono_psf = TFBatchMonochromaticPSF( obscurations=self.obscurations, output_Q=self.output_Q, output_dim=self.output_dim, @@ -1209,35 +253,3 @@ def call(self, inputs, training=True): poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) return poly_psfs - - -def build_PSF_model(model_inst, optimizer=None, loss=None, metrics=None): - """Define the model-compilation parameters. - - Specially the loss function, the optimizer and the metrics. - """ - # Define model loss function - if loss is None: - loss = tf.keras.losses.MeanSquaredError() - - # Define optimizer function - if optimizer is None: - optimizer = tf.keras.optimizers.Adam( - learning_rate=1e-2, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=False - ) - - # Define metric functions - if metrics is None: - metrics = [tf.keras.metrics.MeanSquaredError()] - - # Compile the model - model_inst.compile( - optimizer=optimizer, - loss=loss, - metrics=metrics, - loss_weights=None, - weighted_metrics=None, - run_eagerly=False, - ) - - return model_inst diff --git a/src/wf_psf/sims/psf_simulator.py b/src/wf_psf/sims/psf_simulator.py index 1bebd3ff..e1760098 100644 --- a/src/wf_psf/sims/psf_simulator.py +++ b/src/wf_psf/sims/psf_simulator.py @@ -6,7 +6,7 @@ import matplotlib as mpl from matplotlib.colors import ListedColormap, LinearSegmentedColormap from mpl_toolkits.axes_grid1 import make_axes_locatable -from wf_psf.utils.utils import PI_zernikes, zernike_generator +from wf_psf.utils.utils import zernike_generator try: from cv2 import resize, INTER_AREA @@ -123,7 +123,6 @@ def __init__( self.rand_seed = rand_seed self.plot_opt = plot_opt self.zernike_maps = zernike_generator(self.max_order, self.pupil_diameter) - # self.zernike_maps = zernike_maps self.max_wfe_rms = max_wfe_rms # In [um] self.output_dim = output_dim # In pixels per dimension self.verbose = verbose @@ -155,7 +154,7 @@ def __init__( self.obscurations = np.ones((pupil_diameter, pupil_diameter)) @staticmethod - def _OLD_fft_diffraction_op(wf, pupil_mask, pad_factor=2, match_shapes=True): + def _old_fft_diffraction_op(wf, pupil_mask, pad_factor=2, match_shapes=True): """Perform a fft-based diffraction. Parameters diff --git a/src/wf_psf/sims/spatial_varying_psf.py b/src/wf_psf/sims/spatial_varying_psf.py new file mode 100644 index 00000000..820e8399 --- /dev/null +++ b/src/wf_psf/sims/spatial_varying_psf.py @@ -0,0 +1,808 @@ +import numpy as np +import matplotlib.pyplot as plt +import matplotlib as mpl +from matplotlib.colors import ListedColormap, LinearSegmentedColormap +from mpl_toolkits.axes_grid1 import make_axes_locatable +import logging + +logger = logging.getLogger(__name__) + + +class MeshHelper: + """Mesh Helper. + + A utility class for generating mesh grids. + + """ + + @staticmethod + def build_mesh(x_lims, y_lims, grid_points=None, grid_size=None, endpoint=True): + """Build Mesh. + + A method to build a mesh. + + Parameters + ---------- + x_lims: list + A list representing the lower and upper limits along the x-axis. + y_lims: list + A list representing the lower and upper limits along the y-axis. + grid_points: list or None, optional + List defining the size of each axis grid for constructing the mesh grid. + If provided and `grid_size` is also provided, `grid_size` will + override this parameter. (default is None) + grid_size: int or None, optional + Number of points to generate for each axis of the grid. If None and `grid_points` + is not provided, the default grid size is used. (default is None) + endpoint: bool, optional + If True, `stop` is the last sample. Otherwise, it is not included. (default is True). + + Returns + ------- + tuple + A tuple containing two 2-dimensional arrays for x- and y-coordinate axes. + + """ + if grid_size is None: + if grid_points is None: + raise ValueError( + "At least one of 'grid_points' or 'grid_size' must be provided." + ) + num_x, num_y = grid_points + else: + num_x = grid_size + num_y = grid_size + + # Choose the anchor points on a regular grid + x = np.linspace(x_lims[0], x_lims[1], num=num_x, endpoint=endpoint) + y = np.linspace(y_lims[0], y_lims[1], num=num_y, endpoint=endpoint) + + # Build mesh + return np.meshgrid(x, y) + + +class CoordinateHelper: + """Coordinate Helper. + + A utility class for handling coordinate operations. + + """ + + @staticmethod + def scale_positions(x, y, x_lims, y_lims): + """Scale Positions. + + A method to scale x- and y- positions. + + Parameters + ---------- + x: numpy.ndarray + A 1-dimensional numpy ndarray denoting a vector of x positions. + y: numpy.ndarray + A 1-dimensional numpy ndarray denoting a vector of y positions. + x_lims: list + A list representing the lower and upper limits along the x-axis. + y_lims: list + A list representing the lower and upper limits along the y-axis. + + Returns + ------- + scaled_x: numpy.ndarray + Scaled x positions. + + scaled_y: numpy.ndarray + Scaled y positions. + + + """ + # Scale positions to the square [-1,1] x [-1,1] + scaled_x = (x - x_lims[0]) / (x_lims[1] - x_lims[0]) + scaled_x = (scaled_x - 0.5) * 2 + scaled_y = (y - y_lims[0]) / (y_lims[1] - y_lims[0]) + scaled_y = (scaled_y - 0.5) * 2 + + return scaled_x, scaled_y + + @staticmethod + def calculate_shift(x_lims, y_lims, grid_points): + """Calculate Shift. + + A method to calcuate the step size for shifting positions + based on the specified coordinate limits and grid points. + + Parameters + ---------- + x_lims: list + A list representing the lower and upper limits along the x-axis. + y_lims: list + A list representing the lower and upper limits along the y-axis. + grid_points: list + List defining the size of each axis grid. + + Returns + ------- + xstep: int + ystep: int + """ + x_step = (x_lims[1] - x_lims[0]) / grid_points[0] + y_step = (y_lims[1] - y_lims[0]) / grid_points[1] + + return x_step, y_step + + @staticmethod + def add_random_shift_to_positions( + xv_grid, yv_grid, grid_points, x_lims, y_lims, seed=None + ): + """Add Random Shift to Positions + + Add random shifts to positions within each grid cell, + ensuring no overlap between neighboring cells. + + Parameters + ---------- + xv_grid: Numpy array + Grid of x-coordinates. + yv_grid: Numpy array + Grid of y-coordinates. + grid_points : list + A list defining the size of each axis grid + x_lims: list + Lower and upper limits along the x-axis. + y_lims: list + Lower and upper limits along the y-axis. + seed: int + Seed for random number generation. + + Returns + ------- + xv_s: numpy.ndarray + Positions with added random shifts along the x-axis. + yv_s: numpy.ndarray + Positions with added random shifts along the y-axis. + + + """ + ## Random position shift + # It is done as a random shift defined in a + # square centred in each grid position so that there is + # no overlapping between the squares. + np.random.seed(seed) + xv_rand = np.random.rand(grid_points[0], grid_points[1]) + yv_rand = np.random.rand(grid_points[0], grid_points[1]) + # Calculate the shift length + x_step, y_step = CoordinateHelper.calculate_shift(x_lims, y_lims, grid_points) + # Center and scale shifts + xv_rand = (xv_rand - 0.5) * x_step + yv_rand = (yv_rand - 0.5) * y_step + # Add the shift to the grid values + xv = xv_grid + xv_rand.T + yv = yv_grid + yv_rand.T + + xv_s, yv_s = CoordinateHelper.check_and_adjust_coordinate_limits( + xv.flatten(), yv.flatten(), x_lims, y_lims + ) + + return xv_s, yv_s + + @staticmethod + def check_and_adjust_coordinate_limits(x, y, x_lims, y_lims): + """Check and adjust coordinate limits. + + A method to check and adjust coordinate limits to within + the range of x_lims and y_lims, respectively. + + Parameters + ---------- + x: numpy.ndarray + A 1-dimensional numpy-ndarray containing positions along x-axis. + y: numpy.ndarray + A 1-dimensional numpy-ndarray containing positions along y-axis. + x_lims: list + Lower and upper limits along the x-axis. + y_lims: list + Lower and upper limits along the y-axis. + + Returns + ------- + x: numpy.ndarray + A numpy.ndarraycontaining adjusted positions along the x-axis within the specified limits. + y: numpy.ndarray + A numpy.ndarraycontaining adjusted positions along the x-axis within the specified limits. + + """ + x[x > x_lims[1]] = x_lims[1] + x[x < x_lims[0]] = x_lims[0] + y[y > y_lims[1]] = y_lims[1] + y[y < y_lims[0]] = y_lims[0] + + return x, y + + @staticmethod + def check_position_coordinate_limits(xv, yv, x_lims, y_lims, verbose): + """Check Position Coordinate Limits. + + This function checks if the given position coordinates (xv, yv) are within the specified limits + (x_lims, y_lims). It raises a warning if any coordinate is outside the limits. + + Parameters + ---------- + xv: numpy.ndarray + The x coordinates to be checked. + yv: numpy.ndarray + The y coordinates to be checked. + x_lims: tuple + A tuple (min, max) specifying the lower and upper limits for x coordinates. + y_lims: tuple + A tuple (min, max) specifying the lower and upper limits for y coordinates. + verbose: bool + If True, print warning messages when coordinates are outside the limits. + + Returns + ------- + None + + """ + + x_check = np.sum(xv >= x_lims[1] * 1.1) + np.sum(xv <= x_lims[0] * 1.1) + y_check = np.sum(yv >= y_lims[1] * 1.1) + np.sum(yv <= y_lims[0] * 1.1) + + if verbose and x_check > 0: + logger.warning( + "WARNING! x value is outside the limits [%f, %f]" + % (x_lims[0], x_lims[1]) + ) + + if verbose and y_check > 0: + logger.warning( + "WARNING! y value is outside the limits [%f, %f]" + % (y_lims[0], y_lims[1]) + ) + + +class PolynomialMatrixHelper: + """PolynomialMatrixHelper. + + Helper class with methods for generating polynomial matrices of positions. + + """ + + @staticmethod + def generate_polynomial_matrix(x, y, x_lims, y_lims, d_max): + """Generate polynomial matrix of positions. + + This method constructs a polynomial matrix representing spatial variations + in a two-dimensional field. The polynomial matrix is generated based on the + given x and y positions, considering a maximum polynomial degree specified + by d_max. + + Parameters + ---------- + x: numpy.ndarray + A 1-dimensional numpy ndarray denoting a vector of x positions. + y: numpy.ndarray + A 1-dimensional numpy ndarray denoting a vector of y positions. + x_lims: list + Lower and upper limits along the x-axis. + y_lims: list + Lower and upper limits along the y-axis. + d_max: int + The maximum polynomial degree for the spatial variation of the field. + + Returns + ------- + Pi: numpy.ndarray + A 2-dimensional polynomial matrix representing the spatial variations. + """ + n_mono = (d_max + 1) * (d_max + 2) // 2 + if np.isscalar(x): + Pi = np.zeros((n_mono, 1)) + else: + Pi = np.zeros((n_mono, x.shape[0])) + + # Scale positions to the square [-1,1] x [-1,1] + scaled_x, scaled_y = CoordinateHelper.scale_positions(x, y, x_lims, y_lims) + + for d in range(d_max + 1): + row_idx = d * (d + 1) // 2 + for p in range(d + 1): + Pi[row_idx + p, :] = scaled_x ** (d - p) * scaled_y**p + + return Pi + + +class ZernikeHelper: + """ZernikeHelper. + + Helper class for generating Zernike Polynomials. + + + """ + + @staticmethod + def initialize_Z_matrix(max_order, size, seed=None): + """Initialize Zernike Matrix. + + This method initializes a Zernike matrix with a specified size determined by + the maximum order of Zernike polynomials and the length of the position vector + along the x-coordinate axis. The matrix is populated with random values sampled + from a standard normal distribution. + + Parameters + ---------- + max_order: int + The maximum order of Zernike polynomials to be used in the simulation. + size: int + An integer representing the size of the position vector. + seed: int + Seed for random number generation. + + Returns + ------- + numpy.ndarray + An array of shape (max_order, size) containing randomly generated values + from a standard normal distribution to initialize the Zernike matrix. + + + """ + np.random.seed(seed) + return np.random.randn(max_order, size) + + @staticmethod + def normalize_Z_matrix(Z, lim_max_wfe_rms): + """Normalize Zernike Matrix. + + This method performs normalization on the Zernike matrix. It calculates + normalization weights as the square root of the sum of squares of the + Zernike matrix along the second axis. Each row of the matrix is then + divided by its corresponding normalization weight, scaled by the maximum + allowed Wave Front Error (WFE) Root-Mean-Square (RMS) error. + + Parameters + ---------- + Z: numpy.ndarray + A numpy ndarray representing the Zernike matrix. + lim_max_wfe_rms: int + The upper maximum value limit for the Wave Front Error (WFE) Root-Mean-Square (RMS) error. + + Returns + ------- + Z: numpy.ndarray + The normalized Zernike matrix after applying the normalization process. + + """ + norm_weights = np.sqrt(np.sum(Z**2, axis=1)) + Z /= norm_weights.reshape((-1, 1)) / lim_max_wfe_rms + return Z + + @staticmethod + def initialize_normalized_zernike_matrix( + max_order, size, lim_max_wfe_rms, seed=None + ): + """Initialize Normalized Zernike Matrix. + + This method initializes a normalized Zernike matrix. + + Parameters + ---------- + max_order: int + The maximum order of Zernike polynomials to be used in the simulation. + size: int + An integer representing the size of the position vector. + lim_max_wfe_rms: int + The upper maximum value limit for the Wave Front Error (WFE) Root-Mean-Square (RMS) error. + seed: int + Seed for random number generation. + + Returns + ------- + numpy.ndarray + A normalized Zernike matrix. + + """ + return ZernikeHelper.normalize_Z_matrix( + ZernikeHelper.initialize_Z_matrix(max_order, size, seed), lim_max_wfe_rms + ) + + @staticmethod + def generate_zernike_polynomials(xv, yv, x_lims, y_lims, d_max, polynomial_coeffs): + """Generate Zernike Polynomials. + + [old name: zernike_poly_gen] This method calculates Zernike polynomials based on the given x and y positions, + considering a maximum polynomial degree specified by d_max and a set of polynomial + coefficients. + + Parameters + ---------- + xv: np.ndarray (dim,) + x positions. + yv: np.ndarray (dim,) + y positions. + x_lims: list + Lower and upper limits along the x-axis. + y_lims: list + Lower and upper limits along the y-axis. + d_max: int + The maximum polynomial degree for the spatial variation of the field.s + polynomial_coeffs: numpy.ndarray + An array containing the polynomial coefficients. + + Returns + ------- + numpy.ndarray + A 1-dimensional numpy ndarray representing the spatial polynomials generated + from the given positions and polynomial coefficients. + + """ + # Generate the polynomial matrix + Pi_samples = PolynomialMatrixHelper.generate_polynomial_matrix( + xv, yv, x_lims, y_lims, d_max + ) + + return polynomial_coeffs @ Pi_samples + + @staticmethod + def calculate_zernike( + xv, yv, x_lims, y_lims, d_max, polynomial_coeffs, verbose=False + ): + """Calculate Zernikes for a specific position. + + This method computes Zernike polynomials for given positions (xv, yv). + The positions (xv, yv) should lie within the specified limits: + [x_lims[0], x_lims[1]] along the x-axis and [y_lims[0], y_lims[1]] along the y-axis. + Additionally, the positions should be normalized to the range [-1, +1] along both axes. + + Parameters + ---------- + xv: numpy.ndarray + Array containing positions along the x-axis. + yv: numpy.ndarray + Array containing positions along the y-axis. + x_lims: list + Lower and upper limits along the x-axis. + y_lims: list + Lower and upper limits along the y-axis. + verbose: bool + Flag to indicate whether to print warning messages when positions are outside the specified limits. + + Returns + ------- + numpy.ndarray + Array containing the computed Zernike polynomials for the given positions. + + """ + # Check limits + CoordinateHelper.check_position_coordinate_limits( + xv, yv, x_lims, y_lims, verbose + ) + + # Return Zernikes + # The position scaling is done inside generate_zernike_polynomials + return ZernikeHelper.generate_zernike_polynomials( + xv, yv, x_lims, y_lims, d_max, polynomial_coeffs + ) + + +class SpatialVaryingPSF(object): + """Spatial Varying PSF. + + Generate PSF field with polynomial variations of Zernike coefficients. + + Parameters + ---------- + psf_simulator: PSFSimulator object + Class instance of the PSFSimulator + d_max: int + Integer representing the maximum polynomial degree for the FOV spatial variation of WFE. + grid_points: list + List defining the size of each axis grid for constructing the (constrained random realisation) polynomial coefficient matrix. + grid_size: int or None, optional + Number of points to generate for the grid. If None, the value from + grid_points attribute will be used. (default is None) + max_order: int + The maximum order of Zernike polynomials to be used in the simulation. + x_lims: list + A list representing the lower and upper limits along the x-axis. + y_lims: list + A list representing the lower and upper limits along the y-axis. + n_bins: int + An integer representing the number of equidistant bins to partition the passband to compute polychromatic PSFs. + lim_max_wfe_rms: float + The upper maximum value limit for the Wave Front Error (WFE) Root-Mean-Square (RMS) error. + verbose: bool + A flag to determine whether to print warning messages. + + """ + + def __init__( + self, + psf_simulator, + d_max=2, + grid_points=[4, 4], + grid_size=None, + max_order=45, + x_lims=[0, 1e3], + y_lims=[0, 1e3], + n_bins=35, + lim_max_wfe_rms=None, + verbose=False, + seed=None, + ): + # Input attributes + self.psf_simulator = psf_simulator + self.max_order = max_order + self.d_max = d_max + self.x_lims = x_lims + self.y_lims = y_lims + self.grid_points = grid_points + self.grid_size = grid_size + self.n_bins = n_bins + self.verbose = verbose + self.seed = seed + self._lim_max_wfe_rms = lim_max_wfe_rms + + # Class attributes + self.polynomial_coeffs = None + self.WFE_RMS = None + + # Build coefficient polynomial matrix + self.build_polynomial_coeffs() + + @property + def lim_max_wfe_rms(self): + """Get the upper limit for Wave Front Error (WFE) Root-Mean-Square (RMS). + + If the custom upper limit `lim_max_wfe_rms` is not set, this property returns + the maximum WFE RMS value from the PSF simulator. Otherwise, it returns the + custom upper limit. + + Returns + ------- + float + The upper limit for Wave Front Error (WFE) Root-Mean-Square (RMS). + """ + if self._lim_max_wfe_rms is None: + return self.psf_simulator.max_wfe_rms + else: + return self._lim_max_wfe_rms + + @lim_max_wfe_rms.setter + def lim_max_wfe_rms(self, value): + """Set the upper limit for the Wave Front Error (WFE) Root-Mean-Square (RMS). + + This setter method allows you to specify a custom upper limit for the + Wave Front Error (WFE) Root-Mean-Square (RMS). Once set, this custom limit + will be used instead of the default limit from the PSF simulator. + + Parameters + ---------- + value: float + The new upper limit value to be set. + """ + self._lim_max_wfe_rms = value + + def estimate_polynomial_coeffs(self, xv, yv, Z): + """Estimate polynomial coefficients using least squares. + + This method estimates the polynomial coefficients using the least squares + method based on the provided positions along the x and y axes. + + Parameters + ---------- + xv: numpy.ndarray + A 1-dimensional numpy ndarray representing positions along the x-axis. + yv: numpy.ndarray + A 1-dimensional numpy ndarray representing positions along the y-axis. + + Z: numpy.ndarray + A something numpy ndarray representing the Zernike coefficients. + + """ + Pi = PolynomialMatrixHelper.generate_polynomial_matrix( + xv, yv, self.x_lims, self.y_lims, self.d_max + ) + + return Z @ np.linalg.pinv(Pi) + + def calculate_wfe_rms(self, xv, yv, polynomial_coeffs): + """Calculate the Wave Front Error (WFE) Root-Mean-Square (RMS). + + This method calculates the WFE RMS for a specific position using the provided + x and y coordinates and polynomial coefficients. + + Parameters + ---------- + xv: numpy.ndarray + A 1-dimensional numpy ndarray representing positions along the x-axis. + yv: numpy.ndarray + A 1-dimensional numpy ndarray representing positions along the y-axis. + polynomial_coeffs: numpy.ndarray + A numpy ndarray containing the polynomial coefficients. + + Returns + ------- + numpy.ndarray + An array containing the WFE RMS values for the provided positions. + """ + + Z = ZernikeHelper.generate_zernike_polynomials( + xv, yv, self.x_lims, self.y_lims, self.d_max, polynomial_coeffs + ) + return np.sqrt(np.sum(Z**2, axis=0)) + + def build_polynomial_coeffs(self): + """Build polynomial coefficients for spatial variation. + + This method constructs polynomial coefficients for spatial variation by following these steps: + 1. Build a mesh based on the provided limits and grid points. + 2. Apply random position shifts to the mesh coordinates. + 3. Estimate polynomial coefficients using the shifted positions. + 4. Choose anchor points on a regular grid and calculate the Wave Front Error (WFE) Root-Mean-Square (RMS) + on this new grid. + 5. Scale the polynomial coefficients to ensure that the mean WFE RMS over the field of view is 80% of the + maximum allowed WFE RMS per position. + 6. Recalculate the Zernike coefficients using the scaled polynomial coefficients. + 7. Calculate and save the WFE RMS map of the polynomial coefficient values. + + Returns + ------- + None + """ + + # Build mesh + xv_grid, yv_grid = MeshHelper.build_mesh( + self.x_lims, self.y_lims, self.grid_points + ) + + # Apply random position shifts + xv, yv = CoordinateHelper.add_random_shift_to_positions( + xv_grid, yv_grid, self.grid_points, self.x_lims, self.y_lims, self.seed + ) + + # Generate normalized Z matrix + Z = ZernikeHelper.initialize_normalized_zernike_matrix( + self.max_order, len(xv), self.lim_max_wfe_rms, self.seed + ) + + # Generate Polynomial coefficients for each Zernike + self.polynomial_coeffs = self.estimate_polynomial_coeffs(xv, yv, Z) + + ## Sampling the space + # Choose the anchor points on a regular grid + xv_grid, yv_grid = MeshHelper.build_mesh( + self.x_lims, self.y_lims, self.grid_points, self.grid_size, endpoint=True + ) + + ## Renormalize and check that the WFE RMS has a max value near the expected one + # Calculate the WFE_RMS on the new grid + xv = xv_grid.flatten() + yv = yv_grid.flatten() + + calc_wfe = self.calculate_wfe_rms(xv, yv, self.polynomial_coeffs) + + # Due to the polynomial behaviour, set the mean WFE_RMS over the field of view to be 80% of + # the maximum allowed WFE_RMS per position. + scale_factor = (0.8 * self.lim_max_wfe_rms) / np.mean(calc_wfe) + self.polynomial_coeffs *= scale_factor + + # Scale the Z coefficients + scaled_Z_estimate = ZernikeHelper.generate_zernike_polynomials( + xv, yv, self.x_lims, self.y_lims, self.d_max, self.polynomial_coeffs + ) + + # Calculate and save the WFE_RMS map of the polynomial coefficients values. + self.WFE_RMS = self.calculate_wfe_rms(xv, yv, self.polynomial_coeffs).reshape( + xv_grid.shape + ) + + def plot_WFE_RMS(self, save_img=False, save_name="WFE_field_meshdim"): + """Plot the Wave Front Error (WFE) Root-Mean-Square (RMS) map. + + This method generates a plot of the WFE RMS map for the Point Spread Function (PSF) field. The plot + visualizes the distribution of WFE RMS values across the field of view. + + Parameters + ---------- + save_img: bool, optional + Flag indicating whether to save the plot as an image file. Default is False. + save_name: str, optional + Name of the image file to save. Default is 'WFE_field_meshdim'. + + Returns + ------- + None + """ + fig = plt.figure(figsize=(8, 8)) + ax1 = fig.add_subplot(111) + im1 = ax1.imshow(self.WFE_RMS, interpolation="None") + divider = make_axes_locatable(ax1) + cax = divider.append_axes("right", size="5%", pad=0.05) + fig.colorbar(im1, cax=cax, orientation="vertical") + ax1.set_title("PSF field WFE RMS [um]") + ax1.set_xlabel("x axis") + ax1.set_ylabel("y axis") + if save_img: + plt.savefig("./" + save_name + ".pdf", bbox_inches="tight") + plt.show() + + def get_monochromatic_PSF(self, xv, yv, lambda_obs=0.725): + """Calculate the monochromatic Point Spread Function (PSF) at a specific position and wavelength. + + This method calculates the monochromatic PSF for a given position and wavelength. It utilizes the + Zernike coefficients of the specific field to generate the PSF using the PSF toolkit generator. + + Parameters + ---------- + xv: numpy.ndarray + 1-dimensional numpy array containing the x positions. + yv: numpy.ndarray + 1-dimensional numpy array containing the y positions. + lambda_obs: float, optional + Wavelength of observation for which the PSF is calculated. Default is 0.725 micrometers. + + Returns + ------- + numpy.ndarray + The generated monochromatic PSF. + + Notes + ----- + The PSF generator's Zernike coefficients are set based on the provided positions before generating the PSF. + + """ + # Calculate the specific field's zernike coeffs + zernikes = ZernikeHelper.calculate_zernike( + xv, yv, self.x_lims, self.y_lims, self.d_max, self.polynomial_coeffs + ) + + # Set the Z coefficients to the PSF toolkit generator + self.psf_simulator.set_z_coeffs(zernikes) + # Generate the monochromatic psf + self.psf_simulator.generate_mono_PSF(lambda_obs=lambda_obs, regen_sample=False) + # Return the generated PSF + return self.psf_simulator.get_psf() + + def get_polychromatic_PSF(self, xv, yv, SED): + """Calculate the polychromatic Point Spread Function (PSF) for a specific position and Spectral Energy Distribution (SED). + + This method calculates the polychromatic PSF for a given position and SED. It utilizes the Zernike coefficients + of the specific field to generate the PSF using the PSF Simulator generator. + + Parameters + ---------- + xv: numpy.ndarray + 1-dimensional numpy array containing the x positions. + yv: numpy.ndarray + 1-dimensional numpy array containing the y positions. + SED: array_like + Spectral Energy Distribution (SED) describing the relative intensity of light at different wavelengths. + + Returns + ------- + tuple + A tuple containing: + - polychromatic_psf : numpy.ndarray + The generated polychromatic PSF. + - zernikes : numpy.ndarray + The Zernike coefficients corresponding to the specific field. + - opd : numpy.ndarray + The Optical Path Difference (OPD) corresponding to the generated PSF. + + Notes + ----- + The PSF generator's Zernike coefficients are set based on the provided positions before generating the PSF. + The SED is used to compute the polychromatic PSF by integrating the monochromatic PSFs over the spectral range. + + """ + # Calculate the specific field's zernike coeffs + zernikes = ZernikeHelper.calculate_zernike( + xv, yv, self.x_lims, self.y_lims, self.d_max, self.polynomial_coeffs + ) + + # Set the Z coefficients to the PSF Simulator generator + self.psf_simulator.set_z_coeffs(zernikes) + polychromatic_psf = self.psf_simulator.generate_poly_PSF( + SED, n_bins=self.n_bins + ) + opd = self.psf_simulator.opd + + return polychromatic_psf, zernikes, opd diff --git a/src/wf_psf/tests/conftest.py b/src/wf_psf/tests/conftest.py index 7b8dc462..7c68bb95 100644 --- a/src/wf_psf/tests/conftest.py +++ b/src/wf_psf/tests/conftest.py @@ -12,7 +12,7 @@ from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.training.train import TrainingParamsHandler from wf_psf.psf_models import psf_models -from wf_psf.data.training_preprocessing import TrainingDataHandler, TestDataHandler +from wf_psf.data.training_preprocessing import DataHandler training_config = RecursiveNamespace( id_name="_sample_w_bis1_2k", @@ -66,48 +66,50 @@ ), ) -data = RecursiveNamespace( - training=RecursiveNamespace( - data_dir="data", - file="coherent_euclid_dataset/train_Euclid_res_200_TrainStars_id_001.npy", - stars=None, - positions=None, - SEDS=None, - zernike_coef=None, - C_poly=None, - params=RecursiveNamespace( - d_max=2, - max_order=45, - x_lims=[0, 1000.0], - y_lims=[0, 1000.0], - grid_points=[4, 4], - n_bins=20, - max_wfe_rms=0.1, - oversampling_rate=3.0, - output_Q=3.0, - output_dim=32, - LP_filter_length=2, - pupil_diameter=256, - euclid_obsc=True, - n_stars=200, +data_conf = RecursiveNamespace( + data=RecursiveNamespace( + training=RecursiveNamespace( + data_dir="data", + file="coherent_euclid_dataset/train_Euclid_res_200_TrainStars_id_001.npy", + stars=None, + positions=None, + SEDS=None, + zernike_coef=None, + C_poly=None, + params=RecursiveNamespace( + d_max=2, + max_order=45, + x_lims=[0, 1000.0], + y_lims=[0, 1000.0], + grid_points=[4, 4], + n_bins=20, + max_wfe_rms=0.1, + oversampling_rate=3.0, + output_Q=3.0, + output_dim=32, + LP_filter_length=2, + pupil_diameter=256, + euclid_obsc=True, + n_stars=200, + ), ), - ), - test=RecursiveNamespace( - data_dir="data", - file="coherent_euclid_dataset/test_Euclid_res_id_001.npy", - stars=None, - noisy_stars=None, - positions=None, - SEDS=None, - zernike_coef=None, - C_poly=None, - parameters=RecursiveNamespace( - d_max=2, - max_order=45, - x_lims=[0, 1000.0], - y_lims=[0, 1000.0], - grid_points=[4, 4], - max_wfe_rms=0.1, + test=RecursiveNamespace( + data_dir="data", + file="coherent_euclid_dataset/test_Euclid_res_id_001.npy", + stars=None, + noisy_stars=None, + positions=None, + SEDS=None, + zernike_coef=None, + C_poly=None, + parameters=RecursiveNamespace( + d_max=2, + max_order=45, + x_lims=[0, 1000.0], + y_lims=[0, 1000.0], + grid_points=[4, 4], + max_wfe_rms=0.1, + ), ), ), ) @@ -120,8 +122,9 @@ def training_params(): @pytest.fixture(scope="module") def training_data(): - return TrainingDataHandler( - data.training, + return DataHandler( + "training", + data_conf.data, psf_models.simPSF(training_config.model_params), training_config.model_params.n_bins_lda, ) @@ -129,8 +132,9 @@ def training_data(): @pytest.fixture(scope="module") def test_data(): - return TestDataHandler( - data.test, + return DataHandler( + "test", + data_conf.data, psf_models.simPSF(training_config.model_params), training_config.model_params.n_bins_lda, ) @@ -138,7 +142,7 @@ def test_data(): @pytest.fixture(scope="module") def test_dataset(test_data): - return test_data.test_dataset + return test_data.dataset @pytest.fixture(scope="module") diff --git a/src/wf_psf/tests/data/validation/main_random_seed/config/data_config.yaml b/src/wf_psf/tests/data/validation/main_random_seed/config/data_config.yaml index 115abf0e..956ad179 100644 --- a/src/wf_psf/tests/data/validation/main_random_seed/config/data_config.yaml +++ b/src/wf_psf/tests/data/validation/main_random_seed/config/data_config.yaml @@ -26,7 +26,7 @@ data: euclid_obsc: true n_stars: 200 test: - data_dir: data/coherent_euclid_dataset/ + data_dir: wf-psf/data/coherent_euclid_dataset/ file: test_Euclid_res_id_001.npy # If test dataset file not provided produce a new one stars: null diff --git a/src/wf_psf/tests/psf_models_test.py b/src/wf_psf/tests/psf_models_test.py deleted file mode 100644 index fbbb1191..00000000 --- a/src/wf_psf/tests/psf_models_test.py +++ /dev/null @@ -1,23 +0,0 @@ -"""UNIT TESTS FOR PACKAGE MODULE: PSF MODELS. - -This module contains unit tests for the wf_psf.psf_models psf_models module. - -:Author: Jennifer Pollack - - -""" - -import pytest -from wf_psf.psf_models import psf_models -from wf_psf.utils.io import FileIOHandler -import os - - -def test_get_psf_model_weights_filepath(): - weights_filepath = "src/wf_psf/tests/data/validation/main_random_seed/checkpoint/checkpoint*_poly*_sample_w_bis1_2k_cycle2*" - - ans = psf_models.get_psf_model_weights_filepath(weights_filepath) - assert ( - ans - == "src/wf_psf/tests/data/validation/main_random_seed/checkpoint/checkpoint_callback_poly_sample_w_bis1_2k_cycle2" - ) diff --git a/src/wf_psf/tests/test_data/conftest.py b/src/wf_psf/tests/test_data/conftest.py new file mode 100644 index 00000000..03e6bc6d --- /dev/null +++ b/src/wf_psf/tests/test_data/conftest.py @@ -0,0 +1,87 @@ +"""FIXTURES FOR GENERATING TESTS FOR WF-PSF MODULES: CONFTEST. + +This module contains fixtures to use in unit tests for +various wf_psf packages. + +:Author: Jennifer Pollack + + +""" +import pytest +from wf_psf.utils.read_config import RecursiveNamespace +from wf_psf.training.train import TrainingParamsHandler +from wf_psf.psf_models import psf_models +from wf_psf.data.training_preprocessing import DataHandler + +training_config = RecursiveNamespace( + id_name="_sample_w_bis1_2k", + data_config="data_config.yaml", + metrics_config="metrics_config.yaml", + model_params=RecursiveNamespace( + model_name="poly", + n_bins_lda=8, + output_Q=3, + oversampling_rate=3, + output_dim=32, + pupil_diameter=256, + use_sample_weights=True, + interpolation_type="None", + sed_interp_pts_per_bin=0, + sed_extrapolate=True, + sed_interp_kind="linear", + sed_sigma=0, + x_lims=[0.0, 1000.0], + y_lims=[0.0, 1000.0], + param_hparams=RecursiveNamespace( + random_seed=3877572, + l2_param=0.0, + n_zernikes=15, + d_max=2, + save_optim_history_param=True, + ), + nonparam_hparams=RecursiveNamespace( + d_max_nonparam=5, + num_graph_features=10, + l1_rate=1e-08, + project_dd_features=False, + reset_dd_features=False, + save_optim_history_nonparam=True, + ), + ), + training_hparams=RecursiveNamespace( + n_epochs_params=[2, 2], + n_epochs_non_params=[2, 2], + batch_size=32, + multi_cycle_params=RecursiveNamespace( + total_cycles=2, + cycle_def="complete", + save_all_cycles=True, + saved_cycle="cycle2", + learning_rate_params=[1.0e-2, 1.0e-2], + learning_rate_non_params=[1.0e-1, 1.0e-1], + n_epochs_params=[2, 2], + n_epochs_non_params=[2, 2], + ), + ), +) + +data = RecursiveNamespace( + train=RecursiveNamespace( + data_dir="data", + file="coherent_euclid_dataset/train_Euclid_res_200_TrainStars_id_001.npy", + ), + test=RecursiveNamespace( + data_dir="data", + file="coherent_euclid_dataset/test_Euclid_res_id_001.npy", + ), +) + + +@pytest.fixture(scope="module", params=[data]) +def data_params(): + return data + + +@pytest.fixture(scope="module", params=[training_config]) +def simPSF(): + return psf_models.simPSF(training_config.model_params) diff --git a/src/wf_psf/tests/test_data/training_preprocessing_test.py b/src/wf_psf/tests/test_data/training_preprocessing_test.py new file mode 100644 index 00000000..e60efed1 --- /dev/null +++ b/src/wf_psf/tests/test_data/training_preprocessing_test.py @@ -0,0 +1,192 @@ +import pytest +import os +import numpy as np +import tensorflow as tf +from wf_psf.utils.read_config import RecursiveNamespace +from wf_psf.data.training_preprocessing import ( + DataHandler, + get_obs_positions, + get_zernike_prior, +) +from wf_psf.psf_models import psf_models + + +def test_initialize_load_dataset(data_params, simPSF): + # Test loading dataset without initialization + data_handler = DataHandler( + "train", data_params, simPSF, n_bins_lambda=10, init_flag=False + ) + assert data_handler.dataset is None # Dataset should not be loaded + + # Test loading dataset with initialization + data_handler = DataHandler( + "train", data_params, simPSF, n_bins_lambda=10, init_flag=True + ) + assert data_handler.dataset is not None # Dataset should be loaded + + +def test_initialize_process_sed_data(data_params, simPSF): + # Test processing SED data without initialization + data_handler = DataHandler( + "train", data_params, simPSF, n_bins_lambda=10, init_flag=False + ) + assert data_handler.sed_data is None # SED data should not be processed + + # Test processing SED data with initialization + data_handler = DataHandler( + "train", data_params, simPSF, n_bins_lambda=10, init_flag=True + ) + assert data_handler.sed_data is not None # SED data should be processed + + +def test_load_train_dataset(tmp_path, data_params, simPSF): + # Create a temporary directory and a temporary data file + data_dir = tmp_path / "data" + data_dir.mkdir() + temp_data_dir = data_dir / "train_data.npy" + + # Mock dataset + mock_dataset = { + "positions": np.array([[1, 2], [3, 4]]), + "noisy_stars": np.array([[5, 6], [7, 8]]), + "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), + } + + # Save the mock dataset to the temporary data file + np.save(temp_data_dir, mock_dataset) + + # Initialize DataHandler instance + data_params = RecursiveNamespace( + train=RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") + ) + + n_bins_lambda = 10 + data_handler = DataHandler("train", data_params, simPSF, n_bins_lambda, False) + + # Call the load_dataset method + data_handler.load_dataset() + + # Assertions + assert np.array_equal(data_handler.dataset["positions"], mock_dataset["positions"]) + assert np.array_equal( + data_handler.dataset["noisy_stars"], mock_dataset["noisy_stars"] + ) + assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) + + +def test_load_test_dataset(tmp_path, data_params, simPSF): + # Create a temporary directory and a temporary data file + data_dir = tmp_path / "data" + data_dir.mkdir() + temp_data_dir = data_dir / "test_data.npy" + + # Mock dataset + mock_dataset = { + "positions": np.array([[1, 2], [3, 4]]), + "stars": np.array([[5, 6], [7, 8]]), + "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), + } + + # Save the mock dataset to the temporary data file + np.save(temp_data_dir, mock_dataset) + + # Initialize DataHandler instance + data_params = RecursiveNamespace( + test=RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") + ) + + n_bins_lambda = 10 + data_handler = DataHandler("test", data_params, simPSF, n_bins_lambda, False) + + # Call the load_dataset method + data_handler.load_dataset() + + # Assertions + assert np.array_equal(data_handler.dataset["positions"], mock_dataset["positions"]) + assert np.array_equal(data_handler.dataset["stars"], mock_dataset["stars"]) + assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) + + +def test_process_sed_data(data_params, simPSF): + mock_dataset = { + "positions": np.array([[1, 2], [3, 4]]), + "noisy_stars": np.array([[5, 6], [7, 8]]), + "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), + } + # Initialize DataHandler instance + n_bins_lambda = 4 + data_handler = DataHandler("train", data_params, simPSF, n_bins_lambda, False) + + data_handler.dataset = mock_dataset + data_handler.process_sed_data() + # Assertions + assert isinstance(data_handler.sed_data, tf.Tensor) + assert data_handler.sed_data.dtype == tf.float32 + assert data_handler.sed_data.shape == ( + len(data_handler.dataset["positions"]), + n_bins_lambda, + len(["feasible_N", "feasible_wv", "SED_norm"]), + ) + + +class MockData: + def __init__( + self, + training_positions, + test_positions, + training_zernike_priors, + test_zernike_priors, + ): + self.training_data = MockDataset(training_positions, training_zernike_priors) + self.test_data = MockDataset(test_positions, test_zernike_priors) + + +class MockDataset: + def __init__(self, positions, zernike_priors): + self.dataset = {"positions": positions, "zernike_prior": zernike_priors} + + +@pytest.fixture +def mock_data(): + training_positions = np.array([[1, 2], [3, 4]]) + test_positions = np.array([[5, 6], [7, 8]]) + training_zernike_priors = np.array([[0.1, 0.2], [0.3, 0.4]]) + test_zernike_priors = np.array([[0.5, 0.6], [0.7, 0.8]]) + return MockData( + training_positions, test_positions, training_zernike_priors, test_zernike_priors + ) + + +def test_get_obs_positions(mock_data): + observed_positions = get_obs_positions(mock_data) + expected_positions = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + assert tf.reduce_all(tf.equal(observed_positions, expected_positions)) + + +def test_get_zernike_prior(mock_data): + zernike_priors = get_zernike_prior(mock_data) + expected_shape = ( + 4, + 2, + ) # Assuming 2 Zernike priors for each dataset (training and test) + assert zernike_priors.shape == expected_shape + + +def test_get_zernike_prior_dtype(mock_data): + zernike_priors = get_zernike_prior(mock_data) + assert zernike_priors.dtype == np.float32 + + +def test_get_zernike_prior_concatenation(mock_data): + zernike_priors = get_zernike_prior(mock_data) + expected_zernike_priors = tf.convert_to_tensor( + np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), dtype=tf.float32 + ) + + assert np.array_equal(zernike_priors, expected_zernike_priors) + + +def test_get_zernike_prior_empty_data(): + empty_data = MockData(np.array([]), np.array([]), np.array([]), np.array([])) + zernike_priors = get_zernike_prior(empty_data) + assert zernike_priors.shape == tf.TensorShape([0]) # Check for empty array shape diff --git a/src/wf_psf/tests/test_psf_models/conftest.py b/src/wf_psf/tests/test_psf_models/conftest.py new file mode 100644 index 00000000..7c68bb95 --- /dev/null +++ b/src/wf_psf/tests/test_psf_models/conftest.py @@ -0,0 +1,153 @@ +"""FIXTURES FOR GENERATING TESTS FOR WF-PSF MODULES: CONFTEST. + +This module contains fixtures to use in unit tests for +various wf_psf packages. + +:Author: Jennifer Pollack + + +""" + +import pytest +from wf_psf.utils.read_config import RecursiveNamespace +from wf_psf.training.train import TrainingParamsHandler +from wf_psf.psf_models import psf_models +from wf_psf.data.training_preprocessing import DataHandler + +training_config = RecursiveNamespace( + id_name="_sample_w_bis1_2k", + data_config="data_config.yaml", + metrics_config="metrics_config.yaml", + model_params=RecursiveNamespace( + model_name="poly", + n_bins_lda=8, + output_Q=3, + oversampling_rate=3, + output_dim=32, + pupil_diameter=256, + use_sample_weights=True, + interpolation_type="None", + sed_interp_pts_per_bin=0, + sed_extrapolate=True, + sed_interp_kind="linear", + sed_sigma=0, + x_lims=[0.0, 1000.0], + y_lims=[0.0, 1000.0], + param_hparams=RecursiveNamespace( + random_seed=3877572, + l2_param=0.0, + n_zernikes=15, + d_max=2, + save_optim_history_param=True, + ), + nonparam_hparams=RecursiveNamespace( + d_max_nonparam=5, + num_graph_features=10, + l1_rate=1e-08, + project_dd_features=False, + reset_dd_features=False, + save_optim_history_nonparam=True, + ), + ), + training_hparams=RecursiveNamespace( + n_epochs_params=[2, 2], + n_epochs_non_params=[2, 2], + batch_size=32, + multi_cycle_params=RecursiveNamespace( + total_cycles=2, + cycle_def="complete", + save_all_cycles=True, + saved_cycle="cycle2", + learning_rate_params=[1.0e-2, 1.0e-2], + learning_rate_non_params=[1.0e-1, 1.0e-1], + n_epochs_params=[2, 2], + n_epochs_non_params=[2, 2], + ), + ), +) + +data_conf = RecursiveNamespace( + data=RecursiveNamespace( + training=RecursiveNamespace( + data_dir="data", + file="coherent_euclid_dataset/train_Euclid_res_200_TrainStars_id_001.npy", + stars=None, + positions=None, + SEDS=None, + zernike_coef=None, + C_poly=None, + params=RecursiveNamespace( + d_max=2, + max_order=45, + x_lims=[0, 1000.0], + y_lims=[0, 1000.0], + grid_points=[4, 4], + n_bins=20, + max_wfe_rms=0.1, + oversampling_rate=3.0, + output_Q=3.0, + output_dim=32, + LP_filter_length=2, + pupil_diameter=256, + euclid_obsc=True, + n_stars=200, + ), + ), + test=RecursiveNamespace( + data_dir="data", + file="coherent_euclid_dataset/test_Euclid_res_id_001.npy", + stars=None, + noisy_stars=None, + positions=None, + SEDS=None, + zernike_coef=None, + C_poly=None, + parameters=RecursiveNamespace( + d_max=2, + max_order=45, + x_lims=[0, 1000.0], + y_lims=[0, 1000.0], + grid_points=[4, 4], + max_wfe_rms=0.1, + ), + ), + ), +) + + +@pytest.fixture(scope="module", params=[training_config]) +def training_params(): + return TrainingParamsHandler(training_config) + + +@pytest.fixture(scope="module") +def training_data(): + return DataHandler( + "training", + data_conf.data, + psf_models.simPSF(training_config.model_params), + training_config.model_params.n_bins_lda, + ) + + +@pytest.fixture(scope="module") +def test_data(): + return DataHandler( + "test", + data_conf.data, + psf_models.simPSF(training_config.model_params), + training_config.model_params.n_bins_lda, + ) + + +@pytest.fixture(scope="module") +def test_dataset(test_data): + return test_data.dataset + + +@pytest.fixture(scope="module") +def psf_model(): + return psf_models.get_psf_model( + training_config.model_params, + training_config.training_hparams, + ) diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py new file mode 100644 index 00000000..3baa9c23 --- /dev/null +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -0,0 +1,317 @@ +"""UNIT TESTS FOR PACKAGE MODULE: psf_model_physical_polychromatic. + +This module contains unit tests for the wf_psf.psf_models.psf_model_physical_polychromatic module. + +:Author: Jennifer Pollack + +""" + +import pytest +import numpy as np +import tensorflow as tf +from wf_psf.psf_models.psf_model_physical_polychromatic import ( + TFPhysicalPolychromaticField, +) +from wf_psf.utils.configs_handler import DataConfigHandler + + +@pytest.fixture +def zks_prior(): + # Define your zks_prior data here + zks_prior_data = [ + [1, 2, 3, 4], + [4, 5, 6, 7], + [7, 8, 9, 8], + [10, 11, 12, 13], + ] + return tf.convert_to_tensor(zks_prior_data, dtype=tf.float32) + + +@pytest.fixture +def mock_data(mocker): + mock_instance = mocker.Mock(spec=DataConfigHandler) + # Configure the mock data object to have the necessary attributes + mock_instance.training_data = mocker.Mock() + mock_instance.training_data.dataset = {"positions": np.array([[1, 2], [3, 4]])} + mock_instance.test_data = mocker.Mock() + mock_instance.test_data.dataset = {"positions": np.array([[5, 6], [7, 8]])} + return mock_instance + + +@pytest.fixture +def mock_model_params(mocker): + model_params_mock = mocker.MagicMock() + model_params_mock.param_hparams.n_zernikes = 10 + model_params_mock.pupil_diameter = 256 + return model_params_mock + + +def test_initialize_parameters(mocker, mock_data, mock_model_params, zks_prior): + # Create mock objects for model_params, training_params + # model_params_mock = mocker.MagicMock() + mock_training_params = mocker.Mock() + + # Mock internal methods called during initialization + mocker.patch( + "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", + return_value=zks_prior, + ) + + mocker.patch( + "wf_psf.data.training_preprocessing.get_obs_positions", return_value=True + ) + + # Create TFPhysicalPolychromaticField instance + field_instance = TFPhysicalPolychromaticField( + mock_model_params, mock_training_params, mock_data + ) + + mocker.patch.object(field_instance, "_initialize_zernike_parameters") + mocker.patch.object(field_instance, "_initialize_layers") + mocker.patch.object(field_instance, "assign_coeff_matrix") + + # Call the method being tested + field_instance._initialize_parameters_and_layers( + mock_model_params, mock_training_params, mock_data + ) + + # Check if internal methods were called with the correct arguments + field_instance._initialize_zernike_parameters.assert_called_once_with( + mock_model_params, mock_data + ) + field_instance._initialize_layers.assert_called_once_with( + mock_model_params, mock_training_params + ) + field_instance.assign_coeff_matrix.assert_not_called() # Because coeff_mat is None in this test + + +def test_initialize_zernike_parameters(mocker, mock_model_params, mock_data, zks_prior): + # Create training params mock object + mock_training_params = mocker.Mock() + + # Mock internal methods called during initialization + mocker.patch( + "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", + return_value=zks_prior, + ) + + # Create TFPhysicalPolychromaticField instance + field_instance = TFPhysicalPolychromaticField( + mock_model_params, mock_training_params, mock_data + ) + + # Assert that the attributes are set correctly + # assert field_instance.n_zernikes == mock_model_params.param_hparams.n_zernikes + assert np.array_equal(field_instance.zks_prior.numpy(), zks_prior.numpy()) + assert field_instance.n_zks_total == mock_model_params.param_hparams.n_zernikes + assert isinstance( + field_instance.zernike_maps, tf.Tensor + ) # Check if the returned value is a TensorFlow tensor + assert ( + field_instance.zernike_maps.dtype == tf.float32 + ) # Check if the data type of the tensor is float32 + + # Expected shape of the tensor based on the input parameters + expected_shape = ( + field_instance.n_zks_total, + mock_model_params.pupil_diameter, + mock_model_params.pupil_diameter, + ) + assert field_instance.zernike_maps.shape == expected_shape + + # Modify model_params to simulate zks_prior > n_zernikes + mock_model_params.param_hparams.n_zernikes = 2 + + # Call the method again to initialize the parameters + field_instance._initialize_zernike_parameters(mock_model_params, mock_data) + + assert field_instance.n_zks_total == tf.cast( + tf.shape(field_instance.zks_prior)[1], tf.int32 + ) + # Expected shape of the tensor based on the input parameters + expected_shape = ( + field_instance.n_zks_total, + mock_model_params.pupil_diameter, + mock_model_params.pupil_diameter, + ) + assert field_instance.zernike_maps.shape == expected_shape + + +def test_initialize_physical_layer_mocking( + mocker, mock_model_params, mock_data, zks_prior +): + # Create training params mock object + mock_training_params = mocker.Mock() + + # Mock internal methods called during initialization + mocker.patch( + "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", + return_value=zks_prior, + ) + + # Create a mock for the TFPhysicalLayer class + mock_physical_layer_class = mocker.patch( + "wf_psf.psf_models.psf_model_physical_polychromatic.TFPhysicalLayer" + ) + + # Create TFPhysicalPolychromaticField instance + field_instance = TFPhysicalPolychromaticField( + mock_model_params, mock_training_params, mock_data + ) + + # Assert that the TFPhysicalLayer class was called with the expected arguments + mock_physical_layer_class.assert_called_once_with( + field_instance.obs_pos, + field_instance.zks_prior, + interpolation_type=mock_model_params.interpolation_type, + interpolation_args=mock_model_params.interpolation_args, + ) + + +@pytest.fixture +def physical_layer_instance(mocker, mock_model_params, mock_data, zks_prior): + # Create training params mock object + mock_training_params = mocker.Mock() + + # Mock internal methods called during initialization + mocker.patch( + "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", + return_value=zks_prior, + ) + + # Create a mock for the TFPhysicalLayer class + mock_physical_layer_class = mocker.patch( + "wf_psf.psf_models.psf_model_physical_polychromatic.TFPhysicalLayer" + ) + + # Create TFPhysicalPolychromaticField instance + psf_field_instance = TFPhysicalPolychromaticField( + mock_model_params, mock_training_params, mock_data + ) + return psf_field_instance + + +def test_pad_zernikes_num_of_zernikes_equal(physical_layer_instance): + # Define input tensors with same length and num of Zernikes + zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) + zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) + + # Reshape the tensors to have the desired shapes + zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) + zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 2, 1, 1) + + # Reset n_zks_total attribute + physical_layer_instance.n_zks_total = max( + tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + ) + # Call the method under test + padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( + zk_param, zk_prior + ) + + # Assert that the padded tensors have the correct shapes + assert padded_zk_param.shape == (1, 2, 1, 1) + assert padded_zk_prior.shape == (1, 2, 1, 1) + + +def test_pad_zernikes_prior_greater_than_param(physical_layer_instance): + zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) + zk_prior = tf.constant([[[[1]], [[2]], [[3]], [[4]], [[5]]]]) # Shape: (5, 1, 1, 1) + + # Reshape the tensors to have the desired shapes + zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) + zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) + + # Reset n_zks_total attribute + physical_layer_instance.n_zks_total = max( + tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + ) + + # Call the method under test + padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( + zk_param, zk_prior + ) + + # Assert that the padded tensors have the correct shapes + assert padded_zk_param.shape == (1, 5, 1, 1) + assert padded_zk_prior.shape == (1, 5, 1, 1) + + +def test_pad_zernikes_shapes_mismatch(physical_layer_instance): + zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) + zk_prior = tf.constant([[[[1]], [[2]], [[3]], [[4]], [[5]]]]) # Shape: (5, 1, 1, 1) + + # Reset n_zks_total attribute + physical_layer_instance.n_zks_total = max( + tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + ) + + # Call the method under test and expect a ValueError + with pytest.raises(ValueError): + physical_layer_instance.pad_zernikes(zk_param, zk_prior) + + +def test_pad_zernikes_param_greater_than_prior(physical_layer_instance): + zk_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]]) # Shape: (4, 1, 1, 1) + zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) + + # Reshape the tensors to have the desired shapes + zk_param = tf.reshape(zk_param, (1, 4, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) + zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) + + # Reset n_zks_total attribute + physical_layer_instance.n_zks_total = max( + tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + ) + + # Call the method under test + padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( + zk_param, zk_prior + ) + + # Assert that the padded tensors have the correct shapes + assert padded_zk_param.shape == (1, 4, 1, 1) + assert padded_zk_prior.shape == (1, 4, 1, 1) + + +def test_compute_zernikes(mocker, physical_layer_instance): + # Mock padded tensors + padded_zk_param = tf.constant( + [[[[10]], [[20]], [[30]], [[40]]]] + ) # Shape: (1, 4, 1, 1) + padded_zk_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]]) # Shape: (1, 4, 1, 1) + + # Reset n_zks_total attribute + physical_layer_instance.n_zks_total = 4 # Assuming a specific value for simplicity + + # Define the mock return values for tf_poly_Z_field and tf_physical_layer.call + padded_zernike_param = tf.constant( + [[[[10]], [[20]], [[30]], [[40]]]] + ) # Shape: (1, 4, 1, 1) + padded_zernike_prior = tf.constant( + [[[[1]], [[2]], [[0]], [[0]]]] + ) # Shape: (1, 4, 1, 1) + + mocker.patch.object( + physical_layer_instance, "tf_poly_Z_field", return_value=padded_zk_param + ) + mocker.patch.object(physical_layer_instance, "call", return_value=padded_zk_prior) + mocker.patch.object( + physical_layer_instance, + "pad_zernikes", + return_value=(padded_zernike_param, padded_zernike_prior), + ) + + # Call the method under test + zernike_coeffs = physical_layer_instance.compute_zernikes(tf.constant([[0.0, 0.0]])) + + # Define the expected values + expected_values = tf.constant( + [[[[11]], [[22]], [[30]], [[40]]]] + ) # Shape: (1, 4, 1, 1) + + # Assert that the shapes are equal + assert zernike_coeffs.shape == expected_values.shape + + # Assert that the tensor values are equal + assert tf.reduce_all(tf.equal(zernike_coeffs, expected_values)) diff --git a/src/wf_psf/tests/test_psf_models/psf_models_test.py b/src/wf_psf/tests/test_psf_models/psf_models_test.py new file mode 100644 index 00000000..86b87c70 --- /dev/null +++ b/src/wf_psf/tests/test_psf_models/psf_models_test.py @@ -0,0 +1,51 @@ +"""UNIT TESTS FOR PACKAGE MODULE: PSF MODELS. + +This module contains unit tests for the wf_psf.psf_models psf_models module. + +:Author: Jennifer Pollack + + +""" + +import pytest +from wf_psf.psf_models import psf_models +from wf_psf.utils.io import FileIOHandler +import tensorflow as tf +import numpy as np +import os + + +def test_get_psf_model_weights_filepath(): + weights_filepath = "src/wf_psf/tests/data/validation/main_random_seed/checkpoint/checkpoint*_poly*_sample_w_bis1_2k_cycle2*" + + ans = psf_models.get_psf_model_weights_filepath(weights_filepath) + assert ( + ans + == "src/wf_psf/tests/data/validation/main_random_seed/checkpoint/checkpoint_callback_poly_sample_w_bis1_2k_cycle2" + ) + + +def test_generate_zernike_maps_3d(): + # Define test parameters + n_zernikes = 5 + pupil_diam = 10 + + # Call the function to generate Zernike maps + zernike_maps = psf_models.generate_zernike_maps_3d(n_zernikes, pupil_diam) + + # Assertions to verify properties of the returned tensor + assert isinstance( + zernike_maps, tf.Tensor + ) # Check if the returned value is a TensorFlow tensor + assert ( + zernike_maps.dtype == tf.float32 + ) # Check if the data type of the tensor is float32 + + # Expected shape of the tensor based on the input parameters + expected_shape = (n_zernikes, pupil_diam, pupil_diam) + assert ( + zernike_maps.shape == expected_shape + ) # Check if the shape of the tensor matches the expected shape + + # Check if all values in the tensor are within the expected range (e.g., between 0 and 1) + # assert tf.reduce_all(tf.logical_and(zernike_maps >= 0, zernike_maps <= 1)) #Fails diff --git a/src/wf_psf/tests/test_sims/conftest.py b/src/wf_psf/tests/test_sims/conftest.py new file mode 100644 index 00000000..bd6427f5 --- /dev/null +++ b/src/wf_psf/tests/test_sims/conftest.py @@ -0,0 +1,73 @@ +"""FIXTURES FOR GENERATING TESTS FOR WF-PSF MODULES: CONFTEST. + +This module contains fixtures to use in unit tests for wf_psf.sims package. + +:Author: Jennifer Pollack + + +""" +import pytest +from wf_psf.utils.read_config import RecursiveNamespace +from wf_psf.sims.spatial_varying_psf import SpatialVaryingPSF, MeshHelper +import numpy as np + + +class PSF_Simulator: + max_wfe_rms = 55 + + +psf_params = RecursiveNamespace( + grid_points=[2, 2], + grid_size=4, + max_order=2, + x_lims=[0, 2], + y_lims=[0, 2], + psf_simulator=PSF_Simulator(), + d_max=1, + n_bins=2, + lim_max_wfe_rms=2, + verbose=False, + seed=930293, +) + + +@pytest.fixture(scope="module", params=[psf_params]) +def spatial_varying_psf(): + return SpatialVaryingPSF( + psf_params.psf_simulator, + psf_params.d_max, + psf_params.grid_points, + psf_params.grid_size, + psf_params.max_order, + psf_params.x_lims, + psf_params.y_lims, + psf_params.n_bins, + psf_params.lim_max_wfe_rms, + psf_params.verbose, + psf_params.seed, + ) + + +@pytest.fixture +def example_limits_and_grid(): + x_lims = [-5, 15] + y_lims = [4, 10] + grid_points = [5, 10] + return x_lims, y_lims, grid_points + + +@pytest.fixture +def example_limits_bounds(): + x_lims = [0, 10] + y_lims = [0, 10] + x = np.random.rand(5) * max(x_lims) + y = np.random.rand(5) * max(y_lims) + + return x, y, x_lims, y_lims + + +@pytest.fixture +def xv_and_yv_grid(example_limits_and_grid): + x_lims, y_lims, grid_points = example_limits_and_grid + xv_grid, yv_grid = MeshHelper.build_mesh(x_lims, y_lims, grid_points) + return xv_grid, yv_grid diff --git a/src/wf_psf/tests/test_sims/spatial_varying_psf_test.py b/src/wf_psf/tests/test_sims/spatial_varying_psf_test.py new file mode 100644 index 00000000..e5c10a41 --- /dev/null +++ b/src/wf_psf/tests/test_sims/spatial_varying_psf_test.py @@ -0,0 +1,365 @@ +"""UNIT TESTS FOR PACKAGE MODULE: Sims. + +This module contains unit tests for the wf_psf.sims.spatial_varying_psf module. + +:Author: Jennifer Pollack + +""" +import pytest +import numpy as np +from wf_psf.sims.spatial_varying_psf import ( + MeshHelper, + CoordinateHelper, + ZernikeHelper, + PolynomialMatrixHelper, + SpatialVaryingPSF, +) +import os +import logging + + +@pytest.fixture +def mock_x_lims(): + return [0, 1] + + +@pytest.fixture +def mock_y_lims(): + return [0, 1] + + +@pytest.fixture +def mock_grid_points(): + return [10, 20] + + +@pytest.fixture +def mock_grid_size(): + return 15 + + +def test_build_mesh_with_grid_points(mock_x_lims, mock_y_lims, mock_grid_points): + x_grid, y_grid = MeshHelper.build_mesh( + mock_x_lims, mock_y_lims, grid_points=mock_grid_points + ) + assert x_grid.shape == (mock_grid_points[1], mock_grid_points[0]) + assert y_grid.shape == (mock_grid_points[1], mock_grid_points[0]) + + +def test_build_mesh_with_grid_size(mock_x_lims, mock_y_lims, mock_grid_size): + x_grid, y_grid = MeshHelper.build_mesh( + mock_x_lims, mock_y_lims, grid_size=mock_grid_size + ) + assert x_grid.shape == (mock_grid_size, mock_grid_size) + assert y_grid.shape == (mock_grid_size, mock_grid_size) + + +def test_build_mesh_with_both_params( + mock_x_lims, mock_y_lims, mock_grid_points, mock_grid_size +): + x_grid, y_grid = MeshHelper.build_mesh( + mock_x_lims, + mock_y_lims, + grid_points=mock_grid_points, + grid_size=mock_grid_size, + ) + assert x_grid.shape == (mock_grid_size, mock_grid_size) + assert y_grid.shape == (mock_grid_size, mock_grid_size) + + +def test_build_mesh_with_no_params(mock_x_lims, mock_y_lims): + with pytest.raises(ValueError): + MeshHelper.build_mesh(mock_x_lims, mock_y_lims) + + +def test_build_mesh_default_params(): + # Test case: Default parameters + x_lims = [0, 1] + y_lims = [0, 1] + grid_points = [3, 3] # 3x3 grid + mesh_x, mesh_y = MeshHelper.build_mesh(x_lims, y_lims, grid_points) + assert mesh_x.shape == (3, 3), "Mesh grid shape should be (3, 3)" + assert mesh_y.shape == (3, 3), "Mesh grid shape should be (3, 3)" + assert np.allclose( + mesh_x, np.array([[0, 0.5, 1], [0, 0.5, 1], [0, 0.5, 1]]) + ), "Mesh x coordinates are incorrect" + assert np.allclose( + mesh_y, np.array([[0, 0, 0], [0.5, 0.5, 0.5], [1, 1, 1]]) + ), "Mesh y coordinates are incorrect" + + +def test_build_mesh_custom_params(): + # Test case: Custom parameters + x_lims = [0, 2] + y_lims = [1, 3] + grid_points = [2, 4] # 2x4 grid + mesh_x, mesh_y = MeshHelper.build_mesh(x_lims, y_lims, grid_points) + assert mesh_x.shape == (4, 2), "Mesh grid shape should be (4, 2)" + assert mesh_y.shape == (4, 2), "Mesh grid shape should be (4, 2)" + assert np.allclose( + mesh_x, np.array([[0, 2], [0, 2], [0, 2], [0, 2]]) + ), "Mesh x coordinates are incorrect" + assert np.allclose( + mesh_y, + np.array( + [[1.0, 1.0], [1.66666667, 1.66666667], [2.33333333, 2.33333333], [3.0, 3.0]] + ), + ), "Mesh y coordinates are incorrect" + + +def test_build_mesh_grid_size_parameter(): + # Test case: Custom number of points + x_lims = [0, 1] + y_lims = [0, 1] + grid_points = [3, 3] # 3x3 grid + grid_size = 2 # 2 points in each direction + mesh_x, mesh_y = MeshHelper.build_mesh( + x_lims, y_lims, grid_points, grid_size=grid_size + ) + assert mesh_x.shape == (2, 2), "Mesh grid shape should be (2, 2)" + assert mesh_y.shape == (2, 2), "Mesh grid shape should be (2, 2)" + assert np.allclose( + mesh_x, np.array([[0, 1], [0, 1]]) + ), "Mesh x coordinates are incorrect" + assert np.allclose( + mesh_y, np.array([[0, 0], [1, 1]]) + ), "Mesh y coordinates are incorrect" + + +def test_bounds_of_scaled_positions(example_limits_bounds): + """Test Bounds of Scaled Positions. + + This unit test checks whether the elements of each array + for x and y are within the range [-1, 1]. + + """ + x, y, x_lims, y_lims = example_limits_bounds + x_scale, y_scale = CoordinateHelper.scale_positions(x, y, x_lims, y_lims) + assert np.logical_and(x_scale >= -1, x_scale <= 1).all() + assert np.logical_and(y_scale >= -1, y_scale <= 1).all() + + +def test_correctness_of_shift(example_limits_and_grid): + x_lims, y_lims, grid_points = example_limits_and_grid + + x_step, y_step = CoordinateHelper.calculate_shift(x_lims, y_lims, grid_points) + assert x_step == 4 + assert y_step == 0.6 + + +def test_add_random_shift_to_positions(example_limits_and_grid, xv_and_yv_grid): + x_lims, y_lims, grid_points = example_limits_and_grid + xv_grid, yv_grid = xv_and_yv_grid + seed = 3838284 + + ( + shifted_x, + shifted_y, + ) = CoordinateHelper.add_random_shift_to_positions( + xv_grid, yv_grid, grid_points, x_lims, y_lims, seed + ) + assert np.logical_and(shifted_x >= x_lims[0], shifted_x <= x_lims[1]).all() + assert np.logical_and(shifted_y >= y_lims[0], shifted_y <= y_lims[1]).all() + + +def test_check_and_adjust_coordinate_limits(example_limits_bounds): + x, y, x_lims, y_lims = example_limits_bounds + ( + adjusted_x, + adjusted_y, + ) = CoordinateHelper.check_and_adjust_coordinate_limits(x, y, x_lims, y_lims) + + assert np.logical_and(adjusted_x >= x_lims[0], adjusted_x <= x_lims[1]).all() + assert np.logical_and(adjusted_y >= y_lims[0], adjusted_y <= y_lims[1]).all() + + +def test_check_position_coordinate_limits_within_limits(caplog): + # Test case: xv and yv within the limits + xv = np.array([0.5, 1.5, 2.5]) + yv = np.array([1.0, 2.0, 3.0]) + x_lims = [0, 3] + y_lims = [0, 4] + with caplog.at_level(logging.DEBUG): + CoordinateHelper.check_position_coordinate_limits( + xv, yv, x_lims, y_lims, verbose=True + ) + # Check if log messages are captured + assert not any( + record.levelname == "INFO" and "WARNING!" in record.message + for record in caplog.records + ), "No warning message should be logged for coordinates within limits" + + +def test_check_position_coordinate_limits_outside_limits(caplog): + # Test case: xv and yv outside the limits + xv = np.array([-0.5, 3.5, 2.5]) + yv = np.array([1.0, 4.0, 3.0]) + x_lims = [0, 3] + y_lims = [0, 3] + with caplog.at_level(logging.DEBUG): + CoordinateHelper.check_position_coordinate_limits( + xv, yv, x_lims, y_lims, verbose=True + ) + + # Check if log messages are captured + for record in caplog.records: + assert record.levelname == "WARNING" + assert "x value" in caplog.text + + +def test_check_position_coordinate_limits_no_verbose(caplog): + # Test case: No verbose output + xv = np.array([-0.5, 3.5, 2.5]) + yv = np.array([1.0, 4.0, 3.0]) + x_lims = [0, 3] + y_lims = [0, 3] + with caplog.at_level(logging.DEBUG): + CoordinateHelper.check_position_coordinate_limits( + xv, yv, x_lims, y_lims, verbose=False + ) + # Check if log messages are captured + assert not any( + record.levelname == "INFO" and "WARNING!" in record.message + for record in caplog.records + ), "No warning message should be logged for coordinates within limits" + + +def test_bounds_polynomial_matrix_coefficients(example_limits_bounds): + x, y, x_lims, y_lims = example_limits_bounds + d_max = 2 + Pi = PolynomialMatrixHelper.generate_polynomial_matrix(x, y, x_lims, y_lims, d_max) + assert np.logical_and(Pi >= -1, Pi <= 1).all() + + +def test_shape_initialize_Z_matrix(): + max_order = 5 + length = 10 + Z = ZernikeHelper.initialize_Z_matrix(max_order, length) + + assert np.shape(Z) == (max_order, length) + + +def test_empty_input_vector_initialize_Z_matrix(): + z_matrix = ZernikeHelper.initialize_Z_matrix(10, len(np.array([]))) + assert z_matrix.shape == (10, 0) + + +def test_zero_max_order_initialize_Z_matrix(): + z_matrix = ZernikeHelper.initialize_Z_matrix(0, 3) + assert z_matrix.shape == (0, 3) + + +def test_large_max_order_initialize_Z_matrix(): + z_matrix = ZernikeHelper.initialize_Z_matrix(1000, 3) + + assert z_matrix.shape == (1000, 3) + + +def test_basic_functionality_initialize_Z_matrix(): + Z_matrix = ZernikeHelper.initialize_Z_matrix(2, 4, 930293) + expected_Z = np.array( + [ + [-1.07491132, 0.55969984, 1.4472938, 1.72761226], + [-0.69694839, -0.14165003, -1.97485039, -0.20909905], + ] + ) + np.testing.assert_allclose(Z_matrix, expected_Z) + + +def test_basic_functionality_normalize_Z_matrix(): + Z = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + normalized_Z = ZernikeHelper.normalize_Z_matrix(Z, 10) + expected_normalized_Z = np.array( + [[2.67261242, 5.34522484, 8.01783726], [4.55842306, 5.69802882, 6.83763459]] + ) + np.testing.assert_allclose(normalized_Z, expected_normalized_Z) + + +@pytest.mark.parametrize( + "Z, max_limit", + [ + (np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 8), + (np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]), 10), + (np.array([[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]]), 100), + # Add more Z array values and max_limit values as needed + ], +) +def test_bounds_normalize_Z_matrix(Z, max_limit): + normalized_Z = ZernikeHelper.normalize_Z_matrix(Z, max_limit) + + assert np.logical_and(normalized_Z >= -max_limit, normalized_Z <= max_limit).all() + + +def test_basic_functionality_initialize_normalized_zernike_matrix(): + Z = ZernikeHelper.initialize_normalized_zernike_matrix(2, 4, 2, 930293) + expected_initialize_normalized_Z = np.array( + [ + [-0.84013338, 0.43745238, 1.13118153, 1.35027393], + [-0.66080324, -0.13430377, -1.87243067, -0.19825475], + ] + ) + + +def test_generate_zernike_polynomials(): + # Define input parameters + xv = np.array([0.0, 0.1, 0.2, 0.3, 0.4]) + yv = np.array([0.0, 0.1, 0.2, 0.3, 0.4]) + x_lims = [0.0, 1.0] + y_lims = [0.0, 1.0] + d_max = 2 + polynomial_coeffs = np.array([np.arange(6)] * 10) + + # Call the function + result = ZernikeHelper.generate_zernike_polynomials( + xv, yv, x_lims, y_lims, d_max, polynomial_coeffs + ) + + # Define expected output shape + expected_shape = (10, 5) + + # Assert the shape of the result + assert result.shape == expected_shape + + +def test_WFE_RMS_build_polynomial_coeffs(spatial_varying_psf): + expected_WFE_RMS = np.array( + [ + [1.59860658, 0.98645338, 0.37585926, 0.24511776], + [1.72689883, 1.23068974, 0.89311116, 0.91177326], + [2.1686159, 1.83035052, 1.65774055, 1.70195683], + [2.77807291, 2.54555448, 2.44798666, 2.50121219], + ] + ) + + np.testing.assert_allclose(spatial_varying_psf.WFE_RMS, expected_WFE_RMS) + + +def test_polynomial_coeffs_build_polynomial_coeffs(spatial_varying_psf): + expected_polynomial_coeffs = np.array( + [[0.85217187, 0.42910555, 1.12535278], [-1.05870064, 0.81260699, -0.43522383]] + ) + np.testing.assert_allclose( + spatial_varying_psf.polynomial_coeffs, expected_polynomial_coeffs + ) + + +def test_calculate_zernikes(spatial_varying_psf): + xv = np.array([0.1599547, 2.0, 0.0, 2.0]) + yv = np.array([0.0, 0.0, 1.56821882, 1.8144926]) + expected_zernikes = np.array( + [ + [-0.63364901, 0.15592463, 1.06251295, 2.19786893], + [-1.3061035, 0.18913017, -2.11861001, -0.60058024], + ] + ) + + zernikes = ZernikeHelper.calculate_zernike( + xv, + yv, + spatial_varying_psf.x_lims, + spatial_varying_psf.y_lims, + spatial_varying_psf.d_max, + spatial_varying_psf.polynomial_coeffs, + ) + + np.testing.assert_allclose(zernikes, expected_zernikes) diff --git a/src/wf_psf/tests/test_utils/configs_handler_test.py b/src/wf_psf/tests/test_utils/configs_handler_test.py index 5e8db118..9209c842 100644 --- a/src/wf_psf/tests/test_utils/configs_handler_test.py +++ b/src/wf_psf/tests/test_utils/configs_handler_test.py @@ -9,12 +9,90 @@ import pytest from wf_psf.utils import configs_handler +from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.utils.io import FileIOHandler +from pytest_mock import mocker +from wf_psf.utils.configs_handler import TrainingConfigHandler, DataConfigHandler +from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.training.train import train import os +@pytest.fixture +def mock_training_model_params(): + return RecursiveNamespace(n_bins_lda=10) # Replace 10 with the desired value + + +@pytest.fixture +def mock_data_read_conf(mocker): + return mocker.patch( + "wf_psf.utils.configs_handler.read_conf", + return_value=RecursiveNamespace( + data=RecursiveNamespace( + training=RecursiveNamespace( + data_dir="/path/to/train_data", file="train_data.npy" + ), + test=RecursiveNamespace( + data_dir="/path/to/test_data", + file="test_data.npy", + ), + ), + ), + ) + + +@pytest.fixture +def mock_training_conf(): + return RecursiveNamespace( + training=RecursiveNamespace( + id_name="_test_", + data_config="data_config.yaml", + metrics_config=None, + model_params=RecursiveNamespace( + model_name="poly", + param_hparams=RecursiveNamespace( + random_seed=3877572, + ), + nonparam_hparams=RecursiveNamespace( + d_max_nonparam=5, + ), + ), + training_hparams=RecursiveNamespace(n_epochs_params=[2, 2]), + ), + ) + + +@pytest.fixture +def mock_data_conf(mocker): + # Create a mock object + data_conf = mocker.Mock() + + # Set attributes on the mock object + data_conf.training_data = "value1" + data_conf.test_data = "value2" + + return data_conf + + +@pytest.fixture +def mock_training_config_handler(mocker, mock_training_conf, mock_data_conf): + # Create a mock instance of TrainingConfigHandler + mock_instance = mocker.Mock(spec=TrainingConfigHandler) + + # Set attributes of the mock instance as needed + mock_instance.training_conf = mock_training_conf + mock_instance.data_conf = mock_data_conf + mock_instance.data_conf.training_data = mock_data_conf.training_data + mock_instance.data_conf.test_data = mock_data_conf.test_data + mock_instance.checkpoint_dir = "/mock/checkpoint/dir" + mock_instance.optimizer_dir = "/mock/optimizer/dir" + mock_instance.psf_model_dir = "/mock/psf/model/dir" + + return mock_instance + + @configs_handler.register_configclass -class TestClass: +class RegisterConfigClass: ids = ("test_conf",) def __init__(self, config_params, file_handler): @@ -23,12 +101,12 @@ def __init__(self, config_params, file_handler): def test_register_configclass(): - assert configs_handler.CONFIG_CLASS["test_conf"] == TestClass + assert configs_handler.CONFIG_CLASS["test_conf"] == RegisterConfigClass def test_set_run_config(): config_class = configs_handler.set_run_config("test_conf") - assert config_class == TestClass + assert config_class == RegisterConfigClass config_class = configs_handler.set_run_config("training_conf") assert config_class == configs_handler.TrainingConfigHandler @@ -49,7 +127,130 @@ def test_get_run_config(path_to_repo_dir, path_to_tmp_output_dir, path_to_config "test_conf", "fake_config.yaml", test_file_handler ) - assert type(config_class) is TestClass + assert type(config_class) is RegisterConfigClass + + +def test_data_config_handler_init( + mock_training_model_params, mock_data_read_conf, mocker +): + # Mock read_conf function + mock_data_read_conf() + + # Mock SimPSF instance + mock_simPSF_instance = mocker.Mock(name="SimPSFToolkit") + mocker.patch( + "wf_psf.psf_models.psf_models.simPSF", return_value=mock_simPSF_instance + ) + + # Patch the initialize and load_dataset methods inside DataHandler + mocker.patch.object(DataHandler, "initialize") + + # Create DataConfigHandler instance + data_config_handler = DataConfigHandler( + "/path/to/data_config.yaml", mock_training_model_params + ) + + # Check that attributes are set correctly + assert isinstance(data_config_handler.data_conf, RecursiveNamespace) + assert isinstance(data_config_handler.simPSF, object) + assert ( + data_config_handler.training_data.n_bins_lambda + == mock_training_model_params.n_bins_lda + ) + assert ( + data_config_handler.test_data.n_bins_lambda + == mock_training_model_params.n_bins_lda + ) + + +def test_training_config_handler_init(mocker, mock_training_conf, mock_file_handler): + # Mock read_conf function + mocker.patch( + "wf_psf.utils.configs_handler.read_conf", return_value=mock_training_conf + ) + + # Mock data_conf instance + mock_data_conf = mocker.patch("wf_psf.utils.configs_handler.DataConfigHandler") + + # Mock SimPSF instance + mock_simPSF_instance = mocker.Mock(name="SimPSFToolkit") + mocker.patch( + "wf_psf.psf_models.psf_models.simPSF", return_value=mock_simPSF_instance + ) + + # Initialize TrainingConfigHandler with the mock_file_handler + training_config_handler = TrainingConfigHandler( + "/path/to/training_config.yaml", mock_file_handler + ) + + # Assertions + mock_file_handler.copy_conffile_to_output_dir.assert_called_once_with( + training_config_handler.training_conf.training.data_config + ) + mock_file_handler.get_checkpoint_dir.assert_called_once_with( + mock_file_handler._run_output_dir + ) + mock_file_handler.get_optimizer_dir.assert_called_once_with( + mock_file_handler._run_output_dir + ) + mock_file_handler.get_psf_model_dir.assert_called_once_with( + mock_file_handler._run_output_dir + ) + assert training_config_handler.training_conf == mock_training_conf + assert training_config_handler.file_handler == mock_file_handler + assert ( + training_config_handler.file_handler.repodir_path + == mock_file_handler.repodir_path + ) + + mock_data_conf.assert_called_once_with( + os.path.join( + mock_file_handler.config_path, + training_config_handler.training_conf.training.data_config, + ), + training_config_handler.training_conf.training.model_params, + ) + assert training_config_handler.data_conf == mock_data_conf.return_value + + +def test_run_method_calls_train_with_correct_arguments( + mocker, mock_training_conf, mock_data_conf +): + # Patch the TrainingConfigHandler.__init__() method + mocker.patch( + "wf_psf.utils.configs_handler.TrainingConfigHandler.__init__", return_value=None + ) + mock_th = TrainingConfigHandler(None, None) + # Set attributes of the mock_th + mock_th.training_conf = mock_training_conf + mock_th.data_conf = mock_data_conf + mock_th.data_conf.training_data = mock_data_conf.training_data + mock_th.data_conf.test_data = mock_data_conf.test_data + mock_th.checkpoint_dir = "/mock/checkpoint/dir" + mock_th.optimizer_dir = "/mock/optimizer/dir" + mock_th.psf_model_dir = "/mock/psf/model/dir" + + # Patch the train.train() function + mock_train_function = mocker.patch("wf_psf.training.train.train") + + # Create a spy for the run method + spy = mocker.spy(mock_th, "run") + + # Call the run method + mock_th.run() + + # Assert that run() is called once + spy.assert_called_once() + + # Assert that train.train() is called with the correct arguments + mock_train_function.assert_called_once_with( + mock_th.training_conf.training, + mock_th.data_conf.training_data, + mock_th.data_conf.test_data, + mock_th.checkpoint_dir, + mock_th.optimizer_dir, + mock_th.psf_model_dir, + ) def test_MetricsConfigHandler_weights_basename_filepath( diff --git a/src/wf_psf/tests/test_utils/conftest.py b/src/wf_psf/tests/test_utils/conftest.py index 7d7a5a34..b1de69dd 100644 --- a/src/wf_psf/tests/test_utils/conftest.py +++ b/src/wf_psf/tests/test_utils/conftest.py @@ -10,9 +10,17 @@ import pytest import os +from wf_psf.utils.read_config import RecursiveNamespace +from wf_psf.utils.io import FileIOHandler cwd = os.getcwd() +training_config = RecursiveNamespace( + id_name="_sample_w_bis1_2k", + data_config="data_config.yaml", + metrics_config="metrics_config.yaml", +) + @pytest.fixture(scope="class") def path_to_repo_dir(): @@ -32,3 +40,56 @@ def path_to_tmp_output_dir(tmp_path): @pytest.fixture def path_to_config_dir(path_to_test_dir): return os.path.join(path_to_test_dir, "data") + + +@pytest.fixture +def mock_file_handler(mocker, tmp_path): + # Create a temporary directory + temp_dir = tmp_path / "temp_dir" + temp_dir.mkdir() + + # Create a mock FileIOHandler instance + mock_fh = FileIOHandler( + repodir_path="/path/to/repo", + output_path="/path/to/output", + config_path=str(temp_dir), + ) + + # Mock the methods of FileIOHandler as needed + mocker.patch.object( + mock_fh, "get_checkpoint_dir", return_value="/path/to/checkpoints" + ) + mocker.patch.object(mock_fh, "get_optimizer_dir", return_value="/path/to/optimizer") + mocker.patch.object(mock_fh, "get_psf_model_dir", return_value="/path/to/psf_model") + mocker.patch.object(mock_fh, "copy_conffile_to_output_dir") + + return mock_fh + + +@pytest.fixture() +def mock_config_dir(tmp_path): + # Use os.path.join to construct the file path + mock_data_conf_dir = tmp_path / "tmp_config_dir" + mock_data_conf_dir.mkdir() + return mock_data_conf_dir + + +@pytest.fixture(scope="function") +def mock_data_config(mock_config_dir): + # Create a mock data configuration + mock_data_conf_content = """ + data: + training: + data_dir: data/mock_dataset/ + file: train_data.npy + test: + data_dir: data/mock_dataset/ + file: test_data.npy + """ + + mock_data_conf_path = mock_config_dir / "data_config.yaml" + + # Write the mock training configuration to a file + mock_data_conf_path.write_text(mock_data_conf_content) + + return mock_data_conf_path diff --git a/src/wf_psf/tests/test_utils/utils_test.py b/src/wf_psf/tests/test_utils/utils_test.py new file mode 100644 index 00000000..99d2cf27 --- /dev/null +++ b/src/wf_psf/tests/test_utils/utils_test.py @@ -0,0 +1,108 @@ +"""UNIT TESTS FOR PACKAGE MODULE: UTILS. + +This module contains unit tests for the wf_psf.utils utils module. + +:Author: Tobias Liaudat + + +""" + +import pytest +import tensorflow as tf +import numpy as np +from wf_psf.utils.utils import ( + zernike_generator, + compute_unobscured_zernike_projection, + decompose_tf_obscured_opd_basis, +) +from wf_psf.sims.psf_simulator import PSFSimulator + + +def test_unobscured_zernike_projection(): + n_zernikes = 20 + wfe_dim = 256 + tol = 1e-1 + + # Create zernike basis + zernikes = zernike_generator(n_zernikes=n_zernikes, wfe_dim=wfe_dim) + np_zernike_cube = np.zeros( + (len(zernikes), zernikes[0].shape[0], zernikes[0].shape[1]) + ) + for it in range(len(zernikes)): + np_zernike_cube[it, :, :] = zernikes[it] + np_zernike_cube[np.isnan(np_zernike_cube)] = 0 + tf_zernike_cube = tf.convert_to_tensor(np_zernike_cube, dtype=tf.float32) + + # Create random zernike coefficient array + zk_array = np.random.randn(1, n_zernikes, 1, 1) + tf_zk_array = tf.convert_to_tensor(zk_array, dtype=tf.float32) + + # Compute OPD + tf_unobscured_opd = tf.math.reduce_sum( + tf.math.multiply(tf_zernike_cube, tf_zk_array), axis=1 + ) + + # Compute normalisation factor + norm_factor = compute_unobscured_zernike_projection( + tf_zernike_cube[0, :, :], tf_zernike_cube[0, :, :] + ) + + # Compute projections for each zernike + estimated_zk_array = np.array( + [ + compute_unobscured_zernike_projection( + tf_unobscured_opd, tf_zernike_cube[j, :, :], norm_factor=norm_factor + ) + for j in range(n_zernikes) + ] + ) + + rmse_error = np.linalg.norm(estimated_zk_array - zk_array[0, :, 0, 0]) + + assert rmse_error < tol + + +def test_tf_decompose_obscured_opd_basis(): + n_zernikes = 20 + wfe_dim = 256 + tol = 1e-5 + + # Create zernike basis + zernikes = zernike_generator(n_zernikes=n_zernikes, wfe_dim=wfe_dim) + np_zernike_cube = np.zeros( + (len(zernikes), zernikes[0].shape[0], zernikes[0].shape[1]) + ) + for it in range(len(zernikes)): + np_zernike_cube[it, :, :] = zernikes[it] + np_zernike_cube[np.isnan(np_zernike_cube)] = 0 + tf_zernike_cube = tf.convert_to_tensor(np_zernike_cube, dtype=tf.float32) + + # Create obscurations + obscurations = PSFSimulator.generate_pupil_obscurations(N_pix=wfe_dim, N_filter=2) + tf_obscurations = tf.convert_to_tensor(obscurations, dtype=tf.float32) + + # Create random zernike coefficient array + zk_array = np.random.randn(1, n_zernikes, 1, 1) + tf_zk_array = tf.convert_to_tensor(zk_array, dtype=tf.float32) + + # Compute OPD + tf_unobscured_opd = tf.math.reduce_sum( + tf.math.multiply(tf_zernike_cube, tf_zk_array), axis=1 + ) + # Obscure the OPD + tf_obscured_opd = tf.math.multiply( + tf_unobscured_opd, tf.expand_dims(tf_obscurations, axis=0) + ) + + # Compute zernike array from OPD + obsc_coeffs = decompose_tf_obscured_opd_basis( + tf_opd=tf_obscured_opd, + tf_obscurations=tf_obscurations, + tf_zk_basis=tf_zernike_cube, + n_zernike=n_zernikes, + iters=100, + ) + + rmse_error = np.linalg.norm(obsc_coeffs - zk_array[0, :, 0, 0]) + + assert rmse_error < tol diff --git a/src/wf_psf/training/train.py b/src/wf_psf/training/train.py index a0da107d..5ca5b2bb 100644 --- a/src/wf_psf/training/train.py +++ b/src/wf_psf/training/train.py @@ -308,13 +308,16 @@ def train( # If projected learning is enabled project DD_features. if psf_model.project_dd_features: # need to change this - psf_model.project_DD_features( - psf_model.zernike_maps - ) # make this a callable function - logger.info("Project non-param DD features onto param model: done!") - if psf_model.reset_dd_features: - psf_model.tf_np_poly_opd.init_vars() - logger.info("DD features reset to random initialisation.") + if current_cycle > 1: + psf_model.project_DD_features( + psf_model.zernike_maps + ) # make this a callable function + logger.info( + "Projected non-parametric DD features onto the parametric model." + ) + if psf_model.reset_dd_features: + psf_model.tf_np_poly_opd.init_vars() + logger.info("DataDriven features were reset to random initialisation.") # Prepare the saving callback # Prepare to save the model as a callback @@ -344,16 +347,16 @@ def train( psf_model, # training data inputs=[ - training_data.train_dataset["positions"], + training_data.dataset["positions"], training_data.sed_data, ], - outputs=training_data.train_dataset["noisy_stars"], + outputs=training_data.dataset["noisy_stars"], validation_data=( [ - test_data.test_dataset["positions"], + test_data.dataset["positions"], test_data.sed_data, ], - test_data.test_dataset["stars"], + test_data.dataset["stars"], ), batch_size=training_handler.training_hparams.batch_size, learning_rate_param=training_handler.learning_rate_params[ @@ -397,13 +400,13 @@ def train( # Save optimisation history in the saving dict if psf_model.save_optim_history_param: - saving_optim_hist["param_cycle{}".format(current_cycle)] = ( - hist_param.history - ) + saving_optim_hist[ + "param_cycle{}".format(current_cycle) + ] = hist_param.history if psf_model.save_optim_history_nonparam: - saving_optim_hist["nonparam_cycle{}".format(current_cycle)] = ( - hist_non_param.history - ) + saving_optim_hist[ + "nonparam_cycle{}".format(current_cycle) + ] = hist_non_param.history # Save last cycle if no cycles were saved if not training_handler.multi_cycle_params.save_all_cycles: diff --git a/src/wf_psf/training/train_utils.py b/src/wf_psf/training/train_utils.py index 6866a59c..8c1036ac 100644 --- a/src/wf_psf/training/train_utils.py +++ b/src/wf_psf/training/train_utils.py @@ -1,6 +1,6 @@ import numpy as np import tensorflow as tf -from wf_psf.psf_models.tf_psf_field import build_PSF_model +from wf_psf.psf_models.psf_models import build_PSF_model from wf_psf.utils.utils import NoiseEstimator import logging diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index 3ee7aaf4..466b6cc4 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -9,7 +9,7 @@ import numpy as np from wf_psf.utils.read_config import read_conf -from wf_psf.data.training_preprocessing import TrainingDataHandler, TestDataHandler +from wf_psf.data.training_preprocessing import DataHandler from wf_psf.training import train from wf_psf.psf_models import psf_models from wf_psf.metrics.metrics_interface import evaluate_model @@ -50,10 +50,10 @@ def register_configclass(config_class): def set_run_config(config_name): - """Set Config Class. + """Set Run Configuration Class. - A function to select the class of - a configuration from CONFIG_CLASS dictionary. + A function to retrieve the appropriate configuration + class based on the provided config name. Parameters ---------- @@ -70,26 +70,25 @@ def set_run_config(config_name): config_id = [id for id in CONFIG_CLASS.keys() if re.search(id, config_name)][0] config_class = CONFIG_CLASS[config_id] except KeyError as e: - logger.exception("Config name entered is invalid. Check your config settings.") + logger.exception("Invalid config name. Check your config settings.") exit() return config_class -def get_run_config(run_config, config_params, file_handler): - """Get Run Configuration. +def get_run_config(run_config_name, *config_params): + """Get Run Configuration Instance. - A function to get the configuration - for a wf-psf run. + A function to retrieve an instance of + the appropriate configuration class for + a WF-PSF run. Parameters ---------- - run_config: str - Name of the type of run configuraton - config_params: str - Path of the run configuration file - file_handler: object - A class instance of FileIOHandler + run_config_name: str + Name of the run configuraton + *config_params: str + Run configuration parameters used for class instantiation. Returns ------- @@ -97,9 +96,9 @@ def get_run_config(run_config, config_params, file_handler): A class instance of the selected configuration class. """ - config_class = set_run_config(run_config) + config_class = set_run_config(run_config_name) - return config_class(config_params, file_handler) + return config_class(*config_params) class ConfigParameterError(Exception): @@ -136,13 +135,15 @@ def __init__(self, data_conf, training_model_params): exit() self.simPSF = psf_models.simPSF(training_model_params) - self.training_data = TrainingDataHandler( - self.data_conf.data.training, + self.training_data = DataHandler( + "training", + self.data_conf.data, self.simPSF, training_model_params.n_bins_lda, ) - self.test_data = TestDataHandler( - self.data_conf.data.test, + self.test_data = DataHandler( + "test", + self.data_conf.data, self.simPSF, training_model_params.n_bins_lda, ) @@ -197,6 +198,7 @@ def run(self): input configuration. """ + train.train( self.training_conf.training, self.data_conf.training_data, @@ -413,9 +415,9 @@ def call_plot_config_handler_run(self, model_metrics): ) # Update metrics_confs dict with latest result - plots_config_handler.metrics_confs[self._file_handler.workdir] = ( - self.metrics_conf - ) + plots_config_handler.metrics_confs[ + self._file_handler.workdir + ] = self.metrics_conf # Update metric results dict with latest result plots_config_handler.list_of_metrics_dict[self._file_handler.workdir] = [ @@ -441,8 +443,7 @@ def run(self): model_metrics = evaluate_model( self.metrics_conf.metrics, self.training_conf.training, - self.data_conf.training_data, - self.data_conf.test_data, + self.data_conf, self.psf_model, self.weights_path, self.metrics_dir, diff --git a/src/wf_psf/utils/io.py b/src/wf_psf/utils/io.py index bb822448..b49c9a47 100644 --- a/src/wf_psf/utils/io.py +++ b/src/wf_psf/utils/io.py @@ -135,7 +135,6 @@ def _setup_logging(self): self._log_files, logfile, ) - logging.config.fileConfig( os.path.join(self.repodir_path, "config/logging.conf"), defaults={"filename": logfile}, diff --git a/src/wf_psf/utils/utils.py b/src/wf_psf/utils/utils.py index dbb0328d..a1ed0bb0 100644 --- a/src/wf_psf/utils/utils.py +++ b/src/wf_psf/utils/utils.py @@ -6,7 +6,7 @@ try: from cv2 import resize, INTER_AREA -except: +except ModuleNotFoundError: print("Problem importing opencv..") import sys @@ -128,7 +128,7 @@ def calc_poly_position_mat(pos, x_lims, y_lims, d_max): poly_list = [] for d in range(d_max + 1): - row_idx = d * (d + 1) // 2 + # row_idx = d * (d + 1) // 2 for p in range(d + 1): poly_list.append(scaled_pos_x ** (d - p) * scaled_pos_y**p) @@ -485,14 +485,87 @@ def load_multi_cycle_params_click(args): return args -def PI_zernikes(tf_z1, tf_z2, norm_factor=None): - """Compute internal product between zernikes and OPDs +def compute_unobscured_zernike_projection(tf_z1, tf_z2, norm_factor=None): + """Compute a zernike projection for unobscured wavefronts (OPDs). - Defined such that Zernikes are orthonormal to each other + Compute internal product between zernikes and OPDs. - First one should compute: norm_factor = PI_zernikes(tf_zernike,tf_zernike) - for futur calls: PI_zernikes(OPD,tf_zernike_k, norm_factor) + Defined such that Zernikes are orthonormal to each other. + + First one should compute: norm_factor = unobscured_zernike_projection(tf_zernike,tf_zernike) + for futur calls: unobscured_zernike_projection(OPD,tf_zernike_k, norm_factor) + + If the OPD has obscurations, or is not an unobscured circular aperture, + the Zernike polynomials are no longer orthonormal. Therefore, you should consider + using the function `tf_decompose_obscured_opd_basis` that takes into account the + obscurations in the projection. """ if norm_factor is None: norm_factor = 1 return np.sum((tf.math.multiply(tf_z1, tf_z2)).numpy()) / (norm_factor) + + +def decompose_tf_obscured_opd_basis( + tf_opd, tf_obscurations, tf_zk_basis, n_zernike, iters=20 +): + """Decompose obscured OPD into a basis using an iterative algorithm. + + Tensorflow implementation. + + Parameters + ---------- + tf_opd : tf.Tensor + Input OPD that requires to be decomposed on `tf_zk_basis`. The tensor shape is (opd_dim, opd_dim). + tf_obscurations : tf.Tensor + Tensor with the obscuration map. The tensor shape is (opd_dim, opd_dim). + tf_zk_basis : tf.Tensor + Zernike polynomial maps. The tensor shape is (n_batch, opd_dim, opd_dim) + n_zernike : int + Number of Zernike polynomials to project on. + iters : int + Number of iterations of the algorithm. + + Returns + ------- + obsc_coeffs : np.ndarray + Array of size `n_zernike` with projected Zernike coefficients + + Raises + ------ + ValueError + If `n_zernike` is bigger than tf_zk_basis.shape[0]. + + """ + if n_zernike > tf_zk_basis.shape[0]: + raise ValueError( + "Number of Zernike polynomials to project (n_zernike) exceeds the available Zernike elements in the provided basis (tf_zk_basis). Please ensure that n_zernike is less than or equal to the number of Zernike elements in tf_zk_basis." + ) + # Clone input OPD + input_tf_opd = tf.identity(tf_opd) + # Clone obscurations and project + input_tf_obscurations = tf.math.real(tf.identity(tf_obscurations)) + # Compute normalisation factor + ngood = tf.math.reduce_sum(input_tf_obscurations, axis=None, keepdims=False).numpy() + + obsc_coeffs = np.zeros(n_zernike) + new_coeffs = np.zeros(n_zernike) + + for count in range(iters): + for i, b in enumerate(tf_zk_basis): + this_coeff = ( + tf.math.reduce_sum( + tf.math.multiply(input_tf_opd, b), axis=None, keepdims=False + ).numpy() + / ngood + ) + new_coeffs[i] = this_coeff + + for i, b in enumerate(tf_zk_basis): + input_tf_opd = input_tf_opd - tf.math.multiply( + new_coeffs[i] * b, input_tf_obscurations + ) + + obsc_coeffs += new_coeffs + new_coeffs = np.zeros(n_zernike) + + return obsc_coeffs