-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Example] Add STAFNet Model for Air Quality Prediction #1070
base: develop
Are you sure you want to change the base?
Changes from 11 commits
f79a3f9
68b23d1
d9d2b54
bfa3e69
2d9dc85
57dc7c2
fa1cdee
ab1ae03
d257a49
b43c7f5
757477a
2b46497
711cd36
a79ad1d
af96434
86a9c0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,85 @@ | ||||||||||||||||||||
defaults: | ||||||||||||||||||||
- ppsci_default | ||||||||||||||||||||
- TRAIN: train_default | ||||||||||||||||||||
- TRAIN/ema: ema_default | ||||||||||||||||||||
- TRAIN/swa: swa_default | ||||||||||||||||||||
- EVAL: eval_default | ||||||||||||||||||||
- INFER: infer_default | ||||||||||||||||||||
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default | ||||||||||||||||||||
- _self_ | ||||||||||||||||||||
hydra: | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 配置文件开头请加上以下字段: PaddleScience/examples/ldc/conf/ldc_2d_Re3200_piratenet.yaml Lines 1 to 9 in fad6927
|
||||||||||||||||||||
run: | ||||||||||||||||||||
# dynamic output directory according to running time and override name | ||||||||||||||||||||
dir: outputs_chip_heat/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} | ||||||||||||||||||||
job: | ||||||||||||||||||||
name: ${mode} # name of logfile | ||||||||||||||||||||
chdir: false # keep current working directory unchanged | ||||||||||||||||||||
callbacks: | ||||||||||||||||||||
init_callback: | ||||||||||||||||||||
_target_: ppsci.utils.callbacks.InitCallback | ||||||||||||||||||||
sweep: | ||||||||||||||||||||
# output directory for multirun | ||||||||||||||||||||
dir: ${hydra.run.dir} | ||||||||||||||||||||
subdir: ./ | ||||||||||||||||||||
|
||||||||||||||||||||
# general settings | ||||||||||||||||||||
mode: train # running mode: train/eval | ||||||||||||||||||||
seed: 42 | ||||||||||||||||||||
output_dir: ${hydra:run.dir} | ||||||||||||||||||||
log_freq: 20 | ||||||||||||||||||||
# dataset setting | ||||||||||||||||||||
STAFNet_DATA_PATH: "/data6/home/yinhang2021/dataset/chongqing_1921/train_data.pkl" # | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||
DATASET: | ||||||||||||||||||||
label_keys: ["label"] | ||||||||||||||||||||
data_dir: "/data6/home/yinhang2021/dataset/chongqing_1921/train_data.pkl" | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
MODEL: | ||||||||||||||||||||
input_keys: ["aq_train_data","mete_train_data",] | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||
output_keys: ["label"] | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||
output_attention: True | ||||||||||||||||||||
seq_len: 72 | ||||||||||||||||||||
pred_len: 48 | ||||||||||||||||||||
aq_gat_node_features: 7 | ||||||||||||||||||||
aq_gat_node_num: 35 | ||||||||||||||||||||
mete_gat_node_features: 7 | ||||||||||||||||||||
mete_gat_node_num: 18 | ||||||||||||||||||||
gat_hidden_dim: 32 | ||||||||||||||||||||
gat_edge_dim: 3 | ||||||||||||||||||||
e_layers: 1 | ||||||||||||||||||||
enc_in: 7 | ||||||||||||||||||||
dec_in: 7 | ||||||||||||||||||||
c_out: 7 | ||||||||||||||||||||
d_model: 16 | ||||||||||||||||||||
embed: "fixed" | ||||||||||||||||||||
freq: "t" | ||||||||||||||||||||
dropout: 0.05 | ||||||||||||||||||||
factor: 3 | ||||||||||||||||||||
n_heads: 4 | ||||||||||||||||||||
d_ff: 32 | ||||||||||||||||||||
num_kernels: 6 | ||||||||||||||||||||
top_k: 4 | ||||||||||||||||||||
|
||||||||||||||||||||
# training settings | ||||||||||||||||||||
TRAIN: | ||||||||||||||||||||
epochs: 100 | ||||||||||||||||||||
iters_per_epoch: 400 | ||||||||||||||||||||
save_freq: 10 | ||||||||||||||||||||
eval_during_train: true | ||||||||||||||||||||
eval_freq: 10 | ||||||||||||||||||||
batch_size: 1 | ||||||||||||||||||||
lr_scheduler: | ||||||||||||||||||||
epochs: ${TRAIN.epochs} | ||||||||||||||||||||
iters_per_epoch: ${TRAIN.iters_per_epoch} | ||||||||||||||||||||
learning_rate: 0.001 | ||||||||||||||||||||
step_size: 10 | ||||||||||||||||||||
gamma: 0.9 | ||||||||||||||||||||
pretrained_model_path: null | ||||||||||||||||||||
checkpoint_path: null | ||||||||||||||||||||
|
||||||||||||||||||||
EVAL: | ||||||||||||||||||||
eval_data_path: "/data6/home/yinhang2021/dataset/chongqing_1921/val_data.pkl" | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||
pretrained_model_path: null | ||||||||||||||||||||
compute_metric_by_batch: false | ||||||||||||||||||||
eval_with_no_grad: true | ||||||||||||||||||||
batch_size: 1 |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,153 @@ | ||||||||||||||||||||||||||||
import ppsci | ||||||||||||||||||||||||||||
from ppsci.utils import logger | ||||||||||||||||||||||||||||
from omegaconf import DictConfig | ||||||||||||||||||||||||||||
import hydra | ||||||||||||||||||||||||||||
import paddle | ||||||||||||||||||||||||||||
from ppsci.data.dataset.stafnet_dataset import gat_lstmcollate_fn | ||||||||||||||||||||||||||||
import multiprocessing | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def train(cfg: DictConfig): | ||||||||||||||||||||||||||||
# set model | ||||||||||||||||||||||||||||
model = ppsci.arch.STAFNet(**cfg.MODEL) | ||||||||||||||||||||||||||||
train_dataloader_cfg = { | ||||||||||||||||||||||||||||
"dataset": { | ||||||||||||||||||||||||||||
"name": "STAFNetDataset", | ||||||||||||||||||||||||||||
"file_path": cfg.DATASET.data_dir, | ||||||||||||||||||||||||||||
"input_keys": cfg.MODEL.input_keys, | ||||||||||||||||||||||||||||
"label_keys": cfg.MODEL.output_keys, | ||||||||||||||||||||||||||||
"seq_len": cfg.MODEL.seq_len, | ||||||||||||||||||||||||||||
"pred_len": cfg.MODEL.pred_len, | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 删除多余空行
Suggested change
|
||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||
"batch_size": cfg.TRAIN.batch_size, | ||||||||||||||||||||||||||||
"sampler": { | ||||||||||||||||||||||||||||
"name": "BatchSampler", | ||||||||||||||||||||||||||||
"drop_last": False, | ||||||||||||||||||||||||||||
"shuffle": True, | ||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||
"collate_fn": gat_lstmcollate_fn, | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
eval_dataloader_cfg= { | ||||||||||||||||||||||||||||
"dataset": { | ||||||||||||||||||||||||||||
"name": "STAFNetDataset", | ||||||||||||||||||||||||||||
"file_path": cfg.EVAL.eval_data_path, | ||||||||||||||||||||||||||||
"input_keys": cfg.MODEL.input_keys, | ||||||||||||||||||||||||||||
"label_keys": cfg.MODEL.output_keys, | ||||||||||||||||||||||||||||
"seq_len": cfg.MODEL.seq_len, | ||||||||||||||||||||||||||||
"pred_len": cfg.MODEL.pred_len, | ||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||
"batch_size": cfg.TRAIN.batch_size, | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里应该是EVAL? |
||||||||||||||||||||||||||||
"sampler": { | ||||||||||||||||||||||||||||
"name": "BatchSampler", | ||||||||||||||||||||||||||||
"drop_last": False, | ||||||||||||||||||||||||||||
"shuffle": True, | ||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||
Comment on lines
+22
to
+26
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个"sampler"字段是否可以删掉?eval应该不需要shuffle |
||||||||||||||||||||||||||||
"collate_fn": gat_lstmcollate_fn, | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
sup_constraint = ppsci.constraint.SupervisedConstraint( | ||||||||||||||||||||||||||||
train_dataloader_cfg, | ||||||||||||||||||||||||||||
loss=ppsci.loss.MSELoss("mean"), | ||||||||||||||||||||||||||||
name="STAFNet_Sup", | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
constraint = {sup_constraint.name: sup_constraint} | ||||||||||||||||||||||||||||
sup_validator = ppsci.validate.SupervisedValidator( | ||||||||||||||||||||||||||||
eval_dataloader_cfg, | ||||||||||||||||||||||||||||
loss=ppsci.loss.MSELoss("mean"), | ||||||||||||||||||||||||||||
metric={"MSE": ppsci.metric.MSE()}, | ||||||||||||||||||||||||||||
name="Sup_Validator", | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
validator = {sup_validator.name: sup_validator} | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# set optimizer | ||||||||||||||||||||||||||||
lr_scheduler = ppsci.optimizer.lr_scheduler.Step(**cfg.TRAIN.lr_scheduler)() | ||||||||||||||||||||||||||||
LEARNING_RATE = cfg.TRAIN.lr_scheduler.learning_rate | ||||||||||||||||||||||||||||
optimizer = ppsci.optimizer.Adam(LEARNING_RATE)(model) | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在paddle里,如果你的学习率是lr_scheduler,那么就需要把optimizer的learning_rate设置为lr_scheduler,而不是初始学习率 |
||||||||||||||||||||||||||||
output_dir = cfg.output_dir | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||
ITERS_PER_EPOCH = len(sup_constraint.data_loader) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# initialize solver | ||||||||||||||||||||||||||||
solver = ppsci.solver.Solver( | ||||||||||||||||||||||||||||
model, | ||||||||||||||||||||||||||||
constraint, | ||||||||||||||||||||||||||||
output_dir, | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||
optimizer, | ||||||||||||||||||||||||||||
lr_scheduler, | ||||||||||||||||||||||||||||
cfg.TRAIN.epochs, | ||||||||||||||||||||||||||||
ITERS_PER_EPOCH, | ||||||||||||||||||||||||||||
eval_during_train=cfg.TRAIN.eval_during_train, | ||||||||||||||||||||||||||||
seed=cfg.seed, | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||
validator=validator, | ||||||||||||||||||||||||||||
compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch, | ||||||||||||||||||||||||||||
eval_with_no_grad=cfg.EVAL.eval_with_no_grad, | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# train model | ||||||||||||||||||||||||||||
solver.train() | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def evaluate(cfg: DictConfig): | ||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||
Validate after training an epoch | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
:param epoch: Integer, current training epoch. | ||||||||||||||||||||||||||||
:return: A log that contains information about validation | ||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 删除注释 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||
model = ppsci.arch.STAFNet(**cfg.MODEL) | ||||||||||||||||||||||||||||
eval_dataloader_cfg= { | ||||||||||||||||||||||||||||
"dataset": { | ||||||||||||||||||||||||||||
"name": "STAFNetDataset", | ||||||||||||||||||||||||||||
"file_path": cfg.EVAL.eval_data_path, | ||||||||||||||||||||||||||||
"input_keys": cfg.MODEL.input_keys, | ||||||||||||||||||||||||||||
"label_keys": cfg.MODEL.output_keys, | ||||||||||||||||||||||||||||
"seq_len": cfg.MODEL.seq_len, | ||||||||||||||||||||||||||||
"pred_len": cfg.MODEL.pred_len, | ||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||
"batch_size": cfg.TRAIN.batch_size, | ||||||||||||||||||||||||||||
"sampler": { | ||||||||||||||||||||||||||||
"name": "BatchSampler", | ||||||||||||||||||||||||||||
"drop_last": False, | ||||||||||||||||||||||||||||
"shuffle": True, | ||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||
Comment on lines
+74
to
+78
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||
"collate_fn": gat_lstmcollate_fn, | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
sup_validator = ppsci.validate.SupervisedValidator( | ||||||||||||||||||||||||||||
eval_dataloader_cfg, | ||||||||||||||||||||||||||||
loss=ppsci.loss.MSELoss("mean"), | ||||||||||||||||||||||||||||
metric={"MSE": ppsci.metric.MSE()}, | ||||||||||||||||||||||||||||
name="Sup_Validator", | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
validator = {sup_validator.name: sup_validator} | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# initialize solver | ||||||||||||||||||||||||||||
solver = ppsci.solver.Solver( | ||||||||||||||||||||||||||||
model, | ||||||||||||||||||||||||||||
validator=validator, | ||||||||||||||||||||||||||||
cfg=cfg, | ||||||||||||||||||||||||||||
pretrained_model_path=cfg.EVAL.pretrained_model_path, | ||||||||||||||||||||||||||||
compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch, | ||||||||||||||||||||||||||||
eval_with_no_grad=cfg.EVAL.eval_with_no_grad, | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# evaluate model | ||||||||||||||||||||||||||||
solver.eval() | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
@hydra.main(version_base=None, config_path="./conf", config_name="stafnet.yaml") | ||||||||||||||||||||||||||||
def main(cfg: DictConfig): | ||||||||||||||||||||||||||||
if cfg.mode == "train": | ||||||||||||||||||||||||||||
train(cfg) | ||||||||||||||||||||||||||||
elif cfg.mode == "eval": | ||||||||||||||||||||||||||||
evaluate(cfg) | ||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if __name__ == "__main__": | ||||||||||||||||||||||||||||
# set random seed for reproducibility | ||||||||||||||||||||||||||||
ppsci.utils.misc.set_random_seed(42) | ||||||||||||||||||||||||||||
# set output directory | ||||||||||||||||||||||||||||
OUTPUT_DIR = "./output_example" | ||||||||||||||||||||||||||||
# initialize logger | ||||||||||||||||||||||||||||
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个可以删除,output_dir会由 PaddleScience/ppsci/utils/callbacks.py Lines 90 to 96 in fad6927
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||
multiprocessing.set_start_method("spawn") | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这句代码是什么作用?paddle的多卡训练不需要这样吧? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我这边如果不加 multiprocessing.set_start_method("spawn"),会出现cuda error(3) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件建议使用vscode的yaml插件格式化一下,或者提交前用pre-commit格式化:https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/development/#1