forked from alibaba/graphlearn-for-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_dist_train_sage_sup.py
75 lines (68 loc) · 3.42 KB
/
run_dist_train_sage_sup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import yaml
import argparse
import paramiko
import click
if __name__ == "__main__":
parser = argparse.ArgumentParser('Run DistRandomSampler benchmarks.')
parser.add_argument('--config', type=str, default='dist_train_sage_sup_config.yml',
help='paths to configuration file for benchmarks')
parser.add_argument('--epochs', type=int, default=1,
help='repeat epochs for sampling')
parser.add_argument('--batch_size', type=int, default=2048,
help='batch size for sampling')
parser.add_argument('--master_addr', type=str, default='0.0.0.0',
help='master ip address for synchronization across all training nodes')
parser.add_argument('--master_port', type=str, default='11345',
help='port for synchronization across all training nodes')
args = parser.parse_args()
config = open(args.config, 'r')
config = yaml.safe_load(config)
dataset = config['dataset']
ip_list, port_list, username_list = config['nodes'], config['ports'], config['usernames']
dst_path_list = config['dst_paths']
node_ranks = config['node_ranks']
num_nodes = len(node_ranks)
visible_devices = config['visible_devices']
python_bins = config['python_bins']
num_cores = len(visible_devices[0].split(','))
in_channel = str(config['in_channel'])
out_channel = str(config['out_channel'])
dataset_path = "../../data/"
passwd_dict = {}
for username, ip in zip(username_list, ip_list):
passwd_dict[ip+username] = click.prompt('passwd for '+username+'@'+ip,
hide_input=True)
for username, ip, port, dst, noderk, device, pythonbin in zip(
username_list,
ip_list,
port_list,
dst_path_list,
node_ranks,
visible_devices,
python_bins,
):
trans = paramiko.Transport((ip, port))
trans.connect(username=username, password=passwd_dict[ip+username])
ssh = paramiko.SSHClient()
ssh._transport = trans
to_dist_dir = 'cd '+dst+'/examples/distributed/ '
exec_example = "tmux new -d 'CUDA_VISIBLE_DEVICES="+device+" "+pythonbin+" dist_train_sage_supervised.py --dataset="+dataset+" --dataset_root_dir=../../data/"+dataset+" --in_channel="+in_channel+" --out_channel="+out_channel+" --node_rank="+str(noderk)+" --num_dataset_partitions="+str(num_nodes)+" --num_nodes="+str(num_nodes)+" --num_training_procs="+str(num_cores)+" --master_addr="+args.master_addr+" --training_pg_master_port="+args.master_port+" --train_loader_master_port="+str(int(args.master_port)+1)+" --test_loader_master_port="+str(int(args.master_port)+2)+" --batch_size="+str(args.batch_size)+" --epochs="+str(args.epochs)
print(to_dist_dir + ' && '+ exec_example + " '")
stdin, stdout, stderr = ssh.exec_command(to_dist_dir+' && '+exec_example+" '", bufsize=1)
print(stdout.read().decode())
print(stderr.read().decode())
ssh.close()