Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hid_net #200

Closed
wants to merge 6 commits into from
Closed

hid_net #200

wants to merge 6 commits into from

Conversation

1920309095
Copy link
Contributor

Description

Checklist

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [$CATEGORY] (such as [NN], [Model], [Doc], [Feature]])
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented
  • To the best of my knowledge, examples are either not affected by this change,
    or have been fixed to be compatible with this change
  • Related issue is referred in this PR

Changes

# os.environ['TL_BACKEND'] = 'torch'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分修改的代码不需要提交

@@ -7,8 +7,8 @@
from ggl_build_extension import PyCudaExtension, PyCPUExtension

# TODO will depend on different host
WITH_CUDA = False
# WITH_CUDA = True
#WITH_CUDA = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分的修改代码也不需要提交

@@ -18,8 +18,9 @@
from .to_dense_adj import to_dense_adj
from .smiles import from_smiles
from .shortest_path import shortest_path_distance, batched_shortest_path_distance

__all__ = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分也可以删除

@@ -0,0 +1,228 @@
import os
import sys
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要添加设置运行后端的代码

metrics.reset()
return rst

def main(args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分代码可以放在 if __name__=='__main__':

Comment on lines +136 to +143

val_accs = []
test_accs = []





Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果只运行一次的话,不需要定义 val_accs, test_accs, accs 等变量,同时,记得删除多余的空行,删除之后更美观一些

val_accs.append(val_acc)


test_preds=tlx.gather(out,data['test_mask'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在每个epoch中,只需要跑模型在验证集上的结果即可,模型训练完成之后,可以在测试集上进行测试,这部分代码可以参考其他模型的实现代码。

parser = argparse.ArgumentParser()
parser.add_argument('--times', type=int, default=3, help='config times')
parser.add_argument('--seed', type=int, default=9, help='random seed')
parser.add_argument('--repeat', type=int, default=150, help='repeat time')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议把超参数名称设置为:n_epoch

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

而且在超参数中,好像有些超参数有些并没有用到,后续可以删除。

@gyzhou2000 gyzhou2000 closed this May 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants