-
Notifications
You must be signed in to change notification settings - Fork 21
/
opt.py
executable file
·129 lines (121 loc) · 5.55 KB
/
opt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import argparse
import os
from easydict import EasyDict as edict
def get_args():
parser = argparse.ArgumentParser(description="Training options")
parser.add_argument("--data_path", type=str, default="./data",
choices=[
'./datasets/argoverse',
'./datasets/kitti/object/training',
'./datasets/kitti/odometry',
'./datasets/kitti/raw'],
help="Path to the root data directory")
parser.add_argument("--save_path", type=str, default="./models/",
help="Path to save models")
parser.add_argument(
"--load_weights_folder",
type=str,
default="",
help="Path to a pretrained model used for initialization")
parser.add_argument("--model_name", type=str, default="crossView",
help="Model Name with specifications")
parser.add_argument(
"--split",
type=str,
choices=[
"argo",
"3Dobject",
"odometry",
"raw"],
help="Data split for training/validation")
parser.add_argument("--ext", type=str, default="png",
help="File extension of the images")
parser.add_argument("--height", type=int, default=1024,
help="Image height")
parser.add_argument("--width", type=int, default=1024,
help="Image width")
parser.add_argument(
"--type",
type=str,
choices=[
"both",
"static",
"dynamic"],
help="Type of model being trained")
parser.add_argument("--global_seed", type=int, default=0,
help="seed")
parser.add_argument("--batch_size", type=int, default=6,
help="Mini-Batch size")
parser.add_argument("--lr", type=float, default=1e-4, # attention
help="learning rate")
parser.add_argument("--lr_transform", type=float, default=1e-3,
help="learning rate")
parser.add_argument('--lr_steps', default=[50], type=float, nargs="+", # attention
metavar='LRSteps', help='epochs to decay learning rate by 10')
parser.add_argument('--weight_decay', '--wd', default=1e-5, type=float,
metavar='W', help='weight decay (default: 1e-5)')
parser.add_argument("--scheduler_step_size", type=int, default=5,
help="step size for the both schedulers")
parser.add_argument("--static_weight", type=float, default=5.,
help="static weight for calculating loss")
parser.add_argument("--dynamic_weight", type=float, default=15.,
help="dynamic weight for calculating loss")
parser.add_argument("--occ_map_size", type=int, default=256,
help="size of topview occupancy map")
parser.add_argument("--num_class", type=int, default=2,
help="Number of classes")
parser.add_argument("--num_epochs", type=int, default=120,
help="Max number of training epochs")
parser.add_argument("--log_frequency", type=int, default=5,
help="Log files every x epochs")
parser.add_argument("--num_workers", type=int, default=8,
help="Number of cpu workers for dataloaders")
parser.add_argument("--osm_path", type=str, default="./data/osm",
help="OSM path")
parser.add_argument('--log_root', type=str, default=os.getcwd() + '/log')
parser.add_argument('--model_split_save', type=bool, default=True)
configs = edict(vars(parser.parse_args()))
return configs
def get_eval_args():
parser = argparse.ArgumentParser(description="Evaluation options")
parser.add_argument("--data_path", type=str, default="./data",
help="Path to the root data directory")
parser.add_argument("--pretrained_path", type=str, default="./models/",
help="Path to the pretrained model")
parser.add_argument("--osm_path", type=str, default="./data/osm",
help="OSM path")
parser.add_argument(
"--split",
type=str,
choices=[
"argo",
"3Dobject",
"odometry",
"raw"],
help="Data split for training/validation")
parser.add_argument("--ext", type=str, default="png",
help="File extension of the images")
parser.add_argument("--height", type=int, default=1024,
help="Image height")
parser.add_argument("--width", type=int, default=1024,
help="Image width")
parser.add_argument(
"--type",
type=str,
choices=[
"both",
"static",
"dynamic"],
help="Type of model being trained")
parser.add_argument("--occ_map_size", type=int, default=256,
help="size of topview occupancy map")
parser.add_argument("--num_workers", type=int, default=8,
help="Number of cpu workers for dataloaders")
parser.add_argument("--out_dir", type=str,
default="output")
parser.add_argument("--model_name", type=str, default="crossView",
help="Model Name with specifications")
parser.add_argument("--num_class", type=int, default=2,
help="Number of classes")
configs = edict(vars(parser.parse_args()))
return configs