From b3c648a8b531d377e3bf2a2087f94fb5cbe83864 Mon Sep 17 00:00:00 2001 From: Giorgia Milli Date: Thu, 3 Oct 2024 15:52:38 +0200 Subject: [PATCH 1/9] added decadal masking strategy --- presto/dataset.py | 20 ++++++++++++++++++++ presto/masking.py | 3 +-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/presto/dataset.py b/presto/dataset.py index 5d9a7d7..1d1d423 100644 --- a/presto/dataset.py +++ b/presto/dataset.py @@ -377,7 +377,27 @@ def __getitem__(self, idx): strat, real_mask_per_variable, ) + + +class WorldCerealMasked10DDataset(WorldCerealMaskedDataset): + NUM_TIMESTEPS = 36 + + @classmethod + def get_month_array(cls, row: pd.Series) -> np.ndarray: + start_date, end_date = datetime.strptime(row.start_date, "%Y-%m-%d"), datetime.strptime( + row.end_date, "%Y-%m-%d" + ) + + # Calculate the step size for 10-day intervals and create a list of dates + step = int((end_date - start_date).days / (cls.NUM_TIMESTEPS - 1)) + date_vector = [start_date + timedelta(days=i * step) for i in range(cls.NUM_TIMESTEPS)] + + # Ensure last date is not beyond the end date + if date_vector[-1] > end_date: + date_vector[-1] = end_date + return np.array([d.month - 1 for d in date_vector]) + def filter_remove_noncrops(df: pd.DataFrame) -> pd.DataFrame: labels_to_exclude = [ diff --git a/presto/masking.py b/presto/masking.py index 196e836..5b1425c 100644 --- a/presto/masking.py +++ b/presto/masking.py @@ -10,7 +10,6 @@ BANDS_GROUPS_IDX, NUM_TIMESTEPS, SRTM_INDEX, - TIMESTEPS_IDX, ) MASK_STRATEGIES = ( @@ -113,7 +112,7 @@ def random_masking(mask, num_tokens_to_mask: int): # -1 for SRTM timesteps_to_mask = int(num_tokens_to_mask / (len(BANDS_GROUPS_IDX) - 1)) max_tokens_masked = (len(BANDS_GROUPS_IDX) - 1) * timesteps_to_mask - timesteps = sample(TIMESTEPS_IDX, k=timesteps_to_mask) + timesteps = sample(range(len(num_timesteps)), k=timesteps_to_mask) if timesteps_to_mask > 0: num_tokens_to_mask -= int(max_tokens_masked - sum(sum(mask[timesteps]))) mask[timesteps] = True From 5b6d3e3c6f3a19220c6f2350f61f02f6a833e47d Mon Sep 17 00:00:00 2001 From: Giorgia Milli Date: Thu, 3 Oct 2024 15:59:41 +0200 Subject: [PATCH 2/9] decadal masking import --- train_self_supervised.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/train_self_supervised.py b/train_self_supervised.py index 374e98e..47b5a0d 100644 --- a/train_self_supervised.py +++ b/train_self_supervised.py @@ -15,7 +15,11 @@ # import xarray as xr from presto.dataops import BANDS_GROUPS_IDX, NODATAVALUE -from presto.dataset import WorldCerealBase, WorldCerealMaskedDataset +from presto.dataset import ( + WorldCerealBase, + WorldCerealMaskedDataset, + WorldCerealMasked10DDataset +) from presto.masking import MASK_STRATEGIES, MaskParamsNoDw from presto.presto import ( LossWrapper, @@ -170,7 +174,7 @@ # Load the mask parameters mask_params = MaskParamsNoDw(mask_strategies, mask_ratio, num_timesteps=36 if dekadal else 12) -masked_ds = WorldCerealMaskedDataset +masked_ds = WorldCerealMasked10DDataset if dekadal else WorldCerealMaskedDataset train_dataloader = DataLoader( masked_ds(train_df, mask_params=mask_params, is_ssl=True), From 3a3c0b690c0c9cf5a137416feae78e493897c148 Mon Sep 17 00:00:00 2001 From: Giorgia Milli Date: Wed, 23 Oct 2024 15:16:53 +0200 Subject: [PATCH 3/9] bug fix --- presto/masking.py | 9 ++------- presto/presto.py | 7 +++++-- presto/utils.py | 43 ++++++++++++++++++------------------------- 3 files changed, 25 insertions(+), 34 deletions(-) diff --git a/presto/masking.py b/presto/masking.py index a96bd95..7388361 100644 --- a/presto/masking.py +++ b/presto/masking.py @@ -5,12 +5,7 @@ import numpy as np -from .dataops import ( - BAND_EXPANSION, - BANDS_GROUPS_IDX, - NUM_TIMESTEPS, - SRTM_INDEX, -) +from .dataops import BAND_EXPANSION, BANDS_GROUPS_IDX, NUM_TIMESTEPS, SRTM_INDEX MASK_STRATEGIES = ( "group_bands", @@ -105,7 +100,7 @@ def random_masking(mask, num_tokens_to_mask: int): # -1 for SRTM timesteps_to_mask = int(num_tokens_to_mask / (len(BANDS_GROUPS_IDX) - 1)) max_tokens_masked = (len(BANDS_GROUPS_IDX) - 1) * timesteps_to_mask - timesteps = sample(range(len(num_timesteps)), k=timesteps_to_mask) + timesteps = sample(range(num_timesteps), k=timesteps_to_mask) if timesteps_to_mask > 0: num_tokens_to_mask -= int(max_tokens_masked - sum(sum(mask[timesteps]))) mask[timesteps] = True diff --git a/presto/presto.py b/presto/presto.py index ed3cd85..c76d465 100644 --- a/presto/presto.py +++ b/presto/presto.py @@ -860,8 +860,8 @@ def load_pretrained( valid_month_size=valid_month_size, ) - if dekadal: - model = extend_to_dekadal(model) + # if dekadal: + # model = extend_to_dekadal(model) if is_finetuned: model = model.construct_finetuning_model(num_outputs) @@ -872,6 +872,9 @@ def load_pretrained( else: model.load_state_dict(torch.load(model_path, map_location=device), strict=strict) + if dekadal: + model = extend_to_dekadal(model) + return model @classmethod diff --git a/presto/utils.py b/presto/utils.py index 8715b56..912752e 100644 --- a/presto/utils.py +++ b/presto/utils.py @@ -99,7 +99,7 @@ def get_class_mappings() -> Dict: return CLASS_MAPPINGS -def process_parquet(df: pd.DataFrame) -> pd.DataFrame: +def process_parquet(df: pd.DataFrame, num_ts: int = 12) -> pd.DataFrame: """ This function takes in a DataFrame with S1, S2 and ERA5 observations and their respective dates in long format and returns it in wide format. @@ -131,7 +131,7 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: - computing the number of available timesteps in the timeseries that takes into account updated start_date and end_date; available_timesteps holds the absolute number of timesteps that for which observations are - available; it cannot be less than NUM_TIMESTEPS; if this is the case, + available; it cannot be less than num_ts; if this is the case, sample is considered faulty and is removed from the dataset - post-processing with prep_dataframe function @@ -202,6 +202,8 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: "valid_date", "location_id", "ref_id", + "valid_position", + "available_timesteps", ] bands10m = ["OPTICAL-B02", "OPTICAL-B03", "OPTICAL-B04", "OPTICAL-B08"] @@ -217,18 +219,17 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: ] bands100m = ["METEO-precipitation_flux", "METEO-temperature_mean"] - df["timestamp_ind"] = (df["timestamp"].dt.year * 12 + df["timestamp"].dt.month) - ( - df["start_date"].dt.year * 12 + df["start_date"].dt.month - ) - df["valid_position"] = (df["valid_date"].dt.year * 12 + df["valid_date"].dt.month) - ( - df["start_date"].dt.year * 12 + df["start_date"].dt.month + df["timestamp_ind"] = df.groupby("sample_id")["timestamp"].rank().astype(int) + df["valid_date_ts_diff_days"] = (df["valid_date"] - df["timestamp"]).dt.days.abs() + valid_position = ( + df.set_index("timestamp_ind").groupby("sample_id")["valid_date_ts_diff_days"].idxmin() ) + df["valid_position"] = df["sample_id"].map(valid_position) df["valid_position_diff"] = df["timestamp_ind"] - df["valid_position"] # save the initial start_date for later df["initial_start_date"] = df["start_date"].copy() index_columns.append("initial_start_date") - # define samples where valid_date is outside the range of the actual extractions # and remove them from the dataset latest_obs_position = df.groupby(["sample_id"])[ @@ -295,8 +296,9 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: df["end_date"] = df["sample_id"].map(new_end_date) # reinitialize timestep_ind - df["timestamp_ind"] = (df["timestamp"].dt.year * 12 + df["timestamp"].dt.month) - ( - df["start_date"].dt.year * 12 + df["start_date"].dt.month + df["timestamp_ind"] = df.groupby("sample_id")["timestamp"].rank().astype(int) + df["available_timesteps"] = df["sample_id"].map( + df.groupby("sample_id")["timestamp"].nunique().astype(int) ) # check for missing timestamps in the middle of timeseries @@ -335,22 +337,13 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: f"{xx}-100m" if any(band in xx for band in bands100m) else xx for xx in df_pivot.columns ] # type: ignore - df_pivot["valid_position"] = ( - df_pivot["valid_date"].dt.year * 12 + df_pivot["valid_date"].dt.month - ) - (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month) - df_pivot["available_timesteps"] = ( - (df_pivot["end_date"].dt.year * 12 + df_pivot["end_date"].dt.month) - - (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month) - + 1 - ) - min_center_point = np.maximum( - NUM_TIMESTEPS // 2, - df_pivot["valid_position"] + MIN_EDGE_BUFFER - NUM_TIMESTEPS // 2, + num_ts // 2, + df_pivot["valid_position"] + MIN_EDGE_BUFFER - num_ts // 2, ) max_center_point = np.minimum( - df_pivot["available_timesteps"] - NUM_TIMESTEPS // 2, - df_pivot["valid_position"] - MIN_EDGE_BUFFER + NUM_TIMESTEPS // 2, + df_pivot["available_timesteps"] - num_ts // 2, + df_pivot["valid_position"] - MIN_EDGE_BUFFER + num_ts // 2, ) faulty_samples = min_center_point > max_center_point @@ -358,11 +351,11 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: logger.warning(f"Dropping {faulty_samples.sum()} faulty samples.") df_pivot = df_pivot[~faulty_samples] - samples_with_too_few_ts = df_pivot["available_timesteps"] < NUM_TIMESTEPS + samples_with_too_few_ts = df_pivot["available_timesteps"] < num_ts if samples_with_too_few_ts.sum() > 0: logger.warning( f"Dropping {samples_with_too_few_ts.sum()} samples with \ -number of available timesteps less than {NUM_TIMESTEPS}." +number of available timesteps less than {num_ts}." ) df_pivot = df_pivot[~samples_with_too_few_ts] From db35cf9115b1fba9dda3de0e9246b98363fd1a2b Mon Sep 17 00:00:00 2001 From: Giorgia Milli Date: Thu, 24 Oct 2024 10:34:45 +0200 Subject: [PATCH 4/9] refactoring and bug fixes --- presto/dataset.py | 150 ++++++++++++++++++++++++++++++---------------- presto/masking.py | 4 +- 2 files changed, 102 insertions(+), 52 deletions(-) diff --git a/presto/dataset.py b/presto/dataset.py index e9c5ba1..faaa3bc 100644 --- a/presto/dataset.py +++ b/presto/dataset.py @@ -20,6 +20,7 @@ NDVI_INDEX, NODATAVALUE, NORMED_BANDS, + NUM_TIMESTEPS, S1_S2_ERA5_SRTM, S2_RGB_INDEX, DynamicWorld2020_2021, @@ -38,7 +39,6 @@ class WorldCerealBase(Dataset): # _NODATAVALUE = 65535 - NUM_TIMESTEPS = 12 BAND_MAPPING = { "OPTICAL-B02-ts{}-10m": "B2", "OPTICAL-B03-ts{}-10m": "B3", @@ -57,26 +57,27 @@ class WorldCerealBase(Dataset): } STATIC_BAND_MAPPING = {"DEM-alt-20m": "elevation", "DEM-slo-20m": "slope"} - def __init__(self, dataframe: pd.DataFrame): + def __init__(self, dataframe: pd.DataFrame, num_timesteps: int = NUM_TIMESTEPS): self.df = dataframe + self.num_timesteps = num_timesteps def __len__(self): return self.df.shape[0] - @classmethod def get_timestep_positions( - cls, row_d: Dict, augment: bool = False, is_ssl: bool = False + self, row_d: Dict, augment: bool = False, is_ssl: bool = False ) -> List[int]: available_timesteps = int(row_d["available_timesteps"]) if is_ssl: - if available_timesteps == cls.NUM_TIMESTEPS: - valid_position = int(cls.NUM_TIMESTEPS // 2) + if available_timesteps == self.num_timesteps: + valid_position = int(self.num_timesteps // 2) else: valid_position = int( np.random.choice( range( - cls.NUM_TIMESTEPS // 2, (available_timesteps - cls.NUM_TIMESTEPS // 2) + self.num_timesteps // 2, + (available_timesteps - self.num_timesteps // 2), ), 1, ) @@ -86,11 +87,11 @@ def get_timestep_positions( valid_position = int(row_d["valid_position"]) if not augment: # check if the valid position is too close to the start_date and force shifting it - if valid_position < cls.NUM_TIMESTEPS // 2: - center_point = cls.NUM_TIMESTEPS // 2 + if valid_position < self.num_timesteps // 2: + center_point = self.num_timesteps // 2 # or too close to the end_date - elif valid_position > (available_timesteps - cls.NUM_TIMESTEPS // 2): - center_point = available_timesteps - cls.NUM_TIMESTEPS // 2 + elif valid_position > (available_timesteps - self.num_timesteps // 2): + center_point = available_timesteps - self.num_timesteps // 2 else: # Center the timesteps around the valid position center_point = valid_position @@ -99,35 +100,34 @@ def get_timestep_positions( # well includes the valid position min_center_point = max( - cls.NUM_TIMESTEPS // 2, - valid_position + MIN_EDGE_BUFFER - cls.NUM_TIMESTEPS // 2, + self.num_timesteps // 2, + valid_position + MIN_EDGE_BUFFER - self.num_timesteps // 2, ) max_center_point = min( - available_timesteps - cls.NUM_TIMESTEPS // 2, - valid_position - MIN_EDGE_BUFFER + cls.NUM_TIMESTEPS // 2, + available_timesteps - self.num_timesteps // 2, + valid_position - MIN_EDGE_BUFFER + self.num_timesteps // 2, ) center_point = np.random.randint( min_center_point, max_center_point + 1 ) # max_center_point included - last_timestep = min(available_timesteps, center_point + cls.NUM_TIMESTEPS // 2) - first_timestep = max(0, last_timestep - cls.NUM_TIMESTEPS) + last_timestep = min(available_timesteps, center_point + self.num_timesteps // 2) + first_timestep = max(0, last_timestep - self.num_timesteps) timestep_positions = list(range(first_timestep, last_timestep)) - if len(timestep_positions) != cls.NUM_TIMESTEPS: + if len(timestep_positions) != self.num_timesteps: raise ValueError( f"Acquired timestep positions do not have correct length: \ -required {cls.NUM_TIMESTEPS}, got {len(timestep_positions)}" +required {self.num_timesteps}, got {len(timestep_positions)}" ) assert ( valid_position in timestep_positions ), f"Valid position {valid_position} not in timestep positions {timestep_positions}" return timestep_positions - @classmethod def row_to_arrays( - cls, + self, row: pd.Series, task_type: str = "cropland", croptype_list: List = [], @@ -140,11 +140,11 @@ def row_to_arrays( latlon = np.array([row_d["lat"], row_d["lon"]], dtype=np.float32) - timestep_positions = cls.get_timestep_positions(row_d, augment=augment, is_ssl=is_ssl) + timestep_positions = self.get_timestep_positions(row_d, augment=augment, is_ssl=is_ssl) - if cls.NUM_TIMESTEPS == 12: + if self.num_timesteps == 12: initial_start_date_position = pd.to_datetime(row_d["start_date"]).month - elif cls.NUM_TIMESTEPS > 12: + elif self.num_timesteps > 12: # get the correct index of the start_date based on NUM_TIMESTEPS` # e.g. if NUM_TIMESTEPS is 36 (dekadal setup), we should take the correct # 10-day interval that the start_date falls into @@ -153,18 +153,18 @@ def row_to_arrays( # should also be changed accordingly year = pd.to_datetime(row_d["start_date"]).year year_dates = pd.date_range(start=f"{year}-01-01", end=f"{year}-12-31") - bins = pd.cut(year_dates, bins=cls.NUM_TIMESTEPS, labels=False) + bins = pd.cut(year_dates, bins=self.num_timesteps, labels=False) initial_start_date_position = bins[ np.where(year_dates == pd.to_datetime(row_d["start_date"]))[0][0] ] else: raise ValueError( - f"NUM_TIMESTEPS must be at least 12. Currently it is {cls.NUM_TIMESTEPS}" + f"NUM_TIMESTEPS must be at least 12. Currently it is {self.num_timesteps}" ) # make sure that month for encoding gets shifted according to # the selected timestep positions. Also ensure circular indexing - month = (initial_start_date_position - 1 + timestep_positions[0]) % cls.NUM_TIMESTEPS + month = (initial_start_date_position - 1 + timestep_positions[0]) % self.num_timesteps # adding workaround for compatibility between Phase I and Phase II datasets. # (in Phase II, the relevant attribute name was changed to valid_time) @@ -176,12 +176,12 @@ def row_to_arrays( else: logger.error("Dataset does not contain neither valid_date, nor valid_time attribute.") - eo_data = np.zeros((cls.NUM_TIMESTEPS, len(BANDS))) + eo_data = np.zeros((self.num_timesteps, len(BANDS))) # an assumption we make here is that all timesteps for a token # have the same masking - mask = np.zeros((cls.NUM_TIMESTEPS, len(BANDS_GROUPS_IDX))) + mask = np.zeros((self.num_timesteps, len(BANDS_GROUPS_IDX))) - for df_val, presto_val in cls.BAND_MAPPING.items(): + for df_val, presto_val in self.BAND_MAPPING.items(): values = np.array([float(row_d[df_val.format(t)]) for t in timestep_positions]) # this occurs for the DEM values in one point in Fiji values = np.nan_to_num(values, nan=NODATAVALUE) @@ -199,7 +199,7 @@ def row_to_arrays( mask[:, IDX_TO_BAND_GROUPS[presto_val]] += ~idx_valid eo_data[:, BANDS.index(presto_val)] = values * idx_valid - for df_val, presto_val in cls.STATIC_BAND_MAPPING.items(): + for df_val, presto_val in self.STATIC_BAND_MAPPING.items(): # this occurs for the DEM values in one point in Fiji values = np.nan_to_num(row_d[df_val], nan=NODATAVALUE) idx_valid = values != NODATAVALUE @@ -210,7 +210,7 @@ def row_to_arrays( # or nir mask, and adjust the NDVI mask accordingly mask[:, NDVI_INDEX] = np.logical_or(mask[:, S2_RGB_INDEX], mask[:, S2_NIR_10m_INDEX]) - return (cls.check(eo_data), mask.astype(bool), latlon, month, valid_month) + return (self.check(eo_data), mask.astype(bool), latlon, month, valid_month) def __getitem__(self, idx): # Get the sample @@ -220,7 +220,7 @@ def __getitem__(self, idx): mask_per_variable = np.repeat(mask_per_token, BAND_EXPANSION, axis=1) return ( self.normalize_and_mask(eo), - np.ones(self.NUM_TIMESTEPS) * (DynamicWorld2020_2021.class_amount), + np.ones(self.num_timesteps) * (DynamicWorld2020_2021.class_amount), latlon, month, valid_month, @@ -348,19 +348,24 @@ def split_df( class WorldCerealMaskedDataset(WorldCerealBase): + def __init__( self, dataframe: pd.DataFrame, + num_timesteps: int, mask_params: MaskParamsNoDw, task_type: str = "cropland", croptype_list: List = [], is_ssl: bool = True, + is_dekadal: bool = False, ): - super().__init__(dataframe) + super().__init__(dataframe, num_timesteps) + self.num_timesteps = num_timesteps self.mask_params = mask_params self.task_type = task_type self.croptype_list = croptype_list self.is_ssl = is_ssl + self.is_dekadal = is_dekadal def __getitem__(self, idx): # Get the sample @@ -370,13 +375,15 @@ def __getitem__(self, idx): ) mask_eo, x_eo, y_eo, strat = self.mask_params.mask_data( - self.normalize_and_mask(eo), real_mask_per_token + self.normalize_and_mask(eo), + real_mask_per_token, ) real_mask_per_variable = np.repeat(real_mask_per_token, BAND_EXPANSION, axis=1) - dynamic_world = np.ones(self.NUM_TIMESTEPS) * (DynamicWorld2020_2021.class_amount) - mask_dw = np.full(self.NUM_TIMESTEPS, True) + dynamic_world = np.ones(self.num_timesteps) * (DynamicWorld2020_2021.class_amount) + mask_dw = np.full(self.num_timesteps, True) y_dw = dynamic_world.copy() + return MaskedExample( mask_eo, mask_dw, @@ -384,32 +391,77 @@ def __getitem__(self, idx): y_eo, dynamic_world, y_dw, - month, + self.get_month_array(row) if self.is_dekadal else month, latlon, strat, real_mask_per_variable, ) - - -class WorldCerealMasked10DDataset(WorldCerealMaskedDataset): - NUM_TIMESTEPS = 36 - @classmethod - def get_month_array(cls, row: pd.Series) -> np.ndarray: + def get_month_array(self, row: pd.Series) -> np.ndarray: start_date, end_date = datetime.strptime(row.start_date, "%Y-%m-%d"), datetime.strptime( row.end_date, "%Y-%m-%d" ) # Calculate the step size for 10-day intervals and create a list of dates - step = int((end_date - start_date).days / (cls.NUM_TIMESTEPS - 1)) - date_vector = [start_date + timedelta(days=i * step) for i in range(cls.NUM_TIMESTEPS)] + step = int((end_date - start_date).days / (self.num_timesteps - 1)) + date_vector = [start_date + timedelta(days=i * step) for i in range(self.num_timesteps)] # Ensure last date is not beyond the end date if date_vector[-1] > end_date: date_vector[-1] = end_date return np.array([d.month - 1 for d in date_vector]) - + + +# class WorldCerealMasked10DDataset(WorldCerealMaskedDataset): +# NUM_TIMESTEPS = 36 + +# def __getitem__(self, idx): +# # Get the sample +# row = self.df.iloc[idx, :] +# eo, real_mask_per_token, latlon, _, _ = self.row_to_arrays( +# row, task_type=self.task_type, croptype_list=self.croptype_list, is_ssl=self.is_ssl +# ) + +# mask_eo, x_eo, y_eo, strat = self.mask_params.mask_data( +# self.normalize_and_mask(eo), +# real_mask_per_token, +# num_timesteps=self.NUM_TIMESTEPS, +# ) +# real_mask_per_variable = np.repeat(real_mask_per_token, BAND_EXPANSION, axis=1) + +# dynamic_world = np.ones(self.NUM_TIMESTEPS) * (DynamicWorld2020_2021.class_amount) +# mask_dw = np.full(self.NUM_TIMESTEPS, True) +# y_dw = dynamic_world.copy() +# return MaskedExample( +# mask_eo, +# mask_dw, +# x_eo, +# y_eo, +# dynamic_world, +# y_dw, +# self.get_month_array(row), ######### +# latlon, +# strat, +# real_mask_per_variable, +# ) + +# @classmethod +# def get_month_array(cls, row: pd.Series) -> np.ndarray: +# start_date, end_date = datetime.strptime(row.start_date, "%Y-%m-%d"), datetime.strptime( +# row.end_date, "%Y-%m-%d" +# ) + +# # Calculate the step size for 10-day intervals and create a list of dates +# step = int((end_date - start_date).days / (cls.NUM_TIMESTEPS - 1)) +# date_vector = [start_date + timedelta(days=i * step) for i in range(cls.NUM_TIMESTEPS)] + +# # Ensure last date is not beyond the end date +# if date_vector[-1] > end_date: +# date_vector[-1] = end_date + +# return np.array([d.month - 1 for d in date_vector]) + def filter_remove_noncrops(df: pd.DataFrame) -> pd.DataFrame: labels_to_exclude = [ @@ -838,9 +890,7 @@ def _subset_array_temporally(inarr: xr.DataArray) -> xr.DataArray: return inarr @classmethod - def nc_to_arrays( - cls, filepath: Path - ) -> Tuple[ + def nc_to_arrays(cls, filepath: Path) -> Tuple[ np.ndarray, np.ndarray, np.ndarray, diff --git a/presto/masking.py b/presto/masking.py index 7388361..1b3ecc8 100644 --- a/presto/masking.py +++ b/presto/masking.py @@ -139,14 +139,14 @@ def __post_init__(self): "random_combinations", ] - def mask_data(self, eo_data: np.ndarray, mask: np.ndarray, num_timesteps: int = NUM_TIMESTEPS): + def mask_data(self, eo_data: np.ndarray, mask: np.ndarray): strategy = choice(self.strategies) mask = make_mask_no_dw( strategy=strategy, mask_ratio=self.ratio, existing_mask=mask, - num_timesteps=num_timesteps, + num_timesteps=self.num_timesteps, ) x = eo_data * ~mask From ad2179e6f4859639596b06b5252a63614578263f Mon Sep 17 00:00:00 2001 From: Giorgia Milli Date: Thu, 24 Oct 2024 14:24:01 +0200 Subject: [PATCH 5/9] num_token_to_mask modified to avoid < 0 --- presto/masking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presto/masking.py b/presto/masking.py index 1b3ecc8..5b277b4 100644 --- a/presto/masking.py +++ b/presto/masking.py @@ -47,7 +47,7 @@ def make_mask_no_dw( num_tokens_to_mask = int( ((num_timesteps * (len(BANDS_GROUPS_IDX) - 1)) + 1) * mask_ratio - ) - sum(sum(mask)) + ) # - sum(sum(mask)) assert num_tokens_to_mask > 0 def mask_topography(srtm_mask, num_tokens_to_mask, mask_ratio): From 8a2c8322a912ff39f996b44e3c06f7501e7bc1a3 Mon Sep 17 00:00:00 2001 From: Giorgia Milli Date: Thu, 24 Oct 2024 15:20:17 +0200 Subject: [PATCH 6/9] bug fixes and updates --- train_self_supervised.py | 52 +++++++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/train_self_supervised.py b/train_self_supervised.py index 32cd816..0a6e968 100644 --- a/train_self_supervised.py +++ b/train_self_supervised.py @@ -15,10 +15,9 @@ # import xarray as xr from presto.dataops import BANDS_GROUPS_IDX, NODATAVALUE -from presto.dataset import ( +from presto.dataset import ( # WorldCerealMasked10DDataset, WorldCerealBase, WorldCerealMaskedDataset, - WorldCerealMasked10DDataset ) from presto.masking import MASK_STRATEGIES, MaskParamsNoDw from presto.presto import ( @@ -40,6 +39,8 @@ timestamp_dirname, ) +torch.multiprocessing.set_sharing_strategy("file_system") + logger = logging.getLogger("__main__") argparser = argparse.ArgumentParser() @@ -53,6 +54,11 @@ "and /output/ will be written to. " "Leave empty to use the directory you are running this file from.", ) +argparser.add_argument( + "--parquet_file", + type=str, + default="/home/vito/millig/projects/TAP/worldcereal/data/worldcereal_training_data.parquet", +) argparser.add_argument("--seed", type=int, default=DEFAULT_SEED) argparser.add_argument("--num_workers", type=int, default=64) argparser.add_argument("--wandb", dest="wandb", action="store_true") @@ -85,6 +91,7 @@ default="random", choices=["random", "spatial", "temporal", "seasonal"], ) +argparser.add_argument("--dekadal", dest="dekadal", action="store_true") argparser.add_argument("--train_only_samples_file", type=str, default="train_only_samples.csv") argparser.add_argument("--warm_start", dest="warm_start", action="store_true") @@ -116,6 +123,7 @@ # presto_model_description: str = args["presto_model_description"] test_type: str = args["test_type"] assert test_type in ["random", "spatial", "temporal", "seasonal"] +dekadal = args["dekadal"] seed_everything(seed) output_parent_dir = Path(args["output_dir"]) if args["output_dir"] else Path(__file__).parent @@ -128,6 +136,7 @@ entity=wandb_org, project="presto-worldcereal", dir=output_parent_dir, + name=model_name, ) run_id = cast(wandb.sdk.wandb_run.Run, run).id @@ -136,15 +145,11 @@ initialize_logging(model_logging_dir) logger.info("Using output dir: %s" % model_logging_dir) -# parquet_file: str = args["parquet_file"] -parquet_file = "/home/vito/butskoc/presto-worldcereal/data/long_parquet/\ -worldcereal_training_data.parquet" +parquet_file: str = args["parquet_file"] train_only_samples_file: str = args["train_only_samples_file"] -dekadal = False compositing_window = "30D" -if "10d" in parquet_file: - dekadal = True +if dekadal: compositing_window = "10D" path_to_config = config_dir / "default.json" @@ -153,7 +158,7 @@ logger.info("Loading data") -files = sorted(glob(f"{parquet_file}/**/*.parquet")) +files = sorted(glob(f"{parquet_file}/**/*.parquet"))[:10] df_list = [] for f in tqdm(files): _data = pd.read_parquet(f, engine="fastparquet") @@ -172,20 +177,34 @@ logger.info(f"Preparing train and val splits for {test_type} test") val_samples_df = pd.read_csv(data_dir / "test_splits" / val_samples_file) -train_df, val_df = WorldCerealBase.split_df(df, val_sample_ids=val_samples_df.sample_id.tolist()) +train_df, val_df = WorldCerealBase.split_df( + df, val_size=0.1 +) # val_sample_ids=val_samples_df.sample_id.tolist()) # Load the mask parameters mask_params = MaskParamsNoDw(mask_strategies, mask_ratio, num_timesteps=36 if dekadal else 12) -masked_ds = WorldCerealMasked10DDataset if dekadal else WorldCerealMaskedDataset +# masked_ds = WorldCerealMasked10DDataset if dekadal else WorldCerealMaskedDataset train_dataloader = DataLoader( - masked_ds(train_df, mask_params=mask_params, is_ssl=True), + WorldCerealMaskedDataset( + train_df, + num_timesteps=36 if dekadal else 12, + mask_params=mask_params, + is_ssl=True, + is_dekadal=dekadal, + ), batch_size=batch_size, shuffle=True, num_workers=num_workers, ) val_dataloader = DataLoader( - masked_ds(val_df, mask_params=mask_params, is_ssl=True), + WorldCerealMaskedDataset( + val_df, + num_timesteps=36 if dekadal else 12, + mask_params=mask_params, + is_ssl=True, + is_dekadal=dekadal, + ), batch_size=batch_size, shuffle=False, num_workers=num_workers, @@ -210,10 +229,11 @@ model_kwargs = json.load(Path(path_to_config).open("r")) model = Presto.construct(**model_kwargs) best_model_path = None + # moved into the else block ################## + if dekadal: + logger.info("extending model to dekadal architecture") + model = extend_to_dekadal(model) -if dekadal: - logger.info("extending model to dekadal architecture") - model = extend_to_dekadal(model) model.to(device) # print(f"model pos embed shape {model.encoder.pos_embed.shape}") # correctly reinitialized From 1130d5b4cf9237229141c6d3e3b85220c2d3729a Mon Sep 17 00:00:00 2001 From: Giorgia Milli Date: Wed, 20 Nov 2024 11:44:52 +0100 Subject: [PATCH 7/9] code refactoring --- presto/dataset.py | 164 ++++++++++++++------------------------- presto/eval.py | 20 ++--- presto/masking.py | 7 +- presto/utils.py | 9 ++- train_self_supervised.py | 13 ++-- 5 files changed, 85 insertions(+), 128 deletions(-) diff --git a/presto/dataset.py b/presto/dataset.py index faaa3bc..77a7f94 100644 --- a/presto/dataset.py +++ b/presto/dataset.py @@ -126,6 +126,7 @@ def get_timestep_positions( ), f"Valid position {valid_position} not in timestep positions {timestep_positions}" return timestep_positions + def row_to_arrays( self, row: pd.Series, @@ -211,7 +212,7 @@ def row_to_arrays( mask[:, NDVI_INDEX] = np.logical_or(mask[:, S2_RGB_INDEX], mask[:, S2_NIR_10m_INDEX]) return (self.check(eo_data), mask.astype(bool), latlon, month, valid_month) - + def __getitem__(self, idx): # Get the sample row = self.df.iloc[idx, :] @@ -227,6 +228,21 @@ def __getitem__(self, idx): mask_per_variable, ) + def get_month_array(self, row: pd.Series) -> np.ndarray: + start_date, end_date = datetime.strptime(row.start_date, "%Y-%m-%d"), datetime.strptime( + row.end_date, "%Y-%m-%d" + ) + + # Calculate the step size for 10-day intervals and create a list of dates + step = int((end_date - start_date).days / (self.num_timesteps - 1)) + date_vector = [start_date + timedelta(days=i * step) for i in range(self.num_timesteps)] + + # Ensure last date is not beyond the end date + if date_vector[-1] > end_date: + date_vector[-1] = end_date + + return np.array([d.month - 1 for d in date_vector]) + @classmethod def normalize_and_mask(cls, eo: np.ndarray): # TODO: this can be removed @@ -397,71 +413,6 @@ def __getitem__(self, idx): real_mask_per_variable, ) - def get_month_array(self, row: pd.Series) -> np.ndarray: - start_date, end_date = datetime.strptime(row.start_date, "%Y-%m-%d"), datetime.strptime( - row.end_date, "%Y-%m-%d" - ) - - # Calculate the step size for 10-day intervals and create a list of dates - step = int((end_date - start_date).days / (self.num_timesteps - 1)) - date_vector = [start_date + timedelta(days=i * step) for i in range(self.num_timesteps)] - - # Ensure last date is not beyond the end date - if date_vector[-1] > end_date: - date_vector[-1] = end_date - - return np.array([d.month - 1 for d in date_vector]) - - -# class WorldCerealMasked10DDataset(WorldCerealMaskedDataset): -# NUM_TIMESTEPS = 36 - -# def __getitem__(self, idx): -# # Get the sample -# row = self.df.iloc[idx, :] -# eo, real_mask_per_token, latlon, _, _ = self.row_to_arrays( -# row, task_type=self.task_type, croptype_list=self.croptype_list, is_ssl=self.is_ssl -# ) - -# mask_eo, x_eo, y_eo, strat = self.mask_params.mask_data( -# self.normalize_and_mask(eo), -# real_mask_per_token, -# num_timesteps=self.NUM_TIMESTEPS, -# ) -# real_mask_per_variable = np.repeat(real_mask_per_token, BAND_EXPANSION, axis=1) - -# dynamic_world = np.ones(self.NUM_TIMESTEPS) * (DynamicWorld2020_2021.class_amount) -# mask_dw = np.full(self.NUM_TIMESTEPS, True) -# y_dw = dynamic_world.copy() -# return MaskedExample( -# mask_eo, -# mask_dw, -# x_eo, -# y_eo, -# dynamic_world, -# y_dw, -# self.get_month_array(row), ######### -# latlon, -# strat, -# real_mask_per_variable, -# ) - -# @classmethod -# def get_month_array(cls, row: pd.Series) -> np.ndarray: -# start_date, end_date = datetime.strptime(row.start_date, "%Y-%m-%d"), datetime.strptime( -# row.end_date, "%Y-%m-%d" -# ) - -# # Calculate the step size for 10-day intervals and create a list of dates -# step = int((end_date - start_date).days / (cls.NUM_TIMESTEPS - 1)) -# date_vector = [start_date + timedelta(days=i * step) for i in range(cls.NUM_TIMESTEPS)] - -# # Ensure last date is not beyond the end date -# if date_vector[-1] > end_date: -# date_vector[-1] = end_date - -# return np.array([d.month - 1 for d in date_vector]) - def filter_remove_noncrops(df: pd.DataFrame) -> pd.DataFrame: labels_to_exclude = [ @@ -497,6 +448,7 @@ class WorldCerealLabelledDataset(WorldCerealBase): def __init__( self, dataframe: pd.DataFrame, + num_timesteps: int = NUM_TIMESTEPS, countries_to_remove: Optional[List[str]] = None, years_to_remove: Optional[List[int]] = None, balance: bool = False, @@ -504,6 +456,7 @@ def __init__( croptype_list: List = [], return_hierarchical_labels: bool = False, augment: bool = False, + is_dekadal: bool = False, mask_ratio: float = 0.0, ): dataframe = dataframe.loc[~dataframe.LANDCOVER_LABEL.isin(self.FILTER_LABELS)] @@ -524,6 +477,7 @@ def __init__( self.croptype_list = croptype_list self.return_hierarchical_labels = return_hierarchical_labels self.augment = augment + self.is_dekadal = is_dekadal if augment: logger.info( "Augmentation is enabled. \ @@ -540,7 +494,7 @@ def __init__( mask_ratio, ) - super().__init__(dataframe) + super().__init__(dataframe, num_timesteps) if balance: if self.task_type == "cropland": logger.info("Balancing is enabled. Underrepresented class will be upsampled.") @@ -654,9 +608,9 @@ def __getitem__(self, idx): return ( normed_eo, target, - np.ones(self.NUM_TIMESTEPS) * (DynamicWorld2020_2021.class_amount), + np.ones(self.num_timesteps) * (DynamicWorld2020_2021.class_amount), latlon, - month, + self.get_month_array(row) if self.is_dekadal else month, valid_month, mask_per_variable, ) @@ -680,48 +634,48 @@ def class_weights(self) -> np.ndarray: return self._class_weights -class WorldCerealLabelled10DDataset(WorldCerealLabelledDataset): +# class WorldCerealLabelled10DDataset(WorldCerealLabelledDataset): - NUM_TIMESTEPS = 36 +# NUM_TIMESTEPS = 36 - @classmethod - def get_month_array(cls, row: pd.Series) -> np.ndarray: - start_date, end_date = datetime.strptime(row.start_date, "%Y-%m-%d"), datetime.strptime( - row.end_date, "%Y-%m-%d" - ) +# @classmethod +# def get_month_array(cls, row: pd.Series) -> np.ndarray: +# start_date, end_date = datetime.strptime(row.start_date, "%Y-%m-%d"), datetime.strptime( +# row.end_date, "%Y-%m-%d" +# ) - # Calculate the step size for 10-day intervals and create a list of dates - step = int((end_date - start_date).days / (cls.NUM_TIMESTEPS - 1)) - date_vector = [start_date + timedelta(days=i * step) for i in range(cls.NUM_TIMESTEPS)] +# # Calculate the step size for 10-day intervals and create a list of dates +# step = int((end_date - start_date).days / (cls.NUM_TIMESTEPS - 1)) +# date_vector = [start_date + timedelta(days=i * step) for i in range(cls.NUM_TIMESTEPS)] - # Ensure last date is not beyond the end date - if date_vector[-1] > end_date: - date_vector[-1] = end_date +# # Ensure last date is not beyond the end date +# if date_vector[-1] > end_date: +# date_vector[-1] = end_date - return np.array([d.month - 1 for d in date_vector]) +# return np.array([d.month - 1 for d in date_vector]) - def __getitem__(self, idx): - # Get the sample - df_index = self.indices[idx] - row = self.df.iloc[df_index, :] - eo, mask_per_token, latlon, _, valid_month = self.row_to_arrays( - row, self.task_type, self.croptype_list, self.augment - ) - target = self.target_crop( - row, self.task_type, self.croptype_list, self.return_hierarchical_labels - ) - if self.mask_ratio > 0: - mask_per_token, eo, _, _ = self.mask_params.mask_data(eo, mask_per_token) - mask_per_variable = np.repeat(mask_per_token, BAND_EXPANSION, axis=1) - return ( - self.normalize_and_mask(eo), - target, - np.ones(self.NUM_TIMESTEPS) * (DynamicWorld2020_2021.class_amount), - latlon, - self.get_month_array(row), - valid_month, - mask_per_variable, - ) +# def __getitem__(self, idx): +# # Get the sample +# df_index = self.indices[idx] +# row = self.df.iloc[df_index, :] +# eo, mask_per_token, latlon, _, valid_month = self.row_to_arrays( +# row, self.task_type, self.croptype_list, self.augment +# ) +# target = self.target_crop( +# row, self.task_type, self.croptype_list, self.return_hierarchical_labels +# ) +# if self.mask_ratio > 0: +# mask_per_token, eo, _, _ = self.mask_params.mask_data(eo, mask_per_token) +# mask_per_variable = np.repeat(mask_per_token, BAND_EXPANSION, axis=1) +# return ( +# self.normalize_and_mask(eo), +# target, +# np.ones(self.NUM_TIMESTEPS) * (DynamicWorld2020_2021.class_amount), +# latlon, +# self.get_month_array(row), +# valid_month, +# mask_per_variable, +# ) class WorldCerealInferenceDataset(Dataset): diff --git a/presto/eval.py b/presto/eval.py index d2af74a..5e2dc09 100644 --- a/presto/eval.py +++ b/presto/eval.py @@ -21,7 +21,6 @@ from .dataset import ( NORMED_BANDS, WorldCerealInferenceDataset, - WorldCerealLabelled10DDataset, WorldCerealLabelledDataset, ) from .hierarchical_classification import CatBoostClassifierWrapper @@ -58,6 +57,7 @@ def __init__( self, train_data: pd.DataFrame, test_data: pd.DataFrame, + num_timesteps: int = 12, countries_to_remove: Optional[List[str]] = None, years_to_remove: Optional[List[int]] = None, spatial_inference_savedir: Optional[Path] = None, @@ -67,7 +67,6 @@ def __init__( val_size: float = 0.2, dekadal: bool = False, task_type: str = "cropland", - num_outputs: int = 1, croptype_list: List = [], finetune_classes: str = "CROPTYPE0", downstream_classes: str = "CROPTYPE9", @@ -78,6 +77,7 @@ def __init__( ): self.seed = seed self.task_type = task_type + self.num_timesteps = num_timesteps self.name = f"WorldCereal{task_type.title()}" train_data, val_data = WorldCerealLabelledDataset.split_df(train_data, val_size=val_size) @@ -112,12 +112,12 @@ def __init__( self.num_outputs = len(train_classes) # use classes obtained from train to trim val and test classes - self.val_df.loc[ - ~self.val_df[class_column].isin(train_classes), class_column - ] = "other_crop" - self.test_df.loc[ - ~self.test_df[class_column].isin(train_classes), class_column - ] = "other_crop" + self.val_df.loc[~self.val_df[class_column].isin(train_classes), class_column] = ( + "other_crop" + ) + self.test_df.loc[~self.test_df[class_column].isin(train_classes), class_column] = ( + "other_crop" + ) # create one-hot representation from obtained labels # one-hot is needed for finetuning, @@ -151,7 +151,7 @@ def __init__( self.dekadal = dekadal self.balance = balance - self.ds_class = WorldCerealLabelled10DDataset if dekadal else WorldCerealLabelledDataset + self.ds_class = WorldCerealLabelledDataset self.train_masking = train_masking self.augment = augment self.use_valid_month = use_valid_month @@ -252,6 +252,7 @@ def dataloader_to_encodings_and_targets( dl = DataLoader( self.ds_class( self.train_df, + self.num_timesteps, countries_to_remove=self.countries_to_remove, years_to_remove=self.years_to_remove, task_type=self.task_type, @@ -266,6 +267,7 @@ def dataloader_to_encodings_and_targets( val_dl = DataLoader( self.ds_class( self.val_df, + self.num_timesteps, countries_to_remove=self.countries_to_remove, years_to_remove=self.years_to_remove, task_type=self.task_type, diff --git a/presto/masking.py b/presto/masking.py index 5b277b4..cd95454 100644 --- a/presto/masking.py +++ b/presto/masking.py @@ -47,8 +47,11 @@ def make_mask_no_dw( num_tokens_to_mask = int( ((num_timesteps * (len(BANDS_GROUPS_IDX) - 1)) + 1) * mask_ratio - ) # - sum(sum(mask)) - assert num_tokens_to_mask > 0 + ) - sum(sum(mask)) + # assert num_tokens_to_mask > 0 + if num_tokens_to_mask <= 0: + mask[:, SRTM_INDEX] = srtm_mask + return np.repeat(mask, BAND_EXPANSION, axis=1) def mask_topography(srtm_mask, num_tokens_to_mask, mask_ratio): should_flip = random() < mask_ratio diff --git a/presto/utils.py b/presto/utils.py index 912752e..e83e811 100644 --- a/presto/utils.py +++ b/presto/utils.py @@ -36,7 +36,8 @@ data_dir = Path(__file__).parent.parent / "data" config_dir = Path(__file__).parent.parent / "config" default_model_path = data_dir / "default_model.pt" -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +device = torch.device("cpu") DEFAULT_SEED: int = 42 @@ -191,8 +192,8 @@ def process_parquet(df: pd.DataFrame, num_ts: int = 12) -> pd.DataFrame: "DEM-alt-20m", "DEM-slo-20m", "LANDCOVER_LABEL", - "POTAPOV-LABEL-10m", - "WORLDCOVER-LABEL-10m", + # "POTAPOV-LABEL-10m", + # "WORLDCOVER-LABEL-10m", "aez_zoneid", "end_date", "lat", @@ -200,7 +201,7 @@ def process_parquet(df: pd.DataFrame, num_ts: int = 12) -> pd.DataFrame: "start_date", "sample_id", "valid_date", - "location_id", + # "location_id", "ref_id", "valid_position", "available_timesteps", diff --git a/train_self_supervised.py b/train_self_supervised.py index 0a6e968..3788228 100644 --- a/train_self_supervised.py +++ b/train_self_supervised.py @@ -60,7 +60,7 @@ default="/home/vito/millig/projects/TAP/worldcereal/data/worldcereal_training_data.parquet", ) argparser.add_argument("--seed", type=int, default=DEFAULT_SEED) -argparser.add_argument("--num_workers", type=int, default=64) +argparser.add_argument("--num_workers", type=int, default=8) argparser.add_argument("--wandb", dest="wandb", action="store_true") argparser.add_argument("--wandb_org", type=str, default="nasa-harvest") @@ -140,7 +140,7 @@ ) run_id = cast(wandb.sdk.wandb_run.Run, run).id -model_logging_dir = output_parent_dir / "output" / timestamp_dirname(run_id) +model_logging_dir = output_parent_dir / "models" / model_name / timestamp_dirname(run_id) model_logging_dir.mkdir(exist_ok=True, parents=True) initialize_logging(model_logging_dir) logger.info("Using output dir: %s" % model_logging_dir) @@ -158,7 +158,7 @@ logger.info("Loading data") -files = sorted(glob(f"{parquet_file}/**/*.parquet"))[:10] +files = sorted(glob(f"{parquet_file}/**/*.parquet")) # [:10] df_list = [] for f in tqdm(files): _data = pd.read_parquet(f, engine="fastparquet") @@ -178,12 +178,11 @@ val_samples_df = pd.read_csv(data_dir / "test_splits" / val_samples_file) train_df, val_df = WorldCerealBase.split_df( - df, val_size=0.1 -) # val_sample_ids=val_samples_df.sample_id.tolist()) + df, val_sample_ids=val_samples_df.sample_id.tolist() +) # val_size=0.1 # Load the mask parameters mask_params = MaskParamsNoDw(mask_strategies, mask_ratio, num_timesteps=36 if dekadal else 12) -# masked_ds = WorldCerealMasked10DDataset if dekadal else WorldCerealMaskedDataset train_dataloader = DataLoader( WorldCerealMaskedDataset( @@ -229,13 +228,11 @@ model_kwargs = json.load(Path(path_to_config).open("r")) model = Presto.construct(**model_kwargs) best_model_path = None - # moved into the else block ################## if dekadal: logger.info("extending model to dekadal architecture") model = extend_to_dekadal(model) model.to(device) -# print(f"model pos embed shape {model.encoder.pos_embed.shape}") # correctly reinitialized param_groups = param_groups_weight_decay(model, weight_decay) optimizer = optim.AdamW(param_groups, lr=max_learning_rate, betas=(0.9, 0.95)) From 52a3b484d1b4087c27db01700607587d5eb80686 Mon Sep 17 00:00:00 2001 From: Giorgia Milli Date: Mon, 25 Nov 2024 14:05:24 +0100 Subject: [PATCH 8/9] ignore slurm output files --- .gitignore | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index c242713..421f925 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,6 @@ scrap output/* imgs/* # don't track catboost training info -catboost_info \ No newline at end of file +catboost_info +# ignore slurm files +slurm* \ No newline at end of file From 9b66db7ac9939a98f2e079a78e13e7b2c1a4e941 Mon Sep 17 00:00:00 2001 From: Giorgia Milli Date: Mon, 25 Nov 2024 14:11:17 +0100 Subject: [PATCH 9/9] code refactoring --- presto/dataset.py | 52 +++-------------------------------------------- 1 file changed, 3 insertions(+), 49 deletions(-) diff --git a/presto/dataset.py b/presto/dataset.py index 77a7f94..507a8b1 100644 --- a/presto/dataset.py +++ b/presto/dataset.py @@ -126,7 +126,6 @@ def get_timestep_positions( ), f"Valid position {valid_position} not in timestep positions {timestep_positions}" return timestep_positions - def row_to_arrays( self, row: pd.Series, @@ -212,7 +211,7 @@ def row_to_arrays( mask[:, NDVI_INDEX] = np.logical_or(mask[:, S2_RGB_INDEX], mask[:, S2_NIR_10m_INDEX]) return (self.check(eo_data), mask.astype(bool), latlon, month, valid_month) - + def __getitem__(self, idx): # Get the sample row = self.df.iloc[idx, :] @@ -242,7 +241,7 @@ def get_month_array(self, row: pd.Series) -> np.ndarray: date_vector[-1] = end_date return np.array([d.month - 1 for d in date_vector]) - + @classmethod def normalize_and_mask(cls, eo: np.ndarray): # TODO: this can be removed @@ -634,53 +633,8 @@ def class_weights(self) -> np.ndarray: return self._class_weights -# class WorldCerealLabelled10DDataset(WorldCerealLabelledDataset): - -# NUM_TIMESTEPS = 36 - -# @classmethod -# def get_month_array(cls, row: pd.Series) -> np.ndarray: -# start_date, end_date = datetime.strptime(row.start_date, "%Y-%m-%d"), datetime.strptime( -# row.end_date, "%Y-%m-%d" -# ) - -# # Calculate the step size for 10-day intervals and create a list of dates -# step = int((end_date - start_date).days / (cls.NUM_TIMESTEPS - 1)) -# date_vector = [start_date + timedelta(days=i * step) for i in range(cls.NUM_TIMESTEPS)] - -# # Ensure last date is not beyond the end date -# if date_vector[-1] > end_date: -# date_vector[-1] = end_date - -# return np.array([d.month - 1 for d in date_vector]) - -# def __getitem__(self, idx): -# # Get the sample -# df_index = self.indices[idx] -# row = self.df.iloc[df_index, :] -# eo, mask_per_token, latlon, _, valid_month = self.row_to_arrays( -# row, self.task_type, self.croptype_list, self.augment -# ) -# target = self.target_crop( -# row, self.task_type, self.croptype_list, self.return_hierarchical_labels -# ) -# if self.mask_ratio > 0: -# mask_per_token, eo, _, _ = self.mask_params.mask_data(eo, mask_per_token) -# mask_per_variable = np.repeat(mask_per_token, BAND_EXPANSION, axis=1) -# return ( -# self.normalize_and_mask(eo), -# target, -# np.ones(self.NUM_TIMESTEPS) * (DynamicWorld2020_2021.class_amount), -# latlon, -# self.get_month_array(row), -# valid_month, -# mask_per_variable, -# ) - - class WorldCerealInferenceDataset(Dataset): - # _NODATAVALUE = 65535 - # Y = "worldcereal_cropland" + BAND_MAPPING = { "B02": "B2", "B03": "B3",