Skip to content

Commit

Permalink
fixed ACC, add ml ensemble
Browse files Browse the repository at this point in the history
  • Loading branch information
juannat7 committed Aug 27, 2024
1 parent 72e631e commit 62d08ab
Show file tree
Hide file tree
Showing 41 changed files with 394 additions and 162 deletions.
18 changes: 18 additions & 0 deletions chaosbench/configs/resnet_ensemble_s2s.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
model_args:
model_name: 'resnet_ensemble_s2s'
input_size: 60
output_size: 60
learning_rate: 0.01
num_workers: 12
epochs: 500
t_max: 500
only_headline: False

data_args:
batch_size: 32
train_years: [1979, 1980, 1981, 1982, 1983, 1984, 1985, 1986, 1987, 1988, 1989, 1990, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015]
val_years: [2016, 2017, 2018, 2019, 2020, 2021]
n_step: 3
lead_time: 1
land_vars: []
ocean_vars: []
4 changes: 2 additions & 2 deletions chaosbench/configs/resnet_s2s.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ data_args:
val_years: [2016, 2017, 2018, 2019, 2020, 2021]
n_step: 1
lead_time: 1
land_vars: ['skt', 'src', 'stl1', 'stl2', 'stl3', 'swvl1', 'swvl2', 'swvl3']
ocean_vars: ['somxl010', 'somxl030', 'sosaline', 'sossheig', 'sosstsst']
land_vars: []
ocean_vars: []
18 changes: 18 additions & 0 deletions chaosbench/configs/unet_ensemble_s2s.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
model_args:
model_name: 'unet_ensemble_s2s'
input_size: 60
output_size: 60
learning_rate: 0.01
num_workers: 12
epochs: 500
t_max: 500
only_headline: False

data_args:
batch_size: 32
train_years: [1979, 1980, 1981, 1982, 1983, 1984, 1985, 1986, 1987, 1988, 1989, 1990, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015]
val_years: [2016, 2017, 2018, 2019, 2020, 2021]
n_step: 3
lead_time: 1
land_vars: []
ocean_vars: []
26 changes: 13 additions & 13 deletions chaosbench/criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def __init__(self,

# Retrieve climatology
self.normalization_file = {
'era5': Path(config.DATA_DIR) / 'climatology' / 'climatology_era5.zarr',
'lra5': Path(config.DATA_DIR) / 'climatology' / 'climatology_lra5.zarr',
'oras5': Path(config.DATA_DIR) / 'climatology' / 'climatology_oras5.zarr'
'era5': Path(config.DATA_DIR) / 'climatology' / 'climatology_era5_spatial.zarr',
'lra5': Path(config.DATA_DIR) / 'climatology' / 'climatology_lra5_spatial.zarr',
'oras5': Path(config.DATA_DIR) / 'climatology' / 'climatology_oras5_spatial.zarr'
}

self.normalization_mean = {
Expand All @@ -180,19 +180,19 @@ def __init__(self,
'oras5': xr.open_dataset(self.normalization_file['oras5'], engine='zarr')['mean'],
}

def forward(self, predictions, targets, param, source):
def forward(self, predictions, targets, doys, param, source):

# Retrieve mean climatology
climatology = torch.tensor(self.normalization_mean[source].sel(param=param).values).to(config.device)

# Compute only valid values
valid_mask = ~torch.isnan(predictions) & ~torch.isnan(targets)
predictions, targets = predictions[valid_mask], targets[valid_mask]

climatology = torch.tensor(self.normalization_mean[source].sel(doy=doys, param=param).values).to(config.device)

# Compute anomalies
anomalies_targets = targets - climatology
anomalies_predictions = predictions - climatology


# Compute only valid values
valid_mask = ~torch.isnan(anomalies_predictions) & ~torch.isnan(anomalies_targets)
anomalies_predictions, anomalies_targets = anomalies_predictions[valid_mask], anomalies_targets[valid_mask]

if self.lat_adjusted:
anomalies_targets = self.weights.size(1) * (self.weights / torch.sum(self.weights)) * anomalies_targets
anomalies_predictions = self.weights.size(1) * (self.weights / torch.sum(self.weights)) * anomalies_predictions
Expand Down Expand Up @@ -628,8 +628,8 @@ def forward(self, predictions, targets, doys, param, source):
opts = dict(device=predictions.device, dtype=predictions.dtype)

# Get climatology
clima_mean = self.normalization[source]['mean'].sel(doy=doys, param=param).values
clima_sigma = self.normalization[source]['sigma'].sel(doy=doys, param=param).values
clima_mean = torch.tensor(self.normalization[source]['mean'].sel(doy=doys, param=param).values).to(config.device)
clima_sigma = torch.tensor(self.normalization[source]['sigma'].sel(doy=doys, param=param).values).to(config.device)

if self.lat_adjusted:
predictions = self.weights.size(1) * (self.weights / torch.sum(self.weights)) * predictions
Expand Down
1 change: 0 additions & 1 deletion chaosbench/models/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def __init__(
self.out_conv = nn.Conv2d(64, output_size, 1)



def forward(self, x):
IS_MERGED = False # To handle legacy code where the inputs are separated by pressure level

Expand Down
10 changes: 4 additions & 6 deletions chaosbench/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def __init__(

elif 'unet' in self.model_args['model_name']:
self.model = cnn.UNet(input_size = input_size,
output_size = output_size)
output_size = output_size)

elif 'resnet' in self.model_args['model_name']:
self.model = cnn.ResNet(input_size = input_size,
output_size = output_size)
output_size = output_size)

elif 'vae' in self.model_args['model_name']:
self.model = ae.VAE(input_size = input_size,
Expand Down Expand Up @@ -158,14 +158,12 @@ def setup(self, stage=None):
n_step=self.data_args['n_step'],
lead_time=self.data_args['lead_time'],
land_vars=self.data_args['land_vars'],
ocean_vars=self.data_args['ocean_vars']
)
ocean_vars=self.data_args['ocean_vars'])
self.val_dataset = dataset.S2SObsDataset(years=self.data_args['train_years'],
n_step=self.data_args['n_step'],
lead_time=self.data_args['lead_time'],
land_vars=self.data_args['land_vars'],
ocean_vars=self.data_args['ocean_vars']
)
ocean_vars=self.data_args['ocean_vars'])


