-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_dnns.py
146 lines (119 loc) · 4.43 KB
/
train_dnns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""Train DNNs."""
import os
from pathlib import Path
import numpy as np
import logging
import time
from datetime import datetime
# import jax
# import equinox as eqx
import jax.numpy as jnp
import optax
from jax_canveg.subjects import get_met_forcings, get_obs
from jax_canveg.shared_utilities.plot import get_time
from jax_canveg.shared_utilities.dnn import MLP, train_dnn
from jax_canveg.shared_utilities.optim import mse, weighted_loss
##################################################
# General configurations
##################################################
# Current directory
dir_mother = Path(os.path.dirname(os.path.realpath(__file__)))
# Files and directories
site, key = "US-Bi1", "dl"
f_forcing_train = f"../../data/fluxtower/{site}/{site}-forcings.csv"
f_obs_train = f"../../data/fluxtower/{site}/{site}-fluxes.csv"
f_forcing_test = f"../../data/fluxtower/{site}/{site}-forcings-test.csv"
f_obs_test = f"../../data/fluxtower/{site}/{site}-fluxes-test.csv"
# Input variables
in_varns = ["T_air", "rglobal", "eair", "wind", "CO2",
"P_kPa", "ustar", "soilmoisture", "lai"]
# DNN hyperparameters
batch_size = 64
# batch_size = 1024
initial_lr = 2e-1
nsteps = 300
# nsteps = 10
seed = 5678
scaler_type = 'standard'
model_type = MLP
model_args = {
"depth": 2, "width_size": 6, "model_seed": seed,
"out_size": 2, "hidden_activation": "tanh",
"final_activation": "identity", "in_size": len(in_varns)
}
# Start logging information
ts = time.time()
time_label = datetime.fromtimestamp(ts).strftime("%Y-%m-%d-%H:%M:%S")
logging.basicConfig(
filename=f"train-{site}-{time_label}.log",
filemode="w",
datefmt="%H:%M:%S",
level=logging.INFO,
format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
)
##################################################
# Read data and scaler
##################################################
# Get inputs from forcings
met_train, n_time_train = get_met_forcings(f_forcing_train)
timesteps_train = get_time(met_train)
x_train = np.array([getattr(met_train, varn) for varn in in_varns]).T
met_test, n_time_test = get_met_forcings(f_forcing_test)
timesteps_test = get_time(met_test)
x_test = np.array([getattr(met_test, varn) for varn in in_varns]).T
# Get the observations
obs_train, obs_test = get_obs(f_obs_train), get_obs(f_obs_test)
# Get the observed outputs
# y_train, y_test = obs_train.LE, obs_test.LE
y_train = np.array([obs_train.LE, obs_train.Fco2]).T
y_test = np.array([obs_test.LE, obs_test.Fco2]).T
##################################################
# Initialize optimizer and scheduler
##################################################
scheduler = optax.piecewise_constant_schedule(
initial_lr, boundaries_and_scales={50: 0.1, 100: 0.1, 200: 0.1}
)
# scheduler = optax.constant_schedule(initial_lr)
optim = optax.adam(learning_rate=scheduler) # Adam optimizer
##################################################
# Train DNNs
##################################################
w_set = [0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
# w_set = [0., 0.5, 1.0]
for w in w_set:
# Get the saving dir
dir_save = dir_mother / f"DNN_LE-GPP-{w}"
# Define the weighted normalized function
weights = jnp.array([w, 1-w])
def loss_func(y, pred_y):
return weighted_loss(y, pred_y, mse, weights)
# Train the model
train_dnn(
dir_save, model_type, model_args,
x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test,
batch_size=batch_size, nsteps=nsteps, scaler_type=scaler_type,
optim=optim, loss_func=loss_func, save_log_local=False
)
# ################################################
# # Train a DNN for LE only
# ##################################################
# # Get the observed outputs
# y_train, y_test = obs_train.LE, obs_test.LE
# # Get the saving dir
# dir_save = dir_mother / "DNN_LE"
# # Train the models
# model_args = mlp_configs.copy()
# model_args['in_size'] = len(in_varns)
# model_args['out_size'] = 2
# train_dnn(
# dir_save, model_type, model_args,
# x_train = x_train, y_train=y_train, x_test=x_test, y_test=y_test,
# batch_size=batch_size, nsteps=nsteps, scaler=scaler, optim=optim,
# loss_func=mse, save_log_local=True
# )
##################################################
# Train a DNN for NEE only
##################################################
##################################################
# Train a DNN for both LE and NEE
##################################################