-
Notifications
You must be signed in to change notification settings - Fork 187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Example] Add STAFNet Model for Air Quality Prediction #1070
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
整体项目请使用pre-commit格式化一边
@@ -0,0 +1,136 @@ | |||
hydra: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
配置文件开头请加上以下字段:
PaddleScience/examples/ldc/conf/ldc_2d_Re3200_piratenet.yaml
Lines 1 to 9 in fad6927
defaults: | |
- ppsci_default | |
- TRAIN: train_default | |
- TRAIN/ema: ema_default | |
- TRAIN/swa: swa_default | |
- EVAL: eval_default | |
- INFER: infer_default | |
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default | |
- _self_ |
examples/demo/conf/stafnet.yaml
Outdated
config: | ||
override_dirname: | ||
exclude_keys: | ||
- TRAIN.checkpoint_path | ||
- TRAIN.pretrained_model_path | ||
- EVAL.pretrained_model_path | ||
- mode | ||
- output_dir | ||
- log_freq |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以删了
examples/demo/conf/stafnet.yaml
Outdated
STAFNet_DATA_PATH: "/data6/home/yinhang2021/workspace/SATFNet/data/2020-2023_new/train_data.pkl" # | ||
DATASET: | ||
label_keys: ["label"] | ||
data_dir: "/data6/home/yinhang2021/workspace/SATFNet/data/2020-2023_new/train_data.pkl" | ||
STAFNet_DATA_args: { | ||
"data_dir": "/data6/home/yinhang2021/workspace/SATFNet/data/2020-2023_new/train_data.pkl", | ||
"batch_size": 1, | ||
"shuffle": True, | ||
"num_workers": 0, | ||
"training": True | ||
} | ||
|
||
|
||
|
||
# "data_dir": "data/2020-2023_new/train_data.pkl", | ||
# "batch_size": 32, | ||
# "shuffle": True, | ||
# "num_workers": 0, | ||
# "training": True | ||
# model settings | ||
# MODEL: # | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议改为相对路径,以./data/...
开头即可
examples/demo/conf/stafnet.yaml
Outdated
# "data_dir": "data/2020-2023_new/train_data.pkl", | ||
# "batch_size": 32, | ||
# "shuffle": True, | ||
# "num_workers": 0, | ||
# "training": True | ||
# model settings | ||
# MODEL: # |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个注释如果没用的可以删除
examples/demo/conf/stafnet.yaml
Outdated
# configs: { | ||
# "task_name": "forecast", | ||
# "output_attention": False, | ||
# "seq_len": 72, | ||
# "label_len": 24, | ||
# "pred_len": 48, | ||
|
||
# "aq_gat_node_features" : 7, | ||
# "aq_gat_node_num": 35, | ||
|
||
# "mete_gat_node_features" : 7, | ||
# "mete_gat_node_num": 18, | ||
|
||
# "gat_hidden_dim": 32, | ||
# "gat_edge_dim": 3, | ||
# "gat_embed_dim": 32, | ||
|
||
# "e_layers": 1, | ||
# "enc_in": 7, | ||
# "dec_in": 7, | ||
# "c_out": 7, | ||
# "d_model": 16 , | ||
# "embed": "fixed", | ||
# "freq": "t", | ||
# "dropout": 0.05, | ||
# "factor": 3, | ||
# "n_heads": 4, | ||
|
||
# "d_ff": 32 , | ||
# "num_kernels": 6, | ||
# "top_k": 4 | ||
# } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,如果没用可以删除
examples/demo/demo.py
Outdated
|
||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
避免连续空行
# set random seed for reproducibility | ||
ppsci.utils.misc.set_random_seed(42) | ||
# set output directory | ||
OUTPUT_DIR = "./output_example" | ||
# initialize logger | ||
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以删除,output_dir会由ppsci.utils.callbacks.InitCallback
自动创建:
PaddleScience/ppsci/utils/callbacks.py
Lines 90 to 96 in fad6927
logger.init_logger( | |
"ppsci", | |
osp.join(full_cfg.output_dir, f"{full_cfg.mode}.log") | |
if full_cfg.output_dir and full_cfg.mode not in ["export", "infer"] | |
else None, | |
full_cfg.log_level, | |
) |
from typing import Tuple | ||
|
||
class Inception_Block_V1(paddle.nn.Layer): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
冗余的空行请删除,下同
output_dir: ${hydra:run.dir} | ||
log_freq: 20 | ||
# dataset setting | ||
STAFNet_DATA_PATH: "/data6/home/yinhang2021/dataset/chongqing_1921/train_data.pkl" # |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 这里的路径是否能改成相对路径?比如
./dataset/train_data.pkl
,其余的路径字段也是,建议改为相对路径,并去掉用户名 - STAFNet_DATA_PATH是否应该放到DATASET字段下?
|
||
|
||
MODEL: | ||
input_keys: ["aq_train_data","mete_train_data",] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_keys: ["aq_train_data","mete_train_data",] | |
input_keys: [aq_train_data, mete_train_data] |
|
||
MODEL: | ||
input_keys: ["aq_train_data","mete_train_data",] | ||
output_keys: ["label"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output_keys: ["label"] | |
output_keys: [label] |
checkpoint_path: null | ||
|
||
EVAL: | ||
eval_data_path: "/data6/home/yinhang2021/dataset/chongqing_1921/val_data.pkl" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eval_data_path: "/data6/home/yinhang2021/dataset/chongqing_1921/val_data.pkl" | |
eval_data_path: ./dataset/val_data.pkl |
STAFNet_DATA_PATH: "/data6/home/yinhang2021/dataset/chongqing_1921/train_data.pkl" # | ||
DATASET: | ||
label_keys: ["label"] | ||
data_dir: "/data6/home/yinhang2021/dataset/chongqing_1921/train_data.pkl" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- data_dir为什么是具体文件路径而不是某个文件夹路径?
- 此处的路径是否跟STAFNet_DATA_PATH重复了?
cfg.TRAIN.epochs, | ||
ITERS_PER_EPOCH, | ||
eval_during_train=cfg.TRAIN.eval_during_train, | ||
seed=cfg.seed, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seed=cfg.seed, |
""" | ||
Validate after training an epoch | ||
|
||
:param epoch: Integer, current training epoch. | ||
:return: A log that contains information about validation | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
""" | |
Validate after training an epoch | |
:param epoch: Integer, current training epoch. | |
:return: A log that contains information about validation | |
""" |
"sampler": { | ||
"name": "BatchSampler", | ||
"drop_last": False, | ||
"shuffle": True, | ||
}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"sampler": { | |
"name": "BatchSampler", | |
"drop_last": False, | |
"shuffle": True, | |
}, |
# set random seed for reproducibility | ||
ppsci.utils.misc.set_random_seed(42) | ||
# set output directory | ||
OUTPUT_DIR = "./output_example" | ||
# initialize logger | ||
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# set random seed for reproducibility | |
ppsci.utils.misc.set_random_seed(42) | |
# set output directory | |
OUTPUT_DIR = "./output_example" | |
# initialize logger | |
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") |
OUTPUT_DIR = "./output_example" | ||
# initialize logger | ||
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") | ||
multiprocessing.set_start_method("spawn") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这句代码是什么作用?paddle的多卡训练不需要这样吧?
PR types
PR changes
Describe