diff --git a/.gitignore b/.gitignore index c242713..421f925 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,6 @@ scrap output/* imgs/* # don't track catboost training info -catboost_info \ No newline at end of file +catboost_info +# ignore slurm files +slurm* \ No newline at end of file diff --git a/presto/dataset.py b/presto/dataset.py index affe16f..507a8b1 100644 --- a/presto/dataset.py +++ b/presto/dataset.py @@ -20,6 +20,7 @@ NDVI_INDEX, NODATAVALUE, NORMED_BANDS, + NUM_TIMESTEPS, S1_S2_ERA5_SRTM, S2_RGB_INDEX, DynamicWorld2020_2021, @@ -38,7 +39,6 @@ class WorldCerealBase(Dataset): # _NODATAVALUE = 65535 - NUM_TIMESTEPS = 12 BAND_MAPPING = { "OPTICAL-B02-ts{}-10m": "B2", "OPTICAL-B03-ts{}-10m": "B3", @@ -57,26 +57,27 @@ class WorldCerealBase(Dataset): } STATIC_BAND_MAPPING = {"DEM-alt-20m": "elevation", "DEM-slo-20m": "slope"} - def __init__(self, dataframe: pd.DataFrame): + def __init__(self, dataframe: pd.DataFrame, num_timesteps: int = NUM_TIMESTEPS): self.df = dataframe + self.num_timesteps = num_timesteps def __len__(self): return self.df.shape[0] - @classmethod def get_timestep_positions( - cls, row_d: Dict, augment: bool = False, is_ssl: bool = False + self, row_d: Dict, augment: bool = False, is_ssl: bool = False ) -> List[int]: available_timesteps = int(row_d["available_timesteps"]) if is_ssl: - if available_timesteps == cls.NUM_TIMESTEPS: - valid_position = int(cls.NUM_TIMESTEPS // 2) + if available_timesteps == self.num_timesteps: + valid_position = int(self.num_timesteps // 2) else: valid_position = int( np.random.choice( range( - cls.NUM_TIMESTEPS // 2, (available_timesteps - cls.NUM_TIMESTEPS // 2) + self.num_timesteps // 2, + (available_timesteps - self.num_timesteps // 2), ), 1, ) @@ -86,11 +87,11 @@ def get_timestep_positions( valid_position = int(row_d["valid_position"]) if not augment: # check if the valid position is too close to the start_date and force shifting it - if valid_position < cls.NUM_TIMESTEPS // 2: - center_point = cls.NUM_TIMESTEPS // 2 + if valid_position < self.num_timesteps // 2: + center_point = self.num_timesteps // 2 # or too close to the end_date - elif valid_position > (available_timesteps - cls.NUM_TIMESTEPS // 2): - center_point = available_timesteps - cls.NUM_TIMESTEPS // 2 + elif valid_position > (available_timesteps - self.num_timesteps // 2): + center_point = available_timesteps - self.num_timesteps // 2 else: # Center the timesteps around the valid position center_point = valid_position @@ -99,35 +100,34 @@ def get_timestep_positions( # well includes the valid position min_center_point = max( - cls.NUM_TIMESTEPS // 2, - valid_position + MIN_EDGE_BUFFER - cls.NUM_TIMESTEPS // 2, + self.num_timesteps // 2, + valid_position + MIN_EDGE_BUFFER - self.num_timesteps // 2, ) max_center_point = min( - available_timesteps - cls.NUM_TIMESTEPS // 2, - valid_position - MIN_EDGE_BUFFER + cls.NUM_TIMESTEPS // 2, + available_timesteps - self.num_timesteps // 2, + valid_position - MIN_EDGE_BUFFER + self.num_timesteps // 2, ) center_point = np.random.randint( min_center_point, max_center_point + 1 ) # max_center_point included - last_timestep = min(available_timesteps, center_point + cls.NUM_TIMESTEPS // 2) - first_timestep = max(0, last_timestep - cls.NUM_TIMESTEPS) + last_timestep = min(available_timesteps, center_point + self.num_timesteps // 2) + first_timestep = max(0, last_timestep - self.num_timesteps) timestep_positions = list(range(first_timestep, last_timestep)) - if len(timestep_positions) != cls.NUM_TIMESTEPS: + if len(timestep_positions) != self.num_timesteps: raise ValueError( f"Acquired timestep positions do not have correct length: \ -required {cls.NUM_TIMESTEPS}, got {len(timestep_positions)}" +required {self.num_timesteps}, got {len(timestep_positions)}" ) assert ( valid_position in timestep_positions ), f"Valid position {valid_position} not in timestep positions {timestep_positions}" return timestep_positions - @classmethod def row_to_arrays( - cls, + self, row: pd.Series, task_type: str = "cropland", croptype_list: List = [], @@ -140,11 +140,11 @@ def row_to_arrays( latlon = np.array([row_d["lat"], row_d["lon"]], dtype=np.float32) - timestep_positions = cls.get_timestep_positions(row_d, augment=augment, is_ssl=is_ssl) + timestep_positions = self.get_timestep_positions(row_d, augment=augment, is_ssl=is_ssl) - if cls.NUM_TIMESTEPS == 12: + if self.num_timesteps == 12: initial_start_date_position = pd.to_datetime(row_d["start_date"]).month - elif cls.NUM_TIMESTEPS > 12: + elif self.num_timesteps > 12: # get the correct index of the start_date based on NUM_TIMESTEPS` # e.g. if NUM_TIMESTEPS is 36 (dekadal setup), we should take the correct # 10-day interval that the start_date falls into @@ -153,18 +153,18 @@ def row_to_arrays( # should also be changed accordingly year = pd.to_datetime(row_d["start_date"]).year year_dates = pd.date_range(start=f"{year}-01-01", end=f"{year}-12-31") - bins = pd.cut(year_dates, bins=cls.NUM_TIMESTEPS, labels=False) + bins = pd.cut(year_dates, bins=self.num_timesteps, labels=False) initial_start_date_position = bins[ np.where(year_dates == pd.to_datetime(row_d["start_date"]))[0][0] ] else: raise ValueError( - f"NUM_TIMESTEPS must be at least 12. Currently it is {cls.NUM_TIMESTEPS}" + f"NUM_TIMESTEPS must be at least 12. Currently it is {self.num_timesteps}" ) # make sure that month for encoding gets shifted according to # the selected timestep positions. Also ensure circular indexing - month = (initial_start_date_position - 1 + timestep_positions[0]) % cls.NUM_TIMESTEPS + month = (initial_start_date_position - 1 + timestep_positions[0]) % self.num_timesteps # adding workaround for compatibility between Phase I and Phase II datasets. # (in Phase II, the relevant attribute name was changed to valid_time) @@ -176,12 +176,12 @@ def row_to_arrays( else: logger.error("Dataset does not contain neither valid_date, nor valid_time attribute.") - eo_data = np.zeros((cls.NUM_TIMESTEPS, len(BANDS))) + eo_data = np.zeros((self.num_timesteps, len(BANDS))) # an assumption we make here is that all timesteps for a token # have the same masking - mask = np.zeros((cls.NUM_TIMESTEPS, len(BANDS_GROUPS_IDX))) + mask = np.zeros((self.num_timesteps, len(BANDS_GROUPS_IDX))) - for df_val, presto_val in cls.BAND_MAPPING.items(): + for df_val, presto_val in self.BAND_MAPPING.items(): values = np.array([float(row_d[df_val.format(t)]) for t in timestep_positions]) # this occurs for the DEM values in one point in Fiji values = np.nan_to_num(values, nan=NODATAVALUE) @@ -199,7 +199,7 @@ def row_to_arrays( mask[:, IDX_TO_BAND_GROUPS[presto_val]] += ~idx_valid eo_data[:, BANDS.index(presto_val)] = values * idx_valid - for df_val, presto_val in cls.STATIC_BAND_MAPPING.items(): + for df_val, presto_val in self.STATIC_BAND_MAPPING.items(): # this occurs for the DEM values in one point in Fiji values = np.nan_to_num(row_d[df_val], nan=NODATAVALUE) idx_valid = values != NODATAVALUE @@ -210,7 +210,7 @@ def row_to_arrays( # or nir mask, and adjust the NDVI mask accordingly mask[:, NDVI_INDEX] = np.logical_or(mask[:, S2_RGB_INDEX], mask[:, S2_NIR_10m_INDEX]) - return (cls.check(eo_data), mask.astype(bool), latlon, month, valid_month) + return (self.check(eo_data), mask.astype(bool), latlon, month, valid_month) def __getitem__(self, idx): # Get the sample @@ -220,13 +220,28 @@ def __getitem__(self, idx): mask_per_variable = np.repeat(mask_per_token, BAND_EXPANSION, axis=1) return ( self.normalize_and_mask(eo), - np.ones(self.NUM_TIMESTEPS) * (DynamicWorld2020_2021.class_amount), + np.ones(self.num_timesteps) * (DynamicWorld2020_2021.class_amount), latlon, month, valid_month, mask_per_variable, ) + def get_month_array(self, row: pd.Series) -> np.ndarray: + start_date, end_date = datetime.strptime(row.start_date, "%Y-%m-%d"), datetime.strptime( + row.end_date, "%Y-%m-%d" + ) + + # Calculate the step size for 10-day intervals and create a list of dates + step = int((end_date - start_date).days / (self.num_timesteps - 1)) + date_vector = [start_date + timedelta(days=i * step) for i in range(self.num_timesteps)] + + # Ensure last date is not beyond the end date + if date_vector[-1] > end_date: + date_vector[-1] = end_date + + return np.array([d.month - 1 for d in date_vector]) + @classmethod def normalize_and_mask(cls, eo: np.ndarray): # TODO: this can be removed @@ -348,19 +363,24 @@ def split_df( class WorldCerealMaskedDataset(WorldCerealBase): + def __init__( self, dataframe: pd.DataFrame, + num_timesteps: int, mask_params: MaskParamsNoDw, task_type: str = "cropland", croptype_list: List = [], is_ssl: bool = True, + is_dekadal: bool = False, ): - super().__init__(dataframe) + super().__init__(dataframe, num_timesteps) + self.num_timesteps = num_timesteps self.mask_params = mask_params self.task_type = task_type self.croptype_list = croptype_list self.is_ssl = is_ssl + self.is_dekadal = is_dekadal def __getitem__(self, idx): # Get the sample @@ -370,13 +390,15 @@ def __getitem__(self, idx): ) mask_eo, x_eo, y_eo, strat = self.mask_params.mask_data( - self.normalize_and_mask(eo), real_mask_per_token + self.normalize_and_mask(eo), + real_mask_per_token, ) real_mask_per_variable = np.repeat(real_mask_per_token, BAND_EXPANSION, axis=1) - dynamic_world = np.ones(self.NUM_TIMESTEPS) * (DynamicWorld2020_2021.class_amount) - mask_dw = np.full(self.NUM_TIMESTEPS, True) + dynamic_world = np.ones(self.num_timesteps) * (DynamicWorld2020_2021.class_amount) + mask_dw = np.full(self.num_timesteps, True) y_dw = dynamic_world.copy() + return MaskedExample( mask_eo, mask_dw, @@ -384,7 +406,7 @@ def __getitem__(self, idx): y_eo, dynamic_world, y_dw, - month, + self.get_month_array(row) if self.is_dekadal else month, latlon, strat, real_mask_per_variable, @@ -425,6 +447,7 @@ class WorldCerealLabelledDataset(WorldCerealBase): def __init__( self, dataframe: pd.DataFrame, + num_timesteps: int = NUM_TIMESTEPS, countries_to_remove: Optional[List[str]] = None, years_to_remove: Optional[List[int]] = None, balance: bool = False, @@ -432,6 +455,7 @@ def __init__( croptype_list: List = [], return_hierarchical_labels: bool = False, augment: bool = False, + is_dekadal: bool = False, mask_ratio: float = 0.0, ): dataframe = dataframe.loc[~dataframe.LANDCOVER_LABEL.isin(self.FILTER_LABELS)] @@ -452,6 +476,7 @@ def __init__( self.croptype_list = croptype_list self.return_hierarchical_labels = return_hierarchical_labels self.augment = augment + self.is_dekadal = is_dekadal if augment: logger.info( "Augmentation is enabled. \ @@ -468,7 +493,7 @@ def __init__( mask_ratio, ) - super().__init__(dataframe) + super().__init__(dataframe, num_timesteps) if balance: if self.task_type == "cropland": logger.info("Balancing is enabled. Underrepresented class will be upsampled.") @@ -582,9 +607,9 @@ def __getitem__(self, idx): return ( normed_eo, target, - np.ones(self.NUM_TIMESTEPS) * (DynamicWorld2020_2021.class_amount), + np.ones(self.num_timesteps) * (DynamicWorld2020_2021.class_amount), latlon, - month, + self.get_month_array(row) if self.is_dekadal else month, valid_month, mask_per_variable, ) @@ -608,53 +633,8 @@ def class_weights(self) -> np.ndarray: return self._class_weights -class WorldCerealLabelled10DDataset(WorldCerealLabelledDataset): - - NUM_TIMESTEPS = 36 - - @classmethod - def get_month_array(cls, row: pd.Series) -> np.ndarray: - start_date, end_date = datetime.strptime(row.start_date, "%Y-%m-%d"), datetime.strptime( - row.end_date, "%Y-%m-%d" - ) - - # Calculate the step size for 10-day intervals and create a list of dates - step = int((end_date - start_date).days / (cls.NUM_TIMESTEPS - 1)) - date_vector = [start_date + timedelta(days=i * step) for i in range(cls.NUM_TIMESTEPS)] - - # Ensure last date is not beyond the end date - if date_vector[-1] > end_date: - date_vector[-1] = end_date - - return np.array([d.month - 1 for d in date_vector]) - - def __getitem__(self, idx): - # Get the sample - df_index = self.indices[idx] - row = self.df.iloc[df_index, :] - eo, mask_per_token, latlon, _, valid_month = self.row_to_arrays( - row, self.task_type, self.croptype_list, self.augment - ) - target = self.target_crop( - row, self.task_type, self.croptype_list, self.return_hierarchical_labels - ) - if self.mask_ratio > 0: - mask_per_token, eo, _, _ = self.mask_params.mask_data(eo, mask_per_token) - mask_per_variable = np.repeat(mask_per_token, BAND_EXPANSION, axis=1) - return ( - self.normalize_and_mask(eo), - target, - np.ones(self.NUM_TIMESTEPS) * (DynamicWorld2020_2021.class_amount), - latlon, - self.get_month_array(row), - valid_month, - mask_per_variable, - ) - - class WorldCerealInferenceDataset(Dataset): - # _NODATAVALUE = 65535 - # Y = "worldcereal_cropland" + BAND_MAPPING = { "B02": "B2", "B03": "B3", @@ -818,9 +798,7 @@ def _subset_array_temporally(inarr: xr.DataArray) -> xr.DataArray: return inarr @classmethod - def nc_to_arrays( - cls, filepath: Path - ) -> Tuple[ + def nc_to_arrays(cls, filepath: Path) -> Tuple[ np.ndarray, np.ndarray, np.ndarray, diff --git a/presto/eval.py b/presto/eval.py index d2af74a..5e2dc09 100644 --- a/presto/eval.py +++ b/presto/eval.py @@ -21,7 +21,6 @@ from .dataset import ( NORMED_BANDS, WorldCerealInferenceDataset, - WorldCerealLabelled10DDataset, WorldCerealLabelledDataset, ) from .hierarchical_classification import CatBoostClassifierWrapper @@ -58,6 +57,7 @@ def __init__( self, train_data: pd.DataFrame, test_data: pd.DataFrame, + num_timesteps: int = 12, countries_to_remove: Optional[List[str]] = None, years_to_remove: Optional[List[int]] = None, spatial_inference_savedir: Optional[Path] = None, @@ -67,7 +67,6 @@ def __init__( val_size: float = 0.2, dekadal: bool = False, task_type: str = "cropland", - num_outputs: int = 1, croptype_list: List = [], finetune_classes: str = "CROPTYPE0", downstream_classes: str = "CROPTYPE9", @@ -78,6 +77,7 @@ def __init__( ): self.seed = seed self.task_type = task_type + self.num_timesteps = num_timesteps self.name = f"WorldCereal{task_type.title()}" train_data, val_data = WorldCerealLabelledDataset.split_df(train_data, val_size=val_size) @@ -112,12 +112,12 @@ def __init__( self.num_outputs = len(train_classes) # use classes obtained from train to trim val and test classes - self.val_df.loc[ - ~self.val_df[class_column].isin(train_classes), class_column - ] = "other_crop" - self.test_df.loc[ - ~self.test_df[class_column].isin(train_classes), class_column - ] = "other_crop" + self.val_df.loc[~self.val_df[class_column].isin(train_classes), class_column] = ( + "other_crop" + ) + self.test_df.loc[~self.test_df[class_column].isin(train_classes), class_column] = ( + "other_crop" + ) # create one-hot representation from obtained labels # one-hot is needed for finetuning, @@ -151,7 +151,7 @@ def __init__( self.dekadal = dekadal self.balance = balance - self.ds_class = WorldCerealLabelled10DDataset if dekadal else WorldCerealLabelledDataset + self.ds_class = WorldCerealLabelledDataset self.train_masking = train_masking self.augment = augment self.use_valid_month = use_valid_month @@ -252,6 +252,7 @@ def dataloader_to_encodings_and_targets( dl = DataLoader( self.ds_class( self.train_df, + self.num_timesteps, countries_to_remove=self.countries_to_remove, years_to_remove=self.years_to_remove, task_type=self.task_type, @@ -266,6 +267,7 @@ def dataloader_to_encodings_and_targets( val_dl = DataLoader( self.ds_class( self.val_df, + self.num_timesteps, countries_to_remove=self.countries_to_remove, years_to_remove=self.years_to_remove, task_type=self.task_type, diff --git a/presto/masking.py b/presto/masking.py index ba569b9..cd95454 100644 --- a/presto/masking.py +++ b/presto/masking.py @@ -5,13 +5,7 @@ import numpy as np -from .dataops import ( - BAND_EXPANSION, - BANDS_GROUPS_IDX, - NUM_TIMESTEPS, - SRTM_INDEX, - TIMESTEPS_IDX, -) +from .dataops import BAND_EXPANSION, BANDS_GROUPS_IDX, NUM_TIMESTEPS, SRTM_INDEX MASK_STRATEGIES = ( "group_bands", @@ -54,7 +48,10 @@ def make_mask_no_dw( num_tokens_to_mask = int( ((num_timesteps * (len(BANDS_GROUPS_IDX) - 1)) + 1) * mask_ratio ) - sum(sum(mask)) - assert num_tokens_to_mask > 0 + # assert num_tokens_to_mask > 0 + if num_tokens_to_mask <= 0: + mask[:, SRTM_INDEX] = srtm_mask + return np.repeat(mask, BAND_EXPANSION, axis=1) def mask_topography(srtm_mask, num_tokens_to_mask, mask_ratio): should_flip = random() < mask_ratio @@ -106,7 +103,7 @@ def random_masking(mask, num_tokens_to_mask: int): # -1 for SRTM timesteps_to_mask = int(num_tokens_to_mask / (len(BANDS_GROUPS_IDX) - 1)) max_tokens_masked = (len(BANDS_GROUPS_IDX) - 1) * timesteps_to_mask - timesteps = sample(TIMESTEPS_IDX, k=timesteps_to_mask) + timesteps = sample(range(num_timesteps), k=timesteps_to_mask) if timesteps_to_mask > 0: num_tokens_to_mask -= int(max_tokens_masked - sum(sum(mask[timesteps]))) mask[timesteps] = True @@ -145,14 +142,14 @@ def __post_init__(self): "random_combinations", ] - def mask_data(self, eo_data: np.ndarray, mask: np.ndarray, num_timesteps: int = NUM_TIMESTEPS): + def mask_data(self, eo_data: np.ndarray, mask: np.ndarray): strategy = choice(self.strategies) mask = make_mask_no_dw( strategy=strategy, mask_ratio=self.ratio, existing_mask=mask, - num_timesteps=num_timesteps, + num_timesteps=self.num_timesteps, ) x = eo_data * ~mask diff --git a/presto/presto.py b/presto/presto.py index ed3cd85..c76d465 100644 --- a/presto/presto.py +++ b/presto/presto.py @@ -860,8 +860,8 @@ def load_pretrained( valid_month_size=valid_month_size, ) - if dekadal: - model = extend_to_dekadal(model) + # if dekadal: + # model = extend_to_dekadal(model) if is_finetuned: model = model.construct_finetuning_model(num_outputs) @@ -872,6 +872,9 @@ def load_pretrained( else: model.load_state_dict(torch.load(model_path, map_location=device), strict=strict) + if dekadal: + model = extend_to_dekadal(model) + return model @classmethod diff --git a/presto/utils.py b/presto/utils.py index 8715b56..e83e811 100644 --- a/presto/utils.py +++ b/presto/utils.py @@ -36,7 +36,8 @@ data_dir = Path(__file__).parent.parent / "data" config_dir = Path(__file__).parent.parent / "config" default_model_path = data_dir / "default_model.pt" -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +device = torch.device("cpu") DEFAULT_SEED: int = 42 @@ -99,7 +100,7 @@ def get_class_mappings() -> Dict: return CLASS_MAPPINGS -def process_parquet(df: pd.DataFrame) -> pd.DataFrame: +def process_parquet(df: pd.DataFrame, num_ts: int = 12) -> pd.DataFrame: """ This function takes in a DataFrame with S1, S2 and ERA5 observations and their respective dates in long format and returns it in wide format. @@ -131,7 +132,7 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: - computing the number of available timesteps in the timeseries that takes into account updated start_date and end_date; available_timesteps holds the absolute number of timesteps that for which observations are - available; it cannot be less than NUM_TIMESTEPS; if this is the case, + available; it cannot be less than num_ts; if this is the case, sample is considered faulty and is removed from the dataset - post-processing with prep_dataframe function @@ -191,8 +192,8 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: "DEM-alt-20m", "DEM-slo-20m", "LANDCOVER_LABEL", - "POTAPOV-LABEL-10m", - "WORLDCOVER-LABEL-10m", + # "POTAPOV-LABEL-10m", + # "WORLDCOVER-LABEL-10m", "aez_zoneid", "end_date", "lat", @@ -200,8 +201,10 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: "start_date", "sample_id", "valid_date", - "location_id", + # "location_id", "ref_id", + "valid_position", + "available_timesteps", ] bands10m = ["OPTICAL-B02", "OPTICAL-B03", "OPTICAL-B04", "OPTICAL-B08"] @@ -217,18 +220,17 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: ] bands100m = ["METEO-precipitation_flux", "METEO-temperature_mean"] - df["timestamp_ind"] = (df["timestamp"].dt.year * 12 + df["timestamp"].dt.month) - ( - df["start_date"].dt.year * 12 + df["start_date"].dt.month - ) - df["valid_position"] = (df["valid_date"].dt.year * 12 + df["valid_date"].dt.month) - ( - df["start_date"].dt.year * 12 + df["start_date"].dt.month + df["timestamp_ind"] = df.groupby("sample_id")["timestamp"].rank().astype(int) + df["valid_date_ts_diff_days"] = (df["valid_date"] - df["timestamp"]).dt.days.abs() + valid_position = ( + df.set_index("timestamp_ind").groupby("sample_id")["valid_date_ts_diff_days"].idxmin() ) + df["valid_position"] = df["sample_id"].map(valid_position) df["valid_position_diff"] = df["timestamp_ind"] - df["valid_position"] # save the initial start_date for later df["initial_start_date"] = df["start_date"].copy() index_columns.append("initial_start_date") - # define samples where valid_date is outside the range of the actual extractions # and remove them from the dataset latest_obs_position = df.groupby(["sample_id"])[ @@ -295,8 +297,9 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: df["end_date"] = df["sample_id"].map(new_end_date) # reinitialize timestep_ind - df["timestamp_ind"] = (df["timestamp"].dt.year * 12 + df["timestamp"].dt.month) - ( - df["start_date"].dt.year * 12 + df["start_date"].dt.month + df["timestamp_ind"] = df.groupby("sample_id")["timestamp"].rank().astype(int) + df["available_timesteps"] = df["sample_id"].map( + df.groupby("sample_id")["timestamp"].nunique().astype(int) ) # check for missing timestamps in the middle of timeseries @@ -335,22 +338,13 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: f"{xx}-100m" if any(band in xx for band in bands100m) else xx for xx in df_pivot.columns ] # type: ignore - df_pivot["valid_position"] = ( - df_pivot["valid_date"].dt.year * 12 + df_pivot["valid_date"].dt.month - ) - (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month) - df_pivot["available_timesteps"] = ( - (df_pivot["end_date"].dt.year * 12 + df_pivot["end_date"].dt.month) - - (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month) - + 1 - ) - min_center_point = np.maximum( - NUM_TIMESTEPS // 2, - df_pivot["valid_position"] + MIN_EDGE_BUFFER - NUM_TIMESTEPS // 2, + num_ts // 2, + df_pivot["valid_position"] + MIN_EDGE_BUFFER - num_ts // 2, ) max_center_point = np.minimum( - df_pivot["available_timesteps"] - NUM_TIMESTEPS // 2, - df_pivot["valid_position"] - MIN_EDGE_BUFFER + NUM_TIMESTEPS // 2, + df_pivot["available_timesteps"] - num_ts // 2, + df_pivot["valid_position"] - MIN_EDGE_BUFFER + num_ts // 2, ) faulty_samples = min_center_point > max_center_point @@ -358,11 +352,11 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: logger.warning(f"Dropping {faulty_samples.sum()} faulty samples.") df_pivot = df_pivot[~faulty_samples] - samples_with_too_few_ts = df_pivot["available_timesteps"] < NUM_TIMESTEPS + samples_with_too_few_ts = df_pivot["available_timesteps"] < num_ts if samples_with_too_few_ts.sum() > 0: logger.warning( f"Dropping {samples_with_too_few_ts.sum()} samples with \ -number of available timesteps less than {NUM_TIMESTEPS}." +number of available timesteps less than {num_ts}." ) df_pivot = df_pivot[~samples_with_too_few_ts] diff --git a/train_self_supervised.py b/train_self_supervised.py index a16f31f..3788228 100644 --- a/train_self_supervised.py +++ b/train_self_supervised.py @@ -15,7 +15,10 @@ # import xarray as xr from presto.dataops import BANDS_GROUPS_IDX, NODATAVALUE -from presto.dataset import WorldCerealBase, WorldCerealMaskedDataset +from presto.dataset import ( # WorldCerealMasked10DDataset, + WorldCerealBase, + WorldCerealMaskedDataset, +) from presto.masking import MASK_STRATEGIES, MaskParamsNoDw from presto.presto import ( LossWrapper, @@ -36,6 +39,8 @@ timestamp_dirname, ) +torch.multiprocessing.set_sharing_strategy("file_system") + logger = logging.getLogger("__main__") argparser = argparse.ArgumentParser() @@ -49,8 +54,13 @@ "and /output/ will be written to. " "Leave empty to use the directory you are running this file from.", ) +argparser.add_argument( + "--parquet_file", + type=str, + default="/home/vito/millig/projects/TAP/worldcereal/data/worldcereal_training_data.parquet", +) argparser.add_argument("--seed", type=int, default=DEFAULT_SEED) -argparser.add_argument("--num_workers", type=int, default=64) +argparser.add_argument("--num_workers", type=int, default=8) argparser.add_argument("--wandb", dest="wandb", action="store_true") argparser.add_argument("--wandb_org", type=str, default="nasa-harvest") @@ -81,6 +91,7 @@ default="random", choices=["random", "spatial", "temporal", "seasonal"], ) +argparser.add_argument("--dekadal", dest="dekadal", action="store_true") argparser.add_argument("--train_only_samples_file", type=str, default="train_only_samples.csv") argparser.add_argument("--warm_start", dest="warm_start", action="store_true") @@ -112,6 +123,7 @@ # presto_model_description: str = args["presto_model_description"] test_type: str = args["test_type"] assert test_type in ["random", "spatial", "temporal", "seasonal"] +dekadal = args["dekadal"] seed_everything(seed) output_parent_dir = Path(args["output_dir"]) if args["output_dir"] else Path(__file__).parent @@ -124,23 +136,20 @@ entity=wandb_org, project="presto-worldcereal", dir=output_parent_dir, + name=model_name, ) run_id = cast(wandb.sdk.wandb_run.Run, run).id -model_logging_dir = output_parent_dir / "output" / timestamp_dirname(run_id) +model_logging_dir = output_parent_dir / "models" / model_name / timestamp_dirname(run_id) model_logging_dir.mkdir(exist_ok=True, parents=True) initialize_logging(model_logging_dir) logger.info("Using output dir: %s" % model_logging_dir) -# parquet_file: str = args["parquet_file"] -parquet_file = "/home/vito/butskoc/presto-worldcereal/data/long_parquet/\ -worldcereal_training_data.parquet" +parquet_file: str = args["parquet_file"] train_only_samples_file: str = args["train_only_samples_file"] -dekadal = False compositing_window = "30D" -if "10d" in parquet_file: - dekadal = True +if dekadal: compositing_window = "10D" path_to_config = config_dir / "default.json" @@ -149,7 +158,7 @@ logger.info("Loading data") -files = sorted(glob(f"{parquet_file}/**/*.parquet")) +files = sorted(glob(f"{parquet_file}/**/*.parquet")) # [:10] df_list = [] for f in tqdm(files): _data = pd.read_parquet(f, engine="fastparquet") @@ -168,20 +177,33 @@ logger.info(f"Preparing train and val splits for {test_type} test") val_samples_df = pd.read_csv(data_dir / "test_splits" / val_samples_file) -train_df, val_df = WorldCerealBase.split_df(df, val_sample_ids=val_samples_df.sample_id.tolist()) +train_df, val_df = WorldCerealBase.split_df( + df, val_sample_ids=val_samples_df.sample_id.tolist() +) # val_size=0.1 # Load the mask parameters mask_params = MaskParamsNoDw(mask_strategies, mask_ratio, num_timesteps=36 if dekadal else 12) -masked_ds = WorldCerealMaskedDataset train_dataloader = DataLoader( - masked_ds(train_df, mask_params=mask_params, is_ssl=True), + WorldCerealMaskedDataset( + train_df, + num_timesteps=36 if dekadal else 12, + mask_params=mask_params, + is_ssl=True, + is_dekadal=dekadal, + ), batch_size=batch_size, shuffle=True, num_workers=num_workers, ) val_dataloader = DataLoader( - masked_ds(val_df, mask_params=mask_params, is_ssl=True), + WorldCerealMaskedDataset( + val_df, + num_timesteps=36 if dekadal else 12, + mask_params=mask_params, + is_ssl=True, + is_dekadal=dekadal, + ), batch_size=batch_size, shuffle=False, num_workers=num_workers, @@ -206,12 +228,11 @@ model_kwargs = json.load(Path(path_to_config).open("r")) model = Presto.construct(**model_kwargs) best_model_path = None + if dekadal: + logger.info("extending model to dekadal architecture") + model = extend_to_dekadal(model) -if dekadal: - logger.info("extending model to dekadal architecture") - model = extend_to_dekadal(model) model.to(device) -# print(f"model pos embed shape {model.encoder.pos_embed.shape}") # correctly reinitialized param_groups = param_groups_weight_decay(model, weight_decay) optimizer = optim.AdamW(param_groups, lr=max_learning_rate, betas=(0.9, 0.95))