-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
hid_net #200
hid_net #200
Conversation
# os.environ['TL_BACKEND'] = 'torch' | ||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | ||
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分修改的代码不需要提交
@@ -7,8 +7,8 @@ | |||
from ggl_build_extension import PyCudaExtension, PyCPUExtension | |||
|
|||
# TODO will depend on different host | |||
WITH_CUDA = False | |||
# WITH_CUDA = True | |||
#WITH_CUDA = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分的修改代码也不需要提交
@@ -18,8 +18,9 @@ | |||
from .to_dense_adj import to_dense_adj | |||
from .smiles import from_smiles | |||
from .shortest_path import shortest_path_distance, batched_shortest_path_distance | |||
|
|||
__all__ = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分也可以删除
@@ -0,0 +1,228 @@ | |||
import os | |||
import sys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要添加设置运行后端的代码
metrics.reset() | ||
return rst | ||
|
||
def main(args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分代码可以放在 if __name__=='__main__':
中
|
||
val_accs = [] | ||
test_accs = [] | ||
|
||
|
||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果只运行一次的话,不需要定义 val_accs
, test_accs
, accs
等变量,同时,记得删除多余的空行,删除之后更美观一些
val_accs.append(val_acc) | ||
|
||
|
||
test_preds=tlx.gather(out,data['test_mask']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在每个epoch中,只需要跑模型在验证集上的结果即可,模型训练完成之后,可以在测试集上进行测试,这部分代码可以参考其他模型的实现代码。
parser = argparse.ArgumentParser() | ||
parser.add_argument('--times', type=int, default=3, help='config times') | ||
parser.add_argument('--seed', type=int, default=9, help='random seed') | ||
parser.add_argument('--repeat', type=int, default=150, help='repeat time') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议把超参数名称设置为:n_epoch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
而且在超参数中,好像有些超参数有些并没有用到,后续可以删除。
Description
Checklist
Please feel free to remove inapplicable items for your PR.
or have been fixed to be compatible with this change
Changes