-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathexample.py
31 lines (23 loc) · 1022 Bytes
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from climatenet.utils.data import ClimateDatasetLabeled, ClimateDataset
from climatenet.models import CGNet
from climatenet.utils.utils import Config
from climatenet.track_events import track_events
from climatenet.analyze_events import analyze_events
from climatenet.visualize_events import visualize_events
from os import path
config = Config('config.json')
cgnet = CGNet(config)
train_path = 'PATH_TO_TRAINING_SET'
inference_path = 'PATH_TO_INFERENCE_SET'
train = ClimateDatasetLabeled(path.join(train_path, 'train'), config)
test = ClimateDatasetLabeled(path.join(train_path, 'test'), config)
inference = ClimateDataset(inference_path, config)
cgnet.train(train)
cgnet.evaluate(test)
cgnet.save_model('trained_cgnet')
# use a saved model with
# cgnet.load_model('trained_cgnet')
class_masks = cgnet.predict(inference) # masks with 1==TC, 2==AR
event_masks = track_events(class_masks) # masks with event IDs
analyze_events(event_masks, class_masks, 'results/')
visualize_events(event_masks, inference, 'pngs/')