-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFL.py
172 lines (134 loc) · 4.73 KB
/
FL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os # 这句话是导入os模块到当前程序,这个模块提供了一种方便的使用操作系统函数的方法
"""UPLOADING THE DATASETS"""
import sys # 这个模块可供访问由解释器使用或维护的变量和与解释器进行交互的函数
print(
"dataset - sampling - sim_type - seed - n_SGD - lr - decay - p - force - mu"
)
print(sys.argv[1:]) # sys.argv 命令行参数List,第一个元素是程序本身路径,所以从第2个元素开始print
dataset = sys.argv[1]
sampling = sys.argv[2]
sim_type = sys.argv[3]
seed = int(sys.argv[4])
n_SGD = int(sys.argv[5])
lr = float(sys.argv[6])
decay = float(sys.argv[7])
p = float(sys.argv[8])
force = sys.argv[9] == "True"
try:
mu = float(sys.argv[10])
except:
mu = 0.0
"""
如果当try后的语句执行时发生异常,python就跳回到try并执行第一个匹配该异常的except子句,异常处理完毕,
控制流就通过整个try语句(除非在处理异常时又引发新的异常)。
如果在try子句执行时没有发生异常,python将执行else语句后的语句(如果有else的话),然后控制流通过整个try语句。
"""
"""GET THE HYPERPARAMETERS"""
from py_func.hyperparams import get_hyperparams # 如果要从其他python文件中导入函数,就需要加上py_func.hyperparams
n_iter, batch_size, meas_perf_period = get_hyperparams(dataset, n_SGD)
print("number of iterations", n_iter)
print("batch size", batch_size)
print("percentage of sampled clients", p)
print("metric_period", meas_perf_period)
print("regularization term", mu)
"""NAME UNDER WHICH THE EXPERIMENT'S VARIABLES WILL BE SAVED"""
from py_func.hyperparams import get_file_name
file_name = get_file_name(
dataset, sampling, sim_type, seed, n_SGD, lr, decay, p, mu
)
print(file_name)
"""GET THE DATASETS USED FOR THE FL TRAINING"""
from py_func.read_db import get_dataloaders
list_dls_train, list_dls_test = get_dataloaders(dataset, batch_size) # 从read.db函数中调用get_dataloaders,然后读出来list_dls_train, list_dls_test这两个变量
"""NUMBER OF SAMPLED CLIENTS"""
n_sampled = int(p * len(list_dls_train)) # 计算出抽样参与方的数量
print("number fo sampled clients", n_sampled)
"""LOAD THE INTIAL _GLOBAL MODEL"""
from py_func.create_model import load_model # 调用create_model函数中的load_model 就是加载函数的函数
model_0 = load_model(dataset, seed)
print(model_0)
## by zzy begin
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Experiment running on device: {}".format(device))
model_0.to(device=device)
## by zzy end
"""FEDAVG with random sampling"""
if sampling == "random" and (
not os.path.exists(f"saved_exp_info/acc/{file_name}.pkl") or force
): # 如果抽样方式是“random” 并且f"saved_exp_info/acc/{file_name}.pkl"这个文件不存在
from py_func.FedProx import FedProx_sampling_random # 从FedProx这个python文件中导入FedProx_sampling_random函数
FedProx_sampling_random(
model_0,
n_sampled,
list_dls_train,
list_dls_test,
n_iter,
n_SGD,
lr,
file_name,
decay,
meas_perf_period,
mu,
)
"""Run FEDAVG with clustered sampling"""
if (sampling == "clustered_1" or sampling == "clustered_2") and (
not os.path.exists(f"saved_exp_info/acc/{file_name}.pkl") or force
):
from py_func.FedProx import FedProx_clustered_sampling # 针对不同的sampling,导入不同的函数
FedProx_clustered_sampling(
sampling,
model_0,
n_sampled,
list_dls_train,
list_dls_test,
n_iter,
n_SGD,
lr,
file_name,
sim_type,
0,
decay,
meas_perf_period,
mu,
)
"""RUN FEDAVG with perfect sampling for MNIST-shard"""
if (
sampling == "perfect"
and dataset == "MNIST_shard"
and (not os.path.exists(f"saved_exp_info/acc/{file_name}.pkl") or force)
):
from py_func.FedProx import FedProx_sampling_target
FedProx_sampling_target(
model_0,
n_sampled,
list_dls_train,
list_dls_test,
n_iter,
n_SGD,
lr,
file_name,
decay,
mu,
)
"""RUN FEDAVG with its original sampling scheme sampling clients uniformly"""
if sampling == "FedAvg" and (
not os.path.exists(f"saved_exp_info/acc/{file_name}.pkl") or force
):
from py_func.FedProx import FedProx_FedAvg_sampling
FedProx_FedAvg_sampling(
model_0,
n_sampled,
list_dls_train,
list_dls_test,
n_iter,
n_SGD,
lr,
file_name,
decay,
meas_perf_period,
mu,
)
print("EXPERIMENT IS FINISHED") # 实验结束