From 76ff9488223b8de88386d194861a1dea17b5fead Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Mon, 12 Feb 2024 12:08:57 +0800 Subject: [PATCH 01/18] WIP --- presto/dataset.py | 4 + presto/eval.py | 382 +++++++++++++++++++++++++--------------------- train.py | 39 ++--- 3 files changed, 222 insertions(+), 203 deletions(-) diff --git a/presto/dataset.py b/presto/dataset.py index a465f87..9344892 100644 --- a/presto/dataset.py +++ b/presto/dataset.py @@ -188,6 +188,10 @@ def target_maize(row_d) -> int: return int(row_d["CROPTYPE_LABEL"] == 1200) +def target_croptype(row_d) -> int: + return int(row_d["CROPTYPE_LABEL"]) + + class WorldCerealLabelledDataset(WorldCerealBase): # 0: no information, 10: could be both annual or perennial FILTER_LABELS = [0, 10] diff --git a/presto/eval.py b/presto/eval.py index 8031d50..f6101f7 100644 --- a/presto/eval.py +++ b/presto/eval.py @@ -22,6 +22,7 @@ NORMED_BANDS, WorldCerealInferenceDataset, WorldCerealLabelledDataset, + target_croptype, ) from .presto import Presto, PrestoFineTuningModel, param_groups_lrd from .utils import DEFAULT_SEED, device @@ -40,11 +41,10 @@ class Hyperparams: num_workers: int = 4 -class WorldCerealEval: - name = "WorldCerealCropland" - threshold = 0.5 - num_outputs = 1 - regression = False +class WorldCerealEvalBase: + + num_outputs: int + regression: bool = False def __init__( self, @@ -52,16 +52,13 @@ def __init__( val_data: pd.DataFrame, countries_to_remove: Optional[List[str]] = None, years_to_remove: Optional[List[int]] = None, - spatial_inference_savedir: Optional[Path] = None, seed: int = DEFAULT_SEED, target_function: Optional[Callable[[Dict], int]] = None, filter_function: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, - name: Optional[str] = None, ): self.seed = seed - - if name is not None: - self.name = name + self.countries_to_remove = countries_to_remove + self.years_to_remove = years_to_remove self.target_function = target_function # SAR cannot equal 0.0 since we take the log of it @@ -74,19 +71,53 @@ def __init__( self.val_df = self.val_df[~pd.isna(self.val_df).any(axis=1)] self.val_df = self.val_df[~(self.val_df.loc[:, cols] == 0.0).any(axis=1)] self.val_df = self.val_df.set_index("sample_id") - if filter_function is not None: - self.val_df = filter_function(self.val_df) - self.test_df = self.val_df - self.spatial_inference_savedir = spatial_inference_savedir + def train_ds(self, balance: bool = False): + return WorldCerealLabelledDataset( + self.train_df, + countries_to_remove=self.countries_to_remove, + years_to_remove=self.years_to_remove, + target_function=self.target_function, + balance=balance, + ) - self.countries_to_remove = countries_to_remove - self.years_to_remove = years_to_remove + def val_ds(self, balance: bool = False): + return WorldCerealLabelledDataset( + self.val_ds, + countries_to_remove=self.countries_to_remove, + years_to_remove=self.years_to_remove, + target_function=self.target_function, + balance=balance, + ) - if self.countries_to_remove is not None: - self.name = f"{self.name}_removed_countries_{countries_to_remove}" - if self.years_to_remove is not None: - self.name = f"{self.name}_removed_years_{years_to_remove}" + +class WorldCerealFinetuning(WorldCerealEvalBase): + regression = False + + def __init__( + self, + train_data: pd.DataFrame, + val_data: pd.DataFrame, + countries_to_remove: Optional[List[str]] = None, + years_to_remove: Optional[List[int]] = None, + seed: int = DEFAULT_SEED, + filter_function: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, + ): + super().__init__( + train_data, + val_data, + countries_to_remove, + years_to_remove, + seed, + target_croptype, + filter_function, + ) + + # we will map the croptype label to increments from 1 to N + mapping = {key: idx for idx, key in enumerate(self.train_df.CROPTYPE_LABEL.unique())} + self.train_df.CROPTYPE_LABEL = self.train_df.CROPTYPE_LABEL.replace(mapping) + self.val_df.CROPTYPE_LABEL = self.val_df.CROPTYPE_LABEL.replace(mapping) + self.num_outputs = len(self.train_df["CROPTYPE_LABEL"].unique()) def _construct_finetuning_model(self, pretrained_model: Presto) -> PrestoFineTuningModel: model = cast(Callable, pretrained_model.construct_finetuning_model)( @@ -94,6 +125,156 @@ def _construct_finetuning_model(self, pretrained_model: Presto) -> PrestoFineTun ) return model + def finetune(self, pretrained_model) -> PrestoFineTuningModel: + hyperparams = Hyperparams() + model = self._construct_finetuning_model(pretrained_model) + + parameters = param_groups_lrd(model) + optimizer = AdamW(parameters, lr=hyperparams.lr) + + loss_fn = nn.CrossEntropyLoss() + + generator = torch.Generator() + generator.manual_seed(self.seed) + train_dl = DataLoader( + self.train_ds(), + batch_size=hyperparams.batch_size, + shuffle=True, + num_workers=hyperparams.num_workers, + generator=generator, + ) + + val_dl = DataLoader( + self.val_ds(), + batch_size=hyperparams.batch_size, + shuffle=False, + num_workers=hyperparams.num_workers, + ) + + train_loss = [] + val_loss = [] + best_loss = None + best_model_dict = None + epochs_since_improvement = 0 + + run = None + try: + import wandb + + run = wandb.run + except ImportError: + pass + + for _ in (pbar := tqdm(range(hyperparams.max_epochs), desc="Finetuning")): + model.train() + epoch_train_loss = 0.0 + for x, y, dw, latlons, month, variable_mask in tqdm( + train_dl, desc="Training", leave=False + ): + x, y, dw, latlons, month, variable_mask = [ + t.to(device) for t in (x, y, dw, latlons, month, variable_mask) + ] + optimizer.zero_grad() + preds = model( + x, + dynamic_world=dw.long(), + mask=variable_mask, + latlons=latlons, + month=month, + ) + loss = loss_fn(preds, y.float()) + epoch_train_loss += loss.item() + loss.backward() + optimizer.step() + train_loss.append(epoch_train_loss / len(train_dl)) + + model.eval() + all_preds, all_y = [], [] + for x, y, dw, latlons, month, variable_mask in val_dl: + x, y, dw, latlons, month, variable_mask = [ + t.to(device) for t in (x, y, dw, latlons, month, variable_mask) + ] + with torch.no_grad(): + preds = model( + x, + dynamic_world=dw.long(), + mask=variable_mask, + latlons=latlons, + month=month, + ) + all_preds.append(preds) + all_y.append(y.float()) + + val_loss.append(loss_fn(torch.cat(all_preds), torch.cat(all_y))) + pbar.set_description(f"Train metric: {train_loss[-1]}, Val metric: {val_loss[-1]}") + + if run is not None: + wandb.log( + { + f"{self.name}_finetuning_val_loss": val_loss[-1], + f"{self.name}_finetuning_train_loss": train_loss[-1], + } + ) + + if best_loss is None: + best_loss = val_loss[-1] + best_model_dict = deepcopy(model.state_dict()) + else: + if val_loss[-1] < best_loss: + best_loss = val_loss[-1] + best_model_dict = deepcopy(model.state_dict()) + epochs_since_improvement = 0 + else: + epochs_since_improvement += 1 + if epochs_since_improvement >= hyperparams.patience: + logger.info("Early stopping!") + break + assert best_model_dict is not None + model.load_state_dict(best_model_dict) + + model.eval() + return model + + +class WorldCerealEval(WorldCerealEvalBase): + name = "WorldCerealCropland" + threshold = 0.5 + num_outputs = 1 + regression = False + + def __init__( + self, + train_data: pd.DataFrame, + val_data: pd.DataFrame, + countries_to_remove: Optional[List[str]] = None, + years_to_remove: Optional[List[int]] = None, + spatial_inference_savedir: Optional[Path] = None, + seed: int = DEFAULT_SEED, + target_function: Optional[Callable[[Dict], int]] = None, + filter_function: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, + name: Optional[str] = None, + ): + super().__init__( + train_data, + val_data, + countries_to_remove, + years_to_remove, + seed, + target_function, + filter_function, + ) + + if name is not None: + self.name = name + self.test_df = self.val_df + + self.spatial_inference_savedir = spatial_inference_savedir + + if self.countries_to_remove is not None: + self.name = f"{self.name}_removed_countries_{countries_to_remove}" + if self.years_to_remove is not None: + self.name = f"{self.name}_removed_years_{years_to_remove}" + @torch.no_grad() def finetune_sklearn_model( self, @@ -355,155 +536,21 @@ def partitioned_metrics( **metrics("CatBoost_region", world_attrs.region, catboost_preds), } - def finetune(self, pretrained_model) -> PrestoFineTuningModel: - hyperparams = Hyperparams() - model = self._construct_finetuning_model(pretrained_model) - - parameters = param_groups_lrd(model) - optimizer = AdamW(parameters, lr=hyperparams.lr) - - train_ds = WorldCerealLabelledDataset( - self.train_df, - countries_to_remove=self.countries_to_remove, - years_to_remove=self.years_to_remove, - target_function=self.target_function, - balance=True, - ) - - # should the val set be balanced too? - val_ds = WorldCerealLabelledDataset( - self.val_df, - countries_to_remove=self.countries_to_remove, - years_to_remove=self.years_to_remove, - target_function=self.target_function, - ) - - loss_fn = nn.BCEWithLogitsLoss() - - generator = torch.Generator() - generator.manual_seed(self.seed) - train_dl = DataLoader( - train_ds, - batch_size=hyperparams.batch_size, - shuffle=True, - num_workers=hyperparams.num_workers, - generator=generator, - ) - - val_dl = DataLoader( - val_ds, - batch_size=hyperparams.batch_size, - shuffle=False, - num_workers=hyperparams.num_workers, - ) - - train_loss = [] - val_loss = [] - best_loss = None - best_model_dict = None - epochs_since_improvement = 0 - - run = None - try: - import wandb - - run = wandb.run - except ImportError: - pass - - for _ in (pbar := tqdm(range(hyperparams.max_epochs), desc="Finetuning")): - model.train() - epoch_train_loss = 0.0 - for x, y, dw, latlons, month, variable_mask in tqdm( - train_dl, desc="Training", leave=False - ): - x, y, dw, latlons, month, variable_mask = [ - t.to(device) for t in (x, y, dw, latlons, month, variable_mask) - ] - optimizer.zero_grad() - preds = model( - x, - dynamic_world=dw.long(), - mask=variable_mask, - latlons=latlons, - month=month, - ) - loss = loss_fn(preds.squeeze(-1), y.float()) - epoch_train_loss += loss.item() - loss.backward() - optimizer.step() - train_loss.append(epoch_train_loss / len(train_dl)) - - model.eval() - all_preds, all_y = [], [] - for x, y, dw, latlons, month, variable_mask in val_dl: - x, y, dw, latlons, month, variable_mask = [ - t.to(device) for t in (x, y, dw, latlons, month, variable_mask) - ] - with torch.no_grad(): - preds = model( - x, - dynamic_world=dw.long(), - mask=variable_mask, - latlons=latlons, - month=month, - ) - all_preds.append(preds.squeeze(-1)) - all_y.append(y.float()) - - val_loss.append(loss_fn(torch.cat(all_preds), torch.cat(all_y))) - pbar.set_description(f"Train metric: {train_loss[-1]}, Val metric: {val_loss[-1]}") - - if run is not None: - wandb.log( - { - f"{self.name}_finetuning_val_loss": val_loss[-1], - f"{self.name}_finetuning_train_loss": train_loss[-1], - } - ) - - if best_loss is None: - best_loss = val_loss[-1] - best_model_dict = deepcopy(model.state_dict()) - else: - if val_loss[-1] < best_loss: - best_loss = val_loss[-1] - best_model_dict = deepcopy(model.state_dict()) - epochs_since_improvement = 0 - else: - epochs_since_improvement += 1 - if epochs_since_improvement >= hyperparams.patience: - logger.info("Early stopping!") - break - assert best_model_dict is not None - model.load_state_dict(best_model_dict) - - model.eval() - return model - def finetuning_results_sklearn( self, sklearn_model_modes: List[str], finetuned_model: PrestoFineTuningModel ) -> Dict: + for model_mode in sklearn_model_modes: + assert model_mode in ["Regression", "Random Forest", "CatBoostClassifier"] results_dict = {} if len(sklearn_model_modes) > 0: dl = DataLoader( - WorldCerealLabelledDataset( - self.train_df, - countries_to_remove=self.countries_to_remove, - years_to_remove=self.years_to_remove, - target_function=self.target_function, - ), + self.train_ds(), batch_size=2048, shuffle=False, num_workers=4, ) val_dl = DataLoader( - WorldCerealLabelledDataset( - self.val_df, - countries_to_remove=self.countries_to_remove, - years_to_remove=self.years_to_remove, - target_function=self.target_function, - ), + self.val_ds(), batch_size=2048, shuffle=False, num_workers=4, @@ -520,22 +567,3 @@ def finetuning_results_sklearn( if self.spatial_inference_savedir is not None: self.spatial_inference(sklearn_model, finetuned_model) return results_dict - - def finetuning_results( - self, - pretrained_model, - sklearn_model_modes: List[str], - ) -> Tuple[Dict, PrestoFineTuningModel]: - for model_mode in sklearn_model_modes: - assert model_mode in ["Regression", "Random Forest", "CatBoostClassifier"] - - results_dict = {} - # we want to always finetune the model, since the sklearn models - # will use the finetuned model as a base. This better reflects - # the deployment scenario for WorldCereal - finetuned_model = self.finetune(pretrained_model) - results_dict.update(self.evaluate(finetuned_model, None)) - if self.spatial_inference_savedir is not None: - self.spatial_inference(finetuned_model, None) - results_dict.update(self.finetuning_results_sklearn(sklearn_model_modes, finetuned_model)) - return results_dict, finetuned_model diff --git a/train.py b/train.py index 2885ed0..3293f03 100644 --- a/train.py +++ b/train.py @@ -16,7 +16,7 @@ from presto.dataops import BANDS_GROUPS_IDX from presto.dataset import WorldCerealMaskedDataset as WorldCerealDataset from presto.dataset import filter_remove_noncrops, target_maize -from presto.eval import WorldCerealEval +from presto.eval import WorldCerealEval, WorldCerealFinetuning from presto.masking import MASK_STRATEGIES, MaskParamsNoDw from presto.presto import ( LossWrapper, @@ -153,10 +153,6 @@ shuffle=False, num_workers=num_workers, ) -validation_task = WorldCerealEval( - train_data=train_df.sample(1000, random_state=DEFAULT_SEED), - val_data=val_df.sample(1000, random_state=DEFAULT_SEED), -) if val_per_n_steps == -1: val_per_n_steps = len(train_dataloader) @@ -293,11 +289,6 @@ } tqdm_epoch.set_postfix(loss=val_eo_loss) - val_task_results, _ = validation_task.finetuning_results( - model, sklearn_model_modes=["Random Forest"] - ) - to_log.update(val_task_results) - if lowest_validation_loss is None or val_eo_loss < lowest_validation_loss: lowest_validation_loss = val_eo_loss best_val_epoch = epoch @@ -330,16 +321,21 @@ logger.info("Running eval with randomly init weights") -model_modes = ["Random Forest", "Regression", "CatBoostClassifier"] -full_eval = WorldCerealEval(train_df, val_df, spatial_inference_savedir=model_logging_dir) -results, finetuned_model = full_eval.finetuning_results(model, sklearn_model_modes=model_modes) -logger.info(json.dumps(results, indent=2)) - +# full finetuning +full_finetuning = WorldCerealFinetuning(train_df, val_df) +finetuned_model = full_finetuning.finetune(model) model_path = model_logging_dir / Path("models") model_path.mkdir(exist_ok=True, parents=True) finetuned_model_path = model_path / "finetuned_model.pt" torch.save(finetuned_model.state_dict(), finetuned_model_path) +model_modes = ["Random Forest", "Regression", "CatBoostClassifier"] +full_eval = WorldCerealEval(train_df, val_df, spatial_inference_savedir=model_logging_dir) +results = full_eval.finetuning_results_sklearn( + sklearn_model_modes=model_modes, finetuned_model=finetuned_model +) +logger.info(json.dumps(results, indent=2)) + full_maize_eval = WorldCerealEval( train_df, val_df, @@ -348,11 +344,10 @@ filter_function=filter_remove_noncrops, name="WorldCerealMaize", ) -maize_results, maize_finetuned_model = full_maize_eval.finetuning_results( - model, sklearn_model_modes=model_modes +maize_results = full_maize_eval.finetuning_results_sklearn( + sklearn_model_modes=model_modes, finetuned_model=finetuned_model ) logger.info(json.dumps(maize_results, indent=2)) -torch.save(maize_finetuned_model.state_dict(), model_path / "maize_finetuned_model.pt") # not saving plots to wandb plot_results(load_world_df(), results, model_logging_dir, show=True, to_wandb=False) @@ -360,13 +355,6 @@ load_world_df(), maize_results, model_logging_dir, show=True, to_wandb=False, prefix="maize_" ) -# this is a bit hacky, but it lets us simulate crop/non-crop finetuning -> maize prediction head -full_maize_eval.name = "WorldCerealCropFinetuningMaizeHead" -crop_to_maize_results = full_maize_eval.finetuning_results_sklearn( - sklearn_model_modes=model_modes, finetuned_model=finetuned_model -) -logger.info(json.dumps(crop_to_maize_results, indent=2)) - # missing data experiments country_results = [] for country in ["Latvia", "Brazil", "Togo", "Madagascar"]: @@ -410,7 +398,6 @@ if wandb_enabled: wandb.log(results) wandb.log(maize_results) - wandb.log(crop_to_maize_results) for results in country_results: wandb.log(results) wandb.log(year_results) From dbb14e060cc5e9cc0d371a687c06bc71d6e8c61c Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Mon, 12 Feb 2024 14:51:50 +0800 Subject: [PATCH 02/18] Consolidate everything --- presto/eval.py | 9 +-------- tests/test_eval.py | 19 +++++++++++-------- train.py | 16 ++++------------ 3 files changed, 16 insertions(+), 28 deletions(-) diff --git a/presto/eval.py b/presto/eval.py index f6101f7..e489799 100644 --- a/presto/eval.py +++ b/presto/eval.py @@ -92,6 +92,7 @@ def val_ds(self, balance: bool = False): class WorldCerealFinetuning(WorldCerealEvalBase): + name = "WorldCerealFineTuning" regression = False def __init__( @@ -329,15 +330,7 @@ def dataloader_to_encodings_and_targets(dl: DataLoader) -> Tuple[np.ndarray, np. class_weight=class_weight_dict, random_state=self.seed, ), - # Parameters emulate - # https://github.com/WorldCereal/wc-classification/blob/ - # 4a9a839507d9b4f63c378b3b1d164325cbe843d6/src/worldcereal/classification/models.py#L490 "CatBoostClassifier": CatBoostClassifier( - iterations=8000, - depth=8, - learning_rate=0.05, - early_stopping_rounds=20, - l2_leaf_reg=3, random_state=self.seed, class_weights=class_weight_dict, ), diff --git a/tests/test_eval.py b/tests/test_eval.py index a711334..9531575 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -6,7 +6,7 @@ import rioxarray import xarray as xr -from presto.eval import WorldCerealEval +from presto.eval import WorldCerealEval, WorldCerealFinetuning from presto.presto import Presto from presto.utils import data_dir from tests.utils import read_test_file @@ -18,12 +18,15 @@ def test_eval(self): test_data = read_test_file() eval_task = WorldCerealEval(test_data, test_data) - - output, _ = eval_task.finetuning_results(model, ["CatBoostClassifier"]) - # * 283 per model: WorldCereal CatBoost, Presto finetuned, Presto + CatBoost3 - self.assertEqual(len(output), 282 * 3) - self.assertTrue("WorldCerealCropland_CatBoostClassifier_f1" in output) - self.assertTrue("WorldCerealCropland_CatBoostClassifier_f1" in output) + finetuning_task = WorldCerealFinetuning(test_data, test_data) + finetuning_model = finetuning_task.finetune(model) + results = eval_task.finetuning_results_sklearn( + finetuned_model=finetuning_model, sklearn_model_modes=["CatBoostClassifier"] + ) + # * 283 per model: WorldCereal CatBoost, Presto + CatBoost3 + self.assertEqual(len(results), 282 * 2) + self.assertTrue("WorldCerealCropland_CatBoostClassifier_f1" in results) + self.assertTrue("WorldCerealCropland_CatBoostClassifier_f1" in results) def test_spatial_inference( self, @@ -40,7 +43,7 @@ def test_spatial_inference( eval_task = WorldCerealEval( test_data, test_data, spatial_inference_savedir=Path(tmpdirname) ) - finetuned_model = eval_task._construct_finetuning_model(model) + finetuned_model = model.construct_finetuning_model(num_outputs=1) eval_task.spatial_inference(finetuned_model, None) output = xr.open_dataset( Path(tmpdirname) / f"{eval_task.name}_{spatial_data_prefix}_finetuning.nc" diff --git a/train.py b/train.py index 3293f03..c71c6f7 100644 --- a/train.py +++ b/train.py @@ -358,6 +358,8 @@ # missing data experiments country_results = [] for country in ["Latvia", "Brazil", "Togo", "Madagascar"]: + finetuning_task = WorldCerealFinetuning(train_df, val_df, countries_to_remove=[country]) + finetuned_model = finetuning_task.finetune(model) for predict_maize in [True, False]: kwargs = { "train_data": train_df, @@ -374,20 +376,11 @@ } ) eval_task = WorldCerealEval(**kwargs) - results, finetuned_model = eval_task.finetuning_results( - model, sklearn_model_modes=model_modes + results = eval_task.finetuning_results_sklearn( + finetuned_model=finetuned_model, sklearn_model_modes=model_modes ) logger.info(json.dumps(results, indent=2)) country_results.append(results) - prefix = "maize" if predict_maize else "" - finetuned_model_path = model_path / f"{prefix}_finetuned_{country}_removed_model.pt" - torch.save(finetuned_model.state_dict(), finetuned_model_path) - -missing_year = WorldCerealEval( - train_df, val_df, years_to_remove=[2021], spatial_inference_savedir=model_logging_dir -) -year_results, _ = missing_year.finetuning_results(model, sklearn_model_modes=model_modes) -logger.info(json.dumps(year_results, indent=2)) all_spatial_preds = list(model_logging_dir.glob("*.nc")) for spatial_preds_path in all_spatial_preds: @@ -400,7 +393,6 @@ wandb.log(maize_results) for results in country_results: wandb.log(results) - wandb.log(year_results) if wandb_enabled and run: run.finish() From b3efb5a23b36c592ac3414f44c11ea484e1faf83 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Mon, 12 Feb 2024 15:02:26 +0800 Subject: [PATCH 03/18] s -> f --- presto/eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presto/eval.py b/presto/eval.py index e489799..22aa6bc 100644 --- a/presto/eval.py +++ b/presto/eval.py @@ -83,7 +83,7 @@ def train_ds(self, balance: bool = False): def val_ds(self, balance: bool = False): return WorldCerealLabelledDataset( - self.val_ds, + self.val_df, countries_to_remove=self.countries_to_remove, years_to_remove=self.years_to_remove, target_function=self.target_function, From 907a77348e8d781d9a4594914825e713fa96aaf6 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Mon, 12 Feb 2024 15:06:50 +0800 Subject: [PATCH 04/18] RuntimeError: expected scalar type Long but found Float --- presto/eval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/presto/eval.py b/presto/eval.py index 22aa6bc..183a005 100644 --- a/presto/eval.py +++ b/presto/eval.py @@ -183,7 +183,7 @@ def finetune(self, pretrained_model) -> PrestoFineTuningModel: latlons=latlons, month=month, ) - loss = loss_fn(preds, y.float()) + loss = loss_fn(preds, y.long()) epoch_train_loss += loss.item() loss.backward() optimizer.step() @@ -204,7 +204,7 @@ def finetune(self, pretrained_model) -> PrestoFineTuningModel: month=month, ) all_preds.append(preds) - all_y.append(y.float()) + all_y.append(y.long()) val_loss.append(loss_fn(torch.cat(all_preds), torch.cat(all_y))) pbar.set_description(f"Train metric: {train_loss[-1]}, Val metric: {val_loss[-1]}") From 624bec6424fd374b3fe78123a1a93b74f22306f6 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Mon, 12 Feb 2024 20:04:29 +0800 Subject: [PATCH 05/18] More finetuning --- presto/eval.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/presto/eval.py b/presto/eval.py index 183a005..a495a25 100644 --- a/presto/eval.py +++ b/presto/eval.py @@ -34,9 +34,9 @@ @dataclass class Hyperparams: - lr: float = 2e-5 - max_epochs: int = 20 - batch_size: int = 64 + lr: float = 3e-4 + max_epochs: int = 100 + batch_size: int = 128 patience: int = 3 num_workers: int = 4 From 51b5f198fd43f1d7895049f90a5393bc02502b94 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Tue, 13 Feb 2024 15:25:07 +0800 Subject: [PATCH 06/18] Don't change too many things at once --- presto/eval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/presto/eval.py b/presto/eval.py index a495a25..8958c87 100644 --- a/presto/eval.py +++ b/presto/eval.py @@ -34,9 +34,9 @@ @dataclass class Hyperparams: - lr: float = 3e-4 + lr: float = 2e-5 max_epochs: int = 100 - batch_size: int = 128 + batch_size: int = 64 patience: int = 3 num_workers: int = 4 From e4d29797e27f162f1d57f0bdffdb6426aa34a5cb Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Fri, 23 Feb 2024 10:28:08 +0800 Subject: [PATCH 07/18] Pass the valid_month as an additional token --- presto/dataset.py | 11 ++++++++--- presto/eval.py | 29 +++++++++++++++++------------ presto/presto.py | 24 ++++++++++++++++++------ 3 files changed, 43 insertions(+), 21 deletions(-) diff --git a/presto/dataset.py b/presto/dataset.py index 9344892..683d258 100644 --- a/presto/dataset.py +++ b/presto/dataset.py @@ -68,13 +68,14 @@ def target_crop(row_d: Dict) -> int: @classmethod def row_to_arrays( cls, row: pd.Series, target_function: Callable[[Dict], int] - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float, int]: + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float, int, int]: # https://stackoverflow.com/questions/45783891/is-there-a-way-to-speed-up-the-pandas-getitem-getitem-axis-and-get-label # This is faster than indexing the series every time! row_d = pd.Series.to_dict(row) latlon = np.array([row_d["lat"], row_d["lon"]], dtype=np.float32) month = datetime.strptime(row_d["start_date"], "%Y-%m-%d").month - 1 + valid_month = datetime.strptime(row_d["valid_date"], "%Y-%m-%d").month - 1 eo_data = np.zeros((cls.NUM_TIMESTEPS, len(BANDS))) # an assumption we make here is that all timesteps for a token @@ -109,6 +110,7 @@ def row_to_arrays( mask.astype(bool), latlon, month, + valid_month, target_function(row_d), ) @@ -154,7 +156,7 @@ def __init__(self, dataframe: pd.DataFrame, mask_params: MaskParamsNoDw): def __getitem__(self, idx): # Get the sample row = self.df.iloc[idx, :] - eo, real_mask_per_token, latlon, month, _ = self.row_to_arrays(row, self.target_crop) + eo, real_mask_per_token, latlon, month, _, _ = self.row_to_arrays(row, self.target_crop) mask_eo, x_eo, y_eo, strat = self.mask_params.mask_data( self.normalize_and_mask(eo), real_mask_per_token ) @@ -250,7 +252,9 @@ def __getitem__(self, idx): # Get the sample df_index = self.indices[idx] row = self.df.iloc[df_index, :] - eo, mask_per_token, latlon, month, target = self.row_to_arrays(row, self.target_function) + eo, mask_per_token, latlon, month, valid_month, target = self.row_to_arrays( + row, self.target_function + ) mask_per_variable = np.repeat(mask_per_token, BAND_EXPANSION, axis=1) return ( self.normalize_and_mask(eo), @@ -258,6 +262,7 @@ def __getitem__(self, idx): np.ones(self.NUM_TIMESTEPS) * (DynamicWorld2020_2021.class_amount), latlon, month, + valid_month, mask_per_variable, ) diff --git a/presto/eval.py b/presto/eval.py index 8958c87..24e4c69 100644 --- a/presto/eval.py +++ b/presto/eval.py @@ -169,11 +169,11 @@ def finetune(self, pretrained_model) -> PrestoFineTuningModel: for _ in (pbar := tqdm(range(hyperparams.max_epochs), desc="Finetuning")): model.train() epoch_train_loss = 0.0 - for x, y, dw, latlons, month, variable_mask in tqdm( + for x, y, dw, latlons, month, valid_month, variable_mask in tqdm( train_dl, desc="Training", leave=False ): - x, y, dw, latlons, month, variable_mask = [ - t.to(device) for t in (x, y, dw, latlons, month, variable_mask) + x, y, dw, latlons, month, valid_month, variable_mask = [ + t.to(device) for t in (x, y, dw, latlons, month, valid_month, variable_mask) ] optimizer.zero_grad() preds = model( @@ -182,6 +182,7 @@ def finetune(self, pretrained_model) -> PrestoFineTuningModel: mask=variable_mask, latlons=latlons, month=month, + valid_month=valid_month, ) loss = loss_fn(preds, y.long()) epoch_train_loss += loss.item() @@ -191,9 +192,9 @@ def finetune(self, pretrained_model) -> PrestoFineTuningModel: model.eval() all_preds, all_y = [], [] - for x, y, dw, latlons, month, variable_mask in val_dl: - x, y, dw, latlons, month, variable_mask = [ - t.to(device) for t in (x, y, dw, latlons, month, variable_mask) + for x, y, dw, latlons, month, valid_month, variable_mask in val_dl: + x, y, dw, latlons, month, valid_month, variable_mask = [ + t.to(device) for t in (x, y, dw, latlons, month, valid_month, variable_mask) ] with torch.no_grad(): preds = model( @@ -202,6 +203,7 @@ def finetune(self, pretrained_model) -> PrestoFineTuningModel: mask=variable_mask, latlons=latlons, month=month, + valid_month=valid_month, ) all_preds.append(preds) all_y.append(y.long()) @@ -290,9 +292,9 @@ def finetune_sklearn_model( def dataloader_to_encodings_and_targets(dl: DataLoader) -> Tuple[np.ndarray, np.ndarray]: encoding_list, target_list = [], [] - for x, y, dw, latlons, month, variable_mask in dl: - x_f, dw_f, latlons_f, month_f, variable_mask_f = [ - t.to(device) for t in (x, dw, latlons, month, variable_mask) + for x, y, dw, latlons, month, valid_month, variable_mask in dl: + x_f, dw_f, latlons_f, month_f, valid_month_f, variable_mask_f = [ + t.to(device) for t in (x, dw, latlons, month, valid_month, variable_mask) ] target_list.append(y) with torch.no_grad(): @@ -303,6 +305,7 @@ def dataloader_to_encodings_and_targets(dl: DataLoader) -> Tuple[np.ndarray, np. mask=variable_mask_f, latlons=latlons_f, month=month_f, + valid_month=valid_month_f, ) .cpu() .numpy() @@ -355,10 +358,10 @@ def _inference_for_dl( test_preds, targets = [], [] - for x, y, dw, latlons, month, variable_mask in dl: + for x, y, dw, latlons, month, valid_month, variable_mask in dl: targets.append(y) - x_f, dw_f, latlons_f, month_f, variable_mask_f = [ - t.to(device) for t in (x, dw, latlons, month, variable_mask) + x_f, dw_f, latlons_f, month_f, valid_month_f, variable_mask_f = [ + t.to(device) for t in (x, dw, latlons, month, valid_month, variable_mask) ] if isinstance(finetuned_model, PrestoFineTuningModel): finetuned_model.eval() @@ -368,6 +371,7 @@ def _inference_for_dl( mask=variable_mask_f, latlons=latlons_f, month=month_f, + valid_month=valid_month_f, ).squeeze(dim=1) preds = torch.sigmoid(preds).cpu().numpy() else: @@ -380,6 +384,7 @@ def _inference_for_dl( mask=variable_mask_f, latlons=latlons_f, month=month_f, + valid_month=valid_month_f, ) .cpu() .numpy() diff --git a/presto/presto.py b/presto/presto.py index 174c240..b294e46 100644 --- a/presto/presto.py +++ b/presto/presto.py @@ -333,6 +333,8 @@ def __init__( num_embeddings=len(self.band_groups) + 1, embedding_dim=channel_embedding_size ) + self.valid_month_encoding = nn.Embedding(num_embeddings=12, embedding_dim=embedding_size) + self.initialize_weights() def initialize_weights(self): @@ -382,6 +384,16 @@ def mask_tokens(x, mask): return x, indices, updated_mask + @staticmethod + def add_token(token, token_list, mask, indices): + token_list = torch.cat((token, token_list), dim=1) + mask = torch.cat((torch.zeros(token_list.shape[0])[:, None].to(device), mask), dim=1) + indices = torch.cat( + (torch.zeros(token_list.shape[0])[:, None].to(device).int(), indices + 1), + dim=1, + ) + return token_list, mask, indices + def forward( self, x: torch.Tensor, @@ -389,6 +401,7 @@ def forward( latlons: torch.Tensor, mask: Optional[torch.Tensor] = None, month: Union[torch.Tensor, int] = 0, + valid_month: Optional[torch.Tensor] = None, eval_task: bool = True, ): device = x.device @@ -458,12 +471,11 @@ def forward( # append latlon tokens latlon_tokens = self.latlon_embed(self.cartesian(latlons)).unsqueeze(1) - x = torch.cat((latlon_tokens, x), dim=1) - upd_mask = torch.cat((torch.zeros(x.shape[0])[:, None].to(device), upd_mask), dim=1) - orig_indices = torch.cat( - (torch.zeros(x.shape[0])[:, None].to(device).int(), orig_indices + 1), - dim=1, - ) + x, upd_mask, orig_indices = self.add_token(latlon_tokens, x, upd_mask, orig_indices) + + if valid_month is not None: + val_month_token = self.valid_month_encoding(valid_month) + x, upd_mask, orig_indices = self.add_token(val_month_token, x, upd_mask, orig_indices) # apply Transformer blocks for blk in self.blocks: From b05a16c338a38972ff706ed13eccb0c7efc2373d Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Fri, 23 Feb 2024 10:32:52 +0800 Subject: [PATCH 08/18] Update test file to include a valid_date --- tests/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/utils.py b/tests/utils.py index 79e2af5..7a4a177 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -17,6 +17,7 @@ def read_test_file() -> pd.DataFrame: axis=1, inplace=True, ) + test_df["valid_date"] = test_df["start_date"] test_df["sample_id"] = np.arange(len(test_df)) test_df["year"] = 2021 labels = [99] * len(test_df) # 99 = No cropland From b5230b59ba642aa9a39ee3892db8fb0228aabda6 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Fri, 23 Feb 2024 10:36:53 +0800 Subject: [PATCH 09/18] Load_pretrained False by default now that we have added this new embedding layer --- presto/presto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presto/presto.py b/presto/presto.py index b294e46..0622d8a 100644 --- a/presto/presto.py +++ b/presto/presto.py @@ -803,7 +803,7 @@ def construct_finetuning_model( @classmethod def load_pretrained( - cls, model_path: Union[str, Path] = default_model_path, strict: bool = True + cls, model_path: Union[str, Path] = default_model_path, strict: bool = False ): model = cls.construct() model.load_state_dict(torch.load(model_path, map_location=device), strict=strict) From aa3c3701f8e9df33805aeee60659044cddf0c6c0 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Fri, 23 Feb 2024 16:33:17 +0800 Subject: [PATCH 10/18] Update tests --- presto/eval.py | 53 +++++++++++++++++++++++--------------------- presto/presto.py | 2 ++ tests/test_presto.py | 12 +++++++--- 3 files changed, 39 insertions(+), 28 deletions(-) diff --git a/presto/eval.py b/presto/eval.py index 24e4c69..635b889 100644 --- a/presto/eval.py +++ b/presto/eval.py @@ -358,37 +358,40 @@ def _inference_for_dl( test_preds, targets = [], [] - for x, y, dw, latlons, month, valid_month, variable_mask in dl: + for b in dl: + try: + x, y, dw, latlons, month, valid_month, variable_mask = b + x_f, dw_f, latlons_f, month_f, valid_month_f, variable_mask_f = [ + t.to(device) for t in (x, dw, latlons, month, valid_month, variable_mask) + ] + input_d = { + "x": x_f, + "dynamic_world": dw_f.long(), + "latlons": latlons_f, + "mask": variable_mask_f, + "month": month_f, + "valid_month": valid_month_f, + } + except ValueError: + x, y, dw, latlons, month, variable_mask = b + x_f, dw_f, latlons_f, month_f, variable_mask_f = [ + t.to(device) for t in (x, dw, latlons, month, variable_mask) + ] + input_d = { + "x": x_f, + "dynamic_world": dw_f.long(), + "latlons": latlons_f, + "mask": variable_mask_f, + "month": month_f, + } targets.append(y) - x_f, dw_f, latlons_f, month_f, valid_month_f, variable_mask_f = [ - t.to(device) for t in (x, dw, latlons, month, valid_month, variable_mask) - ] if isinstance(finetuned_model, PrestoFineTuningModel): finetuned_model.eval() - preds = finetuned_model( - x_f, - dynamic_world=dw_f.long(), - mask=variable_mask_f, - latlons=latlons_f, - month=month_f, - valid_month=valid_month_f, - ).squeeze(dim=1) + preds = finetuned_model(**input_d).squeeze(dim=1) preds = torch.sigmoid(preds).cpu().numpy() else: cast(Presto, pretrained_model).eval() - encodings = ( - cast(Presto, pretrained_model) - .encoder( - x_f, - dynamic_world=dw_f.long(), - mask=variable_mask_f, - latlons=latlons_f, - month=month_f, - valid_month=valid_month_f, - ) - .cpu() - .numpy() - ) + encodings = cast(Presto, pretrained_model).encoder(**input_d).cpu().numpy() preds = finetuned_model.predict_proba(encodings)[:, 1] test_preds.append(preds) diff --git a/presto/presto.py b/presto/presto.py index 0622d8a..d4a1eb7 100644 --- a/presto/presto.py +++ b/presto/presto.py @@ -704,6 +704,7 @@ def forward( latlons: torch.Tensor, mask: Optional[torch.Tensor] = None, month: Union[torch.Tensor, int] = 0, + valid_month: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.head( self.encoder( @@ -712,6 +713,7 @@ def forward( latlons=latlons, mask=mask, month=month, + valid_month=valid_month, eval_task=True, ) ) diff --git a/tests/test_presto.py b/tests/test_presto.py index f2effc3..9c53750 100644 --- a/tests/test_presto.py +++ b/tests/test_presto.py @@ -263,7 +263,7 @@ def test_default_loading_behaviour(self): all 3 ways of loading the pretrained model are in agreement """ model = Presto.construct() - model.load_state_dict(torch.load(default_model_path, map_location=device)) + model.load_state_dict(torch.load(default_model_path, map_location=device), strict=False) from_function = Presto.load_pretrained() for torch_loaded, pretrain_loaded in zip(model.parameters(), from_function.parameters()): @@ -321,7 +321,11 @@ def test_grads(self): output.backward() for name, param in encoder.named_parameters(): - if ("pos_embed" not in name) and ("month_embed" not in name): + if ( + ("pos_embed" not in name) + and ("month_embed" not in name) + and ("valid_month_encoding" not in name) + ): # the positional encoder is frozen self.assertIsNotNone(param.grad, msg=name) @@ -353,7 +357,9 @@ def test_finetuning_model_outputs_equivalent(self): def test_load_pretrained_works_for_finetuned_model(self): path_to_finetuned_model = data_dir / "finetuned_model.pt" model = Presto.load_pretrained().construct_finetuning_model(num_outputs=1) - model.load_state_dict(torch.load(path_to_finetuned_model, map_location=device)) + model.load_state_dict( + torch.load(path_to_finetuned_model, map_location=device), strict=False + ) model_2 = Presto.load_pretrained(path_to_finetuned_model, strict=False) From a3a4f44c539426078d05b994c5033c2d7f6a7685 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Fri, 23 Feb 2024 16:45:26 +0800 Subject: [PATCH 11/18] Update tests --- presto/presto.py | 2 +- tests/test_presto.py | 21 +++++++++++++++------ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/presto/presto.py b/presto/presto.py index d4a1eb7..c9b2a0f 100644 --- a/presto/presto.py +++ b/presto/presto.py @@ -474,7 +474,7 @@ def forward( x, upd_mask, orig_indices = self.add_token(latlon_tokens, x, upd_mask, orig_indices) if valid_month is not None: - val_month_token = self.valid_month_encoding(valid_month) + val_month_token = self.valid_month_encoding(valid_month).unsqueeze(1) x, upd_mask, orig_indices = self.add_token(val_month_token, x, upd_mask, orig_indices) # apply Transformer blocks diff --git a/tests/test_presto.py b/tests/test_presto.py index 9c53750..74b442c 100644 --- a/tests/test_presto.py +++ b/tests/test_presto.py @@ -266,16 +266,24 @@ def test_default_loading_behaviour(self): model.load_state_dict(torch.load(default_model_path, map_location=device), strict=False) from_function = Presto.load_pretrained() - for torch_loaded, pretrain_loaded in zip(model.parameters(), from_function.parameters()): - self.assertTrue(torch.equal(torch_loaded, pretrain_loaded)) + for (name, torch_loaded), pretrain_loaded in zip( + model.named_parameters(), from_function.parameters() + ): + if "valid_month" not in name: + self.assertTrue(torch.equal(torch_loaded, pretrain_loaded)) path_to_config = config_dir / "default.json" with Path(path_to_config).open("r") as f: model_kwargs = json.load(f) from_config = Presto.construct(**model_kwargs) - from_config.load_state_dict(torch.load(default_model_path, map_location=device)) - for torch_loaded, config_loaded in zip(model.parameters(), from_config.parameters()): - self.assertTrue(torch.equal(torch_loaded, config_loaded)) + from_config.load_state_dict( + torch.load(default_model_path, map_location=device), strict=False + ) + for (name, torch_loaded), config_loaded in zip( + model.named_parameters(), from_config.parameters() + ): + if "valid_month" not in name: + self.assertTrue(torch.equal(torch_loaded, config_loaded)) def test_reconstruct_inputs(self): model = Presto.construct().decoder @@ -364,7 +372,8 @@ def test_load_pretrained_works_for_finetuned_model(self): model_2 = Presto.load_pretrained(path_to_finetuned_model, strict=False) for name, param in model.encoder.named_parameters(): - self.assertTrue(param.equal(model_2.encoder.state_dict()[name])) + if "valid_month" not in name: + self.assertTrue(param.equal(model_2.encoder.state_dict()[name])) batch_size = 3 with torch.no_grad(): From 5860314fdb82ff472703d1f8b0a41959f5d50caf Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Fri, 23 Feb 2024 17:00:30 +0800 Subject: [PATCH 12/18] strict=False when loading state_dict in train.py --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index c71c6f7..3f9168b 100644 --- a/train.py +++ b/train.py @@ -315,7 +315,7 @@ if best_model_path is not None: logger.info("Loading best model: %s" % best_model_path) - best_model = torch.load(best_model_path, map_location=device) + best_model = torch.load(best_model_path, map_location=device, strict=False) model.load_state_dict(best_model) else: logger.info("Running eval with randomly init weights") From c106805e2c7ab8c95bfd2c31282bb1cadbb38598 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Fri, 23 Feb 2024 17:09:59 +0800 Subject: [PATCH 13/18] :facepalm: --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 3f9168b..eeff6d0 100644 --- a/train.py +++ b/train.py @@ -315,8 +315,8 @@ if best_model_path is not None: logger.info("Loading best model: %s" % best_model_path) - best_model = torch.load(best_model_path, map_location=device, strict=False) - model.load_state_dict(best_model) + best_model = torch.load(best_model_path, map_location=device) + model.load_state_dict(best_model, strict=False) else: logger.info("Running eval with randomly init weights") From 5c5496f704fbf66bb4e7562840f7915b560b275c Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Wed, 6 Mar 2024 20:17:21 +0800 Subject: [PATCH 14/18] init valid month encoding --- presto/presto.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/presto/presto.py b/presto/presto.py index c9b2a0f..fcab8d0 100644 --- a/presto/presto.py +++ b/presto/presto.py @@ -333,7 +333,9 @@ def __init__( num_embeddings=len(self.band_groups) + 1, embedding_dim=channel_embedding_size ) - self.valid_month_encoding = nn.Embedding(num_embeddings=12, embedding_dim=embedding_size) + self.valid_month_encoding = nn.Embedding.from_pretrained( + get_month_encoding_table(embedding_size) + ) self.initialize_weights() From 59b626cc610291633dfc47698c264835a747d5b0 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Wed, 6 Mar 2024 21:53:25 +0800 Subject: [PATCH 15/18] More flexibility in how we use the val month --- config/default.json | 4 +- presto/presto.py | 24 ++++- train.py | 224 ++------------------------------------------ 3 files changed, 29 insertions(+), 223 deletions(-) diff --git a/config/default.json b/config/default.json index 1236f53..f4df435 100644 --- a/config/default.json +++ b/config/default.json @@ -7,5 +7,7 @@ "encoder_num_heads": 8, "decoder_embedding_size": 128, "decoder_depth": 2, - "decoder_num_heads": 8 + "decoder_num_heads": 8, + "valid_month_as_token": false, + "valid_month_size": 32 } diff --git a/presto/presto.py b/presto/presto.py index fcab8d0..1b350dc 100644 --- a/presto/presto.py +++ b/presto/presto.py @@ -282,11 +282,16 @@ def __init__( mlp_ratio=2, num_heads=8, max_sequence_length=24, + valid_month_as_token: bool = False, + valid_month_size: int = 128, ): super().__init__() self.band_groups = BANDS_GROUPS_IDX self.embedding_size = embedding_size + self.valid_month_as_token = valid_month_as_token + if valid_month_as_token: + assert valid_month_size == embedding_size # this is used for the channel embedding self.band_group_to_idx = { @@ -334,7 +339,7 @@ def __init__( ) self.valid_month_encoding = nn.Embedding.from_pretrained( - get_month_encoding_table(embedding_size) + get_month_encoding_table(valid_month_size) ) self.initialize_weights() @@ -477,6 +482,7 @@ def forward( if valid_month is not None: val_month_token = self.valid_month_encoding(valid_month).unsqueeze(1) + if self.valid_month_as_token: x, upd_mask, orig_indices = self.add_token(val_month_token, x, upd_mask, orig_indices) # apply Transformer blocks @@ -488,7 +494,7 @@ def forward( # set masked tokens to 0 x_for_mean = x * (1 - upd_mask.unsqueeze(-1)) x_mean = x_for_mean.sum(dim=1) / torch.sum(1 - upd_mask, -1, keepdim=True) - return self.norm(x_mean) + return torch.cat([self.norm(x_mean), val_month_token], axis=-1) return self.norm(x), orig_indices, upd_mask @@ -772,6 +778,8 @@ def construct( decoder_depth=2, decoder_num_heads=8, max_sequence_length=24, + valid_month_as_token: bool = False, + valid_month_size: int = 128, ): encoder = Encoder( embedding_size=encoder_embedding_size, @@ -781,6 +789,8 @@ def construct( mlp_ratio=mlp_ratio, num_heads=encoder_num_heads, max_sequence_length=max_sequence_length, + valid_month_as_token=valid_month_as_token, + valid_month_size=valid_month_size, ) decoder = Decoder( channel_embeddings=encoder.channel_embed, @@ -807,9 +817,15 @@ def construct_finetuning_model( @classmethod def load_pretrained( - cls, model_path: Union[str, Path] = default_model_path, strict: bool = False + cls, + model_path: Union[str, Path] = default_model_path, + strict: bool = False, + valid_month_as_token: bool = False, + valid_month_size: int = 128, ): - model = cls.construct() + model = cls.construct( + valid_month_as_token=valid_month_as_token, valid_month_size=valid_month_size + ) model.load_state_dict(torch.load(model_path, map_location=device), strict=strict) return model diff --git a/train.py b/train.py index eeff6d0..46c18ec 100644 --- a/train.py +++ b/train.py @@ -3,27 +3,16 @@ import json import logging from pathlib import Path -from typing import Optional, Tuple, cast +from typing import Optional, cast import pandas as pd import torch -import torch.nn as nn import xarray as xr -from torch import optim -from torch.utils.data import DataLoader -from tqdm import tqdm -from presto.dataops import BANDS_GROUPS_IDX from presto.dataset import WorldCerealMaskedDataset as WorldCerealDataset from presto.dataset import filter_remove_noncrops, target_maize from presto.eval import WorldCerealEval, WorldCerealFinetuning -from presto.masking import MASK_STRATEGIES, MaskParamsNoDw -from presto.presto import ( - LossWrapper, - Presto, - adjust_learning_rate, - param_groups_weight_decay, -) +from presto.presto import Presto from presto.utils import ( DEFAULT_SEED, config_dir, @@ -51,26 +40,6 @@ "and /output/ will be written to. " "Leave empty to use the directory you are running this file from.", ) -argparser.add_argument("--n_epochs", type=int, default=20) -argparser.add_argument("--max_learning_rate", type=float, default=0.0001) -argparser.add_argument("--min_learning_rate", type=float, default=0.0) -argparser.add_argument("--warmup_epochs", type=int, default=2) -argparser.add_argument("--weight_decay", type=float, default=0.05) -argparser.add_argument("--batch_size", type=int, default=4096) -argparser.add_argument("--val_per_n_steps", type=int, default=-1, help="If -1, val every epoch") -argparser.add_argument( - "--mask_strategies", - type=str, - default=[ - "group_bands", - "random_timesteps", - "chunk_timesteps", - "random_combinations", - ], - nargs="+", - help="`all` will use all available masking strategies (including single bands)", -) -argparser.add_argument("--mask_ratio", type=float, default=0.75) argparser.add_argument("--seed", type=int, default=DEFAULT_SEED) argparser.add_argument("--num_workers", type=int, default=4) argparser.add_argument("--wandb", dest="wandb", action="store_true") @@ -109,30 +78,12 @@ initialize_logging(model_logging_dir) logger.info("Using output dir: %s" % model_logging_dir) -num_epochs = args["n_epochs"] -val_per_n_steps = args["val_per_n_steps"] -max_learning_rate = args["max_learning_rate"] -min_learning_rate = args["min_learning_rate"] -warmup_epochs = args["warmup_epochs"] -weight_decay = args["weight_decay"] -batch_size = args["batch_size"] - -# Default mask strategies and mask_ratio -mask_strategies: Tuple[str, ...] = tuple(args["mask_strategies"]) -if (len(mask_strategies) == 1) and (mask_strategies[0] == "all"): - mask_strategies = MASK_STRATEGIES -mask_ratio: float = args["mask_ratio"] - parquet_file: str = args["parquet_file"] val_samples_file: str = args["val_samples_file"] path_to_config = config_dir / "default.json" model_kwargs = json.load(Path(path_to_config).open("r")) -logger.info("Setting up dataloaders") - -# Load the mask parameters -mask_params = MaskParamsNoDw(mask_strategies, mask_ratio) df = pd.read_parquet(data_dir / parquet_file) if (data_dir / val_samples_file).exists(): @@ -141,26 +92,14 @@ train_df, val_df = WorldCerealDataset.split_df(df, val_sample_ids=val_samples) else: train_df, val_df = WorldCerealDataset.split_df(df) -train_dataloader = DataLoader( - WorldCerealDataset(train_df, mask_params=mask_params), - batch_size=batch_size, - shuffle=True, - num_workers=num_workers, -) -val_dataloader = DataLoader( - WorldCerealDataset(val_df, mask_params=mask_params), - batch_size=batch_size, - shuffle=False, - num_workers=num_workers, -) - -if val_per_n_steps == -1: - val_per_n_steps = len(train_dataloader) logger.info("Setting up model") if warm_start: model_kwargs = json.load(Path(config_dir / "default.json").open("r")) - model = Presto.load_pretrained() + model = Presto.load_pretrained( + valid_month_as_token=model_kwargs["valid_month_as_token"], + valid_month_size=model_kwargs["valid_month_size"], + ) best_model_path: Optional[Path] = default_model_path else: if path_to_config == "": @@ -170,157 +109,6 @@ best_model_path = None model.to(device) -param_groups = param_groups_weight_decay(model, weight_decay) -optimizer = optim.AdamW(param_groups, lr=max_learning_rate, betas=(0.9, 0.95)) -mse = LossWrapper(nn.MSELoss()) - -training_config = { - "model": model.__class__, - "encoder": model.encoder.__class__, - "decoder": model.decoder.__class__, - "optimizer": optimizer.__class__.__name__, - "eo_loss": mse.loss.__class__.__name__, - "device": device, - "logging_dir": model_logging_dir, - **args, - **model_kwargs, -} - -if wandb_enabled: - wandb.config.update(training_config) - -lowest_validation_loss = None -best_val_epoch = 0 -training_step = 0 -num_validations = 0 - -with tqdm(range(num_epochs), desc="Epoch") as tqdm_epoch: - for epoch in tqdm_epoch: - # ------------------------ Training ---------------------------------------- - total_eo_train_loss = 0.0 - num_updates_being_captured = 0 - train_size = 0 - model.train() - for epoch_step, b in enumerate(tqdm(train_dataloader, desc="Train", leave=False)): - mask, x, y, start_month = b[0].to(device), b[2].to(device), b[3].to(device), b[6] - dw_mask, x_dw, y_dw = b[1].to(device), b[4].to(device).long(), b[5].to(device).long() - latlons, real_mask = b[7].to(device), b[9].to(device) - # zero the parameter gradients - optimizer.zero_grad() - lr = adjust_learning_rate( - optimizer, - epoch_step / len(train_dataloader) + epoch, - warmup_epochs, - num_epochs, - max_learning_rate, - min_learning_rate, - ) - # Get model outputs and calculate loss - y_pred, dw_pred = model( - x, mask=mask, dynamic_world=x_dw, latlons=latlons, month=start_month - ) - # set all SRTM timesteps except the first one to unmasked, so that - # they will get ignored by the loss function even if the SRTM - # value was masked - mask[:, 1:, BANDS_GROUPS_IDX["SRTM"]] = False - # set the "truly masked" values to unmasked, so they also get ignored in the loss - mask[real_mask] = False - loss = mse(y_pred[mask], y[mask]) - loss.backward() - optimizer.step() - - current_batch_size = len(x) - total_eo_train_loss += loss.item() - num_updates_being_captured += 1 - train_size += current_batch_size - training_step += 1 - - # ------------------------ Validation -------------------------------------- - if training_step % val_per_n_steps == 0: - total_eo_val_loss = 0.0 - num_val_updates_captured = 0 - val_size = 0 - model.eval() - - with torch.no_grad(): - for b in tqdm(val_dataloader, desc="Validate"): - mask, x, y, start_month, real_mask = ( - b[0].to(device), - b[2].to(device), - b[3].to(device), - b[6], - b[9].to(device), - ) - dw_mask, x_dw = b[1].to(device), b[4].to(device).long() - y_dw, latlons = b[5].to(device).long(), b[7].to(device) - # Get model outputs and calculate loss - y_pred, dw_pred = model( - x, mask=mask, dynamic_world=x_dw, latlons=latlons, month=start_month - ) - # set all SRTM timesteps except the first one to unmasked, so that - # they will get ignored by the loss function even if the SRTM - # value was masked - mask[:, 1:, BANDS_GROUPS_IDX["SRTM"]] = False - # set the "truly masked" values to unmasked, so they also get - # ignored in the loss - mask[real_mask] = False - loss = mse(y_pred[mask], y[mask]) - current_batch_size = len(x) - total_eo_val_loss += loss.item() - num_val_updates_captured += 1 - - # ------------------------ Metrics + Logging ------------------------------- - # train_loss now reflects the value against which we calculate gradients - train_eo_loss = total_eo_train_loss / num_updates_being_captured - val_eo_loss = total_eo_val_loss / num_val_updates_captured - - if "train_size" not in training_config and "val_size" not in training_config: - training_config["train_size"] = train_size - training_config["val_size"] = val_size - if wandb_enabled: - wandb.config.update(training_config) - - to_log = { - "train_eo_loss": train_eo_loss, - "val_eo_loss": val_eo_loss, - "training_step": training_step, - "epoch": epoch, - "lr": lr, - } - tqdm_epoch.set_postfix(loss=val_eo_loss) - - if lowest_validation_loss is None or val_eo_loss < lowest_validation_loss: - lowest_validation_loss = val_eo_loss - best_val_epoch = epoch - - model_path = model_logging_dir / Path("models") - model_path.mkdir(exist_ok=True, parents=True) - - best_model_path = model_path / f"{model_name}{epoch}.pt" - logger.info(f"Saving best model to: {best_model_path}") - torch.save(model.state_dict(), best_model_path) - - # reset training logging - total_eo_train_loss = 0.0 - num_updates_being_captured = 0 - train_size = 0 - num_validations += 1 - - if wandb_enabled: - wandb.log(to_log) - - model.train() - -logger.info(f"Trained for {num_epochs} epochs, best model at {best_model_path}") - -if best_model_path is not None: - logger.info("Loading best model: %s" % best_model_path) - best_model = torch.load(best_model_path, map_location=device) - model.load_state_dict(best_model, strict=False) -else: - logger.info("Running eval with randomly init weights") - - # full finetuning full_finetuning = WorldCerealFinetuning(train_df, val_df) finetuned_model = full_finetuning.finetune(model) From f12493ff881ada088081892afc1d40e2d9fe4273 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Wed, 6 Mar 2024 22:06:42 +0800 Subject: [PATCH 16/18] Fix kwarg --- presto/presto.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/presto/presto.py b/presto/presto.py index 1b350dc..5d94562 100644 --- a/presto/presto.py +++ b/presto/presto.py @@ -481,9 +481,11 @@ def forward( x, upd_mask, orig_indices = self.add_token(latlon_tokens, x, upd_mask, orig_indices) if valid_month is not None: - val_month_token = self.valid_month_encoding(valid_month).unsqueeze(1) + val_month_token = self.valid_month_encoding(valid_month) if self.valid_month_as_token: - x, upd_mask, orig_indices = self.add_token(val_month_token, x, upd_mask, orig_indices) + x, upd_mask, orig_indices = self.add_token( + val_month_token.unsqueeze(1), x, upd_mask, orig_indices + ) # apply Transformer blocks for blk in self.blocks: @@ -494,7 +496,7 @@ def forward( # set masked tokens to 0 x_for_mean = x * (1 - upd_mask.unsqueeze(-1)) x_mean = x_for_mean.sum(dim=1) / torch.sum(1 - upd_mask, -1, keepdim=True) - return torch.cat([self.norm(x_mean), val_month_token], axis=-1) + return torch.cat([self.norm(x_mean), val_month_token], dim=-1) return self.norm(x), orig_indices, upd_mask From 99dacbc5d03fe4152629fda450e26cc46453e06f Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Wed, 6 Mar 2024 22:13:02 +0800 Subject: [PATCH 17/18] Fix finetuning model, handle the case where valid_month is None --- presto/presto.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/presto/presto.py b/presto/presto.py index 5d94562..e8ec04e 100644 --- a/presto/presto.py +++ b/presto/presto.py @@ -292,6 +292,7 @@ def __init__( self.valid_month_as_token = valid_month_as_token if valid_month_as_token: assert valid_month_size == embedding_size + self.valid_month_size = valid_month_size # this is used for the channel embedding self.band_group_to_idx = { @@ -482,10 +483,15 @@ def forward( if valid_month is not None: val_month_token = self.valid_month_encoding(valid_month) - if self.valid_month_as_token: - x, upd_mask, orig_indices = self.add_token( - val_month_token.unsqueeze(1), x, upd_mask, orig_indices - ) + if self.valid_month_as_token: + x, upd_mask, orig_indices = self.add_token( + val_month_token.unsqueeze(1), x, upd_mask, orig_indices + ) + else: + # if it is None, we ignore it as a token but do add it to + # the output embedding + valid_month = torch.ones((x.shape[0],), device=x.device).long() + val_month_token = self.valid_month_encoding(valid_month) # apply Transformer blocks for blk in self.blocks: @@ -809,9 +815,13 @@ def construct_finetuning_model( self, num_outputs: int, ) -> PrestoFineTuningModel: + if not self.encoder.valid_month_as_token: + hidden_size = self.encoder.embedding_size + self.encoder.valid_month_size + else: + hidden_size = self.encoder.embedding_size head = FinetuningHead( num_outputs=num_outputs, - hidden_size=self.encoder.embedding_size, + hidden_size=hidden_size, ) model = PrestoFineTuningModel(self.encoder, head).to(self.encoder.pos_embed.device) model.train() From 9318948c4dc8c97f4dd55199a74afdc84adb267f Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Wed, 6 Mar 2024 22:18:14 +0800 Subject: [PATCH 18/18] Fix test --- tests/test_presto.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_presto.py b/tests/test_presto.py index 74b442c..64df8c1 100644 --- a/tests/test_presto.py +++ b/tests/test_presto.py @@ -364,7 +364,9 @@ def test_finetuning_model_outputs_equivalent(self): def test_load_pretrained_works_for_finetuned_model(self): path_to_finetuned_model = data_dir / "finetuned_model.pt" - model = Presto.load_pretrained().construct_finetuning_model(num_outputs=1) + model = Presto.load_pretrained(valid_month_as_token=True).construct_finetuning_model( + num_outputs=1 + ) model.load_state_dict( torch.load(path_to_finetuned_model, map_location=device), strict=False )