-
Notifications
You must be signed in to change notification settings - Fork 3
/
remote.py
104 lines (65 loc) · 2.65 KB
/
remote.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#########################
# define all the remote operation based 'ssh' command
#########################
import os, subprocess, sys
import collections
ali_ssh = '[email protected]'
def riseml_path_to_ssh_path(s):
s = '/riseml' + s # add riseml
s = s.replace('/workspace/', '/workspace/engine/') # add user name
return s
def mkdirs(path):
if not os.path.exists(path):
os.makedirs(path)
def cmd(cmd_str):
s = subprocess.check_output(cmd_str, shell=True)
return s.strip().split('\n')
def ssh_find(ssh, find_dir, find_name, other_cmd=''):
return cmd('ssh {} "find {} -name {} {}"'.format(ssh, find_dir, find_name, other_cmd))
def ssh_find_modify_time(host, find_dir, find_name, time_tag='Modify'):
"""
get the file modify time
Args:
host: ssh host
find_dir: host dir
find_name: file name
time_tag: Modify / Access / Change
Returns:
a dict whose keys are file name and values are time string
"""
cmd_out = ssh_find(host, find_dir, find_name, other_cmd='-exec stat {} \;')
res = collections.OrderedDict()
for line in cmd_out:
a = line.split(None, 1)
if a[0].startswith('File:'):
file_name = a[1].strip("'")
elif a[0].startswith(time_tag + ':'):
file_time = a[1]
res[file_name] = file_time
return res
def ssh_find_the_latest_file(host, find_dir, find_name, time_tag='Modify'):
"""
Returns:
file name, file time
"""
ftimes = ssh_find_modify_time(host, find_dir, find_name, time_tag)
ftimes = sorted(ftimes.items(), key=lambda x: x[1], reverse=True)
return ftimes[0]
def download(host, remote_path, local_path):
os.system('rsync -av {}:{} {}'.format(host, remote_path, local_path))
return os.path.join(local_path, os.path.split(remote_path)[-1])
def download_ckpt(host, ckpt_dir, ckpt_name='*.ckpt', local_dir='.'):
mkdirs(local_dir)
latest_ckpt, _ = ssh_find_the_latest_file(host, ckpt_dir, ckpt_name + '.index')
print('download the latest ckpt: %s' % latest_ckpt)
os.system('rsync -av {}:{} {}'.format(host, latest_ckpt[0:-len('index')] + '*', local_dir))
def download_ckpt_based_checkpoint(host, checkpoint_path, local_dir='.'):
mkdirs(local_dir)
s = cmd('ssh {} cat {}'.format(host, checkpoint_path))[0]
ckpt_path = s.split(None, 1)[1].strip('"\n')
ckpt_dir = os.path.dirname(checkpoint_path)
ckpt_name = os.path.split(ckpt_path)[-1]
try:
download_ckpt(host, ckpt_dir, ckpt_name, local_dir)
except subprocess.CalledProcessError:
print('cannot find the ckpt: %s/%s' % (ckpt_dir, ckpt_name))