-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_config_vit.yaml
86 lines (74 loc) · 1.94 KB
/
model_config_vit.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
DEBUG_MODE: False
## If you set debug mode to True, the model will be trained on a small subset of a fake data
## This is useful for debugging the model and the training pipeline
wandb_config:
project: "SdPNet"
group: "XL"
model_config:
embedding_dim: 768
num_blocks: 6
n_head: 8
activation: "gelu"
embedding_activation: "none"
conv_kernel_size: 3
patch_size: 14
ffn_dropout: 0.2
attn_dropout: 0.2
output_classes: 1000
conv_block_num: 2
ff_multiplication_factor: 4
max_image_size: [16, 16]
max_num_registers: 5
conv_first: False
head_output_from_register: True
simple_mlp_output: False
output_head_bias: False
normalize_qv: True
stochastic_depth_p: [0.0, 0.0]
mixer_deptwise_bias: False
mixer_ffn_bias: False
conv_embedding: False
conv_embedding_kernel_size: 5 ## Is not active if conv_embedding is False
trainer_config:
compile_model: False
snapshot_dir: "model"
snapshot_name: "model_1.pt"
save_every: 1
total_epochs: 350 ## Do not forget to change to T_0 below
gradient_accumulation_steps: 1
report_every_epoch: 1
use_cross_entropy: True #True for cross entropy False for BCEloss
label_smoothing: 0.1
ema_decay: 0.999
optimizer_scheduler_config:
optimizer: "ADAMW"
optimizer_config:
lr: 0.0015 #0.0005*batch_size*gradient_accumulation_steps*GPU/512
weight_decay: 0.05
scheduler_config:
cosine:
T_0: 350
eta_min: 1e-5
constant_scheduler:
factor: 0.001
total_iters: 2
linear_scheduler:
total_iters: 5
start_factor: 0.001
data:
Num_Classes: 1000
train_image_size: [224, 224]
val_image_size: [320, 320]
val_crop_size: [224, 224]
train_data_details:
batch_size: 16 #0.005*batch_size*grad_accumulation_step*gpu/512
num_workers: 12
pin_memory: True
persistent_workers: True
prefetch_factor: 4
drop_last: True
val_data_details:
batch_size: 128
num_workers: 4
pin_memory: True
drop_last: True