import torch
from openfgl.flcore.base import BaseClient
from openfgl.flcore.feddc.feddc_config import config
[docs]class FedDCClient(BaseClient):
"""
FedDCClient is a client implementation for the Federated Learning algorithm with
Drift Decoupling and Correction (FedDC). It extends the BaseClient class and manages
local training while correcting for local drift to handle non-IID data effectively.
Attributes:
local_drift (list): A list of tensors representing the accumulated drift for each model parameter.
last_update (list): A list of tensors representing the last update applied to each model parameter.
"""
[docs] def __init__(self, args, client_id, data, data_dir, message_pool, device):
"""
Initializes the FedDCClient.
Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
data (object): Data specific to the client's task.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
"""
super(FedDCClient, self).__init__(args, client_id, data, data_dir, message_pool, device)
self.local_drift = [torch.zeros_like(p, requires_grad=False) for p in self.task.model.parameters()]
self.last_update = [torch.zeros_like(p, requires_grad=False) for p in self.task.model.parameters()]
[docs] def get_custom_loss_fn(self):
"""
Returns a custom loss function for the FedDC algorithm. This loss function accounts
for local drift correction in addition to the standard task loss.
Returns:
custom_loss_fn (function): A custom loss function.
"""
def custom_loss_fn(embedding, logits, label, mask):
task_loss = self.task.default_loss_fn(logits[mask], label[mask])
if self.message_pool["round"] != 0:
loss_drift = 0
loss_grad = 0
for local_state, global_state, drift_param, update_param, avg_param in zip(self.task.model.parameters(), self.message_pool["server"]["weight"], self.local_drift, self.last_update, self.message_pool["server"]["avg_update"]):
loss_drift += torch.sum(torch.pow(drift_param + local_state - global_state, 2))
loss_grad += torch.sum(local_state * update_param - avg_param)
return task_loss + (config["feddc_alpha"] / 2) * loss_drift + (1/(self.args.lr * self.args.num_epochs)) * loss_grad
else:
return task_loss
return custom_loss_fn
[docs] def execute(self):
"""
Executes the local training process. This method first synchronizes the local model
with the global model parameters received from the server, then trains the model
using a custom loss function that incorporates drift correction. After training, it
updates the local drift and last update tensors.
"""
with torch.no_grad():
for (local_param, global_param) in zip(self.task.model.parameters(), self.message_pool["server"]["weight"]):
local_param.data.copy_(global_param)
self.task.loss_fn = self.get_custom_loss_fn()
self.task.train()
with torch.no_grad():
for it, (update_param, local_state, global_state) in enumerate(zip(self.last_update, self.task.model.parameters(), self.message_pool["server"]["weight"])):
self.last_update[it] = local_state.detach() - global_state.detach()
self.local_drift[it] += update_param
[docs] def send_message(self):
"""
Sends a message to the server containing the model parameters after training, along with
the last update and local drift tensors, and the number of samples in the client's dataset.
"""
self.message_pool[f"client_{self.client_id}"] = {
"num_samples": self.task.num_samples,
"weight": list(self.task.model.parameters()),
"last_update": self.last_update,
"local_drift": self.local_drift
}