-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathHjj_Read_Input_Cmd.lua
94 lines (76 loc) · 2.54 KB
/
Hjj_Read_Input_Cmd.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
require 'torch'
function func_read_validate_rgn_cmd(cmd, arg)
cmd:text()
cmd:text('Validate Agent:')
cmd:text()
cmd:text('Options:')
cmd:option('-data_path', '', 'Training dataset path')
cmd:option('-name', 'a', 'Name of data output')
cmd:option('-class', 0, 'The class you want to train')
cmd:option('-model_name', '0', 'Name of the pretrained model to load')
cmd:option('-rgn_name', '0', 'Name of the pretrained rgn to load')
cmd:option('-alpha', 0.2, 'action scalar, default')
cmd:option('-log_log', './log/v_log', 'log file')
cmd:option('-max_steps', 50, 'max step for one clip, default ')
cmd:text()
local opt = cmd:parse(arg)
return opt
end
function func_read_training_cmd(cmd,arg)
cmd:text()
cmd:text('Train Agent:')
cmd:text()
cmd:text('Options:')
cmd:option('-data_path', '', 'Training dataset path')
cmd:option('-name', 'a', 'name the models')
cmd:option('-class', 0, 'The class you want to train')
cmd:option('-model_name', '0', 'Name of the pretrained model to load')
cmd:option('-alpha', 0.2, 'action scalar, default')
cmd:option('-log_err','./log/training_error.log', 'log training error file')
cmd:option('-log_log', './log/log', 'log file')
cmd:option('-batch_size', 200, 'batch size, default')
cmd:option('-replay_buffer', 2000, 'experience replay memory size, default')
cmd:option('-lr', 1e-3, 'learning rate, default')
cmd:option('-epochs', 50, 'epochs, default')
cmd:text()
local opt = cmd:parse(arg)
return opt
end
function func_read_validate_cmd(cmd, arg)
cmd:text()
cmd:text('Validate Agent:')
cmd:text()
cmd:text('Options:')
cmd:option('-data_path', '', 'Training dataset path')
cmd:option('-name', 'a', 'Name of data output')
cmd:option('-class', 0, 'The class you want to train')
cmd:option('-model_name', '0', 'Name of the pretrained model to load')
cmd:option('-alpha', 0.2, 'action scalar, default')
cmd:option('-log_log', './log/v_log', 'log file')
cmd:option('-max_steps', 50, 'max step for one clip, default ')
cmd:text()
local opt = cmd:parse(arg)
return opt
end
function func_set_gpu(opt, file)
if opt >= 0 then
require 'cutorch'
require 'cunn'
if opt == 0 then
local gpu_id = tonumber(os.getenv('GPU_ID'))
if gpu_id then
opt = gpu_id+1
end
end
if opt > 0 then
cutorch.setDevice(opt)
end
opt = cutorch.getDevice()
file:write('Using GPU device id:'.. opt-1 .. '\n')
print('Using GPU device id:'.. opt-1)
else
file:write('Using CPU code only. GPU device id:' .. opt .. '\n')
print('Using CPU code only. GPU device id:' .. opt)
end
return opt
end