From 076fdd3bd6c57735ff5ef2cdb8884f74e9907ba9 Mon Sep 17 00:00:00 2001 From: cristianrubioa Date: Thu, 1 Feb 2024 22:07:34 -0500 Subject: [PATCH] :hammer: Update frame_image_processor code --- src/frame_image_processor.py | 478 +++++++++++++++++++++++++++-------- 1 file changed, 368 insertions(+), 110 deletions(-) diff --git a/src/frame_image_processor.py b/src/frame_image_processor.py index a8a2ce0..27690b2 100644 --- a/src/frame_image_processor.py +++ b/src/frame_image_processor.py @@ -38,7 +38,21 @@ } -def data_for_line_plot(): +def data_for_line_plot(json_data_filename: str, vis_params: dict) -> dict: + """ + Extract and organize data from a JSON file for visualization in a line plot. + + Args: + - json_data_filename: str + File path to the JSON dataset file. + + - vis_params: dict + Dictionary containing visualization parameters. + + Return: + dict: Updated "vis_params" dictionary containing extracted data + from JSON dataset. + """ # Data for line plot json_keys = [] dates = [] @@ -46,18 +60,17 @@ def data_for_line_plot(): snow_cover_percentages = [] cloud_presence = [] - JSON_DATA_FILENAME = os.path.join( - settings.IMAGES_DATASET.DATASET_PATH, - settings.IMAGES_DATASET.DATASET_METADATA_FILE_TAGS, - ) - # Read JSON data from file - with open(JSON_DATA_FILENAME) as file: + with open(json_data_filename) as file: json_data = json.load(file) + # Read JSON dataset values for key, values in json_data.items(): - json_keys.append(key) + # Get date from JSON key date_str = key.split("_")[0] + + # Add data to memory + json_keys.append(key) dates.append(datetime.strptime(date_str, "%Y%m%d")) temperatures.append(values["temperature_roi"]) snow_cover_percentages.append(values["snow_cover_per"]) @@ -66,8 +79,8 @@ def data_for_line_plot(): my_dict = { "date": { "all": dates, - "start": dates[0] - timedelta(days=90), - "end": dates[-1] + timedelta(days=90), + "min": dates[0] - timedelta(days=90), + "max": dates[-1] + timedelta(days=90), }, "temperature": { "all": temperatures, @@ -86,58 +99,242 @@ def data_for_line_plot(): "all": json_keys, }, } - VISUALIZATION_PARAMS["JSON_DATA"] = my_dict - return VISUALIZATION_PARAMS + # Add read data to visualization params + vis_params["JSON_DATA"] = my_dict + + return vis_params + + +def read_landsat_band(image_path, normalize=False, fill_value=0): + """ + Reads a Landsat band from the given file path and returns its raster data + as a numpy array and metadata as a dictionary. It can optionally normalize + the band data. + + Args: + - image_path: str + Image path and filename to read + + - normalize: bool + Whether to normalize band values or not + + - fill_value: int/float + Invalid band value to fill with np.nan + + Return: + - tuple: returns the raster image and the metadata + """ + # Read GeoTIF band + with rasterio.open(image_path) as src: + # Read raster data + raster = src.read().astype(np.float32) + raster[raster == fill_value] = np.nan + + # Read Metadata + metadata = src.meta + + # Scaled values in range [0, 1] + if normalize: + raster -= np.nanmin(raster) + raster /= np.nanmax(raster) - np.nanmin(raster) + + return raster.squeeze(), metadata + + +def compute_true_color(band_paths, normalize=True, fill_value=0): + """ + Creates a True Color Image (TCI) from RGB corresponding satelite bands + + Args: + - band_paths: list(str) + Path to band files in RGB sequence + + - normalize: bool + Whether to scale values in range [0, 1] + """ + # RGB images data + image_true_color = [] + + for path in band_paths: + # Read band data + image, meta = read_landsat_band( + image_path=path, normalize=normalize, fill_value=fill_value + ) + + image_true_color.append(image) + + return np.stack(image_true_color, axis=-1) + + +def compute_landsat_temperature(image_path, band_factors, celcius=True, fill_value=0): + """ + Compute temperature values from Landsat 8/9 Collection-2 Level-2 data. + + Reads the surface temperature band data from the specified Landsat thermal + band image and applies the provided scale factors and offset to convert + the data to temperature values in Kelvin. Optionally, the values can be + further converted to Celsius. + + Args: + - image_path: str + Path to Landsat thermal band. + + - band_factors: dict + Landsat 8/9 Level-2 product scale factors and offset. + Must have "SCALE_FACTOR" and "ADDITIVE_OFFSET" keys. + + - celcius: bool, optional + Whether to convert temperature values to Celsius. + Defaults to True. + + - fill_value: int/float, optional + Invalid band value to fill with np.nan. + Defaults to 0. + + Returns: + - ndarray: Array of temperature values in °K or °C, based on the + "celcius" argument. + """ + + # Read surface temperature band data + st_band, meta = read_landsat_band( + image_path=image_path, normalize=False, fill_value=fill_value + ) + + # Get scale and offset factors + scale_factor = band_factors["SCALE_FACTOR"] + offset = band_factors["ADDITIVE_OFFSET"] + + # Convert uint to floating point values in °K + temperature = (st_band * scale_factor) + offset + + # Convert values to °C + if celcius: + temperature -= settings.IMAGES_DATASET.CELCIUS_SCALER_FACTOR + + return temperature + + +def compute_ndsi(green_path, swir_path, fill_value=0): + """ + Compute Normalized Differenced Snow Index (NDSI) using Landsat green and + shortwave infrared bands. The result is an array of NDSI values, where + invalid divisions are ignored. + + Args: + - green_path: str + Path to Landsat green band + + - swir_path: str + Path to Landsat shortwave infrared band + + - fill_value: int/float, optional + Invalid band value to fill with np.nan. Defaults to 0. + + Returns: + - ndarray: Array of NDSI values. + """ + # Read green band data + green, meta = read_landsat_band( + image_path=green_path, normalize=False, fill_value=fill_value + ) + + # Read shortwave infrared band data + swir, meta = read_landsat_band( + image_path=swir_path, normalize=False, fill_value=fill_value + ) + + # Compute NDSI (ignoring invalid divisions) + np.seterr(divide="ignore", invalid="ignore") + return (green - swir) / (green + swir) def get_images_to_show(image_filename): - # Bands information in memory - bands_data = {} - - # Read needed bands - band_names = ["SR_B4", "SR_B3", "SR_B2", "SR_B6", "ST_B10"] - for band_name in band_names: - # Get band path - path = os.path.join( + """ + Generate and return various processed images for visualization. + + Parameters: + - image_filename: str + Base filename format for Landsat image bands. + + Returns: + dict: A dictionary containing different processed images: + + - "Color": True Color Image (TCI) computed from Red, Green, + and Blue bands. + + - "Temperature": Temperature image computed from the thermal band. + + - "NDSI": Normalized Difference Snow Index (NDSI) computed from + Green and Shortwave Infrared bands. + """ + # True Color Image (TCI) + # --------------------------------------------------------------------------- + # Get paths for RGB image + true_color_paths = [ + os.path.join( settings.IMAGES_DATASET.ROI_CROPPED_DATASET_PATH, image_filename.format(band=band_name), ) - # Read image -> read_landsat_band() - with rasterio.open(path) as src: - image = src.read().astype(np.float32) - image[image == 0] = np.nan - - # Save band in memory - bands_data[band_name] = image.squeeze() - - red = bands_data["SR_B4"] - green = bands_data["SR_B3"] - blue = bands_data["SR_B2"] - swir1 = bands_data["SR_B6"] - st_b10 = bands_data["ST_B10"] - - # Stack bands read in RGB sequence and normalize - nan_mask = np.logical_or(~np.isfinite(red), ~np.isfinite(green), ~np.isfinite(blue)) - image = np.stack([red, green, blue], axis=-1) - image[nan_mask, 0] = np.nan - image[nan_mask, 1] = np.nan - image[nan_mask, 2] = np.nan - image_true_color = (image - np.nanmin(image)) / ( - np.nanmax(image) - np.nanmin(image) + for band_name in ["SR_B4", "SR_B3", "SR_B2"] + ] + + # Compute TCI + image_true_color = compute_true_color( + band_paths=true_color_paths, normalize=True, fill_value=0 ) - image_true_color = np.nan_to_num(image_true_color, nan=1) - # Convert int values to °K and then convert to °C - image_temperature = ( - (st_b10 * settings.IMAGES_DATASET.L2SP_TEMPERATURE_SCALE_FACTOR) - + settings.IMAGES_DATASET.L2SP_TEMPERATURE_ADDITIVE_OFFSET - - settings.IMAGES_DATASET.CELCIUS_SCALER_FACTOR + # NaN mask for RGB bands + nan_mask = np.logical_or( + ~np.isfinite(image_true_color[:, :, 0]), + ~np.isfinite(image_true_color[:, :, 1]), + ~np.isfinite(image_true_color[:, :, 2]), + ) + + # Fill RGB bands based on mask + image_true_color[nan_mask, 0] = np.nan + image_true_color[nan_mask, 1] = np.nan + image_true_color[nan_mask, 2] = np.nan + # --------------------------------------------------------------------------- + + # Temperature image + # --------------------------------------------------------------------------- + # Get path for thermal band + thermal_path = os.path.join( + settings.IMAGES_DATASET.ROI_CROPPED_DATASET_PATH, + image_filename.format(band="ST_B10"), + ) + + # Compute the temperature in °C + image_temperature = compute_landsat_temperature( + image_path=thermal_path, + band_factors={ + "SCALE_FACTOR": settings.IMAGES_DATASET.L2SP_TEMPERATURE_SCALE_FACTOR, + "ADDITIVE_OFFSET": settings.IMAGES_DATASET.L2SP_TEMPERATURE_ADDITIVE_OFFSET, + }, + celcius=True, + fill_value=0, + ) + # --------------------------------------------------------------------------- + + # NDSI image + # --------------------------------------------------------------------------- + # Get green band path + green_path = os.path.join( + settings.IMAGES_DATASET.ROI_CROPPED_DATASET_PATH, + image_filename.format(band="SR_B3"), + ) + + # Get shortwave infrared band path + swir_path = os.path.join( + settings.IMAGES_DATASET.ROI_CROPPED_DATASET_PATH, + image_filename.format(band="SR_B6"), ) # Compute NDSI - np.seterr(divide="ignore", invalid="ignore") - image_ndsi = (green - swir1) / (green + swir1) + image_ndsi = compute_ndsi(green_path=green_path, swir_path=swir_path, fill_value=0) + # --------------------------------------------------------------------------- return { "Color": image_true_color, @@ -146,19 +343,47 @@ def get_images_to_show(image_filename): } -def display_images(images, axes, figure): +def display_images(images, vis_params, axes, figure): + """ + Display multiple images on specified axes with given visualization + parameters. + + Args: + - images: dict + A dictionary containing image names as keys and corresponding + image data as values. Must be three. + + - vis_params: dict + Visualization parameters for each image, including: + axis number, min, max, cmap, and title. The primary keys + must match with "images" argument keys. + + - axes: list + List of matplotlib axes where images will be displayed. + + - figure: matplotlib.pyplot.figure + Current matplotlib figure object where the images are + displayed + + Raises: + - ValueError: If the number of images does not match the number of axes. + + Returns: + None + """ + # Check input args if len(images) != len(axes): error_message = "Number of images must be the same of axes." raise ValueError(error_message) - # Show images + # Display images for image_name, image in images.items(): # Get visualization parameters - axis_num = VISUALIZATION_PARAMS[image_name]["axis"] - vmin = VISUALIZATION_PARAMS[image_name]["min"] - vmax = VISUALIZATION_PARAMS[image_name]["max"] - cmap = VISUALIZATION_PARAMS[image_name]["cmap"] - title = VISUALIZATION_PARAMS[image_name]["title"] + axis_num = vis_params[image_name]["axis"] + vmin = vis_params[image_name]["min"] + vmax = vis_params[image_name]["max"] + cmap = vis_params[image_name]["cmap"] + title = vis_params[image_name]["title"] if (vmin is None) or (vmax is None): vmin = np.nanmin(image) @@ -177,46 +402,73 @@ def display_images(images, axes, figure): figure.colorbar(image_plot, ax=ax, orientation="vertical") -def display_line_plot(index, visualization_params, ax=None): - # Config plot axis +def display_line_plot_by_index(index, json_data, ax=None): + """ + Display a line plot with temperature, snow cover percentage, and cloud + presence up to a specified index. + + Args: + - index: int + Index up to which data will be displayed on the line plot. + + - json_data: dict + Dictionary containing historical temperature, + snow cover, date, and cloud presence data. + + Each primary value is a dictionary with a mandatory + key: "all" and two optional keys: "min", "max" + + - ax: matplotlib.axes.Axes, optional + Matplotlib axes for plotting. If None, a new subplot is created. + + Returns: + None + """ + # Get axis to plot if ax is None: fig, ax1 = plt.subplots() else: ax1 = ax - # Temperature plot (left y-axis) + # Get plot data from JSON + temperature_data = json_data["temperature"] + date_data = json_data["date"] + snow_data = json_data["snow"] + cloud_data = json_data["cloud"] + + # Create temperature line plot (left y-axis) + (temp_line,) = ax1.plot( + date_data["all"][0 : index + 1], + temperature_data["all"][0 : index + 1], + color="tab:red", + label="Temperature", + marker="o", + ) + + # Config temperature line plot ax1.set_xlabel("Date") - ax1.set_ylabel("Mean temperature (°C)", color="tab:red") + ax1.set_ylabel("Mean temperature [°C]", color="tab:red") ax1.tick_params(axis="y", labelcolor="tab:red") ax1.set_ylim( - bottom=visualization_params["JSON_DATA"]["temperature"]["min"], - top=visualization_params["JSON_DATA"]["temperature"]["max"], + bottom=temperature_data["min"], + top=temperature_data["max"], ) ax1.set_xlim( - visualization_params["JSON_DATA"]["date"]["start"], - visualization_params["JSON_DATA"]["date"]["end"], - ) - (temp_line,) = ax1.plot( - visualization_params["JSON_DATA"]["date"]["all"][0 : index + 1], - visualization_params["JSON_DATA"]["temperature"]["all"][0 : index + 1], - color="tab:red", - label="Temperature", - marker="o", + left=date_data["min"], + right=date_data["max"], ) # Add trendline for temperature slope_temp, intercept_temp, _, _, _ = linregress( - range(len(visualization_params["JSON_DATA"]["date"]["all"][0 : index + 1])), - visualization_params["JSON_DATA"]["temperature"]["all"][0 : index + 1], + range(len(date_data["all"][0 : index + 1])), + temperature_data["all"][0 : index + 1], ) trendline_temp = [ slope_temp * i + intercept_temp - for i in range( - len(visualization_params["JSON_DATA"]["date"]["all"][0 : index + 1]) - ) + for i in range(len(date_data["all"][0 : index + 1])) ] (trendline1,) = ax1.plot( - visualization_params["JSON_DATA"]["date"]["all"][0 : index + 1], + date_data["all"][0 : index + 1], trendline_temp, linestyle="dashed", color="tab:red", @@ -225,33 +477,33 @@ def display_line_plot(index, visualization_params, ax=None): # Snow cover percentage plot (right y-axis) ax2 = ax1.twinx() - ax2.set_ylabel("Snow Cover (%)", color="tab:blue") - ax2.tick_params(axis="y", labelcolor="tab:blue") - ax2.set_ylim( - bottom=visualization_params["JSON_DATA"]["snow"]["min"], - top=visualization_params["JSON_DATA"]["snow"]["max"], - ) (snow_line,) = ax2.plot( - visualization_params["JSON_DATA"]["date"]["all"][0 : index + 1], - visualization_params["JSON_DATA"]["snow"]["all"][0 : index + 1], + date_data["all"][0 : index + 1], + snow_data["all"][0 : index + 1], color="tab:blue", label="Snow Cover", marker="o", ) + # Config snow cover line plot + ax2.set_ylabel("Snow Cover [%]", color="tab:blue") + ax2.tick_params(axis="y", labelcolor="tab:blue") + ax2.set_ylim( + bottom=snow_data["min"], + top=snow_data["max"], + ) + # Add trendline for snow cover percentage slope_snow, intercept_snow, _, _, _ = linregress( - range(len(visualization_params["JSON_DATA"]["date"]["all"][0 : index + 1])), - visualization_params["JSON_DATA"]["snow"]["all"][0 : index + 1], + range(len(date_data["all"][0 : index + 1])), + snow_data["all"][0 : index + 1], ) trendline_snow = [ slope_snow * i + intercept_snow - for i in range( - len(visualization_params["JSON_DATA"]["date"]["all"][0 : index + 1]) - ) + for i in range(len(date_data["all"][0 : index + 1])) ] (trendline2,) = ax2.plot( - visualization_params["JSON_DATA"]["date"]["all"][0 : index + 1], + date_data["all"][0 : index + 1], trendline_snow, linestyle="dashed", color="tab:blue", @@ -259,12 +511,10 @@ def display_line_plot(index, visualization_params, ax=None): ) # Vertical dotted line for cloud presence - for i, cloud in enumerate( - visualization_params["JSON_DATA"]["cloud"]["all"][0 : index + 1] - ): + for i, cloud in enumerate(cloud_data["all"][0 : index + 1]): if cloud: ax1.axvline( - x=visualization_params["JSON_DATA"]["date"]["all"][0 : index + 1][i], + x=date_data["all"][0 : index + 1][i], linestyle="dotted", color="gray", ) @@ -293,9 +543,17 @@ def display_line_plot(index, visualization_params, ax=None): def main_image_frame_visualization(): - visualization_params = data_for_line_plot() - for index in range(len(visualization_params["JSON_DATA"]["date"]["all"])): - fig = plt.figure(figsize=(11, 8)) + json_data_filename = os.path.join( + settings.IMAGES_DATASET.DATASET_PATH, + settings.IMAGES_DATASET.DATASET_METADATA_FILE_TAGS, + ) + visualization_params = data_for_line_plot(json_data_filename, VISUALIZATION_PARAMS) + json_data = visualization_params["JSON_DATA"] + key_list = json_data["key"]["all"] + date_values = json_data["date"]["all"] + total_samples = len(key_list) + for index in range(total_samples): + fig = plt.figure(figsize=(12, 10)) gs = GridSpec(2, 3, figure=fig) # create sub plots as grid @@ -304,27 +562,27 @@ def main_image_frame_visualization(): ax3 = fig.add_subplot(gs[0, 2]) ax4 = fig.add_subplot(gs[1, :]) - image_filename = ( - visualization_params["JSON_DATA"]["key"]["all"][index] - + "_{band}_CROPPED.TIF" - ) + image_filename = key_list[index] + "_{band}_CROPPED.TIF" images_dict = get_images_to_show(image_filename) # Display images at row 0 - display_images(images=images_dict, axes=[ax1, ax2, ax3], figure=fig) + display_images( + images=images_dict, + vis_params=visualization_params, + axes=[ax1, ax2, ax3], + figure=fig, + ) # Display line plot - display_line_plot(index, visualization_params, ax=ax4) - subtitle_date = visualization_params["JSON_DATA"]["date"]["all"][ - index - ].strftime("%B %d, %Y") - plt.suptitle( + display_line_plot_by_index(index, json_data, ax=ax4) + subtitle_date = date_values[index].strftime("%B %d, %Y") + ax4.set_title( subtitle_date, fontsize=14, fontweight="bold", color="white", # x=0.05, - y=0.98, + y=1.05, bbox={ "facecolor": "#ff5555", "edgecolor": "gray",