-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_models.py
150 lines (131 loc) · 5.01 KB
/
train_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""Train the models."""
import os
import json
import logging
import time
from datetime import datetime
import jax
from pathlib import Path
from jax_canveg import train_model
from jax_canveg.shared_utilities import tune_jax_naninfs_for_debug
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_traceback_filtering", "off")
tune_jax_naninfs_for_debug(False)
# Start logging information
ts = time.time()
time_label = datetime.fromtimestamp(ts).strftime("%Y-%m-%d-%H:%M:%S")
logging.basicConfig(
filename=f"train-Me2-{time_label}.log",
filemode="w",
datefmt="%H:%M:%S",
level=logging.INFO,
format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
)
# Current directory
dir_mother = Path(os.path.dirname(os.path.realpath(__file__)))
################################################################
# General configuration
################################################################
f_configs_template = dir_mother / "./test-model/configs.json"
model_configs = {
"time zone": -8,
"latitude": 44.4523,
"longitude": -121.5574,
"stomata type": 0,
"leaf angle type": 0,
"canopy height": 18.0,
"measurement height": 34.0,
"soil respiration module": 1,
}
learning_config = {
"batch size": 1024,
"number of epochs": 300,
# "number of epochs": 2,
"output scaler": "standard",
}
data_config = {
"training forcings": "../../../data/fluxtower/US-Me2/US-Me2-forcings.csv",
"training fluxes": "../../../data/fluxtower/US-Me2/US-Me2-fluxes.csv",
"test forcings": "../../../data/fluxtower/US-Me2/US-Me2-forcings-test.csv",
"test fluxes": "../../../data/fluxtower/US-Me2/US-Me2-fluxes-test.csv",
}
################################################################
# Configurations for
# - canopy layers
# - hybrid model
# - multiobjective optimization
################################################################
canopy_layers = ["1L", "ML"]
model_types = ["PB", "Hybrid"]
multi_optim_le_weight = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
# multi_optim_le_weight = [0.0, 0.5, 1.0]
canopy_layers_config = {
"1L": {
"number of canopy layers": 1,
"dispersion matrix": "../../../data/dij/Dij_US-Me2_1L.csv",
},
"ML": {
"number of canopy layers": 50,
"dispersion matrix": "../../../data/dij/Dij_US-Me2_50L.csv",
},
}
model_types_config = {
"PB": {"leaf relative humidity module": 0},
"Hybrid": {"leaf relative humidity module": 1},
}
################################################################
# Load the default configuration
################################################################
with open(f_configs_template, "r") as f:
configs = json.load(f)
################################################################
# Train the models
################################################################
for cl in canopy_layers:
cl_config = canopy_layers_config[cl]
for mt in model_types:
mt_config = model_types_config[mt]
for mow in multi_optim_le_weight:
# Step 0: Stay in the current directory
os.chdir(dir_mother)
# Step 1: Case folder name
dir_name = dir_mother / f"{mt}-{cl}-{mow}"
f_configs = dir_name / "configs.json"
logging.info("")
logging.info(f"The model: {f_configs}.")
# Step 2-a: Create the folder if not existed
if not dir_name.is_dir():
dir_name.mkdir()
# Step 2-b: Continue to the next loop if the folder and results exist
else:
files = dir_name.glob("**/*")
check_config, check_model = False, False
for f in files:
if f == f_configs:
check_config = True
if f.suffix == ".eqx":
check_model = True
if check_model and check_config:
logging.info(
f"The model has been trained in {f_configs}. Continue to the next model." # noqa: E501
)
continue
# Step 3: Create the configuration file
cfg = configs.copy()
for key, value in model_configs.items():
cfg["model configurations"][key] = value
for key, value in learning_config.items():
cfg["learning configurations"][key] = value
for key, value in data_config.items():
cfg["data"][key] = value
for key, value in cl_config.items():
cfg["model configurations"][key] = value
for key, value in mt_config.items():
cfg["model configurations"][key] = value
cfg["learning configurations"]["loss function"]["weights"] = [mow, 1 - mow]
# Step 4: Save it to the designated folder
with open(f_configs, "w") as f:
json.dump(cfg, f, indent=4)
# Step 5: Launch training!
logging.info("Start training!")
train_model(f_configs)