import torch
import torch.nn as nn
from openfgl.flcore.base import BaseClient
from openfgl.flcore.fedtgp.fedtgp_config import config
[docs]class FedTGPClient(BaseClient):
"""
FedTGPClient implements the client-side operations for the Federated Learning algorithm
described in the paper 'FedTGP: Trainable Global Prototypes with Adaptive-Margin-Enhanced
Contrastive Learning for Data and Model Heterogeneity in Federated Learning'. This class handles
the local training of the model, updating local prototypes, and communication with the server.
Attributes:
local_prototype (dict): Stores the local prototypes for each class.
"""
[docs] def __init__(self, args, client_id, data, data_dir, message_pool, device):
"""
Initializes the FedTGPClient.
Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
data (object): Data specific to the client's task.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
"""
super(FedTGPClient, self).__init__(args, client_id, data, data_dir, message_pool, device)
self.local_prototype = {}
[docs] def execute(self):
"""
Executes the local training process. The method applies the custom loss function that includes
the prototype-based loss, and updates the local prototypes after training.
"""
self.task.loss_fn = self.get_custom_loss_fn()
self.task.train()
self.update_local_prototype()
[docs] def get_custom_loss_fn(self):
"""
Defines a custom loss function that incorporates both the standard classification loss
and the prototype-based loss for contrastive learning.
Returns:
custom_loss_fn (function): A function that computes the loss based on the current round and whether the model is being evaluated on global data.
"""
def custom_loss_fn(embedding, logits, label, mask):
if self.message_pool["round"] == 0 or self.task.num_samples != label.shape[0]: # first round or eval on global
return self.task.default_loss_fn(logits[mask], label[mask])
else:
loss_fedtgp = 0
for class_i in range(self.task.num_global_classes):
selected_idx = self.task.train_mask & (label == class_i)
if selected_idx.sum() == 0:
continue
input = embedding[selected_idx]
target = self.message_pool["server"]["global_prototype"][class_i].expand_as(input)
loss_fedtgp += nn.MSELoss()(input, target)
return self.task.default_loss_fn(logits[mask], label[mask]) + config["fedtgp_lambda"] * loss_fedtgp
return custom_loss_fn
[docs] def update_local_prototype(self):
"""
Updates the local prototypes for each class based on the embeddings generated by the local model after training.
"""
with torch.no_grad():
embedding = self.task.evaluate(mute=True)["embedding"]
for class_i in range(self.task.num_global_classes):
selected_idx = self.task.train_mask & (self.task.data.y.to(self.device) == class_i)
if selected_idx.sum() == 0:
self.local_prototype[class_i] = torch.zeros(self.args.hid_dim).to(self.device)
else:
input = embedding[selected_idx]
self.local_prototype[class_i] = torch.mean(input, dim=0)
[docs] def send_message(self):
"""
Sends a message to the server containing the number of samples and the updated local prototypes.
"""
self.message_pool[f"client_{self.client_id}"] = {
"num_samples": self.task.num_samples,
"local_prototype": self.local_prototype
}