diff --git a/fedlab/contrib/algorithm/qfedavg.py b/fedlab/contrib/algorithm/qfedavg.py index 1f9c9e4a..0c33ced3 100644 --- a/fedlab/contrib/algorithm/qfedavg.py +++ b/fedlab/contrib/algorithm/qfedavg.py @@ -2,6 +2,7 @@ from .basic_server import SyncServerHandler from .basic_client import SGDClientTrainer +from .basic_client import SGDSerialClientTrainer ################## @@ -40,7 +41,7 @@ def uplink_package(self): def setup_optim(self, epochs, batch_size, lr, q): super().setup_optim(epochs, batch_size, lr) self.q = q - + def train(self, model_parameters, train_loader) -> None: """Client trains its local model on local dataset. Args: @@ -72,11 +73,12 @@ def train(self, model_parameters, train_loader) -> None: ret_loss + 1e-10, self.q - 1) * grad.norm( )**2 + 1.0 / self.lr * np.float_power(ret_loss + 1e-10, self.q) + class qFedAvgSerialClientTrainer(SGDSerialClientTrainer): def setup_optim(self, epochs, batch_size, lr, q): super().setup_optim(epochs, batch_size, lr) self.q = q - + def train(self, model_parameters, train_loader) -> None: """Client trains its local model on local dataset. Args: