diff --git a/presto/dataset.py b/presto/dataset.py index affe16f..a411c2c 100644 --- a/presto/dataset.py +++ b/presto/dataset.py @@ -143,7 +143,7 @@ def row_to_arrays( timestep_positions = cls.get_timestep_positions(row_d, augment=augment, is_ssl=is_ssl) if cls.NUM_TIMESTEPS == 12: - initial_start_date_position = pd.to_datetime(row_d["start_date"]).month + initial_start_date_position = datetime.strptime(row_d["start_date"], "%Y-%m-%d").month elif cls.NUM_TIMESTEPS > 12: # get the correct index of the start_date based on NUM_TIMESTEPS` # e.g. if NUM_TIMESTEPS is 36 (dekadal setup), we should take the correct @@ -151,6 +151,7 @@ def row_to_arrays( # TODO: 1) this needs to go into a separate function # 2) definition of valid_position and timestep_ind # should also be changed accordingly + year = pd.to_datetime(row_d["start_date"]).year year_dates = pd.date_range(start=f"{year}-01-01", end=f"{year}-12-31") bins = pd.cut(year_dates, bins=cls.NUM_TIMESTEPS, labels=False) diff --git a/presto/utils.py b/presto/utils.py index 8715b56..25d90ec 100644 --- a/presto/utils.py +++ b/presto/utils.py @@ -99,42 +99,43 @@ def get_class_mappings() -> Dict: return CLASS_MAPPINGS -def process_parquet(df: pd.DataFrame) -> pd.DataFrame: +def process_parquet( + df: pd.DataFrame, + use_valid_time: bool = True, + num_timesteps: int = NUM_TIMESTEPS, + min_edge_buffer: int = MIN_EDGE_BUFFER, +) -> pd.DataFrame: """ This function takes in a DataFrame with S1, S2 and ERA5 observations and their respective dates in long format and returns it in wide format. Each row of the input DataFrame should represent a unique combination - of sample_id and timestamp, also containing start_date and valid_date columns. - start_date is the first date of the timeseries. - valid_date is the date for which the crop of the sample is valid - (prefrerably it is located around - the center of the agricultural season, but not necessarily). - timestamp is the date of the observation. + of sample_id and timestamp, where timestamp is the date of the observation. This function performs the following operations: - - computing relative position of the timestamp (timestamp_ind variable) - and valid_date (valid_position variable) in the timeseries; - - filtering out samples were valid date is outside the range of the actual extractions - - adding dummy timesteps filled with NODATA values before the start_date or after - the end_date for samples where valid_date is close to the edge of the timeseries; - this closeness is defined by the globally defined parameter MIN_EDGE_BUFFER - - reinitializing the start_date, end_date and timestamp_ind to take into account - newly added timesteps - - checking for missing timesteps in the middle of the timeseries and adding them - with NODATA values - - pivoting the DataFrame to wide format with columns for each band + - initializing the start_date and end_date as the first and last available observation; + - computing relative position of the timestamp (timestamp_ind variable) in the timeseries; + - checking for missing timesteps in the middle of the timeseries + and filling them with NODATAVALUE + - pivoting the DataFrame to wide format with columns for each feature column and timesteps as suffixes - assigning the correct suffixes to the band names - - computing the final valid_date position in the timeseries that takes - into account updated start_date - - computing the number of available timesteps in the timeseries that - takes into account updated start_date and end_date; available_timesteps - holds the absolute number of timesteps that for which observations are - available; it cannot be less than NUM_TIMESTEPS; if this is the case, + - computing the number of available timesteps in the timeseries; + it represents the absolute number of timesteps for which observations are + available; it cannot be less than num_timesteps; if this is the case, sample is considered faulty and is removed from the dataset - post-processing with prep_dataframe function + Args: + df (pd.DataFrame): Input dataframe containing EO data and the following required attributes: + ["sample_id", "timestamp"]. + use_valid_time (bool): If True, the function will use the valid_time column to check + if valid_time lies within the range of available observations, + with min_edge_buffer buffer. + Samples where this is not the case are removed from the dataset. + If False, the function will not use the valid_time column + and will not perform this check. + Returns ------- pd.DataFrame @@ -142,14 +143,18 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: Raises ------ + AttributeError + error is raised if DataFrame does not contain the required columns ValueError error is raised if pivot results in an empty DataFrame """ - # add dummy value + rename stuff for compatibility with existing functions - df["OPTICAL-B8A"] = NODATAVALUE + static_features = ["DEM-alt-20m", "DEM-slo-20m", "lat", "lon"] + required_columns = ["sample_id", "timestamp"] + static_features + if not all([col in df.columns for col in required_columns]): + missing_columns = [col for col in required_columns if col not in df.columns] + raise AttributeError(f"DataFrame must contain the following columns: {missing_columns}") - # TODO: this needs to go away once the transition to new data is complete df.rename( columns={ "S1-SIGMA0-VV": "SAR-VV", @@ -165,45 +170,13 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: "S2-L2A-B12": "OPTICAL-B12", "AGERA5-precipitation-flux": "METEO-precipitation_flux", "AGERA5-temperature-mean": "METEO-temperature_mean", + # since the openEO output has the attribute "valid_time", + # # we need the following line for compatibility with earlier datasets + "valid_date": "valid_time", }, inplace=True, ) - # should these definitions be here? or better in the dataops.py? - feature_columns = [ - "METEO-precipitation_flux", - "METEO-temperature_mean", - "SAR-VH", - "SAR-VV", - "OPTICAL-B02", - "OPTICAL-B03", - "OPTICAL-B04", - "OPTICAL-B08", - "OPTICAL-B8A", - "OPTICAL-B05", - "OPTICAL-B06", - "OPTICAL-B07", - "OPTICAL-B11", - "OPTICAL-B12", - ] - index_columns = [ - "CROPTYPE_LABEL", - "DEM-alt-20m", - "DEM-slo-20m", - "LANDCOVER_LABEL", - "POTAPOV-LABEL-10m", - "WORLDCOVER-LABEL-10m", - "aez_zoneid", - "end_date", - "lat", - "lon", - "start_date", - "sample_id", - "valid_date", - "location_id", - "ref_id", - ] - bands10m = ["OPTICAL-B02", "OPTICAL-B03", "OPTICAL-B04", "OPTICAL-B08"] bands20m = [ "SAR-VH", @@ -217,103 +190,106 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: ] bands100m = ["METEO-precipitation_flux", "METEO-temperature_mean"] - df["timestamp_ind"] = (df["timestamp"].dt.year * 12 + df["timestamp"].dt.month) - ( - df["start_date"].dt.year * 12 + df["start_date"].dt.month - ) - df["valid_position"] = (df["valid_date"].dt.year * 12 + df["valid_date"].dt.month) - ( - df["start_date"].dt.year * 12 + df["start_date"].dt.month - ) - df["valid_position_diff"] = df["timestamp_ind"] - df["valid_position"] - - # save the initial start_date for later - df["initial_start_date"] = df["start_date"].copy() - index_columns.append("initial_start_date") - - # define samples where valid_date is outside the range of the actual extractions - # and remove them from the dataset - latest_obs_position = df.groupby(["sample_id"])[ - ["valid_position", "timestamp_ind", "valid_position_diff"] - ].max() - df["is_last_available_ts"] = ( - df["sample_id"].map(latest_obs_position["timestamp_ind"]) == df["timestamp_ind"] - ) - samples_after_end_date = latest_obs_position[ - (latest_obs_position["valid_position"] > latest_obs_position["timestamp_ind"]) - ].index - samples_before_start_date = latest_obs_position[ - (latest_obs_position["valid_position"] < 0) - ].index - - if len(samples_after_end_date) > 0 or len(samples_before_start_date) > 0: - logger.warning( - f"""\ -Dataset {df["ref_id"].iloc[0]}: removing {len(samples_after_end_date)} \ -samples with valid_date after the end_date \ -and {len(samples_before_start_date)} samples with valid_date before the start_date""" + feature_columns = bands10m + bands20m + bands100m + # for index columns we need to include all columns that are not feature columns + index_columns = [col for col in df.columns if col not in feature_columns] + index_columns.remove("timestamp") + + # check that all feature columns are present in the DataFrame + # or initialize them with NODATAVALUE + for feature_col in feature_columns: + if feature_col not in df.columns: + df[feature_col] = NODATAVALUE + + df["timestamp_ind"] = df.groupby("sample_id")["timestamp"].rank().astype(int) - 1 + + # Assign start_date and end_date as the minimum and maximum available timestamp + df["start_date"] = df["sample_id"].map(df.groupby(["sample_id"])["timestamp"].min()) + df["end_date"] = df["sample_id"].map(df.groupby(["sample_id"])["timestamp"].max()) + index_columns.extend(["start_date", "end_date"]) + + if use_valid_time: + df["valid_time_ts_diff_days"] = (df["valid_time"] - df["timestamp"]).dt.days.abs() + valid_position = ( + df.set_index("timestamp_ind").groupby("sample_id")["valid_time_ts_diff_days"].idxmin() + ) + df["valid_position"] = df["sample_id"].map(valid_position) + index_columns.append("valid_position") + + df["valid_position_diff"] = df["timestamp_ind"] - df["valid_position"] + + # define samples where valid_time is outside the range of the actual extractions + # and remove them from the dataset + latest_obs_position = df.groupby(["sample_id"])[ + ["valid_position", "timestamp_ind", "valid_position_diff"] + ].max() + df["is_last_available_ts"] = ( + df["sample_id"].map(latest_obs_position["timestamp_ind"]) == df["timestamp_ind"] ) - df = df[~df["sample_id"].isin(samples_before_start_date)] - df = df[~df["sample_id"].isin(samples_after_end_date)] - - # add timesteps before the start_date where needed - intermediate_dummy_df = pd.DataFrame() - for n_ts_to_add in range(1, MIN_EDGE_BUFFER + 1): - samples_to_add_ts_before_start = latest_obs_position[ - (MIN_EDGE_BUFFER - latest_obs_position["valid_position"]) >= -n_ts_to_add + samples_after_end_date = latest_obs_position[ + (latest_obs_position["valid_position"] > latest_obs_position["timestamp_ind"]) ].index - dummy_df = df[ - (df["sample_id"].isin(samples_to_add_ts_before_start)) & (df["timestamp_ind"] == 0) - ].copy() - dummy_df["timestamp"] = dummy_df["timestamp"] - pd.DateOffset( - months=n_ts_to_add - ) # type: ignore - dummy_df[feature_columns] = NODATAVALUE - intermediate_dummy_df = pd.concat([intermediate_dummy_df, dummy_df]) - df = pd.concat([df, intermediate_dummy_df]) - - # add timesteps after the end_date where needed - intermediate_dummy_df = pd.DataFrame() - for n_ts_to_add in range(1, MIN_EDGE_BUFFER + 1): - samples_to_add_ts_after_end = latest_obs_position[ - (MIN_EDGE_BUFFER - latest_obs_position["valid_position_diff"]) >= n_ts_to_add + samples_before_start_date = latest_obs_position[ + (latest_obs_position["valid_position"] < 0) ].index - dummy_df = df[ - (df["sample_id"].isin(samples_to_add_ts_after_end)) & (df["is_last_available_ts"]) - ].copy() - dummy_df["timestamp"] = dummy_df["timestamp"] + pd.DateOffset( - months=n_ts_to_add - ) # type: ignore - dummy_df[feature_columns] = NODATAVALUE - intermediate_dummy_df = pd.concat([intermediate_dummy_df, dummy_df]) - df = pd.concat([df, intermediate_dummy_df]) - - # Now reassign start_date to the minimum timestamp - new_start_date = df.groupby(["sample_id"])["timestamp"].min() - df["start_date"] = df["sample_id"].map(new_start_date) - - # Also reassign end_date to the maximum timestamp - new_end_date = df.groupby(["sample_id"])["timestamp"].max() - df["end_date"] = df["sample_id"].map(new_end_date) - - # reinitialize timestep_ind - df["timestamp_ind"] = (df["timestamp"].dt.year * 12 + df["timestamp"].dt.month) - ( - df["start_date"].dt.year * 12 + df["start_date"].dt.month - ) - # check for missing timestamps in the middle of timeseries - # and create corresponding columns with NODATAVALUE - missing_timestamps = [ - xx for xx in range(df["timestamp_ind"].max()) if xx not in df["timestamp_ind"].unique() - ] - present_timestamps = [ - xx for xx in range(df["timestamp_ind"].max()) if xx not in missing_timestamps - ] - for missing_timestamp in missing_timestamps: - dummy_df = df[df["timestamp_ind"] == np.random.choice(present_timestamps)].copy() - dummy_df["timestamp_ind"] = missing_timestamp - dummy_df[feature_columns] = NODATAVALUE - df = pd.concat([df, dummy_df]) + if len(samples_after_end_date) > 0 or len(samples_before_start_date) > 0: + logger.warning( + f"""\ + Removing {len(samples_after_end_date)} \ + samples with valid_time after the end_date \ + and {len(samples_before_start_date)} samples with valid_time before the start_date""" + ) + df = df[~df["sample_id"].isin(samples_before_start_date)] + df = df[~df["sample_id"].isin(samples_after_end_date)] + + # compute average distance between observations + # and use it as an approximation for frequency + obs_timestamps = pd.Series(df["timestamp"].unique()).sort_values() + avg_distance = int(obs_timestamps.diff().abs().dt.days.mean()) + + # add timesteps before the start_date where needed + intermediate_dummy_df = pd.DataFrame() + for n_ts_to_add in range(1, min_edge_buffer + 1): + samples_to_add_ts_before_start = latest_obs_position[ + (min_edge_buffer - latest_obs_position["valid_position"]) >= -n_ts_to_add + ].index + dummy_df = df[ + (df["sample_id"].isin(samples_to_add_ts_before_start)) & (df["timestamp_ind"] == 0) + ].copy() + dummy_df["timestamp"] = dummy_df["timestamp"] - pd.DateOffset( + days=(n_ts_to_add * avg_distance) + ) # type: ignore + dummy_df[feature_columns] = NODATAVALUE + intermediate_dummy_df = pd.concat([intermediate_dummy_df, dummy_df]) + df = pd.concat([df, intermediate_dummy_df]) + + # add timesteps after the end_date where needed + intermediate_dummy_df = pd.DataFrame() + for n_ts_to_add in range(1, min_edge_buffer + 1): + samples_to_add_ts_after_end = latest_obs_position[ + (min_edge_buffer - latest_obs_position["valid_position_diff"]) >= n_ts_to_add + ].index + dummy_df = df[ + (df["sample_id"].isin(samples_to_add_ts_after_end)) & (df["is_last_available_ts"]) + ].copy() + dummy_df["timestamp"] = dummy_df["timestamp"] + pd.DateOffset( + months=(n_ts_to_add * avg_distance) + ) # type: ignore + dummy_df[feature_columns] = NODATAVALUE + intermediate_dummy_df = pd.concat([intermediate_dummy_df, dummy_df]) + df = pd.concat([df, intermediate_dummy_df]) + + # reinitialize timestep_ind + df["timestamp_ind"] = df.groupby("sample_id")["timestamp"].rank().astype(int) - 1 + + df["available_timesteps"] = df["sample_id"].map( + df.groupby("sample_id")["timestamp"].nunique().astype(int) + ) + index_columns.append("available_timesteps") # finally pivot the dataframe + index_columns = list(set(index_columns)) df_pivot = df.pivot(index=index_columns, columns="timestamp_ind", values=feature_columns) df_pivot = df_pivot.fillna(NODATAVALUE) @@ -335,42 +311,34 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: f"{xx}-100m" if any(band in xx for band in bands100m) else xx for xx in df_pivot.columns ] # type: ignore - df_pivot["valid_position"] = ( - df_pivot["valid_date"].dt.year * 12 + df_pivot["valid_date"].dt.month - ) - (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month) - df_pivot["available_timesteps"] = ( - (df_pivot["end_date"].dt.year * 12 + df_pivot["end_date"].dt.month) - - (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month) - + 1 - ) + if use_valid_time: + df_pivot["year"] = df_pivot["valid_time"].dt.year + df_pivot["valid_time"] = df_pivot["valid_time"].dt.date.astype(str) - min_center_point = np.maximum( - NUM_TIMESTEPS // 2, - df_pivot["valid_position"] + MIN_EDGE_BUFFER - NUM_TIMESTEPS // 2, - ) - max_center_point = np.minimum( - df_pivot["available_timesteps"] - NUM_TIMESTEPS // 2, - df_pivot["valid_position"] - MIN_EDGE_BUFFER + NUM_TIMESTEPS // 2, - ) + min_center_point = np.maximum( + num_timesteps // 2, + df_pivot["valid_position"] + min_edge_buffer - num_timesteps // 2, + ) + max_center_point = np.minimum( + df_pivot["available_timesteps"] - num_timesteps // 2, + df_pivot["valid_position"] - min_edge_buffer + num_timesteps // 2, + ) - faulty_samples = min_center_point > max_center_point - if faulty_samples.sum() > 0: - logger.warning(f"Dropping {faulty_samples.sum()} faulty samples.") - df_pivot = df_pivot[~faulty_samples] + faulty_samples = min_center_point > max_center_point + if faulty_samples.sum() > 0: + logger.warning(f"Dropping {faulty_samples.sum()} faulty samples.") + df_pivot = df_pivot[~faulty_samples] - samples_with_too_few_ts = df_pivot["available_timesteps"] < NUM_TIMESTEPS + samples_with_too_few_ts = df_pivot["available_timesteps"] < num_timesteps if samples_with_too_few_ts.sum() > 0: logger.warning( f"Dropping {samples_with_too_few_ts.sum()} samples with \ -number of available timesteps less than {NUM_TIMESTEPS}." +number of available timesteps less than {num_timesteps}." ) - df_pivot = df_pivot[~samples_with_too_few_ts] - - df_pivot["year"] = df_pivot["valid_date"].dt.year + df_pivot = df_pivot[~samples_with_too_few_ts] df_pivot["start_date"] = df_pivot["start_date"].dt.date.astype(str) df_pivot["end_date"] = df_pivot["end_date"].dt.date.astype(str) - df_pivot["valid_date"] = df_pivot["valid_date"].dt.date.astype(str) df_pivot = prep_dataframe(df_pivot) @@ -716,7 +684,7 @@ def prep_dataframe( # SAR cannot equal 0.0 since we take the log of it cols = [f"SAR-{s}-ts{t}-20m" for s in ["VV", "VH"] for t in range(36 if dekadal else 12)] - df = df.drop_duplicates(subset=["sample_id", "lat", "lon", "end_date"]) + df = df.drop_duplicates(subset=["sample_id", "lat", "lon", "start_date"]) df = df[~pd.isna(df).any(axis=1)] df = df[~(df.loc[:, cols] == 0.0).any(axis=1)] df = df.set_index("sample_id")