-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_ctrl_transduct.py
60 lines (52 loc) · 2.26 KB
/
train_ctrl_transduct.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from deeprobust.graph.data import Dataset
import numpy as np
import random
import time
import argparse
import torch
from utils import *
import torch.nn.functional as F
from ctrl_agent_transduct import CTRL
from utils_graphsaint import DataGraphSAINT
parser = argparse.ArgumentParser()
parser.add_argument('--gpu_id', type=int, default=1, help='gpu id')
parser.add_argument('--dataset', type=str, default='citeseer')
parser.add_argument('--dis_metric', type=str, default='norm')
parser.add_argument('--epochs', type=int, default=600)
parser.add_argument('--nlayers', type=int, default=2)
parser.add_argument('--hidden', type=int, default=256)
parser.add_argument('--lr_adj', type=float, default=1e-4)
parser.add_argument('--lr_feat', type=float, default=1e-4)
parser.add_argument('--lr_model', type=float, default=0.01)
parser.add_argument('--weight_decay', type=float, default=0.0)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--normalize_features', type=bool, default=True)
parser.add_argument('--keep_ratio', type=float, default=1.0)
parser.add_argument('--reduction_rate', type=float, default=1)
parser.add_argument('--seed', type=int, default=1, help='Random seed.')
parser.add_argument('--alpha', type=float, default=0, help='regularization term.')
parser.add_argument('--debug', type=int, default=0)
parser.add_argument('--sgc', type=int, default=1)
parser.add_argument('--inner', type=int, default=0)
parser.add_argument('--outer', type=int, default=20)
parser.add_argument('--save', type=int, default=0)
parser.add_argument('--one_step', type=int, default=0)
parser.add_argument('--beta', type=float, default=0.7, help='weight')
parser.add_argument('--init_way', type=str, default='K-means')
args = parser.parse_args()
print(args)
torch.cuda.set_device(args.gpu_id)
# random seed setting
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
data_graphsaint = ['flickr', 'reddit', 'ogbn-arxiv', 'ogbn-arxiv-xrt']
if args.dataset in data_graphsaint:
data = DataGraphSAINT(args.dataset)
data_full = data.data_full
else:
data_full = get_dataset(args.dataset, args.normalize_features)
data = Transd2Ind(data_full, keep_ratio=args.keep_ratio)
agent = CTRL(data, args, device='cuda:1')
agent.train()