-
Notifications
You must be signed in to change notification settings - Fork 62
/
Copy pathnode.jinja.yaml
29 lines (29 loc) · 1.37 KB
/
node.jinja.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
replication: {{ replication }}
batch_size: {{ task.get_hyper_param(tpe, 'bs') }}
test_batch_size: {{ task.get_hyper_param(tpe, 'test_bs') }}
cuda: {{ task.get_learn_param('cuda') }}
scheduler_step_size: {{ task.get_scheduler_param(tpe, 'scheduler_step_size') }}
scheduler_gamma: {{ task.get_scheduler_param(tpe, 'scheduler_gamma') }}
rounds: {{ task.get_learn_param('rounds') }}
epochs: {{ task.get_learn_param('epochs_per_round') }}
lr: {{ task.get_optimizer_param(tpe, 'lr') }}
momentum: {{ task.get_optimizer_param(tpe, 'momentum') }}
shuffle: {{ task.get_sampler_param(tpe, 'shuffle') }}
log_interval: 10
min_lr: {{ task.get_scheduler_param(tpe, 'min_lr') }}
rng_seed: {{ task.get_net_param('seed') }}
optimizer: {{ task.get_optimizer_param(tpe, 'type').value }}
optimizer_args: {{ task.get_optimizer_args(tpe) }}
loss_function: {{ task.get_net_param('loss_function').value }}
clients_per_round: {{ task.get_learn_param('clients_per_round')}}
distributed: true
single_machine: false
aggregation: {{ task.get_learn_param('aggregation').value }}
dataset_name: {{ task.get_net_param('dataset').value }}
net_name: {{ task.get_net_param('network').value }}
data_sampler: {{ task.get_sampler_param(tpe, 'type').value }}
data_sampler_args: {{ task.get_sampler_args(tpe) }}
real_time: true
save_data_append: true
replication_id: {{ task.get_net_param('replication') }}
output_path: {{ experiment_path }}