-
Notifications
You must be signed in to change notification settings - Fork 0
/
local_config.yml
115 lines (115 loc) · 3.21 KB
/
local_config.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
############
# Training #
############
# Dataloader kwargs
data:
train_dist_name: 'norm'
train_dist_kwargs:
loc: 0.01
scale: 0.04
in_dir: '/home/jwp/stage/sl/n2j/n2j/data'
train_hp: [10327, 9559]
val_hp: [10450]
n_train: [10000, 10000]
# Final effective training set size
n_subsample_train: 1000
n_val: [10000]
# Final effective val set size
n_subsample_val: 1000
batch_size: 1000
val_batch_size: 1000
num_workers: 4
# Global (graph-level) target; final_gamma1, final_gamm2 also available
sub_target: ['final_kappa']
# Local (node-level) target
sub_target_local: ['stellar_mass', 'redshift']
# Features available; do not modify (determined at data generation time)
features: ['galaxy_id', 'ra', 'dec', 'redshift',
'ra_true', 'dec_true', 'redshift_true',
'bulge_to_total_ratio_i',
'ellipticity_1_true', 'ellipticity_2_true',
'ellipticity_1_bulge_true', 'ellipticity_1_disk_true',
'ellipticity_2_bulge_true', 'ellipticity_2_disk_true',
'shear1', 'shear2', 'convergence',
'size_bulge_true', 'size_disk_true', 'size_true',
'mag_u_lsst', 'mag_g_lsst', 'mag_r_lsst',
'mag_i_lsst', 'mag_z_lsst', 'mag_Y_lsst']
# Features to use as input
sub_features: ['ra_true', 'dec_true',
'mag_u_lsst', 'mag_g_lsst', 'mag_r_lsst',
'mag_i_lsst', 'mag_z_lsst', 'mag_Y_lsst']
noise_kwargs:
mag:
override_kwargs: null
depth: 5
detection_kwargs:
ref_features: ['mag_i_lsst']
max_vals: [22.0]
# Optimizer kwargs
optimization:
early_stop_memory: 50
weight_local_loss: 0.0338987896
optim_kwargs:
lr: 0.001
weight_decay: 0.0001
lr_scheduler_kwargs:
patience: 5
factor: 0.5
min_lr: 0.0000001
verbose: True
# Model kwargs
model:
dim_local: 50
dim_global: 50
dim_hidden: 50
dim_pre_aggr: 50
n_iter: 5
n_out_layers: 5
dropout: 0.006667011687
global_flow: False
# Trainer attributes
trainer:
device_type: 'cuda'
checkpoint_dir: 'results/E1'
seed: 1028
n_epochs: 5
# If you want to resume training from a checkpoint
resume_from:
checkpoint_path: null
#############
# Inference #
#############
inference_manager:
device_type: 'cuda'
out_dir: 'results/E1/inference'
seed: 1028
test_data:
batch_size: 100
test_hp: [10326]
n_test: [50000]
n_subsample_test: 100
dist_name: 'norm'
dist_kwargs:
loc: 0.01
scale: 0.04
idx: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75,
76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,
91, 92, 93, 94, 95, 96, 97, 98, 99]
summary_stats:
thresholds:
N: [2, 4, 8, 16, 32, 64]
N_inv_dist: [32, 64, 128, 256, 1024, 2048]
# Replace with your own
checkpoint_path: '/home/jwp/stage/sl/n2j/tuning/T8/N2JNet_epoch=82_08-08-2021_03:24.mdl'
run_mcmc: False
extra_mcmc_kwargs:
n_run: 200
n_burn: 20
n_walkers: 20
plot_chain: True
clear: True
n_cores: 4