def train_dataloader(self):
Expand Down
Binary file added docs/center_acc.pdf
Binary file not shown.
Binary file modified docs/center_bias.pdf
Binary file not shown.
Binary file added docs/center_ens_acc.pdf
Binary file not shown.
Binary file modified docs/center_ens_bias.pdf
Binary file not shown.
Binary file modified docs/center_ens_rmse.pdf
Binary file not shown.
Binary file modified docs/center_ens_sdiv.pdf
Binary file not shown.
Binary file modified docs/center_ens_ssim.pdf
Binary file not shown.
Binary file added docs/center_ratio_acc.pdf
Binary file not shown.
Binary file modified docs/center_ratio_bias.pdf
Binary file not shown.
Binary file modified docs/center_ratio_rmse.pdf
Binary file not shown.
Binary file modified docs/center_ratio_sdiv.pdf
Binary file not shown.
Binary file modified docs/center_ratio_ssim.pdf
Binary file not shown.
Binary file modified docs/center_rmse.pdf
Binary file not shown.
Binary file modified docs/center_sdiv.pdf
Binary file not shown.
Binary file modified docs/center_ssim.pdf
Binary file not shown.
Binary file added docs/ml_probs_crps.pdf
Binary file not shown.
Binary file added docs/ml_probs_crpss.pdf
Binary file not shown.
Binary file added docs/ml_probs_spread.pdf
Binary file not shown.
Binary file added docs/ml_probs_ssr.pdf
Binary file not shown.
Binary file added docs/ml_ratio_acc.pdf
Binary file not shown.
Binary file added docs/ml_ratio_bias.pdf
Binary file not shown.
Binary file added docs/ml_ratio_rmse.pdf
Binary file not shown.
Binary file added docs/ml_ratio_sdiv.pdf
Binary file not shown.
Binary file added docs/ml_ratio_ssim.pdf
Binary file not shown.
Binary file added docs/sota_acc.pdf
Binary file not shown.
Binary file modified docs/sota_bias.pdf
Binary file not shown.
Binary file modified docs/sota_rmse.pdf
Binary file not shown.
Binary file modified docs/sota_sdiv.pdf
Binary file not shown.
Binary file modified docs/sota_ssim.pdf
Binary file not shown.
35 changes: 24 additions & 11 deletions eval_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def main(args):
"""
assert args.task_num in [1, 2]

print(f'Evaluating reanlysis against {args.model_name}...')
print(f'Evaluating reanalysis against {args.model_name}...')

#########################################
####### Evaluation initialization #######
Expand Down Expand Up @@ -99,14 +99,19 @@ def main(args):
ckpt_filepath = list(ckpt_filepath.glob('*.ckpt'))[0]
baseline = model.S2SBenchmarkModel(model_args=model_args, data_args=data_args)
baseline = baseline.load_from_checkpoint(ckpt_filepath)
baselines.append(copy.deepcopy(baseline))
baselines.append(copy.deepcopy(baseline.eval()))

## Prepare input/output dataset
lra5_vars, oras5_vars = baseline.hparams.get('land_vars', []), baseline.hparams.get('ocean_vars', [])
input_dataset = dataset.S2SObsDataset(years=args.eval_years, n_step=config.N_STEPS-1, land_vars=lra5_vars, ocean_vars=oras5_vars)
input_dataset = dataset.S2SObsDataset(
years=args.eval_years, n_step=config.N_STEPS-1, land_vars=lra5_vars, ocean_vars=oras5_vars
)
input_dataloader = DataLoader(input_dataset, batch_size=BATCH_SIZE, shuffle=False)

output_dataset = dataset.S2SObsDataset(years=args.eval_years, n_step=config.N_STEPS-1, land_vars=config.LRA5_PARAMS, ocean_vars=config.ORAS5_PARAMS, is_normalized=False)
output_dataset = dataset.S2SObsDataset(
years=args.eval_years, n_step=config.N_STEPS-1,
land_vars=config.LRA5_PARAMS, ocean_vars=config.ORAS5_PARAMS, is_normalized=False
)
output_dataloader = DataLoader(output_dataset, batch_size=BATCH_SIZE, shuffle=False)


Expand All @@ -115,10 +120,16 @@ def main(args):
IS_EXTERNAL = True
PARAM_LIST = {'era5': config.CLIMAX_VARS if args.task_num == 1 else config.HEADLINE_VARS, 'lra5': args.lra5, 'oras5': args.oras5}

input_dataset = dataset.S2SObsDataset(years=args.eval_years, n_step=config.N_STEPS-1, land_vars=config.LRA5_PARAMS, ocean_vars=config.ORAS5_PARAMS, is_normalized=False)
input_dataset = dataset.S2SObsDataset(
years=args.eval_years, n_step=config.N_STEPS-1,
land_vars=config.LRA5_PARAMS, ocean_vars=config.ORAS5_PARAMS, is_normalized=False
)
input_dataloader = DataLoader(input_dataset, batch_size=BATCH_SIZE, shuffle=False)

output_dataset = dataset.S2SObsDataset(years=args.eval_years, n_step=config.N_STEPS-1, land_vars=config.LRA5_PARAMS, ocean_vars=config.ORAS5_PARAMS, is_normalized=False)
output_dataset = dataset.S2SObsDataset(
years=args.eval_years, n_step=config.N_STEPS-1,
land_vars=config.LRA5_PARAMS, ocean_vars=config.ORAS5_PARAMS, is_normalized=False
)
output_dataloader = DataLoader(output_dataset, batch_size=BATCH_SIZE, shuffle=False)

## List external prediction
Expand All @@ -136,7 +147,7 @@ def main(args):

all_preds = np.array(all_preds)

##################### Initialize criteria ######################
##################### Initialize criteria #####################
RMSE = criterion.RMSE()
Bias = criterion.Bias()
ACC = criterion.ACC()
Expand All @@ -157,7 +168,10 @@ def main(args):
for input_batch, output_batch in tqdm(zip(input_dataloader, output_dataloader), total=len(input_dataloader)):

_, preds_x, preds_y = input_batch
_, truth_x, truth_y = output_batch
timestamps, truth_x, truth_y = output_batch

# Pre-processing (e.g., get day-of-years for climatology-related metrics...)
doys = utils.get_doys_from_timestep(timestamps)

assert preds_y.size(1) == truth_y.size(1)
N_STEPS = truth_y.size(1)
Expand Down Expand Up @@ -210,7 +224,7 @@ def main(args):
bias = Bias(unique_preds, unique_labels).cpu().numpy()

################################## Criterion 3: ACC ######################################
acc = ACC(unique_preds, unique_labels, param, param_class).cpu().numpy()
acc = ACC(unique_preds, unique_labels, doys[:, delta-1], param, param_class).cpu().numpy()

################################## Criterion 4: SSIM ######################################
ssim = SSIM(unique_preds, unique_labels).cpu().numpy()
Expand Down Expand Up @@ -240,8 +254,7 @@ def main(args):

all_param_idx += 1
param_idx = param_idx + 1 if param_exist else param_idx



all_rmse.append(step_rmse)
all_bias.append(step_bias)
all_acc.append(step_acc)
Expand Down
Loading

0 comments on commit 62d08ab

Please sign in to comment.