-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathempirical_metric.py
63 lines (52 loc) · 3.63 KB
/
empirical_metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
class empirical_metrics_batch:
def __init__(self, target_batch_size, source_grads, target_grads):
self.target_batch_size = target_batch_size # we don't need this for the source client
self.source_grads = source_grads # the source grad is simply the average of all batches size list of [(M, )]
self.target_grads = target_grads # A tensor of size (N, M) where N is the number of batches and M is the dim
self.target_grad = torch.mean(target_grads, dim=0) # the average of self.target_grads, size of (M, )
# call self.compute_quantities() to compute the following quantities after getting the above three quantities
self.target_var = None
self.source_target_var = [] # lenght of number of source clients
# self.taus = [] # length of number of source clients
self.projected_grads_norm_square = []
self.deltas = []
self.compute_quantities()
def compute_quantities(self):
num_batches, dim = self.target_grads.shape
# compute target variance
sample_target_var = torch.sum((self.target_grads - self.target_grad) ** 2) / (num_batches - 1) / dim
self.target_var = sample_target_var / num_batches
# compute norm of the target gradients
self.target_norm_square = torch.norm(self.target_grad).item() ** 2 / dim
# compute source target difference
for source_grad in self.source_grads:
sample_source_target_var = torch.sum((self.target_grads - source_grad) ** 2) / num_batches / dim
self.source_target_var.append(max(sample_source_target_var - sample_target_var, 0.))
# compute tau
# eps = 0.0001 # room to numerical error
# diff = torch.norm(self.target_grad - source_grad)
# cos_rho = (source_grad * self.target_grad).sum() / torch.norm(self.target_grad) / torch.norm(source_grad)
# sin_rho = (1 - cos_rho ** 2) ** 0.5
# print(sin_rho)
# if diff < eps:
# tau = 0
# else:
# tau = (torch.norm(self.target_grad) * sin_rho / diff).item()
projected_grads = self.target_grads - (torch.sum(self.target_grads * source_grad, dim=1) * source_grad.view([-1, 1])).T / torch.norm(source_grad) ** 2
projected_grad = self.target_grad - torch.sum(self.target_grad * source_grad) * source_grad / torch.norm(source_grad) ** 2
projected_grads_var = torch.sum((projected_grads - projected_grad) ** 2) / (num_batches - 1) / dim
projected_grads_norm_var = torch.mean(torch.norm(projected_grads, dim=1) ** 2) / dim
self.projected_grads_norm_square.append(max(projected_grads_norm_var - projected_grads_var, 0.))
# compute delta
inner_products = torch.sum(self.target_grads * source_grad, dim=1)
delta = torch.sum(inner_products > 0) / num_batches
self.deltas.append(1 - (1 - delta.item()) / num_batches)
# self.taus.append(tau)
# print(self.deltas, self.taus, self.source_target_var, self.target_var)
def return_fedda_beta(self):
return [self.target_var / (self.target_var + s_t_var) for s_t_var in self.source_target_var]
def return_fedgp_with_thresh_beta(self):
return [self.target_var / (self.target_var + self.deltas[idx] * self.projected_grads_norm_square[idx] + (1-self.deltas[idx]) * self.target_norm_square) for idx in range(len(self.source_grads))]
def return_fedgp_beta(self):
return [self.target_var / (self.target_var + self.projected_grads_norm_square[idx]) for idx in range(len(self.source_grads))]